Source code for pycif.plugins.transforms.basic.product.adjoint
import copy
import itertools
import numpy as np
import xarray as xr
[docs]
def adjoint(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
onlyinit=False,
**kwargs
):
r"""Propagate sensitivities through the product (product-rule adjoint).
Reloads the forward inputs from ``chain/product/``, then for each input
:math:`x_j`:
.. math::
s_j = s_{out} \cdot \prod_{i \neq j} x_i^{fwd}
Uses ``itertools.permutations`` to enumerate all combinations.
Args:
transform (Plugin): product instance (carries ``model.adj_refdir``
and ``orig_name``).
inout_datastore (dict): mutable datastore.
controlvect: unused.
obsvect: unused.
mapper (dict): transform mapper.
di (datetime): sub-simulation start date.
df (datetime): unused.
mode (str): ``'adj'``.
runsubdir (str): unused.
workdir (str): unused.
onlyinit (bool): if ``True``, return immediately.
**kwargs: unused.
"""
if onlyinit:
return
ddi = min(di, df)
xmod_in = inout_datastore["inputs"]
xmod_out = inout_datastore["outputs"]
trid_out = list(mapper["outputs"].keys())[0]
# Reload forward information
file_fwd_dataset = ddi.strftime(
f"{transform.model.adj_refdir}/chain/product/{transform.orig_name}_fwd_%Y%m%d%H%M.nc")
fwd_dataset = xr.open_dataset(file_fwd_dataset)
inout_datastore["inputs"] = {
trid: {di: {k: 0 * xmod_out[trid_out][di][di][k] for k in xmod_out[
trid_out][di][di]}} for trid in mapper["inputs"]}
xmod_in = inout_datastore["inputs"]
# Loop on combinations
for list_trids in itertools.permutations(mapper["inputs"].keys()):
trid = list_trids[0]
xmod_in[trid][di]["adj_out"] += \
xmod_out[trid_out][di][di]["adj_out"] \
* np.prod(np.vstack([fwd_dataset["___".join(tr)].values[np.newaxis]
for tr in list_trids[1:]]),
axis=0)