import os
import numpy as np
import xarray as xr
import pandas as pd
from logging import warning, info
from .utils.dates import dateslice
from .....utils.dataarrays.reindex import reindex
from .utils.scalemaps import scale2map
[docs]
def forward_perturb(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
**kwargs
):
ddi = min(di, df)
xmod = inout_datastore["outputs"]
# Basic objects
out_mapper = mapper["outputs"]
trids_all = list(out_mapper.keys())
trid_first = trids_all[0]
trcr_ref = trid_first[1].split("__sample#")[0]
nsamples = len(trids_all)
comp = trid_first[0]
trcr = trid_first[1]
# Force reading inputs
force_loadin = out_mapper[trid_first].get("force_loadin", False)
if isinstance(force_loadin, dict):
force_loadin = force_loadin.get(ddi, False)
in_dates = out_mapper[trid_first]["input_dates"].get(ddi, [])
tracer = out_mapper[trid_first]["tracer"]
# load once for all the variable x_ens with all samples
controlvect_file = os.path.join(
transform.dir_samples, transform.file_samples)
controlvect.load(
controlvect_file, component2load=comp, tracer2load=trcr_ref,
target_tracer=trcr, ensemble=True
)
# Translates control vector, and increments if tangent-linear*
variables = {"scale": controlvect.x_ens}
if mode == "tl":
variables["incr"] = controlvect.dx_ens
# Deal with control vector dates and cut if controlvect period spans
# outside the sub-simulation period
if len(in_dates) == 0:
dslice = []
else:
dslice = dateslice(tracer, np.min(
in_dates["start_date"]), np.max(in_dates["end_date"]))
cdates = tracer.dates[dslice]
if len(dslice) == 0:
pass
elif len(dslice) == 1 and cdates[0] < ddi:
cdates[0] = ddi
elif len(dslice) > 1 and cdates[0] < ddi < cdates[1]:
cdates[0] = ddi
# Translates only control variables corresponding to the
# simulation period
xmod_out = {}
for key, val in variables.items():
tmp = np.reshape(
val[:, tracer.xpointer: tracer.xpointer + tracer.dim],
(nsamples, tracer.ndates, tracer.vresoldim, -1),
)[:, dslice]
# Deals with different resolutions
xmod_out[key] = scale2map(tmp, tracer, cdates,
tracer.domain, ensemble=True)
# TODO: deal also with physical variables !
if getattr(tracer, "type", "scalar") == "physical":
raise NotImplementedError(
"physical variables not implemented with perturbations")
# Turn dates to datetime for consistency in read
# TODO: This should be standardized in the future
dates2read = [
[x.to_pydatetime() for x in row]
for row in tracer.input_dates[ddi].itertuples(index=False, name=None)
]
# If inputs do not already exist, read is called
inputs = tracer.read(
trcr,
tracer.varname,
dates2read,
tracer.input_files[ddi],
comp_type=comp,
tracer=tracer,
model=transform.model,
ddi=ddi,
**kwargs
)
# Reprojecting inputs and scaling factors to the same merged
# time steps
outdates = pd.to_datetime(
np.unique(np.array(
mapper["outputs"][trid_first]["input_dates"][di]).flatten()
)
)
if len(outdates) == 1:
outdates = np.append(outdates, outdates)
# NaNs in scale come from regions not included
# in the control vector
# Set to one to multiply by native inputs in scalar mode
if tracer.hresol == "regions":
xmod_out["scale"].values[
np.isnan(xmod_out["scale"]).values] = 1.
if mode == "tl":
xmod_out["incr"].values[np.isnan(xmod_out["incr"]).values] = 0
inputs = reindex(
inputs,
levels={"time": outdates[:-1]},
)
# load all the samples over the full vertical domain
if out_mapper[trid_first].get("loadin_perturb_full_vertical", True):
scale = 0
if len(cdates) != 0:
scale = reindex(
xmod_out["scale"],
levels={"time": outdates[:-1], "lev": inputs.lev},
nthreads=transform.nthreads
)
xmod_out.pop("scale")
xmod_out = xr.Dataset(
{'spec': np.multiply(inputs.values[np.newaxis], scale)}
)
for iens, tr in enumerate(trids_all):
xmod[tr][ddi] = xmod_out.isel(ens=iens)
# load only the first sample over the full vertical domain
# and takes the surface for the others
else:
scale_full = 0
scale_reduced = 0
if len(cdates) != 0:
# Full re-indexing for the first sample
scale_full = reindex(
xmod_out["scale"][0:1],
levels={"time": outdates[:-1], "lev": inputs.lev},
)
ilev_surf = out_mapper[trid_first].get("surface_level", 0)
lev_surf = inputs.lev.values[ilev_surf]
scale_reduced = reindex(
xmod_out["scale"][1:],
levels={"time": outdates[:-1], "lev": [lev_surf]},
)
xmod_out["spec_full"] = inputs.values[np.newaxis] * scale_full
xmod_out["spec_reduced"] = inputs.sel(
lev=slice(lev_surf, lev_surf + 1)).values[np.newaxis] * scale_reduced
xmod_out.pop("scale")
xmod_out_full = xr.Dataset({'spec': xmod_out["spec_full"]})
xmod_out_reduced = xr.Dataset({'spec': xmod_out["spec_reduced"]})
for iens, trid in enumerate(trids_all):
if iens == 0:
xmod[trid][ddi] = xmod_out_full.isel(ens=iens)
else:
xmod[trid][ddi] = xmod_out_reduced.isel(ens=iens - 1)