Source code for pycif.plugins.transforms.basic.exp.forward

import xarray as xr
import numpy as np
from logging import debug
import os
from .....utils.path import init_dir


[docs] def forward( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, onlyinit=False, **kwargs ): ddi = min(di, df) xmod_in = inout_datastore["inputs"] xmod_out = inout_datastore["outputs"] # Initialize outputs keys = ["spec"] if mode == "tl": keys.append("incr") parameter = transform.parameter component = transform.component trid_out = (component, parameter) xmod_out[trid_out][di] = {k: 0 for k in keys} # Forward xmod_out[trid_out][di]["spec"] = np.exp(xmod_in[trid_out][di]["spec"]) # TL if mode == "tl": xmod_out[trid_out][di]["incr"] = \ np.exp(xmod_in[trid_out][di]["spec"]) * \ xmod_in[trid_out][di]["incr"] # Replace NaNs by zeros mask = np.isnan(xmod_out[trid_out][di]["incr"]).values xmod_out[trid_out][di]["incr"].values[mask] = 0. # Save forward information for later adjoint dump_dir = "{}/../chain/product/".format(runsubdir, transform.orig_name) if not os.path.isdir(dump_dir): init_dir(dump_dir) file_fwd_dataset = ddi.strftime("{}/../chain/product/{}_fwd_%Y%m%d%H%M.nc" .format(runsubdir, transform.orig_name)) fwd_dataset = xr.Dataset({ "___".join(trid): xmod_in[trid][di]["spec"] for trid in mapper["inputs"] }) fwd_dataset.to_netcdf(file_fwd_dataset)