from logging import debug, info
from typing import Any, List, Union
import numpy as np
import pandas as pd
from ....utils.iterators import iter_tracers
from .base_function import BaseFunction, BaseFunctionSamplingBatch
from .periods import get_controlvect_periods, groupby_period
# Aliases for type hints
Mode = Any
[docs]
def split_by_parameter(
self: Mode,
batch_list: List[BaseFunctionSamplingBatch],
) -> List[BaseFunctionSamplingBatch]:
"""Splits the base function sampling batches by observation parameter
Args:
self (Mode): the mode plugin
batch_list (list of BaseFunctionSamplingBatch): base function sampling batches
Returns:
list of BaseFunctionSamplingBatch: splitted base function sampling batches
"""
# Get the list of parameters of the observation vector
obs_parameters = [
parameter for (_, parameter), tracer in iter_tracers(self.datavect)
if tracer.isobs
]
# New list of base functions batches
new_batch_list = []
batch_index = 0
# Loop over all base function batches
for batch in batch_list:
param_dict = {parameter: [] for parameter in obs_parameters}
param_dict['_others'] = []
# Loop over all base functions in the batch
for trid, base_func in zip(batch.tracer_list, batch.iter_all()):
if trid in batch.in_out_dict:
# trid is part of some transform in the controlvect tranform pipe
base_func_obs = [
parameter for parameter in obs_parameters
if parameter in [
param_out for (_, param_out) in batch.in_out_dict[trid]
]
]
else:
# trid is not part of any transform in the controlvect tranform pipe
base_func_obs = [trid[1]]
if len(base_func_obs) == 1 and base_func_obs[0] in obs_parameters:
# Base function tracer lead to only one observation parameter
param_dict[base_func_obs[0]].append(base_func)
else:
param_dict['_others'].append(base_func)
for base_func_list in param_dict.values():
if base_func_list:
new_batch_list.append(BaseFunctionSamplingBatch(
self,
batch_index,
[base_func.index for base_func in base_func_list],
[base_func.component for base_func in base_func_list],
[base_func.parameter for base_func in base_func_list],
batch.date_start,
batch.date_end
))
batch_index += 1
return new_batch_list
[docs]
def split_sampling_batches(
self: Mode,
batch_list: List[BaseFunctionSamplingBatch],
n: int
) -> List[BaseFunctionSamplingBatch]:
"""Splits the base function sampling batches
Args:
self (Mode): the mode plugin
batch_list (list of BaseFunctionSamplingBatch): base function sampling batches
n (int): maximum batch size
Returns:
list of BaseFunctionSamplingBatch: splitted base function sampling batches
"""
if n <= 0:
raise ValueError("negative split size")
# New list of base functions batches
base_func_list = []
batch_index = 0
for batch in batch_list:
# Ignored response functions
ignored = [base_function for base_function in batch.iter_all()
if base_function.is_ignored()]
# Ignored response function information
ignored_indices = [base_func.index for base_func in ignored]
ignored_components = [base_func.component for base_func in ignored]
ignored_parameters = [base_func.parameter for base_func in ignored]
# Not ignored response function information
indices = [base_func.index for base_func in batch]
components = [base_func.component for base_func in batch]
parameters = [base_func.parameter for base_func in batch]
for i in range(0, batch.n_samples, n):
stop = i + n if i + n <= batch.n_samples else None
# Last batch
if stop == batch.n_samples - 1 or stop is None:
# Putting all ignored response functions in the last batch
base_func_list.append(BaseFunctionSamplingBatch(
self, batch_index,
indices[i:stop] + ignored_indices,
components[i:stop] + ignored_components,
parameters[i:stop] + ignored_parameters,
batch.date_start, batch.date_end
))
else:
base_func_list.append(BaseFunctionSamplingBatch(
self, batch_index,
indices[i:stop],
components[i:stop],
parameters[i:stop],
batch.date_start, batch.date_end
))
batch_index += 1
return base_func_list
[docs]
def init_base_functions(
self: Mode
) -> Union[List[BaseFunction], List[BaseFunctionSamplingBatch]]:
"""Get the time window of each of the controlvect element and initialize
the corresponding base functions or batch sampling of base functions
Args:
self (Mode): the mode plugin
Returns:
list of BaseFunction or list of BaseFunctionSamplingBatch: simulations to run
"""
df_periods = get_controlvect_periods(self)
# Using batch sampling
if self.use_batch_sampling:
date_start = df_periods.loc[:, 'date_start'].to_numpy()
date_end = df_periods.loc[:, 'date_end'].to_numpy()
batches = groupby_period(date_start, date_end)
base_func_list = [
BaseFunctionSamplingBatch(
self, batch_index, indices,
df_periods.loc[indices, 'component'].to_list(),
df_periods.loc[indices, 'parameter'].to_list(),
di, df
)
for batch_index, (indices, di, df) in enumerate(batches)
if not np.isnan(df) # Drop base funcs outside of the sim window
]
if self.separate_parameters:
base_func_list = split_by_parameter(self, base_func_list)
if hasattr(self, 'batch_sampling_size'):
base_func_list = split_sampling_batches(
self, base_func_list, self.batch_sampling_size)
batch_size = base_func_list[0].n_samples
# Assuming that all batches have an equal size
info(f"Running {len(base_func_list)} sampling batch of {batch_size} "
f"response functions with '{self.model.plugin.name}' model in "
f"'{self.run_mode}' mode")
debug(f"All batches size: {[b.n_samples for b in base_func_list]}")
# Not using batch sampling
else:
base_func_list = [
BaseFunction(self, index, comp, param, di, df) # type: ignore
for index, (comp, param, di, df) in df_periods.iterrows()
if df is not pd.NaT # Drop base funcs outside of the sim window
]
info(f"Running {len(base_func_list)} response functions with "
f"'{self.model.plugin.name}' model in '{self.run_mode}' mode")
if self.first_period_only:
first_period_list = [base_func for base_func in base_func_list
if base_func.date_start == self.datei]
info("Retaining only the first period, for a count of "
f"{len(first_period_list)} response functions")
return first_period_list
return base_func_list