diff --git a/src/oumi/cli/main.py b/src/oumi/cli/main.py index de3b519f63..3e567794d7 100644 --- a/src/oumi/cli/main.py +++ b/src/oumi/cli/main.py @@ -57,6 +57,7 @@ from oumi.cli.synth import synth from oumi.cli.train import train from oumi.cli.tune import tune +from oumi.exceptions import OumiConfigError from oumi.utils.logging import should_use_rich_logging _ASCII_LOGO = r""" @@ -365,6 +366,9 @@ def run(): telemetry = TelemetryManager.get_instance() with telemetry.capture_operation(event_name, event_properties): return app() + except OumiConfigError as e: + CONSOLE.print(f"[red]Error:[/red] {e}") + sys.exit(1) except Exception as e: tb_str = traceback.format_exc() CONSOLE.print(tb_str) diff --git a/src/oumi/core/configs/__init__.py b/src/oumi/core/configs/__init__.py index 8ed6769f5d..d52957c300 100644 --- a/src/oumi/core/configs/__init__.py +++ b/src/oumi/core/configs/__init__.py @@ -158,6 +158,7 @@ from oumi.core.configs.synthesis_config import SynthesisConfig from oumi.core.configs.training_config import TrainingConfig from oumi.core.configs.tuning_config import TuningConfig +from oumi.exceptions import OumiConfigError __all__ = [ "AsyncEvaluationConfig", @@ -192,6 +193,7 @@ "MixedPrecisionDtype", "MixtureStrategy", "ModelParams", + "OumiConfigError", "PeftParams", "PeftSaveMode", "ProfilerParams", diff --git a/src/oumi/core/configs/base_config.py b/src/oumi/core/configs/base_config.py index e7b0ec4d40..e97b90810e 100644 --- a/src/oumi/core/configs/base_config.py +++ b/src/oumi/core/configs/base_config.py @@ -25,6 +25,7 @@ from omegaconf import OmegaConf from oumi.core.configs.params.base_params import BaseParams +from oumi.exceptions import OumiConfigError T = TypeVar("T", bound="BaseConfig") @@ -128,10 +129,15 @@ def _read_config_without_interpolation(config_path: str) -> str: Returns: str: The stringified configuration. """ - with open(config_path) as f: - stringified_config = f.read() - pattern = r"(? None: + "\n".join(f"- {path}" for path in sorted(removed_paths)) ) - OmegaConf.save(config=processed_config, f=config_path) + try: + OmegaConf.save(config=processed_config, f=config_path) + except (FileNotFoundError, NotADirectoryError, IsADirectoryError) as e: + raise OumiConfigError( + f"Cannot save config to {config_path}: " + f"parent directory does not exist or is not a directory: {e}" + ) from e @classmethod def from_yaml( @@ -186,7 +198,12 @@ def from_yaml( stringified_config = _read_config_without_interpolation(str(config_path)) file_config = OmegaConf.create(stringified_config) else: - file_config = OmegaConf.load(config_path) + try: + file_config = OmegaConf.load(config_path) + except (FileNotFoundError, IsADirectoryError, NotADirectoryError) as e: + raise OumiConfigError( + f"Config file not found or path is not a file: {config_path}" + ) from e config = OmegaConf.to_object(OmegaConf.merge(schema, file_config)) if not isinstance(config, cls): raise TypeError(f"config is not {cls}") diff --git a/src/oumi/core/configs/params/model_params.py b/src/oumi/core/configs/params/model_params.py index 6dfb18b699..350846824d 100644 --- a/src/oumi/core/configs/params/model_params.py +++ b/src/oumi/core/configs/params/model_params.py @@ -22,7 +22,10 @@ from transformers.utils import find_adapter_config_file, is_flash_attn_2_available from oumi.core.configs.params.base_params import BaseParams -from oumi.core.types.exceptions import HardwareException +from oumi.exceptions import ( + HardwareException, + OumiConfigError, +) from oumi.utils.logging import logger from oumi.utils.torch_utils import get_torch_dtype @@ -280,8 +283,19 @@ def __finalize_and_validate__(self): # present, set it to the base model name found in the adapter config, # if present. Error otherwise. if len(list(adapter_dir.glob("config.json"))) == 0: - with open(adapter_config_file) as f: - adapter_config = json.load(f) + try: + with open(adapter_config_file) as f: + adapter_config = json.load(f) + except OSError as e: + raise OumiConfigError( + f"Failed to read adapter config at " + f"{adapter_config_file}: {e}" + ) from e + except json.JSONDecodeError as e: + raise OumiConfigError( + f"Adapter config at {adapter_config_file} contains invalid " + f"JSON: (line {e.lineno}, col {e.colno}): {e.msg}" + ) from e model_name = adapter_config.get("base_model_name_or_path") if not model_name: raise ValueError( diff --git a/src/oumi/core/types/__init__.py b/src/oumi/core/types/__init__.py index 74a07172d8..e3ed75f949 100644 --- a/src/oumi/core/types/__init__.py +++ b/src/oumi/core/types/__init__.py @@ -42,7 +42,7 @@ TemplatedMessage, Type, ) -from oumi.core.types.exceptions import HardwareException +from oumi.exceptions import HardwareException __all__ = [ "HardwareException", diff --git a/src/oumi/core/types/exceptions.py b/src/oumi/core/types/exceptions.py index d12c0b3ec4..c346053db0 100644 --- a/src/oumi/core/types/exceptions.py +++ b/src/oumi/core/types/exceptions.py @@ -12,6 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Backward-compatibility shim. -class HardwareException(Exception): - """An exception thrown for invalid hardware configurations.""" +The canonical location for Oumi exceptions is :mod:`oumi.exceptions`. This +module is kept only so that existing callers importing +``HardwareException`` from ``oumi.core.types.exceptions`` continue to work. +New code should import from :mod:`oumi.exceptions` (or, for convenience, +from :mod:`oumi.core.types`) instead. +""" + +from oumi.exceptions import HardwareException + +__all__ = ["HardwareException"] diff --git a/src/oumi/exceptions.py b/src/oumi/exceptions.py new file mode 100644 index 0000000000..dd883dd864 --- /dev/null +++ b/src/oumi/exceptions.py @@ -0,0 +1,27 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Oumi exception hierarchy. + +This module is intentionally free of heavy dependencies (torch, transformers, etc.) +so that it can be imported cheaply in lightweight entry-points such as the CLI. +""" + + +class OumiConfigError(Exception): + """Base class for all configuration-related errors.""" + + +class HardwareException(Exception): + """An exception thrown for invalid hardware configurations.""" diff --git a/tests/unit/cli/test_cli_launch.py b/tests/unit/cli/test_cli_launch.py index 78896e8fe1..bef74baae5 100644 --- a/tests/unit/cli/test_cli_launch.py +++ b/tests/unit/cli/test_cli_launch.py @@ -17,6 +17,7 @@ DatasetParams, DatasetSplitParams, ModelParams, + OumiConfigError, TrainerType, TrainingConfig, TrainingParams, @@ -699,7 +700,7 @@ def test_launch_up_job_not_found( done=True, state=JobState.SUCCEEDED, ) - with pytest.raises(FileNotFoundError) as exception_info: + with pytest.raises(OumiConfigError) as exception_info: res = runner.invoke( app, [ @@ -710,7 +711,9 @@ def test_launch_up_job_not_found( ) if res.exception: raise res.exception - assert "No such file or directory" in str(exception_info.value) + assert "Config file not found or path is not a file" in str( + exception_info.value + ) def test_launch_run_job( diff --git a/tests/unit/core/configs/params/test_model_params.py b/tests/unit/core/configs/params/test_model_params.py index 44108d4d3c..6f70a22f63 100644 --- a/tests/unit/core/configs/params/test_model_params.py +++ b/tests/unit/core/configs/params/test_model_params.py @@ -5,6 +5,7 @@ import pytest from oumi.core.configs.params.model_params import ModelParams +from oumi.exceptions import OumiConfigError def test_post_init_adapter_model_present(): @@ -130,3 +131,50 @@ def test_chat_template_kwargs_custom_assignment(): model_params = ModelParams(chat_template_kwargs={"enable_thinking": False}) assert model_params.chat_template_kwargs is not None assert model_params.chat_template_kwargs["enable_thinking"] is False + + +@patch("oumi.core.configs.params.model_params.find_adapter_config_file") +def test_adapter_config_file_path_not_a_file(mock_find, tmp_path: Path): + """Raises OumiConfigError when the reported adapter config path is missing. + + When `find_adapter_config_file` returns a path that does not exist on disk, the + subsequent `open()` call in the adapter-config reading branch raises + `FileNotFoundError`, which is wrapped as `OumiConfigError` with a + "Failed to read adapter config" message. + """ + mock_find.return_value = str(tmp_path / "ghost_adapter_config.json") + + params = ModelParams(model_name=str(tmp_path)) + with pytest.raises( + OumiConfigError, + match="Failed to read adapter config", + ): + params.finalize_and_validate() + + +@patch("oumi.core.configs.params.model_params.find_adapter_config_file") +def test_adapter_config_read_oserror(mock_find, tmp_path: Path): + """Test OSError reading adapter_config.json is re-raised as OumiConfigError.""" + adapter_path = tmp_path / "adapter_config.json" + adapter_path.write_text('{"base_model_name_or_path": "base_model"}') + mock_find.return_value = str(adapter_path) + + params = ModelParams(model_name=str(tmp_path)) + with patch("builtins.open", side_effect=OSError("Permission denied")): + with pytest.raises( + OumiConfigError, + match="Failed to read adapter config", + ): + params.finalize_and_validate() + + +def test_adapter_config_invalid_json(tmp_path: Path): + """Test malformed JSON raises OumiConfigError with location info.""" + (tmp_path / "adapter_config.json").write_text("{not: valid json!!}") + + params = ModelParams(model_name=str(tmp_path)) + with pytest.raises(OumiConfigError, match="contains invalid JSON") as exc_info: + params.finalize_and_validate() + + assert "line" in str(exc_info.value) + assert "col" in str(exc_info.value) diff --git a/tests/unit/core/configs/test_base_config.py b/tests/unit/core/configs/test_base_config.py index fb92e846c3..5299af41e8 100644 --- a/tests/unit/core/configs/test_base_config.py +++ b/tests/unit/core/configs/test_base_config.py @@ -5,9 +5,15 @@ from pathlib import Path from typing import Any +import pytest from omegaconf import OmegaConf -from oumi.core.configs.base_config import BaseConfig, _handle_non_primitives +from oumi.core.configs.base_config import ( + BaseConfig, + _handle_non_primitives, + _read_config_without_interpolation, +) +from oumi.exceptions import OumiConfigError class TestEnum(Enum): @@ -331,3 +337,101 @@ def test_config_from_yaml_and_arg_list(): assert new_config.bool_value is False assert new_config.list_value[0] == "override" assert new_config.dict_value["key"] == "override" + + +def test_exception_class_hierarchy(): + """Test that OumiConfigError forms the expected inheritance hierarchy.""" + assert issubclass(OumiConfigError, Exception) + + +def test_read_config_without_interpolation_file_not_found(): + """Test that a non-existent path raises OumiConfigError.""" + with pytest.raises( + OumiConfigError, + match="Config file not found or path is not a file", + ): + _read_config_without_interpolation("/nonexistent/path/config.yaml") + + +def test_read_config_without_interpolation_directory_path(tmp_path: Path): + """Test that a directory path raises OumiConfigError.""" + with pytest.raises( + OumiConfigError, + match="Config file not found or path is not a file", + ): + _read_config_without_interpolation(str(tmp_path)) + + +def test_from_yaml_file_not_found(): + """from_yaml raises OumiConfigError for a missing config path.""" + with pytest.raises( + OumiConfigError, + match="Config file not found or path is not a file", + ): + TestConfig.from_yaml("/nonexistent/path/config.yaml") + + +def test_from_yaml_path_is_directory(tmp_path: Path): + """Test that from_yaml raises OumiConfigError when given a directory.""" + with pytest.raises( + OumiConfigError, + match="Config file not found or path is not a file", + ): + TestConfig.from_yaml(tmp_path) + + +def test_to_yaml_missing_parent_directory(): + """Test that to_yaml raises OumiConfigError when the output directory is missing.""" + config = TestConfig( + str_value="test", + int_value=1, + float_value=1.0, + bool_value=True, + none_value=None, + bytes_value=b"test", + path_value=Path("test/path"), + enum_value=TestEnum.VALUE1, + list_value=[], + dict_value={}, + ) + with pytest.raises( + OumiConfigError, + match="parent directory does not exist or is not a directory", + ): + config.to_yaml("/nonexistent/subdir/out.yaml") + + +def test_from_yaml_file_not_found_no_interpolation(): + """from_yaml with ignore_interpolation=False raises on missing file.""" + with pytest.raises( + OumiConfigError, + match="Config file not found or path is not a file", + ): + TestConfig.from_yaml( + "/nonexistent/path/config.yaml", ignore_interpolation=False + ) + + +def test_from_yaml_and_arg_list_nonexistent_config(): + """from_yaml_and_arg_list raises OumiConfigError if file missing.""" + with pytest.raises( + OumiConfigError, + match="Config file not found or path is not a file", + ): + TestConfig.from_yaml_and_arg_list( + config_path="/nonexistent/path/config.yaml", + arg_list=[], + ) + + +def test_from_yaml_and_arg_list_nonexistent_config_no_interpolation(tmp_path: Path): + """Test OumiConfigError via the ignore_interpolation=False branch.""" + with pytest.raises( + OumiConfigError, + match="Config file not found or path is not a file", + ): + TestConfig.from_yaml_and_arg_list( + config_path=str(tmp_path / "nonexistent.yaml"), + arg_list=[], + ignore_interpolation=False, + )