Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 68 additions & 21 deletions pkg/auth/monitored_token_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"log/slog"
"net"
"net/http"
"os"
"strconv"
"strings"
Expand Down Expand Up @@ -249,11 +250,11 @@ func (mts *MonitoredTokenSource) Stopped() <-chan struct{} {
return mts.stopped
}

// Token retrieves a token, retrying with exponential backoff on transient errors
// (see isTransientNetworkError for the full list). On non-transient errors
// (OAuth 4xx, TLS failures) it marks the workload as unauthenticated and returns
// immediately. Context cancellation (workload removal) stops the retry without
// marking the workload as unauthenticated.
// Token retrieves a token, retrying with exponential backoff on transient
// errors and marking the workload as unauthenticated on non-transient errors.
// See isTransientNetworkError for the classification rule. Context
// cancellation (workload removal) stops the retry without marking the
// workload as unauthenticated.
//
// Concurrent callers are deduplicated via singleflight so that only one retry
// loop runs at a time during transient failures.
Expand Down Expand Up @@ -342,30 +343,28 @@ func (mts *MonitoredTokenSource) onTick() (bool, time.Duration) {
return false, wait
}

// isTransientNetworkError reports whether err represents a transient condition
// (DNS failure, TCP transport error, timeout, OAuth server 5xx, unparsable
// token response) that is likely to resolve on its own.
// isTransientNetworkError reports whether err represents a transient
// condition that is likely to resolve on its own. The categories are:
//
// OAuth2 client-level auth failures (invalid_grant, 401, 400) and TLS errors
// (certificate verification, handshake failure) are NOT considered transient and
// return false so the workload is marked unauthenticated immediately.
// - Network-level failures: DNS lookup errors, TCP transport errors,
// timeouts.
// - OAuth token-endpoint responses classified as transient by
// classifyOAuthRetrieveError.
// - Unparsable token responses on a 2xx status (typically an HTML page
// from a load balancer or CDN).
//
// All other errors return false, causing the workload to be marked
// unauthenticated. TLS failures (certificate verification, handshake
// failure) are intentionally non-transient even though they surface
// through net.OpError like transport-level errors.
func isTransientNetworkError(err error) bool {
if err == nil ||
errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}

// OAuth HTTP-level errors: 5xx (Bad Gateway, Service Unavailable, Gateway
// Timeout) are transient server-side issues that typically resolve on their
// own. 4xx errors (invalid_grant, invalid_client) are permanent auth failures.
if retrieveErr, ok := errors.AsType[*oauth2.RetrieveError](err); ok {
if retrieveErr.Response != nil && retrieveErr.Response.StatusCode >= 500 {
slog.Debug("treating OAuth server error as transient",
"status_code", retrieveErr.Response.StatusCode,
)
return true
}
return false
return classifyOAuthRetrieveError(retrieveErr)
}

// Non-JSON responses from the OAuth server (e.g. load balancer HTML pages).
Expand Down Expand Up @@ -399,6 +398,54 @@ func isTransientNetworkError(err error) bool {
return false
}

// classifyOAuthRetrieveError reports whether an *oauth2.RetrieveError should
// be treated as transient. The classification rules are:
//
// - nil Response: non-transient. There is no signal to act on, so we fall
// through to the unauthenticated path rather than retry blindly.
// - 5xx status: transient (server-side issue, likely to resolve).
// - 429 Too Many Requests: transient regardless of body (HTTP standard).
// - 4xx with an empty ErrorCode: transient. The oauth2 library populates
// ErrorCode from the RFC 6749 'error' field in a JSON response body. An
// empty ErrorCode means the response was not a parseable OAuth error —
// typically an HTML page from a WAF, CDN, or reverse proxy that
// intercepted the request before it reached the OAuth server. These
// infrastructure errors (Cloudflare blocks, residential-IP allowlist
// misses, transient bad-config deploys) commonly resolve on their own.
// - 4xx with a populated ErrorCode: permanent. The OAuth server returned
// a structured error code (invalid_grant, invalid_client, etc.) telling
// us specifically what's wrong; retrying won't help.
func classifyOAuthRetrieveError(retrieveErr *oauth2.RetrieveError) bool {
if retrieveErr.Response == nil {
return false
}
statusCode := retrieveErr.Response.StatusCode

if statusCode >= 500 {
slog.Debug("treating OAuth server error as transient",
"status_code", statusCode,
)
return true
}

if statusCode == http.StatusTooManyRequests {
slog.Debug("treating OAuth rate-limit response as transient",
"status_code", statusCode,
"error_code", retrieveErr.ErrorCode,
)
return true
}

if retrieveErr.ErrorCode == "" {
slog.Debug("treating OAuth 4xx without error code as transient",
"status_code", statusCode,
)
return true
}

return false
}

// isOAuthParseError detects errors from the oauth2 library that indicate the
// token endpoint returned an unparsable response body on a 2xx status. This
// typically happens when a load balancer, CDN, or reverse proxy intercepts the
Expand Down
146 changes: 138 additions & 8 deletions pkg/auth/monitored_token_source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
Expand Down Expand Up @@ -86,7 +87,10 @@ func (m *mockTokenSource) Token() (*oauth2.Token, error) {
return tok, err
}

