Source code for pycif.utils.classes.transforms

import copy

from types import MethodType
import numpy as np
from ...utils.check.errclass import PluginError
from .setup import Setup


[docs]class Transform(Setup): # Can the transform start a pipeline on its own start_pipe = False # Can the transform start a pipeline on its own end_pipe = False
[docs] def initiate_template(self): super(Transform, self).initiate_template( plg_type="transform", default_functions={ "ini_mapper": True, "forward": True, "adjoint": True, "mapper2batch": True, "perturb_transform": True, "flushrun": True } )
[docs] @classmethod def register_plugin(cls, name, version, module, subtype="", **kwargs): """Register a module for a plugin and version with possibly options Args: name (str): name of the plugin version (str): version of the plugin module (types.ModuleType): module defining the interface between pyCIF and the plugin plugin_type (str): type of plugin **kwargs (dictionary): default options for module """ super(Transform, cls).register_plugin( name, version, module, plugin_type="transform", subtype=subtype )
[docs] @classmethod def get_transform(cls, plg): """Get the correct Parser for a provider and file_format_id Args: provider (str): provider of the input file file_format_id (str): name of the type of file with a given format Returns: Parser: Parser for provider and file_format_id """ return cls.load_registered( plg.orig_name, "std", "transform", plg_orig=plg )
[docs] def perturb_transform(self, *args, **kwargs): """Default empty perturb_transform method. Do nothing by default. If specified in the corresponding plugin, the function should perturb its own behaviour beyond the mapper if necessary. For instance, for models, it is necessary to update the chemical scheme to know which species should be accounted for """ return
[docs] def ini_mapper(self, *args, **kwargs): """Default empty ini_mapper method""" raise PluginError("This is the default empty ini_mapper method")
[docs] def flushrun(self, *args, **kwargs): """Default empty flushrun method for transforms""" return
[docs] def mapper2batch(self, nsamples, transf_mapper, all_mapper, all_transforms, dir_samples, **kwargs): """Default mapper2batch method. Perturb the mapper to simulate a batch of samples at once. Propagate perturbations from precursors to outputs. For transformations with no inputs, perturb outputs if "fromcontrol" """ list_outputs = list(transf_mapper["outputs"].keys()) list_inputs = list(transf_mapper["inputs"].keys()) # Propagate perturbations from successors to outputs perturbed_output = {} for trid in list_outputs: # Skip if already perturbed if "__sample#" in trid[1]: continue perturbed_output[trid] = False perturbed = [] successors = transf_mapper["successors"][trid] for success in successors: # If trid directly in successor's inputs, means that # no perturbation is necessary if trid in all_mapper[success]["inputs"]: perturbed.append(False) continue # Otherwise, copy perturbed samples for i in range(nsamples): trid_sample = (trid[0], "{}__sample#{:03d}".format(trid[1], i)) transf_mapper["outputs"][trid_sample] = \ transf_mapper["outputs"][trid] transf_mapper["successors"][trid_sample] = \ transf_mapper["successors"][trid] perturbed.append(True) # Check consistency of perturbation from successors perturbed = np.array(perturbed) if not (np.all(perturbed) or np.all(~perturbed)): raise Exception("Successors are not perturbed " "in a consistent way") if np.all(perturbed) and len(perturbed) > 0: del transf_mapper["outputs"][trid] del transf_mapper["successors"][trid] perturbed_output[trid] = True # Now propagate perturbations from outputs to inputs # Special treatment for tooobsvect if self.plugin.name == "toobsvect": for trid in list_inputs: # Do not perturb components as a whole if trid[1] == "": continue # Skip if already perturbed if "__sample#" in trid[1]: continue for i in range(nsamples): trid_sample = (trid[0], "{}__sample#{:03d}".format(trid[1], i)) transf_mapper["inputs"][trid_sample] = \ transf_mapper["inputs"][trid] transf_mapper["outputs"][trid_sample] = \ transf_mapper["outputs"][trid] transf_mapper["precursors"][trid_sample] = \ transf_mapper["precursors"][trid] transf_mapper["successors"][trid_sample] = \ transf_mapper["successors"][trid] del transf_mapper["outputs"][trid] del transf_mapper["inputs"][trid] del transf_mapper["precursors"][trid] del transf_mapper["successors"][trid] elif len(transf_mapper["inputs"].keys()) == 0: pass else: outputs2inputs = transf_mapper.get("outputs2inputs", {}) # Remove trids that should not be propagated dont_propagate = [tuple(dp) for dp in self.dont_propagate] for trid_out, trids_in in outputs2inputs.items(): trids2prop_in = list(set(trids_in) - set(dont_propagate)) outputs2inputs[trid_out] = trids2prop_in for trid in list_outputs: # Skip if already perturbed if "__sample#" in trid[1]: continue trid_sample = (trid[0], f"{trid[1]}__sample#000") trid_tmp = trid_sample if trid_sample in list_inputs else trid if (trid_tmp not in list_inputs and trid not in outputs2inputs): raise Exception( "Can't propagate information when inputs " "are not same as outputs. Please update 'outputs2inputs' " "accordingly for {}".format(trid)) if trid_tmp in list_inputs and not perturbed_output[trid]: continue if trid_tmp in list_inputs: list_trid_in = [trid] if trid_tmp not in list_inputs: list_trid_in = outputs2inputs[trid] for trid_in in list_trid_in: # Skip if trid_in already removed from inputs if trid_in not in transf_mapper["inputs"]: continue for i in range(nsamples): trid_sample = (trid_in[0], "{}__sample#{:03d}".format(trid_in[1], i)) transf_mapper["inputs"][trid_sample] = \ transf_mapper["inputs"][trid_in] transf_mapper["precursors"][trid_sample] = \ transf_mapper["precursors"][trid_in] del transf_mapper["inputs"][trid_in] del transf_mapper["precursors"][trid_in] # Update subsimus subsimus = transf_mapper["subsimus"] for ddi in subsimus: outputs = subsimus[ddi]["outputs"] list_out = copy.deepcopy(list(outputs.keys())) for trid in list_out: if trid in transf_mapper["outputs"]: continue # Skip if already perturbed if "__sample#" in trid[1]: continue for i in range(nsamples): trid_sample = (trid[0], "{}__sample#{:03d}".format(trid[1], i)) outputs[trid_sample] = outputs[trid] del outputs[trid] inputs = subsimus[ddi]["inputs"] list_in = copy.deepcopy(list(inputs.keys())) for trid in list_in: if trid in transf_mapper["inputs"]: continue # Skip if already perturbed if "__sample#" in trid[1]: continue for i in range(nsamples): trid_sample = (trid[0], "{}__sample#{:03d}".format(trid[1], i)) inputs[trid_sample] = inputs[trid] del inputs[trid] # Now perturb the behaviour of the transform itself if not getattr(self, "__perturbed_transform__", False): self.perturb_transform(nsamples, dir_samples, transf_mapper) self.__perturbed_transform__ = True return transf_mapper