Source code for pycif.plugins.transforms.basic.exp.forward

import xarray as xr
import numpy as np
from logging import debug
import os
from .....utils.path import init_dir


[docs] def forward( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, onlyinit=False, **kwargs ): r"""Compute the element-wise exponential of a tracer field. Forward: :math:`y = e^x`. TL: :math:`\delta y = e^x \cdot \delta x` (NaN increments set to 0). The input field is saved to ``chain/product/`` for use by the adjoint. Args: transform (Plugin): exp instance (carries ``component``, ``parameter``, and ``orig_name``). inout_datastore (dict): mutable datastore. controlvect: unused. obsvect: unused. mapper (dict): transform mapper. di (datetime): sub-simulation start date. df (datetime): sub-simulation end date. mode (str): ``'fwd'`` or ``'tl'``. runsubdir (str): sub-simulation run directory (used to locate the ``chain/`` directory). workdir (str): unused. onlyinit (bool): unused. **kwargs: unused. """ ddi = min(di, df) xmod_in = inout_datastore["inputs"] xmod_out = inout_datastore["outputs"] # Initialize outputs keys = ["spec"] if mode == "tl": keys.append("incr") parameter = transform.parameter component = transform.component trid_out = (component, parameter) xmod_out[trid_out][di] = {k: 0 for k in keys} # Forward xmod_out[trid_out][di]["spec"] = np.exp(xmod_in[trid_out][di]["spec"]) # TL if mode == "tl": xmod_out[trid_out][di]["incr"] = \ np.exp(xmod_in[trid_out][di]["spec"]) * \ xmod_in[trid_out][di]["incr"] # Replace NaNs by zeros mask = np.isnan(xmod_out[trid_out][di]["incr"]).values xmod_out[trid_out][di]["incr"].values[mask] = 0. # Save forward information for later adjoint dump_dir = f"{runsubdir}/../chain/exp/{transform.orig_name}/" if not os.path.isdir(dump_dir): init_dir(dump_dir) file_fwd_dataset = ddi.strftime( f"{runsubdir}/../chain/exp/{transform.orig_name}_fwd_%Y%m%d%H%M.nc") fwd_dataset = xr.Dataset({ "___".join(trid): xmod_in[trid][di]["spec"] for trid in mapper["inputs"] }) fwd_dataset.to_netcdf(file_fwd_dataset)