Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/oumi/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from oumi.cli.synth import synth
from oumi.cli.train import train
from oumi.cli.tune import tune
from oumi.exceptions import OumiConfigError
from oumi.utils.logging import should_use_rich_logging

_ASCII_LOGO = r"""
Expand Down Expand Up @@ -365,6 +366,9 @@ def run():
telemetry = TelemetryManager.get_instance()
with telemetry.capture_operation(event_name, event_properties):
return app()
except OumiConfigError as e:
CONSOLE.print(f"[red]Error:[/red] {e}")
sys.exit(1)
except Exception as e:
Comment thread
oelachqar marked this conversation as resolved.
tb_str = traceback.format_exc()
CONSOLE.print(tb_str)
Expand Down
2 changes: 2 additions & 0 deletions src/oumi/core/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@
from oumi.core.configs.synthesis_config import SynthesisConfig
from oumi.core.configs.training_config import TrainingConfig
from oumi.core.configs.tuning_config import TuningConfig
from oumi.exceptions import OumiConfigError

__all__ = [
"AsyncEvaluationConfig",
Expand Down Expand Up @@ -192,6 +193,7 @@
"MixedPrecisionDtype",
"MixtureStrategy",
"ModelParams",
"OumiConfigError",
"PeftParams",
"PeftSaveMode",
"ProfilerParams",
Expand Down
29 changes: 23 additions & 6 deletions src/oumi/core/configs/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from omegaconf import OmegaConf

from oumi.core.configs.params.base_params import BaseParams
from oumi.exceptions import OumiConfigError

T = TypeVar("T", bound="BaseConfig")

Expand Down Expand Up @@ -128,10 +129,15 @@ def _read_config_without_interpolation(config_path: str) -> str:
Returns:
str: The stringified configuration.
"""
with open(config_path) as f:
stringified_config = f.read()
pattern = r"(?<!\\)\$\{" # Matches "${" but not "\${"
stringified_config = re.sub(pattern, "\\${", stringified_config)
try:
with open(config_path) as f:
stringified_config = f.read()
except (FileNotFoundError, IsADirectoryError, NotADirectoryError) as e:
raise OumiConfigError(
f"Config file not found or path is not a file: {config_path}"
) from e
pattern = r"(?<!\\)\$\{" # Matches "${" but not "\${"
stringified_config = re.sub(pattern, "\\${", stringified_config)
return stringified_config


Expand Down Expand Up @@ -165,7 +171,13 @@ def to_yaml(self, config_path: str | Path | StringIO) -> None:
+ "\n".join(f"- {path}" for path in sorted(removed_paths))
)

OmegaConf.save(config=processed_config, f=config_path)
try:
OmegaConf.save(config=processed_config, f=config_path)
except (FileNotFoundError, NotADirectoryError, IsADirectoryError) as e:
raise OumiConfigError(
f"Cannot save config to {config_path}: "
f"parent directory does not exist or is not a directory: {e}"
) from e

