import argparse
import cmocean # noqa: F401
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from geometric_features import FeatureCollection, read_feature_collection
from mpas_tools.ocean.viz.inset import add_inset
from mpas_tools.ocean.viz.transect.vert import (
compute_transect,
interp_mpas_to_transect_cells,
interp_mpas_to_transect_nodes,
)
from mpas_tools.viz.colormaps import register_sci_viz_colormaps
[docs]
def plot_transect(
ds_transect,
mpas_field=None,
out_filename=None,
ax=None,
title=None,
vmin=None,
vmax=None,
colorbar_label=None,
cmap=None,
figsize=(12, 6),
dpi=200,
method='flat',
outline_color='black',
ssh_color=None,
seafloor_color=None,
interface_color=None,
cell_boundary_color=None,
linewidth=1.0,
color_start_and_end=False,
start_color='red',
end_color='green',
):
"""
plot a transect showing the field on the MPAS-Ocean mesh and save to a file
Parameters
----------
ds_transect : xarray.Dataset
A transect dataset from
:py:func:`mpas_tools.ocean.viz.transect.vert.compute_transect()`
mpas_field : xarray.DataArray
The MPAS-Ocean 3D field to plot
out_filename : str, optional
The png file to write out to
ax : matplotlib.axes.Axes
Axes to plot to if making a multi-panel figure
title : str
The title of the plot
vmin : float, optional
The minimum values for the colorbar
vmax : float, optional
The maximum values for the colorbar
colorbar_label : str, optional
The colorbar label, or ``None`` if no colorbar is to be included.
Use an empty string to display a colorbar without a label.
cmap : str, optional
The name of a colormap to use
figsize : tuple, optional
The size of the figure in inches
dpi : int, optional
The dots per inch of the image
method : {'flat', 'bilinear'}, optional
The type of interpolation to use in plots. ``flat`` means constant
values over each MPAS cell. ``bilinear`` means smooth interpolation
between horizontally between cell centers and vertical between the
middle of layers.
outline_color : str or None, optional
The color to use to outline the transect or ``None`` for no outline
ssh_color : str or None, optional
The color to use to plot the SSH (sea surface height) or ``None`` if
not plotting the SSH (except perhaps as part of the outline)
seafloor_color : str or None, optional
The color to use to plot the seafloor depth or ``None`` if not plotting
the seafloor depth (except perhaps as part of the outline)
interface_color : str or None, optional
The color to use to plot interfaces between layers or ``None`` if
not plotting the layer interfaces
cell_boundary_color : str or None, optional
The color to use to plot vertical boundaries between cells or ``None``
if not plotting cell boundaries. Typically, ``cell_boundary_color``
will be used along with ``interface_color`` to outline cells both
horizontally and vertically.
linewidth : float, optional
The width of outlines, interfaces and cell boundaries
color_start_and_end : bool, optional
Whether to color the left and right axes of the transect, which is
useful if the transect is also being plotted in an inset or on top of
a horizontal field
start_color : str, optional
The color of left axis marking the start of the transect if
``plot_start_end == True``
end_color : str, optional
The color of right axis marking the end of the transect if
``plot_start_end == True``
"""
if ax is None and out_filename is None:
raise ValueError('One of ax or out_filename must be supplied')
create_fig = ax is None
if create_fig:
plt.figure(figsize=figsize)
ax = plt.subplot(111)
z = ds_transect.zTransectNode
x = 1e-3 * ds_transect.dNode.broadcast_like(z)
if mpas_field is not None:
if method == 'flat':
transect_field = interp_mpas_to_transect_cells(
ds_transect, mpas_field
)
shading = 'flat'
elif method == 'bilinear':
transect_field = interp_mpas_to_transect_nodes(
ds_transect, mpas_field
)
shading = 'gouraud'
else:
raise ValueError(f'Unsupported method: {method}')
pc = ax.pcolormesh(
x.values,
z.values,
transect_field.values,
shading=shading,
cmap=cmap,
vmin=vmin,
vmax=vmax,
zorder=0,
)
ax.autoscale(tight=True)
if colorbar_label is not None:
plt.colorbar(
pc, extend='both', shrink=0.7, ax=ax, label=colorbar_label
)
_plot_interfaces(
ds_transect,
ax,
interface_color,
cell_boundary_color,
ssh_color,
seafloor_color,
color_start_and_end,
start_color,
end_color,
linewidth,
)
_plot_outline(x, z, ds_transect.validCells, ax, outline_color, linewidth)
ax.set_xlabel('transect distance (km)')
ax.set_ylabel('z (m)')
if create_fig:
if title is not None:
plt.title(title)
plt.savefig(out_filename, dpi=dpi, bbox_inches='tight', pad_inches=0.2)
plt.close()
[docs]
def plot_feature_transects(
fc,
ds,
ds_mesh=None,
variable_list=None,
cmap=None,
flip=False,
write_netcdf=False,
method='flat',
add_z=False,
):
"""
Plot images of the given variables on the given transects. One image
named ``<transect_name>_<variable_name>.png`` will be produced in the
current directory for each transect and variable
Parameters
----------
fc : geometric_features.FeatureCollection
The transects to plot
ds : xarray.Dataset
The MPAS-Ocean dataset to plot
ds_mesh : xarray.Dataset, optional
The MPAS-Ocean mesh to use for plotting, the same as ``ds`` by default
variable_list : list of str, optional
The variables to plot
cmap : str, optional
The name of a colormap to use
flip : book, optional
Whether to flip the x axes of all transect plot
write_netcdf : bool, optional
Whether to write a NetCDF file for the transect in addition to the
image
method : {'flat', 'bilinear'}, optional
The type of interpolation to use in plots. ``flat`` means constant
values over each MPAS cell. ``bilinear`` means smooth interpolation
between horizontally between cell centers and vertical between the
middle of layers.
add_z : bool, optional
Whether to add zMid and zInterface to the mesh dataset
"""
if 'Time' in ds.dims:
ds = ds.isel(Time=0)
if 'Time' in ds_mesh.dims:
ds_mesh = ds_mesh.isel(Time=0)
if add_z:
_add_z(ds_mesh)
print('\nBuilding transect geometry...')
transects = _compute_feature_transects(fc, ds_mesh, flip)
fc_transects = dict()
for transect in fc.features:
transect_name = transect['properties']['name']
fc_transects[transect_name] = FeatureCollection(features=[transect])
register_sci_viz_colormaps()
if variable_list is None:
variable_list = list()
for var_name in ds.data_vars:
var = ds[var_name]
if 'nCells' in var.dims and (
'nVertLevels' in var.dims or 'nVertLevelsP1' in var.dims
):
variable_list.append(var_name)
print('\nPlotting...')
for var_name in variable_list:
if var_name in ds:
var = ds[var_name]
elif var_name in ds_mesh:
var = ds_mesh[var_name]
else:
raise ValueError(
f'{var_name} not found in either the main or the '
f'mesh dataset (if any)'
)
assert 'nCells' in var.dims and (
'nVertLevels' in var.dims or 'nVertLevelsP1' in var.dims
)
for transect_name, ds_transect in transects.items():
print(f' {transect_name} {var_name}')
_plot_feature_transect(
ds_transect,
var,
var_name,
transect_name,
cmap,
fc_transects[transect_name],
write_netcdf,
method,
)
def plot_feature_transects_main():
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
'-g',
'--geojson',
dest='geojson_filename',
required=True,
help='A geojson file with transects to plot',
)
parser.add_argument(
'-m',
'--mesh',
dest='mesh_filename',
help='An MPAS-Ocean mesh file. If not specified, the '
'MPAS-Ocean data file must contain the mesh.',
)
parser.add_argument(
'-f',
'--file',
dest='filename',
required=True,
help='An MPAS-Ocean data file',
)
parser.add_argument(
'-v',
'--variable_list',
dest='variable_list',
nargs='+',
help='List of variables to plot. All variables on '
'cells in the data file is the default.',
)
parser.add_argument(
'-c',
'--colormap',
dest='colormap',
help='A colormap to use for the plots, default '
'depends on the field name.',
)
parser.add_argument(
'--flip',
dest='flip',
action='store_true',
help='Flip the x axis for all transects',
)
parser.add_argument(
'--write_netcdf',
dest='write_netcdf',
action='store_true',
help='Whether to write a NetCDF file for the transect '
'in addition to the image',
)
parser.add_argument(
'--method',
dest='method',
default='flat',
help='The type of interpolation to use in plots. '
'Options are "flat" and "bilinear"',
)
parser.add_argument(
'--add_z',
dest='add_z',
action='store_true',
help='Whether to add zMid and zInterface to the mesh',
)
args = parser.parse_args()
fc = read_feature_collection(args.geojson_filename)
ds = xr.open_dataset(args.filename)
if args.mesh_filename is not None:
ds_mesh = xr.open_dataset(args.mesh_filename)
else:
ds_mesh = ds
variable_list = args.variable_list
if 'Time' in ds.dims:
ds = ds.isel(Time=0)
if 'Time' in ds_mesh.dims:
ds_mesh = ds_mesh.isel(Time=0)
plot_feature_transects(
fc=fc,
ds=ds,
ds_mesh=ds_mesh,
variable_list=variable_list,
cmap=args.colormap,
flip=args.flip,
write_netcdf=args.write_netcdf,
method=args.method,
add_z=args.add_z,
)
def _plot_interfaces(
ds_transect,
ax,
interface_color,
cell_boundary_color,
ssh_color,
seafloor_color,
color_start_and_end,
start_color,
end_color,
linewidth,
):
if cell_boundary_color is not None:
x_bnd = 1e-3 * ds_transect.dCellBoundary.values.T
z_bnd = ds_transect.zCellBoundary.values.T
ax.plot(
x_bnd,
z_bnd,
color=cell_boundary_color,
linewidth=linewidth,
zorder=1,
)
if interface_color is not None:
x_int = 1e-3 * ds_transect.dInterfaceSegment.values.T
z_int = ds_transect.zInterfaceSegment.values.T
ax.plot(
x_int, z_int, color=interface_color, linewidth=linewidth, zorder=2
)
if ssh_color is not None:
valid = ds_transect.validNodes.any(dim='nVertNodes')
x_ssh = 1e-3 * ds_transect.dNode.values
z_ssh = ds_transect.ssh.where(valid).values
ax.plot(x_ssh, z_ssh, color=ssh_color, linewidth=linewidth, zorder=4)
if seafloor_color is not None:
valid = ds_transect.validNodes.any(dim='nVertNodes')
x_floor = 1e-3 * ds_transect.dNode.values
z_floor = ds_transect.zSeafloor.where(valid).values
ax.plot(
x_floor,
z_floor,
color=seafloor_color,
linewidth=linewidth,
zorder=5,
)
if color_start_and_end:
ax.spines['left'].set_color(start_color)
ax.spines['left'].set_linewidth(4 * linewidth)
ax.spines['right'].set_color(end_color)
ax.spines['right'].set_linewidth(4 * linewidth)
def _plot_outline(
x, z, valid_cells, ax, outline_color, linewidth, epsilon=1e-6
):
if outline_color is not None:
# add a buffer of invalid values around the edge of the domain
# and make copies of each node. The validity of each copy of the node
# corresponds to the validity of the adjacent cell
valid = np.zeros(
(2 * valid_cells.shape[0] + 2, 2 * valid_cells.shape[1] + 2),
dtype=float,
)
z_buf = np.zeros(valid.shape, dtype=float)
x_buf = np.zeros(valid.shape, dtype=float)
valid_cells = valid_cells.astype(float)
# each interior node get the value from its cell
valid[1:-2:2, 1:-2:2] = valid_cells
valid[2:-1:2, 1:-2:2] = valid_cells
valid[1:-2:2, 2:-1:2] = valid_cells
valid[2:-1:2, 2:-1:2] = valid_cells
z_buf[:-1:2, :-1:2] = z.values
z_buf[1::2, :-1:2] = z.values
z_buf[:-1:2, 1::2] = z.values
z_buf[1::2, 1::2] = z.values
x_buf[:-1:2, :-1:2] = x.values
x_buf[1::2, :-1:2] = x.values
x_buf[:-1:2, 1::2] = x.values
x_buf[1::2, 1::2] = x.values
ax.contour(
x_buf,
z_buf,
valid,
levels=[1.0 - epsilon],
colors=outline_color,
linewidths=linewidth,
zorder=3,
)
def _compute_feature_transects(fc, ds_mesh, flip):
"""
build a sequence of triangles showing the transect intersecting mpas cells
"""
transects = dict()
layer_thickness = ds_mesh.layerThickness
bottom_depth = ds_mesh.bottomDepth
max_level_cell = ds_mesh.maxLevelCell - 1
if 'minLevelCell' in ds_mesh:
min_level_cell = ds_mesh.minLevelCell - 1
else:
min_level_cell = xr.zeros_like(max_level_cell)
spherical = ds_mesh.attrs['on_a_sphere'] == 'YES'
for transect in fc.features:
transect_name = transect['properties']['name']
print(f' {transect_name}')
assert transect['geometry']['type'] == 'LineString'
coordinates = transect['geometry']['coordinates']
transect_lon, transect_lat = zip(*coordinates)
transect_lon = np.array(transect_lon)
transect_lat = np.array(transect_lat)
if flip:
transect_lon = transect_lon[::-1]
transect_lat = transect_lat[::-1]
transect_lon = xr.DataArray(data=transect_lon, dims=('nPoints',))
transect_lat = xr.DataArray(data=transect_lat, dims=('nPoints',))
ds_mpas_transect = compute_transect(
x=transect_lon,
y=transect_lat,
ds_horiz_mesh=ds_mesh,
layer_thickness=layer_thickness,
bottom_depth=bottom_depth,
min_level_cell=min_level_cell,
max_level_cell=max_level_cell,
spherical=spherical,
)
ds_mpas_transect.compute()
transects[transect_name] = ds_mpas_transect
return transects
def _plot_feature_transect(
ds_transect,
mpas_field,
var_name,
transect_name,
cmap,
fc,
write_netcdf,
method,
):
"""
plot a transect showing the field on the MPAS-Ocean mesh and save to a file
"""
transect_prefix = transect_name.replace(' ', '_')
units = None
if 'units' in mpas_field.attrs:
units = mpas_field.attrs['units']
colormaps = dict(
temperature='cmo.thermal',
salinity='cmo.haline',
density='cmo.dense',
)
if cmap is None:
for contains, map_name in colormaps.items():
if contains in var_name.lower():
cmap = map_name
if units is not None:
colorbar_label = f'{var_name} ({units})'
else:
colorbar_label = f'{var_name}'
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
plot_transect(
ds_transect=ds_transect,
mpas_field=mpas_field,
ax=ax,
title=f'{var_name} through {transect_name}',
colorbar_label=colorbar_label,
cmap=cmap,
color_start_and_end=True,
method=method,
)
plt.tight_layout(pad=0.5, h_pad=0.5, rect=[0.0, 0.0, 1.0, 1.0])
add_inset(fig, fc)
plt.savefig(f'{transect_prefix}_{var_name}.png', dpi=200)
plt.close()
if write_netcdf:
ds_transect.to_netcdf(f'{transect_prefix}_{var_name}.nc')
def _add_z(ds_mesh):
"""
Add zMid and zInterface to ``ds_mesh``, useful for debugging
"""
layer_thickness = ds_mesh.layerThickness
bottom_depth = ds_mesh.bottomDepth
max_level_cell = ds_mesh.maxLevelCell - 1
if 'minLevelCell' in ds_mesh:
min_level_cell = ds_mesh.minLevelCell - 1
else:
min_level_cell = xr.zeros_like(max_level_cell)
n_vert_levels = layer_thickness.sizes['nVertLevels']
vert_index = xr.DataArray.from_dict(
{'dims': ('nVertLevels',), 'data': np.arange(n_vert_levels)}
)
cell_mask = np.logical_and(
vert_index >= min_level_cell, vert_index <= max_level_cell
)
layer_thickness = layer_thickness.where(cell_mask)
thickness_sum = layer_thickness.sum(dim='nVertLevels')
thickness_cum_sum = layer_thickness.cumsum(dim='nVertLevels')
z_surface = -bottom_depth + thickness_sum
z_layer_bot = z_surface - thickness_cum_sum
z_interface_list = [z_surface]
for z_index in range(n_vert_levels):
z_interface_list.append(z_layer_bot.isel(nVertLevels=z_index))
z_interface = xr.concat(z_interface_list, dim='nVertLevelsP1')
vert_index = xr.DataArray.from_dict(
{'dims': ('nVertLevelsP1',), 'data': np.arange(n_vert_levels + 1)}
)
interface_mask = np.logical_and(
vert_index >= min_level_cell, vert_index <= max_level_cell + 1
)
z_interface = z_interface.where(interface_mask).transpose(
'nCells', 'nVertLevelsP1'
)
z_mid = z_layer_bot + 0.5 * layer_thickness
z_mid = z_mid.where(cell_mask).transpose('nCells', 'nVertLevels')
ds_mesh.coords['zMid'] = z_mid
ds_mesh.coords['zInterface'] = z_interface