Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ func (c *Client) capabilities(protocolVersion string) *ClientCapabilities {
// server, calls or notifications will return an error wrapping
// [ErrConnectionClosed].
func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOptions) (cs *ClientSession, err error) {
cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil, c.opts.Logger)
cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil, nil, c.opts.Logger)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -405,7 +405,7 @@ func (cs *ClientSession) registerElicitationWaiter(elicitationID string) (await

// startKeepalive starts the keepalive mechanism for this client session.
func (cs *ClientSession) startKeepalive(interval time.Duration) {
startKeepalive(cs, interval, &cs.keepaliveCancel, cs.client.opts.Logger)
startKeepalive(cs, interval, &cs.keepaliveCancel, nil, cs.client.opts.Logger)
}

// AddRoots adds the given roots to the client,
Expand Down Expand Up @@ -442,7 +442,7 @@ func changeAndNotify[P Params](c *Client, notification string, params P, change
}
}
c.mu.Unlock()
notifySessions(sessions, notification, params, c.opts.Logger)
notifySessions(sessions, notification, params, c.opts.Logger, nil)
}

// shouldSendListChangedNotification checks if the client's capabilities allow
Expand Down
26 changes: 22 additions & 4 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@ type ServerOptions struct {
// trade-offs and usage guidance.
SchemaCache *SchemaCache

// ErrorHandler, if non-nil, is called with out-of-band errors that occur
// during server operation but are not associated with a specific request.
// Examples include keepalive ping failures, notification delivery errors,
// and internal JSON-RPC protocol errors.
//
// If nil, these errors are logged using [ServerOptions.Logger] at the
// appropriate level.
ErrorHandler func(error)

// GetSessionID provides the next session ID to use for an incoming request.
// If nil, a default randomly generated ID will be used.
//
Expand Down Expand Up @@ -198,6 +207,15 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server {
}
}

// reportError reports an out-of-band error via the ErrorHandler, or logs it.
func (s *Server) reportError(err error) {
if h := s.opts.ErrorHandler; h != nil {
h(err)
} else {
s.opts.Logger.Error("out-of-band error", "error", err)
}
}

// AddPrompt adds a [Prompt] to the server, or replaces one with the same name.
func (s *Server) AddPrompt(p *Prompt, h PromptHandler) {
// Assume there was a change, since add replaces existing items.
Expand Down Expand Up @@ -655,7 +673,7 @@ func (s *Server) notifySessions(n string) {
sessions := slices.Clone(s.sessions)
s.pendingNotifications[n] = nil
s.mu.Unlock() // Don't hold the lock during notification: it causes deadlock.
notifySessions(sessions, n, changeNotificationParams[n], s.opts.Logger)
notifySessions(sessions, n, changeNotificationParams[n], s.opts.Logger, s.opts.ErrorHandler)
}

// shouldSendListChangedNotification checks if the server's capabilities allow
Expand Down Expand Up @@ -884,7 +902,7 @@ func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNot
subscribedSessions := s.resourceSubscriptions[params.URI]
sessions := slices.Collect(maps.Keys(subscribedSessions))
s.mu.Unlock()
notifySessions(sessions, notificationResourceUpdated, params, s.opts.Logger)
notifySessions(sessions, notificationResourceUpdated, params, s.opts.Logger, s.opts.ErrorHandler)
s.opts.Logger.Info("resource updated notification sent", "uri", params.URI, "subscriber_count", len(sessions))
return nil
}
Expand Down Expand Up @@ -1026,7 +1044,7 @@ func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOp
}

s.opts.Logger.Info("server connecting")
ss, err := connect(ctx, t, s, state, onClose, s.opts.Logger)
ss, err := connect(ctx, t, s, state, onClose, s.opts.ErrorHandler, s.opts.Logger)
if err != nil {
s.opts.Logger.Error("server connect error", "error", err)
return nil, err
Expand Down Expand Up @@ -1531,7 +1549,7 @@ func (ss *ServerSession) Wait() error {

// startKeepalive starts the keepalive mechanism for this server session.
func (ss *ServerSession) startKeepalive(interval time.Duration) {
startKeepalive(ss, interval, &ss.keepaliveCancel, ss.server.opts.Logger)
startKeepalive(ss, interval, &ss.keepaliveCancel, ss.server.opts.ErrorHandler, ss.server.opts.Logger)
}

// pageToken is the internal structure for the opaque pagination cursor.
Expand Down
25 changes: 25 additions & 0 deletions mcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"log"
"log/slog"
"slices"
Expand Down Expand Up @@ -960,3 +961,27 @@ func TestServerCapabilitiesOverWire(t *testing.T) {
})
}
}

