Source code for pycif.plugins.transforms.complex.ratio2conc.adjoint

import copy
import xarray as xr
import numpy as np


[docs] def adjoint( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, onlyinit=False, **kwargs ): ddi = min(di, df) xmod_out = inout_datastore["outputs"] xmod_in = inout_datastore["inputs"] trids_in = list(mapper["inputs"].keys()) trids_out = list(mapper["outputs"].keys()) inputs_all = [trid[1] for trid in trids_in] outputs_all = [trid[1] for trid in trids_out] in_types_all = [trid[0] for trid in trids_in] out_types_all = [trid[0] for trid in trids_out] iso_mass = transform.parameters_out.iso_mass spec_mass = transform.parameters_out.spec_mass unit = transform.unit # Find the standards for every kind of ratios r_stds = {} for sign in inputs_all: ref_sign = sign.split("__sample#")[0] \ if "__sample#" in sign else sign try: r_stds[sign] = getattr(transform.parameters_out, ref_sign).standard except AttributeError: pass # Disentangle samples if any if any("__sample#" in inp for inp in inputs_all): trids_in_pert = [trid for trid in trids_in if "__sample#" in trid[1]] trids_out_pert = [trid for trid in trids_out if "__sample#" in trid[1]] trids_in_nopert = [ trid for trid in trids_in if not "__sample#" in trid[1]] trids_out_nopert = [ trid for trid in trids_out if not "__sample#" in trid[1]] nsamples = max([int(trid[1].split("__sample#")[1]) for trid in trids_in_pert]) + 1 list_idxs_input = [] list_idxs_output = [] for isample in range(nsamples): sample_id_str = "__sample#{:03d}".format(isample) list_idxs_input.append([idx for idx, trid in enumerate(trids_in_pert) if sample_id_str in trid[1]]) list_idxs_output.append([idx for idx, trid in enumerate(trids_out_pert) if sample_id_str in trid[1]]) inputs_sorted = [[trids_in_pert[idx][1] for idx in idxs] + [trid[1] for trid in trids_in_nopert] for idxs in list_idxs_input] outputs_sorted = [[trids_out_pert[idx][1] for idx in idxs] + [trid[1] for trid in trids_out_nopert] for idxs in list_idxs_output] in_types_sorted = [[trids_in_pert[idx][0] for idx in idxs] + [trid[0] for trid in trids_in_nopert] for idxs in list_idxs_input] out_types_sorted = [[trids_out_pert[idx][0] for idx in idxs] + [trid[0] for trid in trids_out_nopert] for idxs in list_idxs_output] else: inputs_sorted = [inputs_all] outputs_sorted = [outputs_all] in_types_sorted = [in_types_all] out_types_sorted = [out_types_all] # Run the main code for all group of samples for inputs, outputs, in_types, out_types in \ zip(inputs_sorted, outputs_sorted, in_types_sorted, out_types_sorted): noutputs = len(outputs) trid_out = (out_types[0], outputs[0]) for trid in trids_in: xmod_in[trid][di] = { k: xmod_out[trid_out][di][k] for k in xmod_out[trid_out][di]} if not onlyinit: xmod_in[trid][di]["adj_out"] = \ copy.deepcopy(xmod_out[trid_out][di]["adj_out"]) if onlyinit: continue # Fetch fwd data file_fwd_dataset = ddi.strftime( "{}/chain/ratio2conc/{}_fwd_%Y%m%d%H%M.nc".format( transform.model.adj_refdir, transform.orig_name)) fwd_dataset = xr.open_dataset(file_fwd_dataset) spec_data = fwd_dataset["spec"].values signatures = {sign: fwd_dataset[sign].values for sign in inputs[:-1]} a_factors = {sign: (1 + signatures[sign] / 1000) * r_stds[sign] for sign in inputs[:-1]} isotopologue_adj = {output: copy.deepcopy(xmod_out[(out_type, output)][di]["adj_out"]) for (out_type, output) in zip(out_types, outputs)} # Mass correction if units are in mass if unit == "mass": for i, iso in enumerate(isotopologue_adj.keys()): isotopologue_adj[iso] *= iso_mass[i] / spec_mass # Sensitivity to total total_sensitivity = \ 0 * xmod_in[(in_types[-1], inputs[-1])][di]["adj_out"] for ioutput, output in enumerate(outputs): total_multip = isotopologue_adj[output].copy() # multiplication terms for each signature for isign, sign in enumerate(inputs[:-1]): t = noutputs // (2 ** (isign + 1)) num_group = ioutput // t if num_group % 2: total_multip *= a_factors[sign] / (1 + a_factors[sign]) else: total_multip *= 1 / (1 + a_factors[sign]) total_sensitivity += total_multip xmod_in[(in_types[-1], inputs[-1])][di]["adj_out"] = \ total_sensitivity.copy() # Sensitivity to signatures for isign, sign in enumerate(inputs[:-1]): signature_sensitivity = 0 * \ xmod_in[(in_types[0], sign)][di]["adj_out"] for ioutput, (out_type, output) in enumerate(zip(out_types, outputs)): total_multip = isotopologue_adj[output].copy() # multiplication terms for each signature increment for isign_times, sign_times in enumerate(inputs[:-1]): t = noutputs // (2 ** (isign_times + 1)) num_group = ioutput // t if isign_times == isign: if num_group % 2: total_multip *= ( r_stds[sign_times] / 1000 / (1 + a_factors[sign_times]) ** 2 ) else: total_multip *= ( - r_stds[sign_times] / 1000 / (1 + a_factors[sign_times]) ** 2 ) else: if num_group % 2: total_multip *= ( a_factors[sign_times] / (1 + a_factors[sign_times]) ) else: total_multip *= 1 / (1 + a_factors[sign_times]) signature_sensitivity += total_multip * spec_data xmod_in[(in_types[0], sign)][di]["adj_out"] = \ signature_sensitivity.copy()