Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions acp_adapter/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,6 @@ async def prompt(

tool_call_ids: dict[str, Deque[str]] = defaultdict(deque)
tool_call_meta: dict[str, dict[str, Any]] = {}
previous_approval_cb = None

if conn:
tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids, tool_call_meta)
Expand All @@ -517,10 +516,16 @@ async def prompt(
agent.step_callback = step_cb
agent.message_callback = message_cb

# Install the per-session approval callback into the current asyncio
# task's context. Because ``terminal_tool._approval_callback_var`` is
# a ``ContextVar`` and ``loop.run_in_executor`` propagates the caller's
# context to the worker thread, concurrent ACP sessions in the same
# process each see their own callback without stomping on each other.
# No save/restore is needed: when this coroutine returns, the context
# snapshot holding the set is discarded.
if approval_cb:
try:
from tools import terminal_tool as _terminal_tool
previous_approval_cb = getattr(_terminal_tool, "_approval_callback", None)
_terminal_tool.set_approval_callback(approval_cb)
except Exception:
logger.debug("Could not set ACP approval callback", exc_info=True)
Expand All @@ -536,16 +541,16 @@ def _run_agent() -> dict:
except Exception as e:
logger.exception("Agent error in session %s", session_id)
return {"final_response": f"Error: {e}", "messages": state.history}
finally:
if approval_cb:
try:
from tools import terminal_tool as _terminal_tool
_terminal_tool.set_approval_callback(previous_approval_cb)
except Exception:
logger.debug("Could not restore approval callback", exc_info=True)

try:
result = await loop.run_in_executor(_executor, _run_agent)
# Copy the current asyncio task's context and run the agent inside
# it so per-session ContextVar state (e.g. the approval callback
# installed above via set_approval_callback) is visible to tool code
# executing on the worker thread. ``loop.run_in_executor`` does NOT
# propagate contextvars on its own.
import contextvars as _ctxvars
_ctx = _ctxvars.copy_context()
result = await loop.run_in_executor(_executor, lambda: _ctx.run(_run_agent))
except Exception:
logger.exception("Executor error for session %s", session_id)
return PromptResponse(stop_reason="end_turn")
Expand Down
154 changes: 154 additions & 0 deletions tests/acp/test_concurrent_approval_isolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""Regression tests for GHSA-qg5c-hvr5-hjgr.

Before the fix, ``tools.terminal_tool._approval_callback`` was a module-global.
When two ACP sessions overlapped in the same process, session B's
``set_approval_callback`` overwrote session A's — so session A's
dangerous-command approval could be routed through session B's callback
(and vice versa).

The fix stores the callback in a ``ContextVar`` that each asyncio task
gets its own copy of, and ACP's ``prompt`` handler wraps the executor call
with ``contextvars.copy_context().run(...)`` so the per-session callback
survives the hop into the worker thread.

These tests exercise the primitive directly without spinning up a full
``HermesACPAgent`` — they verify that:

1. Two concurrent asyncio tasks can each set ``_approval_callback_var`` to
a distinct session-specific callback and each see their own value.
2. The value is still visible from inside a ``run_in_executor`` worker
thread when the caller uses ``copy_context().run``.
3. The raw ``run_in_executor`` path without ``copy_context`` does NOT
propagate contextvars — this is the asyncio contract we rely on the
ACP adapter to bridge.
"""

import asyncio
import contextvars

import pytest

from tools import terminal_tool as tt


async def _session(session_id: str, overlap_delay: float, observed: dict):
"""Simulate an ACP session.

1. Registers a session-specific approval callback via the public
``set_approval_callback`` API.
2. Yields control so sibling tasks can install their own callbacks
and create a realistic overlap window.
3. Runs a synchronous worker in a thread executor using
``copy_context().run`` (mirrors the ACP adapter's pattern) and
records which callback identity the worker observes.
"""
def approval_cb(command, description, **_):
return f"approval-from-{session_id}"

tt.set_approval_callback(approval_cb)
await asyncio.sleep(overlap_delay)

loop = asyncio.get_running_loop()
ctx = contextvars.copy_context()

def _in_worker():
cb = tt._approval_callback_var.get()
return cb("rm -rf /", "dangerous") if cb else None

observed[session_id] = await loop.run_in_executor(
None, lambda: ctx.run(_in_worker)
)


class TestConcurrentACPApprovalIsolation:
"""Regression guard for cross-session approval callback confusion."""

def test_concurrent_sessions_see_their_own_callback(self):
"""Two overlapping ACP sessions each observe their own callback.

Session A starts first but sleeps longer, so by the time it reads
its callback, session B has already registered its own. Before
the ContextVar fix, both sessions would observe whichever callback
was set most recently in the module-global slot.
"""
observed: dict = {}

async def main():
await asyncio.gather(
_session("A-cd0fa01e", 0.05, observed),
_session("B-cc2f5ce8", 0.02, observed),
)

asyncio.run(main())

assert observed["A-cd0fa01e"] == "approval-from-A-cd0fa01e"
assert observed["B-cc2f5ce8"] == "approval-from-B-cc2f5ce8"

def test_callback_visible_through_run_in_executor_with_copy_context(self):
"""``copy_context().run`` propagates the callback into the worker thread."""
async def runner():
def cb(cmd, desc, **_):
return "approved"

tt.set_approval_callback(cb)

loop = asyncio.get_running_loop()
ctx = contextvars.copy_context()

def _worker():
got = tt._approval_callback_var.get()
return got("x", "y") if got else None

return await loop.run_in_executor(None, lambda: ctx.run(_worker))

assert asyncio.run(runner()) == "approved"

def test_set_approval_callback_is_context_scoped(self):
"""A direct ``set_approval_callback`` call does not leak into the caller's context.

This is the asyncio-level guarantee the ACP fix relies on: a child
task's ``ContextVar.set`` mutates only the child's context copy.
"""
observed: dict = {}

async def child():
def cb(cmd, desc, **_):
return "child"
tt.set_approval_callback(cb)
observed["child"] = tt._approval_callback_var.get()("x", "y")

async def main():
# Parent sees no callback
observed["parent_before"] = tt._approval_callback_var.get()
await asyncio.create_task(child())
# Parent still sees no callback after child completes
observed["parent_after"] = tt._approval_callback_var.get()

asyncio.run(main())

assert observed["parent_before"] is None
assert observed["child"] == "child"
assert observed["parent_after"] is None


class TestRunInExecutorContextContract:
"""Document the asyncio contract the ACP adapter relies on."""

def test_run_in_executor_without_copy_context_does_not_propagate(self):
"""Without ``copy_context().run``, contextvars do NOT cross into the worker.

This is the asyncio standard-library behavior. If the ACP adapter
ever drops the ``copy_context().run`` wrapper around ``_run_agent``,
this test will pass (contextvars will appear empty in the worker)
while the isolation test above will fail — a clear signal that the
bridging wrapper is missing.
"""
probe: contextvars.ContextVar = contextvars.ContextVar("probe", default="unset")

async def runner():
probe.set("set-in-task")
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, probe.get)

# Worker thread does not inherit the task's context
assert asyncio.run(runner()) == "unset"
42 changes: 33 additions & 9 deletions tools/terminal_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import atexit
import shutil
import subprocess
from contextvars import ContextVar
from pathlib import Path
from typing import Optional, Dict, Any, List

Expand Down Expand Up @@ -116,20 +117,42 @@ def _check_disk_usage_warning():
# so prompts route through prompt_toolkit's event loop.
# _sudo_password_callback() -> str (return password or "" to skip)
# _approval_callback(command, description) -> str ("once"/"session"/"always"/"deny")
_sudo_password_callback = None
_approval_callback = None
#
# These callbacks are stored in ``ContextVar`` so that concurrent sessions
# hosted in a single process (e.g. multiple ACP sessions in one Hermes ACP
# adapter) each see their own callback. A module-global would let one
# session's callback overwrite another's while both are still running, routing
# dangerous-command approval prompts to the wrong editor/session context.
# ``asyncio`` tasks — and threads launched via ``loop.run_in_executor`` —
# each receive a copy of the caller's context, so per-task isolation is
# automatic. CLI callers that set the callback once at startup still work
# unchanged: a single context holds the single callback.
_sudo_password_callback_var: ContextVar = ContextVar(
"_sudo_password_callback", default=None
)
_approval_callback_var: ContextVar = ContextVar(
"_approval_callback", default=None
)


def set_sudo_password_callback(cb):
"""Register a callback for sudo password prompts (used by CLI)."""
global _sudo_password_callback
_sudo_password_callback = cb
_sudo_password_callback_var.set(cb)


def set_approval_callback(cb):
"""Register a callback for dangerous command approval prompts (used by CLI)."""
global _approval_callback
_approval_callback = cb
_approval_callback_var.set(cb)


def _get_sudo_password_callback():
"""Return the sudo password callback for the current context."""
return _sudo_password_callback_var.get()


def _get_approval_callback():
"""Return the approval callback for the current context."""
return _approval_callback_var.get()

# =============================================================================
# Dangerous Command Approval System
Expand All @@ -144,7 +167,7 @@ def set_approval_callback(cb):
def _check_all_guards(command: str, env_type: str) -> dict:
"""Delegate to consolidated guard (tirith + dangerous cmd) with CLI callback."""
return _check_all_guards_impl(command, env_type,
approval_callback=_approval_callback)
approval_callback=_approval_callback_var.get())


# Allowlist: characters that can legitimately appear in directory paths.
Expand Down Expand Up @@ -219,9 +242,10 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str:
import sys

# Use the registered callback when available (prompt_toolkit-compatible)
if _sudo_password_callback is not None:
_cb = _sudo_password_callback_var.get()
if _cb is not None:
try:
return _sudo_password_callback() or ""
return _cb() or ""
except Exception:
return ""

Expand Down
Loading