Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agent/anthropic_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def _common_betas_for_base_url(base_url: str | None) -> list[str]:
return _COMMON_BETAS


def build_anthropic_client(api_key: str, base_url: str = None, timeout: float = None):
def build_anthropic_client(api_key: str, base_url: str = None, timeout: Optional[float] = None):
"""Create an Anthropic client, auto-detecting setup-tokens vs API keys.

If *timeout* is provided it overrides the default 900s read timeout. The
Expand Down
11 changes: 9 additions & 2 deletions agent/auxiliary_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,13 @@
import time
from pathlib import Path # noqa: F401 — used by test mocks
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from openai import OpenAI

if TYPE_CHECKING:
from agent.gemini_native_adapter import GeminiNativeClient

from agent.credential_pool import load_pool
from hermes_cli.config import get_hermes_home
from hermes_constants import OPENROUTER_BASE_URL
Expand Down Expand Up @@ -810,7 +813,11 @@ def _read_codex_access_token() -> Optional[str]:
return None


def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
# TODO(refactor): This function has messy types and duplicated logic (pool vs direct creds).
# Ideal fix: (1) define an AuxiliaryClient Protocol both OpenAI/GeminiNativeClient satisfy,
# (2) return a NamedTuple or dataclass instead of raw tuple, (3) extract the repeated
# Gemini/Kimi/Copilot client-building into a helper.
def _resolve_api_key_provider() -> Tuple[Optional[Union[OpenAI, "GeminiNativeClient"]], Optional[str]]:
"""Try each API-key provider in PROVIDER_REGISTRY order.

Returns (client, model) for the first provider with usable runtime
Expand Down
19 changes: 17 additions & 2 deletions agent/credential_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_save_auth_store,
_save_provider_state,
read_credential_pool,
read_provider_credentials,
write_credential_pool,
)

Expand Down Expand Up @@ -321,7 +322,7 @@ def get_custom_provider_pool_key(base_url: str) -> Optional[str]:

