Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions engine/access/rest/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,14 @@ func (b *RouterBuilder) AddWebsocketsRoute(
maxRequestSize int64,
maxResponseSize int64,
dataProviderFactory dp.DataProviderFactory,
limiter *limiters.ConcurrencyLimiter,
streamLimiter *limiters.ConcurrencyLimiter,
) *RouterBuilder {
h := websockets.NewWebSocketHandler(ctx, b.logger, config, chain, maxRequestSize, maxResponseSize, dataProviderFactory)
handler := websockets.NewConnectionLimitedHandler(b.logger, h.HttpHandler, h, limiter)
h := websockets.NewWebSocketHandler(ctx, b.logger, config, chain, maxRequestSize, maxResponseSize, dataProviderFactory, streamLimiter)
b.v1SubRouter.
Methods(http.MethodGet).
Path("/ws").
Name("ws").
Handler(handler)
Handler(h)

return b
}
Expand Down
31 changes: 25 additions & 6 deletions engine/access/rest/websockets/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ import (

dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_providers"
"github.com/onflow/flow-go/engine/access/rest/websockets/models"
"github.com/onflow/flow-go/module/limiters"
"github.com/onflow/flow-go/utils/concurrentmap"
)

Expand Down Expand Up @@ -135,7 +136,8 @@ type Controller struct {
dataProviders *concurrentmap.Map[SubscriptionID, dp.DataProvider]
dataProviderFactory dp.DataProviderFactory
dataProvidersGroup *sync.WaitGroup
limiter *rate.Limiter
rateLimiter *rate.Limiter
streamLimiter *limiters.ConcurrencyLimiter

keepaliveConfig KeepaliveConfig
}
Expand All @@ -145,10 +147,11 @@ func NewWebSocketController(
config Config,
conn WebsocketConnection,
dataProviderFactory dp.DataProviderFactory,
streamLimiter *limiters.ConcurrencyLimiter,
) *Controller {
var limiter *rate.Limiter
var rateLimiter *rate.Limiter
if config.MaxResponsesPerSecond > 0 {
limiter = rate.NewLimiter(rate.Limit(config.MaxResponsesPerSecond), 1)
rateLimiter = rate.NewLimiter(rate.Limit(config.MaxResponsesPerSecond), 1)
}

return &Controller{
Expand All @@ -159,7 +162,8 @@ func NewWebSocketController(
dataProviders: concurrentmap.New[SubscriptionID, dp.DataProvider](),
dataProviderFactory: dataProviderFactory,
dataProvidersGroup: &sync.WaitGroup{},
limiter: limiter,
rateLimiter: rateLimiter,
streamLimiter: streamLimiter,
keepaliveConfig: DefaultKeepaliveConfig(),
}
}
Expand Down Expand Up @@ -442,8 +446,20 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe
return
}

// Check if the global stream limit has been reached.
if !c.streamLimiter.Acquire() {
err := fmt.Errorf("error creating new subscription: maximum number of streams reached")
c.writeErrorResponse(
ctx,
err,
wrapErrorMessage(http.StatusTooManyRequests, err.Error(), models.SubscribeAction, msg.SubscriptionID),
)
return
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

subscriptionID, err := c.parseOrCreateSubscriptionID(msg.SubscriptionID)
if err != nil {
c.streamLimiter.Release()
err = fmt.Errorf("error parsing subscription id: %w", err)
c.writeErrorResponse(
ctx,
Expand All @@ -456,6 +472,7 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe
// register new provider
provider, err := c.dataProviderFactory.NewDataProvider(ctx, subscriptionID.String(), msg.Topic, msg.Arguments, c.multiplexedStream)
if err != nil {
c.streamLimiter.Release()
err = fmt.Errorf("error creating data provider: %w", err)
c.writeErrorResponse(
ctx,
Expand All @@ -478,6 +495,8 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe
// run provider
c.dataProvidersGroup.Add(1)
go func() {
defer c.streamLimiter.Release()

err = provider.Run()
if err != nil {
err = fmt.Errorf("internal error: %w", err)
Expand Down Expand Up @@ -604,9 +623,9 @@ func (c *Controller) parseOrCreateSubscriptionID(id string) (SubscriptionID, err
// An error is returned if the context is canceled or the expected wait time exceeds the context's
// deadline.
func (c *Controller) checkRateLimit(ctx context.Context) error {
if c.limiter == nil {
if c.rateLimiter == nil {
return nil
}

return c.limiter.WaitN(ctx, 1)
return c.rateLimiter.WaitN(ctx, 1)
}
Loading
Loading