Source code for pycif.plugins.transforms.basic.product.forward

import os
from logging import debug

import numpy as np
import xarray as xr

from .....utils.path import init_dir


[docs] def forward( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, onlyinit=False, **kwargs ): ddi = min(di, df) xmod_in = inout_datastore["inputs"] xmod_out = inout_datastore["outputs"] list_trid_in = list(mapper['inputs']) list_trid_out = list(mapper['outputs']) batch_sampling = any("_sample#" in param for _, param in list_trid_in + list_trid_out) if batch_sampling: # Retrieve reference configuration list_trid_in_ref = set( (comp, param.split("__sample#")[0]) for comp, param in list_trid_in) list_trid_out_ref = set( (comp, param.split("__sample#")[0]) for comp, param in list_trid_out) # Building a mapping between outputs and inputs out_in_mapping = {} for trid_out in list_trid_out: if trid_out in list_trid_out_ref: # tracer 'trid_out' is unperturbed out_in_mapping[trid_out] = list_trid_in_ref continue _, param = trid_out sample_id = param.split("__sample#")[1] out_in_mapping[trid_out] = [] for trid_in in list_trid_in_ref: if trid_in in list_trid_in: # tracer 'trid_in' is unperturbed out_in_mapping[trid_out].append(trid_in) else: # tracer 'trid_in' is perturbed comp, param = trid_in out_in_mapping[trid_out].append( (comp, f"{param}__sample#{sample_id}")) else: if len(list_trid_out) > 1: raise ValueError("Multiple outputs for product") out_in_mapping = {list_trid_out[0]: list_trid_in} for trid_out, list_trid_in in out_in_mapping.items(): debug(f"Computing product: {'/'.join(trid_out)} = " f"{' * '.join(['/'.join(trid) for trid in list_trid_in])}") spec_in_arrays = {trid: xmod_in[trid][di]["spec"].values for trid in list_trid_in} # Forward xmod_out[trid_out][di]['spec'] = np.prod( np.vstack([ array[np.newaxis, ...] for array in spec_in_arrays.values() ]), axis=0 ) + 0. * xmod_in[list_trid_in[0]][di]['spec'] # Tangent-Linear if mode == "tl": incr_out = None for ref_trid in list_trid_in: if 'incr' not in xmod_in[ref_trid][di]: continue tmp = ( xmod_in[ref_trid][di]['incr'] * np.prod( np.vstack([ array[np.newaxis, ...] for trid, array in spec_in_arrays.items() if trid != ref_trid ]), axis=0 ) ) incr_out = tmp if incr_out is None else incr_out + tmp xmod_out[trid_out][di]['spec'] = incr_out + \ 0. * xmod_in[list_trid_in[0]][di]['spec'] if not batch_sampling: # Save forward information for later adjoint dump_dir = os.path.join(runsubdir, "../chain/product") if not os.path.isdir(dump_dir): init_dir(dump_dir) file_fwd_dataset = os.path.join(dump_dir, ddi.strftime( f"{transform.orig_name}_fwd_%Y%m%d%H%M.nc")) fwd_dataset = xr.Dataset({ "___".join(trid): xmod_in[trid][di]["spec"] for trid in list_trid_in }) fwd_dataset.to_netcdf(file_fwd_dataset)