diff --git a/.env.example b/.env.example index 7c57ad9..0e3e131 100644 --- a/.env.example +++ b/.env.example @@ -37,3 +37,13 @@ # Observability (optional) # AGENT_VAULT_LOG_LEVEL=info # info (default) | debug — debug emits one line per proxied request (no secret values) + +# Rate limiting (optional) — tiered limits with sensible defaults. +# Profile: default | strict (≈0.5×) | loose (≈2×) | off (disable all limits). +# AGENT_VAULT_RATELIMIT_PROFILE=default +# When true, the owner UI cannot override rate-limit settings (operator pin). +# AGENT_VAULT_RATELIMIT_LOCK=false +# Fine-grained overrides (rare): AGENT_VAULT_RATELIMIT__ +# where TIER ∈ AUTH | PROXY | AUTHED | GLOBAL +# and KNOB ∈ RATE | BURST | WINDOW | MAX | CONCURRENCY. Example: +# AGENT_VAULT_RATELIMIT_PROXY_BURST=50 diff --git a/README.md b/README.md index eabdabf..f934697 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ See the [installation guide](https://docs.agent-vault.dev/installation) for full ### Script (macOS / Linux) ```bash -curl -fsSL https://raw.githubusercontent.com/Infisical/agent-vault/main/install.sh | sh +curl -fsSL https://get.agent-vault.dev | sh agent-vault server -d ``` @@ -117,6 +117,12 @@ const caCert = session.containerConfig!.caCertificate; See the [TypeScript SDK README](sdks/sdk-typescript/README.md) for full documentation. +## Rate limiting + +Agent Vault ships with a **tiered, in-memory rate limiter** keyed on the principal appropriate for each endpoint (client IP for anonymous auth, hashed token for invite/approval redemption, `(actor, vault)` scope for the proxy path, global in-flight ceiling for the server). Defaults are tuned for normal use — agents doing realistic bursts of proxy calls don't trip anything — and 429 responses carry a `Retry-After` header so clients can back off politely. + +Pick a preset via `AGENT_VAULT_RATELIMIT_PROFILE={default,strict,loose,off}`, or fine-tune per tier in **Manage Instance → Settings → Rate Limiting** (owner-only). Set `AGENT_VAULT_RATELIMIT_LOCK=true` on PaaS to pin limits to env vars and disable the UI. See [docs/self-hosting/environment-variables.mdx](docs/self-hosting/environment-variables.mdx) for the full knob list. + ## Development ```bash diff --git a/cmd/server.go b/cmd/server.go index 7335898..1ba1232 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -171,11 +171,14 @@ func attachMITMIfEnabled(srv *server.Server, host string, mitmPort int, masterKe } srv.AttachMITM(mitm.New( net.JoinHostPort(host, strconv.Itoa(mitmPort)), - caProv, - srv.SessionResolver(), - srv.CredentialProvider(), - srv.BaseURL(), - srv.Logger(), + mitm.Options{ + CA: caProv, + Sessions: srv.SessionResolver(), + Credentials: srv.CredentialProvider(), + BaseURL: srv.BaseURL(), + Logger: srv.Logger(), + RateLimit: srv.RateLimit(), + }, )) return nil } diff --git a/cmd/skill_cli.md b/cmd/skill_cli.md index 7022661..19150c1 100644 --- a/cmd/skill_cli.md +++ b/cmd/skill_cli.md @@ -192,7 +192,7 @@ Prints the raw value to stdout (pipe-friendly). Useful for configuration tasks w - 401: Invalid or expired token -- check `AGENT_VAULT_SESSION_TOKEN` - 403 `forbidden`: Host not allowed -- create a proposal - 403 `service_disabled`: Host is configured but currently disabled by an operator. Don't create a new proposal; surface the error to the user so they can re-enable it (UI toggle, or `agent-vault vault service enable `) -- 429: Too many pending proposals -- wait for review +- 429: Rate limited. The response carries a `Retry-After` header (seconds) and a JSON body `{"error":"too_many_requests", ...}`. Respect `Retry-After` — wait that many seconds before retrying. Don't tight-loop or switch to a different Agent Vault ingress to bypass it (MITM + explicit `/proxy/` share one budget). If this trips on normal work, ask the instance owner to raise the limit in **Manage Instance → Settings → Rate Limiting**. - 502: Missing credential or upstream unreachable, tell user a credential may need to be added ## Rules diff --git a/cmd/skill_http.md b/cmd/skill_http.md index 90dc486..d1d44e9 100644 --- a/cmd/skill_http.md +++ b/cmd/skill_http.md @@ -181,7 +181,7 @@ Content-Type: application/json - 401: Invalid or expired token -- check `AGENT_VAULT_SESSION_TOKEN` - 403 `forbidden`: Host not allowed -- create a proposal - 403 `service_disabled`: Host is configured but currently disabled by an operator. Don't create a new proposal; surface the error to the user so they can re-enable it -- 429: Too many pending proposals -- wait for review +- 429: Rate limited. The response carries a `Retry-After` header (seconds) and a JSON body `{"error":"too_many_requests", ...}`. Respect `Retry-After` — wait that many seconds before retrying. Do **not** tight-loop or switch to a different Agent Vault ingress to bypass the limit; the MITM and explicit `/proxy/` paths share one budget. If the limit trips repeatedly on normal work, ask the instance owner to raise the limit in **Manage Instance → Settings → Rate Limiting**. - 502: Missing credential or upstream unreachable, tell user a credential may need to be added ## Rules diff --git a/docs/reference/cli.mdx b/docs/reference/cli.mdx index 29ecb73..1f12b55 100644 --- a/docs/reference/cli.mdx +++ b/docs/reference/cli.mdx @@ -43,6 +43,9 @@ description: "Complete reference for all Agent Vault CLI commands." | `AGENT_VAULT_SMTP_FROM_NAME` | Sender display name (default `Agent Vault`) | | `AGENT_VAULT_SMTP_TLS_MODE` | TLS mode: `opportunistic` (default), `required`, or `none` | | `AGENT_VAULT_SMTP_TLS_SKIP_VERIFY` | Skip TLS certificate verification (default `false`) | + | `AGENT_VAULT_RATELIMIT_PROFILE` | Rate-limit profile: `default`, `strict`, `loose`, or `off`. Affects anonymous auth, token-redeem, proxy, authenticated CRUD, and the global in-flight / RPS ceilings. | + | `AGENT_VAULT_RATELIMIT_LOCK` | When `true`, the rate-limit section in the Manage Instance UI is read-only and UI overrides are ignored. Use when you want limits pinned to env vars on PaaS. | + | `AGENT_VAULT_RATELIMIT__` | Fine-grained per-tier overrides. `TIER` ∈ `AUTH`, `PROXY`, `AUTHED`, `GLOBAL`. `KNOB` ∈ `RATE`, `BURST`, `WINDOW`, `MAX`, `CONCURRENCY`. Env-set knobs always beat UI overrides. | diff --git a/docs/self-hosting/environment-variables.mdx b/docs/self-hosting/environment-variables.mdx index 92b518e..94ac61e 100644 --- a/docs/self-hosting/environment-variables.mdx +++ b/docs/self-hosting/environment-variables.mdx @@ -14,6 +14,9 @@ description: "All environment variables used to configure Agent Vault" | `AGENT_VAULT_NETWORK_MODE` | `public` | Proxy network restriction mode. `public` blocks connections to private/reserved IP ranges (RFC-1918, link-local, cloud metadata). `private` allows all outbound connections including private ranges — use this for local/private deployments where the proxy needs to reach internal services. | | `AGENT_VAULT_TRUSTED_PROXIES` | (unset) | Comma-separated CIDR ranges of trusted reverse proxies (e.g. `10.0.0.0/8,172.16.0.0/12`). When set, `X-Forwarded-For` is only trusted if the direct connection comes from a listed proxy. Used for rate limiting and audit logging behind a load balancer. | | `AGENT_VAULT_LOG_LEVEL` | `info` | Log level for the server. `info` (default) keeps startup banners and warnings only. `debug` adds one structured line per proxied request (ingress path, method, host, path, matched service, injected credential **key names**, upstream status, duration). Credential values are never logged. The `--log-level` flag takes precedence when set. | +| `AGENT_VAULT_RATELIMIT_PROFILE` | `default` | Rate-limit profile: `default`, `strict` (≈0.5× the defaults), `loose` (≈2×), or `off` (disable all limits). Affects every tier — anonymous auth, token-redeem, proxy, authenticated CRUD, global in-flight. Owners can override per-tier in **Manage Instance → Settings → Rate Limiting** unless `AGENT_VAULT_RATELIMIT_LOCK=true`. | +| `AGENT_VAULT_RATELIMIT_LOCK` | `false` | When `true`, the rate-limit UI in **Manage Instance** is read-only and UI overrides are ignored. Use on PaaS deployments (Fly.io, Cloud Run) when the operator wants limits pinned to env vars. | +| `AGENT_VAULT_RATELIMIT__` | — | Fine-grained per-tier overrides. `TIER` is one of `AUTH` (unauthenticated endpoints), `PROXY` (proxy + MITM), `AUTHED` (everything behind requireAuth), `GLOBAL` (server-wide backstop). `KNOB` is one of `RATE` (tokens/sec), `BURST` (bucket depth), `WINDOW` (duration like `5m`), `MAX` (sliding-window event cap), `CONCURRENCY` (semaphore slots). Env-set knobs always take precedence over UI overrides. | Master password resolution order: diff --git a/internal/brokercore/session.go b/internal/brokercore/session.go index 2c1f74a..1ceca65 100644 --- a/internal/brokercore/session.go +++ b/internal/brokercore/session.go @@ -7,6 +7,13 @@ import ( "github.com/Infisical/agent-vault/internal/store" ) +// MaxProxyBodyBytes caps forwarded request bodies on both proxy +// ingresses. Distinct from the generic 1 MB limitBody wrapper used +// on control-plane endpoints: proxy bodies are legitimately larger +// (file uploads, bulk API payloads) but must still be bounded to +// protect RAM under the proxy concurrency semaphore. +const MaxProxyBodyBytes = 64 << 20 + // ProxyScope is the resolved identity + vault context for a proxy request. // It is produced once per ingress (per request for /proxy, per CONNECT for // MITM) and carried through to credential injection. @@ -18,6 +25,16 @@ type ProxyScope struct { VaultRole string } +// ActorID returns the non-empty principal ID — UserID for user +// sessions, AgentID for agent tokens. Used as the actor dimension in +// per-scope rate-limit keys. +func (s *ProxyScope) ActorID() string { + if s.UserID != "" { + return s.UserID + } + return s.AgentID +} + // SessionResolver collapses bearer-token validation and vault selection into // one call. Both ingresses use the same resolver; MITM passes a vault hint // parsed from Proxy-Authorization, /proxy passes r.Header.Get("X-Vault"). diff --git a/internal/mitm/connect.go b/internal/mitm/connect.go index 90d7f27..bc594f1 100644 --- a/internal/mitm/connect.go +++ b/internal/mitm/connect.go @@ -9,13 +9,34 @@ import ( "time" "github.com/Infisical/agent-vault/internal/brokercore" + "github.com/Infisical/agent-vault/internal/ratelimit" ) +// mitmConnectIPKey is the rate-limit key for the CONNECT-flood +// limiter. X-Forwarded-For doesn't exist at this layer (the HTTP +// request is tunnelled); only the direct peer IP is meaningful. +func mitmConnectIPKey(r *http.Request) string { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil || host == "" { + host = r.RemoteAddr + } + return "mitm:" + host +} + // handleConnect terminates a CONNECT tunnel and serves HTTP/1.1 off the // resulting TLS connection. The upstream target is taken from the // CONNECT request line (r.Host) and captured in a closure so subsequent // Host-header rewrites by the client cannot redirect the tunnel. func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) { + // Gate before ParseProxyAuth + session lookup so a bad-auth flood + // can't burn CPU. Per-IP on the raw TCP peer. + if p.rateLimit != nil { + if d := p.rateLimit.Allow(ratelimit.TierAuth, mitmConnectIPKey(r)); !d.Allow { + ratelimit.WriteDenial(w, d, "Too many CONNECT attempts") + return + } + } + target := r.Host host, _, err := net.SplitHostPort(target) if err != nil { @@ -87,8 +108,13 @@ func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) { // closes the listener so Serve returns. listener := newOneShotListener(tlsConn) srv := &http.Server{ - Handler: p.forwardHandler(target, host, scope), + Handler: p.forwardHandler(target, host, scope), + // Slow-loris defense: without these the tunnel can drip bytes + // forever and pin a proxy concurrency slot. ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 60 * time.Second, + WriteTimeout: 5 * time.Minute, // upstream streaming can be legit + IdleTimeout: 2 * time.Minute, ConnState: func(c net.Conn, state http.ConnState) { if state == http.StateClosed || state == http.StateHijacked { _ = listener.Close() diff --git a/internal/mitm/forward.go b/internal/mitm/forward.go index 1aa563a..9daeb50 100644 --- a/internal/mitm/forward.go +++ b/internal/mitm/forward.go @@ -8,6 +8,7 @@ import ( "time" "github.com/Infisical/agent-vault/internal/brokercore" + "github.com/Infisical/agent-vault/internal/ratelimit" ) // forwardHandler returns an http.Handler that forwards each request to @@ -28,6 +29,17 @@ func (p *Proxy) forwardHandler(target, host string, scope *brokercore.ProxyScope event.Emit(p.logger, start, status, errCode) } + // Shares one budget with /proxy so switching ingress can't bypass. + enf := p.rateLimit.EnforceProxy(r.Context(), scope.ActorID(), scope.VaultID) + if !enf.Allowed { + ratelimit.WriteDenial(w, enf.Decision, enf.Message) + emit(http.StatusTooManyRequests, enf.ErrCode) + return + } + defer enf.Release() + + r.Body = http.MaxBytesReader(w, r.Body, brokercore.MaxProxyBodyBytes) + outURL := &url.URL{ Scheme: "https", Host: target, diff --git a/internal/mitm/proxy.go b/internal/mitm/proxy.go index 056dc91..1759c0d 100644 --- a/internal/mitm/proxy.go +++ b/internal/mitm/proxy.go @@ -27,6 +27,7 @@ import ( "github.com/Infisical/agent-vault/internal/brokercore" "github.com/Infisical/agent-vault/internal/ca" "github.com/Infisical/agent-vault/internal/netguard" + "github.com/Infisical/agent-vault/internal/ratelimit" ) // Proxy is a transparent MITM proxy. It is safe to start at most once; @@ -41,15 +42,27 @@ type Proxy struct { isListening atomic.Bool baseURL string // externally-reachable control-plane URL for help links logger *slog.Logger + rateLimit *ratelimit.Registry // shared with the HTTP server; nil = no-op } -// New builds a Proxy bound to addr using caProv for leaf certificates and -// the brokercore sessions/creds for authentication and credential injection. -// baseURL is the externally-reachable control-plane URL (e.g. -// "http://127.0.0.1:14321") used to build help links in error responses. -// The returned Proxy does not begin listening until ListenAndServe is -// called. logger must be non-nil; tests can pass slog.New(slog.DiscardHandler). -func New(addr string, caProv ca.Provider, sessions brokercore.SessionResolver, creds brokercore.CredentialProvider, baseURL string, logger *slog.Logger) *Proxy { +// Options carries the dependencies a Proxy needs. BaseURL is the +// externally-reachable control-plane URL used in help-link error +// responses. Logger must be non-nil; tests can pass +// slog.New(slog.DiscardHandler). RateLimit is shared with the HTTP +// server so proxy limits apply uniformly across both ingresses; nil +// disables rate limiting on the MITM path. +type Options struct { + CA ca.Provider + Sessions brokercore.SessionResolver + Credentials brokercore.CredentialProvider + BaseURL string + Logger *slog.Logger + RateLimit *ratelimit.Registry +} + +// New builds a Proxy bound to addr. The returned Proxy does not begin +// listening until ListenAndServe is called. +func New(addr string, opts Options) *Proxy { upstream := &http.Transport{ DialContext: netguard.SafeDialContext(netguard.ModeFromEnv()), TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, @@ -61,12 +74,13 @@ func New(addr string, caProv ca.Provider, sessions brokercore.SessionResolver, c } p := &Proxy{ - ca: caProv, - sessions: sessions, - creds: creds, - upstream: upstream, - baseURL: baseURL, - logger: logger, + ca: opts.CA, + sessions: opts.Sessions, + creds: opts.Credentials, + upstream: upstream, + baseURL: opts.BaseURL, + logger: opts.Logger, + rateLimit: opts.RateLimit, } p.tlsConfig = &tls.Config{ @@ -75,16 +89,15 @@ func New(addr string, caProv ca.Provider, sessions brokercore.SessionResolver, c sni := hello.ServerName if sni == "" { // No SNI (IP-literal connection per RFC 6066). Use the - // actual local address the client connected to so the - // cert SAN matches regardless of IPv4/IPv6 or which - // interface was used on a wildcard bind. + // local address the client connected to so the cert + // SAN matches regardless of IPv4/IPv6 or wildcard bind. if host, _, err := net.SplitHostPort(hello.Conn.LocalAddr().String()); err == nil && host != "" { sni = host } else { sni = "127.0.0.1" } } - return caProv.MintLeaf(sni) + return opts.CA.MintLeaf(sni) }, } diff --git a/internal/mitm/proxy_test.go b/internal/mitm/proxy_test.go index 2e58de8..8f5ed47 100644 --- a/internal/mitm/proxy_test.go +++ b/internal/mitm/proxy_test.go @@ -97,7 +97,13 @@ func setupProxy(t *testing.T, sr brokercore.SessionResolver, cp brokercore.Crede t.Fatal("failed to load CA root PEM into pool") } - p = New("127.0.0.1:0", caProv, sr, cp, "http://127.0.0.1:14321", slog.New(slog.DiscardHandler)) + p = New("127.0.0.1:0", Options{ + CA: caProv, + Sessions: sr, + Credentials: cp, + BaseURL: "http://127.0.0.1:14321", + Logger: slog.New(slog.DiscardHandler), + }) l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { diff --git a/internal/ratelimit/bucket.go b/internal/ratelimit/bucket.go new file mode 100644 index 0000000..f808f40 --- /dev/null +++ b/internal/ratelimit/bucket.go @@ -0,0 +1,138 @@ +package ratelimit + +import ( + "math" + "sync" + "time" +) + +// tokenBucketMap is a keyed collection of token buckets with LRU-style +// eviction when the map exceeds maxKeys. Each bucket refills at rate +// tokens/sec and can accumulate up to burst tokens. Thread-safe. +type tokenBucketMap struct { + mu sync.Mutex + buckets map[string]*bucket + rate float64 + burst float64 + maxKeys int + now func() time.Time // injectable for tests +} + +type bucket struct { + tokens float64 + lastRefill time.Time +} + +func newTokenBucketMap(cfg TierConfig) *tokenBucketMap { + r := cfg.Rate + if r <= 0 { + r = 1.0 + } + b := float64(cfg.Burst) + if b < 1 { + b = 1 + } + mk := cfg.MaxKeys + if mk <= 0 { + mk = 10000 + } + return &tokenBucketMap{ + buckets: make(map[string]*bucket), + rate: r, + burst: b, + maxKeys: mk, + now: time.Now, + } +} + +// reconfigure updates rate/burst/maxKeys in place. Existing buckets +// carry their current token count forward so hot keys don't get a +// surprise refill on reload. +func (m *tokenBucketMap) reconfigure(cfg TierConfig) { + m.mu.Lock() + defer m.mu.Unlock() + if cfg.Rate > 0 { + m.rate = cfg.Rate + } + if cfg.Burst > 0 { + m.burst = float64(cfg.Burst) + // Clamp any over-full buckets to the new cap. + for _, b := range m.buckets { + if b.tokens > m.burst { + b.tokens = m.burst + } + } + } + if cfg.MaxKeys > 0 { + m.maxKeys = cfg.MaxKeys + } +} + +// allow refills before deducting so burst is honored after idle periods. +func (m *tokenBucketMap) allow(key string) Decision { + m.mu.Lock() + defer m.mu.Unlock() + + now := m.now() + b, ok := m.buckets[key] + if !ok { + b = &bucket{tokens: m.burst, lastRefill: now} + m.buckets[key] = b + m.evictIfNeededLocked(now) + } else { + elapsed := now.Sub(b.lastRefill).Seconds() + if elapsed > 0 { + b.tokens = math.Min(m.burst, b.tokens+elapsed*m.rate) + b.lastRefill = now + } + } + + if b.tokens >= 1 { + b.tokens-- + return AllowOK(int(b.tokens), int(m.burst)) + } + + // Wait for one token to refill. + need := 1 - b.tokens + wait := time.Duration(need / m.rate * float64(time.Second)) + if wait < time.Second { + wait = time.Second + } + return Deny("rate", wait, int(m.burst)) +} + +// evictIfNeededLocked is called under m.mu. Prefers to drop buckets +// whose tokens are near full (idle keys — zero fairness impact). If +// that isn't enough (every bucket is hot), falls back to evicting the +// oldest-by-lastRefill entry so the map stays bounded even under +// adversarial distinct-key traffic. +func (m *tokenBucketMap) evictIfNeededLocked(_ time.Time) { + if m.maxKeys <= 0 || len(m.buckets) <= m.maxKeys { + return + } + for k, b := range m.buckets { + if b.tokens >= m.burst-0.0001 { + delete(m.buckets, k) + } + if len(m.buckets) <= m.maxKeys { + return + } + } + // Fallback: evict oldest by lastRefill until within cap. + for len(m.buckets) > m.maxKeys { + var oldestKey string + var oldestTime time.Time + first := true + for k, b := range m.buckets { + if first || b.lastRefill.Before(oldestTime) { + oldestKey = k + oldestTime = b.lastRefill + first = false + } + } + if oldestKey == "" { + return + } + delete(m.buckets, oldestKey) + } +} diff --git a/internal/ratelimit/config.go b/internal/ratelimit/config.go new file mode 100644 index 0000000..ee8e91a --- /dev/null +++ b/internal/ratelimit/config.go @@ -0,0 +1,229 @@ +package ratelimit + +import ( + "fmt" + "os" + "strconv" + "strings" + "time" +) + +// Profile is a named bundle of tier defaults. +type Profile string + +const ( + ProfileDefault Profile = "default" + ProfileStrict Profile = "strict" + ProfileLoose Profile = "loose" + ProfileOff Profile = "off" +) + +// TierConfig holds the tunables for a single tier. Zero values mean +// "unset"; resolution layers fill them in. +type TierConfig struct { + Algorithm Algorithm + Rate float64 // tokens/sec (bucket) or events/window (sliding) + Burst int // bucket depth + Window time.Duration // sliding window + Max int // sliding max events per window + Concurrency int // semaphore cap + MaxKeys int // per-limiter map cap (LRU evict) +} + +// Config is the fully-resolved per-tier configuration for a Registry. +// Off means "allow everything" — used when profile=off or for operators +// fronting with their own edge rate limiter. +type Config struct { + Profile Profile + Off bool + Locked bool // true when AGENT_VAULT_RATELIMIT_LOCK=true + Tiers [tierCount]TierConfig +} + +// DefaultsFor returns a Config populated with the requested profile's +// baselines. +func DefaultsFor(profile Profile) Config { + var c Config + c.Profile = profile + if profile == ProfileOff { + c.Off = true + return c + } + mul := profileMultiplier(profile) + // Unauthenticated surface: login/register/forgot/reset/verify, + // OAuth, invite/approval-token redemption. + c.Tiers[TierAuth] = TierConfig{ + Algorithm: AlgSliding, Window: 5 * time.Minute, Max: scaleMax(10, mul), MaxKeys: 10000, + } + // /proxy/* + MITM: token bucket smooths traffic; Concurrency caps + // in-flight slow upstream calls per (actor, vault). + c.Tiers[TierProxy] = TierConfig{ + Algorithm: AlgTokenBucket, Rate: scaleRate(2.0, mul), Burst: scaleMax(30, mul), + Concurrency: scaleMax(16, mul), MaxKeys: 10000, + } + // Everything behind requireAuth — generous; the heaviest legitimate + // agent workload is 50+ discover+CRUD calls/minute. + c.Tiers[TierAuthed] = TierConfig{ + Algorithm: AlgTokenBucket, Rate: scaleRate(5.0, mul), Burst: scaleMax(120, mul), MaxKeys: 10000, + } + // Server-wide backstop. Rate/Burst drive the RPS bucket; Concurrency + // drives the in-flight semaphore. + c.Tiers[TierGlobal] = TierConfig{ + Rate: float64(scaleMax(2000, mul)), Burst: scaleMax(4000, mul), + Concurrency: scaleMax(512, mul), + } + // Internal: failure counter for verification codes. + c.Tiers[TierVerifyFailure] = TierConfig{ + Algorithm: AlgFailureCounter, Max: scaleMax(10, mul), MaxKeys: 10000, + } + return c +} + +func profileMultiplier(p Profile) float64 { + switch p { + case ProfileStrict: + return 0.5 + case ProfileLoose: + return 2.0 + default: + return 1.0 + } +} + +func scaleMax(base int, mul float64) int { + v := int(float64(base) * mul) + if v < 1 { + v = 1 + } + return v +} + +func scaleRate(base, mul float64) float64 { + v := base * mul + if v <= 0 { + v = 0.1 + } + return v +} + +// EnvSetMask records which per-tier env knobs were explicitly set. +// Callers that need to preserve env precedence when merging a setting +// payload use this to avoid rescanning os.Getenv. +type EnvSetMask struct { + Rate, Burst, Window, Max, Concurrency bool +} + +// EnvMasks is the per-tier mask array returned by LoadFromEnv. Named +// so callers can hold it in a variable without referencing tierCount. +type EnvMasks [tierCount]EnvSetMask + +// LoadFromEnv returns a Config initialized from the environment and a +// per-tier mask marking which knobs came from env (as opposed to +// profile defaults). Precedence: AGENT_VAULT_RATELIMIT_PROFILE sets +// the baseline; per-tier AGENT_VAULT_RATELIMIT__ vars +// override individual fields. AGENT_VAULT_RATELIMIT_LOCK=true marks +// the config as operator-pinned; UI overrides are ignored by callers +// that honor the Locked flag. +func LoadFromEnv() (Config, EnvMasks) { + profile := Profile(strings.ToLower(os.Getenv("AGENT_VAULT_RATELIMIT_PROFILE"))) + if profile == "" { + profile = ProfileDefault + } + cfg := DefaultsFor(profile) + cfg.Locked = strings.EqualFold(os.Getenv("AGENT_VAULT_RATELIMIT_LOCK"), "true") + var mask EnvMasks + if cfg.Off { + return cfg, mask + } + for t := Tier(0); t < tierCount; t++ { + prefix := "AGENT_VAULT_RATELIMIT_" + t.String() + "_" + if v := os.Getenv(prefix + "RATE"); v != "" { + if f, err := strconv.ParseFloat(v, 64); err == nil { + cfg.Tiers[t].Rate = f + mask[t].Rate = true + } + } + if v := os.Getenv(prefix + "BURST"); v != "" { + if n, err := strconv.Atoi(v); err == nil { + cfg.Tiers[t].Burst = n + mask[t].Burst = true + } + } + if v := os.Getenv(prefix + "WINDOW"); v != "" { + if d, err := time.ParseDuration(v); err == nil { + cfg.Tiers[t].Window = d + mask[t].Window = true + } + } + if v := os.Getenv(prefix + "MAX"); v != "" { + if n, err := strconv.Atoi(v); err == nil { + cfg.Tiers[t].Max = n + mask[t].Max = true + } + } + if v := os.Getenv(prefix + "CONCURRENCY"); v != "" { + if n, err := strconv.Atoi(v); err == nil { + cfg.Tiers[t].Concurrency = n + mask[t].Concurrency = true + } + } + } + return cfg, mask +} + +// EnvSet reports whether any of a tier's knobs were set via env. +func (m EnvSetMask) Any() bool { + return m.Rate || m.Burst || m.Window || m.Max || m.Concurrency +} + +// ApplyOverrides merges a partial set of overrides (e.g. from the +// instance settings pane) into c. Env-pinned values (when c.Locked) are +// not touched; callers should gate on Locked before calling this. +func (c *Config) ApplyOverrides(overrides map[Tier]TierConfig) { + for t, ov := range overrides { + if t < 0 || t >= tierCount { + continue + } + if ov.Rate != 0 { + c.Tiers[t].Rate = ov.Rate + } + if ov.Burst != 0 { + c.Tiers[t].Burst = ov.Burst + } + if ov.Window != 0 { + c.Tiers[t].Window = ov.Window + } + if ov.Max != 0 { + c.Tiers[t].Max = ov.Max + } + if ov.Concurrency != 0 { + c.Tiers[t].Concurrency = ov.Concurrency + } + } +} + +// Validate returns a non-nil error if c would produce a server the +// operator cannot recover from (e.g. auth limits set to 0). +func (c *Config) Validate() error { + if c.Off { + return nil + } + if c.Tiers[TierAuth].Max < minAuthFloor { + return fmt.Errorf("AUTH.max below floor (%d): owner lockout risk", minAuthFloor) + } + if c.Tiers[TierGlobal].Concurrency < minGlobalInflight { + return fmt.Errorf("GLOBAL.concurrency below floor (%d)", minGlobalInflight) + } + if c.Tiers[TierGlobal].Rate < float64(minGlobalRPS) { + return fmt.Errorf("GLOBAL.rate below floor (%d)", minGlobalRPS) + } + return nil +} + +// Floors enforced by Validate. Env vars can go below them (for +// deliberate testing); UI overrides cannot. +const ( + minAuthFloor = 5 + minGlobalInflight = 32 + minGlobalRPS = 100 +) diff --git a/internal/ratelimit/failure.go b/internal/ratelimit/failure.go new file mode 100644 index 0000000..ca6c4b1 --- /dev/null +++ b/internal/ratelimit/failure.go @@ -0,0 +1,77 @@ +package ratelimit + +import "sync" + +// failureCounter tracks failed attempts per key (typically email). +// Semantics differ from a rate limiter: the counter only increments on +// recordFailure; it does not decay with time. A successful operation +// calls reset. Callers use check() to gate retries — when a caller +// hits the cap, the underlying credential (verification code, password +// reset code) is considered burned and the caller must request a new +// one. Thread-safe. +type failureCounter struct { + mu sync.Mutex + attempts map[string]int + max int + maxKeys int +} + +func newFailureCounter(cfg TierConfig) *failureCounter { + m := cfg.Max + if m < 1 { + m = 10 + } + mk := cfg.MaxKeys + if mk <= 0 { + mk = 10000 + } + return &failureCounter{ + attempts: make(map[string]int), + max: m, + maxKeys: mk, + } +} + +func (f *failureCounter) reconfigure(cfg TierConfig) { + f.mu.Lock() + defer f.mu.Unlock() + if cfg.Max > 0 { + f.max = cfg.Max + } + if cfg.MaxKeys > 0 { + f.maxKeys = cfg.MaxKeys + } +} + +// check returns true if key is still under the failure cap. +func (f *failureCounter) check(key string) bool { + f.mu.Lock() + defer f.mu.Unlock() + return f.attempts[key] < f.max +} + +// record increments the failure counter for key. On overflow, evicts +// an arbitrary other entry — Go's map iteration is randomized so the +// victim is unpredictable, but under a mass-enumeration attack the +// caller's ~10000-key cap means we're throttling anyway; losing an +// occasional legitimate counter just lets that email try again. +func (f *failureCounter) record(key string) { + f.mu.Lock() + defer f.mu.Unlock() + f.attempts[key]++ + if f.maxKeys > 0 && len(f.attempts) > f.maxKeys { + for k := range f.attempts { + if k != key { + delete(f.attempts, k) + break + } + } + } +} + +// reset clears the counter for key. +func (f *failureCounter) reset(key string) { + f.mu.Lock() + defer f.mu.Unlock() + delete(f.attempts, key) +} diff --git a/internal/ratelimit/key.go b/internal/ratelimit/key.go new file mode 100644 index 0000000..7173da1 --- /dev/null +++ b/internal/ratelimit/key.go @@ -0,0 +1,56 @@ +package ratelimit + +import ( + "crypto/sha256" + "encoding/hex" + "net/http" +) + +// Keyer extracts a stable string key from a request for a rate-limit +// bucket. Returning "" tells the middleware to skip the check — +// useful when a keyer depends on context that isn't yet populated +// (e.g., session post-auth). +type Keyer func(*http.Request) string + +// IPKey wraps a clientIP function into a Keyer. The server package +// owns the clientIP logic (AGENT_VAULT_TRUSTED_PROXIES handling) and +// passes it in, so this package has no http-header policy. +func IPKey(clientIP func(*http.Request) string) Keyer { + return func(r *http.Request) string { + return "ip:" + clientIP(r) + } +} + +// HashToken returns a base16 SHA-256 prefix of s. Used for token +// keys so raw secrets never sit in the limiter's memory or logs. +func HashToken(s string) string { + sum := sha256.Sum256([]byte(s)) + return hex.EncodeToString(sum[:8]) +} + +// IPTokenKey combines clientIP + hashed token/state from the request. +// tokenFromRequest picks the token out (path value, query param, +// body — handler's choice). Returns "" when the token is empty so +// the middleware skips the check. +func IPTokenKey(clientIP func(*http.Request) string, tokenFromRequest func(*http.Request) string) Keyer { + return func(r *http.Request) string { + tok := tokenFromRequest(r) + if tok == "" { + return "" + } + return "ipt:" + clientIP(r) + ":" + HashToken(tok) + } +} + +// ActorKey keys on a caller-provided actor identifier. The server +// package resolves session→actor once per request post-auth and +// passes a closure that reads from the request context. +func ActorKey(actorFromRequest func(*http.Request) string) Keyer { + return func(r *http.Request) string { + id := actorFromRequest(r) + if id == "" { + return "" + } + return "actor:" + id + } +} diff --git a/internal/ratelimit/middleware.go b/internal/ratelimit/middleware.go new file mode 100644 index 0000000..bb8fa28 --- /dev/null +++ b/internal/ratelimit/middleware.go @@ -0,0 +1,107 @@ +package ratelimit + +import ( + "encoding/json" + "log/slog" + "net/http" + "strconv" +) + +// HandlerFunc returns a per-route wrapper that applies tier to each +// request. Empty-key requests pass through (fail-open). On denial: +// 429 with Retry-After + X-RateLimit-* headers; logs WARN with the +// already-scoped key. +// +// For the proxy path, handlers call Registry.EnforceProxy directly — +// the scope and target host aren't known until after vault resolution. +func (r *Registry) HandlerFunc(tier Tier, keyer Keyer, logger *slog.Logger) func(http.HandlerFunc) http.HandlerFunc { + return func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + if r.cfg.Load().Off { + next(w, req) + return + } + key := "" + if keyer != nil { + key = keyer(req) + } + if key == "" { + next(w, req) + return + } + d := r.Allow(tier, key) + if d.Allow { + writeRateLimitHeaders(w, d) + next(w, req) + return + } + if logger != nil { + logger.Warn("ratelimit deny", + "tier", tier.String(), + "key", key, + "path", req.URL.Path, + "method", req.Method, + "reason", d.Reason, + "retry_after_sec", int(d.RetryAfter.Seconds()), + ) + } + WriteDenial(w, d, "Too many requests, try again later") + } + } +} + +// GlobalMiddleware is the outermost wrapper: server-wide RPS ceiling +// + in-flight semaphore. Off short-circuits immediately so operators +// fronting with their own edge limiter pay no overhead. +func (r *Registry) GlobalMiddleware(logger *slog.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if r.cfg.Load().Off { + next.ServeHTTP(w, req) + return + } + if d := r.AllowGlobalRPS(); !d.Allow { + if logger != nil { + logger.Warn("ratelimit deny (global_rps)", "path", req.URL.Path, "method", req.Method) + } + WriteDenial(w, d, "Server busy, try again shortly") + return + } + release, d := r.AcquireGlobal(req.Context()) + if !d.Allow { + if logger != nil { + logger.Warn("ratelimit deny (global_inflight)", "path", req.URL.Path, "method", req.Method) + } + WriteDenial(w, d, "Server at capacity, try again shortly") + return + } + defer release() + next.ServeHTTP(w, req) + }) + } +} + +// WriteDenial emits a 429 with standard rate-limit headers. Exported +// so handlers that do in-handler limit checks (e.g. login's email +// bucket) share the same response shape as the middleware. +func WriteDenial(w http.ResponseWriter, d Decision, message string) { + writeRateLimitHeaders(w, d) + if d.RetryAfter > 0 { + w.Header().Set("Retry-After", strconv.Itoa(int(d.RetryAfter.Seconds()))) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTooManyRequests) + body, _ := json.Marshal(map[string]string{ + "error": "too_many_requests", + "message": message, + }) + _, _ = w.Write(body) +} + +func writeRateLimitHeaders(w http.ResponseWriter, d Decision) { + if d.Limit > 0 { + w.Header().Set("X-RateLimit-Limit", strconv.Itoa(d.Limit)) + w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(d.Remaining)) + } +} + diff --git a/internal/ratelimit/ratelimit.go b/internal/ratelimit/ratelimit.go new file mode 100644 index 0000000..da02c22 --- /dev/null +++ b/internal/ratelimit/ratelimit.go @@ -0,0 +1,140 @@ +// Package ratelimit is Agent Vault's tiered rate limiter. It groups +// endpoints into tiers by attack surface and cost, keys each tier on +// the principal being defended (IP, hashed token, actor, scope-host), +// and exposes one Registry that owns all in-memory state. +package ratelimit + +import ( + "errors" + "fmt" + "time" +) + +// Tier identifies a rate-limiting bucket. Each tier carries its own +// Config and covers a distinct attack surface. +type Tier int + +const ( + // TierAuth covers every unauthenticated endpoint: login, register, + // forgot/reset password, email verification, OAuth login/callback, + // invite redemption, approval-token lookups. Sliding window. The + // caller picks the keyer — IPKey for IP-flood defense, IPTokenKey + // for token-enumeration, and the login handler uses both an IP + // and email key against this tier (reject if either is exhausted). + TierAuth Tier = iota + + // TierProxy is the /proxy and MITM rate limit keyed on + // (actor, vault). Token bucket for smooth traffic + a per-scope + // concurrency semaphore for slow upstream calls. + TierProxy + + // TierAuthed is the catch-all for authenticated endpoints (CRUD, + // reads, admin, expensive fan-out ops). Token bucket keyed on + // actor. Defaults accommodate the heaviest workload in the tier; + // if a specific endpoint needs tighter protection, add it inline. + TierAuthed + + // TierGlobal is the server-wide backstop: requests-per-second + // ceiling + total in-flight cap. Rate/Burst drive the RPS token + // bucket; Concurrency drives the in-flight semaphore. + TierGlobal + + // TierVerifyFailure is an internal failure counter — not a rate + // limit — for email-verification and password-reset codes. The + // counter increments on bad codes per email and resets on success; + // hitting the cap invalidates the outstanding code. Not exposed in + // the operator UI because there is nothing useful to tune. + TierVerifyFailure + + tierCount +) + +// tierNames is the authoritative mapping between the stable wire +// name (also the env-var suffix) and the internal Tier constant. +// Exposed via String and TierByName so callers never duplicate it. +var tierNames = [tierCount]string{ + TierAuth: "AUTH", + TierProxy: "PROXY", + TierAuthed: "AUTHED", + TierGlobal: "GLOBAL", + TierVerifyFailure: "VERIFY_FAIL", +} + +// String returns the stable env-suffix form of a Tier, used for +// AGENT_VAULT_RATELIMIT__ variable names. +func (t Tier) String() string { + if t < 0 || int(t) >= len(tierNames) { + return fmt.Sprintf("TIER_%d", int(t)) + } + return tierNames[t] +} + +// TierByName returns the Tier matching name, or (0, false) if name is +// not recognized. Case-sensitive — callers that accept user input +// should upper-case the name first. +func TierByName(name string) (Tier, bool) { + for i, n := range tierNames { + if n == name { + return Tier(i), true + } + } + return 0, false +} + +// AllTiers returns the list of valid Tier values in declaration order. +// Use in UI enumeration and config serialization; do not rely on slice +// index matching Tier value (that's only guaranteed today because +// Tier is a simple iota). +func AllTiers() []Tier { + out := make([]Tier, 0, tierCount) + for i := 0; i < int(tierCount); i++ { + out = append(out, Tier(i)) + } + return out +} + +// Algorithm selects the backing implementation for a tier. +type Algorithm int + +const ( + // AlgSliding is the sliding-window limiter: at most N events within + // a rolling Window. Best for strict attempt caps (login floods). + AlgSliding Algorithm = iota + + // AlgTokenBucket is the smooth token-bucket: Rate refill per second + // with Burst as bucket depth. Best for sustained traffic (proxy). + AlgTokenBucket + + // AlgSemaphore is a counting semaphore: at most Concurrency + // in-flight acquirers. Used alongside another algorithm. + AlgSemaphore + + // AlgFailureCounter counts failures (not rate): increments on + // recordFailure, resets on reset. Used for verify-code invalidation. + AlgFailureCounter +) + +// Decision is the outcome of an Allow/Acquire check. A Decision with +// Allow=false carries a RetryAfter and Reason so the caller can emit a +// well-formed 429. +type Decision struct { + Allow bool + RetryAfter time.Duration + Remaining int + Limit int + Reason string // "rate", "concurrency", "global", or "" +} + +// Deny builds a refusal decision. +func Deny(reason string, retryAfter time.Duration, limit int) Decision { + return Decision{Allow: false, RetryAfter: retryAfter, Limit: limit, Reason: reason} +} + +// Allow builds an allowance decision. +func AllowOK(remaining, limit int) Decision { + return Decision{Allow: true, Remaining: remaining, Limit: limit} +} + +// ErrLocked is returned by setting mutation when +// AGENT_VAULT_RATELIMIT_LOCK pins the config. +var ErrLocked = errors.New("rate-limit config is pinned by operator env var") diff --git a/internal/ratelimit/ratelimit_test.go b/internal/ratelimit/ratelimit_test.go new file mode 100644 index 0000000..4174819 --- /dev/null +++ b/internal/ratelimit/ratelimit_test.go @@ -0,0 +1,273 @@ +package ratelimit + +import ( + "context" + "log/slog" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" +) + +func testRegistry(t *testing.T) *Registry { + t.Helper() + return New(DefaultsFor(ProfileDefault)) +} + +func TestSlidingWindowAllowThenDeny(t *testing.T) { + r := testRegistry(t) + // Drive TierAuth to its ceiling (10 at default). + var lastAllowed bool + for i := 0; i < 10; i++ { + d := r.Allow(TierAuth, "ip:1.2.3.4") + lastAllowed = d.Allow + if !d.Allow { + t.Fatalf("attempt %d unexpectedly denied (remaining=%d)", i+1, d.Remaining) + } + } + if !lastAllowed { + t.Fatalf("final allowed attempt had Allow=false") + } + d := r.Allow(TierAuth, "ip:1.2.3.4") + if d.Allow { + t.Fatalf("11th attempt should be denied") + } + if d.RetryAfter <= 0 { + t.Fatalf("retry-after should be positive on denial, got %v", d.RetryAfter) + } + if d.Reason != "rate" { + t.Fatalf("reason=%q want %q", d.Reason, "rate") + } + // Different key is independent. + if d := r.Allow(TierAuth, "ip:5.6.7.8"); !d.Allow { + t.Fatalf("different key should be allowed independently") + } +} + +func TestSlidingWindowEvictionCap(t *testing.T) { + // Tight map cap + tight window so eviction fires on an old entry. + cfg := DefaultsFor(ProfileDefault) + cfg.Tiers[TierAuth].MaxKeys = 4 + cfg.Tiers[TierAuth].Window = 50 * time.Millisecond + r := New(cfg) + // Seed 3 keys, then let them go cold. + for i := 0; i < 3; i++ { + r.Allow(TierAuth, string(rune('a'+i))) + } + time.Sleep(60 * time.Millisecond) + // Seed 5 more (past cap); eviction should purge cold entries. + for i := 0; i < 5; i++ { + r.Allow(TierAuth, string(rune('A'+i))) + } + if sz := r.sliding[TierAuth].size(); sz > cfg.Tiers[TierAuth].MaxKeys+2 { + t.Fatalf("map grew past eviction threshold: size=%d cap=%d", sz, cfg.Tiers[TierAuth].MaxKeys) + } +} + +func TestTokenBucketBurstAndRefill(t *testing.T) { + cfg := DefaultsFor(ProfileDefault) + // Force a tiny, predictable bucket. + cfg.Tiers[TierAuthed] = TierConfig{ + Algorithm: AlgTokenBucket, Rate: 1000, Burst: 3, MaxKeys: 100, + } + r := New(cfg) + for i := 0; i < 3; i++ { + if d := r.Allow(TierAuthed, "actor:x"); !d.Allow { + t.Fatalf("burst slot %d denied", i+1) + } + } + if d := r.Allow(TierAuthed, "actor:x"); d.Allow { + t.Fatalf("over-burst should deny") + } + time.Sleep(20 * time.Millisecond) // 1000/sec × 20ms = 20 tokens refilled + if d := r.Allow(TierAuthed, "actor:x"); !d.Allow { + t.Fatalf("refill did not restore bucket") + } +} + +func TestSemaphoreCapContextCancel(t *testing.T) { + cfg := DefaultsFor(ProfileDefault) + cfg.Tiers[TierProxy].Concurrency = 2 + r := New(cfg) + rel1, d1 := r.acquireKeyed(context.Background(), TierProxy, "scope:a:b") + rel2, d2 := r.acquireKeyed(context.Background(), TierProxy, "scope:a:b") + if !d1.Allow || !d2.Allow { + t.Fatalf("first two acquires should succeed") + } + // Third acquire on the same key blocks then denies when ctx is done. + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + _, d3 := r.acquireKeyed(ctx, TierProxy, "scope:a:b") + if d3.Allow { + t.Fatalf("over-cap acquire should deny") + } + // Release one, a fresh acquire should now succeed. + rel1() + rel4, d4 := r.acquireKeyed(context.Background(), TierProxy, "scope:a:b") + if !d4.Allow { + t.Fatalf("acquire after release should succeed") + } + rel4() + rel2() +} + +func TestFailureCounterCheckRecordReset(t *testing.T) { + cfg := DefaultsFor(ProfileDefault) + cfg.Tiers[TierVerifyFailure].Max = 3 + r := New(cfg) + key := "email:alice@example.com" + for i := 0; i < 3; i++ { + if !r.FailureCheck(TierVerifyFailure, key) { + t.Fatalf("check should return true before exhaustion (i=%d)", i) + } + r.FailureRecord(TierVerifyFailure, key) + } + if r.FailureCheck(TierVerifyFailure, key) { + t.Fatalf("check should return false after %d failures", 3) + } + r.FailureReset(TierVerifyFailure, key) + if !r.FailureCheck(TierVerifyFailure, key) { + t.Fatalf("check should return true after reset") + } +} + +func TestOffConfigShortCircuits(t *testing.T) { + cfg := DefaultsFor(ProfileOff) + if !cfg.Off { + t.Fatalf("ProfileOff should set Off=true") + } + r := New(cfg) + // 1000 calls with the same key must all be allowed. + for i := 0; i < 1000; i++ { + if d := r.Allow(TierAuth, "k"); !d.Allow { + t.Fatalf("off config denied at attempt %d", i) + } + } +} + +func TestMiddlewareEmptyKeySkips(t *testing.T) { + r := testRegistry(t) + logger := slog.New(slog.DiscardHandler) + wrap := r.HandlerFunc(TierAuth, func(*http.Request) string { return "" }, logger) + served := 0 + h := wrap(func(w http.ResponseWriter, _ *http.Request) { served++ }) + for i := 0; i < 100; i++ { + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + h(w, req) + if w.Code != 200 { + t.Fatalf("empty key should fail open, got %d", w.Code) + } + } + if served != 100 { + t.Fatalf("expected 100 passes, got %d", served) + } +} + +func TestMiddlewareDeniesAfterCap(t *testing.T) { + cfg := DefaultsFor(ProfileDefault) + cfg.Tiers[TierAuth].Max = 3 + r := New(cfg) + logger := slog.New(slog.DiscardHandler) + wrap := r.HandlerFunc(TierAuth, func(*http.Request) string { return "ip:1" }, logger) + h := wrap(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(200) }) + var denials int + for i := 0; i < 10; i++ { + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + h(w, req) + if w.Code == http.StatusTooManyRequests { + denials++ + if w.Header().Get("Retry-After") == "" { + t.Fatalf("denied response must set Retry-After") + } + if w.Header().Get("X-RateLimit-Limit") == "" { + t.Fatalf("denied response must set X-RateLimit-Limit") + } + } + } + if denials < 1 { + t.Fatalf("expected at least one 429, got 0") + } +} + +func TestReloadUpdatesInPlace(t *testing.T) { + cfg := DefaultsFor(ProfileDefault) + cfg.Tiers[TierAuth].Max = 2 + r := New(cfg) + for i := 0; i < 2; i++ { + if d := r.Allow(TierAuth, "k"); !d.Allow { + t.Fatalf("attempt %d denied before reload", i+1) + } + } + if d := r.Allow(TierAuth, "k"); d.Allow { + t.Fatalf("over-cap attempt allowed before reload") + } + // Raise the cap — same key should now get more allowance after + // the previous history ages out (we keep history, so this tests + // that window matters more than cap raise). + cfg2 := cfg + cfg2.Tiers[TierAuth].Max = 10 + r.Reload(cfg2) + // Use a fresh key to confirm the new cap applies. + for i := 0; i < 10; i++ { + if d := r.Allow(TierAuth, "fresh"); !d.Allow { + t.Fatalf("post-reload attempt %d denied (new cap should allow 10)", i+1) + } + } +} + +func TestGlobalMiddlewareShortCircuitsWhenOff(t *testing.T) { + r := New(DefaultsFor(ProfileOff)) + logger := slog.New(slog.DiscardHandler) + gm := r.GlobalMiddleware(logger) + h := gm(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(204) })) + for i := 0; i < 50; i++ { + req := httptest.NewRequest(http.MethodGet, "/x", nil) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + if w.Code != 204 { + t.Fatalf("off config denied a request: %d", w.Code) + } + } +} + +func TestValidateFloors(t *testing.T) { + cfg := DefaultsFor(ProfileDefault) + cfg.Tiers[TierAuth].Max = 1 // below floor 5 + if err := cfg.Validate(); err == nil { + t.Fatalf("expected validation error for AUTH below floor") + } + cfg = DefaultsFor(ProfileDefault) + cfg.Tiers[TierGlobal].Concurrency = 10 // below 32 + if err := cfg.Validate(); err == nil { + t.Fatalf("expected validation error for global-inflight below floor") + } +} + +func TestConcurrentUsageRace(t *testing.T) { + // Smoke test: concurrent callers shouldn't panic or corrupt the map. + r := testRegistry(t) + var wg sync.WaitGroup + for g := 0; g < 16; g++ { + wg.Add(1) + go func(g int) { + defer wg.Done() + for i := 0; i < 200; i++ { + _ = r.Allow(TierAuthed, "actor:x") + _, rel := pairRelease(r.acquireKeyed(context.Background(), TierProxy, "s")) + rel() + _ = g + } + }(g) + } + wg.Wait() +} + +func pairRelease(rel func(), d Decision) (Decision, func()) { + if rel == nil { + rel = func() {} + } + return d, rel +} diff --git a/internal/ratelimit/registry.go b/internal/ratelimit/registry.go new file mode 100644 index 0000000..960dcf5 --- /dev/null +++ b/internal/ratelimit/registry.go @@ -0,0 +1,252 @@ +package ratelimit + +import ( + "context" + "sync" + "sync/atomic" + "time" +) + +// Registry owns all per-tier limiters in a single process. It is safe +// for concurrent use. Call New() once at server construction; callers +// hold on to it and pass it into middleware and handlers that need to +// perform explicit allow/acquire calls (like login's email bucket). +type Registry struct { + cfg atomic.Pointer[Config] + + mu sync.RWMutex + // Per-tier state. Only the field matching the tier's algorithm is + // non-nil. Reload swaps fields atomically by taking the write lock. + sliding [tierCount]*slidingWindow + buckets [tierCount]*tokenBucketMap + kSemaphores [tierCount]*keyedSemaphore + failures [tierCount]*failureCounter + globalSem *semaphore + globalRate *tokenBucketMap +} + +// New constructs a Registry from cfg. Call Reload later to pick up +// updated settings without rebuilding the registry. +func New(cfg Config) *Registry { + r := &Registry{} + r.cfg.Store(&cfg) + r.build(cfg) + return r +} + +// Config returns a copy of the current effective config. +func (r *Registry) Config() Config { return *r.cfg.Load() } + +// Reload replaces the registry's config. Existing per-key state is +// preserved where possible (sliding histories, bucket token counts); +// tier capacities and rates update in place. The global semaphore is +// rebuilt if the capacity changes — in-flight holders will release to +// the old semaphore and new requests acquire on the new one. +func (r *Registry) Reload(newCfg Config) { + old := r.cfg.Load() + r.cfg.Store(&newCfg) + r.mu.Lock() + defer r.mu.Unlock() + for t := Tier(0); t < tierCount; t++ { + if r.sliding[t] != nil { + r.sliding[t].reconfigure(newCfg.Tiers[t]) + } + if r.buckets[t] != nil { + r.buckets[t].reconfigure(newCfg.Tiers[t]) + } + if r.kSemaphores[t] != nil { + r.kSemaphores[t].reconfigure(newCfg.Tiers[t]) + } else if t == TierProxy && newCfg.Tiers[t].Concurrency > 0 { + // Boot config had Concurrency=0 so no sem was allocated; + // a UI / env change re-enabled it — build it now. + r.kSemaphores[t] = newKeyedSemaphore(newCfg.Tiers[t], 2*time.Second) + } + if r.failures[t] != nil { + r.failures[t].reconfigure(newCfg.Tiers[t]) + } + } + // Rebuild global sem if capacity changed. + if old == nil || newCfg.Tiers[TierGlobal].Concurrency != old.Tiers[TierGlobal].Concurrency { + r.globalSem = newSemaphore(newCfg.Tiers[TierGlobal].Concurrency, 500*time.Millisecond) + } + if r.globalRate != nil { + r.globalRate.reconfigure(TierConfig{ + Rate: newCfg.Tiers[TierGlobal].Rate, + Burst: newCfg.Tiers[TierGlobal].Burst, + }) + } +} + +// build initializes the per-tier structures for cfg. Called once from +// New; Reload updates in place. +func (r *Registry) build(cfg Config) { + if cfg.Off { + return + } + for t := Tier(0); t < tierCount; t++ { + // TierGlobal is handled below — it carries both a bucket and + // a semaphore, neither of which matches a single Algorithm. + if t == TierGlobal { + continue + } + tc := cfg.Tiers[t] + switch tc.Algorithm { + case AlgSliding: + r.sliding[t] = newSlidingWindow(tc) + case AlgTokenBucket: + r.buckets[t] = newTokenBucketMap(tc) + case AlgFailureCounter: + r.failures[t] = newFailureCounter(tc) + } + } + // TierProxy layers a per-key concurrency semaphore on top of its + // token bucket; the bucket smooths sustained traffic, the sem + // bounds in-flight slow upstream calls. + if cfg.Tiers[TierProxy].Concurrency > 0 { + r.kSemaphores[TierProxy] = newKeyedSemaphore(cfg.Tiers[TierProxy], 2*time.Second) + } + // TierGlobal is two primitives in one tier: an RPS token bucket + // keyed on a constant, and an in-flight semaphore. + r.globalSem = newSemaphore(cfg.Tiers[TierGlobal].Concurrency, 500*time.Millisecond) + r.globalRate = newTokenBucketMap(TierConfig{ + Rate: cfg.Tiers[TierGlobal].Rate, + Burst: cfg.Tiers[TierGlobal].Burst, + MaxKeys: 1, + }) +} + +// Allow records one event against tier for key and returns a Decision. +// Off configs return AllowOK. Unknown/unset tiers fail open (AllowOK). +// Empty key fails open too (the middleware treats "" as "skip"). +func (r *Registry) Allow(tier Tier, key string) Decision { + if r.cfg.Load().Off || key == "" { + return AllowOK(0, 0) + } + r.mu.RLock() + defer r.mu.RUnlock() + if sw := r.sliding[tier]; sw != nil { + return sw.allow(key) + } + if tb := r.buckets[tier]; tb != nil { + return tb.allow(key) + } + return AllowOK(0, 0) +} + +// AllowGlobalRPS is the single-bucket rate check applied to every +// request. Keyed on a constant ("global") inside the bucket map so the +// tokenBucketMap eviction machinery is reused without special-casing. +func (r *Registry) AllowGlobalRPS() Decision { + if r.cfg.Load().Off { + return AllowOK(0, 0) + } + r.mu.RLock() + defer r.mu.RUnlock() + if r.globalRate == nil { + return AllowOK(0, 0) + } + return r.globalRate.allow("global") +} + +// AcquireGlobal takes a slot from the server-wide in-flight semaphore. +// Returns nil release + Deny when the cap is hit. +func (r *Registry) AcquireGlobal(ctx context.Context) (func(), Decision) { + if r.cfg.Load().Off { + return func() {}, AllowOK(0, 0) + } + r.mu.RLock() + sem := r.globalSem + r.mu.RUnlock() + if sem == nil { + return func() {}, AllowOK(0, 0) + } + return sem.acquire(ctx) +} + +// ProxyEnforcement is the outcome of Registry.EnforceProxy. On denial, +// Release is nil and the caller emits a 429 using Decision + ErrCode +// + Message. On allow, defer Release() to free the concurrency slot. +type ProxyEnforcement struct { + Allowed bool + Release func() + Decision Decision + ErrCode string // "rate_limit_scope" | "concurrency_scope" + Message string +} + +// EnforceProxy runs the two TierProxy checks (token bucket first, +// concurrency semaphore second) for one proxy request. Shared between +// /proxy/* and the MITM forward handler so limits apply uniformly +// regardless of ingress. +func (r *Registry) EnforceProxy(ctx context.Context, actorID, vaultID string) ProxyEnforcement { + if r == nil || r.cfg.Load().Off || actorID == "" || vaultID == "" { + return ProxyEnforcement{Allowed: true, Release: func() {}} + } + scopeKey := "scope:" + actorID + ":" + vaultID + if d := r.Allow(TierProxy, scopeKey); !d.Allow { + return ProxyEnforcement{Decision: d, ErrCode: "rate_limit_scope", Message: "Proxy rate limit exceeded for this scope"} + } + release, d := r.acquireKeyed(ctx, TierProxy, scopeKey) + if !d.Allow { + return ProxyEnforcement{Decision: d, ErrCode: "concurrency_scope", Message: "Proxy at concurrency limit for this scope"} + } + return ProxyEnforcement{Allowed: true, Release: release} +} + +// acquireKeyed takes a slot on a per-key semaphore for tier. +func (r *Registry) acquireKeyed(ctx context.Context, tier Tier, key string) (func(), Decision) { + if r.cfg.Load().Off || key == "" { + return func() {}, AllowOK(0, 0) + } + r.mu.RLock() + k := r.kSemaphores[tier] + r.mu.RUnlock() + if k == nil { + return func() {}, AllowOK(0, 0) + } + return k.acquire(ctx, key) +} + +// FailureCheck reports whether key is still allowed to attempt. Used +// for verification-code invalidation (not a rate limit — a failure +// budget). +func (r *Registry) FailureCheck(tier Tier, key string) bool { + if r.cfg.Load().Off || key == "" { + return true + } + r.mu.RLock() + defer r.mu.RUnlock() + fc := r.failures[tier] + if fc == nil { + return true + } + return fc.check(key) +} + +// FailureRecord increments the failure counter for key. +func (r *Registry) FailureRecord(tier Tier, key string) { + if r.cfg.Load().Off || key == "" { + return + } + r.mu.RLock() + fc := r.failures[tier] + r.mu.RUnlock() + if fc != nil { + fc.record(key) + } +} + +// FailureReset clears the counter for key (called on successful +// verification). +func (r *Registry) FailureReset(tier Tier, key string) { + if key == "" { + return + } + r.mu.RLock() + fc := r.failures[tier] + r.mu.RUnlock() + if fc != nil { + fc.reset(key) + } +} + diff --git a/internal/ratelimit/semaphore.go b/internal/ratelimit/semaphore.go new file mode 100644 index 0000000..056fa9d --- /dev/null +++ b/internal/ratelimit/semaphore.go @@ -0,0 +1,160 @@ +package ratelimit + +import ( + "context" + "sync" + "sync/atomic" + "time" +) + +// semaphore is a counting semaphore with context-aware acquire and a +// short wait budget. Callers must invoke the returned release once when +// they no longer hold the slot, even if the decision was a denial +// (release is nil on denial — handled by the caller). +type semaphore struct { + slots chan struct{} + held atomic.Int64 + timeout time.Duration +} + +func newSemaphore(capacity int, waitBudget time.Duration) *semaphore { + if capacity < 1 { + capacity = 1 + } + if waitBudget <= 0 { + waitBudget = 2 * time.Second + } + return &semaphore{ + slots: make(chan struct{}, capacity), + timeout: waitBudget, + } +} + +// acquire attempts to take a slot. It tries non-blocking first, then +// waits up to the configured budget (bounded further by ctx). Returns +// (release, Decision). release is non-nil only when Allow is true. +func (s *semaphore) acquire(ctx context.Context) (func(), Decision) { + // Fast path: slot available immediately. + select { + case s.slots <- struct{}{}: + s.held.Add(1) + return s.releaseOnce(), AllowOK(cap(s.slots)-len(s.slots), cap(s.slots)) + default: + } + + // Wait up to the budget or until ctx is done. + budget := s.timeout + timer := time.NewTimer(budget) + defer timer.Stop() + select { + case s.slots <- struct{}{}: + s.held.Add(1) + return s.releaseOnce(), AllowOK(cap(s.slots)-len(s.slots), cap(s.slots)) + case <-ctx.Done(): + return nil, Deny("concurrency", time.Second, cap(s.slots)) + case <-timer.C: + return nil, Deny("concurrency", time.Second, cap(s.slots)) + } +} + +func (s *semaphore) releaseOnce() func() { + var once sync.Once + return func() { + once.Do(func() { + <-s.slots + s.held.Add(-1) + }) + } +} + +// held returns the current number of in-flight holders (for gauges). +func (s *semaphore) holders() int64 { return s.held.Load() } + +// keyedSemaphore is a map of per-key semaphores, each with the same +// capacity. Used for TierProxy per-scope concurrency caps. maxKeys evicts +// idle (fully-released) semaphores to bound memory. +type keyedSemaphore struct { + mu sync.Mutex + sems map[string]*semaphore + capacity int + waitBudget time.Duration + maxKeys int +} + +func newKeyedSemaphore(cfg TierConfig, waitBudget time.Duration) *keyedSemaphore { + c := cfg.Concurrency + if c < 1 { + c = 1 + } + mk := cfg.MaxKeys + if mk <= 0 { + mk = 10000 + } + return &keyedSemaphore{ + sems: make(map[string]*semaphore), + capacity: c, + waitBudget: waitBudget, + maxKeys: mk, + } +} + +// reconfigure updates the per-key capacity. Existing semaphores keep +// their old capacity until the last holder releases (at which point +// they may be evicted); new keys get the new cap. +func (k *keyedSemaphore) reconfigure(cfg TierConfig) { + k.mu.Lock() + defer k.mu.Unlock() + if cfg.Concurrency > 0 { + k.capacity = cfg.Concurrency + } + if cfg.MaxKeys > 0 { + k.maxKeys = cfg.MaxKeys + } +} + +// acquire returns a release function and decision for a specific key. +func (k *keyedSemaphore) acquire(ctx context.Context, key string) (func(), Decision) { + k.mu.Lock() + sem, ok := k.sems[key] + if !ok { + sem = newSemaphore(k.capacity, k.waitBudget) + k.sems[key] = sem + k.evictLocked(key) + } + k.mu.Unlock() + return sem.acquire(ctx) +} + +// evictLocked keeps the sems map bounded. Preferred: drop entries +// with no holders (cheap, fair). If every entry has holders, abandon +// the weakest-claim semaphore — under adversarial new-key traffic +// this loses one queued slot but bounds memory. Called under k.mu. +func (k *keyedSemaphore) evictLocked(skipKey string) { + if k.maxKeys <= 0 || len(k.sems) <= k.maxKeys { + return + } + for kk, ss := range k.sems { + if kk == skipKey { + continue + } + if ss.holders() == 0 { + delete(k.sems, kk) + if len(k.sems) <= k.maxKeys { + return + } + } + } + // Hard fallback: drop an arbitrary non-skip entry. New requests + // on that key rebuild a fresh semaphore on next acquire; in-flight + // holders still release on the (now-orphaned) sem — safe because + // each holder captured the sem pointer at acquire time. + for kk := range k.sems { + if kk == skipKey { + continue + } + delete(k.sems, kk) + if len(k.sems) <= k.maxKeys { + return + } + } +} diff --git a/internal/ratelimit/sliding.go b/internal/ratelimit/sliding.go new file mode 100644 index 0000000..6a9692e --- /dev/null +++ b/internal/ratelimit/sliding.go @@ -0,0 +1,114 @@ +package ratelimit + +import ( + "sync" + "time" +) + +// slidingWindow is a generic sliding-window limiter keyed by string. +// It tracks timestamps of recent events per key, drops entries older +// than window on every check, and evicts cold keys when the map grows +// past maxKeys. Thread-safe. +// +// The struct fields are mutable at runtime so Registry.Reload can +// change window/max without rebuilding the bucket map; in-flight keys +// pick up the new limits on their next call. +type slidingWindow struct { + mu sync.Mutex + attempts map[string][]time.Time + window time.Duration + max int + maxKeys int + now func() time.Time // injectable for tests +} + +func newSlidingWindow(cfg TierConfig) *slidingWindow { + w := cfg.Window + if w <= 0 { + w = 5 * time.Minute + } + m := cfg.Max + if m < 1 { + m = 1 + } + mk := cfg.MaxKeys + if mk <= 0 { + mk = 10000 + } + return &slidingWindow{ + attempts: make(map[string][]time.Time), + window: w, + max: m, + maxKeys: mk, + now: time.Now, + } +} + +// reconfigure updates the window/max/maxKeys in place. Existing per-key +// history is preserved; future checks use the new parameters. +func (l *slidingWindow) reconfigure(cfg TierConfig) { + l.mu.Lock() + defer l.mu.Unlock() + if cfg.Window > 0 { + l.window = cfg.Window + } + if cfg.Max > 0 { + l.max = cfg.Max + } + if cfg.MaxKeys > 0 { + l.maxKeys = cfg.MaxKeys + } +} + +// allow records an attempt for key and returns a Decision. The earliest +// event in the window is used to compute RetryAfter on denial so the +// response's Retry-After header is accurate. +func (l *slidingWindow) allow(key string) Decision { + l.mu.Lock() + defer l.mu.Unlock() + + now := l.now() + cutoff := now.Add(-l.window) + + recent := l.attempts[key][:0] + for _, t := range l.attempts[key] { + if t.After(cutoff) { + recent = append(recent, t) + } + } + l.attempts[key] = recent + + if len(recent) >= l.max { + // retry-after = time until earliest event falls out of the window. + var wait time.Duration + if len(recent) > 0 { + wait = recent[0].Add(l.window).Sub(now) + } + if wait < time.Second { + wait = time.Second + } + return Deny("rate", wait, l.max) + } + + l.attempts[key] = append(l.attempts[key], now) + + // Evict cold keys to prevent unbounded growth. Walk once looking + // for entries whose most-recent attempt is already outside the + // window — cheap and correct, matches the legacy limiter. + if l.maxKeys > 0 && len(l.attempts) > l.maxKeys { + for k, v := range l.attempts { + if len(v) == 0 || v[len(v)-1].Before(cutoff) { + delete(l.attempts, k) + } + } + } + + return AllowOK(l.max-len(recent)-1, l.max) +} + +// size returns the number of tracked keys (for gauges/tests). +func (l *slidingWindow) size() int { + l.mu.Lock() + defer l.mu.Unlock() + return len(l.attempts) +} diff --git a/internal/server/handle_auth.go b/internal/server/handle_auth.go index 2b463dc..26e961f 100644 --- a/internal/server/handle_auth.go +++ b/internal/server/handle_auth.go @@ -12,11 +12,11 @@ import ( "net/http" "os" "strings" - "sync" "time" "github.com/Infisical/agent-vault/internal/auth" "github.com/Infisical/agent-vault/internal/crypto" + "github.com/Infisical/agent-vault/internal/ratelimit" "github.com/Infisical/agent-vault/internal/store" ) @@ -24,46 +24,6 @@ const emailVerificationTTL = 15 * time.Minute const maxPendingVerifications = 3 -// verifyRateLimiter tracks failed verification attempts per email. -type verifyRateLimiter struct { - mu sync.Mutex - attempts map[string]int - maxKeys int -} - -const maxVerifyAttempts = 10 // max failed attempts per email before code is invalidated - -const maxVerifyKeys = 10000 // max tracked emails to prevent unbounded map growth - -var verifyLimiter = &verifyRateLimiter{attempts: make(map[string]int), maxKeys: maxVerifyKeys} - -func (l *verifyRateLimiter) check(email string) bool { - l.mu.Lock() - defer l.mu.Unlock() - return l.attempts[email] < maxVerifyAttempts -} - -func (l *verifyRateLimiter) recordFailure(email string) { - l.mu.Lock() - defer l.mu.Unlock() - l.attempts[email]++ - // Evict entries if map grows too large (DoS protection). - if len(l.attempts) > l.maxKeys { - for k := range l.attempts { - if k != email { - delete(l.attempts, k) - break - } - } - } -} - -func (l *verifyRateLimiter) reset(email string) { - l.mu.Lock() - defer l.mu.Unlock() - delete(l.attempts, email) -} - // generateAndSendVerificationCode creates a new 6-digit verification code for // the given email and sends it via email (or logs to stderr if SMTP is not configured). func (s *Server) generateAndSendVerificationCode(ctx context.Context, email string) (bool, error) { @@ -120,13 +80,6 @@ func (s *Server) handleRegister(w http.ResponseWriter, r *http.Request) { return } - // Rate limit registrations by IP to prevent account creation floods. - ip := clientIP(r) - if !registerLimiter.allow(ip) { - jsonError(w, http.StatusTooManyRequests, "Too many registration attempts, try again later") - return - } - ctx := r.Context() // Check domain and invite-only restrictions (skip for first user — owner can set any email). @@ -280,14 +233,15 @@ func (s *Server) handleVerify(w http.ResponseWriter, r *http.Request) { ctx := r.Context() // Rate limit verification attempts per email. - if !verifyLimiter.check(req.Email) { + verifyKey := "v:" + req.Email + if !s.rateLimit.FailureCheck(ratelimit.TierVerifyFailure, verifyKey) { jsonError(w, http.StatusTooManyRequests, "Too many failed verification attempts; request a new code") return } ev, err := s.store.GetPendingEmailVerification(ctx, req.Email, req.Code) if err != nil || ev == nil { - verifyLimiter.recordFailure(req.Email) + s.rateLimit.FailureRecord(ratelimit.TierVerifyFailure, verifyKey) jsonError(w, http.StatusBadRequest, "Invalid or expired verification code") return } @@ -314,7 +268,7 @@ func (s *Server) handleVerify(w http.ResponseWriter, r *http.Request) { } // Reset rate limit on successful verification. - verifyLimiter.reset(req.Email) + s.rateLimit.FailureReset(ratelimit.TierVerifyFailure, verifyKey) // Auto-login: create session and set cookie. session, err := s.store.CreateSession(ctx, user.ID, time.Now().Add(sessionTTL)) @@ -340,13 +294,6 @@ func (s *Server) handleResendVerification(w http.ResponseWriter, r *http.Request return } - // Rate limit by IP. - ip := clientIP(r) - if !resendVerifyLimiter.allow(ip) { - jsonError(w, http.StatusTooManyRequests, "Too many requests, try again later") - return - } - ctx := r.Context() // Uniform response to prevent email enumeration. @@ -382,13 +329,6 @@ func (s *Server) handleForgotPassword(w http.ResponseWriter, r *http.Request) { return } - // Rate limit by IP. - ip := clientIP(r) - if !forgotPasswordLimiter.allow(ip) { - jsonError(w, http.StatusTooManyRequests, "Too many requests, try again later") - return - } - ctx := r.Context() // Lazy expiration of old password resets. @@ -475,14 +415,15 @@ func (s *Server) handleResetPassword(w http.ResponseWriter, r *http.Request) { ctx := r.Context() // Rate limit verification attempts per email. - if !resetVerifyLimiter.check(req.Email) { + resetKey := "rp:" + req.Email + if !s.rateLimit.FailureCheck(ratelimit.TierVerifyFailure, resetKey) { jsonError(w, http.StatusTooManyRequests, "Too many failed reset attempts; request a new code") return } pr, err := s.store.GetPendingPasswordReset(ctx, req.Email, req.Code) if err != nil || pr == nil { - resetVerifyLimiter.recordFailure(req.Email) + s.rateLimit.FailureRecord(ratelimit.TierVerifyFailure, resetKey) jsonError(w, http.StatusBadRequest, "Invalid or expired reset code") return } @@ -520,7 +461,7 @@ func (s *Server) handleResetPassword(w http.ResponseWriter, r *http.Request) { _ = s.store.DeleteUserSessions(ctx, user.ID) // Reset rate limit on successful reset. - resetVerifyLimiter.reset(req.Email) + s.rateLimit.FailureReset(ratelimit.TierVerifyFailure, resetKey) // Create new session and auto-login. session, err := s.store.CreateSession(ctx, user.ID, time.Now().Add(sessionTTL)) @@ -636,15 +577,6 @@ func clientIP(r *http.Request) string { return remoteIP } -func newSlidingWindowLimiter(window time.Duration, max, maxKeys int) *slidingWindowLimiter { - return &slidingWindowLimiter{ - attempts: make(map[string][]time.Time), - window: window, - max: max, - maxKeys: maxKeys, - } -} - func init() { dummyPasswordHash, dummyPasswordSalt, dummyKDFParams, _ = auth.HashUserPassword([]byte("sb-dummy-timing-equalization")) } @@ -656,10 +588,18 @@ func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) { return } - // Rate limit by IP and by email. + // Rate limit by IP and by email (reject if either bucket is full). + // Normalize the email key so case/whitespace variations don't + // let an attacker bypass the email bucket for a legitimate user. ip := clientIP(r) - if !loginIPLimiter.allow(ip) || !loginEmailLimiter.allow(req.Email) { - jsonError(w, http.StatusTooManyRequests, "Too many login attempts, try again later") + ipDecision := s.rateLimit.Allow(ratelimit.TierAuth, "ip:"+ip) + emailDecision := s.rateLimit.Allow(ratelimit.TierAuth, "email:"+strings.ToLower(strings.TrimSpace(req.Email))) + if !ipDecision.Allow || !emailDecision.Allow { + d := ipDecision + if !emailDecision.Allow { + d = emailDecision + } + ratelimit.WriteDenial(w, d, "Too many login attempts, try again later") return } diff --git a/internal/server/handle_mitm_test.go b/internal/server/handle_mitm_test.go index f82769a..62fd343 100644 --- a/internal/server/handle_mitm_test.go +++ b/internal/server/handle_mitm_test.go @@ -29,7 +29,7 @@ func TestHandleMITMCA(t *testing.T) { if err != nil { t.Fatalf("ca.New: %v", err) } - p := mitm.New("127.0.0.1:0", caProv, srv.SessionResolver(), srv.CredentialProvider(), srv.BaseURL(), srv.Logger()) + p := mitm.New("127.0.0.1:0", mitm.Options{CA: caProv, Sessions: srv.SessionResolver(), Credentials: srv.CredentialProvider(), BaseURL: srv.BaseURL(), Logger: srv.Logger(), RateLimit: srv.RateLimit()}) srv.AttachMITM(p) // Start the proxy so IsListening() reports true. The handler gates @@ -100,7 +100,7 @@ func TestHandleMITMCA(t *testing.T) { } // Bind to an explicit non-default port so we can assert the // header reflects the configured Addr rather than any constant. - p := mitm.New("127.0.0.1:19322", caProv, srv.SessionResolver(), srv.CredentialProvider(), srv.BaseURL(), srv.Logger()) + p := mitm.New("127.0.0.1:19322", mitm.Options{CA: caProv, Sessions: srv.SessionResolver(), Credentials: srv.CredentialProvider(), BaseURL: srv.BaseURL(), Logger: srv.Logger(), RateLimit: srv.RateLimit()}) srv.AttachMITM(p) l, err := net.Listen("tcp", "127.0.0.1:0") @@ -164,7 +164,7 @@ func TestHandleMITMCA(t *testing.T) { if err != nil { t.Fatalf("ca.New: %v", err) } - p := mitm.New("127.0.0.1:0", caProv, srv.SessionResolver(), srv.CredentialProvider(), srv.BaseURL(), srv.Logger()) + p := mitm.New("127.0.0.1:0", mitm.Options{CA: caProv, Sessions: srv.SessionResolver(), Credentials: srv.CredentialProvider(), BaseURL: srv.BaseURL(), Logger: srv.Logger(), RateLimit: srv.RateLimit()}) srv.AttachMITM(p) // Intentionally do not call Serve — simulates a bind failure. diff --git a/internal/server/handle_oauth.go b/internal/server/handle_oauth.go index b8d4ca7..33314a9 100644 --- a/internal/server/handle_oauth.go +++ b/internal/server/handle_oauth.go @@ -17,8 +17,6 @@ import ( const oauthStateTTL = 10 * time.Minute -var oauthLoginLimiter = newSlidingWindowLimiter(5*time.Minute, 20, 10000) // 20 OAuth initiations per IP per 5 min - // handleOAuthProviders returns the list of enabled OAuth providers. func (s *Server) handleOAuthProviders(w http.ResponseWriter, r *http.Request) { type providerInfo struct { @@ -48,11 +46,6 @@ func (s *Server) handleOAuthLogin(w http.ResponseWriter, r *http.Request) { return } - ip := clientIP(r) - if !oauthLoginLimiter.allow(ip) { - jsonError(w, http.StatusTooManyRequests, "Too many login attempts, try again later") - return - } // Generate CSRF state, PKCE code verifier, and OIDC nonce. state, err := generateRandomHex(32) diff --git a/internal/server/handle_proxy.go b/internal/server/handle_proxy.go index 8f8a3cf..178820f 100644 --- a/internal/server/handle_proxy.go +++ b/internal/server/handle_proxy.go @@ -10,6 +10,7 @@ import ( "github.com/Infisical/agent-vault/internal/brokercore" "github.com/Infisical/agent-vault/internal/netguard" + "github.com/Infisical/agent-vault/internal/ratelimit" "github.com/Infisical/agent-vault/internal/store" ) @@ -156,6 +157,18 @@ func (s *Server) handleProxy(w http.ResponseWriter, r *http.Request) { return // error already written } + // Enforced post-vault-resolution; scope isn't known until here. + scope := &brokercore.ProxyScope{UserID: sess.UserID, AgentID: sess.AgentID} + enf := s.rateLimit.EnforceProxy(ctx, scope.ActorID(), ns.ID) + if !enf.Allowed { + ratelimit.WriteDenial(w, enf.Decision, enf.Message) + emit(http.StatusTooManyRequests, enf.ErrCode) + return + } + defer enf.Release() + + r.Body = http.MaxBytesReader(w, r.Body, brokercore.MaxProxyBodyBytes) + // Resolve broker service + inject credentials. inject, err := s.CredentialProvider().Inject(ctx, ns.ID, targetHost) if inject != nil { diff --git a/internal/server/handle_settings.go b/internal/server/handle_settings.go index 0e38ead..f888399 100644 --- a/internal/server/handle_settings.go +++ b/internal/server/handle_settings.go @@ -9,6 +9,8 @@ import ( "net/http" "os" "strings" + + "github.com/Infisical/agent-vault/internal/ratelimit" ) // handleEmailTest sends a test email to verify SMTP configuration. @@ -88,6 +90,12 @@ func (s *Server) writeSettingsResponse(w http.ResponseWriter, ctx context.Contex resp["invite_only"] = raw == "true" } + // Rate-limit settings: include the effective config, its source per + // tier, and the operator-pin flag so the UI can disable fields. + if rl := s.buildRateLimitSettingResponse(ctx); rl != nil { + resp["rate_limit"] = rl + } + jsonOK(w, resp) } @@ -97,8 +105,9 @@ func (s *Server) handleUpdateSettings(w http.ResponseWriter, r *http.Request) { } var req struct { - AllowedEmailDomains *[]string `json:"allowed_email_domains"` - InviteOnly *bool `json:"invite_only"` + AllowedEmailDomains *[]string `json:"allowed_email_domains"` + InviteOnly *bool `json:"invite_only"` + RateLimit *rateLimitSettingPayload `json:"rate_limit"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { jsonError(w, http.StatusBadRequest, "Invalid request body") @@ -153,6 +162,17 @@ func (s *Server) handleUpdateSettings(w http.ResponseWriter, r *http.Request) { } } + if req.RateLimit != nil { + if err := s.handleUpdateRateLimitSetting(ctx, req.RateLimit); err != nil { + status := http.StatusBadRequest + if errors.Is(err, ratelimit.ErrLocked) { + status = http.StatusConflict + } + jsonError(w, status, err.Error()) + return + } + } + // Return the updated settings. s.writeSettingsResponse(w, r.Context()) } diff --git a/internal/server/handle_users.go b/internal/server/handle_users.go index d4473b8..e0aa9a2 100644 --- a/internal/server/handle_users.go +++ b/internal/server/handle_users.go @@ -420,12 +420,6 @@ func (s *Server) handleUserInviteCreate(w http.ResponseWriter, r *http.Request) } func (s *Server) handleUserInviteAccept(w http.ResponseWriter, r *http.Request) { - ip := clientIP(r) - if !userInviteAcceptLimiter.allow(ip) { - jsonError(w, http.StatusTooManyRequests, "Too many requests. Please try again later.") - return - } - ctx := r.Context() token := r.PathValue("token") diff --git a/internal/server/ratelimit_settings.go b/internal/server/ratelimit_settings.go new file mode 100644 index 0000000..682127b --- /dev/null +++ b/internal/server/ratelimit_settings.go @@ -0,0 +1,309 @@ +package server + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "net/http" + "time" + + "github.com/Infisical/agent-vault/internal/ratelimit" +) + +// rateLimitSettingPayload is the JSON blob stored under +// settings[settingRateLimitConfig]. profile is a named bundle; +// overrides (optional) tune individual tier fields on top of it. +// Zero/empty fields in an override mean "inherit from profile". +type rateLimitSettingPayload struct { + Profile string `json:"profile"` + Overrides map[string]rateLimitTierOverride `json:"overrides,omitempty"` +} + +// rateLimitTierOverride mirrors ratelimit.TierConfig but uses a JSON- +// friendly duration string for Window. All fields are optional. +type rateLimitTierOverride struct { + Rate *float64 `json:"rate,omitempty"` + Burst *int `json:"burst,omitempty"` + Window string `json:"window,omitempty"` // RFC: e.g. "5m", "1h" + Max *int `json:"max,omitempty"` + Concurrency *int `json:"concurrency,omitempty"` +} + + +// loadRateLimitSetting returns the parsed setting payload, or a +// zero-value payload if the setting is absent. +func loadRateLimitSetting(ctx context.Context, s Store) (rateLimitSettingPayload, bool, error) { + raw, err := s.GetSetting(ctx, settingRateLimitConfig) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return rateLimitSettingPayload{}, false, nil + } + return rateLimitSettingPayload{}, false, err + } + if raw == "" { + return rateLimitSettingPayload{}, false, nil + } + var p rateLimitSettingPayload + if err := json.Unmarshal([]byte(raw), &p); err != nil { + return rateLimitSettingPayload{}, false, fmt.Errorf("parse ratelimit_config: %w", err) + } + return p, true, nil +} + +// resolveRateLimitConfig computes the effective Config applied to the +// registry. Precedence: env > instance setting > built-in default. +// When AGENT_VAULT_RATELIMIT_LOCK=true, the instance setting is +// ignored entirely (operator pin). Returns the env-set mask so +// callers can render per-tier "source" without re-scanning os.Getenv. +func resolveRateLimitConfig(ctx context.Context, s Store) (ratelimit.Config, rateLimitSettingPayload, bool, ratelimit.EnvMasks, error) { + envCfg, envMask := ratelimit.LoadFromEnv() + if envCfg.Locked { + return envCfg, rateLimitSettingPayload{}, false, envMask, nil + } + payload, present, err := loadRateLimitSetting(ctx, s) + if err != nil { + return envCfg, rateLimitSettingPayload{}, false, envMask, err + } + if !present { + return envCfg, payload, false, envMask, nil + } + return applyPayload(envCfg, envMask, payload), payload, true, envMask, nil +} + +// applyPayload layers a setting payload on top of an env-derived base. +// envMask tells us which env knobs were explicitly set so we can +// re-assert them after switching to the payload's profile. Pure — no +// DB reads. +func applyPayload(envCfg ratelimit.Config, envMask ratelimit.EnvMasks, payload rateLimitSettingPayload) ratelimit.Config { + base := envCfg + if payload.Profile != "" { + base = ratelimit.DefaultsFor(ratelimit.Profile(payload.Profile)) + base.Locked = envCfg.Locked + // DefaultsFor wiped the env knobs; copy them back from envCfg. + for _, t := range ratelimit.AllTiers() { + if envMask[t].Rate { + base.Tiers[t].Rate = envCfg.Tiers[t].Rate + } + if envMask[t].Burst { + base.Tiers[t].Burst = envCfg.Tiers[t].Burst + } + if envMask[t].Window { + base.Tiers[t].Window = envCfg.Tiers[t].Window + } + if envMask[t].Max { + base.Tiers[t].Max = envCfg.Tiers[t].Max + } + if envMask[t].Concurrency { + base.Tiers[t].Concurrency = envCfg.Tiers[t].Concurrency + } + } + } + + overrides := map[ratelimit.Tier]ratelimit.TierConfig{} + for name, ov := range payload.Overrides { + tier, ok := ratelimit.TierByName(name) + if !ok { + continue + } + tc := ratelimit.TierConfig{} + if ov.Rate != nil { + tc.Rate = *ov.Rate + } + if ov.Burst != nil { + tc.Burst = *ov.Burst + } + if ov.Window != "" { + if d, err := time.ParseDuration(ov.Window); err == nil { + tc.Window = d + } + } + if ov.Max != nil { + tc.Max = *ov.Max + } + if ov.Concurrency != nil { + tc.Concurrency = *ov.Concurrency + } + overrides[tier] = tc + } + base.ApplyOverrides(overrides) + return base +} + +// rateLimitSourceForTier returns "env", "override", or "default" so +// the UI can show which layer supplied each tier's current values. +func rateLimitSourceForTier(t ratelimit.Tier, payload rateLimitSettingPayload, hasPayload bool, envMask ratelimit.EnvMasks) string { + if envMask[t].Any() { + return "env" + } + if hasPayload { + if _, ok := payload.Overrides[t.String()]; ok { + return "override" + } + } + return "default" +} + +// applyRateLimitSettingToRegistry reads the current setting + env and +// reloads the registry. Called once at server startup and again after +// every write to the settings pane. Returns the effective config. +func (s *Server) applyRateLimitSettingToRegistry(ctx context.Context) (ratelimit.Config, error) { + cfg, _, _, _, err := resolveRateLimitConfig(ctx, s.store) + if err != nil { + return cfg, err + } + s.rateLimit.Reload(cfg) + return cfg, nil +} + +// tierJSON is the wire representation of one tier's effective config +// returned by GET /v1/admin/settings. Source is "env" | "override" | +// "default"; Window is a duration string ("5m", "1h") for readability. +type tierJSON struct { + Rate float64 `json:"rate,omitempty"` + Burst int `json:"burst,omitempty"` + Window string `json:"window,omitempty"` + Max int `json:"max,omitempty"` + Concurrency int `json:"concurrency,omitempty"` + Source string `json:"source"` +} + +// buildRateLimitSettingResponse assembles the "rate_limit" field of +// the GET /v1/admin/settings response: profile + per-tier effective +// values + per-tier source + locked flag. Returns nil on a hard +// error; the rest of the settings response degrades without the +// rate-limit block. +func (s *Server) buildRateLimitSettingResponse(ctx context.Context) map[string]interface{} { + cfg, payload, hasPayload, envMask, err := resolveRateLimitConfig(ctx, s.store) + if err != nil { + return nil + } + return buildRateLimitResponse(cfg, payload, hasPayload, envMask) +} + +// buildRateLimitResponse serializes the effective config + per-tier +// source into the wire shape. Separate from buildRateLimitSettingResponse +// so the preview handler can reuse it without a DB read. +func buildRateLimitResponse(cfg ratelimit.Config, payload rateLimitSettingPayload, hasPayload bool, envMask ratelimit.EnvMasks) map[string]interface{} { + all := ratelimit.AllTiers() + tiers := make(map[string]tierJSON, len(all)) + for _, t := range all { + tc := cfg.Tiers[t] + tiers[t.String()] = tierJSON{ + Rate: tc.Rate, + Burst: tc.Burst, + Window: formatDuration(tc.Window), + Max: tc.Max, + Concurrency: tc.Concurrency, + Source: rateLimitSourceForTier(t, payload, hasPayload, envMask), + } + } + profile := string(cfg.Profile) + if payload.Profile != "" { + profile = payload.Profile + } + return map[string]interface{}{ + "profile": profile, + "locked": cfg.Locked, + "off": cfg.Off, + "tiers": tiers, + } +} + +// handleRateLimitPreview computes the effective config for a proposed +// payload without persisting it. Used by the Manage Instance UI to +// update the table live as the owner changes the profile dropdown or +// edits override fields. +func (s *Server) handleRateLimitPreview(w http.ResponseWriter, r *http.Request) { + if _, err := s.requireOwnerActor(w, r); err != nil { + return + } + envCfg, envMask := ratelimit.LoadFromEnv() + if envCfg.Locked { + jsonError(w, http.StatusConflict, "Rate-limit config is pinned by operator env var") + return + } + var p rateLimitSettingPayload + if err := json.NewDecoder(r.Body).Decode(&p); err != nil { + jsonError(w, http.StatusBadRequest, "Invalid request body") + return + } + cfg := applyPayload(envCfg, envMask, p) + jsonOK(w, buildRateLimitResponse(cfg, p, true, envMask)) +} + +func formatDuration(d time.Duration) string { + if d == 0 { + return "" + } + return d.String() +} + +// handleUpdateRateLimitSetting validates and persists a rate-limit +// setting payload. Invoked from handleUpdateSettings when the request +// includes a rate_limit field. On success, the registry is reloaded +// and the new effective config is the one returned by the subsequent +// GET response. +func (s *Server) handleUpdateRateLimitSetting(ctx context.Context, p *rateLimitSettingPayload) error { + if s.rateLimit.Config().Locked { + return ratelimit.ErrLocked + } + if p == nil { + return fmt.Errorf("rate_limit payload is required") + } + + // Validate profile if present. + if p.Profile != "" { + switch ratelimit.Profile(p.Profile) { + case ratelimit.ProfileDefault, ratelimit.ProfileStrict, ratelimit.ProfileLoose, ratelimit.ProfileOff: + default: + return fmt.Errorf("invalid profile %q (default|strict|loose|off)", p.Profile) + } + } + + // Validate the overrides against fixed clamps before layering them + // onto a candidate config through the shared applyPayload helper. + for name, ov := range p.Overrides { + if _, ok := ratelimit.TierByName(name); !ok { + return fmt.Errorf("unknown tier %q", name) + } + if ov.Rate != nil && (*ov.Rate < 0 || *ov.Rate > 100000) { + return fmt.Errorf("tier %s: rate %.1f out of range (0-100000)", name, *ov.Rate) + } + if ov.Burst != nil && (*ov.Burst < 0 || *ov.Burst > 100000) { + return fmt.Errorf("tier %s: burst %d out of range (0-100000)", name, *ov.Burst) + } + if ov.Max != nil && (*ov.Max < 0 || *ov.Max > 100000) { + return fmt.Errorf("tier %s: max %d out of range (0-100000)", name, *ov.Max) + } + if ov.Concurrency != nil && (*ov.Concurrency < 0 || *ov.Concurrency > 8192) { + return fmt.Errorf("tier %s: concurrency %d out of range (0-8192)", name, *ov.Concurrency) + } + if ov.Window != "" { + if _, err := time.ParseDuration(ov.Window); err != nil { + return fmt.Errorf("tier %s: window %q: %w", name, ov.Window, err) + } + } + } + envCfg, envMask := ratelimit.LoadFromEnv() + candidate := applyPayload(envCfg, envMask, *p) + if err := candidate.Validate(); err != nil { + return err + } + + // Persist and reload. We always store the exact caller-supplied + // payload so the UI can round-trip values even when a future + // profile change would otherwise clobber them. + encoded, err := json.Marshal(p) + if err != nil { + return fmt.Errorf("encode ratelimit_config: %w", err) + } + if err := s.store.SetSetting(ctx, settingRateLimitConfig, string(encoded)); err != nil { + return fmt.Errorf("save ratelimit_config: %w", err) + } + if _, err := s.applyRateLimitSettingToRegistry(ctx); err != nil { + return fmt.Errorf("reload: %w", err) + } + return nil +} diff --git a/internal/server/ratelimit_settings_test.go b/internal/server/ratelimit_settings_test.go new file mode 100644 index 0000000..6d12af2 --- /dev/null +++ b/internal/server/ratelimit_settings_test.go @@ -0,0 +1,128 @@ +package server + +import ( + "context" + "database/sql" + "encoding/json" + "testing" + + "github.com/Infisical/agent-vault/internal/ratelimit" +) + +// fakeSettingsStore exercises the GetSetting/SetSetting path only. +type fakeSettingsStore struct { + Store + settings map[string]string +} + +func newFakeSettingsStore() *fakeSettingsStore { + return &fakeSettingsStore{settings: make(map[string]string)} +} + +func (f *fakeSettingsStore) GetSetting(_ context.Context, key string) (string, error) { + v, ok := f.settings[key] + if !ok { + return "", sql.ErrNoRows + } + return v, nil +} +func (f *fakeSettingsStore) SetSetting(_ context.Context, key, value string) error { + f.settings[key] = value + return nil +} + +func TestResolveRateLimitConfigDefault(t *testing.T) { + t.Setenv("AGENT_VAULT_RATELIMIT_PROFILE", "") + t.Setenv("AGENT_VAULT_RATELIMIT_LOCK", "") + store := newFakeSettingsStore() + cfg, _, hasPayload, _, err := resolveRateLimitConfig(context.Background(), store) + if err != nil { + t.Fatalf("resolve: %v", err) + } + if hasPayload { + t.Fatalf("expected no payload when setting is absent") + } + if cfg.Profile != ratelimit.ProfileDefault { + t.Fatalf("want default profile, got %q", cfg.Profile) + } + if cfg.Tiers[ratelimit.TierAuth].Max < 5 { + t.Fatalf("AUTH max below floor, cfg misbuilt: %+v", cfg.Tiers[ratelimit.TierAuth]) + } +} + +func TestResolveRateLimitConfigWithOverride(t *testing.T) { + t.Setenv("AGENT_VAULT_RATELIMIT_PROFILE", "") + t.Setenv("AGENT_VAULT_RATELIMIT_LOCK", "") + store := newFakeSettingsStore() + payload := rateLimitSettingPayload{ + Profile: "default", + Overrides: map[string]rateLimitTierOverride{ + "AUTHED": {Rate: float64Ptr(3.5), Burst: intPtr(25)}, + }, + } + b, _ := json.Marshal(payload) + store.settings[settingRateLimitConfig] = string(b) + + cfg, _, hasPayload, _, err := resolveRateLimitConfig(context.Background(), store) + if err != nil { + t.Fatalf("resolve: %v", err) + } + if !hasPayload { + t.Fatalf("expected payload to be loaded") + } + if cfg.Tiers[ratelimit.TierAuthed].Rate != 3.5 { + t.Fatalf("override rate not applied: %v", cfg.Tiers[ratelimit.TierAuthed].Rate) + } + if cfg.Tiers[ratelimit.TierAuthed].Burst != 25 { + t.Fatalf("override burst not applied: %v", cfg.Tiers[ratelimit.TierAuthed].Burst) + } +} + +func TestResolveRateLimitConfigHonorsEnvLock(t *testing.T) { + t.Setenv("AGENT_VAULT_RATELIMIT_PROFILE", "strict") + t.Setenv("AGENT_VAULT_RATELIMIT_LOCK", "true") + store := newFakeSettingsStore() + payload := rateLimitSettingPayload{ + Profile: "loose", + Overrides: map[string]rateLimitTierOverride{ + "AUTHED": {Rate: float64Ptr(999)}, + }, + } + b, _ := json.Marshal(payload) + store.settings[settingRateLimitConfig] = string(b) + + cfg, _, hasPayload, _, err := resolveRateLimitConfig(context.Background(), store) + if err != nil { + t.Fatalf("resolve: %v", err) + } + if hasPayload { + t.Fatalf("locked config should ignore stored payload") + } + // Strict profile halves the default rate (1.0 → 0.5); should NOT be 999. + if cfg.Tiers[ratelimit.TierAuthed].Rate == 999 { + t.Fatalf("stored override leaked through env lock") + } + if cfg.Profile != ratelimit.ProfileStrict { + t.Fatalf("want strict from env, got %q", cfg.Profile) + } +} + +func TestRateLimitSourceForTier(t *testing.T) { + t.Setenv("AGENT_VAULT_RATELIMIT_PROXY_BURST", "42") + _, envMask := ratelimit.LoadFromEnv() + payload := rateLimitSettingPayload{ + Overrides: map[string]rateLimitTierOverride{"AUTHED": {Rate: float64Ptr(1)}}, + } + if got := rateLimitSourceForTier(ratelimit.TierProxy, payload, true, envMask); got != "env" { + t.Fatalf("env knob should win: got %q", got) + } + if got := rateLimitSourceForTier(ratelimit.TierAuthed, payload, true, envMask); got != "override" { + t.Fatalf("override should win when no env knob: got %q", got) + } + if got := rateLimitSourceForTier(ratelimit.TierAuth, payload, true, envMask); got != "default" { + t.Fatalf("want default when no env/override: got %q", got) + } +} + +func float64Ptr(f float64) *float64 { return &f } +func intPtr(i int) *int { return &i } diff --git a/internal/server/server.go b/internal/server/server.go index 7cd86d2..6136fee 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -15,7 +15,6 @@ import ( "time" "net" - "sync" "github.com/Infisical/agent-vault/internal/brokercore" "github.com/Infisical/agent-vault/internal/crypto" @@ -23,6 +22,7 @@ import ( "github.com/Infisical/agent-vault/internal/notify" "github.com/Infisical/agent-vault/internal/oauth" "github.com/Infisical/agent-vault/internal/pidfile" + "github.com/Infisical/agent-vault/internal/ratelimit" "github.com/Infisical/agent-vault/internal/store" ) @@ -61,10 +61,15 @@ type Server struct { oauthProviders map[string]oauth.Provider skillCLI []byte // embedded CLI skill content (served at GET /v1/skills/cli) skillHTTP []byte // embedded HTTP skill content (served at GET /v1/skills/http) - mitm *mitm.Proxy // transparent MITM proxy; nil only when --mitm-port 0 - logger *slog.Logger // structured logger for per-request observability + mitm *mitm.Proxy // transparent MITM proxy; nil only when --mitm-port 0 + logger *slog.Logger // structured logger for per-request observability + rateLimit *ratelimit.Registry // tiered rate limiter; shared with the MITM ingress } +// RateLimit returns the server's rate-limit registry. Exported so the +// MITM ingress can share the same tier state (see cmd/server.go). +func (s *Server) RateLimit() *ratelimit.Registry { return s.rateLimit } + // AttachMITM registers an optional transparent MITM proxy whose lifecycle // is bound to this Server: Start launches it, and SIGINT/SIGTERM/Shutdown // stops it alongside the HTTP server. @@ -524,10 +529,13 @@ func New(addr string, store Store, encKey []byte, notifier *notify.Notifier, ini proxyClient = newProxyClient() } + rlCfg, _ := ratelimit.LoadFromEnv() + rl := ratelimit.New(rlCfg) + s := &Server{ httpServer: &http.Server{ Addr: addr, - Handler: securityHeaders(mux), + Handler: securityHeaders(rl.GlobalMiddleware(logger)(mux)), ReadHeaderTimeout: 10 * time.Second, ReadTimeout: 30 * time.Second, WriteTimeout: 60 * time.Second, @@ -540,128 +548,148 @@ func New(addr string, store Store, encKey []byte, notifier *notify.Notifier, ini baseURL: strings.TrimRight(baseURL, "/"), oauthProviders: oauthProviders, logger: logger, + rateLimit: rl, } - // Always available (no initialization required) + ipAuth := s.tier(ratelimit.TierAuth, s.ipKeyer()) + ipInviteToken := s.tier(ratelimit.TierAuth, s.tokenKeyer("token")) + + // /health, /v1/status, and other public static routes rely on the + // server-wide TierGlobal backstop; no per-route limit is useful. mux.HandleFunc("GET /health", s.handleHealth) mux.HandleFunc("GET /v1/status", s.handleStatus) - mux.HandleFunc("POST /v1/auth/register", limitBody(s.handleRegister)) - mux.HandleFunc("POST /v1/auth/verify", limitBody(s.handleVerify)) - mux.HandleFunc("POST /v1/auth/resend-verification", limitBody(s.handleResendVerification)) - mux.HandleFunc("POST /v1/auth/forgot-password", limitBody(s.handleForgotPassword)) - mux.HandleFunc("POST /v1/auth/reset-password", limitBody(s.handleResetPassword)) + mux.HandleFunc("POST /v1/auth/register", ipAuth(limitBody(s.handleRegister))) + mux.HandleFunc("POST /v1/auth/verify", ipAuth(limitBody(s.handleVerify))) + mux.HandleFunc("POST /v1/auth/resend-verification", ipAuth(limitBody(s.handleResendVerification))) + mux.HandleFunc("POST /v1/auth/forgot-password", ipAuth(limitBody(s.handleForgotPassword))) + mux.HandleFunc("POST /v1/auth/reset-password", ipAuth(limitBody(s.handleResetPassword))) + + actorAuthed := s.tier(ratelimit.TierAuthed, s.actorKeyer()) // Require initialization - mux.HandleFunc("GET /v1/auth/me", s.requireInitialized(s.requireAuth(s.handleAuthMe))) - mux.HandleFunc("POST /v1/auth/login", s.requireInitialized(limitBody(s.handleLogin))) - mux.HandleFunc("POST /v1/auth/change-password", s.requireInitialized(s.requireAuth(limitBody(s.handleChangePassword)))) - mux.HandleFunc("DELETE /v1/auth/account", s.requireInitialized(s.requireAuth(s.handleDeleteAccount))) - mux.HandleFunc("POST /v1/sessions", s.requireInitialized(s.requireAuth(limitBody(s.handleScopedSession)))) - mux.HandleFunc("GET /v1/credentials", s.requireInitialized(s.requireAuth(s.handleCredentialsList))) - mux.HandleFunc("POST /v1/credentials", s.requireInitialized(s.requireAuth(limitBody(s.handleCredentialsSet)))) - mux.HandleFunc("DELETE /v1/credentials", s.requireInitialized(s.requireAuth(limitBody(s.handleCredentialsDelete)))) - mux.HandleFunc("GET /discover", s.requireInitialized(s.requireAuth(s.handleDiscover))) - mux.HandleFunc("POST /v1/proposals", s.requireInitialized(s.requireAuth(limitBody(s.handleProposalCreate)))) - mux.HandleFunc("GET /v1/proposals/{id}", s.requireInitialized(s.requireAuth(s.handleProposalGet))) - mux.HandleFunc("GET /v1/proposals", s.requireInitialized(s.requireAuth(s.handleProposalList))) - mux.HandleFunc("POST /v1/admin/proposals/{id}/approve", s.requireInitialized(s.requireAuth(limitBody(s.handleAdminProposalApprove)))) - mux.HandleFunc("POST /v1/admin/proposals/{id}/reject", s.requireInitialized(s.requireAuth(limitBody(s.handleAdminProposalReject)))) + mux.HandleFunc("GET /v1/auth/me", s.requireInitialized(s.requireAuth(actorAuthed(s.handleAuthMe)))) + mux.HandleFunc("POST /v1/auth/login", s.requireInitialized(ipAuth(limitBody(s.handleLogin)))) + mux.HandleFunc("POST /v1/auth/change-password", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleChangePassword))))) + mux.HandleFunc("DELETE /v1/auth/account", s.requireInitialized(s.requireAuth(actorAuthed(s.handleDeleteAccount)))) + mux.HandleFunc("POST /v1/sessions", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleScopedSession))))) + mux.HandleFunc("GET /v1/credentials", s.requireInitialized(s.requireAuth(actorAuthed(s.handleCredentialsList)))) + mux.HandleFunc("POST /v1/credentials", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleCredentialsSet))))) + mux.HandleFunc("DELETE /v1/credentials", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleCredentialsDelete))))) + mux.HandleFunc("GET /discover", s.requireInitialized(s.requireAuth(actorAuthed(s.handleDiscover)))) + mux.HandleFunc("POST /v1/proposals", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleProposalCreate))))) + mux.HandleFunc("GET /v1/proposals/{id}", s.requireInitialized(s.requireAuth(actorAuthed(s.handleProposalGet)))) + mux.HandleFunc("GET /v1/proposals", s.requireInitialized(s.requireAuth(actorAuthed(s.handleProposalList)))) + mux.HandleFunc("POST /v1/admin/proposals/{id}/approve", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleAdminProposalApprove))))) + mux.HandleFunc("POST /v1/admin/proposals/{id}/reject", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleAdminProposalReject))))) + // /proxy/ enforces its rate limit inside the handler (needs the resolved vault). mux.HandleFunc("/proxy/", s.requireInitialized(s.requireAuth(s.handleProxy))) // Agent invite redemption (no auth — token is the credential) - mux.HandleFunc("GET /invite/{token}", s.requireInitialized(s.handleInviteRedeem)) - mux.HandleFunc("POST /invite/{token}", s.requireInitialized(limitBody(s.handlePersistentInviteRedeem))) + mux.HandleFunc("GET /invite/{token}", s.requireInitialized(ipInviteToken(s.handleInviteRedeem))) + mux.HandleFunc("POST /invite/{token}", s.requireInitialized(ipInviteToken(limitBody(s.handlePersistentInviteRedeem)))) + + ipUserInviteToken := s.tier(ratelimit.TierAuth, ratelimit.IPTokenKey(clientIP, func(r *http.Request) string { + return r.PathValue("token") + })) + ipApprovalToken := s.tier(ratelimit.TierAuth, ratelimit.IPTokenKey(clientIP, func(r *http.Request) string { + return r.URL.Query().Get("token") + })) + // OAuth callback: keyed on the hashed state query param. + ipOAuthCallback := s.tier(ratelimit.TierAuth, ratelimit.IPTokenKey(clientIP, func(r *http.Request) string { + return r.URL.Query().Get("state") + })) // Agent invites (instance-level, requires auth) - mux.HandleFunc("POST /v1/agents/invites", s.requireInitialized(s.requireAuth(limitBody(s.handleAgentInviteCreate)))) - mux.HandleFunc("GET /v1/agents/invites", s.requireInitialized(s.requireAuth(s.handleAgentInviteList))) - mux.HandleFunc("DELETE /v1/agents/invites/{token}", s.requireInitialized(s.requireAuth(s.handleAgentInviteRevoke))) - mux.HandleFunc("DELETE /v1/agents/invites/by-id/{id}", s.requireInitialized(s.requireAuth(s.handleAgentInviteRevokeByID))) + mux.HandleFunc("POST /v1/agents/invites", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleAgentInviteCreate))))) + mux.HandleFunc("GET /v1/agents/invites", s.requireInitialized(s.requireAuth(actorAuthed(s.handleAgentInviteList)))) + mux.HandleFunc("DELETE /v1/agents/invites/{token}", s.requireInitialized(s.requireAuth(actorAuthed(s.handleAgentInviteRevoke)))) + mux.HandleFunc("DELETE /v1/agents/invites/by-id/{id}", s.requireInitialized(s.requireAuth(actorAuthed(s.handleAgentInviteRevokeByID)))) // Agent management (instance-level) - mux.HandleFunc("GET /v1/agents", s.requireInitialized(s.requireAuth(s.handleAgentList))) - mux.HandleFunc("GET /v1/agents/{name}", s.requireInitialized(s.requireAuth(s.handleAgentGet))) - mux.HandleFunc("DELETE /v1/agents/{name}", s.requireInitialized(s.requireAuth(s.handleAgentRevoke))) - mux.HandleFunc("POST /v1/agents/{name}/rotate", s.requireInitialized(s.requireAuth(limitBody(s.handleAgentRotate)))) - mux.HandleFunc("POST /v1/agents/{name}/rename", s.requireInitialized(s.requireAuth(limitBody(s.handleAgentRename)))) - mux.HandleFunc("POST /v1/agents/{name}/role", s.requireInitialized(s.requireAuth(limitBody(s.handleAgentSetRole)))) + mux.HandleFunc("GET /v1/agents", s.requireInitialized(s.requireAuth(actorAuthed(s.handleAgentList)))) + mux.HandleFunc("GET /v1/agents/{name}", s.requireInitialized(s.requireAuth(actorAuthed(s.handleAgentGet)))) + mux.HandleFunc("DELETE /v1/agents/{name}", s.requireInitialized(s.requireAuth(actorAuthed(s.handleAgentRevoke)))) + mux.HandleFunc("POST /v1/agents/{name}/rotate", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleAgentRotate))))) + mux.HandleFunc("POST /v1/agents/{name}/rename", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleAgentRename))))) + mux.HandleFunc("POST /v1/agents/{name}/role", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleAgentSetRole))))) // Vault-level agent management - mux.HandleFunc("GET /v1/vaults/{name}/agents", s.requireInitialized(s.requireAuth(s.handleVaultAgentList))) - mux.HandleFunc("POST /v1/vaults/{name}/agents", s.requireInitialized(s.requireAuth(limitBody(s.handleVaultAgentAdd)))) - mux.HandleFunc("DELETE /v1/vaults/{name}/agents/{agentName}", s.requireInitialized(s.requireAuth(s.handleVaultAgentRemove))) - mux.HandleFunc("POST /v1/vaults/{name}/agents/{agentName}/role", s.requireInitialized(s.requireAuth(limitBody(s.handleVaultAgentSetRole)))) + mux.HandleFunc("GET /v1/vaults/{name}/agents", s.requireInitialized(s.requireAuth(actorAuthed(s.handleVaultAgentList)))) + mux.HandleFunc("POST /v1/vaults/{name}/agents", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleVaultAgentAdd))))) + mux.HandleFunc("DELETE /v1/vaults/{name}/agents/{agentName}", s.requireInitialized(s.requireAuth(actorAuthed(s.handleVaultAgentRemove)))) + mux.HandleFunc("POST /v1/vaults/{name}/agents/{agentName}/role", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleVaultAgentSetRole))))) // Instance settings (owner-only) - mux.HandleFunc("GET /v1/admin/settings", s.requireInitialized(s.requireAuth(s.handleGetSettings))) - mux.HandleFunc("PUT /v1/admin/settings", s.requireInitialized(s.requireAuth(limitBody(s.handleUpdateSettings)))) + mux.HandleFunc("GET /v1/admin/settings", s.requireInitialized(s.requireAuth(actorAuthed(s.handleGetSettings)))) + mux.HandleFunc("PUT /v1/admin/settings", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleUpdateSettings))))) + mux.HandleFunc("POST /v1/admin/settings/rate-limit/preview", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleRateLimitPreview))))) // Public user list (any authenticated user) - mux.HandleFunc("GET /v1/users", s.requireInitialized(s.requireAuth(s.handlePublicUserList))) + mux.HandleFunc("GET /v1/users", s.requireInitialized(s.requireAuth(actorAuthed(s.handlePublicUserList)))) // User management (owner-only, except GET self) - mux.HandleFunc("GET /v1/admin/users/{email}", s.requireInitialized(s.requireAuth(s.handleUserGet))) - mux.HandleFunc("DELETE /v1/admin/users/{email}", s.requireInitialized(s.requireAuth(s.handleUserDelete))) - mux.HandleFunc("POST /v1/admin/users/{email}/role", s.requireInitialized(s.requireAuth(limitBody(s.handleUserSetRole)))) + mux.HandleFunc("GET /v1/admin/users/{email}", s.requireInitialized(s.requireAuth(actorAuthed(s.handleUserGet)))) + mux.HandleFunc("DELETE /v1/admin/users/{email}", s.requireInitialized(s.requireAuth(actorAuthed(s.handleUserDelete)))) + mux.HandleFunc("POST /v1/admin/users/{email}/role", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleUserSetRole))))) // Vault management (any auth'd user) - mux.HandleFunc("GET /v1/vaults/{name}/context", s.requireInitialized(s.requireAuth(s.handleVaultContext))) - mux.HandleFunc("POST /v1/vaults", s.requireInitialized(s.requireAuth(limitBody(s.handleVaultCreate)))) - mux.HandleFunc("GET /v1/vaults", s.requireInitialized(s.requireAuth(s.handleVaultList))) - mux.HandleFunc("DELETE /v1/vaults/{name}", s.requireInitialized(s.requireAuth(s.handleVaultDelete))) - mux.HandleFunc("POST /v1/vaults/{name}/rename", s.requireInitialized(s.requireAuth(limitBody(s.handleVaultRename)))) - mux.HandleFunc("POST /v1/vaults/{name}/join", s.requireInitialized(s.requireAuth(limitBody(s.handleVaultJoin)))) + mux.HandleFunc("GET /v1/vaults/{name}/context", s.requireInitialized(s.requireAuth(actorAuthed(s.handleVaultContext)))) + mux.HandleFunc("POST /v1/vaults", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleVaultCreate))))) + mux.HandleFunc("GET /v1/vaults", s.requireInitialized(s.requireAuth(actorAuthed(s.handleVaultList)))) + mux.HandleFunc("DELETE /v1/vaults/{name}", s.requireInitialized(s.requireAuth(actorAuthed(s.handleVaultDelete)))) + mux.HandleFunc("POST /v1/vaults/{name}/rename", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleVaultRename))))) + mux.HandleFunc("POST /v1/vaults/{name}/join", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleVaultJoin))))) // Vault admin (owner-only) - mux.HandleFunc("GET /v1/admin/vaults", s.requireInitialized(s.requireAuth(s.handleAdminVaultList))) - mux.HandleFunc("GET /v1/vaults/{name}/services", s.requireInitialized(s.requireAuth(s.handleServicesGet))) - mux.HandleFunc("POST /v1/vaults/{name}/services", s.requireInitialized(s.requireAuth(limitBody(s.handleServicesUpsert)))) - mux.HandleFunc("PUT /v1/vaults/{name}/services", s.requireInitialized(s.requireAuth(limitBody(s.handleServicesSet)))) - mux.HandleFunc("PATCH /v1/vaults/{name}/services/{host}", s.requireInitialized(s.requireAuth(limitBody(s.handleServicePatch)))) - mux.HandleFunc("DELETE /v1/vaults/{name}/services/{host}", s.requireInitialized(s.requireAuth(s.handleServiceRemove))) - mux.HandleFunc("DELETE /v1/vaults/{name}/services", s.requireInitialized(s.requireAuth(s.handleServicesClear))) - mux.HandleFunc("GET /v1/vaults/{name}/services/credential-usage", s.requireInitialized(s.requireAuth(s.handleServicesCredentialUsage))) - mux.HandleFunc("GET /v1/service-catalog", s.requireInitialized(s.handleServiceCatalog)) - mux.HandleFunc("GET /v1/skills/cli", s.requireInitialized(s.handleSkillCLI)) - mux.HandleFunc("GET /v1/skills/http", s.requireInitialized(s.handleSkillHTTP)) + mux.HandleFunc("GET /v1/admin/vaults", s.requireInitialized(s.requireAuth(actorAuthed(s.handleAdminVaultList)))) + mux.HandleFunc("GET /v1/vaults/{name}/services", s.requireInitialized(s.requireAuth(actorAuthed(s.handleServicesGet)))) + mux.HandleFunc("POST /v1/vaults/{name}/services", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleServicesUpsert))))) + mux.HandleFunc("PUT /v1/vaults/{name}/services", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleServicesSet))))) + mux.HandleFunc("PATCH /v1/vaults/{name}/services/{host}", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleServicePatch))))) + mux.HandleFunc("DELETE /v1/vaults/{name}/services/{host}", s.requireInitialized(s.requireAuth(actorAuthed(s.handleServiceRemove)))) + mux.HandleFunc("DELETE /v1/vaults/{name}/services", s.requireInitialized(s.requireAuth(actorAuthed(s.handleServicesClear)))) + mux.HandleFunc("GET /v1/vaults/{name}/services/credential-usage", s.requireInitialized(s.requireAuth(actorAuthed(s.handleServicesCredentialUsage)))) + mux.HandleFunc("GET /v1/service-catalog", s.requireInitialized(ipAuth(s.handleServiceCatalog))) + mux.HandleFunc("GET /v1/skills/cli", s.requireInitialized(ipAuth(s.handleSkillCLI))) + mux.HandleFunc("GET /v1/skills/http", s.requireInitialized(ipAuth(s.handleSkillHTTP))) // Public: transparent-proxy root CA. Safe to expose; clients need it to // trust the minted leaves. Not wrapped in requireInitialized — the CA // lifecycle is tied to --mitm-port, not owner registration. - mux.HandleFunc("GET /v1/mitm/ca.pem", s.handleMITMCA) + mux.HandleFunc("GET /v1/mitm/ca.pem", ipAuth(s.handleMITMCA)) // Instance-level user invites - mux.HandleFunc("POST /v1/users/invites", s.requireInitialized(s.requireAuth(limitBody(s.handleUserInviteCreate)))) - mux.HandleFunc("GET /v1/users/invites", s.requireInitialized(s.requireAuth(s.handleUserInviteList))) - mux.HandleFunc("DELETE /v1/users/invites/{token}", s.requireInitialized(s.requireAuth(s.handleUserInviteRevoke))) - mux.HandleFunc("POST /v1/users/invites/{token}/reinvite", s.requireInitialized(s.requireAuth(limitBody(s.handleUserInviteReinvite)))) - mux.HandleFunc("GET /v1/users/invites/{token}/details", s.requireInitialized(s.handleUserInviteDetails)) - mux.HandleFunc("POST /v1/users/invites/{token}/accept", s.requireInitialized(limitBody(s.handleUserInviteAccept))) + mux.HandleFunc("POST /v1/users/invites", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleUserInviteCreate))))) + mux.HandleFunc("GET /v1/users/invites", s.requireInitialized(s.requireAuth(actorAuthed(s.handleUserInviteList)))) + mux.HandleFunc("DELETE /v1/users/invites/{token}", s.requireInitialized(s.requireAuth(actorAuthed(s.handleUserInviteRevoke)))) + mux.HandleFunc("POST /v1/users/invites/{token}/reinvite", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleUserInviteReinvite))))) + mux.HandleFunc("GET /v1/users/invites/{token}/details", s.requireInitialized(ipUserInviteToken(s.handleUserInviteDetails))) + mux.HandleFunc("POST /v1/users/invites/{token}/accept", s.requireInitialized(ipUserInviteToken(limitBody(s.handleUserInviteAccept)))) // Vault user management (vault admin) - mux.HandleFunc("GET /v1/vaults/{name}/users", s.requireInitialized(s.requireAuth(s.handleVaultUserList))) - mux.HandleFunc("POST /v1/vaults/{name}/users", s.requireInitialized(s.requireAuth(limitBody(s.handleVaultUserAdd)))) - mux.HandleFunc("DELETE /v1/vaults/{name}/users/{email}", s.requireInitialized(s.requireAuth(s.handleVaultUserRemove))) - mux.HandleFunc("POST /v1/vaults/{name}/users/{email}/role", s.requireInitialized(s.requireAuth(limitBody(s.handleVaultUserSetRole)))) + mux.HandleFunc("GET /v1/vaults/{name}/users", s.requireInitialized(s.requireAuth(actorAuthed(s.handleVaultUserList)))) + mux.HandleFunc("POST /v1/vaults/{name}/users", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleVaultUserAdd))))) + mux.HandleFunc("DELETE /v1/vaults/{name}/users/{email}", s.requireInitialized(s.requireAuth(actorAuthed(s.handleVaultUserRemove)))) + mux.HandleFunc("POST /v1/vaults/{name}/users/{email}/role", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleVaultUserSetRole))))) // Proposal approval details (token-based, no auth required) - mux.HandleFunc("GET /v1/proposals/approve-details", s.requireInitialized(s.handleProposalApproveDetails)) + mux.HandleFunc("GET /v1/proposals/approve-details", s.requireInitialized(ipApprovalToken(s.handleProposalApproveDetails))) // Admin proposal management - mux.HandleFunc("GET /v1/admin/proposals", s.requireInitialized(s.requireAuth(s.handleAdminProposalList))) - mux.HandleFunc("GET /v1/admin/proposals/{id}", s.requireInitialized(s.requireAuth(s.handleAdminProposalGet))) + mux.HandleFunc("GET /v1/admin/proposals", s.requireInitialized(s.requireAuth(actorAuthed(s.handleAdminProposalList)))) + mux.HandleFunc("GET /v1/admin/proposals/{id}", s.requireInitialized(s.requireAuth(actorAuthed(s.handleAdminProposalGet)))) // Email - mux.HandleFunc("POST /v1/admin/email/test", s.requireInitialized(s.requireAuth(limitBody(s.handleEmailTest)))) + mux.HandleFunc("POST /v1/admin/email/test", s.requireInitialized(s.requireAuth(actorAuthed(limitBody(s.handleEmailTest))))) - mux.HandleFunc("POST /v1/auth/logout", s.requireInitialized(s.handleLogout)) + mux.HandleFunc("POST /v1/auth/logout", s.requireInitialized(ipAuth(s.handleLogout))) // OAuth - mux.HandleFunc("GET /v1/auth/oauth/providers", s.handleOAuthProviders) - mux.HandleFunc("GET /v1/auth/oauth/{provider}/login", s.requireInitialized(s.optionalAuth(s.handleOAuthLogin))) - mux.HandleFunc("GET /v1/auth/oauth/{provider}/callback", s.requireInitialized(s.handleOAuthCallback)) - mux.HandleFunc("POST /v1/auth/oauth/{provider}/connect", s.requireInitialized(s.requireAuth(s.handleOAuthConnect))) - mux.HandleFunc("DELETE /v1/auth/oauth/{provider}", s.requireInitialized(s.requireAuth(s.handleOAuthDisconnect))) + mux.HandleFunc("GET /v1/auth/oauth/providers", ipAuth(s.handleOAuthProviders)) + mux.HandleFunc("GET /v1/auth/oauth/{provider}/login", s.requireInitialized(ipAuth(s.optionalAuth(s.handleOAuthLogin)))) + mux.HandleFunc("GET /v1/auth/oauth/{provider}/callback", s.requireInitialized(ipOAuthCallback(s.handleOAuthCallback))) + mux.HandleFunc("POST /v1/auth/oauth/{provider}/connect", s.requireInitialized(s.requireAuth(actorAuthed(s.handleOAuthConnect)))) + mux.HandleFunc("DELETE /v1/auth/oauth/{provider}", s.requireInitialized(s.requireAuth(actorAuthed(s.handleOAuthDisconnect)))) // React app static assets (Vite outputs to /assets/ with base "/") webFS, _ := fs.Sub(webDistFS, "webdist") @@ -701,6 +729,13 @@ func (s *Server) requireInitialized(next http.HandlerFunc) http.HandlerFunc { // Start starts the server and blocks until shutdown. // It listens for SIGINT/SIGTERM to shut down gracefully. func (s *Server) Start() error { + // Non-fatal: registry already holds env-based config from New(). + if s.initialized { + if _, err := s.applyRateLimitSettingToRegistry(context.Background()); err != nil { + s.logger.Warn("ratelimit setting load failed", "err", err.Error()) + } + } + stop := make(chan os.Signal, 1) signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) @@ -803,64 +838,45 @@ func init() { } } -// slidingWindowLimiter is a generic sliding-window rate limiter keyed by string. -type slidingWindowLimiter struct { - mu sync.Mutex - attempts map[string][]time.Time - window time.Duration - max int - maxKeys int // max map keys before eviction (0 = unlimited) +// ipKeyer returns a ratelimit.Keyer that keys on the request's client IP +// (honoring AGENT_VAULT_TRUSTED_PROXIES via clientIP). +func (s *Server) ipKeyer() ratelimit.Keyer { + return ratelimit.IPKey(clientIP) } -// allow checks whether an action for the given key should be allowed. -func (l *slidingWindowLimiter) allow(key string) bool { - l.mu.Lock() - defer l.mu.Unlock() - - now := time.Now() - cutoff := now.Add(-l.window) - - // Filter to recent attempts only. - recent := l.attempts[key][:0] - for _, t := range l.attempts[key] { - if t.After(cutoff) { - recent = append(recent, t) +// actorKeyer returns a ratelimit.Keyer that keys on the authenticated +// actor (user or agent). Returns "" if no session is on the context; +// the middleware then skips the check, which is safe because actor +// tiers are wrapped *after* requireAuth. +func (s *Server) actorKeyer() ratelimit.Keyer { + return ratelimit.ActorKey(func(r *http.Request) string { + sess := sessionFromContext(r.Context()) + if sess == nil { + return "" } - } - l.attempts[key] = recent - - if len(recent) >= l.max { - return false - } - - l.attempts[key] = append(l.attempts[key], now) - - // Evict oldest keys if map grows too large (prevents unbounded growth). - if l.maxKeys > 0 && len(l.attempts) > l.maxKeys { - for k, v := range l.attempts { - if len(v) == 0 || v[len(v)-1].Before(cutoff) { - delete(l.attempts, k) - } + if sess.UserID != "" { + return "u:" + sess.UserID } - } - - return true + return "a:" + sess.AgentID + }) } -const ( - loginRateWindow = 5 * time.Minute - loginRateMax = 10 // max attempts per key per window -) +// tokenKeyer returns a ratelimit.Keyer that combines clientIP with a +// hashed URL path value so token-enumeration attempts are bounded by +// both the caller's IP and the token being probed. +func (s *Server) tokenKeyer(pathValue string) ratelimit.Keyer { + return ratelimit.IPTokenKey(clientIP, func(r *http.Request) string { + return r.PathValue(pathValue) + }) +} -var ( - loginIPLimiter = newSlidingWindowLimiter(loginRateWindow, loginRateMax, 10000) - loginEmailLimiter = newSlidingWindowLimiter(loginRateWindow, loginRateMax, 10000) - registerLimiter = newSlidingWindowLimiter(loginRateWindow, 5, 10000) // 5 registrations per IP per 5 min - forgotPasswordLimiter = newSlidingWindowLimiter(loginRateWindow, 5, 10000) // 5 forgot-password requests per IP per 5 min - resendVerifyLimiter = newSlidingWindowLimiter(loginRateWindow, 5, 10000) // 5 resend-verification requests per IP per 5 min - userInviteAcceptLimiter = newSlidingWindowLimiter(loginRateWindow, 10, 10000) // 10 invite accepts per IP per 5 min - resetVerifyLimiter = &verifyRateLimiter{attempts: make(map[string]int), maxKeys: maxVerifyKeys} -) +// tier wraps handler with a rate-limit check for tier keyed by keyer. +// On denial the middleware writes a 429 with standard headers; on +// allow, it calls handler. This is the canonical way new routes in +// server.go register tier enforcement. +func (s *Server) tier(t ratelimit.Tier, keyer ratelimit.Keyer) func(http.HandlerFunc) http.HandlerFunc { + return s.rateLimit.HandlerFunc(t, keyer, s.logger) +} var ( dummyPasswordHash []byte @@ -962,3 +978,5 @@ func sessionExpired(s *store.Session) bool { const settingAllowedDomains = "allowed_email_domains" const settingInviteOnly = "invite_only" + +const settingRateLimitConfig = "ratelimit_config" diff --git a/web/src/pages/instance/SettingsTab.tsx b/web/src/pages/instance/SettingsTab.tsx index 4049473..352c2b6 100644 --- a/web/src/pages/instance/SettingsTab.tsx +++ b/web/src/pages/instance/SettingsTab.tsx @@ -1,10 +1,83 @@ -import { useState, useEffect, type FormEvent } from "react"; +import { useState, useEffect, useRef, type FormEvent } from "react"; import { useRouteContext } from "@tanstack/react-router"; import { apiFetch } from "../../lib/api"; import Button from "../../components/Button"; import Input from "../../components/Input"; import type { AuthContext } from "../../router"; +type RateLimitTier = { + rate?: number; + burst?: number; + window?: string; + max?: number; + concurrency?: number; + source: "env" | "override" | "default"; +}; + +type RateLimitState = { + profile: string; + locked: boolean; + off: boolean; + tiers: Record; +}; + +const TIER_ORDER = ["GLOBAL", "AUTH", "PROXY", "AUTHED"]; + +const TIER_LABELS: Record = { + AUTH: "Auth (unauthenticated)", + PROXY: "Proxy ingress", + AUTHED: "Authenticated endpoints", + GLOBAL: "Server-wide (global)", +}; + +const TIER_TOOLTIPS: Record = { + AUTH: "Every unauthenticated endpoint: login, register, forgot/reset password, email verification, OAuth login/callback, invite redemption, approval-token lookups. Keyed on client IP (and additionally on email for login; rejected when either bucket is exhausted).", + PROXY: "/proxy/* and the MITM forward path, keyed on (agent, vault). Token bucket smooths sustained traffic; a per-scope concurrency semaphore bounds in-flight upstream calls. Both ingresses share one budget so switching doesn't bypass the limit.", + AUTHED: "Everything behind requireAuth — CRUD, reads, admin, proposals, /discover. One bucket per actor. Defaults accommodate the heaviest legitimate agent workload; tighten only if abuse is observed.", + GLOBAL: "Server-wide backstop: Rate + Burst drive a requests-per-second ceiling; Concurrency caps total in-flight requests. Outermost safety net — sheds load before per-tier limits engage.", +}; + +// toCleanedOverrides drops empty/NaN fields from the override map so +// the wire payload only carries values the owner actually set. +function toCleanedOverrides( + overrides: Record>, +): Record> { + const out: Record> = {}; + for (const [tier, ov] of Object.entries(overrides)) { + const trimmed: Partial = {}; + if (ov.rate !== undefined && ov.rate !== null && !Number.isNaN(ov.rate)) trimmed.rate = Number(ov.rate); + if (ov.burst !== undefined && ov.burst !== null && !Number.isNaN(ov.burst)) trimmed.burst = Number(ov.burst); + if (ov.max !== undefined && ov.max !== null && !Number.isNaN(ov.max)) trimmed.max = Number(ov.max); + if (ov.concurrency !== undefined && ov.concurrency !== null && !Number.isNaN(ov.concurrency)) trimmed.concurrency = Number(ov.concurrency); + if (ov.window) trimmed.window = ov.window; + if (Object.keys(trimmed).length > 0) out[tier] = trimmed; + } + return out; +} + +function TierLabel({ name, label }: { name: string; label: string }) { + const description = TIER_TOOLTIPS[name]; + return ( + + + {label} + + {description && ( + + {description} + + )} + + ); +} + export default function InstanceSettingsTab() { const { auth } = useRouteContext({ from: "/_auth" }) as { auth: AuthContext }; @@ -22,6 +95,17 @@ export default function InstanceSettingsTab() { const [testEmailError, setTestEmailError] = useState(""); const [testEmailSuccess, setTestEmailSuccess] = useState(""); + const [rateLimit, setRateLimit] = useState(null); + const [rlProfile, setRlProfile] = useState("default"); + const [rlAdvanced, setRlAdvanced] = useState(false); + const [rlOverrides, setRlOverrides] = useState>>({}); + const [rlSaving, setRlSaving] = useState(false); + const [rlError, setRlError] = useState(""); + // Set by save/reset so the preview effect's next run skips its POST — + // the save response already carries the fresh effective config. + const skipNextPreview = useRef(false); + const [rlSuccess, setRlSuccess] = useState(""); + useEffect(() => { apiFetch("/v1/admin/settings") .then((r) => r.json()) @@ -29,11 +113,48 @@ export default function InstanceSettingsTab() { setInviteOnly(data.invite_only ?? false); setDomains(data.allowed_email_domains || []); setSmtpConfigured(data.smtp_configured ?? false); + if (data.rate_limit) { + setRateLimit(data.rate_limit as RateLimitState); + setRlProfile(data.rate_limit.profile || "default"); + } setLoading(false); }) .catch(() => setLoading(false)); }, []); + // Live preview: whenever profile or overrides change, ask the server + // what the effective config would be and reflect it in the table. + // Debounced so typing into a numeric field doesn't spam. + useEffect(() => { + if (!rateLimit || rateLimit.locked) return; + if (skipNextPreview.current) { + skipNextPreview.current = false; + return; + } + const cleaned = toCleanedOverrides(rlOverrides); + const controller = new AbortController(); + const id = setTimeout(async () => { + try { + const resp = await apiFetch("/v1/admin/settings/rate-limit/preview", { + method: "POST", + body: JSON.stringify({ profile: rlProfile, overrides: cleaned }), + signal: controller.signal, + }); + if (!resp.ok) return; + const data = (await resp.json()) as RateLimitState; + setRateLimit((prev) => (prev ? { ...data, locked: prev.locked } : data)); + } catch { + // AbortController cancellation or network — ignore. + } + }, 150); + return () => { + controller.abort(); + clearTimeout(id); + }; + // rateLimit excluded: preview *updates* rateLimit; re-running would loop. + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [rlProfile, rlOverrides]); + function addDomain(e: FormEvent) { e.preventDefault(); const domain = inputValue.trim().toLowerCase(); @@ -106,6 +227,79 @@ export default function InstanceSettingsTab() { } } + async function handleSaveRateLimit() { + setRlSaving(true); + setRlError(""); + setRlSuccess(""); + try { + const resp = await apiFetch("/v1/admin/settings", { + method: "PUT", + body: JSON.stringify({ + rate_limit: { profile: rlProfile, overrides: toCleanedOverrides(rlOverrides) }, + }), + }); + const data = await resp.json(); + if (resp.ok) { + skipNextPreview.current = true; + setRateLimit(data.rate_limit as RateLimitState); + setRlOverrides({}); + setRlSuccess("Rate-limit settings saved."); + } else { + setRlError(data.error || "Failed to save rate-limit settings."); + } + } catch { + setRlError("Network error."); + } finally { + setRlSaving(false); + } + } + + async function handleResetRateLimit() { + setRlSaving(true); + setRlError(""); + setRlSuccess(""); + try { + const resp = await apiFetch("/v1/admin/settings", { + method: "PUT", + body: JSON.stringify({ rate_limit: { profile: "default" } }), + }); + const data = await resp.json(); + if (resp.ok) { + skipNextPreview.current = true; + setRateLimit(data.rate_limit as RateLimitState); + setRlProfile("default"); + setRlOverrides({}); + setRlSuccess("Rate-limit settings reset to defaults."); + } else { + setRlError(data.error || "Failed to reset rate-limit settings."); + } + } catch { + setRlError("Network error."); + } finally { + setRlSaving(false); + } + } + + function setOverride(tier: string, field: keyof RateLimitTier, value: string) { + setRlOverrides((prev) => { + const next = { ...prev }; + const cur: Partial = { ...(next[tier] || {}) }; + if (value === "") { + delete (cur as Record)[field]; + } else if (field === "window") { + cur.window = value; + } else { + (cur as Record)[field] = Number(value); + } + if (Object.keys(cur).length === 0) { + delete next[tier]; + } else { + next[tier] = cur; + } + return next; + }); + } + if (loading) { return (
@@ -236,6 +430,156 @@ export default function InstanceSettingsTab() {
+ {rateLimit && ( +
+
+
+

Rate Limiting

+ {rateLimit.locked && ( + + Pinned by env + + )} +
+

+ Tiered limits protect auth, proxy, authed CRUD, and global in-flight. + Pick a profile; expand Advanced for per-tier overrides. + {rateLimit.locked && ( + <> Fields are read-only because AGENT_VAULT_RATELIMIT_LOCK=true is set. + )} +

+ +
+ + +
+ + + + {rlAdvanced && ( +
+ + + + + + + + + + + + + + {TIER_ORDER.map((name) => { + const tier = rateLimit.tiers[name]; + if (!tier) return null; + const ov = rlOverrides[name] || {}; + const locked = rateLimit.locked || tier.source === "env"; + return ( + + + + + + + + + + ); + })} + +
TierRateBurstWindowMaxConcurrencySource
+ + + setOverride(name, "rate", e.target.value)} + /> + + setOverride(name, "burst", e.target.value)} + /> + + setOverride(name, "window", e.target.value)} + /> + + setOverride(name, "max", e.target.value)} + /> + + setOverride(name, "concurrency", e.target.value)} + /> + + {tier.source} +
+
+ )} + + {rlError && ( +
+ {rlError} +
+ )} + {rlSuccess && ( +
+ {rlSuccess} +
+ )} + +
+ + +
+
+
+ )} +