Source code for compass.landice.tests.mesh_modifications.subdomain_extractor.extract_region
import os
import shutil
import sys
import mpas_tools
import numpy as np
import xarray
from mpas_tools.io import write_netcdf
from mpas_tools.logging import check_call
from mpas_tools.mesh.conversion import convert, cull
from pyremap import MpasCellMeshDescriptor, Remapper
from compass.landice.mesh import mpas_flood_fill
from compass.model import make_graph_file
from compass.step import Step
[docs]
class ExtractRegion(Step):
"""
A step for extracting a regional domain from a larger domain
Attributes
----------
"""
[docs]
def __init__(self, test_case):
"""
Create the step
Parameters
----------
test_case : compass.TestCase
The test case this step belongs to
"""
super().__init__(test_case=test_case, name='extract_region')
def setup(self):
self.ntasks = 128
self.mintasks = 1
def run(self):
"""
Run this step of the test case
"""
logger = self.logger
# Get info from config file
config = self.config
section = config['subdomain']
source_file_path = section.get('source_file')
source_file_name = os.path.basename(source_file_path)
source_file_rootname = source_file_name.rsplit('.nc', 1)[0]
region_definition = section.get('region_definition')
geojson_file = section.get('geojson_file')
region_mask_file = section.get('region_mask_file')
region_number = section.getint('region_number')
dest_file_name = section.get('dest_file_name')
dest_file_rootname = dest_file_name.rsplit('.nc', 1)[0]
mesh_projection = section.get('mesh_projection')
extend_mesh = section.getboolean('extend_ocean_buffer')
grow_iters = section.getint('grow_iters')
interp_method = section.get('interp_method')
extra_file1 = section.get('extra_file1')
extra_file2 = section.get('extra_file2')
extra_file3 = section.get('extra_file3')
extra_file4 = section.get('extra_file4')
extra_file5 = section.get('extra_file5')
# create a tmp dir for intermediate files
tmpdir = os.path.join(self.work_dir, 'tmp')
os.makedirs(tmpdir, exist_ok=True)
# get needed dims from source mesh
ds_src = xarray.open_dataset(source_file_path)
nCells = ds_src.sizes['nCells']
levels = ds_src.sizes['nVertLevels']
# create cull mask
if region_definition == 'geojson':
args = ['compute_mpas_region_masks',
'-m', source_file_path,
'-o', 'cull_mask.nc',
'-g', geojson_file,
'--process_count', f'{self.ntasks}',
'--format', mpas_tools.io.default_format,
'--engine', mpas_tools.io.default_engine]
check_call(args, logger=logger)
dsMaskOut = xarray.open_dataset('cull_mask.nc')
elif region_definition == 'region_mask_file':
logger.info('creating cull mask file')
dsMask = xarray.open_dataset(region_mask_file)
regionCellMasks = dsMask['regionCellMasks'][:].values
# get region mask for the requested region
keepMask = regionCellMasks[:, region_number - 1]
if extend_mesh:
# Grow the mask into the ocean, because the standard regions
# may end at the ice terminus.
thickness = ds_src['thickness'][:].values
bed = ds_src['bedTopography'][:].values
oceanMask = np.squeeze((thickness[0, :] == 0.0) * (bed[0, :] <=
0.0))
floatMask = np.squeeze(((thickness[0, :] * 910.0 / 1028.0 +
bed[0, :]) < 0.0) *
(thickness[0, :] > 0))
conc = ds_src['cellsOnCell'][:].values
neonc = ds_src['nEdgesOnCell'][:].values
# First grow forward to capture any adjacent ice shelf
logger.info('Starting floating ice fill')
keepMask = mpas_flood_fill(keepMask, floatMask, conc, neonc)
# Don't grow into other regions.
# The area to grow into is region adjacent to the domain that
# either has no region assigned to it OR is open ocean.
# We also want to fill into any *adjacent* floating ice, due to
# some funky region boundaries near ice-shelf fronts.
logger.info('Starting ocean grow fill')
noRegionMask = (np.squeeze(regionCellMasks.sum(axis=1)) == 0)
growMask = np.logical_or(noRegionMask, oceanMask)
keepMask = mpas_flood_fill(keepMask, growMask, conc, neonc,
grow_iters=grow_iters)
# To call 'cull' with an inverse mask, we need a dataset with the
# mask saved to the field regionCellMasks
outdata = {'regionCellMasks': (('nCells', 'nRegions'),
keepMask.reshape(nCells, 1))}
dsMaskOut = xarray.Dataset(data_vars=outdata)
# For troubleshooting, one may want to inspect the mask,
# so write out (otherwise not necessary to save to disk)
write_netcdf(dsMaskOut, os.path.join(tmpdir, 'cull_mask.nc'))
else:
sys.exit('ERROR: unknown value for region_definition='
f'{region_definition}')
# cull the mesh
logger.info('culling and converting mesh')
ds_out = cull(ds_src, dsInverse=dsMaskOut, logger=logger)
# convert mesh
ds_out = convert(ds_out, logger=logger)
write_netcdf(ds_out,
os.path.join(tmpdir,
f'{source_file_rootname}_culled.nc'))
# mark horns for culling
logger.info('Marking horns for culling')
args = ['mark_horns_for_culling.py', '-f',
os.path.join(tmpdir, f'{source_file_rootname}_culled.nc')]
check_call(args, logger=logger)
# cull again
logger.info('culling and converting mesh')
ds_out = xarray.open_dataset(
os.path.join(tmpdir, f'{source_file_rootname}_culled.nc'))
ds_out = cull(ds_out, logger=logger)
ds_out = convert(ds_out, logger=logger)
dest_mesh_only_name = os.path.join(tmpdir, 'dest_mesh_only.nc')
write_netcdf(ds_out, dest_mesh_only_name)
# set lat/lon
logger.info('calling set_lat_lon_fields_in_planar_grid.py')
args = ['set_lat_lon_fields_in_planar_grid.py',
'-f', dest_mesh_only_name, '-p', mesh_projection]
check_call(args, logger=logger)
if interp_method == 'ncremap':
# remap data from the original file to the culled mesh using
# pyremap with nearest neighbor interpolation
in_descriptor = MpasCellMeshDescriptor(source_file_path,
'src_mesh')
out_descriptor = MpasCellMeshDescriptor(dest_mesh_only_name,
'dst_mesh')
mapping_filename = os.path.join(tmpdir, 'map_src_to_dst_nstd.nc')
logger.info(f'Creating the mapping file {mapping_filename}...')
remapper = Remapper(in_descriptor, out_descriptor,
mapping_filename)
parallel_executable = config.get('parallel',
'parallel_executable')
remapper.build_mapping_file(method='neareststod',
mpiTasks=self.ntasks,
tempdir=self.work_dir, logger=logger,
esmf_parallel_exec=parallel_executable) # noqa
logger.info('done.')
logger.info('Remapping mesh file...')
_remap_with_ncremap(
source_file_path,
os.path.join(tmpdir, f'{dest_file_rootname}_vars_only.nc'),
mapping_filename, logger, tmpdir)
# now combine the remapped variables with the mesh fields
# that don't get remapped
shutil.copyfile(dest_mesh_only_name, dest_file_name)
args = [
'ncks', '-A',
os.path.join(tmpdir, f'{dest_file_rootname}_vars_only.nc'),
dest_file_name]
check_call(args, logger=logger)
logger.info('done.')
logger.info(f'Created {dest_file_name}')
elif interp_method == 'mali_interp':
# create landice mesh
logger.info('calling create_landice_grid_from_generic_MPAS_grid.py') # noqa
args = ['create_landice_grid_from_generic_MPAS_grid.py',
'-i', os.path.join(tmpdir, 'dest_mesh_only.nc'),
'-o', dest_file_name,
'-l', f'{levels}', '-v', 'glimmer',
'--beta', '--thermal', '--obs', '--diri']
check_call(args, logger=logger)
# interpolate to new mesh using nearest neighbor to ensure we get
# identical values
logger.info('calling interpolate_to_mpasli_grid.py')
args = ['interpolate_to_mpasli_grid.py',
'-s', source_file_path,
'-d', dest_file_name, '-m', 'n']
check_call(args, logger=logger)
else:
sys.exit(f"Error: Unknown interp_method of {interp_method}")
# mark Dirichlet boundaries
logger.info('Marking domain boundaries dirichlet')
args = ['mark_domain_boundaries_dirichlet.py',
'-f', dest_file_name]
check_call(args, logger=logger)
logger.info('creating graph.info')
make_graph_file(mesh_filename=dest_file_name,
graph_filename='graph.info')
for extra_file in [extra_file1, extra_file2, extra_file3,
extra_file4, extra_file5]:
if extra_file != "None":
if interp_method != "ncremap":
sys.exit("Error: interpolating ancillary files is only "
"supported when interp_method=ncremap")
dst_file = \
f'{dest_file_rootname}_{os.path.basename(extra_file)}'
_remap_with_ncremap(extra_file,
dst_file,
mapping_filename,
logger, tmpdir)
logger.info(f'Created {dst_file}')
def _remap_with_ncremap(src_path, dst_file, mapping_filename, logger,
tmpdir='.'):
"""
Remaps a file using ncremap
Parameters
----------
src_path : str
path to source file
dst_file : str
name of the destination file that should be created
mapping_filename : str
name of already generated mapping file
logger
logger object
tmpdir : str
temp dir to write intermediate files, optional
Returns
-------
"""
src_file_rootname = os.path.basename(src_path).rsplit('.nc', 1)[0]
# ncremap requires the spatial dimension to be the last one,
# which MALI does not exclusively follow. So we have to
# permute dimensions before calling ncremap, and then permute back
args = ['ncpdq', '-O', '-a',
'Time,nVertInterfaces,nVertLevels,nRegions,nISMIP6OceanLayers,nEdges,nCells', # noqa
src_path, os.path.join(tmpdir, f'{src_file_rootname}_permuted.nc')]
check_call(args, logger=logger)
args = ['ncremap', '-m', mapping_filename,
os.path.join(tmpdir, f'{src_file_rootname}_permuted.nc'),
os.path.join(tmpdir, f'{dst_file}_permuted.nc')]
check_call(args, logger=logger)
args = ['ncpdq', '-O', '-a',
'Time,nCells,nEdges,nVertInterfaces,nVertLevels,nRegions,nISMIP6OceanLayers', # noqa
os.path.join(tmpdir, f'{dst_file}_permuted.nc'),
os.path.join(tmpdir, f'{dst_file}_extra_var.nc')]
check_call(args, logger=logger)
# drop some extra vars that ncremap adds
ds_out = xarray.open_dataset(os.path.join(tmpdir,
f'{dst_file}_extra_var.nc'))
ds_out = ds_out.drop_vars(['lat', 'lon', 'lat_vertices',
'lon_vertices', 'area'])
# drop variables on vertices or edges, which will not have been
# remapped properly
drop_list = []
for varname, da in ds_out.data_vars.items():
if 'nVertices' in da.dims or 'nEdges' in da.dims:
drop_list.append(varname)
ds_out = ds_out.drop_vars(drop_list)
write_netcdf(ds_out, dst_file)