diff --git a/acp_adapter/server.py b/acp_adapter/server.py index 119a08685a3..a18c43c6f22 100644 --- a/acp_adapter/server.py +++ b/acp_adapter/server.py @@ -496,7 +496,6 @@ async def prompt( tool_call_ids: dict[str, Deque[str]] = defaultdict(deque) tool_call_meta: dict[str, dict[str, Any]] = {} - previous_approval_cb = None if conn: tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids, tool_call_meta) @@ -517,10 +516,16 @@ async def prompt( agent.step_callback = step_cb agent.message_callback = message_cb + # Install the per-session approval callback into the current asyncio + # task's context. Because ``terminal_tool._approval_callback_var`` is + # a ``ContextVar`` and ``loop.run_in_executor`` propagates the caller's + # context to the worker thread, concurrent ACP sessions in the same + # process each see their own callback without stomping on each other. + # No save/restore is needed: when this coroutine returns, the context + # snapshot holding the set is discarded. if approval_cb: try: from tools import terminal_tool as _terminal_tool - previous_approval_cb = getattr(_terminal_tool, "_approval_callback", None) _terminal_tool.set_approval_callback(approval_cb) except Exception: logger.debug("Could not set ACP approval callback", exc_info=True) @@ -536,16 +541,16 @@ def _run_agent() -> dict: except Exception as e: logger.exception("Agent error in session %s", session_id) return {"final_response": f"Error: {e}", "messages": state.history} - finally: - if approval_cb: - try: - from tools import terminal_tool as _terminal_tool - _terminal_tool.set_approval_callback(previous_approval_cb) - except Exception: - logger.debug("Could not restore approval callback", exc_info=True) try: - result = await loop.run_in_executor(_executor, _run_agent) + # Copy the current asyncio task's context and run the agent inside + # it so per-session ContextVar state (e.g. the approval callback + # installed above via set_approval_callback) is visible to tool code + # executing on the worker thread. ``loop.run_in_executor`` does NOT + # propagate contextvars on its own. + import contextvars as _ctxvars + _ctx = _ctxvars.copy_context() + result = await loop.run_in_executor(_executor, lambda: _ctx.run(_run_agent)) except Exception: logger.exception("Executor error for session %s", session_id) return PromptResponse(stop_reason="end_turn") diff --git a/tests/acp/test_concurrent_approval_isolation.py b/tests/acp/test_concurrent_approval_isolation.py new file mode 100644 index 00000000000..687b9324401 --- /dev/null +++ b/tests/acp/test_concurrent_approval_isolation.py @@ -0,0 +1,154 @@ +"""Regression tests for GHSA-qg5c-hvr5-hjgr. + +Before the fix, ``tools.terminal_tool._approval_callback`` was a module-global. +When two ACP sessions overlapped in the same process, session B's +``set_approval_callback`` overwrote session A's — so session A's +dangerous-command approval could be routed through session B's callback +(and vice versa). + +The fix stores the callback in a ``ContextVar`` that each asyncio task +gets its own copy of, and ACP's ``prompt`` handler wraps the executor call +with ``contextvars.copy_context().run(...)`` so the per-session callback +survives the hop into the worker thread. + +These tests exercise the primitive directly without spinning up a full +``HermesACPAgent`` — they verify that: + +1. Two concurrent asyncio tasks can each set ``_approval_callback_var`` to + a distinct session-specific callback and each see their own value. +2. The value is still visible from inside a ``run_in_executor`` worker + thread when the caller uses ``copy_context().run``. +3. The raw ``run_in_executor`` path without ``copy_context`` does NOT + propagate contextvars — this is the asyncio contract we rely on the + ACP adapter to bridge. +""" + +import asyncio +import contextvars + +import pytest + +from tools import terminal_tool as tt + + +async def _session(session_id: str, overlap_delay: float, observed: dict): + """Simulate an ACP session. + + 1. Registers a session-specific approval callback via the public + ``set_approval_callback`` API. + 2. Yields control so sibling tasks can install their own callbacks + and create a realistic overlap window. + 3. Runs a synchronous worker in a thread executor using + ``copy_context().run`` (mirrors the ACP adapter's pattern) and + records which callback identity the worker observes. + """ + def approval_cb(command, description, **_): + return f"approval-from-{session_id}" + + tt.set_approval_callback(approval_cb) + await asyncio.sleep(overlap_delay) + + loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() + + def _in_worker(): + cb = tt._approval_callback_var.get() + return cb("rm -rf /", "dangerous") if cb else None + + observed[session_id] = await loop.run_in_executor( + None, lambda: ctx.run(_in_worker) + ) + + +class TestConcurrentACPApprovalIsolation: + """Regression guard for cross-session approval callback confusion.""" + + def test_concurrent_sessions_see_their_own_callback(self): + """Two overlapping ACP sessions each observe their own callback. + + Session A starts first but sleeps longer, so by the time it reads + its callback, session B has already registered its own. Before + the ContextVar fix, both sessions would observe whichever callback + was set most recently in the module-global slot. + """ + observed: dict = {} + + async def main(): + await asyncio.gather( + _session("A-cd0fa01e", 0.05, observed), + _session("B-cc2f5ce8", 0.02, observed), + ) + + asyncio.run(main()) + + assert observed["A-cd0fa01e"] == "approval-from-A-cd0fa01e" + assert observed["B-cc2f5ce8"] == "approval-from-B-cc2f5ce8" + + def test_callback_visible_through_run_in_executor_with_copy_context(self): + """``copy_context().run`` propagates the callback into the worker thread.""" + async def runner(): + def cb(cmd, desc, **_): + return "approved" + + tt.set_approval_callback(cb) + + loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() + + def _worker(): + got = tt._approval_callback_var.get() + return got("x", "y") if got else None + + return await loop.run_in_executor(None, lambda: ctx.run(_worker)) + + assert asyncio.run(runner()) == "approved" + + def test_set_approval_callback_is_context_scoped(self): + """A direct ``set_approval_callback`` call does not leak into the caller's context. + + This is the asyncio-level guarantee the ACP fix relies on: a child + task's ``ContextVar.set`` mutates only the child's context copy. + """ + observed: dict = {} + + async def child(): + def cb(cmd, desc, **_): + return "child" + tt.set_approval_callback(cb) + observed["child"] = tt._approval_callback_var.get()("x", "y") + + async def main(): + # Parent sees no callback + observed["parent_before"] = tt._approval_callback_var.get() + await asyncio.create_task(child()) + # Parent still sees no callback after child completes + observed["parent_after"] = tt._approval_callback_var.get() + + asyncio.run(main()) + + assert observed["parent_before"] is None + assert observed["child"] == "child" + assert observed["parent_after"] is None + + +class TestRunInExecutorContextContract: + """Document the asyncio contract the ACP adapter relies on.""" + + def test_run_in_executor_without_copy_context_does_not_propagate(self): + """Without ``copy_context().run``, contextvars do NOT cross into the worker. + + This is the asyncio standard-library behavior. If the ACP adapter + ever drops the ``copy_context().run`` wrapper around ``_run_agent``, + this test will pass (contextvars will appear empty in the worker) + while the isolation test above will fail — a clear signal that the + bridging wrapper is missing. + """ + probe: contextvars.ContextVar = contextvars.ContextVar("probe", default="unset") + + async def runner(): + probe.set("set-in-task") + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, probe.get) + + # Worker thread does not inherit the task's context + assert asyncio.run(runner()) == "unset" diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index 7a7dc9c1a6b..1d1aaee71d4 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -41,6 +41,7 @@ import atexit import shutil import subprocess +from contextvars import ContextVar from pathlib import Path from typing import Optional, Dict, Any, List @@ -116,20 +117,42 @@ def _check_disk_usage_warning(): # so prompts route through prompt_toolkit's event loop. # _sudo_password_callback() -> str (return password or "" to skip) # _approval_callback(command, description) -> str ("once"/"session"/"always"/"deny") -_sudo_password_callback = None -_approval_callback = None +# +# These callbacks are stored in ``ContextVar`` so that concurrent sessions +# hosted in a single process (e.g. multiple ACP sessions in one Hermes ACP +# adapter) each see their own callback. A module-global would let one +# session's callback overwrite another's while both are still running, routing +# dangerous-command approval prompts to the wrong editor/session context. +# ``asyncio`` tasks — and threads launched via ``loop.run_in_executor`` — +# each receive a copy of the caller's context, so per-task isolation is +# automatic. CLI callers that set the callback once at startup still work +# unchanged: a single context holds the single callback. +_sudo_password_callback_var: ContextVar = ContextVar( + "_sudo_password_callback", default=None +) +_approval_callback_var: ContextVar = ContextVar( + "_approval_callback", default=None +) def set_sudo_password_callback(cb): """Register a callback for sudo password prompts (used by CLI).""" - global _sudo_password_callback - _sudo_password_callback = cb + _sudo_password_callback_var.set(cb) def set_approval_callback(cb): """Register a callback for dangerous command approval prompts (used by CLI).""" - global _approval_callback - _approval_callback = cb + _approval_callback_var.set(cb) + + +def _get_sudo_password_callback(): + """Return the sudo password callback for the current context.""" + return _sudo_password_callback_var.get() + + +def _get_approval_callback(): + """Return the approval callback for the current context.""" + return _approval_callback_var.get() # ============================================================================= # Dangerous Command Approval System @@ -144,7 +167,7 @@ def set_approval_callback(cb): def _check_all_guards(command: str, env_type: str) -> dict: """Delegate to consolidated guard (tirith + dangerous cmd) with CLI callback.""" return _check_all_guards_impl(command, env_type, - approval_callback=_approval_callback) + approval_callback=_approval_callback_var.get()) # Allowlist: characters that can legitimately appear in directory paths. @@ -219,9 +242,10 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str: import sys # Use the registered callback when available (prompt_toolkit-compatible) - if _sudo_password_callback is not None: + _cb = _sudo_password_callback_var.get() + if _cb is not None: try: - return _sudo_password_callback() or "" + return _cb() or "" except Exception: return ""