diff --git a/tests/tools/test_mcp_circuit_breaker.py b/tests/tools/test_mcp_circuit_breaker.py new file mode 100644 index 00000000000..0173fa52afe --- /dev/null +++ b/tests/tools/test_mcp_circuit_breaker.py @@ -0,0 +1,252 @@ +"""Tests for MCP tool-handler circuit-breaker recovery. + +The circuit breaker in ``tools/mcp_tool.py`` is intended to short-circuit +calls to an MCP server that has failed ``_CIRCUIT_BREAKER_THRESHOLD`` +consecutive times, then *transition back to a usable state* once the +server has had time to recover (or an explicit reconnect succeeds). + +The original implementation only had two states — closed and open — with +no mechanism to transition back to closed, so a tripped breaker stayed +tripped for the lifetime of the process. These tests lock in the +half-open / cooldown / reconnect-resets-breaker behavior that fixes +that. +""" +import json +from unittest.mock import MagicMock + +import pytest + + +pytest.importorskip("mcp.client.auth.oauth2") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _install_stub_server(mcp_tool_module, name: str, call_tool_impl): + """Install a fake MCP server in the module's registry. + + ``call_tool_impl`` is an async function stored at ``session.call_tool`` + (it's what the tool handler invokes). + """ + server = MagicMock() + server.name = name + session = MagicMock() + session.call_tool = call_tool_impl + server.session = session + server._reconnect_event = MagicMock() + server._ready = MagicMock() + server._ready.is_set.return_value = True + + mcp_tool_module._servers[name] = server + mcp_tool_module._server_error_counts.pop(name, None) + if hasattr(mcp_tool_module, "_server_breaker_opened_at"): + mcp_tool_module._server_breaker_opened_at.pop(name, None) + return server + + +def _cleanup(mcp_tool_module, name: str) -> None: + mcp_tool_module._servers.pop(name, None) + mcp_tool_module._server_error_counts.pop(name, None) + if hasattr(mcp_tool_module, "_server_breaker_opened_at"): + mcp_tool_module._server_breaker_opened_at.pop(name, None) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_circuit_breaker_half_opens_after_cooldown(monkeypatch, tmp_path): + """After a tripped breaker's cooldown elapses, the *next* call must + actually execute against the session (half-open probe). When the + probe succeeds, the breaker resets to fully closed. + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools import mcp_tool + from tools.mcp_tool import _make_tool_handler + + call_count = {"n": 0} + + async def _call_tool_success(*a, **kw): + call_count["n"] += 1 + result = MagicMock() + result.isError = False + block = MagicMock() + block.text = "ok" + result.content = [block] + result.structuredContent = None + return result + + _install_stub_server(mcp_tool, "srv", _call_tool_success) + mcp_tool._ensure_mcp_loop() + + try: + # Trip the breaker by setting the count at/above threshold and + # stamping the open-time to "now". + mcp_tool._server_error_counts["srv"] = mcp_tool._CIRCUIT_BREAKER_THRESHOLD + fake_now = [1000.0] + + def _fake_monotonic(): + return fake_now[0] + + monkeypatch.setattr(mcp_tool.time, "monotonic", _fake_monotonic) + # The breaker-open timestamp dict is introduced by the fix; on + # a pre-fix build it won't exist, which will cause the test to + # fail at the .get() inside the gate (correct — the fix is + # required for this state to be tracked at all). + if hasattr(mcp_tool, "_server_breaker_opened_at"): + mcp_tool._server_breaker_opened_at["srv"] = fake_now[0] + cooldown = getattr(mcp_tool, "_CIRCUIT_BREAKER_COOLDOWN_SEC", 60.0) + + handler = _make_tool_handler("srv", "tool1", 10.0) + + # Before cooldown: must short-circuit (no session call). + result = handler({}) + parsed = json.loads(result) + assert "error" in parsed, parsed + assert "unreachable" in parsed["error"].lower() + assert call_count["n"] == 0, ( + "breaker should short-circuit before cooldown elapses" + ) + + # Advance past cooldown → next call is a half-open probe that + # actually hits the session. + fake_now[0] += cooldown + 1.0 + + result = handler({}) + parsed = json.loads(result) + assert parsed.get("result") == "ok", parsed + assert call_count["n"] == 1, "half-open probe should invoke session" + + # On probe success the breaker must close (count reset to 0). + assert mcp_tool._server_error_counts.get("srv", 0) == 0 + finally: + _cleanup(mcp_tool, "srv") + + +def test_circuit_breaker_reopens_on_probe_failure(monkeypatch, tmp_path): + """If the half-open probe fails, the breaker must re-arm the + cooldown (not let every subsequent call through). + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools import mcp_tool + from tools.mcp_tool import _make_tool_handler + + call_count = {"n": 0} + + async def _call_tool_fails(*a, **kw): + call_count["n"] += 1 + raise RuntimeError("still broken") + + _install_stub_server(mcp_tool, "srv", _call_tool_fails) + mcp_tool._ensure_mcp_loop() + + try: + mcp_tool._server_error_counts["srv"] = mcp_tool._CIRCUIT_BREAKER_THRESHOLD + fake_now = [1000.0] + + def _fake_monotonic(): + return fake_now[0] + + monkeypatch.setattr(mcp_tool.time, "monotonic", _fake_monotonic) + if hasattr(mcp_tool, "_server_breaker_opened_at"): + mcp_tool._server_breaker_opened_at["srv"] = fake_now[0] + cooldown = getattr(mcp_tool, "_CIRCUIT_BREAKER_COOLDOWN_SEC", 60.0) + + handler = _make_tool_handler("srv", "tool1", 10.0) + + # Advance past cooldown, run probe, expect failure. + fake_now[0] += cooldown + 1.0 + result = handler({}) + parsed = json.loads(result) + assert "error" in parsed + assert call_count["n"] == 1, "probe should invoke session once" + + # The probe failure must have re-armed the cooldown — another + # immediate call should short-circuit, not invoke session again. + result = handler({}) + parsed = json.loads(result) + assert "unreachable" in parsed.get("error", "").lower() + assert call_count["n"] == 1, ( + "breaker should re-open and block further calls after probe failure" + ) + finally: + _cleanup(mcp_tool, "srv") + + +def test_circuit_breaker_cleared_on_reconnect(monkeypatch, tmp_path): + """When the auth-recovery path successfully reconnects the server, + the breaker should be cleared so subsequent calls aren't gated on a + stale failure count — even if the post-reconnect retry itself fails. + + This locks in the fix-#2 contract: a successful reconnect is + sufficient evidence that the server is viable again. Under the old + implementation, reset only happened on retry *success*, so a + reconnect+retry-failure left the counter pinned above threshold + forever. + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools import mcp_tool + from tools.mcp_oauth_manager import get_manager, reset_manager_for_tests + from mcp.client.auth import OAuthFlowError + + reset_manager_for_tests() + + async def _call_tool_unused(*a, **kw): # pragma: no cover + raise AssertionError("session.call_tool should not be reached in this test") + + _install_stub_server(mcp_tool, "srv", _call_tool_unused) + mcp_tool._ensure_mcp_loop() + + # Open the breaker well above threshold, with a recent open-time so + # it would short-circuit everything without a reset. + mcp_tool._server_error_counts["srv"] = mcp_tool._CIRCUIT_BREAKER_THRESHOLD + 2 + if hasattr(mcp_tool, "_server_breaker_opened_at"): + import time as _time + mcp_tool._server_breaker_opened_at["srv"] = _time.monotonic() + + # Force handle_401 to claim recovery succeeded. + mgr = get_manager() + + async def _h401(name, token=None): + return True + + monkeypatch.setattr(mgr, "handle_401", _h401) + + try: + # Retry fails *after* the successful reconnect. Under the old + # implementation this bumps an already-tripped counter even + # higher. Under fix #2 the reset happens on successful + # reconnect, and the post-retry bump only raises the fresh + # count to 1 — still below threshold. + def _retry_call(): + raise OAuthFlowError("still failing post-reconnect") + + result = mcp_tool._handle_auth_error_and_retry( + "srv", + OAuthFlowError("initial"), + _retry_call, + "tools/call test", + ) + # The call as a whole still surfaces needs_reauth because the + # retry itself didn't succeed, but the breaker state must + # reflect the successful reconnect. + assert result is not None + parsed = json.loads(result) + assert parsed.get("needs_reauth") is True, parsed + + # Post-reconnect count was reset to 0, then the failing retry + # bumped it to exactly 1 — well below threshold. + count = mcp_tool._server_error_counts.get("srv", 0) + assert count < mcp_tool._CIRCUIT_BREAKER_THRESHOLD, ( + f"successful reconnect must reset the breaker below threshold; " + f"got count={count}, threshold={mcp_tool._CIRCUIT_BREAKER_THRESHOLD}" + ) + finally: + _cleanup(mcp_tool, "srv") diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index e5e856d0bb5..5c4c0ab368e 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -1249,9 +1249,47 @@ async def shutdown(self): # _CIRCUIT_BREAKER_THRESHOLD consecutive failures, the handler returns # a "server unreachable" message that tells the model to stop retrying, # preventing the 90-iteration burn loop described in #10447. -# Reset to 0 on any successful call. +# +# State machine: +# closed — error count below threshold; all calls go through. +# open — threshold reached; calls short-circuit until the +# cooldown elapses. +# half-open — cooldown elapsed; the next call is a probe that +# actually hits the session. Probe success → closed. +# Probe failure → reopens (cooldown re-armed). +# +# ``_server_breaker_opened_at`` records the monotonic timestamp when +# the breaker most recently transitioned into the open state. Use the +# ``_bump_server_error`` / ``_reset_server_error`` helpers to mutate +# this state — they keep the count and timestamp in sync. _server_error_counts: Dict[str, int] = {} +_server_breaker_opened_at: Dict[str, float] = {} _CIRCUIT_BREAKER_THRESHOLD = 3 +_CIRCUIT_BREAKER_COOLDOWN_SEC = 60.0 + + +def _bump_server_error(server_name: str) -> None: + """Increment the consecutive-failure count for ``server_name``. + + When the count crosses :data:`_CIRCUIT_BREAKER_THRESHOLD`, stamp the + breaker-open timestamp so the cooldown clock starts (or re-starts, + for probe failures in the half-open state). + """ + n = _server_error_counts.get(server_name, 0) + 1 + _server_error_counts[server_name] = n + if n >= _CIRCUIT_BREAKER_THRESHOLD: + _server_breaker_opened_at[server_name] = time.monotonic() + + +def _reset_server_error(server_name: str) -> None: + """Fully close the breaker for ``server_name``. + + Clears both the failure count and the breaker-open timestamp. Call + this on any unambiguous success signal (successful tool call, + successful reconnect, manual /mcp refresh). + """ + _server_error_counts[server_name] = 0 + _server_breaker_opened_at.pop(server_name, None) # --------------------------------------------------------------------------- # Auth-failure detection helpers (Task 6 of MCP OAuth consolidation) @@ -1391,15 +1429,25 @@ async def _recover(): break time.sleep(0.25) + # A successful OAuth recovery is independent evidence that the + # server is viable again, so close the circuit breaker here — + # not only on retry success. Without this, a reconnect + # followed by a failing retry would leave the breaker pinned + # above threshold forever (the retry-exception branch below + # bumps the count again). The post-reset retry still goes + # through _bump_server_error on failure, so a genuinely broken + # server will re-trip the breaker as normal. + _reset_server_error(server_name) + try: result = retry_call() try: parsed = json.loads(result) if "error" not in parsed: - _server_error_counts[server_name] = 0 + _reset_server_error(server_name) return result except (json.JSONDecodeError, TypeError): - _server_error_counts[server_name] = 0 + _reset_server_error(server_name) return result except Exception as retry_exc: logger.warning( @@ -1410,7 +1458,7 @@ async def _recover(): # No recovery available, or retry also failed: surface a structured # needs_reauth error. Bumps the circuit breaker so the model stops # retrying the tool. - _server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1 + _bump_server_error(server_name) return json.dumps({ "error": ( f"MCP server '{server_name}' requires re-authentication. " @@ -1615,20 +1663,33 @@ def _handler(args: dict, **kwargs) -> str: # Circuit breaker: if this server has failed too many times # consecutively, short-circuit with a clear message so the model # stops retrying and uses alternative approaches (#10447). + # + # Once the cooldown elapses, the breaker transitions to + # half-open: we let the *next* call through as a probe. On + # success the success-path below resets the breaker; on + # failure the error paths below bump the count again, which + # re-stamps the open-time via _bump_server_error (re-arming + # the cooldown). if _server_error_counts.get(server_name, 0) >= _CIRCUIT_BREAKER_THRESHOLD: - return json.dumps({ - "error": ( - f"MCP server '{server_name}' is unreachable after " - f"{_CIRCUIT_BREAKER_THRESHOLD} consecutive failures. " - f"Do NOT retry this tool — use alternative approaches " - f"or ask the user to check the MCP server." - ) - }, ensure_ascii=False) + opened_at = _server_breaker_opened_at.get(server_name, 0.0) + age = time.monotonic() - opened_at + if age < _CIRCUIT_BREAKER_COOLDOWN_SEC: + remaining = max(1, int(_CIRCUIT_BREAKER_COOLDOWN_SEC - age)) + return json.dumps({ + "error": ( + f"MCP server '{server_name}' is unreachable after " + f"{_server_error_counts[server_name]} consecutive " + f"failures. Auto-retry available in ~{remaining}s. " + f"Do NOT retry this tool yet — use alternative " + f"approaches or ask the user to check the MCP server." + ) + }, ensure_ascii=False) + # Cooldown elapsed → fall through as a half-open probe. with _lock: server = _servers.get(server_name) if not server or not server.session: - _server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1 + _bump_server_error(server_name) return json.dumps({ "error": f"MCP server '{server_name}' is not connected" }, ensure_ascii=False) @@ -1677,11 +1738,11 @@ def _call_once(): try: parsed = json.loads(result) if "error" in parsed: - _server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1 + _bump_server_error(server_name) else: - _server_error_counts[server_name] = 0 # success — reset + _reset_server_error(server_name) # success — reset except (json.JSONDecodeError, TypeError): - _server_error_counts[server_name] = 0 # non-JSON = success + _reset_server_error(server_name) # non-JSON = success return result except InterruptedError: return _interrupted_call_result() @@ -1696,7 +1757,7 @@ def _call_once(): if recovered is not None: return recovered - _server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1 + _bump_server_error(server_name) logger.error( "MCP tool %s/%s call failed: %s", server_name, tool_name, exc,