diff --git a/docs/server/docs.go b/docs/server/docs.go index aef6944614..68a44607af 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -415,7 +415,7 @@ const docTemplate = `{ "description": "DCRConfig enables RFC 7591 Dynamic Client Registration against the\nupstream authorization server. When set, the client credentials are\nobtained at runtime rather than being pre-provisioned via ClientID /\nClientSecretFile / ClientSecretEnvVar, and ClientID must be left empty.\nMutually exclusive with ClientID.", "properties": { "discovery_url": { - "description": "DiscoveryURL is the RFC 8414 / OIDC Discovery URL from which the\nregistration_endpoint is resolved at runtime. Mutually exclusive with\nRegistrationEndpoint.", + "description": "DiscoveryURL is the exact RFC 8414 / OIDC Discovery document URL to\nfetch at runtime. The resolver issues a single GET against this URL\n(no well-known-path fallback) and reads registration_endpoint,\nauthorization_endpoint, token_endpoint,\ntoken_endpoint_auth_methods_supported, and scopes_supported from the\nresponse. Per RFC 8414 §3.3, the document's \"issuer\" field must\nexactly match the upstream issuer configured on the parent\nrun-config.\n\nUse this field when the upstream publishes discovery metadata at a\npath that differs from the issuer-derived well-known paths — for\nexample a multi-tenant IdP whose metadata lives at\nhttps://idp.example.com/tenants/acme/.well-known/openid-configuration.\n\nMutually exclusive with RegistrationEndpoint.", "type": "string" }, "initial_access_token_env_var": { @@ -427,7 +427,7 @@ const docTemplate = `{ "type": "string" }, "registration_endpoint": { - "description": "RegistrationEndpoint is the RFC 7591 registration endpoint URL used\ndirectly, bypassing discovery. Mutually exclusive with DiscoveryURL.", + "description": "RegistrationEndpoint is the RFC 7591 registration endpoint URL used\ndirectly, bypassing discovery. Because no discovery is performed,\nserver-capability fields (token_endpoint_auth_methods_supported,\nscopes_supported) are unavailable on this code path; the caller is\nexpected to also supply AuthorizationEndpoint, TokenEndpoint, and an\nexplicit Scopes list on the parent OAuth2UpstreamRunConfig. Auth\nmethod falls back to the resolver's default (client_secret_basic).\n\nMutually exclusive with DiscoveryURL.", "type": "string" }, "software_id": { diff --git a/docs/server/swagger.json b/docs/server/swagger.json index 082a1611d4..1130276df5 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -408,7 +408,7 @@ "description": "DCRConfig enables RFC 7591 Dynamic Client Registration against the\nupstream authorization server. When set, the client credentials are\nobtained at runtime rather than being pre-provisioned via ClientID /\nClientSecretFile / ClientSecretEnvVar, and ClientID must be left empty.\nMutually exclusive with ClientID.", "properties": { "discovery_url": { - "description": "DiscoveryURL is the RFC 8414 / OIDC Discovery URL from which the\nregistration_endpoint is resolved at runtime. Mutually exclusive with\nRegistrationEndpoint.", + "description": "DiscoveryURL is the exact RFC 8414 / OIDC Discovery document URL to\nfetch at runtime. The resolver issues a single GET against this URL\n(no well-known-path fallback) and reads registration_endpoint,\nauthorization_endpoint, token_endpoint,\ntoken_endpoint_auth_methods_supported, and scopes_supported from the\nresponse. Per RFC 8414 §3.3, the document's \"issuer\" field must\nexactly match the upstream issuer configured on the parent\nrun-config.\n\nUse this field when the upstream publishes discovery metadata at a\npath that differs from the issuer-derived well-known paths — for\nexample a multi-tenant IdP whose metadata lives at\nhttps://idp.example.com/tenants/acme/.well-known/openid-configuration.\n\nMutually exclusive with RegistrationEndpoint.", "type": "string" }, "initial_access_token_env_var": { @@ -420,7 +420,7 @@ "type": "string" }, "registration_endpoint": { - "description": "RegistrationEndpoint is the RFC 7591 registration endpoint URL used\ndirectly, bypassing discovery. Mutually exclusive with DiscoveryURL.", + "description": "RegistrationEndpoint is the RFC 7591 registration endpoint URL used\ndirectly, bypassing discovery. Because no discovery is performed,\nserver-capability fields (token_endpoint_auth_methods_supported,\nscopes_supported) are unavailable on this code path; the caller is\nexpected to also supply AuthorizationEndpoint, TokenEndpoint, and an\nexplicit Scopes list on the parent OAuth2UpstreamRunConfig. Auth\nmethod falls back to the resolver's default (client_secret_basic).\n\nMutually exclusive with DiscoveryURL.", "type": "string" }, "software_id": { diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index 6d973f5cea..b3962194ae 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -415,9 +415,21 @@ components: properties: discovery_url: description: |- - DiscoveryURL is the RFC 8414 / OIDC Discovery URL from which the - registration_endpoint is resolved at runtime. Mutually exclusive with - RegistrationEndpoint. + DiscoveryURL is the exact RFC 8414 / OIDC Discovery document URL to + fetch at runtime. The resolver issues a single GET against this URL + (no well-known-path fallback) and reads registration_endpoint, + authorization_endpoint, token_endpoint, + token_endpoint_auth_methods_supported, and scopes_supported from the + response. Per RFC 8414 §3.3, the document's "issuer" field must + exactly match the upstream issuer configured on the parent + run-config. + + Use this field when the upstream publishes discovery metadata at a + path that differs from the issuer-derived well-known paths — for + example a multi-tenant IdP whose metadata lives at + https://idp.example.com/tenants/acme/.well-known/openid-configuration. + + Mutually exclusive with RegistrationEndpoint. type: string initial_access_token_env_var: description: |- @@ -435,7 +447,14 @@ components: registration_endpoint: description: |- RegistrationEndpoint is the RFC 7591 registration endpoint URL used - directly, bypassing discovery. Mutually exclusive with DiscoveryURL. + directly, bypassing discovery. Because no discovery is performed, + server-capability fields (token_endpoint_auth_methods_supported, + scopes_supported) are unavailable on this code path; the caller is + expected to also supply AuthorizationEndpoint, TokenEndpoint, and an + explicit Scopes list on the parent OAuth2UpstreamRunConfig. Auth + method falls back to the resolver's default (client_secret_basic). + + Mutually exclusive with DiscoveryURL. type: string software_id: description: |- diff --git a/pkg/authserver/config.go b/pkg/authserver/config.go index bf59c88e03..e82f4bdab8 100644 --- a/pkg/authserver/config.go +++ b/pkg/authserver/config.go @@ -266,14 +266,44 @@ type OAuth2UpstreamRunConfig struct { // points at RFC 8414 / OIDC Discovery metadata from which the registration // endpoint is resolved; RegistrationEndpoint is used directly when the upstream // does not publish discovery metadata. +// +// Trust assumption: DiscoveryURL and RegistrationEndpoint are operator-supplied +// URLs validated only for HTTPS-or-loopback. The DCR resolver will issue +// outbound HTTP requests — possibly carrying the RFC 7591 initial access token +// as a bearer header — to whatever address those URLs resolve to. There is +// currently no allowlist or RFC1918 / link-local / cloud-metadata-service +// guard, because the operator role is fully trusted today. If the trust +// boundary ever changes (e.g. a multi-tenant operator deployment, or a less- +// privileged role gains write access to this struct via a CRD or YAML +// surface), this field becomes a confused-deputy SSRF vector. Hardening is +// tracked in https://github.com/stacklok/toolhive/issues/5135. type DCRUpstreamConfig struct { - // DiscoveryURL is the RFC 8414 / OIDC Discovery URL from which the - // registration_endpoint is resolved at runtime. Mutually exclusive with - // RegistrationEndpoint. + // DiscoveryURL is the exact RFC 8414 / OIDC Discovery document URL to + // fetch at runtime. The resolver issues a single GET against this URL + // (no well-known-path fallback) and reads registration_endpoint, + // authorization_endpoint, token_endpoint, + // token_endpoint_auth_methods_supported, and scopes_supported from the + // response. Per RFC 8414 §3.3, the document's "issuer" field must + // exactly match the upstream issuer configured on the parent + // run-config. + // + // Use this field when the upstream publishes discovery metadata at a + // path that differs from the issuer-derived well-known paths — for + // example a multi-tenant IdP whose metadata lives at + // https://idp.example.com/tenants/acme/.well-known/openid-configuration. + // + // Mutually exclusive with RegistrationEndpoint. DiscoveryURL string `json:"discovery_url,omitempty" yaml:"discovery_url,omitempty"` // RegistrationEndpoint is the RFC 7591 registration endpoint URL used - // directly, bypassing discovery. Mutually exclusive with DiscoveryURL. + // directly, bypassing discovery. Because no discovery is performed, + // server-capability fields (token_endpoint_auth_methods_supported, + // scopes_supported) are unavailable on this code path; the caller is + // expected to also supply AuthorizationEndpoint, TokenEndpoint, and an + // explicit Scopes list on the parent OAuth2UpstreamRunConfig. Auth + // method falls back to the resolver's default (client_secret_basic). + // + // Mutually exclusive with DiscoveryURL. RegistrationEndpoint string `json:"registration_endpoint,omitempty" yaml:"registration_endpoint,omitempty"` // InitialAccessTokenFile is the path to a file containing the RFC 7591 @@ -507,6 +537,22 @@ func (c *OAuth2UpstreamRunConfig) Validate() error { if err := c.DCRConfig.Validate(); err != nil { return fmt.Errorf("oauth2 upstream: invalid dcr_config: %w", err) } + + // When the operator configures DCRConfig.RegistrationEndpoint, the + // resolver bypasses discovery and therefore cannot populate + // AuthorizationEndpoint or TokenEndpoint from server metadata. The + // run-config must supply both explicitly or the upstream is + // unusable: registration would succeed and the first authorize or + // token-exchange call would silently fail with empty endpoints. + // Discovery flow (DCRConfig.DiscoveryURL) is unaffected — those + // fields populate from metadata. + if c.DCRConfig.RegistrationEndpoint != "" { + if c.AuthorizationEndpoint == "" || c.TokenEndpoint == "" { + return fmt.Errorf( + "oauth2 upstream: authorization_endpoint and token_endpoint are required " + + "when dcr_config.registration_endpoint is set (no discovery to populate them)") + } + } } return nil diff --git a/pkg/authserver/config_test.go b/pkg/authserver/config_test.go index dd6cb8b689..49245529b5 100644 --- a/pkg/authserver/config_test.go +++ b/pkg/authserver/config_test.go @@ -296,13 +296,49 @@ func TestOAuth2UpstreamRunConfigValidate(t *testing.T) { errMsg: "either discovery_url or registration_endpoint is required", }, { - name: "DCRConfig with only registration_endpoint is valid", + name: "DCRConfig with only registration_endpoint is valid when authorization_endpoint and token_endpoint are also set", config: OAuth2UpstreamRunConfig{ + AuthorizationEndpoint: "https://idp.example.com/authorize", + TokenEndpoint: "https://idp.example.com/token", DCRConfig: &DCRUpstreamConfig{ RegistrationEndpoint: "https://idp.example.com/register", }, }, }, + + // registration_endpoint requires explicit authorize/token endpoints. + // Discovery would have populated them; bypassing discovery means the + // run-config must supply them or the upstream is unusable. + { + name: "DCRConfig.registration_endpoint without authorization_endpoint rejects", + config: OAuth2UpstreamRunConfig{ + TokenEndpoint: "https://idp.example.com/token", + DCRConfig: &DCRUpstreamConfig{ + RegistrationEndpoint: "https://idp.example.com/register", + }, + }, + wantErr: true, + errMsg: "authorization_endpoint and token_endpoint are required", + }, + { + name: "DCRConfig.registration_endpoint without token_endpoint rejects", + config: OAuth2UpstreamRunConfig{ + AuthorizationEndpoint: "https://idp.example.com/authorize", + DCRConfig: &DCRUpstreamConfig{ + RegistrationEndpoint: "https://idp.example.com/register", + }, + }, + wantErr: true, + errMsg: "authorization_endpoint and token_endpoint are required", + }, + { + name: "DCRConfig.discovery_url is valid without explicit endpoints (discovery populates them)", + config: OAuth2UpstreamRunConfig{ + DCRConfig: &DCRUpstreamConfig{ + DiscoveryURL: "https://idp.example.com/.well-known/oauth-authorization-server", + }, + }, + }, } for _, tt := range tests { diff --git a/pkg/authserver/runner/dcr.go b/pkg/authserver/runner/dcr.go new file mode 100644 index 0000000000..04d4016c19 --- /dev/null +++ b/pkg/authserver/runner/dcr.go @@ -0,0 +1,923 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package runner + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "log/slog" + "net/http" + "net/url" + "runtime/debug" + "slices" + "sort" + "strings" + "time" + + "golang.org/x/sync/singleflight" + + "github.com/stacklok/toolhive/pkg/authserver" + "github.com/stacklok/toolhive/pkg/networking" + "github.com/stacklok/toolhive/pkg/oauthproto" +) + +// dcrFlight coalesces concurrent resolveDCRCredentials calls that share the +// same DCRKey. Two goroutines hitting the resolver for the same upstream and +// scope set will both miss the cache, so without coalescing they would both +// call RegisterClientDynamically and the loser's registration would become +// orphaned at the upstream IdP — an operator-visible cleanup task and +// possibly a transient startup failure if the upstream rate-limits +// concurrent registrations. Followers wait for the leader's result and +// observe the same DCRResolution. +// +// Package-level rather than per-store because the deduplication concern is +// the resolver's, not the cache's: a future Redis-backed store would still +// want this in-process gate so a single replica does not double-register. +var dcrFlight singleflight.Group + +// defaultUpstreamRedirectPath is the redirect path derived from the issuer +// origin when the caller's run-config does not supply an explicit RedirectURI. +// Matches the authserver's public callback route. +const defaultUpstreamRedirectPath = "/oauth/callback" + +// authMethodPreference is the preferred order of token_endpoint_auth_methods, +// most preferred first. The resolver intersects this list with the server's +// advertised methods and picks the first match. +// +// Rationale: private_key_jwt is cryptographically strongest (asymmetric, no +// shared secret on the wire). client_secret_basic and client_secret_post are +// equally secure in transit but basic is marginally preferred because the +// credentials do not appear in request-body logs. "none" is the fallback for +// public PKCE clients. +var authMethodPreference = []string{ + "private_key_jwt", + "client_secret_basic", + "client_secret_post", + "none", +} + +// DCRResolution captures the full RFC 7591 + RFC 7592 response for a +// successful Dynamic Client Registration, together with the endpoints the +// upstream advertises so the caller need not re-discover them. +// +// The struct is the unit of storage in DCRCredentialStore and the unit of +// application via applyResolution. +type DCRResolution struct { + // ClientID is the RFC 7591 "client_id" returned by the authorization + // server. + ClientID string + + // ClientSecret is the RFC 7591 "client_secret" returned by the + // authorization server. Empty for public PKCE clients. + ClientSecret string + + // AuthorizationEndpoint is the discovered (or configured) authorization + // endpoint for this upstream. + AuthorizationEndpoint string + + // TokenEndpoint is the discovered (or configured) token endpoint for this + // upstream. + TokenEndpoint string + + // RegistrationAccessToken is the RFC 7592 "registration_access_token" + // required for subsequent registration management operations (update, + // read, delete). + RegistrationAccessToken string + + // RegistrationClientURI is the RFC 7592 "registration_client_uri" for + // registration management operations. + RegistrationClientURI string + + // TokenEndpointAuthMethod is the authentication method negotiated at the + // token endpoint for this client. + TokenEndpointAuthMethod string + + // ClientIDIssuedAt is the RFC 7591 §3.2.1 "client_id_issued_at" value + // converted to a Go time.Time. Zero when the upstream omitted the field + // (the field is OPTIONAL per RFC 7591). Informational; not used to + // invalidate the cache. + ClientIDIssuedAt time.Time + + // ClientSecretExpiresAt is the RFC 7591 §3.2.1 "client_secret_expires_at" + // value converted to a Go time.Time. The wire convention is that 0 means + // "the secret does not expire"; in this struct that is represented by + // the zero time.Time so callers can use IsZero() rather than special- + // casing 0. + // + // When non-zero, this field is the authoritative signal that + // lookupCachedResolution uses to refetch credentials before the upstream + // rejects them at the token endpoint. The 90-day dcrStaleAgeThreshold + // is a heuristic for "consider rotating"; this is a hard expiry asserted + // by the upstream itself. + ClientSecretExpiresAt time.Time + + // CreatedAt is the wall-clock time at which the resolution was completed. + // Used by Step 2g observability to compute staleness against + // dcrStaleAgeThreshold. + CreatedAt time.Time +} + +// needsDCR reports whether rc requires runtime Dynamic Client Registration. +// A run-config needs DCR exactly when ClientID is empty and DCRConfig is +// non-nil (the mutually-exclusive constraint is enforced by +// OAuth2UpstreamRunConfig.Validate; this helper is a convenience check). +func needsDCR(rc *authserver.OAuth2UpstreamRunConfig) bool { + if rc == nil { + return false + } + return rc.ClientID == "" && rc.DCRConfig != nil +} + +// applyResolution copies resolved credentials and endpoints from res into rc. +// Callers must pass a COPY of the upstream run-config (per the +// copy-before-mutate rule in .claude/rules/go-style.md); applyResolution does +// not clone rc internally. +// +// All three fields (ClientID, AuthorizationEndpoint, TokenEndpoint) are +// written only when rc leaves them empty — explicit caller configuration +// always wins. resolveDCRCredentials enforces ClientID == "" up front via +// validateResolveInputs, so the conditional write here is defence-in-depth +// against future call sites that bypass the resolver and invoke +// applyResolution directly: an unconditional overwrite would silently +// clobber a pre-provisioned ClientID with no error. +// +// Note: the resolved ClientSecret is NOT copied onto rc because +// OAuth2UpstreamRunConfig models secrets as file-or-env references, not +// inline values. Callers that need the resolved secret must read it from +// the DCRResolution directly. +func applyResolution(rc *authserver.OAuth2UpstreamRunConfig, res *DCRResolution) { + if rc == nil || res == nil { + return + } + if rc.ClientID == "" { + rc.ClientID = res.ClientID + } + if rc.AuthorizationEndpoint == "" { + rc.AuthorizationEndpoint = res.AuthorizationEndpoint + } + if rc.TokenEndpoint == "" { + rc.TokenEndpoint = res.TokenEndpoint + } +} + +// scopesHash returns the SHA-256 hex digest of the canonical scope set. +// +// Canonicalisation: +// 1. Sort ascending so the digest is order-insensitive — e.g. +// []string{"openid", "profile"} and []string{"profile", "openid"} hash to +// the same value. +// 2. Deduplicate so that []string{"openid"} and []string{"openid", "openid"} +// hash to the same value. An OAuth scope set is a set, not a multiset +// (RFC 6749 §3.3), and without deduplication a caller that accidentally +// duplicated a scope would miss cache entries and trigger redundant +// RFC 7591 registrations. +// 3. Join with newlines (a character not valid in OAuth scope tokens per +// RFC 6749 §3.3) to avoid collision between e.g. ["ab", "c"] and +// ["a", "bc"]. +func scopesHash(scopes []string) string { + sorted := slices.Clone(scopes) + sort.Strings(sorted) + sorted = slices.Compact(sorted) + + h := sha256.New() + for i, s := range sorted { + if i > 0 { + _, _ = h.Write([]byte("\n")) + } + _, _ = h.Write([]byte(s)) + } + return hex.EncodeToString(h.Sum(nil)) +} + +// resolveDCRCredentials performs Dynamic Client Registration for rc against +// the upstream authorization server identified by rc.DCRConfig, caching the +// resulting credentials in cache. On cache hit the resolver returns +// immediately without any network I/O. +// +// rc must have ClientID == "" and DCRConfig != nil — the caller is expected +// to have validated this via OAuth2UpstreamRunConfig.Validate. +// +// localIssuer is *this* auth server's issuer identifier, NOT the upstream's. +// It is used to key the cache and to default the redirect URI to +// {localIssuer}/oauth/callback when rc.RedirectURI is empty. The upstream's +// issuer is recovered separately from rc.DCRConfig.DiscoveryURL inside the +// resolver and is used solely for RFC 8414 §3.3 metadata verification. +// Passing the upstream's issuer here would produce a wrong-origin default +// redirect and a cache key that does not identify the auth-server context. +// +// The caller is responsible for applying the returned resolution onto a COPY +// of rc via applyResolution (per the copy-before-mutate rule). This function +// neither mutates rc nor the cache on failure. +func resolveDCRCredentials( + ctx context.Context, + rc *authserver.OAuth2UpstreamRunConfig, + localIssuer string, + cache DCRCredentialStore, +) (*DCRResolution, error) { + if err := validateResolveInputs(rc, localIssuer, cache); err != nil { + return nil, err + } + + redirectURI, err := resolveUpstreamRedirectURI(rc.RedirectURI, localIssuer) + if err != nil { + return nil, fmt.Errorf("dcr: resolve redirect uri: %w", err) + } + + scopes := slices.Clone(rc.Scopes) + key := DCRKey{ + Issuer: localIssuer, + RedirectURI: redirectURI, + ScopesHash: scopesHash(scopes), + } + + // Cache lookup short-circuits before any network I/O. + if cached, hit, err := lookupCachedResolution(ctx, cache, key, localIssuer, redirectURI); err != nil { + return nil, err + } else if hit { + return cached, nil + } + + // Coalesce concurrent registrations for the same DCRKey — see dcrFlight + // doc comment. The leader runs the registerOnce closure; followers + // receive the leader's *DCRResolution result. The flight key embeds the + // DCRKey fields with a separator that cannot appear in any of them + // (newline is not valid in OAuth scope tokens, URLs, or hex digests). + // + // A defer/recover inside the closure converts a panic in registerAndCache + // (or anything it calls) into a normal error. Without this, singleflight + // re-panics the leader's panic in every follower — N concurrent callers + // for the same DCRKey would all crash with the same value. The panic is + // still surfaced: it is logged at Error with a stack trace, and the + // returned error wraps the recovered value so callers can react to it as + // a normal failure. + flightKey := key.Issuer + "\n" + key.RedirectURI + "\n" + key.ScopesHash + resolutionAny, err, _ := dcrFlight.Do(flightKey, func() (res any, err error) { + defer func() { + if r := recover(); r != nil { + slog.Error("dcr: registration panicked", + "panic", fmt.Sprintf("%v", r), + "stack", string(debug.Stack()), + ) + err = fmt.Errorf("dcr: registration panicked: %v", r) + res = nil + } + }() + return registerAndCache(ctx, rc, localIssuer, redirectURI, scopes, key, cache) + }) + if err != nil { + return nil, err + } + return resolutionAny.(*DCRResolution), nil +} + +// registerAndCache is the leader-only body of resolveDCRCredentials wrapped +// by the singleflight. It rechecks the cache before any network I/O so +// followers that arrive after the leader's Put returns immediately see the +// fresh entry on a subsequent call. Endpoint resolution, registration, and +// the durable Put live here. +func registerAndCache( + ctx context.Context, + rc *authserver.OAuth2UpstreamRunConfig, + localIssuer, redirectURI string, + scopes []string, + key DCRKey, + cache DCRCredentialStore, +) (*DCRResolution, error) { + // Recheck cache: another flight that just finished may have populated + // it between our initial lookup and our singleflight entry. + if cached, hit, err := lookupCachedResolution(ctx, cache, key, localIssuer, redirectURI); err != nil { + return nil, err + } else if hit { + return cached, nil + } + + // Endpoint resolution: discover metadata when configured, otherwise use + // the caller-supplied RegistrationEndpoint directly. The upstream's + // expected issuer is recovered from cfg.DiscoveryURL inside the helper. + // localIssuer here is *this* auth server's issuer — correct for cache + // keying and redirect URI defaulting, but it must not be used for + // RFC 8414 §3.3 metadata verification (which is the upstream's concern). + endpoints, err := resolveDCREndpoints(ctx, rc.DCRConfig) + if err != nil { + return nil, err + } + applyExplicitEndpointOverrides(endpoints, rc) + + // Token-endpoint auth method: intersect server support with our + // preference order; default to client_secret_basic if the server does + // not advertise the field at all. + authMethod, err := selectTokenEndpointAuthMethod( + endpoints.tokenEndpointAuthMethodsSupported, + endpoints.codeChallengeMethodsSupported, + ) + if err != nil { + return nil, fmt.Errorf("dcr: %w", err) + } + + registrationScopes := chooseRegistrationScopes(scopes, endpoints.scopesSupported, localIssuer) + + response, err := performRegistration(ctx, rc.DCRConfig, endpoints.registrationEndpoint, + redirectURI, authMethod, registrationScopes) + if err != nil { + return nil, err + } + + resolution := buildResolution(response, endpoints, authMethod) + + // Write to durable storage before updating caller state (per + // .claude/rules/go-style.md "write to durable storage before in-memory"). + if err := cache.Put(ctx, key, resolution); err != nil { + return nil, fmt.Errorf("dcr: cache put: %w", err) + } + + //nolint:gosec // G706: client_id is public metadata per RFC 7591. + slog.Debug("dcr: registered new client", + "local_issuer", localIssuer, + "redirect_uri", redirectURI, + "client_id", resolution.ClientID, + ) + return resolution, nil +} + +// ----------------------------------------------------------------------------- +// Private helpers +// ----------------------------------------------------------------------------- + +// validateResolveInputs performs the defensive re-check of resolver +// preconditions. Validate() enforces most of these at config-load time, but +// resolveDCRCredentials is an entry point that programmatic callers can +// reach with partially-constructed run-configs. +func validateResolveInputs( + rc *authserver.OAuth2UpstreamRunConfig, + localIssuer string, + cache DCRCredentialStore, +) error { + if rc == nil { + return fmt.Errorf("oauth2 upstream run-config is required") + } + if rc.ClientID != "" { + return fmt.Errorf("dcr: oauth2 upstream has a pre-provisioned client_id") + } + if rc.DCRConfig == nil { + return fmt.Errorf("dcr: oauth2 upstream has no dcr_config") + } + if localIssuer == "" { + return fmt.Errorf("dcr: issuer is required") + } + if cache == nil { + return fmt.Errorf("dcr: credential store is required") + } + return nil +} + +// lookupCachedResolution checks the cache and logs the hit. On hit it +// returns (resolution, true, nil). On miss it returns (nil, false, nil). An +// error is returned only on backend failure. +// +// Entries whose RFC 7591 §3.2.1 client_secret_expires_at has already passed +// are treated as misses so the singleflight body (registerAndCache) re-runs +// the registration and overwrites the stale entry via cache.Put. Without +// this check the cache would serve an expired secret indefinitely; the +// upstream's token endpoint would 401 on every use and the resolver would +// have no signal to refetch. The check is skipped when the field is zero, +// per the RFC 7591 convention "0 means the secret does not expire". +func lookupCachedResolution( + ctx context.Context, + cache DCRCredentialStore, + key DCRKey, + localIssuer, redirectURI string, +) (*DCRResolution, bool, error) { + cached, ok, err := cache.Get(ctx, key) + if err != nil { + return nil, false, fmt.Errorf("dcr: cache lookup: %w", err) + } + if !ok { + return nil, false, nil + } + if !cached.ClientSecretExpiresAt.IsZero() && time.Now().After(cached.ClientSecretExpiresAt) { + //nolint:gosec // G706: client_id is public metadata per RFC 7591. + slog.Debug("dcr: cache hit ignored; cached secret expired per upstream client_secret_expires_at", + "local_issuer", localIssuer, + "redirect_uri", redirectURI, + "client_id", cached.ClientID, + "client_secret_expires_at", cached.ClientSecretExpiresAt.UTC().Format(time.RFC3339), + ) + return nil, false, nil + } + slog.Debug("dcr: cache hit", + "local_issuer", localIssuer, + "redirect_uri", redirectURI, + "client_id", cached.ClientID, + ) + return cached, true, nil +} + +// applyExplicitEndpointOverrides overwrites the discovered +// authorizationEndpoint / tokenEndpoint in endpoints with explicit values +// from rc when rc specifies them. Explicit caller configuration always wins +// over discovery. +func applyExplicitEndpointOverrides(endpoints *dcrEndpoints, rc *authserver.OAuth2UpstreamRunConfig) { + if rc.AuthorizationEndpoint != "" { + endpoints.authorizationEndpoint = rc.AuthorizationEndpoint + } + if rc.TokenEndpoint != "" { + endpoints.tokenEndpoint = rc.TokenEndpoint + } +} + +// chooseRegistrationScopes selects the scopes to send in the registration +// request: explicit caller scopes > discovered scopes_supported > empty. +// Logs a warning when neither source produces any scopes. +func chooseRegistrationScopes(explicit, discovered []string, localIssuer string) []string { + if len(explicit) > 0 { + return explicit + } + if len(discovered) > 0 { + return discovered + } + slog.Warn("dcr: no scopes configured or discovered; registering with empty scope", + "local_issuer", localIssuer, + ) + return nil +} + +// performRegistration executes the HTTP registration request exactly once. +// The initial access token (if configured) is injected as a +// bearer-token Authorization header via a wrapping http.Client. +func performRegistration( + ctx context.Context, + dcrCfg *authserver.DCRUpstreamConfig, + registrationEndpoint, redirectURI, authMethod string, + scopes []string, +) (*oauthproto.DynamicClientRegistrationResponse, error) { + // Initial access token is optional; resolveSecret returns ("", nil) + // when neither file nor env var is configured. + initialAccessToken, err := resolveSecret(dcrCfg.InitialAccessTokenFile, dcrCfg.InitialAccessTokenEnvVar) + if err != nil { + return nil, fmt.Errorf("dcr: resolve initial access token: %w", err) + } + + httpClient := newDCRHTTPClient(initialAccessToken) + + request := &oauthproto.DynamicClientRegistrationRequest{ + RedirectURIs: []string{redirectURI}, + ClientName: oauthproto.ToolHiveMCPClientName, + TokenEndpointAuthMethod: authMethod, + GrantTypes: []string{oauthproto.GrantTypeAuthorizationCode, oauthproto.GrantTypeRefreshToken}, + ResponseTypes: []string{oauthproto.ResponseTypeCode}, + Scopes: scopes, + } + + // Call exactly once — no retry loop. Step 2g will add retry/backoff at a + // higher layer if needed. + response, err := oauthproto.RegisterClientDynamically(ctx, registrationEndpoint, request, httpClient) + if err != nil { + return nil, fmt.Errorf("dcr: register client: %w", err) + } + return response, nil +} + +// buildResolution assembles the DCRResolution from the RFC 7591 response and +// the resolved endpoints. If the server did not echo a +// token_endpoint_auth_method in the response, the method actually sent is +// recorded so downstream consumers see a definite value. +// +// RFC 7591 §3.2.1 client_id_issued_at and client_secret_expires_at are +// converted from int64 epoch seconds to time.Time. The wire value 0 means +// "field absent" or "secret does not expire"; both map to the zero time.Time +// so callers can use IsZero() uniformly. +func buildResolution( + response *oauthproto.DynamicClientRegistrationResponse, + endpoints *dcrEndpoints, + sentAuthMethod string, +) *DCRResolution { + authMethod := response.TokenEndpointAuthMethod + if authMethod == "" { + authMethod = sentAuthMethod + } + return &DCRResolution{ + ClientID: response.ClientID, + ClientSecret: response.ClientSecret, + AuthorizationEndpoint: endpoints.authorizationEndpoint, + TokenEndpoint: endpoints.tokenEndpoint, + RegistrationAccessToken: response.RegistrationAccessToken, + RegistrationClientURI: response.RegistrationClientURI, + TokenEndpointAuthMethod: authMethod, + ClientIDIssuedAt: epochSecondsToTime(response.ClientIDIssuedAt), + ClientSecretExpiresAt: epochSecondsToTime(response.ClientSecretExpiresAt), + CreatedAt: time.Now(), + } +} + +// epochSecondsToTime converts the int64 epoch-seconds form used by RFC 7591 +// into a time.Time. Zero passes through to the zero time.Time so callers can +// rely on IsZero() to mean "field absent" / "does not expire". +func epochSecondsToTime(epoch int64) time.Time { + if epoch == 0 { + return time.Time{} + } + return time.Unix(epoch, 0).UTC() +} + +// dcrEndpoints is the internal bundle of endpoints produced by endpoint +// resolution. The separation from DCRResolution lets the resolver reason +// about discovered vs. overridden values before committing to a resolution. +type dcrEndpoints struct { + authorizationEndpoint string + tokenEndpoint string + registrationEndpoint string + tokenEndpointAuthMethodsSupported []string + scopesSupported []string + // codeChallengeMethodsSupported is consumed by + // selectTokenEndpointAuthMethod to gate the public-client (none) auth + // method on S256 PKCE being advertised. RFC 7636 / OAuth 2.1 require + // PKCE-with-S256 for public clients; registering as none against an + // upstream that advertises only plain (or omits the field) would be a + // compliance gap. + codeChallengeMethodsSupported []string +} + +// resolveDCREndpoints produces the endpoint bundle from the DCRUpstreamConfig. +// +// Three branches, in priority order: +// +// 1. cfg.RegistrationEndpoint set — use it directly and skip discovery +// entirely. Server-capability fields (token_endpoint_auth_methods_supported, +// scopes_supported) are unavailable on this branch; the caller is +// expected to also supply AuthorizationEndpoint, TokenEndpoint, and an +// explicit Scopes list. Auth method falls back to the +// selectTokenEndpointAuthMethod default. +// 2. cfg.DiscoveryURL set — fetch the exact document the operator +// configured (bypassing the well-known path fallback). RFC 8414 §3.3 +// requires the metadata's "issuer" field to match the authorization +// server's issuer identifier; that identifier is the upstream's, not +// this auth server's, so it is recovered from the discovery URL via +// deriveExpectedIssuerFromDiscoveryURL rather than reusing the +// caller-supplied issuer (which names this auth server and is used +// elsewhere in resolveDCRCredentials for redirect URI defaulting and +// cache keying). +// 3. Neither set — defensive; Validate() rejects this configuration, but +// as a programmatic entry point the resolver returns an error rather +// than falling back to an unexpected strategy. +// +// When metadata is returned but omits registration_endpoint, the resolver +// synthesises {origin}/register — a convention used by nanobot and Hydra +// for providers that ship DCR without advertising it in discovery. Origin +// is taken from the upstream issuer, not this auth server's issuer, so the +// synthesised endpoint lands at the upstream. +func resolveDCREndpoints( + ctx context.Context, + cfg *authserver.DCRUpstreamConfig, +) (*dcrEndpoints, error) { + if cfg.RegistrationEndpoint != "" { + // Validate locally so a non-HTTPS or malformed URL fails before + // performRegistration constructs a bearer-token transport for it. + if err := validateUpstreamEndpointURL(cfg.RegistrationEndpoint, "registration_endpoint"); err != nil { + return nil, fmt.Errorf("dcr: %w", err) + } + return &dcrEndpoints{ + registrationEndpoint: cfg.RegistrationEndpoint, + }, nil + } + + if cfg.DiscoveryURL == "" { + return nil, fmt.Errorf( + "dcr: dcr_config must set either discovery_url or registration_endpoint") + } + + upstreamIssuer, err := deriveExpectedIssuerFromDiscoveryURL(cfg.DiscoveryURL) + if err != nil { + return nil, err + } + + metadata, err := oauthproto.FetchAuthorizationServerMetadataFromURL(ctx, cfg.DiscoveryURL, upstreamIssuer, nil) + return endpointsFromMetadata(metadata, err, upstreamIssuer) +} + +// deriveExpectedIssuerFromDiscoveryURL recovers the issuer identifier the +// upstream is expected to claim in its RFC 8414 / OIDC Discovery document, +// given an operator-configured DiscoveryURL. +// +// Two recognised conventions: +// +// 1. Well-known suffix: the URL ends with /.well-known/oauth-authorization-server +// or /.well-known/openid-configuration. The suffix is stripped to recover +// the issuer; this covers single-tenant providers (e.g. +// https://mcp.atlassian.com/.well-known/oauth-authorization-server → +// https://mcp.atlassian.com) and the issuer-suffix multi-tenant style +// (e.g. https://idp.example.com/tenants/acme/.well-known/openid-configuration +// → https://idp.example.com/tenants/acme). +// 2. Non-well-known path: the URL points at a custom metadata endpoint that +// does not end in either suffix. Origin (scheme://host) is used as a +// best-effort fallback; this matches the common shape where the upstream +// issuer is the host root. +// +// RFC 8414 §3.1's path-aware form (well-known path inserted between host and +// tenant path, e.g. https://example.com/.well-known/oauth-authorization-server/tenant) +// is not auto-detected here — operators on that pattern can switch to +// dcr_config.registration_endpoint to bypass discovery. +func deriveExpectedIssuerFromDiscoveryURL(discoveryURL string) (string, error) { + const ( + oauthSuffix = "/.well-known/oauth-authorization-server" + oidcSuffix = "/.well-known/openid-configuration" + ) + + u, err := url.Parse(discoveryURL) + if err != nil { + return "", fmt.Errorf("parse discovery url %q: %w", discoveryURL, err) + } + if u.Scheme == "" || u.Host == "" { + return "", fmt.Errorf("discovery url missing scheme or host: %q", discoveryURL) + } + + switch { + case strings.HasSuffix(u.Path, oauthSuffix): + u.Path = strings.TrimSuffix(u.Path, oauthSuffix) + case strings.HasSuffix(u.Path, oidcSuffix): + u.Path = strings.TrimSuffix(u.Path, oidcSuffix) + default: + // Custom (non-well-known) discovery URL — fall back to origin. + u.Path = "" + } + u.RawQuery = "" + u.Fragment = "" + return u.String(), nil +} + +// endpointsFromMetadata converts a FetchAuthorizationServerMetadata* result +// into a dcrEndpoints bundle. Handles the ErrRegistrationEndpointMissing +// sentinel by synthesising {origin}/register. +// +// authorization_endpoint and token_endpoint are validated for HTTPS / well- +// formedness before being copied into the bundle. A self-consistent metadata +// document — possible if TLS to the metadata host is compromised, or if the +// upstream is misconfigured — could otherwise plant http:// URLs that flow +// through to the authorization-code and token-exchange call paths. +func endpointsFromMetadata( + metadata *oauthproto.AuthorizationServerMetadata, + fetchErr error, + upstreamIssuer string, +) (*dcrEndpoints, error) { + if fetchErr != nil && !errors.Is(fetchErr, oauthproto.ErrRegistrationEndpointMissing) { + return nil, fmt.Errorf("discover authorization server metadata: %w", fetchErr) + } + + if err := validateUpstreamEndpointURL(metadata.AuthorizationEndpoint, "authorization_endpoint"); err != nil { + return nil, fmt.Errorf("dcr: discovered %w", err) + } + if err := validateUpstreamEndpointURL(metadata.TokenEndpoint, "token_endpoint"); err != nil { + return nil, fmt.Errorf("dcr: discovered %w", err) + } + + registrationEndpoint := metadata.RegistrationEndpoint + if errors.Is(fetchErr, oauthproto.ErrRegistrationEndpointMissing) { + // Metadata is otherwise valid — synthesise the registration + // endpoint from the upstream issuer's origin. + // FetchAuthorizationServerMetadata* deliberately returns + // ErrRegistrationEndpointMissing alongside a non-nil metadata + // document, so we still use the returned endpoints/scopes. + synth, err := synthesiseRegistrationEndpoint(upstreamIssuer) + if err != nil { + return nil, fmt.Errorf("synthesise registration endpoint: %w", err) + } + registrationEndpoint = synth + } + + return &dcrEndpoints{ + authorizationEndpoint: metadata.AuthorizationEndpoint, + tokenEndpoint: metadata.TokenEndpoint, + registrationEndpoint: registrationEndpoint, + tokenEndpointAuthMethodsSupported: metadata.TokenEndpointAuthMethodsSupported, + scopesSupported: metadata.ScopesSupported, + codeChallengeMethodsSupported: metadata.CodeChallengeMethodsSupported, + }, nil +} + +// synthesiseRegistrationEndpoint builds {upstreamIssuer}/register, used when +// discovery succeeds but omits registration_endpoint. The argument is the +// upstream's issuer (recovered from the discovery URL), not this auth +// server's local issuer. +// +// The issuer's path is preserved so multi-tenant upstreams that ship DCR +// without advertising it (e.g. https://idp.example.com/tenants/acme) keep +// their tenant prefix in the synthesised URL. Stripping the path would land +// the registration request at a global /register that does not match the +// tenant-aware token/authorize URLs already accepted from metadata. +func synthesiseRegistrationEndpoint(upstreamIssuer string) (string, error) { + u, err := url.Parse(upstreamIssuer) + if err != nil { + return "", fmt.Errorf("parse issuer: %w", err) + } + if u.Scheme == "" || u.Host == "" { + return "", fmt.Errorf("issuer missing scheme or host: %q", upstreamIssuer) + } + synth := &url.URL{ + Scheme: u.Scheme, + Host: u.Host, + Path: strings.TrimRight(u.Path, "/") + "/register", + } + return synth.String(), nil +} + +// resolveUpstreamRedirectURI returns the redirect URI to present to the +// upstream. The caller-supplied value wins; otherwise a default is derived +// from {localIssuer}/oauth/callback. HTTPS is required except for loopback +// hosts (development). +// +// localIssuer here is *this* auth server's issuer — the redirect URI is +// where the upstream sends the user back to us, so it must live on our +// origin, not the upstream's. +// +// The issuer's path is preserved when defaulting: an issuer with a tenant +// prefix produces a redirect URI under that prefix, not at the host root. +// url.URL.ResolveReference would replace the path entirely because +// defaultUpstreamRedirectPath starts with "/", so we explicitly concatenate +// instead. +func resolveUpstreamRedirectURI(configured, localIssuer string) (string, error) { + if configured != "" { + u, err := url.Parse(configured) + if err != nil { + return "", fmt.Errorf("invalid redirect uri %q: %w", configured, err) + } + if err := validateRedirectURL(u); err != nil { + return "", err + } + return configured, nil + } + + issuerURL, err := url.Parse(localIssuer) + if err != nil { + return "", fmt.Errorf("invalid issuer %q: %w", localIssuer, err) + } + resolved := &url.URL{ + Scheme: issuerURL.Scheme, + Host: issuerURL.Host, + Path: strings.TrimRight(issuerURL.Path, "/") + defaultUpstreamRedirectPath, + } + if err := validateRedirectURL(resolved); err != nil { + return "", err + } + return resolved.String(), nil +} + +// validateRedirectURL enforces the HTTPS-except-loopback rule shared across +// OAuth URLs. +func validateRedirectURL(u *url.URL) error { + if u.Scheme == "" || u.Host == "" { + return fmt.Errorf("redirect uri missing scheme or host: %q", u.String()) + } + if u.Scheme != "https" && !networking.IsLocalhost(u.Host) { + return fmt.Errorf("redirect uri must use https (got %q) unless host is loopback", u.Scheme) + } + return nil +} + +// validateUpstreamEndpointURL enforces well-formedness and the +// HTTPS-except-loopback rule for an upstream-supplied OAuth endpoint URL. +// +// Used at every point where an endpoint URL enters the resolver from outside +// — operator-configured RegistrationEndpoint, or authorization_endpoint / +// token_endpoint copied out of an upstream's metadata document. The +// downstream oauthproto.validateRegistrationEndpoint enforces HTTPS for the +// registration URL too, but only after a bearer-token transport has already +// been constructed, so a local fail-fast check keeps the +// "no bearer-token transport for a non-HTTPS endpoint" invariant local. +// +// label is included in the error message ("registration_endpoint", +// "authorization_endpoint", "token_endpoint", …) so failures can be tied +// back to the specific field without an additional wrapper. +func validateUpstreamEndpointURL(rawURL, label string) error { + if rawURL == "" { + return fmt.Errorf("%s is required", label) + } + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("%s %q is not a valid URL: %w", label, rawURL, err) + } + if u.Scheme == "" || u.Host == "" { + return fmt.Errorf("%s %q missing scheme or host", label, rawURL) + } + if u.Scheme != "https" && !networking.IsLocalhost(u.Host) { + return fmt.Errorf("%s %q must use https unless host is loopback (got scheme %q)", + label, rawURL, u.Scheme) + } + return nil +} + +// selectTokenEndpointAuthMethod returns the preferred token endpoint auth +// method given the server's advertised set, intersected with our preference +// order. When the server does not advertise any methods the caller's default +// of client_secret_basic is used (RFC 6749 §2.3.1 baseline). +// +// PKCE coupling for "none": the public-client method "none" is selected only +// when the upstream also advertises S256 in code_challenge_methods_supported. +// RFC 7636 §4.2 / OAuth 2.1 require S256 PKCE for public clients; registering +// as none against an upstream that advertises only "plain" — or omits the +// field entirely — would be a compliance gap. When S256 is missing, "none" +// is skipped (the iteration continues to the next less-preferred method), +// and if no other method is mutually supported the function returns an error +// so the operator sees a clear failure at boot rather than a silent +// downgrade at runtime. +func selectTokenEndpointAuthMethod(serverSupported, codeChallengeMethodsSupported []string) (string, error) { + if len(serverSupported) == 0 { + return "client_secret_basic", nil + } + + supported := make(map[string]struct{}, len(serverSupported)) + for _, m := range serverSupported { + supported[m] = struct{}{} + } + + pkceS256Advertised := slices.Contains(codeChallengeMethodsSupported, oauthproto.PKCEMethodS256) + + for _, m := range authMethodPreference { + if _, ok := supported[m]; !ok { + continue + } + if m == "none" && !pkceS256Advertised { + // Public-client registration without S256 PKCE is non-compliant + // per RFC 7636 / OAuth 2.1. Try the next less-preferred method. + continue + } + return m, nil + } + if _, noneOnly := supported["none"]; noneOnly && !pkceS256Advertised { + return "", fmt.Errorf( + "upstream advertises only token_endpoint_auth_method=none but does not advertise "+ + "S256 in code_challenge_methods_supported (got %v); refusing to register a public "+ + "client without S256 PKCE per RFC 7636 / OAuth 2.1", codeChallengeMethodsSupported) + } + return "", fmt.Errorf( + "no supported token_endpoint_auth_method in server advertisement %v; "+ + "client supports %v", serverSupported, authMethodPreference) +} + +// bearerTokenTransport is an http.RoundTripper that adds +// Authorization: Bearer {token} to each outgoing request. Used to supply the +// RFC 7591 initial access token to oauthproto.RegisterClientDynamically +// without leaking the abstraction up into that package. +// +// The wrapping http.Client (see newDCRHTTPClient) refuses to follow HTTP +// redirects via CheckRedirect, so this transport is only ever invoked for +// the original registration request — never for a redirected request whose +// URL the upstream chose. That is what prevents this token from being +// forwarded to a foreign origin. +type bearerTokenTransport struct { + token string + next http.RoundTripper +} + +// RoundTrip implements http.RoundTripper. +func (t *bearerTokenTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Clone per http.RoundTripper contract: implementations must not modify + // the input request's headers. + cp := req.Clone(req.Context()) + cp.Header.Set("Authorization", "Bearer "+t.token) + return t.next.RoundTrip(cp) +} + +// errDCRRedirectRefused is returned when a DCR registration endpoint +// responds with a 30x. Net/http surfaces it via *url.Error so callers +// observe a clear failure mode instead of a confusing JSON decode error. +var errDCRRedirectRefused = errors.New( + "dcr: registration endpoint returned a redirect; refusing to follow " + + "to avoid forwarding the RFC 7591 initial access token to a foreign origin") + +// newDCRHTTPClient returns the http.Client to pass to +// oauthproto.RegisterClientDynamically. The client always blocks HTTP +// redirects so that an upstream cannot use a 30x to coerce us into +// re-issuing the registration request (and any attached +// Authorization: Bearer header) against a different origin. RFC 7591 §3 +// does not require redirect support, so refusing them is safe. +// +// When initialAccessToken is non-empty the client also wraps the canonical +// DCR client's transport with a bearerTokenTransport that injects the +// Authorization header. The combination of the bearer transport plus the +// redirect block is what prevents the token-leak class of bug. +// +// The timeout policy is sourced from oauthproto.NewDefaultDCRClient so +// future tightening of those bounds propagates automatically. +func newDCRHTTPClient(initialAccessToken string) *http.Client { + client := oauthproto.NewDefaultDCRClient() + client.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { + return errDCRRedirectRefused + } + + if initialAccessToken == "" { + return client + } + + next := client.Transport + if next == nil { + next = http.DefaultTransport + } + client.Transport = &bearerTokenTransport{ + token: initialAccessToken, + next: next, + } + return client +} diff --git a/pkg/authserver/runner/dcr_store.go b/pkg/authserver/runner/dcr_store.go new file mode 100644 index 0000000000..32468c4153 --- /dev/null +++ b/pkg/authserver/runner/dcr_store.go @@ -0,0 +1,149 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package runner + +import ( + "context" + "fmt" + "sync" + "time" +) + +// dcrStaleAgeThreshold is the age beyond which a cached DCR resolution is +// considered stale and logged as such by higher-level wiring. The store itself +// does not expire or evict entries — RFC 7591 client registrations are +// long-lived and are only purged by explicit RFC 7592 deregistration. This +// threshold is consumed by Step 2g observability logs introduced in the next +// PR in the DCR stack (sub-issue C, #5039); 5042 only defines the constant +// so the consumer can land without a cross-PR cycle. +// +//nolint:unused // consumed by lookupCachedResolution in #5039 +const dcrStaleAgeThreshold = 90 * 24 * time.Hour + +// DCRKey is the canonical lookup key for a DCR resolution. The tuple is +// designed so a future Redis-backed store can serialise it into a single key +// segment (Phase 3) without redefining the canonical form. ScopesHash rather +// than the raw scope slice is used so the key is comparable and order- +// insensitive. +type DCRKey struct { + // Issuer is *this* auth server's issuer identifier — the local issuer + // of the embedded authorization server that performed the registration, + // NOT the upstream's. The cache is keyed by this value because two + // different local issuers registering against the same upstream are + // distinct OAuth clients and must not share credentials. The upstream's + // issuer is used only for RFC 8414 §3.3 metadata verification inside + // the resolver and is not part of the cache key. + Issuer string + + // RedirectURI is the redirect URI registered with the upstream + // authorization server. Lives on the local issuer's origin since it is + // where the upstream sends the user back to us after authentication. + RedirectURI string + + // ScopesHash is the SHA-256 hex digest of the sorted scope list. + // See scopesHash in dcr.go for the canonical form. + ScopesHash string +} + +// DCRCredentialStore caches RFC 7591 Dynamic Client Registration resolutions +// keyed by the (Issuer, RedirectURI, ScopesHash) tuple. Implementations must +// be safe for concurrent use. +// +// The store is an in-memory cache of long-lived registrations — it is not a +// durable store, and entries are never expired or evicted by the store +// itself. Callers are responsible for invalidating entries when the +// underlying registration is revoked (e.g., via RFC 7592 deregistration). +type DCRCredentialStore interface { + // Get returns the cached resolution for key, or (nil, false, nil) if the + // key is not present. An error is returned only on backend failure. + Get(ctx context.Context, key DCRKey) (*DCRResolution, bool, error) + + // Put stores the resolution for key, overwriting any existing entry. + // Implementations must reject a nil resolution with an error rather + // than silently succeeding — a no-op would leave callers with no + // debug trail for the subsequent Get miss. + Put(ctx context.Context, key DCRKey, resolution *DCRResolution) error +} + +// NewInMemoryDCRCredentialStore returns a thread-safe in-memory +// DCRCredentialStore intended for tests and single-replica development +// deployments. Production deployments should use the Redis-backed store +// introduced in Phase 3, which addresses the cross-replica sharing, +// durability, and cross-process coordination gaps documented below. +// +// Entries are retained for the process lifetime; there is no TTL and no +// background cleanup goroutine. The unbounded-cache footgun called out in +// .claude/rules/go-style.md "Resource Leaks" does not bite here because the +// key space is bounded by the operator-configured upstream count, and this +// implementation is not the production answer. +// +// What this enables: serialises Get/Put against a single in-process map so +// concurrent callers within one authserver process see a consistent view of +// the cache without redundant RFC 7591 registrations. +// +// What this does NOT solve: +// - Cross-replica sharing: each replica holds its own independent map, so a +// registration performed on replica A is not visible to replica B. In a +// multi-replica deployment every replica will register its own DCR client +// against the upstream on first boot. Phase 3 introduces a Redis-backed +// store that addresses this. +// - Durability across restarts: process exit drops every entry; the next +// boot re-registers. Operators relying on stable client_ids must use a +// persistent backend. +// - Cross-process write coordination: two processes (or replicas) calling +// Put for the same DCRKey concurrently will both succeed against their +// local maps; whichever registration the upstream accepts last wins on +// that side, the loser becomes orphaned. The +// resolveDCRCredentials-level singleflight in dcr.go only deduplicates +// within one process. +func NewInMemoryDCRCredentialStore() DCRCredentialStore { + return &inMemoryDCRCredentialStore{ + entries: make(map[DCRKey]*DCRResolution), + } +} + +// inMemoryDCRCredentialStore is the default DCRCredentialStore backed by a +// plain map guarded by sync.RWMutex. Modelled on +// pkg/authserver/storage/memory.go but stripped of TTL bookkeeping — DCR +// resolutions are long-lived. +type inMemoryDCRCredentialStore struct { + mu sync.RWMutex + entries map[DCRKey]*DCRResolution +} + +// Get implements DCRCredentialStore. +func (s *inMemoryDCRCredentialStore) Get(_ context.Context, key DCRKey) (*DCRResolution, bool, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + res, ok := s.entries[key] + if !ok { + return nil, false, nil + } + // Return a defensive copy so mutations by the caller never reach the + // cache entry. This mirrors the copy-before-mutate rule in + // .claude/rules/go-style.md. + cp := *res + return &cp, true, nil +} + +// Put implements DCRCredentialStore. +// +// A nil resolution is rejected rather than silently no-oped: a caller +// passing nil would otherwise get a successful return, observe a miss on +// the next Get, and have no error trail to debug from. Per the constructor- +// validation rule in .claude/rules/go-style.md, fail loudly at the boundary. +func (s *inMemoryDCRCredentialStore) Put(_ context.Context, key DCRKey, resolution *DCRResolution) error { + if resolution == nil { + return fmt.Errorf("dcr: resolution must not be nil") + } + s.mu.Lock() + defer s.mu.Unlock() + + // Defensive copy so the caller's subsequent mutations do not reach the + // cache entry. + cp := *resolution + s.entries[key] = &cp + return nil +} diff --git a/pkg/authserver/runner/dcr_store_test.go b/pkg/authserver/runner/dcr_store_test.go new file mode 100644 index 0000000000..afffba4b1f --- /dev/null +++ b/pkg/authserver/runner/dcr_store_test.go @@ -0,0 +1,321 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package runner + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInMemoryDCRCredentialStore_PutGet_RoundTrip(t *testing.T) { + t.Parallel() + + store := NewInMemoryDCRCredentialStore() + ctx := context.Background() + + key := DCRKey{ + Issuer: "https://idp.example.com", + RedirectURI: "https://toolhive.example.com/oauth/callback", + ScopesHash: scopesHash([]string{"openid", "profile"}), + } + resolution := &DCRResolution{ + ClientID: "client-abc", + ClientSecret: "secret-xyz", + AuthorizationEndpoint: "https://idp.example.com/authorize", + TokenEndpoint: "https://idp.example.com/token", + RegistrationAccessToken: "reg-tok", + RegistrationClientURI: "https://idp.example.com/register/client-abc", + TokenEndpointAuthMethod: "client_secret_basic", + CreatedAt: time.Now(), + } + + require.NoError(t, store.Put(ctx, key, resolution)) + + got, ok, err := store.Get(ctx, key) + require.NoError(t, err) + require.True(t, ok) + assert.Equal(t, resolution.ClientID, got.ClientID) + assert.Equal(t, resolution.ClientSecret, got.ClientSecret) + assert.Equal(t, resolution.AuthorizationEndpoint, got.AuthorizationEndpoint) + assert.Equal(t, resolution.TokenEndpoint, got.TokenEndpoint) + assert.Equal(t, resolution.RegistrationAccessToken, got.RegistrationAccessToken) + assert.Equal(t, resolution.RegistrationClientURI, got.RegistrationClientURI) + assert.Equal(t, resolution.TokenEndpointAuthMethod, got.TokenEndpointAuthMethod) +} + +func TestInMemoryDCRCredentialStore_Get_MissingKey(t *testing.T) { + t.Parallel() + + store := NewInMemoryDCRCredentialStore() + ctx := context.Background() + + got, ok, err := store.Get(ctx, DCRKey{Issuer: "https://unknown.example.com"}) + require.NoError(t, err) + assert.False(t, ok) + assert.Nil(t, got) +} + +func TestInMemoryDCRCredentialStore_DistinctKeysDoNotCollide(t *testing.T) { + t.Parallel() + + store := NewInMemoryDCRCredentialStore() + ctx := context.Background() + + keyA := DCRKey{ + Issuer: "https://idp-a.example.com", + RedirectURI: "https://toolhive.example.com/oauth/callback", + ScopesHash: scopesHash([]string{"openid"}), + } + keyB := DCRKey{ + Issuer: "https://idp-b.example.com", + RedirectURI: "https://toolhive.example.com/oauth/callback", + ScopesHash: scopesHash([]string{"openid"}), + } + keyC := DCRKey{ + Issuer: "https://idp-a.example.com", + RedirectURI: "https://other.example.com/callback", + ScopesHash: scopesHash([]string{"openid"}), + } + keyD := DCRKey{ + Issuer: "https://idp-a.example.com", + RedirectURI: "https://toolhive.example.com/oauth/callback", + ScopesHash: scopesHash([]string{"openid", "email"}), + } + + require.NoError(t, store.Put(ctx, keyA, &DCRResolution{ClientID: "a"})) + require.NoError(t, store.Put(ctx, keyB, &DCRResolution{ClientID: "b"})) + require.NoError(t, store.Put(ctx, keyC, &DCRResolution{ClientID: "c"})) + require.NoError(t, store.Put(ctx, keyD, &DCRResolution{ClientID: "d"})) + + for _, tc := range []struct { + key DCRKey + expected string + }{ + {keyA, "a"}, + {keyB, "b"}, + {keyC, "c"}, + {keyD, "d"}, + } { + got, ok, err := store.Get(ctx, tc.key) + require.NoError(t, err) + require.True(t, ok, "key %+v should be present", tc.key) + assert.Equal(t, tc.expected, got.ClientID) + } +} + +func TestInMemoryDCRCredentialStore_Put_OverwritesExisting(t *testing.T) { + t.Parallel() + + store := NewInMemoryDCRCredentialStore() + ctx := context.Background() + + key := DCRKey{Issuer: "https://idp.example.com", RedirectURI: "https://x.example.com/cb"} + require.NoError(t, store.Put(ctx, key, &DCRResolution{ClientID: "first"})) + require.NoError(t, store.Put(ctx, key, &DCRResolution{ClientID: "second"})) + + got, ok, err := store.Get(ctx, key) + require.NoError(t, err) + require.True(t, ok) + assert.Equal(t, "second", got.ClientID) +} + +// TestInMemoryDCRCredentialStore_Put_RejectsNilResolution pins the +// fail-loud-on-invalid-input contract: passing nil must error rather than +// silently no-op. A silent no-op would leave the caller with a successful +// Put followed by a Get miss and no debug trail to explain it. +func TestInMemoryDCRCredentialStore_Put_RejectsNilResolution(t *testing.T) { + t.Parallel() + + store := NewInMemoryDCRCredentialStore() + ctx := context.Background() + key := DCRKey{Issuer: "https://idp.example.com", RedirectURI: "https://x.example.com/cb"} + + err := store.Put(ctx, key, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "must not be nil") + + // And confirm the rejection did not partially populate the store. + _, ok, getErr := store.Get(ctx, key) + require.NoError(t, getErr) + assert.False(t, ok, "rejected Put must not leave any entry behind") +} + +func TestInMemoryDCRCredentialStore_GetReturnsDefensiveCopy(t *testing.T) { + t.Parallel() + + store := NewInMemoryDCRCredentialStore() + ctx := context.Background() + + key := DCRKey{Issuer: "https://idp.example.com"} + require.NoError(t, store.Put(ctx, key, &DCRResolution{ClientID: "orig"})) + + got, ok, err := store.Get(ctx, key) + require.NoError(t, err) + require.True(t, ok) + got.ClientID = "mutated" + + refetched, ok, err := store.Get(ctx, key) + require.NoError(t, err) + require.True(t, ok) + assert.Equal(t, "orig", refetched.ClientID) +} + +func TestScopesHash_StableAcrossPermutation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + a, b []string + }{ + { + name: "two-element permutation", + a: []string{"openid", "profile"}, + b: []string{"profile", "openid"}, + }, + { + name: "three-element permutation", + a: []string{"openid", "profile", "email"}, + b: []string{"email", "openid", "profile"}, + }, + { + // OAuth scope sets are sets, not multisets (RFC 6749 §3.3). + // scopesHash deduplicates before hashing so a caller who + // accidentally repeats a scope still hits the cache entry + // keyed under the canonical set. + name: "single element equals double element duplicate", + a: []string{"openid"}, + b: []string{"openid", "openid"}, + }, + { + name: "three-element with duplicate equals two-element unique", + a: []string{"openid", "profile", "openid"}, + b: []string{"openid", "profile"}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, scopesHash(tc.a), scopesHash(tc.b)) + }) + } +} + +func TestScopesHash_DistinctForDistinctScopes(t *testing.T) { + t.Parallel() + + a := scopesHash([]string{"openid"}) + b := scopesHash([]string{"openid", "profile"}) + c := scopesHash([]string{"profile"}) + d := scopesHash(nil) + e := scopesHash([]string{}) + + // Non-empty distinct sets produce distinct hashes. + assert.NotEqual(t, a, b) + assert.NotEqual(t, a, c) + assert.NotEqual(t, b, c) + assert.NotEqual(t, a, d) + // nil and empty slice canonicalise to the same hash (both sort-then-join + // to the empty canonical form). + assert.Equal(t, d, e) +} + +func TestScopesHash_NoCollisionFromBoundaryJoin(t *testing.T) { + t.Parallel() + + // Without a delimiter that cannot appear inside a scope value, + // ["ab", "c"] and ["a", "bc"] would collide. This test exists to + // prevent a regression if the canonical form is ever simplified. + h1 := scopesHash([]string{"ab", "c"}) + h2 := scopesHash([]string{"a", "bc"}) + assert.NotEqual(t, h1, h2) +} + +// TestInMemoryDCRCredentialStore_ConcurrentAccess fans out N goroutines +// performing alternating Put / Get against overlapping and disjoint keys, +// exercising the sync.RWMutex guard advertised in the DCRCredentialStore +// interface doc. With go test -race this catches any future change that +// drops the lock or introduces a data race in the map access. +// +// The test is bounded by a fail-fast deadline per .claude/rules/testing.md +// "Concurrent Tests: Always Add Timeouts to Blocking Barriers" — a +// regression that deadlocks would otherwise hang until the global Go test +// timeout. +func TestInMemoryDCRCredentialStore_ConcurrentAccess(t *testing.T) { + t.Parallel() + + store := NewInMemoryDCRCredentialStore() + + const ( + workers = 16 + opsPerWorker = 200 + ) + + // Two key spaces: overlapping (every worker writes the same keys, so the + // lock must serialise their writes) and disjoint (each worker has its own + // key space, so reads never see another worker's writes). + overlappingKey := func(i int) DCRKey { + return DCRKey{ + Issuer: "https://idp.example.com", + RedirectURI: "https://thv.example.com/oauth/callback", + ScopesHash: fmt.Sprintf("overlap-%d", i%4), + } + } + disjointKey := func(worker, i int) DCRKey { + return DCRKey{ + Issuer: fmt.Sprintf("https://idp-%d.example.com", worker), + RedirectURI: "https://thv.example.com/oauth/callback", + ScopesHash: fmt.Sprintf("disjoint-%d", i), + } + } + + var errCount int32 + var wg sync.WaitGroup + wg.Add(workers) + for w := 0; w < workers; w++ { + go func(worker int) { + defer wg.Done() + ctx := context.Background() + for i := 0; i < opsPerWorker; i++ { + resolution := &DCRResolution{ + ClientID: fmt.Sprintf("worker-%d-op-%d", worker, i), + CreatedAt: time.Now(), + } + if i%2 == 0 { + if err := store.Put(ctx, overlappingKey(i), resolution); err != nil { + atomic.AddInt32(&errCount, 1) + } + if _, _, err := store.Get(ctx, overlappingKey(i)); err != nil { + atomic.AddInt32(&errCount, 1) + } + } else { + if err := store.Put(ctx, disjointKey(worker, i), resolution); err != nil { + atomic.AddInt32(&errCount, 1) + } + if _, _, err := store.Get(ctx, disjointKey(worker, i)); err != nil { + atomic.AddInt32(&errCount, 1) + } + } + } + }(w) + } + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + + select { + case <-done: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for concurrent store operations to finish; possible deadlock") + } + + assert.Zero(t, atomic.LoadInt32(&errCount), + "no Get/Put should have errored under concurrent access") +} diff --git a/pkg/authserver/runner/dcr_test.go b/pkg/authserver/runner/dcr_test.go new file mode 100644 index 0000000000..ef234dbe1c --- /dev/null +++ b/pkg/authserver/runner/dcr_test.go @@ -0,0 +1,1598 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package runner + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/authserver" + "github.com/stacklok/toolhive/pkg/oauthproto" +) + +// dcrTestHandlerConfig controls the behaviour of newDCRTestServer. +type dcrTestHandlerConfig struct { + // omitRegistrationEndpoint causes discovery metadata to omit the + // registration_endpoint field, triggering the synthesised /register + // path. + omitRegistrationEndpoint bool + + // registrationEndpointPath overrides the path served as + // registration_endpoint. Defaults to "/register". + registrationEndpointPath string + + // tokenEndpointAuthMethodsSupported is advertised in metadata. + tokenEndpointAuthMethodsSupported []string + + // scopesSupported is advertised in metadata. + scopesSupported []string + + // codeChallengeMethodsSupported is advertised in metadata. Tests that + // exercise public-client (none) registration must include "S256" here, + // since selectTokenEndpointAuthMethod refuses to select none without + // it (RFC 7636 / OAuth 2.1). + codeChallengeMethodsSupported []string + + // observeRegistration is called for each request hitting the + // registration endpoint. Safe for concurrent use. + observeRegistration func(r *http.Request, body []byte) + + // clientIDIssuedAt and clientSecretExpiresAt are echoed back in the + // RFC 7591 §3.2.1 response. Both are int64 epoch seconds; 0 is the wire + // convention for "field absent" and (for ClientSecretExpiresAt) "secret + // does not expire". + clientIDIssuedAt int64 + clientSecretExpiresAt int64 +} + +// newDCRTestServer mounts RFC 8414 metadata and a DCR endpoint on a single +// httptest.NewServer. The returned server's URL is the issuer; callers must +// t.Cleanup(server.Close) (not defer, when using t.Parallel()). +func newDCRTestServer(t *testing.T, cfg dcrTestHandlerConfig) *httptest.Server { + t.Helper() + + mux := http.NewServeMux() + var server *httptest.Server + + registrationPath := cfg.registrationEndpointPath + if registrationPath == "" { + registrationPath = "/register" + } + + mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, _ *http.Request) { + md := oauthproto.AuthorizationServerMetadata{ + Issuer: server.URL, + AuthorizationEndpoint: server.URL + "/authorize", + TokenEndpoint: server.URL + "/token", + JWKSURI: server.URL + "/jwks", + TokenEndpointAuthMethodsSupported: cfg.tokenEndpointAuthMethodsSupported, + ScopesSupported: cfg.scopesSupported, + CodeChallengeMethodsSupported: cfg.codeChallengeMethodsSupported, + } + if !cfg.omitRegistrationEndpoint { + md.RegistrationEndpoint = server.URL + registrationPath + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(md) + }) + + mux.HandleFunc(registrationPath, func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + _ = r.Body.Close() + if cfg.observeRegistration != nil { + cfg.observeRegistration(r, body) + } + // Decode the request to echo the auth method back in the response. + var req oauthproto.DynamicClientRegistrationRequest + if err := json.Unmarshal(body, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := oauthproto.DynamicClientRegistrationResponse{ + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + RegistrationAccessToken: "test-reg-token", + RegistrationClientURI: server.URL + "/register/test-client-id", + TokenEndpointAuthMethod: req.TokenEndpointAuthMethod, + ClientIDIssuedAt: cfg.clientIDIssuedAt, + ClientSecretExpiresAt: cfg.clientSecretExpiresAt, + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(resp) + }) + + server = httptest.NewServer(mux) + t.Cleanup(server.Close) + return server +} + +func TestResolveDCRCredentials_CacheHitShortCircuits(t *testing.T) { + t.Parallel() + + // Count every request to every path — discovery, registration, + // anything. The acceptance criterion is that a cache hit issues zero + // network I/O, so the cache-hit path must never reach this server. + var totalRequests int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&totalRequests, 1) + w.WriteHeader(http.StatusTeapot) + })) + t.Cleanup(server.Close) + + cache := NewInMemoryDCRCredentialStore() + issuer := server.URL + + // Pre-populate the cache with a resolution matching the key we will + // look up. + redirectURI := issuer + "/oauth/callback" + key := DCRKey{ + Issuer: issuer, + RedirectURI: redirectURI, + ScopesHash: scopesHash([]string{"openid", "profile"}), + } + preloaded := &DCRResolution{ + ClientID: "preloaded-id", + ClientSecret: "preloaded-secret", + AuthorizationEndpoint: "https://preloaded/authorize", + TokenEndpoint: "https://preloaded/token", + } + require.NoError(t, cache.Put(context.Background(), key, preloaded)) + + rc := &authserver.OAuth2UpstreamRunConfig{ + Scopes: []string{"openid", "profile"}, + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: issuer + "/.well-known/openid-configuration", + }, + } + + got, err := resolveDCRCredentials(context.Background(), rc, issuer, cache) + require.NoError(t, err) + assert.Equal(t, "preloaded-id", got.ClientID) + assert.Equal(t, "preloaded-secret", got.ClientSecret) + assert.Equal(t, int32(0), atomic.LoadInt32(&totalRequests), + "cache hit must not issue any network I/O (discovery or registration)") +} + +func TestResolveDCRCredentials_RegistersOnCacheMiss(t *testing.T) { + t.Parallel() + + var gotAuthHeader string + var gotBody []byte + server := newDCRTestServer(t, dcrTestHandlerConfig{ + tokenEndpointAuthMethodsSupported: []string{"client_secret_basic"}, + scopesSupported: []string{"openid", "profile"}, + observeRegistration: func(r *http.Request, body []byte) { + gotAuthHeader = r.Header.Get("Authorization") + gotBody = body + }, + }) + + cache := NewInMemoryDCRCredentialStore() + issuer := server.URL + rc := &authserver.OAuth2UpstreamRunConfig{ + Scopes: []string{"openid", "profile"}, + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: issuer + "/.well-known/oauth-authorization-server", + }, + } + + res, err := resolveDCRCredentials(context.Background(), rc, issuer, cache) + require.NoError(t, err) + assert.Equal(t, "test-client-id", res.ClientID) + assert.Equal(t, "test-client-secret", res.ClientSecret) + assert.Equal(t, "test-reg-token", res.RegistrationAccessToken) + assert.Equal(t, issuer+"/register/test-client-id", res.RegistrationClientURI) + assert.Equal(t, issuer+"/authorize", res.AuthorizationEndpoint) + assert.Equal(t, issuer+"/token", res.TokenEndpoint) + assert.Equal(t, "client_secret_basic", res.TokenEndpointAuthMethod) + assert.False(t, res.CreatedAt.IsZero(), "CreatedAt should be populated") + // No initial access token configured -> no Authorization header. + assert.Empty(t, gotAuthHeader) + + // Verify the request body carried the expected fields. + var req oauthproto.DynamicClientRegistrationRequest + require.NoError(t, json.Unmarshal(gotBody, &req)) + assert.Equal(t, []string{issuer + "/oauth/callback"}, req.RedirectURIs) + assert.ElementsMatch(t, []string{"openid", "profile"}, []string(req.Scopes)) + + // Cache was populated. + cached, ok, err := cache.Get(context.Background(), + DCRKey{Issuer: issuer, RedirectURI: issuer + "/oauth/callback", ScopesHash: scopesHash([]string{"openid", "profile"})}) + require.NoError(t, err) + require.True(t, ok) + assert.Equal(t, "test-client-id", cached.ClientID) +} + +func TestResolveDCRCredentials_ExplicitEndpointsOverride(t *testing.T) { + t.Parallel() + + server := newDCRTestServer(t, dcrTestHandlerConfig{}) + cache := NewInMemoryDCRCredentialStore() + issuer := server.URL + + rc := &authserver.OAuth2UpstreamRunConfig{ + AuthorizationEndpoint: "https://explicit.example.com/authorize", + TokenEndpoint: "https://explicit.example.com/token", + Scopes: []string{"openid"}, + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: issuer + "/.well-known/oauth-authorization-server", + }, + } + + res, err := resolveDCRCredentials(context.Background(), rc, issuer, cache) + require.NoError(t, err) + assert.Equal(t, "https://explicit.example.com/authorize", res.AuthorizationEndpoint) + assert.Equal(t, "https://explicit.example.com/token", res.TokenEndpoint) +} + +func TestResolveDCRCredentials_InitialAccessTokenAsBearer(t *testing.T) { + t.Parallel() + + var gotAuthHeader string + server := newDCRTestServer(t, dcrTestHandlerConfig{ + observeRegistration: func(r *http.Request, _ []byte) { + gotAuthHeader = r.Header.Get("Authorization") + }, + }) + + // Use a file-based initial access token so the test can remain parallel + // (t.Setenv and t.Parallel are mutually exclusive). tokenPath is scoped + // to t.TempDir(), so concurrent subtests cannot clobber each other's + // token values even if the test is later subdivided. + tokenPath := filepath.Join(t.TempDir(), "iat") + require.NoError(t, os.WriteFile(tokenPath, []byte("iat-secret-value\n"), 0o600)) + + cache := NewInMemoryDCRCredentialStore() + issuer := server.URL + rc := &authserver.OAuth2UpstreamRunConfig{ + Scopes: []string{"openid"}, + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: issuer + "/.well-known/oauth-authorization-server", + InitialAccessTokenFile: tokenPath, + }, + } + + _, err := resolveDCRCredentials(context.Background(), rc, issuer, cache) + require.NoError(t, err) + assert.Equal(t, "Bearer iat-secret-value", gotAuthHeader) +} + +// TestResolveDCRCredentials_DoesNotForwardBearerOnRedirect pins the +// security property that an upstream cannot use a 30x redirect from the +// registration endpoint to coerce the resolver into re-issuing the +// registration request — and the attached RFC 7591 initial access token — +// against a different origin. The resolver must refuse the redirect; the +// foreign origin must observe zero traffic. +func TestResolveDCRCredentials_DoesNotForwardBearerOnRedirect(t *testing.T) { + t.Parallel() + + // Foreign origin: a separate httptest server that records every request + // it receives. After the test we assert it received exactly zero + // requests, which proves the bearer token never crossed origins. + var foreignHits int32 + var foreignAuthHeaders []string + var foreignMu sync.Mutex + foreign := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + foreignMu.Lock() + atomic.AddInt32(&foreignHits, 1) + foreignAuthHeaders = append(foreignAuthHeaders, r.Header.Get("Authorization")) + foreignMu.Unlock() + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(foreign.Close) + + // Upstream: serves discovery normally, but its /register handler 302s + // to the foreign origin. A non-defended client would re-issue the + // registration request (with the Authorization header) against + // foreign.URL/stolen. + mux := http.NewServeMux() + var upstream *httptest.Server + mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(oauthproto.AuthorizationServerMetadata{ + Issuer: upstream.URL, + AuthorizationEndpoint: upstream.URL + "/authorize", + TokenEndpoint: upstream.URL + "/token", + JWKSURI: upstream.URL + "/jwks", + RegistrationEndpoint: upstream.URL + "/register", + }) + }) + mux.HandleFunc("/register", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, foreign.URL+"/stolen", http.StatusFound) + }) + upstream = httptest.NewServer(mux) + t.Cleanup(upstream.Close) + + tokenPath := filepath.Join(t.TempDir(), "iat") + require.NoError(t, os.WriteFile(tokenPath, []byte("iat-secret-value\n"), 0o600)) + + cache := NewInMemoryDCRCredentialStore() + issuer := upstream.URL + rc := &authserver.OAuth2UpstreamRunConfig{ + Scopes: []string{"openid"}, + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: issuer + "/.well-known/oauth-authorization-server", + InitialAccessTokenFile: tokenPath, + }, + } + + _, err := resolveDCRCredentials(context.Background(), rc, issuer, cache) + require.Error(t, err, "registration must fail when the upstream returns a redirect") + assert.ErrorIs(t, err, errDCRRedirectRefused, + "the resolver must refuse to follow registration-endpoint redirects") + + foreignMu.Lock() + defer foreignMu.Unlock() + assert.EqualValues(t, 0, atomic.LoadInt32(&foreignHits), + "foreign origin must receive zero requests; got %v Authorization headers: %v", + atomic.LoadInt32(&foreignHits), foreignAuthHeaders) +} + +func TestResolveDCRCredentials_AuthMethodPreference(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + supported []string + // codeChallenge is the upstream's advertised + // code_challenge_methods_supported. Required by the gating in + // selectTokenEndpointAuthMethod whenever the test expects "none". + codeChallenge []string + expected string + }{ + { + name: "prefers client_secret_basic over none", + supported: []string{"none", "client_secret_basic"}, + expected: "client_secret_basic", + }, + { + name: "prefers private_key_jwt over others", + supported: []string{"client_secret_basic", "private_key_jwt", "none"}, + expected: "private_key_jwt", + }, + { + name: "falls back to none when only none supported and S256 advertised", + supported: []string{"none"}, + codeChallenge: []string{"S256"}, + expected: "none", + }, + { + name: "defaults to client_secret_basic when metadata omits the field", + supported: nil, + expected: "client_secret_basic", + }, + { + name: "prefers client_secret_basic over client_secret_post", + supported: []string{"client_secret_post", "client_secret_basic"}, + expected: "client_secret_basic", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + server := newDCRTestServer(t, dcrTestHandlerConfig{ + tokenEndpointAuthMethodsSupported: tc.supported, + codeChallengeMethodsSupported: tc.codeChallenge, + }) + cache := NewInMemoryDCRCredentialStore() + issuer := server.URL + rc := &authserver.OAuth2UpstreamRunConfig{ + Scopes: []string{"openid"}, + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: issuer + "/.well-known/oauth-authorization-server", + }, + } + + res, err := resolveDCRCredentials(context.Background(), rc, issuer, cache) + require.NoError(t, err) + assert.Equal(t, tc.expected, res.TokenEndpointAuthMethod) + }) + } +} + +// TestResolveDCRCredentials_RefusesNoneWithoutS256 pins the compliance gate +// added for the "none" auth method: an upstream that advertises only "none" +// for token_endpoint_auth_methods but does not advertise S256 in +// code_challenge_methods_supported must be rejected at boot rather than +// quietly registering a public client without RFC 7636 / OAuth 2.1 PKCE. +func TestResolveDCRCredentials_RefusesNoneWithoutS256(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + codeChallenge []string + }{ + {name: "code_challenge_methods_supported omitted", codeChallenge: nil}, + {name: "code_challenge_methods_supported lists only plain", codeChallenge: []string{"plain"}}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + server := newDCRTestServer(t, dcrTestHandlerConfig{ + tokenEndpointAuthMethodsSupported: []string{"none"}, + codeChallengeMethodsSupported: tc.codeChallenge, + }) + cache := NewInMemoryDCRCredentialStore() + issuer := server.URL + rc := &authserver.OAuth2UpstreamRunConfig{ + Scopes: []string{"openid"}, + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: issuer + "/.well-known/oauth-authorization-server", + }, + } + + _, err := resolveDCRCredentials(context.Background(), rc, issuer, cache) + require.Error(t, err) + assert.Contains(t, err.Error(), "S256", + "error must mention the missing S256 advertisement so operators can correlate") + assert.Contains(t, err.Error(), "RFC 7636", + "error must cite the spec being enforced") + }) + } +} + +func TestResolveDCRCredentials_EmptyAuthMethodIntersectionErrors(t *testing.T) { + t.Parallel() + + // Configure the server to advertise an unknown method so intersection is empty. + server := newDCRTestServer(t, dcrTestHandlerConfig{ + tokenEndpointAuthMethodsSupported: []string{"tls_client_auth"}, + }) + cache := NewInMemoryDCRCredentialStore() + issuer := server.URL + rc := &authserver.OAuth2UpstreamRunConfig{ + Scopes: []string{"openid"}, + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: issuer + "/.well-known/oauth-authorization-server", + }, + } + _, err := resolveDCRCredentials(context.Background(), rc, issuer, cache) + require.Error(t, err) + assert.Contains(t, err.Error(), "no supported token_endpoint_auth_method") +} + +func TestResolveDCRCredentials_SynthesisedRegistrationEndpoint(t *testing.T) { + t.Parallel() + + // registrationEndpointPath="/register" is the synthesised path the + // resolver will construct when metadata omits registration_endpoint. + var gotPath string + server := newDCRTestServer(t, dcrTestHandlerConfig{ + omitRegistrationEndpoint: true, + registrationEndpointPath: "/register", + observeRegistration: func(r *http.Request, _ []byte) { + gotPath = r.URL.Path + }, + }) + cache := NewInMemoryDCRCredentialStore() + issuer := server.URL + rc := &authserver.OAuth2UpstreamRunConfig{ + Scopes: []string{"openid"}, + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: issuer + "/.well-known/oauth-authorization-server", + }, + } + + res, err := resolveDCRCredentials(context.Background(), rc, issuer, cache) + require.NoError(t, err) + assert.Equal(t, "test-client-id", res.ClientID) + assert.Equal(t, "/register", gotPath) +} + +func TestResolveDCRCredentials_RegistrationEndpointDirectBypassesDiscovery(t *testing.T) { + t.Parallel() + + var registrationHits int32 + var discoveryHits int32 + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/oauth-authorization-server", func(_ http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&discoveryHits, 1) + }) + mux.HandleFunc("/custom/register", func(w http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(®istrationHits, 1) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte(`{"client_id":"direct-id"}`)) + }) + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + cache := NewInMemoryDCRCredentialStore() + issuer := server.URL + rc := &authserver.OAuth2UpstreamRunConfig{ + AuthorizationEndpoint: issuer + "/authorize", + TokenEndpoint: issuer + "/token", + Scopes: []string{"openid"}, + DCRConfig: &authserver.DCRUpstreamConfig{ + RegistrationEndpoint: issuer + "/custom/register", + }, + } + + res, err := resolveDCRCredentials(context.Background(), rc, issuer, cache) + require.NoError(t, err) + assert.Equal(t, "direct-id", res.ClientID) + assert.Equal(t, int32(0), atomic.LoadInt32(&discoveryHits), + "discovery endpoint must not be contacted when RegistrationEndpoint is set") + assert.Equal(t, int32(1), atomic.LoadInt32(®istrationHits)) +} + +// TestResolveDCRCredentials_RejectsInvalidInputs covers every branch of +// validateResolveInputs in one place: nil run-config, pre-provisioned +// ClientID, missing DCRConfig, empty issuer, and nil credential store. The +// previous split into two single-branch tests left three branches uncovered. +func TestResolveDCRCredentials_RejectsInvalidInputs(t *testing.T) { + t.Parallel() + + validCfg := &authserver.DCRUpstreamConfig{ + RegistrationEndpoint: "https://example.com/register", + } + + tests := []struct { + name string + rc *authserver.OAuth2UpstreamRunConfig + issuer string + cache DCRCredentialStore + wantErrSub string + }{ + { + name: "nil run-config", + rc: nil, + issuer: "https://example.com", + cache: NewInMemoryDCRCredentialStore(), + wantErrSub: "oauth2 upstream run-config is required", + }, + { + name: "pre-provisioned client_id", + rc: &authserver.OAuth2UpstreamRunConfig{ClientID: "preprovisioned", DCRConfig: validCfg}, + issuer: "https://example.com", + cache: NewInMemoryDCRCredentialStore(), + wantErrSub: "pre-provisioned", + }, + { + name: "missing dcr_config", + rc: &authserver.OAuth2UpstreamRunConfig{}, + issuer: "https://example.com", + cache: NewInMemoryDCRCredentialStore(), + wantErrSub: "no dcr_config", + }, + { + name: "empty issuer", + rc: &authserver.OAuth2UpstreamRunConfig{DCRConfig: validCfg}, + issuer: "", + cache: NewInMemoryDCRCredentialStore(), + wantErrSub: "issuer is required", + }, + { + name: "nil cache", + rc: &authserver.OAuth2UpstreamRunConfig{DCRConfig: validCfg}, + issuer: "https://example.com", + cache: nil, + wantErrSub: "credential store is required", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + _, err := resolveDCRCredentials(context.Background(), tc.rc, tc.issuer, tc.cache) + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErrSub) + }) + } +} + +func TestNeedsDCR(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + rc *authserver.OAuth2UpstreamRunConfig + expected bool + }{ + {name: "nil", rc: nil, expected: false}, + {name: "empty client_id and dcr_config", rc: &authserver.OAuth2UpstreamRunConfig{ + DCRConfig: &authserver.DCRUpstreamConfig{}, + }, expected: true}, + {name: "client_id without dcr", rc: &authserver.OAuth2UpstreamRunConfig{ + ClientID: "x", + }, expected: false}, + {name: "client_id wins over dcr_config (defensive AND semantic)", rc: &authserver.OAuth2UpstreamRunConfig{ + ClientID: "x", + DCRConfig: &authserver.DCRUpstreamConfig{}, + }, expected: false}, + {name: "both empty", rc: &authserver.OAuth2UpstreamRunConfig{}, expected: false}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.expected, needsDCR(tc.rc)) + }) + } +} + +func TestApplyResolution_RespectsExplicitEndpoints(t *testing.T) { + t.Parallel() + + rc := &authserver.OAuth2UpstreamRunConfig{ + AuthorizationEndpoint: "https://explicit/authorize", + TokenEndpoint: "https://explicit/token", + } + res := &DCRResolution{ + ClientID: "got-client", + AuthorizationEndpoint: "https://discovered/authorize", + TokenEndpoint: "https://discovered/token", + } + applyResolution(rc, res) + assert.Equal(t, "got-client", rc.ClientID) + assert.Equal(t, "https://explicit/authorize", rc.AuthorizationEndpoint) + assert.Equal(t, "https://explicit/token", rc.TokenEndpoint) +} + +func TestApplyResolution_FillsMissingEndpoints(t *testing.T) { + t.Parallel() + + rc := &authserver.OAuth2UpstreamRunConfig{} + res := &DCRResolution{ + ClientID: "got-client", + AuthorizationEndpoint: "https://discovered/authorize", + TokenEndpoint: "https://discovered/token", + } + applyResolution(rc, res) + assert.Equal(t, "got-client", rc.ClientID) + assert.Equal(t, "https://discovered/authorize", rc.AuthorizationEndpoint) + assert.Equal(t, "https://discovered/token", rc.TokenEndpoint) +} + +func TestResolveUpstreamRedirectURI(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + configured string + issuer string + expect string + wantErr bool + }{ + { + name: "defaults from issuer", + configured: "", + issuer: "https://idp.example.com", + expect: "https://idp.example.com/oauth/callback", + }, + { + name: "explicit https accepted", + configured: "https://app.example.com/cb", + issuer: "https://idp.example.com", + expect: "https://app.example.com/cb", + }, + { + name: "explicit loopback http accepted", + configured: "http://localhost:8080/cb", + issuer: "https://idp.example.com", + expect: "http://localhost:8080/cb", + }, + { + name: "explicit http non-loopback rejected", + configured: "http://evil.example.com/cb", + issuer: "https://idp.example.com", + wantErr: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, err := resolveUpstreamRedirectURI(tc.configured, tc.issuer) + if tc.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tc.expect, got) + }) + } +} + +// TestResolveDCRCredentials_DiscoveryURLHonoured verifies that the resolver +// fetches the operator-configured discovery URL exactly, rather than +// deriving well-known paths from the issuer. This is the behaviour that +// matters for multi-tenant IdPs where the configured URL and the +// issuer-derived paths disagree. +func TestResolveDCRCredentials_DiscoveryURLHonoured(t *testing.T) { + t.Parallel() + + var discoveryPath string + var discoveryHits int32 + var wellKnownHits int32 + mux := http.NewServeMux() + // Mount well-known endpoints as tripwires — they must NOT be contacted + // when DiscoveryURL points elsewhere. + mux.HandleFunc("/.well-known/oauth-authorization-server", func(_ http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&wellKnownHits, 1) + }) + mux.HandleFunc("/.well-known/openid-configuration", func(_ http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&wellKnownHits, 1) + }) + // Mount the operator-configured discovery URL at a tenant-aware path + // that the well-known fallback would never derive from the issuer. + var server *httptest.Server + mux.HandleFunc("/tenants/acme/metadata", func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&discoveryHits, 1) + discoveryPath = r.URL.Path + md := oauthproto.AuthorizationServerMetadata{ + Issuer: server.URL, + AuthorizationEndpoint: server.URL + "/authorize", + TokenEndpoint: server.URL + "/token", + JWKSURI: server.URL + "/jwks", + RegistrationEndpoint: server.URL + "/register", + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(md) + }) + mux.HandleFunc("/register", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte(`{"client_id":"tenant-client"}`)) + }) + server = httptest.NewServer(mux) + t.Cleanup(server.Close) + + cache := NewInMemoryDCRCredentialStore() + issuer := server.URL + rc := &authserver.OAuth2UpstreamRunConfig{ + Scopes: []string{"openid"}, + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: issuer + "/tenants/acme/metadata", + }, + } + + res, err := resolveDCRCredentials(context.Background(), rc, issuer, cache) + require.NoError(t, err) + assert.Equal(t, "tenant-client", res.ClientID) + assert.Equal(t, int32(1), atomic.LoadInt32(&discoveryHits), + "DiscoveryURL must be fetched exactly once") + assert.Equal(t, "/tenants/acme/metadata", discoveryPath, + "resolver must fetch the operator-configured DiscoveryURL, not a derived well-known path") + assert.Equal(t, int32(0), atomic.LoadInt32(&wellKnownHits), + "well-known discovery fallback must NOT be contacted when DiscoveryURL is set") +} + +// TestResolveDCRCredentials_DiscoveryURLIssuerMismatchRejected verifies that +// the resolver enforces RFC 8414 §3.3 issuer equality even when the caller +// pins the discovery URL — a document that advertises a different issuer is +// rejected. +func TestResolveDCRCredentials_DiscoveryURLIssuerMismatchRejected(t *testing.T) { + t.Parallel() + + mux := http.NewServeMux() + mux.HandleFunc("/metadata", func(w http.ResponseWriter, _ *http.Request) { + // Advertise a different issuer than the caller's. + md := oauthproto.AuthorizationServerMetadata{ + Issuer: "https://different.example.com", + TokenEndpoint: "https://different.example.com/token", + RegistrationEndpoint: "https://different.example.com/register", + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(md) + }) + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + cache := NewInMemoryDCRCredentialStore() + issuer := server.URL + rc := &authserver.OAuth2UpstreamRunConfig{ + Scopes: []string{"openid"}, + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: issuer + "/metadata", + }, + } + + _, err := resolveDCRCredentials(context.Background(), rc, issuer, cache) + require.Error(t, err) + assert.Contains(t, err.Error(), "issuer mismatch") +} + +// TestResolveDCRCredentials_DiscoveredScopesFallback verifies that when the +// caller leaves rc.Scopes empty, the resolver sends the scopes advertised +// by the upstream in scopes_supported. +func TestResolveDCRCredentials_DiscoveredScopesFallback(t *testing.T) { + t.Parallel() + + var gotBody []byte + server := newDCRTestServer(t, dcrTestHandlerConfig{ + scopesSupported: []string{"openid", "profile", "email"}, + observeRegistration: func(_ *http.Request, body []byte) { + gotBody = body + }, + }) + cache := NewInMemoryDCRCredentialStore() + issuer := server.URL + rc := &authserver.OAuth2UpstreamRunConfig{ + // Scopes intentionally left empty so the resolver falls back to + // the discovered scopes_supported. + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: issuer + "/.well-known/oauth-authorization-server", + }, + } + + _, err := resolveDCRCredentials(context.Background(), rc, issuer, cache) + require.NoError(t, err) + + var req oauthproto.DynamicClientRegistrationRequest + require.NoError(t, json.Unmarshal(gotBody, &req)) + assert.ElementsMatch(t, []string{"openid", "profile", "email"}, []string(req.Scopes), + "registration request must carry the discovered scopes_supported") +} + +// TestResolveDCRCredentials_EmptyScopesOmitted verifies that when neither +// rc.Scopes nor metadata.ScopesSupported provides any scopes, the +// registration succeeds and the request body omits the scope field. +func TestResolveDCRCredentials_EmptyScopesOmitted(t *testing.T) { + t.Parallel() + + var gotBody []byte + server := newDCRTestServer(t, dcrTestHandlerConfig{ + // Neither scopesSupported nor rc.Scopes — the "empty scope" branch. + observeRegistration: func(_ *http.Request, body []byte) { + gotBody = body + }, + }) + cache := NewInMemoryDCRCredentialStore() + issuer := server.URL + rc := &authserver.OAuth2UpstreamRunConfig{ + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: issuer + "/.well-known/oauth-authorization-server", + }, + } + + res, err := resolveDCRCredentials(context.Background(), rc, issuer, cache) + require.NoError(t, err) + assert.Equal(t, "test-client-id", res.ClientID) + + // The scope field must be omitted (omitempty) rather than sent as an + // empty string — an empty string would violate RFC 7591 §2, and + // ScopeList's MarshalJSON correctly relies on omitempty. + var raw map[string]any + require.NoError(t, json.Unmarshal(gotBody, &raw)) + _, present := raw["scope"] + assert.False(t, present, "registration request must omit the scope field when no scopes are configured") +} + +// TestResolveDCRCredentials_UpstreamIssuerDerivedFromDiscoveryURL verifies +// the production case: the function-param `issuer` (this auth server's +// issuer) differs from the upstream's issuer, and the resolver still +// completes DCR by deriving the upstream's expected issuer from the +// DiscoveryURL itself rather than reusing the caller-supplied issuer for +// RFC 8414 §3.3 verification. +// +// Pre-fix this test would have failed with `issuer mismatch (RFC 8414 §3.3): +// expected "https://our-auth.example", got ""`, because the +// resolver used the caller's issuer as expectedIssuer. +func TestResolveDCRCredentials_UpstreamIssuerDerivedFromDiscoveryURL(t *testing.T) { + t.Parallel() + + server := newDCRTestServer(t, dcrTestHandlerConfig{ + tokenEndpointAuthMethodsSupported: []string{"client_secret_basic"}, + }) + cache := NewInMemoryDCRCredentialStore() + + // Caller-supplied issuer names this auth server, NOT the upstream. + // Production wiring always passes its own issuer here (see + // embeddedauthserver.go: buildUpstreamConfigs(... cfg.Issuer ...)). + ourIssuer := "https://our-auth.example.com" + + rc := &authserver.OAuth2UpstreamRunConfig{ + // Explicit redirect URI so the resolver does not try to default + // it from ourIssuer (which would still work, but isolating the + // concern under test keeps the failure mode crisp). + RedirectURI: "https://our-auth.example.com/oauth/callback", + Scopes: []string{"openid"}, + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: server.URL + "/.well-known/oauth-authorization-server", + }, + } + + res, err := resolveDCRCredentials(context.Background(), rc, ourIssuer, cache) + require.NoError(t, err, + "resolver must derive expectedIssuer from DiscoveryURL, not from the caller's issuer") + assert.Equal(t, "test-client-id", res.ClientID) + assert.Equal(t, server.URL+"/authorize", res.AuthorizationEndpoint) + assert.Equal(t, server.URL+"/token", res.TokenEndpoint) +} + +func TestDeriveExpectedIssuerFromDiscoveryURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + discoveryURL string + want string + wantErr bool + }{ + { + name: "oauth well-known suffix at host root", + discoveryURL: "https://mcp.atlassian.com/.well-known/oauth-authorization-server", + want: "https://mcp.atlassian.com", + }, + { + name: "oidc well-known suffix at host root", + discoveryURL: "https://accounts.example.com/.well-known/openid-configuration", + want: "https://accounts.example.com", + }, + { + name: "oauth well-known suffix with tenant path prefix", + discoveryURL: "https://idp.example.com/tenants/acme/.well-known/oauth-authorization-server", + want: "https://idp.example.com/tenants/acme", + }, + { + name: "oidc well-known suffix with tenant path prefix", + discoveryURL: "https://idp.example.com/tenants/acme/.well-known/openid-configuration", + want: "https://idp.example.com/tenants/acme", + }, + { + name: "non-well-known path falls back to origin", + discoveryURL: "https://idp.example.com/tenants/acme/metadata", + want: "https://idp.example.com", + }, + { + name: "query and fragment are stripped", + discoveryURL: "https://idp.example.com/.well-known/oauth-authorization-server?x=1#frag", + want: "https://idp.example.com", + }, + { + name: "empty url is rejected", + discoveryURL: "", + wantErr: true, + }, + { + name: "missing scheme is rejected", + discoveryURL: "idp.example.com/.well-known/oauth-authorization-server", + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, err := deriveExpectedIssuerFromDiscoveryURL(tc.discoveryURL) + if tc.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tc.want, got) + }) + } +} + +// countingStore is a DCRCredentialStore decorator that counts the number of +// Get calls that returned a hit. The singleflight coalescing test uses it +// to assert that no concurrent caller observed a cache hit during the run: +// a hit during the test would mean a goroutine raced past the gate, took +// the cache-lookup short-circuit instead of joining the singleflight, and +// silently weakened the test's coverage. +type countingStore struct { + inner DCRCredentialStore + hits atomic.Int32 +} + +func (c *countingStore) Get(ctx context.Context, key DCRKey) (*DCRResolution, bool, error) { + res, ok, err := c.inner.Get(ctx, key) + if ok { + c.hits.Add(1) + } + return res, ok, err +} + +func (c *countingStore) Put(ctx context.Context, key DCRKey, res *DCRResolution) error { + return c.inner.Put(ctx, key, res) +} + +// TestResolveDCRCredentials_SingleflightCoalescesConcurrentCallers pins the +// behaviour that N concurrent callers for the same DCRKey result in exactly +// one RegisterClientDynamically call against the upstream — preventing the +// orphaned-registration class of bug raised in PR #5042 review. +// +// "Exactly one registration" is necessary but not sufficient to prove the +// singleflight coalescing path actually fired: a late-arriving goroutine +// that reached resolveDCRCredentials after the leader's cache.Put would +// short-circuit through lookupCachedResolution, take the cache hit, and +// still leave registrationCalls == 1. A countingStore wrapper makes that +// regression loud — we assert no caller observed a cache hit, so any timing +// slip fails the test instead of silently weakening coverage. +func TestResolveDCRCredentials_SingleflightCoalescesConcurrentCallers(t *testing.T) { + t.Parallel() + + // gate blocks the registration handler until the test releases it, + // guaranteeing all goroutines pile up at the singleflight before any + // has a chance to finish and populate the cache. + gate := make(chan struct{}) + + var registrationCalls int32 + server := newDCRTestServer(t, dcrTestHandlerConfig{ + observeRegistration: func(_ *http.Request, _ []byte) { + <-gate + atomic.AddInt32(®istrationCalls, 1) + }, + }) + + cache := &countingStore{inner: NewInMemoryDCRCredentialStore()} + issuer := server.URL + rc := &authserver.OAuth2UpstreamRunConfig{ + Scopes: []string{"openid", "profile"}, + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: issuer + "/.well-known/oauth-authorization-server", + }, + } + + const N = 8 + results := make([]*DCRResolution, N) + errs := make([]error, N) + var wg sync.WaitGroup + wg.Add(N) + for i := 0; i < N; i++ { + go func(idx int) { + defer wg.Done() + res, err := resolveDCRCredentials(context.Background(), rc, issuer, cache) + results[idx] = res + errs[idx] = err + }(i) + } + + // Release the gate so every blocked handler can proceed. Even if Go + // scheduled the leader's handler concurrently with the followers' + // arrival, only the leader actually invokes the handler — the followers + // wait inside singleflight.Do. + // + // 250 ms gives every goroutine slack to reach singleflight.Do under CI + // load before the gate releases. If this still races, the countingStore + // assertion below fails loudly rather than silently weakening coverage. + time.Sleep(250 * time.Millisecond) + close(gate) + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for concurrent resolveDCRCredentials goroutines") + } + + for i := 0; i < N; i++ { + require.NoError(t, errs[i], "goroutine %d errored", i) + require.NotNil(t, results[i], "goroutine %d got nil resolution", i) + assert.Equal(t, "test-client-id", results[i].ClientID) + } + assert.EqualValues(t, 1, atomic.LoadInt32(®istrationCalls), + "expected exactly one registration despite %d concurrent callers; got %d", + N, atomic.LoadInt32(®istrationCalls)) + assert.EqualValues(t, 0, cache.hits.Load(), + "no goroutine should have observed a cache hit; if any did, the gate window "+ + "was too short and a late-arriver took the lookupCachedResolution "+ + "short-circuit instead of exercising the singleflight coalescing path") +} + +// TestSynthesiseRegistrationEndpoint_PreservesIssuerPath guards the fix for +// PR #5042 review comment #2: an issuer with a tenant prefix must surface +// in the synthesised registration URL so DCR-on-multi-tenant providers +// register at the correct tenant-aware path. +func TestSynthesiseRegistrationEndpoint_PreservesIssuerPath(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + issuer string + want string + }{ + { + name: "host-only issuer", + issuer: "https://idp.example.com", + want: "https://idp.example.com/register", + }, + { + name: "trailing slash on host-only issuer is normalised", + issuer: "https://idp.example.com/", + want: "https://idp.example.com/register", + }, + { + name: "tenant prefix preserved", + issuer: "https://idp.example.com/tenants/acme", + want: "https://idp.example.com/tenants/acme/register", + }, + { + name: "tenant prefix with trailing slash normalised", + issuer: "https://idp.example.com/tenants/acme/", + want: "https://idp.example.com/tenants/acme/register", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, err := synthesiseRegistrationEndpoint(tc.issuer) + require.NoError(t, err) + assert.Equal(t, tc.want, got) + }) + } +} + +// TestResolveUpstreamRedirectURI_PreservesIssuerPath is the companion to +// TestSynthesiseRegistrationEndpoint_PreservesIssuerPath for the redirect +// URI defaulting path: a tenant-prefixed issuer must not get its path +// stripped when /oauth/callback is appended. +func TestResolveUpstreamRedirectURI_PreservesIssuerPath(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + issuer string + want string + }{ + { + name: "host-only issuer", + issuer: "https://thv.example.com", + want: "https://thv.example.com/oauth/callback", + }, + { + name: "tenant prefix preserved", + issuer: "https://thv.example.com/tenants/acme", + want: "https://thv.example.com/tenants/acme/oauth/callback", + }, + { + name: "trailing slash normalised", + issuer: "https://thv.example.com/tenants/acme/", + want: "https://thv.example.com/tenants/acme/oauth/callback", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, err := resolveUpstreamRedirectURI("", tc.issuer) + require.NoError(t, err) + assert.Equal(t, tc.want, got) + }) + } +} + +// TestApplyResolution_DoesNotOverwritePreProvisionedClientID verifies the +// defence-in-depth in applyResolution: a caller that bypasses +// validateResolveInputs and invokes applyResolution directly with a +// pre-provisioned ClientID does not have it silently clobbered. +func TestApplyResolution_DoesNotOverwritePreProvisionedClientID(t *testing.T) { + t.Parallel() + + rc := &authserver.OAuth2UpstreamRunConfig{ + ClientID: "pre-provisioned", + } + res := &DCRResolution{ + ClientID: "would-be-overwrite", + } + applyResolution(rc, res) + assert.Equal(t, "pre-provisioned", rc.ClientID, + "applyResolution must not overwrite a non-empty ClientID") +} + +// TestResolveDCREndpoints_DirectRegistrationEndpointValidated covers +// PR #5042 review comment #10: the cfg.RegistrationEndpoint short-circuit +// branch validates the URL locally before performRegistration constructs a +// bearer-token transport for it. Non-HTTPS or malformed values must be +// rejected up front, not deep inside oauthproto. +func TestResolveDCREndpoints_DirectRegistrationEndpointValidated(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + registrationEndpoint string + wantErrSub string + }{ + { + name: "http non-loopback rejected", + registrationEndpoint: "http://idp.example.com/register", + wantErrSub: "must use https", + }, + { + name: "missing scheme rejected", + registrationEndpoint: "idp.example.com/register", + wantErrSub: "missing scheme or host", + }, + { + name: "loopback http accepted", + registrationEndpoint: "http://127.0.0.1:8080/register", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + cfg := &authserver.DCRUpstreamConfig{RegistrationEndpoint: tc.registrationEndpoint} + _, err := resolveDCREndpoints(context.Background(), cfg) + if tc.wantErrSub == "" { + require.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErrSub) + }) + } +} + +// TestEndpointsFromMetadata_RejectsInsecureDiscoveredEndpoints covers +// PR #5042 review comment #13: a self-consistent metadata document that +// advertises an http:// authorization or token endpoint must be rejected +// rather than silently flowing through to the auth-code/token-exchange +// path. A compromised TLS connection to the metadata host is the threat +// model. +func TestEndpointsFromMetadata_RejectsInsecureDiscoveredEndpoints(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metadata *oauthproto.AuthorizationServerMetadata + wantErrSub string + }{ + { + name: "http authorization_endpoint rejected", + metadata: &oauthproto.AuthorizationServerMetadata{ + Issuer: "https://idp.example.com", + AuthorizationEndpoint: "http://idp.example.com/authorize", + TokenEndpoint: "https://idp.example.com/token", + RegistrationEndpoint: "https://idp.example.com/register", + }, + wantErrSub: "authorization_endpoint", + }, + { + name: "http token_endpoint rejected", + metadata: &oauthproto.AuthorizationServerMetadata{ + Issuer: "https://idp.example.com", + AuthorizationEndpoint: "https://idp.example.com/authorize", + TokenEndpoint: "http://idp.example.com/token", + RegistrationEndpoint: "https://idp.example.com/register", + }, + wantErrSub: "token_endpoint", + }, + { + name: "missing authorization_endpoint rejected", + metadata: &oauthproto.AuthorizationServerMetadata{ + Issuer: "https://idp.example.com", + TokenEndpoint: "https://idp.example.com/token", + RegistrationEndpoint: "https://idp.example.com/register", + }, + wantErrSub: "authorization_endpoint is required", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + _, err := endpointsFromMetadata(tc.metadata, nil, "https://idp.example.com") + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErrSub) + }) + } +} + +// failingDCRStore is a test double whose Get and Put always fail. Used by +// TestResolveDCRCredentials_CacheFailureWraps* below to exercise the wrap +// messages that operators see when the store backend errors at runtime. +type failingDCRStore struct { + getErr error + putErr error +} + +func (f failingDCRStore) Get(_ context.Context, _ DCRKey) (*DCRResolution, bool, error) { + if f.getErr != nil { + return nil, false, f.getErr + } + return nil, false, nil +} + +func (f failingDCRStore) Put(_ context.Context, _ DCRKey, _ *DCRResolution) error { + return f.putErr +} + +// TestResolveDCRCredentials_CacheGetFailureWrapped covers PR #5042 review +// comment #12 for the cache.Get error path. When the store backend fails +// (e.g. a Redis network error in Phase 3), the resolver wraps the error +// with the operator-debugging contract message "dcr: cache lookup". +func TestResolveDCRCredentials_CacheGetFailureWrapped(t *testing.T) { + t.Parallel() + + storeErr := errors.New("simulated backend failure") + store := failingDCRStore{getErr: storeErr} + + rc := &authserver.OAuth2UpstreamRunConfig{ + DCRConfig: &authserver.DCRUpstreamConfig{ + RegistrationEndpoint: "https://idp.example.com/register", + }, + } + + _, err := resolveDCRCredentials(context.Background(), rc, "https://idp.example.com", store) + require.Error(t, err) + assert.ErrorIs(t, err, storeErr, + "cache.Get error must be wrapped with %%w so callers can inspect the cause") + assert.Contains(t, err.Error(), "dcr: cache lookup", + "the wrap message is part of the operator-debugging contract") +} + +// TestResolveDCRCredentials_CachePutFailureWrapped covers PR #5042 review +// comment #12 for the cache.Put error path. The path runs after a +// successful registration, so we route the test through a real upstream +// httptest server and only make Put fail. +func TestResolveDCRCredentials_CachePutFailureWrapped(t *testing.T) { + t.Parallel() + + server := newDCRTestServer(t, dcrTestHandlerConfig{}) + + storeErr := errors.New("simulated put backend failure") + store := failingDCRStore{putErr: storeErr} + + rc := &authserver.OAuth2UpstreamRunConfig{ + Scopes: []string{"openid"}, + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: server.URL + "/.well-known/oauth-authorization-server", + }, + } + + _, err := resolveDCRCredentials(context.Background(), rc, server.URL, store) + require.Error(t, err) + assert.ErrorIs(t, err, storeErr, + "cache.Put error must be wrapped with %%w so callers can inspect the cause") + assert.Contains(t, err.Error(), "dcr: cache put", + "the wrap message is part of the operator-debugging contract") +} + +// TestBuildResolution_PopulatesRFC7591ExpiryFields covers the conversion of +// the int64 epoch fields client_id_issued_at and client_secret_expires_at +// into time.Time on DCRResolution. The wire convention "0 means absent / +// does not expire" is preserved as the zero time.Time. +func TestBuildResolution_PopulatesRFC7591ExpiryFields(t *testing.T) { + t.Parallel() + + const ( + issuedEpoch int64 = 1_700_000_000 // 2023-11-14T22:13:20Z + expiresEpoch int64 = 1_800_000_000 // 2027-01-15T08:00:00Z + ) + + tests := []struct { + name string + issuedAt int64 + expiresAt int64 + wantIssuedAt time.Time + wantExpiresAt time.Time + }{ + { + name: "both fields populated", + issuedAt: issuedEpoch, + expiresAt: expiresEpoch, + wantIssuedAt: time.Unix(issuedEpoch, 0).UTC(), + wantExpiresAt: time.Unix(expiresEpoch, 0).UTC(), + }, + { + name: "client_secret_expires_at zero means does-not-expire", + issuedAt: issuedEpoch, + expiresAt: 0, + wantIssuedAt: time.Unix(issuedEpoch, 0).UTC(), + wantExpiresAt: time.Time{}, + }, + { + name: "both fields omitted by upstream", + issuedAt: 0, + expiresAt: 0, + wantIssuedAt: time.Time{}, + wantExpiresAt: time.Time{}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + resolution := buildResolution( + &oauthproto.DynamicClientRegistrationResponse{ + ClientID: "id", + ClientSecret: "secret", + ClientIDIssuedAt: tc.issuedAt, + ClientSecretExpiresAt: tc.expiresAt, + }, + &dcrEndpoints{ + authorizationEndpoint: "https://idp.example.com/authorize", + tokenEndpoint: "https://idp.example.com/token", + }, + "client_secret_basic", + ) + assert.Equal(t, tc.wantIssuedAt, resolution.ClientIDIssuedAt) + assert.Equal(t, tc.wantExpiresAt, resolution.ClientSecretExpiresAt) + }) + } +} + +// TestResolveDCRCredentials_RefetchesOnExpiredCachedSecret pins the fix for +// the cache-serves-expired-secrets bug: when an entry's +// ClientSecretExpiresAt has passed, lookupCachedResolution treats it as a +// miss so registerAndCache re-runs and overwrites the stale entry. Without +// this, the cached secret would be served indefinitely past the upstream- +// asserted expiry and every token-endpoint call would 401 with no signal +// back to the resolver. +func TestResolveDCRCredentials_RefetchesOnExpiredCachedSecret(t *testing.T) { + t.Parallel() + + var registrationCalls int32 + server := newDCRTestServer(t, dcrTestHandlerConfig{ + // Issue a secret that expired one minute ago. Every fresh + // registration call will produce an already-expired entry; the + // resolver will refetch on every Resolve as a result. + clientSecretExpiresAt: time.Now().Add(-time.Minute).Unix(), + observeRegistration: func(_ *http.Request, _ []byte) { + atomic.AddInt32(®istrationCalls, 1) + }, + }) + + cache := NewInMemoryDCRCredentialStore() + issuer := server.URL + rc := &authserver.OAuth2UpstreamRunConfig{ + Scopes: []string{"openid"}, + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: issuer + "/.well-known/oauth-authorization-server", + }, + } + + // First call: registers, populates cache with already-expired entry. + res1, err := resolveDCRCredentials(context.Background(), rc, issuer, cache) + require.NoError(t, err) + require.NotNil(t, res1) + require.False(t, res1.ClientSecretExpiresAt.IsZero(), + "upstream advertised an expiry — the resolution must echo it") + require.True(t, time.Now().After(res1.ClientSecretExpiresAt), + "test setup should have produced an already-expired secret") + require.EqualValues(t, 1, atomic.LoadInt32(®istrationCalls)) + + // Second call: the cached entry is expired, so the resolver must refetch. + res2, err := resolveDCRCredentials(context.Background(), rc, issuer, cache) + require.NoError(t, err) + require.NotNil(t, res2) + assert.EqualValues(t, 2, atomic.LoadInt32(®istrationCalls), + "expired cache entry must trigger a re-registration; got %d total calls", + atomic.LoadInt32(®istrationCalls)) +} + +// TestResolveDCRCredentials_HonoursFutureExpiryAndZero pins that +// lookupCachedResolution does NOT refetch when the cached secret is still +// valid — either because the upstream-asserted expiry is in the future, or +// because the upstream omitted client_secret_expires_at (zero ⇒ "does not +// expire" per RFC 7591 §3.2.1). The cache hit path is the hot path and a +// regression here would silently increase upstream load. +func TestResolveDCRCredentials_HonoursFutureExpiryAndZero(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + expiresAt int64 + }{ + {name: "future expiry served from cache", expiresAt: time.Now().Add(time.Hour).Unix()}, + {name: "zero (does not expire) served from cache", expiresAt: 0}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var registrationCalls int32 + server := newDCRTestServer(t, dcrTestHandlerConfig{ + clientSecretExpiresAt: tc.expiresAt, + observeRegistration: func(_ *http.Request, _ []byte) { + atomic.AddInt32(®istrationCalls, 1) + }, + }) + cache := NewInMemoryDCRCredentialStore() + issuer := server.URL + rc := &authserver.OAuth2UpstreamRunConfig{ + Scopes: []string{"openid"}, + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: issuer + "/.well-known/oauth-authorization-server", + }, + } + + _, err := resolveDCRCredentials(context.Background(), rc, issuer, cache) + require.NoError(t, err) + _, err = resolveDCRCredentials(context.Background(), rc, issuer, cache) + require.NoError(t, err) + + assert.EqualValues(t, 1, atomic.LoadInt32(®istrationCalls), + "second call must hit the cache; got %d total registrations", + atomic.LoadInt32(®istrationCalls)) + }) + } +} + +// panickingPutDCRStore is a test double whose Put panics with a fixed +// value. Get is a normal cache miss so callers reach the singleflight +// closure and trigger the panic via cache.Put inside registerAndCache. +type panickingPutDCRStore struct { + panicValue any +} + +func (panickingPutDCRStore) Get(_ context.Context, _ DCRKey) (*DCRResolution, bool, error) { + return nil, false, nil +} + +func (s panickingPutDCRStore) Put(_ context.Context, _ DCRKey, _ *DCRResolution) error { + panic(s.panicValue) +} + +// TestResolveDCRCredentials_RecoversPanicInsideSingleflight pins the +// behaviour that a panic inside the singleflight closure does not propagate +// up as a panic to either the leader goroutine or any of the followers. +// singleflight.Group re-panics the leader's panic in every follower, so +// without the recover N concurrent callers for the same DCRKey would all +// crash with the same value. The defer/recover converts the panic to a +// normal error, the panic is logged at Error with a stack, and every +// caller gets the same wrapped error. +func TestResolveDCRCredentials_RecoversPanicInsideSingleflight(t *testing.T) { + t.Parallel() + + server := newDCRTestServer(t, dcrTestHandlerConfig{}) + store := panickingPutDCRStore{panicValue: "boom"} + + issuer := server.URL + rc := &authserver.OAuth2UpstreamRunConfig{ + Scopes: []string{"openid"}, + DCRConfig: &authserver.DCRUpstreamConfig{ + DiscoveryURL: issuer + "/.well-known/oauth-authorization-server", + }, + } + + const N = 6 + var wg sync.WaitGroup + wg.Add(N) + errs := make([]error, N) + panicked := make([]bool, N) + + for i := 0; i < N; i++ { + go func(idx int) { + defer wg.Done() + defer func() { + // If the recover inside the singleflight closure is + // missing, the panic re-propagates here. Capture it so + // the assertion below produces a clear failure message + // rather than a runtime crash that taints other tests. + if r := recover(); r != nil { + panicked[idx] = true + } + }() + _, errs[idx] = resolveDCRCredentials(context.Background(), rc, issuer, store) + }(i) + } + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for concurrent callers") + } + + for i := 0; i < N; i++ { + require.False(t, panicked[i], + "goroutine %d observed an un-recovered panic from the singleflight closure", i) + require.Error(t, errs[i], + "goroutine %d should have received an error converted from the panic", i) + assert.Contains(t, errs[i].Error(), "panicked", + "goroutine %d's error must mention the panic so operators can correlate; got %q", + i, errs[i].Error()) + assert.Contains(t, errs[i].Error(), "boom", + "goroutine %d's error must include the panic value so the cause is recoverable", i) + } +} diff --git a/pkg/oauthproto/constants.go b/pkg/oauthproto/constants.go index f51caa1048..1312ddb2b1 100644 --- a/pkg/oauthproto/constants.go +++ b/pkg/oauthproto/constants.go @@ -85,3 +85,11 @@ const ( defaultHTTPTimeout = 30 * time.Second maxResponseBodySize = 1 << 20 // 1 MiB — matches x/oauth2/internal/token.go. ) + +// URL scheme constants. +const ( + // schemeHTTPS is the URL scheme required for all OAuth / OIDC endpoints, + // except when the host is a loopback address (development). Unexported + // so the check stays internally consistent within this package. + schemeHTTPS = "https" +) diff --git a/pkg/oauthproto/dcr.go b/pkg/oauthproto/dcr.go index b9a454ff82..0757658f7f 100644 --- a/pkg/oauthproto/dcr.go +++ b/pkg/oauthproto/dcr.go @@ -199,7 +199,7 @@ func validateRegistrationEndpoint(registrationEndpoint string) (*url.URL, error) } // Ensure HTTPS for security (except loopback addresses for development) - if registrationURL.Scheme != "https" && !IsLoopbackHost(registrationURL.Host) { + if registrationURL.Scheme != schemeHTTPS && !IsLoopbackHost(registrationURL.Host) { return nil, fmt.Errorf("registration endpoint must use HTTPS: %s", registrationEndpoint) } @@ -263,12 +263,19 @@ func createHTTPRequest( return req, nil } -// buildHTTPClient returns the caller-supplied client, or a default client if nil. -func buildHTTPClient(client *http.Client) *http.Client { - if client != nil { - return client - } - +// NewDefaultDCRClient returns the canonical bounded *http.Client used by +// RegisterClientDynamically when its caller does not supply one. It is +// exported so callers that need to wrap the transport (for example, to +// inject an RFC 7591 initial access token as an Authorization header) can +// reuse the same timeout policy and benefit automatically from any future +// tightening of these bounds. +// +// Timeouts: +// +// - Overall request timeout: 30 s +// - TLS handshake timeout: 10 s +// - Response-header timeout: 10 s +func NewDefaultDCRClient() *http.Client { return &http.Client{ Timeout: 30 * time.Second, Transport: &http.Transport{ @@ -278,6 +285,14 @@ func buildHTTPClient(client *http.Client) *http.Client { } } +// buildHTTPClient returns the caller-supplied client, or a default client if nil. +func buildHTTPClient(client *http.Client) *http.Client { + if client != nil { + return client + } + return NewDefaultDCRClient() +} + // handleHTTPResponse handles the HTTP response and validates it. func handleHTTPResponse(resp *http.Response) (*DynamicClientRegistrationResponse, error) { defer func() { diff --git a/pkg/oauthproto/discovery.go b/pkg/oauthproto/discovery.go index 042f7291f7..481edb1e59 100644 --- a/pkg/oauthproto/discovery.go +++ b/pkg/oauthproto/discovery.go @@ -144,7 +144,7 @@ func buildDiscoveryURLs(issuer string) ([]string, error) { if parsed.Scheme == "" || parsed.Host == "" { return nil, fmt.Errorf("invalid issuer URL: scheme and host are required") } - if parsed.Scheme != "https" && !IsLoopbackHost(parsed.Host) { + if parsed.Scheme != schemeHTTPS && !IsLoopbackHost(parsed.Host) { return nil, fmt.Errorf("issuer must use https (got %q)", parsed.Scheme) } @@ -201,6 +201,69 @@ func buildDiscoveryHTTPClient(client *http.Client) *http.Client { } } +// FetchAuthorizationServerMetadataFromURL fetches RFC 8414 authorization +// server metadata from a single caller-supplied URL, bypassing the +// well-known-path fallback used by FetchAuthorizationServerMetadata. +// +// It is intended for cases where the operator has configured an explicit +// discovery document URL (for example a multi-tenant IdP that does not +// advertise the tenant-aware path at {issuer}/.well-known/...). The same +// RFC 8414 §3.3 issuer-equality check is enforced: the returned metadata's +// issuer field must exactly match the caller-supplied expectedIssuer. +// +// If client is nil, the same bounded default client used by +// FetchAuthorizationServerMetadata is constructed. A 10 s per-call timeout +// is applied via context.WithTimeout regardless of the caller's context +// deadline. +// +// Return contract mirrors FetchAuthorizationServerMetadata: +// +// - On full success, returns (metadata, nil) with a non-empty +// RegistrationEndpoint. +// - When the document is otherwise valid but omits +// registration_endpoint, returns (metadata, ErrRegistrationEndpointMissing). +// The metadata is non-nil so callers can reuse the other fields. +// - On any other failure (transport/decode error, issuer mismatch), +// returns (nil, err). +func FetchAuthorizationServerMetadataFromURL( + ctx context.Context, + discoveryURL string, + expectedIssuer string, + client *http.Client, +) (*AuthorizationServerMetadata, error) { + if discoveryURL == "" { + return nil, fmt.Errorf("discovery URL is required") + } + if expectedIssuer == "" { + return nil, fmt.Errorf("expected issuer is required") + } + + parsed, err := url.Parse(discoveryURL) + if err != nil { + return nil, fmt.Errorf("invalid discovery URL: %w", err) + } + if parsed.Scheme == "" || parsed.Host == "" { + return nil, fmt.Errorf("invalid discovery URL: scheme and host are required") + } + if parsed.Scheme != schemeHTTPS && !IsLoopbackHost(parsed.Host) { + return nil, fmt.Errorf("discovery URL must use https (got %q)", parsed.Scheme) + } + + httpClient := buildDiscoveryHTTPClient(client) + + fetchCtx, cancel := context.WithTimeout(ctx, discoveryTimeout) + defer cancel() + + metadata, err := fetchDiscoveryDocument(fetchCtx, httpClient, discoveryURL, expectedIssuer) + if err != nil { + return nil, fmt.Errorf("fetch discovery document from %q: %w", discoveryURL, err) + } + if metadata.RegistrationEndpoint == "" { + return metadata, ErrRegistrationEndpointMissing + } + return metadata, nil +} + // fetchDiscoveryDocument performs a single GET against a discovery URL and // returns the parsed AuthorizationServerMetadata, enforcing RFC 8414 §3.3 // issuer equality. diff --git a/pkg/oauthproto/discovery_test.go b/pkg/oauthproto/discovery_test.go index 085cc422cf..c09efac835 100644 --- a/pkg/oauthproto/discovery_test.go +++ b/pkg/oauthproto/discovery_test.go @@ -11,6 +11,7 @@ import ( "net/http/httptest" "strings" "sync" + "sync/atomic" "testing" "time" @@ -493,6 +494,175 @@ func TestFetchAuthorizationServerMetadata_TenantWithEscapedChars(t *testing.T) { assert.Equal(t, issuer, metadata.Issuer) } +func TestFetchAuthorizationServerMetadataFromURL(t *testing.T) { + t.Parallel() + + var issuer string + var wellKnownHits int32 + var customHits int32 + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/oauth-authorization-server", func(_ http.ResponseWriter, _ *http.Request) { + // Tripwire — must not be contacted when caller pins an exact URL. + atomic.AddInt32(&wellKnownHits, 1) + }) + mux.HandleFunc("/tenants/acme/metadata", func(w http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&customHits, 1) + md := AuthorizationServerMetadata{ + Issuer: issuer, + AuthorizationEndpoint: issuer + "/authorize", + TokenEndpoint: issuer + "/token", + JWKSURI: issuer + "/jwks", + RegistrationEndpoint: issuer + "/register", + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(md) + }) + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + issuer = server.URL + + metadata, err := FetchAuthorizationServerMetadataFromURL( + context.Background(), + issuer+"/tenants/acme/metadata", + issuer, + server.Client(), + ) + require.NoError(t, err) + require.NotNil(t, metadata) + assert.Equal(t, issuer, metadata.Issuer) + assert.Equal(t, issuer+"/register", metadata.RegistrationEndpoint) + assert.EqualValues(t, 1, atomic.LoadInt32(&customHits), + "caller-supplied discovery URL must be fetched exactly once") + assert.EqualValues(t, 0, atomic.LoadInt32(&wellKnownHits), + "well-known fallback must not be contacted") +} + +func TestFetchAuthorizationServerMetadataFromURL_IssuerMismatchRejected(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + md := AuthorizationServerMetadata{ + Issuer: "https://different.example.com", + TokenEndpoint: "https://different.example.com/token", + RegistrationEndpoint: "https://different.example.com/register", + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(md) + })) + t.Cleanup(server.Close) + + _, err := FetchAuthorizationServerMetadataFromURL( + context.Background(), + server.URL+"/metadata", + server.URL, // expected issuer disagrees with server-advertised issuer + server.Client(), + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "issuer mismatch") +} + +func TestFetchAuthorizationServerMetadataFromURL_MissingRegistrationEndpoint(t *testing.T) { + t.Parallel() + + var issuer string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + md := AuthorizationServerMetadata{ + Issuer: issuer, + TokenEndpoint: issuer + "/token", + // RegistrationEndpoint intentionally omitted. + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(md) + })) + t.Cleanup(server.Close) + issuer = server.URL + + metadata, err := FetchAuthorizationServerMetadataFromURL( + context.Background(), + issuer+"/metadata", + issuer, + server.Client(), + ) + require.ErrorIs(t, err, ErrRegistrationEndpointMissing) + require.NotNil(t, metadata) + assert.Equal(t, issuer, metadata.Issuer) + assert.Equal(t, issuer+"/token", metadata.TokenEndpoint) + assert.Empty(t, metadata.RegistrationEndpoint) +} + +func TestFetchAuthorizationServerMetadataFromURL_InputValidation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + discoveryURL string + issuer string + wantErrMsg string + }{ + { + name: "empty discovery URL", + discoveryURL: "", + issuer: "https://example.com", + wantErrMsg: "discovery URL is required", + }, + { + name: "empty issuer", + discoveryURL: "https://example.com/metadata", + issuer: "", + wantErrMsg: "expected issuer is required", + }, + { + name: "http non-loopback discovery URL rejected", + discoveryURL: "http://example.com/metadata", + issuer: "http://example.com", + wantErrMsg: "discovery URL must use https", + }, + { + name: "missing scheme", + discoveryURL: "example.com/metadata", + issuer: "https://example.com", + wantErrMsg: "scheme and host are required", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + _, err := FetchAuthorizationServerMetadataFromURL( + context.Background(), tc.discoveryURL, tc.issuer, nil, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErrMsg) + }) + } +} + +func TestFetchAuthorizationServerMetadataFromURL_AllowsLoopbackHTTP(t *testing.T) { + t.Parallel() + + var issuer string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + md := AuthorizationServerMetadata{ + Issuer: issuer, + TokenEndpoint: issuer + "/token", + RegistrationEndpoint: issuer + "/register", + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(md) + })) + t.Cleanup(server.Close) + issuer = server.URL + + metadata, err := FetchAuthorizationServerMetadataFromURL( + context.Background(), + issuer+"/metadata", + issuer, + server.Client(), + ) + require.NoError(t, err) + require.NotNil(t, metadata) + assert.Equal(t, issuer, metadata.Issuer) +} + // TestFetchAuthorizationServerMetadata_DedupesPathInsertionAndBare locks in // the documented behavior that, for a tenant-less issuer, the path-insertion // (1) and bare RFC 8414 (3) URLs collapse to the same request, so only two