import numpy as np
import pandas as pd
from logging import info, debug, warning
from .utils.dates import dateslice
from .utils.scalemaps import map2scale, vmap2vaggreg
[docs]
def adjoint(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
onlyinit=False,
**kwargs
):
"""Accumulate model sensitivities into the control-vector gradient ``dx``.
The adjoint of :func:`forward`: for each control tracer, takes the
model sensitivity field ``adj_out`` (shape: time × level × lat × lon)
from the output datastore and aggregates it back into the flattened
control-vector gradient ``controlvect.dx``.
The aggregation applies:
1. **Temporal aggregation** — sensitivities are summed into the
control-vector time periods via date-indexed lookup.
2. **Vertical aggregation** — :func:`vmap2vaggreg` reduces the full
3-D sensitivity to the vertical resolution of the control vector
(``column``, ``vpixels``, or ``kbands``).
3. **Horizontal projection** — :func:`map2scale` maps the 2-D
spatial sensitivity to the control-vector horizontal resolution
(pixels, bands, regions, or global).
4. **Prior scaling** — for scalar control variables the sensitivity
is weighted by the prior physical input field value.
Non-control tracers and tracers absent from ``controlvect.datavect``
are silently skipped.
Args:
transform (Plugin): fromcontrol transform instance.
inout_datastore (dict): datastore; ``'outputs'`` holds the
model-space sensitivities (``adj_out`` field).
controlvect: control vector plugin; ``dx`` is incremented in-place.
obsvect: observation vector plugin (unused).
mapper (dict): transform mapper with output tracer metadata.
di (datetime): sub-simulation start date.
df (datetime): sub-simulation end date.
mode (str): ``'adj'``.
runsubdir (str): sub-simulation run directory.
workdir (str): root working directory.
onlyinit (bool): if ``True``, return immediately (dry-run pass).
**kwargs: forwarded to ``tracer.read`` for prior scaling.
"""
if onlyinit:
return
ddi = min(di, df)
datastore = inout_datastore["outputs"]
datavect = controlvect.datavect
tracer_ids = mapper["outputs"]
# # If dx is not yet defined, initialize it
# if not hasattr(data, "dx"):
# info("Setting dx to zero in the control vector")
# data.dx = 0.0 * data.x
# Loop over model sensitivities
for tracer_id in tracer_ids:
mod_input = tracer_id[0]
trcr = tracer_id[1]
# If this type of input is not considered in the control vector,
# ignoring the model sensitivity
component = getattr(getattr(datavect, "components", None),
mod_input,
None)
parameters = getattr(component, "parameters", None)
if parameters is None:
debug(
f"The observation operator is sensitive to {mod_input} but your input vector doesn't include it as a component. Skipping it!"
)
continue
# Skip tracers not in the control space
if not hasattr(parameters, trcr) or not getattr(
getattr(parameters, trcr, None), "iscontrol", False
):
debug(
f"The observation operator is sensitive to {trcr} as a {mod_input} but your control vector doesn't include it as a component. Skipping it!"
)
continue
# Check that the model provides information about sensitivities
if "adj_out" not in datastore[tracer_id][ddi]:
info(
"Couldn't get any model sensitivity. "
"Assuming zero sensitivity"
)
continue
# Process other input types:
# - re-project map sensitivities to control space
# - sum date slices in the sensitivities to control space periods
tracer = getattr(parameters, trcr)
period_dates = mapper["outputs"][tracer_id]["input_dates"][di]
# Loop over control space periods for temporal aggregation
# Make vertical aggregation per temporal slice
sensit = datastore[tracer_id][ddi]["adj_out"]
sensit_data = sensit.data
data_dates = pd.DatetimeIndex(sensit.time.data).to_pydatetime()
# Indexes of the data dates in the control vector
control_indexes = \
pd.Series(range(tracer.ndates),
index=tracer.dates).reindex(
data_dates, method="ffill").fillna(method="bfill")
input_indexes = \
pd.Series(range(len(tracer.input_dates[ddi])),
index=np.array(tracer.input_dates[ddi])[:, 0]).reindex(
data_dates, method="ffill")
# Find indexes in the control vector corresponding to each data periods
for k, period in period_dates.iterrows():
# Either take the corresponding slice of time,
# or take the exact date
# if the control variable is on a time stamp
try:
dd0, dd1 = period
mask = (data_dates >= dd0) & (data_dates < dd1)
# If period is a time stamp, i.e., dd0 = dd1, adapt the mask
if dd0 == dd1:
mask = data_dates == dd0
except TypeError:
dd0 = period
mask = data_dates == dd0
control_slice = control_indexes.iloc[k]
data_slice = input_indexes.iloc[k]
# If control_slice is nan,
# it means that the date in the sensitivity
# is not in the control vector, hence skipping
# print(__file__)
# import code
# code.interact(local=dict(locals(), **globals()))
if np.isnan(control_slice):
warning(
f"Skipping date {period} for {tracer_id} "
"as not in the control vector"
)
continue
# For variables stored as a scaling factor,
# scaling by the original value
phys = np.ones((mask.sum(), 1, 1, 1))
if getattr(tracer, "type", "scalar") == "scalar":
# Turn dates to datetime for consistency in read
# TODO: This should be standardized in the future
dates2read = [
[x.to_pydatetime() for x in row]
for row in tracer.input_dates[ddi].iloc[[data_slice]].itertuples(index=False, name=None)
]
phys = tracer.read(
trcr,
tracer.varname,
dates2read,
[tracer.input_files[ddi][data_slice]],
comp_type=mod_input,
tracer=tracer,
model=transform.model,
ddi=ddi,
debug_read=True,
**kwargs
).data
phys = np.where(np.isnan(phys), 0, phys)
# Vertical aggregation
vdata = np.sum(sensit_data[mask] * phys, axis=0)
vaggreg = vmap2vaggreg(vdata[np.newaxis], tracer,
tracer.domain, tracer_id)
# 2d maps to control vector slices
controlvect.dx[tracer.xpointer: tracer.xpointer + tracer.dim][
int(control_slice)
* tracer.hresoldim
* tracer.vresoldim:
int(control_slice + 1)
* tracer.hresoldim
* tracer.vresoldim
] += map2scale(vaggreg, tracer, tracer.domain).flatten()