diff --git a/src/oumi/cli/main.py b/src/oumi/cli/main.py index de3b519f63..6cafeb6fe1 100644 --- a/src/oumi/cli/main.py +++ b/src/oumi/cli/main.py @@ -57,6 +57,7 @@ from oumi.cli.synth import synth from oumi.cli.train import train from oumi.cli.tune import tune +from oumi.exceptions import OumiConfigError from oumi.utils.logging import should_use_rich_logging _ASCII_LOGO = r""" @@ -365,6 +366,9 @@ def run(): telemetry = TelemetryManager.get_instance() with telemetry.capture_operation(event_name, event_properties): return app() + except OumiConfigError as e: + CONSOLE.print(f"[red]Error: {e}[/red]") + sys.exit(1) except Exception as e: tb_str = traceback.format_exc() CONSOLE.print(tb_str) diff --git a/src/oumi/core/configs/__init__.py b/src/oumi/core/configs/__init__.py index 8ed6769f5d..5b814a5245 100644 --- a/src/oumi/core/configs/__init__.py +++ b/src/oumi/core/configs/__init__.py @@ -158,12 +158,14 @@ from oumi.core.configs.synthesis_config import SynthesisConfig from oumi.core.configs.training_config import TrainingConfig from oumi.core.configs.tuning_config import TuningConfig +from oumi.exceptions import OumiConfigError, OumiConfigFileNotFoundError __all__ = [ "AsyncEvaluationConfig", "AutoWrapPolicy", "BackwardPrefetch", "BaseConfig", + "OumiConfigFileNotFoundError", "DataParams", "DatasetParams", "DatasetSplit", @@ -192,6 +194,7 @@ "MixedPrecisionDtype", "MixtureStrategy", "ModelParams", + "OumiConfigError", "PeftParams", "PeftSaveMode", "ProfilerParams", diff --git a/src/oumi/core/configs/base_config.py b/src/oumi/core/configs/base_config.py index e7b0ec4d40..874759048c 100644 --- a/src/oumi/core/configs/base_config.py +++ b/src/oumi/core/configs/base_config.py @@ -23,8 +23,15 @@ from typing import Any, TypeVar, cast from omegaconf import OmegaConf +from omegaconf.errors import OmegaConfBaseException from oumi.core.configs.params.base_params import BaseParams +from oumi.exceptions import ( + OumiConfigError, + OumiConfigFileNotFoundError, + OumiConfigParsingError, + OumiConfigTypeError, +) T = TypeVar("T", bound="BaseConfig") @@ -128,6 +135,10 @@ def _read_config_without_interpolation(config_path: str) -> str: Returns: str: The stringified configuration. """ + if not Path(config_path).is_file(): + raise OumiConfigFileNotFoundError( + f"Config file not found or path is not a file: {config_path}" + ) with open(config_path) as f: stringified_config = f.read() pattern = r"(? None: + "\n".join(f"- {path}" for path in sorted(removed_paths)) ) - OmegaConf.save(config=processed_config, f=config_path) + try: + OmegaConf.save(config=processed_config, f=config_path) + except OSError as e: + # handle missing parent folder + raise OumiConfigError(f"Failed to save config to {config_path}: {e}") from e @classmethod def from_yaml( @@ -181,15 +196,24 @@ def from_yaml( Returns: BaseConfig: The merged configuration object. """ - schema = OmegaConf.structured(cls) - if ignore_interpolation: - stringified_config = _read_config_without_interpolation(str(config_path)) - file_config = OmegaConf.create(stringified_config) - else: - file_config = OmegaConf.load(config_path) - config = OmegaConf.to_object(OmegaConf.merge(schema, file_config)) + if not Path(config_path).is_file(): + raise OumiConfigFileNotFoundError( + f"Config file not found or path is not a file: {config_path}" + ) + try: + schema = OmegaConf.structured(cls) + if ignore_interpolation: + stringified_config = _read_config_without_interpolation( + str(config_path) + ) + file_config = OmegaConf.create(stringified_config) + else: + file_config = OmegaConf.load(config_path) + config = OmegaConf.to_object(OmegaConf.merge(schema, file_config)) + except OmegaConfBaseException as e: + raise OumiConfigParsingError(e) from e if not isinstance(config, cls): - raise TypeError(f"config is not {cls}") + raise OumiConfigTypeError(config_type=cls, config_value=config) return cast(T, config) @classmethod @@ -202,11 +226,14 @@ def from_str(cls: type[T], config_str: str) -> T: Returns: BaseConfig: The configuration object. """ - schema = OmegaConf.structured(cls) - file_config = OmegaConf.create(config_str) - config = OmegaConf.to_object(OmegaConf.merge(schema, file_config)) + try: + schema = OmegaConf.structured(cls) + file_config = OmegaConf.create(config_str) + config = OmegaConf.to_object(OmegaConf.merge(schema, file_config)) + except OmegaConfBaseException as e: + raise OumiConfigParsingError(e) from e if not isinstance(config, cls): - raise TypeError(f"config is not {cls}") + raise OumiConfigTypeError(config_type=cls, config_value=config) return cast(T, config) @classmethod @@ -234,23 +261,22 @@ def from_yaml_and_arg_list( """ # Start with an empty typed config. This forces OmegaConf to validate # that all other configs are of this structured type as well. - all_configs = [OmegaConf.structured(cls)] - # Override with configuration file if provided. - if config_path is not None: - if ignore_interpolation: - stringified_config = _read_config_without_interpolation(config_path) - all_configs.append(OmegaConf.create(stringified_config)) - else: - all_configs.append(cls.from_yaml(config_path)) - # Merge base config and config from yaml. try: - # Merge and validate configs + all_configs = [OmegaConf.structured(cls)] + if config_path is not None: + if ignore_interpolation: + stringified_config = _read_config_without_interpolation(config_path) + all_configs.append(OmegaConf.create(stringified_config)) + else: + all_configs.append(cls.from_yaml(config_path)) config = OmegaConf.merge(*all_configs) + except OmegaConfBaseException as e: + raise OumiConfigParsingError(e) from e except Exception: if logger: - configs_str = "\n\n".join([f"{config}" for config in all_configs]) + configs_str = "\n\n".join([f"{c}" for c in all_configs]) logger.exception( f"Failed to merge {len(all_configs)} Omega configs:\n{configs_str}" ) @@ -267,16 +293,18 @@ def from_yaml_and_arg_list( arg_list = _filter_ignored_args(arg_list) # Override with CLI arguments. config.merge_with_dotlist(arg_list) + config = OmegaConf.to_object(config) + except OmegaConfBaseException as e: + raise OumiConfigParsingError(e) from e except Exception: if logger: logger.exception( - f"Failed to merge arglist {arg_list} with Omega config:\n{config}" + f"Failed to apply CLI args {arg_list} to Omega config:\n{config}" ) raise - config = OmegaConf.to_object(config) if not isinstance(config, cls): - raise TypeError(f"config {type(config)} is not {type(cls)}") + raise OumiConfigTypeError(config_type=cls, config_value=config) return cast(T, config) diff --git a/src/oumi/core/types/exceptions.py b/src/oumi/core/configs/exceptions.py similarity index 72% rename from src/oumi/core/types/exceptions.py rename to src/oumi/core/configs/exceptions.py index d12c0b3ec4..19360730a1 100644 --- a/src/oumi/core/types/exceptions.py +++ b/src/oumi/core/configs/exceptions.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Backward-compatible re-exports. Canonical definitions live in oumi.exceptions.""" -class HardwareException(Exception): - """An exception thrown for invalid hardware configurations.""" +from oumi.exceptions import OumiConfigError, OumiConfigFileNotFoundError + +__all__ = ["OumiConfigFileNotFoundError", "OumiConfigError"] diff --git a/src/oumi/core/configs/params/model_params.py b/src/oumi/core/configs/params/model_params.py index 6dfb18b699..b81ccb5683 100644 --- a/src/oumi/core/configs/params/model_params.py +++ b/src/oumi/core/configs/params/model_params.py @@ -22,7 +22,11 @@ from transformers.utils import find_adapter_config_file, is_flash_attn_2_available from oumi.core.configs.params.base_params import BaseParams -from oumi.core.types.exceptions import HardwareException +from oumi.exceptions import ( + HardwareException, + OumiConfigError, + OumiConfigFileNotFoundError, +) from oumi.utils.logging import logger from oumi.utils.torch_utils import get_torch_dtype @@ -268,6 +272,11 @@ def __finalize_and_validate__(self): adapter_config_file = None # If this check fails, it means this is not a LoRA model. if adapter_config_file: + if not Path(adapter_config_file).is_file(): + raise OumiConfigFileNotFoundError( + f"Adapter config file not found or path is not a file: " + f"{adapter_config_file}" + ) # If `model_name` is a local dir, this should be the same. # If it's a HF Hub repo, this should be the path to the cached repo. adapter_dir = Path(adapter_config_file).parent @@ -280,8 +289,19 @@ def __finalize_and_validate__(self): # present, set it to the base model name found in the adapter config, # if present. Error otherwise. if len(list(adapter_dir.glob("config.json"))) == 0: - with open(adapter_config_file) as f: - adapter_config = json.load(f) + try: + with open(adapter_config_file) as f: + adapter_config = json.load(f) + except OSError as e: + raise OumiConfigError( + f"Failed to read adapter config at " + f"{adapter_config_file}: {e}" + ) from e + except json.JSONDecodeError as e: + raise OumiConfigError( + f"Adapter config at {adapter_config_file} contains invalid " + f"JSON: (line {e.lineno}, col {e.colno}): {e.msg}" + ) from e model_name = adapter_config.get("base_model_name_or_path") if not model_name: raise ValueError( diff --git a/src/oumi/core/types/__init__.py b/src/oumi/core/types/__init__.py index 74a07172d8..e3ed75f949 100644 --- a/src/oumi/core/types/__init__.py +++ b/src/oumi/core/types/__init__.py @@ -42,7 +42,7 @@ TemplatedMessage, Type, ) -from oumi.core.types.exceptions import HardwareException +from oumi.exceptions import HardwareException __all__ = [ "HardwareException", diff --git a/src/oumi/exceptions.py b/src/oumi/exceptions.py new file mode 100644 index 0000000000..393c0b36a8 --- /dev/null +++ b/src/oumi/exceptions.py @@ -0,0 +1,62 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Oumi exception hierarchy. + +This module is intentionally free of heavy dependencies (torch, transformers, etc.) +so that it can be imported cheaply in lightweight entry-points such as the CLI. +""" + +from typing import Any + + +class OumiConfigError(Exception): + """Base class for all configuration-related errors.""" + + +class OumiConfigFileNotFoundError(OumiConfigError, FileNotFoundError): + """A configuration file path does not exist.""" + + +class OumiConfigTypeError(OumiConfigError): + """A configuration type error.""" + + def __init__(self, config_type: type, config_value: Any): + """Initialize with the expected type and actual value.""" + self.config_type = config_type + self.config_value = config_value + super().__init__( + f"Expected config of type {config_type.__name__}, " + f"got {type(config_value).__name__}" + ) + + +class OumiConfigParsingError(OumiConfigError): + """Wraps any OmegaConf exception into a user-friendly config error. + + Covers all subclasses of ``OmegaConfBaseException``, suppressing the internal + call stack from the CLI. + """ + + def __init__(self, cause: Exception): + """Extract a user-friendly message from an OmegaConf exception.""" + full_key = getattr(cause, "full_key", None) + candidate = full_key if full_key else getattr(cause, "key", None) + self.config_key: str | None = str(candidate) if candidate is not None else None + self.msg: str = getattr(cause, "msg", None) or str(cause) + super().__init__(f"Config error: {self.msg}") + + +class HardwareException(Exception): + """An exception thrown for invalid hardware configurations.""" diff --git a/tests/unit/cli/test_cli_launch.py b/tests/unit/cli/test_cli_launch.py index 78896e8fe1..ae2f56f5ae 100644 --- a/tests/unit/cli/test_cli_launch.py +++ b/tests/unit/cli/test_cli_launch.py @@ -710,7 +710,9 @@ def test_launch_up_job_not_found( ) if res.exception: raise res.exception - assert "No such file or directory" in str(exception_info.value) + assert "Config file not found or path is not a file" in str( + exception_info.value + ) def test_launch_run_job( diff --git a/tests/unit/core/configs/params/test_model_params.py b/tests/unit/core/configs/params/test_model_params.py index 44108d4d3c..83ea2329ec 100644 --- a/tests/unit/core/configs/params/test_model_params.py +++ b/tests/unit/core/configs/params/test_model_params.py @@ -5,6 +5,7 @@ import pytest from oumi.core.configs.params.model_params import ModelParams +from oumi.exceptions import OumiConfigError, OumiConfigFileNotFoundError def test_post_init_adapter_model_present(): @@ -130,3 +131,44 @@ def test_chat_template_kwargs_custom_assignment(): model_params = ModelParams(chat_template_kwargs={"enable_thinking": False}) assert model_params.chat_template_kwargs is not None assert model_params.chat_template_kwargs["enable_thinking"] is False + + +@patch("oumi.core.configs.params.model_params.find_adapter_config_file") +def test_adapter_config_file_path_not_a_file(mock_find, tmp_path: Path): + """Raises OumiConfigFileNotFoundError for a missing adapter config path.""" + mock_find.return_value = str(tmp_path / "ghost_adapter_config.json") + + params = ModelParams(model_name=str(tmp_path)) + with pytest.raises( + OumiConfigFileNotFoundError, + match="Adapter config file not found or path is not a file", + ): + params.finalize_and_validate() + + +@patch("oumi.core.configs.params.model_params.find_adapter_config_file") +def test_adapter_config_read_oserror(mock_find, tmp_path: Path): + """Test OSError reading adapter_config.json is re-raised as OumiConfigError.""" + adapter_path = tmp_path / "adapter_config.json" + adapter_path.write_text('{"base_model_name_or_path": "base_model"}') + mock_find.return_value = str(adapter_path) + + params = ModelParams(model_name=str(tmp_path)) + with patch("builtins.open", side_effect=OSError("Permission denied")): + with pytest.raises( + OumiConfigError, + match="Failed to read adapter config", + ): + params.finalize_and_validate() + + +def test_adapter_config_invalid_json(tmp_path: Path): + """Test malformed JSON raises OumiConfigError with location info.""" + (tmp_path / "adapter_config.json").write_text("{not: valid json!!}") + + params = ModelParams(model_name=str(tmp_path)) + with pytest.raises(OumiConfigError, match="contains invalid JSON") as exc_info: + params.finalize_and_validate() + + assert "line" in str(exc_info.value) + assert "col" in str(exc_info.value) diff --git a/tests/unit/core/configs/test_base_config.py b/tests/unit/core/configs/test_base_config.py index fb92e846c3..5bbe9188f0 100644 --- a/tests/unit/core/configs/test_base_config.py +++ b/tests/unit/core/configs/test_base_config.py @@ -5,9 +5,20 @@ from pathlib import Path from typing import Any +import pytest from omegaconf import OmegaConf +from omegaconf.errors import ConfigKeyError, GrammarParseError -from oumi.core.configs.base_config import BaseConfig, _handle_non_primitives +from oumi.core.configs.base_config import ( + BaseConfig, + _handle_non_primitives, + _read_config_without_interpolation, +) +from oumi.exceptions import ( + OumiConfigError, + OumiConfigFileNotFoundError, + OumiConfigParsingError, +) class TestEnum(Enum): @@ -34,6 +45,25 @@ class TestConfig(BaseConfig): func_value: Any | None = None +# Full valid TestConfig YAML (shared by from_str tests). +_TEST_CONFIG_YAML = """ + str_value: "test" + int_value: 42 + float_value: 3.14 + bool_value: true + none_value: null + bytes_value: !!binary dGVzdA== + path_value: "test/path" + enum_value: "VALUE1" + list_value: ["primitive", [1, 2, 3]] + dict_value: + primitive: "value" + nested: + list: [1, 2, 3] + func_value: "def test_func(x): return x * 2" + """ + + def test_primitive_types(): """Test that primitive types are preserved.""" config = { @@ -168,24 +198,7 @@ def test_config_serialization(): def test_config_loading_from_str(): """Test loading config from YAML string.""" - yaml_str = """ - str_value: "test" - int_value: 42 - float_value: 3.14 - bool_value: true - none_value: null - bytes_value: !!binary dGVzdA== - path_value: "test/path" - enum_value: "VALUE1" - list_value: ["primitive", [1, 2, 3]] - dict_value: - primitive: "value" - nested: - list: [1, 2, 3] - func_value: "def test_func(x): return x * 2" - """ - - config = TestConfig.from_str(yaml_str) + config = TestConfig.from_str(_TEST_CONFIG_YAML) assert config.str_value == "test" assert config.int_value == 42 assert config.float_value == 3.14 @@ -198,6 +211,25 @@ def test_config_loading_from_str(): assert config.dict_value == {"primitive": "value", "nested": {"list": [1, 2, 3]}} +def test_from_str_unknown_field_raises_config_parsing_error(): + """Unknown YAML key raises OumiConfigParsingError with chained cause.""" + yaml_str = _TEST_CONFIG_YAML + "\n unknown_key: 1\n" + with pytest.raises(OumiConfigParsingError) as exc_info: + TestConfig.from_str(yaml_str) + assert exc_info.value.config_key == "unknown_key" + assert isinstance(exc_info.value.__cause__, ConfigKeyError) + + +def test_from_str_malformed_interpolation_raises_config_parsing_error(): + """Malformed interpolation raises OumiConfigParsingError with chained cause.""" + # Standalone YAML overriding str_value with an unclosed interpolation. + # GrammarParseError is raised by OmegaConf during OmegaConf.to_object(). + yaml_str = 'str_value: "${bad"' + with pytest.raises(OumiConfigParsingError) as exc_info: + TestConfig.from_str(yaml_str) + assert isinstance(exc_info.value.__cause__, GrammarParseError) + + def test_config_equality(): """Test config equality comparison.""" config_a = TestConfig( @@ -331,3 +363,116 @@ def test_config_from_yaml_and_arg_list(): assert new_config.bool_value is False assert new_config.list_value[0] == "override" assert new_config.dict_value["key"] == "override" + + +@pytest.mark.parametrize( + ("subclass", "superclass"), + [ + (OumiConfigError, Exception), + (OumiConfigFileNotFoundError, OumiConfigError), + (OumiConfigFileNotFoundError, FileNotFoundError), + ], +) +def test_exception_class_hierarchy(subclass, superclass): + """Test that config exception classes form the expected inheritance hierarchy.""" + assert issubclass(subclass, superclass) + + +@pytest.mark.parametrize("exc_type", [OumiConfigError, FileNotFoundError]) +def test_config_file_not_found_error_caught_as(exc_type): + """OumiConfigFileNotFoundError is catchable as each parent exception type.""" + with pytest.raises(exc_type): + raise OumiConfigFileNotFoundError("test message") + + +def test_read_config_without_interpolation_file_not_found(): + """Test that a non-existent path raises OumiConfigFileNotFoundError.""" + with pytest.raises( + OumiConfigFileNotFoundError, + match="Config file not found or path is not a file", + ): + _read_config_without_interpolation("/nonexistent/path/config.yaml") + + +def test_read_config_without_interpolation_directory_path(tmp_path: Path): + """Test that a directory path raises OumiConfigFileNotFoundError.""" + with pytest.raises( + OumiConfigFileNotFoundError, + match="Config file not found or path is not a file", + ): + _read_config_without_interpolation(str(tmp_path)) + + +def test_from_yaml_file_not_found(): + """from_yaml raises OumiConfigFileNotFoundError for a missing config path.""" + with pytest.raises( + OumiConfigFileNotFoundError, + match="Config file not found or path is not a file", + ): + TestConfig.from_yaml("/nonexistent/path/config.yaml") + + +def test_from_yaml_path_is_directory(tmp_path: Path): + """Test that from_yaml raises OumiConfigFileNotFoundError when given a directory.""" + with pytest.raises( + OumiConfigFileNotFoundError, + match="Config file not found or path is not a file", + ): + TestConfig.from_yaml(tmp_path) + + +def test_to_yaml_missing_parent_directory(): + """Test that to_yaml raises OumiConfigError when the output directory is missing.""" + config = TestConfig( + str_value="test", + int_value=1, + float_value=1.0, + bool_value=True, + none_value=None, + bytes_value=b"test", + path_value=Path("test/path"), + enum_value=TestEnum.VALUE1, + list_value=[], + dict_value={}, + ) + with pytest.raises( + OumiConfigError, + match="Failed to save config to", + ): + config.to_yaml("/nonexistent/subdir/out.yaml") + + +def test_from_yaml_file_not_found_no_interpolation(): + """from_yaml with ignore_interpolation=False raises on missing file.""" + with pytest.raises( + OumiConfigFileNotFoundError, + match="Config file not found or path is not a file", + ): + TestConfig.from_yaml( + "/nonexistent/path/config.yaml", ignore_interpolation=False + ) + + +def test_from_yaml_and_arg_list_nonexistent_config(): + """from_yaml_and_arg_list raises OumiConfigFileNotFoundError if file missing.""" + with pytest.raises( + OumiConfigFileNotFoundError, + match="Config file not found or path is not a file", + ): + TestConfig.from_yaml_and_arg_list( + config_path="/nonexistent/path/config.yaml", + arg_list=[], + ) + + +def test_from_yaml_and_arg_list_nonexistent_config_no_interpolation(tmp_path: Path): + """Test OumiConfigFileNotFoundError via the ignore_interpolation=False branch.""" + with pytest.raises( + OumiConfigFileNotFoundError, + match="Config file not found or path is not a file", + ): + TestConfig.from_yaml_and_arg_list( + config_path=str(tmp_path / "nonexistent.yaml"), + arg_list=[], + ignore_interpolation=False, + ) diff --git a/tests/unit/core/configs/test_config.py b/tests/unit/core/configs/test_config.py index 36a99cf6e0..cf211f9240 100644 --- a/tests/unit/core/configs/test_config.py +++ b/tests/unit/core/configs/test_config.py @@ -12,6 +12,7 @@ TrainingConfig, ) from oumi.core.configs.params.evaluation_params import EvaluationTaskParams +from oumi.exceptions import OumiConfigParsingError def test_config_serialization(): @@ -187,13 +188,14 @@ def test_config_from_yaml_and_arg_list_failure_nonexistent_index(tmp_path): config_path = tmp_path / "eval.yaml" config.to_yaml(config_path) - with pytest.raises(omegaconf.errors.ValidationError): + with pytest.raises(OumiConfigParsingError) as exc_info: EvaluationConfig.from_yaml_and_arg_list( config_path, [ "tasks[2].num_samples=1", # index doesn't exist ], ) + assert isinstance(exc_info.value.__cause__, omegaconf.errors.ValidationError) def test_config_from_yaml_and_arg_list_failure_empty_list(tmp_path): @@ -202,10 +204,11 @@ def test_config_from_yaml_and_arg_list_failure_empty_list(tmp_path): config_path = tmp_path / "eval.yaml" config.to_yaml(config_path) - with pytest.raises(omegaconf.errors.ValidationError): + with pytest.raises(OumiConfigParsingError) as exc_info: EvaluationConfig.from_yaml_and_arg_list( config_path, [ "tasks[0].num_samples=1", # index doesn't exist ], ) + assert isinstance(exc_info.value.__cause__, omegaconf.errors.ValidationError)