from typing import Any, Generator, Tuple
import numpy as np
import pandas as pd
from ....utils.iterators import iter_tracers
# Aliases for type hints
Mode = Any
[docs]
def apply_spin_down(series: pd.Series, spin_down: str) -> pd.DatetimeIndex:
"""Applies spin down in place to the 'date_end' column of DataFrame df
Args:
series (pd.Series): Series of Timestamps, use a view of a DataFrame
column, ex: `apply_spin_down(df.loc[a:b, colname], spin_down)`
spin_down (str): valid pandas period alias (1D, 1M, ...)
"""
if spin_down[-1] == 'S':
raise ValueError(
f"The spin_down '{spin_down}' is not accepted by 'to_period'. "
"The 'S' anchor at the end is not accepted for pandas periods. "
"Please remove the 'S' anchor from your YAML configuration."
)
date_end = series.dt.to_period(spin_down)
date_end = pd.PeriodIndex(date_end + 1).to_timestamp()
return date_end
[docs]
def get_controlvect_periods(self: Mode) -> pd.DataFrame:
"""Computing the time period covered by each element of the control vector
Args:
self (Mode): the mode plugin
Returns:
array, array : start and end times of the period
"""
xdim = self.controlvect.dim
df = pd.DataFrame(
data={'component': xdim * ["_None_"], # placeholders
'parameter': xdim * ["_None_"], # placeholders
'date_start': np.zeros(xdim).astype('datetime64[ns]'),
'date_end': np.zeros(xdim).astype('datetime64[ns]')},
index=range(xdim),
)
for (component_name, tracer_name), tracer in iter_tracers(self.datavect):
if not tracer.iscontrol:
continue
# Adding "-1" to 'stop' because of slices behaviour in pandas ".loc"
# that includes both the start and stop values of a slice
start = tracer.xpointer
stop = tracer.xpointer + tracer.dim - 1
df.loc[start:stop, 'component'] = component_name
df.loc[start:stop, 'parameter'] = tracer_name
if self.full_period:
# All response function will be run over the full inversion period
df.loc[start:stop, 'date_start'] = pd.to_datetime(self.datei)
df.loc[start:stop, 'date_end'] = pd.to_datetime(self.datef)
else:
# Response function will be run over the period their control vector
# covers, with optional spin up and spin down
# Getting the tracer control vector elements dates
date_indices, _, _ = np.meshgrid(
np.arange(tracer.ndates),
np.arange(tracer.vresoldim),
np.arange(tracer.hresoldim)
)
dates = tracer.dates.astype('datetime64[ns]')
if tracer.ndates > 1:
dates = np.concatenate([dates, [np.datetime64('NaT')]])
else:
dates = np.concatenate([dates, dates])
df.loc[start:stop, 'date_start'] = dates[date_indices.flat]
df.loc[start:stop, 'date_end'] = dates[(date_indices + 1).flat]
# Getting spin down value, priority is given to values in the datavect
if 'spin_down' in tracer.attributes:
spin_down = tracer.spin_down
elif hasattr(self, 'spin_down'):
spin_down = self.spin_down
else:
spin_down = None
# Applying spin down
if spin_down is not None:
df.loc[start:stop, 'date_end'] = apply_spin_down( # type: ignore
df.loc[start:stop, 'date_end'], spin_down)
# Cropping period's end time to the simulation window
bound = pd.to_datetime(self.datef)
df.loc[df.date_end > bound, 'date_end'] = bound
# Checking if all placeholders have been replaced
assert (df.component != "_None_").all()
assert (df.parameter != "_None_").all()
return df
[docs]
def groupby_period(
date_start: np.ndarray,
date_end: np.ndarray,
) -> Generator[Tuple[np.ndarray, np.datetime64, np.datetime64], None, None]:
"""Group control vector per time periods
Args:
date_start (array of datetime64): start date of periods
date_end (array of datetime64): end date of periods
Yields:
array of int, datetime64, datetime64: control vector indices within the
period, start date of the period, end date of the period
"""
periods = np.concatenate(
[date_start[:, np.newaxis], date_end[:, np.newaxis]], axis=1)
# NaT datetimes are not grouped with numpy.unique
# Replacing NaT with a placeholder that is not in 'periods'
nat_placeholder = np.nanmax(periods) + 1
assert nat_placeholder not in periods
periods[np.isnan(periods)] = nat_placeholder
unique_periods, unique_inverse = np.unique(
periods, axis=0, return_inverse=True)
for unique_indice, (di, df) in enumerate(unique_periods):
x_indices = np.flatnonzero(unique_inverse == unique_indice)
return_df = np.datetime64('NaT') if df == nat_placeholder else df
yield x_indices, di, return_df