Source code for pycif.plugins.transforms.complex.conc2ratio.adjoint

from __future__ import annotations

from pathlib import Path

import pandas as pd

from .....utils.datastores.dump import read_datastore
from .....utils.datastores.empty import init_empty


[docs] def adjoint( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, onlyinit=False, **kwargs, ): ddi = min(di, df) trids_in = list(mapper["inputs"].keys()) trids_out = list(mapper["outputs"].keys()) datastore_in = inout_datastore["inputs"] datastore_out = inout_datastore["outputs"] inputs_all = [trid[1] for trid in trids_in] outputs_all = [trid[1] for trid in trids_out] in_types_all = [trid[0] for trid in trids_in] out_types_all = [trid[0] for trid in trids_out] # Disentangle samples if any if any("__sample#" in inp for inp in inputs_all): trids_in_pert = [trid for trid in trids_in if "__sample#" in trid[1]] trids_out_pert = [trid for trid in trids_out if "__sample#" in trid[1]] trids_in_nopert = [trid for trid in trids_in if not "__sample#" in trid[1]] trids_out_nopert = [trid for trid in trids_out if not "__sample#" in trid[1]] nsamples = ( max([int(trid[1].split("__sample#")[1]) for trid in trids_in_pert]) + 1 ) list_idxs_input = [] list_idxs_output = [] for isample in range(nsamples): sample_id_str = f"__sample#{isample:03d}" list_idxs_input.append( [ idx for idx, trid in enumerate(trids_in_pert) if sample_id_str in trid[1] ] ) list_idxs_output.append( [ idx for idx, trid in enumerate(trids_out_pert) if sample_id_str in trid[1] ] ) inputs_sorted = [ [trids_in_pert[idx][1] for idx in idxs] + [trid[1] for trid in trids_in_nopert] for idxs in list_idxs_input ] outputs_sorted = [ [trids_out_pert[idx][1] for idx in idxs] + [trid[1] for trid in trids_out_nopert] for idxs in list_idxs_output ] in_types_sorted = [ [trids_in_pert[idx][0] for idx in idxs] + [trid[0] for trid in trids_in_nopert] for idxs in list_idxs_input ] out_types_sorted = [ [trids_out_pert[idx][0] for idx in idxs] + [trid[0] for trid in trids_out_nopert] for idxs in list_idxs_output ] else: inputs_sorted = [inputs_all] outputs_sorted = [outputs_all] in_types_sorted = [in_types_all] out_types_sorted = [out_types_all] # Initialize metadata if not already done transform.metadata = getattr(transform, "metadata", {}) transform.metadata[ddi] = {} # Run the main code for all group of samples for inputs, outputs, in_types, out_types in zip( inputs_sorted, outputs_sorted, in_types_sorted, out_types_sorted ): # Concatenate output datastores to_concat: list[pd.DataFrame] = [] for comp, trcr in zip(out_types, outputs): tmp = datastore_out.get((comp, trcr), {di: init_empty()})[di] if isinstance(tmp, dict): tmp = init_empty() to_concat.append(tmp.assign(parameter_ref=trcr.lower())) dfout = pd.concat(to_concat) dfout.columns = pd.MultiIndex.from_tuples( [ ("metadata", "parameter_ref") if col[0] == "parameter_ref" else col for col in dfout.columns ] ) dfout = dfout.reset_index(drop=True) # Initializes input dataframe from available total concentrations # Implicitly includes the adjoint of the sum of isotopologues for comp, trcr in zip(in_types, inputs): datastore_in[(comp, trcr)][di] = dfout.copy() datastore_in[(comp, trcr)][di][("metadata", "parameter")] = trcr.lower() # Save parameter_ref for later use for comp, trcr in zip(in_types, inputs): param_metadata = dfout["metadata"][["parameter", "parameter_ref"]].copy() param_metadata["parameter"]= trcr.lower() transform.metadata[ddi][(comp, trcr)] = param_metadata.astype("category") # Stop here if do not need to compute the full adjoint if onlyinit: continue # Retrieve fwd data (i.e. concentrations of individual isotopologues) for comp, spec in zip(in_types, inputs): file_fwd_data = Path( transform.model.adj_refdir, "chain", "conc2ratio", ddi.strftime( f"{transform.orig_name}_fwd_obsvect_{comp}_{spec}_%Y%m%d%H%M.nc" ), ) ds = read_datastore( str(file_fwd_data), reorder=False, col2dump=[ ("metadata", "parameter"), ("metadata", "parameter_ref"), ("maindata", "spec"), ("maindata", "incr"), ], ) datastore_in[(comp, spec)][di].loc[:, ("metadata", "parameter")] = ds[ ("metadata", "parameter") ].values datastore_in[(comp, spec)][di].loc[:, ("metadata", "parameter_ref")] = ds[ ("metadata", "parameter_ref") ].values datastore_in[(comp, spec)][di].loc[:, ("maindata", "spec")] = ds[ ("maindata", "spec") ].values datastore_in[(comp, spec)][di].loc[:, ("maindata", "incr")] = ds[ ("maindata", "incr") ].values # Compute df_all_isos_fwd = { trid[1]: datastore_in[trid][di] for trid in zip(in_types, inputs) } df0_fwd = df_all_isos_fwd[inputs[0]] spec_ref = outputs[-1] sub_outputs = [(sign, spec_ref) for sign in outputs[:-1]] for sub_output in sub_outputs: sign_id = sub_output[0] spec_id = sub_output[1] sign_attr = getattr(transform.parameters_out, sign_id) r_std = sign_attr.standard refs = sign_attr.refs isotopologues = sign_attr.isotopologues # Retrieve specific forward data ref_sims_fwd = sum([df_all_isos_fwd[ref]["maindata"]["spec"] for ref in refs]) isotopologue_sims_fwd = sum([df_all_isos_fwd[iso]["maindata"]["spec"] for iso in isotopologues]) mask_sign = df0_fwd["metadata"]["parameter_ref"] == sign_id.lower() ref_fwd = ref_sims_fwd.loc[mask_sign] isotopologue_fwd = isotopologue_sims_fwd.loc[mask_sign] comp_out = out_types[outputs.index(sign_id)] dfin = datastore_out[(comp_out, sign_id)][di] # Compute adj_out for ref in refs: comp_ref = in_types[inputs.index(ref)] df_ref = datastore_in[(comp_ref, ref)][di] mask_sign_ref = df_ref["metadata"]["parameter_ref"] == sign_id.lower() incr_ref = df_ref["maindata"].loc[mask_sign_ref, "adj_out"] adj_out = -isotopologue_fwd / ref_fwd**2 * incr_ref * 1000 / r_std df_ref.loc[mask_sign_ref, ("maindata", "adj_out")] = adj_out.astype("float64") for iso in isotopologues: comp_iso = in_types[inputs.index(iso)] df_iso = datastore_in[(comp_iso, iso)][di] mask_sign_iso = df_iso["metadata"]["parameter_ref"] == sign_id.lower() incr_iso = df_iso["maindata"].loc[mask_sign_iso, "adj_out"] adj_out = 1 / ref_fwd * incr_iso * 1000 / r_std df_iso.loc[mask_sign_iso, ("maindata", "adj_out")] = adj_out.astype("float64")