Source code for pycif.plugins.transforms.system.fromcontrol.forward_perturb

import os
import numpy as np
import xarray as xr
import pandas as pd

from logging import warning, info
from .utils.dates import dateslice
from .....utils.dataarrays.reindex import reindex
from .utils.scalemaps import scale2map


[docs] def forward_perturb( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, **kwargs ): """Ensemble-optimised forward projection for perturbation (EnSRF / MC) runs. Loads all ensemble members from a pre-saved control-vector file in a single I/O call (``controlvect.load(..., ensemble=True)``), then applies :func:`scale2map` to the entire ensemble at once. This avoids the per-member I/O overhead of calling the standard :func:`~fromcontrol.forward.forward` function ``nsamples`` times. Two memory strategies are supported (selected via the mapper flag ``loadin_perturb_full_vertical``): * **Full vertical** (default) — all levels are loaded for every member. Produces ``inout_datastore["outputs"][trid][ddi]["spec"]`` for each sample tracer ID. * **Reduced vertical** — only the first member uses the full vertical extent; subsequent members use only the surface level (index ``surface_level``, default 0). This reduces peak memory at the cost of approximate upper-level perturbations. .. note:: Physical control variables (``type = 'physical'``) are not yet supported and will raise ``NotImplementedError``. Args: transform (Plugin): fromcontrol transform instance with ``dir_samples`` and ``file_samples`` attributes set by :func:`perturb_transform`. inout_datastore (dict): mutable datastore. controlvect: control vector plugin (``x_ens``, ``dx_ens`` loaded in-place). obsvect: observation vector plugin (unused). mapper (dict): transform mapper; ``'outputs'`` contains one entry per sample tracer ID (``__sample#N``). di (datetime): sub-simulation start date. df (datetime): sub-simulation end date. mode (str): ``'fwd'`` or ``'tl'``. runsubdir (str): sub-simulation run directory. workdir (str): root working directory. **kwargs: forwarded to ``tracer.read``. """ ddi = min(di, df) xmod = inout_datastore["outputs"] # Basic objects out_mapper = mapper["outputs"] trids_all = list(out_mapper.keys()) trid_first = trids_all[0] trcr_ref = trid_first[1].split("__sample#")[0] nsamples = len(trids_all) comp = trid_first[0] trcr = trid_first[1] # Force reading inputs force_loadin = out_mapper[trid_first].get("force_loadin", False) if isinstance(force_loadin, dict): force_loadin = force_loadin.get(ddi, False) in_dates = out_mapper[trid_first]["input_dates"].get(ddi, []) tracer = out_mapper[trid_first]["tracer"] # load once for all the variable x_ens with all samples controlvect_file = os.path.join( transform.dir_samples, transform.file_samples) controlvect.load( controlvect_file, component2load=comp, tracer2load=trcr_ref, target_tracer=trcr, ensemble=True ) # Translates control vector, and increments if tangent-linear* variables = {"scale": controlvect.x_ens} if mode == "tl": variables["incr"] = controlvect.dx_ens # Deal with control vector dates and cut if controlvect period spans # outside the sub-simulation period if len(in_dates) == 0: dslice = [] else: dslice = dateslice(tracer, np.min( in_dates["start_date"]), np.max(in_dates["end_date"])) cdates = tracer.dates[dslice] if len(dslice) == 0: pass elif len(dslice) == 1 and cdates[0] < ddi: cdates[0] = ddi elif len(dslice) > 1 and cdates[0] < ddi < cdates[1]: cdates[0] = ddi # Translates only control variables corresponding to the # simulation period xmod_out = {} for key, val in variables.items(): tmp = np.reshape( val[:, tracer.xpointer: tracer.xpointer + tracer.dim], (nsamples, tracer.ndates, tracer.vresoldim, -1), )[:, dslice] # Deals with different resolutions xmod_out[key] = scale2map(tmp, tracer, cdates, tracer.domain, ensemble=True) # TODO: deal also with physical variables ! if getattr(tracer, "type", "scalar") == "physical": raise NotImplementedError( "physical variables not implemented with perturbations") # Turn dates to datetime for consistency in read # TODO: This should be standardized in the future dates2read = [ [x.to_pydatetime() for x in row] for row in tracer.input_dates[ddi].itertuples(index=False, name=None) ] # If inputs do not already exist, read is called inputs = tracer.read( trcr, tracer.varname, dates2read, tracer.input_files[ddi], comp_type=comp, tracer=tracer, model=transform.model, ddi=ddi, **kwargs ) # Reprojecting inputs and scaling factors to the same merged # time steps outdates = pd.to_datetime( np.unique(np.array( mapper["outputs"][trid_first]["input_dates"][di]).flatten() ) ) if len(outdates) == 1: outdates = np.append(outdates, outdates) # NaNs in scale come from regions not included # in the control vector # Set to one to multiply by native inputs in scalar mode if tracer.hresol == "regions": xmod_out["scale"].values[ np.isnan(xmod_out["scale"]).values] = 1. if mode == "tl": xmod_out["incr"].values[np.isnan(xmod_out["incr"]).values] = 0 inputs = reindex( inputs, levels={"time": outdates[:-1]}, ) # load all the samples over the full vertical domain if out_mapper[trid_first].get("loadin_perturb_full_vertical", True): scale = 0 if len(cdates) != 0: scale = reindex( xmod_out["scale"], levels={"time": outdates[:-1], "lev": inputs.lev}, nthreads=transform.nthreads ) xmod_out.pop("scale") xmod_out = xr.Dataset( {'spec': np.multiply(inputs.values[np.newaxis], scale)} ) for iens, tr in enumerate(trids_all): xmod[tr][ddi] = xmod_out.isel(ens=iens) # load only the first sample over the full vertical domain # and takes the surface for the others else: scale_full = 0 scale_reduced = 0 if len(cdates) != 0: # Full re-indexing for the first sample scale_full = reindex( xmod_out["scale"][0:1], levels={"time": outdates[:-1], "lev": inputs.lev}, ) ilev_surf = out_mapper[trid_first].get("surface_level", 0) lev_surf = inputs.lev.values[ilev_surf] scale_reduced = reindex( xmod_out["scale"][1:], levels={"time": outdates[:-1], "lev": [lev_surf]}, ) xmod_out["spec_full"] = inputs.values[np.newaxis] * scale_full xmod_out["spec_reduced"] = inputs.sel( lev=slice(lev_surf, lev_surf + 1)).values[np.newaxis] * scale_reduced xmod_out.pop("scale") xmod_out_full = xr.Dataset({'spec': xmod_out["spec_full"]}) xmod_out_reduced = xr.Dataset({'spec': xmod_out["spec_reduced"]}) for iens, trid in enumerate(trids_all): if iens == 0: xmod[trid][ddi] = xmod_out_full.isel(ens=iens) else: xmod[trid][ddi] = xmod_out_reduced.isel(ens=iens - 1)