Source code for pycif.plugins.controlvects.standard.dump

import copy

import numpy as np
import xarray as xr
import pandas as pd
from logging import debug, warning

from ....utils.path import init_dir
from ....plugins.transforms.system.fromcontrol.utils.scalemaps \
    import scale2map
from ....utils.dataarrays.reindex import reindex

try:
    import cPickle as pickle
except ImportError:
    import pickle


[docs]def dump(self, cntrl_file, to_netcdf=False, dir_netcdf=None, **kwargs): """Dumps a control vector into a pickle file. Does not save large correlations. Args: self (pycif.utils.classes.controlvects.ControlVect): the Control Vector to dump cntrl_file (str): path to the file to dump as pickle to_netcdf (bool): save to netcdf files if True dir_netcdf (str): root path for the netcdf directory """ debug("Dumping the control vector to {}".format(cntrl_file)) # Saving recursive attributes from the Yaml exclude = ["transform", "domain", "datastore", "input_dates", "obsvect", "tracer", "input_files", "tstep_dates", "tstep_all", "dataflx", "logfile", "datei", "datef", "workdir", "verbose", "subsimu_dates", "tcorrelations", "hcorrelations", "databos"] tosave = self.to_dict(self, exclude_patterns=exclude) # Save the control vector as a pandas datastore controlvect_ds = {} components = self.datavect.components for comp in components.attributes: component = getattr(components, comp) # Skip if component does not have parameters if not hasattr(component, "parameters"): continue dir_comp = "{}/{}".format(dir_netcdf, comp) init_dir(dir_comp) for trcr in component.parameters.attributes: tracer = getattr(component.parameters, trcr) # Do nothing if not in control vector if not tracer.iscontrol: continue # Update controlvect_ds dictionary if comp not in controlvect_ds: controlvect_ds[comp] = {} # Fetch information for tmp ds tmp_ds = pd.DataFrame( columns=["x", "xb", "dx", "std", "pa", "component", "tracer", "date", "horiz_id", "vert_id"]) var2read = ["x", "xb", "dx", "std", "pa"] for var in var2read: if hasattr(self, var): if getattr(self, var).ndim == 1: tmp_ds.loc[:, var] = \ getattr(self, var)[ tracer.xpointer: tracer.xpointer + tracer.dim] else: tmp_ds.loc[:, var] = copy.deepcopy( np.diag(getattr(self, var))[ tracer.xpointer: tracer.xpointer + tracer.dim]) else: del tmp_ds[var] tmp_ds.loc[:, "component"] = comp tmp_ds.loc[:, "tracer"] = trcr dates_id, vert_id, horiz_id = np.meshgrid( range(tracer.ndates), range(tracer.vresoldim), range(tracer.hresoldim)) tmp_ds.loc[:, "horiz_id"] = horiz_id.flatten() tmp_ds.loc[:, "vert_id"] = vert_id.flatten() tmp_ds.loc[:, "date"] = tracer.dates[dates_id].flatten() controlvect_ds[comp][trcr] = tmp_ds # Don't go further if no need to dump as netcdf if not to_netcdf or dir_netcdf is None: continue debug("Dumping control vector as NetCDF for {}/{}" .format(comp, trcr)) # Translating x and xb to maps x = np.reshape( self.x[tracer.xpointer: tracer.xpointer + tracer.dim], (tracer.ndates, tracer.vresoldim, -1), ) x = scale2map(x, tracer, tracer.dates, tracer.domain) xb = np.reshape( self.xb[tracer.xpointer: tracer.xpointer + tracer.dim], (tracer.ndates, tracer.vresoldim, -1), ) xb = scale2map(xb, tracer, tracer.dates, tracer.domain) std = np.reshape( self.std[tracer.xpointer: tracer.xpointer + tracer.dim], (tracer.ndates, tracer.vresoldim, -1), ) std = scale2map(std, tracer, tracer.dates, tracer.domain) # Adding the diagonal of posterior uncertainties if available if hasattr(self, "pa"): if self.pa.ndim == 2: pa = np.diag(self.pa) else: pa = self.pa pa_std = np.reshape(np.sqrt(pa[ tracer.xpointer: tracer.xpointer + tracer.dim ]), (tracer.ndates, tracer.vresoldim, -1), ) pa_std = scale2map(pa_std, tracer, tracer.dates, tracer.domain) ds = xr.Dataset({"x": x, "xb": xb, "b_std": std, "pa_std": pa_std}) else: ds = xr.Dataset({"x": x, "xb": xb, "b_std": std}) # If tracer is scalar, also include the "physical" projection if getattr(tracer, "type", "scalar") == "scalar": # Read the tracer array and apply the present control vector # scaling factor # Apply same protocol as ini_mapper from transform "fromcontrol" # to find correct dates (merged input_dates and tracer dates) ds_phys = None for di in tracer.input_dates: # Skip if input dates are empty for some reason for that # period if len(tracer.input_dates[di]) == 0: continue outdates = pd.DatetimeIndex(np.sort(np.unique(np.append( tracer.input_dates[di], tracer.dates )))).to_pydatetime() if len(outdates) == 1: outdates = np.append(outdates, outdates) mask_min = np.zeros(len(outdates), dtype=bool) \ if tracer.input_dates[di] == [] \ else outdates >= np.min(tracer.input_dates[di]) mask_max = np.zeros(len(outdates), dtype=bool) \ if tracer.input_dates[di] == [] \ else outdates <= np.max(tracer.input_dates[di]) outdates = outdates[mask_min & mask_max] outdates = pd.to_datetime(outdates) # Read reference inputs inputs = tracer.read( trcr, tracer.varname, tracer.input_dates[di], tracer.input_files[di], comp_type=comp, tracer=tracer, ddi=di, # model=self.model, **kwargs ) # Reindex xb, x and inputs to common outdates inputs = reindex( inputs, levels={"time": outdates[:-1]}, ) xb_phys = inputs * reindex( xb, levels={"time": outdates[:-1], "lev": inputs.lev}, ) x_phys = inputs * reindex( x, levels={"time": outdates[:-1], "lev": inputs.lev}, ) b_phys = inputs * reindex( std, levels={"time": outdates[:-1], "lev": inputs.lev}, ) ds_tmp = xr.Dataset({ "x_phys": x_phys, "xb_phys": xb_phys, "b_phys": b_phys}) if hasattr(self, "pa"): pa_phys = \ inputs * reindex( pa_std, levels={"time": outdates[:-1], "lev": inputs.lev}, ) ds_tmp = ds_tmp.assign(pa_phys=pa_phys) if ds_phys is None: ds_phys = ds_tmp else: ds_phys = xr.concat([ds_phys, ds_tmp], dim="time", join="inner") # Drop duplicated times index_unique = np.unique(ds_phys["time"], return_index=True)[1] ds_phys = ds_phys.isel({"time": index_unique}) # Merge with non-physical values ds_phys = ds_phys.rename({"time": "time_phys"}) ds = ds.merge(ds_phys) # Adding longitudes and latitudes if not getattr(tracer, "is_lbc", False): ds = ds.assign( latitudes=(("lat", "lon"), tracer.domain.zlat), longitudes=(("lat", "lon"), tracer.domain.zlon), latitudes_corner=(("latc", "lonc"), tracer.domain.zlatc), longitudes_corner=(("latc", "lonc"), tracer.domain.zlonc)) else: ds = ds.assign( latitudes=(("lat", "lon"), tracer.domain.zlat_side), longitudes=(("lat", "lon"), tracer.domain.zlon_side), latitudes_corner=(("latc", "lonc"), tracer.domain.zlatc_side), longitudes_corner=(("latc", "lonc"), tracer.domain.zlonc_side)) # Dumping ds.to_netcdf("{}/controlvect_{}_{}.nc".format(dir_comp, comp, trcr)) # Dumping the dictionary to a pickle tosave["datastore"] = controlvect_ds with open(cntrl_file, "wb") as f: pickle.dump(tosave, f, pickle.HIGHEST_PROTOCOL)
[docs]def load(self, cntrl_file, **kwargs): debug("Loading control vector from {}".format(cntrl_file)) with open(cntrl_file, "rb") as f: toread = pickle.load(f) out = self.from_dict(toread) out_ds = out.datastore # Loop over components and tracers components = self.datavect.components for comp in components.attributes: component = getattr(components, comp) # Skip if component does not have parameters if not hasattr(component, "parameters"): continue # Skip if component not in pickle if not hasattr(out_ds, comp): warning("Could not read component '{}' for pickle {}. " .format(comp, cntrl_file)) continue comp_ds = getattr(out_ds, comp) for trcr in component.parameters.attributes: tracer = getattr(component.parameters, trcr) # Do nothing if not in control vector if not tracer.iscontrol: continue # Skip if component not in pickle if not hasattr(comp_ds, trcr): warning("Could not read tracer '{}/{}' for pickle {}. " .format(comp, trcr, cntrl_file)) continue debug("Loading variable {}/{}".format(comp, trcr)) print(__file__) import code code.interact(local=dict(locals(), **globals())) # Fill the correct chunk with corresponding values dates_id, vert_id, horiz_id = np.meshgrid( range(tracer.ndates), range(tracer.vresoldim), range(tracer.hresoldim)) dates_id = tracer.dates[dates_id] tmp_ds = getattr(comp_ds, trcr).set_index( ["date", "horiz_id", "vert_id"]) tmp_ds = tmp_ds.reindex( list(zip(dates_id.flatten(), horiz_id.flatten(), vert_id.flatten()))) # Loop over variables to initialize var2read = ["x", "xb", "dx", "std", "pa"] for var in var2read: if var not in tmp_ds: continue # Initialize variable to zero in controlvect if not already here if not hasattr(self, var): setattr(self, var, np.zeros(self.dim)) getattr(self, var)[tracer.xpointer: tracer.xpointer + tracer.dim ] = tmp_ds[var].values debug("Successfully loaded control vector from {}".format(cntrl_file)) return out