import numpy as np
import pandas as pd
from logging import info, debug, warning
from .utils.dates import dateslice
from .utils.scalemaps import map2scale, vmap2vaggreg
[docs]
def adjoint(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
onlyinit=False,
**kwargs
):
"""Translates real-size data as extracted from the model outputs to
the control space. This includes mainly temporal and spatial aggregation.
This routine is used in adjoint mode, thus computes operations on
increments.
Args:
self (Plugin): the control vect
datastore (dict): the data at the model resolution to be converted
to the control space
di (datetime): starting date of the simulation window
df (datetime): ending date of the simulation window
workdir (str): pycif working directory
"""
if onlyinit:
return
ddi = min(di, df)
datastore = inout_datastore["outputs"]
datavect = controlvect.datavect
tracer_ids = mapper["outputs"]
# # If dx is not yet defined, initialize it
# if not hasattr(data, "dx"):
# info("Setting dx to zero in the control vector")
# data.dx = 0.0 * data.x
# Loop over model sensitivities
for tracer_id in tracer_ids:
mod_input = tracer_id[0]
trcr = tracer_id[1]
# If this type of input is not considered in the control vector,
# ignoring the model sensitivity
component = getattr(getattr(datavect, "components", None),
mod_input,
None)
parameters = getattr(component, "parameters", None)
if parameters is None:
debug(
"The observation operator is sensitive to {} "
"but your input vector doesn't "
"include it as a component. Skipping it!".format(
mod_input
)
)
continue
# Skip tracers not in the control space
if not hasattr(parameters, trcr) or not getattr(
getattr(parameters, trcr, None), "iscontrol", False
):
debug(
"The observation operator is sensitive to {} as a {} "
"but your control vector doesn't "
"include it as a component. Skipping it!".format(
trcr, mod_input
)
)
continue
# Check that the model provides information about sensitivities
if "adj_out" not in datastore[tracer_id][ddi]:
info(
"Couldn't get any model sensitivity. "
"Assuming zero sensitivity"
)
continue
# Process other input types:
# - re-project map sensitivities to control space
# - sum date slices in the sensitivities to control space periods
tracer = getattr(parameters, trcr)
period_dates = mapper["outputs"][tracer_id]["input_dates"][di]
# Loop over control space periods for temporal aggregation
# Make vertical aggregation per temporal slice
sensit = datastore[tracer_id][ddi]["adj_out"]
sensit_data = sensit.data
data_dates = pd.DatetimeIndex(sensit.time.data).to_pydatetime()
# Indexes of the data dates in the control vector
control_indexes = \
pd.Series(range(tracer.ndates),
index=tracer.dates).reindex(
data_dates, method="ffill").fillna(method="bfill")
input_indexes = \
pd.Series(range(len(tracer.input_dates[ddi])),
index=np.array(tracer.input_dates[ddi])[:, 0]).reindex(
data_dates, method="ffill")
# Find indexes in the control vector corresponding to each data periods
for k, period in period_dates.iterrows():
# Either take the corresponding slice of time,
# or take the exact date
# if the control variable is on a time stamp
try:
dd0, dd1 = period
mask = (data_dates >= dd0) & (data_dates < dd1)
# If period is a time stamp, i.e., dd0 = dd1, adapt the mask
if dd0 == dd1:
mask = data_dates == dd0
except TypeError:
dd0 = period
mask = data_dates == dd0
control_slice = control_indexes.iloc[k]
data_slice = input_indexes.iloc[k]
# If control_slice is nan,
# it means that the date in the sensitivity
# is not in the control vector, hence skipping
# print(__file__)
# import code
# code.interact(local=dict(locals(), **globals()))
if np.isnan(control_slice):
warning(
f"Skipping date {period} for {tracer_id} "
"as not in the control vector"
)
continue
# For variables stored as a scaling factor,
# scaling by the original value
phys = np.ones((mask.sum(), 1, 1, 1))
if getattr(tracer, "type", "scalar") == "scalar":
# Turn dates to datetime for consistency in read
# TODO: This should be standardized in the future
dates2read = [
[x.to_pydatetime() for x in row]
for row in tracer.input_dates[ddi].iloc[[data_slice]].itertuples(index=False, name=None)
]
phys = tracer.read(
trcr,
tracer.varname,
dates2read,
[tracer.input_files[ddi][data_slice]],
comp_type=mod_input,
tracer=tracer,
model=transform.model,
ddi=ddi,
debug_read=True,
**kwargs
).data
phys = np.where(np.isnan(phys), 0, phys)
# Vertical aggregation
vdata = np.sum(sensit_data[mask] * phys, axis=0)
vaggreg = vmap2vaggreg(vdata[np.newaxis], tracer,
tracer.domain, tracer_id)
# 2d maps to control vector slices
controlvect.dx[tracer.xpointer: tracer.xpointer + tracer.dim][
int(control_slice)
* tracer.hresoldim
* tracer.vresoldim:
int(control_slice + 1)
* tracer.hresoldim
* tracer.vresoldim
] += map2scale(vaggreg, tracer, tracer.domain).flatten()