diff --git a/cmd/access/node_builder/access_node_builder.go b/cmd/access/node_builder/access_node_builder.go index e3329d8e9ae..5c3a42fafaa 100644 --- a/cmd/access/node_builder/access_node_builder.go +++ b/cmd/access/node_builder/access_node_builder.go @@ -1763,6 +1763,10 @@ func (builder *FlowAccessNodeBuilder) extraFlags() { return errors.New("execution-data-indexing-enabled must be set if store-tx-result-error-messages is enabled") } + if builder.stateStreamConf.MaxGlobalStreams == 0 { + return errors.New("state-stream-global-max-streams must be greater than 0") + } + return nil }) } diff --git a/cmd/observer/node_builder/observer_builder.go b/cmd/observer/node_builder/observer_builder.go index 069222211f4..62a797ee62a 100644 --- a/cmd/observer/node_builder/observer_builder.go +++ b/cmd/observer/node_builder/observer_builder.go @@ -945,6 +945,10 @@ func (builder *ObserverServiceBuilder) extraFlags() { return errors.New("rest-max-request-size must be greater than 0") } + if builder.stateStreamConf.MaxGlobalStreams == 0 { + return errors.New("state-stream-global-max-streams must be greater than 0") + } + return nil }) } diff --git a/engine/access/rest/router/router.go b/engine/access/rest/router/router.go index 23159704d61..53b9dc01832 100644 --- a/engine/access/rest/router/router.go +++ b/engine/access/rest/router/router.go @@ -108,15 +108,14 @@ func (b *RouterBuilder) AddWebsocketsRoute( maxRequestSize int64, maxResponseSize int64, dataProviderFactory dp.DataProviderFactory, - limiter *limiters.ConcurrencyLimiter, + streamLimiter *limiters.ConcurrencyLimiter, ) *RouterBuilder { - h := websockets.NewWebSocketHandler(ctx, b.logger, config, chain, maxRequestSize, maxResponseSize, dataProviderFactory) - handler := websockets.NewConnectionLimitedHandler(b.logger, h.HttpHandler, h, limiter) + h := websockets.NewWebSocketHandler(ctx, b.logger, config, chain, maxRequestSize, maxResponseSize, dataProviderFactory, streamLimiter) b.v1SubRouter. Methods(http.MethodGet). Path("/ws"). Name("ws"). - Handler(handler) + Handler(h) return b } diff --git a/engine/access/rest/server.go b/engine/access/rest/server.go index 797111b63fc..2dd3bf3520b 100644 --- a/engine/access/rest/server.go +++ b/engine/access/rest/server.go @@ -1,6 +1,7 @@ package rest import ( + "errors" "net/http" "time" @@ -55,6 +56,10 @@ func NewServer( extendedBackend extended.API, limiter *limiters.ConcurrencyLimiter, ) (*http.Server, error) { + if limiter == nil && (stateStreamApi != nil || enableNewWebsocketsStreamAPI) { + return nil, errors.New("stream limiter is required when websocket routes are enabled") + } + builder := router.NewRouterBuilder(logger, restCollector).AddRestRoutes(serverAPI, chain, config.MaxRequestSize, config.MaxResponseSize) if stateStreamApi != nil { builder.AddLegacyWebsocketsRoutes(stateStreamApi, chain, stateStreamConfig, config.MaxRequestSize, config.MaxResponseSize, limiter) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index bf201afe95f..02401f77056 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -89,6 +89,7 @@ import ( dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_providers" "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/module/limiters" "github.com/onflow/flow-go/utils/concurrentmap" ) @@ -135,7 +136,8 @@ type Controller struct { dataProviders *concurrentmap.Map[SubscriptionID, dp.DataProvider] dataProviderFactory dp.DataProviderFactory dataProvidersGroup *sync.WaitGroup - limiter *rate.Limiter + rateLimiter *rate.Limiter + streamLimiter *limiters.ConcurrencyLimiter keepaliveConfig KeepaliveConfig } @@ -145,10 +147,15 @@ func NewWebSocketController( config Config, conn WebsocketConnection, dataProviderFactory dp.DataProviderFactory, -) *Controller { - var limiter *rate.Limiter + streamLimiter *limiters.ConcurrencyLimiter, +) (*Controller, error) { + if streamLimiter == nil { + return nil, errors.New("stream limiter is required") + } + + var rateLimiter *rate.Limiter if config.MaxResponsesPerSecond > 0 { - limiter = rate.NewLimiter(rate.Limit(config.MaxResponsesPerSecond), 1) + rateLimiter = rate.NewLimiter(rate.Limit(config.MaxResponsesPerSecond), 1) } return &Controller{ @@ -159,9 +166,10 @@ func NewWebSocketController( dataProviders: concurrentmap.New[SubscriptionID, dp.DataProvider](), dataProviderFactory: dataProviderFactory, dataProvidersGroup: &sync.WaitGroup{}, - limiter: limiter, + rateLimiter: rateLimiter, + streamLimiter: streamLimiter, keepaliveConfig: DefaultKeepaliveConfig(), - } + }, nil } // HandleConnection manages the lifecycle of a WebSocket connection, @@ -442,8 +450,20 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe return } + // Check if the global stream limit has been reached. + if !c.streamLimiter.Acquire() { + err := fmt.Errorf("error creating new subscription: maximum number of streams reached") + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(http.StatusTooManyRequests, err.Error(), models.SubscribeAction, msg.SubscriptionID), + ) + return + } + subscriptionID, err := c.parseOrCreateSubscriptionID(msg.SubscriptionID) if err != nil { + c.streamLimiter.Release() err = fmt.Errorf("error parsing subscription id: %w", err) c.writeErrorResponse( ctx, @@ -456,6 +476,7 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe // register new provider provider, err := c.dataProviderFactory.NewDataProvider(ctx, subscriptionID.String(), msg.Topic, msg.Arguments, c.multiplexedStream) if err != nil { + c.streamLimiter.Release() err = fmt.Errorf("error creating data provider: %w", err) c.writeErrorResponse( ctx, @@ -478,6 +499,8 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe // run provider c.dataProvidersGroup.Add(1) go func() { + defer c.streamLimiter.Release() + err = provider.Run() if err != nil { err = fmt.Errorf("internal error: %w", err) @@ -604,9 +627,9 @@ func (c *Controller) parseOrCreateSubscriptionID(id string) (SubscriptionID, err // An error is returned if the context is canceled or the expected wait time exceeds the context's // deadline. func (c *Controller) checkRateLimit(ctx context.Context) error { - if c.limiter == nil { + if c.rateLimiter == nil { return nil } - return c.limiter.WaitN(ctx, 1) + return c.rateLimiter.WaitN(ctx, 1) } diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 44df8a0746b..14e90613613 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -23,6 +23,7 @@ import ( connmock "github.com/onflow/flow-go/engine/access/rest/websockets/mock" "github.com/onflow/flow-go/engine/access/rest/websockets/models" "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/module/limiters" "github.com/onflow/flow-go/utils/unittest" ) @@ -30,8 +31,9 @@ import ( type WsControllerSuite struct { suite.Suite - logger zerolog.Logger - wsConfig Config + logger zerolog.Logger + wsConfig Config + streamLimiter *limiters.ConcurrencyLimiter } func TestControllerSuite(t *testing.T) { @@ -42,6 +44,10 @@ func TestControllerSuite(t *testing.T) { func (s *WsControllerSuite) SetupTest() { s.logger = unittest.Logger() s.wsConfig = NewDefaultWebsocketConfig() + + var err error + s.streamLimiter, err = limiters.NewConcurrencyLimiter(1000) + s.Require().NoError(err) } // TestSubscribeRequest tests the subscribe to topic flow. @@ -51,7 +57,8 @@ func (s *WsControllerSuite) TestSubscribeRequest() { t.Parallel() conn, dataProviderFactory, dataProvider := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory, s.streamLimiter) + require.NoError(t, err) dataProviderFactory. On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). @@ -115,7 +122,8 @@ func (s *WsControllerSuite) TestSubscribeRequest() { t.Parallel() conn, dataProviderFactory, _ := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory, s.streamLimiter) + require.NoError(t, err) type Request struct { Action string `json:"action"` @@ -164,7 +172,8 @@ func (s *WsControllerSuite) TestSubscribeRequest() { t.Parallel() conn, dataProviderFactory, _ := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory, s.streamLimiter) + require.NoError(t, err) dataProviderFactory. On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). @@ -201,7 +210,8 @@ func (s *WsControllerSuite) TestSubscribeRequest() { t.Parallel() conn, dataProviderFactory, dataProvider := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory, s.streamLimiter) + require.NoError(t, err) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() @@ -245,12 +255,156 @@ func (s *WsControllerSuite) TestSubscribeRequest() { }) } +// TestGlobalStreamLimiter verifies that the global stream limiter correctly +// controls subscription creation and releases slots on completion or error. +func (s *WsControllerSuite) TestGlobalStreamLimiter() { + s.T().Run("Rejects subscription when global limit reached", func(t *testing.T) { + t.Parallel() + + // Create a limiter with capacity 1 and exhaust it. + streamLimiter, err := limiters.NewConcurrencyLimiter(1) + require.NoError(t, err) + require.True(t, streamLimiter.Acquire()) // exhaust the single slot + + conn, dataProviderFactory, _ := newControllerMocks(t) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory, streamLimiter) + require.NoError(t, err) + + done := make(chan struct{}) + subscriptionID := "dummy-id" + s.expectSubscribeRequest(t, conn, subscriptionID) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + defer close(done) + + response, ok := msg.(models.BaseMessageResponse) + require.True(t, ok) + require.NotEmpty(t, response.Error) + require.Equal(t, http.StatusTooManyRequests, response.Error.Code) + require.Contains(t, response.Error.Message, "maximum number of streams reached") + + return &websocket.CloseError{Code: websocket.CloseNormalClosure} + }) + + s.expectCloseConnection(conn, done) + + controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + // Factory should never be called — rejected before provider creation. + dataProviderFactory.AssertExpectations(t) + + // The externally acquired slot must still be held. + require.False(t, streamLimiter.Acquire(), "externally acquired slot should still be held") + streamLimiter.Release() + }) + + s.T().Run("Releases slot when provider creation fails", func(t *testing.T) { + t.Parallel() + + streamLimiter, err := limiters.NewConcurrencyLimiter(1) + require.NoError(t, err) + + conn, dataProviderFactory, _ := newControllerMocks(t) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory, streamLimiter) + require.NoError(t, err) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(nil, fmt.Errorf("invalid topic")). + Once() + + done := make(chan struct{}) + subscriptionID := "dummy-id" + s.expectSubscribeRequest(t, conn, subscriptionID) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + defer close(done) + + response, ok := msg.(models.BaseMessageResponse) + require.True(t, ok) + require.NotEmpty(t, response.Error) + require.Equal(t, http.StatusBadRequest, response.Error.Code) + + return &websocket.CloseError{Code: websocket.CloseNormalClosure} + }) + + s.expectCloseConnection(conn, done) + + controller.HandleConnection(context.Background()) + + // Slot must have been released despite the error. + require.True(t, streamLimiter.Acquire(), "slot should be released after provider creation failure") + streamLimiter.Release() + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + }) + + s.T().Run("Releases slot when provider completes", func(t *testing.T) { + t.Parallel() + + streamLimiter, err := limiters.NewConcurrencyLimiter(1) + require.NoError(t, err) + + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory, streamLimiter) + require.NoError(t, err) + + dataProvider.On("Close").Return(nil).Maybe() + dataProvider. + On("Run", mock.Anything). + Return(nil). + Once() + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + done := make(chan struct{}) + subscriptionID := "dummy-id" + s.expectSubscribeRequest(t, conn, subscriptionID) + + // When the subscribe OK response is written, close done to trigger + // connection shutdown. The provider runs and completes immediately + // (Run returns nil), so no further writes are expected. + conn. + On("WriteJSON", mock.Anything). + Run(func(args mock.Arguments) { + response, ok := args.Get(0).(models.SubscribeMessageResponse) + require.True(t, ok) + require.Equal(t, subscriptionID, response.SubscriptionID) + close(done) + }). + Return(nil). + Once() + + s.expectCloseConnection(conn, done) + + controller.HandleConnection(context.Background()) + + // Slot must have been released after provider completed. + require.True(t, streamLimiter.Acquire(), "slot should be released after provider completes") + streamLimiter.Release() + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) + }) +} + func (s *WsControllerSuite) TestUnsubscribeRequest() { s.T().Run("Happy path", func(t *testing.T) { t.Parallel() conn, dataProviderFactory, dataProvider := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory, s.streamLimiter) + require.NoError(t, err) dataProviderFactory. On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). @@ -318,7 +472,8 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { t.Parallel() conn, dataProviderFactory, dataProvider := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory, s.streamLimiter) + require.NoError(t, err) dataProviderFactory. On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). @@ -388,7 +543,8 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { t.Parallel() conn, dataProviderFactory, dataProvider := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory, s.streamLimiter) + require.NoError(t, err) dataProviderFactory. On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). @@ -461,7 +617,8 @@ func (s *WsControllerSuite) TestListSubscriptions() { s.T().Run("Happy path", func(t *testing.T) { conn, dataProviderFactory, dataProvider := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory, s.streamLimiter) + require.NoError(t, err) dataProviderFactory. On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). @@ -543,7 +700,8 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { t.Parallel() conn, dataProviderFactory, dataProvider := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory, s.streamLimiter) + require.NoError(t, err) dataProviderFactory. On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). @@ -597,7 +755,8 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { t.Parallel() conn, dataProviderFactory, dataProvider := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory, s.streamLimiter) + require.NoError(t, err) dataProviderFactory. On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). @@ -682,7 +841,8 @@ func (s *WsControllerSuite) TestRateLimiter() { config := NewDefaultWebsocketConfig() config.MaxResponsesPerSecond = 2 - controller := NewWebSocketController(s.logger, config, conn, nil) + controller, err := NewWebSocketController(s.logger, config, conn, nil, s.streamLimiter) + require.NoError(s.T(), err) // Step 3: Simulate sending messages to the controller's `multiplexedStream`. go func() { @@ -733,9 +893,10 @@ func (s *WsControllerSuite) TestConfigureKeepaliveConnection() { conn.On("SetReadDeadline", mock.Anything).Return(nil) factory := dpmock.NewDataProviderFactory(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, factory, s.streamLimiter) + require.NoError(t, err) - err := controller.configureKeepalive() + err = controller.configureKeepalive() s.Require().NoError(err, "configureKeepalive should not return an error") conn.AssertExpectations(t) @@ -752,7 +913,8 @@ func (s *WsControllerSuite) TestControllerShutdown() { conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() factory := dpmock.NewDataProviderFactory(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, factory, s.streamLimiter) + require.NoError(t, err) // Mock keepalive to return an error done := make(chan struct{}, 1) @@ -786,7 +948,8 @@ func (s *WsControllerSuite) TestControllerShutdown() { conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() factory := dpmock.NewDataProviderFactory(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, factory, s.streamLimiter) + require.NoError(t, err) conn. On("ReadJSON", mock.Anything). @@ -803,7 +966,8 @@ func (s *WsControllerSuite) TestControllerShutdown() { t.Parallel() conn, dataProviderFactory, dataProvider := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory, s.streamLimiter) + require.NoError(t, err) dataProviderFactory. On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). @@ -852,7 +1016,8 @@ func (s *WsControllerSuite) TestControllerShutdown() { conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() factory := dpmock.NewDataProviderFactory(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, factory, s.streamLimiter) + require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) @@ -875,7 +1040,8 @@ func (s *WsControllerSuite) TestControllerShutdown() { wsConfig := s.wsConfig wsConfig.InactivityTimeout = 50 * time.Millisecond - controller := NewWebSocketController(s.logger, wsConfig, conn, factory) + controller, err := NewWebSocketController(s.logger, wsConfig, conn, factory, s.streamLimiter) + require.NoError(t, err) conn. On("ReadJSON", mock.Anything). @@ -927,7 +1093,8 @@ func (s *WsControllerSuite) TestKeepaliveRoutine() { }) factory := dpmock.NewDataProviderFactory(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, factory, s.streamLimiter) + require.NoError(t, err) controller.keepaliveConfig = keepaliveConfig controller.HandleConnection(context.Background()) @@ -942,13 +1109,14 @@ func (s *WsControllerSuite) TestKeepaliveRoutine() { Once() factory := dpmock.NewDataProviderFactory(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, factory, s.streamLimiter) + require.NoError(t, err) controller.keepaliveConfig = keepaliveConfig ctx, cancel := context.WithCancel(context.Background()) defer cancel() - err := controller.keepalive(ctx) + err = controller.keepalive(ctx) s.Require().Error(err) s.Require().ErrorIs(expectedError, err) }) @@ -961,13 +1129,14 @@ func (s *WsControllerSuite) TestKeepaliveRoutine() { Once() factory := dpmock.NewDataProviderFactory(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, factory, s.streamLimiter) + require.NoError(t, err) controller.keepaliveConfig = keepaliveConfig ctx, cancel := context.WithCancel(context.Background()) defer cancel() - err := controller.keepalive(ctx) + err = controller.keepalive(ctx) s.Require().Error(err) s.Require().ErrorContains(err, "error sending ping") }) @@ -975,14 +1144,15 @@ func (s *WsControllerSuite) TestKeepaliveRoutine() { s.T().Run("Context cancelled", func(t *testing.T) { conn := connmock.NewWebsocketConnection(t) factory := dpmock.NewDataProviderFactory(t) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) + controller, err := NewWebSocketController(s.logger, s.wsConfig, conn, factory, s.streamLimiter) + require.NoError(t, err) controller.keepaliveConfig = keepaliveConfig ctx, cancel := context.WithCancel(context.Background()) cancel() // Immediately cancel the context // Start the keepalive process with the context canceled - err := controller.keepalive(ctx) + err = controller.keepalive(ctx) s.Require().NoError(err) }) } diff --git a/engine/access/rest/websockets/handler.go b/engine/access/rest/websockets/handler.go index 951e56e7896..8ee100aa07d 100644 --- a/engine/access/rest/websockets/handler.go +++ b/engine/access/rest/websockets/handler.go @@ -10,6 +10,7 @@ import ( dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_providers" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/irrecoverable" + "github.com/onflow/flow-go/module/limiters" ) type Handler struct { @@ -24,6 +25,7 @@ type Handler struct { logger zerolog.Logger websocketConfig Config dataProviderFactory dp.DataProviderFactory + streamLimiter *limiters.ConcurrencyLimiter } var _ http.Handler = (*Handler)(nil) @@ -36,6 +38,7 @@ func NewWebSocketHandler( maxRequestSize int64, maxResponseSize int64, dataProviderFactory dp.DataProviderFactory, + streamLimiter *limiters.ConcurrencyLimiter, ) *Handler { return &Handler{ ctx: ctx, @@ -43,6 +46,7 @@ func NewWebSocketHandler( websocketConfig: config, logger: logger, dataProviderFactory: dataProviderFactory, + streamLimiter: streamLimiter, } } @@ -69,6 +73,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - controller := NewWebSocketController(logger, h.websocketConfig, NewWebsocketConnection(conn), h.dataProviderFactory) + controller, err := NewWebSocketController(logger, h.websocketConfig, NewWebsocketConnection(conn), h.dataProviderFactory, h.streamLimiter) + if err != nil { + h.HttpHandler.ErrorHandler(w, common.NewRestError(http.StatusInternalServerError, "could not create websocket controller: ", err), logger) + return + } controller.HandleConnection(h.ctx) } diff --git a/module/limiters/concurrency_limiter.go b/module/limiters/concurrency_limiter.go index 39654387ef9..539647a1a89 100644 --- a/module/limiters/concurrency_limiter.go +++ b/module/limiters/concurrency_limiter.go @@ -26,23 +26,45 @@ func NewConcurrencyLimiter(maxConcurrent uint32) (*ConcurrencyLimiter, error) { }, nil } -// Allow executes fn if the number of concurrent operations is below the configured limit. -// Returns true if fn was executed, false if the limit was reached and fn was not called. -// The concurrency counter is decremented when fn returns, including on panic. -func (h *ConcurrencyLimiter) Allow(fn func()) bool { +// Acquire atomically increments the concurrency counter if it is below the configured limit. +// Returns true if the slot was acquired, false if the limit was reached. +// The caller MUST call [Release] exactly once after a successful Acquire. +func (h *ConcurrencyLimiter) Acquire() bool { for { current := h.totalConcurrent.Load() if current >= h.maxConcurrent { return false } if h.totalConcurrent.CompareAndSwap(current, current+1) { - break + return true + } + } +} + +// Release decrements the concurrency counter, freeing a slot previously obtained via [Acquire]. +// Must be called exactly once for every successful [Acquire] call. +// Panics if called without a matching [Acquire] (counter underflow). +func (h *ConcurrencyLimiter) Release() { + for { + current := h.totalConcurrent.Load() + if current == 0 { + panic("concurrency limiter release without matching acquire") + } + if h.totalConcurrent.CompareAndSwap(current, current-1) { + return } } +} + +// Allow executes fn if the number of concurrent operations is below the configured limit. +// Returns true if fn was executed, false if the limit was reached and fn was not called. +// The concurrency counter is decremented when fn returns, including on panic. +func (h *ConcurrencyLimiter) Allow(fn func()) bool { + if !h.Acquire() { + return false + } // decrement within a defer to support usecases where panics are handled gracefully by the caller - defer func() { - h.totalConcurrent.Sub(1) - }() + defer h.Release() fn() return true diff --git a/module/limiters/concurrency_limiter_test.go b/module/limiters/concurrency_limiter_test.go index 49acd8b4c3b..4af4f29208b 100644 --- a/module/limiters/concurrency_limiter_test.go +++ b/module/limiters/concurrency_limiter_test.go @@ -90,6 +90,89 @@ func TestConcurrencyLimiter_NewZeroLimit(t *testing.T) { assert.Error(t, err) } +// TestConcurrencyLimiter_Acquire_WithinLimit verifies that Acquire returns true +// when below the concurrency limit. +func TestConcurrencyLimiter_Acquire_WithinLimit(t *testing.T) { + limiter, err := NewConcurrencyLimiter(2) + require.NoError(t, err) + + assert.True(t, limiter.Acquire()) + assert.True(t, limiter.Acquire()) + + limiter.Release() + limiter.Release() +} + +// TestConcurrencyLimiter_Acquire_AtLimit verifies that Acquire returns false +// when the concurrency limit is reached, and succeeds again after Release. +func TestConcurrencyLimiter_Acquire_AtLimit(t *testing.T) { + limiter, err := NewConcurrencyLimiter(1) + require.NoError(t, err) + + assert.True(t, limiter.Acquire()) + assert.False(t, limiter.Acquire(), "second Acquire must fail at limit") + + limiter.Release() + assert.True(t, limiter.Acquire(), "Acquire must succeed after Release") + + limiter.Release() +} + +// TestConcurrencyLimiter_Acquire_ConcurrentCalls verifies that at most maxConcurrent +// slots can be acquired simultaneously across concurrent goroutines. +func TestConcurrencyLimiter_Acquire_ConcurrentCalls(t *testing.T) { + const maxConcurrent = 5 + const totalGoroutines = 50 + + limiter, err := NewConcurrencyLimiter(maxConcurrent) + require.NoError(t, err) + + var ( + peak atomic.Int32 + current atomic.Int32 + wg sync.WaitGroup + ) + + start := make(chan struct{}) + + for i := 0; i < totalGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + if limiter.Acquire() { + n := current.Add(1) + for { + old := peak.Load() + if n <= old || peak.CompareAndSwap(old, n) { + break + } + } + time.Sleep(time.Millisecond) + current.Add(-1) + limiter.Release() + } + }() + } + + close(start) + wg.Wait() + + assert.LessOrEqual(t, peak.Load(), int32(maxConcurrent), + "peak concurrent acquisitions must not exceed maxConcurrent") +} + +// TestConcurrencyLimiter_Release_Underflow verifies that Release panics when called +// without a matching Acquire (counter at zero). +func TestConcurrencyLimiter_Release_Underflow(t *testing.T) { + limiter, err := NewConcurrencyLimiter(1) + require.NoError(t, err) + + assert.PanicsWithValue(t, "concurrency limiter release without matching acquire", func() { + limiter.Release() + }) +} + // TestConcurrencyLimiter_Allow_ConcurrentCalls verifies that at most maxConcurrent // goroutines execute fn simultaneously across a burst of concurrent callers. func TestConcurrencyLimiter_Allow_ConcurrentCalls(t *testing.T) {