Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions graph_weather/models/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from huggingface_hub import PyTorchModelHubMixin

from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Processor
from graph_weather.models.validation import validate_features_shape


@dataclass
Expand Down Expand Up @@ -138,12 +139,20 @@ def forward(self, features: torch.Tensor, obs_lat_lon_heights: torch.Tensor) ->
Compute the analysis output

Args:
features: The input features, aligned with the order of lat_lons_heights
obs_lat_lon_heights: Observation lat/lon/heights in same order as features
features: The input features, aligned with the order of lat_lons_heights.
Expected shape: [batch, nodes, features]
obs_lat_lon_heights: Observation lat/lon/heights in same order as features.
Expected shape: [batch, nodes, 3]

Returns:
The next state in the forecast

Raises:
ValueError: If features tensor is not 3D with shape [batch, nodes, features]
"""
# Validate input shape at API boundary for clear error messages
validate_features_shape(features)

x, edge_idx, edge_attr = self.encoder(features, obs_lat_lon_heights)
x = self.processor(x, edge_idx, edge_attr)
x = self.decoder(x, features.shape[0])
Expand Down
10 changes: 9 additions & 1 deletion graph_weather/models/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from graph_weather.models import Decoder, Encoder, Processor
from graph_weather.models.layers.constraint_layer import PhysicalConstraintLayer
from graph_weather.models.validation import validate_features_shape


@dataclass
Expand Down Expand Up @@ -217,12 +218,19 @@ def forward(self, features: torch.Tensor, t: int = 0) -> torch.Tensor:
Compute the new state of the forecast

Args:
features: The input features, aligned with the order of lat_lons_heights
features: The input features, aligned with the order of lat_lons_heights.
Expected shape: [batch, nodes, features]
t: Timestep for the thermalizer

Returns:
The next state in the forecast

Raises:
ValueError: If features tensor is not 3D with shape [batch, nodes, features]
"""
# Validate input shape at API boundary for clear error messages
validate_features_shape(features)

x, edge_idx, edge_attr = self.encoder(features)
x = self.processor(x, edge_idx, edge_attr, t)
x = self.decoder(x, features[..., : self.feature_dim])
Expand Down
48 changes: 48 additions & 0 deletions graph_weather/models/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Input validation utilities for graph_weather models.

Provides centralized validation functions to ensure input tensors have correct
shapes at the API boundary, enabling fail-fast behavior with clear error messages.
"""

import torch


def validate_input_shape(
tensor: torch.Tensor,
expected_ndim: int,
name: str = "input",
expected_shape_desc: str = "[batch, nodes, features]",
) -> None:
"""Validate that input tensor has the expected number of dimensions.

Args:
tensor: Input tensor to validate.
expected_ndim: Expected number of dimensions.
name: Name of the input parameter for error messages.
expected_shape_desc: Human-readable description of expected shape.

Raises:
ValueError: If tensor does not have expected number of dimensions.
"""
if tensor.ndim != expected_ndim:
raise ValueError(
f"Invalid {name} shape: expected {expected_ndim}D tensor with shape "
f"{expected_shape_desc}, got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}"
)


def validate_features_shape(features: torch.Tensor) -> None:
"""Validate that features tensor has shape [batch, nodes, features].

Args:
features: Input features tensor.

Raises:
ValueError: If features tensor is not 3D.
"""
validate_input_shape(
tensor=features,
expected_ndim=3,
name="features",
expected_shape_desc="[batch, nodes, features]",
)
102 changes: 102 additions & 0 deletions tests/test_input_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Unit tests for input validation in graph_weather models.

These tests verify that model forward() methods correctly validate input tensor
shapes and fail fast with clear error messages at the API boundary.
"""

import pytest
import torch

from graph_weather.models.validation import validate_features_shape, validate_input_shape


class TestValidateInputShape:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this out of the class? Pytest doesn't need it, and so would prefer it being just the test methods.

"""Tests for the generic validate_input_shape function."""

def test_valid_3d_tensor(self):
"""Test that valid 3D tensor passes validation."""
tensor = torch.randn(4, 100, 64)
# Should not raise
validate_input_shape(tensor, expected_ndim=3)

def test_invalid_2d_tensor_raises_valueerror(self):
"""Test that 2D tensor raises ValueError when 3D expected."""
tensor = torch.randn(100, 64)
with pytest.raises(ValueError, match=r"Invalid input shape"):
validate_input_shape(tensor, expected_ndim=3)

def test_invalid_1d_tensor_raises_valueerror(self):
"""Test that 1D tensor raises ValueError when 3D expected."""
tensor = torch.randn(100)
with pytest.raises(ValueError, match=r"Invalid input shape"):
validate_input_shape(tensor, expected_ndim=3)

def test_invalid_4d_tensor_raises_valueerror(self):
"""Test that 4D tensor raises ValueError when 3D expected."""
tensor = torch.randn(4, 100, 64, 32)
with pytest.raises(ValueError, match=r"Invalid input shape"):
validate_input_shape(tensor, expected_ndim=3)

def test_error_message_includes_expected_shape(self):
"""Test that error message includes expected shape description."""
tensor = torch.randn(100, 64)
with pytest.raises(ValueError, match=r"\[batch, nodes, features\]"):
validate_input_shape(
tensor,
expected_ndim=3,
expected_shape_desc="[batch, nodes, features]",
)

def test_error_message_includes_actual_shape(self):
"""Test that error message includes actual tensor shape."""
tensor = torch.randn(100, 64)
with pytest.raises(ValueError, match=r"\(100, 64\)"):
validate_input_shape(tensor, expected_ndim=3)


class TestValidateFeaturesShape:
"""Tests for the features-specific validation function."""

def test_valid_features_shape(self):
"""Test that valid features tensor passes validation."""
features = torch.randn(8, 256, 78)
# Should not raise
validate_features_shape(features)

def test_missing_batch_dimension(self):
"""Test that tensor missing batch dimension raises ValueError."""
features = torch.randn(256, 78) # Missing batch dimension
with pytest.raises(ValueError, match=r"Invalid features shape"):
validate_features_shape(features)

def test_missing_features_dimension(self):
"""Test that tensor missing features dimension raises ValueError."""
features = torch.randn(8, 256) # Missing features dimension
with pytest.raises(ValueError, match=r"Invalid features shape"):
validate_features_shape(features)

def test_empty_batch(self):
"""Test that empty batch still passes shape validation."""
features = torch.randn(0, 256, 78) # Empty batch
# Should not raise - shape is still valid even if batch is empty
validate_features_shape(features)


class TestModelInputValidation:
"""Integration tests for model forward() input validation.

Note: These tests focus on the validation behavior, not full model execution.
The models require heavy dependencies (h3, etc.) so we test validation in isolation.
"""

def test_forecaster_import(self):
"""Test that forecaster can be imported with validation."""
from graph_weather.models.forecast import GraphWeatherForecaster

assert hasattr(GraphWeatherForecaster, "forward")

def test_assimilator_import(self):
"""Test that assimilator can be imported with validation."""
from graph_weather.models.analysis import GraphWeatherAssimilator

assert hasattr(GraphWeatherAssimilator, "forward")