diff --git a/pkg/auth/tokenexchange/exchange.go b/pkg/auth/tokenexchange/exchange.go index 430a4dff9b..b6a8e3fb68 100644 --- a/pkg/auth/tokenexchange/exchange.go +++ b/pkg/auth/tokenexchange/exchange.go @@ -7,7 +7,6 @@ package tokenexchange import ( "context" "encoding/json" - "errors" "fmt" "io" "log/slog" @@ -67,34 +66,6 @@ func NormalizeTokenType(tokenType string) (string, error) { } } -// oAuthError represents an OAuth 2.0 error response as defined in RFC 6749 Section 5.2. -type oAuthError struct { - Error string `json:"error"` - ErrorDescription string `json:"error_description,omitempty"` - ErrorURI string `json:"error_uri,omitempty"` - StatusCode int `json:"-"` -} - -func (e *oAuthError) String() string { - if e.ErrorURI != "" { - return fmt.Sprintf("OAuth error %q (status %d): see %s", e.Error, e.StatusCode, e.ErrorURI) - } - return fmt.Sprintf("OAuth error %q (status %d)", e.Error, e.StatusCode) -} - -// parseOAuthError attempts to parse an OAuth error response from the given response body. -func parseOAuthError(statusCode int, body []byte) *oAuthError { - var oauthErr oAuthError - if err := json.Unmarshal(body, &oauthErr); err != nil { - return nil - } - if oauthErr.Error == "" { - return nil - } - oauthErr.StatusCode = statusCode - return &oauthErr -} - // defaultHTTPClient is the default HTTP client used for token exchange requests. var defaultHTTPClient = &http.Client{ Timeout: defaultHTTPTimeout, @@ -502,7 +473,7 @@ func executeTokenExchangeRequest(client *http.Client, req *http.Request) ([]byte return nil, fmt.Errorf("failed to read token exchange response: %w", err) } - if err := validateResponseStatus(resp.StatusCode, body); err != nil { + if err := validateResponseStatus(resp, body); err != nil { return nil, err } @@ -510,21 +481,44 @@ func executeTokenExchangeRequest(client *http.Client, req *http.Request) ([]byte } // validateResponseStatus checks the HTTP status code and returns an error if not successful. -func validateResponseStatus(statusCode int, body []byte) error { - if statusCode >= 200 && statusCode <= 299 { +// On non-2xx responses it extracts RFC 6749 §5.2 fields (error, error_description, error_uri) +// onto the structured fields of the returned *oauth2.RetrieveError. Body is always cleared so +// callers cannot interpolate raw upstream content into error strings — matching the pattern used +// by Ory Hydra, which never surfaces raw error bodies through its public error type. +func validateResponseStatus(resp *http.Response, body []byte) error { + if resp.StatusCode >= 200 && resp.StatusCode <= 299 { return nil } - // Try to parse as OAuth error first - if oauthErr := parseOAuthError(statusCode, body); oauthErr != nil { - //nolint:gosec // G706: OAuth error codes are standard protocol values, not user input - slog.Debug("Token exchange OAuth error", "oauth_error_code", oauthErr.Error, "description", oauthErr.ErrorDescription) - return errors.New(oauthErr.String()) + retrieveErr := &oauth2.RetrieveError{ + Response: resp, + Body: body, + } + + // Best-effort parse of the RFC 6749 Section 5.2 error response. Non-JSON or + // non-error-shaped bodies leave ErrorCode/ErrorDescription/ErrorURI empty. + var oauthErr struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` + ErrorURI string `json:"error_uri,omitempty"` } + if err := json.Unmarshal(body, &oauthErr); err == nil { + retrieveErr.ErrorCode = oauthErr.Error + retrieveErr.ErrorDescription = oauthErr.ErrorDescription + retrieveErr.ErrorURI = oauthErr.ErrorURI + } + + if retrieveErr.ErrorCode != "" { + slog.Debug("Token exchange OAuth error", + "oauth_error_code", retrieveErr.ErrorCode, + "description", retrieveErr.ErrorDescription) + } else { + slog.Debug("Token exchange failed", "status", resp.StatusCode, "body_length", len(body), "body", string(body)) + } + + retrieveErr.Body = nil - //nolint:gosec // G706: status code and body length are safe diagnostic values - slog.Debug("Token exchange failed", "status", statusCode, "body_length", len(body)) - return fmt.Errorf("token exchange failed with status %d", statusCode) + return retrieveErr } // parseTokenExchangeResponse parses the token exchange response body. @@ -532,7 +526,7 @@ func parseTokenExchangeResponse(body []byte) (*response, error) { var tokenResp response if err := json.Unmarshal(body, &tokenResp); err != nil { slog.Debug("Failed to parse token exchange response", "error", err) - return nil, errors.New("failed to parse token exchange response") + return nil, fmt.Errorf("failed to parse token exchange response: %w", err) } return &tokenResp, nil diff --git a/pkg/auth/tokenexchange/exchange_test.go b/pkg/auth/tokenexchange/exchange_test.go index 419731aba7..abc5a0fe01 100644 --- a/pkg/auth/tokenexchange/exchange_test.go +++ b/pkg/auth/tokenexchange/exchange_test.go @@ -371,40 +371,42 @@ func TestExchangeToken_HTTPErrorResponses(t *testing.T) { t.Parallel() tests := []struct { - name string - statusCode int - responseBody string - expectedError string + name string + statusCode int + responseBody string + expectedErrorCode string + expectedDescription string }{ { - name: "400 Bad Request", - statusCode: http.StatusBadRequest, - responseBody: `{"error":"invalid_request","error_description":"Missing required parameter"}`, - expectedError: "OAuth error \"invalid_request\" (status 400)", + name: "400 Bad Request", + statusCode: http.StatusBadRequest, + responseBody: `{"error":"invalid_request","error_description":"Missing required parameter"}`, + expectedErrorCode: "invalid_request", + expectedDescription: "Missing required parameter", }, { - name: "401 Unauthorized", - statusCode: http.StatusUnauthorized, - responseBody: `{"error":"invalid_client"}`, - expectedError: "OAuth error \"invalid_client\" (status 401)", + name: "401 Unauthorized", + statusCode: http.StatusUnauthorized, + responseBody: `{"error":"invalid_client"}`, + expectedErrorCode: "invalid_client", }, { - name: "403 Forbidden", - statusCode: http.StatusForbidden, - responseBody: `{"error":"access_denied"}`, - expectedError: "OAuth error \"access_denied\" (status 403)", + name: "403 Forbidden", + statusCode: http.StatusForbidden, + responseBody: `{"error":"access_denied"}`, + expectedErrorCode: "access_denied", }, { - name: "500 Internal Server Error", - statusCode: http.StatusInternalServerError, - responseBody: `{"error":"server_error"}`, - expectedError: "OAuth error \"server_error\" (status 500)", + name: "500 Internal Server Error", + statusCode: http.StatusInternalServerError, + responseBody: `{"error":"server_error"}`, + expectedErrorCode: "server_error", }, { - name: "503 Service Unavailable", - statusCode: http.StatusServiceUnavailable, - responseBody: "Service temporarily unavailable", - expectedError: "token exchange failed with status 503", + name: "503 Service Unavailable", + statusCode: http.StatusServiceUnavailable, + responseBody: "Service temporarily unavailable", + // Non-JSON body: ErrorCode stays empty, body cleared to prevent info leak. }, } @@ -434,7 +436,14 @@ func TestExchangeToken_HTTPErrorResponses(t *testing.T) { require.Error(t, err) assert.Nil(t, resp) - assert.Contains(t, err.Error(), tt.expectedError) + + var retrieveErr *oauth2.RetrieveError + require.ErrorAs(t, err, &retrieveErr) + require.NotNil(t, retrieveErr.Response) + assert.Equal(t, tt.statusCode, retrieveErr.Response.StatusCode) + assert.Equal(t, tt.expectedErrorCode, retrieveErr.ErrorCode) + assert.Equal(t, tt.expectedDescription, retrieveErr.ErrorDescription) + assert.Nil(t, retrieveErr.Body) }) } }