-
-
Notifications
You must be signed in to change notification settings - Fork 95
feat: Add explicit input validation with clear error messages for model inputs #210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Raakshass
wants to merge
2
commits into
openclimatefix:main
Choose a base branch
from
Raakshass:feat/input-validation
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+170
−3
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| """Input validation utilities for graph_weather models. | ||
|
|
||
| Provides centralized validation functions to ensure input tensors have correct | ||
| shapes at the API boundary, enabling fail-fast behavior with clear error messages. | ||
| """ | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| def validate_input_shape( | ||
| tensor: torch.Tensor, | ||
| expected_ndim: int, | ||
| name: str = "input", | ||
| expected_shape_desc: str = "[batch, nodes, features]", | ||
| ) -> None: | ||
| """Validate that input tensor has the expected number of dimensions. | ||
|
|
||
| Args: | ||
| tensor: Input tensor to validate. | ||
| expected_ndim: Expected number of dimensions. | ||
| name: Name of the input parameter for error messages. | ||
| expected_shape_desc: Human-readable description of expected shape. | ||
|
|
||
| Raises: | ||
| ValueError: If tensor does not have expected number of dimensions. | ||
| """ | ||
| if tensor.ndim != expected_ndim: | ||
| raise ValueError( | ||
| f"Invalid {name} shape: expected {expected_ndim}D tensor with shape " | ||
| f"{expected_shape_desc}, got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}" | ||
| ) | ||
|
|
||
|
|
||
| def validate_features_shape(features: torch.Tensor) -> None: | ||
| """Validate that features tensor has shape [batch, nodes, features]. | ||
|
|
||
| Args: | ||
| features: Input features tensor. | ||
|
|
||
| Raises: | ||
| ValueError: If features tensor is not 3D. | ||
| """ | ||
| validate_input_shape( | ||
| tensor=features, | ||
| expected_ndim=3, | ||
| name="features", | ||
| expected_shape_desc="[batch, nodes, features]", | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| """Unit tests for input validation in graph_weather models. | ||
|
|
||
| These tests verify that model forward() methods correctly validate input tensor | ||
| shapes and fail fast with clear error messages at the API boundary. | ||
| """ | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from graph_weather.models.validation import validate_features_shape, validate_input_shape | ||
|
|
||
|
|
||
| class TestValidateInputShape: | ||
| """Tests for the generic validate_input_shape function.""" | ||
|
|
||
| def test_valid_3d_tensor(self): | ||
| """Test that valid 3D tensor passes validation.""" | ||
| tensor = torch.randn(4, 100, 64) | ||
| # Should not raise | ||
| validate_input_shape(tensor, expected_ndim=3) | ||
|
|
||
| def test_invalid_2d_tensor_raises_valueerror(self): | ||
| """Test that 2D tensor raises ValueError when 3D expected.""" | ||
| tensor = torch.randn(100, 64) | ||
| with pytest.raises(ValueError, match=r"Invalid input shape"): | ||
| validate_input_shape(tensor, expected_ndim=3) | ||
|
|
||
| def test_invalid_1d_tensor_raises_valueerror(self): | ||
| """Test that 1D tensor raises ValueError when 3D expected.""" | ||
| tensor = torch.randn(100) | ||
| with pytest.raises(ValueError, match=r"Invalid input shape"): | ||
| validate_input_shape(tensor, expected_ndim=3) | ||
|
|
||
| def test_invalid_4d_tensor_raises_valueerror(self): | ||
| """Test that 4D tensor raises ValueError when 3D expected.""" | ||
| tensor = torch.randn(4, 100, 64, 32) | ||
| with pytest.raises(ValueError, match=r"Invalid input shape"): | ||
| validate_input_shape(tensor, expected_ndim=3) | ||
|
|
||
| def test_error_message_includes_expected_shape(self): | ||
| """Test that error message includes expected shape description.""" | ||
| tensor = torch.randn(100, 64) | ||
| with pytest.raises(ValueError, match=r"\[batch, nodes, features\]"): | ||
| validate_input_shape( | ||
| tensor, | ||
| expected_ndim=3, | ||
| expected_shape_desc="[batch, nodes, features]", | ||
| ) | ||
|
|
||
| def test_error_message_includes_actual_shape(self): | ||
| """Test that error message includes actual tensor shape.""" | ||
| tensor = torch.randn(100, 64) | ||
| with pytest.raises(ValueError, match=r"\(100, 64\)"): | ||
| validate_input_shape(tensor, expected_ndim=3) | ||
|
|
||
|
|
||
| class TestValidateFeaturesShape: | ||
| """Tests for the features-specific validation function.""" | ||
|
|
||
| def test_valid_features_shape(self): | ||
| """Test that valid features tensor passes validation.""" | ||
| features = torch.randn(8, 256, 78) | ||
| # Should not raise | ||
| validate_features_shape(features) | ||
|
|
||
| def test_missing_batch_dimension(self): | ||
| """Test that tensor missing batch dimension raises ValueError.""" | ||
| features = torch.randn(256, 78) # Missing batch dimension | ||
| with pytest.raises(ValueError, match=r"Invalid features shape"): | ||
| validate_features_shape(features) | ||
|
|
||
| def test_missing_features_dimension(self): | ||
| """Test that tensor missing features dimension raises ValueError.""" | ||
| features = torch.randn(8, 256) # Missing features dimension | ||
| with pytest.raises(ValueError, match=r"Invalid features shape"): | ||
| validate_features_shape(features) | ||
|
|
||
| def test_empty_batch(self): | ||
| """Test that empty batch still passes shape validation.""" | ||
| features = torch.randn(0, 256, 78) # Empty batch | ||
| # Should not raise - shape is still valid even if batch is empty | ||
| validate_features_shape(features) | ||
|
|
||
|
|
||
| class TestModelInputValidation: | ||
| """Integration tests for model forward() input validation. | ||
|
|
||
| Note: These tests focus on the validation behavior, not full model execution. | ||
| The models require heavy dependencies (h3, etc.) so we test validation in isolation. | ||
| """ | ||
|
|
||
| def test_forecaster_import(self): | ||
| """Test that forecaster can be imported with validation.""" | ||
| from graph_weather.models.forecast import GraphWeatherForecaster | ||
|
|
||
| assert hasattr(GraphWeatherForecaster, "forward") | ||
|
|
||
| def test_assimilator_import(self): | ||
| """Test that assimilator can be imported with validation.""" | ||
| from graph_weather.models.analysis import GraphWeatherAssimilator | ||
|
|
||
| assert hasattr(GraphWeatherAssimilator, "forward") | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move this out of the class? Pytest doesn't need it, and so would prefer it being just the test methods.