Source code for pycif.utils.classes.transforms
import copy
from types import MethodType
import numpy as np
from ...utils.check.errclass import PluginError
from .setup import Setup
from logging import info, debug
[docs]
class Transform(Setup):
    # Can the transform start a pipeline on its own
    start_pipe = False
    # Can the transform start a pipeline on its own
    end_pipe = False
[docs]
    def initiate_template(self):
        super(Transform, self).initiate_template(
            plg_type="transform",
            default_functions={
                "ini_mapper": True, "forward": True, "adjoint": True,
                "mapper2batch": True, "perturb_transform": True,
                "flushrun": True,
                "propagate_incompatible_input_dates": True,
                "propagate_incompatible_domain": True,
                "propagate_incompatible_tracer": True,
                "propagate_incompatible_files": True,
            }
        )
[docs]
    @classmethod
    def register_plugin(cls, name, version, module, subtype="", **kwargs):
        """Register a module for a plugin and version with possibly options
        Args:
            name (str):  name of the plugin
            version (str):  version of the plugin
            module (types.ModuleType): module defining the interface
                between pyCIF and the plugin
            plugin_type (str): type of plugin
            **kwargs (dictionary): default options for module
        """
        super(Transform, cls).register_plugin(
            name, version, module, plugin_type="transform", subtype=subtype
        )
[docs]
    @classmethod
    def get_transform(cls, plg):
        """Get the correct Parser for a provider and file_format_id
        Args:
            provider (str):  provider of the input file
            file_format_id (str): name of the type of file with a given format
        Returns:
            Parser: Parser for provider and file_format_id
        """
        return cls.load_registered(
            plg.orig_name, "std", "transform", plg_orig=plg
        )
[docs]
    def perturb_transform(self, *args, **kwargs):
        """Default empty perturb_transform method.
        Do nothing by default.
        If specified in the corresponding plugin, the function should perturb
        its own behaviour beyond the mapper if necessary.
        For instance, for models, it is necessary to update the chemical scheme
        to know which species should be accounted for
        """
        return
[docs]
    def ini_mapper(self, *args, **kwargs):
        """Default empty ini_mapper method"""
        raise PluginError("This is the default empty ini_mapper method")
[docs]
    def flushrun(self, *args, **kwargs):
        """Default empty flushrun method for transforms"""
        return
[docs]
    def propagate_incompatible_dates(self, *args, **kwargs):
        """Default empty propagate_incompatible_dates method for transforms"""
        return
[docs]
    def propagate_incompatible_input_dates(self, *args, **kwargs):
        """Default empty propagate_incompatible_dates method for transforms"""
        return
[docs]
    def propagate_incompatible_domain(self, *args, **kwargs):
        """Default empty propagate_incompatible_domain method for transforms"""
        return
[docs]
    def propagate_incompatible_tracer(self, *args, **kwargs):
        """Default empty propagate_incompatible_tracer method for transforms"""
        return
[docs]
    def propagate_incompatible_input_files(self, *args, **kwargs):
        """Default empty propagate_incompatible_files method for transforms"""
        return
