import os
from logging import debug, info
from typing import Any, Iterable, List, Literal
import numpy as np
import xarray as xr
from ....utils.datastores.crop_monitor import crop_monitor
from ....utils.datastores.dump import read_datastore
from ....utils.iterators import iter_tracers
from ....utils.sparse_array import to_dense_dataset, to_sparse_dataset
from .base_function import BaseFunction
# Aliases for type hints
Mode = Any
ControlVect = Any
ObsVect = Any
[docs]
def get_obsvect_var_name(
run_mode: Literal["fwd", "tl"]
) -> Literal["sim", "sim_tl"]:
"""Map a run mode string to the corresponding obs-vector column name.
Args:
run_mode: ``'fwd'`` for a full forward run, ``'tl'`` for the
tangent-linear operator.
Returns:
``'sim'`` for ``run_mode='fwd'``, ``'sim_tl'`` for ``'tl'``.
Raises:
ValueError: if *run_mode* is not ``'fwd'`` or ``'tl'``.
"""
if run_mode == "fwd":
return "sim"
elif run_mode == "tl":
return "sim_tl"
else:
raise ValueError(f"unexpected run_mode '{run_mode}'")
[docs]
def init_h(controlvect: ControlVect, obsvect: ObsVect) -> np.ndarray:
"""Allocate a zero-filled H matrix of the correct shape.
Args:
controlvect (ControlVect): control vector plugin providing ``dim``.
obsvect (ObsVect): observation vector plugin providing ``dim``.
Returns:
np.ndarray: zero array of shape ``(obs_dim, control_dim)``.
"""
xdim = controlvect.dim
ydim = obsvect.dim
return np.zeros((ydim, xdim))
[docs]
def build_h(
self: Mode,
base_function_list: Iterable[BaseFunction]
) -> np.ndarray:
"""Iterates over all response functions and observation vector tracers to
fill the H matrix
Args:
self(Mode): the mode plugin
base_function_list (list of BaseFunction): base functions
Returns:
2D array: H matrix
"""
var_name = get_obsvect_var_name(self.run_mode)
# Filling the H matrix with contribution from each response function
h_matrix = init_h(self.controlvect, self.obsvect)
# Looping over all ran response function
for base_func in base_function_list:
# Looping over all observation vector tracers
for (component_name, tracer_name), tracer in iter_tracers(self.datavect):
if not tracer.isobs:
continue
# Getting tracer's main datastore
global_ds = tracer.datastore
n_obs, _ = global_ds.shape
local_contrib = np.zeros(n_obs)
if not base_func.is_ignored():
# Reading response function tracer's datastore
monitor_file = os.path.join(
base_func.obsdir, component_name, tracer_name, "monitor.nc")
local_ds = read_datastore(monitor_file)
# Extracting values from datasore
values = local_ds.loc[:, ('maindata', var_name)].values
if self.full_period:
local_contrib[:] = values
else:
mask = crop_monitor(global_ds, base_func.date_start,
base_func.date_end, return_index=True)
local_contrib[mask] = values
start = tracer.ypointer
stop = tracer.ypointer + tracer.dim
h_matrix[start:stop, base_func.index] = local_contrib
assert np.all(np.isfinite(h_matrix))
if self.clamp_h_matrix_to_zero:
debug(
f"clamping H matrix to zero, min value was {np.min(h_matrix):.2e}")
h_matrix[h_matrix < 0.0] = 0.0
return h_matrix
[docs]
def read_h_matrix(self: Mode, path_list: List[str]) -> np.ndarray:
h_matrix = init_h(self.controlvect, self.obsvect)
path_list = [
path if os.path.isabs(path) else os.path.join(self.workdir, path)
for path in path_list
]
for path in path_list:
info(f"Reloading H matrix from '{path}'")
with xr.open_dataset(path) as ds:
if 'sparse_H_matrix' in ds:
# Avoid computing dense B and R matrices
coord_name = ds['sparse_H_matrix'].attrs['sparse_coords']
ds = to_dense_dataset(ds[['sparse_H_matrix', coord_name]])
h_matrix += ds['H_matrix'].values
if self.clamp_h_matrix_to_zero:
h_matrix[h_matrix < 0.0] = 0.0
return h_matrix
[docs]
def fill_obsvect(self: Mode, h_matrix: np.ndarray, xb: np.ndarray) -> None:
"""Iterates over all observation vector tracers to fill the observation
vector in-place with the response function cntributions in the H matrix
Args:
self (Mode): the mode plugin
h_matrix (2D array): the filled H matrix
xb (1D array): control vector prior
"""
var_name = get_obsvect_var_name(self.run_mode)
# Filling the observation vector
if var_name == 'sim':
self.obsvect.ysim = h_matrix.dot(xb)
elif var_name == 'sim_tl':
self.obsvect.dy = np.sum(h_matrix, axis=1)
else:
raise KeyError(var_name)
[docs]
def dump_obsvect_decomp(self: Mode, h_matrix: np.ndarray, decompdir: str) -> None:
"""Iterates over all control vector tracers to get the decomposition of
their contribution to each observations.
Dumps the results in NetCDF files
Args:
self (Mode): the mode plugin
h_matrix (2D array): the filled H matrix
"""
ydim = self.obsvect.dim
var_name = get_obsvect_var_name(self.run_mode)
# Looping over all control vector tracers
for (component_name, tracer_name), tracer in iter_tracers(self.datavect):
if not tracer.iscontrol:
continue
start = tracer.xpointer
stop = tracer.xpointer + tracer.dim
shape = ydim, tracer.ndates, tracer.vresoldim, tracer.hresoldim
# Extracting and unflatenning traver data
tracer_data = h_matrix[:, start:stop].reshape(shape) # type: ignore
ds = xr.Dataset(
{var_name: (['index', 'date', 'vresol', 'hresol'], tracer_data)},
coords={
'index': (['index'], np.arange(ydim), {
'standard_name': "index",
'long_name': "observation vector dimension"
}),
'date': (['date'], tracer.dates.astype('datetime64[ns]'), {
'standard_name': "date",
'long_name': "control vector temporal dimension"
}),
'vresol': (['vresol'], np.arange(tracer.vresoldim), {
'standard_name': "vresol",
'long_name': "control vector vertical dimension",
'kind': tracer.vresol
}),
'hresol': (['hresol'], np.arange(tracer.hresoldim), {
'standard_name': "hresol",
'long_name': "control vector horizontal dimension",
'kind': tracer.hresol
})
},
attrs={
'title': "Response functions results",
'component': component_name,
'parameter': tracer_name
}
)
# Adding region indices
if tracer.hresol == "regions":
region_infos = tracer.regions_infos
region_file = os.path.join(region_infos.dir, region_infos.file)
ds['region_id'] = (['hresol'], tracer.region_ids, {
'standard_name': "region_id",
'long_name': "region id within the region mask file",
'region_file': region_file
})
if self.dump_sparse_arrays:
ds = to_sparse_dataset(ds, variable_names=[var_name])
# Dumping the dataset
subdir = os.path.join(decompdir, component_name, tracer_name)
os.makedirs(subdir, exist_ok=True)
dump_path = os.path.join(subdir, "obsvect_decomp.nc")
info(f"Dumping {component_name, tracer_name} response function "
f"contribution to '{dump_path}'")
ds.to_netcdf(dump_path)