import os
import copy
import numpy as np
import xarray as xr
import itertools
from .utils.get_weights import get_weights
from .....utils.classes.domains import Domain
[docs]
def adjoint(
transf,
inout_datastore,
controlvect,
obsvect,
mapper,
ddi,
ddf,
mode,
runsubdir,
workdir,
onlyinit=False,
save_debug=False,
**kwargs
):
xmod = inout_datastore["outputs"]
trid_ref = list(xmod)[0]
# First fetch information to
# Differentiate full matrices and sparse data
is_sparse_out = mapper["outputs"][trid_ref].get("sparse_data", False)
is_sparse_in = mapper["inputs"][trid_ref].get("sparse_data", False)
is_sampled_in = mapper["inputs"][trid_ref]["sampled"]
is_sampled_out = mapper["outputs"][trid_ref]["sampled"]
# Inputs domain from the present tracer
if is_sparse_in:
if onlyinit:
return
domain_in = Domain()
ds = inout_datastore["outputs"][trid_ref][ddi]["metadata"]
domain_in.zlon = ds.loc[:, "lon"].values[:, np.newaxis]
domain_in.zlonc = ds.loc[:, "lon"].values[:, np.newaxis]
domain_in.zlat = ds.loc[:, "lat"].values[:, np.newaxis]
domain_in.zlatc = ds.loc[:, "lat"].values[:, np.newaxis]
domain_in.nlon = len(ds)
domain_in.nlat = len(ds)
nlat_in, nlon_in = domain_in.nlon, domain_in.nlon
else:
domain_in = mapper["inputs"][trid_ref]["domain"]
nlat_in, nlon_in = domain_in.zlon.shape
# Outputs domain from the mapper
is_lbc = mapper["outputs"][trid_ref].get("is_lbc", False)
# Create domain from metadata if sparse data
if is_sparse_out or is_sampled_out:
# Do nothing if onlyinit and propagating empty matrix
if onlyinit and len(xmod[trid_ref][ddi]) == 0:
return
domain_out = Domain()
ds = xmod[trid_ref][ddi]["metadata"]
domain_out.zlon = ds.loc[:, "lon"].values[:, np.newaxis]
domain_out.zlonc = ds.loc[:, "lon"].values[:, np.newaxis]
domain_out.zlat = ds.loc[:, "lat"].values[:, np.newaxis]
domain_out.zlatc = ds.loc[:, "lat"].values[:, np.newaxis]
domain_out.nlon = len(ds)
domain_out.nlat = len(ds)
else:
domain_out = mapper["outputs"][trid_ref]["domain"]
# For non-sparse data, in initialization mode, don't need to do
# any further operations
if onlyinit:
return
# Differentiate dimensions if LBC
if is_lbc:
nlon_out = domain_out.nlon_side
nlat_out = domain_out.nlat_side
else:
nlat_out, nlon_out = domain_out.zlat.shape
# Getting weights
weights = get_weights(
transf,
trid_ref,
mapper,
domain_in,
domain_out,
is_lbc,
ddi,
is_sparse_in=is_sparse_in,
is_sparse_out=is_sparse_out or is_sampled_out
)
for trid in mapper["inputs"]:
inout_datastore["inputs"][trid] = {}
if not (
is_sparse_in or is_sparse_out
or is_sampled_out
):
inout_datastore["inputs"][trid][ddi] = {
"adj_out":
do_regridding_adj(
xmod[trid][ddi]["adj_out"],
nlat_in, nlon_in,
nlat_out, nlon_out,
weights,
min_weight=transf.min_weight,
is_sparse_in=is_sparse_in,
is_sparse_out=is_sparse_out or is_sampled_out,
)
}
else:
inout_datastore["inputs"][trid][ddi] = \
do_regridding_adj(
xmod[trid][ddi],
nlat_in, nlon_in,
nlat_out, nlon_out,
weights,
min_weight=transf.min_weight,
is_sparse_in=is_sparse_in,
is_sparse_out=is_sparse_out or is_sampled_out,
save_debug=save_debug, transf=transf
)
[docs]
def do_regridding_adj(
data, nlat_in, nlon_in, nlat_out, nlon_out,
weights, min_weight=1e-10,
is_sparse_in=False, is_sparse_out=False,
save_debug=False, transf=None
):
if not (is_sparse_in or is_sparse_out):
# Applying weights
nlev = len(data.lev)
ntimes = len(data.time)
nlat_out = len(data.lat)
nlon_out = len(data.lon)
var_in = np.zeros((ntimes, nlev, nlat_in, nlon_in), data.dtype)
var_out = data.values
# Filter weights
mask = weights["wgt"] > min_weight
weights["wgt"][~mask] = 0
# Apply weights
iout, jout = np.unravel_index(range(nlat_out * nlon_out),
(nlat_out, nlon_out), order="F")
if "filtered" in weights:
iout = iout[weights["filtered"]]
jout = jout[weights["filtered"]]
for time, level \
in itertools.product(range(var_out.shape[0]),
range(var_out.shape[1])):
np.add.at(var_in[time, level],
(weights["i"].astype(int), weights["j"].astype(int)),
var_out[time, level, iout, jout, np.newaxis]
* weights["wgt"])
times = data.time.values
return xr.DataArray(
var_in, coords={"time": times},
dims=("time", "lev", "lat", "lon")
)
else:
i = weights["i"].astype(int)
j = weights["j"].astype(int)
wgts = weights["wgt"]
# Mask filtered if exists
mask = np.arange(len(data))
if "filtered" in weights:
mask = weights["filtered"]
if is_sparse_out:
data_out = copy.deepcopy(data.iloc[mask])
# Extend data_out depending on the shape of i/j
if "index" in data_out:
del data_out["index"]
data_out = (
data_out.iloc[
data_out.reset_index().index.repeat(
np.shape(i)[1] * np.ones(len(i), dtype=int))])
data_out[("metadata", "i")] = i.flatten()
data_out[("metadata", "j")] = j.flatten()
data_out[("maindata", "adj_out")] *= wgts.flatten()
else:
raise Exception("Adjoint of regrid with sparse inputs not yet "
"implemented")
return data_out