Source code for compass.ocean.tests.ice_shelf_2d.initial_state

import numpy as np
import xarray
from mpas_tools.cime.constants import constants
from mpas_tools.io import write_netcdf
from mpas_tools.mesh.conversion import convert, cull
from mpas_tools.planar_hex import make_planar_hex_mesh

from compass.ocean.iceshelf import compute_land_ice_pressure_and_draft
from compass.ocean.vertical import init_vertical_coord
from compass.step import Step


[docs] class InitialState(Step): """ A step for creating a mesh and initial condition for ice-shelf 2D test cases Attributes ---------- resolution : str The resolution of the test case """
[docs] def __init__(self, test_case, resolution): """ Create the step Parameters ---------- test_case : compass.TestCase The test case this step belongs to resolution : str The resolution of the test case """ super().__init__(test_case=test_case, name='initial_state') self.resolution = resolution for file in ['base_mesh.nc', 'culled_mesh.nc', 'culled_graph.info', 'initial_state.nc']: self.add_output_file(file)
[docs] def run(self): """ Run this step of the test case """ config = self.config logger = self.logger section = config['ice_shelf_2d'] nx = section.getint('nx') ny = section.getint('ny') dc = section.getfloat('dc') dsMesh = make_planar_hex_mesh(nx=nx, ny=ny, dc=dc, nonperiodic_x=False, nonperiodic_y=True) write_netcdf(dsMesh, 'base_mesh.nc') dsMesh = cull(dsMesh, logger=logger) dsMesh = convert(dsMesh, graphInfoFileName='culled_graph.info', logger=logger) write_netcdf(dsMesh, 'culled_mesh.nc') bottom_depth = config.getfloat('vertical_grid', 'bottom_depth') section = config['ice_shelf_2d'] temperature = section.getfloat('temperature') surface_salinity = section.getfloat('surface_salinity') bottom_salinity = section.getfloat('bottom_salinity') # points 1 and 2 are where angles on ice shelf are located. # point 3 is at the surface. # d variables are total water-column thickness below ice shelf y1 = section.getfloat('y1') y2 = section.getfloat('y2') y3 = section.getfloat('y3') d1 = section.getfloat('y1_water_column_thickness') d2 = section.getfloat('y2_water_column_thickness') d3 = bottom_depth ds = dsMesh.copy() ds['bottomDepth'] = bottom_depth * xarray.ones_like(ds.xCell) yCell = ds.yCell column_thickness = xarray.where( yCell < y1, d1, d1 + (d2 - d1) * (yCell - y1) / (y2 - y1)) column_thickness = xarray.where( yCell < y2, column_thickness, d2 + (d3 - d2) * (yCell - y2) / (y3 - y2)) column_thickness = xarray.where(yCell < y3, column_thickness, d3) ds['ssh'] = -bottom_depth + column_thickness # set up the vertical coordinate init_vertical_coord(config, ds) modify_mask = xarray.where(yCell < y3, 1, 0).expand_dims( dim='Time', axis=0) landIceFraction = modify_mask.astype(float) landIceMask = modify_mask.copy() landIceFloatingFraction = landIceFraction.copy() landIceFloatingMask = landIceMask.copy() ref_density = constants['SHR_CONST_RHOSW'] landIcePressure, landIceDraft = compute_land_ice_pressure_and_draft( ssh=ds.ssh, modify_mask=modify_mask, ref_density=ref_density) salinity = surface_salinity + ((bottom_salinity - surface_salinity) * (ds.zMid / (-bottom_depth))) salinity, _ = xarray.broadcast(salinity, ds.layerThickness) salinity = salinity.transpose('Time', 'nCells', 'nVertLevels') normalVelocity = xarray.zeros_like(ds.xEdge) normalVelocity, _ = xarray.broadcast(normalVelocity, ds.refBottomDepth) normalVelocity = normalVelocity.transpose('nEdges', 'nVertLevels') normalVelocity = normalVelocity.expand_dims(dim='Time', axis=0) ds['temperature'] = temperature * xarray.ones_like(ds.layerThickness) ds['salinity'] = salinity ds['normalVelocity'] = normalVelocity ds['fCell'] = xarray.zeros_like(ds.xCell) ds['fEdge'] = xarray.zeros_like(ds.xEdge) ds['fVertex'] = xarray.zeros_like(ds.xVertex) ds['sshAdjustmentMask'] = modify_mask ds['landIceFraction'] = landIceFraction ds['landIceFloatingFraction'] = landIceFloatingFraction ds['landIceMask'] = landIceMask ds['landIceFloatingMask'] = landIceFloatingMask ds['landIcePressure'] = landIcePressure ds['landIceDraft'] = landIceDraft write_netcdf(ds, 'initial_state.nc') # Generate the tidal forcing dataset whether it is used or not ds_forcing = xarray.Dataset() y_max = np.max(ds.yCell.values) ds_forcing['tidalInputMask'] = xarray.where( ds.yCell > (y_max - 0.6 * 5.0e3), 1.0, 0.0) write_netcdf(ds_forcing, 'init_mode_forcing_data.nc')