import numpy as np
import xarray as xr
import pandas as pd
import copy
from logging import warning, info
from .utils.dates import dateslice
from .....utils.dataarrays.reindex import reindex
from .utils.scalemaps import scale2map
from .forward_perturb import forward_perturb
[docs]
def forward(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
**kwargs
):
out_mapper = mapper["outputs"]
# If perturb_xb and samples, speed up the process
if getattr(transform, "perturb_xb", False) and \
any('__sample#' in tr[1] for tr in out_mapper):
forward_perturb(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
**kwargs
)
return
ddi = min(di, df)
xmod = inout_datastore["outputs"]
# Basic objects
for trid in out_mapper:
comp = trid[0]
trcr = trid[1]
# Force reading inputs
# This can be a dictionary based on dates,
# originating from time_interpolation behaviour
force_loadin = out_mapper[trid].get("force_loadin", False)
if isinstance(force_loadin, dict):
force_loadin = force_loadin.get(ddi, False)
in_files = out_mapper[trid]["input_files"].get(ddi, [])
in_dates = out_mapper[trid]["input_dates"].get(ddi, [])
tracer = out_mapper[trid]["tracer"]
# Skip parameters not in the control space
if not getattr(tracer, "iscontrol", False):
if force_loadin:
if not hasattr(tracer, "read"):
raise AttributeError(
f"Needing to read data for {trid}, "
f"but no read function is implemented for {tracer}"
)
# Turn dates to datetime for consistency in read
# TODO: This should be standardized in the future
dates2read = [
[x.to_pydatetime() for x in row]
for row in in_dates.itertuples(index=False, name=None)
]
# If inputs do not already exist, read is called
inputs = tracer.read(
trcr,
tracer.varname,
dates2read,
in_files,
comp_type=comp,
tracer=tracer,
model=transform.model,
ddi=ddi,
**kwargs
)
# Directly return what was read for sparse data
if getattr(tracer, "sparse_data", False):
inputs[("maindata", "incr")] = 0.
xmod[trid][ddi] = inputs
continue
xmod[trid][ddi] = xr.Dataset({"spec": inputs})
# Force increments to zero if tangent linear
if mode == "tl":
xmod[trid][ddi]["incr"] = 0 * xmod[trid][ddi]["spec"]
continue
# Otherwise, reformat from x
# Translates control vector, and increments if tangent-linear
variables = {"scale": copy.deepcopy(controlvect.x)}
if mode == "tl":
variables["incr"] = copy.deepcopy(controlvect.dx)
# Deal with control vector dates and cut if controlvect period spans
# outside the sub-simulation period
if len(in_dates) == 0:
dslice = []
else:
dslice = dateslice(tracer, np.min(
in_dates["start_date"]), np.max(in_dates["end_date"]))
cdates = tracer.dates[dslice]
if len(dslice) == 0:
pass
elif len(dslice) == 1 and cdates[0] < ddi:
cdates[0] = ddi
elif len(dslice) > 1 and cdates[0] < ddi < cdates[1]:
cdates[0] = ddi
# Translates only control variables corresponding to the
# simulation period
xmod_out = {}
for key, val in variables.items():
tmp = np.reshape(
val[tracer.xpointer: tracer.xpointer + tracer.dim],
(tracer.ndates, tracer.vresoldim, -1),
)[dslice]
# Deals with different resolutions
xmod_out[key] = scale2map(tmp, tracer, cdates, tracer.domain)
# Now deals with scalars and physical variables
if getattr(tracer, "type", "scalar") == "scalar":
# Turn dates to datetime for consistency in read
# TODO: This should be standardized in the future
dates2read = [
[x.to_pydatetime() for x in row]
for row in tracer.input_dates[ddi].itertuples(index=False, name=None)
]
# Read the tracer array and apply the present control vector
# scaling factor
try:
inputs = tracer.read(
trcr,
tracer.varname,
dates2read,
tracer.input_files[ddi],
comp_type=comp,
tracer=tracer,
model=transform.model,
ddi=ddi,
**kwargs
)
# If input is a pd.DataFrame, not working yet
if type(inputs) == pd.DataFrame:
raise Exception(
f"The inputs for {trcr} are DataFrames, hence sparse data. \n"
"This is not covered by CIF yet! \n"
"Input files: \n"
f"{tracer.input_files[ddi][0]}"
)
# Reprojecting inputs and scaling factors to the same merged
# time steps
outdates = pd.to_datetime(
np.unique(np.array(
mapper["outputs"][trid]["input_dates"][di]).flatten()
)
)
if len(outdates) == 1:
outdates = outdates.append(outdates)
# NaNs in scale come from regions not included
# in the control vector
# Set to one to multiply by native inputs in scalar mode
if tracer.hresol == "regions":
xmod_out["scale"].values[
np.isnan(xmod_out["scale"]).values] = 1.
if mode == "tl":
xmod_out["incr"].values[
np.isnan(xmod_out["incr"]).values] = 0
try:
inputs = reindex(
inputs,
levels={"time": outdates[:-1].values},
)
except:
print(__file__)
import code
code.interact(local=dict(locals(), **globals()))
if len(cdates) != 0:
scale = reindex(
xmod_out["scale"],
levels={"time": outdates[:-1], "lev": inputs.lev},
)
# Check that dimensions are aligned
if getattr(tracer, "is_lbc", False) and inputs.shape[2] != 1:
raise Exception(
f"WARNING! The tracer {trid} has option `is_lbc` "
f"whereas the corresponding data sets are full domain"
f" in {tracer.input_files[ddi]}. "
f"Either provide flat lateral boundary condition files "
f"or remove the option `is_lbc` in your yml"
)
xmod_out["spec"] = inputs * scale
if mode == "tl":
incr = 0
if len(cdates) != 0:
incr = reindex(
xmod_out["incr"],
levels={"time": outdates[:-1], "lev": inputs.lev},
)
xmod_out["incr"] = incr * inputs
except ValueError as e:
warning("I may have problem in alignment to LBC in the "
"target vector. If components of the target vector "
"correspond to LBC, please make sure that the "
"attribute 'is_lbc' is set as True in the "
"configuration yaml.")
raise e
# Data already contains the correct info for physical control variables
# WARNING: so far, assumes that the vertical resolution is already
# correct
elif getattr(tracer, "type", "scalar") == "physical":
# Reprojecting inputs and scaling factors to the same merged
# time steps
outdates = pd.to_datetime(
np.unique(np.array(
mapper["outputs"][trid]["input_dates"][di]).flatten()
)
)
if len(outdates) == 1:
outdates = np.append(outdates, outdates)
# NaNs in scale come from regions not included
# in the control vector
# Set to zero in physical mode
if tracer.hresol == "regions":
xmod_out["scale"].values[
np.isnan(xmod_out["scale"]).values] = 0
if mode == "tl":
xmod_out["incr"].values[
np.isnan(xmod_out["incr"]).values] = 0
spec = reindex(
xmod_out["scale"],
levels={"time": outdates[:-1],
"lev": range(out_mapper[trid]["domain"].nlev)},
)
xmod_out["spec"] = spec
if mode == "tl":
incr = reindex(
xmod_out["incr"],
levels={"time": outdates[:-1],
"lev": range(out_mapper[trid]["domain"].nlev)},
)
xmod_out["incr"] = incr
# Removing the scaling factor as all information is stored in
# 'spec' and 'incr' now
xmod_out.pop("scale")
xmod[trid][ddi] = xr.Dataset(xmod_out)