Source code for pycif.plugins.modes.ensrf.segment_cropping

import copy
import numpy as np
import pandas as pd
import os
from logging import debug, warning

try:
    import cPickle as pickle
except ImportError:
    import pickle

[docs] def select_obs2assimilate(obsvect, hx_samples_dev, mask_obs_assimilated): """ Apply multiple operations to select the information (observations, controlvect) that must be considered/updated as part of the current optimization segment. 1) Select the observations that must be assimilated 2) Select the part of the controlvector that must be updated Args: obsvect (Plugin): obsvect Plugin hx_samples_dev (str): for each member, deviations of the simulated values from the mean mask_obs_assimilated (np.array): mask indicating the observations already assimilated Returns: list_obs2assim_idx (list): list of observation indexes that must be assimilated hx_samples_dev (str): updated deviations mask_obs_assimilated (np.array): updated mask of assimilated observations """ # Mask to select observations to assimilate mask_obsvect = obsvect.obsvect_mask # Mask to select observations within the segment mask_obs_segment = ~np.any(np.isnan(hx_samples_dev), axis=1) # TODO: detect Nans between good values, means there is a problem with the model # TODO: detect the last good value and if there is a nan before, then problem # Replace simulations that are NaNs with 0 # hx_samples_dev[~mask_obs_segment] = 0 # Mask to remove lines where sim == 0 so the matrix is always invertible # mask_sim_zero = np.all(hx_samples_dev == 0, axis=1) # Derive the obs indexes to assimilate based on masks mask_obs2assim = mask_obsvect & mask_obs_segment & \ ~mask_obs_assimilated list_obs2assim_idx = np.nonzero(mask_obs2assim)[0] return mask_obs_segment, mask_obs2assim
[docs] def get_cntrlv_trdates(controlvect): """Fetch the dates for each component and each tracer of the controlvector. Args: controlvect (Plugin): controlvect Plugin Returns: dict_components_dates (dict): Dictionary with list of dates for each component and each tracer. """ dict_components_dates = {} components = controlvect.datavect.components for comp in components.attributes: component = getattr(components, comp) dict_components_dates[comp] = {} if not hasattr(component, "parameters"): continue for trcr in component.parameters.attributes: tracer = getattr(component.parameters, trcr) if not tracer.iscontrol: continue dates_id, vert_id, horiz_id = np.meshgrid( range(tracer.ndates), range(tracer.vresoldim), range(tracer.hresoldim)) dates_array = tracer.dates[dates_id].flatten() dict_components_dates[comp][trcr] = dates_array return dict_components_dates
[docs] def get_cntrlv_idx_segment(controlvect, ddi, ddf): """Find the controlvect indexes in the current segment Args: controlvect (Plugin): controlvect Plugin ddi (datetime): initial date of the segment ddf (datetime): end date of the segment Returns: list_cntrlv_idx (dict): indexes of the controlvect that are in the segment """ # Find the right controlvect indexes to update mask_ctrlv_segment = np.array([False] * controlvect.dim) dict_components_dates = get_cntrlv_trdates(controlvect) components = controlvect.datavect.components for comp in components.attributes: component = getattr(components, comp) for trcr, dates_array in dict_components_dates[comp].items(): tracer = getattr(component.parameters, trcr) mask_trcr_segment = (ddi <= dates_array) & (dates_array < ddf) mask_ctrlv_segment[tracer.xpointer: tracer.xpointer + tracer.dim] = \ mask_ctrlv_segment[tracer.xpointer: tracer.xpointer + tracer.dim] | \ mask_trcr_segment list_cntrlv_idx = np.nonzero(mask_ctrlv_segment)[0] return list_cntrlv_idx
[docs] def get_cntrlv_idx_comp(controlvect, ddi, ddf): """Find the controlvect indexes for each optimized component. Args: controlvect (Plugin): controlvect Plugin ddi (datetime): initial date of the segment ddf (datetime): end date of the segment Returns: list_cntrlv_idx (dict): indexes of the controlvect that are in the segment """ dict_components_idx = {} new_cntrlv_dim = 0 components = controlvect.datavect.components for comp in components.attributes: component = getattr(components, comp) dict_components_idx[comp] = {} if not hasattr(component, "parameters"): continue for trcr in component.parameters.attributes: tracer = getattr(component.parameters, trcr) if not tracer.iscontrol: continue dates_id, vert_id, horiz_id = np.meshgrid( range(tracer.ndates), range(tracer.vresoldim), range(tracer.hresoldim)) dates_array = tracer.dates[dates_id].flatten() mask_crop = (ddi <= dates_array) & (dates_array < ddf) new_tracer_dim = dates_array[mask_crop].size if new_tracer_dim == 0: continue dict_components_idx[comp][trcr] = \ np.arange(new_cntrlv_dim, new_cntrlv_dim + new_tracer_dim) new_cntrlv_dim += new_tracer_dim return dict_components_idx
[docs] def crop_dump(cntrlv, cntrlv_file, ddi_crop, ddf_crop, ensemble=False, dump_indexes=False, **kwargs): """Dumps the controlvector into a picke file. Only the elements that are inside the crop window are dumped. Args: cntrlv (pycif.utils.classes.controlvects.ControlVect): the Control Vector to dump cntrlv_file (str): path to the file to dump as pickle ddi_crop (datetime): cropping start date ddf_crop (datetime): cropping end date """ debug("Cropping and dumping the control vector to {}".format(cntrlv_file)) # Store the controlvect index for each component/tracer dict_components_idx = {} # Saving recursive attributes from the Yaml exclude = ["transform", "domain", "datastore", "input_dates", "obsvect", "tracer", "input_files", "tstep_dates", "tstep_all", "dataflx", "logfile", "datei", "datef", "workdir", "verbose", "subsimu_dates", "tcorrelations", "hcorrelations", "databos"] tosave = cntrlv.to_dict(cntrlv, exclude_patterns=exclude) # Save the control vector as a pandas datastore controlvect_ds = {} diminfos_ds = {} components = cntrlv.datavect.components new_cntrlv_dim = 0 for comp in components.attributes: component = getattr(components, comp) # Skip if component does not have parameters if not hasattr(component, "parameters"): continue for trcr in component.parameters.attributes: tracer = getattr(component.parameters, trcr) # Do nothing if not in control vector if not tracer.iscontrol: continue # Update controlvect_ds dictionary if comp not in controlvect_ds: controlvect_ds[comp] = {} if comp not in diminfos_ds: diminfos_ds[comp] = {} # Fetch information for tmp ds tmp_ds = {} dates_id, vert_id, horiz_id = np.meshgrid( range(tracer.ndates), range(tracer.vresoldim), range(tracer.hresoldim)) dates_array = tracer.dates[dates_id].flatten() # TODO: open end for ddf but close end for the controlvect idxs mask_crop = (ddi_crop <= dates_array) & (dates_array <= ddf_crop) tmp_ds["horiz_id"] = horiz_id.flatten()[mask_crop] del horiz_id tmp_ds["vert_id"] = vert_id.flatten()[mask_crop] del vert_id tmp_ds["date"] = dates_array[mask_crop] del dates_id # Reducing memory usage with panda's categorical dtype # (this improves pickle's read/write times) for col in ['date', 'horiz_id', 'vert_id']: tmp_ds[col] = pd.Series(tmp_ds[col]).astype('category') new_tracer_dim = tmp_ds['date'].size if new_tracer_dim == 0: continue var2read = ["x", "xb", "dx", "std", "pa"] start_idx = tracer.xpointer + np.nonzero(mask_crop)[0][0] end_idx = start_idx + new_tracer_dim for var in var2read: if hasattr(cntrlv, var): if getattr(cntrlv, var).ndim == 1: tmp_ds[var] = \ getattr(cntrlv, var)[start_idx: end_idx] else: tmp_ds[var] = copy.deepcopy( np.diag(getattr(cntrlv, var))[start_idx: end_idx]) controlvect_ds[comp][trcr] = tmp_ds # Save pointers diminfos_ds[comp][trcr] = { "xpointer": new_cntrlv_dim, "dim": new_tracer_dim } # Save ensemble if ensemble and hasattr(cntrlv, 'x_ens'): nsamples = cntrlv.x_ens.shape[0] for isample in range(nsamples): trcr_sample = trcr + "__sample#{:03d}".format(isample) controlvect_ds[comp][trcr_sample] = {} for col in ['date', 'horiz_id', 'vert_id']: controlvect_ds[comp][trcr_sample][col] = \ copy.deepcopy(controlvect_ds[comp][trcr][col]) xsample = cntrlv.x_ens[isample, start_idx: end_idx] controlvect_ds[comp][trcr_sample]["x"] = xsample new_cntrlv_dim += new_tracer_dim # Dump the dictionary to a pickle tosave["datastore"] = controlvect_ds tosave["dim_infos"] = diminfos_ds with open(cntrlv_file, "wb") as f: pickle.dump(tosave, f, pickle.HIGHEST_PROTOCOL) # Dump the controlvect indexes if dump_indexes: dict_components_idx = get_cntrlv_idx_comp(cntrlv, ddi_crop, ddf_crop) comp_idx_file = os.path.join(os.path.dirname(cntrlv_file), "comp_indexes.pickle") with open(comp_idx_file, "wb") as f: pickle.dump(dict_components_idx, f, pickle.HIGHEST_PROTOCOL)