#!/usr/bin/env python
import argparse
import cmocean
import numpy as np
import xarray as xr
from geometric_features import read_feature_collection, FeatureCollection
import matplotlib.pyplot as plt
from matplotlib.tri import Triangulation
from mpas_tools.viz import mesh_to_triangles
from mpas_tools.viz.transects import find_transect_cells_and_weights, \
    make_triangle_tree
from mpas_tools.ocean.transects import find_transect_levels_and_weights, \
    interp_mpas_to_transect_triangles, get_outline_segments
from mpas_tools.ocean.viz.inset import add_inset
from mpas_tools.viz.colormaps import register_sci_viz_colormaps
[docs]
def plot_ocean_transects(fc, ds, ds_mesh=None, variable_list=None, cmap=None,
                         flip=False):
    """
    Plot images of the given variables on the given transects.  One image
    named ``<transect_name>_<variable_name>.png`` will be produced in the
    current directory for each transect and variable
    Parameters
    ----------
    fc : geometric_features.FeatureCollection
        The transects to plot
    ds : xarray.Dataset
        The MPAS-Ocean dataset to plot
    ds_mesh : xarray.Dataset, optional
        The MPAS-Ocean mesh to use for plotting, the same as ``ds`` by default
    variable_list : list of str, optional
        The variables to plot
    cmap : str, optional
        The name of a colormap to use
    flip : book, optional
        Whether to flip the x axes of all transect plot
    """
    if 'Time' in ds.dims:
        ds = ds.isel(Time=0)
    if 'Time' in ds_mesh.dims:
        ds_mesh = ds_mesh.isel(Time=0)
    transects = _compute_transects(fc, ds_mesh, flip)
    print('\nBuilding transect geometry...')
    fc_transects = dict()
    for transect in fc.features:
        transect_name = transect['properties']['name']
        print(f'  {transect_name}')
        fc_transects[transect_name] = FeatureCollection(features=[transect])
    register_sci_viz_colormaps()
    if variable_list is None:
        variable_list = list()
        for var_name in ds.data_vars:
            var = ds[var_name]
            if 'nCells' in var.dims and 'nVertLevels' in var.dims:
                variable_list.append(var_name)
    print('\nPlotting...')
    for var_name in variable_list:
        var = ds[var_name]
        assert 'nCells' in var.dims and 'nVertLevels' in var.dims
        for transect_name, ds_transect in transects.items():
            print(f'  {transect_name} {var_name}')
            _plot_transect(ds_transect, var, var_name, transect_name, cmap,
                           fc_transects[transect_name]) 
