Source code for mpas_tools.ocean.viz.transects

#!/usr/bin/env python
import argparse

import cmocean
import numpy as np
import xarray as xr
from geometric_features import read_feature_collection, FeatureCollection
import matplotlib.pyplot as plt
from matplotlib.tri import Triangulation
from mpas_tools.viz import mesh_to_triangles
from mpas_tools.viz.transects import find_transect_cells_and_weights, \
    make_triangle_tree
from mpas_tools.ocean.transects import find_transect_levels_and_weights, \
    interp_mpas_to_transect_triangles, get_outline_segments

from mpas_tools.ocean.viz.inset import add_inset
from mpas_tools.viz.colormaps import register_sci_viz_colormaps


[docs] def plot_ocean_transects(fc, ds, ds_mesh=None, variable_list=None, cmap=None, flip=False): """ Plot images of the given variables on the given transects. One image named ``<transect_name>_<variable_name>.png`` will be produced in the current directory for each transect and variable Parameters ---------- fc : geometric_features.FeatureCollection The transects to plot ds : xarray.Dataset The MPAS-Ocean dataset to plot ds_mesh : xarray.Dataset, optional The MPAS-Ocean mesh to use for plotting, the same as ``ds`` by default variable_list : list of str, optional The variables to plot cmap : str, optional The name of a colormap to use flip : book, optional Whether to flip the x axes of all transect plot """ if 'Time' in ds.dims: ds = ds.isel(Time=0) if 'Time' in ds_mesh.dims: ds_mesh = ds_mesh.isel(Time=0) transects = _compute_transects(fc, ds_mesh, flip) print('\nBuilding transect geometry...') fc_transects = dict() for transect in fc.features: transect_name = transect['properties']['name'] print(f' {transect_name}') fc_transects[transect_name] = FeatureCollection(features=[transect]) register_sci_viz_colormaps() if variable_list is None: variable_list = list() for var_name in ds.data_vars: var = ds[var_name] if 'nCells' in var.dims and 'nVertLevels' in var.dims: variable_list.append(var_name) print('\nPlotting...') for var_name in variable_list: var = ds[var_name] assert 'nCells' in var.dims and 'nVertLevels' in var.dims for transect_name, ds_transect in transects.items(): print(f' {transect_name} {var_name}') _plot_transect(ds_transect, var, var_name, transect_name, cmap, fc_transects[transect_name])
def main(): parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawTextHelpFormatter) parser.add_argument('-g', '--geojson', dest='geojson_filename', required=True, help='A geojson file with transects to plot') parser.add_argument('-m', '--mesh', dest='mesh_filename', help='An MPAS-Ocean mesh file. If not specified, the ' 'MPAS-Ocean data file must contain the mesh.') parser.add_argument('-f', '--file', dest='filename', required=True, help='An MPAS-Ocean data file') parser.add_argument("-v", "--variable_list", dest="variable_list", nargs='+', help="List of variables to plot. All variables on " "cells in the data file is the default.") parser.add_argument('-c', '--colormap', dest='colormap', help='A colormap to use for the plots, default ' 'depends on the field name.') parser.add_argument('--flip', dest='flip', action='store_true', help='Flip the x axis for all transects') args = parser.parse_args() fc = read_feature_collection(args.geojson_filename) ds = xr.open_dataset(args.filename) if args.mesh_filename is not None: ds_mesh = xr.open_dataset(args.mesh_filename) else: ds_mesh = ds variable_list = args.variable_list plot_ocean_transects(fc=fc, ds=ds, ds_mesh=ds_mesh, variable_list=variable_list, cmap=args.colormap, flip=args.flip) def _compute_transects(fc, ds_mesh, flip): """ build a sequence of triangles showing the transect intersecting mpas cells """ ds_tris = mesh_to_triangles(ds_mesh) triangle_tree = make_triangle_tree(ds_tris) transects = dict() for transect in fc.features: transect_name = transect['properties']['name'] assert transect['geometry']['type'] == 'LineString' coordinates = transect['geometry']['coordinates'] transect_lon, transect_lat = zip(*coordinates) transect_lon = np.array(transect_lon) transect_lat = np.array(transect_lat) if flip: transect_lon = transect_lon[::-1] transect_lat = transect_lat[::-1] transect_lon = xr.DataArray(data=transect_lon, dims=('nPoints',)) transect_lat = xr.DataArray(data=transect_lat, dims=('nPoints',)) ds_mpas_transect = find_transect_cells_and_weights( transect_lon, transect_lat, ds_tris, ds_mesh, triangle_tree, degrees=True) ds_mpas_transect = find_transect_levels_and_weights( ds_mpas_transect, ds_mesh.layerThickness, ds_mesh.bottomDepth, ds_mesh.maxLevelCell - 1) if 'landIceFraction' in ds_mesh: interp_cell_indices = ds_mpas_transect.interpHorizCellIndices interp_cell_weights = ds_mpas_transect.interpHorizCellWeights land_ice_fraction = ds_mesh.landIceFraction.isel( nCells=interp_cell_indices) land_ice_fraction = (land_ice_fraction * interp_cell_weights).sum( dim='nHorizWeights') ds_mpas_transect['landIceFraction'] = land_ice_fraction ds_mpas_transect['x'] = ds_mpas_transect.dNode.isel( nSegments=ds_mpas_transect.segmentIndices, nHorizBounds=ds_mpas_transect.nodeHorizBoundsIndices) ds_mpas_transect['z'] = ds_mpas_transect.zTransectNode ds_mpas_transect.compute() transects[transect_name] = ds_mpas_transect return transects def _plot_transect(ds_transect, mpas_field, var_name, transect_name, cmap, fc): """ plot a transect showing the field on the MPAS-Ocean mesh and save to a file """ transect_prefix = transect_name.replace(' ', '_') transect_field = interp_mpas_to_transect_triangles(ds_transect, mpas_field) units = None if 'units' in mpas_field.attrs: units = mpas_field.attrs['units'] x_outline, z_outline = get_outline_segments(ds_transect) x_outline = 1e-3 * x_outline colormaps = dict( temperature='cmo.thermal', salinity='cmo.haline', density='cmo.dense', ) if cmap is None: for contains, map_name in colormaps.items(): if contains in var_name.lower(): cmap = map_name tri_mask = np.logical_not(transect_field.notnull().values) # if any node of a triangle is masked, the triangle is masked # tri_mask = np.amax(tri_mask, axis=1) triangulation_args = _get_ds_triangulation_args(ds_transect) triangulation_args['mask'] = tri_mask tris = Triangulation(**triangulation_args) fig = plt.figure(figsize=(12, 6)) ax = plt.gca() plt.tripcolor(tris, facecolors=transect_field.values, shading='flat', cmap=cmap) plt.plot(x_outline, z_outline, 'k') if units is not None: colorbar_label = f'{var_name} ({units})' else: colorbar_label = f'{var_name}' plt.colorbar(label=colorbar_label) plt.title(f'{var_name} through {transect_name}') plt.xlabel('x (km)') plt.ylabel('z (m)') # make a red start axis and green end axis to correspond to the dots # in the inset ax.spines['left'].set_color('red') ax.spines['right'].set_color('green') ax.spines['left'].set_linewidth(4) ax.spines['right'].set_linewidth(4) plt.tight_layout(pad=0.5, h_pad=0.5, rect=[0.0, 0.0, 1.0, 1.0]) add_inset(fig, fc) plt.savefig(f'{transect_prefix}_{var_name}.png', dpi=200) plt.close() def _get_ds_triangulation_args(ds_transect): """ get arguments for matplotlib Triangulation from triangulation dataset """ n_transect_triangles = ds_transect.sizes['nTransectTriangles'] d_node = ds_transect.dNode.isel( nSegments=ds_transect.segmentIndices, nHorizBounds=ds_transect.nodeHorizBoundsIndices) x = 1e-3 * d_node.values.ravel() z_transect_node = ds_transect.zTransectNode y = z_transect_node.values.ravel() tris = np.arange(3 * n_transect_triangles).reshape( (n_transect_triangles, 3)) triangulation_args = dict(x=x, y=y, triangles=tris) return triangulation_args if __name__ == '__main__': main()