Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 130 additions & 59 deletions environments/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,19 @@ def resize_tool_pool(max_workers: int):
old_executor.shutdown(wait=False)
logger.info("Tool thread pool resized to %d workers", max_workers)


logger = logging.getLogger(__name__)


@dataclass
class ToolError:
"""Record of a tool execution error during the agent loop."""

turn: int # Which turn the error occurred on
tool_name: str # Which tool was called
arguments: str # The arguments passed (truncated)
error: str # The error message
tool_result: str # The raw result returned to the model
turn: int # Which turn the error occurred on
tool_name: str # Which tool was called
arguments: str # The arguments passed (truncated)
error: str # The error message
tool_result: str # The raw result returned to the model


@dataclass
Expand Down Expand Up @@ -179,7 +180,9 @@ async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
tool_errors: List[ToolError] = []

# Per-loop TodoStore for the todo tool (ephemeral, dies with the loop)
from tools.todo_tool import TodoStore, todo_tool as _todo_tool
from tools.todo_tool import TodoStore
from tools.todo_tool import todo_tool as _todo_tool

_todo_store = TodoStore()

# Extract user task from first user message for browser_snapshot context
Expand Down Expand Up @@ -214,6 +217,12 @@ async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
# Inject extra_body for provider-specific params (e.g., OpenRouter
# provider preferences like banned/preferred providers, transforms)
if self.extra_body:
# Custom atropos extension: inhibit sending tools= to the server
# (Used to bypass buggy server-side tool parsing in vLLM 0.6.5)
if self.extra_body.get("atropos_inhibit_tools"):
chat_kwargs.pop("tools", None)
chat_kwargs["tool_choice"] = "none"

chat_kwargs["extra_body"] = self.extra_body

