Source code for pycif.plugins.transforms.system.fromcontrol.perturb_transform
from logging import debug
[docs]
def perturb_transform(self, nsamples, dir_samples, file_samples, transf_mapper):
"""Prepare the fromcontrol transform for ensemble (perturbation) mode.
Registers the ensemble control-vector file path on the transform so
that :func:`forward_perturb` can load all members at once, and extends
the datavect to include one tracer copy per sample (``__sample#N``
naming convention). The reference (non-sampled) tracers are then
removed from the datavect to avoid duplicate processing.
Args:
self (Plugin): fromcontrol transform instance; receives
``perturb_xb = True``, ``nsamples``, ``dir_samples``, and
``file_samples`` as new attributes.
nsamples (int): total number of ensemble members.
dir_samples (str): directory containing the ensemble control-vector
file.
file_samples (str): file name of the ensemble control vector.
transf_mapper (dict): the transform mapper; its ``'outputs'`` dict
is scanned for ``__sample#N`` tracer IDs to set up.
"""
# Just keep in memory that x should be perturbed
self.perturb_xb = True
self.nsamples = nsamples
self.dir_samples = dir_samples
self.file_samples = file_samples
# Update data vector to fit samples
datavect = self.controlvect.datavect
components = datavect.components
trid2clean = []
for trid in transf_mapper["outputs"]:
comp, trcr = trid
if "__sample#" not in trcr:
continue
trcr_ref = trcr.split("__sample#")[0]
component = getattr(components, comp)
if not hasattr(component, "parameters"):
raise KeyError(f"Component '{comp}' has no 'parameters' attribute. "
"Please check your yaml.")
params = component.parameters
if not hasattr(params, trcr_ref):
debug(f"Skipping tracer {trid} for 'fromcontrol' transfrom as "
f"{trcr_ref} is not in the datavect")
continue
tracer = getattr(params, trcr_ref)
params.attributes.append(trcr)
trid2clean.append((comp, trcr_ref))
# Copy tracer
tracer_out = tracer.__class__(plg_orig=tracer)
setattr(params, trcr, tracer_out)
# Update varname if not specified
if tracer_out.varname == "":
tracer_out.varname = trcr_ref
transf_mapper["outputs"][trid]["tracer"] = tracer_out
for trid in set(trid2clean):
comp, trcr = trid
component = getattr(components, comp)
params = component.parameters
params.attributes.remove(trcr)
delattr(params, trcr)