diff --git a/src/rai_core/pyproject.toml b/src/rai_core/pyproject.toml index 4cfb7ac77..638241bf2 100644 --- a/src/rai_core/pyproject.toml +++ b/src/rai_core/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rai_core" -version = "2.10.1" +version = "2.11.0" description = "Core functionality for RAI framework" readme = "README.md" requires-python = ">=3.10,<3.13" diff --git a/src/rai_core/rai/initialization/model_initialization.py b/src/rai_core/rai/initialization/model_initialization.py index e16eede95..3be3901e0 100644 --- a/src/rai_core/rai/initialization/model_initialization.py +++ b/src/rai_core/rai/initialization/model_initialization.py @@ -96,6 +96,24 @@ class RAIConfig: tracing: TracingConfig +# Default placeholder configs for vendors not present in the TOML file. +_DEFAULT_AWS = AWSConfig( + simple_model="", complex_model="", embeddings_model="", region_name="" +) +_DEFAULT_OPENAI = OpenAIConfig( + simple_model="", complex_model="", embeddings_model="", base_url="" +) +_DEFAULT_OLLAMA = OllamaConfig( + simple_model="", complex_model="", embeddings_model="", base_url="" +) +_DEFAULT_GOOGLE = GoogleConfig(simple_model="", complex_model="", embeddings_model="") +_DEFAULT_TRACING = TracingConfig( + project="", + langfuse=LangfuseConfig(use_langfuse=False, host=""), + langsmith=LangsmithConfig(use_langsmith=False, host=""), +) + + def load_config(config_path: Optional[str] = None) -> RAIConfig: if config_path is None: with open("config.toml", "rb") as f: @@ -103,17 +121,42 @@ def load_config(config_path: Optional[str] = None) -> RAIConfig: else: with open(config_path, "rb") as f: config_dict = tomli.load(f) - return RAIConfig( - vendor=VendorConfig(**config_dict["vendor"]), - aws=AWSConfig(**config_dict["aws"]), - openai=OpenAIConfig(**config_dict["openai"]), - ollama=OllamaConfig(**config_dict["ollama"]), - google=GoogleConfig(**config_dict["google"]), - tracing=TracingConfig( + + # Only require config sections for vendors actually referenced in [vendor]. + # Missing sections get safe defaults so single-vendor setups work. + aws = AWSConfig(**config_dict["aws"]) if "aws" in config_dict else _DEFAULT_AWS + openai = ( + OpenAIConfig(**config_dict["openai"]) + if "openai" in config_dict + else _DEFAULT_OPENAI + ) + ollama = ( + OllamaConfig(**config_dict["ollama"]) + if "ollama" in config_dict + else _DEFAULT_OLLAMA + ) + google = ( + GoogleConfig(**config_dict["google"]) + if "google" in config_dict + else _DEFAULT_GOOGLE + ) + + if "tracing" in config_dict: + tracing = TracingConfig( project=config_dict["tracing"]["project"], langfuse=LangfuseConfig(**config_dict["tracing"]["langfuse"]), langsmith=LangsmithConfig(**config_dict["tracing"]["langsmith"]), - ), + ) + else: + tracing = _DEFAULT_TRACING + + return RAIConfig( + vendor=VendorConfig(**config_dict["vendor"]), + aws=aws, + openai=openai, + ollama=ollama, + google=google, + tracing=tracing, ) diff --git a/tests/initialization/test_model_initialization.py b/tests/initialization/test_model_initialization.py index 920050213..eafdad88e 100644 --- a/tests/initialization/test_model_initialization.py +++ b/tests/initialization/test_model_initialization.py @@ -195,3 +195,54 @@ def test_get_embeddings_model_return_kwargs_openai(monkeypatch, tmp_path): assert kwargs["base_url"] == "https://openai.example/v1/" assert kwargs["vendor"] == "openai" assert "class" in kwargs + + +def test_load_config_allows_missing_unused_vendor_sections(tmp_path): + sparse_config = """ +[vendor] +simple_model = "openai" +complex_model = "openai" +embeddings_model = "openai" + +[openai] +simple_model = "gpt-4o-mini" +complex_model = "gpt-4o" +embeddings_model = "text-embedding-3-small" +base_url = "https://openai.example/v1/" +""" + config_path = write_config(tmp_path / "config.toml", sparse_config) + + config = model_initialization.load_config(str(config_path)) + + assert config.openai.simple_model == "gpt-4o-mini" + assert config.openai.complex_model == "gpt-4o" + assert config.openai.embeddings_model == "text-embedding-3-small" + assert config.openai.base_url == "https://openai.example/v1/" + assert config.aws.simple_model == "" + assert config.aws.region_name == "" + assert config.ollama.base_url == "" + assert config.google.embeddings_model == "" + + +def test_load_config_uses_default_tracing_when_section_missing(tmp_path): + sparse_config = """ +[vendor] +simple_model = "openai" +complex_model = "openai" +embeddings_model = "openai" + +[openai] +simple_model = "gpt-4o-mini" +complex_model = "gpt-4o" +embeddings_model = "text-embedding-3-small" +base_url = "https://openai.example/v1/" +""" + config_path = write_config(tmp_path / "config.toml", sparse_config) + + config = model_initialization.load_config(str(config_path)) + + assert config.tracing.project == "" + assert config.tracing.langfuse.use_langfuse is False + assert config.tracing.langfuse.host == "" + assert config.tracing.langsmith.use_langsmith is False + assert config.tracing.langsmith.host == ""