import os
import copy
import numpy as np
import xarray as xr
import pandas as pd
from logging import debug, warning
import tracemalloc
from ....utils.path import init_dir
from ....plugins.transforms.system.fromcontrol.utils.scalemaps \
import scale2map
from ....utils.dataarrays.reindex import reindex
try:
import cPickle as pickle
except ImportError:
import pickle
[docs]
def dump(self, cntrl_file, to_netcdf=False, dir_netcdf=None, ensemble=False, **kwargs):
"""Dumps a control vector into a pickle file.
Does not save large correlations.
Args:
self (pycif.utils.classes.controlvects.ControlVect):
the Control Vector to dump
cntrl_file (str): path to the file to dump as pickle
to_netcdf (bool): save to netcdf files if True
dir_netcdf (str): root path for the netcdf directory
"""
debug("Dumping the control vector to {}".format(cntrl_file))
# Saving recursive attributes from the Yaml
exclude = ["transform", "domain", "datastore",
"input_dates", "obsvect", "tracer", "input_files",
"tstep_dates", "tstep_all", "dataflx",
"logfile", "datei", "datef", "workdir",
"verbose", "subsimu_dates", "tcorrelations",
"hcorrelations", "databos"]
tosave = self.to_dict(self, exclude_patterns=exclude)
# Save the control vector as a pandas datastore
controlvect_ds = {}
diminfos_ds = {}
components = self.datavect.components
for comp in components.attributes:
component = getattr(components, comp)
# Skip if component does not have parameters
if not hasattr(component, "parameters"):
continue
dir_comp = "{}/{}".format(dir_netcdf, comp)
init_dir(dir_comp)
for trcr in component.parameters.attributes:
tracer = getattr(component.parameters, trcr)
# Do nothing if not in control vector
if not tracer.iscontrol:
continue
# Update controlvect_ds dictionary
if comp not in controlvect_ds:
controlvect_ds[comp] = {}
if comp not in diminfos_ds:
diminfos_ds[comp] = {}
# Fetch information for tmp ds
var2read = ["x", "xb", "dx", "std", "pa"]
tmp_ds = {}
for var in var2read:
if hasattr(self, var):
if getattr(self, var).ndim == 1:
tmp_ds[var] = \
getattr(self, var)[
tracer.xpointer: tracer.xpointer + tracer.dim]
else:
tmp_ds[var] = copy.deepcopy(
np.diag(getattr(self, var))[
tracer.xpointer: tracer.xpointer + tracer.dim])
dates_id, vert_id, horiz_id = np.meshgrid(
range(tracer.ndates),
range(tracer.vresoldim),
range(tracer.hresoldim))
tmp_ds["horiz_id"] = horiz_id.flatten()
del horiz_id
tmp_ds["vert_id"] = vert_id.flatten()
del vert_id
tmp_ds["date"] = tracer.dates[dates_id].flatten()
del dates_id
# Reducing memory usage with panda's categorical dtype
# (this improves pickle's read/write times)
for col in ['date', 'horiz_id', 'vert_id']:
tmp_ds[col] = pd.Series(tmp_ds[col]).astype('category')
controlvect_ds[comp][trcr] = tmp_ds
# Save pointers
diminfos_ds[comp][trcr] = {
"xpointer": tracer.xpointer,
"dim": tracer.dim
}
# Variables with ensemble data ('x' ensemble data is named 'x_ens')
ensemble_variables = ['x', 'dx']
# Saving ensemble data
if ensemble:
# Getting number of samples
n_samples = 0
for var_name in ensemble_variables:
ens_var_name = f'{var_name}_ens'
if hasattr(self, ens_var_name):
n_samples = getattr(self, ens_var_name).shape[0]
if n_samples == 0:
raise ValueError(
"No ensemble variables in the control vector for "
f"variable {comp}/{trcr}"
)
# TODO: is it really usefull use separate variables for each
# samples? Why not saving the samples matrix as is ?
for sample_index in range(n_samples):
trcr_sample = f"{trcr}__sample#{sample_index:03d}"
# TODO: may not be usefull, controlvect.load get those
# columns from 'ds[comp][trcr]' and not from
# ds[comp][trcr_sample]
#
# Copying tracer data in tracer sample
controlvect_ds[comp][trcr_sample] = {
col: copy.deepcopy(controlvect_ds[comp][trcr][col])
for col in ['date', 'horiz_id', 'vert_id']
}
for var_name in ensemble_variables:
ens_var_name = f'{var_name}_ens'
if not hasattr(self, ens_var_name):
continue
# Extracting sample data from ensemble data
var = getattr(self, ens_var_name)
controlvect_ds[comp][trcr_sample][var_name] = \
var[sample_index,
tracer.xpointer: tracer.xpointer + tracer.dim]
# Don't go further if no need to dump as netcdf
if not to_netcdf or dir_netcdf is None:
continue
debug("Dumping control vector as NetCDF for {}/{}"
.format(comp, trcr))
# Translating x and xb to maps
x = np.reshape(
self.x[tracer.xpointer: tracer.xpointer + tracer.dim],
(tracer.ndates, tracer.vresoldim, -1),
)
x = scale2map(x, tracer, tracer.dates, tracer.domain)
xb = np.reshape(
self.xb[tracer.xpointer: tracer.xpointer + tracer.dim],
(tracer.ndates, tracer.vresoldim, -1),
)
xb = scale2map(xb, tracer, tracer.dates, tracer.domain)
std = np.reshape(
self.std[tracer.xpointer: tracer.xpointer + tracer.dim],
(tracer.ndates, tracer.vresoldim, -1),
)
std = scale2map(std, tracer, tracer.dates, tracer.domain)
dx = np.reshape(
self.dx[tracer.xpointer: tracer.xpointer + tracer.dim],
(tracer.ndates, tracer.vresoldim, -1),
)
dx = scale2map(dx, tracer, tracer.dates, tracer.domain)
# Adding the diagonal of posterior uncertainties if available
if hasattr(self, "pa"):
if self.pa.ndim == 2:
pa = np.diag(self.pa)
else:
pa = self.pa
pa_std = np.reshape(np.sqrt(
pa[tracer.xpointer: tracer.xpointer + tracer.dim]),
(tracer.ndates, tracer.vresoldim, -1),
)
pa_std = scale2map(pa_std, tracer, tracer.dates, tracer.domain)
ds = xr.Dataset({"x": x, "xb": xb, "dx": dx,
"b_std": std, "pa_std": pa_std})
else:
ds = xr.Dataset({"x": x, "xb": xb, "dx": dx,
"b_std": std})
# If tracer is scalar, also include the "physical" projection
if getattr(tracer, "type", "scalar") == "scalar" \
and getattr(tracer, "dump_physical", True):
# Read the tracer array and apply the present control vector
# scaling factor
# Apply same protocol as ini_mapper from transform "fromcontrol"
# to find correct dates (merged input_dates and tracer dates)
ds_phys = None
for di in tracer.input_dates:
# Skip if input dates are empty for some reason for that
# period
if len(tracer.input_dates[di]) == 0:
continue
outdates = pd.DatetimeIndex(np.sort(np.unique(np.append(
tracer.input_dates[di],
tracer.dates
)))).to_pydatetime()
if len(outdates) == 1:
outdates = np.append(outdates, outdates)
mask_min = np.zeros(len(outdates), dtype=bool) \
if tracer.input_dates[di] == [] \
else outdates >= np.min(tracer.input_dates[di])
mask_max = np.zeros(len(outdates), dtype=bool) \
if tracer.input_dates[di] == [] \
else outdates <= np.max(tracer.input_dates[di])
outdates = outdates[mask_min & mask_max]
outdates = pd.to_datetime(outdates)
# Read reference inputs
inputs = tracer.read(
trcr,
tracer.varname,
tracer.input_dates[di],
tracer.input_files[di],
comp_type=comp,
tracer=tracer,
ddi=di,
model=self.model,
**kwargs
)
# Check that horizontal dimensions are compatible
input_dims = inputs.shape[2:]
xb_dims = xb.shape[2:]
if input_dims != xb_dims:
raise Exception(
"Dimensions for inputs and xb are not compatible. \n"
"This can arise if `is_lbc` has erroneously been set to True in your yaml.\n"
f"\t- Input dimension: {input_dims}\n"
f"\t- Xb dimension: {xb_dims}"
)
# Reindex xb, x and inputs to common outdates
inputs = reindex(
inputs,
levels={"time": outdates[:-1]},
)
xb_phys = inputs * reindex(
xb,
levels={"time": outdates[:-1], "lev": inputs.lev},
)
x_phys = inputs * reindex(
x,
levels={"time": outdates[:-1], "lev": inputs.lev},
)
dx_phys = inputs * reindex(
dx,
levels={"time": outdates[:-1], "lev": inputs.lev},
)
b_phys = inputs * reindex(
std, levels={"time": outdates[:-1], "lev": inputs.lev},
)
ds_tmp = xr.Dataset({
"x_phys": x_phys,
"dx_phys": dx_phys,
"xb_phys": xb_phys,
"b_phys": b_phys})
if hasattr(self, "pa"):
pa_phys = \
inputs * reindex(
pa_std, levels={"time": outdates[:-1],
"lev": inputs.lev},
)
ds_tmp = ds_tmp.assign(pa_phys=pa_phys)
if ds_phys is None:
ds_phys = ds_tmp
else:
ds_phys = xr.concat([ds_phys, ds_tmp],
dim="time",
join="inner")
# Drop duplicated times
index_unique = np.unique(ds_phys["time"], return_index=True)[1]
ds_phys = ds_phys.isel({"time": index_unique})
# Merge with non-physical values
ds_phys = ds_phys.rename({"time": "time_phys"})
ds = ds.merge(ds_phys)
# Adding longitudes and latitudes
if not getattr(tracer, "is_lbc", False):
ds = ds.assign(
latitudes=(("lat", "lon"), tracer.domain.zlat),
longitudes=(("lat", "lon"), tracer.domain.zlon),
latitudes_corner=(("latc", "lonc"), tracer.domain.zlatc),
longitudes_corner=(("latc", "lonc"), tracer.domain.zlonc))
else:
ds = ds.assign(
latitudes=(("lat", "lon"), tracer.domain.zlat_side),
longitudes=(("lat", "lon"), tracer.domain.zlon_side),
latitudes_corner=(("latc", "lonc"),
tracer.domain.zlatc_side),
longitudes_corner=(("latc", "lonc"),
tracer.domain.zlonc_side))
# Adding areas
if not getattr(tracer, "is_lbc", False):
if not hasattr(tracer.domain, "areas"):
tracer.domain.calc_areas()
ds = ds.assign(areas=(("lat", "lon"), tracer.domain.areas))
# Dumping
controlvect_file = "{}/controlvect_{}_{}.nc".format(
dir_comp, comp, trcr)
if os.path.exists(controlvect_file):
os.remove(controlvect_file)
ds.to_netcdf(controlvect_file)
# Dumping the dictionary to a pickle
tosave["datastore"] = controlvect_ds
tosave["dim_infos"] = diminfos_ds
with open(cntrl_file, "wb") as f:
pickle.dump(tosave, f, pickle.HIGHEST_PROTOCOL)
[docs]
def load(self, cntrl_file,
component2load=None, tracer2load=None, target_tracer=None, ensemble=False,
**kwargs):
debug("Loading control vector from {}".format(cntrl_file))
if ensemble and component2load is None and tracer2load is None:
warning("Trying to load control vector ensemble data without "
"specifying some specific component and tracer to load")
with open(cntrl_file, "rb") as f:
toread = pickle.load(f)
out_ds = toread["datastore"]
del toread["datastore"]
out = self.from_dict(toread)
# Loop over components and tracers
components = self.datavect.components
list_components = components.attributes if component2load is None \
else [component2load]
for comp in list_components:
component = getattr(components, comp)
# Skip if component does not have parameters
if not hasattr(component, "parameters"):
continue
# Skip if component not in pickle
if comp not in out_ds:
warning("Could not read component '{}' for pickle {}. "
.format(comp, cntrl_file))
continue
comp_ds = out_ds[comp]
list_tracers = component.parameters.attributes if tracer2load is None \
else [tracer2load]
for trcr in list_tracers:
tracer = getattr(component.parameters,
trcr if target_tracer is None else target_tracer)
# Do nothing if not in control vector
if not tracer.iscontrol:
continue
# Skip if component not in pickle
if not trcr in comp_ds:
warning("Could not read tracer '{}/{}' for pickle {}. "
.format(comp, trcr, cntrl_file))
continue
debug("Loading variable {}/{}".format(comp, trcr))
# Fill the correct chunk with corresponding values
dates_id, vert_id, horiz_id = np.meshgrid(
range(tracer.ndates),
range(tracer.vresoldim),
range(tracer.hresoldim))
dates_id = tracer.dates[dates_id]
target_index = pd.MultiIndex.from_arrays([
dates_id.flatten(), horiz_id.flatten(), vert_id.flatten()
])
tmp_ds = pd.DataFrame(comp_ds[trcr])
# Removing categorical dtypes (speeds up the reindixing)
tmp_ds['date'] = tmp_ds.date.astype('datetime64[ns]')
tmp_ds['horiz_id'] = tmp_ds.horiz_id.astype('int')
tmp_ds['vert_id'] = tmp_ds.vert_id.astype('int')
if ensemble:
# Keeping index data for ensemble variables
index_data = tmp_ds[['date', 'horiz_id', 'vert_id']].copy()
# Reindexing
tmp_ds.set_index(["date", "horiz_id", "vert_id"], inplace=True)
tmp_ds = tmp_ds.reindex(target_index, copy=False)
# Loop over variables to initialize
var2read = ["x", "xb", "dx", "std", "pa"]
for var in var2read:
if var not in tmp_ds:
continue
# Initialize variable to zero in controlvect if not already here
array = getattr(self, var, np.zeros(self.dim))
array[tracer.xpointer: tracer.xpointer + tracer.dim] = \
tmp_ds[var].values
setattr(self, var, array)
# Variables with ensemble data ('x' ensemble data is named 'x_ens')
ensemble_variables = ['x', 'dx']
# Loading ensemble data
if ensemble:
# Getting variable names and number of samples
sample_names = []
max_number_of_samples = 1000
for sample_index in range(max_number_of_samples + 1):
# Sample variable name
trcr_sample = f"{trcr}__sample#{sample_index:03d}"
if trcr_sample in comp_ds:
sample_names.append(trcr_sample)
else:
n_samples = sample_index
break
else:
# Branch here when reaching the end of the for loop without
# breaking, i.e. when max number of samples is reached.
raise ValueError(
f"Too many sample in ensemble for variable {comp}/{trcr}")
if n_samples == 0:
raise ValueError(
f"No samples detected in the control vector for variable {comp}/{trcr}")
# Initializing empty ensemble data
ens_data = {var_name: [] for var_name in ensemble_variables}
# Iterating over samples to extract ensemble data
for trcr_sample in sample_names:
# Sample datastore
sample_ds = comp_ds[trcr_sample]
# Getting sample data
for var_name in ensemble_variables:
if var_name in sample_ds:
ens_data[var_name].append(sample_ds[var_name])
# Iterating over ensemble variables to format and save data
for var_name in ensemble_variables:
if not ens_data[var_name]:
continue
# Concatenating samples data
var = np.concatenate(
[col[:, np.newaxis] for col in ens_data[var_name]],
axis=1
)
assert var.shape[1] == n_samples
# Formatting ensemble data
tmp_sample_ds = pd.DataFrame(
data=var, columns=sample_names)
tmp_sample_ds = pd.concat(
[index_data, tmp_sample_ds], axis='columns')
tmp_sample_ds.set_index(
['date', 'horiz_id', 'vert_id'], inplace=True)
tmp_sample_ds = tmp_sample_ds.reindex(
target_index, copy=False)
# Initialize variable to zero in controlvect if not already here
ens_var_name = f'{var_name}_ens'
ens_var = getattr(self, ens_var_name,
np.zeros((n_samples, self.dim)))
# Saving formatted data
ens_var[:, tracer.xpointer: tracer.xpointer + tracer.dim] = \
tmp_sample_ds.values.T
setattr(self, ens_var_name, ens_var)
debug("Successfully loaded control vector from {}".format(cntrl_file))
return out