Source code for roger.models.svat_oxygen18.svat_oxygen18

from pathlib import Path
import h5netcdf
import numpy as onp
from roger import RogerSetup, roger_routine, roger_kernel, KernelOutput
from roger.variables import allocate
from roger.core.operators import numpy as npx, update, at, for_loop
from roger.core.transport import delta_to_conc, conc_to_delta


[docs] class SVATTRANSPORTSetup(RogerSetup): """A SVAT oxygen-18 transport model.""" # custom attributes required by helper functions _base_path = Path(__file__).parent _tm_structure = "complete-mixing" _input_dir = _base_path / "input" _identifier = "SVATOXYGEN18_complete-mixing" _sas_solver = "deterministic" # custom helper functions def _read_var_from_nc(self, var, path_dir, file): nc_file = path_dir / file with h5netcdf.File(nc_file, "r", decode_vlen_strings=False) as infile: var_obj = infile.variables[var] return npx.array(var_obj) def _get_nitt(self, path_dir, file): nc_file = path_dir / file with h5netcdf.File(nc_file, "r", decode_vlen_strings=False) as infile: var_obj = infile.variables["Time"] return len(onp.array(var_obj)) + 1 def _get_runlen(self, path_dir, file): nc_file = path_dir / file with h5netcdf.File(nc_file, "r", decode_vlen_strings=False) as infile: var_obj = infile.variables["Time"] return len(onp.array(var_obj)) * 60 * 60 * 24 def _get_time_origin(self, path_dir, file): nc_file = path_dir / file with h5netcdf.File(nc_file, "r", decode_vlen_strings=False) as infile: date = infile.variables["Time"].attrs["time_origin"].split(" ")[0] return f"{date} 00:00:00" def _set_identifier(self, identifier): self._identifier = identifier def _set_tm_structure(self, tm_structure): self._tm_structure = tm_structure def _set_sas_solver(self, sas_solver): self._sas_solver = sas_solver def _bfill_3d(self, state, arr): idx_shape = tuple([slice(None)] + [npx.newaxis] * (3 - 2 - 1)) idx = allocate(state.dimensions, ("x", "y", "t"), dtype=int) arr1 = allocate(state.dimensions, ("x", 1, 1), dtype=int) arr2 = allocate(state.dimensions, (1, "y", 1), dtype=int) arr3 = allocate(state.dimensions, ("x", "y", "t"), dtype=int) arr_fill = allocate(state.dimensions, ("x", "y", "t")) idx = update( idx, at[2:-2, 2:-2, :], npx.where(npx.isfinite(arr), npx.arange(npx.shape(arr)[2])[idx_shape], 0)[2:-2, 2:-2, :], ) idx = update( idx, at[2:-2, 2:-2, :], _bfill(idx)[2:-2, 2:-2, :], ) arr1 = update( arr1, at[:, 0, 0], npx.arange(npx.shape(arr)[0]), ) arr2 = update( arr2, at[0, :, 0], npx.arange(npx.shape(arr)[1]), ) arr3 = update( arr3, at[:, :, :], idx, ) arr_fill = update( arr_fill, at[:, :, :], arr[arr1, arr2, arr3], ) return arr_fill @roger_routine def set_settings(self, state): settings = state.settings settings.identifier = self._identifier # set the solver scheme settings.sas_solver = self._sas_solver # number of substeps settings.sas_solver_substeps = 6 if settings.sas_solver in ["RK4", "Euler"]: # time increment of substep (in days) settings.h = 1 / settings.sas_solver_substeps # output frequency (in seconds) settings.output_frequency = 86400 # total grid numbers in x- and y-direction settings.nx, settings.ny = 1, 1 # number of iterations (i.e. number of days) settings.nitt = self._get_nitt(self._input_dir, "forcing_tracer.nc") # maximum water age (in days) settings.ages = settings.nitt settings.nages = settings.ages + 1 # length of simulation (in seconds) settings.runlen = self._get_runlen(self._input_dir, "forcing_tracer.nc") # spatial discretization (in meters) settings.dx = 1 settings.dy = 1 # origin of spatial grid settings.x_origin = 0.0 settings.y_origin = 0.0 # origin of time steps (e.g. 01-01-2023) settings.time_origin = self._get_time_origin(self._input_dir, "forcing_tracer.nc") # enable transport settings.enable_offline_transport = True # enable oxygen-18 settings.enable_oxygen18 = True # set model structure settings.tm_structure = self._tm_structure # enable calculation of age statistic settings.enable_age_statistics = True @roger_routine def set_grid(self, state): vs = state.variables settings = state.settings # temporal grid vs.dt_secs = 60 * 60 * 24 vs.dt = 60 * 60 * 24 / (60 * 60) vs.ages = update(vs.ages, at[:], npx.arange(1, settings.nages)) vs.nages = update(vs.nages, at[:], npx.arange(settings.nages)) # spatial grid dx = allocate(state.dimensions, ("x")) dx = update(dx, at[:], settings.dx) dy = allocate(state.dimensions, ("y")) dy = update(dy, at[:], settings.dy) # distance from origin vs.x = update(vs.x, at[3:-2], settings.x_origin + npx.cumsum(dx[3:-2])) vs.y = update(vs.y, at[3:-2], settings.y_origin + npx.cumsum(dy[3:-2])) @roger_routine def set_look_up_tables(self, state): pass @roger_routine def set_topography(self, state): pass @roger_routine( dist_safe=False, local_variables=[ "S_pwp_rz", "S_pwp_ss", "S_sat_rz", "S_sat_ss", "sas_params_evap_soil", "sas_params_cpr_rz", "sas_params_transp", "sas_params_q_rz", "sas_params_q_ss", "itt", ], ) def set_parameters_setup(self, state): vs = state.variables vs.S_pwp_rz = update( vs.S_pwp_rz, at[2:-2, 2:-2], self._read_var_from_nc("S_pwp_rz", self._base_path, "states_hm.nc")[:, :, vs.itt], ) vs.S_pwp_ss = update( vs.S_pwp_ss, at[2:-2, 2:-2], self._read_var_from_nc("S_pwp_ss", self._base_path, "states_hm.nc")[:, :, vs.itt], ) vs.S_sat_rz = update( vs.S_sat_rz, at[2:-2, 2:-2], self._read_var_from_nc("S_sat_rz", self._base_path, "states_hm.nc")[:, :, vs.itt], ) vs.S_sat_ss = update( vs.S_sat_ss, at[2:-2, 2:-2], self._read_var_from_nc("S_sat_ss", self._base_path, "states_hm.nc")[:, :, vs.itt], ) # SAS parameterization vs.sas_params_evap_soil = update(vs.sas_params_evap_soil, at[2:-2, 2:-2, 0], 6) vs.sas_params_evap_soil = update(vs.sas_params_evap_soil, at[2:-2, 2:-2, 1], 0.1) vs.sas_params_cpr_rz = update(vs.sas_params_cpr_rz, at[2:-2, 2:-2, 0], 6) vs.sas_params_cpr_rz = update(vs.sas_params_cpr_rz, at[2:-2, 2:-2, 1], 0.1) vs.sas_params_transp = update(vs.sas_params_transp, at[2:-2, 2:-2, 0], 6) vs.sas_params_transp = update(vs.sas_params_transp, at[2:-2, 2:-2, 1], 0.3) vs.sas_params_q_rz = update(vs.sas_params_q_rz, at[2:-2, 2:-2, 0], 6) vs.sas_params_q_rz = update(vs.sas_params_q_rz, at[2:-2, 2:-2, 1], 2) vs.sas_params_q_ss = update(vs.sas_params_q_ss, at[2:-2, 2:-2, 0], 6) vs.sas_params_q_ss = update(vs.sas_params_q_ss, at[2:-2, 2:-2, 1], 3) @roger_routine def set_parameters(self, state): pass @roger_routine( dist_safe=False, local_variables=["S_snow", "S_rz", "S_rz_init", "S_ss", "S_ss_init", "S_s", "itt", "taup1"], ) def set_initial_conditions_setup(self, state): vs = state.variables vs.S_snow = update( vs.S_snow, at[2:-2, 2:-2, : vs.taup1], self._read_var_from_nc("S_snow", self._base_path, "states_hm.nc")[:, :, vs.itt], ) vs.S_rz = update( vs.S_rz, at[2:-2, 2:-2, : vs.taup1], self._read_var_from_nc("S_rz", self._base_path, "states_hm.nc")[:, :, vs.itt], ) vs.S_ss = update( vs.S_ss, at[2:-2, 2:-2, : vs.taup1], self._read_var_from_nc("S_ss", self._base_path, "states_hm.nc")[:, :, vs.itt], ) vs.S_s = update( vs.S_s, at[2:-2, 2:-2, : vs.taup1], vs.S_rz[2:-2, 2:-2, : vs.taup1] + vs.S_ss[2:-2, 2:-2, : vs.taup1] ) vs.S_rz_init = update(vs.S_rz_init, at[2:-2, 2:-2], vs.S_rz[2:-2, 2:-2, 0]) vs.S_ss_init = update(vs.S_ss_init, at[2:-2, 2:-2], vs.S_ss[2:-2, 2:-2, 0]) @roger_routine def set_initial_conditions(self, state): vs = state.variables settings = state.settings arr0 = allocate(state.dimensions, ("x", "y")) vs.sa_rz = update( vs.sa_rz, at[2:-2, 2:-2, : vs.taup1, 1:], npx.diff(npx.linspace(arr0[2:-2, 2:-2], vs.S_rz[2:-2, 2:-2, vs.tau], settings.ages, axis=-1), axis=-1)[ :, :, npx.newaxis, : ], ) vs.sa_ss = update( vs.sa_ss, at[2:-2, 2:-2, : vs.taup1, 1:], npx.diff(npx.linspace(arr0[2:-2, 2:-2], vs.S_ss[2:-2, 2:-2, vs.tau], settings.ages, axis=-1), axis=-1)[ :, :, npx.newaxis, : ], ) vs.SA_rz = update( vs.SA_rz, at[2:-2, 2:-2, :, 1:], npx.cumsum(vs.sa_rz[2:-2, 2:-2, :, :], axis=-1), ) vs.SA_ss = update( vs.SA_ss, at[2:-2, 2:-2, :, 1:], npx.cumsum(vs.sa_rz[2:-2, 2:-2, :, :], axis=-1), ) vs.sa_s = update( vs.sa_s, at[2:-2, 2:-2, :, :], vs.sa_rz[2:-2, 2:-2, :, :] + vs.sa_ss[2:-2, 2:-2, :, :], ) vs.SA_s = update( vs.SA_s, at[2:-2, 2:-2, :, 1:], npx.cumsum(vs.sa_s[2:-2, 2:-2, :, :], axis=-1), ) if settings.enable_oxygen18: vs.C_iso_snow = update(vs.C_iso_snow, at[2:-2, 2:-2, : vs.taup1], npx.nan) vs.C_iso_rz = update(vs.C_iso_rz, at[2:-2, 2:-2, : vs.taup1], -13) vs.C_iso_ss = update(vs.C_iso_ss, at[2:-2, 2:-2, : vs.taup1], -7) vs.C_rz = update( vs.C_rz, at[2:-2, 2:-2, : vs.taup1], delta_to_conc(state, vs.C_iso_rz[2:-2, 2:-2, vs.tau, npx.newaxis]), ) vs.msa_rz = update( vs.msa_rz, at[2:-2, 2:-2, : vs.taup1, :], vs.C_rz[2:-2, 2:-2, : vs.taup1, npx.newaxis], ) vs.msa_rz = update( vs.msa_rz, at[2:-2, 2:-2, : vs.taup1, 0], 0, ) vs.C_ss = update( vs.C_ss, at[2:-2, 2:-2, : vs.taup1], delta_to_conc(state, vs.C_iso_ss[2:-2, 2:-2, vs.tau, npx.newaxis]), ) vs.msa_ss = update( vs.msa_ss, at[2:-2, 2:-2, : vs.taup1, :], vs.C_ss[2:-2, 2:-2, : vs.taup1, npx.newaxis], ) vs.msa_ss = update( vs.msa_ss, at[2:-2, 2:-2, : vs.taup1, 0], 0, ) vs.msa_s = update( vs.msa_s, at[2:-2, 2:-2, :, :], npx.where( vs.sa_rz[2:-2, 2:-2, :, :] + vs.sa_ss[2:-2, 2:-2, :, :] > 0, vs.msa_rz[2:-2, 2:-2, :, :] * (vs.sa_rz[2:-2, 2:-2, :, :] / (vs.sa_rz[2:-2, 2:-2, :, :] + vs.sa_ss[2:-2, 2:-2, :, :])) + vs.msa_ss[2:-2, 2:-2, :, :] * (vs.sa_ss[2:-2, 2:-2, :, :] / (vs.sa_rz[2:-2, 2:-2, :, :] + vs.sa_ss[2:-2, 2:-2, :, :])), 0, ), ) vs.msa_s = update( vs.msa_s, at[2:-2, 2:-2, : vs.taup1, 0], 0, ) vs.C_s = update( vs.C_s, at[2:-2, 2:-2, vs.tau], npx.sum( npx.where( vs.sa_s[2:-2, 2:-2, vs.tau, :] > 0, vs.msa_s[2:-2, 2:-2, vs.tau, :] * ( vs.sa_s[2:-2, 2:-2, vs.tau, :] / npx.sum(vs.sa_s[2:-2, 2:-2, vs.tau, :], axis=-1)[:, :, npx.newaxis] ), 0, ), axis=-1, ), ) vs.C_s = update( vs.C_s, at[2:-2, 2:-2, vs.taum1], vs.C_s[2:-2, 2:-2, vs.tau] * vs.maskCatch[2:-2, 2:-2], ) vs.C_iso_s = update( vs.C_iso_s, at[2:-2, 2:-2, vs.taum1], conc_to_delta(state, vs.C_s[2:-2, 2:-2, vs.tau]) * vs.maskCatch[2:-2, 2:-2], ) vs.C_iso_s = update( vs.C_iso_s, at[2:-2, 2:-2, vs.tau], conc_to_delta(state, vs.C_s[2:-2, 2:-2, vs.tau]) * vs.maskCatch[2:-2, 2:-2], ) vs.csa_rz = update( vs.csa_rz, at[2:-2, 2:-2, vs.tau, :], conc_to_delta(state, vs.msa_rz[2:-2, 2:-2, vs.tau, :]), ) vs.csa_ss = update( vs.csa_ss, at[2:-2, 2:-2, vs.tau, :], conc_to_delta(state, vs.msa_ss[2:-2, 2:-2, vs.tau, :]), ) vs.csa_s = update( vs.csa_s, at[2:-2, 2:-2, vs.tau, :], conc_to_delta(state, vs.msa_s[2:-2, 2:-2, vs.tau, :]), ) @roger_routine def set_boundary_conditions_setup(self, state): pass @roger_routine def set_boundary_conditions(self, state): pass @roger_routine( dist_safe=False, local_variables=[ "C_ISO_IN", "C_IN", ], ) def set_forcing_setup(self, state): vs = state.variables settings = state.settings if settings.enable_oxygen18: vs.C_ISO_IN = update(vs.C_ISO_IN, at[2:-2, 2:-2, 0], npx.nan) vs.C_ISO_IN = update( vs.C_ISO_IN, at[2:-2, 2:-2, 1:], self._read_var_from_nc("d18O", self._input_dir, "forcing_tracer.nc") ) vs.C_ISO_IN = update(vs.C_ISO_IN, at[2:-2, 2:-2, :], self._bfill_3d(state, vs.C_ISO_IN)[2:-2, 2:-2, :]) vs.C_IN = update(vs.C_IN, at[2:-2, 2:-2, :], delta_to_conc(state, vs.C_ISO_IN)[2:-2, 2:-2, :]) @roger_routine( dist_safe=False, local_variables=[ "ta", "prec", "inf_mat_rz", "inf_pf_rz", "inf_pf_ss", "transp", "evap_soil", "cpr_rz", "q_rz", "q_ss", "S_rz", "S_ss", "S_s", "S_snow", "tau", "taum1", "itt", "C_in", "C_iso_in", "C_IN", "C_snow", "C_iso_snow", ], ) def set_forcing(self, state): vs = state.variables vs.ta = update( vs.ta, at[2:-2, 2:-2], self._read_var_from_nc("ta", self._base_path, "states_hm.nc")[:, :, vs.itt] ) vs.prec = update( vs.prec, at[2:-2, 2:-2, vs.tau], self._read_var_from_nc("prec", self._base_path, "states_hm.nc")[:, :, vs.itt], ) vs.inf_mat_rz = update( vs.inf_mat_rz, at[2:-2, 2:-2], self._read_var_from_nc("inf_mat_rz", self._base_path, "states_hm.nc")[:, :, vs.itt], ) vs.inf_pf_rz = update( vs.inf_pf_rz, at[2:-2, 2:-2], self._read_var_from_nc("inf_mp_rz", self._base_path, "states_hm.nc")[:, :, vs.itt] + self._read_var_from_nc("inf_sc_rz", self._base_path, "states_hm.nc")[:, :, vs.itt], ) vs.inf_pf_ss = update( vs.inf_pf_ss, at[2:-2, 2:-2], self._read_var_from_nc("inf_ss", self._base_path, "states_hm.nc")[:, :, vs.itt], ) vs.transp = update( vs.transp, at[2:-2, 2:-2], self._read_var_from_nc("transp", self._base_path, "states_hm.nc")[:, :, vs.itt] ) vs.evap_soil = update( vs.evap_soil, at[2:-2, 2:-2], self._read_var_from_nc("evap_soil", self._base_path, "states_hm.nc")[:, :, vs.itt], ) vs.cpr_rz = update( vs.cpr_rz, at[2:-2, 2:-2], self._read_var_from_nc("cpr_rz", self._base_path, "states_hm.nc")[:, :, vs.itt] ) vs.q_rz = update( vs.q_rz, at[2:-2, 2:-2], self._read_var_from_nc("q_rz", self._base_path, "states_hm.nc")[:, :, vs.itt] ) vs.q_ss = update( vs.q_ss, at[2:-2, 2:-2], self._read_var_from_nc("q_ss", self._base_path, "states_hm.nc")[:, :, vs.itt] ) vs.S_rz = update( vs.S_rz, at[2:-2, 2:-2, vs.tau], self._read_var_from_nc("S_rz", self._base_path, "states_hm.nc")[:, :, vs.itt], ) vs.S_ss = update( vs.S_ss, at[2:-2, 2:-2, vs.tau], self._read_var_from_nc("S_ss", self._base_path, "states_hm.nc")[:, :, vs.itt], ) vs.S_s = update(vs.S_s, at[2:-2, 2:-2, vs.tau], vs.S_rz[2:-2, 2:-2, vs.tau] + vs.S_ss[2:-2, 2:-2, vs.tau]) vs.S_snow = update( vs.S_snow, at[2:-2, 2:-2, vs.tau], self._read_var_from_nc("S_snow", self._base_path, "states_hm.nc")[:, :, vs.itt], ) vs.C_in = update(vs.C_in, at[2:-2, 2:-2], vs.C_IN[2:-2, 2:-2, vs.itt]) # mixing of isotopes while snow accumulation vs.C_snow = update( vs.C_snow, at[2:-2, 2:-2, vs.tau], npx.where( vs.S_snow[2:-2, 2:-2, vs.tau] > 0, npx.where( npx.isnan(vs.C_snow[2:-2, 2:-2, vs.tau]), vs.C_in[2:-2, 2:-2], (vs.prec[2:-2, 2:-2, vs.tau] / (vs.prec[2:-2, 2:-2, vs.tau] + vs.S_snow[2:-2, 2:-2, vs.tau])) * vs.C_in[2:-2, 2:-2] + (vs.S_snow[2:-2, 2:-2, vs.tau] / (vs.prec[2:-2, 2:-2, vs.tau] + vs.S_snow[2:-2, 2:-2, vs.tau])) * vs.C_snow[2:-2, 2:-2, vs.taum1], ), npx.nan, ), ) vs.C_snow = update( vs.C_snow, at[2:-2, 2:-2, vs.tau], npx.where(vs.S_snow[2:-2, 2:-2, vs.tau] <= 0, npx.nan, vs.C_snow[2:-2, 2:-2, vs.tau]), ) vs.C_iso_snow = update( vs.C_iso_snow, at[2:-2, 2:-2, vs.tau], conc_to_delta(state, vs.C_snow[2:-2, 2:-2, vs.tau]), ) # mix isotopes from snow melt and rainfall vs.C_in = update( vs.C_in, at[2:-2, 2:-2], npx.where( npx.isfinite(vs.C_snow[2:-2, 2:-2, vs.taum1]), vs.C_snow[2:-2, 2:-2, vs.taum1], npx.where(vs.prec[2:-2, 2:-2, vs.tau] > 0, vs.C_IN[2:-2, 2:-2, vs.itt], 0), ), ) vs.C_iso_in = update(vs.C_iso_in, at[2:-2, 2:-2], conc_to_delta(state, vs.C_in[2:-2, 2:-2])) @roger_routine def set_diagnostics(self, state): pass @roger_routine def after_timestep(self, state): vs = state.variables vs.update(after_timestep_kernel(state))
@roger_kernel def after_timestep_kernel(state): vs = state.variables vs.S_snow = update( vs.S_snow, at[2:-2, 2:-2, vs.taum1], vs.S_snow[2:-2, 2:-2, vs.tau], ) vs.C_snow = update( vs.C_snow, at[2:-2, 2:-2, vs.taum1], vs.C_snow[2:-2, 2:-2, vs.tau], ) vs.prec = update( vs.prec, at[2:-2, 2:-2, vs.taum1], vs.prec[2:-2, 2:-2, vs.tau], ) return KernelOutput( prec=vs.prec, C_snow=vs.C_snow, S_snow=vs.S_snow, ) @roger_kernel def _bfill(loop_arr): # fill NaN values in backward direction def loop_body(i, loop_arr): j = loop_arr.shape[2] - i loop_arr = update( loop_arr, at[:, :, j - 1], npx.where(loop_arr[:, :, j - 1] == 0, loop_arr[:, :, j], loop_arr[:, :, j - 1]), ) return loop_arr loop_arr = for_loop(1, loop_arr.shape[2], loop_body, loop_arr) loop_arr = update( loop_arr, at[:, :, -1], npx.where(loop_arr[:, :, -1] == 0, loop_arr[:, :, -2], loop_arr[:, :, -1]), ) return loop_arr