import pandas as pd
import copy
import xarray as xr
import numpy as np
from logging import info, debug
from .....utils.datastores.empty import init_empty
from .....utils.classes.domains import Domain
from .....utils.datastores.crop_monitor import crop_monitor
from .....utils.dates import date_range
[docs]
def adjoint(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
onlyinit=False,
**kwargs
):
"""Translates real-size data as extracted from the model outputs to
the control space. This includes mainly temporal and spatial aggregation.
This routine is used in adjoint mode, thus computes operations on
increments.
Args:
self (Plugin): the control vect
datastore (dict): the data at the model resolution to be converted
to the control space
di (datetime): starting date of the simulation window
df (datetime): ending date of the simulation window
workdir (str): pycif working directory
"""
ddi = min(di, df)
ddf = max(di, df)
datastore = inout_datastore["inputs"]
datavect = obsvect.datavect
tracer_ids = mapper["inputs"]
# Fetch reference information
tracer_ref = list(tracer_ids)[0]
mod_input = tracer_ref[0]
trcr = tracer_ref[1]
# If this type of input is not considered in the observation vector,
# ignoring the model sensitivity
component = getattr(getattr(datavect, "components", None),
mod_input,
None)
# Now fetch info from data structure to common structure
parameters = getattr(component, "parameters", None)
param = getattr(parameters, trcr, None)
# Skip tracers not in the obs space
if not hasattr(parameters, trcr) or not getattr(
getattr(parameters, trcr, None), "isobs", False
):
debug(
"The observation operator simulates {} as a {} "
"but your observation vector doesn't "
"include it as a component".format(
trcr, mod_input
)
)
return
if parameters is None:
debug(
"The observation operator simulates "
"{} but your observation vector doesn't "
"include it as a parameter".format(mod_input)
)
return
if param is None or not getattr(param, "isobs", False):
debug("{}/{} was not is the observation vector. Passing"
.format(trcr, mod_input))
return
# Now fetch info from data structure to common structure
ds_meta = param.datastore
ds_dates = ds_meta[
[("metadata", "date"),
("metadata", "duration"),
("metadata", "enddate")]
]
# Keep only sub-period if `split_freq` is specified
crop_index = np.arange(len(ds_meta), dtype=int)
weight = 1.0
if hasattr(transform, "split_freq"):
obsvect_dates = date_range(
transform.datei, transform.datef, period=transform.split_freq)
i0 = np.argwhere(obsvect_dates == ddi)[0][0]
split_ddf = obsvect_dates[i0 + 1]
crop_index, weight = crop_monitor(
ds_dates, ddi, split_ddf,
return_index=True,
keep_partial=True,
return_weight=True
)
# Now loop over tracers to get values from the observation vector
for tracer_id in datastore:
component = getattr(
getattr(datavect, "components", None), tracer_id[0], None
)
parameters = getattr(component, "parameters", None)
param = getattr(parameters, tracer_id[1])
ds = param.datastore
# Crop according to dates to process
data = copy.copy(ds.iloc[crop_index]).reset_index(drop=True)
# Stop here if empty
if len(data) == 0:
continue
# Adjust dates if split_freq
if hasattr(transform, "split_freq"):
data.loc[:, ("metadata", "date")] = \
np.maximum(
data.loc[:, ("metadata", "date")].dt.to_pydatetime(),
ddi
)
data.loc[:, ("metadata", "enddate")] = \
np.minimum(
data.loc[:, ("metadata", "enddate")].dt.to_pydatetime(),
split_ddf
)
# Fetching values from increments and forward if needed later
data.loc[:, ("maindata", "spec")] = \
obsvect.ysim[param.ypointer: param.ypointer +
param.dim][crop_index]
data.loc[:, ("maindata", "incr")] = \
0. * \
obsvect.ysim[param.ypointer: param.ypointer +
param.dim][crop_index]
data.loc[:, ("maindata", "adj_out")] = \
obsvect.dy[param.ypointer: param.ypointer +
param.dim][crop_index] \
* weight
# Set incr and adj_out to zero when not in observation vector
obsvect_mask = obsvect.obsvect_mask[
param.ypointer: param.ypointer + param.dim][crop_index]
data.loc[~obsvect_mask, ("maindata", "incr")] = 0
data.loc[~obsvect_mask, ("maindata", "adj_out")] = 0
datastore[tracer_id][ddi] = data