# Make the API call -- standard OpenAI spec
Expand All @@ -222,7 +231,9 @@ async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
response = await self.server.chat_completion(**chat_kwargs)
except Exception as e:
api_elapsed = _time.monotonic() - api_start
logger.error("API call failed on turn %d (%.1fs): %s", turn + 1, api_elapsed, e)
logger.error(
"API call failed on turn %d (%.1fs): %s", turn + 1, api_elapsed, e
)
return AgentResult(
messages=messages,
managed_state=self._get_managed_state(),
Expand All @@ -235,7 +246,9 @@ async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
api_elapsed = _time.monotonic() - api_start

if not response or not response.choices:
logger.warning("Empty response on turn %d (api=%.1fs)", turn + 1, api_elapsed)
logger.warning(
"Empty response on turn %d (api=%.1fs)", turn + 1, api_elapsed
)
return AgentResult(
messages=messages,
managed_state=self._get_managed_state(),
Expand All @@ -246,6 +259,11 @@ async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
)

assistant_msg = response.choices[0].message
logger.info(
"\n--- [TURN %d] Assistant Response ---\n%s\n----------------------------------",
turn + 1,
assistant_msg.content,
)

# Extract reasoning content from the response (all provider formats)
reasoning = _extract_reasoning_from_message(assistant_msg)
Expand All @@ -261,36 +279,44 @@ async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
not assistant_msg.tool_calls
and assistant_msg.content
and self.tool_schemas
and "<tool_call>" in (assistant_msg.content or "")
and (
"<tool_code>" in assistant_msg.content
or "<tool_call>" in assistant_msg.content
)
):
try:
from environments.tool_call_parsers import get_parser
fallback_parser = get_parser("hermes")
parsed_content, parsed_calls = fallback_parser.parse(
from model_tools import parse_tool_calls_from_text

parsed_calls_dicts, parsed_content = parse_tool_calls_from_text(
assistant_msg.content
)
if parsed_calls:
assistant_msg.tool_calls = parsed_calls
if parsed_calls_dicts:
assistant_msg.tool_calls = parsed_calls_dicts
if parsed_content is not None:
assistant_msg.content = parsed_content
logger.debug(
"Fallback parser extracted %d tool calls from raw content",
len(parsed_calls),
len(parsed_calls_dicts),
)
except Exception:
except Exception as e:
logger.error("Fallback parser error: %s", e)
pass # Fall through to no tool calls

if assistant_msg.tool_calls:
# Normalize tool calls to dicts β€” they may come as objects
# Normalize tool calls to dicts - they may come as objects
# (OpenAI API) or dicts (vLLM ToolCallTranslator).
def _tc_to_dict(tc):
if isinstance(tc, dict):
return {
"id": tc.get("id", f"call_{uuid.uuid4().hex[:8]}"),
"type": "function",
"function": {
"name": tc.get("function", {}).get("name", tc.get("name", "")),
"arguments": tc.get("function", {}).get("arguments", tc.get("arguments", "{}")),
"name": tc.get("function", {}).get(
"name", tc.get("name", "")
),
"arguments": tc.get("function", {}).get(
"arguments", tc.get("arguments", "{}")
),
},
}
return {
Expand Down Expand Up @@ -321,8 +347,12 @@ def _tc_to_dict(tc):
for tc in assistant_msg.tool_calls:
# Handle both object (OpenAI) and dict (vLLM) formats
if isinstance(tc, dict):
tool_name = tc.get("function", {}).get("name", tc.get("name", ""))
tool_args_raw = tc.get("function", {}).get("arguments", tc.get("arguments", "{}"))
tool_name = tc.get("function", {}).get(
"name", tc.get("name", "")
)
tool_args_raw = tc.get("function", {}).get(
"arguments", tc.get("arguments", "{}")
)
else:
tool_name = tc.function.name
tool_args_raw = tc.function.arguments
Expand All @@ -335,15 +365,19 @@ def _tc_to_dict(tc):
f"Available tools: {sorted(self.valid_tool_names)}"
}
)
tool_errors.append(ToolError(
turn=turn + 1, tool_name=tool_name,
arguments=tool_args_raw[:200],
error=f"Unknown tool '{tool_name}'",
tool_result=tool_result,
))
tool_errors.append(
ToolError(
turn=turn + 1,
tool_name=tool_name,
arguments=tool_args_raw[:200],
error=f"Unknown tool '{tool_name}'",
tool_result=tool_result,
)
)
logger.warning(
"Model called unknown tool '%s' on turn %d",
tool_name, turn + 1,
tool_name,
turn + 1,
)
else:
# Parse arguments
Expand All @@ -352,17 +386,23 @@ def _tc_to_dict(tc):
except json.JSONDecodeError as e:
args = None
tool_result = json.dumps(
{"error": f"Invalid JSON in tool arguments: {e}. Please retry with valid JSON."}
{
"error": f"Invalid JSON in tool arguments: {e}. Please retry with valid JSON."
}
)
tool_errors.append(
ToolError(
turn=turn + 1,
tool_name=tool_name,
arguments=tool_args_raw[:200],
error=f"Invalid JSON: {e}",
tool_result=tool_result,
)
)
tool_errors.append(ToolError(
turn=turn + 1, tool_name=tool_name,
arguments=tool_args_raw[:200],
error=f"Invalid JSON: {e}",
tool_result=tool_result,
))
logger.warning(
"Invalid JSON in tool call arguments for '%s': %s",
tool_name, tool_args_raw[:200],
tool_name,
tool_args_raw[:200],
)

# Dispatch tool only if arguments parsed successfully
Expand All @@ -372,7 +412,9 @@ def _tc_to_dict(tc):
backend = os.getenv("TERMINAL_ENV", "local")
cmd_preview = args.get("command", "")[:80]
logger.info(
"[%s] $ %s", self.task_id[:8], cmd_preview,
"[%s] $ %s",
self.task_id[:8],
cmd_preview,
)

