Source code for pycif.plugins.transforms.system.fromcontrol.forward

import numpy as np
import xarray as xr
import pandas as pd
import copy

from logging import warning, info
from .utils.dates import dateslice
from .....utils.dataarrays.reindex import reindex
from .utils.scalemaps import scale2map
from .forward_perturb import forward_perturb


[docs] def forward( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, **kwargs ): """Project the control vector onto physical model-input fields. For each output tracer declared in the mapper: * **Non-control tracers** (``iscontrol = False``) — input files are read via ``tracer.read`` and stored verbatim. * **Scalar control tracers** (``type = 'scalar'``, default) — the control-vector slice is unpacked to a map (:func:`scale2map`), re-indexed to the merged output date grid, and multiplied element-wise by the prior physical input field read from disk. * **Physical control tracers** (``type = 'physical'``) — the control-vector slice is projected directly to physical space at the model domain resolution without reading any prior files. In tangent-linear mode (``mode = 'tl'``), the same operations apply to ``controlvect.dx`` in parallel, producing an ``'incr'`` field. For ensemble runs (``perturb_xb = True``), delegates to :func:`forward_perturb` to handle all members in one I/O pass. Args: transform (Plugin): fromcontrol transform instance. inout_datastore (dict): mutable datastore (``'inputs'``, ``'outputs'``). controlvect: control vector plugin (``x``, ``dx``). obsvect: observation vector plugin (unused). mapper (dict): transform mapper with output tracer metadata. di (datetime): sub-simulation start date. df (datetime): sub-simulation end date. mode (str): ``'fwd'``, ``'tl'``, or ``'adj'``. runsubdir (str): sub-simulation run directory. workdir (str): root working directory. **kwargs: forwarded to ``tracer.read``. """ out_mapper = mapper["outputs"] # If perturb_xb and samples, speed up the process if getattr(transform, "perturb_xb", False) and \ any('__sample#' in tr[1] for tr in out_mapper): forward_perturb( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, **kwargs ) return ddi = min(di, df) xmod = inout_datastore["outputs"] # Basic objects for trid in out_mapper: comp = trid[0] trcr = trid[1] # Force reading inputs # This can be a dictionary based on dates, # originating from time_interpolation behaviour force_loadin = out_mapper[trid].get("force_loadin", False) if isinstance(force_loadin, dict): force_loadin = force_loadin.get(ddi, False) in_files = out_mapper[trid]["input_files"].get(ddi, []) in_dates = out_mapper[trid]["input_dates"].get(ddi, []) tracer = out_mapper[trid]["tracer"] # Skip parameters not in the control space if not getattr(tracer, "iscontrol", False): if force_loadin: if not hasattr(tracer, "read"): raise AttributeError( f"Needing to read data for {trid}, " f"but no read function is implemented for {tracer}" ) # Turn dates to datetime for consistency in read # TODO: This should be standardized in the future dates2read = [ [x.to_pydatetime() for x in row] for row in in_dates.itertuples(index=False, name=None) ] # If inputs do not already exist, read is called inputs = tracer.read( trcr, tracer.varname, dates2read, in_files, comp_type=comp, tracer=tracer, model=transform.model, ddi=ddi, **kwargs ) # Directly return what was read for sparse data if getattr(tracer, "sparse_data", False): inputs[("maindata", "incr")] = 0. xmod[trid][ddi] = inputs continue xmod[trid][ddi] = xr.Dataset({"spec": inputs}) # Force increments to zero if tangent linear if mode == "tl": xmod[trid][ddi]["incr"] = 0 * xmod[trid][ddi]["spec"] continue # Otherwise, reformat from x # Translates control vector, and increments if tangent-linear variables = {"scale": copy.deepcopy(controlvect.x)} if mode == "tl": variables["incr"] = copy.deepcopy(controlvect.dx) # Deal with control vector dates and cut if controlvect period spans # outside the sub-simulation period if len(in_dates) == 0: dslice = [] else: dslice = dateslice(tracer, np.min( in_dates["start_date"]), np.max(in_dates["end_date"])) cdates = tracer.dates[dslice] if len(dslice) == 0: pass elif len(dslice) == 1 and cdates[0] < ddi: cdates[0] = ddi elif len(dslice) > 1 and cdates[0] < ddi < cdates[1]: cdates[0] = ddi # Translates only control variables corresponding to the # simulation period xmod_out = {} for key, val in variables.items(): tmp = np.reshape( val[tracer.xpointer: tracer.xpointer + tracer.dim], (tracer.ndates, tracer.vresoldim, -1), )[dslice] # Deals with different resolutions xmod_out[key] = scale2map(tmp, tracer, cdates, tracer.domain) # Now deals with scalars and physical variables if getattr(tracer, "type", "scalar") == "scalar": # Turn dates to datetime for consistency in read # TODO: This should be standardized in the future dates2read = [ [x.to_pydatetime() for x in row] for row in tracer.input_dates[ddi].itertuples(index=False, name=None) ] # Read the tracer array and apply the present control vector # scaling factor try: inputs = tracer.read( trcr, tracer.varname, dates2read, tracer.input_files[ddi], comp_type=comp, tracer=tracer, model=transform.model, ddi=ddi, **kwargs ) # If input is a pd.DataFrame, not working yet if type(inputs) == pd.DataFrame: raise Exception( f"The inputs for {trcr} are DataFrames, hence sparse data. \n" "This is not covered by CIF yet! \n" "Input files: \n" f"{tracer.input_files[ddi][0]}" ) # Reprojecting inputs and scaling factors to the same merged # time steps outdates = pd.to_datetime( np.unique(np.array( mapper["outputs"][trid]["input_dates"][di]).flatten() ) ) if len(outdates) == 1: outdates = outdates.append(outdates) # NaNs in scale come from regions not included # in the control vector # Set to one to multiply by native inputs in scalar mode if tracer.hresol == "regions": xmod_out["scale"].values[ np.isnan(xmod_out["scale"]).values] = 1. if mode == "tl": xmod_out["incr"].values[ np.isnan(xmod_out["incr"]).values] = 0 try: inputs = reindex( inputs, levels={"time": outdates[:-1].values}, ) except: print(__file__) import code code.interact(local=dict(locals(), **globals())) if len(cdates) != 0: scale = reindex( xmod_out["scale"], levels={"time": outdates[:-1], "lev": inputs.lev}, ) # Check that dimensions are aligned if getattr(tracer, "is_lbc", False) and inputs.shape[2] != 1: raise Exception( f"WARNING! The tracer {trid} has option `is_lbc` " f"whereas the corresponding data sets are full domain" f" in {tracer.input_files[ddi]}. " f"Either provide flat lateral boundary condition files " f"or remove the option `is_lbc` in your yml" ) xmod_out["spec"] = inputs * scale if mode == "tl": incr = 0 if len(cdates) != 0: incr = reindex( xmod_out["incr"], levels={"time": outdates[:-1], "lev": inputs.lev}, ) xmod_out["incr"] = incr * inputs except ValueError as e: warning("I may have problem in alignment to LBC in the " "target vector. If components of the target vector " "correspond to LBC, please make sure that the " "attribute 'is_lbc' is set as True in the " "configuration yaml.") raise e # Data already contains the correct info for physical control variables # WARNING: so far, assumes that the vertical resolution is already # correct elif getattr(tracer, "type", "scalar") == "physical": # Reprojecting inputs and scaling factors to the same merged # time steps outdates = pd.to_datetime( np.unique(np.array( mapper["outputs"][trid]["input_dates"][di]).flatten() ) ) if len(outdates) == 1: outdates = np.append(outdates, outdates) # NaNs in scale come from regions not included # in the control vector # Set to zero in physical mode if tracer.hresol == "regions": xmod_out["scale"].values[ np.isnan(xmod_out["scale"]).values] = 0 if mode == "tl": xmod_out["incr"].values[ np.isnan(xmod_out["incr"]).values] = 0 spec = reindex( xmod_out["scale"], levels={"time": outdates[:-1], "lev": range(out_mapper[trid]["domain"].nlev)}, ) xmod_out["spec"] = spec if mode == "tl": incr = reindex( xmod_out["incr"], levels={"time": outdates[:-1], "lev": range(out_mapper[trid]["domain"].nlev)}, ) xmod_out["incr"] = incr # Removing the scaling factor as all information is stored in # 'spec' and 'incr' now xmod_out.pop("scale") xmod[trid][ddi] = xr.Dataset(xmod_out)