import os
import numpy as np
import xarray as xr
import pandas as pd
from logging import warning, info
from .utils.dates import dateslice
from .....utils.dataarrays.reindex import reindex
from .utils.scalemaps import scale2map
[docs]
def forward_perturb(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
**kwargs
):
"""Ensemble-optimised forward projection for perturbation (EnSRF / MC) runs.
Loads all ensemble members from a pre-saved control-vector file in a
single I/O call (``controlvect.load(..., ensemble=True)``), then
applies :func:`scale2map` to the entire ensemble at once. This avoids
the per-member I/O overhead of calling the standard
:func:`~fromcontrol.forward.forward` function ``nsamples`` times.
Two memory strategies are supported (selected via the mapper flag
``loadin_perturb_full_vertical``):
* **Full vertical** (default) — all levels are loaded for every
member. Produces ``inout_datastore["outputs"][trid][ddi]["spec"]``
for each sample tracer ID.
* **Reduced vertical** — only the first member uses the full vertical
extent; subsequent members use only the surface level (index
``surface_level``, default 0). This reduces peak memory at the cost
of approximate upper-level perturbations.
.. note::
Physical control variables (``type = 'physical'``) are not yet
supported and will raise ``NotImplementedError``.
Args:
transform (Plugin): fromcontrol transform instance with
``dir_samples`` and ``file_samples`` attributes set by
:func:`perturb_transform`.
inout_datastore (dict): mutable datastore.
controlvect: control vector plugin (``x_ens``, ``dx_ens`` loaded
in-place).
obsvect: observation vector plugin (unused).
mapper (dict): transform mapper; ``'outputs'`` contains one entry
per sample tracer ID (``__sample#N``).
di (datetime): sub-simulation start date.
df (datetime): sub-simulation end date.
mode (str): ``'fwd'`` or ``'tl'``.
runsubdir (str): sub-simulation run directory.
workdir (str): root working directory.
**kwargs: forwarded to ``tracer.read``.
"""
ddi = min(di, df)
xmod = inout_datastore["outputs"]
# Basic objects
out_mapper = mapper["outputs"]
trids_all = list(out_mapper.keys())
trid_first = trids_all[0]
trcr_ref = trid_first[1].split("__sample#")[0]
nsamples = len(trids_all)
comp = trid_first[0]
trcr = trid_first[1]
# Force reading inputs
force_loadin = out_mapper[trid_first].get("force_loadin", False)
if isinstance(force_loadin, dict):
force_loadin = force_loadin.get(ddi, False)
in_dates = out_mapper[trid_first]["input_dates"].get(ddi, [])
tracer = out_mapper[trid_first]["tracer"]
# load once for all the variable x_ens with all samples
controlvect_file = os.path.join(
transform.dir_samples, transform.file_samples)
controlvect.load(
controlvect_file, component2load=comp, tracer2load=trcr_ref,
target_tracer=trcr, ensemble=True
)
# Translates control vector, and increments if tangent-linear*
variables = {"scale": controlvect.x_ens}
if mode == "tl":
variables["incr"] = controlvect.dx_ens
# Deal with control vector dates and cut if controlvect period spans
# outside the sub-simulation period
if len(in_dates) == 0:
dslice = []
else:
dslice = dateslice(tracer, np.min(
in_dates["start_date"]), np.max(in_dates["end_date"]))
cdates = tracer.dates[dslice]
if len(dslice) == 0:
pass
elif len(dslice) == 1 and cdates[0] < ddi:
cdates[0] = ddi
elif len(dslice) > 1 and cdates[0] < ddi < cdates[1]:
cdates[0] = ddi
# Translates only control variables corresponding to the
# simulation period
xmod_out = {}
for key, val in variables.items():
tmp = np.reshape(
val[:, tracer.xpointer: tracer.xpointer + tracer.dim],
(nsamples, tracer.ndates, tracer.vresoldim, -1),
)[:, dslice]
# Deals with different resolutions
xmod_out[key] = scale2map(tmp, tracer, cdates,
tracer.domain, ensemble=True)
# TODO: deal also with physical variables !
if getattr(tracer, "type", "scalar") == "physical":
raise NotImplementedError(
"physical variables not implemented with perturbations")
# Turn dates to datetime for consistency in read
# TODO: This should be standardized in the future
dates2read = [
[x.to_pydatetime() for x in row]
for row in tracer.input_dates[ddi].itertuples(index=False, name=None)
]
# If inputs do not already exist, read is called
inputs = tracer.read(
trcr,
tracer.varname,
dates2read,
tracer.input_files[ddi],
comp_type=comp,
tracer=tracer,
model=transform.model,
ddi=ddi,
**kwargs
)
# Reprojecting inputs and scaling factors to the same merged
# time steps
outdates = pd.to_datetime(
np.unique(np.array(
mapper["outputs"][trid_first]["input_dates"][di]).flatten()
)
)
if len(outdates) == 1:
outdates = np.append(outdates, outdates)
# NaNs in scale come from regions not included
# in the control vector
# Set to one to multiply by native inputs in scalar mode
if tracer.hresol == "regions":
xmod_out["scale"].values[
np.isnan(xmod_out["scale"]).values] = 1.
if mode == "tl":
xmod_out["incr"].values[np.isnan(xmod_out["incr"]).values] = 0
inputs = reindex(
inputs,
levels={"time": outdates[:-1]},
)
# load all the samples over the full vertical domain
if out_mapper[trid_first].get("loadin_perturb_full_vertical", True):
scale = 0
if len(cdates) != 0:
scale = reindex(
xmod_out["scale"],
levels={"time": outdates[:-1], "lev": inputs.lev},
nthreads=transform.nthreads
)
xmod_out.pop("scale")
xmod_out = xr.Dataset(
{'spec': np.multiply(inputs.values[np.newaxis], scale)}
)
for iens, tr in enumerate(trids_all):
xmod[tr][ddi] = xmod_out.isel(ens=iens)
# load only the first sample over the full vertical domain
# and takes the surface for the others
else:
scale_full = 0
scale_reduced = 0
if len(cdates) != 0:
# Full re-indexing for the first sample
scale_full = reindex(
xmod_out["scale"][0:1],
levels={"time": outdates[:-1], "lev": inputs.lev},
)
ilev_surf = out_mapper[trid_first].get("surface_level", 0)
lev_surf = inputs.lev.values[ilev_surf]
scale_reduced = reindex(
xmod_out["scale"][1:],
levels={"time": outdates[:-1], "lev": [lev_surf]},
)
xmod_out["spec_full"] = inputs.values[np.newaxis] * scale_full
xmod_out["spec_reduced"] = inputs.sel(
lev=slice(lev_surf, lev_surf + 1)).values[np.newaxis] * scale_reduced
xmod_out.pop("scale")
xmod_out_full = xr.Dataset({'spec': xmod_out["spec_full"]})
xmod_out_reduced = xr.Dataset({'spec': xmod_out["spec_reduced"]})
for iens, trid in enumerate(trids_all):
if iens == 0:
xmod[trid][ddi] = xmod_out_full.isel(ens=iens)
else:
xmod[trid][ddi] = xmod_out_reduced.isel(ens=iens - 1)