import pickle
import copy
import subprocess
import os
import time
import numpy as np
import pandas as pd
from logging import info, warning
from .segment_cropping import get_cntrlv_trdates, crop_dump
from .utils import get_path_in_dict, get_trids_dontpropagate, get_trids_dontpropagate_obsvect
from ....utils import path
from ....utils.yml import ordered_dump
from ....utils.dates import date_range
[docs]
def draw_normal_samples(self, controlvect, nsample, set_dev_equal=True):
"""Draw a new ensemble of nsample members for the full assimilation window
based on the prior mean (xb), the B matrix and a normal distribution.
Args:
self (Plugin): mode Plugin
controlvect (Plugin): controlvect Plugin
nsample (int): number of members (besides prior and mean)
Returns:
x_samples (np.array): created members
"""
# Draw samples for normal distribution with mean=0 and std=1
# TODO: samples should be transposed (makes more sense)
# Set all deviations equal for each window
if set_dev_equal:
info("Deviations are set equal.")
samples = []
components = controlvect.datavect.components
for icomp, comp in enumerate(components.attributes):
component = getattr(components, comp)
# Skip if component does not have parameters
if not hasattr(component, "parameters"):
continue
for itr, trcr in enumerate(component.parameters.attributes):
tracer = getattr(component.parameters, trcr)
# Do nothing if not in control vector
if not tracer.iscontrol:
continue
# All windows have the same deviations for a specific comp/tracer
dim_wind = tracer.dim // tracer.ndates
# Define the seed
if getattr(self, "seed", False):
info(
f"A seed of {self.seed_id + icomp + itr}"
f" for {comp}/{trcr} is used."
)
np.random.seed(self.seed_id + icomp + itr)
samples_trcr_wind = np.random.normal(size=(nsample, dim_wind))
samples_trcr = np.tile(samples_trcr_wind, tracer.ndates)
samples.append(samples_trcr)
samples = np.concatenate(samples, axis=1)
# Set different deviations for each window
else:
info("Deviations are set different.")
if getattr(self, "seed", False):
info(f"A seed of {self.seed_id} is used.")
np.random.seed(self.seed_id)
samples = np.random.normal(size=(nsample, controlvect.dim))
# Adjust the mean and the std
if self.unbias_ensemble:
info("Remove sampling bias")
samples = (samples - samples.mean(axis=0)) / samples.std(axis=0, ddof=1)
# Apply s = sqrt(B) * s + xb to set mean=xb and cov=B
samples = controlvect.sqrtbprod(samples.T, ensemble=True).T
# TODO: when correlations, the std of each sample is not necessarily 1.0, why ?
return samples
[docs]
def generate_ensemble_full_window(self, controlvect, nsample, set_dev_equal=True):
"""Draw a new ensemble of nsample + 3 members for the full assimilation window
based on the prior mean (xb), the B matrix and a normal distribution.
First member is only 1.0 (for propagating unperturbed flux)
Second member is the prior (for calculating the prior simulation)
Third member is the mean / analysis (for calculating the posterior simulation)
Args:
self (Plugin): mode Plugin
controlvect (Plugin): controlvect Plugin
nsample (int): number of members (besides prior and mean)
Returns:
x_samples (np.array): created members
"""
# Create samples based on a standard normal distribution
normal_samples = draw_normal_samples(self, controlvect, nsample,
set_dev_equal=set_dev_equal)
all_samples = np.concatenate(
[
np.ones((1, controlvect.dim)), # first sample with only 1.0
copy.deepcopy(controlvect.xb[np.newaxis, :]), # second sample always equal to xb
copy.deepcopy(controlvect.xb[np.newaxis, :]), # third sample always equal to xa
normal_samples # the rest is the random samples
],
axis=0
)
return all_samples
[docs]
def update_ensemble_new_segment(self, controlvect, x_samples, ddi, ddf,
list_cntrlv_idx, set_dev_equal=False):
"""When a new segment/cycle is started, this function updates the mean
of the new window based on the posterior mean of the window before it.
The new mean is a linear combination of the prior mean and the
posterior mean of the previous window.
This step can be seen as a very simple forecast step.
A new ensemble is then drawn based on this new mean and the existing ensemble
for this new window is filled with these new members.
Args:
self (Plugin): mode Plugin
controlvect (Plugin): controlvect Plugin
x_samples (np.array): existing members
ddi (datetime): initial date of the segment
ddf (datetime): end date of the segment
set_dev_equal: Force the deviations of scaling factors for different windows in the same
segment to be equal. The aim is to limit spurious anticorrelations between the
same control variables in two different windows that could limit the
analysis.
Returns:
x_samples (np.array): updated members
"""
# Find the number of samples
nsample = x_samples.shape[0] - 3
if not hasattr(self, "window_length"):
return x_samples
window_length = self.window_length
nlag = self.nlag
list_windows = date_range(ddi, ddf, window_length)
nwindows = len(list_windows)
prop_weights = self.mean_propagwgt
if nlag < 2:
return x_samples
info("Propagate posterior mean from the previous window to the new window...")
# Load the original prior samples for the entire period
samples_prior_full_file = "{}/ensemble/x_samples_prior.pickle".format(self.workdir)
with open(samples_prior_full_file, "rb") as f:
x_samples_prior = pickle.load(f)
# Define the segments to consider
ddi_all_wind = list_windows[0]
ddf_all_wind = list_windows[-1]
ddi_previous_wind = list_windows[-3]
ddf_previous_wind = list_windows[-2]
ddi_new_wind = list_windows[-2]
ddf_new_wind = list_windows[-1]
# Change the mean of the new window
dict_components_dates = get_cntrlv_trdates(controlvect)
components = controlvect.datavect.components
for comp in components.attributes:
component = getattr(components, comp)
for trcr, dates_array in dict_components_dates[comp].items():
tracer = getattr(component.parameters, trcr)
# Find controlvect indices within the windows
mask_new_wind = (ddi_new_wind <= dates_array) & (
dates_array < ddf_new_wind)
inds_new_wind = np.arange(
tracer.xpointer, tracer.xpointer + tracer.dim)
inds_new_wind = inds_new_wind[mask_new_wind]
if inds_new_wind.size == 0:
continue
mask_previous_wind = (ddi_previous_wind <= dates_array) & (
dates_array < ddf_previous_wind)
inds_previous_wind = np.arange(
tracer.xpointer, tracer.xpointer + tracer.dim)
inds_previous_wind = inds_previous_wind[mask_previous_wind]
# Fetch the propagation weights from the parameters
# Only one value for every component
if isinstance(prop_weights, int) or isinstance(prop_weights, float):
prop_weight_comp = prop_weights
# Different values for each component
else:
if hasattr(prop_weights, comp):
prop_weight_comp = getattr(prop_weights, comp)
else:
prop_weight_comp = 0
warning(f"No propagation weight has been prescribed for this component: {comp}.\n"
"The weight is set to 0.")
# Update the mean of the new window (TODO: generalize that using an averaging kernel of all windows)
new_mean = prop_weight_comp * controlvect.x[inds_previous_wind] + \
(1. - prop_weight_comp) * controlvect.xb[inds_previous_wind]
controlvect.x[inds_new_wind] = copy.deepcopy(new_mean)
x_samples[2, inds_new_wind] = copy.deepcopy(new_mean)
# Update the samples
mean_update = controlvect.x - controlvect.xb
x_samples[3:, inds_new_wind] = copy.deepcopy(
(x_samples[3:] + mean_update)[:, inds_new_wind]
)
# # Set the deviations equal for the first segment
# if set_dev_equal and ddi == self.datei:
# mask_all_wind = (ddi_all_wind <= dates_array) & (
# dates_array < ddf_all_wind)
# inds_all_wind = np.arange(
# tracer.xpointer, tracer.xpointer + tracer.dim)
# inds_all_wind = inds_all_wind[mask_all_wind]
# # Reset all samples and adjust the mean for this
# mean_update = controlvect.x - controlvect.xb
# x_samples[3:, inds_all_wind] = copy.deepcopy(
# np.tile(x_samples[3:, inds_new_wind], nwindows - 1) + \
# mean_update[inds_all_wind]
# )
# Reset all samples and adjust the mean for the whole segment
# if not set_dev_equal:
# mean_update = controlvect.x - controlvect.xb
# x_samples[3:, list_cntrlv_idx] = copy.deepcopy(
# (x_samples_prior[3:] + mean_update)[:, list_cntrlv_idx]
# )
return x_samples
[docs]
def run_ensemble(self, controlvect, nsample, x_samples, segment_dir, ddi, ddf, same_restart=True):
"""Dump the controlvect and run the ensemble with a batch-sampling method
or a separate-sampling method.
Args:
self (Plugin): mode Plugin
controlvect (Plugin): controlvect Plugin
nsample (int): number of members (besides prior and mean)
x_samples (np.array): existing members
segment_dir (str): path to the segment directory
ddi (datetime): initial date of the segment
ddf (datetime): end date of the segment
same_restart (boolean): whether samples in restart_inicond should be set equal to the mean or not
Returns:
None
"""
# Create the sample dir for the segment
samples_dir = "{}/x_samples/".format(segment_dir)
path.init_dir(samples_dir)
# Crop outside the segment and dump the samples
controlvect.x_ens = x_samples
cntrlv_file = "{}/controlvect_ensemble.pickle".format(samples_dir)
crop_dump(controlvect, cntrlv_file, ddi, ddf,
ensemble=True, dump_indexes=True)
# Add the controlvect file to the flushing list
self.direc2flush.append(cntrlv_file)
# Batch sampling
if self.batch_sampling:
# Run the batch sampling
batch_sampling(
self, controlvect, segment_dir, nsample, ddi, ddf, same_restart=same_restart)
# Separate sampling
else:
# Get the list of samples to transport within each run
max_nsamples_per_run = self.max_nsamples_per_run - self.include_system_samples * 3
offset = 3 if self.include_system_samples else 0
list_sample_chunks = list(range(offset, nsample + 3, max_nsamples_per_run)) + [nsample + 3]
sample_chunks = [list(range(s[0], s[1])) for s in zip(list_sample_chunks[:-1], list_sample_chunks[1:])]
if self.include_system_samples:
info("The three system-bound samples will be included in each run.")
sample_chunks = [[0, 1, 2] + l for l in sample_chunks]
# Crop outside the segment and dump the controlvect for each sample chunk
info("Sampling will be split into {} runs".format(len(sample_chunks)))
for ichunk, sch in enumerate(sample_chunks):
# Create sample chunk directory
sample_chunk_dir = "{}/x_samples_{:04d}/".format(samples_dir, ichunk)
path.init_dir(sample_chunk_dir)
controlvect.x_ens = x_samples[sch]
cntrlv_file = "{}/controlvect_ensemble.pickle".format(sample_chunk_dir)
crop_dump(controlvect, cntrlv_file, ddi, ddf,
ensemble=True, dump_indexes=True)
# Add sample directories to the flushing list
self.direc2flush.append(sample_chunk_dir)
# Run the separate sampling
separate_sampling(
self, controlvect, segment_dir, nsample, sample_chunks, ddi, ddf, same_restart=same_restart)
return
[docs]
def batch_sampling(self, controlvect, segment_dir, nsample, ddi, ddf, same_restart=True):
"""Create new Yaml configuration files using the batch_computation option
and launch a job to perform the batch sampling.
Args:
self (Plugin): mode Plugin
controlvect (Plugin): controlvect Plugin
segment_dir (str): path to the segment directory
nsample (int): number of members (besides prior and mean)
ddi (datetime): initial date of the segment
ddf (datetime): end date of the segment
same_restart (boolean): whether samples in restart_inicond should be set equal to the mean or not
Returns:
None
"""
workdir = controlvect.workdir
platform = self.platform
# Create the batch directory
samples_dir = "{}/x_samples/".format(segment_dir)
batch_dir = "{}/batch_sampling/".format(segment_dir)
path.init_dir(batch_dir)
# Dump the reference control vector to the batch directory
cntrlv_file = "{}/controlvect.pickle".format(batch_dir)
crop_dump(controlvect, cntrlv_file, ddi, ddf)
# Updating configuration dictionary
yml_dict = \
self.from_yaml(self.reference_instances["reference_setup"].def_file)
yml_dict.update(
{"workdir": batch_dir,
"datei": ddi,
"datef": ddf,
"mode": {"plugin": {"name": "forward", "version": "std"}},
}
)
yml_dict["obsvect"] = {
**yml_dict.get("obsvect", {}),
**{"plugin": {"name": "standard", "version": "std"},
"dir_obsvect": "{}/obsvect/".format(workdir), "dump_type": "nc"}
}
yml_dict["controlvect"] = {
**yml_dict.get("controlvect", {}),
**{"plugin": {"name": "standard", "version": "std"},
"reload_xb": True,
"reload_file": "{}/controlvect.pickle".format(batch_dir)}}
yml_dict["obsoperator"] = {
**yml_dict.get("obsoperator", {}),
**{"plugin": {"name": "standard", "version": "std"},
"batch_computation": {
"nsamples": nsample + 3,
"dir_samples": samples_dir,
"dont_propagate": get_trids_dontpropagate(yml_dict, controlvect),
"dont_propagate_obsvect": get_trids_dontpropagate_obsvect(yml_dict, controlvect)
}}}
# Indicate that a restart file from a previous segment must be used
if self.datei != ddi:
acspecies = self.obsoperator.model.chemistry.acspecies.attributes #TODO: change ? some models might not have chemistry plugin ?
dd_restart = ddi - pd.Timedelta(self.obsoperator.model.periods)
rf_stem, rf_ext = os.path.splitext(self.restart_format)
post_restart_dir = "{}/ensemble/chain/".format(workdir)
post_restart_file = dd_restart.strftime("{}_post{}".format(rf_stem, rf_ext))
yml_dict["model"] = {
**yml_dict.get("model", {}),
**{"ensrf_restart_file": True,
"ensrf_datei": self.datei,
"ensrf_same_restart": same_restart
}
}
yml_dict["datavect"] = {
**yml_dict.get("datavect", {
"components": {}
}),
}
yml_dict["datavect"]["components"]["restart_inicond"] = {
"parameters": {
spec: {
"dir": post_restart_dir,
"file": post_restart_file,
}
for spec in acspecies
}
}
# Remove inicond in datavect and transforms calling inicond
if "inicond" in yml_dict["datavect"]["components"]:
del yml_dict["datavect"]["components"]["inicond"]
if "transform_pipe" in yml_dict["controlvect"]:
config_trfs = yml_dict["controlvect"]["transform_pipe"]
paths_trfs_cntrlv = get_path_in_dict(config_trfs, "inicond")
for p in paths_trfs_cntrlv:
if p[0] in config_trfs:
del config_trfs[p[0]]
# Add restart_inicond to dont_propagate
yml_dict["obsoperator"]["batch_computation"]["dont_propagate"] += \
[['restart_inicond', spec] for spec in acspecies]
# Dumps new yml file
yml_file = "{}/config_batch.yml".format(batch_dir)
with open(yml_file, "w") as outfile:
ordered_dump(outfile, yml_dict)
# Either submit a sub-job or run inside the same job
if not self.batch_subjob:
info("Running the batch sampling as a sub-process")
process = subprocess.Popen(
f"{platform.python} -m pycif {yml_file}".split(),
cwd=batch_dir,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = process.communicate()
else:
job_file = os.path.join(batch_dir, "job_pycif_batch")
info("Submitting batch sampling for {} sample"
.format(nsample))
job_id = platform.submit_job(
"{} -m pycif {}".format(platform.python, yml_file),
job_file
)
# Check that jobs are over
while not platform.check_jobs([job_id]):
time.sleep(platform.sleep_time)
# Create the H_matrix directory
h_dir = "{}/H_matrix/".format(segment_dir)
path.init_dir(h_dir)
self.direc2flush.append("{}/H_matrix/".format(segment_dir))
# Post-process monitors and fetch perturbed xb
fwd_dir = "{}/obsoperator/fwd_0000/obsvect/".format(batch_dir)
for component in os.listdir(fwd_dir):
comp_dir = "{}/{}/".format(fwd_dir, component)
dict_param_pert = {
param.split("__sample#")[0]:("__sample#" in param)
for param in os.listdir(comp_dir)
}
for param, pert in dict_param_pert.items():
for isample in range(nsample + 3):
if pert:
param_monitor = "{}/{}__sample#{:03d}/monitor.nc".format(
comp_dir, param, isample)
else:
param_monitor = "{}/{}/monitor.nc".format(
comp_dir, param)
param_hmatrix = "{}/obsvect_{:04d}/{}/{}".format(
h_dir, isample, component, param)
path.init_dir(param_hmatrix)
path.link(param_monitor, os.path.join(param_hmatrix, "monitor.nc"))
return
[docs]
def separate_sampling(self, controlvect, segment_dir, nsample, sample_chunks, ddi, ddf, same_restart=True):
"""Create new Yaml configuration files for each member and launch
multiple jobs to perform each sampling.
Args:
self (Plugin): mode Plugin
controlvect (Plugin): controlvect Plugin
segment_dir (str): path to the segment directory
nsample (int): number of members (besides prior and mean)
sample_chunks (list): list of sample indices in each chunk
ddi (datetime): initial date of the segment
ddf (datetime): end date of the segment
same_restart (boolean): whether samples in restart_inicond should be set equal to the mean or not
Returns:
None
"""
workdir = controlvect.workdir
platform = self.platform
nsamples_per_run = [len(sch) for sch in sample_chunks]
# Loop over jobs to run
list_jobs = []
for ichunk, sch in enumerate(sample_chunks):
# Create the batch directory
samples_dir = "{}/x_samples/".format(segment_dir)
sample_chunk_dir = "{}/x_samples_{:04d}/".format(samples_dir, ichunk)
base_dir = "{}/sampling_{:04d}/".format(segment_dir, ichunk)
path.init_dir(base_dir)
# Dump the reference control vector to the sample directory
cntrlv_file = "{}/controlvect.pickle".format(base_dir)
crop_dump(controlvect, cntrlv_file, ddi, ddf)
# Updating configuration dictionary
yml_dict = \
self.from_yaml(self.reference_instances["reference_setup"].def_file)
yml_dict.update(
{"workdir": base_dir,
"datei": ddi,
"datef": ddf,
"mode": {"plugin": {"name": "forward", "version": "std"}},
}
)
yml_dict["obsvect"] = {
**yml_dict.get("obsvect", {}),
**{"plugin": {"name": "standard", "version": "std"},
"dir_obsvect": "{}/obsvect/".format(workdir), "dump_type": "nc"}
}
yml_dict["controlvect"] = {
**yml_dict.get("controlvect", {}),
**{"plugin": {"name": "standard", "version": "std"},
"reload_xb": True,
"reload_file": "{}/controlvect.pickle".format(base_dir)}}
yml_dict["obsoperator"] = {
**yml_dict.get("obsoperator", {}),
**{"plugin": {"name": "standard", "version": "std"},
"batch_computation": {
"nsamples": nsamples_per_run[ichunk],
"dir_samples": sample_chunk_dir,
"dont_propagate": get_trids_dontpropagate(yml_dict, controlvect),
"dont_propagate_obsvect": get_trids_dontpropagate_obsvect(yml_dict, controlvect)
}}}
# Indicate that a restart file from a previous segment must be used
if self.datei != ddi:
acspecies = self.obsoperator.model.chemistry.acspecies.attributes #TODO: change ? some models might not have chemistry plugin ?
dd_restart = ddi - pd.Timedelta(self.obsoperator.model.periods)
rf_stem, rf_ext = os.path.splitext(self.restart_format)
post_restart_dir = "{}/ensemble/chain/sampling_{:04d}/".format(workdir, ichunk)
post_restart_file = dd_restart.strftime("{}_post{}".format(rf_stem, rf_ext))
yml_dict["model"] = {
**yml_dict.get("model", {}),
**{"ensrf_restart_file": True,
"ensrf_datei": self.datei,
"ensrf_same_restart": same_restart
}
}
yml_dict["datavect"] = {
**yml_dict.get("datavect", {
"components": {}
}),
}
yml_dict["datavect"]["components"]["restart_inicond"] = {
"parameters": {
spec: {
"dir": post_restart_dir,
"file": post_restart_file,
}
for spec in acspecies
}
}
# Remove inicond in datavect and transforms calling inicond
if "inicond" in yml_dict["datavect"]["components"]:
del yml_dict["datavect"]["components"]["inicond"]
if "transform_pipe" in yml_dict["controlvect"]:
config_trfs = yml_dict["controlvect"]["transform_pipe"]
paths_trfs_cntrlv = get_path_in_dict(config_trfs, "inicond")
for p in paths_trfs_cntrlv:
if p[0] in config_trfs:
del config_trfs[p[0]]
# Add restart_inicond to dont_propagate
yml_dict["obsoperator"]["batch_computation"]["dont_propagate"] += \
[['restart_inicond', spec] for spec in acspecies]
# Dumps new yml file
yml_file = "{}/config_base_{:04d}.yml".format(base_dir, ichunk)
with open(yml_file, "w") as outfile:
ordered_dump(outfile, yml_dict)
# Run the base function as an independent process
job_file = os.path.join(base_dir, "job_pycif_base_{:04d}".format(ichunk))
info("Submitting base function {} from {}"
.format(ichunk + 1, len(sample_chunks)))
job_id = platform.submit_job(
"{} -m pycif {}".format(platform.python, yml_file),
job_file
)
list_jobs.append(job_id)
# Check that jobs are over
while not platform.check_jobs(list_jobs):
time.sleep(platform.sleep_time)
# Create the H_matrix directory
h_dir = "{}/H_matrix/".format(segment_dir)
path.init_dir(h_dir)
self.direc2flush.append("{}/H_matrix/".format(segment_dir))
# Post-process monitors and fetch perturbed xb
nsamples_per_run = [len(sch) for sch in sample_chunks]
for ichunk, sch in enumerate(sample_chunks):
# Declare the folder from which the data must be fetched
base_dir = "{}/sampling_{:04d}/".format(segment_dir, ichunk)
fwd_dir = "{}/obsoperator/fwd_0000/obsvect/".format(base_dir)
# Associate the sample_id in the ensemble to the sample_id in the chunk
if self.include_system_samples:
range_samples_ens = sch if ichunk == 0 else sch[3:]
range_samples_chunk = sch if ichunk == 0 else list(range(3, len(sch)))
else:
range_samples_ens = sch
range_samples_chunk = list(range(len(sch)))
# Loop over the components
for component in os.listdir(fwd_dir):
comp_dir = "{}/{}/".format(fwd_dir, component)
dict_param_pert = {
param.split("__sample#")[0]:("__sample#" in param)
for param in os.listdir(comp_dir)
}
# Loop over the parameters
for param, pert in dict_param_pert.items():
# Loop over the samples and link the monitors
for isample_ens, isample_chunk in zip(range_samples_ens, range_samples_chunk):
if pert:
param_monitor = "{}/{}__sample#{:03d}/monitor.nc".format(
comp_dir, param, isample_chunk)
else:
param_monitor = "{}/{}/monitor.nc".format(
comp_dir, param)
param_hmatrix = "{}/obsvect_{:04d}/{}/{}".format(
h_dir, isample_ens, component, param)
path.init_dir(param_hmatrix)
path.link(param_monitor, os.path.join(param_hmatrix, "monitor.nc"))
return