// createRetrieveError creates an error for testing token failures
// createRetrieveError creates an error for testing token failures. ErrorCode
// is left unset, mirroring what golang.org/x/oauth2 produces when the response
// body is not a parseable RFC 6749 error response (e.g. an HTML page from a
// WAF or load balancer).
func createRetrieveError(statusCode int, body string) *oauth2.RetrieveError {
response := &http.Response{
StatusCode: statusCode,
Expand All @@ -98,6 +102,16 @@ func createRetrieveError(statusCode int, body string) *oauth2.RetrieveError {
}
}

// createRetrieveErrorWithCode is like createRetrieveError but also sets the
// ErrorCode field, mirroring what golang.org/x/oauth2 populates when the
// server responds with a parseable JSON error body containing an "error"
// field.
func createRetrieveErrorWithCode(statusCode int, errorCode, body string) *oauth2.RetrieveError {
err := createRetrieveError(statusCode, body)
err.ErrorCode = errorCode
return err
}

func TestMonitoredTokenSource_SuccessfulTokenRetrieval(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
Expand Down Expand Up @@ -143,7 +157,7 @@ func TestMonitoredTokenSource_AuthenticationErrorMarksUnauthenticated(t *testing
tokenSource := newMockTokenSource()

// Create an error that simulates token retrieval failure
retrieveErr := createRetrieveError(http.StatusBadRequest, `{"error":"invalid_grant","error_description":"refresh token expired"}`)
retrieveErr := createRetrieveErrorWithCode(http.StatusBadRequest, "invalid_grant", `{"error":"invalid_grant","error_description":"refresh token expired"}`)
tokenSource.setTokenFn(func() (*oauth2.Token, error) {
return nil, retrieveErr
})
Expand Down Expand Up @@ -238,7 +252,7 @@ func TestMonitoredTokenSource_BackgroundMonitoring(t *testing.T) {
}, nil
}
// Subsequent calls: return authentication error
retrieveErr := createRetrieveError(http.StatusUnauthorized, `{"error":"invalid_token"}`)
retrieveErr := createRetrieveErrorWithCode(http.StatusUnauthorized, "invalid_token", `{"error":"invalid_token"}`)
return nil, retrieveErr
})

