Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/mcpgodebug.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Options listed below will be removed in the 1.6.0 version of the SDK.

- `disablecrossoriginprotection` added. If set to `1`, newly added cross-origin
protection will be disabled. The default behavior was changed to enable
cross-origin protection.
cross-origin protection. **Removal of this option was postponed until 1.7.0.**

### 1.4.0

Expand All @@ -37,5 +37,6 @@ Options listed below will be removed in the 1.6.0 version of the SDK.
- `disablelocalhostprotection` added. If set to `1`, newly added DNS rebinding
protection will be disabled. The default behavior was changed to enable DNS rebinding
protection. The protection can also be disabled by setting the
`DisableLocalhostProtection` field in the `StreamableHTTPOptions` struct to
`true`, which is the recommended way to disable the protection long term.
`DisableLocalhostProtection` field in the `StreamableHTTPOptions` or
`SSEOptions` struct to `true`, which is the recommended way to disable
the protection long term. **Removal of this option was postponed until 1.7.0.**
7 changes: 4 additions & 3 deletions internal/docs/mcpgodebug.src.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Options listed below will be removed in the 1.6.0 version of the SDK.

- `disablecrossoriginprotection` added. If set to `1`, newly added cross-origin
protection will be disabled. The default behavior was changed to enable
cross-origin protection.
cross-origin protection. **Removal of this option was postponed until 1.7.0.**

### 1.4.0

Expand All @@ -36,5 +36,6 @@ Options listed below will be removed in the 1.6.0 version of the SDK.
- `disablelocalhostprotection` added. If set to `1`, newly added DNS rebinding
protection will be disabled. The default behavior was changed to enable DNS rebinding
protection. The protection can also be disabled by setting the
`DisableLocalhostProtection` field in the `StreamableHTTPOptions` struct to
`true`, which is the recommended way to disable the protection long term.
`DisableLocalhostProtection` field in the `StreamableHTTPOptions` or
`SSEOptions` struct to `true`, which is the recommended way to disable
the protection long term. **Removal of this option was postponed until 1.7.0.**
57 changes: 52 additions & 5 deletions mcp/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ import (
"crypto/rand"
"fmt"
"io"
"net"
"net/http"
"net/url"
"sync"

"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
"github.com/modelcontextprotocol/go-sdk/internal/util"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
)

Expand Down Expand Up @@ -52,9 +54,25 @@ type SSEHandler struct {
}

// SSEOptions specifies options for an [SSEHandler].
// for now, it is empty, but may be extended in future.
// https://github.com/modelcontextprotocol/go-sdk/issues/507
type SSEOptions struct{}
type SSEOptions struct {
// DisableLocalhostProtection disables automatic DNS rebinding protection.
// By default, requests arriving via a localhost address (127.0.0.1, [::1])
// that have a non-localhost Host header are rejected with 403 Forbidden.
// This protects against DNS rebinding attacks regardless of whether the
// server is listening on localhost specifically or on 0.0.0.0.
//
// Only disable this if you understand the security implications.
// See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise
DisableLocalhostProtection bool

// CrossOriginProtection allows to customize cross-origin protection.
// The deny handler set in the CrossOriginProtection through SetDenyHandler
// is ignored.
// If nil, default (zero-value) cross-origin protection will be used.
// Use `disablecrossoriginprotection` MCPGODEBUG compatibility parameter
// to disable the default protection until v1.7.0.
CrossOriginProtection *http.CrossOriginProtection
}

// NewSSEHandler returns a new [SSEHandler] that creates and manages MCP
// sessions created via incoming HTTP requests.
Expand All @@ -79,6 +97,10 @@ func NewSSEHandler(getServer func(request *http.Request) *Server, opts *SSEOptio
s.opts = *opts
}

if s.opts.CrossOriginProtection == nil {
s.opts.CrossOriginProtection = &http.CrossOriginProtection{}
}

return s
}

Expand Down Expand Up @@ -179,9 +201,34 @@ func (t *SSEServerTransport) Connect(context.Context) (Connection, error) {
}

func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
sessionID := req.URL.Query().Get("sessionid")
// DNS rebinding protection: auto-enabled for localhost servers.
// See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise
if !h.opts.DisableLocalhostProtection && disablelocalhostprotection != "1" {
if localAddr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok && localAddr != nil {
if util.IsLoopback(localAddr.String()) && !util.IsLoopback(req.Host) {
http.Error(w, fmt.Sprintf("Forbidden: invalid Host header %q", req.Host), http.StatusForbidden)
return
}
}
}

// TODO: consider checking Content-Type here. For now, we are lax.
if disablecrossoriginprotection != "1" {
// Verify the 'Origin' header to protect against CSRF attacks.
if err := h.opts.CrossOriginProtection.Check(req); err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
// Validate 'Content-Type' header.
if req.Method == http.MethodPost {
contentType := req.Header.Get("Content-Type")
if contentType != "application/json" {
http.Error(w, "Content-Type must be 'application/json'", http.StatusUnsupportedMediaType)
return
}
}
}

sessionID := req.URL.Query().Get("sessionid")

// For POST requests, the message body is a message to send to a session.
if req.Method == http.MethodPost {
Expand Down
173 changes: 173 additions & 0 deletions mcp/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"

Expand Down Expand Up @@ -221,3 +223,174 @@ func TestSSE405AllowHeader(t *testing.T) {
})
}
}