def main():
    parser = argparse.ArgumentParser(
        description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('-g', '--geojson', dest='geojson_filename',
                        required=True,
                        help='A geojson file with transects to plot')
    parser.add_argument('-m', '--mesh', dest='mesh_filename',
                        help='An MPAS-Ocean mesh file.  If not specified, the '
                             'MPAS-Ocean data file must contain the mesh.')
    parser.add_argument('-f', '--file', dest='filename', required=True,
                        help='An MPAS-Ocean data file')
    parser.add_argument("-v", "--variable_list", dest="variable_list",
                        nargs='+',
                        help="List of variables to plot.  All variables on "
                             "cells in the data file is the default.")
    parser.add_argument('-c', '--colormap', dest='colormap',
                        help='A colormap to use for the plots, default '
                             'depends on the field name.')
    parser.add_argument('--flip', dest='flip', action='store_true',
                        help='Flip the x axis for all transects')
    args = parser.parse_args()
    fc = read_feature_collection(args.geojson_filename)
    ds = xr.open_dataset(args.filename)
    if args.mesh_filename is not None:
        ds_mesh = xr.open_dataset(args.mesh_filename)
    else:
        ds_mesh = ds
    variable_list = args.variable_list
    plot_ocean_transects(fc=fc, ds=ds, ds_mesh=ds_mesh,
                         variable_list=variable_list, cmap=args.colormap,
                         flip=args.flip)
def _compute_transects(fc, ds_mesh, flip):
    """
    build a sequence of triangles showing the transect intersecting mpas cells
    """
    ds_tris = mesh_to_triangles(ds_mesh)
    triangle_tree = make_triangle_tree(ds_tris)
    transects = dict()
    for transect in fc.features:
        transect_name = transect['properties']['name']
        assert transect['geometry']['type'] == 'LineString'
        coordinates = transect['geometry']['coordinates']
        transect_lon, transect_lat = zip(*coordinates)
        transect_lon = np.array(transect_lon)
        transect_lat = np.array(transect_lat)
        if flip:
            transect_lon = transect_lon[::-1]
            transect_lat = transect_lat[::-1]
        transect_lon = xr.DataArray(data=transect_lon,
                                    dims=('nPoints',))
        transect_lat = xr.DataArray(data=transect_lat,
                                    dims=('nPoints',))
        ds_mpas_transect = find_transect_cells_and_weights(
            transect_lon, transect_lat, ds_tris, ds_mesh,
            triangle_tree, degrees=True)
        ds_mpas_transect = find_transect_levels_and_weights(
            ds_mpas_transect, ds_mesh.layerThickness,
            ds_mesh.bottomDepth, ds_mesh.maxLevelCell - 1)
        if 'landIceFraction' in ds_mesh:
            interp_cell_indices = ds_mpas_transect.interpHorizCellIndices
            interp_cell_weights = ds_mpas_transect.interpHorizCellWeights
            land_ice_fraction = ds_mesh.landIceFraction.isel(
                nCells=interp_cell_indices)
            land_ice_fraction = (land_ice_fraction * interp_cell_weights).sum(
                dim='nHorizWeights')
            ds_mpas_transect['landIceFraction'] = land_ice_fraction
        ds_mpas_transect['x'] = ds_mpas_transect.dNode.isel(
            nSegments=ds_mpas_transect.segmentIndices,
            nHorizBounds=ds_mpas_transect.nodeHorizBoundsIndices)
        ds_mpas_transect['z'] = ds_mpas_transect.zTransectNode
        ds_mpas_transect.compute()
        transects[transect_name] = ds_mpas_transect
    return transects
def _plot_transect(ds_transect, mpas_field, var_name, transect_name, cmap, fc):
    """
    plot a transect showing the field on the MPAS-Ocean mesh and save to a file
    """
    transect_prefix = transect_name.replace(' ', '_')
    transect_field = interp_mpas_to_transect_triangles(ds_transect,
                                                       mpas_field)
    units = None
    if 'units' in mpas_field.attrs:
        units = mpas_field.attrs['units']
    x_outline, z_outline = get_outline_segments(ds_transect)
    x_outline = 1e-3 * x_outline
    colormaps = dict(
        temperature='cmo.thermal',
        salinity='cmo.haline',
        density='cmo.dense',
    )
    if cmap is None:
        for contains, map_name in colormaps.items():
            if contains in var_name.lower():
                cmap = map_name
    tri_mask = np.logical_not(transect_field.notnull().values)
    # if any node of a triangle is masked, the triangle is masked
    # tri_mask = np.amax(tri_mask, axis=1)
    triangulation_args = _get_ds_triangulation_args(ds_transect)
    triangulation_args['mask'] = tri_mask
    tris = Triangulation(**triangulation_args)
    fig = plt.figure(figsize=(12, 6))
    ax = plt.gca()
    plt.tripcolor(tris, facecolors=transect_field.values, shading='flat',
                  cmap=cmap)
    plt.plot(x_outline, z_outline, 'k')
    if units is not None:
        colorbar_label = f'{var_name} ({units})'
    else:
        colorbar_label = f'{var_name}'
    plt.colorbar(label=colorbar_label)
    plt.title(f'{var_name} through {transect_name}')
    plt.xlabel('x (km)')
    plt.ylabel('z (m)')
    # make a red start axis and green end axis to correspond to the dots
    # in the inset
    ax.spines['left'].set_color('red')
    ax.spines['right'].set_color('green')
    ax.spines['left'].set_linewidth(4)
    ax.spines['right'].set_linewidth(4)
    plt.tight_layout(pad=0.5, h_pad=0.5, rect=[0.0, 0.0, 1.0, 1.0])
    add_inset(fig, fc)
    plt.savefig(f'{transect_prefix}_{var_name}.png', dpi=200)
    plt.close()
def _get_ds_triangulation_args(ds_transect):
    """
    get arguments for matplotlib Triangulation from triangulation dataset
    """
    n_transect_triangles = ds_transect.sizes['nTransectTriangles']
    d_node = ds_transect.dNode.isel(
        nSegments=ds_transect.segmentIndices,
        nHorizBounds=ds_transect.nodeHorizBoundsIndices)
    x = 1e-3 * d_node.values.ravel()
    z_transect_node = ds_transect.zTransectNode
    y = z_transect_node.values.ravel()
    tris = np.arange(3 * n_transect_triangles).reshape(
        (n_transect_triangles, 3))
    triangulation_args = dict(x=x, y=y, triangles=tris)
    return triangulation_args
if __name__ == '__main__':
    main()