Source code for pycif.plugins.controlvects.standard.dump

import copy
from logging import debug, warning
from pathlib import Path

import numpy as np
import pandas as pd
import xarray as xr

from ....plugins.transforms.system.fromcontrol.utils.scalemaps import scale2map
from ....utils.dataarrays.reindex import reindex
from ....utils.iterators import iter_tracers

try:
    import cPickle as pickle
except ImportError:
    import pickle


[docs] def dump(self, cntrl_file, to_netcdf=False, dir_netcdf=None, ensemble=False, **kwargs): """Dumps a control vector into a pickle file. Does not save large correlations. Args: self (pycif.utils.classes.controlvects.ControlVect): the Control Vector to dump cntrl_file (str): path to the file to dump as pickle to_netcdf (bool): save to netcdf files if True dir_netcdf (str): root path for the netcdf directory """ debug(f"dumping controlvect to file '{cntrl_file}'") # Saving recursive attributes from the Yaml # fmt: off exclude = [ "transform", "domain", "datastore", "input_dates", "obsvect", "tracer", "input_files", "tstep_dates", "tstep_all", "dataflx", "logfile", "datei", "datef", "workdir", "verbose", "subsimu_dates", "tcorrelations", "hcorrelations", "databos" ] # fmt: on tosave = self.to_dict(self, exclude_patterns=exclude) # Save the control vector as a pandas datastore controlvect_ds = {} diminfos_ds = {} for (comp_name, tracer_name), tracer in iter_tracers(self.datavect): # Do nothing if not in control vector if not tracer.iscontrol: continue tracer_type = getattr(tracer, "type", "scalar") dump_physical = getattr(tracer, "dump_physical", True) is_lbc = getattr(tracer, "is_lbc", False) xslice = slice(tracer.xpointer, tracer.xpointer + tracer.dim) xshape = (tracer.ndates, tracer.vresoldim, -1) # Update controlvect_ds dictionary if comp_name not in controlvect_ds: controlvect_ds[comp_name] = {} if comp_name not in diminfos_ds: diminfos_ds[comp_name] = {} tmp_ds = {} # Fetch information for tmp ds var2read = ["x", "xb", "dx", "std", "pa"] for varname in var2read: if hasattr(self, varname): var = getattr(self, varname) if var.ndim == 1: tmp_ds[varname] = var[xslice] else: tmp_ds[varname] = np.diag(var)[xslice].copy() dates_id, vert_id, horiz_id = np.meshgrid( range(tracer.ndates), range(tracer.vresoldim), range(tracer.hresoldim), ) tmp_ds["horiz_id"] = horiz_id.flatten() del horiz_id tmp_ds["vert_id"] = vert_id.flatten() del vert_id tmp_ds["date"] = tracer.dates[dates_id].flatten() del dates_id # Reducing memory usage with panda's categorical dtype # (this improves pickle's read/write times) for col in ["date", "horiz_id", "vert_id"]: tmp_ds[col] = pd.Series(tmp_ds[col]).astype("category") controlvect_ds[comp_name][tracer_name] = tmp_ds # Save pointers diminfos_ds[comp_name][tracer_name] = { "xpointer": tracer.xpointer, "dim": tracer.dim, } # Variables with ensemble data ('x' ensemble data is named 'x_ens') ensemble_variables = ["x", "dx"] # Saving ensemble data if ensemble: # Getting number of samples n_samples = 0 for var_name in ensemble_variables: ens_var_name = f"{var_name}_ens" if hasattr(self, ens_var_name): n_samples = getattr(self, ens_var_name).shape[0] if n_samples == 0: raise ValueError( "No ensemble variables in the control vector for " f"variable {comp_name}/{tracer_name}" ) # TODO: is it really usefull use separate variables for each # samples? Why not saving the samples matrix as is ? for sample_index in range(n_samples): trcr_sample = f"{tracer_name}__sample#{sample_index:03d}" # TODO: may not be usefull, controlvect.load get those # columns from 'ds[comp][trcr]' and not from # ds[comp][trcr_sample] # # Copying tracer data in tracer sample controlvect_ds[comp_name][trcr_sample] = { col: copy.deepcopy(controlvect_ds[comp_name][tracer_name][col]) for col in ["date", "horiz_id", "vert_id"] } for var_name in ensemble_variables: ens_var_name = f"{var_name}_ens" if not hasattr(self, ens_var_name): continue # Extracting sample data from ensemble data var = getattr(self, ens_var_name) var = var[sample_index, xslice] controlvect_ds[comp_name][trcr_sample][var_name] = var # Don't go further if no need to dump as netcdf if not to_netcdf or dir_netcdf is None: continue dir_comp = Path(dir_netcdf, comp_name) dir_comp.mkdir(parents=True, exist_ok=True) ctrlvect_file = Path(dir_comp, f"controlvect_{comp_name}_{tracer_name}.nc") debug( f"Dumping controlvect for {comp_name}/{tracer_name} to NetCDF file '{ctrlvect_file}'" ) ds = xr.Dataset() variables = { "x": ("x", "posterior"), "xb": ("xb", "prior"), "dx": ("dx", "increment"), "b_std": ("std", "prior_uncertainty_diag"), "pa_std": ("pa", "posterior_uncertainty_diag"), } for varname, (attrname, fullname) in variables.items(): if varname == "pa_std": # Skip the diagonal of posterior uncertainties if not available if not hasattr(self, attrname): continue pa = np.diag(self.pa) if self.pa.ndim == 2 else self.pa var = np.sqrt(pa) else: var = getattr(self, attrname) var = np.reshape(var[xslice], xshape) var = scale2map(var, tracer, tracer.dates, tracer.domain) var.attrs["standard_name"] = f"{tracer_type}_{fullname}" var.attrs["component"] = comp_name var.attrs["tracer"] = tracer_name var.attrs["type"] = tracer_type ds[varname] = var # If tracer is scalar, also include the "physical" projection if tracer_type == "scalar" and dump_physical: # Read the tracer array and apply the present control vector scaling factor # Apply same protocol as ini_mapper from transform "fromcontrol" # to find correct dates (merged input_dates and tracer dates) ds_phys = None for di in tracer.input_dates: # Skip if input dates are empty for some reason for that # period if len(tracer.input_dates[di]) == 0: continue outdates = ( pd.DatetimeIndex(tracer.input_dates[di].stack()) .append(pd.DatetimeIndex(tracer.dates)) .drop_duplicates() .sort_values() ) if len(outdates) == 1: outdates = np.append(outdates, outdates) mask_min = ( np.zeros(len(outdates), dtype=bool) if len(tracer.input_dates[di]) == 0 else outdates >= np.min(tracer.input_dates[di]["start_date"]) ) mask_max = ( np.zeros(len(outdates), dtype=bool) if len(tracer.input_dates[di]) == 0 else outdates <= np.max(tracer.input_dates[di]["end_date"]) ) outdates = outdates[mask_min & mask_max] outdates = pd.to_datetime(outdates) # Turn dates to datetime for consistency in read # TODO: This should be standardized in the future dates2read = [ [x.to_pydatetime() for x in row] for row in tracer.input_dates[di].itertuples(index=False, name=None) ] # Read reference inputs inputs = tracer.read( tracer_name, tracer.varname, dates2read, tracer.input_files[di], comp_type=comp_name, tracer=tracer, ddi=di, model=self.model, **kwargs, ) # Check that horizontal dimensions are compatible input_dims = inputs.shape[2:] xb_dims = ds["xb"].shape[2:] if input_dims != xb_dims: raise ValueError( "Dimensions for inputs and xb are not compatible. This can " "arise if `is_lbc` has erroneously been set to True in your " "yaml.\n" f"\t- Input dimension: {input_dims}\n" f"\t- Xb dimension: {xb_dims}" ) # Reindex xb, x and inputs to common outdates inputs = reindex(inputs, levels={"time": outdates[:-1]}) ds_tmp = xr.Dataset() for varname, (attrname, fullname) in variables.items(): if varname == "pa_std": # Skip the diagonal of posterior uncertainties if not available if not hasattr(self, attrname): continue var = inputs * reindex( ds[varname], levels={"time": outdates[:-1], "lev": inputs.lev} ) var.attrs["standard_name"] = f"physical_{fullname}" var.attrs["component"] = comp_name var.attrs["tracer"] = tracer_name var.attrs["type"] = tracer_type ds_tmp[f"{varname}_phys"] = var if ds_phys is None: ds_phys = ds_tmp else: ds_phys = xr.concat([ds_phys, ds_tmp], dim="time", join="inner") # Drop duplicated times index_unique = np.unique(ds_phys["time"], return_index=True)[1] ds_phys = ds_phys.isel({"time": index_unique}) # Merge with non-physical values ds_phys = ds_phys.rename({ "time": "time_phys", "lev": "lev_phys" }) ds = ds.merge(ds_phys) # Adding longitudes and latitudes if is_lbc: lat = tracer.domain.zlat_side lon = tracer.domain.zlon_side latc = tracer.domain.zlatc_side lonc = tracer.domain.zlonc_side else: lat = tracer.domain.zlat lon = tracer.domain.zlon latc = tracer.domain.zlatc lonc = tracer.domain.zlonc if not hasattr(tracer.domain, "areas"): tracer.domain.calc_areas() if tracer.domain.unstructured_domain: ds = ds.squeeze("lat") ds = ds.rename_dims(lon="cell") ds = ds.assign_coords( lat=(["cell"], lat[0, :]), lon=(["cell"], lon[0, :]), lat_bnds=(["cell", "vertex"], latc.T), lon_bnds=(["cell", "vertex"], lonc.T), ) if not is_lbc: ds["areas"] = (["cell"], tracer.domain.areas) else: # Squeeze lat lon coords if they are regular if ( np.all(lon == lon[0, :]) and np.all(lat == lat[:, [0]]) and np.all(lonc == lonc[0, :]) and np.all(latc == latc[:, [0]]) ): lon_bnds = np.concatenate( [lonc[0, :-1, np.newaxis], lonc[0, 1:, np.newaxis]], axis=1 ) lat_bnds = np.concatenate( [latc[:-1, 0, np.newaxis], latc[1:, 0, np.newaxis]], axis=1 ) ds = ds.assign_coords( lat=(["lat"], lat[:, 0]), lon=(["lon"], lon[0, :]), lat_bnds=(["lat", "bnds"], lat_bnds), lon_bnds=(["lon", "bnds"], lon_bnds), ) # Otherwise keep 2D lat and lon else: ds = ds.assign_coords( lat=(["lat", "lon"], lat), lon=(["lat", "lon"], lon), lat_bnds=(["lat_bnds", "lon_bnds"], latc), lon_bnds=(["lat_bnds", "lon_bnds"], lonc), ) if not is_lbc: ds["areas"] = (["lat", "lon"], tracer.domain.areas) # Squeeze vertical dimension if it is 1 if tracer.vresoldim == 1: ds = ds.squeeze("lev") ds = ds.drop_vars("lev") # Coordinates attributes ds["time"].attrs = { "standard_name": "time", "long_name": "time", } if tracer_type == "scalar" and dump_physical: ds["time_phys"].attrs = { "standard_name": "time_physical", "long_name": "time_physical", } ds["lat"].attrs = { "standard_name": "latitude", "long_name": "latitude", "units": "degrees_north", "bounds": "lat_bnds", } ds["lon"].attrs = { "standard_name": "longitude", "long_name": "longitude", "units": "degrees_east", "bounds": "lon_bnds", } ds["lat_bnds"].attrs = { "standard_name": "latitude_bounds", "long_name": "latitude bounds", "units": "degrees_north", } ds["lon_bnds"].attrs = { "standard_name": "longitude_bounds", "long_name": "longitude bounds", "units": "degrees_east", } if not is_lbc: ds["areas"].attrs = { "standard_name": "area", "long_name": "area", "units": "m2", } for varname in ds: if "coordinates" in ds[varname].attrs: ds[varname].attrs.pop("coordinates") # Global attributes now = pd.Timestamp.now().round("s").isoformat() ds.attrs = { "title": f"CIF control vector for {comp_name}/{tracer_name}", "history": f"{now}: pycif, file created;", } # Dumping if ctrlvect_file.exists(): ctrlvect_file.unlink() ds.to_netcdf(ctrlvect_file) # Dumping the dictionary to a pickle tosave["datastore"] = controlvect_ds tosave["dim_infos"] = diminfos_ds with open(cntrl_file, "wb") as f: pickle.dump(tosave, f, pickle.HIGHEST_PROTOCOL)
[docs] def load( self, cntrl_file, component2load=None, tracer2load=None, target_tracer=None, ensemble=False, **kwargs, ): debug(f"Loading control vector from {cntrl_file}") if ensemble and component2load is None and tracer2load is None: warning( "Trying to load control vector ensemble data without " "specifying some specific component and tracer to load" ) with open(cntrl_file, "rb") as f: toread = pickle.load(f) out_ds = toread["datastore"] del toread["datastore"] out = self.from_dict(toread) # Loop over components and tracers components = self.datavect.components list_components = ( components.attributes if component2load is None else [component2load] ) for comp in list_components: component = getattr(components, comp) # Skip if component does not have parameters if not hasattr(component, "parameters"): continue # Skip if component not in pickle if comp not in out_ds: warning(f"Could not read component '{comp}' for pickle {cntrl_file}.") continue comp_ds = out_ds[comp] list_tracers = ( component.parameters.attributes if tracer2load is None else [tracer2load] ) for trcr in list_tracers: tracer = getattr( component.parameters, trcr if target_tracer is None else target_tracer ) # Do nothing if not in control vector if not tracer.iscontrol: continue xslice = slice(tracer.xpointer, tracer.xpointer + tracer.dim) # Skip if component not in pickle if not trcr in comp_ds: warning( f"Could not read tracer '{comp}/{trcr}' for pickle {cntrl_file}." ) continue debug(f"Loading variable {comp}/{trcr}") # Fill the correct chunk with corresponding values dates_id, vert_id, horiz_id = np.meshgrid( range(tracer.ndates), range(tracer.vresoldim), range(tracer.hresoldim) ) dates_id = tracer.dates[dates_id] target_index = pd.MultiIndex.from_arrays( [dates_id.flatten(), horiz_id.flatten(), vert_id.flatten()] ) tmp_ds = pd.DataFrame(comp_ds[trcr]) # Removing categorical dtypes (speeds up the reindixing) tmp_ds["date"] = tmp_ds.date.astype("datetime64[ns]") tmp_ds["horiz_id"] = tmp_ds.horiz_id.astype("int") tmp_ds["vert_id"] = tmp_ds.vert_id.astype("int") if ensemble: # Keeping index data for ensemble variables index_data = tmp_ds[["date", "horiz_id", "vert_id"]].copy() # Reindexing tmp_ds.set_index(["date", "horiz_id", "vert_id"], inplace=True) tmp_ds = tmp_ds.reindex(target_index, copy=False) # Loop over variables to initialize var2read = ["x", "xb", "dx", "std", "pa"] for var in var2read: if var not in tmp_ds: continue # Initialize variable to zero in controlvect if not already here array = getattr(self, var, np.zeros(self.dim)) array[xslice] = tmp_ds[var].values setattr(self, var, array) # Variables with ensemble data ('x' ensemble data is named 'x_ens') ensemble_variables = ["x", "dx"] # Loading ensemble data if ensemble: # Getting variable names and number of samples sample_names = [] max_number_of_samples = 1000 for sample_index in range(max_number_of_samples + 1): # Sample variable name trcr_sample = f"{trcr}__sample#{sample_index:03d}" if trcr_sample in comp_ds: sample_names.append(trcr_sample) else: n_samples = sample_index break else: # Branch here when reaching the end of the for loop without # breaking, i.e. when max number of samples is reached. raise ValueError( f"Too many sample in ensemble for variable {comp}/{trcr}" ) if n_samples == 0: raise ValueError( f"No samples detected in the control vector for variable {comp}/{trcr}" ) # Initializing empty ensemble data ens_data = {var_name: [] for var_name in ensemble_variables} # Iterating over samples to extract ensemble data for trcr_sample in sample_names: # Sample datastore sample_ds = comp_ds[trcr_sample] # Getting sample data for var_name in ensemble_variables: if var_name in sample_ds: ens_data[var_name].append(sample_ds[var_name]) # Iterating over ensemble variables to format and save data for var_name in ensemble_variables: if not ens_data[var_name]: continue # Concatenating samples data var = np.concatenate( [col[:, np.newaxis] for col in ens_data[var_name]], axis=1 ) assert var.shape[1] == n_samples # Formatting ensemble data tmp_sample_ds = pd.DataFrame(data=var, columns=sample_names) tmp_sample_ds = pd.concat( [index_data, tmp_sample_ds], axis="columns" ) tmp_sample_ds.set_index( ["date", "horiz_id", "vert_id"], inplace=True ) tmp_sample_ds = tmp_sample_ds.reindex(target_index, copy=False) # Initialize variable to zero in controlvect if not already here ens_var_name = f"{var_name}_ens" ens_var = getattr( self, ens_var_name, np.zeros((n_samples, self.dim)) ) # Saving formatted data ens_var[:, xslice] = tmp_sample_ds.values.T setattr(self, ens_var_name, ens_var) debug(f"Successfully loaded control vector from {cntrl_file}") return out