From f2aa43891c3c6cbd828473140643b3a92075a112 Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Fri, 17 Apr 2026 09:07:24 +0300 Subject: [PATCH] Add Origin header validation middleware MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ToolHive's proxy layer had no Origin-header validation, and the legacy HTTP+SSE transport sent `Access-Control-Allow-Origin: *`, leaving both modes open to DNS-rebinding attacks from browser clients. MCP 2025-11-25 §"Security Warning" requires servers to validate Origin on all connections and respond 403 when the value is invalid. This change introduces a dedicated middleware at pkg/transport/middleware/origin/ that rejects requests whose Origin header is present and not in an operator-configured allowlist, and wires it into both the factory-based chain (thv run / thv-proxyrunner / vMCP) and the inline chain (thv proxy). Behavior: - New --allowed-origins flag on `thv run` and `thv proxy` accepts a repeatable exact-match list. When empty and the bind host is loopback, a default loopback-only allowlist is derived automatically (http://localhost:PORT + 127.0.0.1 + [::1]). When empty and the bind is non-loopback, the middleware is skipped and a warning is logged — the bind-opt-in hardening lands in a follow-up. - Matching is byte-exact except that scheme and host are lowercased per RFC 6454 §4. Requests with multiple Origin headers are rejected outright. - 403 responses carry a JSON-RPC error body (id: null, code -32600, message "Origin not allowed"). - `Access-Control-Allow-Origin: *` removed from the httpsse SSE handler; the wildcard would have neutered any Origin enforcement via preflight response inheritance. Closes audit row 5 (Origin validation absent). Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Juan Antonio Osorio --- cmd/thv/app/proxy.go | 28 +- cmd/thv/app/run_flags.go | 11 + docs/cli/thv_proxy.md | 1 + docs/cli/thv_run.md | 1 + docs/server/docs.go | 8 + docs/server/swagger.json | 8 + docs/server/swagger.yaml | 12 + pkg/runner/config.go | 8 + pkg/runner/config_builder.go | 12 + pkg/runner/middleware.go | 46 ++- pkg/transport/middleware/origin/origin.go | 266 +++++++++++++ .../middleware/origin/origin_test.go | 352 ++++++++++++++++++ pkg/transport/proxy/httpsse/http_proxy.go | 5 +- 13 files changed, 751 insertions(+), 7 deletions(-) create mode 100644 pkg/transport/middleware/origin/origin.go create mode 100644 pkg/transport/middleware/origin/origin_test.go diff --git a/cmd/thv/app/proxy.go b/cmd/thv/app/proxy.go index de33227bcf..b291d6f588 100644 --- a/cmd/thv/app/proxy.go +++ b/cmd/thv/app/proxy.go @@ -24,6 +24,7 @@ import ( "github.com/stacklok/toolhive/pkg/networking" "github.com/stacklok/toolhive/pkg/transport" "github.com/stacklok/toolhive/pkg/transport/middleware" + "github.com/stacklok/toolhive/pkg/transport/middleware/origin" "github.com/stacklok/toolhive/pkg/transport/proxy/transparent" "github.com/stacklok/toolhive/pkg/transport/types" ) @@ -110,9 +111,10 @@ Dynamic client registration (automatic OAuth client setup): } var ( - proxyHost string - proxyPort int - proxyTargetURI string + proxyHost string + proxyPort int + proxyTargetURI string + proxyAllowedOrigins []string resourceURL string // Explicit resource URL for OAuth discovery endpoint (RFC 9728) @@ -133,6 +135,10 @@ const ( func init() { proxyCmd.Flags().StringVar(&proxyHost, "host", transport.LocalhostIPv4, "Host for the HTTP proxy to listen on (IP or hostname)") proxyCmd.Flags().IntVar(&proxyPort, "port", 0, "Port for the HTTP proxy to listen on (host port)") + proxyCmd.Flags().StringArrayVar(&proxyAllowedOrigins, "allowed-origins", nil, + "Exact-match allowlist for the HTTP Origin header (repeatable). Recommended when binding publicly; "+ + "loopback binds derive a default allowlist automatically, non-loopback binds log a warning when "+ + "no value is supplied. Example: https://my-mcp.example.com") proxyCmd.Flags().StringVar( &proxyTargetURI, "target-uri", @@ -226,6 +232,22 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error { // Create middlewares slice for incoming request authentication var middlewares []types.NamedMiddleware + // Origin-header validation (DNS-rebinding protection per MCP 2025-11-25 + // §"Security Warning"). Added first so disallowed Origins are rejected + // before authentication or any outbound token acquisition runs. + if allowed := origin.ResolveAllowedOrigins(proxyHost, port, proxyAllowedOrigins); len(allowed) > 0 { + middlewares = append(middlewares, types.NamedMiddleware{ + Name: origin.MiddlewareType, + Function: origin.CreateOriginMiddleware(allowed), + }) + } else { + slog.Warn("Origin validation disabled — no allowlist configured for non-loopback bind", + "host", proxyHost, + "port", port, + "hint", "pass --allowed-origins=https://your-client.example to enable DNS-rebind protection", + ) + } + // Get OIDC configuration if enabled (for protecting the proxy endpoint) oidcConfig := getProxyOIDCConfig(cmd) diff --git a/cmd/thv/app/run_flags.go b/cmd/thv/app/run_flags.go index 94c159af8a..f9e35275e6 100644 --- a/cmd/thv/app/run_flags.go +++ b/cmd/thv/app/run_flags.go @@ -137,6 +137,12 @@ type RunFlags struct { RemoteForwardHeaders []string RemoteForwardHeadersSecret []string + // AllowedOrigins is the HTTP Origin-header allowlist for DNS-rebinding protection + // (MCP 2025-11-25 §"Security Warning"). Empty with a loopback host auto-derives + // loopback-only defaults; empty with a non-loopback host disables the check + // (operator must supply explicit origins for public bind). + AllowedOrigins []string + // Runtime configuration RuntimeImage string RuntimeAddPackages []string @@ -156,6 +162,10 @@ func AddRunFlags(cmd *cobra.Command, config *RunFlags) { cmd.Flags().StringVar(&config.Name, "name", "", "Name of the MCP server (default to auto-generated from image)") cmd.Flags().StringVar(&config.Group, "group", "default", "Name of the group this workload should belong to") cmd.Flags().StringVar(&config.Host, "host", transport.LocalhostIPv4, "Host for the HTTP proxy to listen on (IP or hostname)") + cmd.Flags().StringArrayVar(&config.AllowedOrigins, "allowed-origins", nil, + "Exact-match allowlist for the HTTP Origin header (repeatable). Recommended when binding publicly; "+ + "loopback binds derive a default allowlist automatically, non-loopback binds log a warning when "+ + "no value is supplied. Example: https://my-mcp.example.com") cmd.Flags().IntVar(&config.ProxyPort, "proxy-port", 0, "Port for the HTTP proxy to listen on (host port)") cmd.Flags().IntVar(&config.TargetPort, "target-port", 0, "Port for the container to expose (only applicable to SSE or Streamable HTTP transport)") @@ -678,6 +688,7 @@ func buildRunnerConfig( PrintOverlays: runFlags.PrintOverlays, }), runner.WithPublish(runFlags.Publish), + runner.WithAllowedOrigins(runFlags.AllowedOrigins), } opts = append(opts, extraOpts...) diff --git a/docs/cli/thv_proxy.md b/docs/cli/thv_proxy.md index be2e8d92d2..6cbc09c22e 100644 --- a/docs/cli/thv_proxy.md +++ b/docs/cli/thv_proxy.md @@ -97,6 +97,7 @@ thv proxy [flags] SERVER_NAME ### Options ``` + --allowed-origins stringArray Exact-match allowlist for the HTTP Origin header (repeatable). Recommended when binding publicly; loopback binds derive a default allowlist automatically, non-loopback binds log a warning when no value is supplied. Example: https://my-mcp.example.com -h, --help help for proxy --host string Host for the HTTP proxy to listen on (IP or hostname) (default "127.0.0.1") --oidc-audience string Expected audience for the token diff --git a/docs/cli/thv_run.md b/docs/cli/thv_run.md index d3c38dc4cb..e540964498 100644 --- a/docs/cli/thv_run.md +++ b/docs/cli/thv_run.md @@ -112,6 +112,7 @@ thv run [flags] SERVER_OR_IMAGE_OR_PROTOCOL [-- ARGS...] ``` --allow-docker-gateway Allow outbound connections to Docker gateway addresses (host.docker.internal, gateway.docker.internal, 172.17.0.1). Only applies when --isolate-network is set. These are blocked by default even when insecure_allow_all is enabled. + --allowed-origins stringArray Exact-match allowlist for the HTTP Origin header (repeatable). Recommended when binding publicly; loopback binds derive a default allowlist automatically, non-loopback binds log a warning when no value is supplied. Example: https://my-mcp.example.com --audit-config string Path to the audit configuration file --authz-config string Path to the authorization configuration file --ca-cert string Path to a custom CA certificate file to use for container builds diff --git a/docs/server/docs.go b/docs/server/docs.go index 88e07fcd08..76f717aacb 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -1131,6 +1131,14 @@ const docTemplate = `{ "description": "AllowDockerGateway permits outbound connections to Docker gateway addresses\n(host.docker.internal, gateway.docker.internal, 172.17.0.1). These are\nblocked by default in the egress proxy even when InsecureAllowAll is set.\nOnly applicable to Docker deployments with network isolation enabled.", "type": "boolean" }, + "allowed_origins": { + "description": "AllowedOrigins is the allowlist of values accepted on the HTTP Origin header,\nused for DNS-rebinding protection per MCP 2025-11-25 §\"Security Warning\".\nWhen empty and Host is loopback (127.0.0.1 / localhost / [::1]), a default\nloopback-only allowlist is derived at middleware-wiring time.\nWhen empty and Host is non-loopback, the middleware is disabled — operators\nexposing the proxy publicly must configure an explicit allowlist.", + "items": { + "type": "string" + }, + "type": "array", + "uniqueItems": false + }, "audit_config": { "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_audit.Config" }, diff --git a/docs/server/swagger.json b/docs/server/swagger.json index be4022ac58..1b6d33ccf5 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -1124,6 +1124,14 @@ "description": "AllowDockerGateway permits outbound connections to Docker gateway addresses\n(host.docker.internal, gateway.docker.internal, 172.17.0.1). These are\nblocked by default in the egress proxy even when InsecureAllowAll is set.\nOnly applicable to Docker deployments with network isolation enabled.", "type": "boolean" }, + "allowed_origins": { + "description": "AllowedOrigins is the allowlist of values accepted on the HTTP Origin header,\nused for DNS-rebinding protection per MCP 2025-11-25 §\"Security Warning\".\nWhen empty and Host is loopback (127.0.0.1 / localhost / [::1]), a default\nloopback-only allowlist is derived at middleware-wiring time.\nWhen empty and Host is non-loopback, the middleware is disabled — operators\nexposing the proxy publicly must configure an explicit allowlist.", + "items": { + "type": "string" + }, + "type": "array", + "uniqueItems": false + }, "audit_config": { "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_audit.Config" }, diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index 5fa8ca2e2c..7a2539a020 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -1091,6 +1091,18 @@ components: blocked by default in the egress proxy even when InsecureAllowAll is set. Only applicable to Docker deployments with network isolation enabled. type: boolean + allowed_origins: + description: |- + AllowedOrigins is the allowlist of values accepted on the HTTP Origin header, + used for DNS-rebinding protection per MCP 2025-11-25 §"Security Warning". + When empty and Host is loopback (127.0.0.1 / localhost / [::1]), a default + loopback-only allowlist is derived at middleware-wiring time. + When empty and Host is non-loopback, the middleware is disabled — operators + exposing the proxy publicly must configure an explicit allowlist. + items: + type: string + type: array + uniqueItems: false audit_config: $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_audit.Config' audit_config_path: diff --git a/pkg/runner/config.go b/pkg/runner/config.go index 519f277abd..b95ede0059 100644 --- a/pkg/runner/config.go +++ b/pkg/runner/config.go @@ -97,6 +97,14 @@ type RunConfig struct { // TargetHost is the host to forward traffic to (only applicable to SSE transport) TargetHost string `json:"target_host,omitempty" yaml:"target_host,omitempty"` + // AllowedOrigins is the allowlist of values accepted on the HTTP Origin header, + // used for DNS-rebinding protection per MCP 2025-11-25 §"Security Warning". + // When empty and Host is loopback (127.0.0.1 / localhost / [::1]), a default + // loopback-only allowlist is derived at middleware-wiring time. + // When empty and Host is non-loopback, the middleware is disabled — operators + // exposing the proxy publicly must configure an explicit allowlist. + AllowedOrigins []string `json:"allowed_origins,omitempty" yaml:"allowed_origins,omitempty"` + // Publish lists ports to publish to the host in format "hostPort:containerPort" Publish []string `json:"publish,omitempty" yaml:"publish,omitempty"` diff --git a/pkg/runner/config_builder.go b/pkg/runner/config_builder.go index 89cdf2c318..86c81d555c 100644 --- a/pkg/runner/config_builder.go +++ b/pkg/runner/config_builder.go @@ -330,6 +330,18 @@ func WithAllowDockerGateway(allow bool) RunConfigBuilderOption { } } +// WithAllowedOrigins sets the HTTP Origin-header allowlist used for +// DNS-rebinding protection (MCP 2025-11-25 §"Security Warning"). +// An empty slice defers the choice to middleware wiring, which derives a +// loopback-only default when the bind host is loopback and otherwise leaves +// the middleware disabled. +func WithAllowedOrigins(origins []string) RunConfigBuilderOption { + return func(b *runConfigBuilder) error { + b.config.AllowedOrigins = origins + return nil + } +} + // WithTrustProxyHeaders sets whether to trust X-Forwarded-* headers from reverse proxies func WithTrustProxyHeaders(trust bool) RunConfigBuilderOption { return func(b *runConfigBuilder) error { diff --git a/pkg/runner/middleware.go b/pkg/runner/middleware.go index be9dd33506..bc56a62104 100644 --- a/pkg/runner/middleware.go +++ b/pkg/runner/middleware.go @@ -5,6 +5,7 @@ package runner import ( "fmt" + "log/slog" "github.com/stacklok/toolhive/pkg/audit" "github.com/stacklok/toolhive/pkg/auth" @@ -20,6 +21,7 @@ import ( "github.com/stacklok/toolhive/pkg/recovery" "github.com/stacklok/toolhive/pkg/telemetry" headerfwd "github.com/stacklok/toolhive/pkg/transport/middleware" + "github.com/stacklok/toolhive/pkg/transport/middleware/origin" "github.com/stacklok/toolhive/pkg/transport/types" "github.com/stacklok/toolhive/pkg/usagemetrics" "github.com/stacklok/toolhive/pkg/webhook/mutating" @@ -43,6 +45,7 @@ func GetSupportedMiddlewareFactories() map[string]types.MiddlewareFactory { audit.MiddlewareType: audit.CreateMiddleware, recovery.MiddlewareType: recovery.CreateMiddleware, headerfwd.HeaderForwardMiddlewareName: headerfwd.CreateMiddleware, + origin.MiddlewareType: origin.CreateMiddleware, validating.MiddlewareType: validating.CreateMiddleware, mutating.MiddlewareType: mutating.CreateMiddleware, } @@ -56,13 +59,21 @@ func PopulateMiddlewareConfigs(config *RunConfig) error { var middlewareConfigs []types.MiddlewareConfig // TODO: Consider extracting other middleware setup into helper functions like addUsageMetricsMiddleware + // Origin-validation middleware (DNS-rebinding protection per MCP 2025-11-25). + // Positioned first in the slice so it runs earliest in the chain — disallowed + // Origin values are rejected before any authentication or business logic. + middlewareConfigs, err := addOriginMiddleware(middlewareConfigs, config) + if err != nil { + return err + } + // Authentication middleware (always present) authParams := auth.MiddlewareParams{ OIDCConfig: config.OIDCConfig, } - authConfig, err := types.NewMiddlewareConfig(auth.MiddlewareType, authParams) - if err != nil { - return fmt.Errorf("failed to create auth middleware config: %w", err) + authConfig, authErr := types.NewMiddlewareConfig(auth.MiddlewareType, authParams) + if authErr != nil { + return fmt.Errorf("failed to create auth middleware config: %w", authErr) } middlewareConfigs = append(middlewareConfigs, *authConfig) @@ -419,6 +430,35 @@ func addAWSStsMiddleware(middlewares []types.MiddlewareConfig, config *RunConfig return append(middlewares, *awsStsMwConfig), nil } +// addOriginMiddleware adds Origin-header validation middleware for DNS-rebind +// protection per MCP 2025-11-25 §"Security Warning". Default-derivation logic +// lives in origin.ResolveAllowedOrigins so the standalone `thv proxy` command +// and the runner path agree on behavior. +// +// When the effective allowlist is empty — which happens when the operator +// binds to a non-loopback host without supplying --allowed-origins — the +// middleware is skipped entirely and a WARN is logged so the security-disabled +// state is visible in operator logs. A follow-up PR hardens the non-loopback +// path by requiring an explicit opt-in flag (see audit row 22). +func addOriginMiddleware(middlewares []types.MiddlewareConfig, config *RunConfig) ([]types.MiddlewareConfig, error) { + allowed := origin.ResolveAllowedOrigins(config.Host, config.Port, config.AllowedOrigins) + if len(allowed) == 0 { + slog.Warn("Origin validation disabled — no allowlist configured for non-loopback bind", + "host", config.Host, + "port", config.Port, + "hint", "pass --allowed-origins=https://your-client.example to enable DNS-rebind protection", + ) + return middlewares, nil + } + + params := origin.MiddlewareParams{AllowedOrigins: allowed} + mwCfg, err := types.NewMiddlewareConfig(origin.MiddlewareType, params) + if err != nil { + return nil, fmt.Errorf("failed to create origin middleware config: %w", err) + } + return append(middlewares, *mwCfg), nil +} + // addRateLimitMiddleware adds rate limit middleware if configured. func addRateLimitMiddleware(middlewares []types.MiddlewareConfig, config *RunConfig) ([]types.MiddlewareConfig, error) { if config.RateLimitConfig == nil { diff --git a/pkg/transport/middleware/origin/origin.go b/pkg/transport/middleware/origin/origin.go new file mode 100644 index 0000000000..d17fac4086 --- /dev/null +++ b/pkg/transport/middleware/origin/origin.go @@ -0,0 +1,266 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package origin provides HTTP middleware that enforces MCP Origin header +// validation (DNS-rebinding protection) per MCP 2025-11-25 §"Security Warning" +// (https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#security-warning). +// +// When the Origin header is present on an inbound request, it MUST exactly +// match one of the configured allowed origins. Otherwise the middleware +// responds with HTTP 403 and a JSON-RPC error body. Requests without an +// Origin header (typical for non-browser clients) are permitted through. +package origin + +import ( + "encoding/json" + "fmt" + "log/slog" + "maps" + "net" + "net/http" + "slices" + "strings" + + "github.com/stacklok/toolhive/pkg/transport/types" +) + +const ( + // MiddlewareType is the type identifier registered in the middleware factory map. + MiddlewareType = "origin" + + // jsonRPCCodeInvalidRequest is the JSON-RPC 2.0 error code for an invalid + // request. We reuse it for rejected Origin values because the request is + // not well-formed from the server's security policy perspective. + jsonRPCCodeInvalidRequest int64 = -32600 + + // forbiddenBodyFallback is returned if JSON marshalling of the error body + // fails (should never happen with simple map types). + forbiddenBodyFallback = `{"jsonrpc":"2.0","error":{"code":-32600,"message":"Origin not allowed"},"id":null}` +) + +// MiddlewareParams holds the parameters for the origin middleware factory. +type MiddlewareParams struct { + // AllowedOrigins is the exact-match allowlist of acceptable Origin values. + // An empty list disables the middleware (requests pass through unchanged). + AllowedOrigins []string `json:"allowed_origins"` +} + +// FactoryMiddleware wraps origin-validation as a factory-pattern middleware. +type FactoryMiddleware struct { + handler types.MiddlewareFunction +} + +// Handler returns the middleware function used by the proxy. +func (m *FactoryMiddleware) Handler() types.MiddlewareFunction { + return m.handler +} + +// Close releases any resources held by the middleware. +func (*FactoryMiddleware) Close() error { + return nil +} + +// CreateMiddleware is the factory function registered in +// runner.GetSupportedMiddlewareFactories. +// +// If params.AllowedOrigins is empty the factory still registers a pass-through +// handler so the middleware slot is occupied, but logs at Warn level to make +// the security-disabled state visible in operator logs. Callers that want to +// avoid registration entirely should skip calling this factory (see +// pkg/runner.addOriginMiddleware). +func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRunner) error { + var params MiddlewareParams + if err := json.Unmarshal(config.Parameters, ¶ms); err != nil { + return fmt.Errorf("failed to unmarshal origin middleware parameters: %w", err) + } + + if len(params.AllowedOrigins) == 0 { + slog.Warn("origin middleware registered with empty allowlist; Origin validation disabled") + } + + handler := createOriginHandler(params.AllowedOrigins) + runner.AddMiddleware(MiddlewareType, &FactoryMiddleware{handler: handler}) + return nil +} + +// CreateOriginMiddleware returns a middleware function that enforces Origin +// header validation against the provided allowlist. Intended for callers that +// build their middleware chain directly (e.g. `thv proxy`) and do not go +// through the factory registry. +// +// What this solves: DNS-rebinding protection per MCP 2025-11-25 §"Security +// Warning" — requests whose Origin header is present and not in allowedOrigins +// receive HTTP 403 with a JSON-RPC error body. +// +// What this does NOT solve: CORS, CSRF token validation, authentication, or +// Origin-header injection via trusted reverse proxies (the caller's reverse +// proxy must deduplicate Origin headers upstream). +// +// An empty allowedOrigins slice produces a pass-through handler — the caller +// is responsible for deciding whether that is acceptable (e.g. when bind is +// loopback-only and the caller derived an allowlist via ResolveAllowedOrigins). +// +// Matching rules: exact match on byte representation except that the scheme +// and host portions of the Origin value are lowercased (RFC 6454 §4: scheme +// and host are ASCII-case-insensitive). Configured allowlist entries are +// lowercased once at construction time. +func CreateOriginMiddleware(allowedOrigins []string) types.MiddlewareFunction { + return createOriginHandler(allowedOrigins) +} + +// createOriginHandler builds the actual middleware function. An empty +// allowlist short-circuits to a no-op so that callers can safely pass a +// possibly-empty slice. +func createOriginHandler(allowedOrigins []string) types.MiddlewareFunction { + if len(allowedOrigins) == 0 { + return func(next http.Handler) http.Handler { return next } + } + + // Build a set for O(1) lookups. Entries are canonicalized so that + // case-variant Origin values (RFC 6454 §4 makes scheme + host case- + // insensitive) match predictably. Preserve the sorted list for logging. + allowedSet := make(map[string]struct{}, len(allowedOrigins)) + for _, o := range allowedOrigins { + allowedSet[canonicalizeOrigin(o)] = struct{}{} + } + slog.Debug("origin middleware configured", + "allowed_origin_count", len(allowedSet), + "allowed_origins", slices.Sorted(maps.Keys(allowedSet)), + ) + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Reject requests with multiple Origin headers outright — the + // Fetch spec defines Origin as a single-value header and browsers + // never legitimately send more than one. Splitting / merging at an + // upstream proxy is the only way this fires. + if values := r.Header.Values("Origin"); len(values) > 1 { + slog.Warn("rejecting request with multiple Origin headers", + "count", len(values), + "method", r.Method, + "path", r.URL.Path, + "remote", r.RemoteAddr, + ) + writeForbidden(w) + return + } + + origin := r.Header.Get("Origin") + if origin == "" { + // MCP spec §"Security Warning" only mandates validation when + // the header is present. Non-browser clients (stdio bridges, + // SDK clients) typically omit Origin entirely. + next.ServeHTTP(w, r) + return + } + if _, ok := allowedSet[canonicalizeOrigin(origin)]; !ok { + slog.Warn("rejecting request with disallowed Origin", + "origin", origin, + "method", r.Method, + "path", r.URL.Path, + "remote", r.RemoteAddr, + ) + writeForbidden(w) + return + } + next.ServeHTTP(w, r) + }) + } +} + +// canonicalizeOrigin lowercases the scheme and host portions of an Origin +// value while preserving the port verbatim. RFC 6454 §4 makes the scheme and +// host ASCII-case-insensitive; the port is a decimal integer and has no case. +// Malformed inputs (no "://" separator) are returned lowercased in full on the +// assumption that they will simply not match any legitimate allowlist entry. +func canonicalizeOrigin(raw string) string { + if raw == "" { + return raw + } + schemeEnd := strings.Index(raw, "://") + if schemeEnd < 0 { + return strings.ToLower(raw) + } + scheme := strings.ToLower(raw[:schemeEnd]) + rest := raw[schemeEnd+3:] + // rest is "host[:port]"; port starts at the LAST ":" to correctly handle + // IPv6 literals that the spec requires wrapped in brackets (e.g. "[::1]:8080"). + if portIdx := strings.LastIndex(rest, ":"); portIdx > 0 && !strings.Contains(rest[portIdx+1:], "]") { + host := strings.ToLower(rest[:portIdx]) + return scheme + "://" + host + rest[portIdx:] + } + return scheme + "://" + strings.ToLower(rest) +} + +// ResolveAllowedOrigins picks the effective Origin allowlist for a proxy +// listener. Resolution order: +// 1. If explicit is non-empty, use it verbatim. +// 2. Otherwise, if host is a loopback IP or the string "localhost", and port +// is valid, return loopback-only defaults +// (http://localhost:PORT, http://127.0.0.1:PORT, http://[::1]:PORT). +// 3. Otherwise, return nil — operators exposing the proxy publicly must +// configure an explicit allowlist. +// +// Shared by the runner middleware-config helper (pkg/runner) and the +// standalone `thv proxy` command to keep the default-derivation logic in one +// place; exported because the `thv proxy` call site is outside the runner +// package and cannot reach an internal helper. +// +// What this does NOT solve: it does not validate that `explicit` entries are +// well-formed Origin values. Callers that pass operator-supplied slices must +// rely on the middleware's canonical matching to either accept or reject +// malformed entries at request time (they will simply fail to match). +func ResolveAllowedOrigins(host string, port int, explicit []string) []string { + if len(explicit) > 0 { + return explicit + } + if port <= 0 { + return nil + } + if !isLoopbackHost(host) { + return nil + } + return []string{ + fmt.Sprintf("http://localhost:%d", port), + fmt.Sprintf("http://127.0.0.1:%d", port), + fmt.Sprintf("http://[::1]:%d", port), + } +} + +// isLoopbackHost reports whether host refers to a loopback address. Accepts +// the literal string "localhost" plus any IP literal that net.ParseIP +// classifies as loopback (e.g. 127.0.0.0/8, ::1). IPv6 is currently rejected +// by cmd/thv/app/run.go:ValidateAndNormaliseHostFlag; this helper nevertheless +// handles it so future IPv6 support does not silently lose default Origin +// protection. +func isLoopbackHost(host string) bool { + if host == "localhost" { + return true + } + // Strip bracket form for IPv6 literals: "[::1]" → "::1". + trimmed := strings.TrimSuffix(strings.TrimPrefix(host, "["), "]") + if ip := net.ParseIP(trimmed); ip != nil { + return ip.IsLoopback() + } + return false +} + +// writeForbidden emits a 403 response with a JSON-RPC error body (id: null). +func writeForbidden(w http.ResponseWriter) { + body, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "error": map[string]any{ + "code": jsonRPCCodeInvalidRequest, + "message": "Origin not allowed", + }, + "id": nil, + }) + if err != nil { + // Marshal of a static map should never fail; fall back to a literal. + body = []byte(forbiddenBodyFallback) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusForbidden) + //nolint:gosec // G104: writing a static JSON error response to an HTTP client + _, _ = w.Write(body) +} diff --git a/pkg/transport/middleware/origin/origin_test.go b/pkg/transport/middleware/origin/origin_test.go new file mode 100644 index 0000000000..3931053490 --- /dev/null +++ b/pkg/transport/middleware/origin/origin_test.go @@ -0,0 +1,352 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package origin + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/transport/types" + typesmocks "github.com/stacklok/toolhive/pkg/transport/types/mocks" +) + +// runMiddleware applies the middleware to a stub handler, issues a request +// with the given Origin header (skipped when empty), and returns the response. +func runMiddleware( + t *testing.T, + allowedOrigins []string, + origin string, +) (*httptest.ResponseRecorder, bool) { + t.Helper() + var nextCalled bool + mw := createOriginHandler(allowedOrigins) + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + })) + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + if origin != "" { + req.Header.Set("Origin", origin) + } + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + return rec, nextCalled +} + +func TestOriginMiddleware_RequestPermitted(t *testing.T) { + t.Parallel() + tests := []struct { + name string + allowedOrigins []string + origin string + }{ + { + name: "empty allowlist disables middleware", + allowedOrigins: nil, + origin: "http://evil.example", + }, + { + name: "missing Origin header passes", + allowedOrigins: []string{"http://localhost:8080"}, + origin: "", + }, + { + name: "exact match passes", + allowedOrigins: []string{"http://localhost:8080"}, + origin: "http://localhost:8080", + }, + { + name: "match against second entry", + allowedOrigins: []string{"http://localhost:8080", "http://127.0.0.1:8080"}, + origin: "http://127.0.0.1:8080", + }, + { + name: "case-insensitive scheme match (RFC 6454)", + allowedOrigins: []string{"http://app.example.com"}, + origin: "HTTP://app.example.com", + }, + { + name: "case-insensitive host match (RFC 6454)", + allowedOrigins: []string{"https://App.Example.com"}, + origin: "https://app.example.com", + }, + { + name: "mixed-case allowlist entry matches lowercase Origin", + allowedOrigins: []string{"HTTPS://App.Example.com:443"}, + origin: "https://app.example.com:443", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + rec, nextCalled := runMiddleware(t, tc.allowedOrigins, tc.origin) + assert.True(t, nextCalled, "next handler must be invoked") + assert.Equal(t, http.StatusOK, rec.Code) + }) + } +} + +func TestOriginMiddleware_RequestRejected(t *testing.T) { + t.Parallel() + tests := []struct { + name string + allowedOrigins []string + origin string + }{ + { + name: "different host rejected", + allowedOrigins: []string{"http://localhost:8080"}, + origin: "http://evil.example", + }, + { + name: "different port rejected (exact match required)", + allowedOrigins: []string{"http://localhost:8080"}, + origin: "http://localhost:9090", + }, + { + name: "different scheme rejected", + allowedOrigins: []string{"https://app.example.com"}, + origin: "http://app.example.com", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + rec, nextCalled := runMiddleware(t, tc.allowedOrigins, tc.origin) + assertForbiddenJSONRPC(t, rec, nextCalled) + }) + } +} + +func TestOriginMiddleware_MultipleOriginHeadersRejected(t *testing.T) { + t.Parallel() + + var nextCalled bool + mw := createOriginHandler([]string{"http://localhost:8080"}) + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Add("Origin", "http://localhost:8080") + req.Header.Add("Origin", "http://evil.example") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assertForbiddenJSONRPC(t, rec, nextCalled) +} + +// assertForbiddenJSONRPC validates that rec carries a 403 with a canonical +// JSON-RPC error body and that the inner handler was never invoked. +func assertForbiddenJSONRPC(t *testing.T, rec *httptest.ResponseRecorder, nextCalled bool) { + t.Helper() + assert.False(t, nextCalled, "next handler must NOT be invoked") + assert.Equal(t, http.StatusForbidden, rec.Code) + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + + body, err := io.ReadAll(rec.Body) + require.NoError(t, err) + var parsed struct { + JSONRPC string `json:"jsonrpc"` + Error struct { + Code int64 `json:"code"` + Message string `json:"message"` + } `json:"error"` + ID any `json:"id"` + } + require.NoError(t, json.Unmarshal(body, &parsed)) + assert.Equal(t, "2.0", parsed.JSONRPC) + assert.Equal(t, jsonRPCCodeInvalidRequest, parsed.Error.Code) + assert.Equal(t, "Origin not allowed", parsed.Error.Message) + assert.Nil(t, parsed.ID) +} + +func TestCreateOriginMiddleware_PublicAPI(t *testing.T) { + t.Parallel() + mw := CreateOriginMiddleware([]string{"http://localhost:8080"}) + require.NotNil(t, mw) + + // Sanity-check it behaves the same as the internal constructor. + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("Origin", "http://evil.example") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestCreateMiddleware_Factory(t *testing.T) { + t.Parallel() + + t.Run("valid parameters register middleware", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + runner := typesmocks.NewMockMiddlewareRunner(ctrl) + + params := MiddlewareParams{AllowedOrigins: []string{"http://localhost:8080"}} + cfg, err := types.NewMiddlewareConfig(MiddlewareType, params) + require.NoError(t, err) + + runner.EXPECT(). + AddMiddleware(MiddlewareType, gomock.AssignableToTypeOf(&FactoryMiddleware{})). + Times(1) + + require.NoError(t, CreateMiddleware(cfg, runner)) + }) + + t.Run("invalid JSON returns error", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + runner := typesmocks.NewMockMiddlewareRunner(ctrl) + + cfg := &types.MiddlewareConfig{ + Type: MiddlewareType, + Parameters: json.RawMessage(`{not json}`), + } + + err := CreateMiddleware(cfg, runner) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to unmarshal origin middleware parameters") + }) + + t.Run("empty allowlist still registers pass-through", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + runner := typesmocks.NewMockMiddlewareRunner(ctrl) + + cfg, err := types.NewMiddlewareConfig(MiddlewareType, MiddlewareParams{AllowedOrigins: nil}) + require.NoError(t, err) + + runner.EXPECT(). + AddMiddleware(MiddlewareType, gomock.AssignableToTypeOf(&FactoryMiddleware{})). + Times(1) + + require.NoError(t, CreateMiddleware(cfg, runner)) + }) +} + +func TestFactoryMiddleware_Lifecycle(t *testing.T) { + t.Parallel() + + mw := &FactoryMiddleware{handler: createOriginHandler([]string{"http://localhost:8080"})} + require.NotNil(t, mw.Handler()) + require.NoError(t, mw.Close()) +} + +func TestResolveAllowedOrigins(t *testing.T) { + t.Parallel() + tests := []struct { + name string + host string + port int + explicit []string + want []string + }{ + { + name: "explicit list wins over loopback derivation", + host: "127.0.0.1", + port: 8080, + explicit: []string{"https://app.example.com"}, + want: []string{"https://app.example.com"}, + }, + { + name: "loopback IPv4 auto-derives localhost defaults", + host: "127.0.0.1", + port: 8080, + want: []string{ + "http://localhost:8080", + "http://127.0.0.1:8080", + "http://[::1]:8080", + }, + }, + { + name: "non-standard loopback IPv4 auto-derives defaults", + host: "127.0.0.2", + port: 8080, + want: []string{ + "http://localhost:8080", + "http://127.0.0.1:8080", + "http://[::1]:8080", + }, + }, + { + name: "localhost string auto-derives defaults", + host: "localhost", + port: 8080, + want: []string{ + "http://localhost:8080", + "http://127.0.0.1:8080", + "http://[::1]:8080", + }, + }, + { + name: "IPv6 loopback ::1 auto-derives defaults", + host: "::1", + port: 9090, + want: []string{ + "http://localhost:9090", + "http://127.0.0.1:9090", + "http://[::1]:9090", + }, + }, + { + name: "IPv6 loopback in bracket form auto-derives defaults", + host: "[::1]", + port: 9090, + want: []string{ + "http://localhost:9090", + "http://127.0.0.1:9090", + "http://[::1]:9090", + }, + }, + { + name: "non-loopback host with empty explicit returns nil", + host: "0.0.0.0", + port: 8080, + want: nil, + }, + { + name: "public host with empty explicit returns nil", + host: "192.168.1.10", + port: 8080, + want: nil, + }, + { + name: "garbage host returns nil", + host: "not-a-host", + port: 8080, + want: nil, + }, + { + name: "zero port disables derivation", + host: "127.0.0.1", + port: 0, + want: nil, + }, + { + name: "negative port disables derivation", + host: "127.0.0.1", + port: -1, + want: nil, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := ResolveAllowedOrigins(tc.host, tc.port, tc.explicit) + assert.Equal(t, tc.want, got) + }) + } +} diff --git a/pkg/transport/proxy/httpsse/http_proxy.go b/pkg/transport/proxy/httpsse/http_proxy.go index 9890317fcc..1821a36ade 100644 --- a/pkg/transport/proxy/httpsse/http_proxy.go +++ b/pkg/transport/proxy/httpsse/http_proxy.go @@ -355,7 +355,10 @@ func (p *HTTPSSEProxy) handleSSEConnection(w http.ResponseWriter, r *http.Reques w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") + // CORS headers deliberately omitted: the origin middleware + // (pkg/transport/middleware/origin) enforces Origin validation per + // MCP 2025-11-25 §"Security Warning". Reflecting Origin or emitting + // `*` here would bypass that protection. // Create a unique client ID clientID := uuid.New().String()