import numpy as np
import xarray as xr
import pandas as pd
import copy
from logging import warning, info
from .utils.dates import dateslice
from .....utils.dataarrays.reindex import reindex
from .utils.scalemaps import scale2map
from .forward_perturb import forward_perturb
[docs]
def forward(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
**kwargs
):
"""Project the control vector onto physical model-input fields.
For each output tracer declared in the mapper:
* **Non-control tracers** (``iscontrol = False``) — input files are
read via ``tracer.read`` and stored verbatim.
* **Scalar control tracers** (``type = 'scalar'``, default) — the
control-vector slice is unpacked to a map (:func:`scale2map`),
re-indexed to the merged output date grid, and multiplied
element-wise by the prior physical input field read from disk.
* **Physical control tracers** (``type = 'physical'``) — the
control-vector slice is projected directly to physical space at the
model domain resolution without reading any prior files.
In tangent-linear mode (``mode = 'tl'``), the same operations apply
to ``controlvect.dx`` in parallel, producing an ``'incr'`` field.
For ensemble runs (``perturb_xb = True``), delegates to
:func:`forward_perturb` to handle all members in one I/O pass.
Args:
transform (Plugin): fromcontrol transform instance.
inout_datastore (dict): mutable datastore (``'inputs'``, ``'outputs'``).
controlvect: control vector plugin (``x``, ``dx``).
obsvect: observation vector plugin (unused).
mapper (dict): transform mapper with output tracer metadata.
di (datetime): sub-simulation start date.
df (datetime): sub-simulation end date.
mode (str): ``'fwd'``, ``'tl'``, or ``'adj'``.
runsubdir (str): sub-simulation run directory.
workdir (str): root working directory.
**kwargs: forwarded to ``tracer.read``.
"""
out_mapper = mapper["outputs"]
# If perturb_xb and samples, speed up the process
if getattr(transform, "perturb_xb", False) and \
any('__sample#' in tr[1] for tr in out_mapper):
forward_perturb(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
**kwargs
)
return
ddi = min(di, df)
xmod = inout_datastore["outputs"]
# Basic objects
for trid in out_mapper:
comp = trid[0]
trcr = trid[1]
# Force reading inputs
# This can be a dictionary based on dates,
# originating from time_interpolation behaviour
force_loadin = out_mapper[trid].get("force_loadin", False)
if isinstance(force_loadin, dict):
force_loadin = force_loadin.get(ddi, False)
in_files = out_mapper[trid]["input_files"].get(ddi, [])
in_dates = out_mapper[trid]["input_dates"].get(ddi, [])
tracer = out_mapper[trid]["tracer"]
# Skip parameters not in the control space
if not getattr(tracer, "iscontrol", False):
if force_loadin:
if not hasattr(tracer, "read"):
raise AttributeError(
f"Needing to read data for {trid}, "
f"but no read function is implemented for {tracer}"
)
# Turn dates to datetime for consistency in read
# TODO: This should be standardized in the future
dates2read = [
[x.to_pydatetime() for x in row]
for row in in_dates.itertuples(index=False, name=None)
]
# If inputs do not already exist, read is called
inputs = tracer.read(
trcr,
tracer.varname,
dates2read,
in_files,
comp_type=comp,
tracer=tracer,
model=transform.model,
ddi=ddi,
**kwargs
)
# Directly return what was read for sparse data
if getattr(tracer, "sparse_data", False):
inputs[("maindata", "incr")] = 0.
xmod[trid][ddi] = inputs
continue
xmod[trid][ddi] = xr.Dataset({"spec": inputs})
# Force increments to zero if tangent linear
if mode == "tl":
xmod[trid][ddi]["incr"] = 0 * xmod[trid][ddi]["spec"]
continue
# Otherwise, reformat from x
# Translates control vector, and increments if tangent-linear
variables = {"scale": copy.deepcopy(controlvect.x)}
if mode == "tl":
variables["incr"] = copy.deepcopy(controlvect.dx)
# Deal with control vector dates and cut if controlvect period spans
# outside the sub-simulation period
if len(in_dates) == 0:
dslice = []
else:
dslice = dateslice(tracer, np.min(
in_dates["start_date"]), np.max(in_dates["end_date"]))
cdates = tracer.dates[dslice]
if len(dslice) == 0:
pass
elif len(dslice) == 1 and cdates[0] < ddi:
cdates[0] = ddi
elif len(dslice) > 1 and cdates[0] < ddi < cdates[1]:
cdates[0] = ddi
# Translates only control variables corresponding to the
# simulation period
xmod_out = {}
for key, val in variables.items():
tmp = np.reshape(
val[tracer.xpointer: tracer.xpointer + tracer.dim],
(tracer.ndates, tracer.vresoldim, -1),
)[dslice]
# Deals with different resolutions
xmod_out[key] = scale2map(tmp, tracer, cdates, tracer.domain)
# Now deals with scalars and physical variables
if getattr(tracer, "type", "scalar") == "scalar":
# Turn dates to datetime for consistency in read
# TODO: This should be standardized in the future
dates2read = [
[x.to_pydatetime() for x in row]
for row in tracer.input_dates[ddi].itertuples(index=False, name=None)
]
# Read the tracer array and apply the present control vector
# scaling factor
try:
inputs = tracer.read(
trcr,
tracer.varname,
dates2read,
tracer.input_files[ddi],
comp_type=comp,
tracer=tracer,
model=transform.model,
ddi=ddi,
**kwargs
)
# If input is a pd.DataFrame, not working yet
if type(inputs) == pd.DataFrame:
raise Exception(
f"The inputs for {trcr} are DataFrames, hence sparse data. \n"
"This is not covered by CIF yet! \n"
"Input files: \n"
f"{tracer.input_files[ddi][0]}"
)
# Reprojecting inputs and scaling factors to the same merged
# time steps
outdates = pd.to_datetime(
np.unique(np.array(
mapper["outputs"][trid]["input_dates"][di]).flatten()
)
)
if len(outdates) == 1:
outdates = outdates.append(outdates)
# NaNs in scale come from regions not included
# in the control vector
# Set to one to multiply by native inputs in scalar mode
if tracer.hresol == "regions":
xmod_out["scale"].values[
np.isnan(xmod_out["scale"]).values] = 1.
if mode == "tl":
xmod_out["incr"].values[
np.isnan(xmod_out["incr"]).values] = 0
try:
inputs = reindex(
inputs,
levels={"time": outdates[:-1].values},
)
except:
print(__file__)
import code
code.interact(local=dict(locals(), **globals()))
if len(cdates) != 0:
scale = reindex(
xmod_out["scale"],
levels={"time": outdates[:-1], "lev": inputs.lev},
)
# Check that dimensions are aligned
if getattr(tracer, "is_lbc", False) and inputs.shape[2] != 1:
raise Exception(
f"WARNING! The tracer {trid} has option `is_lbc` "
f"whereas the corresponding data sets are full domain"
f" in {tracer.input_files[ddi]}. "
f"Either provide flat lateral boundary condition files "
f"or remove the option `is_lbc` in your yml"
)
xmod_out["spec"] = inputs * scale
if mode == "tl":
incr = 0
if len(cdates) != 0:
incr = reindex(
xmod_out["incr"],
levels={"time": outdates[:-1], "lev": inputs.lev},
)
xmod_out["incr"] = incr * inputs
except ValueError as e:
warning("I may have problem in alignment to LBC in the "
"target vector. If components of the target vector "
"correspond to LBC, please make sure that the "
"attribute 'is_lbc' is set as True in the "
"configuration yaml.")
raise e
# Data already contains the correct info for physical control variables
# WARNING: so far, assumes that the vertical resolution is already
# correct
elif getattr(tracer, "type", "scalar") == "physical":
# Reprojecting inputs and scaling factors to the same merged
# time steps
outdates = pd.to_datetime(
np.unique(np.array(
mapper["outputs"][trid]["input_dates"][di]).flatten()
)
)
if len(outdates) == 1:
outdates = np.append(outdates, outdates)
# NaNs in scale come from regions not included
# in the control vector
# Set to zero in physical mode
if tracer.hresol == "regions":
xmod_out["scale"].values[
np.isnan(xmod_out["scale"]).values] = 0
if mode == "tl":
xmod_out["incr"].values[
np.isnan(xmod_out["incr"]).values] = 0
spec = reindex(
xmod_out["scale"],
levels={"time": outdates[:-1],
"lev": range(out_mapper[trid]["domain"].nlev)},
)
xmod_out["spec"] = spec
if mode == "tl":
incr = reindex(
xmod_out["incr"],
levels={"time": outdates[:-1],
"lev": range(out_mapper[trid]["domain"].nlev)},
)
xmod_out["incr"] = incr
# Removing the scaling factor as all information is stored in
# 'spec' and 'incr' now
xmod_out.pop("scale")
xmod[trid][ddi] = xr.Dataset(xmod_out)