@classmethod
def from_yaml(
Expand All @@ -186,7 +198,12 @@ def from_yaml(
stringified_config = _read_config_without_interpolation(str(config_path))
file_config = OmegaConf.create(stringified_config)
else:
file_config = OmegaConf.load(config_path)
try:
file_config = OmegaConf.load(config_path)
except (FileNotFoundError, IsADirectoryError, NotADirectoryError) as e:
raise OumiConfigError(
f"Config file not found or path is not a file: {config_path}"
) from e
config = OmegaConf.to_object(OmegaConf.merge(schema, file_config))
if not isinstance(config, cls):
raise TypeError(f"config is not {cls}")
Expand Down
20 changes: 17 additions & 3 deletions src/oumi/core/configs/params/model_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from transformers.utils import find_adapter_config_file, is_flash_attn_2_available

from oumi.core.configs.params.base_params import BaseParams
from oumi.core.types.exceptions import HardwareException
from oumi.exceptions import (
HardwareException,
OumiConfigError,
)
from oumi.utils.logging import logger
from oumi.utils.torch_utils import get_torch_dtype

Expand Down Expand Up @@ -280,8 +283,19 @@ def __finalize_and_validate__(self):
# present, set it to the base model name found in the adapter config,
# if present. Error otherwise.
if len(list(adapter_dir.glob("config.json"))) == 0:
with open(adapter_config_file) as f:
adapter_config = json.load(f)
try:
with open(adapter_config_file) as f:
adapter_config = json.load(f)
except OSError as e:
raise OumiConfigError(
f"Failed to read adapter config at "
f"{adapter_config_file}: {e}"
) from e
except json.JSONDecodeError as e:
raise OumiConfigError(
f"Adapter config at {adapter_config_file} contains invalid "
f"JSON: (line {e.lineno}, col {e.colno}): {e.msg}"
) from e
model_name = adapter_config.get("base_model_name_or_path")
if not model_name:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion src/oumi/core/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
TemplatedMessage,
Type,
)
from oumi.core.types.exceptions import HardwareException
from oumi.exceptions import HardwareException

Comment thread
idoudali marked this conversation as resolved.
__all__ = [
"HardwareException",
Expand Down
13 changes: 11 additions & 2 deletions src/oumi/core/types/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Backward-compatibility shim.

class HardwareException(Exception):
"""An exception thrown for invalid hardware configurations."""
The canonical location for Oumi exceptions is :mod:`oumi.exceptions`. This
module is kept only so that existing callers importing
``HardwareException`` from ``oumi.core.types.exceptions`` continue to work.
New code should import from :mod:`oumi.exceptions` (or, for convenience,
from :mod:`oumi.core.types`) instead.
"""

from oumi.exceptions import HardwareException

__all__ = ["HardwareException"]
27 changes: 27 additions & 0 deletions src/oumi/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Oumi exception hierarchy.

This module is intentionally free of heavy dependencies (torch, transformers, etc.)
so that it can be imported cheaply in lightweight entry-points such as the CLI.
"""


class OumiConfigError(Exception):
"""Base class for all configuration-related errors."""


class HardwareException(Exception):
"""An exception thrown for invalid hardware configurations."""
7 changes: 5 additions & 2 deletions tests/unit/cli/test_cli_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DatasetParams,
DatasetSplitParams,
ModelParams,
OumiConfigError,
TrainerType,
TrainingConfig,
TrainingParams,
Expand Down Expand Up @@ -699,7 +700,7 @@ def test_launch_up_job_not_found(
done=True,
state=JobState.SUCCEEDED,
)
with pytest.raises(FileNotFoundError) as exception_info:
with pytest.raises(OumiConfigError) as exception_info:
res = runner.invoke(
app,
[
Expand All @@ -710,7 +711,9 @@ def test_launch_up_job_not_found(
)
if res.exception:
raise res.exception
assert "No such file or directory" in str(exception_info.value)
assert "Config file not found or path is not a file" in str(
exception_info.value
)


def test_launch_run_job(
Expand Down
48 changes: 48 additions & 0 deletions tests/unit/core/configs/params/test_model_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from oumi.core.configs.params.model_params import ModelParams
from oumi.exceptions import OumiConfigError


def test_post_init_adapter_model_present():
Expand Down Expand Up @@ -130,3 +131,50 @@ def test_chat_template_kwargs_custom_assignment():
model_params = ModelParams(chat_template_kwargs={"enable_thinking": False})
assert model_params.chat_template_kwargs is not None
assert model_params.chat_template_kwargs["enable_thinking"] is False


@patch("oumi.core.configs.params.model_params.find_adapter_config_file")
def test_adapter_config_file_path_not_a_file(mock_find, tmp_path: Path):
"""Raises OumiConfigError when the reported adapter config path is missing.

When `find_adapter_config_file` returns a path that does not exist on disk, the
subsequent `open()` call in the adapter-config reading branch raises
`FileNotFoundError`, which is wrapped as `OumiConfigError` with a
"Failed to read adapter config" message.
"""
mock_find.return_value = str(tmp_path / "ghost_adapter_config.json")

params = ModelParams(model_name=str(tmp_path))
with pytest.raises(
OumiConfigError,
match="Failed to read adapter config",
):
params.finalize_and_validate()


@patch("oumi.core.configs.params.model_params.find_adapter_config_file")
def test_adapter_config_read_oserror(mock_find, tmp_path: Path):
"""Test OSError reading adapter_config.json is re-raised as OumiConfigError."""
adapter_path = tmp_path / "adapter_config.json"
adapter_path.write_text('{"base_model_name_or_path": "base_model"}')
mock_find.return_value = str(adapter_path)

params = ModelParams(model_name=str(tmp_path))
with patch("builtins.open", side_effect=OSError("Permission denied")):
with pytest.raises(
OumiConfigError,
match="Failed to read adapter config",
):
params.finalize_and_validate()


def test_adapter_config_invalid_json(tmp_path: Path):
"""Test malformed JSON raises OumiConfigError with location info."""
(tmp_path / "adapter_config.json").write_text("{not: valid json!!}")

params = ModelParams(model_name=str(tmp_path))
with pytest.raises(OumiConfigError, match="contains invalid JSON") as exc_info:
params.finalize_and_validate()

assert "line" in str(exc_info.value)
assert "col" in str(exc_info.value)
106 changes: 105 additions & 1 deletion tests/unit/core/configs/test_base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
from pathlib import Path
from typing import Any

import pytest
from omegaconf import OmegaConf

from oumi.core.configs.base_config import BaseConfig, _handle_non_primitives
from oumi.core.configs.base_config import (
BaseConfig,
_handle_non_primitives,
_read_config_without_interpolation,
)
from oumi.exceptions import OumiConfigError


class TestEnum(Enum):
Expand Down Expand Up @@ -331,3 +337,101 @@ def test_config_from_yaml_and_arg_list():
assert new_config.bool_value is False
assert new_config.list_value[0] == "override"
assert new_config.dict_value["key"] == "override"


def test_exception_class_hierarchy():
"""Test that OumiConfigError forms the expected inheritance hierarchy."""
assert issubclass(OumiConfigError, Exception)


def test_read_config_without_interpolation_file_not_found():
"""Test that a non-existent path raises OumiConfigError."""
with pytest.raises(
OumiConfigError,
match="Config file not found or path is not a file",
):
_read_config_without_interpolation("/nonexistent/path/config.yaml")


def test_read_config_without_interpolation_directory_path(tmp_path: Path):
"""Test that a directory path raises OumiConfigError."""
with pytest.raises(
OumiConfigError,
match="Config file not found or path is not a file",
):
_read_config_without_interpolation(str(tmp_path))


def test_from_yaml_file_not_found():
"""from_yaml raises OumiConfigError for a missing config path."""
with pytest.raises(
OumiConfigError,
match="Config file not found or path is not a file",
):
TestConfig.from_yaml("/nonexistent/path/config.yaml")


def test_from_yaml_path_is_directory(tmp_path: Path):
"""Test that from_yaml raises OumiConfigError when given a directory."""
with pytest.raises(
OumiConfigError,
match="Config file not found or path is not a file",
):
TestConfig.from_yaml(tmp_path)


def test_to_yaml_missing_parent_directory():
"""Test that to_yaml raises OumiConfigError when the output directory is missing."""
config = TestConfig(
str_value="test",
int_value=1,
float_value=1.0,
bool_value=True,
none_value=None,
bytes_value=b"test",
path_value=Path("test/path"),
enum_value=TestEnum.VALUE1,
list_value=[],
dict_value={},
)
with pytest.raises(
OumiConfigError,
match="parent directory does not exist or is not a directory",
):
config.to_yaml("/nonexistent/subdir/out.yaml")


def test_from_yaml_file_not_found_no_interpolation():
"""from_yaml with ignore_interpolation=False raises on missing file."""
with pytest.raises(
OumiConfigError,
match="Config file not found or path is not a file",
):
TestConfig.from_yaml(
"/nonexistent/path/config.yaml", ignore_interpolation=False
)


def test_from_yaml_and_arg_list_nonexistent_config():
"""from_yaml_and_arg_list raises OumiConfigError if file missing."""
with pytest.raises(
OumiConfigError,
match="Config file not found or path is not a file",
):
TestConfig.from_yaml_and_arg_list(
config_path="/nonexistent/path/config.yaml",
arg_list=[],
)


def test_from_yaml_and_arg_list_nonexistent_config_no_interpolation(tmp_path: Path):
"""Test OumiConfigError via the ignore_interpolation=False branch."""
with pytest.raises(
OumiConfigError,
match="Config file not found or path is not a file",
):
TestConfig.from_yaml_and_arg_list(
config_path=str(tmp_path / "nonexistent.yaml"),
arg_list=[],
ignore_interpolation=False,
)
Loading