From be38f5db6dc51091795897a084e320619ec14f8d Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Wed, 1 Apr 2026 20:34:39 -0700 Subject: [PATCH 01/18] feat: add agentic tool synthesis skeleton Establishes the public API surface for tool-based conversation synthesis: - ToolAttribute, ToolEnvironmentAttribute config dataclasses with validation - ToolExecutor stub (parsing, execution, output formatting) - GeneratedToolEnvironment + EnvironmentRegistry stubs (stateful tools) - json_patch utility stub (RFC 6902) - MultiTurnAttribute extended with available_tools/max_tool_calls_per_turn - jsonpatch dependency added to pyproject.toml All new modules contain signatures only (raise NotImplementedError). Implementations follow in subsequent PRs. --- pyproject.toml | 56 +-- .../core/configs/params/synthesis_params.py | 139 ++++++++ src/oumi/core/configs/params/tool_params.py | 275 ++++++++++++++ src/oumi/core/synthesis/environment.py | 176 +++++++++ src/oumi/core/synthesis/synthesis_pipeline.py | 16 + src/oumi/core/synthesis/tool_executor.py | 185 ++++++++++ src/oumi/utils/json_patch.py | 39 ++ .../core/configs/params/test_tool_params.py | 337 ++++++++++++++++++ .../core/synthesis/test_synthesis_pipeline.py | 12 +- 9 files changed, 1188 insertions(+), 47 deletions(-) create mode 100644 src/oumi/core/configs/params/tool_params.py create mode 100644 src/oumi/core/synthesis/environment.py create mode 100644 src/oumi/core/synthesis/tool_executor.py create mode 100644 src/oumi/utils/json_patch.py create mode 100644 tests/unit/core/configs/params/test_tool_params.py diff --git a/pyproject.toml b/pyproject.toml index 7101fa57b6..c4786d14de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,11 +44,12 @@ dependencies = [ "aioresponses>=0.7,<0.8", # User by inference engine tests "backoff>=2.2.1,<2.3", "click<8.4.0", # Used by CLI. 8.2.0 is currently unsupported by Typer. - "datasets>=3.2,<4.8.5", + "datasets>=3.2,<5", "greenlet", # Required by skypilot 0.11+ (sqlalchemy asyncio) "hdrhistogram>=0.10,<0.11", "httpx>=0.27,<1.0", # Used by deploy module (async HTTP client) "jsonlines", + "jsonpatch>=1.33,<2.0", "lm_eval[wandb]>=0.4,<0.5.0", "mlflow>=3.1", # >=3.1.4 requires Python3.10>= "numpy>=1.26,<2.4", # verl==0.5.0 depends on numpy<2.0.0 @@ -76,11 +77,13 @@ dependencies = [ "torchao>=0.12,<0.16", # Used by transformers "torchvision>=0.21,<0.26", # Used by some VLM-s (multimodal) "tqdm", - "transformers>=4.57,<5.5", - "trl>=0.24,<0.30", + # Llama Vision attention is broken as late as 4.48.2 if gradient checkpointing is + # enabled. See OPE-875 and https://github.com/huggingface/transformers/issues/36040. + "transformers>=4.57,<4.58", + "trl>=0.24,<0.27", "typer<0.24.2", # Used by CLI. Pinned due to sphinxcontrib-typer compatibility. OPE-1806 "typing_extensions", # Backports of typing updates - "uvicorn<0.43.0", # TODO: Remove on resolution of https://github.com/skypilot-org/skypilot/issues/7303 + "uvicorn<0.42.0", # TODO: Remove on resolution of https://github.com/skypilot-org/skypilot/issues/7303 "wandb>=0.21,<0.26", # Logging to Weights and Biases. ] @@ -120,12 +123,12 @@ docs = [ # Useful dependencies when running on GPU gpu = [ "liger-kernel>=0.6,<0.8", - "nvidia-ml-py>=13.580,<13.596", + "nvidia-ml-py>=13.580,<13.591", "bitsandbytes>=0.47,<0.50", # Used for QLora, and PagedAdam implementation # When updating verl version, make sure to also update the default config: # src/oumi/core/trainers/verl_trainer_config.yaml. - "verl>=0.5,<0.8; python_version >= '3.10'", # Used for the VERL_GRPO trainer - "vllm>=0.10,<0.19", # For VLLMInferenceEngine, and vLLM-powered GRPO training. + "verl>=0.5,<0.6; python_version >= '3.10'", # Used for the VERL_GRPO trainer + "vllm>=0.10,<0.11", # For VLLMInferenceEngine, and vLLM-powered GRPO training. "kernels>=0.11,<0.13", # For compute kernel implementations (e.g., attn_implementation in gold trainer) ] @@ -164,6 +167,7 @@ flash_attn = [ ] evaluation = [ + "alpaca-eval>=0.6,<0.7; python_version >= '3.10'", "langdetect", # leaderboard_ifeval "immutabledict", # leaderboard_ifeval "nltk>=3.9", # leaderboard_ifeval @@ -178,13 +182,7 @@ tune = ["optuna>=4.0.0,<5.0"] bitnet = ["onebitllms>=0.0.3"] -mcp = [ - "fastmcp>=3.0.0,<4", - "httpx>=0.28.1,<1", - "mcp>=1.25.0,<2", -] - -file_formats = ["pdf2image>=1.17,<1.18"] +file_formats = ["pdf2image>=1.17,<1.18", "python-poppler>=0.4,<0.5"] llama_cpp = ["llama-cpp-python>=0.3.5,<0.4"] @@ -196,9 +194,8 @@ torchdata = ["torchdata>=0.9,<0.10.0"] # CI targets ci_cpu = [ - "oumi[dev,docs,gcp,mcp,synthesis,tune,torchdata,file_formats]", - "vllm>=0.10,<0.19", # For VLLMInferenceEngine - "boto3", # For bedrock inference engine tests + "oumi[dev,docs,gcp,synthesis,tune,torchdata]", + "vllm>=0.10,<0.11", # For VLLMInferenceEngine # This may fail to install. As a temporary workaround, run: # CMAKE_ARGS="-DLLAVA_BUILD=OFF" pip install -U llama-cpp-python "llama-cpp-python>=0.3,<0.4", # For LlamaCppInferenceEngine @@ -206,13 +203,12 @@ ci_cpu = [ # llama-cpp-python is not compatible with the github # gpu actions runner, so we skip it for now ci_gpu = [ - "oumi[dev,docs,gcp,gpu,deepspeed,synthesis,tune,torchdata,file_formats]", - "boto3", # For bedrock inference engine tests + "oumi[dev,docs,gcp,gpu,deepspeed,synthesis,tune,torchdata]", + "alpaca-eval>=0.6,<0.7; python_version >= '3.10'", ] [project.scripts] oumi = "oumi.cli.main:run" -oumi-mcp = "oumi.mcp.server:main" [tool.ruff] extend-include = [ @@ -252,7 +248,6 @@ ban-relative-imports = "all" # Disallow all relative imports. [tool.ruff.lint.per-file-ignores] "tests/*" = ["D", "PTH"] # Ignore docstring checks in tests -"src/oumi/mcp/*" = ["ASYNC210", "ASYNC220", "ASYNC230", "ASYNC240"] # fastmcp uses sync-over-async patterns "src/oumi/utils/verl_model_merger.*" = [ "E501", # Line too long "D", # Ignore docstring checks @@ -303,7 +298,6 @@ unsupported-operator = "warn" # Type narrowing limitations with isinstance too-many-positional-arguments = "warn" # Loose typing (Callable) doesn't capture signatures parameter-already-assigned = "warn" # False positives with *args/**kwargs patterns - [tool.pytest.ini_options] asyncio_default_fixture_loop_scope = "function" testpaths = ["tests"] @@ -312,24 +306,6 @@ filterwarnings = [ # Warnings from mlflow dependency that we can't fix "ignore:Support for class-based `config` is deprecated.*:DeprecationWarning", "ignore:builtin type .* has no __module__ attribute:DeprecationWarning", - # torchao deprecation warnings - upstream issue, will be resolved in future torchao releases - "ignore:Importing from torchao.dtypes:DeprecationWarning", - "ignore:Importing BlockSparseLayout from torchao.dtypes:DeprecationWarning", - # torchdata datapipes deprecation - known limitation, pinned to <0.10 - "ignore:The 'datapipes', 'dataloader2' modules are deprecated:UserWarning", - # Ray state API deprecation - upstream issue in vllm/ray - "ignore:Ray state API is no longer experimental:DeprecationWarning", - # transformers deprecation warnings - upstream, will be resolved when upgrading to v5 - "ignore:The image_processor_class argument is deprecated:FutureWarning", - "ignore:The class `AutoModelForVision2Seq` is deprecated:FutureWarning", - # TRL warnings - version-aware code already handles these, warnings are from TRL internals - "ignore:To copy construct from a tensor, it is recommended to use:UserWarning", - "ignore:The `KTOConfig` is now located in `trl.experimental`:FutureWarning", - "ignore:The `KTOTrainer` is now located in `trl.experimental`:FutureWarning", - "ignore:The `GKDConfig` is now located in `trl.experimental`:FutureWarning", - "ignore:The `GKDTrainer` is now located in `trl.experimental`:FutureWarning", - # torchao quant_api invalid escape sequence - upstream issue - "ignore:invalid escape sequence:DeprecationWarning", ] markers = [ "e2e: Slow e2e integration tests", diff --git a/src/oumi/core/configs/params/synthesis_params.py b/src/oumi/core/configs/params/synthesis_params.py index 8de36f75ca..d15bcbac97 100644 --- a/src/oumi/core/configs/params/synthesis_params.py +++ b/src/oumi/core/configs/params/synthesis_params.py @@ -22,6 +22,10 @@ from typing import Any from oumi.core.configs.params.base_params import BaseParams +from oumi.core.configs.params.tool_params import ( + ToolAttribute, + ToolEnvironmentAttribute, +) from oumi.core.types.conversation import Conversation, Message, Role _SUPPORTED_DATASET_FILE_TYPES = {".jsonl", ".json", ".csv", ".parquet", ".tsv", ".xlsx"} @@ -474,6 +478,14 @@ class MultiTurnAttribute: Allows user to specify custom instructions for the planner while planning out the conversation.""" + available_tools: list[str] = field(default_factory=list) + """List of tool ids (from GeneralSynthesisParams.tools) available in this + conversation.""" + + max_tool_calls_per_turn: int = 50 + """Safety ceiling for tool calls per ASSISTANT turn. The agent naturally stops + when it decides no more tools are needed. This only prevents runaway loops.""" + def __post_init__(self): """Verifies/populates params.""" if not self.id: @@ -543,6 +555,17 @@ def __post_init__(self): "string." ) + if self.available_tools is not None: + if not isinstance(self.available_tools, list): + raise ValueError( + "MultiTurnAttribute.available_tools must be a list of tool names." + ) + for tool in self.available_tools: + if not isinstance(tool, str): + raise ValueError( + "MultiTurnAttribute.available_tools must be a list of strings." + ) + class TransformationType(str, Enum): """Types of transformation strategies.""" @@ -770,6 +793,45 @@ class GeneralSynthesisParams(BaseParams): ] """ + tools: list[ToolAttribute] | None = None + """Tool definitions for agentic synthesis. + + Tools are defined here and referenced by id from + MultiTurnAttribute.available_tools. Each tool specifies its parameters + (input schema), output_schema, and output strategy (DETERMINISTIC or + GENERATED). + + Example:: + + tools = [ + ToolAttribute( + id="search", + name="WebSearch", + description="Searches the web for the given query.", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query.", + } + }, + "required": ["query"], + }, + output_strategy=ToolOutputStrategy.GENERATED, + generated_output=GeneratedToolOutput( + instruction="Return relevant search results.", + ), + ), + ] + """ + + environments: list[ToolEnvironmentAttribute] | None = None + """Environment definitions for stateful tool synthesis. + + Environments are stateful containers that tools operate on. Tools + reference environments by id via ToolAttribute.environment.""" + transformed_attributes: list[TransformedAttribute] | None = None """Transformation of existing attributes. @@ -912,6 +974,80 @@ def _check_passthrough_attribute_ids(self) -> None: self.passthrough_attributes = None return + def _check_environment_ids(self) -> None: + """Validate environment ids are unique.""" + if not self.environments: + self.environments = None + return + + env_ids = [env.id for env in self.environments] + if len(env_ids) != len(set(env_ids)): + seen: set[str] = set() + dupes = [e for e in env_ids if e in seen or seen.add(e)] # type: ignore[func-returns-value] + raise ValueError( + f"GeneralSynthesisParams.environments contains " + f"duplicate environment ids: {dupes}" + ) + + def _check_tool_environment_references(self) -> None: + """Validate that tool.environment references valid environment ids.""" + if not self.tools: + return + + env_ids: set[str] = set() + if self.environments: + env_ids = {env.id for env in self.environments} + + for tool in self.tools: + if tool.environment and tool.environment not in env_ids: + raise ValueError( + f"ToolAttribute '{tool.id}' references unknown " + f"environment '{tool.environment}'. " + f"Defined environment ids: {sorted(env_ids)}" + ) + + def _check_available_tools(self) -> None: + """Validate that available_tools ids reference defined tools.""" + if not self.multiturn_attributes: + self.multiturn_attributes = None + return + + # Collect all tool ids referenced by any multiturn attribute + all_referenced = [ + tool_id + for mt_attr in self.multiturn_attributes + for tool_id in mt_attr.available_tools + ] + + if not all_referenced: + if not self.tools: + self.tools = None + return + + if not self.tools: + raise ValueError( + "GeneralSynthesisParams.tools must be defined when " + "MultiTurnAttribute.available_tools is non-empty. " + f"Referenced tool ids: {sorted(set(all_referenced))}" + ) + + tool_id_list = [tool.id for tool in self.tools] + tool_ids = set(tool_id_list) + if len(tool_id_list) != len(tool_ids): + seen: set[str] = set() + dupes = [t for t in tool_id_list if t in seen or seen.add(t)] # type: ignore[func-returns-value] + raise ValueError( + f"GeneralSynthesisParams.tools contains duplicate tool ids: {dupes}" + ) + + for mt_attr in self.multiturn_attributes: + for tool_id in mt_attr.available_tools: + if tool_id not in tool_ids: + raise ValueError( + f"MultiTurnAttribute '{mt_attr.id}' references unknown " + f"tool '{tool_id}'. Defined tool ids: {sorted(tool_ids)}" + ) + def __post_init__(self): """Verifies/populates params.""" self._reserved_attribute_ids = self._get_reserved_attribute_ids() @@ -925,3 +1061,6 @@ def __post_init__(self): self._check_transformed_attribute_ids(all_attribute_ids) self._check_passthrough_attribute_ids() self._check_combination_sampling_sample_rates() + self._check_available_tools() + self._check_environment_ids() + self._check_tool_environment_references() diff --git a/src/oumi/core/configs/params/tool_params.py b/src/oumi/core/configs/params/tool_params.py new file mode 100644 index 0000000000..73b0aae838 --- /dev/null +++ b/src/oumi/core/configs/params/tool_params.py @@ -0,0 +1,275 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Params for tool configuration in agentic synthesis.""" + +import math +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class ToolOutputStrategy(str, Enum): + """Strategy for how a tool produces its output.""" + + DETERMINISTIC = "deterministic" + """Output is selected from user-defined values. One value is selected + per conversation and reused for all calls to this tool within that + conversation.""" + + GENERATED = "generated" + """Output is generated by an LLM simulator conditioned on the tool schema, + arguments, conversation history, and optional environment state.""" + + ENVIRONMENT = "environment" + """Output comes from a stateful environment's step() method.""" + + +@dataclass +class DeterministicToolOutput: + """A possible canned output for a tool. + + Used with ToolOutputStrategy.DETERMINISTIC. One value is randomly selected + per conversation, weighted by sample_rate. Matches the sampling pattern + used by SampledAttributeValue in synthesis_params. + """ + + values: dict[str, Any] = field(default_factory=dict) + """Structured output values. Serialized to JSON by the system at output time.""" + + sample_rate: float | None = None + """Selection weight. If not specified, assumes uniform sampling + among all possible outputs. Must be between 0 and 1.""" + + def __post_init__(self): + """Verifies/populates params.""" + if not self.values: + raise ValueError("DeterministicToolOutput.values cannot be empty.") + if self.sample_rate is not None and ( + self.sample_rate < 0 or self.sample_rate > 1 + ): + raise ValueError( + "DeterministicToolOutput.sample_rate must be between 0 and 1." + ) + + +@dataclass +class GeneratedToolOutput: + """Configuration for LLM-simulated tool output. + + Used with ToolOutputStrategy.GENERATED. The LLM simulator receives + the tool definition, call arguments, conversation history, and this + config to produce a realistic output. + """ + + instruction: str + """Prompt hint for the LLM simulator. Guides what kind of output to produce. + Example: "Return eligibility based on order status and 30-day return window." + """ + + def __post_init__(self): + """Verifies/populates params.""" + if not self.instruction: + raise ValueError("GeneratedToolOutput.instruction cannot be empty.") + + +@dataclass +class ToolEnvironmentAttribute: + """Defines a stateful environment for tool synthesis. + + Environments are stateful containers that tools operate on. + Each environment maintains a JSON state document that evolves + as tools read from and write to it. + """ + + id: str + """Unique identifier. Referenced by ToolAttribute.environment.""" + + name: str + """Display name (e.g. 'Filesystem', 'PropertyManagementSystem').""" + + description: str + """What this environment represents. Used in LLM prompts for + state generation and tool result generation.""" + + system_prompt: str + """Instructions for the LLM managing this environment's state. + Included in all _generate_result and _update_state prompts.""" + + state_schema: dict[str, Any] | None = None + """JSON Schema the state must conform to. If None, auto-generated + from bound tools at initialization time.""" + + initial_state: dict[str, Any] | None = None + """Starting state document. If None, auto-generated from schema. + Validated against state_schema when both are provided.""" + + def __post_init__(self): + """Verifies/populates params.""" + if not self.id: + raise ValueError("ToolEnvironmentAttribute.id cannot be empty.") + if not self.name: + raise ValueError("ToolEnvironmentAttribute.name cannot be empty.") + if not self.description: + raise ValueError("ToolEnvironmentAttribute.description cannot be empty.") + if not self.system_prompt: + raise ValueError("ToolEnvironmentAttribute.system_prompt cannot be empty.") + + +@dataclass +class ToolAttribute: + """Defines a tool available for agentic synthesis. + + Tools are defined at the GeneralSynthesisParams level and referenced + by id from MultiTurnAttribute.available_tools. + + The output_strategy field determines which output config is used: + - DETERMINISTIC: uses deterministic_outputs (list, one sampled per conversation) + - GENERATED: uses generated_output (singular config for LLM simulator) + """ + + id: str + """Unique identifier for the tool. Referenced by available_tools.""" + + name: str + """Display name of the tool (e.g., "SearchOrders"). + Used in the tool catalog and output tool definitions.""" + + description: str + """What the tool does. Shown in the tool catalog, used in simulation + prompts, and included in output tool definitions.""" + + output_strategy: ToolOutputStrategy = ToolOutputStrategy.GENERATED + """How this tool produces its output.""" + + parameters: dict[str, Any] = field(default_factory=dict) + """JSON Schema for tool parameters. Matches the OpenAI/HuggingFace standard. + + Example:: + + { + "type": "object", + "properties": { + "customer_id": { + "type": "string", + "description": "The customer's ID" + } + }, + "required": ["customer_id"] + } + """ + + output_schema: dict[str, Any] = field(default_factory=dict) + """JSON Schema describing the tool's output structure. Shared across strategies. + + For DETERMINISTIC: used to validate that canned outputs conform to the schema. + For GENERATED: passed to the LLM simulator as structure guidance. + + Example:: + + { + "type": "object", + "properties": { + "eligible": {"type": "boolean"}, + "reason": {"type": "string"} + } + } + """ + + deterministic_outputs: list[DeterministicToolOutput] = field(default_factory=list) + """Possible canned outputs for DETERMINISTIC strategy. + + One is selected per conversation, weighted by sample_rate (uniform if unset). + Selection happens once at conversation initialization, not per-call.""" + + generated_output: GeneratedToolOutput | None = None + """Configuration for LLM-simulated output for GENERATED strategy. + + The LLM simulator uses this config along with the tool call arguments + and conversation context to produce a realistic output.""" + + environment: str | None = None + """References an ToolEnvironmentAttribute.id. When set, output_strategy + must be ENVIRONMENT and tool results come from env.step().""" + + read_only: bool = True + """Whether this tool only reads (does not modify) its environment's state. + Only relevant when environment is set.""" + + def __post_init__(self): + """Verifies/populates params.""" + if not self.id: + raise ValueError("ToolAttribute.id cannot be empty.") + if not self.name: + raise ValueError("ToolAttribute.name cannot be empty.") + if not self.description: + raise ValueError("ToolAttribute.description cannot be empty.") + if self.output_strategy == ToolOutputStrategy.ENVIRONMENT: + if not self.environment: + raise ValueError( + "ToolAttribute.environment must be set " + "when output_strategy is ENVIRONMENT." + ) + return + + if self.environment: + raise ValueError( + "ToolAttribute.output_strategy must be ENVIRONMENT " + "when environment is set." + ) + + if self.output_strategy == ToolOutputStrategy.DETERMINISTIC: + if not self.deterministic_outputs: + raise ValueError( + "ToolAttribute.deterministic_outputs cannot be empty " + "when output_strategy is DETERMINISTIC." + ) + self._normalize_sample_rates() + + elif self.output_strategy == ToolOutputStrategy.GENERATED: + if not self.generated_output: + raise ValueError( + "ToolAttribute.generated_output must be provided " + "when output_strategy is GENERATED." + ) + + def _normalize_sample_rates(self) -> None: + """Normalize sample rates for deterministic outputs. + + Matches the pattern used by SampledAttribute for consistency: + - Defined rates are summed + - Remaining probability is split uniformly among undefined rates + """ + sample_rates = [o.sample_rate for o in self.deterministic_outputs] + + defined_rate = 0.0 + undefined_count = 0 + for rate in sample_rates: + if rate is not None: + defined_rate += rate + else: + undefined_count += 1 + + if defined_rate > 1.0 and not math.isclose(defined_rate, 1.0): + raise ValueError( + "ToolAttribute.deterministic_outputs sample rates " + "must sum to at most 1.0." + ) + + if undefined_count > 0: + remaining = max(0.0, 1.0 - defined_rate) + per_undefined = remaining / undefined_count + for output_value in self.deterministic_outputs: + if output_value.sample_rate is None: + output_value.sample_rate = per_undefined diff --git a/src/oumi/core/synthesis/environment.py b/src/oumi/core/synthesis/environment.py new file mode 100644 index 0000000000..d243da5a5f --- /dev/null +++ b/src/oumi/core/synthesis/environment.py @@ -0,0 +1,176 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stateful environment and registry for agentic tool synthesis.""" + +from collections.abc import Callable +from typing import Any + +from oumi.core.configs.params.tool_params import ( + ToolAttribute, + ToolEnvironmentAttribute, +) +from oumi.core.types.conversation import Conversation, Message + +_MAX_STATE_UPDATE_RETRIES = 2 +_MAX_RESULT_RETRIES = 2 + + +class GeneratedToolEnvironment: + """Stateful environment for tool synthesis. + + Maintains a JSON state document that evolves as tools read/write to it. + Builds prompts and applies responses — does not call inference itself. + """ + + def __init__(self, config: ToolEnvironmentAttribute): + """Initialize the environment with config.""" + raise NotImplementedError + + @property + def state(self) -> dict[str, Any]: + """Current state of the environment.""" + raise NotImplementedError + + def summarize_for_planner(self) -> dict[str, Any]: + """Return a compact state view for planner grounding.""" + raise NotImplementedError + + def set_state(self, state: dict[str, Any], validate: bool = True) -> bool: + """Set state, optionally skipping schema validation.""" + raise NotImplementedError + + def set_schema(self, schema: dict[str, Any]) -> None: + """Set the state schema.""" + raise NotImplementedError + + def build_result_prompt( + self, + tool: ToolAttribute, + arguments: dict[str, Any], + retry: bool = False, + ) -> Conversation: + """Build the prompt for generating a tool result.""" + raise NotImplementedError + + def build_write_state_update_prompt( + self, + tool: ToolAttribute, + arguments: dict[str, Any], + retry: bool = False, + retry_error: str | None = None, + ) -> Conversation: + """Build prompt for state-first write.""" + raise NotImplementedError + + def build_write_result_prompt( + self, + tool: ToolAttribute, + arguments: dict[str, Any], + patch_ops: list[dict[str, Any]], + patch_succeeded: bool, + pre_patch_state: dict[str, Any] | None = None, + retry: bool = False, + retry_error: str | None = None, + ) -> Conversation: + """Build prompt for generating a write tool result after state update.""" + raise NotImplementedError + + def apply_result(self, response: Conversation) -> str: + """Extract the tool result text from an inference response.""" + raise NotImplementedError + + def apply_state_update_returning_patch( + self, response: Conversation + ) -> tuple[bool, list[dict[str, Any]], str | None]: + """Parse and apply a write-state update. Returns (succeeded, ops, error).""" + raise NotImplementedError + + def apply_state_update(self, response: Conversation) -> bool: + """Convenience wrapper: apply patch and return success bool.""" + raise NotImplementedError + + +class EnvironmentRegistry: + """Builds environments once, then copies N times for parallel samples.""" + + def __init__(self): + raise NotImplementedError + + def register_static(self, config: ToolEnvironmentAttribute) -> None: + """Register an environment that needs no LLM generation.""" + raise NotImplementedError + + def build( + self, + config: ToolEnvironmentAttribute, + tools: list[ToolAttribute], + inference_engine: Any, + inference_config: Any, + scenario_context: str | None = None, + ) -> None: + """Build a fully populated environment.""" + raise NotImplementedError + + def create_copies( + self, env_id: str, n: int + ) -> list[GeneratedToolEnvironment]: + """Return n independent deepcopies of a built environment.""" + raise NotImplementedError + + +def resolve_env_tool( + tool_executor: Any, + tool_call: dict, + idx_envs: dict[str, GeneratedToolEnvironment] | None, +) -> tuple[GeneratedToolEnvironment | None, ToolAttribute | None]: + """Look up (env, tool) for an environment-bound tool call.""" + raise NotImplementedError + + +def is_env_tool_missing_env(tool_executor: Any, tool_call: dict) -> bool: + """Return True if tool_call targets an ENVIRONMENT tool whose env is missing.""" + raise NotImplementedError + + +def serialize_env_states(envs: dict[str, GeneratedToolEnvironment]) -> str: + """Serialize planner-oriented environment summaries as stable JSON.""" + raise NotImplementedError + + +def init_sample_environments( + samples: list[dict], + tools: list[ToolAttribute], + env_configs: dict[str, ToolEnvironmentAttribute], + formatter: Any, + inference_engine: Any, + inference_config: Any, +) -> list[dict[str, GeneratedToolEnvironment] | None]: + """Create per-sample environments, reusing builds when config is identical.""" + raise NotImplementedError + + +def process_env_tool_calls( + env_items: list[ + tuple[int, str, dict, str, GeneratedToolEnvironment, ToolAttribute] + ], + env_result_prompts: list[Conversation], + turn_tool_msgs: dict[int, list[Message]], + output_messages: list[list[dict]], + inference_engine: Any, + inference_config: Any, + record_fn: Callable[..., None], +) -> list[int]: + """Execute batched env tool calls.""" + raise NotImplementedError diff --git a/src/oumi/core/synthesis/synthesis_pipeline.py b/src/oumi/core/synthesis/synthesis_pipeline.py index 7f68969393..098df6509b 100644 --- a/src/oumi/core/synthesis/synthesis_pipeline.py +++ b/src/oumi/core/synthesis/synthesis_pipeline.py @@ -89,6 +89,22 @@ def synthesize(self) -> list[dict[str, Any]]: continue sample.update(result) + required_multiturn_ids = { + attr.id for attr in self._config.strategy_params.multiturn_attributes + } + before_count = len(dataset) + dataset = [ + sample + for sample in dataset + if all(attr_id in sample for attr_id in required_multiturn_ids) + ] + dropped_count = before_count - len(dataset) + if dropped_count > 0: + logger.warning( + f"Dropped {dropped_count} sample(s) missing required " + f"multiturn outputs: {sorted(required_multiturn_ids)}" + ) + # Add the transformed attributes to the dataset logger.info("Adding transformed attributes") if self._config.strategy_params.transformed_attributes: diff --git a/src/oumi/core/synthesis/tool_executor.py b/src/oumi/core/synthesis/tool_executor.py new file mode 100644 index 0000000000..d9832012b9 --- /dev/null +++ b/src/oumi/core/synthesis/tool_executor.py @@ -0,0 +1,185 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tool executor for agentic synthesis.""" + +from dataclasses import dataclass +from typing import Any + +from oumi.core.configs.params.tool_params import ToolAttribute +from oumi.core.types.conversation import Conversation, Message + + +def clean_json_output(text: str) -> str: + """Strip markdown fences and extract clean JSON from LLM output.""" + raise NotImplementedError + + +def is_valid_json(text: str) -> bool: + """Return True if text is parseable as a JSON object or array.""" + raise NotImplementedError + + +def build_example_result(tool: ToolAttribute) -> str: + """Build a realistic example result from a tool's output_schema.""" + raise NotImplementedError + + +def _example_value(schema: dict[str, Any]) -> Any: + """Generate a placeholder example value from a JSON Schema property.""" + raise NotImplementedError + + +@dataclass +class ToolCallParsed: + """Successfully parsed and validated tool call.""" + + tool_call: dict[str, Any] + + +@dataclass +class ToolCallError: + """Structured error from parsing or validation.""" + + error_json: str + tool_name: str | None + + +ToolCallResult = ToolCallParsed | ToolCallError | None + + +class ToolExecutor: + """Parses tool calls from LLM responses and resolves tool outputs.""" + + def __init__(self, tools: list[ToolAttribute]): + """Initialize the tool executor with available tools.""" + raise NotImplementedError + + def get_tool_by_name(self, name: str) -> ToolAttribute | None: + """Look up a tool by its display name.""" + raise NotImplementedError + + def parse_and_validate_tool_call(self, response: str) -> ToolCallResult: + """Parse tags from response and validate arguments.""" + raise NotImplementedError + + def sample_deterministic_outputs( + self, tools: list[ToolAttribute] + ) -> dict[str, str]: + """Sample one deterministic output per DETERMINISTIC tool.""" + raise NotImplementedError + + def resolve_output( + self, + tool_call: dict[str, Any], + deterministic_selections: dict[str, str], + ) -> str | None: + """Resolve a tool call to its output. None for GENERATED tools.""" + raise NotImplementedError + + def build_generated_simulator_prompt( + self, + tool_call: dict[str, Any], + conversation_history: list[Message] | None = None, + ) -> Conversation: + """Build LLM prompt for simulating a GENERATED tool's output.""" + raise NotImplementedError + + @staticmethod + def build_capability_summary(tools: list[ToolAttribute]) -> str: + """Build a planner-facing capability summary.""" + raise NotImplementedError + + @staticmethod + def build_tool_catalog(tools: list[ToolAttribute]) -> str: + """Build a formatted tool catalog with schemas and usage examples.""" + raise NotImplementedError + + @staticmethod + def build_tool_definitions( + tools: list[ToolAttribute], + ) -> list[dict[str, Any]]: + """Convert ToolAttributes to standard tool definitions for output.""" + raise NotImplementedError + + @staticmethod + def format_tool_call_message( + tool_call: dict[str, Any], + call_id: str, + ) -> dict[str, Any]: + """Format a parsed tool call as a standard OpenAI assistant message.""" + raise NotImplementedError + + @staticmethod + def format_tool_result_message( + call_id: str, + content: str, + name: str, + ) -> dict[str, Any]: + """Format a tool result as a standard OpenAI tool message.""" + raise NotImplementedError + + @staticmethod + def strip_tool_tags(text: str) -> str: + """Remove any residual or tags.""" + raise NotImplementedError + + @staticmethod + def strip_bare_tool_json(text: str) -> str: + """Remove bare JSON objects that look like tool calls.""" + raise NotImplementedError + + @staticmethod + def sanitize_assistant_content(text: str) -> str: + """Strip tool-call artifacts from assistant prose before export.""" + raise NotImplementedError + + @staticmethod + def build_tool_few_shot(tools: list[ToolAttribute]) -> list[Message]: + """Build few-shot messages demonstrating a correct tool-call exchange.""" + raise NotImplementedError + + @staticmethod + def build_tool_turn_info( + current_turn: int, + target_turns: int, + turn_instruction: str, + max_calls_reached: bool, + ) -> str: + """Build the turn-level user message for assistant tool turns.""" + raise NotImplementedError + + @staticmethod + def build_prose_turn_info( + current_turn: int, + target_turns: int, + role: str, + turn_instruction: str, + ) -> str: + """Build turn-level user message for non-tool turns.""" + raise NotImplementedError + + @staticmethod + def record_tool_result( + idx: int, + raw_text: str, + tool_call: dict, + call_id: str, + result: str, + turn_tool_msgs: dict[int, list[Message]], + output_messages: list[list[dict]], + env_state: dict | None = None, + ) -> None: + """Append a tool call + result to conversation history and output.""" + raise NotImplementedError diff --git a/src/oumi/utils/json_patch.py b/src/oumi/utils/json_patch.py new file mode 100644 index 0000000000..7aa7641fa1 --- /dev/null +++ b/src/oumi/utils/json_patch.py @@ -0,0 +1,39 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RFC 6902 JSON Patch utilities.""" + +from typing import Any + + +class JsonPatchError(Exception): + """Raised when a JSON Patch is malformed or cannot be applied.""" + + +class JsonPatchValidationError(Exception): + """Raised when the patched document fails JSON Schema validation.""" + + +def apply_json_patch( + document: dict[str, Any], + patch: list[dict[str, Any]], + schema: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Apply an RFC 6902 JSON Patch to a document.""" + raise NotImplementedError + + +def parse_patch_response(text: str) -> list[dict[str, Any]] | None: + """Extract a JSON Patch array from LLM-generated text.""" + raise NotImplementedError diff --git a/tests/unit/core/configs/params/test_tool_params.py b/tests/unit/core/configs/params/test_tool_params.py new file mode 100644 index 0000000000..8e139378e0 --- /dev/null +++ b/tests/unit/core/configs/params/test_tool_params.py @@ -0,0 +1,337 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import pytest + +from oumi.core.configs.params.synthesis_params import ( + GeneralSynthesisParams, + MultiTurnAttribute, +) +from oumi.core.configs.params.tool_params import ( + DeterministicToolOutput, + GeneratedToolOutput, + ToolAttribute, + ToolEnvironmentAttribute, + ToolOutputStrategy, +) +from oumi.core.types.conversation import Role + + +def test_deterministic_tool_output_empty_values_raises(): + with pytest.raises(ValueError, match="values cannot be empty"): + DeterministicToolOutput(values={}) + + +@pytest.mark.parametrize("rate", [-0.1, 1.1]) +def test_deterministic_tool_output_invalid_sample_rate_raises(rate): + with pytest.raises(ValueError, match="sample_rate must be between 0 and 1"): + DeterministicToolOutput(values={"x": 1}, sample_rate=rate) + + +def _make_deterministic_tool(**overrides) -> ToolAttribute: + defaults = dict( + id="tool1", + name="MyTool", + description="A tool", + output_strategy=ToolOutputStrategy.DETERMINISTIC, + deterministic_outputs=[ + DeterministicToolOutput(values={"a": 1}), + ], + ) + defaults.update(overrides) + return ToolAttribute(**defaults) # type: ignore[arg-type] + + +def _make_generated_tool(**overrides) -> ToolAttribute: + defaults = dict( + id="tool2", + name="GenTool", + description="A generated tool", + output_strategy=ToolOutputStrategy.GENERATED, + generated_output=GeneratedToolOutput(instruction="Do something."), + ) + defaults.update(overrides) + return ToolAttribute(**defaults) # type: ignore[arg-type] + + +def test_tool_attribute_deterministic_without_outputs_raises(): + with pytest.raises(ValueError, match="deterministic_outputs cannot be empty"): + ToolAttribute( + id="t", + name="T", + description="d", + output_strategy=ToolOutputStrategy.DETERMINISTIC, + deterministic_outputs=[], + ) + + +def test_tool_attribute_generated_without_output_raises(): + with pytest.raises(ValueError, match="generated_output must be provided"): + ToolAttribute( + id="t", + name="T", + description="d", + output_strategy=ToolOutputStrategy.GENERATED, + generated_output=None, + ) + + +@pytest.mark.parametrize( + "field,value", + [("id", ""), ("name", ""), ("description", "")], +) +def test_tool_attribute_empty_field_raises(field, value): + with pytest.raises(ValueError, match=f"{field} cannot be empty"): + _make_generated_tool(**{field: value}) + + +def test_tool_attribute_normalizes_undefined_sample_rates(): + outputs = [ + DeterministicToolOutput(values={"a": 1}), + DeterministicToolOutput(values={"b": 2}), + ] + tool = _make_deterministic_tool(deterministic_outputs=outputs) + assert tool.deterministic_outputs[0].sample_rate == pytest.approx(0.5) + assert tool.deterministic_outputs[1].sample_rate == pytest.approx(0.5) + + +def test_tool_attribute_normalizes_mixed_sample_rates(): + outputs = [ + DeterministicToolOutput(values={"a": 1}, sample_rate=0.7), + DeterministicToolOutput(values={"b": 2}), + ] + tool = _make_deterministic_tool(deterministic_outputs=outputs) + assert tool.deterministic_outputs[0].sample_rate == pytest.approx(0.7) + assert tool.deterministic_outputs[1].sample_rate == pytest.approx(0.3) + + +def test_tool_attribute_sample_rates_exceeding_one_raises(): + outputs = [ + DeterministicToolOutput(values={"a": 1}, sample_rate=0.6), + DeterministicToolOutput(values={"b": 2}, sample_rate=0.6), + ] + with pytest.raises(ValueError, match="sample rates must sum to at most 1.0"): + _make_deterministic_tool(deterministic_outputs=outputs) + + +def _make_multiturn_attr(**overrides) -> MultiTurnAttribute: + defaults = dict( + id="chat", + min_turns=1, + max_turns=3, + role_instruction_messages={ + Role.USER: "You are a user.", + Role.ASSISTANT: "You are an assistant.", + }, + available_tools=[], + ) + defaults.update(overrides) + return MultiTurnAttribute(**defaults) # type: ignore[arg-type] + + +def test_synthesis_params_valid_tool_references(): + tool = _make_generated_tool(id="search") + mt = _make_multiturn_attr(available_tools=["search"]) + params = GeneralSynthesisParams( + tools=[tool], + multiturn_attributes=[mt], + ) + assert params.tools is not None + assert len(params.tools) == 1 + + +def test_synthesis_params_undefined_tool_reference_raises(): + tool = _make_generated_tool(id="search") + mt = _make_multiturn_attr(available_tools=["nonexistent"]) + with pytest.raises(ValueError, match="references unknown tool 'nonexistent'"): + GeneralSynthesisParams( + tools=[tool], + multiturn_attributes=[mt], + ) + + +def test_synthesis_params_available_tools_without_tools_defined_raises(): + mt = _make_multiturn_attr(available_tools=["search"]) + with pytest.raises(ValueError, match="tools must be defined"): + GeneralSynthesisParams( + tools=None, + multiturn_attributes=[mt], + ) + + +def test_synthesis_params_duplicate_tool_ids_raises(): + t1 = _make_generated_tool(id="dup") + t2 = _make_generated_tool(id="dup") + mt = _make_multiturn_attr(available_tools=["dup"]) + with pytest.raises(ValueError, match="duplicate tool ids"): + GeneralSynthesisParams( + tools=[t1, t2], + multiturn_attributes=[mt], + ) + + +# --- ToolOutputStrategy.ENVIRONMENT --- + + +def test_tool_output_strategy_environment_exists(): + assert ToolOutputStrategy.ENVIRONMENT == "environment" + + +def _make_environment_tool(**overrides: Any) -> ToolAttribute: + defaults: dict[str, Any] = dict( + id="tool_env", + name="EnvTool", + description="An environment tool", + output_strategy=ToolOutputStrategy.ENVIRONMENT, + environment="my_env", + read_only=True, + ) + defaults.update(overrides) + return ToolAttribute(**defaults) + + +def test_tool_attribute_environment_valid(): + tool = _make_environment_tool() + assert tool.environment == "my_env" + assert tool.read_only is True + assert tool.output_strategy == ToolOutputStrategy.ENVIRONMENT + + +def test_tool_attribute_environment_read_only_false(): + tool = _make_environment_tool(read_only=False) + assert tool.read_only is False + + +def test_tool_attribute_environment_strategy_without_env_raises(): + """ENVIRONMENT strategy requires environment field.""" + with pytest.raises(ValueError, match="environment must be set"): + ToolAttribute( + id="t", + name="T", + description="d", + output_strategy=ToolOutputStrategy.ENVIRONMENT, + ) + + +def test_tool_attribute_env_set_without_environment_strategy_raises(): + """Setting environment requires ENVIRONMENT strategy.""" + with pytest.raises(ValueError, match="output_strategy must be ENVIRONMENT"): + ToolAttribute( + id="t", + name="T", + description="d", + output_strategy=ToolOutputStrategy.GENERATED, + environment="some_env", + generated_output=GeneratedToolOutput(instruction="x"), + ) + + +def test_tool_attribute_environment_ignores_generated_output(): + """ENVIRONMENT tools don't need generated_output or deterministic_outputs.""" + tool = _make_environment_tool() + assert tool.generated_output is None + assert tool.deterministic_outputs == [] + + +# --- ToolEnvironmentAttribute --- + + +def test_environment_attribute_valid(): + env = ToolEnvironmentAttribute( + id="filesystem", + name="Filesystem", + description="A simple filesystem", + system_prompt="You manage a filesystem.", + ) + assert env.id == "filesystem" + assert env.state_schema is None + assert env.initial_state is None + + +def test_environment_attribute_with_schema_and_state(): + schema = { + "type": "object", + "properties": {"files": {"type": "object"}}, + "required": ["files"], + } + state = {"files": {}} + env = ToolEnvironmentAttribute( + id="fs", + name="FS", + description="d", + system_prompt="p", + state_schema=schema, + initial_state=state, + ) + assert env.state_schema == schema + assert env.initial_state == state + + +def test_environment_attribute_empty_id_raises(): + with pytest.raises(ValueError, match="id cannot be empty"): + ToolEnvironmentAttribute(id="", name="n", description="d", system_prompt="p") + + +def test_environment_attribute_empty_name_raises(): + with pytest.raises(ValueError, match="name cannot be empty"): + ToolEnvironmentAttribute(id="x", name="", description="d", system_prompt="p") + + +def test_environment_attribute_empty_description_raises(): + with pytest.raises(ValueError, match="description cannot be empty"): + ToolEnvironmentAttribute(id="x", name="n", description="", system_prompt="p") + + +def test_environment_attribute_empty_system_prompt_raises(): + with pytest.raises(ValueError, match="system_prompt cannot be empty"): + ToolEnvironmentAttribute(id="x", name="n", description="d", system_prompt="") + + +# --- GeneralSynthesisParams with environments --- + + +def test_general_synthesis_params_with_environments(): + env = ToolEnvironmentAttribute( + id="fs", name="FS", description="d", system_prompt="p" + ) + tool = _make_environment_tool(environment="fs") + params = GeneralSynthesisParams( + environments=[env], + tools=[tool], + multiturn_attributes=[_make_multiturn_attr(available_tools=["tool_env"])], + ) + assert params.environments is not None + assert len(params.environments) == 1 + + +def test_general_synthesis_params_tool_references_unknown_env_raises(): + tool = _make_environment_tool(environment="nonexistent") + with pytest.raises(ValueError, match="references unknown environment"): + GeneralSynthesisParams( + tools=[tool], + multiturn_attributes=[_make_multiturn_attr(available_tools=["tool_env"])], + ) + + +def test_general_synthesis_params_duplicate_env_ids_raises(): + env1 = ToolEnvironmentAttribute( + id="fs", name="FS1", description="d1", system_prompt="p1" + ) + env2 = ToolEnvironmentAttribute( + id="fs", name="FS2", description="d2", system_prompt="p2" + ) + with pytest.raises(ValueError, match="duplicate environment"): + GeneralSynthesisParams(environments=[env1, env2]) diff --git a/tests/unit/core/synthesis/test_synthesis_pipeline.py b/tests/unit/core/synthesis/test_synthesis_pipeline.py index 38e560beb0..f65862381a 100644 --- a/tests/unit/core/synthesis/test_synthesis_pipeline.py +++ b/tests/unit/core/synthesis/test_synthesis_pipeline.py @@ -324,7 +324,7 @@ def test_synthesize_with_multiturn_attributes( @patch("oumi.core.synthesis.synthesis_pipeline.DatasetPlanner") @patch("oumi.core.synthesis.synthesis_pipeline.AttributeTransformer") @patch("oumi.core.synthesis.synthesis_pipeline.AttributeSynthesizer") -def test_synthesize_with_multiturn_filtered_result_keeps_samples( +def test_synthesize_with_multiturn_filtered_result_drops_failed_samples( mock_attr_synth, mock_attr_transformer_class, mock_dataset_planner_class, @@ -333,7 +333,7 @@ def test_synthesize_with_multiturn_filtered_result_keeps_samples( mock_dataset_planner, mock_attribute_transformer, ): - """Test that filtered conversation results do not crash or drop samples.""" + """Test that failed multiturn rows are removed from final dataset.""" sample_dataset = [ {"id": "s1", "base": "v1"}, {"id": "s2", "base": "v2"}, @@ -360,13 +360,11 @@ def test_synthesize_with_multiturn_filtered_result_keeps_samples( pipeline = SynthesisPipeline(synthesis_config_with_multiturn_attributes) result = pipeline.synthesize() - assert len(result) == 2 + assert len(result) == 1 assert multiturn_attr.id in result[0] assert plan_key in result[0] - assert multiturn_attr.id not in result[1] - assert plan_key not in result[1] - assert result[1]["id"] == "s2" - assert result[1]["base"] == "v2" + assert result[0]["id"] == "s1" + assert result[0]["base"] == "v1" @patch("oumi.core.synthesis.synthesis_pipeline.DatasetPlanner") From 2df6f03cd6b27beb896eee19766f72babbe51280 Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Wed, 1 Apr 2026 20:35:07 -0700 Subject: [PATCH 02/18] fix: add missing docstring and apply ruff format --- src/oumi/core/synthesis/environment.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/oumi/core/synthesis/environment.py b/src/oumi/core/synthesis/environment.py index d243da5a5f..436a762f1a 100644 --- a/src/oumi/core/synthesis/environment.py +++ b/src/oumi/core/synthesis/environment.py @@ -106,6 +106,7 @@ class EnvironmentRegistry: """Builds environments once, then copies N times for parallel samples.""" def __init__(self): + """Initialize an empty registry.""" raise NotImplementedError def register_static(self, config: ToolEnvironmentAttribute) -> None: @@ -123,9 +124,7 @@ def build( """Build a fully populated environment.""" raise NotImplementedError - def create_copies( - self, env_id: str, n: int - ) -> list[GeneratedToolEnvironment]: + def create_copies(self, env_id: str, n: int) -> list[GeneratedToolEnvironment]: """Return n independent deepcopies of a built environment.""" raise NotImplementedError From 6505cbc586351801bf11a81cf6bbd2e4256d2711 Mon Sep 17 00:00:00 2001 From: Aniruddhan Ramesh <77236089+aniruddh-alt@users.noreply.github.com> Date: Wed, 1 Apr 2026 20:58:12 -0700 Subject: [PATCH 03/18] Potential fix for pull request finding 'Unused global variable' Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> --- src/oumi/core/synthesis/environment.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/oumi/core/synthesis/environment.py b/src/oumi/core/synthesis/environment.py index 436a762f1a..fa92720745 100644 --- a/src/oumi/core/synthesis/environment.py +++ b/src/oumi/core/synthesis/environment.py @@ -23,7 +23,6 @@ ) from oumi.core.types.conversation import Conversation, Message -_MAX_STATE_UPDATE_RETRIES = 2 _MAX_RESULT_RETRIES = 2 From a628ea7606d7ea98252582736c5a786f3afe6f7a Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Wed, 1 Apr 2026 21:11:31 -0700 Subject: [PATCH 04/18] fix: suppress pre-existing ASYNC240/ASYNC230 ruff errors in MCP files --- src/oumi/mcp/job_launcher.py | 4 ++-- src/oumi/mcp/job_logs.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/oumi/mcp/job_launcher.py b/src/oumi/mcp/job_launcher.py index 10f01f156f..852a1f1102 100644 --- a/src/oumi/mcp/job_launcher.py +++ b/src/oumi/mcp/job_launcher.py @@ -186,13 +186,13 @@ async def _launch_cloud( await evict_runtime(record.job_id) return "" - config_parent = str(Path(record.config_path).expanduser().resolve().parent) + config_parent = str(Path(record.config_path).expanduser().resolve().parent) # noqa: ASYNC240 _stage_cloud_config(record, rt, working_dir=config_parent) job_config = launcher.JobConfig.from_yaml(rt.staged_config_path) if not job_config.name: job_config.name = record.job_id if client_cwd and job_config.working_dir: - wd = Path(job_config.working_dir).expanduser() + wd = Path(job_config.working_dir).expanduser() # noqa: ASYNC240 if not wd.is_absolute(): job_config.working_dir = str((Path(client_cwd) / wd).resolve()) elif client_cwd and not job_config.working_dir: diff --git a/src/oumi/mcp/job_logs.py b/src/oumi/mcp/job_logs.py index 03df5e9d63..e336ecea36 100644 --- a/src/oumi/mcp/job_logs.py +++ b/src/oumi/mcp/job_logs.py @@ -74,7 +74,7 @@ async def tail_log_file( If the file does not exist yet, waits up to ``poll_interval`` between checks until it appears or *done_event* fires. """ - while not path.exists(): + while not path.exists(): # noqa: ASYNC240 if done_event.is_set(): return await asyncio.sleep(poll_interval) @@ -84,13 +84,13 @@ async def tail_log_file( while True: try: - size = path.stat().st_size + size = path.stat().st_size # noqa: ASYNC240 except OSError: size = 0 if size > position: try: - with open(path, encoding="utf-8", errors="replace") as f: + with open(path, encoding="utf-8", errors="replace") as f: # noqa: ASYNC230 f.seek(position) chunk = f.read() position = f.tell() @@ -105,7 +105,7 @@ async def tail_log_file( if done_event.is_set(): try: - with open(path, encoding="utf-8", errors="replace") as f: + with open(path, encoding="utf-8", errors="replace") as f: # noqa: ASYNC230 f.seek(position) remaining = f.read() except OSError: From 1b7060bf7c113c32b19c4f3d9f099ea60df4af0a Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Tue, 7 Apr 2026 15:00:22 -0700 Subject: [PATCH 05/18] Update tool_executor.py --- src/oumi/core/synthesis/tool_executor.py | 25 ------------------------ 1 file changed, 25 deletions(-) diff --git a/src/oumi/core/synthesis/tool_executor.py b/src/oumi/core/synthesis/tool_executor.py index d9832012b9..6b6698c6fe 100644 --- a/src/oumi/core/synthesis/tool_executor.py +++ b/src/oumi/core/synthesis/tool_executor.py @@ -21,26 +21,6 @@ from oumi.core.types.conversation import Conversation, Message -def clean_json_output(text: str) -> str: - """Strip markdown fences and extract clean JSON from LLM output.""" - raise NotImplementedError - - -def is_valid_json(text: str) -> bool: - """Return True if text is parseable as a JSON object or array.""" - raise NotImplementedError - - -def build_example_result(tool: ToolAttribute) -> str: - """Build a realistic example result from a tool's output_schema.""" - raise NotImplementedError - - -def _example_value(schema: dict[str, Any]) -> Any: - """Generate a placeholder example value from a JSON Schema property.""" - raise NotImplementedError - - @dataclass class ToolCallParsed: """Successfully parsed and validated tool call.""" @@ -135,11 +115,6 @@ def strip_tool_tags(text: str) -> str: """Remove any residual or tags.""" raise NotImplementedError - @staticmethod - def strip_bare_tool_json(text: str) -> str: - """Remove bare JSON objects that look like tool calls.""" - raise NotImplementedError - @staticmethod def sanitize_assistant_content(text: str) -> str: """Strip tool-call artifacts from assistant prose before export.""" From 088cd5a4266a4dbb690dc6937e741be4209d9c67 Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Wed, 8 Apr 2026 21:36:10 -0700 Subject: [PATCH 06/18] refactor: move environments to top-level package with typed tool hierarchy Restructure the agentic environment skeleton to follow an environment-first ownership model: - Create `oumi/environments/` as a first-class top-level package (alongside `oumi/inference/`) for use by synthesis, evaluation, and RL - Each environment file co-locates its typed tool subclass: DeterministicEnvironment + DeterministicTool, StatefulEnvironment + StatefulTool, StatelessEnvironment + StatelessTool - Replace the monolithic ToolAttribute with BaseTool + typed subclasses that carry only fields relevant to their environment type - Remove ToolExecutor from environments (synthesis-specific concern) - Environments own tool resolution: DeterministicEnvironment.resolve(), StatelessEnvironment.resolve_cached()/cache_result() - BaseEnvironment.create() now requires 'type' explicitly (no silent default to stateful) - EnvironmentConfig.resolve_tools() raises on unknown ids instead of silently returning empty lists - Config layer no longer imports from synthesis runtime --- docs/index.md | 1 + docs/user_guides/synth.md | 70 +++ src/oumi/core/configs/__init__.py | 20 + src/oumi/core/configs/environment_config.py | 132 ++++++ .../core/configs/params/synthesis_params.py | 138 +----- src/oumi/core/configs/params/tool_params.py | 275 ----------- src/oumi/core/configs/synthesis_config.py | 126 ++++- .../synthesis/conversation_synthesizer.py | 22 + src/oumi/core/synthesis/environment.py | 174 ------- src/oumi/core/synthesis/synthesis_pipeline.py | 6 +- src/oumi/core/synthesis/tool_executor.py | 160 ------- src/oumi/environments/__init__.py | 58 +++ src/oumi/environments/base_environment.py | 126 +++++ src/oumi/environments/base_tool.py | 64 +++ .../environments/deterministic_environment.py | 152 +++++++ src/oumi/environments/stateful_environment.py | 85 ++++ .../environments/stateless_environment.py | 141 ++++++ src/oumi/environments/types.py | 25 + .../core/configs/params/test_tool_params.py | 430 +++++++++--------- .../core/configs/test_synthesis_config.py | 264 ++++++++++- 20 files changed, 1526 insertions(+), 943 deletions(-) create mode 100644 src/oumi/core/configs/environment_config.py delete mode 100644 src/oumi/core/configs/params/tool_params.py delete mode 100644 src/oumi/core/synthesis/environment.py delete mode 100644 src/oumi/core/synthesis/tool_executor.py create mode 100644 src/oumi/environments/__init__.py create mode 100644 src/oumi/environments/base_environment.py create mode 100644 src/oumi/environments/base_tool.py create mode 100644 src/oumi/environments/deterministic_environment.py create mode 100644 src/oumi/environments/stateful_environment.py create mode 100644 src/oumi/environments/stateless_environment.py create mode 100644 src/oumi/environments/types.py diff --git a/docs/index.md b/docs/index.md index e2d64a7b41..fa0ae12736 100644 --- a/docs/index.md +++ b/docs/index.md @@ -106,6 +106,7 @@ faq/oom development/dev_setup development/contributing +development/agentic_synthesis_environments development/code_of_conduct development/style_guide development/docs_guide diff --git a/docs/user_guides/synth.md b/docs/user_guides/synth.md index 97437fe20f..71af6658ab 100644 --- a/docs/user_guides/synth.md +++ b/docs/user_guides/synth.md @@ -164,6 +164,76 @@ Ready to dive deeper? The sections below cover all available options in detail. --- +## Environment-First Tool Synthesis + +Agentic synthesis now follows an environment-first model. Tools do not declare an output strategy directly. Instead, each tool is bound to an environment, and the environment type defines the execution model. + +- **`stateful` environments** maintain shared JSON state. Tool calls read from or update that state, which is how consistency is preserved across turns. +- **`stateless` environments** generate tool results with an LLM. Responses are cached by input, so the same tool input can reuse the same generated output. +- **`deterministic` environments** behave like lookup tables. Matching inputs return responses from a predefined set without LLM generation. + +At the config level: + +- Environments own their tool definitions. +- Reusable environment catalogs live in top-level `environment_config` or `environment_config_path`. +- Tools do not declare an `environment` field. The parent environment owns the binding. +- `generated_output` is only used for tools in `stateless` environments. +- `deterministic_outputs` is only used for tools in `deterministic` environments. +- `read_only` is only meaningful for tools in `stateful` environments. + +Example: + +```yaml +environment_config: + environments: + - id: support_backend + name: Support Backend + description: Simulated support system state + type: stateful + system_prompt: You manage support system state. + tools: + - id: get_ticket + name: GetTicket + description: Read a ticket from the support backend. + read_only: true + + - id: faq_lookup + name: FAQ Lookup + description: Cached LLM-backed FAQ answers + type: stateless + system_prompt: Generate concise FAQ answers grounded in the tool contract. + tools: + - id: answer_faq + name: AnswerFAQ + description: Answer common support questions. + generated_output: + instruction: Return the FAQ answer for the given question. + + - id: policy_table + name: Policy Table + description: Predefined policy responses + type: deterministic + tools: + - id: get_refund_policy + name: GetRefundPolicy + description: Return the matching refund policy. + deterministic_outputs: + - input: + policy_type: standard + output: + policy: Standard 30-day refund policy + +strategy_params: + multiturn_attributes: + - id: support_chat + min_turns: 2 + max_turns: 4 + role_instruction_messages: + USER: You are a customer contacting support. + ASSISTANT: You are a helpful support agent. + available_tools: [get_ticket, answer_faq, get_refund_policy] +``` + ## Complete Configuration Reference ### Top-Level Parameters diff --git a/src/oumi/core/configs/__init__.py b/src/oumi/core/configs/__init__.py index 8ed6769f5d..bec87c8a5b 100644 --- a/src/oumi/core/configs/__init__.py +++ b/src/oumi/core/configs/__init__.py @@ -83,6 +83,7 @@ ) from oumi.core.configs.async_evaluation_config import AsyncEvaluationConfig from oumi.core.configs.base_config import BaseConfig +from oumi.core.configs.environment_config import EnvironmentConfig from oumi.core.configs.evaluation_config import EvaluationConfig from oumi.core.configs.inference_config import InferenceConfig from oumi.core.configs.inference_engine_type import InferenceEngineType @@ -158,12 +159,23 @@ from oumi.core.configs.synthesis_config import SynthesisConfig from oumi.core.configs.training_config import TrainingConfig from oumi.core.configs.tuning_config import TuningConfig +from oumi.environments import ( + BaseEnvironment, + BaseTool, + DeterministicEnvironment, + DeterministicToolOutput, + GeneratedToolOutput, + StatefulEnvironment, + StatelessEnvironment, + ToolEnvironmentType, +) __all__ = [ "AsyncEvaluationConfig", "AutoWrapPolicy", "BackwardPrefetch", "BaseConfig", + "BaseEnvironment", "DataParams", "DatasetParams", "DatasetSplit", @@ -176,6 +188,7 @@ "EvaluationBackend", "EvaluationConfig", "EvaluationTaskParams", + "EnvironmentConfig", "FSDPParams", "GenerationParams", "GrpoParams", @@ -211,17 +224,24 @@ "TuningParams", "AttributeCombination", "DatasetSourceParam", + "DeterministicToolOutput", + "DeterministicEnvironment", "DocumentSegmentationParams", "DocumentSource", "ExampleSource", + "GeneratedToolOutput", "GeneratedAttributePostprocessingParams", "GeneralSynthesisParams", "GeneratedAttribute", "SampledAttribute", "SampledAttributeValue", "SegmentationStrategy", + "StatefulEnvironment", + "StatelessEnvironment", "TextConversation", "TextMessage", + "BaseTool", + "ToolEnvironmentType", "TransformationStrategy", "TransformationType", "TransformedAttribute", diff --git a/src/oumi/core/configs/environment_config.py b/src/oumi/core/configs/environment_config.py new file mode 100644 index 0000000000..e9c6712379 --- /dev/null +++ b/src/oumi/core/configs/environment_config.py @@ -0,0 +1,132 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration for agentic environments.""" + +from dataclasses import dataclass, field +from typing import Any + +from oumi.core.configs.base_config import BaseConfig +from oumi.environments import BaseEnvironment, BaseTool + + +@dataclass +class EnvironmentConfig(BaseConfig): + """Top-level config for environment-first tool definitions.""" + + environments: list[Any] = field(default_factory=list) + """Reusable environments and their owned tools.""" + + def __post_init__(self): + """Verifies/populates params.""" + self.environments = [ + self._coerce_environment(environment) for environment in self.environments + ] + + env_ids: set[str] = set() + tool_ids: set[str] = set() + + for environment in self.environments: + if environment.id in env_ids: + raise ValueError( + f"EnvironmentConfig.environments contains duplicate " + f"environment id '{environment.id}'." + ) + env_ids.add(environment.id) + + for tool in environment.tools: + if tool.id in tool_ids: + raise ValueError( + f"EnvironmentConfig.environments contains duplicate " + f"tool id '{tool.id}'." + ) + tool_ids.add(tool.id) + + @property + def all_tools(self) -> list[BaseTool]: + """Flatten all tools across environments.""" + return [tool for environment in self.environments for tool in environment.tools] + + @property + def tool_environment_map(self) -> dict[str, str]: + """Map each tool id to the environment that owns it.""" + return { + tool.id: environment.id + for environment in self.environments + for tool in environment.tools + } + + def get_environment(self, environment_id: str) -> BaseEnvironment | None: + """Look up an environment by id.""" + for environment in self.environments: + if environment.id == environment_id: + return environment + return None + + def get_tool(self, tool_id: str) -> BaseTool | None: + """Look up a tool by id.""" + for tool in self.all_tools: + if tool.id == tool_id: + return tool + return None + + def resolve_tools( + self, + environment_ids: list[str] | None = None, + tool_ids: list[str] | None = None, + ) -> list[BaseTool]: + """Resolve tools from selected environments and optional tool ids. + + Raises: + ValueError: If any environment_id or tool_id is not found. + """ + all_env_ids = {env.id for env in self.environments} + + if environment_ids: + unknown_envs = set(environment_ids) - all_env_ids + if unknown_envs: + raise ValueError( + f"Unknown environment id(s): {sorted(unknown_envs)}. " + f"Defined: {sorted(all_env_ids)}" + ) + selected_environment_ids = environment_ids + else: + selected_environment_ids = list(all_env_ids) + + selected_environments = [ + environment + for environment in self.environments + if environment.id in set(selected_environment_ids) + ] + tools = [ + tool for environment in selected_environments for tool in environment.tools + ] + + if tool_ids: + available_tool_ids = {tool.id for tool in tools} + unknown_tools = set(tool_ids) - available_tool_ids + if unknown_tools: + raise ValueError( + f"Unknown tool id(s): {sorted(unknown_tools)}. " + f"Available in selected environments: " + f"{sorted(available_tool_ids)}" + ) + allowed_tool_ids = set(tool_ids) + tools = [tool for tool in tools if tool.id in allowed_tool_ids] + + return tools + + def _coerce_environment(self, environment: Any) -> BaseEnvironment: + """Coerce a raw dict or environment instance into a concrete environment.""" + return BaseEnvironment.create(environment) diff --git a/src/oumi/core/configs/params/synthesis_params.py b/src/oumi/core/configs/params/synthesis_params.py index d15bcbac97..6bb8eecbbe 100644 --- a/src/oumi/core/configs/params/synthesis_params.py +++ b/src/oumi/core/configs/params/synthesis_params.py @@ -22,10 +22,6 @@ from typing import Any from oumi.core.configs.params.base_params import BaseParams -from oumi.core.configs.params.tool_params import ( - ToolAttribute, - ToolEnvironmentAttribute, -) from oumi.core.types.conversation import Conversation, Message, Role _SUPPORTED_DATASET_FILE_TYPES = {".jsonl", ".json", ".csv", ".parquet", ".tsv", ".xlsx"} @@ -478,9 +474,11 @@ class MultiTurnAttribute: Allows user to specify custom instructions for the planner while planning out the conversation.""" + available_environments: list[str] = field(default_factory=list) + """List of environment ids availabe in this conversation.""" + available_tools: list[str] = field(default_factory=list) - """List of tool ids (from GeneralSynthesisParams.tools) available in this - conversation.""" + """List of tool ids available in this conversation.""" max_tool_calls_per_turn: int = 50 """Safety ceiling for tool calls per ASSISTANT turn. The agent naturally stops @@ -565,6 +563,18 @@ def __post_init__(self): raise ValueError( "MultiTurnAttribute.available_tools must be a list of strings." ) + if self.available_environments is not None: + if not isinstance(self.available_environments, list): + raise ValueError( + "MultiTurnAttribute.available_environments must be a list of " + "environment ids." + ) + for environment in self.available_environments: + if not isinstance(environment, str): + raise ValueError( + "MultiTurnAttribute.available_environments must be a list " + "of strings." + ) class TransformationType(str, Enum): @@ -793,45 +803,6 @@ class GeneralSynthesisParams(BaseParams): ] """ - tools: list[ToolAttribute] | None = None - """Tool definitions for agentic synthesis. - - Tools are defined here and referenced by id from - MultiTurnAttribute.available_tools. Each tool specifies its parameters - (input schema), output_schema, and output strategy (DETERMINISTIC or - GENERATED). - - Example:: - - tools = [ - ToolAttribute( - id="search", - name="WebSearch", - description="Searches the web for the given query.", - parameters={ - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "The search query.", - } - }, - "required": ["query"], - }, - output_strategy=ToolOutputStrategy.GENERATED, - generated_output=GeneratedToolOutput( - instruction="Return relevant search results.", - ), - ), - ] - """ - - environments: list[ToolEnvironmentAttribute] | None = None - """Environment definitions for stateful tool synthesis. - - Environments are stateful containers that tools operate on. Tools - reference environments by id via ToolAttribute.environment.""" - transformed_attributes: list[TransformedAttribute] | None = None """Transformation of existing attributes. @@ -974,80 +945,6 @@ def _check_passthrough_attribute_ids(self) -> None: self.passthrough_attributes = None return - def _check_environment_ids(self) -> None: - """Validate environment ids are unique.""" - if not self.environments: - self.environments = None - return - - env_ids = [env.id for env in self.environments] - if len(env_ids) != len(set(env_ids)): - seen: set[str] = set() - dupes = [e for e in env_ids if e in seen or seen.add(e)] # type: ignore[func-returns-value] - raise ValueError( - f"GeneralSynthesisParams.environments contains " - f"duplicate environment ids: {dupes}" - ) - - def _check_tool_environment_references(self) -> None: - """Validate that tool.environment references valid environment ids.""" - if not self.tools: - return - - env_ids: set[str] = set() - if self.environments: - env_ids = {env.id for env in self.environments} - - for tool in self.tools: - if tool.environment and tool.environment not in env_ids: - raise ValueError( - f"ToolAttribute '{tool.id}' references unknown " - f"environment '{tool.environment}'. " - f"Defined environment ids: {sorted(env_ids)}" - ) - - def _check_available_tools(self) -> None: - """Validate that available_tools ids reference defined tools.""" - if not self.multiturn_attributes: - self.multiturn_attributes = None - return - - # Collect all tool ids referenced by any multiturn attribute - all_referenced = [ - tool_id - for mt_attr in self.multiturn_attributes - for tool_id in mt_attr.available_tools - ] - - if not all_referenced: - if not self.tools: - self.tools = None - return - - if not self.tools: - raise ValueError( - "GeneralSynthesisParams.tools must be defined when " - "MultiTurnAttribute.available_tools is non-empty. " - f"Referenced tool ids: {sorted(set(all_referenced))}" - ) - - tool_id_list = [tool.id for tool in self.tools] - tool_ids = set(tool_id_list) - if len(tool_id_list) != len(tool_ids): - seen: set[str] = set() - dupes = [t for t in tool_id_list if t in seen or seen.add(t)] # type: ignore[func-returns-value] - raise ValueError( - f"GeneralSynthesisParams.tools contains duplicate tool ids: {dupes}" - ) - - for mt_attr in self.multiturn_attributes: - for tool_id in mt_attr.available_tools: - if tool_id not in tool_ids: - raise ValueError( - f"MultiTurnAttribute '{mt_attr.id}' references unknown " - f"tool '{tool_id}'. Defined tool ids: {sorted(tool_ids)}" - ) - def __post_init__(self): """Verifies/populates params.""" self._reserved_attribute_ids = self._get_reserved_attribute_ids() @@ -1061,6 +958,3 @@ def __post_init__(self): self._check_transformed_attribute_ids(all_attribute_ids) self._check_passthrough_attribute_ids() self._check_combination_sampling_sample_rates() - self._check_available_tools() - self._check_environment_ids() - self._check_tool_environment_references() diff --git a/src/oumi/core/configs/params/tool_params.py b/src/oumi/core/configs/params/tool_params.py deleted file mode 100644 index 73b0aae838..0000000000 --- a/src/oumi/core/configs/params/tool_params.py +++ /dev/null @@ -1,275 +0,0 @@ -# Copyright 2025 - Oumi -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Params for tool configuration in agentic synthesis.""" - -import math -from dataclasses import dataclass, field -from enum import Enum -from typing import Any - - -class ToolOutputStrategy(str, Enum): - """Strategy for how a tool produces its output.""" - - DETERMINISTIC = "deterministic" - """Output is selected from user-defined values. One value is selected - per conversation and reused for all calls to this tool within that - conversation.""" - - GENERATED = "generated" - """Output is generated by an LLM simulator conditioned on the tool schema, - arguments, conversation history, and optional environment state.""" - - ENVIRONMENT = "environment" - """Output comes from a stateful environment's step() method.""" - - -@dataclass -class DeterministicToolOutput: - """A possible canned output for a tool. - - Used with ToolOutputStrategy.DETERMINISTIC. One value is randomly selected - per conversation, weighted by sample_rate. Matches the sampling pattern - used by SampledAttributeValue in synthesis_params. - """ - - values: dict[str, Any] = field(default_factory=dict) - """Structured output values. Serialized to JSON by the system at output time.""" - - sample_rate: float | None = None - """Selection weight. If not specified, assumes uniform sampling - among all possible outputs. Must be between 0 and 1.""" - - def __post_init__(self): - """Verifies/populates params.""" - if not self.values: - raise ValueError("DeterministicToolOutput.values cannot be empty.") - if self.sample_rate is not None and ( - self.sample_rate < 0 or self.sample_rate > 1 - ): - raise ValueError( - "DeterministicToolOutput.sample_rate must be between 0 and 1." - ) - - -@dataclass -class GeneratedToolOutput: - """Configuration for LLM-simulated tool output. - - Used with ToolOutputStrategy.GENERATED. The LLM simulator receives - the tool definition, call arguments, conversation history, and this - config to produce a realistic output. - """ - - instruction: str - """Prompt hint for the LLM simulator. Guides what kind of output to produce. - Example: "Return eligibility based on order status and 30-day return window." - """ - - def __post_init__(self): - """Verifies/populates params.""" - if not self.instruction: - raise ValueError("GeneratedToolOutput.instruction cannot be empty.") - - -@dataclass -class ToolEnvironmentAttribute: - """Defines a stateful environment for tool synthesis. - - Environments are stateful containers that tools operate on. - Each environment maintains a JSON state document that evolves - as tools read from and write to it. - """ - - id: str - """Unique identifier. Referenced by ToolAttribute.environment.""" - - name: str - """Display name (e.g. 'Filesystem', 'PropertyManagementSystem').""" - - description: str - """What this environment represents. Used in LLM prompts for - state generation and tool result generation.""" - - system_prompt: str - """Instructions for the LLM managing this environment's state. - Included in all _generate_result and _update_state prompts.""" - - state_schema: dict[str, Any] | None = None - """JSON Schema the state must conform to. If None, auto-generated - from bound tools at initialization time.""" - - initial_state: dict[str, Any] | None = None - """Starting state document. If None, auto-generated from schema. - Validated against state_schema when both are provided.""" - - def __post_init__(self): - """Verifies/populates params.""" - if not self.id: - raise ValueError("ToolEnvironmentAttribute.id cannot be empty.") - if not self.name: - raise ValueError("ToolEnvironmentAttribute.name cannot be empty.") - if not self.description: - raise ValueError("ToolEnvironmentAttribute.description cannot be empty.") - if not self.system_prompt: - raise ValueError("ToolEnvironmentAttribute.system_prompt cannot be empty.") - - -@dataclass -class ToolAttribute: - """Defines a tool available for agentic synthesis. - - Tools are defined at the GeneralSynthesisParams level and referenced - by id from MultiTurnAttribute.available_tools. - - The output_strategy field determines which output config is used: - - DETERMINISTIC: uses deterministic_outputs (list, one sampled per conversation) - - GENERATED: uses generated_output (singular config for LLM simulator) - """ - - id: str - """Unique identifier for the tool. Referenced by available_tools.""" - - name: str - """Display name of the tool (e.g., "SearchOrders"). - Used in the tool catalog and output tool definitions.""" - - description: str - """What the tool does. Shown in the tool catalog, used in simulation - prompts, and included in output tool definitions.""" - - output_strategy: ToolOutputStrategy = ToolOutputStrategy.GENERATED - """How this tool produces its output.""" - - parameters: dict[str, Any] = field(default_factory=dict) - """JSON Schema for tool parameters. Matches the OpenAI/HuggingFace standard. - - Example:: - - { - "type": "object", - "properties": { - "customer_id": { - "type": "string", - "description": "The customer's ID" - } - }, - "required": ["customer_id"] - } - """ - - output_schema: dict[str, Any] = field(default_factory=dict) - """JSON Schema describing the tool's output structure. Shared across strategies. - - For DETERMINISTIC: used to validate that canned outputs conform to the schema. - For GENERATED: passed to the LLM simulator as structure guidance. - - Example:: - - { - "type": "object", - "properties": { - "eligible": {"type": "boolean"}, - "reason": {"type": "string"} - } - } - """ - - deterministic_outputs: list[DeterministicToolOutput] = field(default_factory=list) - """Possible canned outputs for DETERMINISTIC strategy. - - One is selected per conversation, weighted by sample_rate (uniform if unset). - Selection happens once at conversation initialization, not per-call.""" - - generated_output: GeneratedToolOutput | None = None - """Configuration for LLM-simulated output for GENERATED strategy. - - The LLM simulator uses this config along with the tool call arguments - and conversation context to produce a realistic output.""" - - environment: str | None = None - """References an ToolEnvironmentAttribute.id. When set, output_strategy - must be ENVIRONMENT and tool results come from env.step().""" - - read_only: bool = True - """Whether this tool only reads (does not modify) its environment's state. - Only relevant when environment is set.""" - - def __post_init__(self): - """Verifies/populates params.""" - if not self.id: - raise ValueError("ToolAttribute.id cannot be empty.") - if not self.name: - raise ValueError("ToolAttribute.name cannot be empty.") - if not self.description: - raise ValueError("ToolAttribute.description cannot be empty.") - if self.output_strategy == ToolOutputStrategy.ENVIRONMENT: - if not self.environment: - raise ValueError( - "ToolAttribute.environment must be set " - "when output_strategy is ENVIRONMENT." - ) - return - - if self.environment: - raise ValueError( - "ToolAttribute.output_strategy must be ENVIRONMENT " - "when environment is set." - ) - - if self.output_strategy == ToolOutputStrategy.DETERMINISTIC: - if not self.deterministic_outputs: - raise ValueError( - "ToolAttribute.deterministic_outputs cannot be empty " - "when output_strategy is DETERMINISTIC." - ) - self._normalize_sample_rates() - - elif self.output_strategy == ToolOutputStrategy.GENERATED: - if not self.generated_output: - raise ValueError( - "ToolAttribute.generated_output must be provided " - "when output_strategy is GENERATED." - ) - - def _normalize_sample_rates(self) -> None: - """Normalize sample rates for deterministic outputs. - - Matches the pattern used by SampledAttribute for consistency: - - Defined rates are summed - - Remaining probability is split uniformly among undefined rates - """ - sample_rates = [o.sample_rate for o in self.deterministic_outputs] - - defined_rate = 0.0 - undefined_count = 0 - for rate in sample_rates: - if rate is not None: - defined_rate += rate - else: - undefined_count += 1 - - if defined_rate > 1.0 and not math.isclose(defined_rate, 1.0): - raise ValueError( - "ToolAttribute.deterministic_outputs sample rates " - "must sum to at most 1.0." - ) - - if undefined_count > 0: - remaining = max(0.0, 1.0 - defined_rate) - per_undefined = remaining / undefined_count - for output_value in self.deterministic_outputs: - if output_value.sample_rate is None: - output_value.sample_rate = per_undefined diff --git a/src/oumi/core/configs/synthesis_config.py b/src/oumi/core/configs/synthesis_config.py index 2ec7646d9b..10a22a67a6 100644 --- a/src/oumi/core/configs/synthesis_config.py +++ b/src/oumi/core/configs/synthesis_config.py @@ -14,10 +14,16 @@ from dataclasses import dataclass, field from enum import Enum +from pathlib import Path from oumi.core.configs.base_config import BaseConfig +from oumi.core.configs.environment_config import EnvironmentConfig from oumi.core.configs.inference_config import InferenceConfig -from oumi.core.configs.params.synthesis_params import GeneralSynthesisParams +from oumi.core.configs.params.synthesis_params import ( + GeneralSynthesisParams, + MultiTurnAttribute, +) +from oumi.environments import BaseEnvironment, BaseTool class SynthesisStrategy(str, Enum): @@ -45,6 +51,12 @@ class SynthesisConfig(BaseConfig): ) """The synthesis strategy parameters to use.""" + environment_config: EnvironmentConfig | None = None + """Reusable environment-first tool configuration.""" + + environment_config_path: str | None = None + """Optional path to an EnvironmentConfig YAML file.""" + inference_config: InferenceConfig = field(default_factory=InferenceConfig) """The inference configuration to use.""" @@ -74,3 +86,115 @@ def __post_init__(self): if not self.output_path.endswith(".jsonl"): raise ValueError("Output path must end with .jsonl.") + + self.environment_config = self._resolve_environment_config() + self._validate_available_tooling() + + def _resolve_environment_config(self) -> EnvironmentConfig | None: + """Resolve top-level environment configuration.""" + if ( + self.environment_config is not None + and self.environment_config_path is not None + ): + raise ValueError( + "SynthesisConfig.environment_config and " + "SynthesisConfig.environment_config_path cannot both be set." + ) + + if self.environment_config is not None: + return self.environment_config + + if self.environment_config_path is not None: + if self.environment_config_path == "": + raise ValueError( + "SynthesisConfig.environment_config_path cannot be empty." + ) + + config_path = Path(self.environment_config_path) + if not config_path.exists(): + raise ValueError( + f"Environment config path does not exist: " + f"{self.environment_config_path}" + ) + return EnvironmentConfig.from_yaml(config_path) + + return None + + def resolve_multiturn_environments( + self, multiturn_attribute: MultiTurnAttribute + ) -> list[BaseEnvironment]: + """Resolve the environments available to a multiturn attribute.""" + if self.environment_config is None: + return [] + + if not multiturn_attribute.available_environments: + return list(self.environment_config.environments) + + resolved_environments: list[BaseEnvironment] = [] + for environment_id in multiturn_attribute.available_environments: + environment = self.environment_config.get_environment(environment_id) + if environment is None: + raise ValueError( + f"MultiTurnAttribute '{multiturn_attribute.id}' references unknown " + f"environment '{environment_id}'. Defined environment ids: " + f"{sorted(env.id for env in self.environment_config.environments)}" + ) + resolved_environments.append(environment) + return resolved_environments + + def resolve_multiturn_tools( + self, multiturn_attribute: MultiTurnAttribute + ) -> list[BaseTool]: + """Resolve the tools available to a multiturn attribute.""" + if self.environment_config is None: + return [] + + environments = self.resolve_multiturn_environments(multiturn_attribute) + return self.environment_config.resolve_tools( + environment_ids=[environment.id for environment in environments], + tool_ids=multiturn_attribute.available_tools or None, + ) + + def _validate_available_tooling(self) -> None: + """Validate multiturn environment/tool selections against the catalog.""" + if not self.strategy_params.multiturn_attributes: + return + + all_referenced_tools = [ + tool_id + for mt_attr in self.strategy_params.multiturn_attributes + for tool_id in mt_attr.available_tools + ] + all_referenced_environments = [ + environment_id + for mt_attr in self.strategy_params.multiturn_attributes + for environment_id in mt_attr.available_environments + ] + if not all_referenced_tools and not all_referenced_environments: + return + + if self.environment_config is None: + raise ValueError( + "Environment or tool references require " + "SynthesisConfig.environment_config, or " + "SynthesisConfig.environment_config_path." + ) + + for mt_attr in self.strategy_params.multiturn_attributes: + selected_environments = self.resolve_multiturn_environments(mt_attr) + selected_environment_ids = { + environment.id for environment in selected_environments + } + selected_tools = self.environment_config.resolve_tools( + environment_ids=list(selected_environment_ids) + ) + selected_tool_ids = {tool.id for tool in selected_tools} + + for tool_id in mt_attr.available_tools: + if tool_id not in selected_tool_ids: + raise ValueError( + f"MultiTurnAttribute '{mt_attr.id}' references unknown " + f"tool '{tool_id}' for environments " + f"{sorted(selected_environment_ids)}. Defined tool ids: " + f"{sorted(selected_tool_ids)}" + ) diff --git a/src/oumi/core/synthesis/conversation_synthesizer.py b/src/oumi/core/synthesis/conversation_synthesizer.py index 59f09082a8..7717fefc01 100644 --- a/src/oumi/core/synthesis/conversation_synthesizer.py +++ b/src/oumi/core/synthesis/conversation_synthesizer.py @@ -15,6 +15,7 @@ import random from oumi.builders.inference_engines import build_inference_engine +from oumi.core.configs.environment_config import EnvironmentConfig from oumi.core.configs.inference_config import InferenceConfig from oumi.core.configs.inference_engine_type import InferenceEngineType from oumi.core.configs.params.synthesis_params import ( @@ -23,6 +24,7 @@ ) from oumi.core.synthesis.attribute_formatter import AttributeFormatter from oumi.core.types.conversation import Conversation, Message, Role +from oumi.environments import BaseTool from oumi.utils.logging import logger from oumi.utils.str_utils import extract_json @@ -39,9 +41,11 @@ def __init__( self, params: GeneralSynthesisParams, inference_config: InferenceConfig, + environment_config: EnvironmentConfig | None = None, ): """Initialize the synthesizer.""" self._params = params + self._environment_config = environment_config self._formatter = AttributeFormatter(params) self._inference_engine = build_inference_engine( @@ -52,6 +56,17 @@ def __init__( self._inference_config = inference_config self._default_turn_order = [Role.USER, Role.ASSISTANT] + def _resolve_available_tools( + self, multiturn_attribute: MultiTurnAttribute + ) -> list[BaseTool]: + """Resolve tools for a multiturn attribute from selected environments.""" + if self._environment_config is None: + return [] + return self._environment_config.resolve_tools( + environment_ids=multiturn_attribute.available_environments or None, + tool_ids=multiturn_attribute.available_tools or None, + ) + def _validate_roles(self, multiturn_attribute: MultiTurnAttribute) -> None: """Validate that required roles have corresponding personas. @@ -98,6 +113,13 @@ def synthesize( f"Synthesizing {len(samples)} conversations for " f"attribute '{multiturn_attributes.id}'" ) + available_tools = self._resolve_available_tools(multiturn_attributes) + if available_tools: + logger.debug( + "Resolved tools for '%s': %s", + multiturn_attributes.id, + [tool.id for tool in available_tools], + ) samples = self._plan_samples(samples, multiturn_attributes) conversations = self._synthesize_all_samples(samples, multiturn_attributes) diff --git a/src/oumi/core/synthesis/environment.py b/src/oumi/core/synthesis/environment.py deleted file mode 100644 index fa92720745..0000000000 --- a/src/oumi/core/synthesis/environment.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright 2025 - Oumi -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Stateful environment and registry for agentic tool synthesis.""" - -from collections.abc import Callable -from typing import Any - -from oumi.core.configs.params.tool_params import ( - ToolAttribute, - ToolEnvironmentAttribute, -) -from oumi.core.types.conversation import Conversation, Message - -_MAX_RESULT_RETRIES = 2 - - -class GeneratedToolEnvironment: - """Stateful environment for tool synthesis. - - Maintains a JSON state document that evolves as tools read/write to it. - Builds prompts and applies responses — does not call inference itself. - """ - - def __init__(self, config: ToolEnvironmentAttribute): - """Initialize the environment with config.""" - raise NotImplementedError - - @property - def state(self) -> dict[str, Any]: - """Current state of the environment.""" - raise NotImplementedError - - def summarize_for_planner(self) -> dict[str, Any]: - """Return a compact state view for planner grounding.""" - raise NotImplementedError - - def set_state(self, state: dict[str, Any], validate: bool = True) -> bool: - """Set state, optionally skipping schema validation.""" - raise NotImplementedError - - def set_schema(self, schema: dict[str, Any]) -> None: - """Set the state schema.""" - raise NotImplementedError - - def build_result_prompt( - self, - tool: ToolAttribute, - arguments: dict[str, Any], - retry: bool = False, - ) -> Conversation: - """Build the prompt for generating a tool result.""" - raise NotImplementedError - - def build_write_state_update_prompt( - self, - tool: ToolAttribute, - arguments: dict[str, Any], - retry: bool = False, - retry_error: str | None = None, - ) -> Conversation: - """Build prompt for state-first write.""" - raise NotImplementedError - - def build_write_result_prompt( - self, - tool: ToolAttribute, - arguments: dict[str, Any], - patch_ops: list[dict[str, Any]], - patch_succeeded: bool, - pre_patch_state: dict[str, Any] | None = None, - retry: bool = False, - retry_error: str | None = None, - ) -> Conversation: - """Build prompt for generating a write tool result after state update.""" - raise NotImplementedError - - def apply_result(self, response: Conversation) -> str: - """Extract the tool result text from an inference response.""" - raise NotImplementedError - - def apply_state_update_returning_patch( - self, response: Conversation - ) -> tuple[bool, list[dict[str, Any]], str | None]: - """Parse and apply a write-state update. Returns (succeeded, ops, error).""" - raise NotImplementedError - - def apply_state_update(self, response: Conversation) -> bool: - """Convenience wrapper: apply patch and return success bool.""" - raise NotImplementedError - - -class EnvironmentRegistry: - """Builds environments once, then copies N times for parallel samples.""" - - def __init__(self): - """Initialize an empty registry.""" - raise NotImplementedError - - def register_static(self, config: ToolEnvironmentAttribute) -> None: - """Register an environment that needs no LLM generation.""" - raise NotImplementedError - - def build( - self, - config: ToolEnvironmentAttribute, - tools: list[ToolAttribute], - inference_engine: Any, - inference_config: Any, - scenario_context: str | None = None, - ) -> None: - """Build a fully populated environment.""" - raise NotImplementedError - - def create_copies(self, env_id: str, n: int) -> list[GeneratedToolEnvironment]: - """Return n independent deepcopies of a built environment.""" - raise NotImplementedError - - -def resolve_env_tool( - tool_executor: Any, - tool_call: dict, - idx_envs: dict[str, GeneratedToolEnvironment] | None, -) -> tuple[GeneratedToolEnvironment | None, ToolAttribute | None]: - """Look up (env, tool) for an environment-bound tool call.""" - raise NotImplementedError - - -def is_env_tool_missing_env(tool_executor: Any, tool_call: dict) -> bool: - """Return True if tool_call targets an ENVIRONMENT tool whose env is missing.""" - raise NotImplementedError - - -def serialize_env_states(envs: dict[str, GeneratedToolEnvironment]) -> str: - """Serialize planner-oriented environment summaries as stable JSON.""" - raise NotImplementedError - - -def init_sample_environments( - samples: list[dict], - tools: list[ToolAttribute], - env_configs: dict[str, ToolEnvironmentAttribute], - formatter: Any, - inference_engine: Any, - inference_config: Any, -) -> list[dict[str, GeneratedToolEnvironment] | None]: - """Create per-sample environments, reusing builds when config is identical.""" - raise NotImplementedError - - -def process_env_tool_calls( - env_items: list[ - tuple[int, str, dict, str, GeneratedToolEnvironment, ToolAttribute] - ], - env_result_prompts: list[Conversation], - turn_tool_msgs: dict[int, list[Message]], - output_messages: list[list[dict]], - inference_engine: Any, - inference_config: Any, - record_fn: Callable[..., None], -) -> list[int]: - """Execute batched env tool calls.""" - raise NotImplementedError diff --git a/src/oumi/core/synthesis/synthesis_pipeline.py b/src/oumi/core/synthesis/synthesis_pipeline.py index 098df6509b..1950b7813d 100644 --- a/src/oumi/core/synthesis/synthesis_pipeline.py +++ b/src/oumi/core/synthesis/synthesis_pipeline.py @@ -46,7 +46,11 @@ def __init__(self, config: SynthesisConfig): else None ) self._conversation_synthesizer = ( - ConversationSynthesizer(config.strategy_params, config.inference_config) + ConversationSynthesizer( + config.strategy_params, + config.inference_config, + environment_config=config.environment_config, + ) if config.strategy_params.multiturn_attributes else None ) diff --git a/src/oumi/core/synthesis/tool_executor.py b/src/oumi/core/synthesis/tool_executor.py deleted file mode 100644 index 6b6698c6fe..0000000000 --- a/src/oumi/core/synthesis/tool_executor.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright 2025 - Oumi -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tool executor for agentic synthesis.""" - -from dataclasses import dataclass -from typing import Any - -from oumi.core.configs.params.tool_params import ToolAttribute -from oumi.core.types.conversation import Conversation, Message - - -@dataclass -class ToolCallParsed: - """Successfully parsed and validated tool call.""" - - tool_call: dict[str, Any] - - -@dataclass -class ToolCallError: - """Structured error from parsing or validation.""" - - error_json: str - tool_name: str | None - - -ToolCallResult = ToolCallParsed | ToolCallError | None - - -class ToolExecutor: - """Parses tool calls from LLM responses and resolves tool outputs.""" - - def __init__(self, tools: list[ToolAttribute]): - """Initialize the tool executor with available tools.""" - raise NotImplementedError - - def get_tool_by_name(self, name: str) -> ToolAttribute | None: - """Look up a tool by its display name.""" - raise NotImplementedError - - def parse_and_validate_tool_call(self, response: str) -> ToolCallResult: - """Parse tags from response and validate arguments.""" - raise NotImplementedError - - def sample_deterministic_outputs( - self, tools: list[ToolAttribute] - ) -> dict[str, str]: - """Sample one deterministic output per DETERMINISTIC tool.""" - raise NotImplementedError - - def resolve_output( - self, - tool_call: dict[str, Any], - deterministic_selections: dict[str, str], - ) -> str | None: - """Resolve a tool call to its output. None for GENERATED tools.""" - raise NotImplementedError - - def build_generated_simulator_prompt( - self, - tool_call: dict[str, Any], - conversation_history: list[Message] | None = None, - ) -> Conversation: - """Build LLM prompt for simulating a GENERATED tool's output.""" - raise NotImplementedError - - @staticmethod - def build_capability_summary(tools: list[ToolAttribute]) -> str: - """Build a planner-facing capability summary.""" - raise NotImplementedError - - @staticmethod - def build_tool_catalog(tools: list[ToolAttribute]) -> str: - """Build a formatted tool catalog with schemas and usage examples.""" - raise NotImplementedError - - @staticmethod - def build_tool_definitions( - tools: list[ToolAttribute], - ) -> list[dict[str, Any]]: - """Convert ToolAttributes to standard tool definitions for output.""" - raise NotImplementedError - - @staticmethod - def format_tool_call_message( - tool_call: dict[str, Any], - call_id: str, - ) -> dict[str, Any]: - """Format a parsed tool call as a standard OpenAI assistant message.""" - raise NotImplementedError - - @staticmethod - def format_tool_result_message( - call_id: str, - content: str, - name: str, - ) -> dict[str, Any]: - """Format a tool result as a standard OpenAI tool message.""" - raise NotImplementedError - - @staticmethod - def strip_tool_tags(text: str) -> str: - """Remove any residual or tags.""" - raise NotImplementedError - - @staticmethod - def sanitize_assistant_content(text: str) -> str: - """Strip tool-call artifacts from assistant prose before export.""" - raise NotImplementedError - - @staticmethod - def build_tool_few_shot(tools: list[ToolAttribute]) -> list[Message]: - """Build few-shot messages demonstrating a correct tool-call exchange.""" - raise NotImplementedError - - @staticmethod - def build_tool_turn_info( - current_turn: int, - target_turns: int, - turn_instruction: str, - max_calls_reached: bool, - ) -> str: - """Build the turn-level user message for assistant tool turns.""" - raise NotImplementedError - - @staticmethod - def build_prose_turn_info( - current_turn: int, - target_turns: int, - role: str, - turn_instruction: str, - ) -> str: - """Build turn-level user message for non-tool turns.""" - raise NotImplementedError - - @staticmethod - def record_tool_result( - idx: int, - raw_text: str, - tool_call: dict, - call_id: str, - result: str, - turn_tool_msgs: dict[int, list[Message]], - output_messages: list[list[dict]], - env_state: dict | None = None, - ) -> None: - """Append a tool call + result to conversation history and output.""" - raise NotImplementedError diff --git a/src/oumi/environments/__init__.py b/src/oumi/environments/__init__.py new file mode 100644 index 0000000000..2322d897c6 --- /dev/null +++ b/src/oumi/environments/__init__.py @@ -0,0 +1,58 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Environments for agentic tool interactions. + +Environments are simulated worlds that agents interact with via tools. +Consumers include synthesis (training data generation), evaluation +(agent testing), and RL (reward-driven training). + +Each environment type defines how tool calls are resolved: + +- **StatefulEnvironment**: mutable JSON state across calls. +- **StatelessEnvironment**: LLM-generated outputs with optional caching. +- **DeterministicEnvironment**: fixed input-to-output lookup tables. +""" + +from oumi.environments.base_environment import BaseEnvironment +from oumi.environments.base_tool import BaseTool +from oumi.environments.deterministic_environment import ( + DeterministicEnvironment, + DeterministicTool, + DeterministicToolOutput, +) +from oumi.environments.stateful_environment import ( + StatefulEnvironment, + StatefulTool, +) +from oumi.environments.stateless_environment import ( + GeneratedToolOutput, + StatelessEnvironment, + StatelessTool, +) +from oumi.environments.types import ToolEnvironmentType + +__all__ = [ + "BaseEnvironment", + "BaseTool", + "DeterministicEnvironment", + "DeterministicTool", + "DeterministicToolOutput", + "GeneratedToolOutput", + "StatefulEnvironment", + "StatefulTool", + "StatelessEnvironment", + "StatelessTool", + "ToolEnvironmentType", +] diff --git a/src/oumi/environments/base_environment.py b/src/oumi/environments/base_environment.py new file mode 100644 index 0000000000..879d3bf2c1 --- /dev/null +++ b/src/oumi/environments/base_environment.py @@ -0,0 +1,126 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Abstract base class for tool environments. + +Environments are simulated worlds that agents interact with via tools. +They are used by synthesis (to generate training data), evaluation +(to test agent behaviour), and RL (to provide reward signals). +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Mapping +from dataclasses import dataclass, field, fields +from typing import Any, ClassVar + +from oumi.core.configs.params.base_params import BaseParams +from oumi.environments.base_tool import BaseTool +from oumi.environments.types import ToolEnvironmentType + + +@dataclass +class BaseEnvironment(BaseParams, ABC): + """Abstract base class for tool environments. + + Each environment owns a set of tools and defines how tool calls are + resolved. Subclasses implement the concrete execution model and + coerce raw tool definitions into their typed tool subclass. + """ + + _registry: ClassVar[dict[ToolEnvironmentType, type[BaseEnvironment]]] = {} + + id: str + name: str + description: str + tools: list[BaseTool] = field(default_factory=list) + type: ToolEnvironmentType = field(init=False) + + def __init_subclass__(cls, **kwargs): + """Register subclass in the environment type registry.""" + super().__init_subclass__(**kwargs) + environment_type = getattr(cls, "ENVIRONMENT_TYPE", None) + if environment_type is not None: + cls._registry[environment_type] = cls + + def __post_init__(self): + """Validate common fields and coerce tools.""" + if not self.id: + raise ValueError(f"{type(self).__name__}.id cannot be empty.") + if not self.name: + raise ValueError(f"{type(self).__name__}.name cannot be empty.") + if not self.description: + raise ValueError(f"{type(self).__name__}.description cannot be empty.") + self.tools = self._coerce_tools(self.tools) + self._validate_unique_tool_ids() + self._validate_type_specific() + + def _validate_unique_tool_ids(self) -> None: + tool_ids: set[str] = set() + for tool in self.tools: + if tool.id in tool_ids: + raise ValueError( + f"{type(self).__name__} '{self.id}' contains duplicate " + f"tool id '{tool.id}'." + ) + tool_ids.add(tool.id) + + @abstractmethod + def _coerce_tools(self, tools: list[Any]) -> list[BaseTool]: + """Coerce raw tool definitions into this environment's typed tool class.""" + + @abstractmethod + def _validate_type_specific(self) -> None: + """Validate fields specific to the environment subtype.""" + + @classmethod + def create(cls, raw: Mapping[str, Any] | BaseEnvironment) -> BaseEnvironment: + """Create a concrete environment from raw config data. + + Raises: + TypeError: If raw is not a mapping or BaseEnvironment. + ValueError: If type is missing or unsupported. + """ + if isinstance(raw, BaseEnvironment): + return raw + if not isinstance(raw, Mapping): + raise TypeError( + "Environment definitions must be environment objects or mappings, " + f"got {type(raw)}" + ) + raw_type = raw.get("type") + if raw_type is None: + raise ValueError( + "Environment definition must include a 'type' field. " + f"Supported types: {[t.value for t in ToolEnvironmentType]}" + ) + if isinstance(raw_type, ToolEnvironmentType): + environment_type = raw_type + elif isinstance(raw_type, str): + try: + environment_type = ToolEnvironmentType(raw_type) + except ValueError: + environment_type = ToolEnvironmentType[raw_type] + else: + environment_type = ToolEnvironmentType(raw_type) + environment_cls = cls._registry.get(environment_type) + if environment_cls is None: + raise ValueError(f"Unsupported environment type: {environment_type}") + init_fields = { + field_def.name for field_def in fields(environment_cls) if field_def.init + } + return environment_cls( + **{key: value for key, value in raw.items() if key in init_fields} + ) diff --git a/src/oumi/environments/base_tool.py b/src/oumi/environments/base_tool.py new file mode 100644 index 0000000000..71e575e369 --- /dev/null +++ b/src/oumi/environments/base_tool.py @@ -0,0 +1,64 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base tool class shared by all environment types.""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Any + +from oumi.core.configs.params.base_params import BaseParams + + +@dataclass +class BaseTool(BaseParams): + """Common fields for all tools exposed by an environment.""" + + id: str + name: str + description: str + parameters: dict[str, Any] = field(default_factory=dict) + + @classmethod + def create(cls, raw: Mapping[str, Any] | BaseTool) -> BaseTool: + """Create a tool from raw config data. + + Returns a ``BaseTool`` with only the common fields. Environment + subclasses call their own typed factory (e.g. + ``DeterministicTool.create``) to get the full subclass. + """ + if isinstance(raw, BaseTool): + return raw + if not isinstance(raw, Mapping): + raise TypeError( + f"Tool definitions must be tool objects or mappings, " + f"got {type(raw)}" + ) + return cls( + id=raw["id"], + name=raw["name"], + description=raw["description"], + parameters=raw.get("parameters", {}), + ) + + def __post_init__(self): + """Validate common tool fields.""" + if not self.id: + raise ValueError(f"{type(self).__name__}.id cannot be empty.") + if not self.name: + raise ValueError(f"{type(self).__name__}.name cannot be empty.") + if not self.description: + raise ValueError(f"{type(self).__name__}.description cannot be empty.") diff --git a/src/oumi/environments/deterministic_environment.py b/src/oumi/environments/deterministic_environment.py new file mode 100644 index 0000000000..0a04c2fc40 --- /dev/null +++ b/src/oumi/environments/deterministic_environment.py @@ -0,0 +1,152 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Deterministic environment with fixed lookup responses.""" + +from __future__ import annotations + +import json +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Any, ClassVar + +from oumi.core.configs.params.base_params import BaseParams +from oumi.environments.base_environment import BaseEnvironment +from oumi.environments.base_tool import BaseTool +from oumi.environments.types import ToolEnvironmentType + + +@dataclass +class DeterministicToolOutput(BaseParams): + """An input-to-output mapping for a deterministic tool.""" + + input: dict[str, Any] = field(default_factory=dict) + output: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + """Validate the input and output fields are not empty.""" + if not self.input: + raise ValueError("DeterministicToolOutput.input cannot be empty.") + if not self.output: + raise ValueError("DeterministicToolOutput.output cannot be empty.") + + def matches(self, arguments: dict[str, Any]) -> bool: + """Check if the input matches the given arguments.""" + return json.dumps(self.input, sort_keys=True) == json.dumps( + arguments, sort_keys=True + ) + + +@dataclass +class DeterministicTool(BaseTool): + """Tool with fixed input-to-output lookup responses.""" + + deterministic_outputs: list[DeterministicToolOutput] = field(default_factory=list) + + def __post_init__(self): + """Validate deterministic tool fields.""" + super().__post_init__() + if not self.deterministic_outputs: + raise ValueError( + f"DeterministicTool '{self.id}' must have at least one " + f"deterministic_output entry." + ) + self._check_deterministic_duplicates() + + def _check_deterministic_duplicates(self) -> None: + seen: set[str] = set() + for entry in self.deterministic_outputs: + key = json.dumps(entry.input, sort_keys=True) + if key in seen: + raise ValueError( + f"DeterministicTool '{self.id}' has duplicate " + f"deterministic input entry: {entry.input}" + ) + seen.add(key) + + def resolve_deterministic( + self, arguments: dict[str, Any] + ) -> dict[str, Any] | None: + """Resolve a deterministic output for the given arguments.""" + for entry in self.deterministic_outputs: + if entry.matches(arguments): + return entry.output + return None + + @classmethod + def create(cls, raw: Mapping[str, Any] | BaseTool) -> DeterministicTool: + """Create a DeterministicTool from raw config data.""" + if isinstance(raw, DeterministicTool): + return raw + if isinstance(raw, BaseTool): + raise TypeError( + f"Cannot coerce {type(raw).__name__} to DeterministicTool. " + f"Use a mapping with 'deterministic_outputs'." + ) + if not isinstance(raw, Mapping): + raise TypeError( + f"Tool definitions must be tool objects or mappings, " + f"got {type(raw)}" + ) + deterministic_outputs = [ + entry + if isinstance(entry, DeterministicToolOutput) + else DeterministicToolOutput(**entry) + for entry in raw.get("deterministic_outputs", []) + ] + return cls( + id=raw["id"], + name=raw["name"], + description=raw["description"], + parameters=raw.get("parameters", {}), + deterministic_outputs=deterministic_outputs, + ) + + +@dataclass +class DeterministicEnvironment(BaseEnvironment): + """Environment that resolves tools from fixed lookups. + + Each tool is a ``DeterministicTool`` with a list of input-to-output + mappings. The environment owns the resolution logic. + """ + + ENVIRONMENT_TYPE: ClassVar[ToolEnvironmentType] = ToolEnvironmentType.DETERMINISTIC + type: ToolEnvironmentType = field( + init=False, default=ToolEnvironmentType.DETERMINISTIC + ) + tools: list[DeterministicTool] = field(default_factory=list) # type: ignore[assignment] + + def _coerce_tools(self, tools: list[Any]) -> list[DeterministicTool]: + """Coerce raw tool definitions into DeterministicTool instances.""" + return [DeterministicTool.create(t) for t in tools] + + def _validate_type_specific(self) -> None: + return + + def resolve( + self, tool_id: str, arguments: dict[str, Any] + ) -> dict[str, Any] | None: + """Resolve a deterministic tool call to its output. + + Raises: + ValueError: If tool_id is not found in this environment. + """ + for tool in self.tools: + if tool.id == tool_id: + return tool.resolve_deterministic(arguments) + raise ValueError( + f"Tool '{tool_id}' not found in environment '{self.id}'. " + f"Available tools: {[t.id for t in self.tools]}" + ) diff --git a/src/oumi/environments/stateful_environment.py b/src/oumi/environments/stateful_environment.py new file mode 100644 index 0000000000..93688f1c7e --- /dev/null +++ b/src/oumi/environments/stateful_environment.py @@ -0,0 +1,85 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stateful environment with mutable shared state.""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Any, ClassVar + +from oumi.environments.base_environment import BaseEnvironment +from oumi.environments.base_tool import BaseTool +from oumi.environments.types import ToolEnvironmentType + + +@dataclass +class StatefulTool(BaseTool): + """Tool bound to a stateful environment.""" + + output_schema: dict[str, Any] = field(default_factory=dict) + read_only: bool = True + + @classmethod + def create(cls, raw: Mapping[str, Any] | BaseTool) -> StatefulTool: + """Create a StatefulTool from raw config data.""" + if isinstance(raw, StatefulTool): + return raw + if isinstance(raw, BaseTool): + return cls( + id=raw.id, + name=raw.name, + description=raw.description, + parameters=raw.parameters, + ) + if not isinstance(raw, Mapping): + raise TypeError( + f"Tool definitions must be tool objects or mappings, " + f"got {type(raw)}" + ) + return cls( + id=raw["id"], + name=raw["name"], + description=raw["description"], + parameters=raw.get("parameters", {}), + output_schema=raw.get("output_schema", {}), + read_only=raw.get("read_only", True), + ) + + +@dataclass +class StatefulEnvironment(BaseEnvironment): + """Environment with mutable shared state. + + Maintains a JSON state dict that tools can read and modify. The + ``system_prompt`` instructs the LLM how to simulate this environment. + Each tool is a ``StatefulTool`` with an output schema and read/write flag. + """ + + ENVIRONMENT_TYPE: ClassVar[ToolEnvironmentType] = ToolEnvironmentType.STATEFUL + type: ToolEnvironmentType = field(init=False, default=ToolEnvironmentType.STATEFUL) + system_prompt: str = "" + state_schema: dict[str, Any] | None = None + initial_state: dict[str, Any] | None = None + tools: list[StatefulTool] = field(default_factory=list) # type: ignore[assignment] + + def _coerce_tools(self, tools: list[Any]) -> list[StatefulTool]: + """Coerce raw tool definitions into StatefulTool instances.""" + return [StatefulTool.create(t) for t in tools] + + def _validate_type_specific(self) -> None: + """Validate stateful-specific fields.""" + if not self.system_prompt: + raise ValueError("StatefulEnvironment.system_prompt cannot be empty.") diff --git a/src/oumi/environments/stateless_environment.py b/src/oumi/environments/stateless_environment.py new file mode 100644 index 0000000000..6cb011e2f1 --- /dev/null +++ b/src/oumi/environments/stateless_environment.py @@ -0,0 +1,141 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stateless environment with optional response caching.""" + +from __future__ import annotations + +import json +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Any, ClassVar + +from oumi.core.configs.params.base_params import BaseParams +from oumi.environments.base_environment import BaseEnvironment +from oumi.environments.base_tool import BaseTool +from oumi.environments.types import ToolEnvironmentType + + +@dataclass +class GeneratedToolOutput(BaseParams): + """Configuration for tool output in a stateless environment.""" + + instruction: str + + def __post_init__(self): + """Validate the instruction field is not empty.""" + if not self.instruction: + raise ValueError("GeneratedToolOutput.instruction cannot be empty.") + + +@dataclass +class StatelessTool(BaseTool): + """Tool bound to a stateless environment.""" + + generated_output: GeneratedToolOutput | None = None + + def __post_init__(self): + """Validate stateless tool fields.""" + super().__post_init__() + if self.generated_output is None: + raise ValueError( + f"StatelessTool '{self.id}' must have a generated_output." + ) + + @classmethod + def create(cls, raw: Mapping[str, Any] | BaseTool) -> StatelessTool: + """Create a StatelessTool from raw config data.""" + if isinstance(raw, StatelessTool): + return raw + if isinstance(raw, BaseTool): + raise TypeError( + f"Cannot coerce {type(raw).__name__} to StatelessTool. " + f"Use a mapping with 'generated_output'." + ) + if not isinstance(raw, Mapping): + raise TypeError( + f"Tool definitions must be tool objects or mappings, " + f"got {type(raw)}" + ) + generated_output = raw.get("generated_output") + if isinstance(generated_output, Mapping): + generated_output = GeneratedToolOutput(**generated_output) + return cls( + id=raw["id"], + name=raw["name"], + description=raw["description"], + parameters=raw.get("parameters", {}), + generated_output=generated_output, + ) + + +@dataclass +class StatelessEnvironment(BaseEnvironment): + """Environment that simulates outputs and optionally caches by input. + + Each tool is a ``StatelessTool`` with a ``generated_output`` instruction + for the LLM. When ``cache_by_input`` is True, the environment caches + results keyed by (tool_id, arguments) so repeated calls with the same + input return consistent results. + """ + + ENVIRONMENT_TYPE: ClassVar[ToolEnvironmentType] = ToolEnvironmentType.STATELESS + type: ToolEnvironmentType = field(init=False, default=ToolEnvironmentType.STATELESS) + system_prompt: str = "" + cache_by_input: bool = True + tools: list[StatelessTool] = field(default_factory=list) # type: ignore[assignment] + + def __post_init__(self): + """Validate and initialize the response cache.""" + super().__post_init__() + self._cache: dict[str, str] = {} + self._frozen_context: str | None = None + + def _coerce_tools(self, tools: list[Any]) -> list[StatelessTool]: + """Coerce raw tool definitions into StatelessTool instances.""" + return [StatelessTool.create(t) for t in tools] + + def _validate_type_specific(self) -> None: + """Validate stateless-specific fields.""" + if not self.system_prompt: + raise ValueError("StatelessEnvironment.system_prompt cannot be empty.") + + @property + def frozen_context(self) -> str | None: + """Frozen context generated once at build time.""" + return self._frozen_context + + def set_frozen_context(self, context: str) -> None: + """Set the frozen context (called once during environment build).""" + self._frozen_context = context + + @staticmethod + def _cache_key(tool_id: str, arguments: dict[str, Any]) -> str: + """Build a stable cache key from tool id and arguments.""" + return f"{tool_id}::{json.dumps(arguments, sort_keys=True)}" + + def resolve_cached( + self, tool_id: str, arguments: dict[str, Any] + ) -> str | None: + """Look up a cached result for the given tool call.""" + if not self.cache_by_input: + return None + return self._cache.get(self._cache_key(tool_id, arguments)) + + def cache_result( + self, tool_id: str, arguments: dict[str, Any], result: str + ) -> None: + """Store a generated result in the cache. No-op if caching is disabled.""" + if self.cache_by_input: + self._cache[self._cache_key(tool_id, arguments)] = result diff --git a/src/oumi/environments/types.py b/src/oumi/environments/types.py new file mode 100644 index 0000000000..f3654ab0cf --- /dev/null +++ b/src/oumi/environments/types.py @@ -0,0 +1,25 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared types for the environments package.""" + +from enum import Enum + + +class ToolEnvironmentType(str, Enum): + """Execution model for an environment-bound tool.""" + + STATEFUL = "stateful" + STATELESS = "stateless" + DETERMINISTIC = "deterministic" diff --git a/tests/unit/core/configs/params/test_tool_params.py b/tests/unit/core/configs/params/test_tool_params.py index 8e139378e0..cbca2bffc8 100644 --- a/tests/unit/core/configs/params/test_tool_params.py +++ b/tests/unit/core/configs/params/test_tool_params.py @@ -12,263 +12,216 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any - import pytest -from oumi.core.configs.params.synthesis_params import ( - GeneralSynthesisParams, - MultiTurnAttribute, -) -from oumi.core.configs.params.tool_params import ( +from oumi.core.configs.environment_config import EnvironmentConfig +from oumi.environments import ( + BaseTool, + DeterministicEnvironment, + DeterministicTool, DeterministicToolOutput, GeneratedToolOutput, - ToolAttribute, - ToolEnvironmentAttribute, - ToolOutputStrategy, + StatefulEnvironment, + StatefulTool, + StatelessEnvironment, + StatelessTool, + ToolEnvironmentType, ) -from oumi.core.types.conversation import Role - - -def test_deterministic_tool_output_empty_values_raises(): - with pytest.raises(ValueError, match="values cannot be empty"): - DeterministicToolOutput(values={}) - -@pytest.mark.parametrize("rate", [-0.1, 1.1]) -def test_deterministic_tool_output_invalid_sample_rate_raises(rate): - with pytest.raises(ValueError, match="sample_rate must be between 0 and 1"): - DeterministicToolOutput(values={"x": 1}, sample_rate=rate) - -def _make_deterministic_tool(**overrides) -> ToolAttribute: +def _make_deterministic_tool(**overrides) -> DeterministicTool: defaults = dict( id="tool1", name="MyTool", description="A tool", - output_strategy=ToolOutputStrategy.DETERMINISTIC, deterministic_outputs=[ - DeterministicToolOutput(values={"a": 1}), + DeterministicToolOutput(input={"id": "01"}, output={"msg": "ok"}), ], ) defaults.update(overrides) - return ToolAttribute(**defaults) # type: ignore[arg-type] + return DeterministicTool(**defaults) -def _make_generated_tool(**overrides) -> ToolAttribute: +def _make_stateless_tool(**overrides) -> StatelessTool: defaults = dict( id="tool2", name="GenTool", description="A generated tool", - output_strategy=ToolOutputStrategy.GENERATED, generated_output=GeneratedToolOutput(instruction="Do something."), ) defaults.update(overrides) - return ToolAttribute(**defaults) # type: ignore[arg-type] + return StatelessTool(**defaults) -def test_tool_attribute_deterministic_without_outputs_raises(): - with pytest.raises(ValueError, match="deterministic_outputs cannot be empty"): - ToolAttribute( - id="t", - name="T", - description="d", - output_strategy=ToolOutputStrategy.DETERMINISTIC, - deterministic_outputs=[], - ) +def _make_stateful_tool(**overrides) -> StatefulTool: + defaults = dict( + id="tool3", + name="StatefulTool", + description="A stateful tool", + ) + defaults.update(overrides) + return StatefulTool(**defaults) -def test_tool_attribute_generated_without_output_raises(): - with pytest.raises(ValueError, match="generated_output must be provided"): - ToolAttribute( - id="t", - name="T", - description="d", - output_strategy=ToolOutputStrategy.GENERATED, - generated_output=None, - ) +# --- DeterministicToolOutput tests --- -@pytest.mark.parametrize( - "field,value", - [("id", ""), ("name", ""), ("description", "")], -) -def test_tool_attribute_empty_field_raises(field, value): - with pytest.raises(ValueError, match=f"{field} cannot be empty"): - _make_generated_tool(**{field: value}) +def test_deterministic_tool_output_empty_input_raises(): + with pytest.raises(ValueError, match="input cannot be empty"): + DeterministicToolOutput(input={}, output={"msg": "ok"}) -def test_tool_attribute_normalizes_undefined_sample_rates(): - outputs = [ - DeterministicToolOutput(values={"a": 1}), - DeterministicToolOutput(values={"b": 2}), - ] - tool = _make_deterministic_tool(deterministic_outputs=outputs) - assert tool.deterministic_outputs[0].sample_rate == pytest.approx(0.5) - assert tool.deterministic_outputs[1].sample_rate == pytest.approx(0.5) +def test_deterministic_tool_output_empty_output_raises(): + with pytest.raises(ValueError, match="output cannot be empty"): + DeterministicToolOutput(input={"id": "1"}, output={}) -def test_tool_attribute_normalizes_mixed_sample_rates(): - outputs = [ - DeterministicToolOutput(values={"a": 1}, sample_rate=0.7), - DeterministicToolOutput(values={"b": 2}), - ] - tool = _make_deterministic_tool(deterministic_outputs=outputs) - assert tool.deterministic_outputs[0].sample_rate == pytest.approx(0.7) - assert tool.deterministic_outputs[1].sample_rate == pytest.approx(0.3) - - -def test_tool_attribute_sample_rates_exceeding_one_raises(): - outputs = [ - DeterministicToolOutput(values={"a": 1}, sample_rate=0.6), - DeterministicToolOutput(values={"b": 2}, sample_rate=0.6), - ] - with pytest.raises(ValueError, match="sample rates must sum to at most 1.0"): - _make_deterministic_tool(deterministic_outputs=outputs) - - -def _make_multiturn_attr(**overrides) -> MultiTurnAttribute: - defaults = dict( - id="chat", - min_turns=1, - max_turns=3, - role_instruction_messages={ - Role.USER: "You are a user.", - Role.ASSISTANT: "You are an assistant.", - }, - available_tools=[], +def test_deterministic_tool_output_matches_exact(): + entry = DeterministicToolOutput( + input={"id": "01", "status": "pending"}, + output={"message": "Order is pending"}, ) - defaults.update(overrides) - return MultiTurnAttribute(**defaults) # type: ignore[arg-type] + assert entry.matches({"id": "01", "status": "pending"}) is True + assert entry.matches({"status": "pending", "id": "01"}) is True -def test_synthesis_params_valid_tool_references(): - tool = _make_generated_tool(id="search") - mt = _make_multiturn_attr(available_tools=["search"]) - params = GeneralSynthesisParams( - tools=[tool], - multiturn_attributes=[mt], +def test_deterministic_tool_output_no_match(): + entry = DeterministicToolOutput( + input={"id": "01"}, + output={"message": "ok"}, ) - assert params.tools is not None - assert len(params.tools) == 1 + assert entry.matches({"id": "02"}) is False + assert entry.matches({"id": "01", "extra": "arg"}) is False -def test_synthesis_params_undefined_tool_reference_raises(): - tool = _make_generated_tool(id="search") - mt = _make_multiturn_attr(available_tools=["nonexistent"]) - with pytest.raises(ValueError, match="references unknown tool 'nonexistent'"): - GeneralSynthesisParams( - tools=[tool], - multiturn_attributes=[mt], - ) +# --- BaseTool tests --- -def test_synthesis_params_available_tools_without_tools_defined_raises(): - mt = _make_multiturn_attr(available_tools=["search"]) - with pytest.raises(ValueError, match="tools must be defined"): - GeneralSynthesisParams( - tools=None, - multiturn_attributes=[mt], - ) +@pytest.mark.parametrize( + "field,value", [("id", ""), ("name", ""), ("description", "")] +) +def test_base_tool_empty_field_raises(field, value): + with pytest.raises(ValueError, match=f"{field} cannot be empty"): + BaseTool(**{"id": "t", "name": "T", "description": "d", **{field: value}}) -def test_synthesis_params_duplicate_tool_ids_raises(): - t1 = _make_generated_tool(id="dup") - t2 = _make_generated_tool(id="dup") - mt = _make_multiturn_attr(available_tools=["dup"]) - with pytest.raises(ValueError, match="duplicate tool ids"): - GeneralSynthesisParams( - tools=[t1, t2], - multiturn_attributes=[mt], - ) +# --- DeterministicTool tests --- -# --- ToolOutputStrategy.ENVIRONMENT --- +def test_deterministic_tool_requires_outputs(): + with pytest.raises(ValueError, match="must have at least one"): + DeterministicTool( + id="t", name="T", description="d", deterministic_outputs=[] + ) -def test_tool_output_strategy_environment_exists(): - assert ToolOutputStrategy.ENVIRONMENT == "environment" +def test_deterministic_tool_duplicate_inputs_raises(): + outputs = [ + DeterministicToolOutput(input={"id": "01"}, output={"msg": "a"}), + DeterministicToolOutput(input={"id": "01"}, output={"msg": "b"}), + ] + with pytest.raises(ValueError, match="duplicate"): + _make_deterministic_tool(deterministic_outputs=outputs) -def _make_environment_tool(**overrides: Any) -> ToolAttribute: - defaults: dict[str, Any] = dict( - id="tool_env", - name="EnvTool", - description="An environment tool", - output_strategy=ToolOutputStrategy.ENVIRONMENT, - environment="my_env", - read_only=True, +def test_deterministic_tool_resolve_match(): + tool = _make_deterministic_tool( + deterministic_outputs=[ + DeterministicToolOutput( + input={"id": "01"}, output={"msg": "pending"} + ), + DeterministicToolOutput( + input={"id": "02"}, output={"msg": "delivered"} + ), + ] ) - defaults.update(overrides) - return ToolAttribute(**defaults) + assert tool.resolve_deterministic({"id": "01"}) == {"msg": "pending"} + assert tool.resolve_deterministic({"id": "02"}) == {"msg": "delivered"} -def test_tool_attribute_environment_valid(): - tool = _make_environment_tool() - assert tool.environment == "my_env" - assert tool.read_only is True - assert tool.output_strategy == ToolOutputStrategy.ENVIRONMENT +def test_deterministic_tool_resolve_no_match(): + tool = _make_deterministic_tool() + assert tool.resolve_deterministic({"id": "99"}) is None -def test_tool_attribute_environment_read_only_false(): - tool = _make_environment_tool(read_only=False) - assert tool.read_only is False +# --- StatelessTool tests --- -def test_tool_attribute_environment_strategy_without_env_raises(): - """ENVIRONMENT strategy requires environment field.""" - with pytest.raises(ValueError, match="environment must be set"): - ToolAttribute( - id="t", - name="T", - description="d", - output_strategy=ToolOutputStrategy.ENVIRONMENT, +def test_stateless_tool_requires_generated_output(): + with pytest.raises(ValueError, match="must have a generated_output"): + StatelessTool( + id="t", name="T", description="d", generated_output=None ) -def test_tool_attribute_env_set_without_environment_strategy_raises(): - """Setting environment requires ENVIRONMENT strategy.""" - with pytest.raises(ValueError, match="output_strategy must be ENVIRONMENT"): - ToolAttribute( - id="t", - name="T", - description="d", - output_strategy=ToolOutputStrategy.GENERATED, - environment="some_env", - generated_output=GeneratedToolOutput(instruction="x"), - ) +# --- ToolEnvironmentType tests --- -def test_tool_attribute_environment_ignores_generated_output(): - """ENVIRONMENT tools don't need generated_output or deterministic_outputs.""" - tool = _make_environment_tool() - assert tool.generated_output is None - assert tool.deterministic_outputs == [] +def test_tool_environment_type_values_exist(): + assert ToolEnvironmentType.STATEFUL == "stateful" + assert ToolEnvironmentType.STATELESS == "stateless" + assert ToolEnvironmentType.DETERMINISTIC == "deterministic" -# --- ToolEnvironmentAttribute --- +# --- Environment + typed tool integration tests --- -def test_environment_attribute_valid(): - env = ToolEnvironmentAttribute( +def test_stateful_environment_valid(): + env = StatefulEnvironment( id="filesystem", name="Filesystem", description="A simple filesystem", system_prompt="You manage a filesystem.", + tools=[StatefulTool(id="read", name="Read", description="Read files.")], ) assert env.id == "filesystem" assert env.state_schema is None assert env.initial_state is None + assert isinstance(env.tools[0], StatefulTool) + + +def test_stateful_environment_coerces_dict_tools(): + env = StatefulEnvironment( + id="fs", + name="FS", + description="d", + system_prompt="p", + tools=[{"id": "read", "name": "Read", "description": "Read files."}], + ) + assert isinstance(env.tools[0], StatefulTool) + + +def test_deterministic_environment_valid(): + env = DeterministicEnvironment( + id="lookup", + name="Lookup", + description="A deterministic lookup environment", + tools=[ + DeterministicTool( + id="policy", + name="Policy", + description="Look up policy.", + deterministic_outputs=[ + DeterministicToolOutput( + input={"id": "1"}, + output={"result": "ok"}, + ) + ], + ) + ], + ) + assert not hasattr(env, "system_prompt") + assert isinstance(env.tools[0], DeterministicTool) -def test_environment_attribute_with_schema_and_state(): +def test_environment_with_schema_and_state(): schema = { "type": "object", "properties": {"files": {"type": "object"}}, "required": ["files"], } state = {"files": {}} - env = ToolEnvironmentAttribute( + env = StatefulEnvironment( id="fs", name="FS", description="d", @@ -280,58 +233,117 @@ def test_environment_attribute_with_schema_and_state(): assert env.initial_state == state -def test_environment_attribute_empty_id_raises(): +def test_environment_empty_id_raises(): with pytest.raises(ValueError, match="id cannot be empty"): - ToolEnvironmentAttribute(id="", name="n", description="d", system_prompt="p") + StatefulEnvironment(id="", name="n", description="d", system_prompt="p") -def test_environment_attribute_empty_name_raises(): +def test_environment_empty_name_raises(): with pytest.raises(ValueError, match="name cannot be empty"): - ToolEnvironmentAttribute(id="x", name="", description="d", system_prompt="p") + StatefulEnvironment(id="x", name="", description="d", system_prompt="p") -def test_environment_attribute_empty_description_raises(): +def test_environment_empty_description_raises(): with pytest.raises(ValueError, match="description cannot be empty"): - ToolEnvironmentAttribute(id="x", name="n", description="", system_prompt="p") + StatefulEnvironment( + id="x", name="n", description="", system_prompt="p" + ) -def test_environment_attribute_empty_system_prompt_raises(): +def test_stateful_environment_empty_system_prompt_raises(): with pytest.raises(ValueError, match="system_prompt cannot be empty"): - ToolEnvironmentAttribute(id="x", name="n", description="d", system_prompt="") + StatefulEnvironment( + id="x", name="n", description="d", system_prompt="" + ) -# --- GeneralSynthesisParams with environments --- +def test_stateless_environment_empty_system_prompt_raises(): + with pytest.raises(ValueError, match="system_prompt cannot be empty"): + StatelessEnvironment( + id="x", name="n", description="d", system_prompt="" + ) -def test_general_synthesis_params_with_environments(): - env = ToolEnvironmentAttribute( - id="fs", name="FS", description="d", system_prompt="p" - ) - tool = _make_environment_tool(environment="fs") - params = GeneralSynthesisParams( - environments=[env], - tools=[tool], - multiturn_attributes=[_make_multiturn_attr(available_tools=["tool_env"])], - ) - assert params.environments is not None - assert len(params.environments) == 1 +def test_stateless_environment_with_state_schema_raises(): + with pytest.raises( + TypeError, match="unexpected keyword argument 'state_schema'" + ): + StatelessEnvironment( + id="x", + name="n", + description="d", + system_prompt="p", + state_schema={"type": "object"}, + ) + + +def test_deterministic_environment_with_initial_state_raises(): + with pytest.raises( + TypeError, match="unexpected keyword argument 'initial_state'" + ): + DeterministicEnvironment( + id="x", name="n", description="d", initial_state={} + ) -def test_general_synthesis_params_tool_references_unknown_env_raises(): - tool = _make_environment_tool(environment="nonexistent") - with pytest.raises(ValueError, match="references unknown environment"): - GeneralSynthesisParams( - tools=[tool], - multiturn_attributes=[_make_multiturn_attr(available_tools=["tool_env"])], +def test_environment_duplicate_tool_ids_raises(): + with pytest.raises(ValueError, match="duplicate tool id 'dup'"): + StatefulEnvironment( + id="env2", + name="Env 2", + description="d", + system_prompt="p", + tools=[ + StatefulTool(id="dup", name="Read", description="Read files."), + StatefulTool( + id="dup", name="Write", description="Write files." + ), + ], ) -def test_general_synthesis_params_duplicate_env_ids_raises(): - env1 = ToolEnvironmentAttribute( - id="fs", name="FS1", description="d1", system_prompt="p1" +def test_environment_config_duplicate_tool_ids_across_envs_raises(): + env1 = StatefulEnvironment( + id="env1", + name="Env 1", + description="d", + system_prompt="p", + tools=[ + StatefulTool(id="dup", name="Read", description="Read files.") + ], ) - env2 = ToolEnvironmentAttribute( - id="fs", name="FS2", description="d2", system_prompt="p2" + env2 = StatefulEnvironment( + id="env2", + name="Env 2", + description="d", + system_prompt="p", + tools=[ + StatefulTool(id="dup", name="Write", description="Write files.") + ], + ) + with pytest.raises(ValueError, match="duplicate tool id 'dup'"): + EnvironmentConfig(environments=[env1, env2]) + + +def test_environment_config_tool_environment_map(): + env = StatelessEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[_make_stateless_tool(id="answer_faq")], ) - with pytest.raises(ValueError, match="duplicate environment"): - GeneralSynthesisParams(environments=[env1, env2]) + config = EnvironmentConfig(environments=[env]) + assert config.tool_environment_map == {"answer_faq": "faq"} + + +def test_deterministic_environment_requires_outputs_on_tool(): + with pytest.raises(ValueError, match="must have at least one"): + DeterministicEnvironment( + id="det_env", + name="Deterministic", + description="d", + tools=[ + _make_deterministic_tool(deterministic_outputs=[]) + ], + ) diff --git a/tests/unit/core/configs/test_synthesis_config.py b/tests/unit/core/configs/test_synthesis_config.py index 43fd0b0d79..a5a47c5aa2 100644 --- a/tests/unit/core/configs/test_synthesis_config.py +++ b/tests/unit/core/configs/test_synthesis_config.py @@ -14,9 +14,21 @@ import pytest +from oumi.core.configs.environment_config import EnvironmentConfig from oumi.core.configs.inference_config import InferenceConfig -from oumi.core.configs.params.synthesis_params import GeneralSynthesisParams +from oumi.core.configs.params.synthesis_params import ( + GeneralSynthesisParams, + MultiTurnAttribute, +) from oumi.core.configs.synthesis_config import SynthesisConfig, SynthesisStrategy +from oumi.core.types.conversation import Role +from oumi.environments import ( + GeneratedToolOutput, + StatefulEnvironment, + StatefulTool, + StatelessEnvironment, + StatelessTool, +) def test_default_synthesis_config(): @@ -70,3 +82,253 @@ def test_invalid_output_path(): with pytest.raises(ValueError, match="Output path is not supported"): SynthesisConfig(inference_config=inference_config) + + +def _make_faq_tool() -> StatelessTool: + return StatelessTool( + id="answer_faq", + name="AnswerFAQ", + description="Answer a FAQ question.", + generated_output=GeneratedToolOutput( + instruction="Answer the given FAQ question." + ), + ) + + +def test_synthesis_config_with_top_level_environment_config(): + env_config = EnvironmentConfig( + environments=[ + StatelessEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[_make_faq_tool()], + ) + ] + ) + params = GeneralSynthesisParams() + params.multiturn_attributes = [] + + config = SynthesisConfig( + strategy_params=params, + environment_config=env_config, + ) + + assert config.environment_config == env_config + assert config.environment_config.tool_environment_map == { + "answer_faq": "faq" + } + + +def test_synthesis_config_loads_environment_config_from_path(tmp_path): + env_config_path = tmp_path / "environments.yaml" + env_config = EnvironmentConfig( + environments=[ + StatelessEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[_make_faq_tool()], + ) + ] + ) + env_config.to_yaml(env_config_path) + + config = SynthesisConfig(environment_config_path=str(env_config_path)) + + assert config.environment_config is not None + assert config.environment_config.all_tools[0].id == "answer_faq" + + +def test_synthesis_config_validates_available_tools(): + env_config = EnvironmentConfig( + environments=[ + StatelessEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[_make_faq_tool()], + ) + ] + ) + params = GeneralSynthesisParams( + multiturn_attributes=[ + MultiTurnAttribute( + id="chat", + min_turns=1, + max_turns=2, + role_instruction_messages={ + Role.USER: "You are a user.", + Role.ASSISTANT: "You are an assistant.", + }, + available_tools=["answer_faq"], + ) + ] + ) + + config = SynthesisConfig( + strategy_params=params, + environment_config=env_config, + ) + + assert config.environment_config is not None + assert config.environment_config.all_tools[0].id == "answer_faq" + mt_attr = params.multiturn_attributes[0] + assert [t.id for t in config.resolve_multiturn_tools(mt_attr)] == [ + "answer_faq" + ] + + +def test_synthesis_config_requires_environment_config_for_available_tools(): + params = GeneralSynthesisParams( + multiturn_attributes=[ + MultiTurnAttribute( + id="chat", + min_turns=1, + max_turns=2, + role_instruction_messages={ + Role.USER: "You are a user.", + Role.ASSISTANT: "You are an assistant.", + }, + available_tools=["answer_faq"], + ) + ] + ) + + with pytest.raises( + ValueError, match="Environment or tool references require" + ): + SynthesisConfig(strategy_params=params) + + +def test_synthesis_config_validates_available_environments(): + env_config = EnvironmentConfig( + environments=[ + StatelessEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[_make_faq_tool()], + ) + ] + ) + params = GeneralSynthesisParams( + multiturn_attributes=[ + MultiTurnAttribute( + id="chat", + min_turns=1, + max_turns=2, + role_instruction_messages={ + Role.USER: "You are a user.", + Role.ASSISTANT: "You are an assistant.", + }, + available_environments=["missing_env"], + ) + ] + ) + + with pytest.raises(ValueError, match="references unknown environment"): + SynthesisConfig( + strategy_params=params, environment_config=env_config + ) + + +def test_synthesis_config_restricts_tools_to_selected_environments(): + env_config = EnvironmentConfig( + environments=[ + StatelessEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[_make_faq_tool()], + ), + StatefulEnvironment( + id="files", + name="Files", + description="File tools", + system_prompt="Manage files.", + tools=[ + StatefulTool( + id="read_file", + name="ReadFile", + description="Read a file.", + ) + ], + ), + ] + ) + params = GeneralSynthesisParams( + multiturn_attributes=[ + MultiTurnAttribute( + id="chat", + min_turns=1, + max_turns=2, + role_instruction_messages={ + Role.USER: "You are a user.", + Role.ASSISTANT: "You are an assistant.", + }, + available_environments=["faq"], + available_tools=["read_file"], + ) + ] + ) + + with pytest.raises(ValueError, match="references unknown tool"): + SynthesisConfig( + strategy_params=params, environment_config=env_config + ) + + +def test_synthesis_config_resolves_all_tools_from_selected_environments(): + env_config = EnvironmentConfig( + environments=[ + StatelessEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[_make_faq_tool()], + ), + StatefulEnvironment( + id="files", + name="Files", + description="File tools", + system_prompt="Manage files.", + tools=[ + StatefulTool( + id="read_file", + name="ReadFile", + description="Read a file.", + ) + ], + ), + ] + ) + mt_attr = MultiTurnAttribute( + id="chat", + min_turns=1, + max_turns=2, + role_instruction_messages={ + Role.USER: "You are a user.", + Role.ASSISTANT: "You are an assistant.", + }, + available_environments=["faq", "files"], + ) + config = SynthesisConfig( + strategy_params=GeneralSynthesisParams( + multiturn_attributes=[mt_attr] + ), + environment_config=env_config, + ) + + assert [ + env.id for env in config.resolve_multiturn_environments(mt_attr) + ] == ["faq", "files"] + assert [ + tool.id for tool in config.resolve_multiturn_tools(mt_attr) + ] == ["answer_faq", "read_file"] From 01ab23dfef53419bc12f1768c7a1f0b00541f75f Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Wed, 8 Apr 2026 21:54:20 -0700 Subject: [PATCH 07/18] style: apply ruff format to environment and test files --- src/oumi/environments/base_tool.py | 3 +- .../environments/deterministic_environment.py | 11 +--- src/oumi/environments/stateful_environment.py | 3 +- .../environments/stateless_environment.py | 11 +--- .../core/configs/params/test_tool_params.py | 60 +++++-------------- .../core/configs/test_synthesis_config.py | 38 +++++------- 6 files changed, 37 insertions(+), 89 deletions(-) diff --git a/src/oumi/environments/base_tool.py b/src/oumi/environments/base_tool.py index 71e575e369..be5e0226b2 100644 --- a/src/oumi/environments/base_tool.py +++ b/src/oumi/environments/base_tool.py @@ -44,8 +44,7 @@ def create(cls, raw: Mapping[str, Any] | BaseTool) -> BaseTool: return raw if not isinstance(raw, Mapping): raise TypeError( - f"Tool definitions must be tool objects or mappings, " - f"got {type(raw)}" + f"Tool definitions must be tool objects or mappings, got {type(raw)}" ) return cls( id=raw["id"], diff --git a/src/oumi/environments/deterministic_environment.py b/src/oumi/environments/deterministic_environment.py index 0a04c2fc40..f4424634b6 100644 --- a/src/oumi/environments/deterministic_environment.py +++ b/src/oumi/environments/deterministic_environment.py @@ -75,9 +75,7 @@ def _check_deterministic_duplicates(self) -> None: ) seen.add(key) - def resolve_deterministic( - self, arguments: dict[str, Any] - ) -> dict[str, Any] | None: + def resolve_deterministic(self, arguments: dict[str, Any]) -> dict[str, Any] | None: """Resolve a deterministic output for the given arguments.""" for entry in self.deterministic_outputs: if entry.matches(arguments): @@ -96,8 +94,7 @@ def create(cls, raw: Mapping[str, Any] | BaseTool) -> DeterministicTool: ) if not isinstance(raw, Mapping): raise TypeError( - f"Tool definitions must be tool objects or mappings, " - f"got {type(raw)}" + f"Tool definitions must be tool objects or mappings, got {type(raw)}" ) deterministic_outputs = [ entry @@ -135,9 +132,7 @@ def _coerce_tools(self, tools: list[Any]) -> list[DeterministicTool]: def _validate_type_specific(self) -> None: return - def resolve( - self, tool_id: str, arguments: dict[str, Any] - ) -> dict[str, Any] | None: + def resolve(self, tool_id: str, arguments: dict[str, Any]) -> dict[str, Any] | None: """Resolve a deterministic tool call to its output. Raises: diff --git a/src/oumi/environments/stateful_environment.py b/src/oumi/environments/stateful_environment.py index 93688f1c7e..e619ff7bf9 100644 --- a/src/oumi/environments/stateful_environment.py +++ b/src/oumi/environments/stateful_environment.py @@ -46,8 +46,7 @@ def create(cls, raw: Mapping[str, Any] | BaseTool) -> StatefulTool: ) if not isinstance(raw, Mapping): raise TypeError( - f"Tool definitions must be tool objects or mappings, " - f"got {type(raw)}" + f"Tool definitions must be tool objects or mappings, got {type(raw)}" ) return cls( id=raw["id"], diff --git a/src/oumi/environments/stateless_environment.py b/src/oumi/environments/stateless_environment.py index 6cb011e2f1..65ece814fd 100644 --- a/src/oumi/environments/stateless_environment.py +++ b/src/oumi/environments/stateless_environment.py @@ -49,9 +49,7 @@ def __post_init__(self): """Validate stateless tool fields.""" super().__post_init__() if self.generated_output is None: - raise ValueError( - f"StatelessTool '{self.id}' must have a generated_output." - ) + raise ValueError(f"StatelessTool '{self.id}' must have a generated_output.") @classmethod def create(cls, raw: Mapping[str, Any] | BaseTool) -> StatelessTool: @@ -65,8 +63,7 @@ def create(cls, raw: Mapping[str, Any] | BaseTool) -> StatelessTool: ) if not isinstance(raw, Mapping): raise TypeError( - f"Tool definitions must be tool objects or mappings, " - f"got {type(raw)}" + f"Tool definitions must be tool objects or mappings, got {type(raw)}" ) generated_output = raw.get("generated_output") if isinstance(generated_output, Mapping): @@ -125,9 +122,7 @@ def _cache_key(tool_id: str, arguments: dict[str, Any]) -> str: """Build a stable cache key from tool id and arguments.""" return f"{tool_id}::{json.dumps(arguments, sort_keys=True)}" - def resolve_cached( - self, tool_id: str, arguments: dict[str, Any] - ) -> str | None: + def resolve_cached(self, tool_id: str, arguments: dict[str, Any]) -> str | None: """Look up a cached result for the given tool call.""" if not self.cache_by_input: return None diff --git a/tests/unit/core/configs/params/test_tool_params.py b/tests/unit/core/configs/params/test_tool_params.py index cbca2bffc8..c734d15b87 100644 --- a/tests/unit/core/configs/params/test_tool_params.py +++ b/tests/unit/core/configs/params/test_tool_params.py @@ -97,9 +97,7 @@ def test_deterministic_tool_output_no_match(): # --- BaseTool tests --- -@pytest.mark.parametrize( - "field,value", [("id", ""), ("name", ""), ("description", "")] -) +@pytest.mark.parametrize("field,value", [("id", ""), ("name", ""), ("description", "")]) def test_base_tool_empty_field_raises(field, value): with pytest.raises(ValueError, match=f"{field} cannot be empty"): BaseTool(**{"id": "t", "name": "T", "description": "d", **{field: value}}) @@ -110,9 +108,7 @@ def test_base_tool_empty_field_raises(field, value): def test_deterministic_tool_requires_outputs(): with pytest.raises(ValueError, match="must have at least one"): - DeterministicTool( - id="t", name="T", description="d", deterministic_outputs=[] - ) + DeterministicTool(id="t", name="T", description="d", deterministic_outputs=[]) def test_deterministic_tool_duplicate_inputs_raises(): @@ -127,12 +123,8 @@ def test_deterministic_tool_duplicate_inputs_raises(): def test_deterministic_tool_resolve_match(): tool = _make_deterministic_tool( deterministic_outputs=[ - DeterministicToolOutput( - input={"id": "01"}, output={"msg": "pending"} - ), - DeterministicToolOutput( - input={"id": "02"}, output={"msg": "delivered"} - ), + DeterministicToolOutput(input={"id": "01"}, output={"msg": "pending"}), + DeterministicToolOutput(input={"id": "02"}, output={"msg": "delivered"}), ] ) assert tool.resolve_deterministic({"id": "01"}) == {"msg": "pending"} @@ -149,9 +141,7 @@ def test_deterministic_tool_resolve_no_match(): def test_stateless_tool_requires_generated_output(): with pytest.raises(ValueError, match="must have a generated_output"): - StatelessTool( - id="t", name="T", description="d", generated_output=None - ) + StatelessTool(id="t", name="T", description="d", generated_output=None) # --- ToolEnvironmentType tests --- @@ -245,29 +235,21 @@ def test_environment_empty_name_raises(): def test_environment_empty_description_raises(): with pytest.raises(ValueError, match="description cannot be empty"): - StatefulEnvironment( - id="x", name="n", description="", system_prompt="p" - ) + StatefulEnvironment(id="x", name="n", description="", system_prompt="p") def test_stateful_environment_empty_system_prompt_raises(): with pytest.raises(ValueError, match="system_prompt cannot be empty"): - StatefulEnvironment( - id="x", name="n", description="d", system_prompt="" - ) + StatefulEnvironment(id="x", name="n", description="d", system_prompt="") def test_stateless_environment_empty_system_prompt_raises(): with pytest.raises(ValueError, match="system_prompt cannot be empty"): - StatelessEnvironment( - id="x", name="n", description="d", system_prompt="" - ) + StatelessEnvironment(id="x", name="n", description="d", system_prompt="") def test_stateless_environment_with_state_schema_raises(): - with pytest.raises( - TypeError, match="unexpected keyword argument 'state_schema'" - ): + with pytest.raises(TypeError, match="unexpected keyword argument 'state_schema'"): StatelessEnvironment( id="x", name="n", @@ -278,12 +260,8 @@ def test_stateless_environment_with_state_schema_raises(): def test_deterministic_environment_with_initial_state_raises(): - with pytest.raises( - TypeError, match="unexpected keyword argument 'initial_state'" - ): - DeterministicEnvironment( - id="x", name="n", description="d", initial_state={} - ) + with pytest.raises(TypeError, match="unexpected keyword argument 'initial_state'"): + DeterministicEnvironment(id="x", name="n", description="d", initial_state={}) def test_environment_duplicate_tool_ids_raises(): @@ -295,9 +273,7 @@ def test_environment_duplicate_tool_ids_raises(): system_prompt="p", tools=[ StatefulTool(id="dup", name="Read", description="Read files."), - StatefulTool( - id="dup", name="Write", description="Write files." - ), + StatefulTool(id="dup", name="Write", description="Write files."), ], ) @@ -308,18 +284,14 @@ def test_environment_config_duplicate_tool_ids_across_envs_raises(): name="Env 1", description="d", system_prompt="p", - tools=[ - StatefulTool(id="dup", name="Read", description="Read files.") - ], + tools=[StatefulTool(id="dup", name="Read", description="Read files.")], ) env2 = StatefulEnvironment( id="env2", name="Env 2", description="d", system_prompt="p", - tools=[ - StatefulTool(id="dup", name="Write", description="Write files.") - ], + tools=[StatefulTool(id="dup", name="Write", description="Write files.")], ) with pytest.raises(ValueError, match="duplicate tool id 'dup'"): EnvironmentConfig(environments=[env1, env2]) @@ -343,7 +315,5 @@ def test_deterministic_environment_requires_outputs_on_tool(): id="det_env", name="Deterministic", description="d", - tools=[ - _make_deterministic_tool(deterministic_outputs=[]) - ], + tools=[_make_deterministic_tool(deterministic_outputs=[])], ) diff --git a/tests/unit/core/configs/test_synthesis_config.py b/tests/unit/core/configs/test_synthesis_config.py index a5a47c5aa2..a780472fb6 100644 --- a/tests/unit/core/configs/test_synthesis_config.py +++ b/tests/unit/core/configs/test_synthesis_config.py @@ -116,9 +116,7 @@ def test_synthesis_config_with_top_level_environment_config(): ) assert config.environment_config == env_config - assert config.environment_config.tool_environment_map == { - "answer_faq": "faq" - } + assert config.environment_config.tool_environment_map == {"answer_faq": "faq"} def test_synthesis_config_loads_environment_config_from_path(tmp_path): @@ -177,9 +175,7 @@ def test_synthesis_config_validates_available_tools(): assert config.environment_config is not None assert config.environment_config.all_tools[0].id == "answer_faq" mt_attr = params.multiturn_attributes[0] - assert [t.id for t in config.resolve_multiturn_tools(mt_attr)] == [ - "answer_faq" - ] + assert [t.id for t in config.resolve_multiturn_tools(mt_attr)] == ["answer_faq"] def test_synthesis_config_requires_environment_config_for_available_tools(): @@ -198,9 +194,7 @@ def test_synthesis_config_requires_environment_config_for_available_tools(): ] ) - with pytest.raises( - ValueError, match="Environment or tool references require" - ): + with pytest.raises(ValueError, match="Environment or tool references require"): SynthesisConfig(strategy_params=params) @@ -232,9 +226,7 @@ def test_synthesis_config_validates_available_environments(): ) with pytest.raises(ValueError, match="references unknown environment"): - SynthesisConfig( - strategy_params=params, environment_config=env_config - ) + SynthesisConfig(strategy_params=params, environment_config=env_config) def test_synthesis_config_restricts_tools_to_selected_environments(): @@ -279,9 +271,7 @@ def test_synthesis_config_restricts_tools_to_selected_environments(): ) with pytest.raises(ValueError, match="references unknown tool"): - SynthesisConfig( - strategy_params=params, environment_config=env_config - ) + SynthesisConfig(strategy_params=params, environment_config=env_config) def test_synthesis_config_resolves_all_tools_from_selected_environments(): @@ -320,15 +310,15 @@ def test_synthesis_config_resolves_all_tools_from_selected_environments(): available_environments=["faq", "files"], ) config = SynthesisConfig( - strategy_params=GeneralSynthesisParams( - multiturn_attributes=[mt_attr] - ), + strategy_params=GeneralSynthesisParams(multiturn_attributes=[mt_attr]), environment_config=env_config, ) - assert [ - env.id for env in config.resolve_multiturn_environments(mt_attr) - ] == ["faq", "files"] - assert [ - tool.id for tool in config.resolve_multiturn_tools(mt_attr) - ] == ["answer_faq", "read_file"] + assert [env.id for env in config.resolve_multiturn_environments(mt_attr)] == [ + "faq", + "files", + ] + assert [tool.id for tool in config.resolve_multiturn_tools(mt_attr)] == [ + "answer_faq", + "read_file", + ] From 56efd03411b661435e4a2f34773816034b396d3f Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Wed, 8 Apr 2026 21:58:25 -0700 Subject: [PATCH 08/18] fix: resolve pyright type errors in environment tests Add type annotations to test helper functions, type-ignore comments for intentionally invalid kwargs in negative tests, and None guards for optional fields. --- .../core/configs/params/test_tool_params.py | 20 ++++++++++--------- .../core/configs/test_synthesis_config.py | 2 ++ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/unit/core/configs/params/test_tool_params.py b/tests/unit/core/configs/params/test_tool_params.py index c734d15b87..8ea9b8e54e 100644 --- a/tests/unit/core/configs/params/test_tool_params.py +++ b/tests/unit/core/configs/params/test_tool_params.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import pytest from oumi.core.configs.environment_config import EnvironmentConfig @@ -29,8 +31,8 @@ ) -def _make_deterministic_tool(**overrides) -> DeterministicTool: - defaults = dict( +def _make_deterministic_tool(**overrides: Any) -> DeterministicTool: + defaults: dict[str, Any] = dict( id="tool1", name="MyTool", description="A tool", @@ -42,8 +44,8 @@ def _make_deterministic_tool(**overrides) -> DeterministicTool: return DeterministicTool(**defaults) -def _make_stateless_tool(**overrides) -> StatelessTool: - defaults = dict( +def _make_stateless_tool(**overrides: Any) -> StatelessTool: + defaults: dict[str, Any] = dict( id="tool2", name="GenTool", description="A generated tool", @@ -53,8 +55,8 @@ def _make_stateless_tool(**overrides) -> StatelessTool: return StatelessTool(**defaults) -def _make_stateful_tool(**overrides) -> StatefulTool: - defaults = dict( +def _make_stateful_tool(**overrides: Any) -> StatefulTool: + defaults: dict[str, Any] = dict( id="tool3", name="StatefulTool", description="A stateful tool", @@ -176,7 +178,7 @@ def test_stateful_environment_coerces_dict_tools(): name="FS", description="d", system_prompt="p", - tools=[{"id": "read", "name": "Read", "description": "Read files."}], + tools=[{"id": "read", "name": "Read", "description": "Read files."}], # type: ignore[arg-type] ) assert isinstance(env.tools[0], StatefulTool) @@ -255,13 +257,13 @@ def test_stateless_environment_with_state_schema_raises(): name="n", description="d", system_prompt="p", - state_schema={"type": "object"}, + state_schema={"type": "object"}, # type: ignore[call-arg] ) def test_deterministic_environment_with_initial_state_raises(): with pytest.raises(TypeError, match="unexpected keyword argument 'initial_state'"): - DeterministicEnvironment(id="x", name="n", description="d", initial_state={}) + DeterministicEnvironment(id="x", name="n", description="d", initial_state={}) # type: ignore[call-arg] def test_environment_duplicate_tool_ids_raises(): diff --git a/tests/unit/core/configs/test_synthesis_config.py b/tests/unit/core/configs/test_synthesis_config.py index a780472fb6..11dbc67381 100644 --- a/tests/unit/core/configs/test_synthesis_config.py +++ b/tests/unit/core/configs/test_synthesis_config.py @@ -115,6 +115,7 @@ def test_synthesis_config_with_top_level_environment_config(): environment_config=env_config, ) + assert config.environment_config is not None assert config.environment_config == env_config assert config.environment_config.tool_environment_map == {"answer_faq": "faq"} @@ -174,6 +175,7 @@ def test_synthesis_config_validates_available_tools(): assert config.environment_config is not None assert config.environment_config.all_tools[0].id == "answer_faq" + assert params.multiturn_attributes is not None mt_attr = params.multiturn_attributes[0] assert [t.id for t in config.resolve_multiturn_tools(mt_attr)] == ["answer_faq"] From ace8c8eafe4609b0685a4ffd09c31a3c952e8632 Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Thu, 9 Apr 2026 08:57:37 -0700 Subject: [PATCH 09/18] fix: update test assertion to include environment_config parameter The test_synthesize_with_multiturn_attributes test was failing because the ConversationSynthesizer now accepts environment_config as a kwarg. --- tests/unit/core/synthesis/test_synthesis_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/core/synthesis/test_synthesis_pipeline.py b/tests/unit/core/synthesis/test_synthesis_pipeline.py index f65862381a..e85c08364b 100644 --- a/tests/unit/core/synthesis/test_synthesis_pipeline.py +++ b/tests/unit/core/synthesis/test_synthesis_pipeline.py @@ -314,6 +314,7 @@ def test_synthesize_with_multiturn_attributes( mock_conv_synth_class.assert_called_once_with( synthesis_config_with_multiturn_attributes.strategy_params, synthesis_config_with_multiturn_attributes.inference_config, + environment_config=None, ) mock_conv_synth.synthesize.assert_called_once_with(sample_dataset, multiturn_attr) assert all(multiturn_attr.id in item for item in result) From 21ceb87ecc7b4f43b8559474bc54730aa666b99c Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Thu, 9 Apr 2026 10:34:02 -0700 Subject: [PATCH 10/18] revert: remove unrelated MCP file changes from branch --- src/oumi/mcp/job_launcher.py | 4 ++-- src/oumi/mcp/job_logs.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/oumi/mcp/job_launcher.py b/src/oumi/mcp/job_launcher.py index 852a1f1102..10f01f156f 100644 --- a/src/oumi/mcp/job_launcher.py +++ b/src/oumi/mcp/job_launcher.py @@ -186,13 +186,13 @@ async def _launch_cloud( await evict_runtime(record.job_id) return "" - config_parent = str(Path(record.config_path).expanduser().resolve().parent) # noqa: ASYNC240 + config_parent = str(Path(record.config_path).expanduser().resolve().parent) _stage_cloud_config(record, rt, working_dir=config_parent) job_config = launcher.JobConfig.from_yaml(rt.staged_config_path) if not job_config.name: job_config.name = record.job_id if client_cwd and job_config.working_dir: - wd = Path(job_config.working_dir).expanduser() # noqa: ASYNC240 + wd = Path(job_config.working_dir).expanduser() if not wd.is_absolute(): job_config.working_dir = str((Path(client_cwd) / wd).resolve()) elif client_cwd and not job_config.working_dir: diff --git a/src/oumi/mcp/job_logs.py b/src/oumi/mcp/job_logs.py index e336ecea36..03df5e9d63 100644 --- a/src/oumi/mcp/job_logs.py +++ b/src/oumi/mcp/job_logs.py @@ -74,7 +74,7 @@ async def tail_log_file( If the file does not exist yet, waits up to ``poll_interval`` between checks until it appears or *done_event* fires. """ - while not path.exists(): # noqa: ASYNC240 + while not path.exists(): if done_event.is_set(): return await asyncio.sleep(poll_interval) @@ -84,13 +84,13 @@ async def tail_log_file( while True: try: - size = path.stat().st_size # noqa: ASYNC240 + size = path.stat().st_size except OSError: size = 0 if size > position: try: - with open(path, encoding="utf-8", errors="replace") as f: # noqa: ASYNC230 + with open(path, encoding="utf-8", errors="replace") as f: f.seek(position) chunk = f.read() position = f.tell() @@ -105,7 +105,7 @@ async def tail_log_file( if done_event.is_set(): try: - with open(path, encoding="utf-8", errors="replace") as f: # noqa: ASYNC230 + with open(path, encoding="utf-8", errors="replace") as f: f.seek(position) remaining = f.read() except OSError: From ec78eb4935aee5ea6d96be3cf05a685187aba408 Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Thu, 9 Apr 2026 10:35:47 -0700 Subject: [PATCH 11/18] revert: remove unrelated datasets version bump from pyproject.toml --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0aecf3a201..14dbc232fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ dependencies = [ "aioresponses>=0.7,<0.8", # User by inference engine tests "backoff>=2.2.1,<2.3", "click<8.4.0", # Used by CLI. 8.2.0 is currently unsupported by Typer. - "datasets>=3.2,<5", + "datasets>=3.2,<4.8.5", "greenlet", # Required by skypilot 0.11+ (sqlalchemy asyncio) "hdrhistogram>=0.10,<0.11", "httpx>=0.27,<1.0", # Used by deploy module (async HTTP client) @@ -305,6 +305,7 @@ unsupported-operator = "warn" # Type narrowing limitations with isinstance too-many-positional-arguments = "warn" # Loose typing (Callable) doesn't capture signatures parameter-already-assigned = "warn" # False positives with *args/**kwargs patterns + [tool.pytest.ini_options] asyncio_default_fixture_loop_scope = "function" testpaths = ["tests"] From 4148d557c892caafd9137347baddd0d34ed8f9c6 Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Tue, 14 Apr 2026 13:34:15 -0700 Subject: [PATCH 12/18] feat: refactor environments, consolidate synthetic environments --- docs/user_guides/synth.md | 54 ++- src/oumi/core/configs/__init__.py | 28 +- src/oumi/core/configs/environment_config.py | 9 +- src/oumi/core/configs/synthesis_config.py | 5 +- .../synthesis/conversation_synthesizer.py | 4 +- src/oumi/environments/__init__.py | 46 +- src/oumi/environments/base_environment.py | 69 ++- src/oumi/environments/base_tool.py | 44 +- .../environments/deterministic_environment.py | 157 ++---- src/oumi/environments/stateful_environment.py | 84 ---- .../environments/stateless_environment.py | 136 ------ .../environments/synthetic_environment.py | 169 +++++++ .../environments/{types.py => tool_result.py} | 17 +- .../core/configs/params/test_tool_params.py | 447 ++++++++++++------ .../core/configs/test_synthesis_config.py | 39 +- 15 files changed, 669 insertions(+), 639 deletions(-) delete mode 100644 src/oumi/environments/stateful_environment.py delete mode 100644 src/oumi/environments/stateless_environment.py create mode 100644 src/oumi/environments/synthetic_environment.py rename src/oumi/environments/{types.py => tool_result.py} (63%) diff --git a/docs/user_guides/synth.md b/docs/user_guides/synth.md index 71af6658ab..df276b58b9 100644 --- a/docs/user_guides/synth.md +++ b/docs/user_guides/synth.md @@ -166,20 +166,19 @@ Ready to dive deeper? The sections below cover all available options in detail. ## Environment-First Tool Synthesis -Agentic synthesis now follows an environment-first model. Tools do not declare an output strategy directly. Instead, each tool is bound to an environment, and the environment type defines the execution model. +Agentic synthesis now follows an environment-first model. Tools do not declare an output strategy directly. Instead, each tool is bound to an environment, and the environment type defines how tool calls are executed via its `step()` method. -- **`stateful` environments** maintain shared JSON state. Tool calls read from or update that state, which is how consistency is preserved across turns. -- **`stateless` environments** generate tool results with an LLM. Responses are cached by input, so the same tool input can reuse the same generated output. -- **`deterministic` environments** behave like lookup tables. Matching inputs return responses from a predefined set without LLM generation. +- **`synthetic` environments** are backed by an LLM that simulates tool execution. They can be stateless (no persistent state) or stateful (mutable JSON state across turns). Statefulness is controlled by the optional `state_params` field — when provided, the environment tracks and mutates state across calls; when absent, each call is independent. +- **`deterministic` environments** behave like lookup tables. Each tool defines a set of input-to-output mappings, and `step()` resolves tool calls by matching arguments against those mappings. No LLM is involved. At the config level: - Environments own their tool definitions. - Reusable environment catalogs live in top-level `environment_config` or `environment_config_path`. - Tools do not declare an `environment` field. The parent environment owns the binding. -- `generated_output` is only used for tools in `stateless` environments. - `deterministic_outputs` is only used for tools in `deterministic` environments. -- `read_only` is only meaningful for tools in `stateful` environments. +- `read_only` is only meaningful for tools in stateful `synthetic` environments. +- Multiturn attributes reference environments (not individual tools) to select which tools are available. Example: @@ -188,26 +187,51 @@ environment_config: environments: - id: support_backend name: Support Backend - description: Simulated support system state - type: stateful - system_prompt: You manage support system state. + description: Simulated support system with tickets and users + type: synthetic + system_prompt: You manage a customer support system with tickets and users. + state_params: + state_schema: + type: object + properties: + tickets: { type: array } + users: { type: array } + initial_state: + tickets: [] + users: [] tools: - id: get_ticket name: GetTicket description: Read a ticket from the support backend. read_only: true + parameters: + type: object + properties: + ticket_id: { type: string } + - id: create_ticket + name: CreateTicket + description: Create a new support ticket. + read_only: false + parameters: + type: object + properties: + subject: { type: string } + priority: { type: string, enum: [low, medium, high] } - id: faq_lookup name: FAQ Lookup description: Cached LLM-backed FAQ answers - type: stateless + type: synthetic system_prompt: Generate concise FAQ answers grounded in the tool contract. + cache_by_input: true tools: - id: answer_faq name: AnswerFAQ description: Answer common support questions. - generated_output: - instruction: Return the FAQ answer for the given question. + parameters: + type: object + properties: + question: { type: string } - id: policy_table name: Policy Table @@ -217,6 +241,10 @@ environment_config: - id: get_refund_policy name: GetRefundPolicy description: Return the matching refund policy. + parameters: + type: object + properties: + policy_type: { type: string } deterministic_outputs: - input: policy_type: standard @@ -231,7 +259,7 @@ strategy_params: role_instruction_messages: USER: You are a customer contacting support. ASSISTANT: You are a helpful support agent. - available_tools: [get_ticket, answer_faq, get_refund_policy] + available_environments: [support_backend, faq_lookup, policy_table] ``` ## Complete Configuration Reference diff --git a/src/oumi/core/configs/__init__.py b/src/oumi/core/configs/__init__.py index bec87c8a5b..068483ce01 100644 --- a/src/oumi/core/configs/__init__.py +++ b/src/oumi/core/configs/__init__.py @@ -124,7 +124,9 @@ from oumi.core.configs.params.profiler_params import ProfilerParams from oumi.core.configs.params.remote_params import RemoteParams from oumi.core.configs.params.synthesis_params import ( - AttributeCombination, + DatasetSource as DatasetSourceParam, +) +from oumi.core.configs.params.synthesis_params import ( DocumentSegmentationParams, DocumentSource, ExampleSource, @@ -141,9 +143,6 @@ TransformationType, TransformedAttribute, ) -from oumi.core.configs.params.synthesis_params import ( - DatasetSource as DatasetSourceParam, -) from oumi.core.configs.params.telemetry_params import TelemetryParams from oumi.core.configs.params.training_params import ( MixedPrecisionDtype, @@ -159,23 +158,12 @@ from oumi.core.configs.synthesis_config import SynthesisConfig from oumi.core.configs.training_config import TrainingConfig from oumi.core.configs.tuning_config import TuningConfig -from oumi.environments import ( - BaseEnvironment, - BaseTool, - DeterministicEnvironment, - DeterministicToolOutput, - GeneratedToolOutput, - StatefulEnvironment, - StatelessEnvironment, - ToolEnvironmentType, -) __all__ = [ "AsyncEvaluationConfig", "AutoWrapPolicy", "BackwardPrefetch", "BaseConfig", - "BaseEnvironment", "DataParams", "DatasetParams", "DatasetSplit", @@ -187,8 +175,8 @@ "EvaluationConfig", "EvaluationBackend", "EvaluationConfig", - "EvaluationTaskParams", "EnvironmentConfig", + "EvaluationTaskParams", "FSDPParams", "GenerationParams", "GrpoParams", @@ -222,26 +210,18 @@ "TunerType", "TuningConfig", "TuningParams", - "AttributeCombination", "DatasetSourceParam", - "DeterministicToolOutput", - "DeterministicEnvironment", "DocumentSegmentationParams", "DocumentSource", "ExampleSource", - "GeneratedToolOutput", "GeneratedAttributePostprocessingParams", "GeneralSynthesisParams", "GeneratedAttribute", "SampledAttribute", "SampledAttributeValue", "SegmentationStrategy", - "StatefulEnvironment", - "StatelessEnvironment", "TextConversation", "TextMessage", - "BaseTool", - "ToolEnvironmentType", "TransformationStrategy", "TransformationType", "TransformedAttribute", diff --git a/src/oumi/core/configs/environment_config.py b/src/oumi/core/configs/environment_config.py index e9c6712379..1a8a985397 100644 --- a/src/oumi/core/configs/environment_config.py +++ b/src/oumi/core/configs/environment_config.py @@ -18,7 +18,8 @@ from typing import Any from oumi.core.configs.base_config import BaseConfig -from oumi.environments import BaseEnvironment, BaseTool +from oumi.environments.base_environment import BaseEnvironment +from oumi.environments.base_tool import Tool @dataclass @@ -54,7 +55,7 @@ def __post_init__(self): tool_ids.add(tool.id) @property - def all_tools(self) -> list[BaseTool]: + def all_tools(self) -> list[Tool]: """Flatten all tools across environments.""" return [tool for environment in self.environments for tool in environment.tools] @@ -74,7 +75,7 @@ def get_environment(self, environment_id: str) -> BaseEnvironment | None: return environment return None - def get_tool(self, tool_id: str) -> BaseTool | None: + def get_tool(self, tool_id: str) -> Tool | None: """Look up a tool by id.""" for tool in self.all_tools: if tool.id == tool_id: @@ -85,7 +86,7 @@ def resolve_tools( self, environment_ids: list[str] | None = None, tool_ids: list[str] | None = None, - ) -> list[BaseTool]: + ) -> list[Tool]: """Resolve tools from selected environments and optional tool ids. Raises: diff --git a/src/oumi/core/configs/synthesis_config.py b/src/oumi/core/configs/synthesis_config.py index 10a22a67a6..8d63f34658 100644 --- a/src/oumi/core/configs/synthesis_config.py +++ b/src/oumi/core/configs/synthesis_config.py @@ -23,7 +23,8 @@ GeneralSynthesisParams, MultiTurnAttribute, ) -from oumi.environments import BaseEnvironment, BaseTool +from oumi.environments.base_environment import BaseEnvironment +from oumi.environments.base_tool import Tool class SynthesisStrategy(str, Enum): @@ -144,7 +145,7 @@ def resolve_multiturn_environments( def resolve_multiturn_tools( self, multiturn_attribute: MultiTurnAttribute - ) -> list[BaseTool]: + ) -> list[Tool]: """Resolve the tools available to a multiturn attribute.""" if self.environment_config is None: return [] diff --git a/src/oumi/core/synthesis/conversation_synthesizer.py b/src/oumi/core/synthesis/conversation_synthesizer.py index 7717fefc01..f0b230ceef 100644 --- a/src/oumi/core/synthesis/conversation_synthesizer.py +++ b/src/oumi/core/synthesis/conversation_synthesizer.py @@ -24,7 +24,7 @@ ) from oumi.core.synthesis.attribute_formatter import AttributeFormatter from oumi.core.types.conversation import Conversation, Message, Role -from oumi.environments import BaseTool +from oumi.environments import Tool from oumi.utils.logging import logger from oumi.utils.str_utils import extract_json @@ -58,7 +58,7 @@ def __init__( def _resolve_available_tools( self, multiturn_attribute: MultiTurnAttribute - ) -> list[BaseTool]: + ) -> list[Tool]: """Resolve tools for a multiturn attribute from selected environments.""" if self._environment_config is None: return [] diff --git a/src/oumi/environments/__init__.py b/src/oumi/environments/__init__.py index 2322d897c6..dafe7bdab7 100644 --- a/src/oumi/environments/__init__.py +++ b/src/oumi/environments/__init__.py @@ -12,47 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Environments for agentic tool interactions. - -Environments are simulated worlds that agents interact with via tools. -Consumers include synthesis (training data generation), evaluation -(agent testing), and RL (reward-driven training). - -Each environment type defines how tool calls are resolved: - -- **StatefulEnvironment**: mutable JSON state across calls. -- **StatelessEnvironment**: LLM-generated outputs with optional caching. -- **DeterministicEnvironment**: fixed input-to-output lookup tables. -""" +"""Environments for agentic tool interactions.""" from oumi.environments.base_environment import BaseEnvironment -from oumi.environments.base_tool import BaseTool -from oumi.environments.deterministic_environment import ( - DeterministicEnvironment, - DeterministicTool, - DeterministicToolOutput, -) -from oumi.environments.stateful_environment import ( - StatefulEnvironment, - StatefulTool, -) -from oumi.environments.stateless_environment import ( - GeneratedToolOutput, - StatelessEnvironment, - StatelessTool, +from oumi.environments.base_tool import DeterministicToolOutput, Tool +from oumi.environments.deterministic_environment import DeterministicEnvironment +from oumi.environments.synthetic_environment import ( + SyntheticEnvironment, + SyntheticStateParams, ) -from oumi.environments.types import ToolEnvironmentType +from oumi.environments.tool_result import ToolResult __all__ = [ "BaseEnvironment", - "BaseTool", + "Tool", + "ToolResult", + "SyntheticEnvironment", + "SyntheticStateParams", "DeterministicEnvironment", - "DeterministicTool", "DeterministicToolOutput", - "GeneratedToolOutput", - "StatefulEnvironment", - "StatefulTool", - "StatelessEnvironment", - "StatelessTool", - "ToolEnvironmentType", ] diff --git a/src/oumi/environments/base_environment.py b/src/oumi/environments/base_environment.py index 879d3bf2c1..1ff63bff12 100644 --- a/src/oumi/environments/base_environment.py +++ b/src/oumi/environments/base_environment.py @@ -12,12 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Abstract base class for tool environments. - -Environments are simulated worlds that agents interact with via tools. -They are used by synthesis (to generate training data), evaluation -(to test agent behaviour), and RL (to provide reward signals). -""" +"""Abstract base class for tool environments.""" from __future__ import annotations @@ -27,26 +22,20 @@ from typing import Any, ClassVar from oumi.core.configs.params.base_params import BaseParams -from oumi.environments.base_tool import BaseTool -from oumi.environments.types import ToolEnvironmentType +from oumi.environments.base_tool import Tool +from oumi.environments.tool_result import ToolResult @dataclass class BaseEnvironment(BaseParams, ABC): - """Abstract base class for tool environments. - - Each environment owns a set of tools and defines how tool calls are - resolved. Subclasses implement the concrete execution model and - coerce raw tool definitions into their typed tool subclass. - """ + """Abstract base class for tool environments.""" - _registry: ClassVar[dict[ToolEnvironmentType, type[BaseEnvironment]]] = {} + _registry: ClassVar[dict[str, type[BaseEnvironment]]] = {} id: str name: str description: str - tools: list[BaseTool] = field(default_factory=list) - type: ToolEnvironmentType = field(init=False) + tools: list[Tool] = field(default_factory=list) def __init_subclass__(cls, **kwargs): """Register subclass in the environment type registry.""" @@ -65,7 +54,6 @@ def __post_init__(self): raise ValueError(f"{type(self).__name__}.description cannot be empty.") self.tools = self._coerce_tools(self.tools) self._validate_unique_tool_ids() - self._validate_type_specific() def _validate_unique_tool_ids(self) -> None: tool_ids: set[str] = set() @@ -77,16 +65,32 @@ def _validate_unique_tool_ids(self) -> None: ) tool_ids.add(tool.id) - @abstractmethod - def _coerce_tools(self, tools: list[Any]) -> list[BaseTool]: - """Coerce raw tool definitions into this environment's typed tool class.""" + def _coerce_tools(self, tools: list[Any]) -> list[Tool]: + """Coerce raw tool definitions into Tool instances.""" + return [Tool.create(tool) for tool in tools] @abstractmethod - def _validate_type_specific(self) -> None: - """Validate fields specific to the environment subtype.""" + def step(self, tool_id: str, arguments: dict[str, Any]) -> ToolResult: + """Execute a tool call within this environment.""" + + def _get_tool(self, tool_id: str) -> Tool | None: + """Look up a tool owned by this environment.""" + for tool in self.tools: + if tool.id == tool_id: + return tool + return None + + def _get_tool_or_raise(self, tool_id: str) -> Tool: + tool = self._get_tool(tool_id) + if tool is None: + raise ValueError( + f"Tool '{tool_id}' not found in environment '{self.id}'. " + f"Available tools: {[tool.id for tool in self.tools]}" + ) + return tool @classmethod - def create(cls, raw: Mapping[str, Any] | BaseEnvironment) -> BaseEnvironment: + def create(cls, raw: Any) -> BaseEnvironment: """Create a concrete environment from raw config data. Raises: @@ -100,21 +104,14 @@ def create(cls, raw: Mapping[str, Any] | BaseEnvironment) -> BaseEnvironment: "Environment definitions must be environment objects or mappings, " f"got {type(raw)}" ) - raw_type = raw.get("type") - if raw_type is None: + environment_type = raw.get("type") + if environment_type is None: raise ValueError( "Environment definition must include a 'type' field. " - f"Supported types: {[t.value for t in ToolEnvironmentType]}" + f"Supported types: {sorted(cls._registry)}" ) - if isinstance(raw_type, ToolEnvironmentType): - environment_type = raw_type - elif isinstance(raw_type, str): - try: - environment_type = ToolEnvironmentType(raw_type) - except ValueError: - environment_type = ToolEnvironmentType[raw_type] - else: - environment_type = ToolEnvironmentType(raw_type) + if not isinstance(environment_type, str): + environment_type = str(environment_type) environment_cls = cls._registry.get(environment_type) if environment_cls is None: raise ValueError(f"Unsupported environment type: {environment_type}") diff --git a/src/oumi/environments/base_tool.py b/src/oumi/environments/base_tool.py index be5e0226b2..d6cbf22824 100644 --- a/src/oumi/environments/base_tool.py +++ b/src/oumi/environments/base_tool.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Base tool class shared by all environment types.""" +"""Tool definitions shared by all environment types.""" from __future__ import annotations +import json from collections.abc import Mapping from dataclasses import dataclass, field from typing import Any @@ -24,23 +25,34 @@ @dataclass -class BaseTool(BaseParams): - """Common fields for all tools exposed by an environment.""" +class DeterministicToolOutput(BaseParams): + """An input-to-output mapping for a deterministic tool.""" + + input: dict[str, Any] = field(default_factory=dict) + output: dict[str, Any] = field(default_factory=dict) + + def matches(self, arguments: dict[str, Any]) -> bool: + """Check if the input matches the given arguments.""" + return json.dumps(self.input, sort_keys=True) == json.dumps( + arguments, sort_keys=True + ) + + +@dataclass +class Tool(BaseParams): + """Tool schema owned by an environment.""" id: str name: str description: str parameters: dict[str, Any] = field(default_factory=dict) + read_only: bool = True + deterministic_outputs: list[DeterministicToolOutput] = field(default_factory=list) @classmethod - def create(cls, raw: Mapping[str, Any] | BaseTool) -> BaseTool: - """Create a tool from raw config data. - - Returns a ``BaseTool`` with only the common fields. Environment - subclasses call their own typed factory (e.g. - ``DeterministicTool.create``) to get the full subclass. - """ - if isinstance(raw, BaseTool): + def create(cls, raw: Any) -> Tool: + """Create a tool from raw config data.""" + if isinstance(raw, Tool): return raw if not isinstance(raw, Mapping): raise TypeError( @@ -51,6 +63,8 @@ def create(cls, raw: Mapping[str, Any] | BaseTool) -> BaseTool: name=raw["name"], description=raw["description"], parameters=raw.get("parameters", {}), + read_only=raw.get("read_only", True), + deterministic_outputs=raw.get("deterministic_outputs", []), ) def __post_init__(self): @@ -61,3 +75,11 @@ def __post_init__(self): raise ValueError(f"{type(self).__name__}.name cannot be empty.") if not self.description: raise ValueError(f"{type(self).__name__}.description cannot be empty.") + + def to_llm_schema(self) -> dict[str, Any]: + """Export a provider-agnostic schema for LLM tool registration.""" + return { + "name": self.name, + "description": self.description, + "parameters": self.parameters, + } diff --git a/src/oumi/environments/deterministic_environment.py b/src/oumi/environments/deterministic_environment.py index f4424634b6..768b9df86b 100644 --- a/src/oumi/environments/deterministic_environment.py +++ b/src/oumi/environments/deterministic_environment.py @@ -17,131 +17,58 @@ from __future__ import annotations import json -from collections.abc import Mapping from dataclasses import dataclass, field from typing import Any, ClassVar -from oumi.core.configs.params.base_params import BaseParams from oumi.environments.base_environment import BaseEnvironment -from oumi.environments.base_tool import BaseTool -from oumi.environments.types import ToolEnvironmentType +from oumi.environments.base_tool import DeterministicToolOutput, Tool +from oumi.environments.tool_result import ToolResult @dataclass -class DeterministicToolOutput(BaseParams): - """An input-to-output mapping for a deterministic tool.""" - - input: dict[str, Any] = field(default_factory=dict) - output: dict[str, Any] = field(default_factory=dict) - - def __post_init__(self): - """Validate the input and output fields are not empty.""" - if not self.input: - raise ValueError("DeterministicToolOutput.input cannot be empty.") - if not self.output: - raise ValueError("DeterministicToolOutput.output cannot be empty.") - - def matches(self, arguments: dict[str, Any]) -> bool: - """Check if the input matches the given arguments.""" - return json.dumps(self.input, sort_keys=True) == json.dumps( - arguments, sort_keys=True - ) - - -@dataclass -class DeterministicTool(BaseTool): - """Tool with fixed input-to-output lookup responses.""" - - deterministic_outputs: list[DeterministicToolOutput] = field(default_factory=list) +class DeterministicEnvironment(BaseEnvironment): + """Environment that resolves tools from fixed lookups.""" + + ENVIRONMENT_TYPE: ClassVar[str] = "deterministic" + type: str = field(init=False, default=ENVIRONMENT_TYPE) + + def _coerce_tools(self, tools: list[Any]) -> list[Tool]: + """Coerce tools and deterministic outputs into typed objects.""" + coerced_tools: list[Tool] = [] + for raw_tool in tools: + tool = Tool.create(raw_tool) + tool.deterministic_outputs = [ + entry + if isinstance(entry, DeterministicToolOutput) + else DeterministicToolOutput(**entry) + for entry in tool.deterministic_outputs + ] + coerced_tools.append(tool) + return coerced_tools def __post_init__(self): - """Validate deterministic tool fields.""" + """Validate that deterministic tools have deterministic output entry.""" super().__post_init__() - if not self.deterministic_outputs: - raise ValueError( - f"DeterministicTool '{self.id}' must have at least one " - f"deterministic_output entry." - ) - self._check_deterministic_duplicates() - - def _check_deterministic_duplicates(self) -> None: - seen: set[str] = set() - for entry in self.deterministic_outputs: - key = json.dumps(entry.input, sort_keys=True) - if key in seen: + for tool in self.tools: + if not tool.deterministic_outputs: raise ValueError( - f"DeterministicTool '{self.id}' has duplicate " - f"deterministic input entry: {entry.input}" + f"Deterministic tool '{tool.id}' must have at least one " + "deterministic_output entry." ) - seen.add(key) - - def resolve_deterministic(self, arguments: dict[str, Any]) -> dict[str, Any] | None: - """Resolve a deterministic output for the given arguments.""" - for entry in self.deterministic_outputs: + seen: set[str] = set() + for entry in tool.deterministic_outputs: + key = json.dumps(entry.input, sort_keys=True) + if key in seen: + raise ValueError( + f"Deterministic tool '{tool.id}' has duplicate " + f"deterministic input entry: {entry.input}" + ) + seen.add(key) + + def step(self, tool_id: str, arguments: dict[str, Any]) -> ToolResult: + """Resolve a deterministic tool call to its output.""" + tool = self._get_tool_or_raise(tool_id) + for entry in tool.deterministic_outputs: if entry.matches(arguments): - return entry.output - return None - - @classmethod - def create(cls, raw: Mapping[str, Any] | BaseTool) -> DeterministicTool: - """Create a DeterministicTool from raw config data.""" - if isinstance(raw, DeterministicTool): - return raw - if isinstance(raw, BaseTool): - raise TypeError( - f"Cannot coerce {type(raw).__name__} to DeterministicTool. " - f"Use a mapping with 'deterministic_outputs'." - ) - if not isinstance(raw, Mapping): - raise TypeError( - f"Tool definitions must be tool objects or mappings, got {type(raw)}" - ) - deterministic_outputs = [ - entry - if isinstance(entry, DeterministicToolOutput) - else DeterministicToolOutput(**entry) - for entry in raw.get("deterministic_outputs", []) - ] - return cls( - id=raw["id"], - name=raw["name"], - description=raw["description"], - parameters=raw.get("parameters", {}), - deterministic_outputs=deterministic_outputs, - ) - - -@dataclass -class DeterministicEnvironment(BaseEnvironment): - """Environment that resolves tools from fixed lookups. - - Each tool is a ``DeterministicTool`` with a list of input-to-output - mappings. The environment owns the resolution logic. - """ - - ENVIRONMENT_TYPE: ClassVar[ToolEnvironmentType] = ToolEnvironmentType.DETERMINISTIC - type: ToolEnvironmentType = field( - init=False, default=ToolEnvironmentType.DETERMINISTIC - ) - tools: list[DeterministicTool] = field(default_factory=list) # type: ignore[assignment] - - def _coerce_tools(self, tools: list[Any]) -> list[DeterministicTool]: - """Coerce raw tool definitions into DeterministicTool instances.""" - return [DeterministicTool.create(t) for t in tools] - - def _validate_type_specific(self) -> None: - return - - def resolve(self, tool_id: str, arguments: dict[str, Any]) -> dict[str, Any] | None: - """Resolve a deterministic tool call to its output. - - Raises: - ValueError: If tool_id is not found in this environment. - """ - for tool in self.tools: - if tool.id == tool_id: - return tool.resolve_deterministic(arguments) - raise ValueError( - f"Tool '{tool_id}' not found in environment '{self.id}'. " - f"Available tools: {[t.id for t in self.tools]}" - ) + return ToolResult(output=entry.output) + return ToolResult(output=None) diff --git a/src/oumi/environments/stateful_environment.py b/src/oumi/environments/stateful_environment.py deleted file mode 100644 index e619ff7bf9..0000000000 --- a/src/oumi/environments/stateful_environment.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2025 - Oumi -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Stateful environment with mutable shared state.""" - -from __future__ import annotations - -from collections.abc import Mapping -from dataclasses import dataclass, field -from typing import Any, ClassVar - -from oumi.environments.base_environment import BaseEnvironment -from oumi.environments.base_tool import BaseTool -from oumi.environments.types import ToolEnvironmentType - - -@dataclass -class StatefulTool(BaseTool): - """Tool bound to a stateful environment.""" - - output_schema: dict[str, Any] = field(default_factory=dict) - read_only: bool = True - - @classmethod - def create(cls, raw: Mapping[str, Any] | BaseTool) -> StatefulTool: - """Create a StatefulTool from raw config data.""" - if isinstance(raw, StatefulTool): - return raw - if isinstance(raw, BaseTool): - return cls( - id=raw.id, - name=raw.name, - description=raw.description, - parameters=raw.parameters, - ) - if not isinstance(raw, Mapping): - raise TypeError( - f"Tool definitions must be tool objects or mappings, got {type(raw)}" - ) - return cls( - id=raw["id"], - name=raw["name"], - description=raw["description"], - parameters=raw.get("parameters", {}), - output_schema=raw.get("output_schema", {}), - read_only=raw.get("read_only", True), - ) - - -@dataclass -class StatefulEnvironment(BaseEnvironment): - """Environment with mutable shared state. - - Maintains a JSON state dict that tools can read and modify. The - ``system_prompt`` instructs the LLM how to simulate this environment. - Each tool is a ``StatefulTool`` with an output schema and read/write flag. - """ - - ENVIRONMENT_TYPE: ClassVar[ToolEnvironmentType] = ToolEnvironmentType.STATEFUL - type: ToolEnvironmentType = field(init=False, default=ToolEnvironmentType.STATEFUL) - system_prompt: str = "" - state_schema: dict[str, Any] | None = None - initial_state: dict[str, Any] | None = None - tools: list[StatefulTool] = field(default_factory=list) # type: ignore[assignment] - - def _coerce_tools(self, tools: list[Any]) -> list[StatefulTool]: - """Coerce raw tool definitions into StatefulTool instances.""" - return [StatefulTool.create(t) for t in tools] - - def _validate_type_specific(self) -> None: - """Validate stateful-specific fields.""" - if not self.system_prompt: - raise ValueError("StatefulEnvironment.system_prompt cannot be empty.") diff --git a/src/oumi/environments/stateless_environment.py b/src/oumi/environments/stateless_environment.py deleted file mode 100644 index 65ece814fd..0000000000 --- a/src/oumi/environments/stateless_environment.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright 2025 - Oumi -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Stateless environment with optional response caching.""" - -from __future__ import annotations - -import json -from collections.abc import Mapping -from dataclasses import dataclass, field -from typing import Any, ClassVar - -from oumi.core.configs.params.base_params import BaseParams -from oumi.environments.base_environment import BaseEnvironment -from oumi.environments.base_tool import BaseTool -from oumi.environments.types import ToolEnvironmentType - - -@dataclass -class GeneratedToolOutput(BaseParams): - """Configuration for tool output in a stateless environment.""" - - instruction: str - - def __post_init__(self): - """Validate the instruction field is not empty.""" - if not self.instruction: - raise ValueError("GeneratedToolOutput.instruction cannot be empty.") - - -@dataclass -class StatelessTool(BaseTool): - """Tool bound to a stateless environment.""" - - generated_output: GeneratedToolOutput | None = None - - def __post_init__(self): - """Validate stateless tool fields.""" - super().__post_init__() - if self.generated_output is None: - raise ValueError(f"StatelessTool '{self.id}' must have a generated_output.") - - @classmethod - def create(cls, raw: Mapping[str, Any] | BaseTool) -> StatelessTool: - """Create a StatelessTool from raw config data.""" - if isinstance(raw, StatelessTool): - return raw - if isinstance(raw, BaseTool): - raise TypeError( - f"Cannot coerce {type(raw).__name__} to StatelessTool. " - f"Use a mapping with 'generated_output'." - ) - if not isinstance(raw, Mapping): - raise TypeError( - f"Tool definitions must be tool objects or mappings, got {type(raw)}" - ) - generated_output = raw.get("generated_output") - if isinstance(generated_output, Mapping): - generated_output = GeneratedToolOutput(**generated_output) - return cls( - id=raw["id"], - name=raw["name"], - description=raw["description"], - parameters=raw.get("parameters", {}), - generated_output=generated_output, - ) - - -@dataclass -class StatelessEnvironment(BaseEnvironment): - """Environment that simulates outputs and optionally caches by input. - - Each tool is a ``StatelessTool`` with a ``generated_output`` instruction - for the LLM. When ``cache_by_input`` is True, the environment caches - results keyed by (tool_id, arguments) so repeated calls with the same - input return consistent results. - """ - - ENVIRONMENT_TYPE: ClassVar[ToolEnvironmentType] = ToolEnvironmentType.STATELESS - type: ToolEnvironmentType = field(init=False, default=ToolEnvironmentType.STATELESS) - system_prompt: str = "" - cache_by_input: bool = True - tools: list[StatelessTool] = field(default_factory=list) # type: ignore[assignment] - - def __post_init__(self): - """Validate and initialize the response cache.""" - super().__post_init__() - self._cache: dict[str, str] = {} - self._frozen_context: str | None = None - - def _coerce_tools(self, tools: list[Any]) -> list[StatelessTool]: - """Coerce raw tool definitions into StatelessTool instances.""" - return [StatelessTool.create(t) for t in tools] - - def _validate_type_specific(self) -> None: - """Validate stateless-specific fields.""" - if not self.system_prompt: - raise ValueError("StatelessEnvironment.system_prompt cannot be empty.") - - @property - def frozen_context(self) -> str | None: - """Frozen context generated once at build time.""" - return self._frozen_context - - def set_frozen_context(self, context: str) -> None: - """Set the frozen context (called once during environment build).""" - self._frozen_context = context - - @staticmethod - def _cache_key(tool_id: str, arguments: dict[str, Any]) -> str: - """Build a stable cache key from tool id and arguments.""" - return f"{tool_id}::{json.dumps(arguments, sort_keys=True)}" - - def resolve_cached(self, tool_id: str, arguments: dict[str, Any]) -> str | None: - """Look up a cached result for the given tool call.""" - if not self.cache_by_input: - return None - return self._cache.get(self._cache_key(tool_id, arguments)) - - def cache_result( - self, tool_id: str, arguments: dict[str, Any], result: str - ) -> None: - """Store a generated result in the cache. No-op if caching is disabled.""" - if self.cache_by_input: - self._cache[self._cache_key(tool_id, arguments)] = result diff --git a/src/oumi/environments/synthetic_environment.py b/src/oumi/environments/synthetic_environment.py new file mode 100644 index 0000000000..243303c5b1 --- /dev/null +++ b/src/oumi/environments/synthetic_environment.py @@ -0,0 +1,169 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Synthetic environment backed by LLM-simulated tool execution.""" + +from __future__ import annotations + +import copy +import json +from dataclasses import dataclass, field +from typing import Any, ClassVar + +from oumi.core.configs.params.base_params import BaseParams +from oumi.environments.base_environment import BaseEnvironment +from oumi.environments.tool_result import ToolResult + + +def _validate_json_schema_value( + value: Any, + schema: dict[str, Any], + path: str = "$", +) -> None: + """Validate a JSON-like value against a minimal schema subset.""" + expected_type = schema.get("type") + if expected_type == "object": + if not isinstance(value, dict): + raise ValueError(f"{path} must be an object.") + required = schema.get("required", []) + for key in required: + if key not in value: + raise ValueError(f"{path}.{key} is required.") + for key, child_value in value.items(): + child_schema = schema.get("properties", {}).get(key) + if child_schema is not None: + _validate_json_schema_value(child_value, child_schema, f"{path}.{key}") + elif expected_type == "array": + if not isinstance(value, list): + raise ValueError(f"{path} must be an array.") + item_schema = schema.get("items") + if item_schema is not None: + for idx, item in enumerate(value): + _validate_json_schema_value(item, item_schema, f"{path}[{idx}]") + elif expected_type == "string": + if not isinstance(value, str): + raise ValueError(f"{path} must be a string.") + elif expected_type == "integer": + if isinstance(value, bool) or not isinstance(value, int): + raise ValueError(f"{path} must be an integer.") + elif expected_type == "number": + if isinstance(value, bool) or not isinstance(value, (int, float)): + raise ValueError(f"{path} must be a number.") + elif expected_type == "boolean": + if not isinstance(value, bool): + raise ValueError(f"{path} must be a boolean.") + elif expected_type == "null": + if value is not None: + raise ValueError(f"{path} must be null.") + + enum_values = schema.get("enum") + if enum_values is not None and value not in enum_values: + raise ValueError(f"{path} must be one of {enum_values}.") + + +@dataclass +class SyntheticStateParams(BaseParams): + """Optional state configuration for a synthetic environment.""" + + state_schema: dict[str, Any] | None = None + initial_state: dict[str, Any] | None = None + + def __post_init__(self): + """Validate state config consistency.""" + if self.state_schema is not None and self.initial_state is not None: + _validate_json_schema_value(self.initial_state, self.state_schema) + + +@dataclass +class SyntheticEnvironment(BaseEnvironment): + """LLM-simulated environment with optional mutable state.""" + + ENVIRONMENT_TYPE: ClassVar[str] = "synthetic" + type: str = field(init=False, default=ENVIRONMENT_TYPE) + system_prompt: str = "" + state_params: SyntheticStateParams | dict[str, Any] | None = None + cache_by_input: bool = True + + def __post_init__(self): + """Validate synthetic-only fields after common environment validation.""" + if isinstance(self.state_params, dict): + self.state_params = SyntheticStateParams(**self.state_params) + super().__post_init__() + self._cache: dict[str, ToolResult] = {} + self._state: dict[str, Any] | None = ( + copy.deepcopy(self.state_params.initial_state) + if self.state_params is not None + and self.state_params.initial_state is not None + else None + ) + + # Validate synthetic-only fields after common environment validation/coercion. + if not self.system_prompt: + raise ValueError("SyntheticEnvironment.system_prompt cannot be empty.") + if self.state_params is not None and self.cache_by_input: + raise ValueError( + "SyntheticEnvironment.cache_by_input must be False when " + "state_params is provided." + ) + for tool in self.tools: + if tool.deterministic_outputs: + raise ValueError( + f"Synthetic tool '{tool.id}' cannot define deterministic_outputs." + ) + + @property + def current_state(self) -> dict[str, Any] | None: + """Return the current in-memory state snapshot.""" + if self._state is None: + return None + return copy.deepcopy(self._state) + + @staticmethod + def _cache_key(tool_id: str, arguments: dict[str, Any]) -> str: + """Build a stable cache key from tool id and arguments.""" + return f"{tool_id}::{json.dumps(arguments, sort_keys=True)}" + + def _resolve_cached( + self, tool_id: str, arguments: dict[str, Any] + ) -> ToolResult | None: + """Look up a cached result for the given tool call.""" + if not self.cache_by_input: + return None + result = self._cache.get(self._cache_key(tool_id, arguments)) + if result is None: + return None + return ToolResult( + output=copy.deepcopy(result.output), + updated_state=copy.deepcopy(result.updated_state), + ) + + def _cache_result( + self, tool_id: str, arguments: dict[str, Any], result: ToolResult + ) -> None: + """Store a generated result in the cache.""" + if not self.cache_by_input: + return + self._cache[self._cache_key(tool_id, arguments)] = ToolResult( + output=copy.deepcopy(result.output), + updated_state=copy.deepcopy(result.updated_state), + ) + + def step(self, tool_id: str, arguments: dict[str, Any]) -> ToolResult: + """Execute a synthetic tool call. + + The environment interface is defined now; actual LLM-backed execution + will be implemented in follow-up changes. + """ + self._get_tool_or_raise(tool_id) + raise NotImplementedError("SyntheticEnvironment.step() is not implemented yet.") diff --git a/src/oumi/environments/types.py b/src/oumi/environments/tool_result.py similarity index 63% rename from src/oumi/environments/types.py rename to src/oumi/environments/tool_result.py index f3654ab0cf..d0af0ee6b7 100644 --- a/src/oumi/environments/types.py +++ b/src/oumi/environments/tool_result.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Shared types for the environments package.""" +"""Tool execution results.""" -from enum import Enum +from dataclasses import dataclass +from typing import Any +from oumi.core.configs.params.base_params import BaseParams -class ToolEnvironmentType(str, Enum): - """Execution model for an environment-bound tool.""" - STATEFUL = "stateful" - STATELESS = "stateless" - DETERMINISTIC = "deterministic" +@dataclass +class ToolResult(BaseParams): + """Result returned by an environment step.""" + + output: str | dict[str, Any] | None + updated_state: dict[str, Any] | None = None diff --git a/tests/unit/core/configs/params/test_tool_params.py b/tests/unit/core/configs/params/test_tool_params.py index 8ea9b8e54e..8069a63f33 100644 --- a/tests/unit/core/configs/params/test_tool_params.py +++ b/tests/unit/core/configs/params/test_tool_params.py @@ -18,20 +18,17 @@ from oumi.core.configs.environment_config import EnvironmentConfig from oumi.environments import ( - BaseTool, + BaseEnvironment, DeterministicEnvironment, - DeterministicTool, DeterministicToolOutput, - GeneratedToolOutput, - StatefulEnvironment, - StatefulTool, - StatelessEnvironment, - StatelessTool, - ToolEnvironmentType, + SyntheticEnvironment, + SyntheticStateParams, + Tool, + ToolResult, ) -def _make_deterministic_tool(**overrides: Any) -> DeterministicTool: +def _make_deterministic_tool(**overrides: Any) -> Tool: defaults: dict[str, Any] = dict( id="tool1", name="MyTool", @@ -41,41 +38,41 @@ def _make_deterministic_tool(**overrides: Any) -> DeterministicTool: ], ) defaults.update(overrides) - return DeterministicTool(**defaults) + return Tool(**defaults) -def _make_stateless_tool(**overrides: Any) -> StatelessTool: +def _make_synthetic_tool(**overrides: Any) -> Tool: defaults: dict[str, Any] = dict( id="tool2", name="GenTool", description="A generated tool", - generated_output=GeneratedToolOutput(instruction="Do something."), ) defaults.update(overrides) - return StatelessTool(**defaults) + return Tool(**defaults) -def _make_stateful_tool(**overrides: Any) -> StatefulTool: - defaults: dict[str, Any] = dict( - id="tool3", - name="StatefulTool", - description="A stateful tool", - ) - defaults.update(overrides) - return StatefulTool(**defaults) - - -# --- DeterministicToolOutput tests --- +def _make_state_schema() -> dict[str, Any]: + return { + "type": "object", + "properties": { + "files": { + "type": "object", + "properties": {"count": {"type": "integer"}}, + "required": ["count"], + } + }, + "required": ["files"], + } -def test_deterministic_tool_output_empty_input_raises(): - with pytest.raises(ValueError, match="input cannot be empty"): - DeterministicToolOutput(input={}, output={"msg": "ok"}) +def test_deterministic_tool_output_allows_empty_input(): + entry = DeterministicToolOutput(input={}, output={"msg": "ok"}) + assert entry.input == {} -def test_deterministic_tool_output_empty_output_raises(): - with pytest.raises(ValueError, match="output cannot be empty"): - DeterministicToolOutput(input={"id": "1"}, output={}) +def test_deterministic_tool_output_allows_empty_output(): + entry = DeterministicToolOutput(input={"id": "1"}, output={}) + assert entry.output == {} def test_deterministic_tool_output_matches_exact(): @@ -96,91 +93,175 @@ def test_deterministic_tool_output_no_match(): assert entry.matches({"id": "01", "extra": "arg"}) is False -# --- BaseTool tests --- - - @pytest.mark.parametrize("field,value", [("id", ""), ("name", ""), ("description", "")]) -def test_base_tool_empty_field_raises(field, value): +def test_tool_empty_field_raises(field, value): with pytest.raises(ValueError, match=f"{field} cannot be empty"): - BaseTool(**{"id": "t", "name": "T", "description": "d", **{field: value}}) + Tool(**{"id": "t", "name": "T", "description": "d", **{field: value}}) -# --- DeterministicTool tests --- +def test_tool_to_llm_schema(): + tool = Tool( + id="search", + name="Search", + description="Search the catalog.", + parameters={"type": "object", "properties": {"query": {"type": "string"}}}, + ) + assert tool.to_llm_schema() == { + "name": "Search", + "description": "Search the catalog.", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + }, + } -def test_deterministic_tool_requires_outputs(): - with pytest.raises(ValueError, match="must have at least one"): - DeterministicTool(id="t", name="T", description="d", deterministic_outputs=[]) +def test_tool_create_coerces_deterministic_outputs(): + env = BaseEnvironment.create( + { + "id": "lookup", + "type": "deterministic", + "name": "Lookup", + "description": "Lookup tools", + "tools": [ + { + "id": "policy", + "name": "Policy", + "description": "Look up policy.", + "deterministic_outputs": [ + {"input": {"id": "1"}, "output": {"result": "ok"}} + ], + } + ], + } + ) + assert isinstance(env, DeterministicEnvironment) + assert isinstance(env.tools[0].deterministic_outputs[0], DeterministicToolOutput) -def test_deterministic_tool_duplicate_inputs_raises(): - outputs = [ - DeterministicToolOutput(input={"id": "01"}, output={"msg": "a"}), - DeterministicToolOutput(input={"id": "01"}, output={"msg": "b"}), - ] - with pytest.raises(ValueError, match="duplicate"): - _make_deterministic_tool(deterministic_outputs=outputs) +def test_synthetic_state_params_validates_initial_state_against_schema(): + with pytest.raises(ValueError, match=r"\$\.files\.count must be an integer"): + SyntheticStateParams( + state_schema=_make_state_schema(), + initial_state={"files": {"count": "bad"}}, + ) -def test_deterministic_tool_resolve_match(): - tool = _make_deterministic_tool( - deterministic_outputs=[ - DeterministicToolOutput(input={"id": "01"}, output={"msg": "pending"}), - DeterministicToolOutput(input={"id": "02"}, output={"msg": "delivered"}), - ] +def test_synthetic_state_params_accepts_partial_inputs(): + assert SyntheticStateParams(state_schema=_make_state_schema()).state_schema is not None + assert ( + SyntheticStateParams(initial_state={"files": {"count": 1}}).initial_state + == {"files": {"count": 1}} ) - assert tool.resolve_deterministic({"id": "01"}) == {"msg": "pending"} - assert tool.resolve_deterministic({"id": "02"}) == {"msg": "delivered"} -def test_deterministic_tool_resolve_no_match(): - tool = _make_deterministic_tool() - assert tool.resolve_deterministic({"id": "99"}) is None +def test_synthetic_environment_valid_stateless(): + env = SyntheticEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[Tool(id="answer", name="Answer", description="Answer a FAQ.")], + ) + assert env.type == "synthetic" + assert env.state_params is None + assert env.current_state is None + assert isinstance(env.tools[0], Tool) -# --- StatelessTool tests --- +def test_synthetic_environment_valid_stateful(): + env = SyntheticEnvironment( + id="filesystem", + name="Filesystem", + description="A simple filesystem", + system_prompt="You manage a filesystem.", + state_params=SyntheticStateParams( + state_schema=_make_state_schema(), + initial_state={"files": {"count": 1}}, + ), + cache_by_input=False, + tools=[Tool(id="read", name="Read", description="Read files.")], + ) + assert env.current_state == {"files": {"count": 1}} -def test_stateless_tool_requires_generated_output(): - with pytest.raises(ValueError, match="must have a generated_output"): - StatelessTool(id="t", name="T", description="d", generated_output=None) +def test_synthetic_environment_coerces_dict_tools(): + env = SyntheticEnvironment( + id="fs", + name="FS", + description="d", + system_prompt="p", + tools=[{"id": "read", "name": "Read", "description": "Read files."}], # type: ignore[list-item] + ) + assert isinstance(env.tools[0], Tool) -# --- ToolEnvironmentType tests --- +def test_synthetic_environment_empty_system_prompt_raises(): + with pytest.raises(ValueError, match="system_prompt cannot be empty"): + SyntheticEnvironment(id="x", name="n", description="d", system_prompt="") -def test_tool_environment_type_values_exist(): - assert ToolEnvironmentType.STATEFUL == "stateful" - assert ToolEnvironmentType.STATELESS == "stateless" - assert ToolEnvironmentType.DETERMINISTIC == "deterministic" +def test_synthetic_environment_rejects_deterministic_outputs(): + with pytest.raises(ValueError, match="cannot define deterministic_outputs"): + SyntheticEnvironment( + id="x", + name="n", + description="d", + system_prompt="p", + tools=[_make_deterministic_tool()], + ) -# --- Environment + typed tool integration tests --- +def test_synthetic_environment_rejects_cache_when_stateful(): + with pytest.raises(ValueError, match="cache_by_input must be False"): + SyntheticEnvironment( + id="x", + name="n", + description="d", + system_prompt="p", + state_params=SyntheticStateParams(), + cache_by_input=True, + ) -def test_stateful_environment_valid(): - env = StatefulEnvironment( - id="filesystem", - name="Filesystem", - description="A simple filesystem", - system_prompt="You manage a filesystem.", - tools=[StatefulTool(id="read", name="Read", description="Read files.")], +def test_synthetic_environment_cache_round_trip(): + env = SyntheticEnvironment( + id="weather", + name="Weather", + description="Weather API", + system_prompt="Simulate weather.", + cache_by_input=True, + tools=[_make_synthetic_tool(id="get_weather")], ) - assert env.id == "filesystem" - assert env.state_schema is None - assert env.initial_state is None - assert isinstance(env.tools[0], StatefulTool) + result = ToolResult(output={"temp": 72}) + env._cache_result("get_weather", {"city": "SF"}, result) + cached = env._resolve_cached("get_weather", {"city": "SF"}) + assert cached == result + assert cached is not result -def test_stateful_environment_coerces_dict_tools(): - env = StatefulEnvironment( - id="fs", - name="FS", - description="d", - system_prompt="p", - tools=[{"id": "read", "name": "Read", "description": "Read files."}], # type: ignore[arg-type] +def test_synthetic_environment_step_unknown_tool_raises(): + env = SyntheticEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[_make_synthetic_tool(id="answer")], + ) + with pytest.raises(ValueError, match="Tool 'missing' not found"): + env.step("missing", {}) + + +def test_synthetic_environment_step_known_tool_is_stub(): + env = SyntheticEnvironment( + id="faq", + name="FAQ", + description="FAQ tools", + system_prompt="Answer FAQs.", + tools=[_make_synthetic_tool(id="answer")], ) - assert isinstance(env.tools[0], StatefulTool) + with pytest.raises(NotImplementedError, match="not implemented yet"): + env.step("answer", {}) def test_deterministic_environment_valid(): @@ -189,7 +270,7 @@ def test_deterministic_environment_valid(): name="Lookup", description="A deterministic lookup environment", tools=[ - DeterministicTool( + Tool( id="policy", name="Policy", description="Look up policy.", @@ -202,120 +283,186 @@ def test_deterministic_environment_valid(): ) ], ) - assert not hasattr(env, "system_prompt") - assert isinstance(env.tools[0], DeterministicTool) + assert env.type == "deterministic" + assert isinstance(env.tools[0], Tool) -def test_environment_with_schema_and_state(): - schema = { - "type": "object", - "properties": {"files": {"type": "object"}}, - "required": ["files"], - } - state = {"files": {}} - env = StatefulEnvironment( - id="fs", - name="FS", - description="d", - system_prompt="p", - state_schema=schema, - initial_state=state, - ) - assert env.state_schema == schema - assert env.initial_state == state +def test_deterministic_environment_requires_outputs_on_tool(): + with pytest.raises(ValueError, match="must have at least one"): + DeterministicEnvironment( + id="det_env", + name="Deterministic", + description="d", + tools=[_make_deterministic_tool(deterministic_outputs=[])], + ) -def test_environment_empty_id_raises(): - with pytest.raises(ValueError, match="id cannot be empty"): - StatefulEnvironment(id="", name="n", description="d", system_prompt="p") +def test_deterministic_environment_duplicate_inputs_raises(): + outputs = [ + DeterministicToolOutput(input={"id": "01"}, output={"msg": "a"}), + DeterministicToolOutput(input={"id": "01"}, output={"msg": "b"}), + ] + with pytest.raises(ValueError, match="duplicate"): + DeterministicEnvironment( + id="det_env", + name="Deterministic", + description="d", + tools=[_make_deterministic_tool(deterministic_outputs=outputs)], + ) -def test_environment_empty_name_raises(): - with pytest.raises(ValueError, match="name cannot be empty"): - StatefulEnvironment(id="x", name="", description="d", system_prompt="p") +def test_deterministic_environment_step_match(): + env = DeterministicEnvironment( + id="lookup", + name="Lookup", + description="A deterministic lookup environment", + tools=[ + _make_deterministic_tool( + deterministic_outputs=[ + DeterministicToolOutput( + input={"id": "01"}, output={"msg": "pending"} + ), + DeterministicToolOutput( + input={"id": "02"}, output={"msg": "delivered"} + ), + ] + ) + ], + ) + assert env.step("tool1", {"id": "01"}) == ToolResult(output={"msg": "pending"}) + assert env.step("tool1", {"id": "02"}) == ToolResult(output={"msg": "delivered"}) -def test_environment_empty_description_raises(): - with pytest.raises(ValueError, match="description cannot be empty"): - StatefulEnvironment(id="x", name="n", description="", system_prompt="p") +def test_deterministic_environment_step_no_match(): + env = DeterministicEnvironment( + id="lookup", + name="Lookup", + description="A deterministic lookup environment", + tools=[_make_deterministic_tool()], + ) + assert env.step("tool1", {"id": "99"}) == ToolResult(output=None) -def test_stateful_environment_empty_system_prompt_raises(): - with pytest.raises(ValueError, match="system_prompt cannot be empty"): - StatefulEnvironment(id="x", name="n", description="d", system_prompt="") +def test_deterministic_environment_supports_empty_argument_match(): + env = DeterministicEnvironment( + id="lookup", + name="Lookup", + description="A deterministic lookup environment", + tools=[ + Tool( + id="ping", + name="Ping", + description="Zero-arg tool.", + deterministic_outputs=[ + DeterministicToolOutput(input={}, output={}), + ], + ) + ], + ) + assert env.step("ping", {}) == ToolResult(output={}) -def test_stateless_environment_empty_system_prompt_raises(): - with pytest.raises(ValueError, match="system_prompt cannot be empty"): - StatelessEnvironment(id="x", name="n", description="d", system_prompt="") +def test_environment_empty_id_raises(): + with pytest.raises(ValueError, match="id cannot be empty"): + SyntheticEnvironment(id="", name="n", description="d", system_prompt="p") -def test_stateless_environment_with_state_schema_raises(): - with pytest.raises(TypeError, match="unexpected keyword argument 'state_schema'"): - StatelessEnvironment( - id="x", - name="n", - description="d", - system_prompt="p", - state_schema={"type": "object"}, # type: ignore[call-arg] - ) +def test_environment_empty_name_raises(): + with pytest.raises(ValueError, match="name cannot be empty"): + SyntheticEnvironment(id="x", name="", description="d", system_prompt="p") -def test_deterministic_environment_with_initial_state_raises(): - with pytest.raises(TypeError, match="unexpected keyword argument 'initial_state'"): - DeterministicEnvironment(id="x", name="n", description="d", initial_state={}) # type: ignore[call-arg] +def test_environment_empty_description_raises(): + with pytest.raises(ValueError, match="description cannot be empty"): + SyntheticEnvironment(id="x", name="n", description="", system_prompt="p") def test_environment_duplicate_tool_ids_raises(): with pytest.raises(ValueError, match="duplicate tool id 'dup'"): - StatefulEnvironment( + SyntheticEnvironment( id="env2", name="Env 2", description="d", system_prompt="p", tools=[ - StatefulTool(id="dup", name="Read", description="Read files."), - StatefulTool(id="dup", name="Write", description="Write files."), + Tool(id="dup", name="Read", description="Read files."), + Tool(id="dup", name="Write", description="Write files."), ], ) def test_environment_config_duplicate_tool_ids_across_envs_raises(): - env1 = StatefulEnvironment( + env1 = SyntheticEnvironment( id="env1", name="Env 1", description="d", system_prompt="p", - tools=[StatefulTool(id="dup", name="Read", description="Read files.")], + tools=[Tool(id="dup", name="Read", description="Read files.")], ) - env2 = StatefulEnvironment( + env2 = SyntheticEnvironment( id="env2", name="Env 2", description="d", system_prompt="p", - tools=[StatefulTool(id="dup", name="Write", description="Write files.")], + tools=[Tool(id="dup", name="Write", description="Write files.")], ) with pytest.raises(ValueError, match="duplicate tool id 'dup'"): EnvironmentConfig(environments=[env1, env2]) def test_environment_config_tool_environment_map(): - env = StatelessEnvironment( + env = SyntheticEnvironment( id="faq", name="FAQ", description="FAQ tools", system_prompt="Answer FAQs.", - tools=[_make_stateless_tool(id="answer_faq")], + tools=[_make_synthetic_tool(id="answer_faq")], ) config = EnvironmentConfig(environments=[env]) assert config.tool_environment_map == {"answer_faq": "faq"} -def test_deterministic_environment_requires_outputs_on_tool(): - with pytest.raises(ValueError, match="must have at least one"): - DeterministicEnvironment( - id="det_env", - name="Deterministic", - description="d", - tools=[_make_deterministic_tool(deterministic_outputs=[])], - ) +def test_base_environment_create_routes_synthetic(): + env = BaseEnvironment.create( + { + "id": "faq", + "type": "synthetic", + "name": "FAQ", + "description": "FAQ tools", + "system_prompt": "Answer FAQs.", + "tools": [{"id": "answer", "name": "Answer", "description": "Answer."}], + } + ) + assert isinstance(env, SyntheticEnvironment) + + +def test_base_environment_create_routes_deterministic(): + env = BaseEnvironment.create( + { + "id": "lookup", + "type": "deterministic", + "name": "Lookup", + "description": "Lookup tools", + "tools": [ + { + "id": "policy", + "name": "Policy", + "description": "Look up policy.", + "deterministic_outputs": [ + {"input": {"id": "1"}, "output": {"result": "ok"}} + ], + } + ], + } + ) + assert isinstance(env, DeterministicEnvironment) + + +def test_base_environment_create_missing_type_raises(): + with pytest.raises(ValueError, match="must include a 'type' field"): + BaseEnvironment.create({"id": "faq"}) + + +def test_base_environment_create_unsupported_type_raises(): + with pytest.raises(ValueError, match="Unsupported environment type"): + BaseEnvironment.create({"id": "faq", "type": "unknown"}) diff --git a/tests/unit/core/configs/test_synthesis_config.py b/tests/unit/core/configs/test_synthesis_config.py index 11dbc67381..f6ff5cc1d9 100644 --- a/tests/unit/core/configs/test_synthesis_config.py +++ b/tests/unit/core/configs/test_synthesis_config.py @@ -23,11 +23,9 @@ from oumi.core.configs.synthesis_config import SynthesisConfig, SynthesisStrategy from oumi.core.types.conversation import Role from oumi.environments import ( - GeneratedToolOutput, - StatefulEnvironment, - StatefulTool, - StatelessEnvironment, - StatelessTool, + SyntheticEnvironment, + SyntheticStateParams, + Tool, ) @@ -84,21 +82,18 @@ def test_invalid_output_path(): SynthesisConfig(inference_config=inference_config) -def _make_faq_tool() -> StatelessTool: - return StatelessTool( +def _make_faq_tool() -> Tool: + return Tool( id="answer_faq", name="AnswerFAQ", description="Answer a FAQ question.", - generated_output=GeneratedToolOutput( - instruction="Answer the given FAQ question." - ), ) def test_synthesis_config_with_top_level_environment_config(): env_config = EnvironmentConfig( environments=[ - StatelessEnvironment( + SyntheticEnvironment( id="faq", name="FAQ", description="FAQ tools", @@ -124,7 +119,7 @@ def test_synthesis_config_loads_environment_config_from_path(tmp_path): env_config_path = tmp_path / "environments.yaml" env_config = EnvironmentConfig( environments=[ - StatelessEnvironment( + SyntheticEnvironment( id="faq", name="FAQ", description="FAQ tools", @@ -144,7 +139,7 @@ def test_synthesis_config_loads_environment_config_from_path(tmp_path): def test_synthesis_config_validates_available_tools(): env_config = EnvironmentConfig( environments=[ - StatelessEnvironment( + SyntheticEnvironment( id="faq", name="FAQ", description="FAQ tools", @@ -203,7 +198,7 @@ def test_synthesis_config_requires_environment_config_for_available_tools(): def test_synthesis_config_validates_available_environments(): env_config = EnvironmentConfig( environments=[ - StatelessEnvironment( + SyntheticEnvironment( id="faq", name="FAQ", description="FAQ tools", @@ -234,20 +229,22 @@ def test_synthesis_config_validates_available_environments(): def test_synthesis_config_restricts_tools_to_selected_environments(): env_config = EnvironmentConfig( environments=[ - StatelessEnvironment( + SyntheticEnvironment( id="faq", name="FAQ", description="FAQ tools", system_prompt="Answer FAQs.", tools=[_make_faq_tool()], ), - StatefulEnvironment( + SyntheticEnvironment( id="files", name="Files", description="File tools", system_prompt="Manage files.", + state_params=SyntheticStateParams(), + cache_by_input=False, tools=[ - StatefulTool( + Tool( id="read_file", name="ReadFile", description="Read a file.", @@ -279,20 +276,22 @@ def test_synthesis_config_restricts_tools_to_selected_environments(): def test_synthesis_config_resolves_all_tools_from_selected_environments(): env_config = EnvironmentConfig( environments=[ - StatelessEnvironment( + SyntheticEnvironment( id="faq", name="FAQ", description="FAQ tools", system_prompt="Answer FAQs.", tools=[_make_faq_tool()], ), - StatefulEnvironment( + SyntheticEnvironment( id="files", name="Files", description="File tools", system_prompt="Manage files.", + state_params=SyntheticStateParams(), + cache_by_input=False, tools=[ - StatefulTool( + Tool( id="read_file", name="ReadFile", description="Read a file.", From f8a4d520b762faa285d5512ec796066a4f58b54c Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Tue, 14 Apr 2026 13:35:37 -0700 Subject: [PATCH 13/18] Update test_tool_params.py --- tests/unit/core/configs/params/test_tool_params.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/unit/core/configs/params/test_tool_params.py b/tests/unit/core/configs/params/test_tool_params.py index 8069a63f33..4d14b8765a 100644 --- a/tests/unit/core/configs/params/test_tool_params.py +++ b/tests/unit/core/configs/params/test_tool_params.py @@ -148,11 +148,12 @@ def test_synthetic_state_params_validates_initial_state_against_schema(): def test_synthetic_state_params_accepts_partial_inputs(): - assert SyntheticStateParams(state_schema=_make_state_schema()).state_schema is not None assert ( - SyntheticStateParams(initial_state={"files": {"count": 1}}).initial_state - == {"files": {"count": 1}} + SyntheticStateParams(state_schema=_make_state_schema()).state_schema is not None ) + assert SyntheticStateParams( + initial_state={"files": {"count": 1}} + ).initial_state == {"files": {"count": 1}} def test_synthetic_environment_valid_stateless(): From 6f48a8ccc1054ee06211a3b82958ca0dae6d6c59 Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Tue, 14 Apr 2026 13:49:00 -0700 Subject: [PATCH 14/18] refactor: move ToolResult into base_tool.py and fix circular imports MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move ToolResult from its own file into base_tool.py alongside Tool and DeterministicToolOutput — all tool-related data types now live in one place. Restore lazy loading of EnvironmentConfig and SynthesisConfig in oumi.core.configs.__init__ to prevent circular imports when oumi.environments is imported first. --- src/oumi/core/configs/__init__.py | 15 ++++++++-- src/oumi/environments/__init__.py | 3 +- src/oumi/environments/base_environment.py | 3 +- src/oumi/environments/base_tool.py | 10 ++++++- .../environments/deterministic_environment.py | 3 +- .../environments/synthetic_environment.py | 2 +- src/oumi/environments/tool_result.py | 28 ------------------- 7 files changed, 26 insertions(+), 38 deletions(-) delete mode 100644 src/oumi/environments/tool_result.py diff --git a/src/oumi/core/configs/__init__.py b/src/oumi/core/configs/__init__.py index 068483ce01..67d1aaf03d 100644 --- a/src/oumi/core/configs/__init__.py +++ b/src/oumi/core/configs/__init__.py @@ -83,7 +83,6 @@ ) from oumi.core.configs.async_evaluation_config import AsyncEvaluationConfig from oumi.core.configs.base_config import BaseConfig -from oumi.core.configs.environment_config import EnvironmentConfig from oumi.core.configs.evaluation_config import EvaluationConfig from oumi.core.configs.inference_config import InferenceConfig from oumi.core.configs.inference_engine_type import InferenceEngineType @@ -155,10 +154,22 @@ TuningParams, ) from oumi.core.configs.quantization_config import QuantizationConfig -from oumi.core.configs.synthesis_config import SynthesisConfig from oumi.core.configs.training_config import TrainingConfig from oumi.core.configs.tuning_config import TuningConfig +def __getattr__(name: str): # noqa: E302 + """Lazily import configs that depend on oumi.environments to avoid circular imports.""" + if name == "EnvironmentConfig": + from oumi.core.configs.environment_config import EnvironmentConfig + + return EnvironmentConfig + if name == "SynthesisConfig": + from oumi.core.configs.synthesis_config import SynthesisConfig + + return SynthesisConfig + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + __all__ = [ "AsyncEvaluationConfig", "AutoWrapPolicy", diff --git a/src/oumi/environments/__init__.py b/src/oumi/environments/__init__.py index dafe7bdab7..c6b3268e21 100644 --- a/src/oumi/environments/__init__.py +++ b/src/oumi/environments/__init__.py @@ -15,13 +15,12 @@ """Environments for agentic tool interactions.""" from oumi.environments.base_environment import BaseEnvironment -from oumi.environments.base_tool import DeterministicToolOutput, Tool +from oumi.environments.base_tool import DeterministicToolOutput, Tool, ToolResult from oumi.environments.deterministic_environment import DeterministicEnvironment from oumi.environments.synthetic_environment import ( SyntheticEnvironment, SyntheticStateParams, ) -from oumi.environments.tool_result import ToolResult __all__ = [ "BaseEnvironment", diff --git a/src/oumi/environments/base_environment.py b/src/oumi/environments/base_environment.py index 1ff63bff12..1a15f9d3db 100644 --- a/src/oumi/environments/base_environment.py +++ b/src/oumi/environments/base_environment.py @@ -22,8 +22,7 @@ from typing import Any, ClassVar from oumi.core.configs.params.base_params import BaseParams -from oumi.environments.base_tool import Tool -from oumi.environments.tool_result import ToolResult +from oumi.environments.base_tool import Tool, ToolResult @dataclass diff --git a/src/oumi/environments/base_tool.py b/src/oumi/environments/base_tool.py index d6cbf22824..2de38ea847 100644 --- a/src/oumi/environments/base_tool.py +++ b/src/oumi/environments/base_tool.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tool definitions shared by all environment types.""" +"""Tool definitions and execution results shared by all environment types.""" from __future__ import annotations @@ -83,3 +83,11 @@ def to_llm_schema(self) -> dict[str, Any]: "description": self.description, "parameters": self.parameters, } + + +@dataclass +class ToolResult(BaseParams): + """Result returned by an environment step.""" + + output: str | dict[str, Any] | None + updated_state: dict[str, Any] | None = None diff --git a/src/oumi/environments/deterministic_environment.py b/src/oumi/environments/deterministic_environment.py index 768b9df86b..5c4c1dea93 100644 --- a/src/oumi/environments/deterministic_environment.py +++ b/src/oumi/environments/deterministic_environment.py @@ -21,8 +21,7 @@ from typing import Any, ClassVar from oumi.environments.base_environment import BaseEnvironment -from oumi.environments.base_tool import DeterministicToolOutput, Tool -from oumi.environments.tool_result import ToolResult +from oumi.environments.base_tool import DeterministicToolOutput, Tool, ToolResult @dataclass diff --git a/src/oumi/environments/synthetic_environment.py b/src/oumi/environments/synthetic_environment.py index 243303c5b1..60357864f9 100644 --- a/src/oumi/environments/synthetic_environment.py +++ b/src/oumi/environments/synthetic_environment.py @@ -23,7 +23,7 @@ from oumi.core.configs.params.base_params import BaseParams from oumi.environments.base_environment import BaseEnvironment -from oumi.environments.tool_result import ToolResult +from oumi.environments.base_tool import ToolResult def _validate_json_schema_value( diff --git a/src/oumi/environments/tool_result.py b/src/oumi/environments/tool_result.py deleted file mode 100644 index d0af0ee6b7..0000000000 --- a/src/oumi/environments/tool_result.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2025 - Oumi -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tool execution results.""" - -from dataclasses import dataclass -from typing import Any - -from oumi.core.configs.params.base_params import BaseParams - - -@dataclass -class ToolResult(BaseParams): - """Result returned by an environment step.""" - - output: str | dict[str, Any] | None - updated_state: dict[str, Any] | None = None From 1a223d40426d871d43f955be99dfa919314ce4ad Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Tue, 14 Apr 2026 14:03:17 -0700 Subject: [PATCH 15/18] fix: resolve circular import without lazy loading Move EnvironmentConfig, BaseEnvironment, and Tool imports behind TYPE_CHECKING in synthesis_config.py since they are only used for type annotations. The single runtime usage of EnvironmentConfig (from_yaml call) uses a local import. Remove the __getattr__ lazy-loading hack from oumi.core.configs.__init__ and restore SynthesisConfig as a normal re-export. --- src/oumi/core/configs/__init__.py | 15 +-------------- src/oumi/core/configs/synthesis_config.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/src/oumi/core/configs/__init__.py b/src/oumi/core/configs/__init__.py index 67d1aaf03d..8b44df535f 100644 --- a/src/oumi/core/configs/__init__.py +++ b/src/oumi/core/configs/__init__.py @@ -154,22 +154,10 @@ TuningParams, ) from oumi.core.configs.quantization_config import QuantizationConfig +from oumi.core.configs.synthesis_config import SynthesisConfig from oumi.core.configs.training_config import TrainingConfig from oumi.core.configs.tuning_config import TuningConfig -def __getattr__(name: str): # noqa: E302 - """Lazily import configs that depend on oumi.environments to avoid circular imports.""" - if name == "EnvironmentConfig": - from oumi.core.configs.environment_config import EnvironmentConfig - - return EnvironmentConfig - if name == "SynthesisConfig": - from oumi.core.configs.synthesis_config import SynthesisConfig - - return SynthesisConfig - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - __all__ = [ "AsyncEvaluationConfig", "AutoWrapPolicy", @@ -186,7 +174,6 @@ def __getattr__(name: str): # noqa: E302 "EvaluationConfig", "EvaluationBackend", "EvaluationConfig", - "EnvironmentConfig", "EvaluationTaskParams", "FSDPParams", "GenerationParams", diff --git a/src/oumi/core/configs/synthesis_config.py b/src/oumi/core/configs/synthesis_config.py index 8d63f34658..7879b4c5c4 100644 --- a/src/oumi/core/configs/synthesis_config.py +++ b/src/oumi/core/configs/synthesis_config.py @@ -12,19 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from dataclasses import dataclass, field from enum import Enum from pathlib import Path +from typing import TYPE_CHECKING from oumi.core.configs.base_config import BaseConfig -from oumi.core.configs.environment_config import EnvironmentConfig from oumi.core.configs.inference_config import InferenceConfig from oumi.core.configs.params.synthesis_params import ( GeneralSynthesisParams, MultiTurnAttribute, ) -from oumi.environments.base_environment import BaseEnvironment -from oumi.environments.base_tool import Tool + +if TYPE_CHECKING: + from oumi.core.configs.environment_config import EnvironmentConfig + from oumi.environments.base_environment import BaseEnvironment + from oumi.environments.base_tool import Tool class SynthesisStrategy(str, Enum): @@ -117,6 +122,8 @@ def _resolve_environment_config(self) -> EnvironmentConfig | None: f"Environment config path does not exist: " f"{self.environment_config_path}" ) + from oumi.core.configs.environment_config import EnvironmentConfig + return EnvironmentConfig.from_yaml(config_path) return None From c5cd36d9abe3cff1ff2032230cf8d7958714ff7c Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Tue, 14 Apr 2026 14:37:09 -0700 Subject: [PATCH 16/18] feat: fix circular imports --- src/oumi/core/configs/__init__.py | 2 ++ src/oumi/core/configs/environment_config.py | 12 +++++++++--- src/oumi/environments/synthetic_environment.py | 8 +------- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/oumi/core/configs/__init__.py b/src/oumi/core/configs/__init__.py index 8b44df535f..068483ce01 100644 --- a/src/oumi/core/configs/__init__.py +++ b/src/oumi/core/configs/__init__.py @@ -83,6 +83,7 @@ ) from oumi.core.configs.async_evaluation_config import AsyncEvaluationConfig from oumi.core.configs.base_config import BaseConfig +from oumi.core.configs.environment_config import EnvironmentConfig from oumi.core.configs.evaluation_config import EvaluationConfig from oumi.core.configs.inference_config import InferenceConfig from oumi.core.configs.inference_engine_type import InferenceEngineType @@ -174,6 +175,7 @@ "EvaluationConfig", "EvaluationBackend", "EvaluationConfig", + "EnvironmentConfig", "EvaluationTaskParams", "FSDPParams", "GenerationParams", diff --git a/src/oumi/core/configs/environment_config.py b/src/oumi/core/configs/environment_config.py index 1a8a985397..a832f5a4ec 100644 --- a/src/oumi/core/configs/environment_config.py +++ b/src/oumi/core/configs/environment_config.py @@ -14,12 +14,16 @@ """Configuration for agentic environments.""" +from __future__ import annotations + from dataclasses import dataclass, field -from typing import Any +from typing import TYPE_CHECKING, Any from oumi.core.configs.base_config import BaseConfig -from oumi.environments.base_environment import BaseEnvironment -from oumi.environments.base_tool import Tool + +if TYPE_CHECKING: + from oumi.environments.base_environment import BaseEnvironment + from oumi.environments.base_tool import Tool @dataclass @@ -130,4 +134,6 @@ def resolve_tools( def _coerce_environment(self, environment: Any) -> BaseEnvironment: """Coerce a raw dict or environment instance into a concrete environment.""" + from oumi.environments.base_environment import BaseEnvironment + return BaseEnvironment.create(environment) diff --git a/src/oumi/environments/synthetic_environment.py b/src/oumi/environments/synthetic_environment.py index 60357864f9..2b2c63decb 100644 --- a/src/oumi/environments/synthetic_environment.py +++ b/src/oumi/environments/synthetic_environment.py @@ -107,8 +107,6 @@ def __post_init__(self): and self.state_params.initial_state is not None else None ) - - # Validate synthetic-only fields after common environment validation/coercion. if not self.system_prompt: raise ValueError("SyntheticEnvironment.system_prompt cannot be empty.") if self.state_params is not None and self.cache_by_input: @@ -160,10 +158,6 @@ def _cache_result( ) def step(self, tool_id: str, arguments: dict[str, Any]) -> ToolResult: - """Execute a synthetic tool call. - - The environment interface is defined now; actual LLM-backed execution - will be implemented in follow-up changes. - """ + """Execute a synthetic tool call.""" self._get_tool_or_raise(tool_id) raise NotImplementedError("SyntheticEnvironment.step() is not implemented yet.") From 1b467623d1a31c874ca637d49b4521249d57403a Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Tue, 14 Apr 2026 14:47:22 -0700 Subject: [PATCH 17/18] fix: use Any annotation for environment_config field in SynthesisConfig OmegaConf resolves dataclass field annotations at runtime via get_type_hints(). With EnvironmentConfig behind TYPE_CHECKING, the annotation string cannot be resolved, breaking YAML parsing. Use Any for the runtime annotation while keeping the TYPE_CHECKING import for static analysis. --- src/oumi/core/configs/synthesis_config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/oumi/core/configs/synthesis_config.py b/src/oumi/core/configs/synthesis_config.py index 7879b4c5c4..5da1acd6f8 100644 --- a/src/oumi/core/configs/synthesis_config.py +++ b/src/oumi/core/configs/synthesis_config.py @@ -17,7 +17,7 @@ from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from oumi.core.configs.base_config import BaseConfig from oumi.core.configs.inference_config import InferenceConfig @@ -57,8 +57,8 @@ class SynthesisConfig(BaseConfig): ) """The synthesis strategy parameters to use.""" - environment_config: EnvironmentConfig | None = None - """Reusable environment-first tool configuration.""" + environment_config: Any | None = None + """Reusable environment-first tool configuration (EnvironmentConfig).""" environment_config_path: str | None = None """Optional path to an EnvironmentConfig YAML file.""" From 4ef5f1bd3af6979a2cd06da6f2b4cffc82d83abd Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Wed, 15 Apr 2026 16:13:14 -0700 Subject: [PATCH 18/18] Add ToolSchema class for structured tool I/O definitions Extract tool parameter and output schemas into a dedicated ToolSchema class with validation and serialization. Update Tool to use ToolSchema for parameters and add optional output_schema field. --- src/oumi/environments/__init__.py | 8 +- src/oumi/environments/base_tool.py | 93 ++++++++++++++- .../core/configs/params/test_tool_params.py | 112 +++++++++++++++++- 3 files changed, 207 insertions(+), 6 deletions(-) diff --git a/src/oumi/environments/__init__.py b/src/oumi/environments/__init__.py index c6b3268e21..f84a0785e7 100644 --- a/src/oumi/environments/__init__.py +++ b/src/oumi/environments/__init__.py @@ -15,7 +15,12 @@ """Environments for agentic tool interactions.""" from oumi.environments.base_environment import BaseEnvironment -from oumi.environments.base_tool import DeterministicToolOutput, Tool, ToolResult +from oumi.environments.base_tool import ( + DeterministicToolOutput, + Tool, + ToolResult, + ToolSchema, +) from oumi.environments.deterministic_environment import DeterministicEnvironment from oumi.environments.synthetic_environment import ( SyntheticEnvironment, @@ -25,6 +30,7 @@ __all__ = [ "BaseEnvironment", "Tool", + "ToolSchema", "ToolResult", "SyntheticEnvironment", "SyntheticStateParams", diff --git a/src/oumi/environments/base_tool.py b/src/oumi/environments/base_tool.py index 2de38ea847..7becfcc080 100644 --- a/src/oumi/environments/base_tool.py +++ b/src/oumi/environments/base_tool.py @@ -38,6 +38,76 @@ def matches(self, arguments: dict[str, Any]) -> bool: ) +@dataclass +class ToolSchema(BaseParams): + """JSON schema for tool inputs or outputs.""" + + type: str = "object" + properties: dict[str, ToolSchema] = field(default_factory=dict) + description: str | None = None + required: list[str] = field(default_factory=list) + + @classmethod + def create(cls, raw: Any) -> ToolSchema: + """Create a schema from raw config data.""" + if isinstance(raw, ToolSchema): + return raw + if not isinstance(raw, Mapping): + raise TypeError( + f"Tool schema definitions must be schema objects or mappings, got " + f"{type(raw)}" + ) + return cls( + type=raw.get("type", "object"), + properties={ + key: cls.create(value) + for key, value in raw.get("properties", {}).items() + }, + description=raw.get("description"), + required=raw.get("required", []), + ) + + def __post_init__(self): + """Validate schema fields.""" + if not self.type: + raise ValueError(f"{type(self).__name__}.type cannot be empty.") + if not isinstance(self.properties, dict): + raise ValueError(f"{type(self).__name__}.properties must be a dict.") + if not all(isinstance(value, ToolSchema) for value in self.properties.values()): + raise ValueError( + f"{type(self).__name__}.properties values must be ToolSchema." + ) + if self.description is not None and not isinstance(self.description, str): + raise ValueError( + f"{type(self).__name__}.description must be a string when specified." + ) + if not isinstance(self.required, list) or not all( + isinstance(param, str) for param in self.required + ): + raise ValueError( + f"{type(self).__name__}.required must be a list of strings." + ) + missing_required = set(self.required) - set(self.properties) + if missing_required: + raise ValueError( + f"{type(self).__name__}.required contains unknown properties: " + f"{sorted(missing_required)}" + ) + + def to_dict(self) -> dict[str, Any]: + """Convert to a JSON-schema-shaped dict.""" + schema: dict[str, Any] = {"type": self.type} + if self.properties: + schema["properties"] = { + key: value.to_dict() for key, value in self.properties.items() + } + if self.description is not None: + schema["description"] = self.description + if self.required: + schema["required"] = list(self.required) + return schema + + @dataclass class Tool(BaseParams): """Tool schema owned by an environment.""" @@ -45,7 +115,8 @@ class Tool(BaseParams): id: str name: str description: str - parameters: dict[str, Any] = field(default_factory=dict) + parameters: ToolSchema = field(default_factory=ToolSchema) + output_schema: ToolSchema | None = None read_only: bool = True deterministic_outputs: list[DeterministicToolOutput] = field(default_factory=list) @@ -62,7 +133,12 @@ def create(cls, raw: Any) -> Tool: id=raw["id"], name=raw["name"], description=raw["description"], - parameters=raw.get("parameters", {}), + parameters=ToolSchema.create(raw.get("parameters", {})), + output_schema=( + ToolSchema.create(raw["output_schema"]) + if raw.get("output_schema") is not None + else None + ), read_only=raw.get("read_only", True), deterministic_outputs=raw.get("deterministic_outputs", []), ) @@ -75,14 +151,23 @@ def __post_init__(self): raise ValueError(f"{type(self).__name__}.name cannot be empty.") if not self.description: raise ValueError(f"{type(self).__name__}.description cannot be empty.") + if not isinstance(self.parameters, ToolSchema): + self.parameters = ToolSchema.create(self.parameters) + if self.output_schema is not None and not isinstance( + self.output_schema, ToolSchema + ): + self.output_schema = ToolSchema.create(self.output_schema) def to_llm_schema(self) -> dict[str, Any]: """Export a provider-agnostic schema for LLM tool registration.""" - return { + schema: dict[str, Any] = { "name": self.name, "description": self.description, - "parameters": self.parameters, + "parameters": self.parameters.to_dict(), } + if self.output_schema is not None: + schema["output_schema"] = self.output_schema.to_dict() + return schema @dataclass diff --git a/tests/unit/core/configs/params/test_tool_params.py b/tests/unit/core/configs/params/test_tool_params.py index 4d14b8765a..6590c81c85 100644 --- a/tests/unit/core/configs/params/test_tool_params.py +++ b/tests/unit/core/configs/params/test_tool_params.py @@ -25,6 +25,7 @@ SyntheticStateParams, Tool, ToolResult, + ToolSchema, ) @@ -104,7 +105,11 @@ def test_tool_to_llm_schema(): id="search", name="Search", description="Search the catalog.", - parameters={"type": "object", "properties": {"query": {"type": "string"}}}, + parameters=ToolSchema( + type="object", + properties={"query": ToolSchema(type="string")}, + required=["query"], + ), ) assert tool.to_llm_schema() == { "name": "Search", @@ -112,6 +117,29 @@ def test_tool_to_llm_schema(): "parameters": { "type": "object", "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + } + + +def test_tool_to_llm_schema_includes_output_schema(): + tool = Tool( + id="search", + name="Search", + description="Search the catalog.", + parameters=ToolSchema(type="object", properties={}), + output_schema=ToolSchema( + type="object", + properties={"result": ToolSchema(type="string")}, + ), + ) + assert tool.to_llm_schema() == { + "name": "Search", + "description": "Search the catalog.", + "parameters": {"type": "object"}, + "output_schema": { + "type": "object", + "properties": {"result": {"type": "string"}}, }, } @@ -139,6 +167,88 @@ def test_tool_create_coerces_deterministic_outputs(): assert isinstance(env.tools[0].deterministic_outputs[0], DeterministicToolOutput) +def test_tool_create_reads_extended_tool_fields(): + tool = Tool.create( + { + "id": "policy", + "name": "Policy", + "description": "Look up policy.", + "parameters": { + "type": "object", + "properties": { + "policy_id": {"type": "string"}, + }, + "required": ["policy_id"], + }, + "output_schema": {"type": "object", "properties": {}}, + } + ) + assert tool.parameters.required == ["policy_id"] + assert tool.output_schema == ToolSchema(type="object", properties={}) + + +def test_tool_schema_to_dict(): + schema = ToolSchema( + type="object", + properties={ + "query": ToolSchema(type="string"), + }, + description="Tool input schema.", + required=["query"], + ) + assert schema.to_dict() == { + "type": "object", + "properties": {"query": {"type": "string"}}, + "description": "Tool input schema.", + "required": ["query"], + } + + +def test_tool_schema_create_recursively_coerces_nested_properties(): + schema = ToolSchema.create( + { + "type": "object", + "properties": { + "customer": { + "type": "object", + "properties": { + "email": {"type": "string", "description": "Customer email."} + }, + "required": ["email"], + } + }, + "required": ["customer"], + } + ) + assert isinstance(schema.properties["customer"], ToolSchema) + assert isinstance(schema.properties["customer"].properties["email"], ToolSchema) + assert schema.to_dict() == { + "type": "object", + "properties": { + "customer": { + "type": "object", + "properties": { + "email": { + "type": "string", + "description": "Customer email.", + } + }, + "required": ["email"], + } + }, + "required": ["customer"], + } + + +def test_tool_schema_required_must_exist_in_properties(): + with pytest.raises(ValueError, match="required contains unknown properties"): + ToolSchema( + type="object", + properties={"query": ToolSchema(type="string")}, + required=["missing"], + ) + + def test_synthetic_state_params_validates_initial_state_against_schema(): with pytest.raises(ValueError, match=r"\$\.files\.count must be an integer"): SyntheticStateParams(