tool_submit_time = _time.monotonic()
Expand All @@ -386,10 +428,18 @@ def _tc_to_dict(tc):
)
tool_elapsed = _time.monotonic() - tool_submit_time
elif tool_name == "memory":
tool_result = json.dumps({"error": "Memory is not available in RL environments."})
tool_result = json.dumps(
{
"error": "Memory is not available in RL environments."
}
)
tool_elapsed = _time.monotonic() - tool_submit_time
elif tool_name == "session_search":
tool_result = json.dumps({"error": "Session search is not available in RL environments."})
tool_result = json.dumps(
{
"error": "Session search is not available in RL environments."
}
)
tool_elapsed = _time.monotonic() - tool_submit_time
else:
# Run tool calls in a thread pool so backends that
Expand All @@ -401,7 +451,9 @@ def _tc_to_dict(tc):
tool_result = await loop.run_in_executor(
_tool_executor,
lambda: handle_function_call(
_tn, _ta, task_id=_tid,
_tn,
_ta,
task_id=_tid,
user_task=_user_task,
),
)
Expand All @@ -412,22 +464,32 @@ def _tc_to_dict(tc):
if tool_elapsed > 30:
logger.warning(
"[%s] turn %d: %s took %.1fs (pool queue=%d)",
self.task_id[:8], turn + 1, tool_name,
tool_elapsed, pool_active,
self.task_id[:8],
turn + 1,
tool_name,
tool_elapsed,
pool_active,
)
except Exception as e:
tool_result = json.dumps(
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
{
"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"
}
)
tool_errors.append(
ToolError(
turn=turn + 1,
tool_name=tool_name,
arguments=tool_args_raw[:200],
error=f"{type(e).__name__}: {str(e)}",
tool_result=tool_result,
)
)
tool_errors.append(ToolError(
turn=turn + 1, tool_name=tool_name,
arguments=tool_args_raw[:200],
error=f"{type(e).__name__}: {str(e)}",
tool_result=tool_result,
))
logger.error(
"Tool '%s' execution failed on turn %d: %s",
tool_name, turn + 1, e,
tool_name,
turn + 1,
e,
)

# Also check if the tool returned an error in its JSON result
Expand All @@ -437,12 +499,15 @@ def _tc_to_dict(tc):
err = result_data.get("error")
exit_code = result_data.get("exit_code")
if err and exit_code and exit_code < 0:
tool_errors.append(ToolError(
turn=turn + 1, tool_name=tool_name,
arguments=tool_args_raw[:200],
error=str(err),
tool_result=tool_result[:500],
))
tool_errors.append(
ToolError(
turn=turn + 1,
tool_name=tool_name,
arguments=tool_args_raw[:200],
error=str(err),
tool_result=tool_result[:500],
)
)
except (json.JSONDecodeError, TypeError):
pass

Expand All @@ -459,8 +524,11 @@ def _tc_to_dict(tc):
turn_elapsed = _time.monotonic() - turn_start
logger.info(
"[%s] turn %d: api=%.1fs, %d tools, turn_total=%.1fs",
self.task_id[:8], turn + 1, api_elapsed,
len(assistant_msg.tool_calls), turn_elapsed,
self.task_id[:8],
turn + 1,
api_elapsed,
len(assistant_msg.tool_calls),
turn_elapsed,
)

else:
Expand All @@ -476,7 +544,10 @@ def _tc_to_dict(tc):
turn_elapsed = _time.monotonic() - turn_start
logger.info(
"[%s] turn %d: api=%.1fs, no tools (finished), turn_total=%.1fs",
self.task_id[:8], turn + 1, api_elapsed, turn_elapsed,
self.task_id[:8],
turn + 1,
api_elapsed,
turn_elapsed,
)

return AgentResult(
Expand Down
23 changes: 23 additions & 0 deletions environments/code_debug_env/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# CodeDebugEnv (hermes-agent implementation)

A multi-turn RL environment for training and evaluating Hermes agents on code debugging tasks (e.g., HumanEvalPack).

## Features
- **Real Tool Execution**: The agent has access to `terminal`, `read_file`, `write_file`, and `patch`.
- **Execution-based Scoring**: Rewards are calculated based on the actual test pass rate after the agent's fixes.
- **Universal Tool Strategy**: A specialized architecture for stabilizing vLLM inference by bypassing server-side tool-call parsing and using ultra-robust client-side parsing.

## Configuration
The environment is configured via `default.yaml`. Key settings:
- `atropos_inhibit_tools: True`: Bypasses vLLM's internal tool parser to prevent 400/500 errors.
- `system_prompt`: Contains the manual tool documentation for the client-side parser.

## Usage
Run the environment in process mode:
```bash
/opt/conda/envs/hermes_conda/bin/python environments/code_debug_env/code_debug_env.py process \
--config environments/code_debug_env/default.yaml
```

## Stability Notes
We use a custom `parse_tool_calls_from_text` function in `model_tools.py` to extract tool calls from `<tool_code>` tags. This handles varied model outputs and is much more stable than standard server-side parsing for vLLM 0.6.5.
Empty file.
Loading