diff --git a/pkg/auth/monitored_token_source.go b/pkg/auth/monitored_token_source.go index 75e0d0f395..333a16dfa1 100644 --- a/pkg/auth/monitored_token_source.go +++ b/pkg/auth/monitored_token_source.go @@ -9,6 +9,7 @@ import ( "fmt" "log/slog" "net" + "net/http" "os" "strconv" "strings" @@ -249,11 +250,11 @@ func (mts *MonitoredTokenSource) Stopped() <-chan struct{} { return mts.stopped } -// Token retrieves a token, retrying with exponential backoff on transient errors -// (see isTransientNetworkError for the full list). On non-transient errors -// (OAuth 4xx, TLS failures) it marks the workload as unauthenticated and returns -// immediately. Context cancellation (workload removal) stops the retry without -// marking the workload as unauthenticated. +// Token retrieves a token, retrying with exponential backoff on transient +// errors and marking the workload as unauthenticated on non-transient errors. +// See isTransientNetworkError for the classification rule. Context +// cancellation (workload removal) stops the retry without marking the +// workload as unauthenticated. // // Concurrent callers are deduplicated via singleflight so that only one retry // loop runs at a time during transient failures. @@ -342,30 +343,28 @@ func (mts *MonitoredTokenSource) onTick() (bool, time.Duration) { return false, wait } -// isTransientNetworkError reports whether err represents a transient condition -// (DNS failure, TCP transport error, timeout, OAuth server 5xx, unparsable -// token response) that is likely to resolve on its own. +// isTransientNetworkError reports whether err represents a transient +// condition that is likely to resolve on its own. The categories are: // -// OAuth2 client-level auth failures (invalid_grant, 401, 400) and TLS errors -// (certificate verification, handshake failure) are NOT considered transient and -// return false so the workload is marked unauthenticated immediately. +// - Network-level failures: DNS lookup errors, TCP transport errors, +// timeouts. +// - OAuth token-endpoint responses classified as transient by +// classifyOAuthRetrieveError. +// - Unparsable token responses on a 2xx status (typically an HTML page +// from a load balancer or CDN). +// +// All other errors return false, causing the workload to be marked +// unauthenticated. TLS failures (certificate verification, handshake +// failure) are intentionally non-transient even though they surface +// through net.OpError like transport-level errors. func isTransientNetworkError(err error) bool { if err == nil || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return false } - // OAuth HTTP-level errors: 5xx (Bad Gateway, Service Unavailable, Gateway - // Timeout) are transient server-side issues that typically resolve on their - // own. 4xx errors (invalid_grant, invalid_client) are permanent auth failures. if retrieveErr, ok := errors.AsType[*oauth2.RetrieveError](err); ok { - if retrieveErr.Response != nil && retrieveErr.Response.StatusCode >= 500 { - slog.Debug("treating OAuth server error as transient", - "status_code", retrieveErr.Response.StatusCode, - ) - return true - } - return false + return classifyOAuthRetrieveError(retrieveErr) } // Non-JSON responses from the OAuth server (e.g. load balancer HTML pages). @@ -399,6 +398,54 @@ func isTransientNetworkError(err error) bool { return false } +// classifyOAuthRetrieveError reports whether an *oauth2.RetrieveError should +// be treated as transient. The classification rules are: +// +// - nil Response: non-transient. There is no signal to act on, so we fall +// through to the unauthenticated path rather than retry blindly. +// - 5xx status: transient (server-side issue, likely to resolve). +// - 429 Too Many Requests: transient regardless of body (HTTP standard). +// - 4xx with an empty ErrorCode: transient. The oauth2 library populates +// ErrorCode from the RFC 6749 'error' field in a JSON response body. An +// empty ErrorCode means the response was not a parseable OAuth error — +// typically an HTML page from a WAF, CDN, or reverse proxy that +// intercepted the request before it reached the OAuth server. These +// infrastructure errors (Cloudflare blocks, residential-IP allowlist +// misses, transient bad-config deploys) commonly resolve on their own. +// - 4xx with a populated ErrorCode: permanent. The OAuth server returned +// a structured error code (invalid_grant, invalid_client, etc.) telling +// us specifically what's wrong; retrying won't help. +func classifyOAuthRetrieveError(retrieveErr *oauth2.RetrieveError) bool { + if retrieveErr.Response == nil { + return false + } + statusCode := retrieveErr.Response.StatusCode + + if statusCode >= 500 { + slog.Debug("treating OAuth server error as transient", + "status_code", statusCode, + ) + return true + } + + if statusCode == http.StatusTooManyRequests { + slog.Debug("treating OAuth rate-limit response as transient", + "status_code", statusCode, + "error_code", retrieveErr.ErrorCode, + ) + return true + } + + if retrieveErr.ErrorCode == "" { + slog.Debug("treating OAuth 4xx without error code as transient", + "status_code", statusCode, + ) + return true + } + + return false +} + // isOAuthParseError detects errors from the oauth2 library that indicate the // token endpoint returned an unparsable response body on a 2xx status. This // typically happens when a load balancer, CDN, or reverse proxy intercepts the diff --git a/pkg/auth/monitored_token_source_test.go b/pkg/auth/monitored_token_source_test.go index 6c08b527ad..af45e52ee7 100644 --- a/pkg/auth/monitored_token_source_test.go +++ b/pkg/auth/monitored_token_source_test.go @@ -9,6 +9,7 @@ import ( "fmt" "net" "net/http" + "net/http/httptest" "net/url" "os" "strings" @@ -86,7 +87,10 @@ func (m *mockTokenSource) Token() (*oauth2.Token, error) { return tok, err } -// createRetrieveError creates an error for testing token failures +// createRetrieveError creates an error for testing token failures. ErrorCode +// is left unset, mirroring what golang.org/x/oauth2 produces when the response +// body is not a parseable RFC 6749 error response (e.g. an HTML page from a +// WAF or load balancer). func createRetrieveError(statusCode int, body string) *oauth2.RetrieveError { response := &http.Response{ StatusCode: statusCode, @@ -98,6 +102,16 @@ func createRetrieveError(statusCode int, body string) *oauth2.RetrieveError { } } +// createRetrieveErrorWithCode is like createRetrieveError but also sets the +// ErrorCode field, mirroring what golang.org/x/oauth2 populates when the +// server responds with a parseable JSON error body containing an "error" +// field. +func createRetrieveErrorWithCode(statusCode int, errorCode, body string) *oauth2.RetrieveError { + err := createRetrieveError(statusCode, body) + err.ErrorCode = errorCode + return err +} + func TestMonitoredTokenSource_SuccessfulTokenRetrieval(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) @@ -143,7 +157,7 @@ func TestMonitoredTokenSource_AuthenticationErrorMarksUnauthenticated(t *testing tokenSource := newMockTokenSource() // Create an error that simulates token retrieval failure - retrieveErr := createRetrieveError(http.StatusBadRequest, `{"error":"invalid_grant","error_description":"refresh token expired"}`) + retrieveErr := createRetrieveErrorWithCode(http.StatusBadRequest, "invalid_grant", `{"error":"invalid_grant","error_description":"refresh token expired"}`) tokenSource.setTokenFn(func() (*oauth2.Token, error) { return nil, retrieveErr }) @@ -238,7 +252,7 @@ func TestMonitoredTokenSource_BackgroundMonitoring(t *testing.T) { }, nil } // Subsequent calls: return authentication error - retrieveErr := createRetrieveError(http.StatusUnauthorized, `{"error":"invalid_token"}`) + retrieveErr := createRetrieveErrorWithCode(http.StatusUnauthorized, "invalid_token", `{"error":"invalid_token"}`) return nil, retrieveErr }) @@ -399,7 +413,7 @@ func TestMonitoredTokenSource_MultipleCallsToToken(t *testing.T) { statusUpdater, statusManager := newMockStatusUpdater(ctrl) tokenSource := newMockTokenSource() - retrieveErr := createRetrieveError(http.StatusUnauthorized, `{"error":"invalid_token"}`) + retrieveErr := createRetrieveErrorWithCode(http.StatusUnauthorized, "invalid_token", `{"error":"invalid_token"}`) tokenSource.setTokenFn(func() (*oauth2.Token, error) { return nil, retrieveErr }) @@ -578,12 +592,14 @@ func TestMonitoredTokenSource_BackgroundMonitor_ErrorClassification(t *testing.T err error isTransient bool // true → monitor retries; false → monitor marks unauthenticated }{ - // Non-transient: plain and auth-level errors must fail fast. + // Non-transient: plain errors and OAuth protocol failures (4xx with a + // populated RFC 6749 error code) must fail fast. {name: "plain error", err: errors.New("some error"), isTransient: false}, {name: "context.Canceled", err: context.Canceled, isTransient: false}, {name: "context.DeadlineExceeded", err: context.DeadlineExceeded, isTransient: false}, - {name: "oauth2.RetrieveError 401", err: createRetrieveError(http.StatusUnauthorized, "unauthorized"), isTransient: false}, - {name: "oauth2.RetrieveError 400 invalid_grant", err: createRetrieveError(http.StatusBadRequest, "invalid_grant"), isTransient: false}, + {name: "oauth2.RetrieveError 400 invalid_grant", err: createRetrieveErrorWithCode(http.StatusBadRequest, "invalid_grant", `{"error":"invalid_grant"}`), isTransient: false}, + {name: "oauth2.RetrieveError 401 invalid_client", err: createRetrieveErrorWithCode(http.StatusUnauthorized, "invalid_client", `{"error":"invalid_client"}`), isTransient: false}, + {name: "oauth2.RetrieveError 403 unauthorized_client", err: createRetrieveErrorWithCode(http.StatusForbidden, "unauthorized_client", `{"error":"unauthorized_client"}`), isTransient: false}, {name: "oauth2.RetrieveError nil response", err: &oauth2.RetrieveError{}, isTransient: false}, // Transient: network-level errors must be retried. {name: "*net.DNSError timeout", err: &net.DNSError{Err: "i/o timeout", Name: "example.com", IsTimeout: true}, isTransient: true}, @@ -595,6 +611,17 @@ func TestMonitoredTokenSource_BackgroundMonitor_ErrorClassification(t *testing.T {name: "oauth2.RetrieveError 502", err: createRetrieveError(http.StatusBadGateway, "Bad Gateway"), isTransient: true}, {name: "oauth2.RetrieveError 503", err: createRetrieveError(http.StatusServiceUnavailable, "Service Unavailable"), isTransient: true}, {name: "oauth2.RetrieveError 504", err: createRetrieveError(http.StatusGatewayTimeout, "Gateway Timeout"), isTransient: true}, + // Transient: 4xx without an RFC 6749 error code in the body. + // These are infrastructure-level errors (WAF, CDN, proxy) that + // commonly resolve on their own, not OAuth protocol failures. + {name: "oauth2.RetrieveError 401 with HTML body", err: createRetrieveError(http.StatusUnauthorized, "Unauthorized"), isTransient: true}, + {name: "oauth2.RetrieveError 403 WAF block", err: createRetrieveError(http.StatusForbidden, "Cloudflare Firewall Block"), isTransient: true}, + {name: "oauth2.RetrieveError 400 with empty body", err: createRetrieveError(http.StatusBadRequest, ""), isTransient: true}, + {name: "oauth2.RetrieveError 408 request timeout", err: createRetrieveError(http.StatusRequestTimeout, ""), isTransient: true}, + // Transient: 429 Too Many Requests is retryable per HTTP standard + // regardless of body content. + {name: "oauth2.RetrieveError 429 empty body", err: createRetrieveError(http.StatusTooManyRequests, ""), isTransient: true}, + {name: "oauth2.RetrieveError 429 with rate-limit error code", err: createRetrieveErrorWithCode(http.StatusTooManyRequests, "rate_limit_exceeded", `{"error":"rate_limit_exceeded"}`), isTransient: true}, // Transient: unparsable OAuth responses (HTML from load balancer on 200). {name: "oauth2 cannot parse json", err: fmt.Errorf("oauth2: cannot parse json: invalid character '<'"), isTransient: true}, {name: "wrapped oauth2 parse error", err: fmt.Errorf("refresh failed: %w", fmt.Errorf("oauth2: cannot parse json: invalid character '<'")), isTransient: true}, @@ -657,6 +684,109 @@ func TestMonitoredTokenSource_BackgroundMonitor_ErrorClassification(t *testing.T } } +// TestIsTransientNetworkError_AgainstRealOAuth2Library exercises the +// classification function against errors actually produced by +// golang.org/x/oauth2 when refreshing a token against a real HTTP server. +// +// The classification rule depends on whether the oauth2 library populates +// RetrieveError.ErrorCode for a given response shape. This test pins that +// assumption so a future oauth2 upgrade that changes ErrorCode population +// would surface here, not in production. +func TestIsTransientNetworkError_AgainstRealOAuth2Library(t *testing.T) { + t.Parallel() + + const refreshToken = "test-refresh-token" + + tests := []struct { + name string + handler http.HandlerFunc + isTransient bool + }{ + { + name: "403 with HTML body (Cloudflare WAF)", + handler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte("Cloudflare Firewall Block")) + }, + isTransient: true, + }, + { + name: "401 with HTML body", + handler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte("Unauthorized")) + }, + isTransient: true, + }, + { + name: "400 with invalid_grant JSON", + handler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"refresh token expired"}`)) + }, + isTransient: false, + }, + { + name: "401 with invalid_client JSON", + handler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_client"}`)) + }, + isTransient: false, + }, + { + name: "429 Too Many Requests with empty body", + handler: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + }, + isTransient: true, + }, + { + name: "503 Service Unavailable", + handler: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + }, + isTransient: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(tt.handler) + t.Cleanup(server.Close) + + cfg := &oauth2.Config{ + ClientID: "test-client", + ClientSecret: "test-secret", + Endpoint: oauth2.Endpoint{TokenURL: server.URL}, + } + + expired := &oauth2.Token{ + AccessToken: "expired-access-token", + RefreshToken: refreshToken, + Expiry: time.Now().Add(-time.Hour), + } + + _, err := cfg.TokenSource(context.Background(), expired).Token() + if err == nil { + t.Fatalf("expected refresh to fail, got nil error") + } + + got := isTransientNetworkError(err) + if got != tt.isTransient { + t.Errorf("isTransientNetworkError(%v) = %v, want %v", + err, got, tt.isTransient) + } + }) + } +} + // --- background monitor transient-error behaviour --- // TestMonitoredTokenSource_TransientErrorRetriesAndSucceeds verifies that when the @@ -771,7 +901,7 @@ func TestMonitoredTokenSource_TransientThenNonTransientMarksUnauthenticated(t *t Times(1) transientErr := &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED}} - nonTransientErr := createRetrieveError(http.StatusUnauthorized, `{"error":"invalid_token"}`) + nonTransientErr := createRetrieveErrorWithCode(http.StatusUnauthorized, "invalid_token", `{"error":"invalid_token"}`) tokenSource.setTokenFn(func() (*oauth2.Token, error) { switch tokenSource.callCount {