Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 91 additions & 6 deletions pkg/auth/monitored_token_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ type transientRefresher struct {
source oauth2.TokenSource
workload string

// upstream identifies the upstream authorization server that issued the
// token, and clientID is the OAuth 2.0 client_id used by this workload.
// Both are optional and are only surfaced in structured logs (notably the
// DCR/CIMD remediation warning emitted from MonitoredTokenSource on the
// transition to Unauthenticated). Empty strings are acceptable and are
// omitted from log output by the slog handler.
upstream string
clientID string

// newBackOff is a factory for the backoff used during retries.
// Nil in production; overridable in tests for fast execution.
newBackOff func() backoff.BackOff
Expand Down Expand Up @@ -212,6 +221,8 @@ func (r *transientRefresher) getBackOff() backoff.BackOff {
type MonitoredTokenSource struct {
tokenSource oauth2.TokenSource
workloadName string
upstream string
clientID string
statusUpdater StatusUpdater
monitoringCtx context.Context
stopMonitoring chan struct{}
Expand All @@ -226,20 +237,42 @@ type MonitoredTokenSource struct {

// NewMonitoredTokenSource creates a new MonitoredTokenSource that wraps the provided
// oauth2.TokenSource and monitors it for authentication failures.
//
// upstream and clientID annotate structured logs emitted by the token source,
// most importantly the DCR/CIMD remediation warning fired on the transition
// to Unauthenticated when the token endpoint returns a permanent 4xx (which
// frequently indicates stale cached credentials). Pass empty strings when
// the workload does not use DCR/CIMD or the upstream issuer is unknown; the
// remediation log will be suppressed when clientID is empty since its
// operator-correlation field would be blank.
//
// The fields are fixed at construction time rather than exposed via a setter
// so there is no data race between a late writer and the readers in Token()
// / transientRefresher.retry() — both of which may run on the background
// monitor goroutine started by StartBackgroundMonitoring.
func NewMonitoredTokenSource(
ctx context.Context,
tokenSource oauth2.TokenSource,
workloadName string,
upstream string,
clientID string,
statusUpdater StatusUpdater,
) *MonitoredTokenSource {
return &MonitoredTokenSource{
tokenSource: tokenSource,
workloadName: workloadName,
upstream: upstream,
clientID: clientID,
statusUpdater: statusUpdater,
monitoringCtx: ctx,
stopMonitoring: make(chan struct{}),
stopped: make(chan struct{}),
refresher: &transientRefresher{source: tokenSource, workload: workloadName},
refresher: &transientRefresher{
source: tokenSource,
workload: workloadName,
upstream: upstream,
clientID: clientID,
},
}
}

Expand All @@ -264,7 +297,10 @@ func (mts *MonitoredTokenSource) Token() (*oauth2.Token, error) {
}

if !isTransientNetworkError(err) {
mts.markAsUnauthenticated(fmt.Sprintf("Token retrieval failed: %v", err))
mts.markAsUnauthenticated(
fmt.Sprintf("Token retrieval failed: %v", err),
isPermanentTokenEndpointError(err),
)
return nil, err
}

Expand All @@ -273,7 +309,10 @@ func (mts *MonitoredTokenSource) Token() (*oauth2.Token, error) {
tok, err = mts.refresher.Refresh(mts.monitoringCtx, err)
if err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
mts.markAsUnauthenticated(fmt.Sprintf("Token refresh failed after retries: %v", err))
mts.markAsUnauthenticated(
fmt.Sprintf("Token refresh failed after retries: %v", err),
isPermanentTokenEndpointError(err),
)
}
return nil, err
}
Expand Down Expand Up @@ -349,6 +388,11 @@ func (mts *MonitoredTokenSource) onTick() (bool, time.Duration) {
// OAuth2 client-level auth failures (invalid_grant, 401, 400) and TLS errors
// (certificate verification, handshake failure) are NOT considered transient and
// return false so the workload is marked unauthenticated immediately.
//
// The function is side-effect free; callers that want to emit a DCR
// remediation hint on a permanent 4xx must do so themselves at the
// state-transition boundary using isPermanentTokenEndpointError to
// classify, so a tight Token() loop does not spam the same record.
func isTransientNetworkError(err error) bool {
if err == nil ||
errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
Expand Down Expand Up @@ -399,6 +443,21 @@ func isTransientNetworkError(err error) bool {
return false
}

// isPermanentTokenEndpointError reports whether err is an *oauth2.RetrieveError
// with a non-5xx HTTP status (i.e. a permanent token-endpoint response such as
// 400, 401, 403). Used at state-transition boundaries to decide whether to
// emit a DCR/CIMD remediation hint alongside the unauthentication.
func isPermanentTokenEndpointError(err error) bool {
retrieveErr, ok := errors.AsType[*oauth2.RetrieveError](err)
if !ok {
return false
}
if retrieveErr.Response == nil {
return true
}
return retrieveErr.Response.StatusCode < 500
Comment thread
tgrunnagle marked this conversation as resolved.
Outdated
}

// isOAuthParseError detects errors from the oauth2 library that indicate the
// token endpoint returned an unparsable response body on a 2xx status. This
// typically happens when a load balancer, CDN, or reverse proxy intercepts the
Expand All @@ -414,13 +473,39 @@ func isOAuthParseError(err error) bool {
strings.Contains(msg, "oauth2: cannot parse response")
}

// markAsUnauthenticated marks the workload as unauthenticated and stops background monitoring.
func (mts *MonitoredTokenSource) markAsUnauthenticated(reason string) {
// markAsUnauthenticated marks the workload as unauthenticated and stops
// background monitoring. If permanent4xx is true and the workload was
// constructed with a non-empty client_id, a one-shot DCR/CIMD remediation
// hint is emitted alongside the stop transition. The hint and the close
// of stopMonitoring share stopOnce, so a caller (e.g. a tight Token()
// loop) cannot spam the record on every call after the workload has
// already transitioned to Unauthenticated.
func (mts *MonitoredTokenSource) markAsUnauthenticated(reason string, permanent4xx bool) {
_ = mts.statusUpdater.SetWorkloadStatus(
context.Background(),
mts.workloadName,
runtime.WorkloadStatusUnauthenticated,
reason,
)
mts.stopOnce.Do(func() { close(mts.stopMonitoring) })
mts.stopOnce.Do(func() {
// A permanent 4xx from the token endpoint commonly indicates the
// cached client (DCR or CIMD) is no longer recognised — but the
// same branch fires for revoked consent, disabled accounts, and
// statically configured clients, so the message has to be honest
// about the variability. Gating on clientID != "" suppresses the
// log entirely for workloads where no client_id context is
// available; the operator-correlation it provides would be empty.
if permanent4xx && mts.clientID != "" {
//nolint:gosec // G706: client_id is public metadata per RFC 7591.
slog.Warn(
"token endpoint returned a permanent error; if this workload uses "+
"cached DCR or CIMD credentials they may be stale — delete the "+
"cached credentials and restart to re-register.",
"workload", mts.workloadName,
"upstream", mts.upstream,
"client_id", mts.clientID,
)
}
close(mts.stopMonitoring)
})
}
26 changes: 13 additions & 13 deletions pkg/auth/monitored_token_source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func TestMonitoredTokenSource_SuccessfulTokenRetrieval(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)

// Test successful token retrieval
token, err := ats.Token()
Expand Down Expand Up @@ -151,7 +151,7 @@ func TestMonitoredTokenSource_AuthenticationErrorMarksUnauthenticated(t *testing
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)

// Expect SetWorkloadStatus to be called with unauthenticated status
statusManager.EXPECT().
Expand Down Expand Up @@ -195,7 +195,7 @@ func TestMonitoredTokenSource_ErrorMarksUnauthenticated(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)

// Expect SetWorkloadStatus to be called for any error
statusManager.EXPECT().
Expand Down Expand Up @@ -245,7 +245,7 @@ func TestMonitoredTokenSource_BackgroundMonitoring(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)

// Expect SetWorkloadStatus to be called when auth error occurs
statusManager.EXPECT().
Expand Down Expand Up @@ -297,7 +297,7 @@ func TestMonitoredTokenSource_BackgroundMonitoringStopsOnAnyError(t *testing.T)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)

// Expect SetWorkloadStatus to be called when any error occurs
statusManager.EXPECT().
Expand Down Expand Up @@ -341,7 +341,7 @@ func TestMonitoredTokenSource_ExpiredTokenHandling(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)

// Should not mark as unauthenticated just for expired token
// (oauth2 library should handle refresh; we only mark on actual auth errors)
Expand Down Expand Up @@ -374,7 +374,7 @@ func TestMonitoredTokenSource_StopMonitoring(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)
ats.StartBackgroundMonitoring()

// Wait a bit to ensure monitoring started
Expand Down Expand Up @@ -407,7 +407,7 @@ func TestMonitoredTokenSource_MultipleCallsToToken(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)

statusManager.EXPECT().
SetWorkloadStatus(
Expand Down Expand Up @@ -627,7 +627,7 @@ func TestMonitoredTokenSource_BackgroundMonitor_ErrorClassification(t *testing.T
statusUpdater, _ := newMockStatusUpdater(ctrl)
retrying := tokenSource.notifyOnCall(2)

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)
ats.refresher.newBackOff = fastBackOff
ats.StartBackgroundMonitoring()

Expand All @@ -647,7 +647,7 @@ func TestMonitoredTokenSource_BackgroundMonitor_ErrorClassification(t *testing.T
Return(nil).
Times(1)

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)
ats.refresher.newBackOff = fastBackOff
ats.StartBackgroundMonitoring()

Expand Down Expand Up @@ -698,7 +698,7 @@ func TestMonitoredTokenSource_TransientErrorRetriesAndSucceeds(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)
ats.refresher.newBackOff = fastBackOff
ats.StartBackgroundMonitoring()

Expand Down Expand Up @@ -738,7 +738,7 @@ func TestMonitoredTokenSource_TransientErrorContextCancellation(t *testing.T) {

ctx, cancel := context.WithCancel(context.Background())

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)
ats.refresher.newBackOff = fastBackOff
ats.StartBackgroundMonitoring()

Expand Down Expand Up @@ -793,7 +793,7 @@ func TestMonitoredTokenSource_TransientThenNonTransientMarksUnauthenticated(t *t
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)
ats.refresher.newBackOff = fastBackOff
ats.StartBackgroundMonitoring()

Expand Down
Loading
Loading