Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ endif
############################################################################

PLATFORMS ?= linux/amd64,linux/arm64
BUILDX_BUILDER ?= container-builder

binaries := spire-server spire-agent oidc-discovery-provider

Expand Down Expand Up @@ -332,13 +333,17 @@ spire-buildx-tls:

.PHONY: container-builder
container-builder: $(dockertls)
$(E)docker buildx create $(dockertls) --platform $(PLATFORMS) --name container-builder --node container-builder0 --use
$(E)if [ "$(BUILDX_BUILDER)" = "container-builder" ]; then \
docker buildx inspect container-builder > /dev/null 2>&1 || \
docker buildx create $(dockertls) --platform $(PLATFORMS) --name container-builder --node container-builder0 --use; \
fi

define image_rule
.PHONY: $1
$1: $3 container-builder
@echo Building docker image $2 $(PLATFORM)…
$(E)docker buildx build \
--builder $(BUILDX_BUILDER) \
--platform $(PLATFORMS) \
--build-arg goversion=$(go_version) \
--build-arg TAG=$(TAG) \
Expand Down
2 changes: 2 additions & 0 deletions cmd/spire-agent/cli/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ type agentConfig struct {
ProfilingPort int `hcl:"profiling_port"`
ProfilingFreq int `hcl:"profiling_freq"`
ProfilingNames []string `hcl:"profiling_names"`
WorkloadAPIRateLimit int `hcl:"workload_api_rate_limit"`
Experimental experimentalConfig `hcl:"experimental"`

UnusedKeyPositions map[string][]token.Pos `hcl:",unusedKeyPositions"`
Expand Down Expand Up @@ -559,6 +560,7 @@ func NewAgentConfig(c *Config, logOptions []log.Option, allowUnknownConfig bool)
ac.ProfilingNames = c.Agent.ProfilingNames

ac.AllowedForeignJWTClaims = c.Agent.AllowedForeignJWTClaims
ac.WorkloadAPIRateLimit = c.Agent.WorkloadAPIRateLimit

ac.PluginConfigs, err = catalog.PluginConfigsFromHCLNode(c.Plugins)
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ func (a *Agent) newEndpoints(metrics telemetry.Metrics, mgr manager.Manager, att
AllowUnauthenticatedVerifiers: a.c.AllowUnauthenticatedVerifiers,
AllowedForeignJWTClaims: a.c.AllowedForeignJWTClaims,
TrustDomain: a.c.TrustDomain,
WorkloadAPIRateLimit: a.c.WorkloadAPIRateLimit,
})
}

Expand All @@ -481,6 +482,7 @@ func (a *Agent) newAdminEndpoints(metrics telemetry.Metrics, mgr manager.Manager
Uptime: uptime.Uptime,
Attestor: attestor,
AuthorizedDelegates: authorizedDelegates,
WorkloadAPIRateLimit: a.c.WorkloadAPIRateLimit,
}

