Source code for compass.ocean.tests.internal_wave.viz

import matplotlib.pyplot as plt
import numpy
import xarray

from compass.step import Step

[docs] class Viz(Step): """ A step for visualizing a cross-section through the internal wave """
[docs] def __init__(self, test_case): """ Create the step Parameters ---------- test_case : compass.TestCase The test case this step belongs to """ super().__init__(test_case=test_case, name='viz') self.add_input_file(filename='', target='../forward/') self.add_output_file('uNormal_depth_section_t0.png') self.add_output_file('pt_depth_section_t0.png') self.add_output_file('sa_depth_section_t0.png') self.add_output_file('layerThickness_depth_section_t0.png')
[docs] def run(self): """ Run this step of the test case """ ds = xarray.open_dataset('') figsize = [6.4, 4.8] if 'Time' not in ds.dims: print('Dataset missing time dimension') return nSteps = ds.sizes['Time'] # number of timesteps tend = nSteps - 1 for j in [0, tend]: ds1 = ds.isel(Time=j) # prep all variables for uNormal plot ds1 = ds1.sortby('yEdge') nCells = ds1.sizes['nCells'] nEdges = ds1.sizes['nEdges'] nVertLevels = ds1.sizes['nVertLevels'] xEdge = numpy.zeros((nEdges)) xEdge = ds1.xEdge yCell = numpy.zeros((nCells)) yCell = ds1.yCell xEdge_mid = numpy.median(xEdge) edgeMask_x = numpy.equal(xEdge, xEdge_mid) zIndex = xarray.DataArray(data=numpy.arange(nVertLevels), dims='nVertLevels') zInterface = numpy.zeros((nCells, nVertLevels + 1)) zInterface[:, 0] = ds1.ssh.values for zIndex in range(nVertLevels): thickness = ds1.layerThickness.isel(nVertLevels=zIndex) thickness = thickness.fillna(0.) zInterface[:, zIndex + 1] = \ zInterface[:, zIndex] - thickness.values zMid = numpy.zeros((nCells, nVertLevels)) for zIndex in range(nVertLevels): zMid[:, zIndex] = (zInterface[:, zIndex] + numpy.divide(zInterface[:, zIndex + 1] - zInterface[:, zIndex], 2.)) # Solve for lateral boundaries of uNormal at cell centers for # x-section cellsOnEdge = ds1.cellsOnEdge cellsOnEdge_x = cellsOnEdge[edgeMask_x, :] yEdges = numpy.zeros((len(cellsOnEdge_x) + 1)) for i in range(len(cellsOnEdge_x)): if cellsOnEdge[i, 1] == 0: yEdges[i] = yCell[cellsOnEdge_x[i, 0] - 1] yEdges[i + 1] = yCell[cellsOnEdge_x[i, 0] - 1] elif cellsOnEdge[i, 1] == 0: yEdges[i] = yCell[cellsOnEdge_x[i, 1] - 1] yEdges[i + 1] = yCell[cellsOnEdge_x[i, 1] - 1] else: yEdges[i] = min(yCell[cellsOnEdge_x[i, 0] - 1], yCell[cellsOnEdge_x[i, 1] - 1]) yEdges[i + 1] = max(yCell[cellsOnEdge_x[i, 0] - 1], yCell[cellsOnEdge_x[i, 1] - 1]) zInterfaces_mesh, yEdges_mesh = numpy.meshgrid(zInterface[0, :], yEdges) normalVelocity = numpy.zeros((nCells, nVertLevels)) normalVelocity = ds1.normalVelocity normalVelocity_xmesh = normalVelocity[edgeMask_x, :] # Figures plt.figure(figsize=figsize, dpi=100) cmax = numpy.max(numpy.abs(normalVelocity_xmesh.values)) plt.pcolormesh(numpy.divide(yEdges_mesh, 1e3), zInterfaces_mesh, normalVelocity_xmesh.values, cmap='RdBu', vmin=-1. * cmax, vmax=cmax) plt.xlabel('y (km)') plt.ylabel('z (m)') cbar = plt.colorbar()'uNormal (m/s)') plt.savefig('uNormal_depth_section_t{}.png'.format(j), bbox_inches='tight', dpi=200) plt.close() # ------------------------------------------------------------------ # Plot cell-centered variables # ------------------------------------------------------------------ # Prep variables for cell quantities cellIndex = numpy.subtract(cellsOnEdge_x[1:, 0], 1) yEdge = numpy.zeros((nEdges)) yEdge = ds1.yEdge yEdge_x = yEdge[edgeMask_x] zInterfaces_mesh, yCells_mesh = numpy.meshgrid(zInterface[0, :], yEdge_x) # Import cell quantities layerThickness = numpy.zeros((nCells, nVertLevels)) layerThickness = ds1.layerThickness layerThickness_x = layerThickness[cellIndex, :] temperature = numpy.zeros((nCells, nVertLevels)) temperature = ds1.temperature temperature_z = temperature.mean(dim='nCells') zMid_z = zMid.mean(axis=0) temperature_x = temperature[cellIndex, :] salinity = numpy.zeros((nCells, nVertLevels)) salinity = ds1.salinity salinity_x = salinity[cellIndex, :] w = numpy.zeros((nCells, nVertLevels)) w = ds1.vertVelocityTop w_x = w[cellIndex, :] # Figures plt.figure(figsize=figsize, dpi=100) plt.plot(temperature_z.values, zMid_z) plt.xlabel('PT (C)') plt.ylabel('z (m)') plt.savefig('pt_depth_t{}.png'.format(j), bbox_inches='tight', dpi=200) plt.close() plt.figure(figsize=figsize, dpi=100) plt.pcolormesh(numpy.divide(yCells_mesh, 1e3), zInterfaces_mesh, temperature_x.values, cmap='viridis') plt.xlabel('y (km)') plt.ylabel('z (m)') cbar = plt.colorbar()'PT (C)') plt.savefig('pt_depth_section_t{}.png'.format(j), bbox_inches='tight', dpi=200) plt.close() plt.figure(figsize=figsize, dpi=100) plt.pcolormesh(numpy.divide(yCells_mesh, 1e3), zInterfaces_mesh, salinity_x.values, cmap='viridis') plt.xlabel('y (km)') plt.ylabel('z (m)') cbar = plt.colorbar()'SA (g/kg)') plt.savefig('sa_depth_section_t{}.png'.format(j), bbox_inches='tight', dpi=200) plt.close() plt.figure(figsize=figsize, dpi=100) plt.pcolormesh(numpy.divide(yCells_mesh, 1e3), zInterfaces_mesh, w_x.values[:, :-1], cmap='viridis') plt.xlabel('y (km)') plt.ylabel('z (m)') cbar = plt.colorbar()'h (m)') plt.savefig('w_depth_section_t{}.png'.format(j), bbox_inches='tight', dpi=200) plt.close() plt.figure(figsize=figsize, dpi=100) plt.pcolormesh(numpy.divide(yCells_mesh, 1e3), zInterfaces_mesh, layerThickness_x.values, cmap='viridis') plt.xlabel('y (km)') plt.ylabel('z (m)') cbar = plt.colorbar()'h (m)') plt.savefig('layerThickness_depth_section_t{}.png'.format(j), bbox_inches='tight', dpi=200) plt.close()