Source code for pycif.plugins.modes.response_functions.init_base_functions

from logging import debug, info
from typing import Any, List, Union

import numpy as np
import pandas as pd

from ....utils.iterators import iter_tracers
from .base_function import BaseFunction, BaseFunctionSamplingBatch
from .periods import get_controlvect_periods, groupby_period

# Aliases for type hints
Mode = Any


[docs] def split_by_parameter( self: Mode, batch_list: List[BaseFunctionSamplingBatch], ) -> List[BaseFunctionSamplingBatch]: """Splits the base function sampling batches by observation parameter Args: self (Mode): the mode plugin batch_list (list of BaseFunctionSamplingBatch): base function sampling batches Returns: list of BaseFunctionSamplingBatch: splitted base function sampling batches """ # Get the list of parameters of the observation vector obs_parameters = [ parameter for (_, parameter), tracer in iter_tracers(self.datavect) if tracer.isobs ] # New list of base functions batches new_batch_list = [] batch_index = 0 # Loop over all base function batches for batch in batch_list: param_dict = {parameter: [] for parameter in obs_parameters} param_dict['_others'] = [] # Loop over all base functions in the batch for trid, base_func in zip(batch.tracer_list, batch.iter_all()): if trid in batch.in_out_dict: # trid is part of some transform in the controlvect tranform pipe base_func_obs = [ parameter for parameter in obs_parameters if parameter in [ param_out for (_, param_out) in batch.in_out_dict[trid] ] ] else: # trid is not part of any transform in the controlvect tranform pipe base_func_obs = [trid[1]] if len(base_func_obs) == 1 and base_func_obs[0] in obs_parameters: # Base function tracer lead to only one observation parameter param_dict[base_func_obs[0]].append(base_func) else: param_dict['_others'].append(base_func) for base_func_list in param_dict.values(): if base_func_list: new_batch_list.append(BaseFunctionSamplingBatch( self, batch_index, [base_func.index for base_func in base_func_list], [base_func.component for base_func in base_func_list], [base_func.parameter for base_func in base_func_list], batch.date_start, batch.date_end )) batch_index += 1 return new_batch_list
[docs] def split_sampling_batches( self: Mode, batch_list: List[BaseFunctionSamplingBatch], n: int ) -> List[BaseFunctionSamplingBatch]: """Splits the base function sampling batches Args: self (Mode): the mode plugin batch_list (list of BaseFunctionSamplingBatch): base function sampling batches n (int): maximum batch size Returns: list of BaseFunctionSamplingBatch: splitted base function sampling batches """ if n <= 0: raise ValueError("negative split size") # New list of base functions batches base_func_list = [] batch_index = 0 for batch in batch_list: # Ignored response functions ignored = [base_function for base_function in batch.iter_all() if base_function.is_ignored()] # Ignored response function information ignored_indices = [base_func.index for base_func in ignored] ignored_components = [base_func.component for base_func in ignored] ignored_parameters = [base_func.parameter for base_func in ignored] # Not ignored response function information indices = [base_func.index for base_func in batch] components = [base_func.component for base_func in batch] parameters = [base_func.parameter for base_func in batch] for i in range(0, batch.n_samples, n): stop = i + n if i + n <= batch.n_samples else None # Last batch if stop == batch.n_samples - 1 or stop is None: # Putting all ignored response functions in the last batch base_func_list.append(BaseFunctionSamplingBatch( self, batch_index, indices[i:stop] + ignored_indices, components[i:stop] + ignored_components, parameters[i:stop] + ignored_parameters, batch.date_start, batch.date_end )) else: base_func_list.append(BaseFunctionSamplingBatch( self, batch_index, indices[i:stop], components[i:stop], parameters[i:stop], batch.date_start, batch.date_end )) batch_index += 1 return base_func_list
[docs] def init_base_functions( self: Mode ) -> Union[List[BaseFunction], List[BaseFunctionSamplingBatch]]: """Get the time window of each of the controlvect element and initialize the corresponding base functions or batch sampling of base functions Args: self (Mode): the mode plugin Returns: list of BaseFunction or list of BaseFunctionSamplingBatch: simulations to run """ df_periods = get_controlvect_periods(self) # Using batch sampling if self.use_batch_sampling: date_start = df_periods.loc[:, 'date_start'].to_numpy() date_end = df_periods.loc[:, 'date_end'].to_numpy() batches = groupby_period(date_start, date_end) base_func_list = [ BaseFunctionSamplingBatch( self, batch_index, indices, df_periods.loc[indices, 'component'].to_list(), df_periods.loc[indices, 'parameter'].to_list(), di, df ) for batch_index, (indices, di, df) in enumerate(batches) if not np.isnan(df) # Drop base funcs outside of the sim window ] if self.separate_parameters: base_func_list = split_by_parameter(self, base_func_list) if hasattr(self, 'batch_sampling_size'): base_func_list = split_sampling_batches( self, base_func_list, self.batch_sampling_size) batch_size = base_func_list[0].n_samples # Assuming that all batches have an equal size info(f"Running {len(base_func_list)} sampling batch of {batch_size} " f"response functions with '{self.model.plugin.name}' model in " f"'{self.run_mode}' mode") debug(f"All batches size: {[b.n_samples for b in base_func_list]}") # Not using batch sampling else: base_func_list = [ BaseFunction(self, index, comp, param, di, df) # type: ignore for index, (comp, param, di, df) in df_periods.iterrows() if df is not pd.NaT # Drop base funcs outside of the sim window ] info(f"Running {len(base_func_list)} response functions with " f"'{self.model.plugin.name}' model in '{self.run_mode}' mode") if self.first_period_only: first_period_list = [base_func for base_func in base_func_list if base_func.date_start == self.datei] info("Retaining only the first period, for a count of " f"{len(first_period_list)} response functions") return first_period_list return base_func_list