Source code for pycif.plugins.transforms.basic.exp.forward
import xarray as xr
import numpy as np
from logging import debug
import os
from .....utils.path import init_dir
[docs]
def forward(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
onlyinit=False,
**kwargs
):
ddi = min(di, df)
xmod_in = inout_datastore["inputs"]
xmod_out = inout_datastore["outputs"]
# Initialize outputs
keys = ["spec"]
if mode == "tl":
keys.append("incr")
parameter = transform.parameter
component = transform.component
trid_out = (component, parameter)
xmod_out[trid_out][di] = {k: 0 for k in keys}
# Forward
xmod_out[trid_out][di]["spec"] = np.exp(xmod_in[trid_out][di]["spec"])
# TL
if mode == "tl":
xmod_out[trid_out][di]["incr"] = \
np.exp(xmod_in[trid_out][di]["spec"]) * \
xmod_in[trid_out][di]["incr"]
# Replace NaNs by zeros
mask = np.isnan(xmod_out[trid_out][di]["incr"]).values
xmod_out[trid_out][di]["incr"].values[mask] = 0.
# Save forward information for later adjoint
dump_dir = "{}/../chain/product/".format(runsubdir, transform.orig_name)
if not os.path.isdir(dump_dir):
init_dir(dump_dir)
file_fwd_dataset = ddi.strftime("{}/../chain/product/{}_fwd_%Y%m%d%H%M.nc"
.format(runsubdir, transform.orig_name))
fwd_dataset = xr.Dataset({
"___".join(trid): xmod_in[trid][di]["spec"] for trid in mapper["inputs"]
})
fwd_dataset.to_netcdf(file_fwd_dataset)