From 7b4aa70b1c5a6551188e4ba168d0be09df04676d Mon Sep 17 00:00:00 2001 From: Ioannis Doudalis Date: Sat, 28 Mar 2026 10:26:02 +0000 Subject: [PATCH] Replace config ValueError with OumiConfigValueError Use OumiConfigValueError for validation failures under core/configs so the CLI treats them as OumiConfigError and prints a short message instead of a full traceback. OumiConfigValueError subclasses both OumiConfigError and ValueError for backward compatibility with existing except ValueError tests. Update BaseConfig/BaseParams docstrings to recommend OumiConfigValueError. Fixes OPE-1853 --- src/oumi/core/configs/analyze_config.py | 9 +- .../core/configs/async_evaluation_config.py | 5 +- src/oumi/core/configs/base_config.py | 4 +- src/oumi/core/configs/judge_config.py | 5 +- src/oumi/core/configs/params/base_params.py | 4 +- src/oumi/core/configs/params/data_params.py | 33 +++-- .../core/configs/params/deepspeed_params.py | 5 +- .../core/configs/params/evaluation_params.py | 15 +- src/oumi/core/configs/params/fsdp_params.py | 7 +- .../core/configs/params/generation_params.py | 15 +- src/oumi/core/configs/params/gkd_params.py | 11 +- src/oumi/core/configs/params/gold_params.py | 31 ++-- src/oumi/core/configs/params/grpo_params.py | 9 +- .../configs/params/guided_decoding_params.py | 3 +- src/oumi/core/configs/params/judge_params.py | 15 +- src/oumi/core/configs/params/model_params.py | 9 +- src/oumi/core/configs/params/peft_params.py | 3 +- .../core/configs/params/profiler_params.py | 5 +- src/oumi/core/configs/params/remote_params.py | 59 +++++--- .../core/configs/params/rule_judge_params.py | 17 ++- .../core/configs/params/synthesis_params.py | 133 +++++++++++------- src/oumi/core/configs/params/test_params.py | 17 +-- .../core/configs/params/training_params.py | 25 ++-- src/oumi/core/configs/params/tuning_params.py | 31 ++-- src/oumi/core/configs/quantization_config.py | 5 +- src/oumi/core/configs/synthesis_config.py | 13 +- src/oumi/core/configs/training_config.py | 19 +-- src/oumi/exceptions.py | 3 + 28 files changed, 303 insertions(+), 207 deletions(-) diff --git a/src/oumi/core/configs/analyze_config.py b/src/oumi/core/configs/analyze_config.py index fa9788a8fb..1ac8d64e35 100644 --- a/src/oumi/core/configs/analyze_config.py +++ b/src/oumi/core/configs/analyze_config.py @@ -21,6 +21,7 @@ from oumi.core.configs.base_config import BaseConfig from oumi.core.configs.params.base_params import BaseParams +from oumi.exceptions import OumiConfigValueError class DatasetSource(Enum): @@ -167,14 +168,16 @@ def __post_init__(self): # Validate sample_count if self.sample_count is not None and self.sample_count <= 0: - raise ValueError("`sample_count` must be greater than 0.") + raise OumiConfigValueError("`sample_count` must be greater than 0.") # Validate analyzer configurations analyzer_ids = set() for analyzer in self.analyzers: # Validate analyzer ID if not analyzer.id: - raise ValueError("Analyzer 'id' must be provided") + raise OumiConfigValueError("Analyzer 'id' must be provided") if analyzer.id in analyzer_ids: - raise ValueError(f"Duplicate analyzer ID found: '{analyzer.id}'") + raise OumiConfigValueError( + f"Duplicate analyzer ID found: '{analyzer.id}'" + ) analyzer_ids.add(analyzer.id) diff --git a/src/oumi/core/configs/async_evaluation_config.py b/src/oumi/core/configs/async_evaluation_config.py index fcc3b97501..9826ea024d 100644 --- a/src/oumi/core/configs/async_evaluation_config.py +++ b/src/oumi/core/configs/async_evaluation_config.py @@ -18,6 +18,7 @@ from oumi.core.configs.base_config import BaseConfig from oumi.core.configs.evaluation_config import EvaluationConfig +from oumi.exceptions import OumiConfigValueError @dataclass @@ -48,6 +49,6 @@ class AsyncEvaluationConfig(BaseConfig): def __post_init__(self): """Verifies/populates params.""" if self.polling_interval < 0: - raise ValueError("`polling_interval` must be non-negative.") + raise OumiConfigValueError("`polling_interval` must be non-negative.") if self.num_retries < 0: - raise ValueError("`num_retries` must be non-negative.") + raise OumiConfigValueError("`num_retries` must be non-negative.") diff --git a/src/oumi/core/configs/base_config.py b/src/oumi/core/configs/base_config.py index e97b90810e..0cda034f3f 100644 --- a/src/oumi/core/configs/base_config.py +++ b/src/oumi/core/configs/base_config.py @@ -334,8 +334,8 @@ def __finalize_and_validate__(self) -> None: This method can be overridden by subclasses to implement custom validation logic. - In case of validation errors, this method should raise a `ValueError` - or other appropriate exception. + In case of validation errors, this method should raise + `OumiConfigValueError` or another appropriate exception. """ def __iter__(self) -> Iterator[tuple[str, Any]]: diff --git a/src/oumi/core/configs/judge_config.py b/src/oumi/core/configs/judge_config.py index 6beda4ee19..885fc6b813 100644 --- a/src/oumi/core/configs/judge_config.py +++ b/src/oumi/core/configs/judge_config.py @@ -25,6 +25,7 @@ from oumi.core.configs.inference_config import InferenceConfig from oumi.core.configs.params.judge_params import JudgeParams from oumi.core.configs.params.rule_judge_params import RuleJudgeParams +from oumi.exceptions import OumiConfigValueError JUDGE_CONFIG_REPO_PATH_TEMPLATE = "oumi://configs/projects/judges/{path}.yaml" @@ -109,14 +110,14 @@ def _resolve_path(unresolved_path: str) -> str | None: try: return cls.from_yaml_and_arg_list(resolved_path, extra_args) except Exception as e: - raise ValueError( + raise OumiConfigValueError( f"Failed to parse {resolved_path} as JudgeConfig. " f"Please ensure the YAML file contains both 'judge_params' and " f"'inference_config' sections with valid fields. " f"Original error: {e}" ) from e else: - raise ValueError( + raise OumiConfigValueError( f"Could not resolve JudgeConfig from path: {path}. " "Please provide a valid local or GitHub repo path." ) diff --git a/src/oumi/core/configs/params/base_params.py b/src/oumi/core/configs/params/base_params.py index 3dc8e1ee16..5dea0551a0 100644 --- a/src/oumi/core/configs/params/base_params.py +++ b/src/oumi/core/configs/params/base_params.py @@ -42,8 +42,8 @@ def __finalize_and_validate__(self) -> None: This method can be overridden by subclasses to implement custom validation logic. - In case of validation errors, this method should raise a `ValueError` - or other appropriate exception. + In case of validation errors, this method should raise + `OumiConfigValueError` or another appropriate exception. """ def __iter__(self) -> Iterator[tuple[str, Any]]: diff --git a/src/oumi/core/configs/params/data_params.py b/src/oumi/core/configs/params/data_params.py index 30f11a39c2..d25fc6c4c0 100644 --- a/src/oumi/core/configs/params/data_params.py +++ b/src/oumi/core/configs/params/data_params.py @@ -21,6 +21,7 @@ from omegaconf import MISSING from oumi.core.configs.params.base_params import BaseParams +from oumi.exceptions import OumiConfigValueError # Training Params @@ -49,7 +50,7 @@ def get_literal_value(self) -> Literal["first_exhausted", "all_exhausted"]: elif self.value == MixtureStrategy.ALL_EXHAUSTED: return "all_exhausted" else: - raise ValueError("Unsupported value for MixtureStrategy") + raise OumiConfigValueError("Unsupported value for MixtureStrategy") @dataclass @@ -149,24 +150,28 @@ def __post_init__(self): """Verifies params.""" if self.sample_count is not None: if self.sample_count < 0: - raise ValueError("`sample_count` must be greater than 0.") + raise OumiConfigValueError("`sample_count` must be greater than 0.") if self.mixture_proportion is not None: if self.mixture_proportion < 0: - raise ValueError("`mixture_proportion` must be greater than 0.") + raise OumiConfigValueError( + "`mixture_proportion` must be greater than 0." + ) if self.mixture_proportion > 1: - raise ValueError("`mixture_proportion` must not be greater than 1.0 .") + raise OumiConfigValueError( + "`mixture_proportion` must not be greater than 1.0 ." + ) if self.transform_num_workers is not None: if isinstance(self.transform_num_workers, str): if not (self.transform_num_workers == "auto"): - raise ValueError( + raise OumiConfigValueError( "Unknown value of transform_num_workers: " f"{self.transform_num_workers}. Must be 'auto' if string." ) elif (not isinstance(self.transform_num_workers, int)) or ( self.transform_num_workers <= 0 ): - raise ValueError( + raise OumiConfigValueError( "Non-positive value of transform_num_workers: " f"{self.transform_num_workers}." ) @@ -176,7 +181,7 @@ def __post_init__(self): self.dataset_kwargs.keys() ) if len(conflicting_keys) > 0: - raise ValueError( + raise OumiConfigValueError( "dataset_kwargs attempts to override the following " f"reserved fields: {conflicting_keys}. " "Use properties of DatasetParams instead." @@ -270,7 +275,7 @@ def __post_init__(self): if not all( [dataset.mixture_proportion is not None for dataset in self.datasets] ): - raise ValueError( + raise OumiConfigValueError( "If `mixture_proportion` is specified it must be " " specified for all datasets" ) @@ -278,7 +283,7 @@ def __post_init__(self): filter(None, [dataset.mixture_proportion for dataset in self.datasets]) ) if not self._is_sum_normalized(mix_sum): - raise ValueError( + raise OumiConfigValueError( "The sum of `mixture_proportion` must be 1.0. " f"The current sum is {mix_sum} ." ) @@ -286,7 +291,7 @@ def __post_init__(self): self.mixture_strategy != MixtureStrategy.ALL_EXHAUSTED and self.mixture_strategy != MixtureStrategy.FIRST_EXHAUSTED ): - raise ValueError( + raise OumiConfigValueError( "`mixture_strategy` must be one of " f'["{MixtureStrategy.FIRST_EXHAUSTED.value}", ' f'"{MixtureStrategy.ALL_EXHAUSTED.value}"].' @@ -324,12 +329,12 @@ def get_split(self, split: DatasetSplit) -> DatasetSplitParams: elif split == DatasetSplit.VALIDATION: return self.validation else: - raise ValueError(f"Received invalid split: {split}.") + raise OumiConfigValueError(f"Received invalid split: {split}.") def __finalize_and_validate__(self): """Verifies params.""" if len(self.train.datasets) == 0: - raise ValueError("At least one training dataset is required.") + raise OumiConfigValueError("At least one training dataset is required.") all_collators = set() if self.train.collator_name: @@ -339,11 +344,11 @@ def __finalize_and_validate__(self): if self.test.collator_name: all_collators.add(self.test.collator_name) if len(all_collators) >= 2: - raise ValueError( + raise OumiConfigValueError( f"Different data collators are not supported yet: {all_collators}" ) elif len(all_collators) == 1 and not self.train.collator_name: - raise ValueError( + raise OumiConfigValueError( "Data collator must be also specified " f"on the `train` split: {all_collators}" ) diff --git a/src/oumi/core/configs/params/deepspeed_params.py b/src/oumi/core/configs/params/deepspeed_params.py index 809ac5899c..9348abdc75 100644 --- a/src/oumi/core/configs/params/deepspeed_params.py +++ b/src/oumi/core/configs/params/deepspeed_params.py @@ -18,6 +18,7 @@ from typing import Any from oumi.core.configs.params.base_params import BaseParams +from oumi.exceptions import OumiConfigValueError class ZeRORuntimeStage(str, Enum): @@ -287,7 +288,7 @@ def __post_init__(self) -> None: self.offload_param is not None and self.zero_stage != ZeRORuntimeStage.ZERO_3 ): - raise ValueError( + raise OumiConfigValueError( "Parameter offloading is only supported with ZeRO stage 3. " f"Current stage: {self.zero_stage}" ) @@ -297,7 +298,7 @@ def __post_init__(self) -> None: ZeRORuntimeStage.ZERO_2, ZeRORuntimeStage.ZERO_3, ]: - raise ValueError( + raise OumiConfigValueError( "Optimizer offloading requires ZeRO stage 1, 2, or 3. " f"Current stage: {self.zero_stage}" ) diff --git a/src/oumi/core/configs/params/evaluation_params.py b/src/oumi/core/configs/params/evaluation_params.py index c4626d2b82..202caf6ce2 100644 --- a/src/oumi/core/configs/params/evaluation_params.py +++ b/src/oumi/core/configs/params/evaluation_params.py @@ -17,6 +17,7 @@ from typing import Any from oumi.core.configs.params.base_params import BaseParams +from oumi.exceptions import OumiConfigValueError class EvaluationBackend(Enum): @@ -108,7 +109,7 @@ def my_evaluation(task_params, config): def get_evaluation_backend(self) -> EvaluationBackend: """Returns the evaluation backend as an Enum.""" if not self.evaluation_backend: - raise ValueError( + raise OumiConfigValueError( "Missing `evaluation_backend`. When running evaluations, it is " "necessary to specify the evaluation backend to use for EACH task. " "The available backends can be found in the following enum: " @@ -120,7 +121,9 @@ def get_evaluation_backend(self) -> EvaluationBackend: elif self.evaluation_backend == EvaluationBackend.CUSTOM.value: return EvaluationBackend.CUSTOM else: - raise ValueError(f"Unknown evaluation backend: {self.evaluation_backend}") + raise OumiConfigValueError( + f"Unknown evaluation backend: {self.evaluation_backend}" + ) @staticmethod def list_evaluation_backends() -> str: @@ -130,7 +133,9 @@ def list_evaluation_backends() -> str: def __post_init__(self): """Verifies params.""" if self.num_samples is not None and self.num_samples <= 0: - raise ValueError("`num_samples` must be None or a positive integer.") + raise OumiConfigValueError( + "`num_samples` must be None or a positive integer." + ) @dataclass @@ -152,6 +157,6 @@ class LMHarnessTaskParams(EvaluationTaskParams): def __post_init__(self): """Verifies params.""" if not self.task_name: - raise ValueError("`task_name` must be a valid LM Harness task.") + raise OumiConfigValueError("`task_name` must be a valid LM Harness task.") if self.num_fewshot and self.num_fewshot < 0: - raise ValueError("`num_fewshot` must be non-negative.") + raise OumiConfigValueError("`num_fewshot` must be non-negative.") diff --git a/src/oumi/core/configs/params/fsdp_params.py b/src/oumi/core/configs/params/fsdp_params.py index 91159bc56f..b0f50063d1 100644 --- a/src/oumi/core/configs/params/fsdp_params.py +++ b/src/oumi/core/configs/params/fsdp_params.py @@ -18,6 +18,7 @@ import torch.distributed.fsdp as torch_fsdp from oumi.core.configs.params.base_params import BaseParams +from oumi.exceptions import OumiConfigValueError class ShardingStrategy(str, Enum): @@ -60,7 +61,7 @@ def to_torch(self) -> torch_fsdp.ShardingStrategy: } if self not in strategy_map: - raise ValueError(f"Unsupported sharding strategy: {self}") + raise OumiConfigValueError(f"Unsupported sharding strategy: {self}") return strategy_map[self] @@ -101,7 +102,7 @@ def to_torch(self) -> torch_fsdp.StateDictType: } if self not in state_dict_map: - raise ValueError(f"Unsupported state dict type: {self}") + raise OumiConfigValueError(f"Unsupported state dict type: {self}") return state_dict_map[self] @@ -127,7 +128,7 @@ def to_torch(self) -> torch_fsdp.BackwardPrefetch | None: } if self not in map: - raise ValueError(f"Unsupported backward prefetch option: {self}") + raise OumiConfigValueError(f"Unsupported backward prefetch option: {self}") return map[self] diff --git a/src/oumi/core/configs/params/generation_params.py b/src/oumi/core/configs/params/generation_params.py index caaf5b5fc7..bbfbdf2628 100644 --- a/src/oumi/core/configs/params/generation_params.py +++ b/src/oumi/core/configs/params/generation_params.py @@ -17,6 +17,7 @@ from oumi.core.configs.params.base_params import BaseParams from oumi.core.configs.params.guided_decoding_params import GuidedDecodingParams +from oumi.exceptions import OumiConfigValueError @dataclass @@ -133,27 +134,27 @@ class GenerationParams(BaseParams): def __post_init__(self): """Validates generation-specific parameters.""" if self.batch_size is not None and self.batch_size < 1: - raise ValueError("Batch size must be positive.") + raise OumiConfigValueError("Batch size must be positive.") if self.num_beams < 1: - raise ValueError("num_beams must be strictly larger than 0.") + raise OumiConfigValueError("num_beams must be strictly larger than 0.") if self.temperature < 0: - raise ValueError("Temperature must be non-negative.") + raise OumiConfigValueError("Temperature must be non-negative.") if self.top_p is not None and not 0 <= self.top_p <= 1: - raise ValueError("top_p must be between 0 and 1.") + raise OumiConfigValueError("top_p must be between 0 and 1.") for token_id, bias in self.logit_bias.items(): if not isinstance(token_id, str | int): - raise ValueError( + raise OumiConfigValueError( f"Logit bias token ID {token_id} must be an integer or a string." ) if not -100 <= bias <= 100: - raise ValueError( + raise OumiConfigValueError( f"Logit bias for token {token_id} must be between -100 and 100." ) if not 0 <= self.min_p <= 1: - raise ValueError("min_p must be between 0 and 1.") + raise OumiConfigValueError("min_p must be between 0 and 1.") diff --git a/src/oumi/core/configs/params/gkd_params.py b/src/oumi/core/configs/params/gkd_params.py index 82c7b90e4c..0c13585f56 100644 --- a/src/oumi/core/configs/params/gkd_params.py +++ b/src/oumi/core/configs/params/gkd_params.py @@ -17,6 +17,7 @@ from typing import Any from oumi.core.configs.params.base_params import BaseParams +from oumi.exceptions import OumiConfigValueError @dataclass @@ -106,7 +107,7 @@ def __post_init__(self): f"Actual type: {type(self.teacher_model_name_or_path)}" ) if not self.teacher_model_name_or_path.strip(): - raise ValueError( + raise OumiConfigValueError( "GkdParams.teacher_model_name_or_path cannot be empty." ) @@ -115,23 +116,23 @@ def __post_init__(self): and self.temperature > 0.0 and self.temperature <= 1.0 ): - raise ValueError( + raise OumiConfigValueError( "GkdParams.temperature must be in range (0.0, 1.0]. " f"Actual: {self.temperature}" ) if not (math.isfinite(self.lmbda) and 0.0 <= self.lmbda <= 1.0): - raise ValueError( + raise OumiConfigValueError( f"GkdParams.lmbda must be in range [0.0, 1.0]. Actual: {self.lmbda}" ) if not (math.isfinite(self.beta) and 0.0 <= self.beta <= 1.0): - raise ValueError( + raise OumiConfigValueError( f"GkdParams.beta must be in range [0.0, 1.0]. Actual: {self.beta}" ) if self.max_new_tokens <= 0: - raise ValueError( + raise OumiConfigValueError( "GkdParams.max_new_tokens must be positive. " f"Actual: {self.max_new_tokens}" ) diff --git a/src/oumi/core/configs/params/gold_params.py b/src/oumi/core/configs/params/gold_params.py index 266930a56f..64b92c716f 100644 --- a/src/oumi/core/configs/params/gold_params.py +++ b/src/oumi/core/configs/params/gold_params.py @@ -17,6 +17,7 @@ from typing import Any from oumi.core.configs.params.base_params import BaseParams +from oumi.exceptions import OumiConfigValueError @dataclass @@ -302,7 +303,7 @@ def __post_init__(self): f"Actual type: {type(self.teacher_model_name_or_path)}" ) if not self.teacher_model_name_or_path.strip(): - raise ValueError( + raise OumiConfigValueError( "GoldParams.teacher_model_name_or_path cannot be empty." ) @@ -311,33 +312,33 @@ def __post_init__(self): and self.temperature > 0.0 and self.temperature <= 1.0 ): - raise ValueError( + raise OumiConfigValueError( "GoldParams.temperature must be in range (0.0, 1.0]. " f"Actual: {self.temperature}" ) if not (math.isfinite(self.top_p) and 0.0 < self.top_p <= 1.0): - raise ValueError( + raise OumiConfigValueError( f"GoldParams.top_p must be in range (0.0, 1.0]. Actual: {self.top_p}" ) if self.top_k < 0: - raise ValueError( + raise OumiConfigValueError( f"GoldParams.top_k must be non-negative. Actual: {self.top_k}" ) if not (math.isfinite(self.lmbda) and 0.0 <= self.lmbda <= 1.0): - raise ValueError( + raise OumiConfigValueError( f"GoldParams.lmbda must be in range [0.0, 1.0]. Actual: {self.lmbda}" ) if not (math.isfinite(self.beta) and 0.0 <= self.beta <= 1.0): - raise ValueError( + raise OumiConfigValueError( f"GoldParams.beta must be in range [0.0, 1.0]. Actual: {self.beta}" ) if self.max_completion_length <= 0: - raise ValueError( + raise OumiConfigValueError( "GoldParams.max_completion_length must be positive. " f"Actual: {self.max_completion_length}" ) @@ -345,25 +346,25 @@ def __post_init__(self): # Validate ULD parameters if self.use_uld_loss: if self.uld_crossentropy_weight < 0.0: - raise ValueError( + raise OumiConfigValueError( "GoldParams.uld_crossentropy_weight must be non-negative. " f"Actual: {self.uld_crossentropy_weight}" ) if self.uld_distillation_weight < 0.0: - raise ValueError( + raise OumiConfigValueError( "GoldParams.uld_distillation_weight must be non-negative. " f"Actual: {self.uld_distillation_weight}" ) if self.uld_student_temperature <= 0.0: - raise ValueError( + raise OumiConfigValueError( "GoldParams.uld_student_temperature must be positive. " f"Actual: {self.uld_student_temperature}" ) if self.uld_teacher_temperature <= 0.0: - raise ValueError( + raise OumiConfigValueError( "GoldParams.uld_teacher_temperature must be positive. " f"Actual: {self.uld_teacher_temperature}" ) @@ -373,7 +374,7 @@ def __post_init__(self): if (self.uld_hybrid_matched_weight is None) != ( self.uld_hybrid_unmatched_weight is None ): - raise ValueError( + raise OumiConfigValueError( "GoldParams.uld_hybrid_matched_weight and " "uld_hybrid_unmatched_weight must both be None (for adaptive " "weighting) or both be set to numeric values. " @@ -388,18 +389,18 @@ def __post_init__(self): and self.uld_hybrid_unmatched_weight is not None ): if self.uld_hybrid_matched_weight < 0.0: - raise ValueError( + raise OumiConfigValueError( "GoldParams.uld_hybrid_matched_weight must be " f"non-negative. Actual: {self.uld_hybrid_matched_weight}" ) if self.uld_hybrid_unmatched_weight < 0.0: - raise ValueError( + raise OumiConfigValueError( "GoldParams.uld_hybrid_unmatched_weight must be " f"non-negative. Actual: {self.uld_hybrid_unmatched_weight}" ) if self.vllm_mode not in ("server", "colocate"): - raise ValueError( + raise OumiConfigValueError( f"GoldParams.vllm_mode must be 'server' or 'colocate'. " f"Actual: {self.vllm_mode}" ) diff --git a/src/oumi/core/configs/params/grpo_params.py b/src/oumi/core/configs/params/grpo_params.py index 806a609537..f050e9a7de 100644 --- a/src/oumi/core/configs/params/grpo_params.py +++ b/src/oumi/core/configs/params/grpo_params.py @@ -17,6 +17,7 @@ from typing import Any from oumi.core.configs.params.base_params import BaseParams +from oumi.exceptions import OumiConfigValueError @dataclass @@ -112,17 +113,17 @@ class GrpoParams(BaseParams): def __post_init__(self): """Verifies params.""" if self.max_prompt_length is not None and self.max_prompt_length <= 0: - raise ValueError( + raise OumiConfigValueError( "GrpoParams.max_prompt_length must be positive. " f"Actual: {self.max_prompt_length}" ) if self.max_completion_length is not None and self.max_completion_length <= 0: - raise ValueError( + raise OumiConfigValueError( "GrpoParams.max_completion_length must be positive. " f"Actual: {self.max_completion_length}" ) if self.num_generations is not None and self.num_generations <= 0: - raise ValueError( + raise OumiConfigValueError( "GrpoParams.num_generations must be positive. " f"Actual: {self.num_generations}" ) @@ -131,7 +132,7 @@ def __post_init__(self): and self.temperature >= 0.0 and self.temperature <= 1.0 ): - raise ValueError( + raise OumiConfigValueError( "GrpoParams.temperature must be within [0.0, 1.0] range. " f"Actual: {self.temperature}" ) diff --git a/src/oumi/core/configs/params/guided_decoding_params.py b/src/oumi/core/configs/params/guided_decoding_params.py index 4973d30079..19ecbcedc1 100644 --- a/src/oumi/core/configs/params/guided_decoding_params.py +++ b/src/oumi/core/configs/params/guided_decoding_params.py @@ -16,6 +16,7 @@ from typing import Any from oumi.core.configs.params.base_params import BaseParams +from oumi.exceptions import OumiConfigValueError @dataclasses.dataclass @@ -52,6 +53,6 @@ def __post_init__(self) -> None: """Validate parameters.""" provided = sum(x is not None for x in [self.json, self.regex, self.choice]) if provided > 1: - raise ValueError( + raise OumiConfigValueError( "Only one of 'json', 'regex', or 'choice' can be specified" ) diff --git a/src/oumi/core/configs/params/judge_params.py b/src/oumi/core/configs/params/judge_params.py index cedc60104d..f7bd35f9d1 100644 --- a/src/oumi/core/configs/params/judge_params.py +++ b/src/oumi/core/configs/params/judge_params.py @@ -16,6 +16,7 @@ from enum import Enum from oumi.core.configs.params.base_params import BaseParams +from oumi.exceptions import OumiConfigValueError from oumi.utils.placeholders import get_placeholders, resolve_placeholders @@ -155,11 +156,13 @@ def _validate_params(self): """ # Validate prompt template is not empty if not self.prompt_template.strip(): - raise ValueError("prompt_template cannot be empty") + raise OumiConfigValueError("prompt_template cannot be empty") # Validate judgment scores for ENUM judgment type if self.judgment_type == JudgeOutputType.ENUM and not self.judgment_scores: - raise ValueError("judgment_scores must be provided for ENUM judgment_type") + raise OumiConfigValueError( + "judgment_scores must be provided for ENUM judgment_type" + ) # Validate judgment scores are numeric if provided if self.judgment_scores: @@ -167,16 +170,18 @@ def _validate_params(self): isinstance(score, int | float) for score in self.judgment_scores.values() ): - raise ValueError("All judgment_scores values must be numeric") + raise OumiConfigValueError("All judgment_scores values must be numeric") if not self.judgment_scores: - raise ValueError("judgment_scores cannot be empty when provided") + raise OumiConfigValueError( + "judgment_scores cannot be empty when provided" + ) # Validate prompt_template_placeholders if self.prompt_template_placeholders: actual_placeholders = self.get_placeholders() declared_placeholders = set(self.prompt_template_placeholders) if declared_placeholders != actual_placeholders: - raise ValueError( + raise OumiConfigValueError( f"prompt_template_placeholders ({declared_placeholders}) are " "inconsistent with placeholders found in the prompt_template " f"({actual_placeholders})" diff --git a/src/oumi/core/configs/params/model_params.py b/src/oumi/core/configs/params/model_params.py index 350846824d..7e3e9e5cd6 100644 --- a/src/oumi/core/configs/params/model_params.py +++ b/src/oumi/core/configs/params/model_params.py @@ -25,6 +25,7 @@ from oumi.exceptions import ( HardwareException, OumiConfigError, + OumiConfigValueError, ) from oumi.utils.logging import logger from oumi.utils.torch_utils import get_torch_dtype @@ -237,7 +238,7 @@ def __post_init__(self): self.processor_kwargs.keys() ) if len(conflicting_keys) > 0: - raise ValueError( + raise OumiConfigValueError( "processor_kwargs attempts to override the following " f"reserved fields: {conflicting_keys}. " "Use properties of ModelParams instead." @@ -298,7 +299,7 @@ def __finalize_and_validate__(self): ) from e model_name = adapter_config.get("base_model_name_or_path") if not model_name: - raise ValueError( + raise OumiConfigValueError( "`model_name` specifies an adapter model only," " but the base model could not be found!" ) @@ -318,4 +319,6 @@ def __finalize_and_validate__(self): ) if self.model_max_length is not None and self.model_max_length <= 0: - raise ValueError("model_max_length must be a positive integer or None.") + raise OumiConfigValueError( + "model_max_length must be a positive integer or None." + ) diff --git a/src/oumi/core/configs/params/peft_params.py b/src/oumi/core/configs/params/peft_params.py index fdb7d3e18f..a409a33787 100644 --- a/src/oumi/core/configs/params/peft_params.py +++ b/src/oumi/core/configs/params/peft_params.py @@ -21,6 +21,7 @@ from transformers import BitsAndBytesConfig from oumi.core.configs.params.base_params import BaseParams +from oumi.exceptions import OumiConfigValueError class PeftSaveMode(Enum): @@ -89,7 +90,7 @@ def get_literal_value( "loftq", "olora", }: - raise ValueError(f"Invalid enum value: {self.value}") + raise OumiConfigValueError(f"Invalid enum value: {self.value}") return self.value diff --git a/src/oumi/core/configs/params/profiler_params.py b/src/oumi/core/configs/params/profiler_params.py index 170378e8f3..f3fb0b3a61 100644 --- a/src/oumi/core/configs/params/profiler_params.py +++ b/src/oumi/core/configs/params/profiler_params.py @@ -15,6 +15,7 @@ from dataclasses import dataclass, field from oumi.core.configs.params.base_params import BaseParams +from oumi.exceptions import OumiConfigValueError @dataclass @@ -73,13 +74,13 @@ def __post_init__(self): and self.repeat >= 0 and self.skip_first >= 0 ): - raise ValueError( + raise OumiConfigValueError( "Invalid profiler schedule arguments. The parameters " "wait: {self.wait}, warmup: {self.warmup}, repeat: {self.repeat}" "skip_first: {self.skip_first} must be non-negative." ) if not (self.active > 0): - raise ValueError( + raise OumiConfigValueError( "Invalid profiler schedule arguments. The parameter " "active: {self.active} must be positive." ) diff --git a/src/oumi/core/configs/params/remote_params.py b/src/oumi/core/configs/params/remote_params.py index ef029027d3..96f81539a7 100644 --- a/src/oumi/core/configs/params/remote_params.py +++ b/src/oumi/core/configs/params/remote_params.py @@ -17,6 +17,7 @@ import numpy as np from oumi.core.configs.params.base_params import BaseParams +from oumi.exceptions import OumiConfigValueError @dataclass @@ -126,37 +127,45 @@ class RemoteParams(BaseParams): def __post_init__(self): """Validate the remote parameters.""" if self.num_workers < 1: - raise ValueError( + raise OumiConfigValueError( "Number of num_workers must be greater than or equal to 1." ) if self.requests_per_minute is not None and self.requests_per_minute < 1: - raise ValueError("requests_per_minute must be greater than or equal to 1.") + raise OumiConfigValueError( + "requests_per_minute must be greater than or equal to 1." + ) if ( self.input_tokens_per_minute is not None and self.input_tokens_per_minute < 1 ): - raise ValueError( + raise OumiConfigValueError( "input_tokens_per_minute must be greater than or equal to 1." ) if ( self.output_tokens_per_minute is not None and self.output_tokens_per_minute < 1 ): - raise ValueError( + raise OumiConfigValueError( "output_tokens_per_minute must be greater than or equal to 1." ) if self.politeness_policy < 0: - raise ValueError("Politeness policy must be greater than or equal to 0.") + raise OumiConfigValueError( + "Politeness policy must be greater than or equal to 0." + ) if self.connection_timeout < 0: - raise ValueError("Connection timeout must be greater than or equal to 0.") + raise OumiConfigValueError( + "Connection timeout must be greater than or equal to 0." + ) if not np.isfinite(self.politeness_policy): - raise ValueError("Politeness policy must be finite.") + raise OumiConfigValueError("Politeness policy must be finite.") if self.max_retries < 0: - raise ValueError("Max retries must be greater than or equal to 0.") + raise OumiConfigValueError( + "Max retries must be greater than or equal to 0." + ) if self.retry_backoff_base <= 0: - raise ValueError("Retry backoff base must be greater than 0.") + raise OumiConfigValueError("Retry backoff base must be greater than 0.") if self.retry_backoff_max < self.retry_backoff_base: - raise ValueError( + raise OumiConfigValueError( "Retry backoff max must be greater than or equal to retry backoff base." ) @@ -237,24 +246,34 @@ class AdaptiveConcurrencyParams(BaseParams): def __post_init__(self): """Validate the adaptive concurrency parameters.""" if self.min_concurrency < 1: - raise ValueError("Min concurrency must be greater than or equal to 1.") + raise OumiConfigValueError( + "Min concurrency must be greater than or equal to 1." + ) if self.max_concurrency < self.min_concurrency: - raise ValueError( + raise OumiConfigValueError( "Max concurrency must be greater than or equal to min concurrency." ) if self.initial_concurrency_factor < 0 or self.initial_concurrency_factor > 1: - raise ValueError("Initial concurrency factor must be between 0 and 1.") + raise OumiConfigValueError( + "Initial concurrency factor must be between 0 and 1." + ) if self.concurrency_step < 1: - raise ValueError("Concurrency step must be greater than or equal to 1.") + raise OumiConfigValueError( + "Concurrency step must be greater than or equal to 1." + ) if self.min_update_time <= 0: - raise ValueError("Min update time must be greater than 0.") + raise OumiConfigValueError("Min update time must be greater than 0.") if self.error_threshold < 0 or self.error_threshold > 1: - raise ValueError("Error threshold must be between 0 and 1.") + raise OumiConfigValueError("Error threshold must be between 0 and 1.") if self.backoff_factor <= 0: - raise ValueError("Backoff factor must be greater than 0.") + raise OumiConfigValueError("Backoff factor must be greater than 0.") if self.recovery_threshold < 0 or self.recovery_threshold > 1: - raise ValueError("Recovery threshold must be between 0 and 1.") + raise OumiConfigValueError("Recovery threshold must be between 0 and 1.") if self.recovery_threshold >= self.error_threshold: - raise ValueError("Recovery threshold must be less than error threshold.") + raise OumiConfigValueError( + "Recovery threshold must be less than error threshold." + ) if self.min_window_size < 1: - raise ValueError("Min window size must be greater than or equal to 1.") + raise OumiConfigValueError( + "Min window size must be greater than or equal to 1." + ) diff --git a/src/oumi/core/configs/params/rule_judge_params.py b/src/oumi/core/configs/params/rule_judge_params.py index 868efeb9d1..b349c98334 100644 --- a/src/oumi/core/configs/params/rule_judge_params.py +++ b/src/oumi/core/configs/params/rule_judge_params.py @@ -20,6 +20,7 @@ JudgeOutputType, JudgeResponseFormat, ) +from oumi.exceptions import OumiConfigValueError @dataclass @@ -82,24 +83,28 @@ def _validate_params(self): ValueError: If parameters are invalid """ if not self.rule_type or not self.rule_type.strip(): - raise ValueError("rule_type cannot be empty") + raise OumiConfigValueError("rule_type cannot be empty") if not self.input_fields: - raise ValueError("input_fields cannot be empty") + raise OumiConfigValueError("input_fields cannot be empty") if not all( isinstance(field, str) and field.strip() for field in self.input_fields ): - raise ValueError("All input_fields must be non-empty strings") + raise OumiConfigValueError("All input_fields must be non-empty strings") if self.judgment_type == JudgeOutputType.ENUM and not self.judgment_scores: - raise ValueError("judgment_scores must be provided for ENUM judgment_type") + raise OumiConfigValueError( + "judgment_scores must be provided for ENUM judgment_type" + ) if self.judgment_scores: if not all( isinstance(score, int | float) for score in self.judgment_scores.values() ): - raise ValueError("All judgment_scores values must be numeric") + raise OumiConfigValueError("All judgment_scores values must be numeric") if len(self.judgment_scores) == 0: - raise ValueError("judgment_scores cannot be empty when provided") + raise OumiConfigValueError( + "judgment_scores cannot be empty when provided" + ) diff --git a/src/oumi/core/configs/params/synthesis_params.py b/src/oumi/core/configs/params/synthesis_params.py index 8de36f75ca..744d604835 100644 --- a/src/oumi/core/configs/params/synthesis_params.py +++ b/src/oumi/core/configs/params/synthesis_params.py @@ -23,6 +23,7 @@ from oumi.core.configs.params.base_params import BaseParams from oumi.core.types.conversation import Conversation, Message, Role +from oumi.exceptions import OumiConfigValueError _SUPPORTED_DATASET_FILE_TYPES = {".jsonl", ".json", ".csv", ".parquet", ".tsv", ".xlsx"} @@ -90,14 +91,14 @@ class DatasetSource: def __post_init__(self): """Verifies/populates params.""" if not self.path: - raise ValueError("DatasetSource.path cannot be empty.") + raise OumiConfigValueError("DatasetSource.path cannot be empty.") file_path = Path(self.path) prefix = self.path.split(":")[0] if prefix == "hf" or prefix == "oumi": return if file_path.suffix.lower() not in _SUPPORTED_DATASET_FILE_TYPES: - raise ValueError( + raise OumiConfigValueError( f"Unsupported dataset file type: {self.path}\n" f"Supported file types: {_SUPPORTED_DATASET_FILE_TYPES}" ) @@ -105,7 +106,7 @@ def __post_init__(self): # Validate dynamic sampling configuration if self.num_shots is not None and self.num_shots > 1: if not self.id: - raise ValueError( + raise OumiConfigValueError( "DatasetSource.id must be set when num_shots > 1 " "for dynamic sampling." ) @@ -146,14 +147,16 @@ class DocumentSegmentationParams: def __post_init__(self): """Verifies/populates params.""" if self.segment_length <= 0: - raise ValueError("Segment length must be positive.") + raise OumiConfigValueError("Segment length must be positive.") if self.segment_overlap < 0: - raise ValueError("Segment overlap must be non-negative.") + raise OumiConfigValueError("Segment overlap must be non-negative.") if self.segment_overlap >= self.segment_length: - raise ValueError("Segment overlap must be less than segment length.") + raise OumiConfigValueError( + "Segment overlap must be less than segment length." + ) if self.segmentation_strategy == SegmentationStrategy.TOKENS: if not self.tokenizer: - raise ValueError( + raise OumiConfigValueError( "DocumentSegmentationParams.tokenizer cannot be empty when " "segmentation_strategy is TOKENS." ) @@ -186,9 +189,9 @@ class DocumentSource: def __post_init__(self): """Verifies/populates params.""" if not self.path: - raise ValueError("DocumentSource.path cannot be empty.") + raise OumiConfigValueError("DocumentSource.path cannot be empty.") if not self.id: - raise ValueError("DocumentSource.id cannot be empty.") + raise OumiConfigValueError("DocumentSource.id cannot be empty.") @dataclass @@ -214,17 +217,17 @@ class ExampleSource: def __post_init__(self): """Verifies/populates params.""" if not self.examples: - raise ValueError("ExampleSource.examples cannot be empty.") + raise OumiConfigValueError("ExampleSource.examples cannot be empty.") keys = self.examples[0].keys() for example in self.examples: if example.keys() != keys: - raise ValueError("All examples must have the same keys.") + raise OumiConfigValueError("All examples must have the same keys.") # Validate dynamic sampling configuration if self.num_shots is not None and self.num_shots > 1: if not self.id: - raise ValueError( + raise OumiConfigValueError( "ExampleSource.id must be set when num_shots > 1 " "for dynamic sampling." ) @@ -252,15 +255,17 @@ class SampledAttributeValue: def __post_init__(self): """Verifies/populates params.""" if not self.id: - raise ValueError("SampledAttributeValue.id cannot be empty.") + raise OumiConfigValueError("SampledAttributeValue.id cannot be empty.") if not self.name: - raise ValueError("SampledAttributeValue.name cannot be empty.") + raise OumiConfigValueError("SampledAttributeValue.name cannot be empty.") if not self.description: - raise ValueError("SampledAttributeValue.description cannot be empty.") + raise OumiConfigValueError( + "SampledAttributeValue.description cannot be empty." + ) if self.sample_rate is not None and ( self.sample_rate < 0 or self.sample_rate > 1 ): - raise ValueError( + raise OumiConfigValueError( "SampledAttributeValue.sample_rate must be between 0 and 1." ) @@ -291,13 +296,15 @@ def get_value_distribution(self) -> dict[str, float]: def __post_init__(self): """Verifies/populates params.""" if not self.id: - raise ValueError("SampledAttribute.id cannot be empty.") + raise OumiConfigValueError("SampledAttribute.id cannot be empty.") if not self.name: - raise ValueError("SampledAttribute.name cannot be empty.") + raise OumiConfigValueError("SampledAttribute.name cannot be empty.") if not self.description: - raise ValueError("SampledAttribute.description cannot be empty.") + raise OumiConfigValueError("SampledAttribute.description cannot be empty.") if not self.possible_values: - raise ValueError("SampledAttribute.possible_values cannot be empty.") + raise OumiConfigValueError( + "SampledAttribute.possible_values cannot be empty." + ) value_ids = [] sample_rates = [] @@ -307,7 +314,9 @@ def __post_init__(self): value_ids_set = set(value_ids) if len(value_ids) != len(value_ids_set): - raise ValueError("SampledAttribute.possible_values must have unique IDs.") + raise OumiConfigValueError( + "SampledAttribute.possible_values must have unique IDs." + ) # Normalize sample rates normalized_sample_rates = [] @@ -321,7 +330,7 @@ def __post_init__(self): undefined_sample_rate_count += 1 if defined_sample_rate > 1.0 and not math.isclose(defined_sample_rate, 1.0): - raise ValueError( + raise OumiConfigValueError( "SampledAttribute.possible_values must sum to at most 1.0." ) @@ -353,24 +362,26 @@ class AttributeCombination: def __post_init__(self): """Verifies/populates params.""" if self.sample_rate < 0 or self.sample_rate > 1: - raise ValueError( + raise OumiConfigValueError( "AttributeCombination.sample_rate must be between 0 and 1." ) if not self.combination: - raise ValueError("AttributeCombination.combination cannot be empty.") + raise OumiConfigValueError( + "AttributeCombination.combination cannot be empty." + ) for key, value in self.combination.items(): if not key: - raise ValueError( + raise OumiConfigValueError( "AttributeCombination.combination key cannot be empty." ) if not value: - raise ValueError( + raise OumiConfigValueError( "AttributeCombination.combination value cannot be empty." ) if len(self.combination.keys()) <= 1: - raise ValueError( + raise OumiConfigValueError( "AttributeCombination.combination must have at least two keys." ) @@ -408,7 +419,7 @@ class GeneratedAttributePostprocessingParams: def __post_init__(self): """Verifies/populates params.""" if not self.id: - raise ValueError( + raise OumiConfigValueError( "GeneratedAttributePostprocessingParams.id cannot be empty." ) @@ -416,7 +427,7 @@ def __post_init__(self): try: re.compile(self.regex) except Exception as e: - raise ValueError( + raise OumiConfigValueError( f"Error compiling GeneratedAttributePostprocessingParams.regex: {e}" ) @@ -437,12 +448,14 @@ class GeneratedAttribute: def __post_init__(self): """Verifies/populates params.""" if not self.id: - raise ValueError("GeneratedAttribute.id cannot be empty.") + raise OumiConfigValueError("GeneratedAttribute.id cannot be empty.") if not self.instruction_messages: - raise ValueError("GeneratedAttribute.instruction_messages cannot be empty.") + raise OumiConfigValueError( + "GeneratedAttribute.instruction_messages cannot be empty." + ) if self.postprocessing_params: if self.id == self.postprocessing_params.id: - raise ValueError( + raise OumiConfigValueError( "GeneratedAttribute.id and " "GeneratedAttributePostprocessingParams.id " "cannot be the same." @@ -477,7 +490,7 @@ class MultiTurnAttribute: def __post_init__(self): """Verifies/populates params.""" if not self.id: - raise ValueError("MultiTurnAttribute.id cannot be empty.") + raise OumiConfigValueError("MultiTurnAttribute.id cannot be empty.") if self.role_instruction_messages: normalized_role_messages: dict[Role, str] = {} for role_key, persona in self.role_instruction_messages.items(): @@ -490,18 +503,18 @@ def __post_init__(self): try: normalized_role = Role(role_key) except ValueError as exc: - raise ValueError( + raise OumiConfigValueError( "MultiTurnAttribute.role_instruction_messages contains " f"unknown role: {role_key}" ) from exc else: - raise ValueError( + raise OumiConfigValueError( "MultiTurnAttribute.role_instruction_messages keys must be " "Role or str values." ) if not isinstance(persona, str): - raise ValueError( + raise OumiConfigValueError( "MultiTurnAttribute.role_instruction_messages values must " "be strings." ) @@ -510,26 +523,28 @@ def __post_init__(self): self.role_instruction_messages = normalized_role_messages if self.min_turns < 1: - raise ValueError("MultiTurnAttribute.min_turns must be at least 1.") + raise OumiConfigValueError( + "MultiTurnAttribute.min_turns must be at least 1." + ) if self.max_turns is not None and self.max_turns < self.min_turns: - raise ValueError( + raise OumiConfigValueError( "MultiTurnAttribute.max_turns must be greater than or equal to " "min_turns." ) if not self.role_instruction_messages: - raise ValueError( + raise OumiConfigValueError( "MultiTurnAttribute.role_instruction_messages cannot be empty." ) required_roles = [Role.USER, Role.ASSISTANT] for role in required_roles: if role not in self.role_instruction_messages: - raise ValueError( + raise OumiConfigValueError( "MultiTurnAttribute.role_instruction_messages must define " f"instructions for role: {role}" ) if not self.role_instruction_messages[role]: - raise ValueError( + raise OumiConfigValueError( "MultiTurnAttribute.role_instruction_messages must include " f"a non-empty persona for role: {role}" ) @@ -538,7 +553,7 @@ def __post_init__(self): not isinstance(self.output_system_prompt, str) or not self.output_system_prompt ): - raise ValueError( + raise OumiConfigValueError( "MultiTurnAttribute.output_system_prompt must be a non-empty " "string." ) @@ -580,7 +595,9 @@ def __post_init__(self): """Verifies/populates params based on the type.""" if self.type == TransformationType.STRING: if self.string_transform is None or self.string_transform == "": - raise ValueError("string_transform cannot be empty when type=STRING") + raise OumiConfigValueError( + "string_transform cannot be empty when type=STRING" + ) # Clear other fields self.list_transform = None self.dict_transform = None @@ -588,7 +605,9 @@ def __post_init__(self): elif self.type == TransformationType.LIST: if not self.list_transform or len(self.list_transform) == 0: - raise ValueError("list_transform cannot be empty when type=LIST") + raise OumiConfigValueError( + "list_transform cannot be empty when type=LIST" + ) # Clear other fields self.string_transform = None self.dict_transform = None @@ -596,7 +615,9 @@ def __post_init__(self): elif self.type == TransformationType.DICT: if not self.dict_transform or len(self.dict_transform) == 0: - raise ValueError("dict_transform cannot be empty when type=DICT") + raise OumiConfigValueError( + "dict_transform cannot be empty when type=DICT" + ) # Clear other fields self.string_transform = None self.list_transform = None @@ -604,15 +625,21 @@ def __post_init__(self): elif self.type == TransformationType.CHAT: if not self.chat_transform or len(self.chat_transform.messages) == 0: - raise ValueError("chat_transform cannot be empty when type=CHAT") + raise OumiConfigValueError( + "chat_transform cannot be empty when type=CHAT" + ) messages = self.chat_transform.messages for message in messages: content = message.content if not isinstance(content, str): - raise ValueError("chat_transform message content must be a string") + raise OumiConfigValueError( + "chat_transform message content must be a string" + ) if not content: - raise ValueError("chat_transform message content cannot be empty") + raise OumiConfigValueError( + "chat_transform message content cannot be empty" + ) # Clear other fields self.string_transform = None @@ -633,10 +660,10 @@ class TransformedAttribute: def __post_init__(self): """Verifies/populates params.""" if not self.id: - raise ValueError("TransformedAttribute.id cannot be empty.") + raise OumiConfigValueError("TransformedAttribute.id cannot be empty.") if not isinstance(self.transformation_strategy, TransformationStrategy): - raise ValueError( + raise OumiConfigValueError( "TransformedAttribute.transformation_strategy must be a " f"TransformationStrategy, got {type(self.transformation_strategy)}" ) @@ -803,12 +830,12 @@ def _get_reserved_attribute_ids(self) -> set[str]: def _check_attribute_ids(self, attribute_ids: set[str], id: str): """Check if the attribute ID is already in the set.""" if id in self._reserved_attribute_ids: - raise ValueError( + raise OumiConfigValueError( f"GeneralSynthesisParams does not allow '{id}' " "as an attribute ID because it is reserved for multiturn synthesis." ) if id in attribute_ids: - raise ValueError( + raise OumiConfigValueError( f"GeneralSynthesisParams contains duplicate attribute IDs: {id}" ) attribute_ids.add(id) @@ -901,7 +928,7 @@ def _check_combination_sampling_sample_rates(self) -> None: combination.sample_rate for combination in self.combination_sampling ] if sum(sample_rates) > 1.0: - raise ValueError( + raise OumiConfigValueError( "GeneralSynthesisParams.combination_sampling sample rates must be " "less than or equal to 1.0." ) diff --git a/src/oumi/core/configs/params/test_params.py b/src/oumi/core/configs/params/test_params.py index 393022ef4d..15f7a11760 100644 --- a/src/oumi/core/configs/params/test_params.py +++ b/src/oumi/core/configs/params/test_params.py @@ -35,6 +35,7 @@ from typing import Any from oumi.core.configs.params.base_params import BaseParams +from oumi.exceptions import OumiConfigValueError class TestType(str, Enum): @@ -212,10 +213,10 @@ class TestParams(BaseParams): def __finalize_and_validate__(self) -> None: """Validate test configuration based on test type.""" if not self.id: - raise ValueError("Test 'id' is required.") + raise OumiConfigValueError("Test 'id' is required.") if not self.type: - raise ValueError(f"Test 'type' is required for test '{self.id}'.") + raise OumiConfigValueError(f"Test 'type' is required for test '{self.id}'.") self._validate_enum_field("type", TestType, "test type") self._validate_enum_field("severity", TestSeverity, "severity") @@ -236,7 +237,7 @@ def _validate_enum_field( value = getattr(self, field_name) valid_values = [e.value for e in enum_class] if value not in valid_values: - raise ValueError( + raise OumiConfigValueError( f"Invalid {label} '{value}' for test '{self.id}'. " f"Valid values: {valid_values}" ) @@ -251,7 +252,7 @@ def _validate_by_type(self) -> None: for field_name in validation_rules.get("required", []): value = getattr(self, field_name) if value is None or (isinstance(value, str) and not value): - raise ValueError( + raise OumiConfigValueError( f"Test '{self.id}': '{field_name}' is required for " f"{self.type} tests." ) @@ -260,7 +261,7 @@ def _validate_by_type(self) -> None: for field_group in validation_rules.get("either_required", []): if not any(getattr(self, f) is not None for f in field_group): fields_str = "' or '".join(field_group) - raise ValueError( + raise OumiConfigValueError( f"Test '{self.id}': Either '{fields_str}' " f"is required for {self.type} tests." ) @@ -271,7 +272,7 @@ def _validate_by_type(self) -> None: ).items(): value = getattr(self, field_name) if value and value not in valid_values: - raise ValueError( + raise OumiConfigValueError( f"Test '{self.id}': Invalid {field_name} '{value}'. " f"Valid values: {valid_values}" ) @@ -283,7 +284,7 @@ def _validate_by_type(self) -> None: enum_class = globals()[enum_name] valid_values = [e.value for e in enum_class] if value not in valid_values: - raise ValueError( + raise OumiConfigValueError( f"Test '{self.id}': Invalid {field_name} '{value}'. " f"Valid values: {valid_values}" ) @@ -293,7 +294,7 @@ def _validate_by_type(self) -> None: if custom_validator: result = custom_validator(self) if isinstance(result, str): - raise ValueError(f"Test '{self.id}': {result}") + raise OumiConfigValueError(f"Test '{self.id}': {result}") def get_title(self) -> str: """Get the display title for this test.""" diff --git a/src/oumi/core/configs/params/training_params.py b/src/oumi/core/configs/params/training_params.py index 464bf48097..a00dc126b0 100644 --- a/src/oumi/core/configs/params/training_params.py +++ b/src/oumi/core/configs/params/training_params.py @@ -31,6 +31,7 @@ from oumi.core.configs.params.grpo_params import GrpoParams from oumi.core.configs.params.profiler_params import ProfilerParams from oumi.core.configs.params.telemetry_params import TelemetryParams +from oumi.exceptions import OumiConfigValueError from oumi.utils.str_utils import sanitize_run_name @@ -813,7 +814,7 @@ def to_hf(self, training_config: Optional["TrainingConfig"] = None): if isinstance(self.dataloader_num_workers, int): dataloader_num_workers = self.dataloader_num_workers else: - raise ValueError( + raise OumiConfigValueError( "Unexpected type of dataloader_num_workers: " f"{type(self.dataloader_num_workers)} " f"({self.dataloader_num_workers}). Must be `int`." @@ -893,7 +894,7 @@ def to_hf(self, training_config: Optional["TrainingConfig"] = None): grpo_kwargs.keys() ) if len(conflicting_keys) > 0: - raise ValueError( + raise OumiConfigValueError( "trainer_kwargs attempt to override the following " f"GRPO kwargs: {conflicting_keys}. " "Use properties of GrpoParams instead." @@ -906,7 +907,7 @@ def to_hf(self, training_config: Optional["TrainingConfig"] = None): gkd_kwargs.keys() ) if len(conflicting_keys) > 0: - raise ValueError( + raise OumiConfigValueError( "trainer_kwargs attempt to override the following " f"GKD kwargs: {conflicting_keys}. " "Use properties of GkdParams instead." @@ -919,7 +920,7 @@ def to_hf(self, training_config: Optional["TrainingConfig"] = None): gold_kwargs.keys() ) if len(conflicting_keys) > 0: - raise ValueError( + raise OumiConfigValueError( "trainer_kwargs attempt to override the following " f"GOLD kwargs: {conflicting_keys}. " "Use properties of GoldParams instead." @@ -1015,19 +1016,19 @@ def __post_init__(self): if isinstance(self.dataloader_num_workers, str) and not ( self.dataloader_num_workers == "auto" ): - raise ValueError( + raise OumiConfigValueError( "Unknown value of " f"dataloader_num_workers: {self.dataloader_num_workers}" ) if self.gradient_accumulation_steps < 1: - raise ValueError("gradient_accumulation_steps must be >= 1.") + raise OumiConfigValueError("gradient_accumulation_steps must be >= 1.") if self.max_grad_norm is not None and self.max_grad_norm < 0: - raise ValueError("max_grad_norm must be >= 0.") + raise OumiConfigValueError("max_grad_norm must be >= 0.") if not (self.max_steps > 0 or self.num_train_epochs > 0): - raise ValueError( + raise OumiConfigValueError( "At least one of max_steps and num_train_epochs must be positive. " f"Actual: max_steps: {self.max_steps}, " f"num_train_epochs: {self.num_train_epochs}." @@ -1039,12 +1040,12 @@ def __post_init__(self): self.trainer_type not in (TrainerType.TRL_GRPO, TrainerType.VERL_GRPO) and len(function_names) > 0 ): - raise ValueError( + raise OumiConfigValueError( "reward_functions may only be defined for the TRL_GRPO or VERL_GRPO" f"trainers. Actual: {self.trainer_type}" ) if self.trainer_type == TrainerType.VERL_GRPO and len(function_names) > 1: - raise ValueError( + raise OumiConfigValueError( "VERL_GRPO only supports a single reward function. " f"Actual: {function_names}" ) @@ -1052,7 +1053,7 @@ def __post_init__(self): TrainerType.TRL_GRPO, TrainerType.VERL_GRPO, ): - raise ValueError( + raise OumiConfigValueError( "reward_function_kwargs is only supported for the TRL_GRPO or " "VERL_GRPO trainers. Either remove reward_function_kwargs or set " f"trainer_type accordingly. Actual: {self.trainer_type}" @@ -1063,7 +1064,7 @@ def __post_init__(self): self.trainer_type == TrainerType.TRL_GRPO and self.include_performance_metrics ): - raise ValueError( + raise OumiConfigValueError( "`include_performance_metrics` is not supported for TRL_GRPO trainer." ) diff --git a/src/oumi/core/configs/params/tuning_params.py b/src/oumi/core/configs/params/tuning_params.py index 1bd3299d39..0dca3b4b61 100644 --- a/src/oumi/core/configs/params/tuning_params.py +++ b/src/oumi/core/configs/params/tuning_params.py @@ -24,6 +24,7 @@ TrainingParams, ) from oumi.core.registry import REGISTRY, RegistryType +from oumi.exceptions import OumiConfigValueError from oumi.utils.logging import logger from oumi.utils.str_utils import sanitize_run_name @@ -52,14 +53,14 @@ def verify_param_spec( ) -> None: """Verifies that a parameter specification is valid.""" if param_name not in valid_training_params: - raise ValueError( + raise OumiConfigValueError( f"Invalid tunable parameter: {param_name}. " f"Must be a valid `{type_name}` field." ) elif isinstance(param_spec, dict): # Validate required keys if "type" not in param_spec: - raise ValueError( + raise OumiConfigValueError( f"Tunable parameter '{param_name}' must have 'type' key" ) @@ -70,21 +71,21 @@ def verify_param_spec( param_type = ParamType(param_type_str) except ValueError: valid_types = [t.value for t in ParamType] - raise ValueError( + raise OumiConfigValueError( f"Invalid type '{param_type_str}' for parameter" f" '{param_name}'. Must be one of: {valid_types}" ) # Validate based on parameter type if param_type == ParamType.CATEGORICAL: if "choices" not in param_spec: - raise ValueError( + raise OumiConfigValueError( f"Categorical parameter '{param_name}' must have 'choices' key" ) if ( not isinstance(param_spec["choices"], list) or len(param_spec["choices"]) == 0 ): - raise ValueError( + raise OumiConfigValueError( f"Categorical parameter '{param_name}' must have" " non-empty choices list" ) @@ -92,11 +93,13 @@ def verify_param_spec( # All other types need low and high required_keys = {"low", "high"} if not required_keys.issubset(param_spec.keys()): - raise ValueError( + raise OumiConfigValueError( f"Parameter '{param_name}' must have 'low' and 'high' keys" ) else: - raise ValueError(f"Tunable parameter '{param_name}' must be a dict") + raise OumiConfigValueError( + f"Tunable parameter '{param_name}' must be a dict" + ) @dataclass @@ -284,7 +287,7 @@ def __post_init__(self): # Validate logging strategy valid_logging_strategies = {"trials", "epoch", "no"} if self.logging_strategy not in valid_logging_strategies: - raise ValueError( + raise OumiConfigValueError( f"Invalid logging_strategy: {self.logging_strategy}. " f"Choose from {valid_logging_strategies}." ) @@ -298,7 +301,7 @@ def __post_init__(self): "Applying it to all evaluation_metrics." ) else: - raise ValueError( + raise OumiConfigValueError( "Length of evaluation_metrics must match length of " "evaluation_direction, or evaluation_direction must be of length 1." ) @@ -306,7 +309,7 @@ def __post_init__(self): # Validate each evaluation direction for direction in self.evaluation_direction: if direction not in {"minimize", "maximize"}: - raise ValueError( + raise OumiConfigValueError( f"Invalid evaluation_direction: {direction}. " 'Choose either "minimize" or "maximize".' ) @@ -321,7 +324,7 @@ def __post_init__(self): # Validate trainer type # TODO: Add more options in the future. if self.trainer_type != TrainerType.TRL_SFT: - raise ValueError( + raise OumiConfigValueError( f"Invalid trainer_type: {self.trainer_type}. " f"Choose from {[t.value for t in [TrainerType.TRL_SFT]]}." ) @@ -334,7 +337,7 @@ def __post_init__(self): # Verify fixed training params keys are valid TrainingParams fields for param_name in self.fixed_training_params.keys(): if param_name not in valid_training_params: - raise ValueError( + raise OumiConfigValueError( f"Invalid fixed parameter: {param_name}. " f"Must be a valid `TrainingParams` field." ) @@ -352,7 +355,7 @@ def __post_init__(self): # Verify fixed training params keys are valid PEFT fields for param_name in self.fixed_peft_params.keys(): if param_name not in valid_training_params: - raise ValueError( + raise OumiConfigValueError( f"Invalid fixed parameter: {param_name}. " f"Must be a valid `PeftParams` field." ) @@ -385,7 +388,7 @@ def __post_init__(self): available = sorted( REGISTRY.get_all(RegistryType.EVALUATION_FUNCTION).keys() ) - raise ValueError( + raise OumiConfigValueError( "Unregistered custom_eval_metrics detected: " f"{unknown}. Available evaluation functions: {available}" ) diff --git a/src/oumi/core/configs/quantization_config.py b/src/oumi/core/configs/quantization_config.py index abf866361e..051d8c8838 100644 --- a/src/oumi/core/configs/quantization_config.py +++ b/src/oumi/core/configs/quantization_config.py @@ -16,6 +16,7 @@ from oumi.core.configs.base_config import BaseConfig from oumi.core.configs.params.model_params import ModelParams +from oumi.exceptions import OumiConfigValueError @dataclass @@ -76,14 +77,14 @@ def __post_init__(self): # Validate output format if self.output_format not in SUPPORTED_OUTPUT_FORMATS: - raise ValueError( + raise OumiConfigValueError( f"Unsupported output format: {self.output_format}. " f"Must be one of: {SUPPORTED_OUTPUT_FORMATS}." ) # Validate quantization method if self.method not in SUPPORTED_METHODS: - raise ValueError( + raise OumiConfigValueError( f"Unsupported quantization method: {self.method}. " f"Must be one of: {SUPPORTED_METHODS}." ) diff --git a/src/oumi/core/configs/synthesis_config.py b/src/oumi/core/configs/synthesis_config.py index 2ec7646d9b..dfd41b5df2 100644 --- a/src/oumi/core/configs/synthesis_config.py +++ b/src/oumi/core/configs/synthesis_config.py @@ -18,6 +18,7 @@ from oumi.core.configs.base_config import BaseConfig from oumi.core.configs.inference_config import InferenceConfig from oumi.core.configs.params.synthesis_params import GeneralSynthesisParams +from oumi.exceptions import OumiConfigValueError class SynthesisStrategy(str, Enum): @@ -56,21 +57,23 @@ def __post_init__(self): if self.strategy == SynthesisStrategy.GENERAL: pass else: - raise ValueError(f"Unsupported synthesis strategy: {self.strategy}") + raise OumiConfigValueError( + f"Unsupported synthesis strategy: {self.strategy}" + ) if self.inference_config.input_path is not None: - raise ValueError( + raise OumiConfigValueError( "Input path is not supported for general synthesis strategy." ) if self.inference_config.output_path is not None: - raise ValueError( + raise OumiConfigValueError( "Output path is not supported for general synthesis strategy." ) if self.output_path is not None: if self.output_path == "": - raise ValueError("Output path cannot be empty.") + raise OumiConfigValueError("Output path cannot be empty.") if not self.output_path.endswith(".jsonl"): - raise ValueError("Output path must end with .jsonl.") + raise OumiConfigValueError("Output path must end with .jsonl.") diff --git a/src/oumi/core/configs/training_config.py b/src/oumi/core/configs/training_config.py index 4b7b77421a..3903a13946 100644 --- a/src/oumi/core/configs/training_config.py +++ b/src/oumi/core/configs/training_config.py @@ -29,6 +29,7 @@ TrainerType, TrainingParams, ) +from oumi.exceptions import OumiConfigValueError from oumi.utils.logging import logger @@ -89,20 +90,20 @@ class TrainingConfig(BaseConfig): def __post_init__(self): """Verifies/populates params.""" if self.model.compile: - raise ValueError( + raise OumiConfigValueError( "Use `training.compile` instead of `model.compile` to " "enable model compilation during training." ) if self.training.compile and ( self.fsdp.use_orig_params is not None and not self.fsdp.use_orig_params ): - raise ValueError( + raise OumiConfigValueError( "`fsdp.use_orig_params` must be True for model compilation." ) # Validate distributed training configurations if self.fsdp.enable_fsdp and self.deepspeed.enable_deepspeed: - raise ValueError( + raise OumiConfigValueError( "Cannot enable both FSDP and DeepSpeed simultaneously. " "Please enable only one distributed training method." ) @@ -122,7 +123,7 @@ def __post_init__(self): ) and self.deepspeed.train_batch_size != "auto" ): - raise ValueError( + raise OumiConfigValueError( f"When using TRL trainer ({trainer_type}) with DeepSpeed, " "train_batch_size must be set to 'auto' to allow proper batch size " "management. " @@ -135,7 +136,7 @@ def __post_init__(self): MixedPrecisionDtype.BF16, ]: if self.model.torch_dtype != torch.float32: - raise ValueError( + raise OumiConfigValueError( "Model must be loaded in fp32 to enable mixed precision training." ) @@ -188,7 +189,7 @@ def __post_init__(self): # We need to Liger patch ourselves for our own training loop. pass else: - raise ValueError("Unrecognized trainer type!") + raise OumiConfigValueError("Unrecognized trainer type!") # Setup and validate params for "vision_language_sft" collator. # The collator expects VLM SFT dataset to only produce just @@ -201,7 +202,7 @@ def __post_init__(self): self.data.test.datasets, ): if not dataset_params.dataset_kwargs.get("return_conversations", True): - raise ValueError( + raise OumiConfigValueError( "`return_conversations` must be True " f"for the dataset '{dataset_params.dataset_name}' " f"when using '{collator_name}' collator!" @@ -210,7 +211,7 @@ def __post_init__(self): # Extra setup for TRL_SFT. if trainer_type == TrainerType.TRL_SFT: if self.training.trainer_kwargs.get("remove_unused_columns", False): - raise ValueError( + raise OumiConfigValueError( "`remove_unused_columns` must be False " f"when using '{collator_name}' collator! " 'The "unused" columns are consumed by the collator, ' @@ -252,6 +253,6 @@ def __post_init__(self): self.training.trainer_type == TrainerType.VERL_GRPO and not self.data.validation.datasets ): - raise ValueError( + raise OumiConfigValueError( "At least one validation dataset is required for VERL_GRPO training." ) diff --git a/src/oumi/exceptions.py b/src/oumi/exceptions.py index dd883dd864..4655a82772 100644 --- a/src/oumi/exceptions.py +++ b/src/oumi/exceptions.py @@ -23,5 +23,8 @@ class OumiConfigError(Exception): """Base class for all configuration-related errors.""" +class OumiConfigValueError(OumiConfigError, ValueError): + """A configuration value error.""" + class HardwareException(Exception): """An exception thrown for invalid hardware configurations."""