Source code for pycif.plugins.transforms.complex.satellites.adjoint

import numpy as np
import xarray as xr
import itertools
import pandas as pd
from logging import info, debug
from .....utils.datastores.dump import dump_datastore, read_datastore
from .....utils.datastores.empty import init_empty
from .....utils.datastores.crop_monitor import crop_monitor
from .apply_AK import apply_ak_ad
from .vinterp import vertical_interp
import copy


[docs] def adjoint( transf, inout_datastore, controlvect, obsvect, mapper, di, df, mode, runsubdir, workdir, onlyinit=False, **kwargs ): """De-aggregate total columns to the model level.""" ddi = min(di, df) ref_parameter = transf.parameter[0] ref_component = transf.component[0] ref_ds = inout_datastore["outputs"][(ref_component, ref_parameter)][ddi] out_datastore = inout_datastore["inputs"] # Stop here if no data if len(ref_ds) == 0: return # Copy data to output datastore input_components = ["concs", "pressure", "dpressure"] \ + (transf.product == "column") * ["airm", "hlay"] for incomp in input_components: out_datastore[(incomp, ref_parameter)][ddi] = \ copy.deepcopy(ref_ds) if incomp != "concs": out_datastore[(incomp, ref_parameter)][ddi][ ("maindata", "adj_out")] = 0 y0 = out_datastore[("concs", ref_parameter)][ddi] # Exit if empty observations if len(y0) == 0: return # Number of levels to extract for satellites nlev_model = mapper["inputs"][("concs", ref_parameter)]["domain"].nlev dlev = np.ones(len(y0), dtype=int) * nlev_model # Index in the original data of the level-extended dataframe native_inds_main = np.append([0], dlev.cumsum()) # Output index idx = np.zeros((native_inds_main[-1]), dtype=int) idx[native_inds_main[:-1]] = np.arange(len(y0)) np.maximum.accumulate(idx, out=idx) native_inds_main = native_inds_main[:-1] # Output dataframe datacol = "adj_out" if mode == "adj" else "obs" col2process = [ "tstep", "tstep_glo", "i", "j", "level", "dtstep", "parameter", "duration", datacol, ] df_main = copy.deepcopy(y0.iloc[idx]) # Levels sublevels = np.meshgrid(list(range(nlev_model)), np.ones(len(y0)))[0].flatten() df_main[("metadata", "level")] = sublevels # Check that only one satellite was specified in the datastore iq1 = y0["metadata"]["station"] list_satIDs = iq1.unique() if len(list_satIDs) > 1: raise Exception( f"Warning! Several satellite IDs were specified in the monitor file: " f"{list_satIDs} \n " "This corresponds to the column 'station'. Please check your monitor file. " "It is possible to manually exclude IDs from the datastore by using the " "parameter 'exclude_stations' (list[str]) in the corresponding satellite " "paragraph of your data vector in the Yaml configuration file." ) # Deal with case when no ID was specified: assume that all data are concerned satID = list_satIDs[0] if type(satID) == str: pass elif ~pd.notnull(satID): satID = 0 iq1[:] = satID # Saving original values for later re-aggregation df_main.loc[:, ("metadata", "indorig")] = idx df_main.loc[:, ("metadata", "iq1")] = iq1.iloc[idx].values transf.metadata = getattr(transf, "metadata", {}) transf.metadata[ddi] = df_main.loc[ :, [("metadata", "indorig"), ("metadata", "iq1"), ("metadata", "ind_file"), ("metadata", "station"), ("metadata", "date")]] # Initializing datastores for various inputs for incomp in input_components: out_datastore[(incomp, ref_parameter)][ddi] = \ copy.deepcopy(df_main) if incomp != "concs": out_datastore[(incomp, ref_parameter)][ddi][ ("maindata", "adj_out")] = 0 # Un-stack columns into dataframe for stratosphere if transf.fill_strato: nlev_strato = \ mapper["inputs"][("stratosphere", ref_parameter)]["domain"].nlev # Number of levels to extract for satellites dlev = np.ones(len(y0), dtype=int) * nlev_strato # Index in the original data of the level-extended dataframe native_inds_strato = np.append([0], dlev.cumsum()) # Output index idx = np.zeros((native_inds_strato[-1]), dtype=int) idx[native_inds_strato[:-1]] = np.arange(len(y0)) np.maximum.accumulate(idx, out=idx) native_inds_strato = native_inds_strato[:-1] # Output dataframe df_strato = copy.deepcopy(y0.iloc[idx]) # Levels sublevels = np.meshgrid(list(range(nlev_strato)), np.ones(len(y0)))[0].flatten() df_strato[("metadata", "level")] = sublevels # Saving original values for later re-aggregation df_strato.loc[:, ("metadata", "indorig")] = idx df_strato.loc[:, ("metadata", "iq1")] = iq1.iloc[idx] out_datastore[("stratosphere", ref_parameter)][ddi] = df_strato # Stop here if no adjoint to be fully computed # Just forward datastore to precursor transforms if onlyinit: return # Load pressure coordinates from previous run file_monit = ddi.strftime( "{}/chain/satellites/{}/monit_%Y%m%d%H%M.nc".format( transf.adj_refdir, transf.transform_id) ) fwd_pressure = read_datastore( file_monit, col2dump=["pressure", "dp", "indorig", "hlay", "airm", "sim", "pthick", "exclude_zeros"] ) ref_indexes = ~fwd_pressure["metadata"].duplicated(subset=["indorig"]) satmask = iq1 == satID nobs = np.sum(satmask) # Getting the vector of increments obs_incr = y0.loc[satmask, ("maindata", "adj_out")] # If all increments are NaNs, just pass to next satellite if not np.any(obs_incr != 0.0): out_datastore[("concs", transf.parameter[0])][ddi] = df_main return # Get target pressure native_ind_stack = ( native_inds_main[satmask] + np.arange(nlev_model)[:, np.newaxis] ) datasim = xr.Dataset( { "pressure": ( ["level", "index"], np.log(fwd_pressure["metadata"] ["pressure"].values[native_ind_stack]), ), "dp": ( ["level", "index"], fwd_pressure["metadata"]["dp"].values[native_ind_stack], ), "sim": (["level", "index"], fwd_pressure["maindata"]["sim"].values[native_ind_stack]), }, coords={ "index": np.arange(len(y0)), "level": np.arange(nlev_model), }, ) if transf.product == "column": datasim = datasim.assign( { "airm": ( ["level", "index"], fwd_pressure["metadata"]["airm"].values[native_ind_stack], ), "hlay": ( ["level", "index"], fwd_pressure["metadata"]["hlay"].values[native_ind_stack], ), } ) # Crop aks depending on split_freq and original file ref_mapper = mapper["outputs"][(ref_component, ref_parameter)] ref_tracer = ref_mapper["tracer"] ddf = ref_tracer.datef if hasattr(ref_tracer, "split_freq"): obsvect_dates = pd.date_range( ref_tracer.datei, ref_tracer.datef, freq=ref_tracer.split_freq) i0 = np.argwhere(obsvect_dates == ddi)[0][0] ddf = obsvect_dates[i0 + 1] # Getting averaging kernels files_aks = \ set(list(itertools.chain( *itertools.chain(*ref_tracer.input_files.values())))) files_aks = sorted(files_aks) # Use only files relevant for the present dataset files_aks = np.array(files_aks)[ y0["metadata"]["ind_file"].drop_duplicates().values.astype(int) ] info("Fetching satellite infos from files: {}".format(files_aks)) try: colsat = ["ak", "pavg0", "date", "index", "station", "duration"] if transf.use_prior: colsat += ["qa0"] if transf.use_drycols: colsat += ["dryair"] if transf.precomputed_pwgt: colsat += ["pw"] coord2dump = ["index"] all_sat_aks = [] for file_aks in files_aks: # Crop the monitor to replicate what is done # in pycif/plugins/obsvects/standard/utils/__init__.py ds_dates = xr.open_dataset(file_aks)[["date", "duration"]] crop_date = init_empty() crop_date[("metadata", "date")] = ds_dates["date"].values crop_date[("metadata", "duration")] = ds_dates["duration"].values crop_index = crop_monitor( crop_date, ref_tracer.datei, ref_tracer.datef, return_index=True ) if crop_index.size == 0: continue # Now extract only data relevant to present sub-period # This happens only with split_freq crop_date = crop_date.iloc[crop_index] if ref_tracer.datei != ddi or ref_tracer.datef != ddf: crop_index = crop_monitor( crop_date, ddi, ddf, return_index=True, keep_partial=True ) crop_index = crop_date.index[crop_index].values # Now read the data itself sat_aks = read_datastore( file_aks, col2dump=colsat, coord2dump=coord2dump, keep_default=False, to_pandas=False ) if 'index' in sat_aks: sat_aks = sat_aks.drop('index') # Check that all variables are float64 to avoid precision issues for v in sat_aks: if sat_aks.dtypes[v] == np.dtype('float32'): sat_aks[v] = sat_aks[v].astype('float64') # Crop to sub-date sat_aks["index"] = np.arange(sat_aks.dims["index"]) sat_aks = sat_aks.isel(index=crop_index) if len(all_sat_aks) == 0: all_sat_aks = sat_aks else: all_sat_aks = xr.concat([all_sat_aks, sat_aks], "index") # # Crop the monitor to replicate what is done if split_freq # # in pycif/plugins/obsvects/standard/utils/__init__.py # if hasattr(ref_tracer, "split_freq"): # crop_date = init_empty() # crop_date[("metadata", "date")] = all_sat_aks["date"].values # crop_date[("metadata", "duration") # ] = all_sat_aks["duration"].values # ddf = ref_tracer.datef # obsvect_dates = pd.date_range( # ref_tracer.datei, ref_tracer.datef, # freq=ref_tracer.split_freq) # i0 = np.argwhere(obsvect_dates == ddi)[0][0] # ddf = obsvect_dates[i0 + 1] # crop_index = crop_monitor( # crop_date, ddi, ddf, # return_index=True, # keep_partial=hasattr(ref_tracer, "split_freq") # ) all_sat_aks["index"] = np.arange(all_sat_aks.dims["index"]) # all_sat_aks = all_sat_aks.isel(index=crop_index) # If no station was specified in the original file, just fill with 0 if "station" not in all_sat_aks: all_sat_aks["station"] = xr.full_like(all_sat_aks.index, satID) # Crop first observations for debugging purposes if specified in yml if getattr(ref_tracer, "crop_datastore", False): crop_istart = getattr(ref_tracer, "crop_istart", 0) all_sat_aks = all_sat_aks.isel( index=slice(crop_istart, crop_istart + ref_tracer.nobs2crop)) # Selecting only lines used in simulation mask = all_sat_aks["date"].isin(ref_ds["metadata"]["date"]) \ & (all_sat_aks["station"] == satID) sat_aks = all_sat_aks.loc[{"index": mask}] except IOError: raise IOError("Could not fetch " "satellite info from {}".format(file_aks)) # Defining ak info aks = sat_aks["ak"][:, ::-1].T pavgs = sat_aks["pavg0"][:, ::-1].T if transf.pressure == "hPa": pavgs *= 100 # -- For level-based, given pressures are averages over the partial column (n) # -- For layer-based, given pressures are the ones at inter-levels (n + 1) level0_size = aks.level.size + 1 level1_size = aks.level.size if transf.level_based: level0_size -= 1 coords0 = {"index": np.arange(nobs), "level": np.arange(level0_size)} coords1 = {"index": np.arange(nobs), "level": np.arange(level1_size)} dims = ("level", "index") if transf.use_prior: qa0 = sat_aks["qa0"][:, ::-1].T else: qa0 = 0. * aks if transf.level_based: pavgs_mid = xr.DataArray( np.log(pavgs.values), coords0, dims).bfill("level") # Assuming goes from surface to the top of the atmosphere flip_pressure = False if not np.all(np.diff(pavgs.values, axis=0) > 0): pavgs.values = np.flip(pavgs.values, axis=0) flip_pressure = True # Computes dpavgs from top to bottom dpavgs = np.diff(pavgs.values, axis=0) dpavgs = np.concatenate([ np.maximum(0, pavgs.values[[0]] - dpavgs[[0]] / 2), pavgs.values[1:] - dpavgs / 2, pavgs.values[[-1]] + dpavgs[[-1]] / 2 ], axis=0) dpavgs = xr.DataArray(np.diff(dpavgs, axis=0), coords1, dims) # Flip back pressures if flip_pressure: pavgs.values = np.flip(pavgs.values, axis=0) dpavgs.values = np.flip(dpavgs.values, axis=0) else: pavgs = xr.DataArray(pavgs, coords0, dims).bfill("level") dpavgs = np.abs(xr.DataArray(np.diff(-pavgs, axis=0), coords1, dims)) pavgs_mid = xr.DataArray( np.log(0.5 * (pavgs[:-1].values + pavgs[1:].values)), coords1, dims) # Exclude observations where there are all zeros in the simulation exclude_zeros = np.where( ~fwd_pressure["metadata"].groupby(['indorig']) .min()["exclude_zeros"] )[0] datasim = datasim.isel(index=exclude_zeros) sat_aks = sat_aks.isel(index=exclude_zeros) aks = aks.isel(index=exclude_zeros) pavgs_mid = pavgs_mid.isel(index=exclude_zeros) dpavgs = dpavgs.isel(index=exclude_zeros) qa0 = qa0.isel(index=exclude_zeros) nobs = datasim.dims["index"] # Adding dry air mole fraction if formula 5 if transf.use_drycols: drycols = sat_aks["dryair"][:, ::-1].T else: drycols = qa0 * 0.0 + 1 # Pressure weight to apply on columns pwgt = dpavgs / dpavgs.sum(axis=0) if not transf.scale_dpressure: pwgt = 0 * pwgt + 1 elif transf.precomputed_pwgt: pwgt = sat_aks["pw"][:, ::-1].T # Applying aks nbformula = transf.formula chosenlevel = getattr(transf, "chosenlev", 0) debug("nbformula: {}".format(nbformula)) debug("chosenlev: {}".format(chosenlevel)) # If nbformula 3, load sim_ak from forward if transf.log_space: file_dump = ddi.strftime( "{}/chain/satellites/{}/sim_ak_{}_%Y%m%d%H%M.nc".format( transf.adj_refdir, transf.transform_id, satID) ) sim_ak = xr.open_dataarray(file_dump).values else: sim_ak = 0 obs_incr = apply_ak_ad( obs_incr.values[exclude_zeros], sim_ak, aks.values, pwgt.values, drycols.values, qa0.values, use_drycols=transf.use_drycols, chosen_level=chosenlevel, scale_factor=transf.unit_scaling, normalize_columns=transf.normalize_columns, log_space=transf.log_space ) obs_incr[~pd.notnull(obs_incr)] = 0. # Correction with the pressure thickness # WARNING: there is an inconsistency in the number of levels if transf.correct_pthick: scale_pthick = fwd_pressure["metadata"]["pthick"].iloc[ np.flatnonzero(ref_indexes)[satmask]].iloc[exclude_zeros] obs_incr *= scale_pthick.values # Adjoint of the log-pressure interpolation obs_incr_interp = 0.0 * datasim["pressure"].values # Fetch missing values from stratosphere sim_pressure = datasim["pressure"].values sim_dpressure = datasim["dp"].values if transf.fill_strato: strato_mapper = mapper["inputs"][("stratosphere", ref_parameter)] strato_sigma_a = strato_mapper["domain"].sigma_a_mid strato_sigma_b = strato_mapper["domain"].sigma_b_mid strato_sigma_a_interface = strato_mapper["domain"].sigma_a strato_sigma_b_interface = strato_mapper["domain"].sigma_b strato_nlev = len(strato_sigma_a) psurf_strato = np.exp(sim_pressure[0]) pstrato = np.log( strato_sigma_b * psurf_strato[:, np.newaxis] + strato_sigma_a ).T pstrato_interface = ( strato_sigma_b_interface * psurf_strato[:, np.newaxis] + strato_sigma_a_interface ).T dpstrato = np.abs(np.diff(pstrato_interface, axis=0)) # Here merge sim_pressure # and pstrato properly, then apply vertical_interp missing_levels_check = len( np.unique(np.argmax(pstrato < sim_pressure[-1], axis=0))) nblevbasic = np.shape(sim_pressure)[0] cond = pstrato < sim_pressure[-1] missing_levels = np.argmax(cond, axis=0) missing_levels[~np.any(cond, axis=0)] = pstrato.shape[0] nblevremainingstrato = pstrato.shape[0] - missing_levels.min() sim_pressure_varying = np.full( (nblevbasic + nblevremainingstrato, len(datasim.index)), 0.) sim_pressure_varying[:nblevbasic] = sim_pressure[:] sim_dpressure_varying = np.full( (nblevbasic + nblevremainingstrato, len(datasim.index)), 0.) sim_dpressure_varying[:nblevbasic] = sim_dpressure[:] lev_index_target = [ i for m in missing_levels for i in list(range(nblevbasic, nblevbasic + pstrato.shape[0] - m)) ] lev_index_orig = [i for m in missing_levels for i in list(range(m, pstrato.shape[0]))] obs_index_target = [ j for j, m in enumerate(missing_levels) for i in list(range(nblevbasic, nblevbasic + pstrato.shape[0] - m)) ] np.add.at(sim_pressure_varying, (lev_index_target, obs_index_target), pstrato[lev_index_orig, obs_index_target] ) sim_pressure = copy.deepcopy(sim_pressure_varying) np.add.at(sim_dpressure_varying, (lev_index_target, obs_index_target), dpstrato[lev_index_orig, obs_index_target] ) sim_dpressure = copy.deepcopy(sim_dpressure_varying) # Vertical interpolation xlow, xhigh, alphalow, alphahigh = vertical_interp( sim_pressure, sim_dpressure, pavgs_mid.values, dpavgs.values, transf.cropstrato, transf.vinterp_type, transf.weights_nsubsteps ) # Applying coefficients # WARNING: There might be repeated indexes in a given column # To deal with repeated index, np.add.at is recommended meshout_wgt = ( np.arange(pavgs_mid.shape[0])[:, np.newaxis] * np.ones((1, len(datasim.index))) ).astype(int) meshout_wgt_iobs = ( np.arange(len(datasim.index))[np.newaxis] * np.ones((pavgs_mid.shape[0], 1)) ).astype(int) if transf.vinterp_type == "weight": meshout_wgt = np.floor( np.linspace(0, pavgs_mid.shape[0], pavgs_mid.shape[0] * transf.weights_nsubsteps, endpoint=False )[:, np.newaxis] * np.ones((1, len(datasim.index))) ).astype(int) meshout_wgt_iobs = ( np.arange(len(datasim.index))[np.newaxis] * np.ones((pavgs_mid.shape[0] * transf.weights_nsubsteps, 1)) ).astype(int) tmp_obs_incr = 0.0 * sim_pressure np.add.at( tmp_obs_incr, (xlow, meshout_wgt_iobs), obs_incr[meshout_wgt, meshout_wgt_iobs] * alphalow, ) np.add.at( tmp_obs_incr, (xhigh, meshout_wgt_iobs), obs_incr[meshout_wgt, meshout_wgt_iobs] * alphahigh, ) # Deal with the stratosphere if transf.fill_strato: nstrato = strato_mapper["domain"].nlev nlon = strato_mapper["domain"].nlon nlat = strato_mapper["domain"].nlat incr_strato = np.zeros((strato_nlev, len(datasim.index))) np.add.at(incr_strato, (lev_index_orig, obs_index_target), tmp_obs_incr[lev_index_target, obs_index_target]) # ppb to molec/cm2 if column product if transf.product == "column": dpstrato = \ np.diff(np.concatenate( [psurf_strato[np.newaxis, :], np.exp(pstrato)], axis=0), axis=0) / 100 # hPa G = 9.81 dmass = np.abs(dpstrato / G) # kg/m2 column = dmass * 1e3 # g/m2 column /= 28.96 * 1e4 # mol/cm2 column *= 6.02214076e23 # molec / cm2 column /= 1e9 # scaling from ppb incr_strato *= column # Fill adj_out native_ind_stack_strato = ( native_inds_strato[satmask] + np.arange(nstrato)[:, np.newaxis] ) out_strato = out_datastore[ ("stratosphere", transf.parameter[0])][ddi] strato_index = native_ind_stack_strato[:, exclude_zeros].flatten() out_strato.iloc[ strato_index, out_strato.columns.get_loc(("maindata", "adj_out")) ] = incr_strato.flatten() # Keep only the "bottom" part for later obs_incr_interp[:, :] = tmp_obs_incr[:nblevbasic, :] else: obs_incr_interp[:, :] = tmp_obs_incr[:] # Convert CHIMERE fields to the correct unit # from ppb to molec.cm-2 if the satellite product is a column if transf.product == "column": factor = (datasim["hlay"].values * 100) \ / (1e9 / datasim["airm"].values) obs_incr_interp *= factor # Applying increments to the flattened datastore out_index = native_ind_stack[:, exclude_zeros].flatten() df_main.iloc[ out_index, df_main.columns.get_loc(("maindata", "adj_out")) ] = obs_incr_interp.flatten() # Pushing adjoint to the general datastore out_datastore[("concs", transf.parameter[0])][ddi] = df_main