Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions test/io/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np
import pytest
import xarray as xr

from uxarray.io.utils import _parse_grid_type


@pytest.mark.parametrize(
("path_args", "expected_spec"),
[
(("exodus", "outCSne8", "outCSne8.g"), "Exodus"),
(("scrip", "outCSne8", "outCSne8.nc"), "Scrip"),
(("ugrid", "outCSne30", "outCSne30.ug"), "UGRID"),
(("mpas", "QU", "mesh.QU.1920km.151026.nc"), "MPAS"),
(("esmf", "ne30", "ne30pg3.grid.nc"), "ESMF"),
(("geos-cs", "c12", "test-c12.native.nc4"), "GEOS-CS"),
(("icon", "R02B04", "icon_grid_0010_R02B04_G.nc"), "ICON"),
(("fesom", "soufflet-netcdf", "grid.nc"), "FESOM2"),
],
)
def test_parse_grid_type_detects_supported_formats(gridpath, path_args, expected_spec):
with xr.open_dataset(gridpath(*path_args)) as ds:
source_grid_spec, lon_name, lat_name = _parse_grid_type(ds)

assert source_grid_spec == expected_spec
assert lon_name is None
assert lat_name is None


def test_parse_grid_type_detects_structured_grid():
lon = xr.DataArray(
np.array([0.0, 1.0, 2.0]),
dims=["lon"],
attrs={"standard_name": "longitude"},
)
lat = xr.DataArray(
np.array([-1.0, 0.0, 1.0]),
dims=["lat"],
attrs={"standard_name": "latitude"},
)
ds = xr.Dataset(coords={"lon": lon, "lat": lat})

source_grid_spec, lon_name, lat_name = _parse_grid_type(ds)

assert source_grid_spec == "Structured"
assert lon_name == "lon"
assert lat_name == "lat"


@pytest.mark.parametrize(
"dataset",
[
xr.Dataset({"grid_center_lon": xr.DataArray([0.0], dims=["grid_size"])}),
xr.Dataset(
{
"coordx": xr.DataArray([0.0, 1.0], dims=["num_nodes"]),
"coordy": xr.DataArray([0.0, 1.0], dims=["num_nodes"]),
}
),
xr.Dataset({"verticesOnCell": xr.DataArray([[1, 2, 3]], dims=["nCells", "nVert"])}),
],
)
def test_parse_grid_type_rejects_incomplete_format_signals(dataset):
with pytest.raises(RuntimeError, match="Could not recognize dataset format"):
_parse_grid_type(dataset)
88 changes: 75 additions & 13 deletions uxarray/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,74 @@
from uxarray.io._ugrid import _is_ugrid, _read_ugrid


def _is_exodus(dataset: xr.Dataset) -> bool:
"""Check whether a dataset looks like an Exodus mesh."""
has_packed_coords = "coord" in dataset
has_split_coords = {"coordx", "coordy"}.issubset(dataset.variables)
has_connectivity = any(
name.startswith("connect") for name in dataset.variables
) or any("num_nod_per_el" in dim for dim in dataset.dims)

return has_connectivity and (has_packed_coords or has_split_coords)


def _is_scrip(dataset: xr.Dataset) -> bool:
"""Check whether a dataset looks like an unstructured SCRIP grid."""
required_vars = {
"grid_center_lon",
"grid_center_lat",
"grid_corner_lon",
"grid_corner_lat",
}
unstructured_markers = {"grid_imask", "grid_rank", "grid_area"}

return required_vars.issubset(dataset.variables) and any(
marker in dataset for marker in unstructured_markers
)


def _is_mpas(dataset: xr.Dataset) -> bool:
"""Check whether a dataset looks like an MPAS grid."""
if "verticesOnCell" not in dataset:
return False

companion_groups = (
{"nEdgesOnCell"},
{"latCell", "lonCell"},
{"latVertex", "lonVertex"},
{"xCell", "yCell", "zCell"},
{"xVertex", "yVertex", "zVertex"},
)

return any(group.issubset(dataset.variables) for group in companion_groups)


def _is_esmf(dataset: xr.Dataset) -> bool:
"""Check whether a dataset looks like an ESMF mesh."""
return "maxNodePElement" in dataset.dims and "elementConn" in dataset


def _is_geos_cs(dataset: xr.Dataset) -> bool:
"""Check whether a dataset looks like a GEOS cube-sphere grid."""
required_dims = {"nf", "YCdim", "XCdim"}
required_vars = {"corner_lons", "corner_lats"}

return required_dims.issubset(dataset.sizes) and required_vars.issubset(
dataset.variables
)


def _is_icon(dataset: xr.Dataset) -> bool:
"""Check whether a dataset looks like an ICON grid."""
required_vars = {"vertex_of_cell", "clon", "clat", "vlon", "vlat"}
return required_vars.issubset(dataset.variables)


def _is_fesom2(dataset: xr.Dataset) -> bool:
"""Check whether a dataset looks like a FESOM2 grid."""
return "triag_nodes" in dataset


def _parse_grid_type(dataset):
"""Checks input and contents to determine grid type. Supports detection of
UGrid, SCRIP, Exodus, ESMF, and shape file.
Expand All @@ -31,27 +99,21 @@ def _parse_grid_type(dataset):

_structured, lon_name, lat_name = _is_structured(dataset)

if "coord" in dataset:
# exodus with coord or coordx
mesh_type = "Exodus"
elif "coordx" in dataset:
if _is_exodus(dataset):
mesh_type = "Exodus"
elif "grid_center_lon" in dataset:
# scrip with grid_center_lon
elif _is_scrip(dataset):
mesh_type = "Scrip"
elif _is_ugrid(dataset):
# ugrid topology is present
mesh_type = "UGRID"
elif "verticesOnCell" in dataset:
elif _is_mpas(dataset):
mesh_type = "MPAS"
elif "maxNodePElement" in dataset.dims:
elif _is_esmf(dataset):
mesh_type = "ESMF"
elif all(key in dataset.sizes for key in ["nf", "YCdim", "XCdim"]):
# expected dimensions for a GEOS cube sphere grid
elif _is_geos_cs(dataset):
mesh_type = "GEOS-CS"
elif "vertex_of_cell" in dataset:
elif _is_icon(dataset):
mesh_type = "ICON"
elif "triag_nodes" in dataset:
elif _is_fesom2(dataset):
mesh_type = "FESOM2"
elif _structured:
mesh_type = "Structured"
Expand Down
Loading