diff --git a/test/core/test_api.py b/test/core/test_api.py index 150c2877c..090be9ebf 100644 --- a/test/core/test_api.py +++ b/test/core/test_api.py @@ -39,6 +39,54 @@ def test_open_dataset(gridpath, datasetpath, mesh_constants): nt.assert_equal(len(uxds_var2_ne30.uxgrid._ds.data_vars), mesh_constants['DATAVARS_outCSne30']) nt.assert_equal(uxds_var2_ne30.source_datasets, str(data_path)) + +def test_open_dataset_single_combined_mpas_file(gridpath): + """Loads a combined MPAS grid-and-data file with a single argument.""" + + # Use a known combined grid-and-data MPAS file. + file_path = gridpath("mpas", "QU", "oQU480.231010.nc") + + uxds_single = ux.open_dataset(file_path) + uxds_pair = ux.open_dataset(file_path, file_path) + + # Ensure that the single-argument path actually loads data variables + assert len(uxds_single.data_vars) > 0 + nt.assert_equal(uxds_single.uxgrid.source_grid_spec, "MPAS") + nt.assert_equal(uxds_single.source_datasets, str(file_path)) + nt.assert_equal(uxds_single.sizes["n_face"], uxds_pair.sizes["n_face"]) + nt.assert_equal(set(uxds_single.data_vars), set(uxds_pair.data_vars)) + assert "ssh" in uxds_single.data_vars + + +def test_open_dataset_single_combined_xarray_dataset(gridpath): + """Loads a combined MPAS grid-and-data xarray.Dataset with a single argument.""" + + file_path = gridpath("mpas", "QU", "oQU480.231010.nc") + + with xr.open_dataset(file_path) as ds: + uxds = ux.open_dataset(ds) + + nt.assert_equal(uxds.uxgrid.source_grid_spec, "MPAS") + nt.assert_equal(uxds.source_datasets, None) + assert "ssh" in uxds.data_vars + + +def test_open_dataset_single_argument_rejects_directory_grid(tmp_path): + """Requires a separate data file for directory-based grids.""" + + with pytest.raises(ValueError, match="Directory-based grids require a separate data file"): + ux.open_dataset(tmp_path) + + +def test_open_dataset_single_argument_rejects_invalid_combined_file(datasetpath): + """Rejects one-file inputs that do not contain recognizable grid metadata.""" + + data_path = datasetpath("ugrid", "outCSne30", "outCSne30_var2.nc") + + with pytest.raises(RuntimeError, match="Could not recognize dataset format"): + ux.open_dataset(data_path) + + def test_open_mf_dataset(gridpath, datasetpath, mesh_constants): """Loads multiple datasets with their grid topology file using uxarray's open_dataset call.""" diff --git a/uxarray/core/api.py b/uxarray/core/api.py index 5edd612b8..affd739c5 100644 --- a/uxarray/core/api.py +++ b/uxarray/core/api.py @@ -353,7 +353,7 @@ def list_grid_names( def open_dataset( grid_filename_or_obj: str | os.PathLike[Any] | dict | Dataset, - filename_or_obj: str | os.PathLike[Any], + filename_or_obj: str | os.PathLike[Any] | Dataset | None = None, chunks=None, chunk_grid: bool = True, use_dual: bool | None = False, @@ -364,14 +364,17 @@ def open_dataset( Parameters ---------- - grid_filename_or_obj : str | os.PathLike[Any] | dict | xr.dataset + grid_filename_or_obj : str | os.PathLike[Any] | dict | xr.Dataset Strings and Path objects are interpreted as a path to a grid file. Xarray Datasets assume that each member variable is in the UGRID conventions and will be used to create a ``ux.Grid``. Similarly, a dictionary containing UGRID variables can be used to create a ``ux.Grid`` - filename_or_obj : str | os.PathLike[Any] + filename_or_obj : str | os.PathLike[Any] | xr.Dataset, optional String or Path object as a path to a netCDF file or an OpenDAP URL that - stores the actual data set. It is the same ``filename_or_obj`` in - ``xarray.open_dataset``. + stores the actual data set, or an already-open ``xarray.Dataset``. It + is the same ``filename_or_obj`` in ``xarray.open_dataset``. If omitted, + ``grid_filename_or_obj`` is also used as the data source, allowing + combined grid-and-data files or ``xarray.Dataset`` objects to be + opened with a single argument. chunks : int, dict, 'auto' or None, default: None If provided, used to load the grid into dask arrays. @@ -406,23 +409,60 @@ def open_dataset( >>> import uxarray as ux >>> ux_ds = ux.open_dataset("grid_file.nc", "data_file.nc") + + Open a dataset stored in a single combined grid-and-data file + + >>> ux_ds = ux.open_dataset("combined_file.nc") """ + import xarray as xr + if grid_kwargs is None: grid_kwargs = {} - # Construct a Grid, validate parameters, and correct chunks - uxgrid, corrected_chunks = _get_grid( - grid_filename_or_obj, chunks, chunk_grid, use_dual, grid_kwargs, **kwargs - ) + if filename_or_obj is None: + if isinstance(grid_filename_or_obj, (str, os.PathLike)): + if os.path.isdir(grid_filename_or_obj): + raise ValueError( + "Directory-based grids require a separate data file when calling ux.open_dataset()." + ) - # Load the data as a Xarray Dataset - ds = _open_dataset_with_fallback(filename_or_obj, chunks=corrected_chunks, **kwargs) + ds = _open_dataset_with_fallback( + grid_filename_or_obj, + chunks=match_chunks_to_ugrid(grid_filename_or_obj, chunks), + **kwargs, + ) + elif isinstance(grid_filename_or_obj, xr.Dataset): + ds = grid_filename_or_obj + else: + raise ValueError( + "If filename_or_obj is omitted, grid_filename_or_obj must be a file path or xarray.Dataset." + ) + + uxgrid, _ = _get_grid(ds, chunks, chunk_grid, use_dual, grid_kwargs, **kwargs) + filename_or_obj = grid_filename_or_obj + else: + # Construct a Grid, validate parameters, and correct chunks + uxgrid, corrected_chunks = _get_grid( + grid_filename_or_obj, chunks, chunk_grid, use_dual, grid_kwargs, **kwargs + ) + + # Load the data as a Xarray Dataset + if isinstance(filename_or_obj, xr.Dataset): + ds = filename_or_obj + else: + ds = _open_dataset_with_fallback( + filename_or_obj, chunks=corrected_chunks, **kwargs + ) # Map original dimensions to the UGRID conventions ds = _map_dims_to_ugrid(ds, uxgrid._source_dims_dict, uxgrid) # Create a UXarray Dataset by linking the Xarray Dataset with a UXarray Grid - return UxDataset(ds, uxgrid=uxgrid, source_datasets=str(filename_or_obj)) + source_datasets = ( + None if isinstance(filename_or_obj, xr.Dataset) else str(filename_or_obj) + ) + + return UxDataset(ds, uxgrid=uxgrid, source_datasets=source_datasets) def open_mfdataset( diff --git a/uxarray/core/utils.py b/uxarray/core/utils.py index 8092022dd..1a05926aa 100644 --- a/uxarray/core/utils.py +++ b/uxarray/core/utils.py @@ -104,7 +104,11 @@ def match_chunks_to_ugrid(grid_filename_or_obj, chunks): # No need to rename return chunks - ds = _open_dataset_with_fallback(grid_filename_or_obj, chunks=chunks) + if isinstance(grid_filename_or_obj, xr.Dataset): + ds = grid_filename_or_obj + else: + ds = _open_dataset_with_fallback(grid_filename_or_obj, chunks=chunks) + grid_spec, _, _ = _parse_grid_type(ds) source_dims_dict = _get_source_dims_dict(ds, grid_spec)