Source code for pycif.plugins.transforms.system.array2sampled.adjoint
import numpy as np
import pandas as pd
import xarray as xr
from logging import info, debug
[docs]
def adjoint(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
onlyinit=False,
**kwargs
):
"""Scatter observation-space adjoint sensitivities back onto the model grid.
The adjoint of :func:`forward`: for each observation, reads the
``'adj_out'`` value from the sparse output DataFrame and adds it to
the corresponding ``(tstep, lev, i, j)`` cell of the gridded
``'adj_out'`` array via ``np.add.at`` (atomic accumulation for
multiple observations mapping to the same cell).
Initialises ``xmod_in["adj_out"]`` as a zero array shaped like
``xmod_in["spec"]`` if not already present.
Args:
transform (Plugin): array2sampled instance.
inout_datastore (dict): mutable datastore; ``'inputs'`` has the
gridded arrays, ``'outputs'`` has the observation DataFrame
with ``'adj_out'`` column.
controlvect: unused.
obsvect: unused.
mapper (dict): transform mapper.
di (datetime): sub-simulation start date.
df (datetime): sub-simulation end date.
mode (str): ``'adj'``.
runsubdir (str): unused.
workdir (str): unused.
onlyinit (bool): if ``True``, return immediately.
**kwargs: unused.
"""
if onlyinit:
return
ddi = min(di, df)
for trid_in, trid_out in zip(mapper["inputs"], mapper["outputs"]):
try:
xmod_in = inout_datastore["inputs"][trid_in][ddi]
except:
print(__file__)
import code
code.interact(local=dict(locals(), **globals()))
xmod_out = inout_datastore["outputs"][trid_out][ddi]
t = xmod_out["metadata"]["tstep"].astype(int).values
i = xmod_out["metadata"]["i"].astype(int).values
j = xmod_out["metadata"]["j"].astype(int).values
# Deal with levels differently
if xmod_in["spec"].shape[1] == 1:
lev = (0. * i).astype(int)
else:
lev = xmod_out["metadata"]["level"].astype(int).values
if "adj_out" not in xmod_in:
xmod_in["adj_out"] = 0 * xmod_in["spec"]
data_out = xmod_out[("maindata", "adj_out")].values
data_in = xmod_in["adj_out"]
np.add.at(data_in.data,
(t, lev, i, j),
data_out)