Source code for pycif.plugins.transforms.basic.regrid.adjoint

import os
import copy
import numpy as np
import xarray as xr
import itertools
from .utils.get_weights import get_weights
from .....utils.classes.domains import Domain


[docs] def adjoint( transf, inout_datastore, controlvect, obsvect, mapper, ddi, ddf, mode, runsubdir, workdir, onlyinit=False, save_debug=False, **kwargs ): xmod = inout_datastore["outputs"] trid_ref = list(xmod)[0] # First fetch information to # Differentiate full matrices and sparse data is_sparse_out = mapper["outputs"][trid_ref].get("sparse_data", False) is_sparse_in = mapper["inputs"][trid_ref].get("sparse_data", False) is_sampled_in = mapper["inputs"][trid_ref]["sampled"] is_sampled_out = mapper["outputs"][trid_ref]["sampled"] # Inputs domain from the present tracer if is_sparse_in: if onlyinit: return domain_in = Domain() ds = inout_datastore["outputs"][trid_ref][ddi]["metadata"] domain_in.zlon = ds.loc[:, "lon"].values[:, np.newaxis] domain_in.zlonc = ds.loc[:, "lon"].values[:, np.newaxis] domain_in.zlat = ds.loc[:, "lat"].values[:, np.newaxis] domain_in.zlatc = ds.loc[:, "lat"].values[:, np.newaxis] domain_in.nlon = len(ds) domain_in.nlat = len(ds) nlat_in, nlon_in = domain_in.nlon, domain_in.nlon else: domain_in = mapper["inputs"][trid_ref]["domain"] nlat_in, nlon_in = domain_in.zlon.shape # Outputs domain from the mapper is_lbc = mapper["outputs"][trid_ref].get("is_lbc", False) # Create domain from metadata if sparse data if is_sparse_out or is_sampled_out: # Do nothing if onlyinit and propagating empty matrix if onlyinit and len(xmod[trid_ref][ddi]) == 0: return domain_out = Domain() ds = xmod[trid_ref][ddi]["metadata"] domain_out.zlon = ds.loc[:, "lon"].values[:, np.newaxis] domain_out.zlonc = ds.loc[:, "lon"].values[:, np.newaxis] domain_out.zlat = ds.loc[:, "lat"].values[:, np.newaxis] domain_out.zlatc = ds.loc[:, "lat"].values[:, np.newaxis] domain_out.nlon = len(ds) domain_out.nlat = len(ds) else: domain_out = mapper["outputs"][trid_ref]["domain"] # For non-sparse data, in initialization mode, don't need to do # any further operations if onlyinit: return # Differentiate dimensions if LBC if is_lbc: nlon_out = domain_out.nlon_side nlat_out = domain_out.nlat_side else: nlat_out, nlon_out = domain_out.zlat.shape # Getting weights weights = get_weights( transf, trid_ref, mapper, domain_in, domain_out, is_lbc, ddi, is_sparse_in=is_sparse_in, is_sparse_out=is_sparse_out or is_sampled_out ) for trid in mapper["inputs"]: inout_datastore["inputs"][trid] = {} if not ( is_sparse_in or is_sparse_out or is_sampled_out ): inout_datastore["inputs"][trid][ddi] = { "adj_out": do_regridding_adj( xmod[trid][ddi]["adj_out"], nlat_in, nlon_in, nlat_out, nlon_out, weights, min_weight=transf.min_weight, is_sparse_in=is_sparse_in, is_sparse_out=is_sparse_out or is_sampled_out, ) } else: inout_datastore["inputs"][trid][ddi] = \ do_regridding_adj( xmod[trid][ddi], nlat_in, nlon_in, nlat_out, nlon_out, weights, min_weight=transf.min_weight, is_sparse_in=is_sparse_in, is_sparse_out=is_sparse_out or is_sampled_out, save_debug=save_debug, transf=transf )
[docs] def do_regridding_adj( data, nlat_in, nlon_in, nlat_out, nlon_out, weights, min_weight=1e-10, is_sparse_in=False, is_sparse_out=False, save_debug=False, transf=None ): if not (is_sparse_in or is_sparse_out): # Applying weights nlev = len(data.lev) ntimes = len(data.time) nlat_out = len(data.lat) nlon_out = len(data.lon) var_in = np.zeros((ntimes, nlev, nlat_in, nlon_in), data.dtype) var_out = data.values # Filter weights mask = weights["wgt"] > min_weight weights["wgt"][~mask] = 0 # Apply weights iout, jout = np.unravel_index(range(nlat_out * nlon_out), (nlat_out, nlon_out), order="F") if "filtered" in weights: iout = iout[weights["filtered"]] jout = jout[weights["filtered"]] for time, level \ in itertools.product(range(var_out.shape[0]), range(var_out.shape[1])): np.add.at(var_in[time, level], (weights["i"].astype(int), weights["j"].astype(int)), var_out[time, level, iout, jout, np.newaxis] * weights["wgt"]) times = data.time.values return xr.DataArray( var_in, coords={"time": times}, dims=("time", "lev", "lat", "lon") ) else: i = weights["i"].astype(int) j = weights["j"].astype(int) wgts = weights["wgt"] # Mask filtered if exists mask = np.arange(len(data)) if "filtered" in weights: mask = weights["filtered"] if is_sparse_out: data_out = copy.deepcopy(data.iloc[mask]) # Extend data_out depending on the shape of i/j if "index" in data_out: del data_out["index"] data_out = ( data_out.iloc[ data_out.reset_index().index.repeat( np.shape(i)[1] * np.ones(len(i), dtype=int))]) data_out[("metadata", "i")] = i.flatten() data_out[("metadata", "j")] = j.flatten() data_out[("maindata", "adj_out")] *= wgts.flatten() else: raise Exception("Adjoint of regrid with sparse inputs not yet " "implemented") return data_out