Source code for pycif.plugins.transforms.basic.exp.adjoint

import copy
import xarray as xr
import numpy as np


[docs] def adjoint( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, onlyinit=False, **kwargs ): r"""Propagate sensitivities through the exponential (adjoint). Reloads the forward input from ``chain/product/`` and applies: :math:`s_{in} = s_{out} \cdot e^{x_{fwd}}`. Args: transform (Plugin): exp instance (carries ``model.adj_refdir`` and ``orig_name``). inout_datastore (dict): mutable datastore. controlvect: unused. obsvect: unused. mapper (dict): transform mapper. di (datetime): sub-simulation start date. df (datetime): unused. mode (str): ``'adj'``. runsubdir (str): unused. workdir (str): unused. onlyinit (bool): if ``True``, return immediately. **kwargs: unused. """ if onlyinit: return ddi = min(di, df) xmod_out = inout_datastore["outputs"] xmod_in = inout_datastore["inputs"] parameter = transform.parameter component = transform.component trid_out = (component, parameter) # Reload forward information file_fwd_dataset = ddi.strftime( f"{transform.model.adj_refdir}/chain/exp/{transform.orig_name}_fwd_%Y%m%d%H%M.nc") fwd_dataset = xr.open_dataset(file_fwd_dataset) # Initialize inputs if cleared with memory if ddi not in xmod_in[trid_out]: xmod_in[trid_out][ddi] = {} # Applying adjoint xmod_in[trid_out][ddi]["adj_out"] = \ copy.deepcopy(xmod_out[trid_out][ddi]["adj_out"]) * \ np.exp(fwd_dataset["___".join(trid_out)])