Source code for pycif.plugins.transforms.system.fromcontrol.forward_perturb

import os
import numpy as np
import xarray as xr
import pandas as pd

from logging import warning, info
from .utils.dates import dateslice
from .....utils.dataarrays.reindex import reindex
from .utils.scalemaps import scale2map


[docs] def forward_perturb( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, **kwargs ): ddi = min(di, df) xmod = inout_datastore["outputs"] # Basic objects out_mapper = mapper["outputs"] trids_all = list(out_mapper.keys()) trid_first = trids_all[0] trcr_ref = trid_first[1].split("__sample#")[0] nsamples = len(trids_all) comp = trid_first[0] trcr = trid_first[1] # Force reading inputs force_loadin = out_mapper[trid_first].get("force_loadin", False) if isinstance(force_loadin, dict): force_loadin = force_loadin.get(ddi, False) in_dates = out_mapper[trid_first]["input_dates"].get(ddi, []) tracer = out_mapper[trid_first]["tracer"] # load once for all the variable x_ens with all samples controlvect_file = os.path.join( transform.dir_samples, transform.file_samples) controlvect.load( controlvect_file, component2load=comp, tracer2load=trcr_ref, target_tracer=trcr, ensemble=True ) # Translates control vector, and increments if tangent-linear* variables = {"scale": controlvect.x_ens} if mode == "tl": variables["incr"] = controlvect.dx_ens # Deal with control vector dates and cut if controlvect period spans # outside the sub-simulation period if len(in_dates) == 0: dslice = [] else: dslice = dateslice(tracer, np.min( in_dates["start_date"]), np.max(in_dates["end_date"])) cdates = tracer.dates[dslice] if len(dslice) == 0: pass elif len(dslice) == 1 and cdates[0] < ddi: cdates[0] = ddi elif len(dslice) > 1 and cdates[0] < ddi < cdates[1]: cdates[0] = ddi # Translates only control variables corresponding to the # simulation period xmod_out = {} for key, val in variables.items(): tmp = np.reshape( val[:, tracer.xpointer: tracer.xpointer + tracer.dim], (nsamples, tracer.ndates, tracer.vresoldim, -1), )[:, dslice] # Deals with different resolutions xmod_out[key] = scale2map(tmp, tracer, cdates, tracer.domain, ensemble=True) # TODO: deal also with physical variables ! if getattr(tracer, "type", "scalar") == "physical": raise NotImplementedError( "physical variables not implemented with perturbations") # Turn dates to datetime for consistency in read # TODO: This should be standardized in the future dates2read = [ [x.to_pydatetime() for x in row] for row in tracer.input_dates[ddi].itertuples(index=False, name=None) ] # If inputs do not already exist, read is called inputs = tracer.read( trcr, tracer.varname, dates2read, tracer.input_files[ddi], comp_type=comp, tracer=tracer, model=transform.model, ddi=ddi, **kwargs ) # Reprojecting inputs and scaling factors to the same merged # time steps outdates = pd.to_datetime( np.unique(np.array( mapper["outputs"][trid_first]["input_dates"][di]).flatten() ) ) if len(outdates) == 1: outdates = np.append(outdates, outdates) # NaNs in scale come from regions not included # in the control vector # Set to one to multiply by native inputs in scalar mode if tracer.hresol == "regions": xmod_out["scale"].values[ np.isnan(xmod_out["scale"]).values] = 1. if mode == "tl": xmod_out["incr"].values[np.isnan(xmod_out["incr"]).values] = 0 inputs = reindex( inputs, levels={"time": outdates[:-1]}, ) # load all the samples over the full vertical domain if out_mapper[trid_first].get("loadin_perturb_full_vertical", True): scale = 0 if len(cdates) != 0: scale = reindex( xmod_out["scale"], levels={"time": outdates[:-1], "lev": inputs.lev}, nthreads=transform.nthreads ) xmod_out.pop("scale") xmod_out = xr.Dataset( {'spec': np.multiply(inputs.values[np.newaxis], scale)} ) for iens, tr in enumerate(trids_all): xmod[tr][ddi] = xmod_out.isel(ens=iens) # load only the first sample over the full vertical domain # and takes the surface for the others else: scale_full = 0 scale_reduced = 0 if len(cdates) != 0: # Full re-indexing for the first sample scale_full = reindex( xmod_out["scale"][0:1], levels={"time": outdates[:-1], "lev": inputs.lev}, ) ilev_surf = out_mapper[trid_first].get("surface_level", 0) lev_surf = inputs.lev.values[ilev_surf] scale_reduced = reindex( xmod_out["scale"][1:], levels={"time": outdates[:-1], "lev": [lev_surf]}, ) xmod_out["spec_full"] = inputs.values[np.newaxis] * scale_full xmod_out["spec_reduced"] = inputs.sel( lev=slice(lev_surf, lev_surf + 1)).values[np.newaxis] * scale_reduced xmod_out.pop("scale") xmod_out_full = xr.Dataset({'spec': xmod_out["spec_full"]}) xmod_out_reduced = xr.Dataset({'spec': xmod_out["spec_reduced"]}) for iens, trid in enumerate(trids_all): if iens == 0: xmod[trid][ddi] = xmod_out_full.isel(ens=iens) else: xmod[trid][ddi] = xmod_out_reduced.isel(ens=iens - 1)