Source code for pycif.plugins.transforms.basic.exp.forward
import xarray as xr
import numpy as np
from logging import debug
import os
from .....utils.path import init_dir
[docs]
def forward(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
onlyinit=False,
**kwargs
):
r"""Compute the element-wise exponential of a tracer field.
Forward: :math:`y = e^x`.
TL: :math:`\delta y = e^x \cdot \delta x` (NaN increments set to 0).
The input field is saved to ``chain/product/`` for use by the adjoint.
Args:
transform (Plugin): exp instance (carries ``component``,
``parameter``, and ``orig_name``).
inout_datastore (dict): mutable datastore.
controlvect: unused.
obsvect: unused.
mapper (dict): transform mapper.
di (datetime): sub-simulation start date.
df (datetime): sub-simulation end date.
mode (str): ``'fwd'`` or ``'tl'``.
runsubdir (str): sub-simulation run directory (used to locate the
``chain/`` directory).
workdir (str): unused.
onlyinit (bool): unused.
**kwargs: unused.
"""
ddi = min(di, df)
xmod_in = inout_datastore["inputs"]
xmod_out = inout_datastore["outputs"]
# Initialize outputs
keys = ["spec"]
if mode == "tl":
keys.append("incr")
parameter = transform.parameter
component = transform.component
trid_out = (component, parameter)
xmod_out[trid_out][di] = {k: 0 for k in keys}
# Forward
xmod_out[trid_out][di]["spec"] = np.exp(xmod_in[trid_out][di]["spec"])
# TL
if mode == "tl":
xmod_out[trid_out][di]["incr"] = \
np.exp(xmod_in[trid_out][di]["spec"]) * \
xmod_in[trid_out][di]["incr"]
# Replace NaNs by zeros
mask = np.isnan(xmod_out[trid_out][di]["incr"]).values
xmod_out[trid_out][di]["incr"].values[mask] = 0.
# Save forward information for later adjoint
dump_dir = f"{runsubdir}/../chain/exp/{transform.orig_name}/"
if not os.path.isdir(dump_dir):
init_dir(dump_dir)
file_fwd_dataset = ddi.strftime(
f"{runsubdir}/../chain/exp/{transform.orig_name}_fwd_%Y%m%d%H%M.nc")
fwd_dataset = xr.Dataset({
"___".join(trid): xmod_in[trid][di]["spec"] for trid in mapper["inputs"]
})
fwd_dataset.to_netcdf(file_fwd_dataset)