Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 252 additions & 0 deletions tests/tools/test_mcp_circuit_breaker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
"""Tests for MCP tool-handler circuit-breaker recovery.

The circuit breaker in ``tools/mcp_tool.py`` is intended to short-circuit
calls to an MCP server that has failed ``_CIRCUIT_BREAKER_THRESHOLD``
consecutive times, then *transition back to a usable state* once the
server has had time to recover (or an explicit reconnect succeeds).

The original implementation only had two states β€” closed and open β€” with
no mechanism to transition back to closed, so a tripped breaker stayed
tripped for the lifetime of the process. These tests lock in the
half-open / cooldown / reconnect-resets-breaker behavior that fixes
that.
"""
import json
from unittest.mock import MagicMock

import pytest


pytest.importorskip("mcp.client.auth.oauth2")


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _install_stub_server(mcp_tool_module, name: str, call_tool_impl):
"""Install a fake MCP server in the module's registry.

``call_tool_impl`` is an async function stored at ``session.call_tool``
(it's what the tool handler invokes).
"""
server = MagicMock()
server.name = name
session = MagicMock()
session.call_tool = call_tool_impl
server.session = session
server._reconnect_event = MagicMock()
server._ready = MagicMock()
server._ready.is_set.return_value = True

mcp_tool_module._servers[name] = server
mcp_tool_module._server_error_counts.pop(name, None)
if hasattr(mcp_tool_module, "_server_breaker_opened_at"):
mcp_tool_module._server_breaker_opened_at.pop(name, None)
return server


def _cleanup(mcp_tool_module, name: str) -> None:
mcp_tool_module._servers.pop(name, None)
mcp_tool_module._server_error_counts.pop(name, None)
if hasattr(mcp_tool_module, "_server_breaker_opened_at"):
mcp_tool_module._server_breaker_opened_at.pop(name, None)


# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------


def test_circuit_breaker_half_opens_after_cooldown(monkeypatch, tmp_path):
"""After a tripped breaker's cooldown elapses, the *next* call must
actually execute against the session (half-open probe). When the
probe succeeds, the breaker resets to fully closed.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))

from tools import mcp_tool
from tools.mcp_tool import _make_tool_handler

call_count = {"n": 0}

async def _call_tool_success(*a, **kw):
call_count["n"] += 1
result = MagicMock()
result.isError = False
block = MagicMock()
block.text = "ok"
result.content = [block]
result.structuredContent = None
return result

_install_stub_server(mcp_tool, "srv", _call_tool_success)
mcp_tool._ensure_mcp_loop()

try:
# Trip the breaker by setting the count at/above threshold and
# stamping the open-time to "now".
mcp_tool._server_error_counts["srv"] = mcp_tool._CIRCUIT_BREAKER_THRESHOLD
fake_now = [1000.0]

def _fake_monotonic():
return fake_now[0]

monkeypatch.setattr(mcp_tool.time, "monotonic", _fake_monotonic)
# The breaker-open timestamp dict is introduced by the fix; on
# a pre-fix build it won't exist, which will cause the test to
# fail at the .get() inside the gate (correct β€” the fix is
# required for this state to be tracked at all).
if hasattr(mcp_tool, "_server_breaker_opened_at"):
mcp_tool._server_breaker_opened_at["srv"] = fake_now[0]
cooldown = getattr(mcp_tool, "_CIRCUIT_BREAKER_COOLDOWN_SEC", 60.0)

handler = _make_tool_handler("srv", "tool1", 10.0)

# Before cooldown: must short-circuit (no session call).
result = handler({})
parsed = json.loads(result)
assert "error" in parsed, parsed
assert "unreachable" in parsed["error"].lower()
assert call_count["n"] == 0, (
"breaker should short-circuit before cooldown elapses"
)

# Advance past cooldown β†’ next call is a half-open probe that
# actually hits the session.
fake_now[0] += cooldown + 1.0

result = handler({})
parsed = json.loads(result)
assert parsed.get("result") == "ok", parsed
assert call_count["n"] == 1, "half-open probe should invoke session"

# On probe success the breaker must close (count reset to 0).
assert mcp_tool._server_error_counts.get("srv", 0) == 0
finally:
_cleanup(mcp_tool, "srv")


def test_circuit_breaker_reopens_on_probe_failure(monkeypatch, tmp_path):
"""If the half-open probe fails, the breaker must re-arm the
cooldown (not let every subsequent call through).
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))

