Source code for pycif.plugins.obsoperators.standard.transforms.utils.connect_pipes

import copy
import itertools
from logging import debug
from . import add_default


[docs] def connect_pipes(all_transforms, mapper, transform): """Connect transforms based on their inputs and outputs""" debug(f"Connecting pipeline around {transform}") # Check whether transform is already in transform_pipe, # otherwise raise exception if transform not in all_transforms.attributes: raise AttributeError( f"Cannot connect {transform} to the rest of the pipeline. " "Include it to the pipeline before calling connect_pipe" ) itransf = all_transforms.attributes.index(transform) transf_mapper = mapper[transform] # Update precursors and successors with all possible transformation # using the corresponding input/output # For each transformation, looking for all trid in its own inputs that # are found in outputs from another transform before it transf_mapper["precursors"] = transf_mapper.get("precursors", {}) precursors = transf_mapper["precursors"] ref_precursors = copy.deepcopy(precursors) for inpt in transf_mapper["inputs"]: if inpt not in precursors: precursors[inpt] = [] # If already some precursors, skip if ref_precursors[inpt] != []: continue # Otherwise, fetch all possible precursors for tr in all_transforms.attributes[:itransf]: if inpt in mapper[tr]["outputs"] \ and tr not in precursors[inpt]: precursors[inpt].append(tr) transf_mapper["successors"] = transf_mapper.get("successors", {}) successors = transf_mapper["successors"] ref_successors = copy.deepcopy(successors) for outpt in transf_mapper["outputs"]: if outpt not in successors: successors[outpt] = [] # Skip if successors before initialization if ref_successors[outpt] != []: continue # Otherwise, fetch all possible successors for tr in all_transforms.attributes[itransf + 1:]: if outpt in mapper[tr]["inputs"] \ and tr not in successors[outpt]: successors[outpt].append(tr) # Clean successors and precursors by removing direct pipes # This means that some successive transforms may have the same # inputs/outputs and only carry on computations on them only, i.e., # not changing the shape of the datastore. successors = transf_mapper["successors"] for trid in successors: redundancy = set(itertools.chain( *[mapper[tr_ref]["successors"].get(trid, []) for tr_ref in successors[trid]])) for trtmp in redundancy: # Skip if trtmp was already in pre-defined successors if trtmp in ref_successors[trid]: continue # Otherwise exclude it if trtmp != transform and trtmp in successors[trid]: successors[trid].remove(trtmp) precursors = transf_mapper["precursors"] for trid in precursors: redundancy = set(itertools.chain( *[mapper[tr_ref]["precursors"].get(trid, []) for tr_ref in precursors[trid]])) for trtmp in redundancy: # Skip if trtmp was already in pre-defined precursors if trtmp in ref_precursors[trid]: continue # Otherwise exclude it if trtmp != transform and trtmp in precursors[trid]: precursors[trid].remove(trtmp) # Now update precursors (resp. successors) # of transforms after (resp. before) this one for trid in successors: for trtmp in successors[trid]: tmp_mapper = mapper[trtmp] tmp_mapper["precursors"][trid] = [ t for t in tmp_mapper["precursors"][trid] if t not in precursors.get(trid, []) ] if transform not in tmp_mapper["precursors"][trid]: tmp_mapper["precursors"][trid].append(transform) # If several precursors for a given trid, take last if len(tmp_mapper["precursors"][trid]) > 1: list_indexes = [ all_transforms.attributes.index(tr) for tr in tmp_mapper["precursors"][trid] ] if ref_precursors.get(trid, []) != []: tmp_mapper["precursors"][trid] = \ ref_precursors[trid] continue tmp_mapper["precursors"][trid] = \ [all_transforms.attributes[max(list_indexes)]] for tr_index in list_indexes: if tr_index != max(list_indexes): mapper[ all_transforms.attributes[tr_index] ]["successors"][trid].remove(trtmp) for trid in precursors: for trtmp in precursors[trid]: tmp_mapper = mapper[trtmp] tmp_mapper["successors"][trid] = [ t for t in tmp_mapper["successors"][trid] if t not in successors.get(trid, []) ] if transform not in tmp_mapper["successors"][trid]: tmp_mapper["successors"][trid].append(transform)
# Now prune dead branches # prune_dead_branches(all_transforms, mapper, transform) def prune_dead_branches(all_transforms, mapper, skip_transform): ref_transforms = copy.deepcopy(all_transforms.attributes) for transform in ref_transforms: # Skipping toobsvect if "toobsvect" in transform: continue # Skip transform being initialize if transform == skip_transform: continue successors = mapper[transform].get("successors", {}) # Remove transform if no successors if sum([len(successors[trid]) for trid in successors]) == 0: debug(f"Removing {transform}") # Discard it from precursors for trid in mapper[transform].get("precursors", {}): for tr in mapper[transform]["precursors"][trid]: mapper[tr]["successors"][trid].remove(transform) # Now deleting the transform itself mapper[transform] = {} del mapper[transform] all_transforms.attributes.remove(transform) setattr(all_transforms, transform, None)