from __future__ import annotations
import numpy as np
import xarray as xr
from ....utils.classes.setup import Setup
[docs]
def squeeze_vertical_dim(self) -> object:
"""Create a copy of the domain with only 1 vertical level
Returns
-------
object
Domain plugin with only 1 vertical level
"""
domain = Setup.load_registered(
self.plugin.name,
self.plugin.version,
"domain",
plg_orig=self,
)
domain.unstructured_domain = self.unstructured_domain
domain.nlon = self.nlon
domain.nlat = self.nlat
domain.zlon = self.zlon
domain.zlat = self.zlat
domain.zlonc = self.zlonc
domain.zlatc = self.zlatc
domain.nlon_side = self.nlon_side
domain.nlat_side = self.nlat_side
domain.zlonc_side = self.zlonc_side
domain.zlatc_side = self.zlatc_side
domain.zlon_side = self.zlon_side
domain.zlat_side = self.zlat_side
domain.pressure_unit = self.pressure_unit
domain.nlev = 1
domain.sigma_a = np.array([0, 0])
domain.sigma_b = np.array([1, 1])
domain.sigma_a_mid = np.array([0])
domain.sigma_b_mid = np.array([1])
return domain
[docs]
def get_time_splits(self) -> tuple[int, int]:
"""Get time step splits for the domain
Returns
-------
tuple[int, int]
Dynamics time step split, physics time step split
"""
shape = (self.nlon, self.nlat)
if shape == (16, 10):
return 12, 6 # Academic grid used in tests
if shape == (96, 96):
return 12, 6
elif shape == (144, 143):
return 18, 9
elif shape == (256, 257):
return 30, 15
else:
raise ValueError(f"Invalid shape of domain: {shape}")
[docs]
def get_domain_coords(self, vertical: bool = False) -> dict[str, xr.DataArray]:
"""Get the formatted domain coordinates for writing NetCDF files
Parameters
----------
vertical : bool, optional
Include vertical coordinates, by default False
Returns
-------
dict[str, DataArray]
Coordinates
"""
lat = self.zlat[:, 0]
lat_bnds = self.zlatc[:, 0]
lat_bnds = np.concatenate(
[lat_bnds[:-1, np.newaxis], lat_bnds[1:, np.newaxis]], axis=1
)
lon = self.zlon[0, :]
lon_bnds = self.zlonc[0, :]
lon_bnds = np.concatenate(
[lon_bnds[:-1, np.newaxis], lon_bnds[1:, np.newaxis]], axis=1
)
coords = {
# fmt: off
"lat": xr.DataArray(lat, dims=["lat"], attrs={
"standard_name": "latitude",
"long_name": "latitude",
"units": "degrees_north",
"axis": "Y",
"bounds": "lat_bnds"
}),
"lon": xr.DataArray(lon, dims=["lon"], attrs={
"standard_name": "longitude",
"long_name": "longitude",
"units": "degrees_east",
"axis": "X",
"bounds": "lon_bnds"
}),
"lat_bnds": xr.DataArray(lat_bnds, dims=["lat", "bnds"], attrs={
"standard_name": "latitude_bounds",
"long_name": "latitude bounds",
"units": "degrees_north"
}),
"lon_bnds": xr.DataArray(lon_bnds, dims=["lon", "bnds"], attrs={
"standard_name": "longitude_bounds",
"long_name": "longitude bounds",
"units": "degrees_east"
}),
# fmt: on
}
if vertical:
coords.update(
{
# fmt: off
"ap": xr.DataArray(self.sigma_a, dims=["lev"], attrs={
"standard_name": "atmosphere_hybrid_sigma_pressure_coordinate_ap",
"long_name": "Ap coefficient at layer interface",
"units": "Pa"
}),
"bp": xr.DataArray(self.sigma_b, dims=["lev"], attrs={
"standard_name": "atmosphere_hybrid_sigma_pressure_coordinate_bp",
"long_name": "B coefficient at layer interface",
"units": "1"
}),
# fmt: on
}
)
return coords