Source code for mpas_tools.ocean.viz.transect.plot

import argparse

import cmocean  # noqa: F401
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from geometric_features import FeatureCollection, read_feature_collection

from mpas_tools.ocean.viz.inset import add_inset
from mpas_tools.ocean.viz.transect.vert import (
    compute_transect,
    interp_mpas_to_transect_cells,
    interp_mpas_to_transect_nodes,
)
from mpas_tools.viz.colormaps import register_sci_viz_colormaps


[docs] def plot_transect( ds_transect, mpas_field=None, out_filename=None, ax=None, title=None, vmin=None, vmax=None, colorbar_label=None, cmap=None, figsize=(12, 6), dpi=200, method='flat', outline_color='black', ssh_color=None, seafloor_color=None, interface_color=None, cell_boundary_color=None, linewidth=1.0, color_start_and_end=False, start_color='red', end_color='green', ): """ plot a transect showing the field on the MPAS-Ocean mesh and save to a file Parameters ---------- ds_transect : xarray.Dataset A transect dataset from :py:func:`mpas_tools.ocean.viz.transect.vert.compute_transect()` mpas_field : xarray.DataArray The MPAS-Ocean 3D field to plot out_filename : str, optional The png file to write out to ax : matplotlib.axes.Axes Axes to plot to if making a multi-panel figure title : str The title of the plot vmin : float, optional The minimum values for the colorbar vmax : float, optional The maximum values for the colorbar colorbar_label : str, optional The colorbar label, or ``None`` if no colorbar is to be included. Use an empty string to display a colorbar without a label. cmap : str, optional The name of a colormap to use figsize : tuple, optional The size of the figure in inches dpi : int, optional The dots per inch of the image method : {'flat', 'bilinear'}, optional The type of interpolation to use in plots. ``flat`` means constant values over each MPAS cell. ``bilinear`` means smooth interpolation between horizontally between cell centers and vertical between the middle of layers. outline_color : str or None, optional The color to use to outline the transect or ``None`` for no outline ssh_color : str or None, optional The color to use to plot the SSH (sea surface height) or ``None`` if not plotting the SSH (except perhaps as part of the outline) seafloor_color : str or None, optional The color to use to plot the seafloor depth or ``None`` if not plotting the seafloor depth (except perhaps as part of the outline) interface_color : str or None, optional The color to use to plot interfaces between layers or ``None`` if not plotting the layer interfaces cell_boundary_color : str or None, optional The color to use to plot vertical boundaries between cells or ``None`` if not plotting cell boundaries. Typically, ``cell_boundary_color`` will be used along with ``interface_color`` to outline cells both horizontally and vertically. linewidth : float, optional The width of outlines, interfaces and cell boundaries color_start_and_end : bool, optional Whether to color the left and right axes of the transect, which is useful if the transect is also being plotted in an inset or on top of a horizontal field start_color : str, optional The color of left axis marking the start of the transect if ``plot_start_end == True`` end_color : str, optional The color of right axis marking the end of the transect if ``plot_start_end == True`` """ if ax is None and out_filename is None: raise ValueError('One of ax or out_filename must be supplied') create_fig = ax is None if create_fig: plt.figure(figsize=figsize) ax = plt.subplot(111) z = ds_transect.zTransectNode x = 1e-3 * ds_transect.dNode.broadcast_like(z) if mpas_field is not None: if method == 'flat': transect_field = interp_mpas_to_transect_cells( ds_transect, mpas_field ) shading = 'flat' elif method == 'bilinear': transect_field = interp_mpas_to_transect_nodes( ds_transect, mpas_field ) shading = 'gouraud' else: raise ValueError(f'Unsupported method: {method}') pc = ax.pcolormesh( x.values, z.values, transect_field.values, shading=shading, cmap=cmap, vmin=vmin, vmax=vmax, zorder=0, ) ax.autoscale(tight=True) if colorbar_label is not None: plt.colorbar( pc, extend='both', shrink=0.7, ax=ax, label=colorbar_label ) _plot_interfaces( ds_transect, ax, interface_color, cell_boundary_color, ssh_color, seafloor_color, color_start_and_end, start_color, end_color, linewidth, ) _plot_outline(x, z, ds_transect.validCells, ax, outline_color, linewidth) ax.set_xlabel('transect distance (km)') ax.set_ylabel('z (m)') if create_fig: if title is not None: plt.title(title) plt.savefig(out_filename, dpi=dpi, bbox_inches='tight', pad_inches=0.2) plt.close()
[docs] def plot_feature_transects( fc, ds, ds_mesh=None, variable_list=None, cmap=None, flip=False, write_netcdf=False, method='flat', add_z=False, ): """ Plot images of the given variables on the given transects. One image named ``<transect_name>_<variable_name>.png`` will be produced in the current directory for each transect and variable Parameters ---------- fc : geometric_features.FeatureCollection The transects to plot ds : xarray.Dataset The MPAS-Ocean dataset to plot ds_mesh : xarray.Dataset, optional The MPAS-Ocean mesh to use for plotting, the same as ``ds`` by default variable_list : list of str, optional The variables to plot cmap : str, optional The name of a colormap to use flip : book, optional Whether to flip the x axes of all transect plot write_netcdf : bool, optional Whether to write a NetCDF file for the transect in addition to the image method : {'flat', 'bilinear'}, optional The type of interpolation to use in plots. ``flat`` means constant values over each MPAS cell. ``bilinear`` means smooth interpolation between horizontally between cell centers and vertical between the middle of layers. add_z : bool, optional Whether to add zMid and zInterface to the mesh dataset """ if 'Time' in ds.dims: ds = ds.isel(Time=0) if 'Time' in ds_mesh.dims: ds_mesh = ds_mesh.isel(Time=0) if add_z: _add_z(ds_mesh) print('\nBuilding transect geometry...') transects = _compute_feature_transects(fc, ds_mesh, flip) fc_transects = dict() for transect in fc.features: transect_name = transect['properties']['name'] fc_transects[transect_name] = FeatureCollection(features=[transect]) register_sci_viz_colormaps() if variable_list is None: variable_list = list() for var_name in ds.data_vars: var = ds[var_name] if 'nCells' in var.dims and ( 'nVertLevels' in var.dims or 'nVertLevelsP1' in var.dims ): variable_list.append(var_name) print('\nPlotting...') for var_name in variable_list: if var_name in ds: var = ds[var_name] elif var_name in ds_mesh: var = ds_mesh[var_name] else: raise ValueError( f'{var_name} not found in either the main or the ' f'mesh dataset (if any)' ) assert 'nCells' in var.dims and ( 'nVertLevels' in var.dims or 'nVertLevelsP1' in var.dims ) for transect_name, ds_transect in transects.items(): print(f' {transect_name} {var_name}') _plot_feature_transect( ds_transect, var, var_name, transect_name, cmap, fc_transects[transect_name], write_netcdf, method, )
def plot_feature_transects_main(): parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawTextHelpFormatter ) parser.add_argument( '-g', '--geojson', dest='geojson_filename', required=True, help='A geojson file with transects to plot', ) parser.add_argument( '-m', '--mesh', dest='mesh_filename', help='An MPAS-Ocean mesh file. If not specified, the ' 'MPAS-Ocean data file must contain the mesh.', ) parser.add_argument( '-f', '--file', dest='filename', required=True, help='An MPAS-Ocean data file', ) parser.add_argument( '-v', '--variable_list', dest='variable_list', nargs='+', help='List of variables to plot. All variables on ' 'cells in the data file is the default.', ) parser.add_argument( '-c', '--colormap', dest='colormap', help='A colormap to use for the plots, default ' 'depends on the field name.', ) parser.add_argument( '--flip', dest='flip', action='store_true', help='Flip the x axis for all transects', ) parser.add_argument( '--write_netcdf', dest='write_netcdf', action='store_true', help='Whether to write a NetCDF file for the transect ' 'in addition to the image', ) parser.add_argument( '--method', dest='method', default='flat', help='The type of interpolation to use in plots. ' 'Options are "flat" and "bilinear"', ) parser.add_argument( '--add_z', dest='add_z', action='store_true', help='Whether to add zMid and zInterface to the mesh', ) args = parser.parse_args() fc = read_feature_collection(args.geojson_filename) ds = xr.open_dataset(args.filename) if args.mesh_filename is not None: ds_mesh = xr.open_dataset(args.mesh_filename) else: ds_mesh = ds variable_list = args.variable_list if 'Time' in ds.dims: ds = ds.isel(Time=0) if 'Time' in ds_mesh.dims: ds_mesh = ds_mesh.isel(Time=0) plot_feature_transects( fc=fc, ds=ds, ds_mesh=ds_mesh, variable_list=variable_list, cmap=args.colormap, flip=args.flip, write_netcdf=args.write_netcdf, method=args.method, add_z=args.add_z, ) def _plot_interfaces( ds_transect, ax, interface_color, cell_boundary_color, ssh_color, seafloor_color, color_start_and_end, start_color, end_color, linewidth, ): if cell_boundary_color is not None: x_bnd = 1e-3 * ds_transect.dCellBoundary.values.T z_bnd = ds_transect.zCellBoundary.values.T ax.plot( x_bnd, z_bnd, color=cell_boundary_color, linewidth=linewidth, zorder=1, ) if interface_color is not None: x_int = 1e-3 * ds_transect.dInterfaceSegment.values.T z_int = ds_transect.zInterfaceSegment.values.T ax.plot( x_int, z_int, color=interface_color, linewidth=linewidth, zorder=2 ) if ssh_color is not None: valid = ds_transect.validNodes.any(dim='nVertNodes') x_ssh = 1e-3 * ds_transect.dNode.values z_ssh = ds_transect.ssh.where(valid).values ax.plot(x_ssh, z_ssh, color=ssh_color, linewidth=linewidth, zorder=4) if seafloor_color is not None: valid = ds_transect.validNodes.any(dim='nVertNodes') x_floor = 1e-3 * ds_transect.dNode.values z_floor = ds_transect.zSeafloor.where(valid).values ax.plot( x_floor, z_floor, color=seafloor_color, linewidth=linewidth, zorder=5, ) if color_start_and_end: ax.spines['left'].set_color(start_color) ax.spines['left'].set_linewidth(4 * linewidth) ax.spines['right'].set_color(end_color) ax.spines['right'].set_linewidth(4 * linewidth) def _plot_outline( x, z, valid_cells, ax, outline_color, linewidth, epsilon=1e-6 ): if outline_color is not None: # add a buffer of invalid values around the edge of the domain # and make copies of each node. The validity of each copy of the node # corresponds to the validity of the adjacent cell valid = np.zeros( (2 * valid_cells.shape[0] + 2, 2 * valid_cells.shape[1] + 2), dtype=float, ) z_buf = np.zeros(valid.shape, dtype=float) x_buf = np.zeros(valid.shape, dtype=float) valid_cells = valid_cells.astype(float) # each interior node get the value from its cell valid[1:-2:2, 1:-2:2] = valid_cells valid[2:-1:2, 1:-2:2] = valid_cells valid[1:-2:2, 2:-1:2] = valid_cells valid[2:-1:2, 2:-1:2] = valid_cells z_buf[:-1:2, :-1:2] = z.values z_buf[1::2, :-1:2] = z.values z_buf[:-1:2, 1::2] = z.values z_buf[1::2, 1::2] = z.values x_buf[:-1:2, :-1:2] = x.values x_buf[1::2, :-1:2] = x.values x_buf[:-1:2, 1::2] = x.values x_buf[1::2, 1::2] = x.values ax.contour( x_buf, z_buf, valid, levels=[1.0 - epsilon], colors=outline_color, linewidths=linewidth, zorder=3, ) def _compute_feature_transects(fc, ds_mesh, flip): """ build a sequence of triangles showing the transect intersecting mpas cells """ transects = dict() layer_thickness = ds_mesh.layerThickness bottom_depth = ds_mesh.bottomDepth max_level_cell = ds_mesh.maxLevelCell - 1 if 'minLevelCell' in ds_mesh: min_level_cell = ds_mesh.minLevelCell - 1 else: min_level_cell = xr.zeros_like(max_level_cell) spherical = ds_mesh.attrs['on_a_sphere'] == 'YES' for transect in fc.features: transect_name = transect['properties']['name'] print(f' {transect_name}') assert transect['geometry']['type'] == 'LineString' coordinates = transect['geometry']['coordinates'] transect_lon, transect_lat = zip(*coordinates) transect_lon = np.array(transect_lon) transect_lat = np.array(transect_lat) if flip: transect_lon = transect_lon[::-1] transect_lat = transect_lat[::-1] transect_lon = xr.DataArray(data=transect_lon, dims=('nPoints',)) transect_lat = xr.DataArray(data=transect_lat, dims=('nPoints',)) ds_mpas_transect = compute_transect( x=transect_lon, y=transect_lat, ds_horiz_mesh=ds_mesh, layer_thickness=layer_thickness, bottom_depth=bottom_depth, min_level_cell=min_level_cell, max_level_cell=max_level_cell, spherical=spherical, ) ds_mpas_transect.compute() transects[transect_name] = ds_mpas_transect return transects def _plot_feature_transect( ds_transect, mpas_field, var_name, transect_name, cmap, fc, write_netcdf, method, ): """ plot a transect showing the field on the MPAS-Ocean mesh and save to a file """ transect_prefix = transect_name.replace(' ', '_') units = None if 'units' in mpas_field.attrs: units = mpas_field.attrs['units'] colormaps = dict( temperature='cmo.thermal', salinity='cmo.haline', density='cmo.dense', ) if cmap is None: for contains, map_name in colormaps.items(): if contains in var_name.lower(): cmap = map_name if units is not None: colorbar_label = f'{var_name} ({units})' else: colorbar_label = f'{var_name}' fig = plt.figure(figsize=(12, 6)) ax = plt.gca() plot_transect( ds_transect=ds_transect, mpas_field=mpas_field, ax=ax, title=f'{var_name} through {transect_name}', colorbar_label=colorbar_label, cmap=cmap, color_start_and_end=True, method=method, ) plt.tight_layout(pad=0.5, h_pad=0.5, rect=[0.0, 0.0, 1.0, 1.0]) add_inset(fig, fc) plt.savefig(f'{transect_prefix}_{var_name}.png', dpi=200) plt.close() if write_netcdf: ds_transect.to_netcdf(f'{transect_prefix}_{var_name}.nc') def _add_z(ds_mesh): """ Add zMid and zInterface to ``ds_mesh``, useful for debugging """ layer_thickness = ds_mesh.layerThickness bottom_depth = ds_mesh.bottomDepth max_level_cell = ds_mesh.maxLevelCell - 1 if 'minLevelCell' in ds_mesh: min_level_cell = ds_mesh.minLevelCell - 1 else: min_level_cell = xr.zeros_like(max_level_cell) n_vert_levels = layer_thickness.sizes['nVertLevels'] vert_index = xr.DataArray.from_dict( {'dims': ('nVertLevels',), 'data': np.arange(n_vert_levels)} ) cell_mask = np.logical_and( vert_index >= min_level_cell, vert_index <= max_level_cell ) layer_thickness = layer_thickness.where(cell_mask) thickness_sum = layer_thickness.sum(dim='nVertLevels') thickness_cum_sum = layer_thickness.cumsum(dim='nVertLevels') z_surface = -bottom_depth + thickness_sum z_layer_bot = z_surface - thickness_cum_sum z_interface_list = [z_surface] for z_index in range(n_vert_levels): z_interface_list.append(z_layer_bot.isel(nVertLevels=z_index)) z_interface = xr.concat(z_interface_list, dim='nVertLevelsP1') vert_index = xr.DataArray.from_dict( {'dims': ('nVertLevelsP1',), 'data': np.arange(n_vert_levels + 1)} ) interface_mask = np.logical_and( vert_index >= min_level_cell, vert_index <= max_level_cell + 1 ) z_interface = z_interface.where(interface_mask).transpose( 'nCells', 'nVertLevelsP1' ) z_mid = z_layer_bot + 0.5 * layer_thickness z_mid = z_mid.where(cell_mask).transpose('nCells', 'nVertLevels') ds_mesh.coords['zMid'] = z_mid ds_mesh.coords['zInterface'] = z_interface