Source code for pycif.plugins.transforms.basic.regrid.forward

import os
import itertools
import numpy as np
import xarray as xr
import copy
import pandas as pd
from logging import debug
from scipy.interpolate import RectBivariateSpline

from .utils.get_weights import get_weights
from .....utils.classes.domains import Domain
from .....utils.parallel import thread
from .....utils.datastores.empty import init_empty

try:
    import cPickle as pickle
except ImportError:
    import pickle


[docs] def forward( transf, inout_datastore, controlvect, obsvect, mapper, ddi, ddf, mode, runsubdir, workdir, onlyinit=False, save_debug=False, **kwargs ): xmod = inout_datastore["inputs"] trid_ref = list(xmod)[0] # First fetch information to # Differentiate full matrices and sparse data is_sparse_out = mapper["outputs"][trid_ref].get("sparse_data", False) is_sparse_in = mapper["inputs"][trid_ref].get("sparse_data", False) is_sampled_in = mapper["inputs"][trid_ref]["sampled"] is_sampled_out = mapper["outputs"][trid_ref]["sampled"] # Inputs domain from the present tracer if is_sparse_in: if onlyinit: return domain_in = Domain() ds = inout_datastore["outputs"][trid_ref][ddi]["metadata"] domain_in.zlon = np.asarray(ds["lon"])[:, np.newaxis] domain_in.zlonc = np.asarray(ds["lon"])[:, np.newaxis] domain_in.zlat = np.asarray(ds["lat"])[:, np.newaxis] domain_in.zlatc = np.asarray(ds["lat"])[:, np.newaxis] domain_in.nlon = len(ds) domain_in.nlat = len(ds) nlat_in, nlon_in = domain_in.nlon, domain_in.nlon else: domain_in = mapper["inputs"][trid_ref]["domain"] nlat_in, nlon_in = domain_in.zlon.shape # Outputs domain from the mapper is_lbc = mapper["outputs"][trid_ref].get("is_lbc", False) # Create domain from metadata if sparse data if is_sparse_out or is_sampled_out: # Do nothing if onlyinit and propagating empty matrix if onlyinit and len(xmod[trid_ref][ddi]) == 0: return domain_out = Domain() ds = xmod[trid_ref][ddi]["metadata"] domain_out.zlon = np.asarray(ds["lon"])[:, np.newaxis] domain_out.zlonc = np.asarray(ds["lon"])[:, np.newaxis] domain_out.zlat = np.asarray(ds["lat"])[:, np.newaxis] domain_out.zlatc = np.asarray(ds["lat"])[:, np.newaxis] domain_out.nlon = len(ds) domain_out.nlat = len(ds) else: domain_out = mapper["outputs"][trid_ref]["domain"] # # For non-sparse data, in initialization mode, don't need to do # # any further operations # if onlyinit: # return # Differentiate dimensions if LBC if is_lbc: nlon_out = domain_out.nlon_side nlat_out = domain_out.nlat_side else: nlat_out, nlon_out = domain_out.zlat.shape # Getting weights weights = get_weights( transf, trid_ref, mapper, domain_in, domain_out, is_lbc, ddi, is_sparse_in=is_sparse_in, is_sparse_out=is_sparse_out or is_sampled_out ) # Execute parallel threads nthreads = transf.nthreads if transf.threading_mode == "trid" else 1 nthreads_levels = transf.nthreads if transf.threading_mode == "levels" else 1 thread_intervals = np.linspace( 0, len(mapper["inputs"]), nthreads + 1 ).astype(int) list_trids = copy.deepcopy(list(mapper["inputs"].keys())) # Initialize output dictionary before running thread # This is to avoid overwritting data for trid in list_trids: inout_datastore["outputs"][trid] = {ddi: {}} @thread def thread_function(ithread): for itrid in range(thread_intervals[ithread], thread_intervals[ithread + 1]): trid = list_trids[itrid] # Now apply weights if not (is_sparse_in or is_sparse_out or is_sampled_out): toregrid = copy.deepcopy(xmod[trid][ddi]["spec"]) inout_datastore["outputs"][trid][ddi]["spec"] = do_regridding( inout_datastore["outputs"][trid][ddi], toregrid, nlat_in, nlon_in, nlat_out, nlon_out, weights, min_weight=transf.min_weight, nthreads=nthreads_levels ) if mode == "tl": if "incr" in xmod[trid][ddi]: toregrid = copy.deepcopy(xmod[trid][ddi]["incr"]) inout_datastore["outputs"][trid][ddi]["incr"] = do_regridding( inout_datastore["outputs"][trid][ddi], toregrid, nlat_in, nlon_in, nlat_out, nlon_out, weights, min_weight=transf.min_weight, nthreads=nthreads_levels ) else: inout_datastore["outputs"][trid][ddi]["incr"] = \ 0. * inout_datastore["outputs"][trid][ddi]["spec"] else: inout_datastore["outputs"][trid][ddi] = do_regridding( inout_datastore["outputs"][trid][ddi], xmod[trid][ddi], nlat_in, nlon_in, nlat_out, nlon_out, weights, min_weight=transf.min_weight, is_sparse_in=is_sparse_in, is_sparse_out=is_sparse_out or is_sampled_out, save_debug=save_debug, transf=transf ) thread_function(range(nthreads))
[docs] def do_regridding( datastore_out, data, nlat_in, nlon_in, nlat_out, nlon_out, weights, min_weight=1e-10, is_sparse_in=False, is_sparse_out=False, save_debug=False, transf=None, nthreads=1 ): if not (is_sparse_in or is_sparse_out): # Applying weights nlev = len(data.lev) ntimes = len(data.time) var_out = np.zeros( (ntimes, nlev, nlat_out, nlon_out), dtype=data.dtype) var_in = data.values # Filter weights mask = weights["wgt"] > min_weight weights["wgt"][~mask] = 0 # Output index iout, jout = np.unravel_index(range(nlat_out * nlon_out), (nlat_out, nlon_out), order="F") if "filtered" in weights: iout = iout[weights["filtered"]] jout = jout[weights["filtered"]] i = weights["i"].astype(int) j = weights["j"].astype(int) wgt = weights["wgt"] # Apply weights in thread mode list_time_levels = list(itertools.product( range(var_out.shape[0]), range(var_out.shape[1]) )) thread_intervals = np.linspace( 0, len(list_time_levels), nthreads + 1 ).astype(int) @thread def thread_function(ithread): for itime_lev in range(thread_intervals[ithread], thread_intervals[ithread + 1]): time, level = list_time_levels[itime_lev] var_out[time, level, iout, jout] = np.nansum( wgt * var_in[time, level, i, j], axis=-1) thread_function(range(nthreads)) # Return xarray times = data.time.values return xr.DataArray( var_out, coords={"time": times}, dims=("time", "lev", "lat", "lon") ) else: i = weights["i"].astype(int) j = weights["j"].astype(int) wgts = weights["wgt"] # Mask filtered if exists mask = np.arange(len(data)) if "filtered" in weights: mask = weights["filtered"] non_filtered = [] if "non_filtered" in weights: non_filtered = weights["non_filtered"] # Return if mask is empty if mask.size == 0: debug( "Skipping regridding for sparse data as no data is provided." ) return datastore_out # Initialize output dataframe datastore_out = init_empty(nlines=len(mask) + len(non_filtered)) # Output indexes if is_sparse_out: ref_index = pd.Index(np.arange(len(mask))) out_indexes = pd.Index(np.arange(len(mask))).repeat( np.shape(i)[1] * np.ones(len(i), dtype=int)).values column2add = ["incr", "spec"] # non_filtered = datastore_out.reset_index().index.difference(mask) for c in column2add: datastore_out[("maindata", c)] = 0. data_out = datastore_out[("maindata", c)].iloc[mask].to_numpy() np.add.at(data_out, out_indexes, data[("maindata", c)].values * wgts.flatten()) datastore_out.iloc[ mask, datastore_out.columns.get_loc(("maindata", c)) ] = data_out # Putting NaNs in data outside the domain datastore_out.iloc[ non_filtered, datastore_out.columns.get_loc( ("maindata", c)) ] = np.nan if save_debug: debug_cols = pd.MultiIndex.from_product( [[transf.transform_id], ["i_{:02d}".format(k) for k in range(i.shape[1])] + ["j_{:02d}".format(k) for k in range(i.shape[1])] + ["weight_{:02d}".format(k) for k in range(i.shape[1])]] ) df_debug = pd.DataFrame(index=datastore_out.index, columns=debug_cols) # Fill columns for k in range(i.shape[1]): df_debug.iloc[ mask, df_debug.columns.get_loc( (transf.transform_id, "i_{:02d}".format(k)) )] = i[:, k] df_debug.iloc[ mask, df_debug.columns.get_loc( (transf.transform_id, "j_{:02d}".format(k)) )] = j[:, k] df_debug.iloc[ mask, df_debug.columns.get_loc( (transf.transform_id, "weight_{:02d}".format(k)) )] = wgts[:, k] datastore_out = pd.concat([datastore_out, df_debug], axis=1) else: out_indexes = data.iloc[mask].reset_index().index.repeat( np.shape(i)[1] * np.ones(len(i), dtype=int)).values column2add = ["incr", "spec"] datastore_out = copy.deepcopy(data.iloc[mask]).iloc[out_indexes] for c in column2add: datastore_out[("maindata", c)] = 0. data_in = data.iloc[mask][("maindata", c)].to_numpy() datastore_out.loc[:, ("maindata", c) ] = data_in[out_indexes] * wgts.flatten() datastore_out.loc[:, ("metadata", "i")] = i.flatten() datastore_out.loc[:, ("metadata", "j")] = j.flatten() # Group by i/j/alt/dates group = datastore_out.groupby( [("metadata", "i"), ("metadata", "j"), ("metadata", "alt"), ("metadata", "date")] ) datastore_out = group.max().reset_index() for c in column2add: datastore_out[("maindata", c)] = group.sum().reset_index()[ ("maindata", c)] return datastore_out