Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions pkg/auth/monitored_token_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ type transientRefresher struct {
source oauth2.TokenSource
workload string

// upstream identifies the upstream authorization server that issued the
// token, and clientID is the OAuth 2.0 client_id used by this workload.
// Both are optional and are only surfaced in structured logs (notably the
// DCR remediation warning emitted from isTransientNetworkError when a
// permanent 4xx indicates a stale cached DCR client). Empty strings are
// acceptable and are omitted from log output by the slog handler.
upstream string
clientID string

// newBackOff is a factory for the backoff used during retries.
// Nil in production; overridable in tests for fast execution.
newBackOff func() backoff.BackOff
Expand Down Expand Up @@ -172,7 +181,7 @@ func (r *transientRefresher) retry(ctx context.Context, origErr error) (*oauth2.
if tokenErr == nil {
return t, nil
}
if !isTransientNetworkError(tokenErr) {
if !isTransientNetworkError(tokenErr, r.workload, r.upstream, r.clientID) {
return nil, backoff.Permanent(tokenErr)
}
return nil, tokenErr
Expand Down Expand Up @@ -212,6 +221,8 @@ func (r *transientRefresher) getBackOff() backoff.BackOff {
type MonitoredTokenSource struct {
tokenSource oauth2.TokenSource
workloadName string
upstream string
clientID string
statusUpdater StatusUpdater
monitoringCtx context.Context
stopMonitoring chan struct{}
Expand All @@ -226,20 +237,41 @@ type MonitoredTokenSource struct {

// NewMonitoredTokenSource creates a new MonitoredTokenSource that wraps the provided
// oauth2.TokenSource and monitors it for authentication failures.
//
// upstream and clientID annotate structured logs emitted by the token source,
// most importantly the DCR remediation warning fired from
// isTransientNetworkError when the token endpoint returns a permanent 4xx
// (which frequently indicates a stale cached RFC 7591 registration). Pass
// empty strings when the workload does not use DCR or the upstream issuer
// is unknown; the slog handler will render the attributes as empty.
//
// The fields are fixed at construction time rather than exposed via a setter
// so there is no data race between a late writer and the readers in Token()
// / transientRefresher.retry() — both of which may run on the background
// monitor goroutine started by StartBackgroundMonitoring.
func NewMonitoredTokenSource(
ctx context.Context,
tokenSource oauth2.TokenSource,
workloadName string,
upstream string,
clientID string,
statusUpdater StatusUpdater,
) *MonitoredTokenSource {
return &MonitoredTokenSource{
tokenSource: tokenSource,
workloadName: workloadName,
upstream: upstream,
clientID: clientID,
statusUpdater: statusUpdater,
monitoringCtx: ctx,
stopMonitoring: make(chan struct{}),
stopped: make(chan struct{}),
refresher: &transientRefresher{source: tokenSource, workload: workloadName},
refresher: &transientRefresher{
source: tokenSource,
workload: workloadName,
upstream: upstream,
clientID: clientID,
},
}
}

Expand All @@ -263,7 +295,7 @@ func (mts *MonitoredTokenSource) Token() (*oauth2.Token, error) {
return tok, nil
}

if !isTransientNetworkError(err) {
if !isTransientNetworkError(err, mts.workloadName, mts.upstream, mts.clientID) {
mts.markAsUnauthenticated(fmt.Sprintf("Token retrieval failed: %v", err))
return nil, err
}
Expand Down Expand Up @@ -349,7 +381,13 @@ func (mts *MonitoredTokenSource) onTick() (bool, time.Duration) {
// OAuth2 client-level auth failures (invalid_grant, 401, 400) and TLS errors
// (certificate verification, handshake failure) are NOT considered transient and
// return false so the workload is marked unauthenticated immediately.
func isTransientNetworkError(err error) bool {
//
// workload, upstream, and clientID are context strings used only in
// structured logs — notably the DCR remediation warning emitted in the
// permanent-4xx branch, which suggests the cached RFC 7591 registration is
// no longer recognised by the authorization server. All three are optional;
// pass empty strings when the context is unknown.
func isTransientNetworkError(err error, workload, upstream, clientID string) bool {
if err == nil ||
errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
Expand All @@ -365,6 +403,20 @@ func isTransientNetworkError(err error) bool {
)
return true
}
// Permanent 4xx from the token endpoint frequently indicates that the
// cached Dynamic Client Registration has been revoked, rotated, or is
// no longer recognised by the authorization server. Emit a remediation
// hint before returning false so operators can correlate the
// unauthentication with a stale DCR rather than a user-action issue.
// The returned boolean is unchanged.
//nolint:gosec // G706: client_id is public metadata per RFC 7591.
slog.Warn(
Comment thread
tgrunnagle marked this conversation as resolved.
Outdated
"cached DCR client is no longer recognized by the authorization server; "+
"delete the cached credentials and restart to re-register.",
"workload", workload,
"upstream", upstream,
"client_id", clientID,
)
return false
}

Expand Down
26 changes: 13 additions & 13 deletions pkg/auth/monitored_token_source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func TestMonitoredTokenSource_SuccessfulTokenRetrieval(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)

// Test successful token retrieval
token, err := ats.Token()
Expand Down Expand Up @@ -151,7 +151,7 @@ func TestMonitoredTokenSource_AuthenticationErrorMarksUnauthenticated(t *testing
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)

// Expect SetWorkloadStatus to be called with unauthenticated status
statusManager.EXPECT().
Expand Down Expand Up @@ -195,7 +195,7 @@ func TestMonitoredTokenSource_ErrorMarksUnauthenticated(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)

// Expect SetWorkloadStatus to be called for any error
statusManager.EXPECT().
Expand Down Expand Up @@ -245,7 +245,7 @@ func TestMonitoredTokenSource_BackgroundMonitoring(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)

// Expect SetWorkloadStatus to be called when auth error occurs
statusManager.EXPECT().
Expand Down Expand Up @@ -297,7 +297,7 @@ func TestMonitoredTokenSource_BackgroundMonitoringStopsOnAnyError(t *testing.T)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)

// Expect SetWorkloadStatus to be called when any error occurs
statusManager.EXPECT().
Expand Down Expand Up @@ -341,7 +341,7 @@ func TestMonitoredTokenSource_ExpiredTokenHandling(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)

// Should not mark as unauthenticated just for expired token
// (oauth2 library should handle refresh; we only mark on actual auth errors)
Expand Down Expand Up @@ -374,7 +374,7 @@ func TestMonitoredTokenSource_StopMonitoring(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)
ats.StartBackgroundMonitoring()

// Wait a bit to ensure monitoring started
Expand Down Expand Up @@ -407,7 +407,7 @@ func TestMonitoredTokenSource_MultipleCallsToToken(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)

statusManager.EXPECT().
SetWorkloadStatus(
Expand Down Expand Up @@ -627,7 +627,7 @@ func TestMonitoredTokenSource_BackgroundMonitor_ErrorClassification(t *testing.T
statusUpdater, _ := newMockStatusUpdater(ctrl)
retrying := tokenSource.notifyOnCall(2)

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)
ats.refresher.newBackOff = fastBackOff
ats.StartBackgroundMonitoring()

Expand All @@ -647,7 +647,7 @@ func TestMonitoredTokenSource_BackgroundMonitor_ErrorClassification(t *testing.T
Return(nil).
Times(1)

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)
ats.refresher.newBackOff = fastBackOff
ats.StartBackgroundMonitoring()

Expand Down Expand Up @@ -698,7 +698,7 @@ func TestMonitoredTokenSource_TransientErrorRetriesAndSucceeds(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)
ats.refresher.newBackOff = fastBackOff
ats.StartBackgroundMonitoring()

Expand Down Expand Up @@ -738,7 +738,7 @@ func TestMonitoredTokenSource_TransientErrorContextCancellation(t *testing.T) {

ctx, cancel := context.WithCancel(context.Background())

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)
ats.refresher.newBackOff = fastBackOff
ats.StartBackgroundMonitoring()

Expand Down Expand Up @@ -793,7 +793,7 @@ func TestMonitoredTokenSource_TransientThenNonTransientMarksUnauthenticated(t *t
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", statusUpdater)
ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater)
ats.refresher.newBackOff = fastBackOff
ats.StartBackgroundMonitoring()

Expand Down
Loading
Loading