Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 35 additions & 41 deletions pkg/auth/tokenexchange/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ package tokenexchange
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
Expand Down Expand Up @@ -67,34 +66,6 @@ func NormalizeTokenType(tokenType string) (string, error) {
}
}

// oAuthError represents an OAuth 2.0 error response as defined in RFC 6749 Section 5.2.
type oAuthError struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description,omitempty"`
ErrorURI string `json:"error_uri,omitempty"`
StatusCode int `json:"-"`
}

func (e *oAuthError) String() string {
if e.ErrorURI != "" {
return fmt.Sprintf("OAuth error %q (status %d): see %s", e.Error, e.StatusCode, e.ErrorURI)
}
return fmt.Sprintf("OAuth error %q (status %d)", e.Error, e.StatusCode)
}

// parseOAuthError attempts to parse an OAuth error response from the given response body.
func parseOAuthError(statusCode int, body []byte) *oAuthError {
var oauthErr oAuthError
if err := json.Unmarshal(body, &oauthErr); err != nil {
return nil
}
if oauthErr.Error == "" {
return nil
}
oauthErr.StatusCode = statusCode
return &oauthErr
}

// defaultHTTPClient is the default HTTP client used for token exchange requests.
var defaultHTTPClient = &http.Client{
Timeout: defaultHTTPTimeout,
Expand Down Expand Up @@ -502,37 +473,60 @@ func executeTokenExchangeRequest(client *http.Client, req *http.Request) ([]byte
return nil, fmt.Errorf("failed to read token exchange response: %w", err)
}

if err := validateResponseStatus(resp.StatusCode, body); err != nil {
if err := validateResponseStatus(resp, body); err != nil {
return nil, err
}

return body, nil
}

// validateResponseStatus checks the HTTP status code and returns an error if not successful.
func validateResponseStatus(statusCode int, body []byte) error {
if statusCode >= 200 && statusCode <= 299 {
// On non-2xx responses it extracts RFC 6749 §5.2 fields (error, error_description, error_uri)
// onto the structured fields of the returned *oauth2.RetrieveError. Body is always cleared so
// callers cannot interpolate raw upstream content into error strings — matching the pattern used
// by Ory Hydra, which never surfaces raw error bodies through its public error type.
func validateResponseStatus(resp *http.Response, body []byte) error {
if resp.StatusCode >= 200 && resp.StatusCode <= 299 {
return nil
}

// Try to parse as OAuth error first
if oauthErr := parseOAuthError(statusCode, body); oauthErr != nil {
//nolint:gosec // G706: OAuth error codes are standard protocol values, not user input
slog.Debug("Token exchange OAuth error", "oauth_error_code", oauthErr.Error, "description", oauthErr.ErrorDescription)
return errors.New(oauthErr.String())
retrieveErr := &oauth2.RetrieveError{
Response: resp,
Body: body,
}

// Best-effort parse of the RFC 6749 Section 5.2 error response. Non-JSON or
// non-error-shaped bodies leave ErrorCode/ErrorDescription/ErrorURI empty.
var oauthErr struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description,omitempty"`
ErrorURI string `json:"error_uri,omitempty"`
}
if err := json.Unmarshal(body, &oauthErr); err == nil {
retrieveErr.ErrorCode = oauthErr.Error
retrieveErr.ErrorDescription = oauthErr.ErrorDescription
retrieveErr.ErrorURI = oauthErr.ErrorURI
}

if retrieveErr.ErrorCode != "" {
slog.Debug("Token exchange OAuth error",
"oauth_error_code", retrieveErr.ErrorCode,
"description", retrieveErr.ErrorDescription)
} else {
slog.Debug("Token exchange failed", "status", resp.StatusCode, "body_length", len(body), "body", string(body))
}

retrieveErr.Body = nil

//nolint:gosec // G706: status code and body length are safe diagnostic values
slog.Debug("Token exchange failed", "status", statusCode, "body_length", len(body))
return fmt.Errorf("token exchange failed with status %d", statusCode)
return retrieveErr
}

// parseTokenExchangeResponse parses the token exchange response body.
func parseTokenExchangeResponse(body []byte) (*response, error) {
var tokenResp response
if err := json.Unmarshal(body, &tokenResp); err != nil {
slog.Debug("Failed to parse token exchange response", "error", err)
return nil, errors.New("failed to parse token exchange response")
return nil, fmt.Errorf("failed to parse token exchange response: %w", err)
}

return &tokenResp, nil
Expand Down
59 changes: 34 additions & 25 deletions pkg/auth/tokenexchange/exchange_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,40 +371,42 @@ func TestExchangeToken_HTTPErrorResponses(t *testing.T) {
t.Parallel()

tests := []struct {
name string
statusCode int
responseBody string
expectedError string
name string
statusCode int
responseBody string
expectedErrorCode string
expectedDescription string
}{
{
name: "400 Bad Request",
statusCode: http.StatusBadRequest,
responseBody: `{"error":"invalid_request","error_description":"Missing required parameter"}`,
expectedError: "OAuth error \"invalid_request\" (status 400)",
name: "400 Bad Request",
statusCode: http.StatusBadRequest,
responseBody: `{"error":"invalid_request","error_description":"Missing required parameter"}`,
expectedErrorCode: "invalid_request",
expectedDescription: "Missing required parameter",
},
{
name: "401 Unauthorized",
statusCode: http.StatusUnauthorized,
responseBody: `{"error":"invalid_client"}`,
expectedError: "OAuth error \"invalid_client\" (status 401)",
name: "401 Unauthorized",
statusCode: http.StatusUnauthorized,
responseBody: `{"error":"invalid_client"}`,
expectedErrorCode: "invalid_client",
},
{
name: "403 Forbidden",
statusCode: http.StatusForbidden,
responseBody: `{"error":"access_denied"}`,
expectedError: "OAuth error \"access_denied\" (status 403)",
name: "403 Forbidden",
statusCode: http.StatusForbidden,
responseBody: `{"error":"access_denied"}`,
expectedErrorCode: "access_denied",
},
{
name: "500 Internal Server Error",
statusCode: http.StatusInternalServerError,
responseBody: `{"error":"server_error"}`,
expectedError: "OAuth error \"server_error\" (status 500)",
name: "500 Internal Server Error",
statusCode: http.StatusInternalServerError,
responseBody: `{"error":"server_error"}`,
expectedErrorCode: "server_error",
},
{
name: "503 Service Unavailable",
statusCode: http.StatusServiceUnavailable,
responseBody: "Service temporarily unavailable",
expectedError: "token exchange failed with status 503",
name: "503 Service Unavailable",
statusCode: http.StatusServiceUnavailable,
responseBody: "Service temporarily unavailable",
// Non-JSON body: ErrorCode stays empty, body cleared to prevent info leak.
},
}

Expand Down Expand Up @@ -434,7 +436,14 @@ func TestExchangeToken_HTTPErrorResponses(t *testing.T) {

require.Error(t, err)
assert.Nil(t, resp)
assert.Contains(t, err.Error(), tt.expectedError)

var retrieveErr *oauth2.RetrieveError
require.ErrorAs(t, err, &retrieveErr)
require.NotNil(t, retrieveErr.Response)
assert.Equal(t, tt.statusCode, retrieveErr.Response.StatusCode)
assert.Equal(t, tt.expectedErrorCode, retrieveErr.ErrorCode)
assert.Equal(t, tt.expectedDescription, retrieveErr.ErrorDescription)
assert.Nil(t, retrieveErr.Body)
})
}
}
Expand Down
Loading