diff --git a/processor.go b/processor.go index 28d1b5c..aa7f435 100644 --- a/processor.go +++ b/processor.go @@ -375,6 +375,7 @@ func (c *OAuthBodyProcessorConfig) StripHazmat() ProcessorConfig { type Sigv4ProcessorConfig struct { AccessKey string `json:"access_key"` SecretKey string `json:"secret_key"` + NoSwap bool `json:"no_swap"` // Bug compat, when false region/service get swapped. } var _ ProcessorConfig = (*Sigv4ProcessorConfig)(nil) @@ -424,12 +425,12 @@ func (c *Sigv4ProcessorConfig) Processor(params map[string]string) (RequestProce return err } - // FIXME: These are swapped. SigV4 credential format is - // AKID/date/region/service/aws4_request, so credParts[2] is region - // and credParts[3] is service. Left as-is to avoid breaking existing - // clients that may depend on this behavior. - service = credParts[2] - region = credParts[3] + region = credParts[2] + service = credParts[3] + if !c.NoSwap { + // Bug compat: swap service and region when NoSwap is not set. + region, service = service, region + } break } } diff --git a/processor_test.go b/processor_test.go index 48cb03d..bfcac5d 100644 --- a/processor_test.go +++ b/processor_test.go @@ -2,6 +2,7 @@ package tokenizer import ( "bytes" + "context" "io" "net/http" "strings" @@ -388,3 +389,68 @@ func TestOAuthBodyProcessorConfig(t *testing.T) { func stringToReadCloser(s string) io.ReadCloser { return io.NopCloser(strings.NewReader(s)) } + +func TestSigv4Processor(t *testing.T) { + parseAuthHeader := func(s string) (string, string, string, string) { + // AWS4-HMAC-SHA256 Credential=AccessKey/20260304/service/region/aws4_request, SignedHeaders=host;x-amz-date, Signature=674b77f7c09adf9becb2eb1c70183bb4e828330063b8b1577dfc64001074ea3d + var accessKey, dateStr, region, service string + words := strings.Split(strings.TrimPrefix(s, "AWS4-HMAC-SHA256 "), ", ") + for _, word := range words { + kv := strings.SplitN(word, "=", 2) + if len(kv) == 2 { + if kv[0] == "Credential" { + parts := strings.Split(kv[1], "/") + if len(parts) == 5 { + accessKey = parts[0] + dateStr = parts[1] + region = parts[2] + service = parts[3] + } + } + } + } + return accessKey, dateStr, region, service + } + + // Build our base request and Verify that we're parsing out the fields in the correct order... + r, err := http.NewRequest(http.MethodGet, "https://www.test.com/path", http.NoBody) + assert.NoError(t, err) + r.Header.Set("Authorization", "AWS4-HMAC-SHA256 Credential=AccessKey/20260304/region/service/aws4_request, SignedHeaders=host;x-amz-date, Signature=792ff6fc18a10db6bc1ae6a8f9ffb0caf1432da8d2008e33203a9a0434cd9127") + akey, d, reg, svc := parseAuthHeader(r.Header.Get("Authorization")) + assert.Equal(t, akey, "AccessKey") + assert.Equal(t, d, "20260304") + assert.Equal(t, reg, "region") + assert.Equal(t, svc, "service") + + // The processor will swap region/service when NoSwap is false for bug compat. + cfgSwap := &Sigv4ProcessorConfig{"AccessKey", "SecretKey", false} + processSwap, err := cfgSwap.Processor(nil) + assert.NoError(t, err) + + req := r.Clone(context.Background()) + //fmt.Printf("Before: (swap) %v\n", req.Header.Get("Authorization")) + err = processSwap(req) + assert.NoError(t, err) + //fmt.Printf("After: (swap) %v\n", req.Header.Get("Authorization")) + akey, d, reg, svc = parseAuthHeader(req.Header.Get("Authorization")) + assert.Equal(t, akey, "AccessKey") + assert.Equal(t, d, "20260304") + assert.Equal(t, reg, "service") // swapped! + assert.Equal(t, svc, "region") // swapped! + + // The processor will not swap region/service when NoSwap is true. + cfg := &Sigv4ProcessorConfig{"AccessKey", "SecretKey", true} + process, err := cfg.Processor(nil) + assert.NoError(t, err) + + req = r.Clone(context.Background()) + //fmt.Printf("Before: (noswap) %v\n", req.Header.Get("Authorization")) + err = process(req) + assert.NoError(t, err) + //fmt.Printf("After: (noswap) %v\n", req.Header.Get("Authorization")) + akey, d, reg, svc = parseAuthHeader(req.Header.Get("Authorization")) + assert.Equal(t, akey, "AccessKey") + assert.Equal(t, d, "20260304") + assert.Equal(t, reg, "region") + assert.Equal(t, svc, "service") +}