Source code for pycif.plugins.transforms.basic.product.adjoint

import copy
import itertools
import numpy as np
import xarray as xr


[docs] def adjoint( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, onlyinit=False, **kwargs ): if onlyinit: return ddi = min(di, df) xmod_in = inout_datastore["inputs"] xmod_out = inout_datastore["outputs"] trid_out = list(mapper["outputs"].keys())[0] # Reload forward information file_fwd_dataset = ddi.strftime( "{}/chain/product/{}_fwd_%Y%m%d%H%M.nc".format( transform.model.adj_refdir, transform.orig_name)) fwd_dataset = xr.open_dataset(file_fwd_dataset) inout_datastore["inputs"] = { trid: {di: {k: 0 * xmod_out[trid_out][di][di][k] for k in xmod_out[ trid_out][di][di]}} for trid in mapper["inputs"]} xmod_in = inout_datastore["inputs"] # Loop on combinations for list_trids in itertools.permutations(mapper["inputs"].keys()): trid = list_trids[0] xmod_in[trid][di]["adj_out"] += \ xmod_out[trid_out][di][di]["adj_out"] \ * np.prod(np.vstack([fwd_dataset["___".join(tr)].values[np.newaxis] for tr in list_trids[1:]]), axis=0)