Source code for pycif.plugins.transforms.system.toobsvect.perturb_transform
import copy
from logging import debug
import numpy as np
[docs]
def perturb_transform(self, nsamples, dir_samples, file_samples, transf_mapper):
"""Extend the observation vector to accommodate all ensemble members.
For each ``__sample#N`` tracer found in the mapper inputs, creates a
copy of the reference tracer with a shifted ``ypointer`` (offset by
``member_ID × obsvect.dim_ref``) and a deep copy of the datastore
with the parameter name updated to the sample-specific name.
The global ``obsvect.dim`` is extended to
``(nmembers + 1) × dim_ref``, and the data arrays ``dy``,
``ysim``, ``yobs``, ``yobs_err``, and ``obsvect_mask`` are tiled
accordingly.
Reference (non-sampled) tracers are removed from the datavect to
prevent double-counting.
Args:
self (Plugin): toobsvect transform instance.
nsamples (int): total number of ensemble members.
dir_samples (str): directory containing the ensemble files
(unused here; kept for API consistency).
file_samples (str): ensemble control-vector file name (unused).
transf_mapper (dict): transform mapper; ``'inputs'`` is scanned
for ``__sample#N`` tracer IDs.
"""
obsvect = self.obsvect
datavect = obsvect.datavect
components = datavect.components
# Save original dimension
if not hasattr(obsvect, "dim_ref"):
obsvect.dim_ref = obsvect.dim
# Update data vector to fit samples
trid2clean = []
nmembers = 0
for trid in transf_mapper["inputs"]:
comp = trid[0]
trcr = trid[1]
if "__sample#" not in trcr:
continue
trcr_ref = trcr.split("__sample#")[0]
member_ID = int(trcr.split("__sample#")[1])
nmembers = max(nmembers, member_ID)
component = getattr(components, comp)
params = component.parameters
if not hasattr(params, trcr_ref):
debug(
f"Skipping {trid} for 'tooobsvect' as {trcr_ref} "
"is not in datavect"
)
continue
tracer = getattr(params, trcr_ref)
params.attributes.append(trcr)
trid2clean.append((comp, trcr_ref))
# Copy tracer
tracer_out = tracer.__class__(plg_orig=tracer)
setattr(params, trcr, tracer_out)
# Update tracer pointer
tracer_out.ypointer = tracer.ypointer + member_ID * obsvect.dim
# Update tracer datastore
tracer_out.datastore = copy.deepcopy(tracer_out.datastore)
if tracer_out.datastore["metadata"]["parameter"].dtype == 'category':
tracer_out.datastore.loc[:, ("metadata", "parameter")] = \
tracer_out.datastore["metadata"]["parameter"].cat.add_categories(
trcr)
tracer_out.datastore.loc[
tracer_out.datastore["metadata"]["parameter"] == trcr_ref,
("metadata", "parameter")] = trcr
# Update obsvect dimension and data
obsvect.dim = (nmembers + 1) * obsvect.dim_ref
# Extend data according to new dimension
for data_id in ["dy", "ysim", "yobs", "yobs_err", "obsvect_mask"]:
setattr(obsvect, data_id,
np.tile(getattr(obsvect, data_id),
nmembers + 1))
# Removing non-perturbed tracers
for trid in set(trid2clean):
comp = trid[0]
trcr = trid[1]
component = getattr(components, comp)
params = component.parameters
params.attributes.remove(trcr)
delattr(params, trcr)