import numpy as np
import pandas as pd
import copy
import xarray as xr
import os
from .....utils.path import init_dir
[docs]
def forward(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
onlyinit=False,
**kwargs
):
ddi = min(di, df)
xmod_out = inout_datastore["outputs"]
xmod_in = inout_datastore["inputs"]
trids_in = list(mapper["inputs"].keys())
trids_out = list(mapper["outputs"].keys())
inputs_all = [trid[1] for trid in trids_in]
outputs_all = [trid[1] for trid in trids_out]
in_types_all = [trid[0] for trid in trids_in]
out_types_all = [trid[0] for trid in trids_out]
iso_mass = transform.parameters_out.iso_mass
spec_mass = transform.parameters_out.spec_mass
unit = transform.unit
# Find the standards for every kind of ratios
r_stds = {}
for sign in inputs_all:
ref_sign = sign.split("__sample#")[0] \
if "__sample#" in sign else sign
try:
r_stds[sign] = getattr(transform.parameters_out, ref_sign).standard
except AttributeError:
pass
# Disentangle samples if any
if any("__sample#" in inp for inp in inputs_all):
trids_in_pert = [trid for trid in trids_in if "__sample#" in trid[1]]
trids_out_pert = [trid for trid in trids_out if "__sample#" in trid[1]]
trids_in_nopert = [
trid for trid in trids_in if not "__sample#" in trid[1]]
trids_out_nopert = [
trid for trid in trids_out if not "__sample#" in trid[1]]
nsamples = max([int(trid[1].split("__sample#")[1])
for trid in trids_in_pert]) + 1
list_idxs_input = []
list_idxs_output = []
for isample in range(nsamples):
sample_id_str = "__sample#{:03d}".format(isample)
list_idxs_input.append([idx for idx, trid
in enumerate(trids_in_pert)
if sample_id_str in trid[1]])
list_idxs_output.append([idx for idx, trid
in enumerate(trids_out_pert)
if sample_id_str in trid[1]])
inputs_sorted = [[trids_in_pert[idx][1] for idx in idxs] +
[trid[1] for trid in trids_in_nopert]
for idxs in list_idxs_input]
outputs_sorted = [[trids_out_pert[idx][1] for idx in idxs] +
[trid[1] for trid in trids_out_nopert]
for idxs in list_idxs_output]
in_types_sorted = [[trids_in_pert[idx][0] for idx in idxs] +
[trid[0] for trid in trids_in_nopert]
for idxs in list_idxs_input]
out_types_sorted = [[trids_out_pert[idx][0] for idx in idxs] +
[trid[0] for trid in trids_out_nopert]
for idxs in list_idxs_output]
else:
inputs_sorted = [inputs_all]
outputs_sorted = [outputs_all]
in_types_sorted = [in_types_all]
out_types_sorted = [out_types_all]
# Run the main code for all groups of samples
for inputs, outputs, in_types, out_types in \
zip(inputs_sorted, outputs_sorted,
in_types_sorted, out_types_sorted):
noutputs = len(outputs)
keys = ['spec']
if mode == 'tl':
keys.append('incr')
trid = (in_types[-1], inputs[-1])
for out_type, output in zip(out_types, outputs):
trid_out = (out_type, output)
xmod_out[trid_out][di] = {
k: xmod_in[trid][di][k] for k in xmod_in[trid][di]}
xmod_out[trid_out][di]["spec"] = \
copy.deepcopy(xmod_in[trid][di]["spec"])
if mode == 'tl':
xmod_out[trid_out][di]["incr"] = \
copy.deepcopy(xmod_in[trid][di]["incr"])
signatures = {sign: xmod_in[(in_types[0], sign)][di]['spec']
for sign in inputs[:-1]}
spec_data = xmod_in[(in_types[-1], inputs[-1])][di]['spec']
# Check that standards are defined
missing_rstds = [sign for sign in inputs[:-1] if sign not in r_stds]
if missing_rstds != []:
raise Exception(
"Warning! There are missing standards in your Yaml to compute the isotopic transform.\n"
f"Missing species in the `parameters_out` paragraph: {missing_rstds}"
)
a_factors = {
sign: (1 + signatures[sign] / 1000) * r_stds[sign]
for sign in inputs[:-1]}
# Save signature and data for later use by adjoint
dump_dir = "{}/../chain/ratio2conc/".format(
runsubdir, transform.orig_name)
if not os.path.isdir(dump_dir):
init_dir(dump_dir)
file_fwd_dataset = ddi.strftime(
"{}/../chain/ratio2conc/{}_fwd_%Y%m%d%H%M.nc"
.format(runsubdir, transform.orig_name))
fwd_dataset = spec_data.to_dataset(name="spec")
for sign in inputs[:-1]:
fwd_dataset[sign] = signatures[sign]
fwd_dataset.to_netcdf(file_fwd_dataset)
# Compute fwd spec
for ioutput, (out_type, output) in enumerate(zip(out_types, outputs)):
for isign, sign in enumerate(inputs[:-1]):
t = noutputs // (2 ** (isign + 1))
num_group = ioutput // t
if num_group % 2:
xmod_out[(out_type, output)][di]['spec'] *= \
a_factors[sign] / (1 + a_factors[sign])
else:
xmod_out[(out_type, output)][di]['spec'] *= \
1 / (1 + a_factors[sign])
# Applying tangent linear
if mode == 'tl':
signatures_tl = {
sign: xmod_in[(in_types[0], sign)][di]['incr']
for sign in inputs[:-1]}
spec_data_tl = xmod_in[(in_types[-1], inputs[-1])][di]['incr']
for ioutput, (out_type, output) in enumerate(zip(out_types, outputs)):
total_sum = spec_data_tl.copy()
# spec_data increment
for isign, sign in enumerate(inputs[:-1]):
t = noutputs // (2 ** (isign + 1))
num_group = ioutput // t
if num_group % 2:
total_sum *= a_factors[sign] / (1 + a_factors[sign])
else:
total_sum *= 1 / (1 + a_factors[sign])
# signatures increment
for isign_plus, sign_plus in enumerate(inputs[:-1]):
total_multip = signatures_tl[sign_plus].copy()
# multiplication terms for each signature increment
for isign_times, sign_times in enumerate(inputs[:-1]):
t = noutputs // (2 ** (isign_times + 1))
num_group = ioutput // t
if isign_times == isign_plus:
if num_group % 2:
total_multip *= r_stds[sign_times] / 1000\
/ (1 + a_factors[sign_times]) ** 2
else:
total_multip *= - r_stds[sign_times] / 1000\
/ (1 + a_factors[sign_times]) ** 2
else:
if num_group % 2:
total_multip *= a_factors[sign_times] \
/ (1 + a_factors[sign_times])
else:
total_multip *= 1 / (1 + a_factors[sign_times])
total_sum += total_multip * spec_data
xmod_out[(out_type, output)
][di]['incr'] = copy.deepcopy(total_sum)
# Mass correction if units are in mass
if unit == 'mass':
for i, (out_type, output) in enumerate(zip(out_types, outputs)):
xmod_out[(out_type, output)][di]['spec'] *= \
(iso_mass[i] / spec_mass)
if mode == 'tl':
for i, (out_type, output) in enumerate(zip(out_types, outputs)):
xmod_out[(out_type, output)][di]['incr'] *= \
(iso_mass[i] / spec_mass)