Source code for pycif.plugins.transforms.basic.time_interpolation.adjoint

import copy
import numpy as np
import pandas as pd
import xarray as xr
from logging import warning
import datetime
from .....utils.parallel import thread

try:
    import cPickle as pickle
except ImportError:
    import pickle

from .utils.sparse.adjoint import adjoint as sparse_adjoint
from .utils.array.adjoint import adjoint as array_adjoint


[docs] def adjoint( transf, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, onlyinit=False, **kwargs ): # if onlyinit: # return ddi = min(di, df) inputs = inout_datastore["inputs"] # Fetch information from reference trid # Valid for both ensemble and single simulations trid_ref = list(mapper["inputs"].keys())[0] sparse_in = mapper["inputs"][trid_ref]["sparse_data"] sparse_out = mapper["outputs"][trid_ref]["sparse_data"] sampled_in = mapper["inputs"][trid_ref]["sampled"] sampled_out = mapper["outputs"][trid_ref]["sampled"] # Deal differently sparse and array data # Array data is parallelized along trid externally, while sparse share indexing if sampled_out or sparse_out: sparse_adjoint( transf, ddi, mapper, inout_datastore["outputs"], inout_datastore, onlyinit, nthreads=transf.nthreads ) else: array_adjoint( ddi, mapper, inout_datastore, inout_datastore["outputs"], onlyinit, nthreads=transf.nthreads )
# # Threading the application of the scaling factor for ensembles # nthreads = transf.nthreads # thread_intervals = np.linspace( # 0, len(mapper["inputs"]), nthreads + 1 # ).astype(int) # list_trids = copy.deepcopy(list(mapper["inputs"].keys())) # @thread # def thread_function(ithread): # for itrid in range(thread_intervals[ithread], thread_intervals[ithread + 1]): # trid = list_trids[itrid] # # Fetch outputs depending on date # outputs = None # for data_id in inout_datastore["outputs"]: # if data_id == trid: # outputs = inout_datastore["outputs"][data_id][ddi] # continue # if outputs is None: # continue # sparse_in = mapper["inputs"][trid]["sparse_data"] # sparse_out = mapper["outputs"][trid]["sparse_data"] # sampled_in = mapper["inputs"][trid]["sampled"] # sampled_out = mapper["outputs"][trid]["sampled"] # # Interpolation for sparse data # # Here needs to merge interpol_indexes from multiple output dates # # If common input dates # if sampled_out or sparse_out: # sparse_adjoint( # transf, ddi, mapper, outputs, trid, # inout_datastore, onlyinit # ) # else: # array_adjoint( # ddi, mapper, inout_datastore, trid, outputs, onlyinit # ) # thread_function(range(nthreads))