return admin_api.New(config)
Expand Down
2 changes: 1 addition & 1 deletion pkg/agent/api/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type Endpoints struct {

func (e *Endpoints) ListenAndServe(ctx context.Context) error {
unaryInterceptor, streamInterceptor := middleware.Interceptors(
endpoints.Middleware(e.c.Log, e.c.Metrics),
endpoints.Middleware(e.c.Log, e.c.Metrics, e.c.WorkloadAPIRateLimit),
)

server := grpc.NewServer(
Expand Down
3 changes: 3 additions & 0 deletions pkg/agent/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ type Config struct {

// TLSPolicy determines the post-quantum-safe TLS policy to apply to all TLS connections.
TLSPolicy tlspolicy.Policy

// WorkloadAPIRateLimit is the rate limit for Workload API calls
WorkloadAPIRateLimit int
}

func New(c *Config) *Agent {
Expand Down
4 changes: 2 additions & 2 deletions pkg/agent/endpoints/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ type Config struct {
AllowUnauthenticatedVerifiers bool

AllowedForeignJWTClaims []string

TrustDomain spiffeid.TrustDomain
TrustDomain spiffeid.TrustDomain
WorkloadAPIRateLimit int

// Hooks used by the unit tests to assert that the configuration provided
// to each handler is correct and return fake handlers.
Expand Down
4 changes: 3 additions & 1 deletion pkg/agent/endpoints/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type Endpoints struct {
workloadAPIServer workload_pb.SpiffeWorkloadAPIServer
sdsv3Server secret_v3.SecretDiscoveryServiceServer
healthServer grpc_health_v1.HealthServer
rateLimit int

hooks struct {
listening chan struct{} // Hook to signal when the server starts listening
Expand Down Expand Up @@ -90,6 +91,7 @@ func New(c Config) *Endpoints {
workloadAPIServer: workloadAPIServer,
sdsv3Server: sdsv3Server,
healthServer: healthServer,
rateLimit: c.WorkloadAPIRateLimit,
hooks: struct {
listening chan struct{}
}{
Expand All @@ -100,7 +102,7 @@ func New(c Config) *Endpoints {

func (e *Endpoints) ListenAndServe(ctx context.Context) error {
unaryInterceptor, streamInterceptor := middleware.Interceptors(
Middleware(e.log, e.metrics),
Middleware(e.log, e.metrics, e.rateLimit),
)

server := grpc.NewServer(
Expand Down
8 changes: 7 additions & 1 deletion pkg/agent/endpoints/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@ const (
workloadAPIMethodPrefix = "/SpiffeWorkloadAPI/"
)

func Middleware(log logrus.FieldLogger, metrics telemetry.Metrics) middleware.Middleware {
func Middleware(log logrus.FieldLogger, metrics telemetry.Metrics, rateLimit int) middleware.Middleware {
return middleware.Chain(
middleware.WithLogger(log),
middleware.WithMetrics(metrics),
withPerServiceConnectionMetrics(metrics),
middleware.Preprocess(addWatcherPID),
middleware.Preprocess(verifySecurityHeader),
withRateLimit(metrics, map[string]int{
workloadAPIMethodPrefix + "FetchX509SVID": rateLimit,
workloadAPIMethodPrefix + "FetchJWTSVID": rateLimit,
workloadAPIMethodPrefix + "FetchJWTBundles": rateLimit,
workloadAPIMethodPrefix + "ValidateJWTSVID": rateLimit,
}),
)
}

Expand Down
64 changes: 64 additions & 0 deletions pkg/agent/endpoints/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package endpoints

import (
"context"

"github.com/andres-erbsen/clock"
"github.com/spiffe/spire/pkg/common/api/middleware"
"github.com/spiffe/spire/pkg/common/ratelimit"
"github.com/spiffe/spire/pkg/common/telemetry"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

var (
// Used to manipulate time in unit tests
clk = clock.New()
)

type ratelimiter struct {
metrics telemetry.Metrics
limiters map[string]*ratelimit.Map
}

func newRateLimiter(metrics telemetry.Metrics, limits map[string]int) *ratelimiter {
limiters := make(map[string]*ratelimit.Map)
for method, limit := range limits {
if limit > 0 {
limiters[method] = ratelimit.NewMap(limit, ratelimit.DefaultGCInterval, clk)
}
}

return &ratelimiter{
metrics: metrics,
limiters: limiters,
}
}

func (r *ratelimiter) Preprocess(ctx context.Context, fullMethod string, _ any) (context.Context, error) {
m, ok := r.limiters[fullMethod]
if !ok {
return ctx, nil
}

key := getCallerKey(ctx)
limiter := m.Get(key)

if !limiter.AllowN(clk.Now(), 1) {
r.metrics.IncrCounterWithLabels([]string{telemetry.WorkloadAPI, "rate_limit_exceeded"}, 1, []telemetry.Label{
{Name: telemetry.Method, Value: fullMethod},
{Name: telemetry.CallerID, Value: key},
})
return nil, status.Errorf(codes.ResourceExhausted, "method %q rate limit exceeded for %q", fullMethod, key)
}

return ctx, nil
}

func withRateLimit(metrics telemetry.Metrics, limits map[string]int) middleware.Middleware {
if len(limits) == 0 {
return middleware.Chain()
}
r := newRateLimiter(metrics, limits)
return middleware.Preprocess(r.Preprocess)
}
32 changes: 32 additions & 0 deletions pkg/agent/endpoints/ratelimit_posix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//go:build !windows

package endpoints

import (
"context"
"fmt"

"github.com/hashicorp/go-hclog"
"github.com/spiffe/spire/pkg/common/containerinfo"
"github.com/spiffe/spire/pkg/common/peertracker"
)

var (
extractor = &containerinfo.Extractor{
RootDir: "/",
}
)

func getCallerKey(ctx context.Context) string {
watcher, ok := peertracker.WatcherFromContext(ctx)
if !ok {
return "unknown"
}

podUID, _, err := extractor.GetPodUIDAndContainerID(watcher.PID(), hclog.NewNullLogger())
if err == nil && podUID != "" {
return fmt.Sprintf("pod-uid:%s", podUID)
}

return fmt.Sprintf("pid:%d", watcher.PID())
}
19 changes: 19 additions & 0 deletions pkg/agent/endpoints/ratelimit_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//go:build windows

package endpoints

import (
"context"
"fmt"

"github.com/spiffe/spire/pkg/common/peertracker"
)

func getCallerKey(ctx context.Context) string {
watcher, ok := peertracker.WatcherFromContext(ctx)
if !ok {
return "unknown"
}

return fmt.Sprintf("pid:%d", watcher.PID())
}
112 changes: 112 additions & 0 deletions pkg/common/ratelimit/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package ratelimit

import (
"context"
"sync"
"time"

"github.com/andres-erbsen/clock"
"golang.org/x/time/rate"
)

// DefaultGCInterval is the default interval at which inactive limiters are
// garbage collected.
const DefaultGCInterval = time.Minute

// Limiter represents the rate limiter functionality.
type Limiter interface {
WaitN(ctx context.Context, n int) error
AllowN(now time.Time, n int) bool
Limit() rate.Limit
Burst() int
}

// rateLimiter wraps rate.Limiter to implement the Limiter interface.
type rateLimiter struct {
*rate.Limiter
}

func (l *rateLimiter) AllowN(now time.Time, n int) bool {
return l.Limiter.AllowN(now, n)
}

func (l *rateLimiter) WaitN(ctx context.Context, n int) error {
return l.Limiter.WaitN(ctx, n)
}

// NewLimiter creates a new rate limiter with the given limit and burst.
func NewLimiter(limit rate.Limit, burst int) Limiter {
return &rateLimiter{Limiter: rate.NewLimiter(limit, burst)}
}

// Map is a thread-safe map of rate limiters keyed by string.
// It uses a two-generation garbage collection pattern to evict inactive limiters.
type Map struct {
limit int
gcInterval time.Duration
clock clock.Clock
creator func(limit rate.Limit, burst int) Limiter

mtx sync.RWMutex
previous map[string]Limiter
current map[string]Limiter
lastGC time.Time
}

// NewMap creates a new Map with the given limit and GC interval.
func NewMap(limit int, gcInterval time.Duration, clock clock.Clock) *Map {
return NewMapWithCreator(limit, gcInterval, clock, NewLimiter)
}

// NewMapWithCreator creates a new Map with the given limit, GC interval, and
// custom limiter creator.
func NewMapWithCreator(limit int, gcInterval time.Duration, clock clock.Clock, creator func(limit rate.Limit, burst int) Limiter) *Map {
return &Map{
limit: limit,
gcInterval: gcInterval,
clock: clock,
creator: creator,
current: make(map[string]Limiter),
lastGC: clock.Now(),
}
}

// Get returns the limiter for the given key. If a limiter does not exist,
// it is created.
func (m *Map) Get(key string) Limiter {
m.mtx.RLock()
limiter, ok := m.current[key]
if ok {
m.mtx.RUnlock()
return limiter
}
m.mtx.RUnlock()

m.mtx.Lock()
defer m.mtx.Unlock()

// Check current again in case another goroutine created it while we were
// upgrading the lock.
if limiter, ok = m.current[key]; ok {
return limiter
}

// Check previous to see if it was moved to previous by a recent GC.
if limiter, ok = m.previous[key]; ok {
m.current[key] = limiter
delete(m.previous, key)
return limiter
}

// If it's time for GC, move current to previous and start a new current.
now := m.clock.Now()
if now.Sub(m.lastGC) >= m.gcInterval {
m.previous = m.current
m.current = make(map[string]Limiter)
m.lastGC = now
}

limiter = m.creator(rate.Limit(m.limit), m.limit)
m.current[key] = limiter
return limiter
}
Loading