Source code for pycif.plugins.transforms.complex.ratio2conc.adjoint
import copy
import xarray as xr
import numpy as np
[docs]
def adjoint(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
onlyinit=False,
**kwargs
):
ddi = min(di, df)
xmod_out = inout_datastore["outputs"]
xmod_in = inout_datastore["inputs"]
trids_in = list(mapper["inputs"].keys())
trids_out = list(mapper["outputs"].keys())
inputs_all = [trid[1] for trid in trids_in]
outputs_all = [trid[1] for trid in trids_out]
in_types_all = [trid[0] for trid in trids_in]
out_types_all = [trid[0] for trid in trids_out]
iso_mass = transform.parameters_out.iso_mass
spec_mass = transform.parameters_out.spec_mass
unit = transform.unit
# Find the standards for every kind of ratios
r_stds = {}
for sign in inputs_all:
ref_sign = sign.split("__sample#")[0] \
if "__sample#" in sign else sign
try:
r_stds[sign] = getattr(transform.parameters_out, ref_sign).standard
except AttributeError:
pass
# Disentangle samples if any
if any("__sample#" in inp for inp in inputs_all):
trids_in_pert = [trid for trid in trids_in if "__sample#" in trid[1]]
trids_out_pert = [trid for trid in trids_out if "__sample#" in trid[1]]
trids_in_nopert = [
trid for trid in trids_in if not "__sample#" in trid[1]]
trids_out_nopert = [
trid for trid in trids_out if not "__sample#" in trid[1]]
nsamples = max([int(trid[1].split("__sample#")[1])
for trid in trids_in_pert]) + 1
list_idxs_input = []
list_idxs_output = []
for isample in range(nsamples):
sample_id_str = "__sample#{:03d}".format(isample)
list_idxs_input.append([idx for idx, trid
in enumerate(trids_in_pert)
if sample_id_str in trid[1]])
list_idxs_output.append([idx for idx, trid
in enumerate(trids_out_pert)
if sample_id_str in trid[1]])
inputs_sorted = [[trids_in_pert[idx][1] for idx in idxs] +
[trid[1] for trid in trids_in_nopert]
for idxs in list_idxs_input]
outputs_sorted = [[trids_out_pert[idx][1] for idx in idxs] +
[trid[1] for trid in trids_out_nopert]
for idxs in list_idxs_output]
in_types_sorted = [[trids_in_pert[idx][0] for idx in idxs] +
[trid[0] for trid in trids_in_nopert]
for idxs in list_idxs_input]
out_types_sorted = [[trids_out_pert[idx][0] for idx in idxs] +
[trid[0] for trid in trids_out_nopert]
for idxs in list_idxs_output]
else:
inputs_sorted = [inputs_all]
outputs_sorted = [outputs_all]
in_types_sorted = [in_types_all]
out_types_sorted = [out_types_all]
# Run the main code for all group of samples
for inputs, outputs, in_types, out_types in \
zip(inputs_sorted, outputs_sorted,
in_types_sorted, out_types_sorted):
noutputs = len(outputs)
trid_out = (out_types[0], outputs[0])
for trid in trids_in:
xmod_in[trid][di] = {
k: xmod_out[trid_out][di][k] for k in xmod_out[trid_out][di]}
if not onlyinit:
xmod_in[trid][di]["adj_out"] = \
copy.deepcopy(xmod_out[trid_out][di]["adj_out"])
if onlyinit:
continue
# Fetch fwd data
file_fwd_dataset = ddi.strftime(
"{}/chain/ratio2conc/{}_fwd_%Y%m%d%H%M.nc".format(
transform.model.adj_refdir, transform.orig_name))
fwd_dataset = xr.open_dataset(file_fwd_dataset)
spec_data = fwd_dataset["spec"].values
signatures = {sign: fwd_dataset[sign].values for sign in inputs[:-1]}
a_factors = {sign: (1 + signatures[sign] / 1000) * r_stds[sign]
for sign in inputs[:-1]}
isotopologue_adj = {output: copy.deepcopy(xmod_out[(out_type, output)][di]["adj_out"])
for (out_type, output) in zip(out_types, outputs)}
# Mass correction if units are in mass
if unit == "mass":
for i, iso in enumerate(isotopologue_adj.keys()):
isotopologue_adj[iso] *= iso_mass[i] / spec_mass
# Sensitivity to total
total_sensitivity = \
0 * xmod_in[(in_types[-1], inputs[-1])][di]["adj_out"]
for ioutput, output in enumerate(outputs):
total_multip = isotopologue_adj[output].copy()
# multiplication terms for each signature
for isign, sign in enumerate(inputs[:-1]):
t = noutputs // (2 ** (isign + 1))
num_group = ioutput // t
if num_group % 2:
total_multip *= a_factors[sign] / (1 + a_factors[sign])
else:
total_multip *= 1 / (1 + a_factors[sign])
total_sensitivity += total_multip
xmod_in[(in_types[-1], inputs[-1])][di]["adj_out"] = \
total_sensitivity.copy()
# Sensitivity to signatures
for isign, sign in enumerate(inputs[:-1]):
signature_sensitivity = 0 * \
xmod_in[(in_types[0], sign)][di]["adj_out"]
for ioutput, (out_type, output) in enumerate(zip(out_types, outputs)):
total_multip = isotopologue_adj[output].copy()
# multiplication terms for each signature increment
for isign_times, sign_times in enumerate(inputs[:-1]):
t = noutputs // (2 ** (isign_times + 1))
num_group = ioutput // t
if isign_times == isign:
if num_group % 2:
total_multip *= (
r_stds[sign_times]
/ 1000
/ (1 + a_factors[sign_times]) ** 2
)
else:
total_multip *= (
- r_stds[sign_times]
/ 1000
/ (1 + a_factors[sign_times]) ** 2
)
else:
if num_group % 2:
total_multip *= (
a_factors[sign_times] /
(1 + a_factors[sign_times])
)
else:
total_multip *= 1 / (1 + a_factors[sign_times])
signature_sensitivity += total_multip * spec_data
xmod_in[(in_types[0], sign)][di]["adj_out"] = \
signature_sensitivity.copy()