diff --git a/docs/index.md b/docs/index.md index e2d64a7b41..fa0ae12736 100644 --- a/docs/index.md +++ b/docs/index.md @@ -106,6 +106,7 @@ faq/oom development/dev_setup development/contributing +development/agentic_synthesis_environments development/code_of_conduct development/style_guide development/docs_guide diff --git a/docs/user_guides/synth.md b/docs/user_guides/synth.md index 97437fe20f..df276b58b9 100644 --- a/docs/user_guides/synth.md +++ b/docs/user_guides/synth.md @@ -164,6 +164,104 @@ Ready to dive deeper? The sections below cover all available options in detail. --- +## Environment-First Tool Synthesis + +Agentic synthesis now follows an environment-first model. Tools do not declare an output strategy directly. Instead, each tool is bound to an environment, and the environment type defines how tool calls are executed via its `step()` method. + +- **`synthetic` environments** are backed by an LLM that simulates tool execution. They can be stateless (no persistent state) or stateful (mutable JSON state across turns). Statefulness is controlled by the optional `state_params` field — when provided, the environment tracks and mutates state across calls; when absent, each call is independent. +- **`deterministic` environments** behave like lookup tables. Each tool defines a set of input-to-output mappings, and `step()` resolves tool calls by matching arguments against those mappings. No LLM is involved. + +At the config level: + +- Environments own their tool definitions. +- Reusable environment catalogs live in top-level `environment_config` or `environment_config_path`. +- Tools do not declare an `environment` field. The parent environment owns the binding. +- `deterministic_outputs` is only used for tools in `deterministic` environments. +- `read_only` is only meaningful for tools in stateful `synthetic` environments. +- Multiturn attributes reference environments (not individual tools) to select which tools are available. + +Example: + +```yaml +environment_config: + environments: + - id: support_backend + name: Support Backend + description: Simulated support system with tickets and users + type: synthetic + system_prompt: You manage a customer support system with tickets and users. + state_params: + state_schema: + type: object + properties: + tickets: { type: array } + users: { type: array } + initial_state: + tickets: [] + users: [] + tools: + - id: get_ticket + name: GetTicket + description: Read a ticket from the support backend. + read_only: true + parameters: + type: object + properties: + ticket_id: { type: string } + - id: create_ticket + name: CreateTicket + description: Create a new support ticket. + read_only: false + parameters: + type: object + properties: + subject: { type: string } + priority: { type: string, enum: [low, medium, high] } + + - id: faq_lookup + name: FAQ Lookup + description: Cached LLM-backed FAQ answers + type: synthetic + system_prompt: Generate concise FAQ answers grounded in the tool contract. + cache_by_input: true + tools: + - id: answer_faq + name: AnswerFAQ + description: Answer common support questions. + parameters: + type: object + properties: + question: { type: string } + + - id: policy_table + name: Policy Table + description: Predefined policy responses + type: deterministic + tools: + - id: get_refund_policy + name: GetRefundPolicy + description: Return the matching refund policy. + parameters: + type: object + properties: + policy_type: { type: string } + deterministic_outputs: + - input: + policy_type: standard + output: + policy: Standard 30-day refund policy + +strategy_params: + multiturn_attributes: + - id: support_chat + min_turns: 2 + max_turns: 4 + role_instruction_messages: + USER: You are a customer contacting support. + ASSISTANT: You are a helpful support agent. + available_environments: [support_backend, faq_lookup, policy_table] +``` + ## Complete Configuration Reference ### Top-Level Parameters diff --git a/pyproject.toml b/pyproject.toml index 531b5cd87d..30256df801 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dependencies = [ "hdrhistogram>=0.10,<0.11", "httpx>=0.27,<1.0", # Used by deploy module (async HTTP client) "jsonlines", + "jsonpatch>=1.33,<2.0", "lm_eval[wandb]>=0.4,<0.5.0", "mlflow>=3.1", # >=3.1.4 requires Python3.10>= "numpy>=1.26,<2.4", # verl==0.5.0 depends on numpy<2.0.0 diff --git a/src/oumi/core/configs/__init__.py b/src/oumi/core/configs/__init__.py index 8ed6769f5d..068483ce01 100644 --- a/src/oumi/core/configs/__init__.py +++ b/src/oumi/core/configs/__init__.py @@ -83,6 +83,7 @@ ) from oumi.core.configs.async_evaluation_config import AsyncEvaluationConfig from oumi.core.configs.base_config import BaseConfig +from oumi.core.configs.environment_config import EnvironmentConfig from oumi.core.configs.evaluation_config import EvaluationConfig from oumi.core.configs.inference_config import InferenceConfig from oumi.core.configs.inference_engine_type import InferenceEngineType @@ -123,7 +124,9 @@ from oumi.core.configs.params.profiler_params import ProfilerParams from oumi.core.configs.params.remote_params import RemoteParams from oumi.core.configs.params.synthesis_params import ( - AttributeCombination, + DatasetSource as DatasetSourceParam, +) +from oumi.core.configs.params.synthesis_params import ( DocumentSegmentationParams, DocumentSource, ExampleSource, @@ -140,9 +143,6 @@ TransformationType, TransformedAttribute, ) -from oumi.core.configs.params.synthesis_params import ( - DatasetSource as DatasetSourceParam, -) from oumi.core.configs.params.telemetry_params import TelemetryParams from oumi.core.configs.params.training_params import ( MixedPrecisionDtype, @@ -175,6 +175,7 @@ "EvaluationConfig", "EvaluationBackend", "EvaluationConfig", + "EnvironmentConfig", "EvaluationTaskParams", "FSDPParams", "GenerationParams", @@ -209,7 +210,6 @@ "TunerType", "TuningConfig", "TuningParams", - "AttributeCombination", "DatasetSourceParam", "DocumentSegmentationParams", "DocumentSource", diff --git a/src/oumi/core/configs/environment_config.py b/src/oumi/core/configs/environment_config.py new file mode 100644 index 0000000000..a832f5a4ec --- /dev/null +++ b/src/oumi/core/configs/environment_config.py @@ -0,0 +1,139 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration for agentic environments.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from oumi.core.configs.base_config import BaseConfig + +if TYPE_CHECKING: + from oumi.environments.base_environment import BaseEnvironment + from oumi.environments.base_tool import Tool + + +@dataclass +class EnvironmentConfig(BaseConfig): + """Top-level config for environment-first tool definitions.""" + + environments: list[Any] = field(default_factory=list) + """Reusable environments and their owned tools.""" + + def __post_init__(self): + """Verifies/populates params.""" + self.environments = [ + self._coerce_environment(environment) for environment in self.environments + ] + + env_ids: set[str] = set() + tool_ids: set[str] = set() + + for environment in self.environments: + if environment.id in env_ids: + raise ValueError( + f"EnvironmentConfig.environments contains duplicate " + f"environment id '{environment.id}'." + ) + env_ids.add(environment.id) + + for tool in environment.tools: + if tool.id in tool_ids: + raise ValueError( + f"EnvironmentConfig.environments contains duplicate " + f"tool id '{tool.id}'." + ) + tool_ids.add(tool.id) + + @property + def all_tools(self) -> list[Tool]: + """Flatten all tools across environments.""" + return [tool for environment in self.environments for tool in environment.tools] + + @property + def tool_environment_map(self) -> dict[str, str]: + """Map each tool id to the environment that owns it.""" + return { + tool.id: environment.id + for environment in self.environments + for tool in environment.tools + } + + def get_environment(self, environment_id: str) -> BaseEnvironment | None: + """Look up an environment by id.""" + for environment in self.environments: + if environment.id == environment_id: + return environment + return None + + def get_tool(self, tool_id: str) -> Tool | None: + """Look up a tool by id.""" + for tool in self.all_tools: + if tool.id == tool_id: + return tool + return None + + def resolve_tools( + self, + environment_ids: list[str] | None = None, + tool_ids: list[str] | None = None, + ) -> list[Tool]: + """Resolve tools from selected environments and optional tool ids. + + Raises: + ValueError: If any environment_id or tool_id is not found. + """ + all_env_ids = {env.id for env in self.environments} + + if environment_ids: + unknown_envs = set(environment_ids) - all_env_ids + if unknown_envs: + raise ValueError( + f"Unknown environment id(s): {sorted(unknown_envs)}. " + f"Defined: {sorted(all_env_ids)}" + ) + selected_environment_ids = environment_ids + else: + selected_environment_ids = list(all_env_ids) + + selected_environments = [ + environment + for environment in self.environments + if environment.id in set(selected_environment_ids) + ] + tools = [ + tool for environment in selected_environments for tool in environment.tools + ] + + if tool_ids: + available_tool_ids = {tool.id for tool in tools} + unknown_tools = set(tool_ids) - available_tool_ids + if unknown_tools: + raise ValueError( + f"Unknown tool id(s): {sorted(unknown_tools)}. " + f"Available in selected environments: " + f"{sorted(available_tool_ids)}" + ) + allowed_tool_ids = set(tool_ids) + tools = [tool for tool in tools if tool.id in allowed_tool_ids] + + return tools + + def _coerce_environment(self, environment: Any) -> BaseEnvironment: + """Coerce a raw dict or environment instance into a concrete environment.""" + from oumi.environments.base_environment import BaseEnvironment + + return BaseEnvironment.create(environment) diff --git a/src/oumi/core/configs/params/synthesis_params.py b/src/oumi/core/configs/params/synthesis_params.py index 8de36f75ca..6bb8eecbbe 100644 --- a/src/oumi/core/configs/params/synthesis_params.py +++ b/src/oumi/core/configs/params/synthesis_params.py @@ -474,6 +474,16 @@ class MultiTurnAttribute: Allows user to specify custom instructions for the planner while planning out the conversation.""" + available_environments: list[str] = field(default_factory=list) + """List of environment ids availabe in this conversation.""" + + available_tools: list[str] = field(default_factory=list) + """List of tool ids available in this conversation.""" + + max_tool_calls_per_turn: int = 50 + """Safety ceiling for tool calls per ASSISTANT turn. The agent naturally stops + when it decides no more tools are needed. This only prevents runaway loops.""" + def __post_init__(self): """Verifies/populates params.""" if not self.id: @@ -543,6 +553,29 @@ def __post_init__(self): "string." ) + if self.available_tools is not None: + if not isinstance(self.available_tools, list): + raise ValueError( + "MultiTurnAttribute.available_tools must be a list of tool names." + ) + for tool in self.available_tools: + if not isinstance(tool, str): + raise ValueError( + "MultiTurnAttribute.available_tools must be a list of strings." + ) + if self.available_environments is not None: + if not isinstance(self.available_environments, list): + raise ValueError( + "MultiTurnAttribute.available_environments must be a list of " + "environment ids." + ) + for environment in self.available_environments: + if not isinstance(environment, str): + raise ValueError( + "MultiTurnAttribute.available_environments must be a list " + "of strings." + ) + class TransformationType(str, Enum): """Types of transformation strategies.""" diff --git a/src/oumi/core/configs/synthesis_config.py b/src/oumi/core/configs/synthesis_config.py index 2ec7646d9b..5da1acd6f8 100644 --- a/src/oumi/core/configs/synthesis_config.py +++ b/src/oumi/core/configs/synthesis_config.py @@ -12,12 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from dataclasses import dataclass, field from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Any from oumi.core.configs.base_config import BaseConfig from oumi.core.configs.inference_config import InferenceConfig -from oumi.core.configs.params.synthesis_params import GeneralSynthesisParams +from oumi.core.configs.params.synthesis_params import ( + GeneralSynthesisParams, + MultiTurnAttribute, +) + +if TYPE_CHECKING: + from oumi.core.configs.environment_config import EnvironmentConfig + from oumi.environments.base_environment import BaseEnvironment + from oumi.environments.base_tool import Tool class SynthesisStrategy(str, Enum): @@ -45,6 +57,12 @@ class SynthesisConfig(BaseConfig): ) """The synthesis strategy parameters to use.""" + environment_config: Any | None = None + """Reusable environment-first tool configuration (EnvironmentConfig).""" + + environment_config_path: str | None = None + """Optional path to an EnvironmentConfig YAML file.""" + inference_config: InferenceConfig = field(default_factory=InferenceConfig) """The inference configuration to use.""" @@ -74,3 +92,117 @@ def __post_init__(self): if not self.output_path.endswith(".jsonl"): raise ValueError("Output path must end with .jsonl.") + + self.environment_config = self._resolve_environment_config() + self._validate_available_tooling() + + def _resolve_environment_config(self) -> EnvironmentConfig | None: + """Resolve top-level environment configuration.""" + if ( + self.environment_config is not None + and self.environment_config_path is not None + ): + raise ValueError( + "SynthesisConfig.environment_config and " + "SynthesisConfig.environment_config_path cannot both be set." + ) + + if self.environment_config is not None: + return self.environment_config + + if self.environment_config_path is not None: + if self.environment_config_path == "": + raise ValueError( + "SynthesisConfig.environment_config_path cannot be empty." + ) + + config_path = Path(self.environment_config_path) + if not config_path.exists(): + raise ValueError( + f"Environment config path does not exist: " + f"{self.environment_config_path}" + ) + from oumi.core.configs.environment_config import EnvironmentConfig + + return EnvironmentConfig.from_yaml(config_path) + + return None + + def resolve_multiturn_environments( + self, multiturn_attribute: MultiTurnAttribute + ) -> list[BaseEnvironment]: + """Resolve the environments available to a multiturn attribute.""" + if self.environment_config is None: + return [] + + if not multiturn_attribute.available_environments: + return list(self.environment_config.environments) + + resolved_environments: list[BaseEnvironment] = [] + for environment_id in multiturn_attribute.available_environments: + environment = self.environment_config.get_environment(environment_id) + if environment is None: + raise ValueError( + f"MultiTurnAttribute '{multiturn_attribute.id}' references unknown " + f"environment '{environment_id}'. Defined environment ids: " + f"{sorted(env.id for env in self.environment_config.environments)}" + ) + resolved_environments.append(environment) + return resolved_environments + + def resolve_multiturn_tools( + self, multiturn_attribute: MultiTurnAttribute + ) -> list[Tool]: + """Resolve the tools available to a multiturn attribute.""" + if self.environment_config is None: + return [] + + environments = self.resolve_multiturn_environments(multiturn_attribute) + return self.environment_config.resolve_tools( + environment_ids=[environment.id for environment in environments], + tool_ids=multiturn_attribute.available_tools or None, + ) + + def _validate_available_tooling(self) -> None: + """Validate multiturn environment/tool selections against the catalog.""" + if not self.strategy_params.multiturn_attributes: + return + + all_referenced_tools = [ + tool_id + for mt_attr in self.strategy_params.multiturn_attributes + for tool_id in mt_attr.available_tools + ] + all_referenced_environments = [ + environment_id + for mt_attr in self.strategy_params.multiturn_attributes + for environment_id in mt_attr.available_environments + ] + if not all_referenced_tools and not all_referenced_environments: + return + + if self.environment_config is None: + raise ValueError( + "Environment or tool references require " + "SynthesisConfig.environment_config, or " + "SynthesisConfig.environment_config_path." + ) + + for mt_attr in self.strategy_params.multiturn_attributes: + selected_environments = self.resolve_multiturn_environments(mt_attr) + selected_environment_ids = { + environment.id for environment in selected_environments + } + selected_tools = self.environment_config.resolve_tools( + environment_ids=list(selected_environment_ids) + ) + selected_tool_ids = {tool.id for tool in selected_tools} + + for tool_id in mt_attr.available_tools: + if tool_id not in selected_tool_ids: + raise ValueError( + f"MultiTurnAttribute '{mt_attr.id}' references unknown " + f"tool '{tool_id}' for environments " + f"{sorted(selected_environment_ids)}. Defined tool ids: " + f"{sorted(selected_tool_ids)}" + ) diff --git a/src/oumi/core/synthesis/conversation_synthesizer.py b/src/oumi/core/synthesis/conversation_synthesizer.py index 59f09082a8..f0b230ceef 100644 --- a/src/oumi/core/synthesis/conversation_synthesizer.py +++ b/src/oumi/core/synthesis/conversation_synthesizer.py @@ -15,6 +15,7 @@ import random from oumi.builders.inference_engines import build_inference_engine +from oumi.core.configs.environment_config import EnvironmentConfig from oumi.core.configs.inference_config import InferenceConfig from oumi.core.configs.inference_engine_type import InferenceEngineType from oumi.core.configs.params.synthesis_params import ( @@ -23,6 +24,7 @@ ) from oumi.core.synthesis.attribute_formatter import AttributeFormatter from oumi.core.types.conversation import Conversation, Message, Role +from oumi.environments import Tool from oumi.utils.logging import logger from oumi.utils.str_utils import extract_json @@ -39,9 +41,11 @@ def __init__( self, params: GeneralSynthesisParams, inference_config: InferenceConfig, + environment_config: EnvironmentConfig | None = None, ): """Initialize the synthesizer.""" self._params = params + self._environment_config = environment_config self._formatter = AttributeFormatter(params) self._inference_engine = build_inference_engine( @@ -52,6 +56,17 @@ def __init__( self._inference_config = inference_config self._default_turn_order = [Role.USER, Role.ASSISTANT] + def _resolve_available_tools( + self, multiturn_attribute: MultiTurnAttribute + ) -> list[Tool]: + """Resolve tools for a multiturn attribute from selected environments.""" + if self._environment_config is None: + return [] + return self._environment_config.resolve_tools( + environment_ids=multiturn_attribute.available_environments or None, + tool_ids=multiturn_attribute.available_tools or None, + ) + def _validate_roles(self, multiturn_attribute: MultiTurnAttribute) -> None: """Validate that required roles have corresponding personas. @@ -98,6 +113,13 @@ def synthesize( f"Synthesizing {len(samples)} conversations for " f"attribute '{multiturn_attributes.id}'" ) + available_tools = self._resolve_available_tools(multiturn_attributes) + if available_tools: + logger.debug( + "Resolved tools for '%s': %s", + multiturn_attributes.id, + [tool.id for tool in available_tools], + ) samples = self._plan_samples(samples, multiturn_attributes) conversations = self._synthesize_all_samples(samples, multiturn_attributes) diff --git a/src/oumi/core/synthesis/synthesis_pipeline.py b/src/oumi/core/synthesis/synthesis_pipeline.py index 7f68969393..1950b7813d 100644 --- a/src/oumi/core/synthesis/synthesis_pipeline.py +++ b/src/oumi/core/synthesis/synthesis_pipeline.py @@ -46,7 +46,11 @@ def __init__(self, config: SynthesisConfig): else None ) self._conversation_synthesizer = ( - ConversationSynthesizer(config.strategy_params, config.inference_config) + ConversationSynthesizer( + config.strategy_params, + config.inference_config, + environment_config=config.environment_config, + ) if config.strategy_params.multiturn_attributes else None ) @@ -89,6 +93,22 @@ def synthesize(self) -> list[dict[str, Any]]: continue sample.update(result) + required_multiturn_ids = { + attr.id for attr in self._config.strategy_params.multiturn_attributes + } + before_count = len(dataset) + dataset = [ + sample + for sample in dataset + if all(attr_id in sample for attr_id in required_multiturn_ids) + ] + dropped_count = before_count - len(dataset) + if dropped_count > 0: + logger.warning( + f"Dropped {dropped_count} sample(s) missing required " + f"multiturn outputs: {sorted(required_multiturn_ids)}" + ) + # Add the transformed attributes to the dataset logger.info("Adding transformed attributes") if self._config.strategy_params.transformed_attributes: diff --git a/src/oumi/environments/__init__.py b/src/oumi/environments/__init__.py new file mode 100644 index 0000000000..f84a0785e7 --- /dev/null +++ b/src/oumi/environments/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Environments for agentic tool interactions.""" + +from oumi.environments.base_environment import BaseEnvironment +from oumi.environments.base_tool import ( + DeterministicToolOutput, + Tool, + ToolResult, + ToolSchema, +) +from oumi.environments.deterministic_environment import DeterministicEnvironment +from oumi.environments.synthetic_environment import ( + SyntheticEnvironment, + SyntheticStateParams, +) + +__all__ = [ + "BaseEnvironment", + "Tool", + "ToolSchema", + "ToolResult", + "SyntheticEnvironment", + "SyntheticStateParams", + "DeterministicEnvironment", + "DeterministicToolOutput", +] diff --git a/src/oumi/environments/base_environment.py b/src/oumi/environments/base_environment.py new file mode 100644 index 0000000000..1a15f9d3db --- /dev/null +++ b/src/oumi/environments/base_environment.py @@ -0,0 +1,122 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Abstract base class for tool environments.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Mapping +from dataclasses import dataclass, field, fields +from typing import Any, ClassVar + +from oumi.core.configs.params.base_params import BaseParams +from oumi.environments.base_tool import Tool, ToolResult + + +@dataclass +class BaseEnvironment(BaseParams, ABC): + """Abstract base class for tool environments.""" + + _registry: ClassVar[dict[str, type[BaseEnvironment]]] = {} + + id: str + name: str + description: str + tools: list[Tool] = field(default_factory=list) + + def __init_subclass__(cls, **kwargs): + """Register subclass in the environment type registry.""" + super().__init_subclass__(**kwargs) + environment_type = getattr(cls, "ENVIRONMENT_TYPE", None) + if environment_type is not None: + cls._registry[environment_type] = cls + + def __post_init__(self): + """Validate common fields and coerce tools.""" + if not self.id: + raise ValueError(f"{type(self).__name__}.id cannot be empty.") + if not self.name: + raise ValueError(f"{type(self).__name__}.name cannot be empty.") + if not self.description: + raise ValueError(f"{type(self).__name__}.description cannot be empty.") + self.tools = self._coerce_tools(self.tools) + self._validate_unique_tool_ids() + + def _validate_unique_tool_ids(self) -> None: + tool_ids: set[str] = set() + for tool in self.tools: + if tool.id in tool_ids: + raise ValueError( + f"{type(self).__name__} '{self.id}' contains duplicate " + f"tool id '{tool.id}'." + ) + tool_ids.add(tool.id) + + def _coerce_tools(self, tools: list[Any]) -> list[Tool]: + """Coerce raw tool definitions into Tool instances.""" + return [Tool.create(tool) for tool in tools] + + @abstractmethod + def step(self, tool_id: str, arguments: dict[str, Any]) -> ToolResult: + """Execute a tool call within this environment.""" + + def _get_tool(self, tool_id: str) -> Tool | None: + """Look up a tool owned by this environment.""" + for tool in self.tools: + if tool.id == tool_id: + return tool + return None + + def _get_tool_or_raise(self, tool_id: str) -> Tool: + tool = self._get_tool(tool_id) + if tool is None: + raise ValueError( + f"Tool '{tool_id}' not found in environment '{self.id}'. " + f"Available tools: {[tool.id for tool in self.tools]}" + ) + return tool + + @classmethod + def create(cls, raw: Any) -> BaseEnvironment: + """Create a concrete environment from raw config data. + + Raises: + TypeError: If raw is not a mapping or BaseEnvironment. + ValueError: If type is missing or unsupported. + """ + if isinstance(raw, BaseEnvironment): + return raw + if not isinstance(raw, Mapping): + raise TypeError( + "Environment definitions must be environment objects or mappings, " + f"got {type(raw)}" + ) + environment_type = raw.get("type") + if environment_type is None: + raise ValueError( + "Environment definition must include a 'type' field. " + f"Supported types: {sorted(cls._registry)}" + ) + if not isinstance(environment_type, str): + environment_type = str(environment_type) + environment_cls = cls._registry.get(environment_type) + if environment_cls is None: + raise ValueError(f"Unsupported environment type: {environment_type}") + init_fields = { + field_def.name for field_def in fields(environment_cls) if field_def.init + } + return environment_cls( + **{key: value for key, value in raw.items() if key in init_fields} + ) diff --git a/src/oumi/environments/base_tool.py b/src/oumi/environments/base_tool.py new file mode 100644 index 0000000000..7becfcc080 --- /dev/null +++ b/src/oumi/environments/base_tool.py @@ -0,0 +1,178 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tool definitions and execution results shared by all environment types.""" + +from __future__ import annotations + +import json +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Any + +from oumi.core.configs.params.base_params import BaseParams + + +@dataclass +class DeterministicToolOutput(BaseParams): + """An input-to-output mapping for a deterministic tool.""" + + input: dict[str, Any] = field(default_factory=dict) + output: dict[str, Any] = field(default_factory=dict) + + def matches(self, arguments: dict[str, Any]) -> bool: + """Check if the input matches the given arguments.""" + return json.dumps(self.input, sort_keys=True) == json.dumps( + arguments, sort_keys=True + ) + + +@dataclass +class ToolSchema(BaseParams): + """JSON schema for tool inputs or outputs.""" + + type: str = "object" + properties: dict[str, ToolSchema] = field(default_factory=dict) + description: str | None = None + required: list[str] = field(default_factory=list) + + @classmethod + def create(cls, raw: Any) -> ToolSchema: + """Create a schema from raw config data.""" + if isinstance(raw, ToolSchema): + return raw + if not isinstance(raw, Mapping): + raise TypeError( + f"Tool schema definitions must be schema objects or mappings, got " + f"{type(raw)}" + ) + return cls( + type=raw.get("type", "object"), + properties={ + key: cls.create(value) + for key, value in raw.get("properties", {}).items() + }, + description=raw.get("description"), + required=raw.get("required", []), + ) + + def __post_init__(self): + """Validate schema fields.""" + if not self.type: + raise ValueError(f"{type(self).__name__}.type cannot be empty.") + if not isinstance(self.properties, dict): + raise ValueError(f"{type(self).__name__}.properties must be a dict.") + if not all(isinstance(value, ToolSchema) for value in self.properties.values()): + raise ValueError( + f"{type(self).__name__}.properties values must be ToolSchema." + ) + if self.description is not None and not isinstance(self.description, str): + raise ValueError( + f"{type(self).__name__}.description must be a string when specified." + ) + if not isinstance(self.required, list) or not all( + isinstance(param, str) for param in self.required + ): + raise ValueError( + f"{type(self).__name__}.required must be a list of strings." + ) + missing_required = set(self.required) - set(self.properties) + if missing_required: + raise ValueError( + f"{type(self).__name__}.required contains unknown properties: " + f"{sorted(missing_required)}" + ) + + def to_dict(self) -> dict[str, Any]: + """Convert to a JSON-schema-shaped dict.""" + schema: dict[str, Any] = {"type": self.type} + if self.properties: + schema["properties"] = { + key: value.to_dict() for key, value in self.properties.items() + } + if self.description is not None: + schema["description"] = self.description + if self.required: + schema["required"] = list(self.required) + return schema + + +@dataclass +class Tool(BaseParams): + """Tool schema owned by an environment.""" + + id: str + name: str + description: str + parameters: ToolSchema = field(default_factory=ToolSchema) + output_schema: ToolSchema | None = None + read_only: bool = True + deterministic_outputs: list[DeterministicToolOutput] = field(default_factory=list) + + @classmethod + def create(cls, raw: Any) -> Tool: + """Create a tool from raw config data.""" + if isinstance(raw, Tool): + return raw + if not isinstance(raw, Mapping): + raise TypeError( + f"Tool definitions must be tool objects or mappings, got {type(raw)}" + ) + return cls( + id=raw["id"], + name=raw["name"], + description=raw["description"], + parameters=ToolSchema.create(raw.get("parameters", {})), + output_schema=( + ToolSchema.create(raw["output_schema"]) + if raw.get("output_schema") is not None + else None + ), + read_only=raw.get("read_only", True), + deterministic_outputs=raw.get("deterministic_outputs", []), + ) + + def __post_init__(self): + """Validate common tool fields.""" + if not self.id: + raise ValueError(f"{type(self).__name__}.id cannot be empty.") + if not self.name: + raise ValueError(f"{type(self).__name__}.name cannot be empty.") + if not self.description: + raise ValueError(f"{type(self).__name__}.description cannot be empty.") + if not isinstance(self.parameters, ToolSchema): + self.parameters = ToolSchema.create(self.parameters) + if self.output_schema is not None and not isinstance( + self.output_schema, ToolSchema + ): + self.output_schema = ToolSchema.create(self.output_schema) + + def to_llm_schema(self) -> dict[str, Any]: + """Export a provider-agnostic schema for LLM tool registration.""" + schema: dict[str, Any] = { + "name": self.name, + "description": self.description, + "parameters": self.parameters.to_dict(), + } + if self.output_schema is not None: + schema["output_schema"] = self.output_schema.to_dict() + return schema + + +@dataclass +class ToolResult(BaseParams): + """Result returned by an environment step.""" + + output: str | dict[str, Any] | None + updated_state: dict[str, Any] | None = None diff --git a/src/oumi/environments/deterministic_environment.py b/src/oumi/environments/deterministic_environment.py new file mode 100644 index 0000000000..5c4c1dea93 --- /dev/null +++ b/src/oumi/environments/deterministic_environment.py @@ -0,0 +1,73 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Deterministic environment with fixed lookup responses.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import Any, ClassVar + +from oumi.environments.base_environment import BaseEnvironment +from oumi.environments.base_tool import DeterministicToolOutput, Tool, ToolResult + + +@dataclass +class DeterministicEnvironment(BaseEnvironment): + """Environment that resolves tools from fixed lookups.""" + + ENVIRONMENT_TYPE: ClassVar[str] = "deterministic" + type: str = field(init=False, default=ENVIRONMENT_TYPE) + + def _coerce_tools(self, tools: list[Any]) -> list[Tool]: + """Coerce tools and deterministic outputs into typed objects.""" + coerced_tools: list[Tool] = [] + for raw_tool in tools: + tool = Tool.create(raw_tool) + tool.deterministic_outputs = [ + entry + if isinstance(entry, DeterministicToolOutput) + else DeterministicToolOutput(**entry) + for entry in tool.deterministic_outputs + ] + coerced_tools.append(tool) + return coerced_tools + + def __post_init__(self): + """Validate that deterministic tools have deterministic output entry.""" + super().__post_init__() + for tool in self.tools: + if not tool.deterministic_outputs: + raise ValueError( + f"Deterministic tool '{tool.id}' must have at least one " + "deterministic_output entry." + ) + seen: set[str] = set() + for entry in tool.deterministic_outputs: + key = json.dumps(entry.input, sort_keys=True) + if key in seen: + raise ValueError( + f"Deterministic tool '{tool.id}' has duplicate " + f"deterministic input entry: {entry.input}" + ) + seen.add(key) + + def step(self, tool_id: str, arguments: dict[str, Any]) -> ToolResult: + """Resolve a deterministic tool call to its output.""" + tool = self._get_tool_or_raise(tool_id) + for entry in tool.deterministic_outputs: + if entry.matches(arguments): + return ToolResult(output=entry.output) + return ToolResult(output=None) diff --git a/src/oumi/environments/synthetic_environment.py b/src/oumi/environments/synthetic_environment.py new file mode 100644 index 0000000000..2b2c63decb --- /dev/null +++ b/src/oumi/environments/synthetic_environment.py @@ -0,0 +1,163 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Synthetic environment backed by LLM-simulated tool execution.""" + +from __future__ import annotations + +import copy +import json +from dataclasses import dataclass, field +from typing import Any, ClassVar + +from oumi.core.configs.params.base_params import BaseParams +from oumi.environments.base_environment import BaseEnvironment +from oumi.environments.base_tool import ToolResult + + +def _validate_json_schema_value( + value: Any, + schema: dict[str, Any], + path: str = "$", +) -> None: + """Validate a JSON-like value against a minimal schema subset.""" + expected_type = schema.get("type") + if expected_type == "object": + if not isinstance(value, dict): + raise ValueError(f"{path} must be an object.") + required = schema.get("required", []) + for key in required: + if key not in value: + raise ValueError(f"{path}.{key} is required.") + for key, child_value in value.items(): + child_schema = schema.get("properties", {}).get(key) + if child_schema is not None: + _validate_json_schema_value(child_value, child_schema, f"{path}.{key}") + elif expected_type == "array": + if not isinstance(value, list): + raise ValueError(f"{path} must be an array.") + item_schema = schema.get("items") + if item_schema is not None: + for idx, item in enumerate(value): + _validate_json_schema_value(item, item_schema, f"{path}[{idx}]") + elif expected_type == "string": + if not isinstance(value, str): + raise ValueError(f"{path} must be a string.") + elif expected_type == "integer": + if isinstance(value, bool) or not isinstance(value, int): + raise ValueError(f"{path} must be an integer.") + elif expected_type == "number": + if isinstance(value, bool) or not isinstance(value, (int, float)): + raise ValueError(f"{path} must be a number.") + elif expected_type == "boolean": + if not isinstance(value, bool): + raise ValueError(f"{path} must be a boolean.") + elif expected_type == "null": + if value is not None: + raise ValueError(f"{path} must be null.") + + enum_values = schema.get("enum") + if enum_values is not None and value not in enum_values: + raise ValueError(f"{path} must be one of {enum_values}.") + + +@dataclass +class SyntheticStateParams(BaseParams): + """Optional state configuration for a synthetic environment.""" + + state_schema: dict[str, Any] | None = None + initial_state: dict[str, Any] | None = None + + def __post_init__(self): + """Validate state config consistency.""" + if self.state_schema is not None and self.initial_state is not None: + _validate_json_schema_value(self.initial_state, self.state_schema) + + +@dataclass +class SyntheticEnvironment(BaseEnvironment): + """LLM-simulated environment with optional mutable state.""" + + ENVIRONMENT_TYPE: ClassVar[str] = "synthetic" + type: str = field(init=False, default=ENVIRONMENT_TYPE) + system_prompt: str = "" + state_params: SyntheticStateParams | dict[str, Any] | None = None + cache_by_input: bool = True + + def __post_init__(self): + """Validate synthetic-only fields after common environment validation.""" + if isinstance(self.state_params, dict): + self.state_params = SyntheticStateParams(**self.state_params) + super().__post_init__() + self._cache: dict[str, ToolResult] = {} + self._state: dict[str, Any] | None = ( + copy.deepcopy(self.state_params.initial_state) + if self.state_params is not None + and self.state_params.initial_state is not None + else None + ) + if not self.system_prompt: + raise ValueError("SyntheticEnvironment.system_prompt cannot be empty.") + if self.state_params is not None and self.cache_by_input: + raise ValueError( + "SyntheticEnvironment.cache_by_input must be False when " + "state_params is provided." + ) + for tool in self.tools: + if tool.deterministic_outputs: + raise ValueError( + f"Synthetic tool '{tool.id}' cannot define deterministic_outputs." + ) + + @property + def current_state(self) -> dict[str, Any] | None: + """Return the current in-memory state snapshot.""" + if self._state is None: + return None + return copy.deepcopy(self._state) + + @staticmethod + def _cache_key(tool_id: str, arguments: dict[str, Any]) -> str: + """Build a stable cache key from tool id and arguments.""" + return f"{tool_id}::{json.dumps(arguments, sort_keys=True)}" + + def _resolve_cached( + self, tool_id: str, arguments: dict[str, Any] + ) -> ToolResult | None: + """Look up a cached result for the given tool call.""" + if not self.cache_by_input: + return None + result = self._cache.get(self._cache_key(tool_id, arguments)) + if result is None: + return None + return ToolResult( + output=copy.deepcopy(result.output), + updated_state=copy.deepcopy(result.updated_state), + ) + + def _cache_result( + self, tool_id: str, arguments: dict[str, Any], result: ToolResult + ) -> None: + """Store a generated result in the cache.""" + if not self.cache_by_input: + return + self._cache[self._cache_key(tool_id, arguments)] = ToolResult( + output=copy.deepcopy(result.output), + updated_state=copy.deepcopy(result.updated_state), + ) + + def step(self, tool_id: str, arguments: dict[str, Any]) -> ToolResult: + """Execute a synthetic tool call.""" + self._get_tool_or_raise(tool_id) + raise NotImplementedError("SyntheticEnvironment.step() is not implemented yet.") diff --git a/src/oumi/utils/json_patch.py b/src/oumi/utils/json_patch.py new file mode 100644 index 0000000000..7aa7641fa1 --- /dev/null +++ b/src/oumi/utils/json_patch.py @@ -0,0 +1,39 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RFC 6902 JSON Patch utilities.""" + +from typing import Any + + +class JsonPatchError(Exception): + """Raised when a JSON Patch is malformed or cannot be applied.""" + + +class JsonPatchValidationError(Exception): + """Raised when the patched document fails JSON Schema validation.""" + + +def apply_json_patch( + document: dict[str, Any], + patch: list[dict[str, Any]], + schema: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Apply an RFC 6902 JSON Patch to a document.""" + raise NotImplementedError + + +def parse_patch_response(text: str) -> list[dict[str, Any]] | None: + """Extract a JSON Patch array from LLM-generated text.""" + raise NotImplementedError diff --git a/tests/unit/core/configs/params/test_tool_params.py b/tests/unit/core/configs/params/test_tool_params.py new file mode 100644 index 0000000000..6590c81c85 --- /dev/null +++ b/tests/unit/core/configs/params/test_tool_params.py @@ -0,0 +1,579 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import pytest + +from oumi.core.configs.environment_config import EnvironmentConfig +from oumi.environments import ( + BaseEnvironment, + DeterministicEnvironment, + DeterministicToolOutput, + SyntheticEnvironment, + SyntheticStateParams, + Tool, + ToolResult, + ToolSchema, +) + + +def _make_deterministic_tool(**overrides: Any) -> Tool: + defaults: dict[str, Any] = dict( + id="tool1", + name="MyTool", + description="A tool", + deterministic_outputs=[ + DeterministicToolOutput(input={"id": "01"}, output={"msg": "ok"}), + ], + ) + defaults.update(overrides) + return Tool(**defaults) + + +def _make_synthetic_tool(**overrides: Any) -> Tool: + defaults: dict[str, Any] = dict( + id="tool2", + name="GenTool", + description="A generated tool", + ) + defaults.update(overrides) + return Tool(**defaults) + + +def _make_state_schema() -> dict[str, Any]: + return { + "type": "object", + "properties": { + "files": { + "type": "object", + "properties": {"count": {"type": "integer"}}, + "required": ["count"], + } + }, + "required": ["files"], + } + + +def test_deterministic_tool_output_allows_empty_input(): + entry = DeterministicToolOutput(input={}, output={"msg": "ok"}) + assert entry.input == {} + + +def test_deterministic_tool_output_allows_empty_output(): + entry = DeterministicToolOutput(input={"id": "1"}, output={}) + assert entry.output == {} + + +def test_deterministic_tool_output_matches_exact(): + entry = DeterministicToolOutput( + input={"id": "01", "status": "pending"}, + output={"message": "Order is pending"}, + ) + assert entry.matches({"id": "01", "status": "pending"}) is True + assert entry.matches({"status": "pending", "id": "01"}) is True + + +def test_deterministic_tool_output_no_match(): + entry = DeterministicToolOutput( + input={"id": "01"}, + output={"message": "ok"}, + ) + assert entry.matches({"id": "02"}) is False + assert entry.matches({"id": "01", "extra": "arg"}) is False + + +@pytest.mark.parametrize("field,value", [("id", ""), ("name", ""), ("description", "")]) +def test_tool_empty_field_raises(field, value): + with pytest.raises(ValueError, match=f"{field} cannot be empty"): + Tool(**{"id": "t", "name": "T", "description": "d", **{field: value}}) + + +def test_tool_to_llm_schema(): + tool = Tool( + id="search", + name="Search", + description="Search the catalog.", + parameters=ToolSchema( + type="object", + properties={"query": ToolSchema(type="string")}, + required=["query"], + ), + ) + assert tool.to_llm_schema() == { + "name": "Search", + "description": "Search the catalog.", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + } + + +def test_tool_to_llm_schema_includes_output_schema(): + tool = Tool( + id="search", + name="Search", + description="Search the catalog.", + parameters=ToolSchema(type="object", properties={}), + output_schema=ToolSchema( + type="object", + properties={"result": ToolSchema(type="string")}, + ), + ) + assert tool.to_llm_schema() == { + "name": "Search", + "description": "Search the catalog.", + "parameters": {"type": "object"}, + "output_schema": { + "type": "object", + "properties": {"result": {"type": "string"}}, + }, + } + + +def test_tool_create_coerces_deterministic_outputs(): + env = BaseEnvironment.create( + { + "id": "lookup", + "type": "deterministic", + "name": "Lookup", + "description": "Lookup tools", + "tools": [ + { + "id": "policy", + "name": "Policy", + "description": "Look up policy.", + "deterministic_outputs": [ + {"input": {"id": "1"}, "output": {"result": "ok"}} + ], + } + ], + } + ) + assert isinstance(env, DeterministicEnvironment) + assert isinstance(env.tools[0].deterministic_outputs[0], DeterministicToolOutput) + + +def test_tool_create_reads_extended_tool_fields(): + tool = Tool.create( + { + "id": "policy", + "name": "Policy", + "description": "Look up policy.", + "parameters": { + "type": "object", + "properties": { + "policy_id": {"type": "string"}, + }, + "required": ["policy_id"], + }, + "output_schema": {"type": "object", "properties": {}}, + } + ) + assert tool.parameters.required == ["policy_id"] + assert tool.output_schema == ToolSchema(type="object", properties={}) + + +def test_tool_schema_to_dict(): + schema = ToolSchema( + type="object", + properties={ + "query": ToolSchema(type="string"), + }, + description="Tool input schema.", + required=["query"], + ) + assert schema.to_dict() == { + "type": "object", + "properties": {"query": {"type": "string"}}, + "description": "Tool input schema.", + "required": ["query"], + } + + +def test_tool_schema_create_recursively_coerces_nested_properties(): + schema = ToolSchema.create( + { + "type": "object", + "properties": { + "customer": { + "type": "object", + "properties": { + "email": {"type": "string", "description": "Customer email."} + }, + "required": ["email"], + } + }, + "required": ["customer"], + } + ) + assert isinstance(schema.properties["customer"], ToolSchema) + assert isinstance(schema.properties["customer"].properties["email"], ToolSchema) + assert schema.to_dict() == { + "type": "object", + "properties": { + "customer": { + "type": "object", + "properties": { + "email": { + "type": "string", + "description": "Customer email.", + } + }, + "required": ["email"], + } + }, + "required": ["customer"], + } + + +def test_tool_schema_required_must_exist_in_properties(): + with pytest.raises(ValueError, match="required contains unknown properties"): + ToolSchema( + type="object", + properties={"query": ToolSchema(type="string")}, + required=["missing"], + ) + + +def test_synthetic_state_params_validates_initial_state_against_schema(): + with pytest.raises(ValueError, match=r"\$\.files\.count must be an integer"): + SyntheticStateParams( + state_schema=_make_state_schema(), + initial_state={"files": {"count": "bad"}}, + ) + + +def test_synthetic_state_params_accepts_partial_inputs(): + assert ( + SyntheticStateParams(state_schema=_make_state_schema()).state_schema is not None + ) + assert SyntheticStateParams( + initial_state={"files": {"count": 1}} + ).initial_state == {"files": {"count": 1}} + + +def test_synthetic_environment_valid_stateless(): + env = SyntheticEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[Tool(id="answer", name="Answer", description="Answer a FAQ.")], + ) + assert env.type == "synthetic" + assert env.state_params is None + assert env.current_state is None + assert isinstance(env.tools[0], Tool) + + +def test_synthetic_environment_valid_stateful(): + env = SyntheticEnvironment( + id="filesystem", + name="Filesystem", + description="A simple filesystem", + system_prompt="You manage a filesystem.", + state_params=SyntheticStateParams( + state_schema=_make_state_schema(), + initial_state={"files": {"count": 1}}, + ), + cache_by_input=False, + tools=[Tool(id="read", name="Read", description="Read files.")], + ) + assert env.current_state == {"files": {"count": 1}} + + +def test_synthetic_environment_coerces_dict_tools(): + env = SyntheticEnvironment( + id="fs", + name="FS", + description="d", + system_prompt="p", + tools=[{"id": "read", "name": "Read", "description": "Read files."}], # type: ignore[list-item] + ) + assert isinstance(env.tools[0], Tool) + + +def test_synthetic_environment_empty_system_prompt_raises(): + with pytest.raises(ValueError, match="system_prompt cannot be empty"): + SyntheticEnvironment(id="x", name="n", description="d", system_prompt="") + + +def test_synthetic_environment_rejects_deterministic_outputs(): + with pytest.raises(ValueError, match="cannot define deterministic_outputs"): + SyntheticEnvironment( + id="x", + name="n", + description="d", + system_prompt="p", + tools=[_make_deterministic_tool()], + ) + + +def test_synthetic_environment_rejects_cache_when_stateful(): + with pytest.raises(ValueError, match="cache_by_input must be False"): + SyntheticEnvironment( + id="x", + name="n", + description="d", + system_prompt="p", + state_params=SyntheticStateParams(), + cache_by_input=True, + ) + + +def test_synthetic_environment_cache_round_trip(): + env = SyntheticEnvironment( + id="weather", + name="Weather", + description="Weather API", + system_prompt="Simulate weather.", + cache_by_input=True, + tools=[_make_synthetic_tool(id="get_weather")], + ) + result = ToolResult(output={"temp": 72}) + env._cache_result("get_weather", {"city": "SF"}, result) + cached = env._resolve_cached("get_weather", {"city": "SF"}) + assert cached == result + assert cached is not result + + +def test_synthetic_environment_step_unknown_tool_raises(): + env = SyntheticEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[_make_synthetic_tool(id="answer")], + ) + with pytest.raises(ValueError, match="Tool 'missing' not found"): + env.step("missing", {}) + + +def test_synthetic_environment_step_known_tool_is_stub(): + env = SyntheticEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[_make_synthetic_tool(id="answer")], + ) + with pytest.raises(NotImplementedError, match="not implemented yet"): + env.step("answer", {}) + + +def test_deterministic_environment_valid(): + env = DeterministicEnvironment( + id="lookup", + name="Lookup", + description="A deterministic lookup environment", + tools=[ + Tool( + id="policy", + name="Policy", + description="Look up policy.", + deterministic_outputs=[ + DeterministicToolOutput( + input={"id": "1"}, + output={"result": "ok"}, + ) + ], + ) + ], + ) + assert env.type == "deterministic" + assert isinstance(env.tools[0], Tool) + + +def test_deterministic_environment_requires_outputs_on_tool(): + with pytest.raises(ValueError, match="must have at least one"): + DeterministicEnvironment( + id="det_env", + name="Deterministic", + description="d", + tools=[_make_deterministic_tool(deterministic_outputs=[])], + ) + + +def test_deterministic_environment_duplicate_inputs_raises(): + outputs = [ + DeterministicToolOutput(input={"id": "01"}, output={"msg": "a"}), + DeterministicToolOutput(input={"id": "01"}, output={"msg": "b"}), + ] + with pytest.raises(ValueError, match="duplicate"): + DeterministicEnvironment( + id="det_env", + name="Deterministic", + description="d", + tools=[_make_deterministic_tool(deterministic_outputs=outputs)], + ) + + +def test_deterministic_environment_step_match(): + env = DeterministicEnvironment( + id="lookup", + name="Lookup", + description="A deterministic lookup environment", + tools=[ + _make_deterministic_tool( + deterministic_outputs=[ + DeterministicToolOutput( + input={"id": "01"}, output={"msg": "pending"} + ), + DeterministicToolOutput( + input={"id": "02"}, output={"msg": "delivered"} + ), + ] + ) + ], + ) + assert env.step("tool1", {"id": "01"}) == ToolResult(output={"msg": "pending"}) + assert env.step("tool1", {"id": "02"}) == ToolResult(output={"msg": "delivered"}) + + +def test_deterministic_environment_step_no_match(): + env = DeterministicEnvironment( + id="lookup", + name="Lookup", + description="A deterministic lookup environment", + tools=[_make_deterministic_tool()], + ) + assert env.step("tool1", {"id": "99"}) == ToolResult(output=None) + + +def test_deterministic_environment_supports_empty_argument_match(): + env = DeterministicEnvironment( + id="lookup", + name="Lookup", + description="A deterministic lookup environment", + tools=[ + Tool( + id="ping", + name="Ping", + description="Zero-arg tool.", + deterministic_outputs=[ + DeterministicToolOutput(input={}, output={}), + ], + ) + ], + ) + assert env.step("ping", {}) == ToolResult(output={}) + + +def test_environment_empty_id_raises(): + with pytest.raises(ValueError, match="id cannot be empty"): + SyntheticEnvironment(id="", name="n", description="d", system_prompt="p") + + +def test_environment_empty_name_raises(): + with pytest.raises(ValueError, match="name cannot be empty"): + SyntheticEnvironment(id="x", name="", description="d", system_prompt="p") + + +def test_environment_empty_description_raises(): + with pytest.raises(ValueError, match="description cannot be empty"): + SyntheticEnvironment(id="x", name="n", description="", system_prompt="p") + + +def test_environment_duplicate_tool_ids_raises(): + with pytest.raises(ValueError, match="duplicate tool id 'dup'"): + SyntheticEnvironment( + id="env2", + name="Env 2", + description="d", + system_prompt="p", + tools=[ + Tool(id="dup", name="Read", description="Read files."), + Tool(id="dup", name="Write", description="Write files."), + ], + ) + + +def test_environment_config_duplicate_tool_ids_across_envs_raises(): + env1 = SyntheticEnvironment( + id="env1", + name="Env 1", + description="d", + system_prompt="p", + tools=[Tool(id="dup", name="Read", description="Read files.")], + ) + env2 = SyntheticEnvironment( + id="env2", + name="Env 2", + description="d", + system_prompt="p", + tools=[Tool(id="dup", name="Write", description="Write files.")], + ) + with pytest.raises(ValueError, match="duplicate tool id 'dup'"): + EnvironmentConfig(environments=[env1, env2]) + + +def test_environment_config_tool_environment_map(): + env = SyntheticEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[_make_synthetic_tool(id="answer_faq")], + ) + config = EnvironmentConfig(environments=[env]) + assert config.tool_environment_map == {"answer_faq": "faq"} + + +def test_base_environment_create_routes_synthetic(): + env = BaseEnvironment.create( + { + "id": "faq", + "type": "synthetic", + "name": "FAQ", + "description": "FAQ tools", + "system_prompt": "Answer FAQs.", + "tools": [{"id": "answer", "name": "Answer", "description": "Answer."}], + } + ) + assert isinstance(env, SyntheticEnvironment) + + +def test_base_environment_create_routes_deterministic(): + env = BaseEnvironment.create( + { + "id": "lookup", + "type": "deterministic", + "name": "Lookup", + "description": "Lookup tools", + "tools": [ + { + "id": "policy", + "name": "Policy", + "description": "Look up policy.", + "deterministic_outputs": [ + {"input": {"id": "1"}, "output": {"result": "ok"}} + ], + } + ], + } + ) + assert isinstance(env, DeterministicEnvironment) + + +def test_base_environment_create_missing_type_raises(): + with pytest.raises(ValueError, match="must include a 'type' field"): + BaseEnvironment.create({"id": "faq"}) + + +def test_base_environment_create_unsupported_type_raises(): + with pytest.raises(ValueError, match="Unsupported environment type"): + BaseEnvironment.create({"id": "faq", "type": "unknown"}) diff --git a/tests/unit/core/configs/test_synthesis_config.py b/tests/unit/core/configs/test_synthesis_config.py index 43fd0b0d79..f6ff5cc1d9 100644 --- a/tests/unit/core/configs/test_synthesis_config.py +++ b/tests/unit/core/configs/test_synthesis_config.py @@ -14,9 +14,19 @@ import pytest +from oumi.core.configs.environment_config import EnvironmentConfig from oumi.core.configs.inference_config import InferenceConfig -from oumi.core.configs.params.synthesis_params import GeneralSynthesisParams +from oumi.core.configs.params.synthesis_params import ( + GeneralSynthesisParams, + MultiTurnAttribute, +) from oumi.core.configs.synthesis_config import SynthesisConfig, SynthesisStrategy +from oumi.core.types.conversation import Role +from oumi.environments import ( + SyntheticEnvironment, + SyntheticStateParams, + Tool, +) def test_default_synthesis_config(): @@ -70,3 +80,246 @@ def test_invalid_output_path(): with pytest.raises(ValueError, match="Output path is not supported"): SynthesisConfig(inference_config=inference_config) + + +def _make_faq_tool() -> Tool: + return Tool( + id="answer_faq", + name="AnswerFAQ", + description="Answer a FAQ question.", + ) + + +def test_synthesis_config_with_top_level_environment_config(): + env_config = EnvironmentConfig( + environments=[ + SyntheticEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[_make_faq_tool()], + ) + ] + ) + params = GeneralSynthesisParams() + params.multiturn_attributes = [] + + config = SynthesisConfig( + strategy_params=params, + environment_config=env_config, + ) + + assert config.environment_config is not None + assert config.environment_config == env_config + assert config.environment_config.tool_environment_map == {"answer_faq": "faq"} + + +def test_synthesis_config_loads_environment_config_from_path(tmp_path): + env_config_path = tmp_path / "environments.yaml" + env_config = EnvironmentConfig( + environments=[ + SyntheticEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[_make_faq_tool()], + ) + ] + ) + env_config.to_yaml(env_config_path) + + config = SynthesisConfig(environment_config_path=str(env_config_path)) + + assert config.environment_config is not None + assert config.environment_config.all_tools[0].id == "answer_faq" + + +def test_synthesis_config_validates_available_tools(): + env_config = EnvironmentConfig( + environments=[ + SyntheticEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[_make_faq_tool()], + ) + ] + ) + params = GeneralSynthesisParams( + multiturn_attributes=[ + MultiTurnAttribute( + id="chat", + min_turns=1, + max_turns=2, + role_instruction_messages={ + Role.USER: "You are a user.", + Role.ASSISTANT: "You are an assistant.", + }, + available_tools=["answer_faq"], + ) + ] + ) + + config = SynthesisConfig( + strategy_params=params, + environment_config=env_config, + ) + + assert config.environment_config is not None + assert config.environment_config.all_tools[0].id == "answer_faq" + assert params.multiturn_attributes is not None + mt_attr = params.multiturn_attributes[0] + assert [t.id for t in config.resolve_multiturn_tools(mt_attr)] == ["answer_faq"] + + +def test_synthesis_config_requires_environment_config_for_available_tools(): + params = GeneralSynthesisParams( + multiturn_attributes=[ + MultiTurnAttribute( + id="chat", + min_turns=1, + max_turns=2, + role_instruction_messages={ + Role.USER: "You are a user.", + Role.ASSISTANT: "You are an assistant.", + }, + available_tools=["answer_faq"], + ) + ] + ) + + with pytest.raises(ValueError, match="Environment or tool references require"): + SynthesisConfig(strategy_params=params) + + +def test_synthesis_config_validates_available_environments(): + env_config = EnvironmentConfig( + environments=[ + SyntheticEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[_make_faq_tool()], + ) + ] + ) + params = GeneralSynthesisParams( + multiturn_attributes=[ + MultiTurnAttribute( + id="chat", + min_turns=1, + max_turns=2, + role_instruction_messages={ + Role.USER: "You are a user.", + Role.ASSISTANT: "You are an assistant.", + }, + available_environments=["missing_env"], + ) + ] + ) + + with pytest.raises(ValueError, match="references unknown environment"): + SynthesisConfig(strategy_params=params, environment_config=env_config) + + +def test_synthesis_config_restricts_tools_to_selected_environments(): + env_config = EnvironmentConfig( + environments=[ + SyntheticEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[_make_faq_tool()], + ), + SyntheticEnvironment( + id="files", + name="Files", + description="File tools", + system_prompt="Manage files.", + state_params=SyntheticStateParams(), + cache_by_input=False, + tools=[ + Tool( + id="read_file", + name="ReadFile", + description="Read a file.", + ) + ], + ), + ] + ) + params = GeneralSynthesisParams( + multiturn_attributes=[ + MultiTurnAttribute( + id="chat", + min_turns=1, + max_turns=2, + role_instruction_messages={ + Role.USER: "You are a user.", + Role.ASSISTANT: "You are an assistant.", + }, + available_environments=["faq"], + available_tools=["read_file"], + ) + ] + ) + + with pytest.raises(ValueError, match="references unknown tool"): + SynthesisConfig(strategy_params=params, environment_config=env_config) + + +def test_synthesis_config_resolves_all_tools_from_selected_environments(): + env_config = EnvironmentConfig( + environments=[ + SyntheticEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[_make_faq_tool()], + ), + SyntheticEnvironment( + id="files", + name="Files", + description="File tools", + system_prompt="Manage files.", + state_params=SyntheticStateParams(), + cache_by_input=False, + tools=[ + Tool( + id="read_file", + name="ReadFile", + description="Read a file.", + ) + ], + ), + ] + ) + mt_attr = MultiTurnAttribute( + id="chat", + min_turns=1, + max_turns=2, + role_instruction_messages={ + Role.USER: "You are a user.", + Role.ASSISTANT: "You are an assistant.", + }, + available_environments=["faq", "files"], + ) + config = SynthesisConfig( + strategy_params=GeneralSynthesisParams(multiturn_attributes=[mt_attr]), + environment_config=env_config, + ) + + assert [env.id for env in config.resolve_multiturn_environments(mt_attr)] == [ + "faq", + "files", + ] + assert [tool.id for tool in config.resolve_multiturn_tools(mt_attr)] == [ + "answer_faq", + "read_file", + ] diff --git a/tests/unit/core/synthesis/test_synthesis_pipeline.py b/tests/unit/core/synthesis/test_synthesis_pipeline.py index 38e560beb0..e85c08364b 100644 --- a/tests/unit/core/synthesis/test_synthesis_pipeline.py +++ b/tests/unit/core/synthesis/test_synthesis_pipeline.py @@ -314,6 +314,7 @@ def test_synthesize_with_multiturn_attributes( mock_conv_synth_class.assert_called_once_with( synthesis_config_with_multiturn_attributes.strategy_params, synthesis_config_with_multiturn_attributes.inference_config, + environment_config=None, ) mock_conv_synth.synthesize.assert_called_once_with(sample_dataset, multiturn_attr) assert all(multiturn_attr.id in item for item in result) @@ -324,7 +325,7 @@ def test_synthesize_with_multiturn_attributes( @patch("oumi.core.synthesis.synthesis_pipeline.DatasetPlanner") @patch("oumi.core.synthesis.synthesis_pipeline.AttributeTransformer") @patch("oumi.core.synthesis.synthesis_pipeline.AttributeSynthesizer") -def test_synthesize_with_multiturn_filtered_result_keeps_samples( +def test_synthesize_with_multiturn_filtered_result_drops_failed_samples( mock_attr_synth, mock_attr_transformer_class, mock_dataset_planner_class, @@ -333,7 +334,7 @@ def test_synthesize_with_multiturn_filtered_result_keeps_samples( mock_dataset_planner, mock_attribute_transformer, ): - """Test that filtered conversation results do not crash or drop samples.""" + """Test that failed multiturn rows are removed from final dataset.""" sample_dataset = [ {"id": "s1", "base": "v1"}, {"id": "s2", "base": "v2"}, @@ -360,13 +361,11 @@ def test_synthesize_with_multiturn_filtered_result_keeps_samples( pipeline = SynthesisPipeline(synthesis_config_with_multiturn_attributes) result = pipeline.synthesize() - assert len(result) == 2 + assert len(result) == 1 assert multiturn_attr.id in result[0] assert plan_key in result[0] - assert multiturn_attr.id not in result[1] - assert plan_key not in result[1] - assert result[1]["id"] == "s2" - assert result[1]["base"] == "v2" + assert result[0]["id"] == "s1" + assert result[0]["base"] == "v1" @patch("oumi.core.synthesis.synthesis_pipeline.DatasetPlanner")