import logging
import os
import sys
import networkx as nx
import numpy as np
import scipy.sparse
import scipy.sparse.linalg
import xarray as xr
from mpas_tools.io import write_netcdf
from mpas_tools.ocean.streamfunction.velocity import (
compute_vertically_integrated_velocity,
)
from mpas_tools.ocean.streamfunction.vorticity import (
compute_vertically_integrated_vorticity,
)
[docs]
def compute_barotropic_streamfunction(
ds_mesh,
ds,
logger=None,
min_depth=None,
max_depth=None,
prefix='timeMonthly_avg_',
time_index=None,
include_bolus=False,
include_submesoscale=False,
quiet=False,
horiz_chunk=10000,
tmp_dir=None,
):
"""
Compute barotropic streamfunction
Parameters
----------
ds_mesh : xarray.Dataset
A dataset containing MPAS mesh variables
ds : xarray.Dataset
A dataset containing MPAS output variables ``normalVelocity`` and
``layerThickness`` (possibly with a ``prefix``)
logger : logging.Logger, optional
A logger for the output if not stdout
min_depth : float, optional
The minimum depth (positive up) to compute BSF over
max_depth : float, optional
The maximum depth (positive up) to compute BSF over
prefix : str, optional
The prefix on the ``normalVelocity`` and ``layerThickness`` variables
time_index : int, optional
The time at which to index ``ds`` (if it has ``Time`` as a dimension)
include_bolus : bool, optional
Whether to include the GM bolus velocity in the computation
include_submesoscale : bool, optional
Whether to include the submesoscale velocity in the computation
quiet : bool, optional
If True, suppress all logging output
If False, log all output to the logger
horiz_chunk : int, optional
The number of edges to chunk the dataset by when computing
the vertically integrated velocity. This is useful for
large datasets to avoid memory issues. Set this to ``None`` to disable
chunking.
tmp_dir : str, optional
A temporary directory to use for intermediate files. This is useful
for large datasets to avoid memory issues. If None, no temporary
directory is used.
Returns
-------
bsf_vertex : xarray.DataArray
The barotropic streamfunction in Sv on vertices
"""
if quiet:
logger = None
elif logger is None:
logger = logging.getLogger()
logger.addHandler(logging.StreamHandler(sys.stdout))
logger.setLevel(logging.INFO)
if time_index is None:
if 'Time' in ds.dims:
raise ValueError(
'time_index must be provided if "Time" is a dimension of ds'
)
else:
ds = ds.isel(Time=time_index)
if logger:
logger.info('Computing barotropic streamfunction.')
bsf_vertex = _compute_barotropic_streamfunction_vertex(
ds_mesh,
ds,
prefix,
include_bolus,
include_submesoscale,
min_depth,
max_depth,
logger,
horiz_chunk,
tmp_dir,
)
if logger:
logger.info(' Done.')
return bsf_vertex
[docs]
def shift_barotropic_streamfunction(
bsf_vertex, lat_range, cells_on_vertex, lat_vertex, logger=None
):
"""
Shift the barotropic streamfunction to be zero on average at the boundary
over the given latitude range
Parameters
----------
bsf_vertex : xarray.DataArray
The barotropic streamfunction in Sv on vertices
lat_range : list of float
The latitude range in degrees over which to set the mean boundary value
of the BSF to zero
cells_on_vertex : xarray.DataArray
The zero-based cell indices on each vertex
lat_vertex : xarray.DataArray
The latitude of each vertex in radians
logger : logging.Logger, optional
A logger for the output
Returns
-------
bsf_shifted : xarray.DataArray
The shifted barotropic streamfunction in Sv on vertices
"""
is_boundary_cov = cells_on_vertex == -1
boundary_vertices = np.logical_and(
is_boundary_cov.sum(dim='vertexDegree') > 0,
np.logical_and(
lat_vertex >= np.deg2rad(lat_range[0]),
lat_vertex <= np.deg2rad(lat_range[1]),
),
)
# convert from boolean mask to indices
boundary_vertices = np.flatnonzero(boundary_vertices.values)
mean_boundary_bsf = bsf_vertex.isel(nVertices=boundary_vertices).mean()
if logger:
logger.info(
f' Mean BSF on boundary vertices in range {lat_range} '
f'is {mean_boundary_bsf.values:.4f} Sv'
)
bsf_shifted = bsf_vertex - mean_boundary_bsf
return bsf_shifted
def _build_minimal_boundary_constraints(
boundary_vertex0, boundary_vertex1, logger
):
"""
Construct a minimal set of boundary constraints that tie each connected
loop together without introducing redundancy.
Parameters
----------
boundary_vertex0 : xarray.DataArray
The first vertex in each unique pair of boundary vertices.
boundary_vertex1 : xarray.DataArray
The second vertex in each unique pair of boundary vertices.
logger : logging.Logger, optional
Logger for logging messages.
Returns
-------
minimal_constraints : list of tuple
A minimal set of edges (vertex pairs) that constrain the boundary.
"""
# Create a graph from the boundary edges
graph = nx.Graph()
edges = list(zip(boundary_vertex0.values, boundary_vertex1.values))
graph.add_edges_from(edges)
minimal_constraints = []
# Loop over connected components (disjoint loops)
connected_components = list(nx.connected_components(graph))
if logger:
logger.info(
f' Found {len(connected_components)} independent boundary '
f'loops.'
)
for component in connected_components:
subgraph = graph.subgraph(component)
# Find a spanning tree for the component
spanning_tree = nx.minimum_spanning_tree(subgraph)
# Add the edges of the spanning tree to the constraints
minimal_constraints.extend(spanning_tree.edges)
return minimal_constraints
def _identify_boundary_vertices(ds_mesh, logger, all_vertices):
"""
Identify boundary vertices and edges in the mesh, and remove redundant
vertices that lead to overdetermined constraints.
"""
if logger:
logger.info(' Identifying boundary vertices.')
# Identify boundary vertices
nedges = ds_mesh.sizes['nEdges']
cells_on_vertex = ds_mesh.cellsOnVertex - 1
cells_on_edge = ds_mesh.cellsOnEdge - 1
vertices_on_edge = ds_mesh.verticesOnEdge - 1
boundary_mask = (cells_on_vertex == -1).any(dim='vertexDegree')
boundary_vertices = all_vertices.where(boundary_mask, drop=True).astype(
int
)
boundary_edge_mask = (cells_on_edge == -1).any(dim='TWO')
all_edges = xr.DataArray(np.arange(nedges, dtype=int), dims=('nEdges',))
boundary_edges = all_edges.where(boundary_edge_mask, drop=True).astype(int)
boundary_vertex0 = vertices_on_edge.isel(nEdges=boundary_edges, TWO=0)
boundary_vertex1 = vertices_on_edge.isel(nEdges=boundary_edges, TWO=1)
if logger:
logger.info(' Detect boundary loops and remove redundant pairs.')
# find each independent loop of boundary vertices and determine a set of
# pairs that does not close the loop (i.e. a spanning tree), avoiding
# overdetermined constraints
minimal_constraints = _build_minimal_boundary_constraints(
boundary_vertex0, boundary_vertex1, logger
)
# Unpack minimal_constraints into boundary_vertex0 and boundary_vertex1
boundary_vertex0, boundary_vertex1 = zip(*minimal_constraints)
boundary_vertex0 = xr.DataArray(
np.array(boundary_vertex0), dims=('nVertices',)
)
boundary_vertex1 = xr.DataArray(
np.array(boundary_vertex1), dims=('nVertices',)
)
nboundary = boundary_vertex0.sizes['nVertices']
if logger:
nboundary_removed = boundary_vertices.sizes['nVertices'] - nboundary
logger.info(
f' Removed {nboundary_removed} redundant boundary vertices.'
)
return boundary_vertex0, boundary_vertex1, nboundary
def _assemble_matrix(
ds_mesh, logger, edge_sign_on_vertex, vert_integ_vorticity
):
"""
Assemble the sparse matrix for the Poisson equation.
"""
var_list = [
'edgesOnVertex',
'verticesOnEdge',
'areaTriangle',
'dcEdge',
'dvEdge',
'cellsOnVertex',
'cellsOnEdge',
]
ds_mesh = ds_mesh[var_list].as_numpy()
edge_sign_on_vertex = edge_sign_on_vertex.as_numpy()
vert_integ_vorticity = vert_integ_vorticity.as_numpy()
nvertices = ds_mesh.sizes['nVertices']
vertex_degree = ds_mesh.sizes['vertexDegree']
edges_on_vertex = ds_mesh.edgesOnVertex - 1
vertices_on_edge = ds_mesh.verticesOnEdge - 1
area_vertex = ds_mesh.areaTriangle
dc_edge = ds_mesh.dcEdge
dv_edge = ds_mesh.dvEdge
all_vertices = xr.DataArray(
np.arange(nvertices, dtype=int), dims=('nVertices',)
)
boundary_vertex0, boundary_vertex1, nboundary = (
_identify_boundary_vertices(ds_mesh, logger, all_vertices)
)
if logger:
logger.info(' Assembling sparse matrix for the Poisson equation.')
# Assemble the sparse matrix for solving the Poisson equation:
# * the Poisson equation at each vertex involves vertex degree + 1 terms
# * the boundary conditions involve 2 vertices and are duplicated in
# the form of a Lagrange multiplier
# * the unique solution is ensured by adding a constraint on the
# streamfunction at the first vertex, again as a Lagrange multiplier
ndata = (vertex_degree + 1) * nvertices + 4 * nboundary + 2
nmatrix = nvertices + nboundary + 1
indices = np.zeros((2, ndata), dtype=int)
data = np.zeros(ndata, dtype=float)
# Fill the Poisson equation for the BSF at each vertex will be equal to the
# vertically integrated vorticity
idata = (vertex_degree + 1) * all_vertices.values
rows = all_vertices.values
indices[0, idata] = rows
indices[1, idata] = all_vertices.values
for iedge in range(vertex_degree):
eov = edges_on_vertex.isel(vertexDegree=iedge)
dc = dc_edge.isel(nEdges=eov)
dv = dv_edge.isel(nEdges=eov)
v0 = vertices_on_edge.isel(nEdges=eov, TWO=0)
v1 = vertices_on_edge.isel(nEdges=eov, TWO=1)
edge_sign = edge_sign_on_vertex.isel(vertexDegree=iedge)
mask = v0 == all_vertices
this_vert_sign = xr.where(mask, -1.0, 1.0)
other_vert_index = xr.where(mask, v1, v0)
other_vert_index = xr.where(other_vert_index >= 0, other_vert_index, 0)
idata_other = idata + iedge + 1
indices[0, idata_other] = rows
indices[1, idata_other] = other_vert_index.values
this_data = this_vert_sign * edge_sign * dc / (dv * area_vertex)
data[idata] += this_data.values
data[idata_other] = -this_data.values
# Add boundary conditions to the matrix
# The difference in the BSF between adjacent boundary vertices is
# zero
idata = (vertex_degree + 1) * nvertices + 2 * np.arange(nboundary)
rows = nvertices + np.arange(nboundary)
indices[0, idata] = rows
indices[1, idata] = boundary_vertex0.values
data[idata] = -1.0
idata += 1
indices[0, idata] = rows
indices[1, idata] = boundary_vertex1.values
data[idata] = 1.0
# Now the transpose
idata = (
(vertex_degree + 1) * nvertices
+ 2 * nboundary
+ 2 * np.arange(nboundary)
)
col = nvertices + np.arange(nboundary)
indices[0, idata] = boundary_vertex0.values
indices[1, idata] = col
data[idata] = -1.0
idata += 1
indices[0, idata] = boundary_vertex1.values
indices[1, idata] = col
data[idata] = 1.0
# Add gauge constraints to ensure a unique solution
# The BSF at vertex 0 will be zero
idata = ndata - 2
indices[0, idata] = nmatrix - 1
indices[1, idata] = 0
data[idata] = 1.0
# And the transpose
idata = ndata - 1
indices[0, idata] = 0
indices[1, idata] = nmatrix - 1
data[idata] = 1.0
# Assemble the right-hand side of the equation:
# * the vertically integrated vorticity at each vertex
# * the boundary conditions are zero
# * the gauge constraints are zero
rhs = np.zeros(nmatrix, dtype=float)
rhs[0:nvertices] = vert_integ_vorticity.values
return indices, data, rhs, nmatrix
def _compute_barotropic_streamfunction_vertex(
ds_mesh,
ds,
prefix,
include_bolus,
include_submesoscale,
min_depth,
max_depth,
logger,
horiz_chunk,
tmp_dir,
):
"""
Compute the barotropic streamfunction on vertices.
This function solves a Poisson equation to compute the barotropic
streamfunction, which integrates vertically integrated velocity
divergence to obtain the streamfunction.
"""
if logger:
logger.info(' Computing vertically integrated velocity.')
# Compute the vertically integrated velocity on edges
vert_integ_velocity = compute_vertically_integrated_velocity(
ds_mesh=ds_mesh,
ds=ds,
logger=logger,
min_depth=min_depth,
max_depth=max_depth,
prefix=prefix,
include_bolus=include_bolus,
include_submesoscale=include_submesoscale,
nedges_chunk=horiz_chunk,
)
if tmp_dir is not None:
if logger:
logger.info(
' Writing out and reading in vertically integrated velocity.'
)
# write out and read back the vertically integrated velocity
ds_out = xr.Dataset()
ds_out['vertIntegNormalVelocity'] = vert_integ_velocity
filename = os.path.join(tmp_dir, 'bsf_vert_integ_normal_vel.nc')
write_netcdf(ds_out, filename, logger=logger)
ds_in = xr.open_dataset(filename)
vert_integ_velocity = ds_in['vertIntegNormalVelocity']
vert_integ_vorticity, edge_sign_on_vertex = (
compute_vertically_integrated_vorticity(
ds_mesh=ds_mesh,
vert_integ_velocity=vert_integ_velocity,
logger=logger,
)
)
if tmp_dir is not None:
if logger:
logger.info(
' Writing out and reading in vertically integrated vorticity.'
)
# write out and read back the vertically integrated velocity
ds_out = xr.Dataset()
ds_out['vertIntegVorticity'] = vert_integ_vorticity
ds_out['edgeSignOnVertex'] = edge_sign_on_vertex
filename = os.path.join(tmp_dir, 'bsf_vert_integ_vorticity.nc')
write_netcdf(ds_out, filename, logger=logger)
ds_in = xr.open_dataset(filename)
vert_integ_vorticity = ds_in['vertIntegVorticity']
edge_sign_on_vertex = ds_in['edgeSignOnVertex']
indices, data, rhs, nmatrix = _assemble_matrix(
ds_mesh,
logger,
edge_sign_on_vertex,
vert_integ_vorticity,
)
if logger:
logger.info(' Solving the sparse linear system.')
# Solve the sparse linear system
matrix = scipy.sparse.csr_matrix((data, indices), shape=(nmatrix, nmatrix))
solution = scipy.sparse.linalg.spsolve(matrix, rhs)
if logger:
logger.info(' Finalizing the barotropic streamfunction.')
# Convert the solution to the barotropic streamfunction
nvertices = ds_mesh.sizes['nVertices']
bsf_vertex = xr.DataArray(
-1e-6 * solution[0:nvertices], dims=('nVertices',)
)
return bsf_vertex