Source code for pycif.plugins.controlvects.standard.dump

import copy
import numpy as np
import xarray as xr
import pandas as pd
from logging import debug, warning
import tracemalloc
from ....utils.path import init_dir
from ....plugins.transforms.system.fromcontrol.utils.scalemaps \
    import scale2map
from ....utils.dataarrays.reindex import reindex

try:
    import cPickle as pickle
except ImportError:
    import pickle


[docs]def dump(self, cntrl_file, to_netcdf=False, dir_netcdf=None, ensemble=False, **kwargs): """Dumps a control vector into a pickle file. Does not save large correlations. Args: self (pycif.utils.classes.controlvects.ControlVect): the Control Vector to dump cntrl_file (str): path to the file to dump as pickle to_netcdf (bool): save to netcdf files if True dir_netcdf (str): root path for the netcdf directory """ debug("Dumping the control vector to {}".format(cntrl_file)) # Saving recursive attributes from the Yaml exclude = ["transform", "domain", "datastore", "input_dates", "obsvect", "tracer", "input_files", "tstep_dates", "tstep_all", "dataflx", "logfile", "datei", "datef", "workdir", "verbose", "subsimu_dates", "tcorrelations", "hcorrelations", "databos"] tosave = self.to_dict(self, exclude_patterns=exclude) # Save the control vector as a pandas datastore controlvect_ds = {} diminfos_ds = {} components = self.datavect.components for comp in components.attributes: component = getattr(components, comp) # Skip if component does not have parameters if not hasattr(component, "parameters"): continue dir_comp = "{}/{}".format(dir_netcdf, comp) init_dir(dir_comp) for trcr in component.parameters.attributes: tracer = getattr(component.parameters, trcr) # Do nothing if not in control vector if not tracer.iscontrol: continue # Update controlvect_ds dictionary if comp not in controlvect_ds: controlvect_ds[comp] = {} if comp not in diminfos_ds: diminfos_ds[comp] = {} # Fetch information for tmp ds var2read = ["x", "xb", "dx", "std", "pa"] tmp_ds = {} for var in var2read: if hasattr(self, var): if getattr(self, var).ndim == 1: tmp_ds[var] = \ getattr(self, var)[ tracer.xpointer: tracer.xpointer + tracer.dim] else: tmp_ds[var] = copy.deepcopy( np.diag(getattr(self, var))[ tracer.xpointer: tracer.xpointer + tracer.dim]) dates_id, vert_id, horiz_id = np.meshgrid( range(tracer.ndates), range(tracer.vresoldim), range(tracer.hresoldim)) tmp_ds["horiz_id"] = horiz_id.flatten() del horiz_id tmp_ds["vert_id"] = vert_id.flatten() del vert_id tmp_ds["date"] = tracer.dates[dates_id].flatten() del dates_id # Reducing memory usage with panda's categorical dtype # (this improves pickle's read/write times) for col in ['date', 'horiz_id', 'vert_id']: tmp_ds[col] = pd.Series(tmp_ds[col]).astype('category') controlvect_ds[comp][trcr] = tmp_ds # Save pointers diminfos_ds[comp][trcr] = { "xpointer": tracer.xpointer, "dim": tracer.dim } # Save ensemble if ensemble and hasattr(self, 'x_ens'): nsamples = self.x_ens.shape[0] for isample in range(nsamples): trcr_sample = trcr + "__sample#{:03d}".format(isample) controlvect_ds[comp][trcr_sample] = {} for col in ['date', 'horiz_id', 'vert_id']: controlvect_ds[comp][trcr_sample][col] = \ copy.deepcopy(controlvect_ds[comp][trcr][col]) xsample = self.x_ens[ isample, tracer.xpointer: tracer.xpointer + tracer.dim] controlvect_ds[comp][trcr_sample]["x"] = xsample # Don't go further if no need to dump as netcdf if not to_netcdf or dir_netcdf is None: continue debug("Dumping control vector as NetCDF for {}/{}" .format(comp, trcr)) # Translating x and xb to maps x = np.reshape( self.x[tracer.xpointer: tracer.xpointer + tracer.dim], (tracer.ndates, tracer.vresoldim, -1), ) x = scale2map(x, tracer, tracer.dates, tracer.domain) xb = np.reshape( self.xb[tracer.xpointer: tracer.xpointer + tracer.dim], (tracer.ndates, tracer.vresoldim, -1), ) xb = scale2map(xb, tracer, tracer.dates, tracer.domain) std = np.reshape( self.std[tracer.xpointer: tracer.xpointer + tracer.dim], (tracer.ndates, tracer.vresoldim, -1), ) std = scale2map(std, tracer, tracer.dates, tracer.domain) dx = np.reshape( self.dx[tracer.xpointer: tracer.xpointer + tracer.dim], (tracer.ndates, tracer.vresoldim, -1), ) dx = scale2map(dx, tracer, tracer.dates, tracer.domain) # Adding the diagonal of posterior uncertainties if available if hasattr(self, "pa"): if self.pa.ndim == 2: pa = np.diag(self.pa) else: pa = self.pa pa_std = np.reshape(np.sqrt( pa[tracer.xpointer: tracer.xpointer + tracer.dim]), (tracer.ndates, tracer.vresoldim, -1), ) pa_std = scale2map(pa_std, tracer, tracer.dates, tracer.domain) ds = xr.Dataset({"x": x, "xb": xb, "dx": dx, "b_std": std, "pa_std": pa_std}) else: ds = xr.Dataset({"x": x, "xb": xb, "dx": dx, "b_std": std}) # If tracer is scalar, also include the "physical" projection if getattr(tracer, "type", "scalar") == "scalar" \ and getattr(tracer, "dump_physical", True): # Read the tracer array and apply the present control vector # scaling factor # Apply same protocol as ini_mapper from transform "fromcontrol" # to find correct dates (merged input_dates and tracer dates) ds_phys = None for di in tracer.input_dates: # Skip if input dates are empty for some reason for that # period if len(tracer.input_dates[di]) == 0: continue outdates = pd.DatetimeIndex(np.sort(np.unique(np.append( tracer.input_dates[di], tracer.dates )))).to_pydatetime() if len(outdates) == 1: outdates = np.append(outdates, outdates) mask_min = np.zeros(len(outdates), dtype=bool) \ if tracer.input_dates[di] == [] \ else outdates >= np.min(tracer.input_dates[di]) mask_max = np.zeros(len(outdates), dtype=bool) \ if tracer.input_dates[di] == [] \ else outdates <= np.max(tracer.input_dates[di]) outdates = outdates[mask_min & mask_max] outdates = pd.to_datetime(outdates) # Read reference inputs inputs = tracer.read( trcr, tracer.varname, tracer.input_dates[di], tracer.input_files[di], comp_type=comp, tracer=tracer, ddi=di, model=self.model, **kwargs ) # Reindex xb, x and inputs to common outdates inputs = reindex( inputs, levels={"time": outdates[:-1]}, ) xb_phys = inputs * reindex( xb, levels={"time": outdates[:-1], "lev": inputs.lev}, ) x_phys = inputs * reindex( x, levels={"time": outdates[:-1], "lev": inputs.lev}, ) dx_phys = inputs * reindex( dx, levels={"time": outdates[:-1], "lev": inputs.lev}, ) b_phys = inputs * reindex( std, levels={"time": outdates[:-1], "lev": inputs.lev}, ) ds_tmp = xr.Dataset({ "x_phys": x_phys, "dx_phys": dx_phys, "xb_phys": xb_phys, "b_phys": b_phys}) if hasattr(self, "pa"): pa_phys = \ inputs * reindex( pa_std, levels={"time": outdates[:-1], "lev": inputs.lev}, ) ds_tmp = ds_tmp.assign(pa_phys=pa_phys) if ds_phys is None: ds_phys = ds_tmp else: ds_phys = xr.concat([ds_phys, ds_tmp], dim="time", join="inner") # Drop duplicated times index_unique = np.unique(ds_phys["time"], return_index=True)[1] ds_phys = ds_phys.isel({"time": index_unique}) # Merge with non-physical values ds_phys = ds_phys.rename({"time": "time_phys"}) ds = ds.merge(ds_phys) # Adding longitudes and latitudes if not getattr(tracer, "is_lbc", False): ds = ds.assign( latitudes=(("lat", "lon"), tracer.domain.zlat), longitudes=(("lat", "lon"), tracer.domain.zlon), latitudes_corner=(("latc", "lonc"), tracer.domain.zlatc), longitudes_corner=(("latc", "lonc"), tracer.domain.zlonc)) else: ds = ds.assign( latitudes=(("lat", "lon"), tracer.domain.zlat_side), longitudes=(("lat", "lon"), tracer.domain.zlon_side), latitudes_corner=(("latc", "lonc"), tracer.domain.zlatc_side), longitudes_corner=(("latc", "lonc"), tracer.domain.zlonc_side)) # Adding areas if not getattr(tracer, "is_lbc", False): if not hasattr(tracer.domain, "areas"): tracer.domain.calc_areas() ds = ds.assign(areas=(("lat", "lon"), tracer.domain.areas)) # Dumping ds.to_netcdf( "{}/controlvect_{}_{}.nc".format(dir_comp, comp, trcr)) # Dumping the dictionary to a pickle tosave["datastore"] = controlvect_ds tosave["dim_infos"] = diminfos_ds with open(cntrl_file, "wb") as f: pickle.dump(tosave, f, pickle.HIGHEST_PROTOCOL)
[docs]def load(self, cntrl_file, component2load=None, tracer2load=None, target_tracer=None, ensemble=False, **kwargs): debug("Loading control vector from {}".format(cntrl_file)) with open(cntrl_file, "rb") as f: toread = pickle.load(f) out_ds = toread["datastore"] del toread["datastore"] out = self.from_dict(toread) # Loop over components and tracers components = self.datavect.components list_components = components.attributes if component2load is None \ else [component2load] for comp in list_components: component = getattr(components, comp) # Skip if component does not have parameters if not hasattr(component, "parameters"): continue # Skip if component not in pickle if comp not in out_ds: warning("Could not read component '{}' for pickle {}. " .format(comp, cntrl_file)) continue comp_ds = out_ds[comp] list_tracers = component.parameters.attributes if tracer2load is None \ else [tracer2load] for trcr in list_tracers: tracer = getattr(component.parameters, trcr if target_tracer is None else target_tracer) # Do nothing if not in control vector if not tracer.iscontrol: continue # Skip if component not in pickle if not trcr in comp_ds: warning("Could not read tracer '{}/{}' for pickle {}. " .format(comp, trcr, cntrl_file)) continue debug("Loading variable {}/{}".format(comp, trcr)) # Fill the correct chunk with corresponding values dates_id, vert_id, horiz_id = np.meshgrid( range(tracer.ndates), range(tracer.vresoldim), range(tracer.hresoldim)) dates_id = tracer.dates[dates_id] tmp_ds = pd.DataFrame(comp_ds[trcr]) # Removing categorical dtypes (speeds up the reindixing) tmp_ds['date'] = tmp_ds.date.astype('datetime64[ns]') tmp_ds['horiz_id'] = tmp_ds.horiz_id.astype('int') tmp_ds['vert_id'] = tmp_ds.vert_id.astype('int') tmp_ds.set_index(["date", "horiz_id", "vert_id"], inplace=True) tmp_ds = tmp_ds.reindex( pd.MultiIndex.from_arrays([ dates_id.flatten(), horiz_id.flatten(), vert_id.flatten() ]), copy=False ) # Loop over variables to initialize var2read = ["x", "xb", "dx", "std", "pa"] for var in var2read: if var not in tmp_ds: continue # Initialize variable to zero in controlvect if not already here if not hasattr(self, var): setattr(self, var, np.zeros(self.dim)) getattr(self, var)[tracer.xpointer: tracer.xpointer + tracer.dim ] = tmp_ds[var].values[:] if ensemble: var = "x_ens" list_samples = sorted([tr for tr in comp_ds if f"{trcr}__sample#" in tr]) nsamples = len(list_samples) if nsamples == 0: warning("No samples detected in the controlvect parameters.") continue if not hasattr(self, var): setattr(self, var, np.zeros((nsamples, self.dim))) tmp_ds_samples = pd.concat([pd.DataFrame(comp_ds[sample])["x"] for sample in list_samples], axis=1) tmp_ds_samples.columns = list_samples tmp_ds_samples = pd.concat([tmp_ds_samples, pd.DataFrame(comp_ds[list_samples[0]])[ ["date", "horiz_id", "vert_id"] ]], axis=1) # Removing categorical dtypes (speeds up the reindixing) tmp_ds_samples['date'] = tmp_ds_samples.date.astype( 'datetime64[ns]') tmp_ds_samples['horiz_id'] = tmp_ds_samples.horiz_id.astype( 'int') tmp_ds_samples['vert_id'] = tmp_ds_samples.vert_id.astype( 'int') tmp_ds_samples.set_index( ["date", "horiz_id", "vert_id"], inplace=True) tmp_ds_samples = tmp_ds_samples.reindex( pd.MultiIndex.from_arrays([ dates_id.flatten(), horiz_id.flatten(), vert_id.flatten() ]), copy=False ) getattr(self, var)[ :, tracer.xpointer: tracer.xpointer + tracer.dim] = tmp_ds_samples.values.T debug("Successfully loaded control vector from {}".format(cntrl_file)) return out