from __future__ import annotations
from pathlib import Path
import pandas as pd
from .....utils.datastores.dump import read_datastore
from .....utils.datastores.empty import init_empty
[docs]
def adjoint(
transform,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
onlyinit=False,
**kwargs,
):
ddi = min(di, df)
trids_in = list(mapper["inputs"].keys())
trids_out = list(mapper["outputs"].keys())
datastore_in = inout_datastore["inputs"]
datastore_out = inout_datastore["outputs"]
inputs_all = [trid[1] for trid in trids_in]
outputs_all = [trid[1] for trid in trids_out]
in_types_all = [trid[0] for trid in trids_in]
out_types_all = [trid[0] for trid in trids_out]
# Disentangle samples if any
if any("__sample#" in inp for inp in inputs_all):
trids_in_pert = [trid for trid in trids_in if "__sample#" in trid[1]]
trids_out_pert = [trid for trid in trids_out if "__sample#" in trid[1]]
trids_in_nopert = [trid for trid in trids_in if not "__sample#" in trid[1]]
trids_out_nopert = [trid for trid in trids_out if not "__sample#" in trid[1]]
nsamples = (
max([int(trid[1].split("__sample#")[1]) for trid in trids_in_pert]) + 1
)
list_idxs_input = []
list_idxs_output = []
for isample in range(nsamples):
sample_id_str = f"__sample#{isample:03d}"
list_idxs_input.append(
[
idx
for idx, trid in enumerate(trids_in_pert)
if sample_id_str in trid[1]
]
)
list_idxs_output.append(
[
idx
for idx, trid in enumerate(trids_out_pert)
if sample_id_str in trid[1]
]
)
inputs_sorted = [
[trids_in_pert[idx][1] for idx in idxs]
+ [trid[1] for trid in trids_in_nopert]
for idxs in list_idxs_input
]
outputs_sorted = [
[trids_out_pert[idx][1] for idx in idxs]
+ [trid[1] for trid in trids_out_nopert]
for idxs in list_idxs_output
]
in_types_sorted = [
[trids_in_pert[idx][0] for idx in idxs]
+ [trid[0] for trid in trids_in_nopert]
for idxs in list_idxs_input
]
out_types_sorted = [
[trids_out_pert[idx][0] for idx in idxs]
+ [trid[0] for trid in trids_out_nopert]
for idxs in list_idxs_output
]
else:
inputs_sorted = [inputs_all]
outputs_sorted = [outputs_all]
in_types_sorted = [in_types_all]
out_types_sorted = [out_types_all]
# Initialize metadata if not already done
transform.metadata = getattr(transform, "metadata", {})
transform.metadata[ddi] = {}
# Run the main code for all group of samples
for inputs, outputs, in_types, out_types in zip(
inputs_sorted, outputs_sorted, in_types_sorted, out_types_sorted
):
# Concatenate output datastores
to_concat: list[pd.DataFrame] = []
for comp, trcr in zip(out_types, outputs):
tmp = datastore_out.get((comp, trcr), {di: init_empty()})[di]
if isinstance(tmp, dict):
tmp = init_empty()
to_concat.append(tmp.assign(parameter_ref=trcr.lower()))
dfout = pd.concat(to_concat)
dfout.columns = pd.MultiIndex.from_tuples(
[
("metadata", "parameter_ref") if col[0] == "parameter_ref" else col
for col in dfout.columns
]
)
dfout = dfout.reset_index(drop=True)
# Initializes input dataframe from available total concentrations
# Implicitly includes the adjoint of the sum of isotopologues
for comp, trcr in zip(in_types, inputs):
datastore_in[(comp, trcr)][di] = dfout.copy()
datastore_in[(comp, trcr)][di][("metadata", "parameter")] = trcr.lower()
# Save parameter_ref for later use
for comp, trcr in zip(in_types, inputs):
param_metadata = dfout["metadata"][["parameter", "parameter_ref"]].copy()
param_metadata["parameter"]= trcr.lower()
transform.metadata[ddi][(comp, trcr)] = param_metadata.astype("category")
# Stop here if do not need to compute the full adjoint
if onlyinit:
continue
# Retrieve fwd data (i.e. concentrations of individual isotopologues)
for comp, spec in zip(in_types, inputs):
file_fwd_data = Path(
transform.model.adj_refdir,
"chain",
"conc2ratio",
ddi.strftime(
f"{transform.orig_name}_fwd_obsvect_{comp}_{spec}_%Y%m%d%H%M.nc"
),
)
ds = read_datastore(
str(file_fwd_data),
reorder=False,
col2dump=[
("metadata", "parameter"),
("metadata", "parameter_ref"),
("maindata", "spec"),
("maindata", "incr"),
],
)
datastore_in[(comp, spec)][di].loc[:, ("metadata", "parameter")] = ds[
("metadata", "parameter")
].values
datastore_in[(comp, spec)][di].loc[:, ("metadata", "parameter_ref")] = ds[
("metadata", "parameter_ref")
].values
datastore_in[(comp, spec)][di].loc[:, ("maindata", "spec")] = ds[
("maindata", "spec")
].values
datastore_in[(comp, spec)][di].loc[:, ("maindata", "incr")] = ds[
("maindata", "incr")
].values
# Compute
df_all_isos_fwd = {
trid[1]: datastore_in[trid][di] for trid in zip(in_types, inputs)
}
df0_fwd = df_all_isos_fwd[inputs[0]]
spec_ref = outputs[-1]
sub_outputs = [(sign, spec_ref) for sign in outputs[:-1]]
for sub_output in sub_outputs:
sign_id = sub_output[0]
spec_id = sub_output[1]
sign_attr = getattr(transform.parameters_out, sign_id)
r_std = sign_attr.standard
refs = sign_attr.refs
isotopologues = sign_attr.isotopologues
# Retrieve specific forward data
ref_sims_fwd = sum([df_all_isos_fwd[ref]["maindata"]["spec"] for ref in refs])
isotopologue_sims_fwd = sum([df_all_isos_fwd[iso]["maindata"]["spec"] for iso in isotopologues])
mask_sign = df0_fwd["metadata"]["parameter_ref"] == sign_id.lower()
ref_fwd = ref_sims_fwd.loc[mask_sign]
isotopologue_fwd = isotopologue_sims_fwd.loc[mask_sign]
comp_out = out_types[outputs.index(sign_id)]
dfin = datastore_out[(comp_out, sign_id)][di]
# Compute adj_out
for ref in refs:
comp_ref = in_types[inputs.index(ref)]
df_ref = datastore_in[(comp_ref, ref)][di]
mask_sign_ref = df_ref["metadata"]["parameter_ref"] == sign_id.lower()
incr_ref = df_ref["maindata"].loc[mask_sign_ref, "adj_out"]
adj_out = -isotopologue_fwd / ref_fwd**2 * incr_ref * 1000 / r_std
df_ref.loc[mask_sign_ref, ("maindata", "adj_out")] = adj_out.astype("float64")
for iso in isotopologues:
comp_iso = in_types[inputs.index(iso)]
df_iso = datastore_in[(comp_iso, iso)][di]
mask_sign_iso = df_iso["metadata"]["parameter_ref"] == sign_id.lower()
incr_iso = df_iso["maindata"].loc[mask_sign_iso, "adj_out"]
adj_out = 1 / ref_fwd * incr_iso * 1000 / r_std
df_iso.loc[mask_sign_iso, ("maindata", "adj_out")] = adj_out.astype("float64")