[docs]
    def mapper2batch(self, nsamples, transf_mapper,
                     all_mapper, all_transforms, dir_samples, file_samples,
                     **kwargs):
        """Default mapper2batch method.
        Perturb the mapper to simulate a batch of samples at once.
        Propagate perturbations from precursors to outputs.
        For transformations with no inputs, perturb outputs if "fromcontrol"
        """
        list_outputs = list(transf_mapper["outputs"].keys())
        list_inputs = list(transf_mapper["inputs"].keys())
        debug(list(transf_mapper["inputs"].keys()))
        debug(list(transf_mapper["outputs"].keys()))
        debug('-----------------------------------------')
        # Propagate perturbations from successors to outputs
        perturbed_output = {}
        for trid in list_outputs:
            # Skip if already perturbed
            if "__sample#" in trid[1]:
                continue
            perturbed_output[trid] = False
            perturbed = []
            successors = transf_mapper["successors"][trid]
            for success in successors:
                # If trid directly in successor's inputs, means that
                # no perturbation is necessary
                if trid in all_mapper[success]["inputs"] \
                        and not transf_mapper["outputs"][trid].get(
                            "has_control_ancestor", False):
                    perturbed.append(False)
                    continue
                # Otherwise, copy perturbed samples
                for i in range(nsamples):
                    trid_sample = (trid[0],
                                   "{}__sample#{:03d}".format(trid[1], i))
                    transf_mapper["outputs"][trid_sample] = \
                        transf_mapper["outputs"][trid]
                    transf_mapper["successors"][trid_sample] = \
                        transf_mapper["successors"][trid]
                perturbed.append(True)
            # If perturbed is empty, just continue
            if perturbed == []:
                continue
            # If all not perturbed, continue
            perturbed = np.array(perturbed)
            if np.all(~perturbed):
                continue
            # From here, we know there is some perturbation
            perturbed_output[trid] = True
            # Raise exception if inconsistent pipeline
            if not (
                transf_mapper["outputs"][trid].get(
                    "has_control_ancestor", False)
                or transf_mapper["outputs"][trid].get(
                    "has_obsvect_successor", False)
            ):
                raise Exception(
                    "This exception should not happen. There is a coding issue. Contact developpers..."
                )
            # If all perturbed, kill trid and store perturbation info
            del transf_mapper["outputs"][trid]
            del transf_mapper["successors"][trid]
        # Now propagate perturbations from outputs to inputs
        # Special treatment for tooobsvect
        if self.plugin.name == "toobsvect":
            dont_propagate_obsvect = [tuple(dp)
                                      for dp in self.dont_propagate_obsvect]
            for trid in list_inputs:
                # Do not perturb components as a whole
                if trid[1] == "":
                    continue
                if trid in dont_propagate_obsvect \
                        and not transf_mapper["inputs"][trid].get("has_control_ancestor", False):
                    continue
                # Skip if already perturbed
                if "__sample#" in trid[1]:
                    continue
                for i in range(nsamples):
                    trid_sample = (trid[0],
                                   "{}__sample#{:03d}".format(trid[1], i))
                    transf_mapper["inputs"][trid_sample] = \
                        transf_mapper["inputs"][trid]
                    # transf_mapper["outputs"][trid_sample] = \
                    #     transf_mapper["outputs"][trid]
                    transf_mapper["precursors"][trid_sample] = \
                        transf_mapper["precursors"][trid]
                    # transf_mapper["successors"][trid_sample] = \
                    #     transf_mapper["successors"][trid]
                # del transf_mapper["outputs"][trid]
                del transf_mapper["inputs"][trid]
                del transf_mapper["precursors"][trid]
                # del transf_mapper["successors"][trid]
        elif len(transf_mapper["inputs"].keys()) == 0:
            pass
        else:
            outputs2inputs = transf_mapper.get("outputs2inputs", {})
            for trid in list_outputs:
                # Skip if already perturbed
                # Means that transform was already
                if "__sample#" in trid[1]:
                    continue
                trid_sample = (trid[0], f"{trid[1]}__sample#000")
                trid_tmp = trid_sample if trid_sample in list_inputs else trid
                if (trid_tmp not in list_inputs and trid not in outputs2inputs):
                    plg = self.plugin
                    plugin_str = f"{plg.type} / {plg.name} / {plg.version}"
                    raise ValueError(
                        "Can not propagate information when the inputs "
                        "are not the same as the outputs. "
                        "Please update the 'outputs2inputs' mapper dict key in "
                        f"the 'ini_mapper' method of plugin '{plugin_str}' "
                        f"accordingly for tracer {trid}"
                    )
                if trid_tmp in list_inputs and not perturbed_output[trid]:
                    continue
                # Loop on inputs linked to the given input
                list_trid_in = outputs2inputs[trid]
                for trid_in in list_trid_in:
                    # Skip if trid_is already removed from inputs
                    if trid_in not in transf_mapper["inputs"]:
                        continue
                    # Skip if input not related to any control variable
                    if not transf_mapper["inputs"][trid_in].get("has_control_ancestor", False):
                        continue
                    # Now sample the input
                    for i in range(nsamples):
                        trid_sample = (
                            trid_in[0],
                            "{}__sample#{:03d}".format(trid_in[1], i)
                        )
                        transf_mapper["inputs"][trid_sample] = \
                            transf_mapper["inputs"][trid_in]
                        transf_mapper["precursors"][trid_sample] = \
                            transf_mapper["precursors"][trid_in]
                    del transf_mapper["inputs"][trid_in]
                    del transf_mapper["precursors"][trid_in]
        # Update subsimus
        subsimus = transf_mapper["subsimus"]
        for ddi in subsimus:
            outputs = subsimus[ddi]["outputs"]
            list_out = copy.deepcopy(list(outputs.keys()))
            for trid in list_out:
                if trid in transf_mapper["outputs"]:
                    continue
                # Skip if already perturbed
                if "__sample#" in trid[1]:
                    continue
                for i in range(nsamples):
                    trid_sample = (
                        trid[0],
                        "{}__sample#{:03d}".format(trid[1], i)
                    )
                    outputs[trid_sample] = outputs[trid]
                del outputs[trid]
            inputs = subsimus[ddi]["inputs"]
            list_in = copy.deepcopy(list(inputs.keys()))
            for trid in list_in:
                if trid in transf_mapper["inputs"]:
                    continue
                # Skip if already perturbed
                if "__sample#" in trid[1]:
                    continue
                for i in range(nsamples):
                    trid_sample = (
                        trid[0],
                        "{}__sample#{:03d}".format(trid[1], i)
                    )
                    inputs[trid_sample] = inputs[trid]
                del inputs[trid]
        # Now perturb the behaviour of the transform itself
        if not getattr(self, "__perturbed_transform__", False):
            self.perturb_transform(
                nsamples, dir_samples, file_samples, transf_mapper
            )
        self.__perturbed_transform__ = True
        debug(list(transf_mapper["inputs"].keys()))
        debug(list(transf_mapper["outputs"].keys()))
        return transf_mapper
[docs]
    def propagate_incompatible(self, trans_mapper, attributes, trid, mode="backward", **kwargs):
        """Default method to propagate incompatible attributes"""
        for attribute in attributes:
            method_name = f"propagate_incompatible_{attribute}"
            if not hasattr(self, method_name):
                continue
            getattr(self, method_name)(trans_mapper, trid, mode)