Source code for pycif.plugins.transforms.system.fromcontrol.adjoint

import numpy as np
import pandas as pd

from logging import info, debug, warning
from .utils.dates import dateslice
from .utils.scalemaps import map2scale, vmap2vaggreg


[docs] def adjoint( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, onlyinit=False, **kwargs ): """Translates real-size data as extracted from the model outputs to the control space. This includes mainly temporal and spatial aggregation. This routine is used in adjoint mode, thus computes operations on increments. Args: self (Plugin): the control vect datastore (dict): the data at the model resolution to be converted to the control space di (datetime): starting date of the simulation window df (datetime): ending date of the simulation window workdir (str): pycif working directory """ if onlyinit: return ddi = min(di, df) datastore = inout_datastore["outputs"] datavect = controlvect.datavect tracer_ids = mapper["outputs"] # # If dx is not yet defined, initialize it # if not hasattr(data, "dx"): # info("Setting dx to zero in the control vector") # data.dx = 0.0 * data.x # Loop over model sensitivities for tracer_id in tracer_ids: mod_input = tracer_id[0] trcr = tracer_id[1] # If this type of input is not considered in the control vector, # ignoring the model sensitivity component = getattr(getattr(datavect, "components", None), mod_input, None) parameters = getattr(component, "parameters", None) if parameters is None: debug( "The observation operator is sensitive to {} " "but your input vector doesn't " "include it as a component. Skipping it!".format( mod_input ) ) continue # Skip tracers not in the control space if not hasattr(parameters, trcr) or not getattr( getattr(parameters, trcr, None), "iscontrol", False ): debug( "The observation operator is sensitive to {} as a {} " "but your control vector doesn't " "include it as a component. Skipping it!".format( trcr, mod_input ) ) continue # Check that the model provides information about sensitivities if "adj_out" not in datastore[tracer_id][ddi]: info( "Couldn't get any model sensitivity. " "Assuming zero sensitivity" ) continue # Process other input types: # - re-project map sensitivities to control space # - sum date slices in the sensitivities to control space periods tracer = getattr(parameters, trcr) period_dates = mapper["outputs"][tracer_id]["input_dates"][di] # Loop over control space periods for temporal aggregation # Make vertical aggregation per temporal slice sensit = datastore[tracer_id][ddi]["adj_out"] sensit_data = sensit.data data_dates = pd.DatetimeIndex(sensit.time.data).to_pydatetime() # Indexes of the data dates in the control vector control_indexes = \ pd.Series(range(tracer.ndates), index=tracer.dates).reindex( data_dates, method="ffill").fillna(method="bfill") input_indexes = \ pd.Series(range(len(tracer.input_dates[ddi])), index=np.array(tracer.input_dates[ddi])[:, 0]).reindex( data_dates, method="ffill") # Find indexes in the control vector corresponding to each data periods for k, period in period_dates.iterrows(): # Either take the corresponding slice of time, # or take the exact date # if the control variable is on a time stamp try: dd0, dd1 = period mask = (data_dates >= dd0) & (data_dates < dd1) # If period is a time stamp, i.e., dd0 = dd1, adapt the mask if dd0 == dd1: mask = data_dates == dd0 except TypeError: dd0 = period mask = data_dates == dd0 control_slice = control_indexes.iloc[k] data_slice = input_indexes.iloc[k] # If control_slice is nan, # it means that the date in the sensitivity # is not in the control vector, hence skipping # print(__file__) # import code # code.interact(local=dict(locals(), **globals())) if np.isnan(control_slice): warning( f"Skipping date {period} for {tracer_id} " "as not in the control vector" ) continue # For variables stored as a scaling factor, # scaling by the original value phys = np.ones((mask.sum(), 1, 1, 1)) if getattr(tracer, "type", "scalar") == "scalar": # Turn dates to datetime for consistency in read # TODO: This should be standardized in the future dates2read = [ [x.to_pydatetime() for x in row] for row in tracer.input_dates[ddi].iloc[[data_slice]].itertuples(index=False, name=None) ] phys = tracer.read( trcr, tracer.varname, dates2read, [tracer.input_files[ddi][data_slice]], comp_type=mod_input, tracer=tracer, model=transform.model, ddi=ddi, debug_read=True, **kwargs ).data phys = np.where(np.isnan(phys), 0, phys) # Vertical aggregation vdata = np.sum(sensit_data[mask] * phys, axis=0) vaggreg = vmap2vaggreg(vdata[np.newaxis], tracer, tracer.domain, tracer_id) # 2d maps to control vector slices controlvect.dx[tracer.xpointer: tracer.xpointer + tracer.dim][ int(control_slice) * tracer.hresoldim * tracer.vresoldim: int(control_slice + 1) * tracer.hresoldim * tracer.vresoldim ] += map2scale(vaggreg, tracer, tracer.domain).flatten()