import numpy as np
import xarray as xr
import itertools
import pandas as pd
from logging import info, debug
from .....utils.datastores.dump import dump_datastore, read_datastore
from .....utils.datastores.empty import init_empty
from .....utils.datastores.crop_monitor import crop_monitor
from .apply_AK import apply_ak_ad
from .vinterp import vertical_interp
import copy
[docs]
def adjoint(
transf,
inout_datastore,
controlvect,
obsvect,
mapper,
di,
df,
mode,
runsubdir,
workdir,
onlyinit=False,
**kwargs
):
"""De-aggregate total columns to the model level."""
ddi = min(di, df)
ref_parameter = transf.parameter[0]
ref_component = transf.component[0]
ref_ds = inout_datastore["outputs"][(ref_component, ref_parameter)][ddi]
out_datastore = inout_datastore["inputs"]
# Stop here if no data
if len(ref_ds) == 0:
return
# Copy data to output datastore
input_components = ["concs", "pressure", "dpressure"] \
+ (transf.product == "column") * ["airm", "hlay"]
for incomp in input_components:
out_datastore[(incomp, ref_parameter)][ddi] = \
copy.deepcopy(ref_ds)
if incomp != "concs":
out_datastore[(incomp, ref_parameter)][ddi][
("maindata", "adj_out")] = 0
y0 = out_datastore[("concs", ref_parameter)][ddi]
# Exit if empty observations
if len(y0) == 0:
return
# Number of levels to extract for satellites
nlev_model = mapper["inputs"][("concs", ref_parameter)]["domain"].nlev
dlev = np.ones(len(y0), dtype=int) * nlev_model
# Index in the original data of the level-extended dataframe
native_inds_main = np.append([0], dlev.cumsum())
# Output index
idx = np.zeros((native_inds_main[-1]), dtype=int)
idx[native_inds_main[:-1]] = np.arange(len(y0))
np.maximum.accumulate(idx, out=idx)
native_inds_main = native_inds_main[:-1]
# Output dataframe
datacol = "adj_out" if mode == "adj" else "obs"
col2process = [
"tstep",
"tstep_glo",
"i",
"j",
"level",
"dtstep",
"parameter",
"duration",
datacol,
]
df_main = copy.deepcopy(y0.iloc[idx])
# Levels
sublevels = np.meshgrid(list(range(nlev_model)),
np.ones(len(y0)))[0].flatten()
df_main[("metadata", "level")] = sublevels
# Check that only one satellite was specified in the datastore
iq1 = y0["metadata"]["station"]
list_satIDs = iq1.unique()
if len(list_satIDs) > 1:
raise Exception(
f"Warning! Several satellite IDs were specified in the monitor file: "
f"{list_satIDs} \n "
"This corresponds to the column 'station'. Please check your monitor file. "
"It is possible to manually exclude IDs from the datastore by using the "
"parameter 'exclude_stations' (list[str]) in the corresponding satellite "
"paragraph of your data vector in the Yaml configuration file."
)
# Deal with case when no ID was specified: assume that all data are concerned
satID = list_satIDs[0]
if type(satID) == str:
pass
elif ~pd.notnull(satID):
satID = 0
iq1[:] = satID
# Saving original values for later re-aggregation
df_main.loc[:, ("metadata", "indorig")] = idx
df_main.loc[:, ("metadata", "iq1")] = iq1.iloc[idx].values
transf.metadata = getattr(transf, "metadata", {})
transf.metadata[ddi] = df_main.loc[
:, [("metadata", "indorig"),
("metadata", "iq1"),
("metadata", "ind_file"),
("metadata", "station"),
("metadata", "date")]]
# Initializing datastores for various inputs
for incomp in input_components:
out_datastore[(incomp, ref_parameter)][ddi] = \
copy.deepcopy(df_main)
if incomp != "concs":
out_datastore[(incomp, ref_parameter)][ddi][
("maindata", "adj_out")] = 0
# Un-stack columns into dataframe for stratosphere
if transf.fill_strato:
nlev_strato = \
mapper["inputs"][("stratosphere", ref_parameter)]["domain"].nlev
# Number of levels to extract for satellites
dlev = np.ones(len(y0), dtype=int) * nlev_strato
# Index in the original data of the level-extended dataframe
native_inds_strato = np.append([0], dlev.cumsum())
# Output index
idx = np.zeros((native_inds_strato[-1]), dtype=int)
idx[native_inds_strato[:-1]] = np.arange(len(y0))
np.maximum.accumulate(idx, out=idx)
native_inds_strato = native_inds_strato[:-1]
# Output dataframe
df_strato = copy.deepcopy(y0.iloc[idx])
# Levels
sublevels = np.meshgrid(list(range(nlev_strato)),
np.ones(len(y0)))[0].flatten()
df_strato[("metadata", "level")] = sublevels
# Saving original values for later re-aggregation
df_strato.loc[:, ("metadata", "indorig")] = idx
df_strato.loc[:, ("metadata", "iq1")] = iq1.iloc[idx]
out_datastore[("stratosphere", ref_parameter)][ddi] = df_strato
# Stop here if no adjoint to be fully computed
# Just forward datastore to precursor transforms
if onlyinit:
return
# Load pressure coordinates from previous run
file_monit = ddi.strftime(
"{}/chain/satellites/{}/monit_%Y%m%d%H%M.nc".format(
transf.adj_refdir, transf.transform_id)
)
fwd_pressure = read_datastore(
file_monit,
col2dump=["pressure", "dp", "indorig",
"hlay", "airm", "sim", "pthick", "exclude_zeros"]
)
ref_indexes = ~fwd_pressure["metadata"].duplicated(subset=["indorig"])
satmask = iq1 == satID
nobs = np.sum(satmask)
# Getting the vector of increments
obs_incr = y0.loc[satmask, ("maindata", "adj_out")]
# If all increments are NaNs, just pass to next satellite
if not np.any(obs_incr != 0.0):
out_datastore[("concs", transf.parameter[0])][ddi] = df_main
return
# Get target pressure
native_ind_stack = (
native_inds_main[satmask]
+ np.arange(nlev_model)[:, np.newaxis]
)
datasim = xr.Dataset(
{
"pressure": (
["level", "index"],
np.log(fwd_pressure["metadata"]
["pressure"].values[native_ind_stack]),
),
"dp": (
["level", "index"],
fwd_pressure["metadata"]["dp"].values[native_ind_stack],
),
"sim": (["level", "index"],
fwd_pressure["maindata"]["sim"].values[native_ind_stack]),
},
coords={
"index": np.arange(len(y0)),
"level": np.arange(nlev_model),
},
)
if transf.product == "column":
datasim = datasim.assign(
{
"airm": (
["level", "index"],
fwd_pressure["metadata"]["airm"].values[native_ind_stack],
),
"hlay": (
["level", "index"],
fwd_pressure["metadata"]["hlay"].values[native_ind_stack],
),
}
)
# Crop aks depending on split_freq and original file
ref_mapper = mapper["outputs"][(ref_component, ref_parameter)]
ref_tracer = ref_mapper["tracer"]
ddf = ref_tracer.datef
if hasattr(ref_tracer, "split_freq"):
obsvect_dates = pd.date_range(
ref_tracer.datei, ref_tracer.datef,
freq=ref_tracer.split_freq)
i0 = np.argwhere(obsvect_dates == ddi)[0][0]
ddf = obsvect_dates[i0 + 1]
# Getting averaging kernels
files_aks = \
set(list(itertools.chain(
*itertools.chain(*ref_tracer.input_files.values()))))
files_aks = sorted(files_aks)
# Use only files relevant for the present dataset
files_aks = np.array(files_aks)[
y0["metadata"]["ind_file"].drop_duplicates().values.astype(int)
]
info("Fetching satellite infos from files: {}".format(files_aks))
try:
colsat = ["ak", "pavg0", "date",
"index", "station", "duration"]
if transf.use_prior:
colsat += ["qa0"]
if transf.use_drycols:
colsat += ["dryair"]
if transf.precomputed_pwgt:
colsat += ["pw"]
coord2dump = ["index"]
all_sat_aks = []
for file_aks in files_aks:
# Crop the monitor to replicate what is done
# in pycif/plugins/obsvects/standard/utils/__init__.py
ds_dates = xr.open_dataset(file_aks)[["date", "duration"]]
crop_date = init_empty()
crop_date[("metadata", "date")] = ds_dates["date"].values
crop_date[("metadata", "duration")] = ds_dates["duration"].values
crop_index = crop_monitor(
crop_date, ref_tracer.datei, ref_tracer.datef,
return_index=True
)
if crop_index.size == 0:
continue
# Now extract only data relevant to present sub-period
# This happens only with split_freq
crop_date = crop_date.iloc[crop_index]
if ref_tracer.datei != ddi or ref_tracer.datef != ddf:
crop_index = crop_monitor(
crop_date, ddi, ddf,
return_index=True,
keep_partial=True
)
crop_index = crop_date.index[crop_index].values
# Now read the data itself
sat_aks = read_datastore(
file_aks,
col2dump=colsat,
coord2dump=coord2dump,
keep_default=False,
to_pandas=False
)
if 'index' in sat_aks:
sat_aks = sat_aks.drop('index')
# Check that all variables are float64 to avoid precision issues
for v in sat_aks:
if sat_aks.dtypes[v] == np.dtype('float32'):
sat_aks[v] = sat_aks[v].astype('float64')
# Crop to sub-date
sat_aks["index"] = np.arange(sat_aks.dims["index"])
sat_aks = sat_aks.isel(index=crop_index)
if len(all_sat_aks) == 0:
all_sat_aks = sat_aks
else:
all_sat_aks = xr.concat([all_sat_aks, sat_aks], "index")
# # Crop the monitor to replicate what is done if split_freq
# # in pycif/plugins/obsvects/standard/utils/__init__.py
# if hasattr(ref_tracer, "split_freq"):
# crop_date = init_empty()
# crop_date[("metadata", "date")] = all_sat_aks["date"].values
# crop_date[("metadata", "duration")
# ] = all_sat_aks["duration"].values
# ddf = ref_tracer.datef
# obsvect_dates = pd.date_range(
# ref_tracer.datei, ref_tracer.datef,
# freq=ref_tracer.split_freq)
# i0 = np.argwhere(obsvect_dates == ddi)[0][0]
# ddf = obsvect_dates[i0 + 1]
# crop_index = crop_monitor(
# crop_date, ddi, ddf,
# return_index=True,
# keep_partial=hasattr(ref_tracer, "split_freq")
# )
all_sat_aks["index"] = np.arange(all_sat_aks.dims["index"])
# all_sat_aks = all_sat_aks.isel(index=crop_index)
# If no station was specified in the original file, just fill with 0
if "station" not in all_sat_aks:
all_sat_aks["station"] = xr.full_like(all_sat_aks.index, satID)
# Crop first observations for debugging purposes if specified in yml
if getattr(ref_tracer, "crop_datastore", False):
crop_istart = getattr(ref_tracer, "crop_istart", 0)
all_sat_aks = all_sat_aks.isel(
index=slice(crop_istart, crop_istart + ref_tracer.nobs2crop))
# Selecting only lines used in simulation
mask = all_sat_aks["date"].isin(ref_ds["metadata"]["date"]) \
& (all_sat_aks["station"] == satID)
sat_aks = all_sat_aks.loc[{"index": mask}]
except IOError:
raise IOError("Could not fetch "
"satellite info from {}".format(file_aks))
# Defining ak info
aks = sat_aks["ak"][:, ::-1].T
pavgs = sat_aks["pavg0"][:, ::-1].T
if transf.pressure == "hPa":
pavgs *= 100
# -- For level-based, given pressures are averages over the partial column (n)
# -- For layer-based, given pressures are the ones at inter-levels (n + 1)
level0_size = aks.level.size + 1
level1_size = aks.level.size
if transf.level_based:
level0_size -= 1
coords0 = {"index": np.arange(nobs),
"level": np.arange(level0_size)}
coords1 = {"index": np.arange(nobs),
"level": np.arange(level1_size)}
dims = ("level", "index")
if transf.use_prior:
qa0 = sat_aks["qa0"][:, ::-1].T
else:
qa0 = 0. * aks
if transf.level_based:
pavgs_mid = xr.DataArray(
np.log(pavgs.values), coords0, dims).bfill("level")
# Assuming goes from surface to the top of the atmosphere
flip_pressure = False
if not np.all(np.diff(pavgs.values, axis=0) > 0):
pavgs.values = np.flip(pavgs.values, axis=0)
flip_pressure = True
# Computes dpavgs from top to bottom
dpavgs = np.diff(pavgs.values, axis=0)
dpavgs = np.concatenate([
np.maximum(0, pavgs.values[[0]] - dpavgs[[0]] / 2),
pavgs.values[1:] - dpavgs / 2,
pavgs.values[[-1]] + dpavgs[[-1]] / 2
], axis=0)
dpavgs = xr.DataArray(np.diff(dpavgs, axis=0), coords1, dims)
# Flip back pressures
if flip_pressure:
pavgs.values = np.flip(pavgs.values, axis=0)
dpavgs.values = np.flip(dpavgs.values, axis=0)
else:
pavgs = xr.DataArray(pavgs, coords0, dims).bfill("level")
dpavgs = np.abs(xr.DataArray(np.diff(-pavgs, axis=0), coords1, dims))
pavgs_mid = xr.DataArray(
np.log(0.5 * (pavgs[:-1].values + pavgs[1:].values)),
coords1, dims)
# Exclude observations where there are all zeros in the simulation
exclude_zeros = np.where(
~fwd_pressure["metadata"].groupby(['indorig'])
.min()["exclude_zeros"]
)[0]
datasim = datasim.isel(index=exclude_zeros)
sat_aks = sat_aks.isel(index=exclude_zeros)
aks = aks.isel(index=exclude_zeros)
pavgs_mid = pavgs_mid.isel(index=exclude_zeros)
dpavgs = dpavgs.isel(index=exclude_zeros)
qa0 = qa0.isel(index=exclude_zeros)
nobs = datasim.dims["index"]
# Adding dry air mole fraction if formula 5
if transf.use_drycols:
drycols = sat_aks["dryair"][:, ::-1].T
else:
drycols = qa0 * 0.0 + 1
# Pressure weight to apply on columns
pwgt = dpavgs / dpavgs.sum(axis=0)
if not transf.scale_dpressure:
pwgt = 0 * pwgt + 1
elif transf.precomputed_pwgt:
pwgt = sat_aks["pw"][:, ::-1].T
# Applying aks
nbformula = transf.formula
chosenlevel = getattr(transf, "chosenlev", 0)
debug("nbformula: {}".format(nbformula))
debug("chosenlev: {}".format(chosenlevel))
# If nbformula 3, load sim_ak from forward
if transf.log_space:
file_dump = ddi.strftime(
"{}/chain/satellites/{}/sim_ak_{}_%Y%m%d%H%M.nc".format(
transf.adj_refdir, transf.transform_id, satID)
)
sim_ak = xr.open_dataarray(file_dump).values
else:
sim_ak = 0
obs_incr = apply_ak_ad(
obs_incr.values[exclude_zeros], sim_ak, aks.values,
pwgt.values, drycols.values, qa0.values,
use_drycols=transf.use_drycols, chosen_level=chosenlevel,
scale_factor=transf.unit_scaling,
normalize_columns=transf.normalize_columns,
log_space=transf.log_space
)
obs_incr[~pd.notnull(obs_incr)] = 0.
# Correction with the pressure thickness
# WARNING: there is an inconsistency in the number of levels
if transf.correct_pthick:
scale_pthick = fwd_pressure["metadata"]["pthick"].iloc[
np.flatnonzero(ref_indexes)[satmask]].iloc[exclude_zeros]
obs_incr *= scale_pthick.values
# Adjoint of the log-pressure interpolation
obs_incr_interp = 0.0 * datasim["pressure"].values
# Fetch missing values from stratosphere
sim_pressure = datasim["pressure"].values
sim_dpressure = datasim["dp"].values
if transf.fill_strato:
strato_mapper = mapper["inputs"][("stratosphere",
ref_parameter)]
strato_sigma_a = strato_mapper["domain"].sigma_a_mid
strato_sigma_b = strato_mapper["domain"].sigma_b_mid
strato_sigma_a_interface = strato_mapper["domain"].sigma_a
strato_sigma_b_interface = strato_mapper["domain"].sigma_b
strato_nlev = len(strato_sigma_a)
psurf_strato = np.exp(sim_pressure[0])
pstrato = np.log(
strato_sigma_b * psurf_strato[:, np.newaxis] + strato_sigma_a
).T
pstrato_interface = (
strato_sigma_b_interface * psurf_strato[:, np.newaxis]
+ strato_sigma_a_interface
).T
dpstrato = np.abs(np.diff(pstrato_interface, axis=0))
# Here merge sim_pressure
# and pstrato properly, then apply vertical_interp
missing_levels_check = len(
np.unique(np.argmax(pstrato < sim_pressure[-1], axis=0)))
nblevbasic = np.shape(sim_pressure)[0]
cond = pstrato < sim_pressure[-1]
missing_levels = np.argmax(cond, axis=0)
missing_levels[~np.any(cond, axis=0)] = pstrato.shape[0]
nblevremainingstrato = pstrato.shape[0] - missing_levels.min()
sim_pressure_varying = np.full(
(nblevbasic + nblevremainingstrato, len(datasim.index)), 0.)
sim_pressure_varying[:nblevbasic] = sim_pressure[:]
sim_dpressure_varying = np.full(
(nblevbasic + nblevremainingstrato, len(datasim.index)), 0.)
sim_dpressure_varying[:nblevbasic] = sim_dpressure[:]
lev_index_target = [
i for m in missing_levels
for i in list(range(nblevbasic, nblevbasic + pstrato.shape[0] - m))
]
lev_index_orig = [i for m in missing_levels
for i in list(range(m, pstrato.shape[0]))]
obs_index_target = [
j for j, m in enumerate(missing_levels)
for i in list(range(nblevbasic, nblevbasic + pstrato.shape[0] - m))
]
np.add.at(sim_pressure_varying,
(lev_index_target, obs_index_target),
pstrato[lev_index_orig, obs_index_target]
)
sim_pressure = copy.deepcopy(sim_pressure_varying)
np.add.at(sim_dpressure_varying,
(lev_index_target, obs_index_target),
dpstrato[lev_index_orig, obs_index_target]
)
sim_dpressure = copy.deepcopy(sim_dpressure_varying)
# Vertical interpolation
xlow, xhigh, alphalow, alphahigh = vertical_interp(
sim_pressure,
sim_dpressure,
pavgs_mid.values,
dpavgs.values,
transf.cropstrato,
transf.vinterp_type,
transf.weights_nsubsteps
)
# Applying coefficients
# WARNING: There might be repeated indexes in a given column
# To deal with repeated index, np.add.at is recommended
meshout_wgt = (
np.arange(pavgs_mid.shape[0])[:, np.newaxis]
* np.ones((1, len(datasim.index)))
).astype(int)
meshout_wgt_iobs = (
np.arange(len(datasim.index))[np.newaxis]
* np.ones((pavgs_mid.shape[0], 1))
).astype(int)
if transf.vinterp_type == "weight":
meshout_wgt = np.floor(
np.linspace(0, pavgs_mid.shape[0],
pavgs_mid.shape[0] * transf.weights_nsubsteps,
endpoint=False
)[:, np.newaxis] * np.ones((1, len(datasim.index)))
).astype(int)
meshout_wgt_iobs = (
np.arange(len(datasim.index))[np.newaxis]
* np.ones((pavgs_mid.shape[0] * transf.weights_nsubsteps, 1))
).astype(int)
tmp_obs_incr = 0.0 * sim_pressure
np.add.at(
tmp_obs_incr,
(xlow, meshout_wgt_iobs),
obs_incr[meshout_wgt, meshout_wgt_iobs] * alphalow,
)
np.add.at(
tmp_obs_incr,
(xhigh, meshout_wgt_iobs),
obs_incr[meshout_wgt, meshout_wgt_iobs] * alphahigh,
)
# Deal with the stratosphere
if transf.fill_strato:
nstrato = strato_mapper["domain"].nlev
nlon = strato_mapper["domain"].nlon
nlat = strato_mapper["domain"].nlat
incr_strato = np.zeros((strato_nlev, len(datasim.index)))
np.add.at(incr_strato,
(lev_index_orig, obs_index_target),
tmp_obs_incr[lev_index_target, obs_index_target])
# ppb to molec/cm2 if column product
if transf.product == "column":
dpstrato = \
np.diff(np.concatenate(
[psurf_strato[np.newaxis, :], np.exp(pstrato)],
axis=0), axis=0) / 100 # hPa
G = 9.81
dmass = np.abs(dpstrato / G) # kg/m2
column = dmass * 1e3 # g/m2
column /= 28.96 * 1e4 # mol/cm2
column *= 6.02214076e23 # molec / cm2
column /= 1e9 # scaling from ppb
incr_strato *= column
# Fill adj_out
native_ind_stack_strato = (
native_inds_strato[satmask]
+ np.arange(nstrato)[:, np.newaxis]
)
out_strato = out_datastore[
("stratosphere", transf.parameter[0])][ddi]
strato_index = native_ind_stack_strato[:, exclude_zeros].flatten()
out_strato.iloc[
strato_index,
out_strato.columns.get_loc(("maindata", "adj_out"))
] = incr_strato.flatten()
# Keep only the "bottom" part for later
obs_incr_interp[:, :] = tmp_obs_incr[:nblevbasic, :]
else:
obs_incr_interp[:, :] = tmp_obs_incr[:]
# Convert CHIMERE fields to the correct unit
# from ppb to molec.cm-2 if the satellite product is a column
if transf.product == "column":
factor = (datasim["hlay"].values * 100) \
/ (1e9 / datasim["airm"].values)
obs_incr_interp *= factor
# Applying increments to the flattened datastore
out_index = native_ind_stack[:, exclude_zeros].flatten()
df_main.iloc[
out_index, df_main.columns.get_loc(("maindata", "adj_out"))
] = obs_incr_interp.flatten()
# Pushing adjoint to the general datastore
out_datastore[("concs", transf.parameter[0])][ddi] = df_main