diff --git a/FEATURE_PARITY.md b/FEATURE_PARITY.md index 85348de539..711eaccf14 100644 --- a/FEATURE_PARITY.md +++ b/FEATURE_PARITY.md @@ -224,6 +224,7 @@ This document tracks feature parity between IronClaw (Rust implementation) and O | suppressToolErrors config | ✅ | ❌ | Hide tool errors from user | | Intent-first tool display | ✅ | ❌ | Details and exec summaries | | Transcript file size in status | ✅ | ❌ | Show size in session status | +| A2A (Agent-to-Agent) bridge | ❌ | ✅ | Google A2A protocol (JSON-RPC 2.0 + SSE), configurable tool name/endpoint | ### Owner: _Unassigned_ diff --git a/scripts/test-a2a-bridge.sh b/scripts/test-a2a-bridge.sh new file mode 100755 index 0000000000..efd2ac9ec5 --- /dev/null +++ b/scripts/test-a2a-bridge.sh @@ -0,0 +1,89 @@ +#!/usr/bin/env bash +# Test script for the A2A bridge tool. +# +# Usage: +# # Run unit + integration tests (no external server needed) +# ./scripts/test-a2a-bridge.sh +# +# # Run live E2E test against a real A2A agent +# A2A_AGENT_URL=http://your-agent:5085 \ +# A2A_ASSISTANT_ID=your-assistant-id \ +# ./scripts/test-a2a-bridge.sh --live +# +# Environment variables (for --live mode): +# A2A_AGENT_URL Base URL of the A2A-compatible agent server (required) +# A2A_ASSISTANT_ID Assistant/graph ID to query (required) + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +cd "$PROJECT_ROOT" + +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[0;33m' +NC='\033[0m' + +pass() { echo -e "${GREEN}✓ $1${NC}"; } +fail() { echo -e "${RED}✗ $1${NC}"; exit 1; } +info() { echo -e "${YELLOW}► $1${NC}"; } + +LIVE=false +for arg in "$@"; do + case "$arg" in + --live) LIVE=true ;; + esac +done + +# ── Step 1: Format check ──────────────────────────────────────────── +info "Checking formatting..." +cargo fmt --check -- src/tools/builtin/a2a/*.rs src/config/a2a.rs \ + 2>/dev/null && pass "cargo fmt" || fail "cargo fmt" + +# ── Step 2: Clippy ────────────────────────────────────────────────── +info "Running clippy on A2A modules..." +cargo clippy -p ironclaw --all-features -- -D warnings \ + 2>&1 | tail -3 +pass "cargo clippy" + +# ── Step 3: Unit tests ────────────────────────────────────────────── +info "Running A2A unit tests..." +cargo test --lib -- a2a 2>&1 | tail -5 +pass "unit tests" + +# ── Step 4: Integration tests (construction only) ─────────────────── +info "Running A2A integration tests (construction)..." +cargo test --test a2a_bridge_integration 2>&1 | tail -5 +pass "integration tests" + +# ── Step 5: Feature-flag compilation ──────────────────────────────── +info "Checking libsql feature compilation..." +cargo check --no-default-features --features libsql 2>&1 | tail -3 +pass "libsql feature check" + +# ── Step 6 (optional): Live E2E test ──────────────────────────────── +if [ "$LIVE" = true ]; then + if [ -z "${A2A_AGENT_URL:-}" ] || [ -z "${A2A_ASSISTANT_ID:-}" ]; then + fail "Live test requires A2A_AGENT_URL and A2A_ASSISTANT_ID env vars" + fi + + info "Running live A2A test against $A2A_AGENT_URL ..." + + # Quick connectivity check + HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" \ + --connect-timeout 5 "$A2A_AGENT_URL/info" 2>/dev/null || echo "000") + if [ "$HTTP_CODE" = "000" ]; then + fail "Cannot reach $A2A_AGENT_URL (connection refused or timeout)" + fi + pass "server reachable (HTTP $HTTP_CODE)" + + # Run the ignored live test + A2A_AGENT_URL="$A2A_AGENT_URL" \ + A2A_ASSISTANT_ID="$A2A_ASSISTANT_ID" \ + cargo test --test a2a_bridge_integration -- --ignored 2>&1 | tail -5 + pass "live E2E test" +fi + +echo "" +echo -e "${GREEN}All A2A bridge tests passed.${NC}" diff --git a/src/config/a2a.rs b/src/config/a2a.rs new file mode 100644 index 0000000000..3c58ffa661 --- /dev/null +++ b/src/config/a2a.rs @@ -0,0 +1,135 @@ +use std::time::Duration; + +use crate::config::helpers::{parse_bool_env, parse_optional_env, parse_string_env}; +use crate::error::ConfigError; + +/// Configuration for the A2A (Agent-to-Agent) protocol bridge. +/// +/// Connects to a remote agent via the Google A2A protocol (JSON-RPC 2.0 + SSE +/// streaming). All agent-specific values (URL, assistant ID) must be set +/// explicitly — no hardcoded defaults. +#[derive(Debug, Clone)] +pub struct A2aConfig { + /// Whether the A2A bridge is enabled. + pub enabled: bool, + /// Base URL of the remote agent (required when enabled). + pub agent_url: String, + /// Assistant ID for the remote agent (required when enabled). + pub assistant_id: String, + /// Tool name exposed to the LLM (default: `"a2a_query"`). + pub tool_name: String, + /// Tool description exposed to the LLM. + pub tool_description: String, + /// Prefix for push-notification messages from the background SSE consumer. + pub message_prefix: String, + /// Timeout for reading the first SSE event after connection. + pub request_timeout: Duration, + /// Timeout for the entire background SSE stream consumption. + pub task_timeout: Duration, + /// Secret name in the secrets store for the API key. + pub api_key_secret: String, +} + +impl A2aConfig { + pub(crate) fn resolve() -> Result, ConfigError> { + let enabled = parse_bool_env("A2A_ENABLED", false)?; + if !enabled { + return Ok(None); + } + + let agent_url = parse_string_env("A2A_AGENT_URL", "")?; + if agent_url.is_empty() { + return Err(ConfigError::InvalidValue { + key: "A2A_AGENT_URL".to_string(), + message: "must be set when A2A_ENABLED=true".to_string(), + }); + } + + let assistant_id = parse_string_env("A2A_ASSISTANT_ID", "")?; + if assistant_id.is_empty() { + return Err(ConfigError::InvalidValue { + key: "A2A_ASSISTANT_ID".to_string(), + message: "must be set when A2A_ENABLED=true".to_string(), + }); + } + + let tool_name = parse_string_env("A2A_TOOL_NAME", "a2a_query")?; + let tool_description = parse_string_env( + "A2A_TOOL_DESCRIPTION", + "Query a remote AI agent via the A2A (Agent-to-Agent) protocol. \ + Supports multi-turn conversations with thread_id for context continuity.", + )?; + let message_prefix = parse_string_env("A2A_MESSAGE_PREFIX", "[a2a]")?; + let request_timeout_ms: u64 = parse_optional_env("A2A_REQUEST_TIMEOUT_MS", 60_000)?; + let task_timeout_ms: u64 = parse_optional_env("A2A_TASK_TIMEOUT_MS", 1_200_000)?; + let api_key_secret = parse_string_env("A2A_API_KEY_SECRET", "a2a_api_key")?; + + Ok(Some(Self { + enabled, + agent_url, + assistant_id, + tool_name, + tool_description, + message_prefix, + request_timeout: Duration::from_millis(request_timeout_ms), + task_timeout: Duration::from_millis(task_timeout_ms), + api_key_secret, + })) + } + + /// Whether the API key secret name is configured (non-empty). + pub fn has_api_key_configured(&self) -> bool { + !self.api_key_secret.is_empty() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn disabled_by_default() { + let _guard = crate::config::helpers::ENV_MUTEX.lock(); + unsafe { + std::env::remove_var("A2A_ENABLED"); + } + let result = A2aConfig::resolve().unwrap(); + assert!(result.is_none()); + } + + #[test] + fn requires_agent_url_when_enabled() { + let _guard = crate::config::helpers::ENV_MUTEX.lock(); + unsafe { + std::env::set_var("A2A_ENABLED", "true"); + std::env::remove_var("A2A_AGENT_URL"); + } + let result = A2aConfig::resolve(); + assert!(result.is_err()); + unsafe { + std::env::remove_var("A2A_ENABLED"); + } + } + + #[test] + fn has_api_key_configured_checks_non_empty() { + let config = A2aConfig { + enabled: true, + agent_url: "https://example.com".to_string(), + assistant_id: "test-id".to_string(), + tool_name: "a2a_query".to_string(), + tool_description: "test".to_string(), + message_prefix: "[a2a]".to_string(), + request_timeout: Duration::from_secs(60), + task_timeout: Duration::from_secs(1200), + api_key_secret: "my_key".to_string(), + }; + assert!(config.has_api_key_configured()); + + let config_empty = A2aConfig { + api_key_secret: String::new(), + ..config + }; + assert!(!config_empty.has_api_key_configured()); + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs index 38c8088050..f91117f19a 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -5,6 +5,7 @@ //! in startup). Everything else comes from env vars, the DB settings //! table, or auto-detection. +mod a2a; mod agent; mod builder; mod channels; @@ -32,6 +33,7 @@ use crate::error::ConfigError; use crate::settings::Settings; // Re-export all public types so `crate::config::FooConfig` continues to work. +pub use self::a2a::A2aConfig; pub use self::agent::AgentConfig; pub use self::builder::BuilderModeConfig; pub use self::channels::{ @@ -102,6 +104,9 @@ pub struct Config { /// Channel-relay integration (Slack via external relay service). /// Present only when both `CHANNEL_RELAY_URL` and `CHANNEL_RELAY_API_KEY` are set. pub relay: Option, + /// A2A bridge configuration for connecting to remote agents. + /// Present only when `A2A_ENABLED=true`. + pub a2a: Option, } impl Config { @@ -177,6 +182,7 @@ impl Config { search: WorkspaceSearchConfig::default(), observability: crate::observability::ObservabilityConfig::default(), relay: None, + a2a: None, } } @@ -329,6 +335,7 @@ impl Config { backend: std::env::var("OBSERVABILITY_BACKEND").unwrap_or_else(|_| "none".into()), }, relay: RelayConfig::from_env(), + a2a: A2aConfig::resolve()?, }) } } diff --git a/src/main.rs b/src/main.rs index 745cae09b4..1cc11f937a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,6 +23,7 @@ use ironclaw::{ llm::create_session_manager, orchestrator::{ReaperConfig, SandboxReaper}, pairing::PairingStore, + tools::Tool, tracing_fmt::{init_cli_tracing, init_worker_tracing}, webhooks::{self, ToolWebhookState}, }; @@ -466,6 +467,34 @@ async fn async_main() -> anyhow::Result<()> { components.secrets_store.clone(), ); + // ── A2A bridge tool ──────────────────────────────────────────────── + if let Some(ref a2a_config) = config.a2a { + if let Some(ref ss) = components.secrets_store { + match ironclaw::tools::builtin::A2aBridgeTool::new( + a2a_config.clone(), + Arc::clone(ss), + channels.inject_sender(), + ) + .await + { + Ok(tool) => { + let tool_name = tool.name().to_string(); + components.tools.register_sync(Arc::new(tool)); + tracing::info!( + tool = %tool_name, + url = %a2a_config.agent_url, + "A2A bridge enabled" + ); + } + Err(e) => { + tracing::error!("A2A bridge initialization failed: {}", e); + } + } + } else { + tracing::warn!("A2A bridge enabled but no secrets store available — skipping"); + } + } + // ── Gateway channel ──────────────────────────────────────────────── let mut gateway_url: Option = None; diff --git a/src/tools/builtin/a2a/bridge.rs b/src/tools/builtin/a2a/bridge.rs new file mode 100644 index 0000000000..7973abc4e7 --- /dev/null +++ b/src/tools/builtin/a2a/bridge.rs @@ -0,0 +1,675 @@ +//! A2A bridge tool — connects to a remote agent via the A2A protocol. + +use std::net::IpAddr; +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use futures::StreamExt; +use reqwest::Client; +use tokio::sync::mpsc; + +use crate::channels::IncomingMessage; +use crate::config::A2aConfig; +use crate::context::JobContext; +use crate::safety::LeakDetector; +use crate::secrets::SecretsStore; +use crate::tools::tool::{ApprovalRequirement, Tool, ToolError, ToolOutput, require_str}; + +use super::protocol::{ + EventKind, build_jsonrpc_request, classify_event, extract_text_from_result, + has_message_content, parse_sse_events, result_has_text_parts, truncate_str, +}; + +/// Maximum SSE buffer size (10 MB) — same cap as MCP HTTP transport. +const MAX_SSE_BUFFER: usize = 10 * 1024 * 1024; + +/// Maximum summary length for push notifications. +const MAX_SUMMARY_LEN: usize = 2000; + +/// A2A bridge tool that delegates queries to a remote agent. +pub struct A2aBridgeTool { + client: Client, + config: A2aConfig, + secrets_store: Arc, + inject_tx: mpsc::Sender, + leak_detector: LeakDetector, +} + +impl A2aBridgeTool { + /// Create a new A2A bridge tool. + /// + /// The agent URL is validated for SSRF at construction time. Returns an error + /// if the URL points to a private/local address. + pub async fn new( + config: A2aConfig, + secrets_store: Arc, + inject_tx: mpsc::Sender, + ) -> Result { + // Validate agent URL at construction time (defense in depth) + validate_agent_url(&config.agent_url).await?; + + // H5: No-redirect policy to prevent SSRF via redirect + let client = Client::builder() + .connect_timeout(Duration::from_secs(30)) + .redirect(reqwest::redirect::Policy::none()) + .build() + .map_err(|e| { + ToolError::ExternalService(format!("failed to build HTTP client: {}", e)) + })?; + + Ok(Self { + client, + config, + secrets_store, + inject_tx, + leak_detector: LeakDetector::new(), + }) + } + + /// Build the full A2A endpoint URL. + fn endpoint_url(&self) -> String { + let base = self.config.agent_url.trim_end_matches('/'); + format!("{}/a2a/{}", base, self.config.assistant_id) + } +} + +#[async_trait] +impl Tool for A2aBridgeTool { + fn name(&self) -> &str { + &self.config.tool_name + } + + fn description(&self) -> &str { + &self.config.tool_description + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Natural language query for the remote agent" + }, + "context": { + "type": "object", + "description": "Optional structured context passed alongside the query" + }, + "thread_id": { + "type": "string", + "description": "Thread ID for multi-turn conversations. Reuse to continue a previous session." + } + }, + "required": ["query"] + }) + } + + async fn execute( + &self, + params: serde_json::Value, + ctx: &JobContext, + ) -> Result { + let start = std::time::Instant::now(); + + let query = require_str(¶ms, "query")?; + let context = params.get("context"); + let thread_id = params.get("thread_id").and_then(|v| v.as_str()); + + // C2: Scan outgoing content for secret leaks + let query_bytes = query.as_bytes(); + self.leak_detector + .scan_http_request(&self.endpoint_url(), &[], Some(query_bytes)) + .map_err(|e| { + ToolError::NotAuthorized(format!("leak detection blocked request: {}", e)) + })?; + + if let Some(ctx_val) = context { + let ctx_str = serde_json::to_string(ctx_val).unwrap_or_default(); + self.leak_detector + .scan_http_request(&self.endpoint_url(), &[], Some(ctx_str.as_bytes())) + .map_err(|e| { + ToolError::NotAuthorized(format!("leak detection blocked context: {}", e)) + })?; + } + + // Try to get API key from secrets store (optional — agent may not require auth) + let api_key = self + .secrets_store + .get_decrypted(&ctx.user_id, &self.config.api_key_secret) + .await + .ok(); + + // Build JSON-RPC request + let body = build_jsonrpc_request(query, context, thread_id); + let url = self.endpoint_url(); + + // Send POST — accept both SSE and JSON so the server can pick. + // LangGraph requires application/json to be present in the Accept + // header; pure text/event-stream is rejected. + let mut request = self + .client + .post(&url) + .header("Accept", "text/event-stream, application/json") + .header("Content-Type", "application/json") + .header("A2A-Version", "1.0.0"); + + if let Some(ref key) = api_key { + request = request.bearer_auth(key.expose()); + } + + // M2: Use configured request_timeout for the initial connection + let response = + tokio::time::timeout(self.config.request_timeout, request.json(&body).send()) + .await + .map_err(|_| ToolError::Timeout(self.config.request_timeout))? + .map_err(|e| { + if e.is_timeout() { + ToolError::Timeout(self.config.request_timeout) + } else { + ToolError::ExternalService(format!("A2A request failed: {}", e)) + } + })?; + + let status = response.status(); + if !status.is_success() { + let error_body = response.text().await.unwrap_or_default(); + return Err(ToolError::ExternalService(format!( + "A2A agent returned HTTP {}: {}", + status, error_body + ))); + } + + // Check content-type: if the server returned JSON instead of SSE, + // parse the full body directly (LangGraph returns application/json). + let is_json_response = response + .headers() + .get("content-type") + .and_then(|v| v.to_str().ok()) + .map(|ct| ct.contains("application/json")) + .unwrap_or(false); + + if is_json_response { + let body = response.text().await.map_err(|e| { + ToolError::ExternalService(format!("failed to read JSON response: {}", e)) + })?; + let parsed: serde_json::Value = serde_json::from_str(&body).map_err(|e| { + ToolError::ExternalService(format!("invalid JSON from A2A agent: {}", e)) + })?; + + // Check for JSON-RPC error + if let Some(err) = parsed.get("error") { + let msg = err + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("unknown error"); + return Err(ToolError::ExternalService(format!( + "A2A agent error: {}", + msg + ))); + } + + let result = parsed + .get("result") + .cloned() + .unwrap_or(serde_json::Value::Null); + let summary = extract_text_from_result(&result, MAX_SUMMARY_LEN); + let result_json = serde_json::json!({ + "status": "completed", + "result": summary, + }); + return Ok(ToolOutput::success(result_json, start.elapsed())); + } + + // SSE path: read first event to determine sync vs async + let mut stream = response.bytes_stream(); + let mut buffer = String::new(); + + let first_event = tokio::time::timeout(self.config.request_timeout, async { + while let Some(chunk) = stream.next().await { + let chunk = chunk.map_err(|e| { + ToolError::ExternalService(format!("failed to read SSE chunk: {}", e)) + })?; + buffer.push_str(&String::from_utf8_lossy(&chunk)); + + if buffer.len() > MAX_SSE_BUFFER { + return Err(ToolError::ExternalService( + "SSE buffer exceeded 10 MB limit".to_string(), + )); + } + + // Try to parse complete SSE events from buffer + let events = parse_sse_events(&mut buffer); + if let Some(event) = events.into_iter().next() { + return Ok(event); + } + } + Err(ToolError::ExternalService( + "SSE stream ended without any events".to_string(), + )) + }) + .await + .map_err(|_| ToolError::Timeout(self.config.request_timeout))??; + + // Handle first event + match classify_event(&first_event) { + EventKind::Error(msg) => Err(ToolError::ExternalService(format!( + "A2A agent error: {}", + msg + ))), + EventKind::Final(result) => { + let summary = extract_text_from_result(&result, MAX_SUMMARY_LEN); + let result_json = serde_json::json!({ + "status": "completed", + "result": summary, + }); + Ok(ToolOutput::success(result_json, start.elapsed())) + } + EventKind::InProgress { + task_id, + context_id, + } => { + // Spawn background consumer with cancellation via inject_tx closure + let inject_tx = self.inject_tx.clone(); + let task_timeout = self.config.task_timeout; + let query_owned = query.to_string(); + let task_id_for_spawn = task_id.clone(); + let message_prefix = self.config.message_prefix.clone(); + + tokio::spawn(async move { + spawn_stream_consumer( + stream, + buffer, + inject_tx, + task_timeout, + query_owned, + task_id_for_spawn, + message_prefix, + ) + .await; + }); + + let short_id = &task_id[..8.min(task_id.len())]; + let mut result_json = serde_json::json!({ + "status": "submitted", + "task_id": task_id, + "message": format!( + "Query submitted (task: {}). Results will be pushed when ready.", + short_id + ), + }); + + // H3: Include context ID for multi-turn support + if let Some(cid) = context_id { + result_json["context_id"] = serde_json::Value::String(cid); + } + + Ok(ToolOutput::success(result_json, start.elapsed())) + } + } + } + + fn estimated_duration(&self, _params: &serde_json::Value) -> Option { + Some(Duration::from_secs(10)) + } + + fn requires_sanitization(&self) -> bool { + true // External data always needs sanitization + } + + // M3: Always require approval — sends user content to an external service + fn requires_approval(&self, _params: &serde_json::Value) -> ApprovalRequirement { + ApprovalRequirement::Always + } + + fn execution_timeout(&self) -> Duration { + // Controls the initial request phase (reading first SSE event). + // The background consumer has its own timeout via task_timeout. + Duration::from_secs(600) + } + + fn rate_limit_config(&self) -> Option { + Some(crate::tools::tool::ToolRateLimitConfig::new(10, 100)) + } +} + +// ── SSRF validation ───────────────────────────────────────────────── + +/// Validate an A2A agent URL for SSRF protection. +/// +/// Unlike `HttpTool::validate_url()`, this allows both HTTP and HTTPS schemes +/// (the operator chooses the protocol), but still blocks localhost and private IPs +/// to prevent SSRF. +async fn validate_agent_url(url: &str) -> Result<(), ToolError> { + let parsed = reqwest::Url::parse(url) + .map_err(|e| ToolError::InvalidParameters(format!("invalid agent URL: {}", e)))?; + + let scheme = parsed.scheme(); + if scheme != "http" && scheme != "https" { + return Err(ToolError::InvalidParameters(format!( + "A2A agent URL must use http or https scheme, got '{}'", + scheme + ))); + } + + let host = parsed + .host_str() + .ok_or_else(|| ToolError::InvalidParameters("agent URL missing host".to_string()))?; + + let host_lower = host.to_lowercase(); + if host_lower == "localhost" || host_lower.ends_with(".localhost") { + return Err(ToolError::NotAuthorized( + "A2A agent URL must not point to localhost".to_string(), + )); + } + + // Block literal private/local IPs + if let Ok(ip) = host.parse::() + && is_disallowed_ip(&ip) + { + return Err(ToolError::NotAuthorized( + "A2A agent URL must not point to a private or local IP".to_string(), + )); + } + + // DNS resolution check — prevent rebinding to private IPs + let port = parsed.port_or_known_default().unwrap_or(443); + if let Ok(addrs) = tokio::net::lookup_host((host, port)).await { + for addr in addrs { + if is_disallowed_ip(&addr.ip()) { + return Err(ToolError::NotAuthorized(format!( + "A2A agent hostname '{}' resolves to disallowed IP {}", + host, + addr.ip() + ))); + } + } + } + + Ok(()) +} + +/// Check if an IP address is private, loopback, link-local, or otherwise +/// disallowed for outbound requests. +fn is_disallowed_ip(ip: &IpAddr) -> bool { + match ip { + IpAddr::V4(v4) => { + v4.is_private() + || v4.is_loopback() + || v4.is_link_local() + || v4.is_multicast() + || v4.is_unspecified() + || *v4 == std::net::Ipv4Addr::new(169, 254, 169, 254) // AWS metadata + } + IpAddr::V6(v6) => { + v6.is_loopback() + || v6.is_unique_local() + || v6.is_unicast_link_local() + || v6.is_multicast() + || v6.is_unspecified() + } + } +} + +// ── Background SSE consumer ───────────────────────────────────────── + +/// Background SSE stream consumer that reads remaining events and pushes +/// the final result back to the agent loop via `inject_tx`. +/// +/// Implements H6: checks `inject_tx.is_closed()` each iteration so the +/// task terminates promptly when the session ends. +async fn spawn_stream_consumer( + stream: impl futures::Stream> + Unpin + Send, + mut buffer: String, + inject_tx: mpsc::Sender, + task_timeout: Duration, + query: String, + task_id: String, + message_prefix: String, +) { + let short_id = &task_id[..8.min(task_id.len())]; + + let result = tokio::time::timeout(task_timeout, async { + let mut pinned_stream = std::pin::pin!(stream); + let mut last_content_event: Option = None; + + while let Some(chunk) = pinned_stream.next().await { + // H6: Stop if the channel is closed (session ended) + if inject_tx.is_closed() { + tracing::debug!(task_id = %short_id, "A2A: inject channel closed, stopping consumer"); + return Err("session ended".to_string()); + } + + let chunk = match chunk { + Ok(c) => c, + Err(e) => return Err(format!("SSE stream error: {}", e)), + }; + + buffer.push_str(&String::from_utf8_lossy(&chunk)); + + if buffer.len() > MAX_SSE_BUFFER { + return Err("SSE buffer exceeded 10 MB limit".to_string()); + } + + // Process all complete events in the buffer + for event in parse_sse_events(&mut buffer) { + match classify_event(&event) { + EventKind::Error(msg) => { + return Err(format!("A2A agent error: {}", msg)); + } + EventKind::Final(result) => { + // The final event often has only status metadata (no text). + // Prefer last_content_event if the final result lacks text parts. + if !result_has_text_parts(&result) + && let Some(prev) = last_content_event + { + return Ok(prev); + } + return Ok(result); + } + EventKind::InProgress { .. } => { + // Track events that carry message text + if has_message_content(&event.raw) { + last_content_event = Some( + event + .raw + .get("result") + .cloned() + .unwrap_or(event.raw.clone()), + ); + } + } + } + } + } + + // Stream ended — try remaining buffer + for event in parse_sse_events(&mut buffer) { + if let EventKind::Final(result) = classify_event(&event) { + return Ok(result); + } + } + + // Return last content event if we have one + if let Some(last) = last_content_event { + return Ok(last); + } + + Err("SSE stream ended without final result".to_string()) + }) + .await; + + let query_preview = truncate_str(&query, 60); + let msg = match result { + Ok(Ok(result)) => { + let summary = extract_text_from_result(&result, MAX_SUMMARY_LEN); + IncomingMessage::new( + "a2a_bridge", + "system", + format!( + "{} Query completed — \"{}\"\n\n{}", + message_prefix, query_preview, summary + ), + ) + } + Ok(Err(e)) => IncomingMessage::new( + "a2a_bridge", + "system", + format!( + "{} Query failed (task: {}) — {}", + message_prefix, short_id, e + ), + ), + Err(_) => IncomingMessage::new( + "a2a_bridge", + "system", + format!( + "{} Query timed out (task: {}) — waited {}s", + message_prefix, + short_id, + task_timeout.as_secs() + ), + ), + }; + + if inject_tx.send(msg).await.is_err() { + tracing::debug!( + task_id = %short_id, + "A2A inject channel closed, result dropped" + ); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn validate_url_rejects_localhost() { + assert!(validate_agent_url("http://localhost:5085").await.is_err()); + assert!( + validate_agent_url("https://app.localhost:5085/a2a") + .await + .is_err() + ); + } + + #[tokio::test] + async fn validate_url_rejects_private_ips() { + assert!(validate_agent_url("http://192.168.1.1:5085").await.is_err()); + assert!(validate_agent_url("http://10.0.0.1:5085").await.is_err()); + assert!(validate_agent_url("http://172.16.0.1:5085").await.is_err()); + assert!(validate_agent_url("http://127.0.0.1:5085").await.is_err()); + } + + #[tokio::test] + async fn validate_url_rejects_aws_metadata() { + assert!( + validate_agent_url("http://169.254.169.254/latest/meta-data") + .await + .is_err() + ); + } + + #[tokio::test] + async fn validate_url_rejects_bad_scheme() { + assert!(validate_agent_url("ftp://example.com/a2a").await.is_err()); + assert!(validate_agent_url("file:///etc/passwd").await.is_err()); + } + + #[tokio::test] + async fn validate_url_accepts_public_https() { + assert!( + validate_agent_url("https://api.example.com:5085") + .await + .is_ok() + ); + } + + #[tokio::test] + async fn validate_url_accepts_public_http() { + // HTTP is allowed (operator's choice), unlike HttpTool which requires HTTPS + assert!( + validate_agent_url("http://api.example.com:5085") + .await + .is_ok() + ); + } + + #[test] + fn is_disallowed_ip_checks() { + assert!(is_disallowed_ip(&"127.0.0.1".parse().unwrap())); + assert!(is_disallowed_ip(&"10.0.0.1".parse().unwrap())); + assert!(is_disallowed_ip(&"192.168.0.1".parse().unwrap())); + assert!(is_disallowed_ip(&"172.16.0.1".parse().unwrap())); + assert!(is_disallowed_ip(&"169.254.169.254".parse().unwrap())); + assert!(is_disallowed_ip(&"::1".parse().unwrap())); + assert!(!is_disallowed_ip(&"8.8.8.8".parse().unwrap())); + } + + #[tokio::test] + async fn schema_has_required_query() { + let config = A2aConfig { + enabled: true, + agent_url: "https://example.com".to_string(), + assistant_id: "test".to_string(), + tool_name: "a2a_query".to_string(), + tool_description: "test tool".to_string(), + message_prefix: "[a2a]".to_string(), + request_timeout: Duration::from_secs(60), + task_timeout: Duration::from_secs(1200), + api_key_secret: "key".to_string(), + }; + let (tx, _rx) = mpsc::channel(1); + let store: Arc = + Arc::new(crate::testing::credentials::test_secrets_store()); + let tool = A2aBridgeTool::new(config, store, tx).await.unwrap(); + let schema = tool.parameters_schema(); + let required = schema["required"].as_array().unwrap(); + assert!(required.contains(&serde_json::json!("query"))); + } + + #[tokio::test] + async fn tool_name_from_config() { + let config = A2aConfig { + enabled: true, + agent_url: "https://example.com".to_string(), + assistant_id: "test".to_string(), + tool_name: "my_custom_tool".to_string(), + tool_description: "custom description".to_string(), + message_prefix: "[custom]".to_string(), + request_timeout: Duration::from_secs(60), + task_timeout: Duration::from_secs(1200), + api_key_secret: "".to_string(), + }; + let (tx, _rx) = mpsc::channel(1); + let store: Arc = + Arc::new(crate::testing::credentials::test_secrets_store()); + let tool = A2aBridgeTool::new(config, store, tx).await.unwrap(); + assert_eq!(tool.name(), "my_custom_tool"); + assert_eq!(tool.description(), "custom description"); + } + + #[tokio::test] + async fn requires_always_approval() { + let config = A2aConfig { + enabled: true, + agent_url: "https://example.com".to_string(), + assistant_id: "test".to_string(), + tool_name: "a2a_query".to_string(), + tool_description: "test".to_string(), + message_prefix: "[a2a]".to_string(), + request_timeout: Duration::from_secs(60), + task_timeout: Duration::from_secs(1200), + api_key_secret: "".to_string(), + }; + let (tx, _rx) = mpsc::channel(1); + let store: Arc = + Arc::new(crate::testing::credentials::test_secrets_store()); + let tool = A2aBridgeTool::new(config, store, tx).await.unwrap(); + assert_eq!( + tool.requires_approval(&serde_json::json!({})), + ApprovalRequirement::Always + ); + } +} diff --git a/src/tools/builtin/a2a/mod.rs b/src/tools/builtin/a2a/mod.rs new file mode 100644 index 0000000000..98536fe753 --- /dev/null +++ b/src/tools/builtin/a2a/mod.rs @@ -0,0 +1,11 @@ +//! A2A (Agent-to-Agent) protocol bridge tool. +//! +//! Connects to remote agents via the Google A2A protocol (JSON-RPC 2.0 + SSE +//! streaming). The tool sends a query, reads the first SSE event to determine +//! if the result is immediate or async, and spawns a background consumer for +//! long-running tasks that pushes results back via `inject_tx`. + +mod bridge; +pub(crate) mod protocol; + +pub use bridge::A2aBridgeTool; diff --git a/src/tools/builtin/a2a/protocol.rs b/src/tools/builtin/a2a/protocol.rs new file mode 100644 index 0000000000..3120f68cff --- /dev/null +++ b/src/tools/builtin/a2a/protocol.rs @@ -0,0 +1,628 @@ +//! A2A (Agent-to-Agent) protocol parsing: SSE stream events and JSON-RPC 2.0. +//! +//! This module is intentionally generic — it handles only protocol-level +//! concerns (SSE framing, JSON-RPC envelope, event classification) and can +//! be reused by any A2A integration. + +/// A parsed SSE event from an A2A stream. +#[derive(Debug, Clone)] +pub(crate) struct A2aStreamEvent { + /// The SSE `event:` field (e.g. `"message"`). `None` for unnamed events. + /// Retained for future event-type filtering (e.g. skipping non-message events). + #[allow(dead_code)] + pub event_type: Option, + /// Parsed JSON from the `data:` field. + pub raw: serde_json::Value, +} + +/// Classification of an A2A stream event. +#[derive(Debug)] +pub(crate) enum EventKind { + /// JSON-RPC-level error (top-level `error` field) or result-level error. + Error(String), + /// Final result available (task completed synchronously or stream finished). + Final(serde_json::Value), + /// Task is in progress; contains the task ID and optional context ID. + InProgress { + task_id: String, + context_id: Option, + }, +} + +/// Build an A2A JSON-RPC 2.0 request body for `message/send`. +/// +/// Part format follows A2A v1.0: `{"text": "..."}` without `kind` discriminator. +pub(crate) fn build_jsonrpc_request( + query: &str, + context: Option<&serde_json::Value>, + thread_id: Option<&str>, +) -> serde_json::Value { + // v1.0: Parts use flat type fields, no "kind" discriminator + let mut parts = vec![serde_json::json!({ "text": query })]; + + if let Some(ctx) = context { + parts.push(serde_json::json!({ "data": ctx })); + } + + let msg_id = format!("msg-{}", uuid::Uuid::new_v4()); + let mut params = serde_json::json!({ + "message": { + "role": "user", + "parts": parts, + "messageId": msg_id, + } + }); + + if let Some(tid) = thread_id { + params["thread"] = serde_json::json!({ "threadId": tid }); + } + + serde_json::json!({ + "jsonrpc": "2.0", + "id": uuid::Uuid::new_v4().to_string(), + "method": "message/send", + "params": params, + }) +} + +/// Parse complete SSE events from a buffer, consuming processed bytes. +/// +/// Follows the SSE specification: +/// - Events are delimited by blank lines (`\n\n`) +/// - Multi-line `data:` fields are concatenated with `\n` +/// - `event:` lines set the event type +/// - Lines starting with `:` are comments (ignored) +/// - `\r\n` and `\r` line endings are normalized +pub(crate) fn parse_sse_events(buffer: &mut String) -> Vec { + // Normalize line endings: \r\n → \n, bare \r → \n + if buffer.contains('\r') { + *buffer = buffer.replace("\r\n", "\n").replace('\r', "\n"); + } + + let mut events = Vec::new(); + + // Process complete event blocks (terminated by \n\n) + while let Some(boundary) = buffer.find("\n\n") { + let block = buffer[..boundary].to_string(); + // Remove the block + both newlines from the buffer + *buffer = buffer[boundary + 2..].to_string(); + + if block.is_empty() { + continue; + } + + let mut event_type: Option = None; + let mut data_lines: Vec<&str> = Vec::new(); + + for line in block.lines() { + if line.starts_with(':') { + // Comment line — skip + continue; + } + + if let Some(value) = line.strip_prefix("event:") { + event_type = Some(value.trim().to_string()); + } else if let Some(value) = line.strip_prefix("data:") { + let trimmed = value.strip_prefix(' ').unwrap_or(value); + data_lines.push(trimmed); + } + // Ignore `id:`, `retry:`, and unknown fields per SSE spec + } + + if data_lines.is_empty() { + continue; + } + + let data_str = data_lines.join("\n"); + if data_str.is_empty() { + continue; + } + + if let Ok(parsed) = serde_json::from_str::(&data_str) { + events.push(A2aStreamEvent { + event_type, + raw: parsed, + }); + } else { + tracing::debug!( + data = %data_str, + "A2A SSE: skipping non-JSON data block" + ); + } + } + + events +} + +/// Classify an A2A stream event into an actionable kind. +/// +/// Checks for JSON-RPC-level errors first (top-level `error` field), then +/// result-level errors, then final/in-progress status. +pub(crate) fn classify_event(event: &A2aStreamEvent) -> EventKind { + // C3: Check top-level JSON-RPC error field first + if let Some(error) = event.raw.get("error") { + let msg = error + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("unknown JSON-RPC error"); + let code = error + .get("code") + .and_then(|c| c.as_i64()) + .map(|c| format!(" (code: {})", c)) + .unwrap_or_default(); + return EventKind::Error(format!("{}{}", msg, code)); + } + + let result = match event.raw.get("result") { + Some(r) => r, + None => return EventKind::Error("A2A event missing 'result' field".to_string()), + }; + + // Check for result-level error (v0.x used `kind: "error"`, check both) + if result.get("kind").and_then(|k| k.as_str()) == Some("error") { + let msg = result + .get("error") + .and_then(|e| e.get("message")) + .and_then(|m| m.as_str()) + .or_else(|| result.get("message").and_then(|m| m.as_str())) + .unwrap_or("unknown error"); + return EventKind::Error(msg.to_string()); + } + + // v1.0: check status.state for terminal states (replaces deprecated `final` field) + let state = result + .get("status") + .and_then(|s| s.get("state")) + .and_then(|s| s.as_str()); + + match state { + Some("completed") => return EventKind::Final(result.clone()), + Some("failed") | Some("canceled") | Some("rejected") => { + let msg = result + .get("status") + .and_then(|s| s.get("message")) + .and_then(|m| m.get("parts")) + .and_then(|p| p.as_array()) + .and_then(|parts| parts.first()) + .and_then(|part| part.get("text")) + .and_then(|t| t.as_str()) + .unwrap_or("task failed"); + return EventKind::Error(format!("A2A task {}: {}", state.unwrap_or("failed"), msg)); + } + _ => {} + } + + // Backward compat: v0.x `final` field (remove once all servers upgrade) + if result.get("final").and_then(|f| f.as_bool()) == Some(true) { + return EventKind::Final(result.clone()); + } + + // Extract task ID and context ID for in-progress events + let task_id = result + .get("id") + .and_then(|id| id.as_str()) + .unwrap_or("unknown") + .to_string(); + + let context_id = result + .get("contextId") + .and_then(|c| c.as_str()) + .map(|s| s.to_string()); + + EventKind::InProgress { + task_id, + context_id, + } +} + +/// Extract text content from an A2A result's message parts. +/// +/// Tries `status.message.parts[].text`, then `message.parts[].text`, +/// then falls back to pretty-printing the entire result. +pub(crate) fn extract_text_from_result(result: &serde_json::Value, max_len: usize) -> String { + let text = extract_text_parts(result.get("status").and_then(|s| s.get("message"))) + .or_else(|| extract_text_parts(result.get("message"))) + // LangGraph: extract from the last agent message in history[] + .or_else(|| { + result + .get("history") + .and_then(|h| h.as_array()) + .and_then(|arr| { + arr.iter() + .rev() + .find(|msg| msg.get("role").and_then(|r| r.as_str()) == Some("agent")) + }) + .and_then(|msg| extract_text_parts(Some(msg))) + }) + // LangGraph: extract from artifacts[].parts[].text + .or_else(|| { + result + .get("artifacts") + .and_then(|a| a.as_array()) + .and_then(|arr| arr.first()) + .and_then(|artifact| extract_text_parts(Some(artifact))) + }) + .unwrap_or_else(|| serde_json::to_string_pretty(result).unwrap_or_default()); + + truncate_str(&text, max_len) +} + +/// Check if an A2A result has non-empty text parts in `status.message.parts`. +pub(crate) fn result_has_text_parts(result: &serde_json::Value) -> bool { + result + .get("status") + .and_then(|s| s.get("message")) + .and_then(|m| m.get("parts")) + .and_then(|p| p.as_array()) + .is_some_and(|parts| { + parts.iter().any(|part| { + part.get("text") + .and_then(|t| t.as_str()) + .is_some_and(|s| !s.is_empty()) + }) + }) +} + +/// Check if a raw event JSON contains meaningful message text at `result.status.message.parts`. +pub(crate) fn has_message_content(raw: &serde_json::Value) -> bool { + raw.get("result").is_some_and(result_has_text_parts) +} + +/// Extract and join text parts from a message object (`{ "parts": [{"text": ...}] }`). +fn extract_text_parts(message: Option<&serde_json::Value>) -> Option { + message + .and_then(|m| m.get("parts")) + .and_then(|p| p.as_array()) + .map(|parts| { + parts + .iter() + .filter_map(|part| part.get("text").and_then(|t| t.as_str())) + .collect::>() + .join("\n") + }) + .filter(|s| !s.is_empty()) +} + +/// Truncate a string to `max_len` bytes, respecting UTF-8 char boundaries. +pub(crate) fn truncate_str(s: &str, max_len: usize) -> String { + if s.len() <= max_len { + return s.to_string(); + } + let mut end = max_len; + while end > 0 && !s.is_char_boundary(end) { + end -= 1; + } + format!("{}...", &s[..end]) +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── build_jsonrpc_request ─────────────────────────────────────── + + #[test] + fn build_request_basic() { + let req = build_jsonrpc_request("hello", None, None); + assert_eq!(req["method"], "message/send"); + assert_eq!(req["params"]["message"]["role"], "user"); + let parts = req["params"]["message"]["parts"].as_array().unwrap(); + assert_eq!(parts.len(), 1); + assert_eq!(parts[0]["text"], "hello"); + } + + #[test] + fn build_request_with_context_and_thread() { + let ctx = serde_json::json!({"key": "value"}); + let req = build_jsonrpc_request("query", Some(&ctx), Some("thread-42")); + let parts = req["params"]["message"]["parts"].as_array().unwrap(); + assert_eq!(parts.len(), 2); + // v1.0: no "kind" discriminator, data part has direct "data" field + assert_eq!(parts[1]["data"]["key"], "value"); + assert_eq!(req["params"]["thread"]["threadId"], "thread-42"); + } + + // ── parse_sse_events ──────────────────────────────────────────── + + #[test] + fn parse_single_event() { + let mut buf = "data: {\"result\":{\"id\":\"t1\"}}\n\n".to_string(); + let events = parse_sse_events(&mut buf); + assert_eq!(events.len(), 1); + assert_eq!(events[0].raw["result"]["id"], "t1"); + assert!(buf.is_empty()); + } + + #[test] + fn parse_multiple_events() { + let mut buf = "data: {\"a\":1}\n\ndata: {\"b\":2}\n\ndata: {\"c\":3}\n\n".to_string(); + let events = parse_sse_events(&mut buf); + assert_eq!(events.len(), 3); + } + + #[test] + fn parse_incomplete_event_stays_in_buffer() { + let mut buf = "data: {\"partial\":true}".to_string(); + let events = parse_sse_events(&mut buf); + assert!(events.is_empty()); + assert!(!buf.is_empty()); // data remains + } + + #[test] + fn parse_multiline_data() { + let mut buf = "data: {\"multi\":\n\ndata: true}\n\n".to_string(); + // First block: "data: {\"multi\":" — incomplete JSON, will be skipped + // Second block: "data: true}" — also not valid JSON + let events = parse_sse_events(&mut buf); + // Both blocks produce invalid JSON, so no events parsed + assert_eq!(events.len(), 0); + } + + #[test] + fn parse_multiline_data_concatenation() { + // Two data: lines in the same event block should be concatenated + let mut buf = "data: {\"key\":\ndata: \"value\"}\n\n".to_string(); + let events = parse_sse_events(&mut buf); + assert_eq!(events.len(), 1); + assert_eq!(events[0].raw["key"], "value"); + } + + #[test] + fn parse_event_with_type() { + let mut buf = "event: message\ndata: {\"ok\":true}\n\n".to_string(); + let events = parse_sse_events(&mut buf); + assert_eq!(events.len(), 1); + assert_eq!(events[0].event_type.as_deref(), Some("message")); + } + + #[test] + fn parse_comment_lines_ignored() { + let mut buf = ": keep-alive\ndata: {\"ok\":true}\n\n".to_string(); + let events = parse_sse_events(&mut buf); + assert_eq!(events.len(), 1); + } + + #[test] + fn parse_crlf_line_endings() { + let mut buf = "data: {\"ok\":true}\r\n\r\n".to_string(); + let events = parse_sse_events(&mut buf); + assert_eq!(events.len(), 1); + } + + #[test] + fn parse_bare_cr_line_endings() { + let mut buf = "data: {\"ok\":true}\r\r".to_string(); + let events = parse_sse_events(&mut buf); + assert_eq!(events.len(), 1); + } + + #[test] + fn parse_empty_data_skipped() { + let mut buf = "data: \n\n".to_string(); + let events = parse_sse_events(&mut buf); + assert!(events.is_empty()); + } + + // ── classify_event ────────────────────────────────────────────── + + #[test] + fn classify_jsonrpc_error() { + let event = A2aStreamEvent { + event_type: None, + raw: serde_json::json!({ + "jsonrpc": "2.0", + "error": {"code": -32600, "message": "Invalid Request"} + }), + }; + match classify_event(&event) { + EventKind::Error(msg) => assert!(msg.contains("Invalid Request")), + _ => panic!("expected Error"), + } + } + + #[test] + fn classify_result_level_error() { + let event = A2aStreamEvent { + event_type: None, + raw: serde_json::json!({ + "result": { + "kind": "error", + "error": {"message": "rate limited"} + } + }), + }; + match classify_event(&event) { + EventKind::Error(msg) => assert_eq!(msg, "rate limited"), + _ => panic!("expected Error"), + } + } + + #[test] + fn classify_final_event_v1_status_state() { + // v1.0: uses status.state = "completed", no `final` field + let event = A2aStreamEvent { + event_type: None, + raw: serde_json::json!({ + "result": { + "status": { + "state": "completed", + "message": {"parts": [{"text": "done"}]} + } + } + }), + }; + match classify_event(&event) { + EventKind::Final(result) => assert_eq!(result["status"]["state"], "completed"), + _ => panic!("expected Final"), + } + } + + #[test] + fn classify_final_event_legacy_final_field() { + // Backward compat: v0.x `final: true` + let event = A2aStreamEvent { + event_type: None, + raw: serde_json::json!({ + "result": { + "final": true, + "status": {"state": "completed"} + } + }), + }; + match classify_event(&event) { + EventKind::Final(_) => {} + _ => panic!("expected Final"), + } + } + + #[test] + fn classify_failed_task() { + let event = A2aStreamEvent { + event_type: None, + raw: serde_json::json!({ + "result": { + "status": { + "state": "failed", + "message": {"parts": [{"text": "out of memory"}]} + } + } + }), + }; + match classify_event(&event) { + EventKind::Error(msg) => { + assert!(msg.contains("failed")); + assert!(msg.contains("out of memory")); + } + _ => panic!("expected Error"), + } + } + + #[test] + fn classify_canceled_task() { + let event = A2aStreamEvent { + event_type: None, + raw: serde_json::json!({ + "result": { + "status": {"state": "canceled"} + } + }), + }; + match classify_event(&event) { + EventKind::Error(msg) => assert!(msg.contains("canceled")), + _ => panic!("expected Error"), + } + } + + #[test] + fn classify_in_progress_with_context_id() { + let event = A2aStreamEvent { + event_type: None, + raw: serde_json::json!({ + "result": { + "id": "task-abc", + "contextId": "ctx-123", + "status": {"state": "working"} + } + }), + }; + match classify_event(&event) { + EventKind::InProgress { + task_id, + context_id, + } => { + assert_eq!(task_id, "task-abc"); + assert_eq!(context_id.as_deref(), Some("ctx-123")); + } + _ => panic!("expected InProgress"), + } + } + + #[test] + fn classify_missing_result_is_error() { + let event = A2aStreamEvent { + event_type: None, + raw: serde_json::json!({"jsonrpc": "2.0", "id": "1"}), + }; + match classify_event(&event) { + EventKind::Error(msg) => assert!(msg.contains("missing 'result'")), + _ => panic!("expected Error"), + } + } + + // ── extract_text_from_result ──────────────────────────────────── + + #[test] + fn extract_text_from_status_message() { + let result = serde_json::json!({ + "status": { + "message": { + "parts": [ + {"text": "line one"}, + {"text": "line two"} + ] + } + } + }); + let text = extract_text_from_result(&result, 2000); + assert_eq!(text, "line one\nline two"); + } + + #[test] + fn extract_text_truncates_at_char_boundary() { + let result = serde_json::json!({ + "status": { + "message": { + "parts": [{"text": "你好世界abcdefghij"}] + } + } + }); + // "你好世界" is 12 bytes in UTF-8 (3 bytes each) + let text = extract_text_from_result(&result, 10); + assert!(text.ends_with("...")); + assert!(text.len() <= 13); // 9 (3 chars) + "..." + } + + #[test] + fn extract_text_from_langgraph_history() { + // LangGraph returns agent messages in `history[]`, not `status.message` + let result = serde_json::json!({ + "id": "task-1", + "contextId": "ctx-1", + "history": [ + {"role": "user", "parts": [{"text": "What is 2+2?"}]}, + {"role": "agent", "parts": [{"text": "4"}]} + ], + "status": {"state": "completed"}, + "artifacts": [{"parts": [{"text": "4"}]}] + }); + let text = extract_text_from_result(&result, 2000); + assert_eq!(text, "4"); + } + + #[test] + fn extract_text_from_langgraph_artifacts() { + // When history has no agent messages, fall back to artifacts + let result = serde_json::json!({ + "artifacts": [ + {"artifactId": "a1", "parts": [{"text": "result content"}]} + ], + "status": {"state": "completed"} + }); + let text = extract_text_from_result(&result, 2000); + assert_eq!(text, "result content"); + } + + // ── truncate_str ──────────────────────────────────────────────── + + #[test] + fn truncate_short_string_unchanged() { + assert_eq!(truncate_str("hello", 10), "hello"); + } + + #[test] + fn truncate_respects_char_boundaries() { + let s = "分析茅台的估值"; + let t = truncate_str(s, 6); // 6 bytes = 2 Chinese chars + assert_eq!(t, "分析..."); + } +} diff --git a/src/tools/builtin/mod.rs b/src/tools/builtin/mod.rs index 8ba8e57b0b..63990c0e95 100644 --- a/src/tools/builtin/mod.rs +++ b/src/tools/builtin/mod.rs @@ -1,5 +1,6 @@ //! Built-in tools that come with the agent. +pub mod a2a; mod echo; pub mod extension_tools; mod file; @@ -17,6 +18,7 @@ pub mod skill_tools; mod time; mod tool_info; +pub use a2a::A2aBridgeTool; pub use echo::EchoTool; pub use extension_tools::{ ExtensionInfoTool, ToolActivateTool, ToolAuthTool, ToolInstallTool, ToolListTool, diff --git a/tests/a2a_bridge_integration.rs b/tests/a2a_bridge_integration.rs new file mode 100644 index 0000000000..866aaab43a --- /dev/null +++ b/tests/a2a_bridge_integration.rs @@ -0,0 +1,128 @@ +//! Integration tests for the A2A bridge tool. +//! +//! Construction tests run always. Live tests require a running A2A-compatible +//! agent and are marked `#[ignore]`. +//! Run with: `cargo test --test a2a_bridge_integration -- --ignored` + +use std::sync::Arc; +use std::time::Duration; + +use ironclaw::config::A2aConfig; +use ironclaw::secrets::{InMemorySecretsStore, SecretsCrypto, SecretsStore}; +use ironclaw::tools::builtin::A2aBridgeTool; +use ironclaw::tools::{ApprovalRequirement, Tool, ToolOutput}; +use secrecy::SecretString; +use tokio::sync::mpsc; + +fn test_secrets_store() -> Arc { + let key = SecretString::from("test-key-32-bytes-long-enough!!!".to_string()); + let crypto = Arc::new(SecretsCrypto::new(key).expect("test crypto")); + Arc::new(InMemorySecretsStore::new(crypto)) +} + +fn test_config() -> A2aConfig { + A2aConfig { + enabled: true, + agent_url: std::env::var("A2A_AGENT_URL") + .unwrap_or_else(|_| "https://a2a-test.example.com".to_string()), + assistant_id: std::env::var("A2A_ASSISTANT_ID") + .unwrap_or_else(|_| "test-assistant".to_string()), + tool_name: "a2a_test".to_string(), + tool_description: "Test A2A bridge".to_string(), + message_prefix: "[test]".to_string(), + request_timeout: Duration::from_secs(30), + task_timeout: Duration::from_secs(120), + api_key_secret: "a2a_test_key".to_string(), + } +} + +async fn create_tool(config: A2aConfig) -> Result { + let (tx, _rx) = mpsc::channel(10); + A2aBridgeTool::new(config, test_secrets_store(), tx).await +} + +// ── Construction tests (run always) ──────────────────────────────── + +#[tokio::test] +async fn construction_rejects_localhost() { + let mut config = test_config(); + config.agent_url = "http://localhost:5085".to_string(); + assert!(create_tool(config).await.is_err()); +} + +#[tokio::test] +async fn construction_rejects_private_ip() { + let mut config = test_config(); + config.agent_url = "http://192.168.1.100:5085".to_string(); + assert!(create_tool(config).await.is_err()); +} + +#[tokio::test] +async fn construction_rejects_link_local() { + let mut config = test_config(); + config.agent_url = "http://169.254.169.254/latest".to_string(); + assert!(create_tool(config).await.is_err()); +} + +#[tokio::test] +async fn construction_accepts_public_url() { + let config = test_config(); + assert!(create_tool(config).await.is_ok()); +} + +#[tokio::test] +async fn tool_uses_configured_name() { + let mut config = test_config(); + config.tool_name = "custom_a2a".to_string(); + let tool = create_tool(config).await.unwrap(); + assert_eq!(tool.name(), "custom_a2a"); +} + +#[tokio::test] +async fn tool_requires_always_approval() { + let config = test_config(); + let tool = create_tool(config).await.unwrap(); + assert_eq!( + tool.requires_approval(&serde_json::json!({})), + ApprovalRequirement::Always, + ); +} + +// ── Live agent tests (require A2A_AGENT_URL) ─────────────────────── + +#[tokio::test] +#[ignore = "requires running A2A agent (set A2A_AGENT_URL)"] +async fn live_query_returns_result() { + let config = test_config(); + let (tx, mut rx) = mpsc::channel(10); + let tool = A2aBridgeTool::new(config, test_secrets_store(), tx) + .await + .unwrap(); + + let ctx = ironclaw::context::JobContext::default(); + let params = serde_json::json!({ "query": "What is 2+2?" }); + + let result = tool.execute(params, &ctx).await; + assert!(result.is_ok(), "execute failed: {:?}", result.err()); + + let output: ToolOutput = result.unwrap(); + let status = output.result["status"].as_str().unwrap(); + assert!( + status == "completed" || status == "submitted", + "unexpected status: {}", + status + ); + + // If submitted, wait for the background consumer to push a result + if status == "submitted" { + let msg = tokio::time::timeout(Duration::from_secs(120), rx.recv()) + .await + .expect("timed out waiting for background result") + .expect("channel closed"); + assert!( + msg.content.contains("[test]"), + "expected message_prefix in: {}", + msg.content + ); + } +}