Source code for pycif.plugins.transforms.system.fromcontrol.forward

import numpy as np
import xarray as xr
import pandas as pd
import copy

from logging import warning, info
from .utils.dates import dateslice
from .....utils.dataarrays.reindex import reindex
from .utils.scalemaps import scale2map
from .forward_perturb import forward_perturb


[docs] def forward( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, **kwargs ): out_mapper = mapper["outputs"] # If perturb_xb and samples, speed up the process if getattr(transform, "perturb_xb", False) and \ any('__sample#' in tr[1] for tr in out_mapper): forward_perturb( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, **kwargs ) return ddi = min(di, df) xmod = inout_datastore["outputs"] # Basic objects for trid in out_mapper: comp = trid[0] trcr = trid[1] # Force reading inputs # This can be a dictionary based on dates, # originating from time_interpolation behaviour force_loadin = out_mapper[trid].get("force_loadin", False) if isinstance(force_loadin, dict): force_loadin = force_loadin.get(ddi, False) in_files = out_mapper[trid]["input_files"].get(ddi, []) in_dates = out_mapper[trid]["input_dates"].get(ddi, []) tracer = out_mapper[trid]["tracer"] # Skip parameters not in the control space if not getattr(tracer, "iscontrol", False): if force_loadin: if not hasattr(tracer, "read"): raise AttributeError( f"Needing to read data for {trid}, " f"but no read function is implemented for {tracer}" ) # Turn dates to datetime for consistency in read # TODO: This should be standardized in the future dates2read = [ [x.to_pydatetime() for x in row] for row in in_dates.itertuples(index=False, name=None) ] # If inputs do not already exist, read is called inputs = tracer.read( trcr, tracer.varname, dates2read, in_files, comp_type=comp, tracer=tracer, model=transform.model, ddi=ddi, **kwargs ) # Directly return what was read for sparse data if getattr(tracer, "sparse_data", False): inputs[("maindata", "incr")] = 0. xmod[trid][ddi] = inputs continue xmod[trid][ddi] = xr.Dataset({"spec": inputs}) # Force increments to zero if tangent linear if mode == "tl": xmod[trid][ddi]["incr"] = 0 * xmod[trid][ddi]["spec"] continue # Otherwise, reformat from x # Translates control vector, and increments if tangent-linear variables = {"scale": copy.deepcopy(controlvect.x)} if mode == "tl": variables["incr"] = copy.deepcopy(controlvect.dx) # Deal with control vector dates and cut if controlvect period spans # outside the sub-simulation period if len(in_dates) == 0: dslice = [] else: dslice = dateslice(tracer, np.min( in_dates["start_date"]), np.max(in_dates["end_date"])) cdates = tracer.dates[dslice] if len(dslice) == 0: pass elif len(dslice) == 1 and cdates[0] < ddi: cdates[0] = ddi elif len(dslice) > 1 and cdates[0] < ddi < cdates[1]: cdates[0] = ddi # Translates only control variables corresponding to the # simulation period xmod_out = {} for key, val in variables.items(): tmp = np.reshape( val[tracer.xpointer: tracer.xpointer + tracer.dim], (tracer.ndates, tracer.vresoldim, -1), )[dslice] # Deals with different resolutions xmod_out[key] = scale2map(tmp, tracer, cdates, tracer.domain) # Now deals with scalars and physical variables if getattr(tracer, "type", "scalar") == "scalar": # Turn dates to datetime for consistency in read # TODO: This should be standardized in the future dates2read = [ [x.to_pydatetime() for x in row] for row in tracer.input_dates[ddi].itertuples(index=False, name=None) ] # Read the tracer array and apply the present control vector # scaling factor try: inputs = tracer.read( trcr, tracer.varname, dates2read, tracer.input_files[ddi], comp_type=comp, tracer=tracer, model=transform.model, ddi=ddi, **kwargs ) # If input is a pd.DataFrame, not working yet if type(inputs) == pd.DataFrame: raise Exception( f"The inputs for {trcr} are DataFrames, hence sparse data. \n" "This is not covered by CIF yet! \n" "Input files: \n" f"{tracer.input_files[ddi][0]}" ) # Reprojecting inputs and scaling factors to the same merged # time steps outdates = pd.to_datetime( np.unique(np.array( mapper["outputs"][trid]["input_dates"][di]).flatten() ) ) if len(outdates) == 1: outdates = outdates.append(outdates) # NaNs in scale come from regions not included # in the control vector # Set to one to multiply by native inputs in scalar mode if tracer.hresol == "regions": xmod_out["scale"].values[ np.isnan(xmod_out["scale"]).values] = 1. if mode == "tl": xmod_out["incr"].values[ np.isnan(xmod_out["incr"]).values] = 0 try: inputs = reindex( inputs, levels={"time": outdates[:-1].values}, ) except: print(__file__) import code code.interact(local=dict(locals(), **globals())) if len(cdates) != 0: scale = reindex( xmod_out["scale"], levels={"time": outdates[:-1], "lev": inputs.lev}, ) # Check that dimensions are aligned if getattr(tracer, "is_lbc", False) and inputs.shape[2] != 1: raise Exception( f"WARNING! The tracer {trid} has option `is_lbc` " f"whereas the corresponding data sets are full domain" f" in {tracer.input_files[ddi]}. " f"Either provide flat lateral boundary condition files " f"or remove the option `is_lbc` in your yml" ) xmod_out["spec"] = inputs * scale if mode == "tl": incr = 0 if len(cdates) != 0: incr = reindex( xmod_out["incr"], levels={"time": outdates[:-1], "lev": inputs.lev}, ) xmod_out["incr"] = incr * inputs except ValueError as e: warning("I may have problem in alignment to LBC in the " "target vector. If components of the target vector " "correspond to LBC, please make sure that the " "attribute 'is_lbc' is set as True in the " "configuration yaml.") raise e # Data already contains the correct info for physical control variables # WARNING: so far, assumes that the vertical resolution is already # correct elif getattr(tracer, "type", "scalar") == "physical": # Reprojecting inputs and scaling factors to the same merged # time steps outdates = pd.to_datetime( np.unique(np.array( mapper["outputs"][trid]["input_dates"][di]).flatten() ) ) if len(outdates) == 1: outdates = np.append(outdates, outdates) # NaNs in scale come from regions not included # in the control vector # Set to zero in physical mode if tracer.hresol == "regions": xmod_out["scale"].values[ np.isnan(xmod_out["scale"]).values] = 0 if mode == "tl": xmod_out["incr"].values[ np.isnan(xmod_out["incr"]).values] = 0 spec = reindex( xmod_out["scale"], levels={"time": outdates[:-1], "lev": range(out_mapper[trid]["domain"].nlev)}, ) xmod_out["spec"] = spec if mode == "tl": incr = reindex( xmod_out["incr"], levels={"time": outdates[:-1], "lev": range(out_mapper[trid]["domain"].nlev)}, ) xmod_out["incr"] = incr # Removing the scaling factor as all information is stored in # 'spec' and 'incr' now xmod_out.pop("scale") xmod[trid][ddi] = xr.Dataset(xmod_out)