Expand Down Expand Up @@ -399,7 +413,7 @@ func TestMonitoredTokenSource_MultipleCallsToToken(t *testing.T) {
statusUpdater, statusManager := newMockStatusUpdater(ctrl)
tokenSource := newMockTokenSource()

retrieveErr := createRetrieveError(http.StatusUnauthorized, `{"error":"invalid_token"}`)
retrieveErr := createRetrieveErrorWithCode(http.StatusUnauthorized, "invalid_token", `{"error":"invalid_token"}`)
tokenSource.setTokenFn(func() (*oauth2.Token, error) {
return nil, retrieveErr
})
Expand Down Expand Up @@ -578,12 +592,14 @@ func TestMonitoredTokenSource_BackgroundMonitor_ErrorClassification(t *testing.T
err error
isTransient bool // true → monitor retries; false → monitor marks unauthenticated
}{
// Non-transient: plain and auth-level errors must fail fast.
// Non-transient: plain errors and OAuth protocol failures (4xx with a
// populated RFC 6749 error code) must fail fast.
{name: "plain error", err: errors.New("some error"), isTransient: false},
{name: "context.Canceled", err: context.Canceled, isTransient: false},
{name: "context.DeadlineExceeded", err: context.DeadlineExceeded, isTransient: false},
{name: "oauth2.RetrieveError 401", err: createRetrieveError(http.StatusUnauthorized, "unauthorized"), isTransient: false},
{name: "oauth2.RetrieveError 400 invalid_grant", err: createRetrieveError(http.StatusBadRequest, "invalid_grant"), isTransient: false},
{name: "oauth2.RetrieveError 400 invalid_grant", err: createRetrieveErrorWithCode(http.StatusBadRequest, "invalid_grant", `{"error":"invalid_grant"}`), isTransient: false},
{name: "oauth2.RetrieveError 401 invalid_client", err: createRetrieveErrorWithCode(http.StatusUnauthorized, "invalid_client", `{"error":"invalid_client"}`), isTransient: false},
{name: "oauth2.RetrieveError 403 unauthorized_client", err: createRetrieveErrorWithCode(http.StatusForbidden, "unauthorized_client", `{"error":"unauthorized_client"}`), isTransient: false},
{name: "oauth2.RetrieveError nil response", err: &oauth2.RetrieveError{}, isTransient: false},
// Transient: network-level errors must be retried.
{name: "*net.DNSError timeout", err: &net.DNSError{Err: "i/o timeout", Name: "example.com", IsTimeout: true}, isTransient: true},
Expand All @@ -595,6 +611,17 @@ func TestMonitoredTokenSource_BackgroundMonitor_ErrorClassification(t *testing.T
{name: "oauth2.RetrieveError 502", err: createRetrieveError(http.StatusBadGateway, "Bad Gateway"), isTransient: true},
{name: "oauth2.RetrieveError 503", err: createRetrieveError(http.StatusServiceUnavailable, "Service Unavailable"), isTransient: true},
{name: "oauth2.RetrieveError 504", err: createRetrieveError(http.StatusGatewayTimeout, "Gateway Timeout"), isTransient: true},
// Transient: 4xx without an RFC 6749 error code in the body.
// These are infrastructure-level errors (WAF, CDN, proxy) that
// commonly resolve on their own, not OAuth protocol failures.
{name: "oauth2.RetrieveError 401 with HTML body", err: createRetrieveError(http.StatusUnauthorized, "<html><body>Unauthorized</body></html>"), isTransient: true},
{name: "oauth2.RetrieveError 403 WAF block", err: createRetrieveError(http.StatusForbidden, "<html><body>Cloudflare Firewall Block</body></html>"), isTransient: true},
{name: "oauth2.RetrieveError 400 with empty body", err: createRetrieveError(http.StatusBadRequest, ""), isTransient: true},
{name: "oauth2.RetrieveError 408 request timeout", err: createRetrieveError(http.StatusRequestTimeout, ""), isTransient: true},
// Transient: 429 Too Many Requests is retryable per HTTP standard
// regardless of body content.
{name: "oauth2.RetrieveError 429 empty body", err: createRetrieveError(http.StatusTooManyRequests, ""), isTransient: true},
{name: "oauth2.RetrieveError 429 with rate-limit error code", err: createRetrieveErrorWithCode(http.StatusTooManyRequests, "rate_limit_exceeded", `{"error":"rate_limit_exceeded"}`), isTransient: true},
// Transient: unparsable OAuth responses (HTML from load balancer on 200).
{name: "oauth2 cannot parse json", err: fmt.Errorf("oauth2: cannot parse json: invalid character '<'"), isTransient: true},
{name: "wrapped oauth2 parse error", err: fmt.Errorf("refresh failed: %w", fmt.Errorf("oauth2: cannot parse json: invalid character '<'")), isTransient: true},
Expand Down Expand Up @@ -657,6 +684,109 @@ func TestMonitoredTokenSource_BackgroundMonitor_ErrorClassification(t *testing.T
}
}

// TestIsTransientNetworkError_AgainstRealOAuth2Library exercises the
// classification function against errors actually produced by
// golang.org/x/oauth2 when refreshing a token against a real HTTP server.
//
// The classification rule depends on whether the oauth2 library populates
// RetrieveError.ErrorCode for a given response shape. This test pins that
// assumption so a future oauth2 upgrade that changes ErrorCode population
// would surface here, not in production.
func TestIsTransientNetworkError_AgainstRealOAuth2Library(t *testing.T) {
t.Parallel()

const refreshToken = "test-refresh-token"

tests := []struct {
name string
handler http.HandlerFunc
isTransient bool
}{
{
name: "403 with HTML body (Cloudflare WAF)",
handler: func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte("<html><body>Cloudflare Firewall Block</body></html>"))
},
isTransient: true,
},
{
name: "401 with HTML body",
handler: func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("<html><body>Unauthorized</body></html>"))
},
isTransient: true,
},
{
name: "400 with invalid_grant JSON",
handler: func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"refresh token expired"}`))
},
isTransient: false,
},
{
name: "401 with invalid_client JSON",
handler: func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"invalid_client"}`))
},
isTransient: false,
},
{
name: "429 Too Many Requests with empty body",
handler: func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTooManyRequests)
},
isTransient: true,
},
{
name: "503 Service Unavailable",
handler: func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
},
isTransient: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

server := httptest.NewServer(tt.handler)
t.Cleanup(server.Close)

cfg := &oauth2.Config{
ClientID: "test-client",
ClientSecret: "test-secret",
Endpoint: oauth2.Endpoint{TokenURL: server.URL},
}

expired := &oauth2.Token{
AccessToken: "expired-access-token",
RefreshToken: refreshToken,
Expiry: time.Now().Add(-time.Hour),
}

_, err := cfg.TokenSource(context.Background(), expired).Token()
if err == nil {
t.Fatalf("expected refresh to fail, got nil error")
}

got := isTransientNetworkError(err)
if got != tt.isTransient {
t.Errorf("isTransientNetworkError(%v) = %v, want %v",
err, got, tt.isTransient)
}
})
}
}

// --- background monitor transient-error behaviour ---

// TestMonitoredTokenSource_TransientErrorRetriesAndSucceeds verifies that when the
Expand Down Expand Up @@ -771,7 +901,7 @@ func TestMonitoredTokenSource_TransientThenNonTransientMarksUnauthenticated(t *t
Times(1)

transientErr := &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED}}
nonTransientErr := createRetrieveError(http.StatusUnauthorized, `{"error":"invalid_token"}`)
nonTransientErr := createRetrieveErrorWithCode(http.StatusUnauthorized, "invalid_token", `{"error":"invalid_token"}`)

tokenSource.setTokenFn(func() (*oauth2.Token, error) {
switch tokenSource.callCount {
Expand Down
Loading