import copy
import numpy as np
import pandas as pd
import os
from logging import debug, warning
try:
import cPickle as pickle
except ImportError:
import pickle
[docs]
def select_obs2assimilate(obsvect, hx_samples_dev, mask_obs_assimilated):
""" Apply multiple operations to select the information (observations, controlvect)
that must be considered/updated as part of the current optimization segment.
1) Select the observations that must be assimilated
2) Select the part of the controlvector that must be updated
Args:
obsvect (Plugin): obsvect Plugin
hx_samples_dev (str): for each member, deviations of the simulated values from the mean
mask_obs_assimilated (np.array): mask indicating the observations already assimilated
Returns:
list_obs2assim_idx (list): list of observation indexes that must be assimilated
hx_samples_dev (str): updated deviations
mask_obs_assimilated (np.array): updated mask of assimilated observations
"""
# Mask to select observations to assimilate
mask_obsvect = obsvect.obsvect_mask
# Mask to select observations within the segment
mask_obs_segment = ~np.any(np.isnan(hx_samples_dev), axis=1)
# TODO: detect Nans between good values, means there is a problem with the model
# TODO: detect the last good value and if there is a nan before, then problem
# Replace simulations that are NaNs with 0
# hx_samples_dev[~mask_obs_segment] = 0
# Mask to remove lines where sim == 0 so the matrix is always invertible
# mask_sim_zero = np.all(hx_samples_dev == 0, axis=1)
# Derive the obs indexes to assimilate based on masks
mask_obs2assim = mask_obsvect & mask_obs_segment & \
~mask_obs_assimilated
list_obs2assim_idx = np.nonzero(mask_obs2assim)[0]
return mask_obs_segment, mask_obs2assim
[docs]
def get_cntrlv_trdates(controlvect):
"""Fetch the dates for each component and each tracer of the controlvector.
Args:
controlvect (Plugin): controlvect Plugin
Returns:
dict_components_dates (dict): Dictionary with list of dates for each component and each tracer.
"""
dict_components_dates = {}
components = controlvect.datavect.components
for comp in components.attributes:
component = getattr(components, comp)
dict_components_dates[comp] = {}
if not hasattr(component, "parameters"):
continue
for trcr in component.parameters.attributes:
tracer = getattr(component.parameters, trcr)
if not tracer.iscontrol:
continue
dates_id, vert_id, horiz_id = np.meshgrid(
range(tracer.ndates),
range(tracer.vresoldim),
range(tracer.hresoldim))
dates_array = tracer.dates[dates_id].flatten()
dict_components_dates[comp][trcr] = dates_array
return dict_components_dates
[docs]
def get_cntrlv_idx_segment(controlvect, ddi, ddf):
"""Find the controlvect indexes in the current segment
Args:
controlvect (Plugin): controlvect Plugin
ddi (datetime): initial date of the segment
ddf (datetime): end date of the segment
Returns:
list_cntrlv_idx (dict): indexes of the controlvect that are in the segment
"""
# Find the right controlvect indexes to update
mask_ctrlv_segment = np.array([False] * controlvect.dim)
dict_components_dates = get_cntrlv_trdates(controlvect)
components = controlvect.datavect.components
for comp in components.attributes:
component = getattr(components, comp)
for trcr, dates_array in dict_components_dates[comp].items():
tracer = getattr(component.parameters, trcr)
mask_trcr_segment = (ddi <= dates_array) & (dates_array < ddf)
mask_ctrlv_segment[tracer.xpointer: tracer.xpointer + tracer.dim] = \
mask_ctrlv_segment[tracer.xpointer: tracer.xpointer + tracer.dim] | \
mask_trcr_segment
list_cntrlv_idx = np.nonzero(mask_ctrlv_segment)[0]
return list_cntrlv_idx
[docs]
def get_cntrlv_idx_comp(controlvect, ddi, ddf):
"""Find the controlvect indexes for each optimized component.
Args:
controlvect (Plugin): controlvect Plugin
ddi (datetime): initial date of the segment
ddf (datetime): end date of the segment
Returns:
list_cntrlv_idx (dict): indexes of the controlvect that are in the segment
"""
dict_components_idx = {}
new_cntrlv_dim = 0
components = controlvect.datavect.components
for comp in components.attributes:
component = getattr(components, comp)
dict_components_idx[comp] = {}
if not hasattr(component, "parameters"):
continue
for trcr in component.parameters.attributes:
tracer = getattr(component.parameters, trcr)
if not tracer.iscontrol:
continue
dates_id, vert_id, horiz_id = np.meshgrid(
range(tracer.ndates),
range(tracer.vresoldim),
range(tracer.hresoldim))
dates_array = tracer.dates[dates_id].flatten()
mask_crop = (ddi <= dates_array) & (dates_array < ddf)
new_tracer_dim = dates_array[mask_crop].size
if new_tracer_dim == 0:
continue
dict_components_idx[comp][trcr] = \
np.arange(new_cntrlv_dim, new_cntrlv_dim + new_tracer_dim)
new_cntrlv_dim += new_tracer_dim
return dict_components_idx
[docs]
def crop_dump(cntrlv, cntrlv_file, ddi_crop, ddf_crop, ensemble=False, dump_indexes=False, **kwargs):
"""Dumps the controlvector into a picke file. Only the elements that
are inside the crop window are dumped.
Args:
cntrlv (pycif.utils.classes.controlvects.ControlVect):
the Control Vector to dump
cntrlv_file (str): path to the file to dump as pickle
ddi_crop (datetime): cropping start date
ddf_crop (datetime): cropping end date
"""
debug("Cropping and dumping the control vector to {}".format(cntrlv_file))
# Store the controlvect index for each component/tracer
dict_components_idx = {}
# Saving recursive attributes from the Yaml
exclude = ["transform", "domain", "datastore",
"input_dates", "obsvect", "tracer", "input_files",
"tstep_dates", "tstep_all", "dataflx",
"logfile", "datei", "datef", "workdir",
"verbose", "subsimu_dates", "tcorrelations",
"hcorrelations", "databos"]
tosave = cntrlv.to_dict(cntrlv, exclude_patterns=exclude)
# Save the control vector as a pandas datastore
controlvect_ds = {}
diminfos_ds = {}
components = cntrlv.datavect.components
new_cntrlv_dim = 0
for comp in components.attributes:
component = getattr(components, comp)
# Skip if component does not have parameters
if not hasattr(component, "parameters"):
continue
for trcr in component.parameters.attributes:
tracer = getattr(component.parameters, trcr)
# Do nothing if not in control vector
if not tracer.iscontrol:
continue
# Update controlvect_ds dictionary
if comp not in controlvect_ds:
controlvect_ds[comp] = {}
if comp not in diminfos_ds:
diminfos_ds[comp] = {}
# Fetch information for tmp ds
tmp_ds = {}
dates_id, vert_id, horiz_id = np.meshgrid(
range(tracer.ndates),
range(tracer.vresoldim),
range(tracer.hresoldim))
dates_array = tracer.dates[dates_id].flatten()
# TODO: open end for ddf but close end for the controlvect idxs
mask_crop = (ddi_crop <= dates_array) & (dates_array <= ddf_crop)
tmp_ds["horiz_id"] = horiz_id.flatten()[mask_crop]
del horiz_id
tmp_ds["vert_id"] = vert_id.flatten()[mask_crop]
del vert_id
tmp_ds["date"] = dates_array[mask_crop]
del dates_id
# Reducing memory usage with panda's categorical dtype
# (this improves pickle's read/write times)
for col in ['date', 'horiz_id', 'vert_id']:
tmp_ds[col] = pd.Series(tmp_ds[col]).astype('category')
new_tracer_dim = tmp_ds['date'].size
if new_tracer_dim == 0:
continue
var2read = ["x", "xb", "dx", "std", "pa"]
start_idx = tracer.xpointer + np.nonzero(mask_crop)[0][0]
end_idx = start_idx + new_tracer_dim
for var in var2read:
if hasattr(cntrlv, var):
if getattr(cntrlv, var).ndim == 1:
tmp_ds[var] = \
getattr(cntrlv, var)[start_idx: end_idx]
else:
tmp_ds[var] = copy.deepcopy(
np.diag(getattr(cntrlv, var))[start_idx: end_idx])
controlvect_ds[comp][trcr] = tmp_ds
# Save pointers
diminfos_ds[comp][trcr] = {
"xpointer": new_cntrlv_dim,
"dim": new_tracer_dim
}
# Save ensemble
if ensemble and hasattr(cntrlv, 'x_ens'):
nsamples = cntrlv.x_ens.shape[0]
for isample in range(nsamples):
trcr_sample = trcr + "__sample#{:03d}".format(isample)
controlvect_ds[comp][trcr_sample] = {}
for col in ['date', 'horiz_id', 'vert_id']:
controlvect_ds[comp][trcr_sample][col] = \
copy.deepcopy(controlvect_ds[comp][trcr][col])
xsample = cntrlv.x_ens[isample, start_idx: end_idx]
controlvect_ds[comp][trcr_sample]["x"] = xsample
new_cntrlv_dim += new_tracer_dim
# Dump the dictionary to a pickle
tosave["datastore"] = controlvect_ds
tosave["dim_infos"] = diminfos_ds
with open(cntrlv_file, "wb") as f:
pickle.dump(tosave, f, pickle.HIGHEST_PROTOCOL)
# Dump the controlvect indexes
if dump_indexes:
dict_components_idx = get_cntrlv_idx_comp(cntrlv, ddi_crop, ddf_crop)
comp_idx_file = os.path.join(os.path.dirname(cntrlv_file), "comp_indexes.pickle")
with open(comp_idx_file, "wb") as f:
pickle.dump(dict_components_idx, f, pickle.HIGHEST_PROTOCOL)