Source code for pycif.plugins.transforms.basic.product.forward
import os
from logging import debug
import numpy as np
import xarray as xr
from .....utils.path import init_dir
[docs]
def forward(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
onlyinit=False,
**kwargs
):
ddi = min(di, df)
xmod_in = inout_datastore["inputs"]
xmod_out = inout_datastore["outputs"]
list_trid_in = list(mapper['inputs'])
list_trid_out = list(mapper['outputs'])
batch_sampling = any("_sample#" in param for _,
param in list_trid_in + list_trid_out)
if batch_sampling:
# Retrieve reference configuration
list_trid_in_ref = set(
(comp, param.split("__sample#")[0]) for comp, param in list_trid_in)
list_trid_out_ref = set(
(comp, param.split("__sample#")[0]) for comp, param in list_trid_out)
# Building a mapping between outputs and inputs
out_in_mapping = {}
for trid_out in list_trid_out:
if trid_out in list_trid_out_ref:
# tracer 'trid_out' is unperturbed
out_in_mapping[trid_out] = list_trid_in_ref
continue
_, param = trid_out
sample_id = param.split("__sample#")[1]
out_in_mapping[trid_out] = []
for trid_in in list_trid_in_ref:
if trid_in in list_trid_in:
# tracer 'trid_in' is unperturbed
out_in_mapping[trid_out].append(trid_in)
else:
# tracer 'trid_in' is perturbed
comp, param = trid_in
out_in_mapping[trid_out].append(
(comp, f"{param}__sample#{sample_id}"))
else:
if len(list_trid_out) > 1:
raise ValueError("Multiple outputs for product")
out_in_mapping = {list_trid_out[0]: list_trid_in}
for trid_out, list_trid_in in out_in_mapping.items():
debug(f"Computing product: {'/'.join(trid_out)} = "
f"{' * '.join(['/'.join(trid) for trid in list_trid_in])}")
spec_in_arrays = {trid: xmod_in[trid][di]["spec"].values
for trid in list_trid_in}
# Forward
xmod_out[trid_out][di]['spec'] = np.prod(
np.vstack([
array[np.newaxis, ...]
for array in spec_in_arrays.values()
]),
axis=0
) + 0. * xmod_in[list_trid_in[0]][di]['spec']
# Tangent-Linear
if mode == "tl":
incr_out = None
for ref_trid in list_trid_in:
if 'incr' not in xmod_in[ref_trid][di]:
continue
tmp = (
xmod_in[ref_trid][di]['incr'] *
np.prod(
np.vstack([
array[np.newaxis, ...]
for trid, array in spec_in_arrays.items()
if trid != ref_trid
]),
axis=0
)
)
incr_out = tmp if incr_out is None else incr_out + tmp
xmod_out[trid_out][di]['spec'] = incr_out + \
0. * xmod_in[list_trid_in[0]][di]['spec']
if not batch_sampling:
# Save forward information for later adjoint
dump_dir = os.path.join(runsubdir, "../chain/product")
if not os.path.isdir(dump_dir):
init_dir(dump_dir)
file_fwd_dataset = os.path.join(dump_dir, ddi.strftime(
f"{transform.orig_name}_fwd_%Y%m%d%H%M.nc"))
fwd_dataset = xr.Dataset({
"___".join(trid): xmod_in[trid][di]["spec"]
for trid in list_trid_in
})
fwd_dataset.to_netcdf(file_fwd_dataset)