import numpy as np
from collections import defaultdict
import copy
[docs]
def get_path_in_dict(d, value, prepath=()):
"""Find occurences of a value in a dictionary (or a list)
recursively and returns a list of the paths leading to this occurence.
Args:
d (dict or list): Dictionary or list in which value is searched for.
value (str): the string element that is searched for in the keys or values.
prepath (tuple): path to pre-append. Mainly here for recursivity.
Returns:
list_paths (list): list of paths where the value has been found.
"""
list_paths = []
if isinstance(d, list):
for elt in d:
if elt == value:
list_paths.append(prepath + (elt,))
if isinstance(d, dict):
for k, v in d.items():
path = prepath + (k,)
if k == value:
list_paths.append(path)
if v == value:
list_paths.append(path + (v, ))
if isinstance(v, dict) or isinstance(v, list):
list_paths += get_path_in_dict(v, value, path)
return list_paths
[docs]
def get_trids_dontpropagate(yml_dict, controlvect):
"""List including the controlvect trids that should not be affected by the propagation of the perturbed samples.
They include the trids with a False 'iscontrol' parameter and their successors in the transforms mentioned in the
Yaml.
Args:
controlvect (Plugin): controlvect Plugin
Returns:
list_trids_dp (list): List including the controlvect trids that should not be affected by the propagation of
the perturbed samples.
"""
list_trids_dp = []
# Find the main trids in the datavect
components = controlvect.datavect.components
for comp in components.attributes:
component = getattr(components, comp)
if not hasattr(component, "parameters"):
continue
for trcr in component.parameters.attributes:
tracer = getattr(component.parameters, trcr)
if not tracer.iscontrol:
list_trids_dp.append((comp, trcr))
# Store the inputs and outputs of transforms
dict_inout = {}
transf_pipe = controlvect.transform_pipe
for transf in transf_pipe.attributes:
transform = getattr(transf_pipe, transf)
parameters_in = None
for namep_in in ["parameter", "parameter_in", "parameters_in"]:
if hasattr(transform, namep_in):
parameters_in = getattr(transform, namep_in)
parameters_in = [parameters_in] if isinstance(parameters_in, str) else parameters_in
continue
components_in = None
for namec_in in ["component", "component_in", "components_in"]:
if hasattr(transform, namec_in):
components_in = getattr(transform, namec_in)
components_in = [components_in] if isinstance(components_in, str) else components_in
if len(components_in) != len(parameters_in):
components_in *= len(parameters_in)
continue
parameters_out = None
for namep_out in ["parameter_out", "parameters_out"]:
if hasattr(transform, namep_out):
parameters_out = getattr(transform, namep_out)
parameters_out = [parameters_out] if isinstance(parameters_out, str) else parameters_out
continue
components_out = None
for namec_out in ["component_out", "components_out"]:
if hasattr(transform, namec_out):
components_out = getattr(transform, namec_out)
components_out = [components_out] if isinstance(components_out, str) else components_out
if len(components_out) != len(parameters_out):
components_out *= len(parameters_out)
continue
if parameters_out is None:
parameters_out = parameters_in
if components_out is None:
components_out = components_in
for trid_out in list(zip(components_out, parameters_out)):
dict_inout[trid_out] = list(zip(components_in, parameters_in))
# Instead of creating a complex dependency graph to handle connected transformations, just go over the dictionnary
# multiple times, because the number of connected transformations is generally not large
while True:
changed = False
for trid_out in dict_inout:
all_present = all(elem in list_trids_dp for elem in dict_inout[trid_out])
if all_present and trid_out not in list_trids_dp:
list_trids_dp.append(trid_out)
changed = changed | True
if not changed:
break
list_trids_dp = [list(tup) for tup in list_trids_dp]
return list_trids_dp
[docs]
def get_trids_dontpropagate_obsvect(yml_dict, controlvect):
"""List including the controlvect trids that should be affected by the
propagation of the perturbed samples.
They include the trids with an attribute dont_propagate_obsvect: True.
Args:
controlvect (Plugin): controlvect Plugin
Returns:
list_trids_dpo (list): List including the controlvect
trids that should be affected by the
propagation of the perturbed samples.
"""
list_trids_dpo = []
# Find the main trids in the datavect
components = controlvect.datavect.components
for comp in components.attributes:
component = getattr(components, comp)
if not hasattr(component, "parameters"):
continue
for trcr in component.parameters.attributes:
tracer = getattr(component.parameters, trcr)
if getattr(tracer, "dont_propagate_obsvect", False):
list_trids_dpo.append((comp, trcr))
list_trids_dpo = [list(tup) for tup in list_trids_dpo]
return list_trids_dpo
[docs]
def check_tresol(controlvect, window_length):
"""Check the tresol argument of the optimized components.
It must be equal to the window_length.
Args:
controlvect (Plugin): controlvect Plugin
window_length (string): EnSRF window length
"""
components = controlvect.datavect.components
for comp in components.attributes:
component = getattr(components, comp)
if not hasattr(component, "parameters"):
continue
for trcr in component.parameters.attributes:
tracer = getattr(component.parameters, trcr)
if not tracer.iscontrol:
continue
if getattr(tracer, "tresol", "") != window_length:
raise AttributeError(
f"'tresol' for {comp} must be equal to"
f" {window_length=} to prevent strange behaviour in EnSRF."
)
return