from tools import mcp_tool
from tools.mcp_tool import _make_tool_handler

call_count = {"n": 0}

async def _call_tool_fails(*a, **kw):
call_count["n"] += 1
raise RuntimeError("still broken")

_install_stub_server(mcp_tool, "srv", _call_tool_fails)
mcp_tool._ensure_mcp_loop()

try:
mcp_tool._server_error_counts["srv"] = mcp_tool._CIRCUIT_BREAKER_THRESHOLD
fake_now = [1000.0]

def _fake_monotonic():
return fake_now[0]

monkeypatch.setattr(mcp_tool.time, "monotonic", _fake_monotonic)
if hasattr(mcp_tool, "_server_breaker_opened_at"):
mcp_tool._server_breaker_opened_at["srv"] = fake_now[0]
cooldown = getattr(mcp_tool, "_CIRCUIT_BREAKER_COOLDOWN_SEC", 60.0)

handler = _make_tool_handler("srv", "tool1", 10.0)

# Advance past cooldown, run probe, expect failure.
fake_now[0] += cooldown + 1.0
result = handler({})
parsed = json.loads(result)
assert "error" in parsed
assert call_count["n"] == 1, "probe should invoke session once"

# The probe failure must have re-armed the cooldown β€” another
# immediate call should short-circuit, not invoke session again.
result = handler({})
parsed = json.loads(result)
assert "unreachable" in parsed.get("error", "").lower()
assert call_count["n"] == 1, (
"breaker should re-open and block further calls after probe failure"
)
finally:
_cleanup(mcp_tool, "srv")


def test_circuit_breaker_cleared_on_reconnect(monkeypatch, tmp_path):
"""When the auth-recovery path successfully reconnects the server,
the breaker should be cleared so subsequent calls aren't gated on a
stale failure count β€” even if the post-reconnect retry itself fails.

This locks in the fix-#2 contract: a successful reconnect is
sufficient evidence that the server is viable again. Under the old
implementation, reset only happened on retry *success*, so a
reconnect+retry-failure left the counter pinned above threshold
forever.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))

from tools import mcp_tool
from tools.mcp_oauth_manager import get_manager, reset_manager_for_tests
from mcp.client.auth import OAuthFlowError

reset_manager_for_tests()

async def _call_tool_unused(*a, **kw): # pragma: no cover
raise AssertionError("session.call_tool should not be reached in this test")

_install_stub_server(mcp_tool, "srv", _call_tool_unused)
mcp_tool._ensure_mcp_loop()

# Open the breaker well above threshold, with a recent open-time so
# it would short-circuit everything without a reset.
mcp_tool._server_error_counts["srv"] = mcp_tool._CIRCUIT_BREAKER_THRESHOLD + 2
if hasattr(mcp_tool, "_server_breaker_opened_at"):
import time as _time
mcp_tool._server_breaker_opened_at["srv"] = _time.monotonic()

# Force handle_401 to claim recovery succeeded.
mgr = get_manager()

async def _h401(name, token=None):
return True

monkeypatch.setattr(mgr, "handle_401", _h401)

try:
# Retry fails *after* the successful reconnect. Under the old
# implementation this bumps an already-tripped counter even
# higher. Under fix #2 the reset happens on successful
# reconnect, and the post-retry bump only raises the fresh
# count to 1 β€” still below threshold.
def _retry_call():
raise OAuthFlowError("still failing post-reconnect")

result = mcp_tool._handle_auth_error_and_retry(
"srv",
OAuthFlowError("initial"),
_retry_call,
"tools/call test",
)
# The call as a whole still surfaces needs_reauth because the
# retry itself didn't succeed, but the breaker state must
# reflect the successful reconnect.
assert result is not None
parsed = json.loads(result)
assert parsed.get("needs_reauth") is True, parsed

# Post-reconnect count was reset to 0, then the failing retry
# bumped it to exactly 1 β€” well below threshold.
count = mcp_tool._server_error_counts.get("srv", 0)
assert count < mcp_tool._CIRCUIT_BREAKER_THRESHOLD, (
f"successful reconnect must reset the breaker below threshold; "
f"got count={count}, threshold={mcp_tool._CIRCUIT_BREAKER_THRESHOLD}"
)
finally:
_cleanup(mcp_tool, "srv")
Loading
Loading