import pandas as pd
import copy
import xarray as xr
import numpy as np
from logging import info, debug
from .....utils.datastores.empty import init_empty
from .....utils.classes.domains import Domain
from .....utils.datastores.crop_monitor import crop_monitor
from .....utils.dates import date_range
[docs]
def adjoint(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
onlyinit=False,
**kwargs
):
"""Read observation-vector sensitivities and populate the input datastore.
The adjoint of :func:`forward`: reads ``obsvect.dy`` (the
observation-space adjoint increment, e.g. set by the simulator to
:math:`\\mathbf{R}^{-1}(\\mathbf{H}\\mathbf{x} - \\mathbf{y})`) and
writes it into the ``'adj_out'`` column of each tracer's sparse
datastore entry.
Also copies ``obsvect.ysim`` into ``'spec'`` so that downstream
adjoint transforms can access the reference forward simulation.
Locations outside the active observation mask (``obsvect.obsvect_mask``)
are zeroed out.
When ``split_freq`` is configured, only the observations in the
current sub-window are processed and their dates are clipped to the
window boundaries.
Args:
transform (Plugin): toobsvect transform instance.
inout_datastore (dict): datastore; ``'inputs'`` entries are
populated with sparse DataFrames carrying ``spec``, ``incr``,
and ``adj_out`` columns.
controlvect: control vector plugin (unused).
obsvect: observation vector plugin; ``ysim`` and ``dy`` are read.
mapper (dict): transform mapper.
di (datetime): sub-simulation start date.
df (datetime): sub-simulation end date.
mode (str): ``'adj'``.
runsubdir (str): sub-simulation run directory (unused).
workdir (str): root working directory (unused).
onlyinit (bool): unused; kept for interface consistency.
**kwargs: unused.
"""
ddi = min(di, df)
ddf = max(di, df)
datastore = inout_datastore["inputs"]
datavect = obsvect.datavect
tracer_ids = mapper["inputs"]
# Fetch reference information
tracer_ref = list(tracer_ids)[0]
mod_input = tracer_ref[0]
trcr = tracer_ref[1]
# If this type of input is not considered in the observation vector,
# ignoring the model sensitivity
component = getattr(getattr(datavect, "components", None),
mod_input,
None)
# Now fetch info from data structure to common structure
parameters = getattr(component, "parameters", None)
param = getattr(parameters, trcr, None)
# Skip tracers not in the obs space
if not hasattr(parameters, trcr) or not getattr(
getattr(parameters, trcr, None), "isobs", False
):
debug(
f"The observation operator simulates {trcr} as a {mod_input} but your observation vector doesn't include it as a component"
)
return
if parameters is None:
debug(
f"The observation operator simulates {mod_input} but your observation vector doesn't include it as a parameter"
)
return
if param is None or not getattr(param, "isobs", False):
debug(f"{trcr}/{mod_input} was not is the observation vector. Passing")
return
# Now fetch info from data structure to common structure
ds_meta = param.datastore
ds_dates = ds_meta[
[("metadata", "date"),
("metadata", "duration"),
("metadata", "enddate")]
]
# Keep only sub-period if `split_freq` is specified
crop_index = np.arange(len(ds_meta), dtype=int)
weight = 1.0
if hasattr(transform, "split_freq"):
obsvect_dates = date_range(
transform.datei, transform.datef, period=transform.split_freq)
i0 = np.argwhere(obsvect_dates == ddi)[0][0]
split_ddf = obsvect_dates[i0 + 1]
crop_index, weight = crop_monitor(
ds_dates, ddi, split_ddf,
return_index=True,
keep_partial=True,
return_weight=True
)
# Now loop over tracers to get values from the observation vector
for tracer_id in datastore:
component = getattr(
getattr(datavect, "components", None), tracer_id[0], None
)
parameters = getattr(component, "parameters", None)
param = getattr(parameters, tracer_id[1])
ds = param.datastore
# Crop according to dates to process
data = copy.copy(ds.iloc[crop_index]).reset_index(drop=True)
# Stop here if empty
if len(data) == 0:
continue
# Adjust dates if split_freq
if hasattr(transform, "split_freq"):
data.loc[:, ("metadata", "date")] = \
np.maximum(
data.loc[:, ("metadata", "date")].dt.to_pydatetime(),
ddi
)
data.loc[:, ("metadata", "enddate")] = \
np.minimum(
data.loc[:, ("metadata", "enddate")].dt.to_pydatetime(),
split_ddf
)
# Fetching values from increments and forward if needed later
data.loc[:, ("maindata", "spec")] = \
obsvect.ysim[param.ypointer: param.ypointer +
param.dim][crop_index]
data.loc[:, ("maindata", "incr")] = \
0. * \
obsvect.ysim[param.ypointer: param.ypointer +
param.dim][crop_index]
data.loc[:, ("maindata", "adj_out")] = \
obsvect.dy[param.ypointer: param.ypointer +
param.dim][crop_index] \
* weight
# Set incr and adj_out to zero when not in observation vector
obsvect_mask = obsvect.obsvect_mask[
param.ypointer: param.ypointer + param.dim][crop_index]
data.loc[~obsvect_mask, ("maindata", "incr")] = 0
data.loc[~obsvect_mask, ("maindata", "adj_out")] = 0
datastore[tracer_id][ddi] = data