func TestErrorHandler(t *testing.T) {
t.Run("reportError calls ErrorHandler", func(t *testing.T) {
var got error
s := NewServer(testImpl, &ServerOptions{
ErrorHandler: func(err error) { got = err },
})
s.reportError(errors.New("test error"))
if got == nil || got.Error() != "test error" {
t.Errorf("ErrorHandler got %v, want 'test error'", got)
}
})

t.Run("reportError falls back to logger", func(t *testing.T) {
var buf bytes.Buffer
s := NewServer(testImpl, &ServerOptions{
Logger: slog.New(slog.NewTextHandler(&buf, nil)),
})
s.reportError(errors.New("logged error"))
if !strings.Contains(buf.String(), "logged error") {
t.Errorf("log output = %q, want containing 'logged error'", buf.String())
}
})
}
24 changes: 16 additions & 8 deletions mcp/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,8 @@ const (
// notifySessions calls Notify on all the sessions.
// Should be called on a copy of the peer sessions.
// The logger must be non-nil.
func notifySessions[S Session, P Params](sessions []S, method string, params P, logger *slog.Logger) {
// If onError is non-nil, it is called for each notification error instead of logging.
func notifySessions[S Session, P Params](sessions []S, method string, params P, logger *slog.Logger, onError func(error)) {
if sessions == nil {
return
}
Expand All @@ -406,7 +407,11 @@ func notifySessions[S Session, P Params](sessions []S, method string, params P,
for _, s := range sessions {
req := newRequest(s, params)
if err := handleNotify(ctx, method, req); err != nil {
logger.Warn(fmt.Sprintf("calling %s: %v", method, err))
if onError != nil {
onError(fmt.Errorf("calling %s: %w", method, err))
} else {
logger.Warn(fmt.Sprintf("calling %s: %v", method, err))
}
}
}
}
Expand Down Expand Up @@ -583,9 +588,10 @@ type keepaliveSession interface {
// It assigns the cancel function to the provided cancelPtr and starts a goroutine
// that sends ping messages at the specified interval.
//
// logger must be non-nil; ping failures (which terminate the keepalive loop and
// close the session) are reported via logger so they are not silently dropped.
func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr *context.CancelFunc, logger *slog.Logger) {
// If onError is non-nil, it is called when a ping fails before the session is
// closed. Otherwise the failure is reported via logger (which must be non-nil)
// so it is not silently dropped. See #218.
func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr *context.CancelFunc, onError func(error), logger *slog.Logger) {
ctx, cancel := context.WithCancel(context.Background())
// Assign cancel function before starting goroutine to avoid race condition.
// We cannot return it because the caller may need to cancel during the
Expand All @@ -605,9 +611,11 @@ func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr
err := session.Ping(pingCtx, nil)
pingCancel()
if err != nil {
// Ping failed; log it before closing the session so the
// failure is observable to operators. See #218.
logger.Error("keepalive ping failed; closing session", "error", err)
if onError != nil {
onError(fmt.Errorf("keepalive ping failed: %w", err))
} else {
logger.Error("keepalive ping failed; closing session", "error", err)
}
_ = session.Close()
return
}
Expand Down
16 changes: 10 additions & 6 deletions mcp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,11 @@ type handler interface {
handle(ctx context.Context, req *jsonrpc.Request) (any, error)
}

// connect wires a transport to a binder. logger must be non-nil; it receives
// jsonrpc2 internal errors that would otherwise be dropped (see #218).
func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State, onClose func(), logger *slog.Logger) (H, error) {
// connect wires a transport to a binder.
//
// If onError is non-nil, it receives jsonrpc2 internal errors; otherwise they
// are reported via logger (which must be non-nil). See #218.
func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State, onClose func(), onError func(error), logger *slog.Logger) (H, error) {
var zero H
mcpConn, err := t.Connect(ctx)
if err != nil {
Expand All @@ -171,6 +173,10 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H,
preempter.conn = conn
return jsonrpc2.HandlerFunc(h.handle)
}
onInternalError := func(err error) { logger.Error("jsonrpc2 internal error", "error", err) }
if onError != nil {
onInternalError = func(err error) { onError(fmt.Errorf("jsonrpc2: %w", err)) }
}
_ = jsonrpc2.NewConnection(ctx, jsonrpc2.ConnectionConfig{
Reader: reader,
Writer: writer,
Expand All @@ -180,9 +186,7 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H,
OnDone: func() {
b.disconnect(h)
},
OnInternalError: func(err error) {
logger.Error("jsonrpc2 internal error", "error", err)
},
OnInternalError: onInternalError,
})
assert(preempter.conn != nil, "unbound preempter")
return h, nil
Expand Down
Loading