import os
import itertools
import numpy as np
import xarray as xr
import copy
import pandas as pd
from logging import debug
from scipy.interpolate import RectBivariateSpline
from .utils.get_weights import get_weights
from .....utils.classes.domains import Domain
from .....utils.parallel import thread
from .....utils.datastores.empty import init_empty
try:
import cPickle as pickle
except ImportError:
import pickle
[docs]
def forward(
transf,
inout_datastore,
controlvect,
obsvect,
mapper,
ddi,
ddf,
mode,
runsubdir,
workdir,
onlyinit=False,
save_debug=False,
**kwargs
):
xmod = inout_datastore["inputs"]
trid_ref = list(xmod)[0]
# First fetch information to
# Differentiate full matrices and sparse data
is_sparse_out = mapper["outputs"][trid_ref].get("sparse_data", False)
is_sparse_in = mapper["inputs"][trid_ref].get("sparse_data", False)
is_sampled_in = mapper["inputs"][trid_ref]["sampled"]
is_sampled_out = mapper["outputs"][trid_ref]["sampled"]
# Inputs domain from the present tracer
if is_sparse_in:
if onlyinit:
return
domain_in = Domain()
ds = inout_datastore["outputs"][trid_ref][ddi]["metadata"]
domain_in.zlon = np.asarray(ds["lon"])[:, np.newaxis]
domain_in.zlonc = np.asarray(ds["lon"])[:, np.newaxis]
domain_in.zlat = np.asarray(ds["lat"])[:, np.newaxis]
domain_in.zlatc = np.asarray(ds["lat"])[:, np.newaxis]
domain_in.nlon = len(ds)
domain_in.nlat = len(ds)
nlat_in, nlon_in = domain_in.nlon, domain_in.nlon
else:
domain_in = mapper["inputs"][trid_ref]["domain"]
nlat_in, nlon_in = domain_in.zlon.shape
# Outputs domain from the mapper
is_lbc = mapper["outputs"][trid_ref].get("is_lbc", False)
# Create domain from metadata if sparse data
if is_sparse_out or is_sampled_out:
# Do nothing if onlyinit and propagating empty matrix
if onlyinit and len(xmod[trid_ref][ddi]) == 0:
return
domain_out = Domain()
ds = xmod[trid_ref][ddi]["metadata"]
domain_out.zlon = np.asarray(ds["lon"])[:, np.newaxis]
domain_out.zlonc = np.asarray(ds["lon"])[:, np.newaxis]
domain_out.zlat = np.asarray(ds["lat"])[:, np.newaxis]
domain_out.zlatc = np.asarray(ds["lat"])[:, np.newaxis]
domain_out.nlon = len(ds)
domain_out.nlat = len(ds)
else:
domain_out = mapper["outputs"][trid_ref]["domain"]
# # For non-sparse data, in initialization mode, don't need to do
# # any further operations
# if onlyinit:
# return
# Differentiate dimensions if LBC
if is_lbc:
nlon_out = domain_out.nlon_side
nlat_out = domain_out.nlat_side
else:
nlat_out, nlon_out = domain_out.zlat.shape
# Getting weights
weights = get_weights(
transf,
trid_ref,
mapper,
domain_in,
domain_out,
is_lbc,
ddi,
is_sparse_in=is_sparse_in,
is_sparse_out=is_sparse_out or is_sampled_out
)
# Execute parallel threads
nthreads = transf.nthreads if transf.threading_mode == "trid" else 1
nthreads_levels = transf.nthreads if transf.threading_mode == "levels" else 1
thread_intervals = np.linspace(
0, len(mapper["inputs"]), nthreads + 1
).astype(int)
list_trids = copy.deepcopy(list(mapper["inputs"].keys()))
# Initialize output dictionary before running thread
# This is to avoid overwritting data
for trid in list_trids:
inout_datastore["outputs"][trid] = {ddi: {}}
@thread
def thread_function(ithread):
for itrid in range(thread_intervals[ithread], thread_intervals[ithread + 1]):
trid = list_trids[itrid]
# Now apply weights
if not (is_sparse_in or is_sparse_out or is_sampled_out):
toregrid = copy.deepcopy(xmod[trid][ddi]["spec"])
inout_datastore["outputs"][trid][ddi]["spec"] = do_regridding(
inout_datastore["outputs"][trid][ddi],
toregrid,
nlat_in, nlon_in,
nlat_out, nlon_out,
weights,
min_weight=transf.min_weight,
nthreads=nthreads_levels
)
if mode == "tl":
if "incr" in xmod[trid][ddi]:
toregrid = copy.deepcopy(xmod[trid][ddi]["incr"])
inout_datastore["outputs"][trid][ddi]["incr"] = do_regridding(
inout_datastore["outputs"][trid][ddi],
toregrid,
nlat_in, nlon_in,
nlat_out, nlon_out,
weights,
min_weight=transf.min_weight,
nthreads=nthreads_levels
)
else:
inout_datastore["outputs"][trid][ddi]["incr"] = \
0. * inout_datastore["outputs"][trid][ddi]["spec"]
else:
inout_datastore["outputs"][trid][ddi] = do_regridding(
inout_datastore["outputs"][trid][ddi],
xmod[trid][ddi],
nlat_in, nlon_in,
nlat_out, nlon_out,
weights,
min_weight=transf.min_weight,
is_sparse_in=is_sparse_in,
is_sparse_out=is_sparse_out or is_sampled_out,
save_debug=save_debug, transf=transf
)
thread_function(range(nthreads))
[docs]
def do_regridding(
datastore_out, data,
nlat_in, nlon_in,
nlat_out, nlon_out,
weights,
min_weight=1e-10,
is_sparse_in=False, is_sparse_out=False,
save_debug=False, transf=None,
nthreads=1
):
if not (is_sparse_in or is_sparse_out):
# Applying weights
nlev = len(data.lev)
ntimes = len(data.time)
var_out = np.zeros(
(ntimes, nlev, nlat_out, nlon_out), dtype=data.dtype)
var_in = data.values
# Filter weights
mask = weights["wgt"] > min_weight
weights["wgt"][~mask] = 0
# Output index
iout, jout = np.unravel_index(range(nlat_out * nlon_out),
(nlat_out, nlon_out), order="F")
if "filtered" in weights:
iout = iout[weights["filtered"]]
jout = jout[weights["filtered"]]
i = weights["i"].astype(int)
j = weights["j"].astype(int)
wgt = weights["wgt"]
# Apply weights in thread mode
list_time_levels = list(itertools.product(
range(var_out.shape[0]),
range(var_out.shape[1])
))
thread_intervals = np.linspace(
0, len(list_time_levels), nthreads + 1
).astype(int)
@thread
def thread_function(ithread):
for itime_lev in range(thread_intervals[ithread], thread_intervals[ithread + 1]):
time, level = list_time_levels[itime_lev]
var_out[time, level, iout, jout] = np.nansum(
wgt * var_in[time, level, i, j], axis=-1)
thread_function(range(nthreads))
# Return xarray
times = data.time.values
return xr.DataArray(
var_out, coords={"time": times},
dims=("time", "lev", "lat", "lon")
)
else:
i = weights["i"].astype(int)
j = weights["j"].astype(int)
wgts = weights["wgt"]
# Mask filtered if exists
mask = np.arange(len(data))
if "filtered" in weights:
mask = weights["filtered"]
non_filtered = []
if "non_filtered" in weights:
non_filtered = weights["non_filtered"]
# Return if mask is empty
if mask.size == 0:
debug(
"Skipping regridding for sparse data as no data is provided."
)
return datastore_out
# Initialize output dataframe
datastore_out = init_empty(nlines=len(mask) + len(non_filtered))
# Output indexes
if is_sparse_out:
ref_index = pd.Index(np.arange(len(mask)))
out_indexes = pd.Index(np.arange(len(mask))).repeat(
np.shape(i)[1] * np.ones(len(i), dtype=int)).values
column2add = ["incr", "spec"]
# non_filtered = datastore_out.reset_index().index.difference(mask)
for c in column2add:
datastore_out[("maindata", c)] = 0.
data_out = datastore_out[("maindata", c)].iloc[mask].to_numpy()
np.add.at(data_out, out_indexes,
data[("maindata", c)].values * wgts.flatten())
datastore_out.iloc[
mask, datastore_out.columns.get_loc(("maindata", c))
] = data_out
# Putting NaNs in data outside the domain
datastore_out.iloc[
non_filtered, datastore_out.columns.get_loc(
("maindata", c))
] = np.nan
if save_debug:
debug_cols = pd.MultiIndex.from_product(
[[transf.transform_id],
["i_{:02d}".format(k) for k in range(i.shape[1])]
+ ["j_{:02d}".format(k) for k in range(i.shape[1])]
+ ["weight_{:02d}".format(k) for k in range(i.shape[1])]]
)
df_debug = pd.DataFrame(index=datastore_out.index,
columns=debug_cols)
# Fill columns
for k in range(i.shape[1]):
df_debug.iloc[
mask, df_debug.columns.get_loc(
(transf.transform_id, "i_{:02d}".format(k))
)] = i[:, k]
df_debug.iloc[
mask, df_debug.columns.get_loc(
(transf.transform_id, "j_{:02d}".format(k))
)] = j[:, k]
df_debug.iloc[
mask, df_debug.columns.get_loc(
(transf.transform_id, "weight_{:02d}".format(k))
)] = wgts[:, k]
datastore_out = pd.concat([datastore_out, df_debug], axis=1)
else:
out_indexes = data.iloc[mask].reset_index().index.repeat(
np.shape(i)[1] * np.ones(len(i), dtype=int)).values
column2add = ["incr", "spec"]
datastore_out = copy.deepcopy(data.iloc[mask]).iloc[out_indexes]
for c in column2add:
datastore_out[("maindata", c)] = 0.
data_in = data.iloc[mask][("maindata", c)].to_numpy()
datastore_out.loc[:, ("maindata", c)
] = data_in[out_indexes] * wgts.flatten()
datastore_out.loc[:, ("metadata", "i")] = i.flatten()
datastore_out.loc[:, ("metadata", "j")] = j.flatten()
# Group by i/j/alt/dates
group = datastore_out.groupby(
[("metadata", "i"), ("metadata", "j"),
("metadata", "alt"), ("metadata", "date")]
)
datastore_out = group.max().reset_index()
for c in column2add:
datastore_out[("maindata", c)] = group.sum().reset_index()[
("maindata", c)]
return datastore_out