import copy
from logging import debug, warning
from pathlib import Path
import numpy as np
import pandas as pd
import xarray as xr
from ....plugins.transforms.system.fromcontrol.utils.scalemaps import scale2map
from ....utils.dataarrays.reindex import reindex
from ....utils.iterators import iter_tracers
try:
import cPickle as pickle
except ImportError:
import pickle
[docs]
def dump(self, cntrl_file, to_netcdf=False, dir_netcdf=None, ensemble=False, **kwargs):
"""Dumps a control vector into a pickle file.
Does not save large correlations.
Args:
self (pycif.utils.classes.controlvects.ControlVect):
the Control Vector to dump
cntrl_file (str): path to the file to dump as pickle
to_netcdf (bool): save to netcdf files if True
dir_netcdf (str): root path for the netcdf directory
"""
debug(f"dumping controlvect to file '{cntrl_file}'")
# Saving recursive attributes from the Yaml
# fmt: off
exclude = [
"transform", "domain", "datastore", "input_dates", "obsvect", "tracer",
"input_files", "tstep_dates", "tstep_all", "dataflx", "logfile", "datei",
"datef", "workdir", "verbose", "subsimu_dates", "tcorrelations",
"hcorrelations", "databos"
]
# fmt: on
tosave = self.to_dict(self, exclude_patterns=exclude)
# Save the control vector as a pandas datastore
controlvect_ds = {}
diminfos_ds = {}
for (comp_name, tracer_name), tracer in iter_tracers(self.datavect):
# Do nothing if not in control vector
if not tracer.iscontrol:
continue
tracer_type = getattr(tracer, "type", "scalar")
dump_physical = getattr(tracer, "dump_physical", True)
is_lbc = getattr(tracer, "is_lbc", False)
xslice = slice(tracer.xpointer, tracer.xpointer + tracer.dim)
xshape = (tracer.ndates, tracer.vresoldim, -1)
# Update controlvect_ds dictionary
if comp_name not in controlvect_ds:
controlvect_ds[comp_name] = {}
if comp_name not in diminfos_ds:
diminfos_ds[comp_name] = {}
tmp_ds = {}
# Fetch information for tmp ds
var2read = ["x", "xb", "dx", "std", "pa"]
for varname in var2read:
if hasattr(self, varname):
var = getattr(self, varname)
if var.ndim == 1:
tmp_ds[varname] = var[xslice]
else:
tmp_ds[varname] = np.diag(var)[xslice].copy()
dates_id, vert_id, horiz_id = np.meshgrid(
range(tracer.ndates),
range(tracer.vresoldim),
range(tracer.hresoldim),
)
tmp_ds["horiz_id"] = horiz_id.flatten()
del horiz_id
tmp_ds["vert_id"] = vert_id.flatten()
del vert_id
tmp_ds["date"] = tracer.dates[dates_id].flatten()
del dates_id
# Reducing memory usage with panda's categorical dtype
# (this improves pickle's read/write times)
for col in ["date", "horiz_id", "vert_id"]:
tmp_ds[col] = pd.Series(tmp_ds[col]).astype("category")
controlvect_ds[comp_name][tracer_name] = tmp_ds
# Save pointers
diminfos_ds[comp_name][tracer_name] = {
"xpointer": tracer.xpointer,
"dim": tracer.dim,
}
# Variables with ensemble data ('x' ensemble data is named 'x_ens')
ensemble_variables = ["x", "dx"]
# Saving ensemble data
if ensemble:
# Getting number of samples
n_samples = 0
for var_name in ensemble_variables:
ens_var_name = f"{var_name}_ens"
if hasattr(self, ens_var_name):
n_samples = getattr(self, ens_var_name).shape[0]
if n_samples == 0:
raise ValueError(
"No ensemble variables in the control vector for "
f"variable {comp_name}/{tracer_name}"
)
# TODO: is it really usefull use separate variables for each
# samples? Why not saving the samples matrix as is ?
for sample_index in range(n_samples):
trcr_sample = f"{tracer_name}__sample#{sample_index:03d}"
# TODO: may not be usefull, controlvect.load get those
# columns from 'ds[comp][trcr]' and not from
# ds[comp][trcr_sample]
#
# Copying tracer data in tracer sample
controlvect_ds[comp_name][trcr_sample] = {
col: copy.deepcopy(controlvect_ds[comp_name][tracer_name][col])
for col in ["date", "horiz_id", "vert_id"]
}
for var_name in ensemble_variables:
ens_var_name = f"{var_name}_ens"
if not hasattr(self, ens_var_name):
continue
# Extracting sample data from ensemble data
var = getattr(self, ens_var_name)
var = var[sample_index, xslice]
controlvect_ds[comp_name][trcr_sample][var_name] = var
# Don't go further if no need to dump as netcdf
if not to_netcdf or dir_netcdf is None:
continue
dir_comp = Path(dir_netcdf, comp_name)
dir_comp.mkdir(parents=True, exist_ok=True)
ctrlvect_file = Path(dir_comp, f"controlvect_{comp_name}_{tracer_name}.nc")
debug(
f"Dumping controlvect for {comp_name}/{tracer_name} to NetCDF file '{ctrlvect_file}'"
)
ds = xr.Dataset()
variables = {
"x": ("x", "posterior"),
"xb": ("xb", "prior"),
"dx": ("dx", "increment"),
"b_std": ("std", "prior_uncertainty_diag"),
"pa_std": ("pa", "posterior_uncertainty_diag"),
}
for varname, (attrname, fullname) in variables.items():
if varname == "pa_std":
# Skip the diagonal of posterior uncertainties if not available
if not hasattr(self, attrname):
continue
pa = np.diag(self.pa) if self.pa.ndim == 2 else self.pa
var = np.sqrt(pa)
else:
var = getattr(self, attrname)
var = np.reshape(var[xslice], xshape)
var = scale2map(var, tracer, tracer.dates, tracer.domain)
var.attrs["standard_name"] = f"{tracer_type}_{fullname}"
var.attrs["component"] = comp_name
var.attrs["tracer"] = tracer_name
var.attrs["type"] = tracer_type
ds[varname] = var
# If tracer is scalar, also include the "physical" projection
if tracer_type == "scalar" and dump_physical:
# Read the tracer array and apply the present control vector scaling factor
# Apply same protocol as ini_mapper from transform "fromcontrol"
# to find correct dates (merged input_dates and tracer dates)
ds_phys = None
for di in tracer.input_dates:
# Skip if input dates are empty for some reason for that
# period
if len(tracer.input_dates[di]) == 0:
continue
outdates = (
pd.DatetimeIndex(tracer.input_dates[di].stack())
.append(pd.DatetimeIndex(tracer.dates))
.drop_duplicates()
.sort_values()
)
if len(outdates) == 1:
outdates = np.append(outdates, outdates)
mask_min = (
np.zeros(len(outdates), dtype=bool)
if len(tracer.input_dates[di]) == 0
else outdates >= np.min(tracer.input_dates[di]["start_date"])
)
mask_max = (
np.zeros(len(outdates), dtype=bool)
if len(tracer.input_dates[di]) == 0
else outdates <= np.max(tracer.input_dates[di]["end_date"])
)
outdates = outdates[mask_min & mask_max]
outdates = pd.to_datetime(outdates)
# Turn dates to datetime for consistency in read
# TODO: This should be standardized in the future
dates2read = [
[x.to_pydatetime() for x in row]
for row in tracer.input_dates[di].itertuples(index=False, name=None)
]
# Read reference inputs
inputs = tracer.read(
tracer_name,
tracer.varname,
dates2read,
tracer.input_files[di],
comp_type=comp_name,
tracer=tracer,
ddi=di,
model=self.model,
**kwargs,
)
# Check that horizontal dimensions are compatible
input_dims = inputs.shape[2:]
xb_dims = ds["xb"].shape[2:]
if input_dims != xb_dims:
raise ValueError(
"Dimensions for inputs and xb are not compatible. This can "
"arise if `is_lbc` has erroneously been set to True in your "
"yaml.\n"
f"\t- Input dimension: {input_dims}\n"
f"\t- Xb dimension: {xb_dims}"
)
# Reindex xb, x and inputs to common outdates
inputs = reindex(inputs, levels={"time": outdates[:-1]})
ds_tmp = xr.Dataset()
for varname, (attrname, fullname) in variables.items():
if varname == "pa_std":
# Skip the diagonal of posterior uncertainties if not available
if not hasattr(self, attrname):
continue
var = inputs * reindex(
ds[varname], levels={"time": outdates[:-1], "lev": inputs.lev}
)
var.attrs["standard_name"] = f"physical_{fullname}"
var.attrs["component"] = comp_name
var.attrs["tracer"] = tracer_name
var.attrs["type"] = tracer_type
ds_tmp[f"{varname}_phys"] = var
if ds_phys is None:
ds_phys = ds_tmp
else:
ds_phys = xr.concat([ds_phys, ds_tmp], dim="time", join="inner")
# Drop duplicated times
index_unique = np.unique(ds_phys["time"], return_index=True)[1]
ds_phys = ds_phys.isel({"time": index_unique})
# Merge with non-physical values
ds_phys = ds_phys.rename({
"time": "time_phys", "lev": "lev_phys"
})
ds = ds.merge(ds_phys)
# Adding longitudes and latitudes
if is_lbc:
lat = tracer.domain.zlat_side
lon = tracer.domain.zlon_side
latc = tracer.domain.zlatc_side
lonc = tracer.domain.zlonc_side
else:
lat = tracer.domain.zlat
lon = tracer.domain.zlon
latc = tracer.domain.zlatc
lonc = tracer.domain.zlonc
if not hasattr(tracer.domain, "areas"):
tracer.domain.calc_areas()
if tracer.domain.unstructured_domain:
ds = ds.squeeze("lat")
ds = ds.rename_dims(lon="cell")
ds = ds.assign_coords(
lat=(["cell"], lat[0, :]),
lon=(["cell"], lon[0, :]),
lat_bnds=(["cell", "vertex"], latc.T),
lon_bnds=(["cell", "vertex"], lonc.T),
)
if not is_lbc:
ds["areas"] = (["cell"], tracer.domain.areas)
else:
# Squeeze lat lon coords if they are regular
if (
np.all(lon == lon[0, :])
and np.all(lat == lat[:, [0]])
and np.all(lonc == lonc[0, :])
and np.all(latc == latc[:, [0]])
):
lon_bnds = np.concatenate(
[lonc[0, :-1, np.newaxis], lonc[0, 1:, np.newaxis]], axis=1
)
lat_bnds = np.concatenate(
[latc[:-1, 0, np.newaxis], latc[1:, 0, np.newaxis]], axis=1
)
ds = ds.assign_coords(
lat=(["lat"], lat[:, 0]),
lon=(["lon"], lon[0, :]),
lat_bnds=(["lat", "bnds"], lat_bnds),
lon_bnds=(["lon", "bnds"], lon_bnds),
)
# Otherwise keep 2D lat and lon
else:
ds = ds.assign_coords(
lat=(["lat", "lon"], lat),
lon=(["lat", "lon"], lon),
lat_bnds=(["lat_bnds", "lon_bnds"], latc),
lon_bnds=(["lat_bnds", "lon_bnds"], lonc),
)
if not is_lbc:
ds["areas"] = (["lat", "lon"], tracer.domain.areas)
# Squeeze vertical dimension if it is 1
if tracer.vresoldim == 1:
ds = ds.squeeze("lev")
ds = ds.drop_vars("lev")
# Coordinates attributes
ds["time"].attrs = {
"standard_name": "time",
"long_name": "time",
}
if tracer_type == "scalar" and dump_physical:
ds["time_phys"].attrs = {
"standard_name": "time_physical",
"long_name": "time_physical",
}
ds["lat"].attrs = {
"standard_name": "latitude",
"long_name": "latitude",
"units": "degrees_north",
"bounds": "lat_bnds",
}
ds["lon"].attrs = {
"standard_name": "longitude",
"long_name": "longitude",
"units": "degrees_east",
"bounds": "lon_bnds",
}
ds["lat_bnds"].attrs = {
"standard_name": "latitude_bounds",
"long_name": "latitude bounds",
"units": "degrees_north",
}
ds["lon_bnds"].attrs = {
"standard_name": "longitude_bounds",
"long_name": "longitude bounds",
"units": "degrees_east",
}
if not is_lbc:
ds["areas"].attrs = {
"standard_name": "area",
"long_name": "area",
"units": "m2",
}
for varname in ds:
if "coordinates" in ds[varname].attrs:
ds[varname].attrs.pop("coordinates")
# Global attributes
now = pd.Timestamp.now().round("s").isoformat()
ds.attrs = {
"title": f"CIF control vector for {comp_name}/{tracer_name}",
"history": f"{now}: pycif, file created;",
}
# Dumping
if ctrlvect_file.exists():
ctrlvect_file.unlink()
ds.to_netcdf(ctrlvect_file)
# Dumping the dictionary to a pickle
tosave["datastore"] = controlvect_ds
tosave["dim_infos"] = diminfos_ds
with open(cntrl_file, "wb") as f:
pickle.dump(tosave, f, pickle.HIGHEST_PROTOCOL)
[docs]
def load(
self,
cntrl_file,
component2load=None,
tracer2load=None,
target_tracer=None,
ensemble=False,
**kwargs,
):
debug(f"Loading control vector from {cntrl_file}")
if ensemble and component2load is None and tracer2load is None:
warning(
"Trying to load control vector ensemble data without "
"specifying some specific component and tracer to load"
)
with open(cntrl_file, "rb") as f:
toread = pickle.load(f)
out_ds = toread["datastore"]
del toread["datastore"]
out = self.from_dict(toread)
# Loop over components and tracers
components = self.datavect.components
list_components = (
components.attributes if component2load is None else [component2load]
)
for comp in list_components:
component = getattr(components, comp)
# Skip if component does not have parameters
if not hasattr(component, "parameters"):
continue
# Skip if component not in pickle
if comp not in out_ds:
warning(f"Could not read component '{comp}' for pickle {cntrl_file}.")
continue
comp_ds = out_ds[comp]
list_tracers = (
component.parameters.attributes if tracer2load is None else [tracer2load]
)
for trcr in list_tracers:
tracer = getattr(
component.parameters, trcr if target_tracer is None else target_tracer
)
# Do nothing if not in control vector
if not tracer.iscontrol:
continue
xslice = slice(tracer.xpointer, tracer.xpointer + tracer.dim)
# Skip if component not in pickle
if not trcr in comp_ds:
warning(
f"Could not read tracer '{comp}/{trcr}' for pickle {cntrl_file}."
)
continue
debug(f"Loading variable {comp}/{trcr}")
# Fill the correct chunk with corresponding values
dates_id, vert_id, horiz_id = np.meshgrid(
range(tracer.ndates), range(tracer.vresoldim), range(tracer.hresoldim)
)
dates_id = tracer.dates[dates_id]
target_index = pd.MultiIndex.from_arrays(
[dates_id.flatten(), horiz_id.flatten(), vert_id.flatten()]
)
tmp_ds = pd.DataFrame(comp_ds[trcr])
# Removing categorical dtypes (speeds up the reindixing)
tmp_ds["date"] = tmp_ds.date.astype("datetime64[ns]")
tmp_ds["horiz_id"] = tmp_ds.horiz_id.astype("int")
tmp_ds["vert_id"] = tmp_ds.vert_id.astype("int")
if ensemble:
# Keeping index data for ensemble variables
index_data = tmp_ds[["date", "horiz_id", "vert_id"]].copy()
# Reindexing
tmp_ds.set_index(["date", "horiz_id", "vert_id"], inplace=True)
tmp_ds = tmp_ds.reindex(target_index, copy=False)
# Loop over variables to initialize
var2read = ["x", "xb", "dx", "std", "pa"]
for var in var2read:
if var not in tmp_ds:
continue
# Initialize variable to zero in controlvect if not already here
array = getattr(self, var, np.zeros(self.dim))
array[xslice] = tmp_ds[var].values
setattr(self, var, array)
# Variables with ensemble data ('x' ensemble data is named 'x_ens')
ensemble_variables = ["x", "dx"]
# Loading ensemble data
if ensemble:
# Getting variable names and number of samples
sample_names = []
max_number_of_samples = 1000
for sample_index in range(max_number_of_samples + 1):
# Sample variable name
trcr_sample = f"{trcr}__sample#{sample_index:03d}"
if trcr_sample in comp_ds:
sample_names.append(trcr_sample)
else:
n_samples = sample_index
break
else:
# Branch here when reaching the end of the for loop without
# breaking, i.e. when max number of samples is reached.
raise ValueError(
f"Too many sample in ensemble for variable {comp}/{trcr}"
)
if n_samples == 0:
raise ValueError(
f"No samples detected in the control vector for variable {comp}/{trcr}"
)
# Initializing empty ensemble data
ens_data = {var_name: [] for var_name in ensemble_variables}
# Iterating over samples to extract ensemble data
for trcr_sample in sample_names:
# Sample datastore
sample_ds = comp_ds[trcr_sample]
# Getting sample data
for var_name in ensemble_variables:
if var_name in sample_ds:
ens_data[var_name].append(sample_ds[var_name])
# Iterating over ensemble variables to format and save data
for var_name in ensemble_variables:
if not ens_data[var_name]:
continue
# Concatenating samples data
var = np.concatenate(
[col[:, np.newaxis] for col in ens_data[var_name]], axis=1
)
assert var.shape[1] == n_samples
# Formatting ensemble data
tmp_sample_ds = pd.DataFrame(data=var, columns=sample_names)
tmp_sample_ds = pd.concat(
[index_data, tmp_sample_ds], axis="columns"
)
tmp_sample_ds.set_index(
["date", "horiz_id", "vert_id"], inplace=True
)
tmp_sample_ds = tmp_sample_ds.reindex(target_index, copy=False)
# Initialize variable to zero in controlvect if not already here
ens_var_name = f"{var_name}_ens"
ens_var = getattr(
self, ens_var_name, np.zeros((n_samples, self.dim))
)
# Saving formatted data
ens_var[:, xslice] = tmp_sample_ds.values.T
setattr(self, ens_var_name, ens_var)
debug(f"Successfully loaded control vector from {cntrl_file}")
return out