Source code for pycif.plugins.modes.ensrf.utils

import numpy as np
from collections import defaultdict
import copy

[docs] def get_path_in_dict(d, value, prepath=()): """Find occurences of a value in a dictionary (or a list) recursively and returns a list of the paths leading to this occurence. Args: d (dict or list): Dictionary or list in which value is searched for. value (str): the string element that is searched for in the keys or values. prepath (tuple): path to pre-append. Mainly here for recursivity. Returns: list_paths (list): list of paths where the value has been found. """ list_paths = [] if isinstance(d, list): for elt in d: if elt == value: list_paths.append(prepath + (elt,)) if isinstance(d, dict): for k, v in d.items(): path = prepath + (k,) if k == value: list_paths.append(path) if v == value: list_paths.append(path + (v, )) if isinstance(v, dict) or isinstance(v, list): list_paths += get_path_in_dict(v, value, path) return list_paths
[docs] def get_trids_dontpropagate(yml_dict, controlvect): """List including the controlvect trids that should not be affected by the propagation of the perturbed samples. They include the trids with a False 'iscontrol' parameter and their successors in the transforms mentioned in the Yaml. Args: controlvect (Plugin): controlvect Plugin Returns: list_trids_dp (list): List including the controlvect trids that should not be affected by the propagation of the perturbed samples. """ list_trids_dp = [] # Find the main trids in the datavect components = controlvect.datavect.components for comp in components.attributes: component = getattr(components, comp) if not hasattr(component, "parameters"): continue for trcr in component.parameters.attributes: tracer = getattr(component.parameters, trcr) if not tracer.iscontrol: list_trids_dp.append((comp, trcr)) # Store the inputs and outputs of transforms dict_inout = {} transf_pipe = controlvect.transform_pipe for transf in transf_pipe.attributes: transform = getattr(transf_pipe, transf) parameters_in = None for namep_in in ["parameter", "parameter_in", "parameters_in"]: if hasattr(transform, namep_in): parameters_in = getattr(transform, namep_in) parameters_in = [parameters_in] if isinstance(parameters_in, str) else parameters_in continue components_in = None for namec_in in ["component", "component_in", "components_in"]: if hasattr(transform, namec_in): components_in = getattr(transform, namec_in) components_in = [components_in] if isinstance(components_in, str) else components_in if len(components_in) != len(parameters_in): components_in *= len(parameters_in) continue parameters_out = None for namep_out in ["parameter_out", "parameters_out"]: if hasattr(transform, namep_out): parameters_out = getattr(transform, namep_out) parameters_out = [parameters_out] if isinstance(parameters_out, str) else parameters_out continue components_out = None for namec_out in ["component_out", "components_out"]: if hasattr(transform, namec_out): components_out = getattr(transform, namec_out) components_out = [components_out] if isinstance(components_out, str) else components_out if len(components_out) != len(parameters_out): components_out *= len(parameters_out) continue if parameters_out is None: parameters_out = parameters_in if components_out is None: components_out = components_in for trid_out in list(zip(components_out, parameters_out)): dict_inout[trid_out] = list(zip(components_in, parameters_in)) # Instead of creating a complex dependency graph to handle connected transformations, just go over the dictionnary # multiple times, because the number of connected transformations is generally not large while True: changed = False for trid_out in dict_inout: all_present = all(elem in list_trids_dp for elem in dict_inout[trid_out]) if all_present and trid_out not in list_trids_dp: list_trids_dp.append(trid_out) changed = changed | True if not changed: break list_trids_dp = [list(tup) for tup in list_trids_dp] return list_trids_dp
[docs] def get_trids_dontpropagate_obsvect(yml_dict, controlvect): """List including the controlvect trids that should be affected by the propagation of the perturbed samples. They include the trids with an attribute dont_propagate_obsvect: True. Args: controlvect (Plugin): controlvect Plugin Returns: list_trids_dpo (list): List including the controlvect trids that should be affected by the propagation of the perturbed samples. """ list_trids_dpo = [] # Find the main trids in the datavect components = controlvect.datavect.components for comp in components.attributes: component = getattr(components, comp) if not hasattr(component, "parameters"): continue for trcr in component.parameters.attributes: tracer = getattr(component.parameters, trcr) if getattr(tracer, "dont_propagate_obsvect", False): list_trids_dpo.append((comp, trcr)) list_trids_dpo = [list(tup) for tup in list_trids_dpo] return list_trids_dpo
[docs] def check_tresol(controlvect, window_length): """Check the tresol argument of the optimized components. It must be equal to the window_length. Args: controlvect (Plugin): controlvect Plugin window_length (string): EnSRF window length """ components = controlvect.datavect.components for comp in components.attributes: component = getattr(components, comp) if not hasattr(component, "parameters"): continue for trcr in component.parameters.attributes: tracer = getattr(component.parameters, trcr) if not tracer.iscontrol: continue if getattr(tracer, "tresol", "") != window_length: raise AttributeError( f"'tresol' for {comp} must be equal to" f" {window_length=} to prevent strange behaviour in EnSRF." ) return