diff --git a/pkg/vmcp/server/datastorage_injection_test.go b/pkg/vmcp/server/datastorage_injection_test.go new file mode 100644 index 0000000000..f50b57c9d9 --- /dev/null +++ b/pkg/vmcp/server/datastorage_injection_test.go @@ -0,0 +1,162 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package server_test + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" + discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks" + "github.com/stacklok/toolhive/pkg/vmcp/mocks" + routerMocks "github.com/stacklok/toolhive/pkg/vmcp/router/mocks" + "github.com/stacklok/toolhive/pkg/vmcp/server" +) + +// countingDataStorage wraps a real LocalSessionDataStorage and counts how +// many times Close has been invoked. Used to assert that Server.Stop does +// not close a caller-supplied DataStorage. +type countingDataStorage struct { + transportsession.DataStorage + closeCalls atomic.Int32 +} + +func (c *countingDataStorage) Close() error { + c.closeCalls.Add(1) + return c.DataStorage.Close() +} + +func newCountingDataStorage(t *testing.T) *countingDataStorage { + t.Helper() + inner, err := transportsession.NewLocalSessionDataStorage(5 * time.Minute) + require.NoError(t, err) + return &countingDataStorage{DataStorage: inner} +} + +func TestNew_CallerOwnedDataStorageNotClosedOnStop(t *testing.T) { + t.Parallel() + + spy := newCountingDataStorage(t) + // The spy is caller-owned; close the inner LocalSessionDataStorage + // directly at the end of the test so the counter is not ticked by + // cleanup — the post-Stop assertion below must reflect only the server's + // behaviour. Err ignored: closing an already-closed local store is a + // no-op in this implementation. + t.Cleanup(func() { + _ = spy.DataStorage.Close() + }) + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRouter := routerMocks.NewMockRouter(ctrl) + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockDiscoveryMgr.EXPECT().Stop().Times(1) + + srv, err := server.New( + t.Context(), + &server.Config{ + Host: "127.0.0.1", + Port: 0, + SessionFactory: newNoopMockFactory(t), + DataStorage: spy, + }, + mockRouter, + mockBackendClient, + mockDiscoveryMgr, + vmcp.NewImmutableRegistry([]vmcp.Backend{}), + nil, + ) + require.NoError(t, err) + + err = srv.Stop(t.Context()) + require.NoError(t, err) + + assert.Equal(t, int32(0), spy.closeCalls.Load(), + "server must not close a caller-supplied DataStorage") +} + +func TestNew_BothSessionStorageAndDataStorageErrors(t *testing.T) { + t.Parallel() + + spy := newCountingDataStorage(t) + // Err ignored: closing an already-closed local store is a no-op. + t.Cleanup(func() { + _ = spy.DataStorage.Close() + }) + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRouter := routerMocks.NewMockRouter(ctrl) + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + + _, err := server.New( + t.Context(), + &server.Config{ + Host: "127.0.0.1", + Port: 0, + SessionFactory: newNoopMockFactory(t), + SessionStorage: &vmcpconfig.SessionStorageConfig{ + Provider: "redis", + Address: "127.0.0.1:6379", + }, + DataStorage: spy, + }, + mockRouter, + mockBackendClient, + mockDiscoveryMgr, + vmcp.NewImmutableRegistry([]vmcp.Backend{}), + nil, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "DataStorage") + assert.Contains(t, err.Error(), "SessionStorage") + assert.Equal(t, int32(0), spy.closeCalls.Load(), + "server must not close a caller-supplied DataStorage on misconfiguration") +} + +func TestNew_ServerBuiltDataStorageStopSucceeds(t *testing.T) { + // Guards against accidental regression of the server-owned close path + // when Close moved from an inline Stop() block onto sessionDataStorageCloser. + // Stop() must still complete without error when the server built the store. + // This is a smoke test — it cannot observe Close on the internal + // LocalSessionDataStorage because that type is constructed inside New(). + t.Parallel() + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRouter := routerMocks.NewMockRouter(ctrl) + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockDiscoveryMgr.EXPECT().Stop().Times(1) + + srv, err := server.New( + t.Context(), + &server.Config{ + Host: "127.0.0.1", + Port: 0, + SessionFactory: newNoopMockFactory(t), + SessionStorage: &vmcpconfig.SessionStorageConfig{Provider: "memory"}, + }, + mockRouter, + mockBackendClient, + mockDiscoveryMgr, + vmcp.NewImmutableRegistry([]vmcp.Backend{}), + nil, + ) + require.NoError(t, err) + + require.NoError(t, srv.Stop(t.Context())) +} diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index 23c415a551..4d462755aa 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -180,7 +180,38 @@ type Config struct { // When provider is "redis", a Redis-backed store is created for cross-pod // session persistence; the Redis password is read from the // THV_SESSION_REDIS_PASSWORD environment variable. + // + // Mutually exclusive with DataStorage: setting both is rejected at New(). SessionStorage *vmcpconfig.SessionStorageConfig + + // DataStorage optionally injects a caller-supplied session metadata store, + // bypassing the built-in memory/redis providers. When non-nil, the server + // uses this store as-is and SessionStorage is ignored in its entirety (no + // field of SessionStorage is consulted). Setting both DataStorage and a + // non-empty SessionStorage.Provider is rejected at New() as ambiguous + // configuration. + // + // Lifecycle: the caller owns it. The server does NOT call Close() on a + // caller-supplied DataStorage, even on error paths in New() or during + // Stop(). The caller is responsible for invoking Close() exactly once + // after Server.Stop() returns (not before — the session manager may issue + // final Update calls during Stop). The caller is likewise responsible for + // configuring the store's TTL; cfg.SessionTTL applies only to the + // transport-level session manager, not to the caller-supplied DataStorage. + // + // Sensitive material: the store holds HMAC-hashed token material and + // other session metadata. Embedders should treat the backing datastore as + // sensitive (dedicated credentials, encryption at rest, restricted read + // access). Implementations must not include credentials in Close() error + // messages — those errors are surfaced through Server.Stop(). + // + // This seam lets embedders satisfy transportsession.DataStorage against + // datastores other than the built-in providers (e.g. Postgres, DynamoDB) + // without forking the server. It enables cross-replica session metadata + // sharing when backed by a shared store; it does NOT solve cross-replica + // message delivery — callers still need session affinity at the load + // balancer for streaming responses. + DataStorage transportsession.DataStorage } // Server is the Virtual MCP Server that aggregates multiple backends. @@ -223,10 +254,16 @@ type Server struct { sessionManager *transportsession.Manager // sessionDataStorage is the pluggable key-value backend for session metadata. - // Currently always LocalSessionDataStorage (in-memory, single-process). - // Redis-backed storage for multi-pod deployments is not yet wired. + // It may be LocalSessionDataStorage (in-memory, single-process), a Redis-backed + // store, or a caller-supplied implementation injected via Config.DataStorage. sessionDataStorage transportsession.DataStorage + // sessionDataStorageCloser closes sessionDataStorage on shutdown. It is + // set only when the server built the store itself (memory or redis + // providers). When Config.DataStorage was supplied by the caller, this is + // nil and the caller is responsible for closing the store. + sessionDataStorageCloser func(context.Context) error + // Capability adapter for converting aggregator types to SDK types capabilityAdapter *adapter.CapabilityAdapter @@ -256,21 +293,51 @@ type Server struct { } // buildSessionDataStorage constructs the DataStorage backend from cfg. -// When cfg.SessionStorage is nil or provider is "memory" (or empty), local in-process -// storage is used. When provider is "redis", a Redis-backed store is created -// using the address, DB, and key prefix from cfg.SessionStorage; the password -// is read from the THV_SESSION_REDIS_PASSWORD environment variable. -// Any other provider value is a misconfiguration and returns an error. -func buildSessionDataStorage(ctx context.Context, cfg *Config) (transportsession.DataStorage, error) { +// +// Resolution order: +// +// 1. cfg.DataStorage (caller-supplied) takes precedence. When non-nil, the +// store is returned as-is with a nil closer — the caller owns the +// lifecycle. Setting both cfg.DataStorage and a non-empty +// cfg.SessionStorage.Provider is rejected as ambiguous. +// 2. cfg.SessionStorage.Provider "memory" (or empty, or nil SessionStorage): +// local in-process storage is created. +// 3. cfg.SessionStorage.Provider "redis": a Redis-backed store is created +// using the address, DB, and key prefix from cfg.SessionStorage. The +// password is read from the THV_SESSION_REDIS_PASSWORD environment +// variable. +// 4. Any other provider value is a misconfiguration and returns an error. +// +// For cases 2 and 3 (server-built stores), the returned closer wraps the +// store's Close method. For case 1 (caller-supplied), the closer is nil. +// New() routes the returned closer through Server.sessionDataStorageCloser +// so Close is invoked on shutdown (and on New() error after this point) — +// but only for server-built stores. +func buildSessionDataStorage( + ctx context.Context, + cfg *Config, +) (transportsession.DataStorage, func(context.Context) error, error) { + if cfg.DataStorage != nil { + if cfg.SessionStorage != nil && cfg.SessionStorage.Provider != "" { + return nil, nil, fmt.Errorf( + "cannot set both Config.DataStorage and Config.SessionStorage.Provider (%q); pick one", + cfg.SessionStorage.Provider) + } + return cfg.DataStorage, nil, nil + } // Default to in-process storage when session storage is not configured, // or when the provider is explicitly "memory" or left empty. if cfg.SessionStorage == nil || cfg.SessionStorage.Provider == "" || strings.EqualFold(cfg.SessionStorage.Provider, "memory") { - return transportsession.NewLocalSessionDataStorage(cfg.SessionTTL) + store, err := transportsession.NewLocalSessionDataStorage(cfg.SessionTTL) + if err != nil { + return nil, nil, err + } + return store, closerFor(store), nil } if cfg.SessionStorage.Provider != "redis" { - return nil, fmt.Errorf("unsupported session storage provider %q (supported: \"memory\", \"redis\")", + return nil, nil, fmt.Errorf("unsupported session storage provider %q (supported: \"memory\", \"redis\")", cfg.SessionStorage.Provider) } keyPrefix := cfg.SessionStorage.KeyPrefix @@ -288,7 +355,19 @@ func buildSessionDataStorage(ctx context.Context, cfg *Config) (transportsession "db", cfg.SessionStorage.DB, "key_prefix", keyPrefix, ) - return transportsession.NewRedisSessionDataStorage(ctx, redisCfg, cfg.SessionTTL) + store, err := transportsession.NewRedisSessionDataStorage(ctx, redisCfg, cfg.SessionTTL) + if err != nil { + return nil, nil, err + } + return store, closerFor(store), nil +} + +// closerFor adapts DataStorage.Close (no context) to the +// func(context.Context) error signature used by Server.sessionDataStorageCloser. +func closerFor(store transportsession.DataStorage) func(context.Context) error { + return func(context.Context) error { + return store.Close() + } } // New creates a new Virtual MCP Server instance. @@ -412,16 +491,18 @@ func New( // keyed by the same session ID. sessionManager := transportsession.NewManager(cfg.SessionTTL, transportsession.NewStreamableSession) - sessionDataStorage, err := buildSessionDataStorage(ctx, cfg) + sessionDataStorage, sessionDataStorageCloser, err := buildSessionDataStorage(ctx, cfg) if err != nil { return nil, fmt.Errorf("failed to create session data storage: %w", err) } - // Close sessionDataStorage if New() returns an error after this point so the - // background cleanup goroutine does not leak. - closeStorageOnErr := true + // If we built the store ourselves, close it when New() returns an error + // after this point so the background cleanup goroutine does not leak. + // For a caller-supplied store (sessionDataStorageCloser == nil), the + // caller owns the lifecycle and we leave it untouched on every path. + closeStorageOnErr := sessionDataStorageCloser != nil defer func() { if closeStorageOnErr { - _ = sessionDataStorage.Close() + _ = sessionDataStorageCloser(ctx) } }() @@ -486,6 +567,12 @@ func New( srv.shutdownFuncs = append(srv.shutdownFuncs, optimizerCleanup) } + // Store the session data storage closer on the Server so Stop() can invoke + // it last (after session manager and discovery manager have stopped). For + // a caller-supplied store this is nil and Stop() leaves it alone — the + // caller owns the lifecycle. + srv.sessionDataStorageCloser = sessionDataStorageCloser + // Register OnRegisterSession hook to inject capabilities after SDK registers session. // See handleSessionRegistration for implementation details. hooks.AddOnRegisterSession(func(ctx context.Context, session server.ClientSession) { @@ -848,8 +935,10 @@ func (s *Server) Stop(ctx context.Context) error { // Close session data storage last: HTTP server is down (no new in-flight requests), // all other components have stopped (no further restore or liveness checks). - if s.sessionDataStorage != nil { - if err := s.sessionDataStorage.Close(); err != nil { + // Only invoked when the server built the store itself; caller-supplied stores + // (Config.DataStorage) are left for the caller to close. + if s.sessionDataStorageCloser != nil { + if err := s.sessionDataStorageCloser(ctx); err != nil { errs = append(errs, fmt.Errorf("failed to close session data storage: %w", err)) } }