import copy
import shutil
import os
import pandas as pd
import numpy as np
import pickle
from textwrap import dedent
from pandas.tseries.frequencies import to_offset
from logging import info, warn, debug
from .metrics import Metrics_ensrf
from .optimization import serial_optimization, bulk_optimization
from .build_H import build_Hx
from .sampling import generate_ensemble_full_window, update_ensemble_new_segment
from .sampling import run_ensemble
from .localization import fill_loc_matrices, complement_obsvect
from .segment_cropping import select_obs2assimilate, get_cntrlv_idx_segment, get_cntrlv_idx_comp
from .utils import check_tresol
from ....utils.dates import date_range
from ....utils import path
[docs]
def execute(self, **kwargs):
"""Performs an EnSRF inversion based on Peters et al. (2005).
Args:
self (Plugin): mode Plugin
Returns:
controlvect: optimized controlvect Plugin
obsvect: optimized obsvect Plugin
"""
# Working directory
workdir = self.workdir
# Control vector
controlvect = self.controlvect
# Observation operator
obsoper = self.obsoperator
# Observation vector
obsvect = self.obsvect
# Number of samples
nsample = self.nsample
# List of directories to flush
self.direc2flush = []
# Dump initial controlvect
dir_cntrlv = "{}/controlvect/".format(workdir)
cntrlv_file = "{}/controlvect.pickle".format(dir_cntrlv)
controlvect.dump(
cntrlv_file,
to_netcdf=controlvect.save_out_netcdf or self.save_out_netcdf,
dir_netcdf=dir_cntrlv)
# Get additional information about observations
obsvect = complement_obsvect(obsvect)
# Define assimilation cycles/segments
total_freq = to_offset(pd.Timedelta(self.datef - self.datei))
window_length = getattr(self, "window_length", total_freq.freqstr)
list_windows_datei = date_range(self.datei, self.datef, window_length)
list_segments_datei = list_windows_datei[:-self.nlag]
list_segments_datef = list_windows_datei[self.nlag:]
# Initialize metrics
# metrics = Metrics_ensrf(controlvect, obsvect, window_length, list_windows_datei, self.nlag)
# # Some verbose
# towrite = dedent("""
# =====================================================================================
# Running an Ensemble Square Root Filter with the following modules:
# Observation operator: {}
# Model: {}
# Degrees of freedom for one window: {:.0f}
# Dimension of the controlvector for one window: {}
# Assimilation segments:
# """).format(obsoper.plugin.name,
# obsoper.model.plugin.name,
# metrics.data['dof_b_hresol'], metrics.data['hresol'].values) \
# + "\n".join([" - {} / {}".format(di, df)
# for di, df in zip(list_segments_datei, list_segments_datef)]) \
# + "\n=====================================================================================\n"
# info(towrite)
# Define mask to store which observations have already been assimilated
mask_obsvect = obsvect.obsvect_mask
mask_obs_assimilated = False * mask_obsvect
# The variable x will contain the background state for each segment
# This state might have already been optimized if nlag > 1
controlvect.x = copy.deepcopy(controlvect.xb)
# The variable xa will contain the optimized state
controlvect.xa = copy.deepcopy(controlvect.xb)
# The variable pa will contain the diagonal elements
# of the posterior uncertainty matrix
controlvect.pa = np.zeros_like(controlvect.xb)
# Initialize prior and posterior simulated values
sim_prior = np.zeros_like(obsvect.ysim)
sim_post = np.zeros_like(obsvect.ysim)
# Initialize hx_samples
hx_samples = np.nan * np.ones((obsvect.dim, nsample + 3))
hx_samples_final = np.nan * np.ones((obsvect.dim, nsample + 3))
# Check consistency between tresol and window_length
# check_tresol(controlvect, window_length)
# Either load a pre-generated ensemble or create and store a new one
samples_prior_full_file = "{}/ensemble/x_samples_prior.pickle".format(
workdir)
if os.path.exists(samples_prior_full_file):
info("Found a pre-generated ensemble.")
with open(samples_prior_full_file, "rb") as f:
x_samples = pickle.load(f)
else:
# Generate new x_samples for the full window
info("Generate and dump a new ensemble.")
x_samples = generate_ensemble_full_window(
self, controlvect, nsample,
set_dev_equal=self.set_deviations_equal)
# Dump prior x_samples for the full window
with open(samples_prior_full_file, "wb") as f:
pickle.dump(x_samples, f, pickle.HIGHEST_PROTOCOL)
# ---------------------------------------------------------
# Loop
# ---------------------------------------------------------
# Loop over segments
for iseg, (ddi, ddf) in enumerate(zip(list_segments_datei, list_segments_datef)):
# Load Hx or build it from members if not already computed
ddi_str = ddi.strftime("%Y%m%d-%H:%M")
ddf_str = ddf.strftime("%Y%m%d-%H:%M")
info("Computing segment: {} to {}".format(ddi_str, ddf_str))
# Create the segment, x_samples and hx_samples directories
segment_dir = "{}/ensemble/{}_to_{}".format(workdir, ddi_str, ddf_str)
x_samples_dir = "{}/x_samples/".format(segment_dir)
hx_samples_dir = "{}/hx_samples/".format(segment_dir)
path.init_dir(segment_dir)
path.init_dir(x_samples_dir)
path.init_dir(hx_samples_dir)
# Find the elements of the controlvect that are in this segment
list_cntrlv_idx = get_cntrlv_idx_segment(controlvect, ddi, ddf)
# ---------------------------------------------------------
# Look for pre-computed segments or run them
# ---------------------------------------------------------
# Declare the files that we must check before continuing
x_samples_prior_file = os.path.join(
x_samples_dir, "x_samples_prior.pickle")
x_samples_post_file = os.path.join(
x_samples_dir, "x_samples_post.pickle")
hx_samples_prior_file = os.path.join(
hx_samples_dir, "hx_samples_prior.pickle")
hx_samples_post_file = os.path.join(
hx_samples_dir, "hx_samples_post.pickle")
list_obs2assim_idx_file = os.path.join(
hx_samples_dir, "list_obs2assim_idx.pickle")
# Reset optimization requirement
do_optimization = True
# If posterior data exist, no need to perform assimilation
# in this segment
if os.path.exists(x_samples_prior_file) and \
os.path.exists(hx_samples_prior_file) and \
os.path.exists(x_samples_post_file) and \
os.path.exists(hx_samples_post_file) and \
os.path.exists(list_obs2assim_idx_file):
info("Found posterior data from previous run."
" Skip the optimization part...")
do_optimization = False
with open(x_samples_prior_file, "rb") as f:
controlvect.x[list_cntrlv_idx] = pickle.load(f)[2]
with open(list_obs2assim_idx_file, "rb") as f:
list_obs2assim_idx = pickle.load(f)
# TODO: maybe set hx_samples to NaN everytime we start a new segment ?
with open(hx_samples_prior_file, "rb") as f:
hx_samples[list_obs2assim_idx] = pickle.load(f)
hx_mean = copy.deepcopy(hx_samples[:, 2])
# If prior data exist but not posterior,
# get existing samples and matrix
elif os.path.exists(x_samples_prior_file) and \
os.path.exists(hx_samples_prior_file) and \
os.path.exists(list_obs2assim_idx_file):
info("Found prior data from a previous run."
" Proceed to the optimization part...")
with open(x_samples_prior_file, "rb") as f:
x_samples[:, list_cntrlv_idx] = pickle.load(f)
controlvect.x = copy.deepcopy(x_samples[2])
with open(list_obs2assim_idx_file, "rb") as f:
list_obs2assim_idx = pickle.load(f)
mask_obs2assim = False * obsvect.obsvect_mask
mask_obs2assim[list_obs2assim_idx] = True
mask_obs_assimilated = mask_obs2assim | mask_obs_assimilated
with open(hx_samples_prior_file, "rb") as f:
hx_samples[list_obs2assim_idx] = pickle.load(f)
# Create deviations
hx_mean = copy.deepcopy(hx_samples[:, 2])
hx_samples_dev = copy.deepcopy(
hx_samples[:, 3:] - hx_mean[:, np.newaxis])
x_samples_dev = copy.deepcopy(x_samples[3:] - x_samples[2])
# Run from scratch
else:
info("Either pre-generated prior samples or observation"
" operator is missing. Run the segment...")
# Update the samples (forecast) in the new window (no impact for first segment)
x_samples = update_ensemble_new_segment(
self, controlvect, x_samples, ddi, ddf, list_cntrlv_idx)
# Dump corresponding samples (for this segment) before optimization
with open(x_samples_prior_file, "wb") as f:
pickle.dump(x_samples[:, list_cntrlv_idx],
f, pickle.HIGHEST_PROTOCOL)
# Compute Hx for each member of the ensemble
run_ensemble(self, controlvect, nsample,
x_samples, segment_dir, ddi, ddf)
# Build Hx matrix from output simulations of members
hx_samples_run = build_Hx(obsvect, segment_dir, nsample, ddi, ddf)
# Mask to select observations within the segment
mask_obs_segment = ~np.any(np.isnan(hx_samples_run), axis=1)
mask_obs2assim = mask_obsvect & mask_obs_segment & ~mask_obs_assimilated
list_obsseg_idx = np.nonzero(mask_obs_segment)[0]
list_obs2assim_idx = np.nonzero(mask_obs2assim)[0]
# Fill hx_samples with the new assimilated window
hx_samples[list_obs2assim_idx] = hx_samples_run[list_obs2assim_idx]
# TODO: check if there are Os or Nans in hx_samples
# Create deviations
hx_mean = copy.deepcopy(hx_samples[:, 2])
hx_samples_dev = copy.deepcopy(
hx_samples[:, 3:] - hx_mean[:, np.newaxis])
x_samples_dev = copy.deepcopy(x_samples[3:] - x_samples[2])
# Update mask to assimilate observations only once
mask_obs_assimilated = mask_obs2assim | mask_obs_assimilated
# Dump Hx prior samples and list_obs2assim_idx
with open(hx_samples_prior_file, "wb") as f:
pickle.dump(hx_samples[list_obs2assim_idx],
f, pickle.HIGHEST_PROTOCOL)
with open(list_obs2assim_idx_file, "wb") as f:
pickle.dump(list_obs2assim_idx, f, pickle.HIGHEST_PROTOCOL)
# ---------------------------------------------------------
# Optimization
# ---------------------------------------------------------
if do_optimization:
# Initialize and fill the localization vector
loc_matrix_state = np.empty(
(len(list_obs2assim_idx), len(list_cntrlv_idx)))
loc_matrix_obs = np.empty(
(len(list_obs2assim_idx), len(list_obs2assim_idx)))
if hasattr(self, "localization"):
debug(f"Fill the localization matrix...")
lons_obs2assim = np.asarray(
obsvect.metadata['lon'])[list_obs2assim_idx]
lats_obs2assim = np.asarray(
obsvect.metadata['lat'])[list_obs2assim_idx]
loc_matrix_state, loc_matrix_obs = fill_loc_matrices(
self, controlvect, lons_obs2assim, lats_obs2assim, list_cntrlv_idx
)
# Perform the optimization for the current segment
args_optim = (
self,
controlvect,
nsample,
list_obs2assim_idx,
list_cntrlv_idx,
obsvect.yobs[list_obs2assim_idx],
obsvect.yobs_err[list_obs2assim_idx],
x_samples_dev[:, list_cntrlv_idx],
hx_mean[list_obs2assim_idx],
hx_samples_dev[list_obs2assim_idx],
loc_matrix_state,
loc_matrix_obs,
)
if self.serial_optimization:
info(
f"Assimilate {len(list_obs2assim_idx)} observations sequentially...")
xa, xs_dev, hxm, hxs_dev = serial_optimization(*args_optim)
else:
info(
f"Assimilate {len(list_obs2assim_idx)} observations in bulk...")
xa, xs_dev, hxm, hxs_dev = bulk_optimization(*args_optim)
# Compute the diagonal of the posterior covariance matrix
pa_diag = (xs_dev.T * xs_dev.T).sum(axis=-1) / (nsample - 1)
controlvect.pa[list_cntrlv_idx] = copy.deepcopy(pa_diag)
# Update the samples (mean and members)
x_samples[2, list_cntrlv_idx] = copy.deepcopy(xa)
x_samples[3:, list_cntrlv_idx] = copy.deepcopy(xa + xs_dev)
# Update the deviations
x_samples_dev[:, list_cntrlv_idx] = copy.deepcopy(xs_dev)
hx_samples_dev[list_obs2assim_idx] = copy.deepcopy(hxs_dev)
hx_samples[list_obs2assim_idx, 2] = copy.deepcopy(hxm)
hx_samples[list_obs2assim_idx, 3:] = copy.deepcopy(
hxm[:, np.newaxis] + hxs_dev)
# Store the optimized state in the controlvect
controlvect.xa[list_cntrlv_idx] = copy.deepcopy(xa)
# Dump posterior samples
with open(x_samples_post_file, "wb") as f:
pickle.dump(x_samples[:, list_cntrlv_idx],
f, pickle.HIGHEST_PROTOCOL)
# Dump Hx posterior samples
with open(hx_samples_post_file, "wb") as f:
pickle.dump(hx_samples[list_obs2assim_idx],
f, pickle.HIGHEST_PROTOCOL)
# ---------------------------------------------------------
# If optimization already done, load posterior data
# ---------------------------------------------------------
else:
with open(x_samples_post_file, "rb") as f:
x_samples[:, list_cntrlv_idx] = pickle.load(f)
controlvect.xa = copy.deepcopy(x_samples[2])
x_samples_dev = copy.deepcopy(x_samples[3:] - x_samples[2])
xs_dev = x_samples_dev[:, list_cntrlv_idx]
pa_diag = (xs_dev.T * xs_dev.T).sum(axis=-1) / (nsample - 1)
controlvect.pa[list_cntrlv_idx] = copy.deepcopy(pa_diag)
with open(list_obs2assim_idx_file, "rb") as f:
list_obs2assim_idx = pickle.load(f)
mask_obs2assim = False * obsvect.obsvect_mask
mask_obs2assim[list_obs2assim_idx] = True
mask_obs_assimilated = mask_obs2assim | mask_obs_assimilated
with open(hx_samples_post_file, "rb") as f:
hx_samples[list_obs2assim_idx] = pickle.load(f)
hx_samples_dev = copy.deepcopy(
hx_samples[:, 3:] - hx_samples[:, 2:3])
# ---------------------------------------------------------
# Posterior simulation
# ---------------------------------------------------------
# Run a posterior simulation with all members until the beginning of the next segment
ddf_post = ddf if iseg == len(list_segments_datei) - 1 \
else list_segments_datei[iseg + 1]
ddf_post_str = ddf_post.strftime("%Y%m%d-%H:%M")
dir_post = "{}/posterior_fwd/".format(segment_dir)
dir_post_obsvect = "{}/H_matrix/obsvect_0000".format(dir_post)
if not os.path.exists(dir_post_obsvect):
info("Computing posterior simulation for segment: {} to {}".format(
ddi_str, ddf_post_str))
run_ensemble(self, controlvect, nsample,
x_samples, dir_post, ddi, ddf_post, same_restart=False)
# Chain the posterior restart for the next assimilation segment
if iseg != len(list_segments_datei) - 1:
post_restart_date = ddf_post - pd.Timedelta(obsoper.model.periods)
rf = self.restart_format
rf_stem, rf_ext = os.path.splitext(rf)
if self.batch_sampling:
post_restart_file = post_restart_date.strftime(
"{}/batch_sampling/obsoperator/fwd_0000/chain/{}".format(dir_post, rf))
post_restart_file_chain = post_restart_date.strftime(
"{}/ensemble/chain/{}_post{}".format(workdir, rf_stem, rf_ext))
if os.path.exists(post_restart_file):
path.init_dir(os.path.dirname(post_restart_file_chain))
shutil.copy(post_restart_file, post_restart_file_chain)
else:
if not os.path.exists(post_restart_file_chain):
raise FileNotFoundError(
f"Posterior restart file {post_restart_file} not found.\n"
f"Check `restart_format` parameter."
)
# Add directories to the flushing list
self.direc2flush.append(
"{}/batch_sampling/".format(segment_dir))
self.direc2flush.append("{}/batch_sampling/".format(dir_post))
else:
max_nsamples_per_run = self.max_nsamples_per_run - self.include_system_samples * 3
offset = 3 if self.include_system_samples else 0
list_sample_chunks = list(
range(offset, nsample + 3, max_nsamples_per_run)) + [nsample + 3]
for ichunk in range(len(list_sample_chunks) - 1):
post_restart_file = post_restart_date.strftime(
"{}/sampling_{:04d}/obsoperator/fwd_0000/chain/{}".format(dir_post, ichunk, rf))
post_restart_file_chain = post_restart_date.strftime(
"{}/ensemble/chain/sampling_{:04d}/{}_post{}".format(workdir, ichunk, rf_stem, rf_ext))
if os.path.exists(post_restart_file):
path.init_dir(os.path.dirname(post_restart_file_chain))
shutil.copy(post_restart_file, post_restart_file_chain)
else:
raise FileNotFoundError(
f"Posterior restart file {post_restart_file} not found.\n"
f"Check `restart_format` parameter."
)
# Add sampling directories to the flushing list
self.direc2flush.append(
"{}/sampling_{:04d}/".format(segment_dir, ichunk))
self.direc2flush.append(
"{}/sampling_{:04d}/".format(dir_post, ichunk))
# ---------------------------------------------------------
# Save the posterior values for the segment
# ---------------------------------------------------------
info("Dump prior and posterior obsvect for this segment...")
# Read posterior simulations
hx_samples_post = build_Hx(obsvect, dir_post, nsample, ddi, ddf_post)
mask_obs_post = ~np.any(np.isnan(hx_samples_post), axis=1)
# Dump prior obsvect for each segment
hx_prior = hx_samples_post[:, 1].flatten()
sim_prior += np.nan_to_num(hx_prior)
obsvect.ysim = copy.deepcopy(hx_prior)
dir_dump_prior = "{}/obsvect_prior/".format(segment_dir)
obsvect.dump(dir_dump_prior)
# Dump posterior obsvect for each segment
hx_post = hx_samples_post[:, 2].flatten()
sim_post += np.nan_to_num(hx_post)
obsvect.ysim = copy.deepcopy(hx_post)
dir_dump_post = "{}/obsvect_posterior/".format(segment_dir)
obsvect.dump(dir_dump_post)
# Fill the simulated values for each sample
# hx_samples_final[mask_obs_post] = hx_samples_post[mask_obs_post]
# ---------------------------------------------------------
# Update the metrics
# ---------------------------------------------------------
# Load x before it is optimized in this segment
with open(x_samples_prior_file, "rb") as f:
x_samples_segprior = np.zeros_like(x_samples)
x_samples_segprior[:, list_cntrlv_idx] = pickle.load(f)
# # Update the metrics
# info("Updating the metrics...")
# metrics.update(
# self.level_metrics,
# list_segments_datei,
# iseg,
# ddi,
# ddf,
# controlvect,
# obsvect,
# mask_obs2assim,
# mask_obs_assimilated,
# x_samples, # samples (for posterior)
# x_samples_segprior, # segment-prior samples
# hx_samples, # sim samples (for posterior)
# hx_mean, # segment-prior sim
# sim_post, # final posterior sim
# )
# The new background state is set equal to the optimized state
controlvect.x = copy.deepcopy(controlvect.xa)
# Dump transitory controlvect
dir_cntrlv = "{}/controlvect/".format(workdir)
cntrlv_file = "{}/controlvect.pickle".format(dir_cntrlv)
controlvect.dump(
cntrlv_file,
to_netcdf=controlvect.save_out_netcdf or self.save_out_netcdf,
dir_netcdf=dir_cntrlv
)
# Flush the unncessary directories
if self.flushrun:
info("Flushing unnecessary directories and files")
info(self.direc2flush)
# for d in self.direc2flush:
# os.remove(d)
# ---------------------------------------------------------
# Post-process for the full assimilation window
# ---------------------------------------------------------
info("Dump final prior and posterior obsvect...")
# Dump prior obsvect for the full window
# TODO: problem with obs overlapping two windows (leads to 0)
sim_prior[sim_prior == 0] = np.nan
obsvect.ysim = copy.deepcopy(sim_prior)
dir_dump_prior = "{}/obsvect_prior".format(workdir)
obsvect.dump(dir_dump_prior)
# Dump posterior obsvect for the full window
# TODO: problem with obs overlapping two windows (leads to 0)
sim_post[sim_post == 0] = np.nan
obsvect.ysim = copy.deepcopy(sim_post)
dir_dump_post = "{}/obsvect_posterior".format(workdir)
obsvect.dump(dir_dump_post)
# Dump posterior x_samples for the full window
info("Dump final posterior controlvect and samples...")
samples_post_full_file = "{}/ensemble/x_samples_post.pickle".format(
workdir)
with open(samples_post_full_file, "wb") as f:
pickle.dump(x_samples, f, pickle.HIGHEST_PROTOCOL)
# Dump controlvect indexes for each component/tracer
dict_components_idx = get_cntrlv_idx_comp(
controlvect, self.datei, self.datef)
comp_idx_file = "{}/ensemble/comp_indexes.pickle".format(workdir)
with open(comp_idx_file, "wb") as f:
pickle.dump(dict_components_idx, f, pickle.HIGHEST_PROTOCOL)
# Dump posterior control vector
dir_cntrlv = "{}/controlvect/".format(workdir)
cntrlv_file = "{}/controlvect.pickle".format(dir_cntrlv)
controlvect.dump(
cntrlv_file,
to_netcdf=controlvect.save_out_netcdf or self.save_out_netcdf,
dir_netcdf=dir_cntrlv
)
# # Check final cost function and RMSE
# metrics.final_update(obsvect, controlvect, sim_prior, sim_post, hx_samples)
# # Dump the metrics
# metrics_file = os.path.join(workdir, "metrics.nc")
# metrics.dump(metrics_file)
# Return optimized control vector and observation vector
return controlvect, obsvect