Source code for pycif.plugins.transforms.complex.ratio2conc.forward

import numpy as np
import pandas as pd
import copy
import xarray as xr
import os
from .....utils.path import init_dir


[docs] def forward( transform, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, onlyinit=False, **kwargs ): ddi = min(di, df) xmod_out = inout_datastore["outputs"] xmod_in = inout_datastore["inputs"] trids_in = list(mapper["inputs"].keys()) trids_out = list(mapper["outputs"].keys()) inputs_all = [trid[1] for trid in trids_in] outputs_all = [trid[1] for trid in trids_out] in_types_all = [trid[0] for trid in trids_in] out_types_all = [trid[0] for trid in trids_out] iso_mass = transform.parameters_out.iso_mass spec_mass = transform.parameters_out.spec_mass unit = transform.unit # Find the standards for every kind of ratios r_stds = {} for sign in inputs_all: ref_sign = sign.split("__sample#")[0] \ if "__sample#" in sign else sign try: r_stds[sign] = getattr(transform.parameters_out, ref_sign).standard except AttributeError: pass # Disentangle samples if any if any("__sample#" in inp for inp in inputs_all): trids_in_pert = [trid for trid in trids_in if "__sample#" in trid[1]] trids_out_pert = [trid for trid in trids_out if "__sample#" in trid[1]] trids_in_nopert = [ trid for trid in trids_in if not "__sample#" in trid[1]] trids_out_nopert = [ trid for trid in trids_out if not "__sample#" in trid[1]] nsamples = max([int(trid[1].split("__sample#")[1]) for trid in trids_in_pert]) + 1 list_idxs_input = [] list_idxs_output = [] for isample in range(nsamples): sample_id_str = "__sample#{:03d}".format(isample) list_idxs_input.append([idx for idx, trid in enumerate(trids_in_pert) if sample_id_str in trid[1]]) list_idxs_output.append([idx for idx, trid in enumerate(trids_out_pert) if sample_id_str in trid[1]]) inputs_sorted = [[trids_in_pert[idx][1] for idx in idxs] + [trid[1] for trid in trids_in_nopert] for idxs in list_idxs_input] outputs_sorted = [[trids_out_pert[idx][1] for idx in idxs] + [trid[1] for trid in trids_out_nopert] for idxs in list_idxs_output] in_types_sorted = [[trids_in_pert[idx][0] for idx in idxs] + [trid[0] for trid in trids_in_nopert] for idxs in list_idxs_input] out_types_sorted = [[trids_out_pert[idx][0] for idx in idxs] + [trid[0] for trid in trids_out_nopert] for idxs in list_idxs_output] else: inputs_sorted = [inputs_all] outputs_sorted = [outputs_all] in_types_sorted = [in_types_all] out_types_sorted = [out_types_all] # Run the main code for all groups of samples for inputs, outputs, in_types, out_types in \ zip(inputs_sorted, outputs_sorted, in_types_sorted, out_types_sorted): noutputs = len(outputs) keys = ['spec'] if mode == 'tl': keys.append('incr') trid = (in_types[-1], inputs[-1]) for out_type, output in zip(out_types, outputs): trid_out = (out_type, output) xmod_out[trid_out][di] = { k: xmod_in[trid][di][k] for k in xmod_in[trid][di]} xmod_out[trid_out][di]["spec"] = \ copy.deepcopy(xmod_in[trid][di]["spec"]) if mode == 'tl': xmod_out[trid_out][di]["incr"] = \ copy.deepcopy(xmod_in[trid][di]["incr"]) signatures = {sign: xmod_in[(in_types[0], sign)][di]['spec'] for sign in inputs[:-1]} spec_data = xmod_in[(in_types[-1], inputs[-1])][di]['spec'] # Check that standards are defined missing_rstds = [sign for sign in inputs[:-1] if sign not in r_stds] if missing_rstds != []: raise Exception( "Warning! There are missing standards in your Yaml to compute the isotopic transform.\n" f"Missing species in the `parameters_out` paragraph: {missing_rstds}" ) a_factors = { sign: (1 + signatures[sign] / 1000) * r_stds[sign] for sign in inputs[:-1]} # Save signature and data for later use by adjoint dump_dir = "{}/../chain/ratio2conc/".format( runsubdir, transform.orig_name) if not os.path.isdir(dump_dir): init_dir(dump_dir) file_fwd_dataset = ddi.strftime( "{}/../chain/ratio2conc/{}_fwd_%Y%m%d%H%M.nc" .format(runsubdir, transform.orig_name)) fwd_dataset = spec_data.to_dataset(name="spec") for sign in inputs[:-1]: fwd_dataset[sign] = signatures[sign] fwd_dataset.to_netcdf(file_fwd_dataset) # Compute fwd spec for ioutput, (out_type, output) in enumerate(zip(out_types, outputs)): for isign, sign in enumerate(inputs[:-1]): t = noutputs // (2 ** (isign + 1)) num_group = ioutput // t if num_group % 2: xmod_out[(out_type, output)][di]['spec'] *= \ a_factors[sign] / (1 + a_factors[sign]) else: xmod_out[(out_type, output)][di]['spec'] *= \ 1 / (1 + a_factors[sign]) # Applying tangent linear if mode == 'tl': signatures_tl = { sign: xmod_in[(in_types[0], sign)][di]['incr'] for sign in inputs[:-1]} spec_data_tl = xmod_in[(in_types[-1], inputs[-1])][di]['incr'] for ioutput, (out_type, output) in enumerate(zip(out_types, outputs)): total_sum = spec_data_tl.copy() # spec_data increment for isign, sign in enumerate(inputs[:-1]): t = noutputs // (2 ** (isign + 1)) num_group = ioutput // t if num_group % 2: total_sum *= a_factors[sign] / (1 + a_factors[sign]) else: total_sum *= 1 / (1 + a_factors[sign]) # signatures increment for isign_plus, sign_plus in enumerate(inputs[:-1]): total_multip = signatures_tl[sign_plus].copy() # multiplication terms for each signature increment for isign_times, sign_times in enumerate(inputs[:-1]): t = noutputs // (2 ** (isign_times + 1)) num_group = ioutput // t if isign_times == isign_plus: if num_group % 2: total_multip *= r_stds[sign_times] / 1000\ / (1 + a_factors[sign_times]) ** 2 else: total_multip *= - r_stds[sign_times] / 1000\ / (1 + a_factors[sign_times]) ** 2 else: if num_group % 2: total_multip *= a_factors[sign_times] \ / (1 + a_factors[sign_times]) else: total_multip *= 1 / (1 + a_factors[sign_times]) total_sum += total_multip * spec_data xmod_out[(out_type, output) ][di]['incr'] = copy.deepcopy(total_sum) # Mass correction if units are in mass if unit == 'mass': for i, (out_type, output) in enumerate(zip(out_types, outputs)): xmod_out[(out_type, output)][di]['spec'] *= \ (iso_mass[i] / spec_mass) if mode == 'tl': for i, (out_type, output) in enumerate(zip(out_types, outputs)): xmod_out[(out_type, output)][di]['incr'] *= \ (iso_mass[i] / spec_mass)