Source code for pycif.plugins.transforms.basic.time_interpolation.adjoint
import copy
import numpy as np
import pandas as pd
import xarray as xr
from logging import warning
import datetime
from .....utils.parallel import thread
try:
import cPickle as pickle
except ImportError:
import pickle
from .utils.sparse.adjoint import adjoint as sparse_adjoint
from .utils.array.adjoint import adjoint as array_adjoint
[docs]
def adjoint(
transf,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
onlyinit=False,
**kwargs
):
# if onlyinit:
# return
ddi = min(di, df)
inputs = inout_datastore["inputs"]
# Fetch information from reference trid
# Valid for both ensemble and single simulations
trid_ref = list(mapper["inputs"].keys())[0]
sparse_in = mapper["inputs"][trid_ref]["sparse_data"]
sparse_out = mapper["outputs"][trid_ref]["sparse_data"]
sampled_in = mapper["inputs"][trid_ref]["sampled"]
sampled_out = mapper["outputs"][trid_ref]["sampled"]
# Deal differently sparse and array data
# Array data is parallelized along trid externally, while sparse share indexing
if sampled_out or sparse_out:
sparse_adjoint(
transf, ddi, mapper, inout_datastore["outputs"],
inout_datastore, onlyinit, nthreads=transf.nthreads
)
else:
array_adjoint(
ddi, mapper, inout_datastore, inout_datastore["outputs"],
onlyinit, nthreads=transf.nthreads
)
# # Threading the application of the scaling factor for ensembles
# nthreads = transf.nthreads
# thread_intervals = np.linspace(
# 0, len(mapper["inputs"]), nthreads + 1
# ).astype(int)
# list_trids = copy.deepcopy(list(mapper["inputs"].keys()))
# @thread
# def thread_function(ithread):
# for itrid in range(thread_intervals[ithread], thread_intervals[ithread + 1]):
# trid = list_trids[itrid]
# # Fetch outputs depending on date
# outputs = None
# for data_id in inout_datastore["outputs"]:
# if data_id == trid:
# outputs = inout_datastore["outputs"][data_id][ddi]
# continue
# if outputs is None:
# continue
# sparse_in = mapper["inputs"][trid]["sparse_data"]
# sparse_out = mapper["outputs"][trid]["sparse_data"]
# sampled_in = mapper["inputs"][trid]["sampled"]
# sampled_out = mapper["outputs"][trid]["sampled"]
# # Interpolation for sparse data
# # Here needs to merge interpol_indexes from multiple output dates
# # If common input dates
# if sampled_out or sparse_out:
# sparse_adjoint(
# transf, ddi, mapper, outputs, trid,
# inout_datastore, onlyinit
# )
# else:
# array_adjoint(
# ddi, mapper, inout_datastore, trid, outputs, onlyinit
# )
# thread_function(range(nthreads))