Source code for pycif.plugins.transforms.system.array2sampled.adjoint

import numpy as np
import pandas as pd
import xarray as xr
from logging import info, debug


[docs] def adjoint( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, onlyinit=False, **kwargs ): """Scatter observation-space adjoint sensitivities back onto the model grid. The adjoint of :func:`forward`: for each observation, reads the ``'adj_out'`` value from the sparse output DataFrame and adds it to the corresponding ``(tstep, lev, i, j)`` cell of the gridded ``'adj_out'`` array via ``np.add.at`` (atomic accumulation for multiple observations mapping to the same cell). Initialises ``xmod_in["adj_out"]`` as a zero array shaped like ``xmod_in["spec"]`` if not already present. Args: transform (Plugin): array2sampled instance. inout_datastore (dict): mutable datastore; ``'inputs'`` has the gridded arrays, ``'outputs'`` has the observation DataFrame with ``'adj_out'`` column. controlvect: unused. obsvect: unused. mapper (dict): transform mapper. di (datetime): sub-simulation start date. df (datetime): sub-simulation end date. mode (str): ``'adj'``. runsubdir (str): unused. workdir (str): unused. onlyinit (bool): if ``True``, return immediately. **kwargs: unused. """ if onlyinit: return ddi = min(di, df) for trid_in, trid_out in zip(mapper["inputs"], mapper["outputs"]): try: xmod_in = inout_datastore["inputs"][trid_in][ddi] except: print(__file__) import code code.interact(local=dict(locals(), **globals())) xmod_out = inout_datastore["outputs"][trid_out][ddi] t = xmod_out["metadata"]["tstep"].astype(int).values i = xmod_out["metadata"]["i"].astype(int).values j = xmod_out["metadata"]["j"].astype(int).values # Deal with levels differently if xmod_in["spec"].shape[1] == 1: lev = (0. * i).astype(int) else: lev = xmod_out["metadata"]["level"].astype(int).values if "adj_out" not in xmod_in: xmod_in["adj_out"] = 0 * xmod_in["spec"] data_out = xmod_out[("maindata", "adj_out")].values data_in = xmod_in["adj_out"] np.add.at(data_in.data, (t, lev, i, j), data_out)