Source code for compass.ocean.tests.planar_convergence.horizontal_advection.analysis

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import warnings

from compass.ocean.tests.planar_convergence.conv_analysis import ConvAnalysis


[docs]class Analysis(ConvAnalysis): """ A step for visualizing the output from the advection convergence test case """
[docs] def __init__(self, test_case, resolutions): """ Create the step Parameters ---------- test_case : compass.TestCase The test case this step belongs to resolutions : list of int The resolutions of the meshes that have been run """ super().__init__(test_case=test_case, resolutions=resolutions) self.resolutions = resolutions self.add_output_file('convergence.png')
[docs] def run(self): """ Run this step of the test case """ plt.switch_backend('Agg') resolutions = self.resolutions ncells_list = list() errors = list() for res in resolutions: rms_error, ncells = self.rmse(res, variable='tracer1') ncells_list.append(ncells) errors.append(rms_error) ncells = np.array(ncells_list) errors = np.array(errors) p = np.polyfit(np.log10(ncells), np.log10(errors), 1) conv = abs(p[0]) * 2.0 error_fit = ncells**p[0] * 10**p[1] plt.loglog(ncells, error_fit, 'k') plt.loglog(ncells, errors, 'or') plt.annotate('Order of Convergence = {}'.format(np.round(conv, 3)), xycoords='axes fraction', xy=(0.3, 0.95), fontsize=14) plt.xlabel('Number of Grid Cells', fontsize=14) plt.ylabel('L2 Norm', fontsize=14) plt.savefig('convergence.png', bbox_inches='tight', pad_inches=0.1) section = self.config['horizontal_advection'] conv_thresh = section.getfloat('conv_thresh') conv_max = section.getfloat('conv_max') if conv < conv_thresh: raise ValueError(f'order of convergence ' f' {conv} < min tolerence {conv_thresh}') if conv > conv_max: warnings.warn(f'order of convergence ' f'{conv} > max tolerence {conv_max}')
[docs] def rmse(self, resolution, variable): """ Compute the RMSE for a given resolution Parameters ---------- resolution : int The resolution of the (uniform) mesh in km variable : str The name of a variable in the output file to analyze. Returns ------- rms_error : float The root-mean-squared error ncells : int The number of cells in the mesh """ res_tag = '{}km'.format(resolution) ds = xr.open_dataset('{}_output.nc'.format(res_tag)) init = ds[variable].isel(Time=0) final = ds[variable].isel(Time=-1) diff = final - init rms_error = np.sqrt((diff**2).mean()).values ncells = ds.sizes['nCells'] return rms_error, ncells