Source code for pycif.utils.classes.transforms

import copy
import pandas as pd
from types import MethodType
import numpy as np
from ...utils.check.errclass import PluginError
from .setup import Setup
from logging import info, debug


[docs] class Transform(Setup): """Plugin type for data transformations in the inversion pipeline. Transforms operate on tracers and fluxes, participate in the forward and adjoint mapper chain, and can propagate incompatible metadata (dates, domains, tracers) across the pipeline. Ensemble/MC batch perturbations are handled via :meth:`mapper2batch`. Concrete implementations live in ``pycif/plugins/transforms/``. """ # Can the transform start a pipeline on its own start_pipe = False # Can the transform start a pipeline on its own end_pipe = False
[docs] def initiate_template(self): """Initialise the Transform plugin template. Loads the registered transform module and attaches all standard methods (``ini_mapper``, ``forward``, ``adjoint``, ``mapper2batch``, ``perturb_transform``, ``flushrun``, ``propagate_incompatible_input_dates``, ``propagate_incompatible_domain``, ``propagate_incompatible_tracer``, ``propagate_incompatible_files``) as bound methods on this instance. """ super(Transform, self).initiate_template( plg_type="transform", default_functions={ "ini_mapper": True, "forward": True, "adjoint": True, "mapper2batch": True, "perturb_transform": True, "flushrun": True, "propagate_incompatible_input_dates": True, "propagate_incompatible_domain": True, "propagate_incompatible_tracer": True, "propagate_incompatible_files": True, } )
[docs] @classmethod def register_plugin(cls, name, version, module, subtype="", **kwargs): """Register a module for a plugin and version with possibly options Args: name (str): name of the plugin version (str): version of the plugin module (types.ModuleType): module defining the interface between pyCIF and the plugin plugin_type (str): type of plugin **kwargs (dictionary): default options for module """ super(Transform, cls).register_plugin( name, version, module, plugin_type="transform", subtype=subtype )
[docs] @classmethod def get_transform(cls, plg): """Load and return the Transform plugin registered under the given plugin's name. Args: plg (Plugin): Plugin whose ``orig_name`` identifies the transform to load. Returns: Transform: Loaded transform plugin instance. """ return cls.load_registered( plg.orig_name, "std", "transform", plg_orig=plg )
[docs] def perturb_transform(self, *args, **kwargs): """Default empty perturb_transform method. Do nothing by default. If specified in the corresponding plugin, the function should perturb its own behaviour beyond the mapper if necessary. For instance, for models, it is necessary to update the chemical scheme to know which species should be accounted for """ return
[docs] def ini_mapper(self, *args, **kwargs): """Default empty ini_mapper method""" raise PluginError("This is the default empty ini_mapper method")
[docs] def flushrun(self, *args, **kwargs): """Default empty flushrun method for transforms""" return
[docs] def mapper2batch(self, nsamples, transf_mapper, all_mapper, all_transforms, dir_samples, file_samples, **kwargs): """Expand a transform's mapper to handle a batch of ensemble/MC samples. For each output tracer ID that feeds a perturbed successor, ``nsamples`` copies are created (suffixed ``__sample#NNN``). The corresponding input tracer IDs are duplicated in the same way so that the full pipeline can process all samples in a single forward/adjoint pass. Special handling is applied for the ``toobsvect`` plugin (observation vector aggregation) and for transforms whose inputs differ from their outputs (controlled via the ``outputs2inputs`` mapper sub-dict). After expanding inputs and outputs, ``subsimus`` sub-simulation entries are also duplicated and :meth:`perturb_transform` is called once to let the plugin adjust its own internal state. Args: nsamples (int): Number of ensemble / Monte Carlo samples. transf_mapper (dict): Mapper dict for this transform, containing ``"inputs"``, ``"outputs"``, ``"precursors"``, ``"successors"``, ``"outputs2inputs"``, and ``"subsimus"`` sub-dicts. Modified in-place. all_mapper (dict): Full pipeline mapper keyed by transform index, used to check whether a tracer ID is directly consumed by a successor without perturbation. all_transforms (list): Ordered list of all transform instances in the pipeline. dir_samples (str): Directory where per-sample working files are stored. file_samples (str): Path template for per-sample file names. **kwargs: Forwarded to :meth:`perturb_transform`. Returns: dict: The updated ``transf_mapper`` with all sample expansions applied. """ list_outputs = list(transf_mapper["outputs"].keys()) list_inputs = list(transf_mapper["inputs"].keys()) debug(list(transf_mapper["inputs"].keys())) debug(list(transf_mapper["outputs"].keys())) debug('-----------------------------------------') # Propagate perturbations from successors to outputs perturbed_output = {} for trid in list_outputs: # Skip if already perturbed if "__sample#" in trid[1]: continue perturbed_output[trid] = False perturbed = [] successors = transf_mapper["successors"][trid] for success in successors: # If trid directly in successor's inputs, means that # no perturbation is necessary if trid in all_mapper[success]["inputs"] \ and not transf_mapper["outputs"][trid].get( "has_control_ancestor", False): perturbed.append(False) continue # Otherwise, copy perturbed samples for i in range(nsamples): trid_sample = (trid[0], f"{trid[1]}__sample#{i:03d}") transf_mapper["outputs"][trid_sample] = \ transf_mapper["outputs"][trid] transf_mapper["successors"][trid_sample] = \ transf_mapper["successors"][trid] perturbed.append(True) # If perturbed is empty, just continue if perturbed == []: continue # If all not perturbed, continue perturbed = np.array(perturbed) if np.all(~perturbed): continue # From here, we know there is some perturbation perturbed_output[trid] = True # Raise exception if inconsistent pipeline if not ( transf_mapper["outputs"][trid].get( "has_control_ancestor", False) or transf_mapper["outputs"][trid].get( "has_obsvect_successor", False) ): raise Exception( "This exception should not happen. There is a coding issue. Contact developpers..." ) # If all perturbed, kill trid and store perturbation info del transf_mapper["outputs"][trid] del transf_mapper["successors"][trid] # Now propagate perturbations from outputs to inputs # Special treatment for tooobsvect if self.plugin.name == "toobsvect": dont_propagate_obsvect = [tuple(dp) for dp in self.dont_propagate_obsvect] for trid in list_inputs: # Do not perturb components as a whole if trid[1] == "": continue if trid in dont_propagate_obsvect \ and not transf_mapper["inputs"][trid].get("has_control_ancestor", False): continue # Skip if already perturbed if "__sample#" in trid[1]: continue for i in range(nsamples): trid_sample = (trid[0], f"{trid[1]}__sample#{i:03d}") transf_mapper["inputs"][trid_sample] = \ transf_mapper["inputs"][trid] transf_mapper["precursors"][trid_sample] = \ transf_mapper["precursors"][trid] del transf_mapper["inputs"][trid] del transf_mapper["precursors"][trid] elif len(transf_mapper["inputs"].keys()) == 0: pass else: outputs2inputs = transf_mapper.get("outputs2inputs", {}) for trid in list_outputs: # Skip if already perturbed # Means that transform was already if "__sample#" in trid[1]: continue trid_sample = (trid[0], f"{trid[1]}__sample#000") trid_tmp = trid_sample if trid_sample in list_inputs else trid if (trid_tmp not in list_inputs and trid not in outputs2inputs): plg = self.plugin plugin_str = f"{plg.type} / {plg.name} / {plg.version}" raise ValueError( "Can not propagate information when the inputs " "are not the same as the outputs. " "Please update the 'outputs2inputs' mapper dict key in " f"the 'ini_mapper' method of plugin '{plugin_str}' " f"accordingly for tracer {trid}" ) if trid_tmp in list_inputs and not perturbed_output[trid]: continue # Loop on inputs linked to the given input list_trid_in = outputs2inputs[trid] for trid_in in list_trid_in: # Skip if trid_is already removed from inputs if trid_in not in transf_mapper["inputs"]: continue # Skip if input not related to any control variable if not transf_mapper["inputs"][trid_in].get("has_control_ancestor", False): continue # Now sample the input for i in range(nsamples): trid_sample = ( trid_in[0], f"{trid_in[1]}__sample#{i:03d}" ) transf_mapper["inputs"][trid_sample] = \ transf_mapper["inputs"][trid_in] transf_mapper["precursors"][trid_sample] = \ transf_mapper["precursors"][trid_in] del transf_mapper["inputs"][trid_in] del transf_mapper["precursors"][trid_in] # Update outputs2inputs outputs2inputs = transf_mapper.get("outputs2inputs", {}) ref_keys = copy.deepcopy(list(outputs2inputs.keys())) for trid in ref_keys: if trid in transf_mapper['outputs']: continue # Duplicate for each member for trid_out in transf_mapper['outputs']: if trid_out[0] != trid[0]: continue if trid_out[1].split("__sample#")[0] != trid[1]: continue outputs2inputs[trid_out] = copy.deepcopy(outputs2inputs[trid]) # Update reference list with members nsample = int(trid_out[1].split("__sample#")[1]) for k, trid_in in enumerate(outputs2inputs[trid_out]): if trid_in in transf_mapper['inputs']: continue trid_in_sample = ( trid_in[0], f"{trid_in[1]}__sample#{nsample:03d}" ) if trid_in_sample not in transf_mapper['inputs']: raise Exception( f"Could not link {trid_out} to any input corresponding to {trid_in}" ) outputs2inputs[trid_out][k] = trid_in_sample # Remove reference list del outputs2inputs[trid] # Update inputs2outputs transf_mapper['inputs2outputs'] = self.generate_inputs2outputs( transf_mapper) # Update subsimus subsimus = transf_mapper["subsimus"] for ddi in subsimus: outputs = subsimus[ddi]["outputs"] list_out = copy.deepcopy(list(outputs.keys())) for trid in list_out: if trid in transf_mapper["outputs"]: continue # Skip if already perturbed if "__sample#" in trid[1]: continue for i in range(nsamples): trid_sample = ( trid[0], f"{trid[1]}__sample#{i:03d}" ) outputs[trid_sample] = outputs[trid] del outputs[trid] inputs = subsimus[ddi]["inputs"] list_in = copy.deepcopy(list(inputs.keys())) for trid in list_in: if trid in transf_mapper["inputs"]: continue # Skip if already perturbed if "__sample#" in trid[1]: continue for i in range(nsamples): trid_sample = ( trid[0], f"{trid[1]}__sample#{i:03d}" ) inputs[trid_sample] = inputs[trid] del inputs[trid] # Now perturb the behaviour of the transform itself if not getattr(self, "__perturbed_transform__", False): self.perturb_transform( nsamples, dir_samples, file_samples, transf_mapper ) self.__perturbed_transform__ = True debug(list(transf_mapper["inputs"].keys())) debug(list(transf_mapper["outputs"].keys())) return transf_mapper
[docs] def propagate_incompatible_dates(self, *args, **kwargs): """Default empty propagate_incompatible_dates method for transforms""" return
[docs] def propagate_incompatible_input_dates(self, *args, **kwargs): """Default empty propagate_incompatible_dates method for transforms""" return
[docs] def propagate_incompatible_domain(self, *args, **kwargs): """Default empty propagate_incompatible_domain method for transforms""" return
[docs] def propagate_incompatible_tracer(self, *args, **kwargs): """Default empty propagate_incompatible_tracer method for transforms""" return
[docs] def propagate_incompatible_input_files(self, *args, **kwargs): """Default empty propagate_incompatible_files method for transforms""" return
[docs] def propagate_incompatible_all_successors_initialized(self, transf_mapper, trid, mode, **kwargs): """Default empty propagate_incompatible_all_successors_initialized method for transforms""" inputs2outputs = transf_mapper.get("inputs2outputs", {})[trid] output_initialized = [ transf_mapper["outputs"][trid_out].get( "all_successors_initialized", False) for trid_out in inputs2outputs if trid_out in transf_mapper["outputs"] ] if output_initialized == []: transf_mapper["inputs"][trid]["all_successors_initialized"] = True
[docs] def propagate_incompatible( self, transf_mapper, attributes, trid, anyNone, mode="backward", **kwargs ): """Propagate incompatible metadata attributes from outputs to inputs. For each attribute in ``attributes``, delegates to the corresponding ``propagate_incompatible_<attribute>`` method on this instance (e.g. ``propagate_incompatible_domain``, ``propagate_incompatible_tracer``). Propagation is skipped when ``anyNone`` is *True* and at least one downstream output has not yet been fully initialised (i.e. ``all_successors_initialized`` is not set), ensuring that the mapper is only updated once the complete downstream context is known. Args: transf_mapper (dict): Mapper dict for this transform (modified in-place by the per-attribute delegate methods). attributes (list[str]): Names of the metadata attributes to propagate (e.g. ``["domain", "tracer", "input_dates"]``). trid (tuple): Tracer identifier ``(component, name)`` for the input whose metadata is being propagated. anyNone (bool): If *True*, propagation is deferred until all successor outputs report ``all_successors_initialized=True``. mode (str): Propagation direction — ``"backward"`` (default) or ``"forward"``. **kwargs: Forwarded to each per-attribute delegate method. """ inputs2outputs = transf_mapper.get("inputs2outputs", {})[trid] output_initialized = [ transf_mapper["outputs"][trid_out].get( "all_successors_initialized", False) for trid_out in inputs2outputs if trid_out in transf_mapper["outputs"] ] if anyNone and not np.all(output_initialized): return for attribute in attributes: method_name = f"propagate_incompatible_{attribute}" if not hasattr(self, method_name): continue getattr(self, method_name)(transf_mapper, trid, mode)
[docs] @classmethod def generate_inputs2outputs(cls, transf_mapper): """Build the inverse of the outputs→inputs mapping. Inverts ``transf_mapper["outputs2inputs"]`` so that each input tracer ID is mapped to the list of output tracer IDs that depend on it. Args: transf_mapper (dict): Mapper dict containing at least ``"inputs"`` and ``"outputs2inputs"`` keys. Returns: dict: Mapping ``{input_trid: [output_trid, ...]}`` for every input tracer ID currently in ``transf_mapper["inputs"]``. """ return { trin: [ trout for trout in transf_mapper["outputs2inputs"] if trin in transf_mapper["outputs2inputs"][trout] ] for trin in transf_mapper["inputs"] }
[docs] def clean_input_dates(self, mapper): """Cleaning input dates to make sure they are of proper format and consistent with output dates Args: mapper (dict[str]): dictionary of the mapper. Returns: dict[str]: cleaned input_dates """ # Cleaning inputs / outputs for inout in ["inputs", "outputs"]: for trid in mapper[inout]: if "input_dates" not in mapper[inout][trid]: continue input_dates = mapper[inout][trid]["input_dates"] if not isinstance(input_dates, dict): raise Exception( f"input_dates for tracer {trid} should be a dictionary with keys corresponding to the different input periods" ) input_dates = { pd.DatetimeIndex([ddi]).to_pydatetime()[0]: pd.DataFrame(np.array(input_dates[ddi]), columns=[ "start_date", "end_date"]).apply(pd.to_datetime) for ddi in input_dates } mapper[inout][trid]["input_dates"] = input_dates return mapper