Source code for pycif.plugins.transforms.system.toobsvect.adjoint

import pandas as pd
import copy
import xarray as xr
import numpy as np
from logging import info, debug

from .....utils.datastores.empty import init_empty
from .....utils.classes.domains import Domain
from .....utils.datastores.crop_monitor import crop_monitor
from .....utils.dates import date_range


[docs] def adjoint( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, onlyinit=False, **kwargs ): """Read observation-vector sensitivities and populate the input datastore. The adjoint of :func:`forward`: reads ``obsvect.dy`` (the observation-space adjoint increment, e.g. set by the simulator to :math:`\\mathbf{R}^{-1}(\\mathbf{H}\\mathbf{x} - \\mathbf{y})`) and writes it into the ``'adj_out'`` column of each tracer's sparse datastore entry. Also copies ``obsvect.ysim`` into ``'spec'`` so that downstream adjoint transforms can access the reference forward simulation. Locations outside the active observation mask (``obsvect.obsvect_mask``) are zeroed out. When ``split_freq`` is configured, only the observations in the current sub-window are processed and their dates are clipped to the window boundaries. Args: transform (Plugin): toobsvect transform instance. inout_datastore (dict): datastore; ``'inputs'`` entries are populated with sparse DataFrames carrying ``spec``, ``incr``, and ``adj_out`` columns. controlvect: control vector plugin (unused). obsvect: observation vector plugin; ``ysim`` and ``dy`` are read. mapper (dict): transform mapper. di (datetime): sub-simulation start date. df (datetime): sub-simulation end date. mode (str): ``'adj'``. runsubdir (str): sub-simulation run directory (unused). workdir (str): root working directory (unused). onlyinit (bool): unused; kept for interface consistency. **kwargs: unused. """ ddi = min(di, df) ddf = max(di, df) datastore = inout_datastore["inputs"] datavect = obsvect.datavect tracer_ids = mapper["inputs"] # Fetch reference information tracer_ref = list(tracer_ids)[0] mod_input = tracer_ref[0] trcr = tracer_ref[1] # If this type of input is not considered in the observation vector, # ignoring the model sensitivity component = getattr(getattr(datavect, "components", None), mod_input, None) # Now fetch info from data structure to common structure parameters = getattr(component, "parameters", None) param = getattr(parameters, trcr, None) # Skip tracers not in the obs space if not hasattr(parameters, trcr) or not getattr( getattr(parameters, trcr, None), "isobs", False ): debug( f"The observation operator simulates {trcr} as a {mod_input} but your observation vector doesn't include it as a component" ) return if parameters is None: debug( f"The observation operator simulates {mod_input} but your observation vector doesn't include it as a parameter" ) return if param is None or not getattr(param, "isobs", False): debug(f"{trcr}/{mod_input} was not is the observation vector. Passing") return # Now fetch info from data structure to common structure ds_meta = param.datastore ds_dates = ds_meta[ [("metadata", "date"), ("metadata", "duration"), ("metadata", "enddate")] ] # Keep only sub-period if `split_freq` is specified crop_index = np.arange(len(ds_meta), dtype=int) weight = 1.0 if hasattr(transform, "split_freq"): obsvect_dates = date_range( transform.datei, transform.datef, period=transform.split_freq) i0 = np.argwhere(obsvect_dates == ddi)[0][0] split_ddf = obsvect_dates[i0 + 1] crop_index, weight = crop_monitor( ds_dates, ddi, split_ddf, return_index=True, keep_partial=True, return_weight=True ) # Now loop over tracers to get values from the observation vector for tracer_id in datastore: component = getattr( getattr(datavect, "components", None), tracer_id[0], None ) parameters = getattr(component, "parameters", None) param = getattr(parameters, tracer_id[1]) ds = param.datastore # Crop according to dates to process data = copy.copy(ds.iloc[crop_index]).reset_index(drop=True) # Stop here if empty if len(data) == 0: continue # Adjust dates if split_freq if hasattr(transform, "split_freq"): data.loc[:, ("metadata", "date")] = \ np.maximum( data.loc[:, ("metadata", "date")].dt.to_pydatetime(), ddi ) data.loc[:, ("metadata", "enddate")] = \ np.minimum( data.loc[:, ("metadata", "enddate")].dt.to_pydatetime(), split_ddf ) # Fetching values from increments and forward if needed later data.loc[:, ("maindata", "spec")] = \ obsvect.ysim[param.ypointer: param.ypointer + param.dim][crop_index] data.loc[:, ("maindata", "incr")] = \ 0. * \ obsvect.ysim[param.ypointer: param.ypointer + param.dim][crop_index] data.loc[:, ("maindata", "adj_out")] = \ obsvect.dy[param.ypointer: param.ypointer + param.dim][crop_index] \ * weight # Set incr and adj_out to zero when not in observation vector obsvect_mask = obsvect.obsvect_mask[ param.ypointer: param.ypointer + param.dim][crop_index] data.loc[~obsvect_mask, ("maindata", "incr")] = 0 data.loc[~obsvect_mask, ("maindata", "adj_out")] = 0 datastore[tracer_id][ddi] = data