Source code for pycif.plugins.transforms.system.fromcontrol.adjoint

import numpy as np
import pandas as pd

from logging import info, debug, warning
from .utils.dates import dateslice
from .utils.scalemaps import map2scale, vmap2vaggreg


[docs] def adjoint( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, onlyinit=False, **kwargs ): """Accumulate model sensitivities into the control-vector gradient ``dx``. The adjoint of :func:`forward`: for each control tracer, takes the model sensitivity field ``adj_out`` (shape: time × level × lat × lon) from the output datastore and aggregates it back into the flattened control-vector gradient ``controlvect.dx``. The aggregation applies: 1. **Temporal aggregation** — sensitivities are summed into the control-vector time periods via date-indexed lookup. 2. **Vertical aggregation** — :func:`vmap2vaggreg` reduces the full 3-D sensitivity to the vertical resolution of the control vector (``column``, ``vpixels``, or ``kbands``). 3. **Horizontal projection** — :func:`map2scale` maps the 2-D spatial sensitivity to the control-vector horizontal resolution (pixels, bands, regions, or global). 4. **Prior scaling** — for scalar control variables the sensitivity is weighted by the prior physical input field value. Non-control tracers and tracers absent from ``controlvect.datavect`` are silently skipped. Args: transform (Plugin): fromcontrol transform instance. inout_datastore (dict): datastore; ``'outputs'`` holds the model-space sensitivities (``adj_out`` field). controlvect: control vector plugin; ``dx`` is incremented in-place. obsvect: observation vector plugin (unused). mapper (dict): transform mapper with output tracer metadata. di (datetime): sub-simulation start date. df (datetime): sub-simulation end date. mode (str): ``'adj'``. runsubdir (str): sub-simulation run directory. workdir (str): root working directory. onlyinit (bool): if ``True``, return immediately (dry-run pass). **kwargs: forwarded to ``tracer.read`` for prior scaling. """ if onlyinit: return ddi = min(di, df) datastore = inout_datastore["outputs"] datavect = controlvect.datavect tracer_ids = mapper["outputs"] # # If dx is not yet defined, initialize it # if not hasattr(data, "dx"): # info("Setting dx to zero in the control vector") # data.dx = 0.0 * data.x # Loop over model sensitivities for tracer_id in tracer_ids: mod_input = tracer_id[0] trcr = tracer_id[1] # If this type of input is not considered in the control vector, # ignoring the model sensitivity component = getattr(getattr(datavect, "components", None), mod_input, None) parameters = getattr(component, "parameters", None) if parameters is None: debug( f"The observation operator is sensitive to {mod_input} but your input vector doesn't include it as a component. Skipping it!" ) continue # Skip tracers not in the control space if not hasattr(parameters, trcr) or not getattr( getattr(parameters, trcr, None), "iscontrol", False ): debug( f"The observation operator is sensitive to {trcr} as a {mod_input} but your control vector doesn't include it as a component. Skipping it!" ) continue # Check that the model provides information about sensitivities if "adj_out" not in datastore[tracer_id][ddi]: info( "Couldn't get any model sensitivity. " "Assuming zero sensitivity" ) continue # Process other input types: # - re-project map sensitivities to control space # - sum date slices in the sensitivities to control space periods tracer = getattr(parameters, trcr) period_dates = mapper["outputs"][tracer_id]["input_dates"][di] # Loop over control space periods for temporal aggregation # Make vertical aggregation per temporal slice sensit = datastore[tracer_id][ddi]["adj_out"] sensit_data = sensit.data data_dates = pd.DatetimeIndex(sensit.time.data).to_pydatetime() # Indexes of the data dates in the control vector control_indexes = \ pd.Series(range(tracer.ndates), index=tracer.dates).reindex( data_dates, method="ffill").fillna(method="bfill") input_indexes = \ pd.Series(range(len(tracer.input_dates[ddi])), index=np.array(tracer.input_dates[ddi])[:, 0]).reindex( data_dates, method="ffill") # Find indexes in the control vector corresponding to each data periods for k, period in period_dates.iterrows(): # Either take the corresponding slice of time, # or take the exact date # if the control variable is on a time stamp try: dd0, dd1 = period mask = (data_dates >= dd0) & (data_dates < dd1) # If period is a time stamp, i.e., dd0 = dd1, adapt the mask if dd0 == dd1: mask = data_dates == dd0 except TypeError: dd0 = period mask = data_dates == dd0 control_slice = control_indexes.iloc[k] data_slice = input_indexes.iloc[k] # If control_slice is nan, # it means that the date in the sensitivity # is not in the control vector, hence skipping # print(__file__) # import code # code.interact(local=dict(locals(), **globals())) if np.isnan(control_slice): warning( f"Skipping date {period} for {tracer_id} " "as not in the control vector" ) continue # For variables stored as a scaling factor, # scaling by the original value phys = np.ones((mask.sum(), 1, 1, 1)) if getattr(tracer, "type", "scalar") == "scalar": # Turn dates to datetime for consistency in read # TODO: This should be standardized in the future dates2read = [ [x.to_pydatetime() for x in row] for row in tracer.input_dates[ddi].iloc[[data_slice]].itertuples(index=False, name=None) ] phys = tracer.read( trcr, tracer.varname, dates2read, [tracer.input_files[ddi][data_slice]], comp_type=mod_input, tracer=tracer, model=transform.model, ddi=ddi, debug_read=True, **kwargs ).data phys = np.where(np.isnan(phys), 0, phys) # Vertical aggregation vdata = np.sum(sensit_data[mask] * phys, axis=0) vaggreg = vmap2vaggreg(vdata[np.newaxis], tracer, tracer.domain, tracer_id) # 2d maps to control vector slices controlvect.dx[tracer.xpointer: tracer.xpointer + tracer.dim][ int(control_slice) * tracer.hresoldim * tracer.vresoldim: int(control_slice + 1) * tracer.hresoldim * tracer.vresoldim ] += map2scale(vaggreg, tracer, tracer.domain).flatten()