from __future__ import annotations
import datetime
import itertools
import logging
import os
import re
from os import PathLike
from pathlib import Path
from typing import Any, Literal, overload
import cfgrib
import numpy as np
[docs]
class GribDataset:
"""This class is a wrapper around cfgrib to read GRIB files and access their
variables and attributes.
Parameters
----------
path : str | path-like
Path to the GRIB file
read_keys : list of str, optional
Additional keys to read, by default None
filter_by_keys : dict of str, optional
passed to cfgrib.open_file function, by default None
Raises
------
FileNotFoundError
If the file does not exist
Examples
--------
>>> ds = GribDataset(
"path/to/file.grib",
read_keys=["isOctahedral"],
filter_by_keys={"edition": 1}
)
>>> ds["latitude"] # get a variable
array([...])
>>> ds.get_attr("isOctahedral") # get an attribute
1
"""
def __init__(
self,
path: str | PathLike[str],
read_keys: list[str] | None = None,
filter_by_keys: dict[str, Any] | None = None,
) -> None:
self.path = Path(path)
read_keys = [] if read_keys is None else read_keys
filter_by_keys = {} if filter_by_keys is None else filter_by_keys
if not self.path.is_file():
raise FileNotFoundError(f"{self.path} was not found")
# Disable logging to avoid warning messages
logger = logging.getLogger()
log_level = logger.level
logger.setLevel(10 * log_level)
self._ds = cfgrib.open_file(
self.path,
indexpath="",
read_keys=read_keys,
filter_by_keys=filter_by_keys,
)
logger.setLevel(log_level)
@property
def variables(self) -> dict[str, cfgrib.dataset.Variable]:
"""Returns the variables mapping of the GRIB file"""
return self._ds.variables
@overload
def __getitem__(self, key: str) -> np.ndarray: ...
@overload
def __getitem__(self, key: list[str]) -> list[np.ndarray]: ...
def __getitem__(self, key: str | list[str]) -> np.ndarray | list[np.ndarray]:
"""Get one or several variables from the GRIB file
Parameters
----------
key : str or list of str
Variable name or list of variable names to get
Returns
-------
array or list of arrays
Variable data as Numpy array(s)
"""
varname_list = [key] if isinstance(key, str) else key
data = []
for varname in varname_list:
data.append(self.get_var(varname))
if isinstance(key, str):
return data[0]
return data
[docs]
def get_var(self, varname: str) -> np.ndarray:
"""Get one variable from the GRIB file as a Numpy array
Parameters
----------
varname : str
Variable name
Returns
-------
np.ndarray
Variable data as Numpy array
Raises
------
KeyError
If the variable is not found in the GRIB file
"""
if varname not in self.variables:
raise KeyError(
f"The variable {varname} is not available is the file {self.path}. "
+ f"Available variables: {list(self.variables.keys())}"
)
data = self.variables[varname].data
if hasattr(data, "build_array"):
data = data.build_array() # type: ignore
return data
[docs]
def get_attr(self, attr: str, varname: str | None = None) -> Any:
"""Get an attribute from the GRIB file. If varname is provided, looks for the
attribute in the variable attributes, otherwise looks for it in all variables
and returns the first one found.
Parameters
----------
attr : str
Attribute name to get (will be prefixed by "GRIB_" to match cfgrib attributes)
varname : str, optional
variable name to restrict the search, by default None
Returns
-------
Any
Attribute value
Raises
------
KeyError
If the attribute is not found.
"""
grib_attr = f"GRIB_{attr}"
if varname is not None:
if grib_attr not in self.variables[varname].attributes:
raise KeyError(
f"Could not find attribute {attr} in variable {varname} of file {self.path}"
)
return self.variables[varname].attributes[grib_attr]
for varname in self.variables:
if grib_attr in self.variables[varname].attributes:
return self.variables[varname].attributes[grib_attr]
raise KeyError(f"Could not find attribute {attr} in file {self.path}")
[docs]
def find_valid_file(
file_format,
dd,
time_freq,
ref_dir,
ref_dir_next,
ref_dir_previous,
delta_tolerance=1,
cumul_variable=False,
cumul_length=12,
):
# Get all files and dates matching the file and format
# list_files_orig = os.listdir(ref_dir)
ref_dir = Path(ref_dir)
ref_dir_next = Path(ref_dir_next)
ref_dir_previous = Path(ref_dir_previous)
list_files_orig = list(ref_dir.iterdir())
# Convert ref date
ref_date = datetime.datetime.strptime(dd.strftime(file_format), file_format)
previous_date = ref_date - datetime.timedelta(hours=cumul_length)
if previous_date.month < ref_date.month or previous_date.year < ref_date.year:
list_files_orig += list(ref_dir_previous.iterdir())
next_date = ref_date + datetime.timedelta(hours=cumul_length)
if next_date.month > ref_date.month or next_date.year < ref_date.year:
list_files_orig += list(ref_dir_next.iterdir())
list_dates_cur = []
list_forecast_cur = []
list_forecast_hour = []
list_files_cur = []
for f in list_files_orig:
# Ignore index files generated by xarray and cfgrib
if str(f).find("idx") >= 0:
continue
basef = f.name
# Fetch date information
re_format = (
file_format.replace(".", "/")
.replace("%Y", "(\\d{4})")
.replace("%m", "(\\d{2})")
.replace("%d", "(\\d{2})")
.replace("%H", "(\\d{1,2})")
.replace("%M", "(\\d{2})")
.replace("*", "\\d{1,2}")
.replace("{di}", "(\\d{1,2})")
)
match_file = re.search(re_format, basef.replace(".", "/"))
if match_file is None:
continue
if match_file.span() != (0, len(basef)):
continue
patterns = ["%Y", "%m", "%d", "%H", "%M", "{di}"]
pstarts = []
pout = []
for p in patterns:
match = re.search(p.replace("%", "\\%"), file_format)
if match is not None:
pstarts.append(match.start())
pout.append(p)
di = 0
groups = list(match_file.groups())
if "{di}" in file_format:
index_di = sorted(pstarts).index(pstarts[-1])
di_str = match_file.groups()[index_di]
di = int(di_str)
pout = pout[:-1]
pstarts = pstarts[:-1]
groups.pop(index_di)
# Deal with time stamps at 24h
shift_hour = 0
if "%H" in pout:
shift_hour = int(groups[pout.index("%H")])
groups[pout.index("%H")] = "00"
date_cur = datetime.datetime.strptime(
"".join([groups[k] for k in np.argsort(pstarts)]), "".join(pout)
)
date_cur += datetime.timedelta(hours=di + shift_hour)
list_dates_cur.append(date_cur)
list_files_cur.append(str(f))
list_forecast_cur.append((date_cur - datetime.timedelta(hours=di)).hour)
list_forecast_hour.append(di)
list_files = np.array(list_files_cur)
list_dates = np.array(list_dates_cur)
list_forecast_cur = np.array(list_forecast_cur)
list_forecast_hour = np.array(list_forecast_hour)
# Sorting along dates
isort = np.argsort(list_dates)
list_dates = list_dates[isort]
list_files = list_files[isort]
list_forecast_cur = list_forecast_cur[isort]
list_forecast_hour = list_forecast_hour[isort]
if list_files.size == 0:
raise FileNotFoundError(
f"Did not find any valid GRIB files in {ref_dir} "
f"with format {file_format}. Please check your yml file"
)
# Find nearest previous date
mask = (list_dates - ref_date) <= datetime.timedelta(0)
if mask.sum() == 0:
raise FileNotFoundError(
f"No file has valid date for {ref_date} in {ref_dir} "
f"with format {file_format}. Please check your yml file. \n"
f"The range of dates covered by files is: {list_dates.min()} / {list_dates.max()}"
)
max_date = np.max(list_dates[mask])
mask = list_dates == max_date
# Fetch date with smaller forecast hour
mask = mask & (list_forecast_cur == list_forecast_cur[mask].min())
ind_date = np.where(mask)[0][0]
file_ref1 = list_files[ind_date]
date_ref1 = list_dates[ind_date]
forecast_hour_ref1 = list_forecast_hour[ind_date]
forecast_cur_ref1 = list_forecast_cur[ind_date]
# Check that date_ref1 is not too far away from ref_date
delta = date_ref1 - ref_date
if np.abs(delta) > delta_tolerance * time_freq:
raise FileNotFoundError(
f"Could not find files close enough to the expected date (date_ref1):\n"
f"\t- requested date: {dd.isoformat()}\n"
f"\t- date using file formating: {ref_date.isoformat()}\n"
f"\t- closest valid date: {date_ref1.isoformat()}\n"
f"\t- timedelta from expected date: {delta!s}\n"
f"\t- timedelta from file frequency: {time_freq!s}\n"
)
# Deal differently between cumulated and instantaneous variables
# If not cumulative variable, just return instantaneous snapshot
if not cumul_variable:
dd1 = dd + (date_ref1 - ref_date)
return [file_ref1], [dd1]
#
# Now find nearest next date
#
mask = (list_dates - ref_date) > datetime.timedelta(0)
if mask.sum() == 0:
return [file_ref1], [dd + (date_ref1 - ref_date)]
min_date = np.min(list_dates[mask])
mask = list_dates == min_date
# Deal differently between cumulated and instantaneous variables
# For cumulated variable, make sure that the forecast date is the same as the previous date
mask = mask & (list_forecast_hour == forecast_hour_ref1)
if mask.sum() == 0:
raise ValueError(
f"Could not decumulate variable in files {file_format} "
f"for date {dd} because could not guarantee that the same "
"forecast hour is available for the two sides of the interval"
)
ind_date = np.where(mask)[0][0]
file_ref2 = list_files[ind_date]
date_ref2 = list_dates[ind_date]
# Check that date_ref1 is not too far away from ref_date
delta = date_ref2 - ref_date
if np.abs(delta) > delta_tolerance * time_freq:
raise FileNotFoundError(
f"Could not find files close enough to the expected date (date_ref2):\n"
f"\t- requested date: {dd.isoformat()}\n"
f"\t- date using file formating: {ref_date.isoformat()}\n"
f"\t- closest valid date: {date_ref2.isoformat()}\n"
f"\t- timedelta from expected date: {delta!s}\n"
f"\t- timedelta from file frequency: {time_freq!s}\n"
f"\t- file format: {file_format}\n"
)
# Reconvert to original date
dd1 = dd + (date_ref1 - ref_date)
dd2 = dd + (date_ref2 - ref_date)
return [file_ref1, file_ref2], [dd1, dd2]
[docs]
def get_grid_type(
domain_file: str | PathLike[str],
filter_by_keys_dict: dict[str, Any] | None = None,
) -> Literal["regular", "octahedral", "reduced_gaussian"]:
ds = GribDataset(
domain_file,
read_keys=["isOctahedral"],
filter_by_keys=filter_by_keys_dict,
)
try:
is_octahedral = ds.get_attr("isOctahedral") == 1
except KeyError:
is_octahedral = False
if is_octahedral:
return "octahedral"
else:
lat = ds["latitude"]
if not np.any(np.diff(lat) == 0):
return "regular"
else:
return "reduced_gaussian"
[docs]
def get_jscan(
domain_file: str | PathLike[str],
filter_by_keys_dict: dict[str, Any] | None = None,
) -> int:
ds = GribDataset(
domain_file,
read_keys=["jScansPositively"],
filter_by_keys=filter_by_keys_dict,
)
try:
jscan = ds.get_attr("jScansPositively")
except KeyError as e:
raise ValueError(
"When 'jScansPositively' is not defined, it is possible to force it "
"with the argument 'jScansPositively' of your Yml"
) from e
return jscan