Source code for pycif.plugins.transforms.basic.product.adjoint

import copy
import itertools
import numpy as np
import xarray as xr


[docs] def adjoint( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, onlyinit=False, **kwargs ): r"""Propagate sensitivities through the product (product-rule adjoint). Reloads the forward inputs from ``chain/product/``, then for each input :math:`x_j`: .. math:: s_j = s_{out} \cdot \prod_{i \neq j} x_i^{fwd} Uses ``itertools.permutations`` to enumerate all combinations. Args: transform (Plugin): product instance (carries ``model.adj_refdir`` and ``orig_name``). inout_datastore (dict): mutable datastore. controlvect: unused. obsvect: unused. mapper (dict): transform mapper. di (datetime): sub-simulation start date. df (datetime): unused. mode (str): ``'adj'``. runsubdir (str): unused. workdir (str): unused. onlyinit (bool): if ``True``, return immediately. **kwargs: unused. """ if onlyinit: return ddi = min(di, df) xmod_in = inout_datastore["inputs"] xmod_out = inout_datastore["outputs"] trid_out = list(mapper["outputs"].keys())[0] # Reload forward information file_fwd_dataset = ddi.strftime( f"{transform.model.adj_refdir}/chain/product/{transform.orig_name}_fwd_%Y%m%d%H%M.nc") fwd_dataset = xr.open_dataset(file_fwd_dataset) inout_datastore["inputs"] = { trid: {di: {k: 0 * xmod_out[trid_out][di][di][k] for k in xmod_out[ trid_out][di][di]}} for trid in mapper["inputs"]} xmod_in = inout_datastore["inputs"] # Loop on combinations for list_trids in itertools.permutations(mapper["inputs"].keys()): trid = list_trids[0] xmod_in[trid][di]["adj_out"] += \ xmod_out[trid_out][di][di]["adj_out"] \ * np.prod(np.vstack([fwd_dataset["___".join(tr)].values[np.newaxis] for tr in list_trids[1:]]), axis=0)