Source code for pycif.plugins.transforms.system.sparse2sample.forward
import numpy as np
import xarray as xr
from logging import warning
[docs]
def forward(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
onlyinit=False,
**kwargs
):
"""Scatter sparse observation values onto the full model grid.
For each observation in the sparse input DataFrame, reads the ``spec``
(and ``incr`` in TL mode) column and adds its value to the
``(tstep, lev, i, j)`` cell of a zero-initialised output array using
``np.add.at``. The output is returned as an xarray DataArray with
``(time, lev, lat, lon)`` dimensions.
Output shape is determined by the domain dimensions
(``nlon``, ``nlat``, ``nlev``) and the number of time steps
from the mapper's ``input_dates`` for the current sub-simulation.
Args:
transform (Plugin): sparse2sample instance.
inout_datastore (dict): mutable datastore; ``'inputs'`` has the
sparse DataFrame, ``'outputs'`` receives the gridded DataArray.
controlvect: unused.
obsvect: unused.
mapper (dict): transform mapper (provides domain and input_dates).
di (datetime): sub-simulation start date.
df (datetime): sub-simulation end date.
mode (str): ``'fwd'`` or ``'tl'``.
runsubdir (str): unused.
workdir (str): unused.
onlyinit (bool): unused (no early return; data always written).
**kwargs: unused.
"""
ddi = min(di, df)
for trid in mapper["inputs"]:
xmod_in = inout_datastore["inputs"][trid][ddi]
xmod_out = inout_datastore["outputs"][trid][ddi]
t = xmod_in["metadata"]["tstep"].astype(int).values
lev = xmod_in["metadata"]["level"].astype(int).values
i = xmod_in["metadata"]["i"].astype(int).values
j = xmod_in["metadata"]["j"].astype(int).values
# Output shape
nlon = mapper["outputs"][trid]["domain"].nlon
nlat = mapper["outputs"][trid]["domain"].nlat
nlev = mapper["outputs"][trid]["domain"].nlev
ntimes = len(mapper["outputs"][trid]["input_dates"][ddi])
columns = ["spec"] if mode == "fwd" else ["spec", "incr"]
for c in columns:
if c not in xmod_in["maindata"]:
continue
var_out = np.zeros((ntimes, nlev, nlat, nlon))
np.add.at(var_out, (t, lev, i, j), xmod_in[("maindata", c)].values)
xmod_out[c] = xr.DataArray(
var_out,
coords={"time": mapper["outputs"][trid]["input_dates"][ddi][:, 0]},
dims=("time", "lev", "lat", "lon"),
)