import numpy as np
import xarray as xr
from scipy.spatial import KDTree
from shapely.geometry import LineString, Point
from mpas_tools.transects import (
cartesian_to_lon_lat,
lon_lat_to_cartesian,
subdivide_great_circle,
subdivide_planar,
)
from mpas_tools.vector import Vector
[docs]
def mesh_to_triangles(ds_mesh):
"""
Construct a dataset in which each MPAS cell is divided into the triangles
connecting pairs of adjacent vertices to cell centers.
Parameters
----------
ds_mesh : xarray.Dataset
An MPAS mesh
Returns
-------
ds_tris : xarray.Dataset
A dataset that defines triangles connecting pairs of adjacent vertices
to cell centers as well as the cell index that each triangle is in and
cell indices and weights for interpolating data defined at cell centers
to triangle nodes. ``ds_tris`` includes variables ``triCellIndices``,
the cell that each triangle is part of; ``nodeCellIndices`` and
``nodeCellWeights``, the indices and weights used to interpolate from
MPAS cell centers to triangle nodes; Cartesian coordinates ``xNode``,
``yNode``, and ``zNode``; and ``lonNode``` and ``latNode`` in radians.
``lonNode`` is guaranteed to be within 180 degrees of the cell center
corresponding to ``triCellIndices``. Nodes always have a
counterclockwise winding.
"""
n_vertices_on_cell = ds_mesh.nEdgesOnCell.values
vertices_on_cell = ds_mesh.verticesOnCell.values - 1
cells_on_vertex = ds_mesh.cellsOnVertex.values - 1
on_a_sphere = ds_mesh.attrs['on_a_sphere'].strip() == 'YES'
is_periodic = False
x_period = None
y_period = None
if not on_a_sphere:
is_periodic = ds_mesh.attrs['is_periodic'].strip() == 'YES'
if is_periodic:
x_period = ds_mesh.attrs['x_period']
y_period = ds_mesh.attrs['y_period']
kite_areas_on_vertex = ds_mesh.kiteAreasOnVertex.values
n_triangles = np.sum(n_vertices_on_cell)
max_edges = ds_mesh.sizes['maxEdges']
n_cells = ds_mesh.sizes['nCells']
if ds_mesh.sizes['vertexDegree'] != 3:
raise ValueError(
'mesh_to_triangles only supports meshes with vertexDegree = 3'
)
# find the third vertex for each triangle
next_vertex = -1 * np.ones(vertices_on_cell.shape, int)
for i_vertex in range(max_edges):
valid = i_vertex < n_vertices_on_cell
invalid = np.logical_not(valid)
vertices_on_cell[invalid, i_vertex] = -1
nv = n_vertices_on_cell[valid]
cell_indices = np.arange(0, n_cells)[valid]
i_next = np.where(i_vertex < nv - 1, i_vertex + 1, 0)
next_vertex[:, i_vertex][valid] = vertices_on_cell[
cell_indices, i_next
]
valid = vertices_on_cell >= 0
vertices_on_cell = vertices_on_cell[valid]
next_vertex = next_vertex[valid]
# find the cell index for each triangle
tri_cell_indices, _ = np.meshgrid(
np.arange(0, n_cells), np.arange(0, max_edges), indexing='ij'
)
tri_cell_indices = tri_cell_indices[valid]
# find list of cells and weights for each triangle node
node_cell_indices = -1 * np.ones((n_triangles, 3, 3), dtype=int)
node_cell_weights = np.zeros((n_triangles, 3, 3))
# the first node is at the cell center, so the value is just the one from
# that cell
node_cell_indices[:, 0, 0] = tri_cell_indices
node_cell_weights[:, 0, 0] = 1.0
# the other 2 nodes are associated with vertices
node_cell_indices[:, 1, :] = cells_on_vertex[vertices_on_cell, :]
node_cell_weights[:, 1, :] = kite_areas_on_vertex[vertices_on_cell, :]
node_cell_indices[:, 2, :] = cells_on_vertex[next_vertex, :]
node_cell_weights[:, 2, :] = kite_areas_on_vertex[next_vertex, :]
weight_sum = np.sum(node_cell_weights, axis=2)
for i_node in range(3):
node_cell_weights[:, :, i_node] = (
node_cell_weights[:, :, i_node] / weight_sum
)
ds_tris = xr.Dataset()
ds_tris['triCellIndices'] = ('nTriangles', tri_cell_indices)
ds_tris['nodeCellIndices'] = (
('nTriangles', 'nNodes', 'nInterp'),
node_cell_indices,
)
ds_tris['nodeCellWeights'] = (
('nTriangles', 'nNodes', 'nInterp'),
node_cell_weights,
)
# get Cartesian and lon/lat coordinates of each node
for prefix in ['x', 'y', 'z', 'lat', 'lon']:
out_var = f'{prefix}Node'
cell_var = f'{prefix}Cell'
vertex_var = f'{prefix}Vertex'
coord = np.zeros((n_triangles, 3))
coord[:, 0] = ds_mesh[cell_var].values[tri_cell_indices]
coord[:, 1] = ds_mesh[vertex_var].values[vertices_on_cell]
coord[:, 2] = ds_mesh[vertex_var].values[next_vertex]
ds_tris[out_var] = (('nTriangles', 'nNodes'), coord)
# nothing obvious we can do about triangles containing the poles
if on_a_sphere:
ds_tris = _fix_periodic_tris(
ds_tris, periodic_var='lonNode', period=2 * np.pi
)
elif is_periodic:
ds_tris = _fix_periodic_tris(
ds_tris, periodic_var='xNode', period=x_period
)
ds_tris = _fix_periodic_tris(
ds_tris, periodic_var='yNode', period=y_period
)
return ds_tris
[docs]
def make_triangle_tree(ds_tris):
"""
Make a KD-Tree for finding triangle edges that are near enough to transect
segments that they might intersect
Parameters
----------
ds_tris : xarray.Dataset
A dataset that defines triangles, the results of calling
:py:func:`mpas_tools.viz.transect.horiz.mesh_to_triangles()`
Returns
-------
tree : scipy.spatial.KDTree
A tree of edge centers from triangles making up an MPAS mesh
"""
n_triangles = ds_tris.sizes['nTriangles']
n_nodes = ds_tris.sizes['nNodes']
node_coords = np.zeros((n_triangles * n_nodes, 3))
node_coords[:, 0] = ds_tris.xNode.values.ravel()
node_coords[:, 1] = ds_tris.yNode.values.ravel()
node_coords[:, 2] = ds_tris.zNode.values.ravel()
next_tri, next_node = np.meshgrid(
np.arange(n_triangles),
np.mod(np.arange(n_nodes) + 1, 3),
indexing='ij',
)
nextIndices = n_nodes * next_tri.ravel() + next_node.ravel()
# edge centers are half way between adjacent nodes (ignoring great-circle
# distance)
edgeCoords = 0.5 * (node_coords + node_coords[nextIndices, :])
tree = KDTree(data=edgeCoords, copy_data=True)
return tree
[docs]
def find_spherical_transect_cells_and_weights(
lon_transect,
lat_transect,
ds_tris,
ds_mesh,
tree,
degrees=True,
earth_radius=None,
subdivision_res=10e3,
):
"""
Find "nodes" where the transect intersects the edges of the triangles
that make up MPAS cells.
Parameters
----------
lon_transect : xarray.DataArray
The longitude of segments making up the transect
lat_transect : xarray.DataArray
The latitude of segments making up the transect
ds_tris : xarray.Dataset
A dataset that defines triangles, the results of calling
:py:func:`mpas_tools.viz.transect.horiz.mesh_to_triangles()`
ds_mesh : xarray.Dataset
A data set with the full MPAS mesh.
tree : scipy.spatial.KDTree
A tree of edge centers from triangles making up an MPAS mesh, the
return value from
:py:func:`mpas_tools.viz.transect.horiz.make_triangle_tree()`
degrees : bool, optional
Whether ``lon_transect`` and ``lat_transect`` are in degrees (as
opposed to radians).
subdivision_res : float, optional
Resolution in m to use to subdivide the transect when looking for
intersection candidates. Should be small enough that curvature is
small.
earth_radius : float, optional
The radius of the Earth in meters, taken from the `sphere_radius`
global attribute if not provided
Returns
-------
ds_out : xarray.Dataset
A dataset that contains "nodes" where the transect intersects the
edges of the triangles in ``ds_tris``. The nodes also includes the two
end points of the transect, which typically lie within triangles. Each
internal node (that is, not including the end points) is purposefully
repeated twice, once for each triangle that node touches. This allows
for discontinuous fields between triangles (e.g. if one wishes to plot
constant values on each MPAS cell). The Cartesian and lon/lat
coordinates of these nodes are ``xCartNode``, ``yCartNode``,
``zCartNode``, ``lonNode`` and ``latNode``. The distance along the
transect of each intersection is ``dNode``. The index of the triangle
and the first triangle node in ``ds_tris`` associated with each
intersection node are given by ``horizTriangleIndices`` and
``horizTriangleNodeIndices``, respectively. The second node on the
triangle for the edge associated with the intersection is given by
``numpy.mod(horizTriangleNodeIndices + 1, 3)``.
The MPAS cell that a given node belongs to is given by
``horizCellIndices``. Each node also has an associated set of 6
``interpHorizCellIndices`` and ``interpHorizCellWeights`` that can be
used to interpolate from MPAS cell centers to nodes first with
area-weighted averaging to MPAS vertices and then linear interpolation
along triangle edges. Some of the weights may be zero, in which case
the associated ``interpHorizCellIndices`` will be -1.
Finally, ``lonTransect`` and ``latTransect`` are included in the
dataset, along with Cartesian coordinates ``xCartTransect``,
``yCartTransect``, `zCartTransect``, and ``dTransect``, the
great-circle distance along the transect of each original transect
point. In order to interpolate values (e.g. observations) from the
original transect points to the intersection nodes, linear
interpolation indices ``transectIndicesOnHorizNode`` and weights
``transectWeightsOnHorizNode`` are provided. The values at nodes are
found by::
nodeValues = ((transectValues[transectIndicesOnHorizNode] *
transectWeightsOnHorizNode)
+ (transectValues[transectIndicesOnHorizNode+1] *
(1.0 - transectWeightsOnHorizNode))
"""
if earth_radius is None:
earth_radius = ds_mesh.attrs['sphere_radius']
buffer = np.maximum(
np.amax(ds_mesh.dvEdge.values), np.amax(ds_mesh.dcEdge.values)
)
x, y, z = lon_lat_to_cartesian(
lon_transect, lat_transect, earth_radius, degrees
)
n_nodes = ds_tris.sizes['nNodes']
node_cell_weights = ds_tris.nodeCellWeights.values
node_cell_indices = ds_tris.nodeCellIndices.values
x_node = ds_tris.xNode.values.ravel()
y_node = ds_tris.yNode.values.ravel()
z_node = ds_tris.zNode.values.ravel()
d_transect = np.zeros(lon_transect.shape)
d_node = None
x_out = None
y_out = None
z_out = None
tris = None
nodes = None
interp_cells = None
cell_weights = None
n_horiz_weights = 6
first = True
d_start = 0.0
for seg_index in range(len(x) - 1):
transectv0 = Vector(
x[seg_index].values, y[seg_index].values, z[seg_index].values
)
transectv1 = Vector(
x[seg_index + 1].values,
y[seg_index + 1].values,
z[seg_index + 1].values,
)
sub_slice = slice(seg_index, seg_index + 2)
x_sub, y_sub, z_sub, _, _ = subdivide_great_circle(
x[sub_slice].values,
y[sub_slice].values,
z[sub_slice].values,
subdivision_res,
earth_radius,
)
coords = np.zeros((len(x_sub), 3))
coords[:, 0] = x_sub
coords[:, 1] = y_sub
coords[:, 2] = z_sub
radius = buffer + subdivision_res
index_list = tree.query_ball_point(x=coords, r=radius)
unique_indices = set()
for indices in index_list:
unique_indices.update(indices)
n0_indices_cand = np.array(list(unique_indices))
if len(n0_indices_cand) == 0:
continue
tris_cand = n0_indices_cand // n_nodes
next_node_index = np.mod(n0_indices_cand + 1, n_nodes)
n1_indices_cand = n_nodes * tris_cand + next_node_index
n0_cand = Vector(
x_node[n0_indices_cand],
y_node[n0_indices_cand],
z_node[n0_indices_cand],
)
n1_cand = Vector(
x_node[n1_indices_cand],
y_node[n1_indices_cand],
z_node[n1_indices_cand],
)
intersect = Vector.intersects(n0_cand, n1_cand, transectv0, transectv1)
n0_inter = Vector(
n0_cand.x[intersect], n0_cand.y[intersect], n0_cand.z[intersect]
)
n1_inter = Vector(
n1_cand.x[intersect], n1_cand.y[intersect], n1_cand.z[intersect]
)
tris_inter = tris_cand[intersect]
n0_indices_inter = n0_indices_cand[intersect]
n1_indices_inter = n1_indices_cand[intersect]
intersections = Vector.intersection(
n0_inter, n1_inter, transectv0, transectv1
)
intersections = Vector(
earth_radius * intersections.x,
earth_radius * intersections.y,
earth_radius * intersections.z,
)
angular_distance = transectv0.angular_distance(intersections)
d_node_local = d_start + earth_radius * angular_distance
d_start += earth_radius * transectv0.angular_distance(transectv1)
node0_inter = np.mod(n0_indices_inter, n_nodes)
node1_inter = np.mod(n1_indices_inter, n_nodes)
node_weights = intersections.angular_distance(
n1_inter
) / n0_inter.angular_distance(n1_inter)
weights = np.zeros((len(tris_inter), n_horiz_weights))
cell_indices = np.zeros((len(tris_inter), n_horiz_weights), int)
for index in range(3):
weights[:, index] = (
node_weights
* node_cell_weights[tris_inter, node0_inter, index]
)
cell_indices[:, index] = node_cell_indices[
tris_inter, node0_inter, index
]
weights[:, index + 3] = (1.0 - node_weights) * node_cell_weights[
tris_inter, node1_inter, index
]
cell_indices[:, index + 3] = node_cell_indices[
tris_inter, node1_inter, index
]
if first:
x_out = intersections.x
y_out = intersections.y
z_out = intersections.z
d_node = d_node_local
tris = tris_inter
nodes = node0_inter
interp_cells = cell_indices
cell_weights = weights
first = False
else:
x_out = np.append(x_out, intersections.x)
y_out = np.append(y_out, intersections.y)
z_out = np.append(z_out, intersections.z)
d_node = np.append(d_node, d_node_local)
tris = np.concatenate((tris, tris_inter))
nodes = np.concatenate((nodes, node0_inter))
interp_cells = np.concatenate((interp_cells, cell_indices), axis=0)
cell_weights = np.concatenate((cell_weights, weights), axis=0)
d_transect[seg_index + 1] = d_start
epsilon = 1e-6 * subdivision_res
(
d_node,
x_out,
y_out,
z_out,
seg_tris,
seg_nodes,
interp_cells,
cell_weights,
valid_nodes,
) = _sort_intersections(
d_node,
tris,
nodes,
x_out,
y_out,
z_out,
interp_cells,
cell_weights,
epsilon,
)
lon_out, lat_out = cartesian_to_lon_lat(
x_out, y_out, z_out, earth_radius, degrees
)
valid_segs = seg_tris >= 0
cell_indices = -1 * np.ones(seg_tris.shape, dtype=int)
cell_indices[valid_segs] = ds_tris.triCellIndices.values[
seg_tris[valid_segs]
]
ds_out = xr.Dataset()
ds_out['xCartNode'] = (('nNodes',), x_out)
ds_out['yCartNode'] = (('nNodes',), y_out)
ds_out['zCartNode'] = (('nNodes',), z_out)
ds_out['dNode'] = (('nNodes',), d_node)
ds_out['lonNode'] = (('nNodes',), lon_out)
ds_out['latNode'] = (('nNodes',), lat_out)
ds_out['horizTriangleIndices'] = ('nSegments', seg_tris)
ds_out['horizCellIndices'] = ('nSegments', cell_indices)
ds_out['horizTriangleNodeIndices'] = (
('nSegments', 'nHorizBounds'),
seg_nodes,
)
ds_out['interpHorizCellIndices'] = (
('nNodes', 'nHorizWeights'),
interp_cells,
)
ds_out['interpHorizCellWeights'] = (
('nNodes', 'nHorizWeights'),
cell_weights,
)
ds_out['validNodes'] = (('nNodes',), valid_nodes)
transect_indices_on_horiz_node = np.zeros(d_node.shape, dtype=int)
transect_weights_on_horiz_node = np.zeros(d_node.shape)
for trans_index in range(len(d_transect) - 1):
d0 = d_transect[trans_index]
d1 = d_transect[trans_index + 1]
mask = np.logical_and(d_node >= d0, d_node < d1)
transect_indices_on_horiz_node[mask] = trans_index
transect_weights_on_horiz_node[mask] = (d1 - d_node[mask]) / (d1 - d0)
# last index will get missed by the mask and needs to be handled as a
# special case
transect_indices_on_horiz_node[-1] = len(d_transect) - 2
transect_weights_on_horiz_node[-1] = 0.0
ds_out['lonTransect'] = lon_transect
ds_out['latTransect'] = lat_transect
ds_out['xCartTransect'] = x
ds_out['yCartTransect'] = y
ds_out['zCartTransect'] = z
ds_out['dTransect'] = (lon_transect.dims, d_transect)
ds_out['transectIndicesOnHorizNode'] = (
('nNodes',),
transect_indices_on_horiz_node,
)
ds_out['transectWeightsOnHorizNode'] = (
('nNodes',),
transect_weights_on_horiz_node,
)
return ds_out
[docs]
def find_planar_transect_cells_and_weights(
x_transect, y_transect, ds_tris, ds_mesh, tree, subdivision_res=10e3
):
"""
Find "nodes" where the transect intersects the edges of the triangles
that make up MPAS cells.
Parameters
----------
x_transect : xarray.DataArray
The x points defining segments making up the transect
y_transect : xarray.DataArray
The y points defining segments making up the transect
ds_tris : xarray.Dataset
A dataset that defines triangles, the results of calling
:py:func:`mpas_tools.viz.transect.horiz.mesh_to_triangles()`
ds_mesh : xarray.Dataset
A data set with the full MPAS mesh.
tree : scipy.spatial.KDTree
A tree of edge centers from triangles making up an MPAS mesh, the
return value from
:py:func:`mpas_tools.viz.transect.horiz.make_triangle_tree()`
subdivision_res : float, optional
Resolution in m to use to subdivide the transect when looking for
intersection candidates.
Returns
-------
ds_out : xarray.Dataset
A dataset that contains "nodes" where the transect intersects the
edges of the triangles in ``ds_tris``. The nodes also include the two
end points of the transect, which typically lie within triangles. Each
internal node (that is, not including the end points) is purposefully
repeated twice, once for each triangle that node touches. This allows
for discontinuous fields between triangles (e.g. if one wishes to plot
constant values on each MPAS cell). The planar coordinates of these
nodes are ``xNode`` and ``yNode``. The distance along the transect of
each intersection is ``dNode``. The index of the triangle and the first
triangle node in ``ds_tris`` associated with each intersection node are
given by ``horizTriangleIndices`` and ``horizTriangleNodeIndices``,
respectively. The second node on the triangle for the edge associated
with the intersection is given by
``numpy.mod(horizTriangleNodeIndices + 1, 3)``.
The MPAS cell that a given node belongs to is given by
``horizCellIndices``. Each node also has an associated set of 6
``interpHorizCellIndices`` and ``interpHorizCellWeights`` that can be
used to interpolate from MPAS cell centers to nodes first with
area-weighted averaging to MPAS vertices and then linear interpolation
along triangle edges. Some of the weights may be zero, in which case
the associated ``interpHorizCellIndices`` will be -1.
Finally, ``xTransect`` and ``yTransect`` are included in the
dataset, along with ``dTransect``, the distance along the transect of
each original transect point. In order to interpolate values (e.g.
observations) from the original transect points to the intersection
nodes, linear interpolation indices ``transectIndicesOnHorizNode`` and
weights ``transectWeightsOnHorizNode`` are provided. The values at
nodes are found by::
nodeValues = ((transectValues[transectIndicesOnHorizNode] *
transectWeightsOnHorizNode)
+ (transectValues[transectIndicesOnHorizNode+1] *
(1.0 - transectWeightsOnHorizNode))
"""
buffer = np.maximum(
np.amax(ds_mesh.dvEdge.values), np.amax(ds_mesh.dcEdge.values)
)
n_nodes = ds_tris.sizes['nNodes']
node_cell_weights = ds_tris.nodeCellWeights.values
node_cell_indices = ds_tris.nodeCellIndices.values
x = x_transect
y = y_transect
x_node = ds_tris.xNode.values.ravel()
y_node = ds_tris.yNode.values.ravel()
coordNode = np.zeros((len(x_node), 2))
coordNode[:, 0] = x_node
coordNode[:, 1] = y_node
d_transect = np.zeros(x_transect.shape)
d_node = None
x_out = np.array([])
y_out = np.array([])
tris = None
nodes = None
interp_cells = None
cell_weights = None
n_horiz_weights = 6
first = True
d_start = 0.0
for seg_index in range(len(x) - 1):
sub_slice = slice(seg_index, seg_index + 2)
x_sub, y_sub, _, _ = subdivide_planar(
x[sub_slice].values, y[sub_slice].values, subdivision_res
)
start_point = Point(
x_transect[seg_index].values, y_transect[seg_index].values
)
end_point = Point(
x_transect[seg_index + 1].values, y_transect[seg_index + 1].values
)
segment = LineString([start_point, end_point])
coords = np.zeros((len(x_sub), 3))
coords[:, 0] = x_sub
coords[:, 1] = y_sub
radius = buffer + subdivision_res
index_list = tree.query_ball_point(x=coords, r=radius)
unique_indices = set()
for indices in index_list:
unique_indices.update(indices)
start_indices = np.array(list(unique_indices))
if len(start_indices) == 0:
continue
tris_cand = start_indices // n_nodes
next_node_index = np.mod(start_indices + 1, n_nodes)
end_indices = n_nodes * tris_cand + next_node_index
intersecting_nodes = list()
tris_inter_list = list()
x_intersection_list = list()
y_intersection_list = list()
node_weights_list = list()
node0_inter_list = list()
node1_inter_list = list()
distances_list = list()
for index in range(len(start_indices)):
start = start_indices[index]
end = end_indices[index]
node0 = Point(coordNode[start, 0], coordNode[start, 1])
node1 = Point(coordNode[end, 0], coordNode[end, 1])
edge = LineString([node0, node1])
if segment.intersects(edge):
point = segment.intersection(edge)
intersecting_nodes.append((node0, node1, start, end, edge))
if isinstance(point, LineString):
raise ValueError(
'A triangle edge exactly coincides with '
"a transect segment and I can't handle "
'that case. Try moving the transect a '
'tiny bit.'
)
elif not isinstance(point, Point):
raise ValueError(f'Unexpected intersection type {point}')
x_intersection_list.append(point.x)
y_intersection_list.append(point.y)
start_to_intersection = LineString([start_point, point])
weight = (
LineString([point, node1]).length
/ LineString([node0, node1]).length
)
node_weights_list.append(weight)
node0_inter_list.append(np.mod(start, n_nodes))
node1_inter_list.append(np.mod(end, n_nodes))
distances_list.append(start_to_intersection.length)
tris_inter_list.append(tris_cand[index])
distances = np.array(distances_list)
x_intersection = np.array(x_intersection_list)
y_intersection = np.array(y_intersection_list)
node_weights = np.array(node_weights_list)
node0_inter = np.array(node0_inter_list, dtype=int)
node1_inter = np.array(node1_inter_list, dtype=int)
tris_inter = np.array(tris_inter_list, dtype=int)
d_node_local = d_start + distances
d_start += segment.length
weights = np.zeros((len(tris_inter), n_horiz_weights))
cell_indices = np.zeros((len(tris_inter), n_horiz_weights), int)
for index in range(3):
weights[:, index] = (
node_weights
* node_cell_weights[tris_inter, node0_inter, index]
)
cell_indices[:, index] = node_cell_indices[
tris_inter, node0_inter, index
]
weights[:, index + 3] = (1.0 - node_weights) * node_cell_weights[
tris_inter, node1_inter, index
]
cell_indices[:, index + 3] = node_cell_indices[
tris_inter, node1_inter, index
]
if first:
x_out = x_intersection
y_out = y_intersection
d_node = d_node_local
tris = tris_inter
nodes = node0_inter
interp_cells = cell_indices
cell_weights = weights
first = False
else:
x_out = np.append(x_out, x_intersection)
y_out = np.append(y_out, y_intersection)
d_node = np.append(d_node, d_node_local)
tris = np.concatenate((tris, tris_inter))
nodes = np.concatenate((nodes, node0_inter))
interp_cells = np.concatenate((interp_cells, cell_indices), axis=0)
cell_weights = np.concatenate((cell_weights, weights), axis=0)
d_transect[seg_index + 1] = d_start
z_out = np.zeros(x_out.shape)
epsilon = 1e-6 * subdivision_res
(
d_node,
x_out,
y_out,
z_out,
seg_tris,
seg_nodes,
interp_cells,
cell_weights,
valid_nodes,
) = _sort_intersections(
d_node,
tris,
nodes,
x_out,
y_out,
z_out,
interp_cells,
cell_weights,
epsilon,
)
valid_segs = seg_tris >= 0
cell_indices = -1 * np.ones(seg_tris.shape, dtype=int)
cell_indices[valid_segs] = ds_tris.triCellIndices.values[
seg_tris[valid_segs]
]
ds_out = xr.Dataset()
ds_out['xNode'] = (('nNodes',), x_out)
ds_out['yNode'] = (('nNodes',), y_out)
ds_out['dNode'] = (('nNodes',), d_node)
ds_out['horizTriangleIndices'] = ('nSegments', seg_tris)
ds_out['horizCellIndices'] = ('nSegments', cell_indices)
ds_out['horizTriangleNodeIndices'] = (
('nSegments', 'nHorizBounds'),
seg_nodes,
)
ds_out['interpHorizCellIndices'] = (
('nNodes', 'nHorizWeights'),
interp_cells,
)
ds_out['interpHorizCellWeights'] = (
('nNodes', 'nHorizWeights'),
cell_weights,
)
ds_out['validNodes'] = (('nNodes',), valid_nodes)
transect_indices_on_horiz_node = np.zeros(d_node.shape, int)
transect_weights_on_horiz_node = np.zeros(d_node.shape)
for trans_index in range(len(d_transect) - 1):
d0 = d_transect[trans_index]
d1 = d_transect[trans_index + 1]
mask = np.logical_and(d_node >= d0, d_node < d1)
transect_indices_on_horiz_node[mask] = trans_index
transect_weights_on_horiz_node[mask] = (d1 - d_node[mask]) / (d1 - d0)
# last index will get missed by the mask and needs to be handled as a
# special case
transect_indices_on_horiz_node[-1] = len(d_transect) - 2
transect_weights_on_horiz_node[-1] = 0.0
ds_out['xTransect'] = x
ds_out['yTransect'] = y
ds_out['dTransect'] = (x_transect.dims, d_transect)
ds_out['transectIndicesOnHorizNode'] = (
('nNodes',),
transect_indices_on_horiz_node,
)
ds_out['transectWeightsOnHorizNode'] = (
('nNodes',),
transect_weights_on_horiz_node,
)
return ds_out
[docs]
def interp_mpas_horiz_to_transect_nodes(ds_transect, da):
"""
Interpolate a 2D (``nCells``) MPAS DataArray to transect nodes, linearly
interpolating fields between the closest neighboring cells
Parameters
----------
ds_transect : xr.Dataset
A dataset that defines an MPAS transect, the results of calling
``find_spherical_transect_cells_and_weights()`` or
``find_planar_transect_cells_and_weights()``
da : xr.DataArray
An MPAS 2D field with dimensions `nCells`` (possibly among others)
Returns
-------
da_nodes : xr.DataArray
The data array interpolated to transect nodes with dimensions
``nNodes`` (in addition to whatever dimensions were in ``da`` besides
``nCells``)
"""
interp_cell_indices = ds_transect.interpHorizCellIndices
interp_cell_weights = ds_transect.interpHorizCellWeights
da = da.isel(nCells=interp_cell_indices)
da_nodes = (da * interp_cell_weights).sum(dim='nHorizWeights')
da_nodes = da_nodes.where(ds_transect.validNodes)
return da_nodes
def _sort_intersections(
d_node,
tris,
nodes,
x_out,
y_out,
z_out,
interp_cells,
cell_weights,
epsilon,
):
"""sort nodes by distance and define segment between them"""
sort_indices = np.argsort(d_node)
d_sorted = d_node[sort_indices]
# make a list of indices for each unique value of d
d = d_sorted[0]
unique_d_indices = [sort_indices[0]]
unique_d_all_indices = [[sort_indices[0]]]
for (
index,
next_d,
) in zip(sort_indices[1:], d_sorted[1:]):
if next_d - d < epsilon:
# this d value is effectively the same as the last, so we'll treat
# it as the same
unique_d_all_indices[-1].append(index)
else:
# this is a new d, so we'll add to a new list
d = next_d
unique_d_indices.append(index)
unique_d_all_indices.append([index])
# there is a segment between each unique d, though some are invalid (do
# not correspond to a triangle)
seg_tris_list = list()
seg_nodes_list = list()
index0 = unique_d_indices[0]
indices0 = unique_d_all_indices[0]
d0 = d_node[index0]
indices = [index0]
ds = [d0]
for seg_index in range(len(unique_d_all_indices) - 1):
indices1 = unique_d_all_indices[seg_index + 1]
index1 = unique_d_indices[seg_index + 1]
d1 = d_node[index1]
# are there any triangles in common between this d value and the next?
tris0 = tris[indices0]
tris1 = tris[indices1]
both = set(tris0).intersection(set(tris1))
if len(both) > 0:
tri = both.pop()
seg_tris_list.append(tri)
indices.append(index1)
ds.append(d1)
# the triangle nodes are the 2 corresponding to the same triangle
# in the original list
index0 = indices0[np.where(tris0 == tri)[0][0]]
index1 = indices1[np.where(tris1 == tri)[0][0]]
seg_nodes_list.append([nodes[index0], nodes[index1]])
else:
# this is an invalid segment so we need to insert and extra invalid
# node to allow for proper masking
seg_tris_list.extend([-1, -1])
seg_nodes_list.extend([[-1, -1], [-1, -1]])
indices.extend([index0, index1])
ds.extend([0.5 * (d0 + d1), d1])
index0 = index1
indices0 = indices1
d0 = d1
indices = np.array(indices, dtype=int)
d_node = np.array(ds, dtype=float)
seg_tris = np.array(seg_tris_list, dtype=int)
seg_nodes = np.array(seg_nodes_list, dtype=int)
valid_nodes = np.ones(len(indices), dtype=bool)
valid_nodes[1:-1] = np.logical_or(seg_tris[0:-1] >= 0, seg_tris[1:] > 0)
x_out = x_out[indices]
y_out = y_out[indices]
z_out = z_out[indices]
interp_cells = interp_cells[indices, :]
cell_weights = cell_weights[indices, :]
return (
d_node,
x_out,
y_out,
z_out,
seg_tris,
seg_nodes,
interp_cells,
cell_weights,
valid_nodes,
)
def _fix_periodic_tris(ds_tris, periodic_var, period):
"""
make sure the given node coordinate on tris is within one period of the
cell center
"""
coord_node = ds_tris[periodic_var].values
coord_cell = coord_node[:, 0]
n_triangles = ds_tris.sizes['nTriangles']
copy_pos = np.zeros(coord_cell.shape, dtype=bool)
copy_neg = np.zeros(coord_cell.shape, dtype=bool)
for i_node in [1, 2]:
mask = coord_node[:, i_node] - coord_cell > 0.5 * period
copy_pos = np.logical_or(copy_pos, mask)
coord_node[:, i_node][mask] = coord_node[:, i_node][mask] - period
mask = coord_node[:, i_node] - coord_cell < -0.5 * period
copy_neg = np.logical_or(copy_neg, mask)
coord_node[:, i_node][mask] = coord_node[:, i_node][mask] + period
pos_indices = np.nonzero(copy_pos)[0]
neg_indices = np.nonzero(copy_neg)[0]
tri_indices = np.append(
np.append(np.arange(0, n_triangles), pos_indices), neg_indices
)
ds_new = xr.Dataset(ds_tris)
ds_new[periodic_var] = (('nTriangles', 'nNodes'), coord_node)
ds_new = ds_new.isel(nTriangles=tri_indices)
coord_node = ds_new[periodic_var].values
pos_slice = slice(n_triangles, n_triangles + len(pos_indices))
coord_node[pos_slice, :] = coord_node[pos_slice, :] + period
neg_slice = slice(
n_triangles + len(pos_indices),
n_triangles + len(pos_indices) + len(neg_indices),
)
coord_node[neg_slice, :] = coord_node[neg_slice, :] - period
ds_new[periodic_var] = (('nTriangles', 'nNodes'), coord_node)
return ds_new