// TestSSELocalhostProtection verifies that DNS rebinding protection
// is automatically enabled for localhost servers.
func TestSSELocalhostProtection(t *testing.T) {
server := NewServer(testImpl, nil)

tests := []struct {
name string
listenAddr string
hostHeader string
disableProtection bool
wantStatus int
}{
{
name: "127.0.0.1 accepts 127.0.0.1",
listenAddr: "127.0.0.1:0",
hostHeader: "127.0.0.1:1234",
wantStatus: http.StatusOK,
},
{
name: "127.0.0.1 accepts localhost",
listenAddr: "127.0.0.1:0",
hostHeader: "localhost:1234",
wantStatus: http.StatusOK,
},
{
name: "127.0.0.1 rejects evil.com",
listenAddr: "127.0.0.1:0",
hostHeader: "evil.com",
wantStatus: http.StatusForbidden,
},
{
name: "127.0.0.1 rejects evil.com:80",
listenAddr: "127.0.0.1:0",
hostHeader: "evil.com:80",
wantStatus: http.StatusForbidden,
},
{
name: "127.0.0.1 rejects localhost.evil.com",
listenAddr: "127.0.0.1:0",
hostHeader: "localhost.evil.com",
wantStatus: http.StatusForbidden,
},
{
name: "0.0.0.0 via localhost rejects evil.com",
listenAddr: "0.0.0.0:0",
hostHeader: "evil.com",
wantStatus: http.StatusForbidden,
},
{
name: "disabled accepts evil.com",
listenAddr: "127.0.0.1:0",
hostHeader: "evil.com",
disableProtection: true,
wantStatus: http.StatusOK,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := &SSEOptions{
DisableLocalhostProtection: tt.disableProtection,
}
handler := NewSSEHandler(func(req *http.Request) *Server { return server }, opts)

listener, err := net.Listen("tcp", tt.listenAddr)
if err != nil {
t.Fatalf("Failed to listen on %s: %v", tt.listenAddr, err)
}
defer listener.Close()

srv := &http.Server{Handler: handler}
go srv.Serve(listener)
defer srv.Close()

// Use a GET request since it's the entry point for SSE sessions.
// For accepted requests, the response will be a hanging SSE stream,
// but we only need to check the initial status code.
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s", listener.Addr().String()), nil)
if err != nil {
t.Fatal(err)
}
req.Host = tt.hostHeader
req.Header.Set("Accept", "text/event-stream")

resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()

if got := resp.StatusCode; got != tt.wantStatus {
t.Errorf("Status code: got %d, want %d", got, tt.wantStatus)
}
})
}
}

func TestSSEOriginProtection(t *testing.T) {
server := NewServer(testImpl, nil)

tests := []struct {
name string
protection *http.CrossOriginProtection
requestOrigin string
wantStatusCode int
}{
{
name: "default protection with Origin header",
protection: nil,
requestOrigin: "https://example.com",
wantStatusCode: http.StatusForbidden,
},
{
name: "custom protection with trusted origin and same Origin",
protection: func() *http.CrossOriginProtection {
p := http.NewCrossOriginProtection()
if err := p.AddTrustedOrigin("https://example.com"); err != nil {
t.Fatal(err)
}
return p
}(),
requestOrigin: "https://example.com",
wantStatusCode: http.StatusNotFound, // origin accepted; session not found
},
{
name: "custom protection with trusted origin and different Origin",
protection: func() *http.CrossOriginProtection {
p := http.NewCrossOriginProtection()
if err := p.AddTrustedOrigin("https://example.com"); err != nil {
t.Fatal(err)
}
return p
}(),
requestOrigin: "https://malicious.com",
wantStatusCode: http.StatusForbidden,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := &SSEOptions{
CrossOriginProtection: tt.protection,
}
handler := NewSSEHandler(func(req *http.Request) *Server { return server }, opts)
httpServer := httptest.NewServer(handler)
defer httpServer.Close()

// Use POST with a valid session-like URL to test origin protection
// without creating a hanging GET connection.
reqReader := strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"ping"}`)
req, err := http.NewRequest(http.MethodPost, httpServer.URL+"?sessionid=nonexistent", reqReader)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Origin", tt.requestOrigin)

resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()

if got := resp.StatusCode; got != tt.wantStatusCode {
body, _ := io.ReadAll(resp.Body)
t.Errorf("Status code: got %d, want %d (body: %s)", got, tt.wantStatusCode, body)
}
})
}
}
6 changes: 3 additions & 3 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ type StreamableHTTPOptions struct {
// is ignored.
// If nil, default (zero-value) cross-origin protection will be used.
// Use `disablecrossoriginprotection` MCPGODEBUG compatibility parameter
// to disable the default protection until v1.6.0.
// to disable the default protection until v1.7.0.
CrossOriginProtection *http.CrossOriginProtection
}

Expand Down Expand Up @@ -235,14 +235,14 @@ func (h *StreamableHTTPHandler) closeAll() {
// disablelocalhostprotection is a compatibility parameter that allows to disable
// DNS rebinding protection, which was added in the 1.4.0 version of the SDK.
// See the documentation for the mcpgodebug package for instructions how to enable it.
// The option will be removed in the 1.6.0 version of the SDK.
// The option will be removed in the 1.7.0 version of the SDK.
var disablelocalhostprotection = mcpgodebug.Value("disablelocalhostprotection")

// disablecrossoriginprotection is a compatibility parameter that allows to disable
// the verification of the 'Origin' and 'Content-Type' headers, which was added in
// the 1.4.1 version of the SDK. See the documentation for the mcpgodebug package
// for instructions how to enable it.
// The option will be removed in the 1.6.0 version of the SDK.
// The option will be removed in the 1.7.0 version of the SDK.
var disablecrossoriginprotection = mcpgodebug.Value("disablecrossoriginprotection")

func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
Expand Down
Loading