From f65c20d888107164217ce6fa60122cd3cf9a5bb9 Mon Sep 17 00:00:00 2001 From: siddhant jain Date: Tue, 3 Feb 2026 01:48:45 +0530 Subject: [PATCH 1/2] feat: Add explicit input validation with clear error messages for model inputs Addresses issue #190 - adds fail-fast input validation at the API boundary for model forward() methods. Changes: - Add graph_weather/models/validation.py with centralized validation utilities - Update GraphWeatherForecaster.forward() to validate input shape - Update GraphWeatherAssimilator.forward() to validate input shape - Add comprehensive unit tests in tests/test_input_validation.py Fixes #190 --- graph_weather/models/analysis.py | 14 +++- graph_weather/models/forecast.py | 10 ++- graph_weather/models/validation.py | 48 ++++++++++++++ tests/test_input_validation.py | 102 +++++++++++++++++++++++++++++ 4 files changed, 171 insertions(+), 3 deletions(-) create mode 100644 graph_weather/models/validation.py create mode 100644 tests/test_input_validation.py diff --git a/graph_weather/models/analysis.py b/graph_weather/models/analysis.py index fd216efe..a974e63a 100755 --- a/graph_weather/models/analysis.py +++ b/graph_weather/models/analysis.py @@ -6,6 +6,7 @@ from huggingface_hub import PyTorchModelHubMixin from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Processor +from graph_weather.models.validation import validate_features_shape @dataclass @@ -138,13 +139,22 @@ def forward(self, features: torch.Tensor, obs_lat_lon_heights: torch.Tensor) -> Compute the analysis output Args: - features: The input features, aligned with the order of lat_lons_heights - obs_lat_lon_heights: Observation lat/lon/heights in same order as features + features: The input features, aligned with the order of lat_lons_heights. + Expected shape: [batch, nodes, features] + obs_lat_lon_heights: Observation lat/lon/heights in same order as features. + Expected shape: [batch, nodes, 3] Returns: The next state in the forecast + + Raises: + ValueError: If features tensor is not 3D with shape [batch, nodes, features] """ + # Validate input shape at API boundary for clear error messages + validate_features_shape(features) + x, edge_idx, edge_attr = self.encoder(features, obs_lat_lon_heights) x = self.processor(x, edge_idx, edge_attr) x = self.decoder(x, features.shape[0]) return x + diff --git a/graph_weather/models/forecast.py b/graph_weather/models/forecast.py index 15d38b76..7965efbb 100755 --- a/graph_weather/models/forecast.py +++ b/graph_weather/models/forecast.py @@ -9,6 +9,7 @@ from graph_weather.models import Decoder, Encoder, Processor from graph_weather.models.layers.constraint_layer import PhysicalConstraintLayer +from graph_weather.models.validation import validate_features_shape @dataclass @@ -217,12 +218,19 @@ def forward(self, features: torch.Tensor, t: int = 0) -> torch.Tensor: Compute the new state of the forecast Args: - features: The input features, aligned with the order of lat_lons_heights + features: The input features, aligned with the order of lat_lons_heights. + Expected shape: [batch, nodes, features] t: Timestep for the thermalizer Returns: The next state in the forecast + + Raises: + ValueError: If features tensor is not 3D with shape [batch, nodes, features] """ + # Validate input shape at API boundary for clear error messages + validate_features_shape(features) + x, edge_idx, edge_attr = self.encoder(features) x = self.processor(x, edge_idx, edge_attr, t) x = self.decoder(x, features[..., : self.feature_dim]) diff --git a/graph_weather/models/validation.py b/graph_weather/models/validation.py new file mode 100644 index 00000000..32acae1d --- /dev/null +++ b/graph_weather/models/validation.py @@ -0,0 +1,48 @@ +"""Input validation utilities for graph_weather models. + +Provides centralized validation functions to ensure input tensors have correct +shapes at the API boundary, enabling fail-fast behavior with clear error messages. +""" + +import torch + + +def validate_input_shape( + tensor: torch.Tensor, + expected_ndim: int, + name: str = "input", + expected_shape_desc: str = "[batch, nodes, features]", +) -> None: + """Validate that input tensor has the expected number of dimensions. + + Args: + tensor: Input tensor to validate. + expected_ndim: Expected number of dimensions. + name: Name of the input parameter for error messages. + expected_shape_desc: Human-readable description of expected shape. + + Raises: + ValueError: If tensor does not have expected number of dimensions. + """ + if tensor.ndim != expected_ndim: + raise ValueError( + f"Invalid {name} shape: expected {expected_ndim}D tensor with shape " + f"{expected_shape_desc}, got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}" + ) + + +def validate_features_shape(features: torch.Tensor) -> None: + """Validate that features tensor has shape [batch, nodes, features]. + + Args: + features: Input features tensor. + + Raises: + ValueError: If features tensor is not 3D. + """ + validate_input_shape( + tensor=features, + expected_ndim=3, + name="features", + expected_shape_desc="[batch, nodes, features]", + ) diff --git a/tests/test_input_validation.py b/tests/test_input_validation.py new file mode 100644 index 00000000..f42cb1e2 --- /dev/null +++ b/tests/test_input_validation.py @@ -0,0 +1,102 @@ +"""Unit tests for input validation in graph_weather models. + +These tests verify that model forward() methods correctly validate input tensor +shapes and fail fast with clear error messages at the API boundary. +""" + +import pytest +import torch + +from graph_weather.models.validation import validate_features_shape, validate_input_shape + + +class TestValidateInputShape: + """Tests for the generic validate_input_shape function.""" + + def test_valid_3d_tensor(self): + """Test that valid 3D tensor passes validation.""" + tensor = torch.randn(4, 100, 64) + # Should not raise + validate_input_shape(tensor, expected_ndim=3) + + def test_invalid_2d_tensor_raises_valueerror(self): + """Test that 2D tensor raises ValueError when 3D expected.""" + tensor = torch.randn(100, 64) + with pytest.raises(ValueError, match=r"Invalid input shape"): + validate_input_shape(tensor, expected_ndim=3) + + def test_invalid_1d_tensor_raises_valueerror(self): + """Test that 1D tensor raises ValueError when 3D expected.""" + tensor = torch.randn(100) + with pytest.raises(ValueError, match=r"Invalid input shape"): + validate_input_shape(tensor, expected_ndim=3) + + def test_invalid_4d_tensor_raises_valueerror(self): + """Test that 4D tensor raises ValueError when 3D expected.""" + tensor = torch.randn(4, 100, 64, 32) + with pytest.raises(ValueError, match=r"Invalid input shape"): + validate_input_shape(tensor, expected_ndim=3) + + def test_error_message_includes_expected_shape(self): + """Test that error message includes expected shape description.""" + tensor = torch.randn(100, 64) + with pytest.raises(ValueError, match=r"\[batch, nodes, features\]"): + validate_input_shape( + tensor, + expected_ndim=3, + expected_shape_desc="[batch, nodes, features]", + ) + + def test_error_message_includes_actual_shape(self): + """Test that error message includes actual tensor shape.""" + tensor = torch.randn(100, 64) + with pytest.raises(ValueError, match=r"\(100, 64\)"): + validate_input_shape(tensor, expected_ndim=3) + + +class TestValidateFeaturesShape: + """Tests for the features-specific validation function.""" + + def test_valid_features_shape(self): + """Test that valid features tensor passes validation.""" + features = torch.randn(8, 256, 78) + # Should not raise + validate_features_shape(features) + + def test_missing_batch_dimension(self): + """Test that tensor missing batch dimension raises ValueError.""" + features = torch.randn(256, 78) # Missing batch dimension + with pytest.raises(ValueError, match=r"Invalid features shape"): + validate_features_shape(features) + + def test_missing_features_dimension(self): + """Test that tensor missing features dimension raises ValueError.""" + features = torch.randn(8, 256) # Missing features dimension + with pytest.raises(ValueError, match=r"Invalid features shape"): + validate_features_shape(features) + + def test_empty_batch(self): + """Test that empty batch still passes shape validation.""" + features = torch.randn(0, 256, 78) # Empty batch + # Should not raise - shape is still valid even if batch is empty + validate_features_shape(features) + + +class TestModelInputValidation: + """Integration tests for model forward() input validation. + + Note: These tests focus on the validation behavior, not full model execution. + The models require heavy dependencies (h3, etc.) so we test validation in isolation. + """ + + def test_forecaster_import(self): + """Test that forecaster can be imported with validation.""" + from graph_weather.models.forecast import GraphWeatherForecaster + + assert hasattr(GraphWeatherForecaster, "forward") + + def test_assimilator_import(self): + """Test that assimilator can be imported with validation.""" + from graph_weather.models.analysis import GraphWeatherAssimilator + + assert hasattr(GraphWeatherAssimilator, "forward") From 10882dc1f1724385754ff83d2b823a3904ec4d3d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Feb 2026 20:29:53 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/models/analysis.py | 1 - 1 file changed, 1 deletion(-) diff --git a/graph_weather/models/analysis.py b/graph_weather/models/analysis.py index a974e63a..2c11ef7f 100755 --- a/graph_weather/models/analysis.py +++ b/graph_weather/models/analysis.py @@ -157,4 +157,3 @@ def forward(self, features: torch.Tensor, obs_lat_lon_heights: torch.Tensor) -> x = self.processor(x, edge_idx, edge_attr) x = self.decoder(x, features.shape[0]) return x -