Source code for compass.ocean.tests.utility.extrap_woa.extrap_step
import os
from functools import partial
from multiprocessing import Pool
import numpy as np
import progressbar
import xarray as xr
from scipy.signal import convolve2d
from compass.step import Step
[docs]
class ExtrapStep(Step):
"""
Extrapolate WOA 2023 data into missing ocean regions, then land and
grounded ice
Attributes
----------
woa_filename : str
The name of the output file name after extrapolation
"""
[docs]
def __init__(self, test_case):
"""
Create a new test case
Parameters
----------
test_case : compass.ocean.tests.utility.extrap_woa.ExtrapWoa
The test case this step belongs to
"""
super().__init__(test_case=test_case, name='extrap', cpus_per_task=64,
min_cpus_per_task=1, openmp_threads=1)
self.add_input_file(
filename='woa.nc',
target='../combine/woa_combined.nc')
self.add_input_file(
filename='topography.nc',
target='../remap_topography/topography_remapped.nc')
self.woa_filename = None
[docs]
def setup(self):
"""
Determine the output filename
"""
self.woa_filename = 'woa23_decav_0.25_extrap.nc'
self.add_output_file(self.woa_filename)
[docs]
def run(self):
"""
Extrapolate WOA 2023 model temperature and salinity into ice-shelf
cavities.
"""
pool = Pool(self.cpus_per_task)
self._make_3d_ocean_mask()
# extrapolate horizontally using the ocean mask
self._extrap_horiz(use_ocean_mask=True, pool=pool)
# extrapolate vertically using the ocean mask
self._extrap_vert(use_ocean_mask=True)
# extrapolate horizontally into land and grounded ice
self._extrap_horiz(use_ocean_mask=False, pool=pool)
# extrapolate vertically into land and grounded ice
self._extrap_vert(use_ocean_mask=False)
pool.terminate()
@staticmethod
def _make_3d_ocean_mask():
grid_filename = 'woa.nc'
topo_filename = 'topography.nc'
out_filename = 'ocean_mask.nc'
with xr.open_dataset(topo_filename) as ds_topo:
bathymetry = ds_topo.bathymetry
ocean_mask = ds_topo.ocean_mask
ds_out = xr.Dataset()
with xr.open_dataset(grid_filename) as ds_grid:
for var in ['lon', 'lat', 'depth']:
ds_out[var] = ds_grid[var]
ds_out[f'{var}_bnds'] = ds_grid[f'{var}_bnds']
z_top = -ds_grid.depth_bnds.isel(nbounds=0)
ocean_mask_3d = np.logical_and(
bathymetry <= z_top,
ocean_mask >= 0.5).astype(int)
ocean_mask_3d = \
ocean_mask_3d.transpose('depth', 'lat', 'lon')
ds_out['ocean_mask'] = ocean_mask_3d
ds_out.to_netcdf(out_filename)
def _extrap_horiz(self, use_ocean_mask, pool):
logger = self.logger
if use_ocean_mask:
in_filename = 'woa.nc'
out_filename = 'extrap_ocean/woa_extrap_horiz.nc'
progress_dir = 'extrap_ocean/extrap_horiz'
else:
in_filename = 'extrap_ocean/woa_extrap.nc'
out_filename = 'extrap_land/woa_extrap_horiz.nc'
progress_dir = 'extrap_land/extrap_horiz'
try:
os.makedirs(progress_dir)
except FileExistsError:
pass
with xr.open_dataset(in_filename) as ds_woa:
ndepth = ds_woa.sizes['depth']
dims = ds_woa.pt_an.dims
logger.info(' Horizontally extrapolating WOA data...')
progress = self.log_filename is None
if progress:
widgets = [' ', progressbar.Percentage(), ' ',
progressbar.Bar(), ' ', progressbar.ETA()]
bar = progressbar.ProgressBar(widgets=widgets,
maxval=ndepth).start()
else:
bar = None
partial_func = partial(_extrap_level, use_ocean_mask, in_filename,
progress_dir)
depth_indices = range(ndepth)
files = list()
for depth_index, tmp_filename in enumerate(
pool.imap(partial_func, depth_indices)):
files.append(tmp_filename)
if progress:
bar.update(depth_index + 1)
if progress:
bar.finish()
ds_out = xr.open_mfdataset(files, combine='nested', concat_dim='depth')
for field_name in ['pt_an', 's_an']:
ds_out[field_name] = ds_out[field_name].transpose(*dims)
ds_out.to_netcdf(out_filename)
def _extrap_vert(self, use_ocean_mask):
logger = self.logger
if use_ocean_mask:
in_filename = 'extrap_ocean/woa_extrap_horiz.nc'
out_filename = 'extrap_ocean/woa_extrap.nc'
else:
in_filename = 'extrap_land/woa_extrap_horiz.nc'
out_filename = 'woa23_decav_0.25_extrap.nc'
ds = xr.open_dataset(in_filename)
if use_ocean_mask:
ds_mask = xr.open_dataset('ocean_mask.nc')
ocean_mask = ds_mask.ocean_mask
else:
ocean_mask = None
ndepth = ds.sizes['depth']
logger.info(' Vertically extrapolating WOA data...')
progress = self.log_filename is None
if progress:
widgets = [f' pt_an z=1/{ndepth}: ',
progressbar.Percentage(), ' ',
progressbar.Bar(), ' ', progressbar.ETA()]
bar = progressbar.ProgressBar(widgets=widgets,
maxval=2 * ndepth).start()
else:
bar = None
count = 0
for field_name in ['pt_an', 's_an']:
slices = [ds[field_name].isel(depth=0).drop_vars(['depth'])]
for depth_index in range(1, ndepth):
field = ds[field_name]
field_above = field.isel(depth=depth_index - 1)
field_local = field.isel(depth=depth_index)
mask = field_local.isnull()
if ocean_mask is not None:
mask = np.logical_and(mask,
ocean_mask.isel(depth=depth_index))
field_local = xr.where(mask, field_above, field_local)
slices.append(field_local)
count += 1
if progress:
bar.widgets[0] = \
f' {field_name} z={depth_index + 1}/{ndepth}: '
bar.update(count)
field = xr.concat(slices, dim='depth')
attrs = ds[field_name].attrs
dims = ds[field_name].dims
ds[field_name] = field.transpose(*dims)
ds[field_name].attrs = attrs
if progress:
bar.finish()
ds.to_netcdf(out_filename)
def _extrap_level(use_ocean_mask, in_filename, progress_dir, depth_index):
out_filename = os.path.join(progress_dir, f'woa_lev_{depth_index}.nc')
ds = xr.open_dataset(in_filename).isel(depth=depth_index)
if use_ocean_mask:
ds_mask = xr.open_dataset('ocean_mask.nc').isel(depth=depth_index)
ocean_mask = ds_mask.ocean_mask.values
else:
ocean_mask = None
field = ds.pt_an.values
# a small averaging kernel
x = np.arange(-1, 2)
x, y = np.meshgrid(x, x)
kernel = np.exp(-0.5 * (x**2 + y**2))
# a threshold for extrapolation weights to be considered valid
threshold = 0.01
valid = np.isfinite(field)
orig_mask = valid
if ocean_mask is not None:
invalid_after_fill = np.logical_not(np.logical_or(valid, ocean_mask))
else:
invalid_after_fill = None
fields = dict(pt_an=ds.pt_an.values.copy(),
s_an=ds.s_an.values.copy())
nlon = field.shape[1]
lon_with_halo = np.array([nlon - 2, nlon - 1] + list(range(nlon)) + [0, 1])
lon_no_halo = list(range(2, nlon + 2))
prev_fill_count = 0
while True:
valid_weight_sum = _extrap_with_halo(valid, kernel, valid,
lon_with_halo, lon_no_halo)
if invalid_after_fill is not None:
valid_weight_sum[invalid_after_fill] = 0.
new_valid = valid_weight_sum > threshold
# don't want to overwrite original data but do want ot smooth
# extrapolated data
fill_mask = np.logical_and(new_valid, np.logical_not(orig_mask))
fill_count = np.count_nonzero(fill_mask)
if fill_count == prev_fill_count:
# no change so we're done
break
for field_name, field in fields.items():
field_extrap = _extrap_with_halo(field, kernel, valid,
lon_with_halo, lon_no_halo)
field[fill_mask] = \
field_extrap[fill_mask] / valid_weight_sum[fill_mask]
valid = new_valid
prev_fill_count = fill_count
for field_name, field in fields.items():
if invalid_after_fill is not None:
field[invalid_after_fill] = np.nan
attrs = ds[field_name].attrs
dims = ds[field_name].dims
ds[field_name] = (dims, field)
ds[field_name].attrs = attrs
ds.to_netcdf(out_filename)
return out_filename
def _extrap_with_halo(field, kernel, valid, lon_with_halo, lon_no_halo):
field = field.copy()
field[np.logical_not(valid)] = 0.
field_with_halo = field[:, lon_with_halo]
field_extrap = convolve2d(field_with_halo, kernel, mode='same')
field_extrap = field_extrap[:, lon_no_halo]
return field_extrap