def list_custom_pool_providers() -> List[str]:
"""Return all 'custom:*' pool keys that have entries in auth.json."""
pool_data = read_credential_pool(None)
pool_data = read_credential_pool()
return sorted(
key for key in pool_data
if key.startswith(CUSTOM_POOL_PREFIX)
Expand Down Expand Up @@ -875,6 +876,20 @@ def remove_index(self, index: int) -> Optional[PooledCredential]:
self._current_id = None
return removed

def remove_entry(self, entry_id: str) -> Optional[PooledCredential]:
for idx, entry in enumerate(self._entries):
if entry.id == entry_id:
removed = self._entries.pop(idx)
self._entries = [
replace(e, priority=new_priority)
for new_priority, e in enumerate(self._entries)
]
self._persist()
if self._current_id == removed.id:
self._current_id = None
return removed
return None

def resolve_target(self, target: Any) -> Tuple[Optional[int], Optional[PooledCredential], Optional[str]]:
raw = str(target or "").strip()
if not raw:
Expand Down Expand Up @@ -1325,7 +1340,7 @@ def _is_suppressed(_p, _s): # type: ignore[misc]

def load_pool(provider: str) -> CredentialPool:
provider = (provider or "").strip().lower()
raw_entries = read_credential_pool(provider)
raw_entries = read_provider_credentials(provider)
entries = [PooledCredential.from_dict(provider, payload) for payload in raw_entries]

if provider.startswith(CUSTOM_POOL_PREFIX):
Expand Down
1 change: 1 addition & 0 deletions agent/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,7 @@ def _animate(self):
time.sleep(0.1)
continue
frame = self.spinner_frames[self.frame_idx % len(self.spinner_frames)]
assert self.start_time is not None # start() sets it before thread starts
elapsed = time.time() - self.start_time
if wings:
left, right = wings[self.frame_idx % len(wings)]
Expand Down
3 changes: 2 additions & 1 deletion agent/skill_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,8 @@ def parse_qualified_name(name: str) -> Tuple[Optional[str], str]:
"""
if ":" not in name:
return None, name
return tuple(name.split(":", 1)) # type: ignore[return-value]
ns, bare = name.split(":", 1)
return ns, bare


def is_valid_namespace(candidate: Optional[str]) -> bool:
Expand Down
40 changes: 34 additions & 6 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from contextlib import contextmanager
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any, Optional
from typing import List, Dict, Any, Optional, TypedDict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -84,6 +84,34 @@
load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env)


class _ModelPickerState(TypedDict, total=False):
stage: str
providers: List[Dict[str, Any]]
selected: int
current_model: str
current_provider: str
user_provs: Optional[Dict[str, Any]]
custom_provs: Optional[Dict[str, Any]]
provider_data: Dict[str, Any]
model_list: List[str]


class _ApprovalState(TypedDict, total=False):
command: str
description: str
choices: List[str]
selected: int
response_queue: "queue.Queue[str]"
show_full: bool


class _ClarifyState(TypedDict, total=False):
question: str
choices: List[str]
selected: int
response_queue: "queue.Queue[str]"


_REASONING_TAGS = (
"REASONING_SCRATCHPAD",
"think",
Expand Down Expand Up @@ -1728,7 +1756,7 @@ def _parse_skills_argument(skills: str | list[str] | tuple[str, ...] | None) ->
return parsed


def save_config_value(key_path: str, value: any) -> bool:
def save_config_value(key_path: str, value: Any) -> bool:
"""
Save a value to the active config file at the specified key path.

Expand Down Expand Up @@ -2065,16 +2093,16 @@ def __init__(
self._interrupt_queue = queue.Queue()
self._should_exit = False
self._last_ctrl_c_time = 0
self._clarify_state = None
self._clarify_state: Optional[_ClarifyState] = None
self._clarify_freetext = False
self._clarify_deadline = 0
self._sudo_state = None
self._sudo_deadline = 0
self._modal_input_snapshot = None
self._approval_state = None
self._approval_state: Optional[_ApprovalState] = None
self._approval_deadline = 0
self._approval_lock = threading.Lock()
self._model_picker_state = None
self._model_picker_state: Optional[_ModelPickerState] = None
self._secret_state = None
self._secret_deadline = 0
self._spinner_text: str = "" # thinking spinner text for TUI
Expand Down Expand Up @@ -7156,7 +7184,7 @@ def _show_usage(self):
logging.getLogger(noisy).setLevel(logging.WARNING)
else:
logging.getLogger().setLevel(logging.INFO)
for quiet_logger in ('tools', 'run_agent', 'trajectory_compressor', 'cron', 'hermes_cli'):
for quiet_logger in ('tools', 'run_agent', 'scripts.trajectory_compressor', 'cron', 'hermes_cli'):
logging.getLogger(quiet_logger).setLevel(logging.ERROR)

def _show_insights(self, command: str = "/insights"):
Expand Down
5 changes: 3 additions & 2 deletions cron/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,9 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option
delivery_errors.append(msg)
continue

if result and result.get("error"):
msg = f"delivery error: {result['error']}"
error = result.get("error") if result else None
if error:
msg = f"delivery error: {error}"
logger.error("Job '%s': %s", job["id"], msg)
delivery_errors.append(msg)
continue
Expand Down
2 changes: 1 addition & 1 deletion datagen-config-examples/run_browser_tasks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ echo "📝 Logging to: $LOG_FILE"
# Point to the example dataset in this directory
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"

python batch_runner.py \
python scripts/batch_runner.py \
--dataset_file="$SCRIPT_DIR/example_browser_tasks.jsonl" \
--batch_size=5 \
--run_name="browser_tasks_example" \
Expand Down
2 changes: 1 addition & 1 deletion datagen-config-examples/web_research.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Generates tool-calling trajectories for multi-step web research tasks.
#
# Usage:
# python batch_runner.py \
# python scripts/batch_runner.py \
# --config datagen-config-examples/web_research.yaml \
# --run_name web_research_v1

Expand Down
5 changes: 4 additions & 1 deletion environments/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
import os
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING

if TYPE_CHECKING:
from tools.budget_config import BudgetConfig

from model_tools import handle_function_call
from tools.terminal_tool import get_active_env
Expand Down
109 changes: 41 additions & 68 deletions gateway/platforms/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,7 @@
import time
import uuid
from typing import Any, Dict, List, Optional

try:
from aiohttp import web
AIOHTTP_AVAILABLE = True
except ImportError:
AIOHTTP_AVAILABLE = False
web = None # type: ignore[assignment]

from aiohttp import web
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import (
BasePlatformAdapter,
Expand Down Expand Up @@ -270,12 +263,6 @@ def _multimodal_validation_error(exc: ValueError, *, param: str) -> "web.Respons
status=400,
)


def check_api_server_requirements() -> bool:
"""Check if API server dependencies are available."""
return AIOHTTP_AVAILABLE


class ResponseStore:
"""
SQLite-backed LRU store for Responses API state.
Expand Down Expand Up @@ -391,30 +378,26 @@ def __len__(self) -> int:
}


if AIOHTTP_AVAILABLE:
@web.middleware
async def cors_middleware(request, handler):
"""Add CORS headers for explicitly allowed origins; handle OPTIONS preflight."""
adapter = request.app.get("api_server_adapter")
origin = request.headers.get("Origin", "")
cors_headers = None
if adapter is not None:
if not adapter._origin_allowed(origin):
return web.Response(status=403)
cors_headers = adapter._cors_headers_for_origin(origin)

if request.method == "OPTIONS":
if cors_headers is None:
return web.Response(status=403)
return web.Response(status=200, headers=cors_headers)

response = await handler(request)
if cors_headers is not None:
response.headers.update(cors_headers)
return response
else:
cors_middleware = None # type: ignore[assignment]
@web.middleware
async def cors_middleware(request, handler):
"""Add CORS headers for explicitly allowed origins; handle OPTIONS preflight."""
adapter = request.app.get("api_server_adapter")
origin = request.headers.get("Origin", "")
cors_headers = None
if adapter is not None:
if not adapter._origin_allowed(origin):
return web.Response(status=403)
cors_headers = adapter._cors_headers_for_origin(origin)

if request.method == "OPTIONS":
if cors_headers is None:
return web.Response(status=403)
return web.Response(status=200, headers=cors_headers)

response = await handler(request)
if cors_headers is not None:
response.headers.update(cors_headers)
return response

def _openai_error(message: str, err_type: str = "invalid_request_error", param: str = None, code: str = None) -> Dict[str, Any]:
"""OpenAI-style error envelope."""
Expand All @@ -428,38 +411,32 @@ def _openai_error(message: str, err_type: str = "invalid_request_error", param:
}


if AIOHTTP_AVAILABLE:
@web.middleware
async def body_limit_middleware(request, handler):
"""Reject overly large request bodies early based on Content-Length."""
if request.method in ("POST", "PUT", "PATCH"):
cl = request.headers.get("Content-Length")
if cl is not None:
try:
if int(cl) > MAX_REQUEST_BYTES:
return web.json_response(_openai_error("Request body too large.", code="body_too_large"), status=413)
except ValueError:
return web.json_response(_openai_error("Invalid Content-Length header.", code="invalid_content_length"), status=400)
return await handler(request)
else:
body_limit_middleware = None # type: ignore[assignment]
@web.middleware
async def body_limit_middleware(request, handler):
"""Reject overly large request bodies early based on Content-Length."""
if request.method in ("POST", "PUT", "PATCH"):
cl = request.headers.get("Content-Length")
if cl is not None:
try:
if int(cl) > MAX_REQUEST_BYTES:
return web.json_response(_openai_error("Request body too large.", code="body_too_large"), status=413)
except ValueError:
return web.json_response(_openai_error("Invalid Content-Length header.", code="invalid_content_length"), status=400)
return await handler(request)

_SECURITY_HEADERS = {
"X-Content-Type-Options": "nosniff",
"Referrer-Policy": "no-referrer",
}


if AIOHTTP_AVAILABLE:
@web.middleware
async def security_headers_middleware(request, handler):
"""Add security headers to all responses (including errors)."""
response = await handler(request)
for k, v in _SECURITY_HEADERS.items():
response.headers.setdefault(k, v)
return response
else:
security_headers_middleware = None # type: ignore[assignment]
@web.middleware
async def security_headers_middleware(request, handler):
"""Add security headers to all responses (including errors)."""
response = await handler(request)
for k, v in _SECURITY_HEADERS.items():
response.headers.setdefault(k, v)
return response


class _IdempotencyCache:
Expand Down Expand Up @@ -804,7 +781,7 @@ async def _handle_models(self, request: "web.Request") -> "web.Response":
],
})

async def _handle_chat_completions(self, request: "web.Request") -> "web.Response":
async def _handle_chat_completions(self, request: "web.Request") -> "web.StreamResponse":
"""POST /v1/chat/completions — OpenAI Chat Completions format."""
auth_err = self._check_auth(request)
if auth_err:
Expand Down Expand Up @@ -1588,7 +1565,7 @@ async def _dispatch(it) -> None:

return response

async def _handle_responses(self, request: "web.Request") -> "web.Response":
async def _handle_responses(self, request: "web.Request") -> "web.StreamResponse":
"""POST /v1/responses — OpenAI Responses API format."""
auth_err = self._check_auth(request)
if auth_err:
Expand Down Expand Up @@ -2482,10 +2459,6 @@ async def _sweep_orphaned_runs(self) -> None:

async def connect(self) -> bool:
"""Start the aiohttp web server."""
if not AIOHTTP_AVAILABLE:
logger.warning("[%s] aiohttp not installed", self.name)
return False

try:
mws = [mw for mw in (cors_middleware, body_limit_middleware, security_headers_middleware) if mw is not None]
self._app = web.Application(middlewares=mws)
Expand Down
Loading
Loading