Source code for pycif.plugins.modes.ensrf.execute

import copy
import shutil
import os
import pandas as pd
import numpy as np
import pickle
from textwrap import dedent
from pandas.tseries.frequencies import to_offset
from logging import info, warn, debug
from .metrics import Metrics_ensrf
from .optimization import serial_optimization, bulk_optimization
from .build_H import build_Hx
from .sampling import generate_ensemble_full_window, update_ensemble_new_segment
from .sampling import run_ensemble
from .localization import fill_loc_matrices, complement_obsvect
from .segment_cropping import select_obs2assimilate, get_cntrlv_idx_segment, get_cntrlv_idx_comp
from .utils import check_tresol
from ....utils.dates import date_range
from ....utils import path


[docs] def execute(self, **kwargs): """Performs an EnSRF inversion based on Peters et al. (2005). Args: self (Plugin): mode Plugin Returns: controlvect: optimized controlvect Plugin obsvect: optimized obsvect Plugin """ # Working directory workdir = self.workdir # Control vector controlvect = self.controlvect # Observation operator obsoper = self.obsoperator # Observation vector obsvect = self.obsvect # Number of samples nsample = self.nsample # List of directories to flush self.direc2flush = [] # Dump initial controlvect dir_cntrlv = "{}/controlvect/".format(workdir) cntrlv_file = "{}/controlvect.pickle".format(dir_cntrlv) controlvect.dump( cntrlv_file, to_netcdf=controlvect.save_out_netcdf or self.save_out_netcdf, dir_netcdf=dir_cntrlv) # Get additional information about observations obsvect = complement_obsvect(obsvect) # Define assimilation cycles/segments total_freq = to_offset(pd.Timedelta(self.datef - self.datei)) window_length = getattr(self, "window_length", total_freq.freqstr) list_windows_datei = date_range(self.datei, self.datef, window_length) list_segments_datei = list_windows_datei[:-self.nlag] list_segments_datef = list_windows_datei[self.nlag:] # Initialize metrics # metrics = Metrics_ensrf(controlvect, obsvect, window_length, list_windows_datei, self.nlag) # # Some verbose # towrite = dedent(""" # ===================================================================================== # Running an Ensemble Square Root Filter with the following modules: # Observation operator: {} # Model: {} # Degrees of freedom for one window: {:.0f} # Dimension of the controlvector for one window: {} # Assimilation segments: # """).format(obsoper.plugin.name, # obsoper.model.plugin.name, # metrics.data['dof_b_hresol'], metrics.data['hresol'].values) \ # + "\n".join([" - {} / {}".format(di, df) # for di, df in zip(list_segments_datei, list_segments_datef)]) \ # + "\n=====================================================================================\n" # info(towrite) # Define mask to store which observations have already been assimilated mask_obsvect = obsvect.obsvect_mask mask_obs_assimilated = False * mask_obsvect # The variable x will contain the background state for each segment # This state might have already been optimized if nlag > 1 controlvect.x = copy.deepcopy(controlvect.xb) # The variable xa will contain the optimized state controlvect.xa = copy.deepcopy(controlvect.xb) # The variable pa will contain the diagonal elements # of the posterior uncertainty matrix controlvect.pa = np.zeros_like(controlvect.xb) # Initialize prior and posterior simulated values sim_prior = np.zeros_like(obsvect.ysim) sim_post = np.zeros_like(obsvect.ysim) # Initialize hx_samples hx_samples = np.nan * np.ones((obsvect.dim, nsample + 3)) hx_samples_final = np.nan * np.ones((obsvect.dim, nsample + 3)) # Check consistency between tresol and window_length # check_tresol(controlvect, window_length) # Either load a pre-generated ensemble or create and store a new one samples_prior_full_file = "{}/ensemble/x_samples_prior.pickle".format( workdir) if os.path.exists(samples_prior_full_file): info("Found a pre-generated ensemble.") with open(samples_prior_full_file, "rb") as f: x_samples = pickle.load(f) else: # Generate new x_samples for the full window info("Generate and dump a new ensemble.") x_samples = generate_ensemble_full_window( self, controlvect, nsample, set_dev_equal=self.set_deviations_equal) # Dump prior x_samples for the full window with open(samples_prior_full_file, "wb") as f: pickle.dump(x_samples, f, pickle.HIGHEST_PROTOCOL) # --------------------------------------------------------- # Loop # --------------------------------------------------------- # Loop over segments for iseg, (ddi, ddf) in enumerate(zip(list_segments_datei, list_segments_datef)): # Load Hx or build it from members if not already computed ddi_str = ddi.strftime("%Y%m%d-%H:%M") ddf_str = ddf.strftime("%Y%m%d-%H:%M") info("Computing segment: {} to {}".format(ddi_str, ddf_str)) # Create the segment, x_samples and hx_samples directories segment_dir = "{}/ensemble/{}_to_{}".format(workdir, ddi_str, ddf_str) x_samples_dir = "{}/x_samples/".format(segment_dir) hx_samples_dir = "{}/hx_samples/".format(segment_dir) path.init_dir(segment_dir) path.init_dir(x_samples_dir) path.init_dir(hx_samples_dir) # Find the elements of the controlvect that are in this segment list_cntrlv_idx = get_cntrlv_idx_segment(controlvect, ddi, ddf) # --------------------------------------------------------- # Look for pre-computed segments or run them # --------------------------------------------------------- # Declare the files that we must check before continuing x_samples_prior_file = os.path.join( x_samples_dir, "x_samples_prior.pickle") x_samples_post_file = os.path.join( x_samples_dir, "x_samples_post.pickle") hx_samples_prior_file = os.path.join( hx_samples_dir, "hx_samples_prior.pickle") hx_samples_post_file = os.path.join( hx_samples_dir, "hx_samples_post.pickle") list_obs2assim_idx_file = os.path.join( hx_samples_dir, "list_obs2assim_idx.pickle") # Reset optimization requirement do_optimization = True # If posterior data exist, no need to perform assimilation # in this segment if os.path.exists(x_samples_prior_file) and \ os.path.exists(hx_samples_prior_file) and \ os.path.exists(x_samples_post_file) and \ os.path.exists(hx_samples_post_file) and \ os.path.exists(list_obs2assim_idx_file): info("Found posterior data from previous run." " Skip the optimization part...") do_optimization = False with open(x_samples_prior_file, "rb") as f: controlvect.x[list_cntrlv_idx] = pickle.load(f)[2] with open(list_obs2assim_idx_file, "rb") as f: list_obs2assim_idx = pickle.load(f) # TODO: maybe set hx_samples to NaN everytime we start a new segment ? with open(hx_samples_prior_file, "rb") as f: hx_samples[list_obs2assim_idx] = pickle.load(f) hx_mean = copy.deepcopy(hx_samples[:, 2]) # If prior data exist but not posterior, # get existing samples and matrix elif os.path.exists(x_samples_prior_file) and \ os.path.exists(hx_samples_prior_file) and \ os.path.exists(list_obs2assim_idx_file): info("Found prior data from a previous run." " Proceed to the optimization part...") with open(x_samples_prior_file, "rb") as f: x_samples[:, list_cntrlv_idx] = pickle.load(f) controlvect.x = copy.deepcopy(x_samples[2]) with open(list_obs2assim_idx_file, "rb") as f: list_obs2assim_idx = pickle.load(f) mask_obs2assim = False * obsvect.obsvect_mask mask_obs2assim[list_obs2assim_idx] = True mask_obs_assimilated = mask_obs2assim | mask_obs_assimilated with open(hx_samples_prior_file, "rb") as f: hx_samples[list_obs2assim_idx] = pickle.load(f) # Create deviations hx_mean = copy.deepcopy(hx_samples[:, 2]) hx_samples_dev = copy.deepcopy( hx_samples[:, 3:] - hx_mean[:, np.newaxis]) x_samples_dev = copy.deepcopy(x_samples[3:] - x_samples[2]) # Run from scratch else: info("Either pre-generated prior samples or observation" " operator is missing. Run the segment...") # Update the samples (forecast) in the new window (no impact for first segment) x_samples = update_ensemble_new_segment( self, controlvect, x_samples, ddi, ddf, list_cntrlv_idx) # Dump corresponding samples (for this segment) before optimization with open(x_samples_prior_file, "wb") as f: pickle.dump(x_samples[:, list_cntrlv_idx], f, pickle.HIGHEST_PROTOCOL) # Compute Hx for each member of the ensemble run_ensemble(self, controlvect, nsample, x_samples, segment_dir, ddi, ddf) # Build Hx matrix from output simulations of members hx_samples_run = build_Hx(obsvect, segment_dir, nsample, ddi, ddf) # Mask to select observations within the segment mask_obs_segment = ~np.any(np.isnan(hx_samples_run), axis=1) mask_obs2assim = mask_obsvect & mask_obs_segment & ~mask_obs_assimilated list_obsseg_idx = np.nonzero(mask_obs_segment)[0] list_obs2assim_idx = np.nonzero(mask_obs2assim)[0] # Fill hx_samples with the new assimilated window hx_samples[list_obs2assim_idx] = hx_samples_run[list_obs2assim_idx] # TODO: check if there are Os or Nans in hx_samples # Create deviations hx_mean = copy.deepcopy(hx_samples[:, 2]) hx_samples_dev = copy.deepcopy( hx_samples[:, 3:] - hx_mean[:, np.newaxis]) x_samples_dev = copy.deepcopy(x_samples[3:] - x_samples[2]) # Update mask to assimilate observations only once mask_obs_assimilated = mask_obs2assim | mask_obs_assimilated # Dump Hx prior samples and list_obs2assim_idx with open(hx_samples_prior_file, "wb") as f: pickle.dump(hx_samples[list_obs2assim_idx], f, pickle.HIGHEST_PROTOCOL) with open(list_obs2assim_idx_file, "wb") as f: pickle.dump(list_obs2assim_idx, f, pickle.HIGHEST_PROTOCOL) # --------------------------------------------------------- # Optimization # --------------------------------------------------------- if do_optimization: # Initialize and fill the localization vector loc_matrix_state = np.empty( (len(list_obs2assim_idx), len(list_cntrlv_idx))) loc_matrix_obs = np.empty( (len(list_obs2assim_idx), len(list_obs2assim_idx))) if hasattr(self, "localization"): debug(f"Fill the localization matrix...") lons_obs2assim = np.asarray( obsvect.metadata['lon'])[list_obs2assim_idx] lats_obs2assim = np.asarray( obsvect.metadata['lat'])[list_obs2assim_idx] loc_matrix_state, loc_matrix_obs = fill_loc_matrices( self, controlvect, lons_obs2assim, lats_obs2assim, list_cntrlv_idx ) # Perform the optimization for the current segment args_optim = ( self, controlvect, nsample, list_obs2assim_idx, list_cntrlv_idx, obsvect.yobs[list_obs2assim_idx], obsvect.yobs_err[list_obs2assim_idx], x_samples_dev[:, list_cntrlv_idx], hx_mean[list_obs2assim_idx], hx_samples_dev[list_obs2assim_idx], loc_matrix_state, loc_matrix_obs, ) if self.serial_optimization: info( f"Assimilate {len(list_obs2assim_idx)} observations sequentially...") xa, xs_dev, hxm, hxs_dev = serial_optimization(*args_optim) else: info( f"Assimilate {len(list_obs2assim_idx)} observations in bulk...") xa, xs_dev, hxm, hxs_dev = bulk_optimization(*args_optim) # Compute the diagonal of the posterior covariance matrix pa_diag = (xs_dev.T * xs_dev.T).sum(axis=-1) / (nsample - 1) controlvect.pa[list_cntrlv_idx] = copy.deepcopy(pa_diag) # Update the samples (mean and members) x_samples[2, list_cntrlv_idx] = copy.deepcopy(xa) x_samples[3:, list_cntrlv_idx] = copy.deepcopy(xa + xs_dev) # Update the deviations x_samples_dev[:, list_cntrlv_idx] = copy.deepcopy(xs_dev) hx_samples_dev[list_obs2assim_idx] = copy.deepcopy(hxs_dev) hx_samples[list_obs2assim_idx, 2] = copy.deepcopy(hxm) hx_samples[list_obs2assim_idx, 3:] = copy.deepcopy( hxm[:, np.newaxis] + hxs_dev) # Store the optimized state in the controlvect controlvect.xa[list_cntrlv_idx] = copy.deepcopy(xa) # Dump posterior samples with open(x_samples_post_file, "wb") as f: pickle.dump(x_samples[:, list_cntrlv_idx], f, pickle.HIGHEST_PROTOCOL) # Dump Hx posterior samples with open(hx_samples_post_file, "wb") as f: pickle.dump(hx_samples[list_obs2assim_idx], f, pickle.HIGHEST_PROTOCOL) # --------------------------------------------------------- # If optimization already done, load posterior data # --------------------------------------------------------- else: with open(x_samples_post_file, "rb") as f: x_samples[:, list_cntrlv_idx] = pickle.load(f) controlvect.xa = copy.deepcopy(x_samples[2]) x_samples_dev = copy.deepcopy(x_samples[3:] - x_samples[2]) xs_dev = x_samples_dev[:, list_cntrlv_idx] pa_diag = (xs_dev.T * xs_dev.T).sum(axis=-1) / (nsample - 1) controlvect.pa[list_cntrlv_idx] = copy.deepcopy(pa_diag) with open(list_obs2assim_idx_file, "rb") as f: list_obs2assim_idx = pickle.load(f) mask_obs2assim = False * obsvect.obsvect_mask mask_obs2assim[list_obs2assim_idx] = True mask_obs_assimilated = mask_obs2assim | mask_obs_assimilated with open(hx_samples_post_file, "rb") as f: hx_samples[list_obs2assim_idx] = pickle.load(f) hx_samples_dev = copy.deepcopy( hx_samples[:, 3:] - hx_samples[:, 2:3]) # --------------------------------------------------------- # Posterior simulation # --------------------------------------------------------- # Run a posterior simulation with all members until the beginning of the next segment ddf_post = ddf if iseg == len(list_segments_datei) - 1 \ else list_segments_datei[iseg + 1] ddf_post_str = ddf_post.strftime("%Y%m%d-%H:%M") dir_post = "{}/posterior_fwd/".format(segment_dir) dir_post_obsvect = "{}/H_matrix/obsvect_0000".format(dir_post) if not os.path.exists(dir_post_obsvect): info("Computing posterior simulation for segment: {} to {}".format( ddi_str, ddf_post_str)) run_ensemble(self, controlvect, nsample, x_samples, dir_post, ddi, ddf_post, same_restart=False) # Chain the posterior restart for the next assimilation segment if iseg != len(list_segments_datei) - 1: post_restart_date = ddf_post - pd.Timedelta(obsoper.model.periods) rf = self.restart_format rf_stem, rf_ext = os.path.splitext(rf) if self.batch_sampling: post_restart_file = post_restart_date.strftime( "{}/batch_sampling/obsoperator/fwd_0000/chain/{}".format(dir_post, rf)) post_restart_file_chain = post_restart_date.strftime( "{}/ensemble/chain/{}_post{}".format(workdir, rf_stem, rf_ext)) if os.path.exists(post_restart_file): path.init_dir(os.path.dirname(post_restart_file_chain)) shutil.copy(post_restart_file, post_restart_file_chain) else: if not os.path.exists(post_restart_file_chain): raise FileNotFoundError( f"Posterior restart file {post_restart_file} not found.\n" f"Check `restart_format` parameter." ) # Add directories to the flushing list self.direc2flush.append( "{}/batch_sampling/".format(segment_dir)) self.direc2flush.append("{}/batch_sampling/".format(dir_post)) else: max_nsamples_per_run = self.max_nsamples_per_run - self.include_system_samples * 3 offset = 3 if self.include_system_samples else 0 list_sample_chunks = list( range(offset, nsample + 3, max_nsamples_per_run)) + [nsample + 3] for ichunk in range(len(list_sample_chunks) - 1): post_restart_file = post_restart_date.strftime( "{}/sampling_{:04d}/obsoperator/fwd_0000/chain/{}".format(dir_post, ichunk, rf)) post_restart_file_chain = post_restart_date.strftime( "{}/ensemble/chain/sampling_{:04d}/{}_post{}".format(workdir, ichunk, rf_stem, rf_ext)) if os.path.exists(post_restart_file): path.init_dir(os.path.dirname(post_restart_file_chain)) shutil.copy(post_restart_file, post_restart_file_chain) else: raise FileNotFoundError( f"Posterior restart file {post_restart_file} not found.\n" f"Check `restart_format` parameter." ) # Add sampling directories to the flushing list self.direc2flush.append( "{}/sampling_{:04d}/".format(segment_dir, ichunk)) self.direc2flush.append( "{}/sampling_{:04d}/".format(dir_post, ichunk)) # --------------------------------------------------------- # Save the posterior values for the segment # --------------------------------------------------------- info("Dump prior and posterior obsvect for this segment...") # Read posterior simulations hx_samples_post = build_Hx(obsvect, dir_post, nsample, ddi, ddf_post) mask_obs_post = ~np.any(np.isnan(hx_samples_post), axis=1) # Dump prior obsvect for each segment hx_prior = hx_samples_post[:, 1].flatten() sim_prior += np.nan_to_num(hx_prior) obsvect.ysim = copy.deepcopy(hx_prior) dir_dump_prior = "{}/obsvect_prior/".format(segment_dir) obsvect.dump(dir_dump_prior) # Dump posterior obsvect for each segment hx_post = hx_samples_post[:, 2].flatten() sim_post += np.nan_to_num(hx_post) obsvect.ysim = copy.deepcopy(hx_post) dir_dump_post = "{}/obsvect_posterior/".format(segment_dir) obsvect.dump(dir_dump_post) # Fill the simulated values for each sample # hx_samples_final[mask_obs_post] = hx_samples_post[mask_obs_post] # --------------------------------------------------------- # Update the metrics # --------------------------------------------------------- # Load x before it is optimized in this segment with open(x_samples_prior_file, "rb") as f: x_samples_segprior = np.zeros_like(x_samples) x_samples_segprior[:, list_cntrlv_idx] = pickle.load(f) # # Update the metrics # info("Updating the metrics...") # metrics.update( # self.level_metrics, # list_segments_datei, # iseg, # ddi, # ddf, # controlvect, # obsvect, # mask_obs2assim, # mask_obs_assimilated, # x_samples, # samples (for posterior) # x_samples_segprior, # segment-prior samples # hx_samples, # sim samples (for posterior) # hx_mean, # segment-prior sim # sim_post, # final posterior sim # ) # The new background state is set equal to the optimized state controlvect.x = copy.deepcopy(controlvect.xa) # Dump transitory controlvect dir_cntrlv = "{}/controlvect/".format(workdir) cntrlv_file = "{}/controlvect.pickle".format(dir_cntrlv) controlvect.dump( cntrlv_file, to_netcdf=controlvect.save_out_netcdf or self.save_out_netcdf, dir_netcdf=dir_cntrlv ) # Flush the unncessary directories if self.flushrun: info("Flushing unnecessary directories and files") info(self.direc2flush) # for d in self.direc2flush: # os.remove(d) # --------------------------------------------------------- # Post-process for the full assimilation window # --------------------------------------------------------- info("Dump final prior and posterior obsvect...") # Dump prior obsvect for the full window # TODO: problem with obs overlapping two windows (leads to 0) sim_prior[sim_prior == 0] = np.nan obsvect.ysim = copy.deepcopy(sim_prior) dir_dump_prior = "{}/obsvect_prior".format(workdir) obsvect.dump(dir_dump_prior) # Dump posterior obsvect for the full window # TODO: problem with obs overlapping two windows (leads to 0) sim_post[sim_post == 0] = np.nan obsvect.ysim = copy.deepcopy(sim_post) dir_dump_post = "{}/obsvect_posterior".format(workdir) obsvect.dump(dir_dump_post) # Dump posterior x_samples for the full window info("Dump final posterior controlvect and samples...") samples_post_full_file = "{}/ensemble/x_samples_post.pickle".format( workdir) with open(samples_post_full_file, "wb") as f: pickle.dump(x_samples, f, pickle.HIGHEST_PROTOCOL) # Dump controlvect indexes for each component/tracer dict_components_idx = get_cntrlv_idx_comp( controlvect, self.datei, self.datef) comp_idx_file = "{}/ensemble/comp_indexes.pickle".format(workdir) with open(comp_idx_file, "wb") as f: pickle.dump(dict_components_idx, f, pickle.HIGHEST_PROTOCOL) # Dump posterior control vector dir_cntrlv = "{}/controlvect/".format(workdir) cntrlv_file = "{}/controlvect.pickle".format(dir_cntrlv) controlvect.dump( cntrlv_file, to_netcdf=controlvect.save_out_netcdf or self.save_out_netcdf, dir_netcdf=dir_cntrlv ) # # Check final cost function and RMSE # metrics.final_update(obsvect, controlvect, sim_prior, sim_post, hx_samples) # # Dump the metrics # metrics_file = os.path.join(workdir, "metrics.nc") # metrics.dump(metrics_file) # Return optimized control vector and observation vector return controlvect, obsvect