Source code for pycif.plugins.controlvects.standard.dump

import os
import copy
import numpy as np
import xarray as xr
import pandas as pd
from logging import debug, warning
import tracemalloc
from ....utils.path import init_dir
from ....plugins.transforms.system.fromcontrol.utils.scalemaps \
    import scale2map
from ....utils.dataarrays.reindex import reindex

    import cPickle as pickle
except ImportError:
    import pickle

[docs] def dump(self, cntrl_file, to_netcdf=False, dir_netcdf=None, ensemble=False, **kwargs): """Dumps a control vector into a pickle file. Does not save large correlations. Args: self (pycif.utils.classes.controlvects.ControlVect): the Control Vector to dump cntrl_file (str): path to the file to dump as pickle to_netcdf (bool): save to netcdf files if True dir_netcdf (str): root path for the netcdf directory """ debug("Dumping the control vector to {}".format(cntrl_file)) # Saving recursive attributes from the Yaml exclude = ["transform", "domain", "datastore", "input_dates", "obsvect", "tracer", "input_files", "tstep_dates", "tstep_all", "dataflx", "logfile", "datei", "datef", "workdir", "verbose", "subsimu_dates", "tcorrelations", "hcorrelations", "databos"] tosave = self.to_dict(self, exclude_patterns=exclude) # Save the control vector as a pandas datastore controlvect_ds = {} diminfos_ds = {} components = self.datavect.components for comp in components.attributes: component = getattr(components, comp) # Skip if component does not have parameters if not hasattr(component, "parameters"): continue dir_comp = "{}/{}".format(dir_netcdf, comp) init_dir(dir_comp) for trcr in component.parameters.attributes: tracer = getattr(component.parameters, trcr) # Do nothing if not in control vector if not tracer.iscontrol: continue # Update controlvect_ds dictionary if comp not in controlvect_ds: controlvect_ds[comp] = {} if comp not in diminfos_ds: diminfos_ds[comp] = {} # Fetch information for tmp ds var2read = ["x", "xb", "dx", "std", "pa"] tmp_ds = {} for var in var2read: if hasattr(self, var): if getattr(self, var).ndim == 1: tmp_ds[var] = \ getattr(self, var)[ tracer.xpointer: tracer.xpointer + tracer.dim] else: tmp_ds[var] = copy.deepcopy( np.diag(getattr(self, var))[ tracer.xpointer: tracer.xpointer + tracer.dim]) dates_id, vert_id, horiz_id = np.meshgrid( range(tracer.ndates), range(tracer.vresoldim), range(tracer.hresoldim)) tmp_ds["horiz_id"] = horiz_id.flatten() del horiz_id tmp_ds["vert_id"] = vert_id.flatten() del vert_id tmp_ds["date"] = tracer.dates[dates_id].flatten() del dates_id # Reducing memory usage with panda's categorical dtype # (this improves pickle's read/write times) for col in ['date', 'horiz_id', 'vert_id']: tmp_ds[col] = pd.Series(tmp_ds[col]).astype('category') controlvect_ds[comp][trcr] = tmp_ds # Save pointers diminfos_ds[comp][trcr] = { "xpointer": tracer.xpointer, "dim": tracer.dim } # Variables with ensemble data ('x' ensemble data is named 'x_ens') ensemble_variables = ['x', 'dx'] # Saving ensemble data if ensemble: # Getting number of samples n_samples = 0 for var_name in ensemble_variables: ens_var_name = f'{var_name}_ens' if hasattr(self, ens_var_name): n_samples = getattr(self, ens_var_name).shape[0] if n_samples == 0: raise ValueError( "No ensemble variables in the control vector for " f"variable {comp}/{trcr}" ) # TODO: is it really usefull use separate variables for each # samples? Why not saving the samples matrix as is ? for sample_index in range(n_samples): trcr_sample = f"{trcr}__sample#{sample_index:03d}" # TODO: may not be usefull, controlvect.load get those # columns from 'ds[comp][trcr]' and not from # ds[comp][trcr_sample] # # Copying tracer data in tracer sample controlvect_ds[comp][trcr_sample] = { col: copy.deepcopy(controlvect_ds[comp][trcr][col]) for col in ['date', 'horiz_id', 'vert_id'] } for var_name in ensemble_variables: ens_var_name = f'{var_name}_ens' if not hasattr(self, ens_var_name): continue # Extracting sample data from ensemble data var = getattr(self, ens_var_name) controlvect_ds[comp][trcr_sample][var_name] = \ var[sample_index, tracer.xpointer: tracer.xpointer + tracer.dim] # Don't go further if no need to dump as netcdf if not to_netcdf or dir_netcdf is None: continue debug("Dumping control vector as NetCDF for {}/{}" .format(comp, trcr)) # Translating x and xb to maps x = np.reshape( self.x[tracer.xpointer: tracer.xpointer + tracer.dim], (tracer.ndates, tracer.vresoldim, -1), ) x = scale2map(x, tracer, tracer.dates, tracer.domain) xb = np.reshape( self.xb[tracer.xpointer: tracer.xpointer + tracer.dim], (tracer.ndates, tracer.vresoldim, -1), ) xb = scale2map(xb, tracer, tracer.dates, tracer.domain) std = np.reshape( self.std[tracer.xpointer: tracer.xpointer + tracer.dim], (tracer.ndates, tracer.vresoldim, -1), ) std = scale2map(std, tracer, tracer.dates, tracer.domain) dx = np.reshape( self.dx[tracer.xpointer: tracer.xpointer + tracer.dim], (tracer.ndates, tracer.vresoldim, -1), ) dx = scale2map(dx, tracer, tracer.dates, tracer.domain) # Adding the diagonal of posterior uncertainties if available if hasattr(self, "pa"): if == 2: pa = np.diag( else: pa = pa_std = np.reshape(np.sqrt( pa[tracer.xpointer: tracer.xpointer + tracer.dim]), (tracer.ndates, tracer.vresoldim, -1), ) pa_std = scale2map(pa_std, tracer, tracer.dates, tracer.domain) ds = xr.Dataset({"x": x, "xb": xb, "dx": dx, "b_std": std, "pa_std": pa_std}) else: ds = xr.Dataset({"x": x, "xb": xb, "dx": dx, "b_std": std}) # If tracer is scalar, also include the "physical" projection if getattr(tracer, "type", "scalar") == "scalar" \ and getattr(tracer, "dump_physical", True): # Read the tracer array and apply the present control vector # scaling factor # Apply same protocol as ini_mapper from transform "fromcontrol" # to find correct dates (merged input_dates and tracer dates) ds_phys = None for di in tracer.input_dates: # Skip if input dates are empty for some reason for that # period if len(tracer.input_dates[di]) == 0: continue outdates = pd.DatetimeIndex(np.sort(np.unique(np.append( tracer.input_dates[di], tracer.dates )))).to_pydatetime() if len(outdates) == 1: outdates = np.append(outdates, outdates) mask_min = np.zeros(len(outdates), dtype=bool) \ if tracer.input_dates[di] == [] \ else outdates >= np.min(tracer.input_dates[di]) mask_max = np.zeros(len(outdates), dtype=bool) \ if tracer.input_dates[di] == [] \ else outdates <= np.max(tracer.input_dates[di]) outdates = outdates[mask_min & mask_max] outdates = pd.to_datetime(outdates) # Read reference inputs inputs = trcr, tracer.varname, tracer.input_dates[di], tracer.input_files[di], comp_type=comp, tracer=tracer, ddi=di, model=self.model, **kwargs ) # Check that horizontal dimensions are compatible input_dims = inputs.shape[2:] xb_dims = xb.shape[2:] if input_dims != xb_dims: raise Exception( "Dimensions for inputs and xb are not compatible. \n" "This can arise if `is_lbc` has erroneously been set to True in your yaml.\n" f"\t- Input dimension: {input_dims}\n" f"\t- Xb dimension: {xb_dims}" ) # Reindex xb, x and inputs to common outdates inputs = reindex( inputs, levels={"time": outdates[:-1]}, ) xb_phys = inputs * reindex( xb, levels={"time": outdates[:-1], "lev": inputs.lev}, ) x_phys = inputs * reindex( x, levels={"time": outdates[:-1], "lev": inputs.lev}, ) dx_phys = inputs * reindex( dx, levels={"time": outdates[:-1], "lev": inputs.lev}, ) b_phys = inputs * reindex( std, levels={"time": outdates[:-1], "lev": inputs.lev}, ) ds_tmp = xr.Dataset({ "x_phys": x_phys, "dx_phys": dx_phys, "xb_phys": xb_phys, "b_phys": b_phys}) if hasattr(self, "pa"): pa_phys = \ inputs * reindex( pa_std, levels={"time": outdates[:-1], "lev": inputs.lev}, ) ds_tmp = ds_tmp.assign(pa_phys=pa_phys) if ds_phys is None: ds_phys = ds_tmp else: ds_phys = xr.concat([ds_phys, ds_tmp], dim="time", join="inner") # Drop duplicated times index_unique = np.unique(ds_phys["time"], return_index=True)[1] ds_phys = ds_phys.isel({"time": index_unique}) # Merge with non-physical values ds_phys = ds_phys.rename({"time": "time_phys"}) ds = ds.merge(ds_phys) # Adding longitudes and latitudes if not getattr(tracer, "is_lbc", False): ds = ds.assign( latitudes=(("lat", "lon"), tracer.domain.zlat), longitudes=(("lat", "lon"), tracer.domain.zlon), latitudes_corner=(("latc", "lonc"), tracer.domain.zlatc), longitudes_corner=(("latc", "lonc"), tracer.domain.zlonc)) else: ds = ds.assign( latitudes=(("lat", "lon"), tracer.domain.zlat_side), longitudes=(("lat", "lon"), tracer.domain.zlon_side), latitudes_corner=(("latc", "lonc"), tracer.domain.zlatc_side), longitudes_corner=(("latc", "lonc"), tracer.domain.zlonc_side)) # Adding areas if not getattr(tracer, "is_lbc", False): if not hasattr(tracer.domain, "areas"): tracer.domain.calc_areas() ds = ds.assign(areas=(("lat", "lon"), tracer.domain.areas)) # Dumping controlvect_file = "{}/controlvect_{}_{}.nc".format( dir_comp, comp, trcr) if os.path.exists(controlvect_file): os.remove(controlvect_file) ds.to_netcdf(controlvect_file) # Dumping the dictionary to a pickle tosave["datastore"] = controlvect_ds tosave["dim_infos"] = diminfos_ds with open(cntrl_file, "wb") as f: pickle.dump(tosave, f, pickle.HIGHEST_PROTOCOL)
[docs] def load(self, cntrl_file, component2load=None, tracer2load=None, target_tracer=None, ensemble=False, **kwargs): debug("Loading control vector from {}".format(cntrl_file)) if ensemble and component2load is None and tracer2load is None: warning("Trying to load control vector ensemble data without " "specifying some specific component and tracer to load") with open(cntrl_file, "rb") as f: toread = pickle.load(f) out_ds = toread["datastore"] del toread["datastore"] out = self.from_dict(toread) # Loop over components and tracers components = self.datavect.components list_components = components.attributes if component2load is None \ else [component2load] for comp in list_components: component = getattr(components, comp) # Skip if component does not have parameters if not hasattr(component, "parameters"): continue # Skip if component not in pickle if comp not in out_ds: warning("Could not read component '{}' for pickle {}. " .format(comp, cntrl_file)) continue comp_ds = out_ds[comp] list_tracers = component.parameters.attributes if tracer2load is None \ else [tracer2load] for trcr in list_tracers: tracer = getattr(component.parameters, trcr if target_tracer is None else target_tracer) # Do nothing if not in control vector if not tracer.iscontrol: continue # Skip if component not in pickle if not trcr in comp_ds: warning("Could not read tracer '{}/{}' for pickle {}. " .format(comp, trcr, cntrl_file)) continue debug("Loading variable {}/{}".format(comp, trcr)) # Fill the correct chunk with corresponding values dates_id, vert_id, horiz_id = np.meshgrid( range(tracer.ndates), range(tracer.vresoldim), range(tracer.hresoldim)) dates_id = tracer.dates[dates_id] target_index = pd.MultiIndex.from_arrays([ dates_id.flatten(), horiz_id.flatten(), vert_id.flatten() ]) tmp_ds = pd.DataFrame(comp_ds[trcr]) # Removing categorical dtypes (speeds up the reindixing) tmp_ds['date'] ='datetime64[ns]') tmp_ds['horiz_id'] = tmp_ds.horiz_id.astype('int') tmp_ds['vert_id'] = tmp_ds.vert_id.astype('int') if ensemble: # Keeping index data for ensemble variables index_data = tmp_ds[['date', 'horiz_id', 'vert_id']].copy() # Reindexing tmp_ds.set_index(["date", "horiz_id", "vert_id"], inplace=True) tmp_ds = tmp_ds.reindex(target_index, copy=False) # Loop over variables to initialize var2read = ["x", "xb", "dx", "std", "pa"] for var in var2read: if var not in tmp_ds: continue # Initialize variable to zero in controlvect if not already here array = getattr(self, var, np.zeros(self.dim)) array[tracer.xpointer: tracer.xpointer + tracer.dim] = \ tmp_ds[var].values setattr(self, var, array) # Variables with ensemble data ('x' ensemble data is named 'x_ens') ensemble_variables = ['x', 'dx'] # Loading ensemble data if ensemble: # Getting variable names and number of samples sample_names = [] max_number_of_samples = 1000 for sample_index in range(max_number_of_samples + 1): # Sample variable name trcr_sample = f"{trcr}__sample#{sample_index:03d}" if trcr_sample in comp_ds: sample_names.append(trcr_sample) else: n_samples = sample_index break else: # Branch here when reaching the end of the for loop without # breaking, i.e. when max number of samples is reached. raise ValueError( f"Too many sample in ensemble for variable {comp}/{trcr}") if n_samples == 0: raise ValueError( f"No samples detected in the control vector for variable {comp}/{trcr}") # Initializing empty ensemble data ens_data = {var_name: [] for var_name in ensemble_variables} # Iterating over samples to extract ensemble data for trcr_sample in sample_names: # Sample datastore sample_ds = comp_ds[trcr_sample] # Getting sample data for var_name in ensemble_variables: if var_name in sample_ds: ens_data[var_name].append(sample_ds[var_name]) # Iterating over ensemble variables to format and save data for var_name in ensemble_variables: if not ens_data[var_name]: continue # Concatenating samples data var = np.concatenate( [col[:, np.newaxis] for col in ens_data[var_name]], axis=1 ) assert var.shape[1] == n_samples # Formatting ensemble data tmp_sample_ds = pd.DataFrame( data=var, columns=sample_names) tmp_sample_ds = pd.concat( [index_data, tmp_sample_ds], axis='columns') tmp_sample_ds.set_index( ['date', 'horiz_id', 'vert_id'], inplace=True) tmp_sample_ds = tmp_sample_ds.reindex( target_index, copy=False) # Initialize variable to zero in controlvect if not already here ens_var_name = f'{var_name}_ens' ens_var = getattr(self, ens_var_name, np.zeros((n_samples, self.dim))) # Saving formatted data ens_var[:, tracer.xpointer: tracer.xpointer + tracer.dim] = \ tmp_sample_ds.values.T setattr(self, ens_var_name, ens_var) debug("Successfully loaded control vector from {}".format(cntrl_file)) return out