diff --git a/internal/envconfig/envconfig.go b/internal/envconfig/envconfig.go index 9bc13c01c2d9..cc5713fd9d5e 100644 --- a/internal/envconfig/envconfig.go +++ b/internal/envconfig/envconfig.go @@ -55,6 +55,20 @@ var ( // setting the environment variable "GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST" // to "false". NewPickFirstEnabled = boolFromEnv("GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST", true) + + // XDSEndpointHashKeyBackwardCompat controls the parsing of the endpoint hash + // key from EDS LbEndpoint metadata. Endpoint hash keys can be disabled by + // setting "GRPC_XDS_ENDPOINT_HASH_KEY_BACKWARD_COMPAT" to "true". When the + // implementation of A76 is stable, we will flip the default value to false + // in a subsequent release. A final release will remove this environment + // variable, enabling the new behavior unconditionally. + XDSEndpointHashKeyBackwardCompat = boolFromEnv("GRPC_XDS_ENDPOINT_HASH_KEY_BACKWARD_COMPAT", true) + + // RingHashSetRequestHashKey is set if the ring hash balancer can get the + // request hash header by setting the "requestHashHeader" field, according + // to gRFC A76. It can be enabled by setting the environment variable + // "GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY" to "true". + RingHashSetRequestHashKey = boolFromEnv("GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY", false) ) func boolFromEnv(envVar string, def bool) bool { diff --git a/internal/metadata/metadata.go b/internal/metadata/metadata.go index 900bfb716080..c4055bc00e51 100644 --- a/internal/metadata/metadata.go +++ b/internal/metadata/metadata.go @@ -97,13 +97,11 @@ func hasNotPrintable(msg string) bool { return false } -// ValidatePair validate a key-value pair with the following rules (the pseudo-header will be skipped) : -// -// - key must contain one or more characters. -// - the characters in the key must be contained in [0-9 a-z _ - .]. -// - if the key ends with a "-bin" suffix, no validation of the corresponding value is performed. -// - the characters in the every value must be printable (in [%x20-%x7E]). -func ValidatePair(key string, vals ...string) error { +// ValidateKey validates a key with the following rules (pseudo-headers are +// skipped): +// - the key must contain one or more characters. +// - the characters in the key must be in [0-9 a-z _ - .]. +func ValidateKey(key string) error { // key should not be empty if key == "" { return fmt.Errorf("there is an empty key in the header") @@ -119,6 +117,20 @@ func ValidatePair(key string, vals ...string) error { return fmt.Errorf("header key %q contains illegal characters not in [0-9a-z-_.]", key) } } + return nil +} + +// ValidatePair validates a key-value pair with the following rules +// (pseudo-header are skipped): +// - the key must contain one or more characters. +// - the characters in the key must be in [0-9 a-z _ - .]. +// - if the key ends with a "-bin" suffix, no validation of the corresponding +// value is performed. +// - the characters in every value must be printable (in [%x20-%x7E]). +func ValidatePair(key string, vals ...string) error { + if err := ValidateKey(key); err != nil { + return err + } if strings.HasSuffix(key, "-bin") { return nil } diff --git a/internal/testutils/envconfig.go b/internal/testutils/envconfig.go new file mode 100644 index 000000000000..3b25aacf5ca7 --- /dev/null +++ b/internal/testutils/envconfig.go @@ -0,0 +1,33 @@ +package testutils + +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import ( + "testing" +) + +// SetEnvConfig sets the value of the given variable to the specified value, +// taking care of restoring the original value after the test completes. +func SetEnvConfig[T any](t *testing.T, variable *T, value T) { + t.Helper() + old := *variable + t.Cleanup(func() { + *variable = old + }) + *variable = value +} diff --git a/internal/testutils/xds/e2e/clientresources.go b/internal/testutils/xds/e2e/clientresources.go index e94d226f8663..b8ebf8ba5266 100644 --- a/internal/testutils/xds/e2e/clientresources.go +++ b/internal/testutils/xds/e2e/clientresources.go @@ -26,6 +26,7 @@ import ( "github.com/envoyproxy/go-control-plane/pkg/wellknown" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/wrapperspb" v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" @@ -649,6 +650,9 @@ type BackendOptions struct { HealthStatus v3corepb.HealthStatus // Weight sets the backend weight. Defaults to 1. Weight uint32 + // Metadata sets the LB endpoint metadata (envoy.lb FilterMetadata field). + // See https://www.envoyproxy.io/docs/envoy/latest/api-v3/config/core/v3/base.proto#envoy-v3-api-msg-config-core-v3-metadata + Metadata map[string]any } // EndpointOptions contains options to configure an Endpoint (or @@ -708,6 +712,10 @@ func EndpointResourceWithOptions(opts EndpointOptions) *v3endpointpb.ClusterLoad }, } } + metadata, err := structpb.NewStruct(b.Metadata) + if err != nil { + panic(err) + } lbEndpoints = append(lbEndpoints, &v3endpointpb.LbEndpoint{ HostIdentifier: &v3endpointpb.LbEndpoint_Endpoint{Endpoint: &v3endpointpb.Endpoint{ Address: &v3corepb.Address{Address: &v3corepb.Address_SocketAddress{ @@ -721,6 +729,11 @@ func EndpointResourceWithOptions(opts EndpointOptions) *v3endpointpb.ClusterLoad }}, HealthStatus: b.HealthStatus, LoadBalancingWeight: &wrapperspb.UInt32Value{Value: b.Weight}, + Metadata: &v3corepb.Metadata{ + FilterMetadata: map[string]*structpb.Struct{ + "envoy.lb": metadata, + }, + }, }) } diff --git a/resolver/ringhash/attr.go b/resolver/ringhash/attr.go new file mode 100644 index 000000000000..154f02307799 --- /dev/null +++ b/resolver/ringhash/attr.go @@ -0,0 +1,60 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package ringhash implements resolver related functions for the ring_hash +// load balancing policy. +package ringhash + +import ( + "google.golang.org/grpc/resolver" +) + +type hashKeyType string + +// hashKeyKey is the key to store the ring hash key attribute in +// a resolver.Endpoint attribute. +const hashKeyKey = hashKeyType("grpc.resolver.ringhash.hash_key") + +// SetHashKey sets the hash key for this endpoint. Combined with the ring_hash +// load balancing policy, it allows placing the endpoint on the ring based on an +// arbitrary string instead of the IP address. If hashKey is empty, the endpoint +// is returned unmodified. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func SetHashKey(endpoint resolver.Endpoint, hashKey string) resolver.Endpoint { + if hashKey == "" { + return endpoint + } + endpoint.Attributes = endpoint.Attributes.WithValue(hashKeyKey, hashKey) + return endpoint +} + +// HashKey returns the hash key attribute of endpoint. If this attribute is +// not set, it returns the empty string. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func HashKey(endpoint resolver.Endpoint) string { + hashKey, _ := endpoint.Attributes.Value(hashKeyKey).(string) + return hashKey +} diff --git a/xds/internal/balancer/clusterresolver/configbuilder.go b/xds/internal/balancer/clusterresolver/configbuilder.go index 28313a90cd3f..fb2bc7629185 100644 --- a/xds/internal/balancer/clusterresolver/configbuilder.go +++ b/xds/internal/balancer/clusterresolver/configbuilder.go @@ -27,6 +27,7 @@ import ( "google.golang.org/grpc/internal/hierarchy" internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/ringhash" "google.golang.org/grpc/xds/internal" "google.golang.org/grpc/xds/internal/balancer/clusterimpl" "google.golang.org/grpc/xds/internal/balancer/outlierdetection" @@ -284,6 +285,7 @@ func priorityLocalitiesToClusterImpl(localities []xdsresource.Locality, priority ew = endpoint.Weight } resolverEndpoint = weight.Set(resolverEndpoint, weight.EndpointInfo{Weight: lw * ew}) + resolverEndpoint = ringhash.SetHashKey(resolverEndpoint, endpoint.HashKey) retEndpoints = append(retEndpoints, resolverEndpoint) } } diff --git a/xds/internal/balancer/ringhash/config.go b/xds/internal/balancer/ringhash/config.go index b4afcf100132..eaa6ca233d45 100644 --- a/xds/internal/balancer/ringhash/config.go +++ b/xds/internal/balancer/ringhash/config.go @@ -21,8 +21,10 @@ package ringhash import ( "encoding/json" "fmt" + "strings" "google.golang.org/grpc/internal/envconfig" + "google.golang.org/grpc/internal/metadata" "google.golang.org/grpc/serviceconfig" ) @@ -30,8 +32,9 @@ import ( type LBConfig struct { serviceconfig.LoadBalancingConfig `json:"-"` - MinRingSize uint64 `json:"minRingSize,omitempty"` - MaxRingSize uint64 `json:"maxRingSize,omitempty"` + MinRingSize uint64 `json:"minRingSize,omitempty"` + MaxRingSize uint64 `json:"maxRingSize,omitempty"` + RequestHashHeader string `json:"requestHashHeader,omitempty"` } const ( @@ -66,5 +69,17 @@ func parseConfig(c json.RawMessage) (*LBConfig, error) { if cfg.MaxRingSize > envconfig.RingHashCap { cfg.MaxRingSize = envconfig.RingHashCap } + if !envconfig.RingHashSetRequestHashKey { + cfg.RequestHashHeader = "" + } + if cfg.RequestHashHeader != "" { + // See rules in https://github.com/grpc/proposal/blob/master/A76-ring-hash-improvements.md#explicitly-setting-the-request-hash-key + if err := metadata.ValidateKey(cfg.RequestHashHeader); err != nil { + return nil, fmt.Errorf("invalid requestHashHeader %q: %v", cfg.RequestHashHeader, err) + } + if strings.HasSuffix(cfg.RequestHashHeader, "-bin") { + return nil, fmt.Errorf("invalid requestHashHeader %q: key must not end with \"-bin\"", cfg.RequestHashHeader) + } + } return &cfg, nil } diff --git a/xds/internal/balancer/ringhash/config_test.go b/xds/internal/balancer/ringhash/config_test.go index 1077d3e7dafb..9588a8984c6e 100644 --- a/xds/internal/balancer/ringhash/config_test.go +++ b/xds/internal/balancer/ringhash/config_test.go @@ -19,90 +19,136 @@ package ringhash import ( + "encoding/json" "testing" "github.com/google/go-cmp/cmp" "google.golang.org/grpc/internal/envconfig" + "google.golang.org/grpc/internal/testutils" ) func (s) TestParseConfig(t *testing.T) { tests := []struct { - name string - js string - envConfigCap uint64 - want *LBConfig - wantErr bool + name string + js string + envConfigCap uint64 + requestHeaderEnvVar bool + want *LBConfig + wantErr bool }{ { - name: "OK", - js: `{"minRingSize": 1, "maxRingSize": 2}`, - want: &LBConfig{MinRingSize: 1, MaxRingSize: 2}, + name: "OK", + js: `{"minRingSize": 1, "maxRingSize": 2}`, + requestHeaderEnvVar: true, + want: &LBConfig{MinRingSize: 1, MaxRingSize: 2}, }, { - name: "OK with default min", - js: `{"maxRingSize": 2000}`, - want: &LBConfig{MinRingSize: defaultMinSize, MaxRingSize: 2000}, + name: "OK with default min", + js: `{"maxRingSize": 2000}`, + requestHeaderEnvVar: true, + want: &LBConfig{MinRingSize: defaultMinSize, MaxRingSize: 2000}, }, { - name: "OK with default max", - js: `{"minRingSize": 2000}`, - want: &LBConfig{MinRingSize: 2000, MaxRingSize: defaultMaxSize}, + name: "OK with default max", + js: `{"minRingSize": 2000}`, + requestHeaderEnvVar: true, + want: &LBConfig{MinRingSize: 2000, MaxRingSize: defaultMaxSize}, }, { - name: "min greater than max", - js: `{"minRingSize": 10, "maxRingSize": 2}`, - want: nil, - wantErr: true, + name: "min greater than max", + js: `{"minRingSize": 10, "maxRingSize": 2}`, + requestHeaderEnvVar: true, + want: nil, + wantErr: true, }, { - name: "min greater than max greater than global limit", - js: `{"minRingSize": 6000, "maxRingSize": 5000}`, - want: nil, - wantErr: true, + name: "min greater than max greater than global limit", + js: `{"minRingSize": 6000, "maxRingSize": 5000}`, + requestHeaderEnvVar: true, + want: nil, + wantErr: true, }, { - name: "max greater than global limit", - js: `{"minRingSize": 1, "maxRingSize": 6000}`, - want: &LBConfig{MinRingSize: 1, MaxRingSize: 4096}, + name: "max greater than global limit", + js: `{"minRingSize": 1, "maxRingSize": 6000}`, + requestHeaderEnvVar: true, + want: &LBConfig{MinRingSize: 1, MaxRingSize: 4096}, }, { - name: "min and max greater than global limit", - js: `{"minRingSize": 5000, "maxRingSize": 6000}`, - want: &LBConfig{MinRingSize: 4096, MaxRingSize: 4096}, + name: "min and max greater than global limit", + js: `{"minRingSize": 5000, "maxRingSize": 6000}`, + requestHeaderEnvVar: true, + want: &LBConfig{MinRingSize: 4096, MaxRingSize: 4096}, }, { - name: "min and max less than raised global limit", - js: `{"minRingSize": 5000, "maxRingSize": 6000}`, - envConfigCap: 8000, - want: &LBConfig{MinRingSize: 5000, MaxRingSize: 6000}, + name: "min and max less than raised global limit", + js: `{"minRingSize": 5000, "maxRingSize": 6000}`, + envConfigCap: 8000, + requestHeaderEnvVar: true, + want: &LBConfig{MinRingSize: 5000, MaxRingSize: 6000}, }, { - name: "min and max greater than raised global limit", - js: `{"minRingSize": 10000, "maxRingSize": 10000}`, - envConfigCap: 8000, - want: &LBConfig{MinRingSize: 8000, MaxRingSize: 8000}, + name: "min and max greater than raised global limit", + js: `{"minRingSize": 10000, "maxRingSize": 10000}`, + envConfigCap: 8000, + requestHeaderEnvVar: true, + want: &LBConfig{MinRingSize: 8000, MaxRingSize: 8000}, }, { - name: "min greater than upper bound", - js: `{"minRingSize": 8388610, "maxRingSize": 10}`, - want: nil, - wantErr: true, + name: "min greater than upper bound", + js: `{"minRingSize": 8388610, "maxRingSize": 10}`, + requestHeaderEnvVar: true, + want: nil, + wantErr: true, }, { - name: "max greater than upper bound", - js: `{"minRingSize": 10, "maxRingSize": 8388610}`, - want: nil, - wantErr: true, + name: "max greater than upper bound", + js: `{"minRingSize": 10, "maxRingSize": 8388610}`, + requestHeaderEnvVar: true, + want: nil, + wantErr: true, + }, + { + name: "request metadata key set", + js: `{"requestHashHeader": "x-foo"}`, + requestHeaderEnvVar: true, + want: &LBConfig{ + MinRingSize: defaultMinSize, + MaxRingSize: defaultMaxSize, + RequestHashHeader: "x-foo", + }, + }, + { + name: "invalid request hash header", + js: `{"requestHashHeader": "!invalid"}`, + requestHeaderEnvVar: true, + want: nil, + wantErr: true, + }, + { + name: "binary request hash header", + js: `{"requestHashHeader": "header-with-bin"}`, + requestHeaderEnvVar: true, + want: nil, + wantErr: true, + }, + { + name: "request hash header cleared when RingHashSetRequestHashKey env var is false", + js: `{"requestHashHeader": "x-foo"}`, + requestHeaderEnvVar: false, + want: &LBConfig{ + MinRingSize: defaultMinSize, + MaxRingSize: defaultMaxSize, + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.envConfigCap != 0 { - old := envconfig.RingHashCap - defer func() { envconfig.RingHashCap = old }() - envconfig.RingHashCap = tt.envConfigCap + testutils.SetEnvConfig(t, &envconfig.RingHashCap, tt.envConfigCap) } - got, err := parseConfig([]byte(tt.js)) + testutils.SetEnvConfig(t, &envconfig.RingHashSetRequestHashKey, tt.requestHeaderEnvVar) + got, err := parseConfig(json.RawMessage(tt.js)) if (err != nil) != tt.wantErr { t.Errorf("parseConfig() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/xds/internal/balancer/ringhash/e2e/ringhash_balancer_test.go b/xds/internal/balancer/ringhash/e2e/ringhash_balancer_test.go index 7641f9178179..502c355a759a 100644 --- a/xds/internal/balancer/ringhash/e2e/ringhash_balancer_test.go +++ b/xds/internal/balancer/ringhash/e2e/ringhash_balancer_test.go @@ -26,6 +26,8 @@ import ( rand "math/rand/v2" "net" "slices" + "strconv" + "sync" "testing" "time" @@ -48,6 +50,7 @@ import ( "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" "google.golang.org/grpc/status" + "google.golang.org/grpc/xds/internal/balancer/ringhash" v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" @@ -123,9 +126,10 @@ func (s) TestRingHash_ReconnectToMoveOutOfTransientFailure(t *testing.T) { defer cc.Close() // Push the address of the test backend through the manual resolver. - r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}) + r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + ctx = ringhash.SetXDSRequestHash(ctx, 0) defer cancel() client := testgrpc.NewTestServiceClient(cc) if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil { @@ -469,7 +473,7 @@ func (s) TestRingHash_AggregateClusterFallBackFromRingHashToLogicalDnsAtStartup( } dnsR := replaceDNSResolver(t) - dnsR.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: backends[0]}}}) + dnsR.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: backends[0]}}}) if err := xdsServer.Update(ctx, updateOpts); err != nil { t.Fatalf("Failed to update xDS resources: %v", err) @@ -547,7 +551,7 @@ func (s) TestRingHash_AggregateClusterFallBackFromRingHashToLogicalDnsAtStartupN } dnsR := replaceDNSResolver(t) - dnsR.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: backends[0]}}}) + dnsR.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: backends[0]}}}) if err := xdsServer.Update(ctx, updateOpts); err != nil { t.Fatalf("Failed to update xDS resources: %v", err) @@ -2542,3 +2546,373 @@ func (s) TestRingHash_RecoverWhenResolverRemovesEndpoint(t *testing.T) { // Wait for channel to become READY without any pending RPC. testutils.AwaitState(ctx, t, conn, connectivity.Ready) } + +// Tests that RPCs are routed according to endpoint hash key rather than +// endpoint first address if it is set in EDS endpoint metadata. +func (s) TestRingHash_EndpointHashKey(t *testing.T) { + testutils.SetEnvConfig(t, &envconfig.XDSEndpointHashKeyBackwardCompat, false) + + backends := backendAddrs(startTestServiceBackends(t, 4)) + + const clusterName = "cluster" + var backendOpts []e2e.BackendOptions + for i, addr := range backends { + var ports []uint32 + ports = append(ports, testutils.ParsePort(t, addr)) + backendOpts = append(backendOpts, e2e.BackendOptions{ + Ports: ports, + Metadata: map[string]any{"hash_key": strconv.Itoa(i)}, + }) + } + endpoints := e2e.EndpointResourceWithOptions(e2e.EndpointOptions{ + ClusterName: clusterName, + Host: "localhost", + Localities: []e2e.LocalityOptions{{ + Backends: backendOpts, + Weight: 1, + }}, + }) + cluster := e2e.ClusterResourceWithOptions(e2e.ClusterOptions{ + ClusterName: clusterName, + ServiceName: clusterName, + Policy: e2e.LoadBalancingPolicyRingHash, + }) + route := headerHashRoute("new_route", virtualHostName, clusterName, "address_hash") + listener := e2e.DefaultClientListener(virtualHostName, route.Name) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + xdsServer, nodeID, xdsResolver := setupManagementServerAndResolver(t) + if err := xdsServer.Update(ctx, xdsUpdateOpts(nodeID, endpoints, cluster, route, listener)); err != nil { + t.Fatalf("Failed to update xDS resources: %v", err) + } + + opts := []grpc.DialOption{ + grpc.WithResolvers(xdsResolver), + grpc.WithTransportCredentials(insecure.NewCredentials()), + } + conn, err := grpc.NewClient("xds:///test.server", opts...) + if err != nil { + t.Fatalf("Failed to create client: %s", err) + } + defer conn.Close() + client := testgrpc.NewTestServiceClient(conn) + + // Make sure RPCs are routed to backends according to the endpoint metadata + // rather than their address. Note each type of RPC contains a header value + // that will always be hashed to a specific backend as the header value + // matches the endpoint metadata hash key. + for i, backend := range backends { + ctx := metadata.NewOutgoingContext(ctx, metadata.Pairs("address_hash", strconv.Itoa(i)+"_0")) + numRPCs := 10 + reqPerBackend := checkRPCSendOK(ctx, t, client, numRPCs) + if reqPerBackend[backend] != numRPCs { + t.Errorf("Got RPC routed to addresses %v, want all RPCs routed to %v", reqPerBackend, backend) + } + } + + // Update the endpoints to swap the metadata hash key. + for i := range backendOpts { + backendOpts[i].Metadata = map[string]any{"hash_key": strconv.Itoa(len(backends) - i - 1)} + } + endpoints = e2e.EndpointResourceWithOptions(e2e.EndpointOptions{ + ClusterName: clusterName, + Host: "localhost", + Localities: []e2e.LocalityOptions{{ + Backends: backendOpts, + Weight: 1, + }}, + }) + if err := xdsServer.Update(ctx, xdsUpdateOpts(nodeID, endpoints, cluster, route, listener)); err != nil { + t.Fatalf("Failed to update xDS resources: %v", err) + } + + // Wait for the resolver update to make it to the balancer. This RPC should + // be routed to backend 3 with the reverse numbering of the hash_key + // attribute delivered above. + for { + ctx := metadata.NewOutgoingContext(ctx, metadata.Pairs("address_hash", "0_0")) + var remote peer.Peer + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&remote)); err != nil { + t.Fatalf("Unexpected RPC error waiting for EDS update propagation: %s", err) + } + if remote.Addr.String() == backends[3] { + break + } + } + + // Now that the balancer has the new endpoint attributes, make sure RPCs are + // routed to backends according to the new endpoint metadata. + for i, backend := range backends { + ctx := metadata.NewOutgoingContext(ctx, metadata.Pairs("address_hash", strconv.Itoa(len(backends)-i-1)+"_0")) + numRPCs := 10 + reqPerBackend := checkRPCSendOK(ctx, t, client, numRPCs) + if reqPerBackend[backend] != numRPCs { + t.Errorf("Got RPC routed to addresses %v, want all RPCs routed to %v", reqPerBackend, backend) + } + } +} + +// Tests that when a request hash key is set in the balancer configuration via +// service config, this header is used to route to a specific backend. +func (s) TestRingHash_RequestHashKey(t *testing.T) { + testutils.SetEnvConfig(t, &envconfig.RingHashSetRequestHashKey, true) + + backends := backendAddrs(startTestServiceBackends(t, 4)) + + // Create a clientConn with a manual resolver (which is used to push the + // address of the test backend), and a default service config pointing to + // the use of the ring_hash_experimental LB policy with an explicit hash + // header. + const ringHashServiceConfig = `{"loadBalancingConfig": [{"ring_hash_experimental":{"requestHashHeader":"address_hash"}}]}` + r := manual.NewBuilderWithScheme("whatever") + dopts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithResolvers(r), + grpc.WithDefaultServiceConfig(ringHashServiceConfig), + grpc.WithConnectParams(fastConnectParams), + } + cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...) + if err != nil { + t.Fatalf("Failed to dial local test server: %v", err) + } + defer cc.Close() + var endpoints []resolver.Endpoint + for _, backend := range backends { + endpoints = append(endpoints, resolver.Endpoint{ + Addresses: []resolver.Address{{Addr: backend}}, + }) + } + r.UpdateState(resolver.State{ + Endpoints: endpoints, + }) + client := testgrpc.NewTestServiceClient(cc) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Note each type of RPC contains a header value that will always be hashed + // to a specific backend as the header value matches the value used to + // create the entry in the ring. + for _, backend := range backends { + ctx := metadata.NewOutgoingContext(ctx, metadata.Pairs("address_hash", backend+"_0")) + numRPCs := 10 + reqPerBackend := checkRPCSendOK(ctx, t, client, numRPCs) + if reqPerBackend[backend] != numRPCs { + t.Errorf("Got RPC routed to addresses %v, want all RPCs routed to %v", reqPerBackend, backend) + } + } + + const ringHashServiceConfigUpdate = `{"loadBalancingConfig": [{"ring_hash_experimental":{"requestHashHeader":"other_header"}}]}` + r.UpdateState(resolver.State{ + Endpoints: endpoints, + ServiceConfig: (&testutils.ResolverClientConn{}).ParseServiceConfig(ringHashServiceConfigUpdate), + }) + + // Make sure that requests with the new hash are sent to the right backend. + for _, backend := range backends { + ctx := metadata.NewOutgoingContext(ctx, metadata.Pairs("other_header", backend+"_0")) + numRPCs := 10 + reqPerBackend := checkRPCSendOK(ctx, t, client, numRPCs) + if reqPerBackend[backend] != numRPCs { + t.Errorf("Got RPC routed to addresses %v, want all RPCs routed to %v", reqPerBackend, backend) + } + } +} + +// Tests that when a request hash key is set in the balancer configuration via +// service config, and the header is not set in the outgoing request, then it +// is sent to a random backend. +func (s) TestRingHash_RequestHashKeyRandom(t *testing.T) { + testutils.SetEnvConfig(t, &envconfig.RingHashSetRequestHashKey, true) + + backends := backendAddrs(startTestServiceBackends(t, 4)) + + // Create a clientConn with a manual resolver (which is used to push the + // address of the test backend), and a default service config pointing to + // the use of the ring_hash_experimental LB policy with an explicit hash + // header. + const ringHashServiceConfig = `{"loadBalancingConfig": [{"ring_hash_experimental":{"requestHashHeader":"address_hash"}}]}` + r := manual.NewBuilderWithScheme("whatever") + dopts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithResolvers(r), + grpc.WithDefaultServiceConfig(ringHashServiceConfig), + grpc.WithConnectParams(fastConnectParams), + } + cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...) + if err != nil { + t.Fatalf("Failed to dial local test server: %v", err) + } + defer cc.Close() + var endpoints []resolver.Endpoint + for _, backend := range backends { + endpoints = append(endpoints, resolver.Endpoint{ + Addresses: []resolver.Address{{Addr: backend}}, + }) + } + r.UpdateState(resolver.State{ + Endpoints: endpoints, + }) + client := testgrpc.NewTestServiceClient(cc) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Due to the way that ring hash lazily establishes connections when using a + // random hash, request distribution is skewed towards the order in which we + // connected. The test send RPCs until we are connected to all backends, so + // we can later assert that the distribution is uniform. + seen := make(map[string]bool) + for len(seen) != 4 { + var remote peer.Peer + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&remote)); err != nil { + t.Fatalf("rpc EmptyCall() failed: %v", err) + } + seen[remote.String()] = true + } + + // Make sure that requests with the old hash are sent to random backends. + numRPCs := computeIdealNumberOfRPCs(t, .25, errorTolerance) + gotPerBackend := checkRPCSendOK(ctx, t, client, numRPCs) + for _, backend := range backends { + got := float64(gotPerBackend[backend]) / float64(numRPCs) + want := .25 + if !cmp.Equal(got, want, cmpopts.EquateApprox(0, errorTolerance)) { + t.Errorf("Fraction of RPCs to backend %s: got %v, want %v (margin: +-%v)", backend, got, want, errorTolerance) + } + } +} + +// Tests that when a request hash key is set in the balancer configuration via +// service config, and the header is not set in the outgoing request (random +// behavior), then each RPC wakes up at most one SubChannel, and, if there are +// SubChannels in Ready state, RPCs are routed to them. +func (s) TestRingHash_RequestHashKeyConnecting(t *testing.T) { + testutils.SetEnvConfig(t, &envconfig.RingHashSetRequestHashKey, true) + + backends := backendAddrs(startTestServiceBackends(t, 20)) + + // Create a clientConn with a manual resolver (which is used to push the + // address of the test backend), and a default service config pointing to + // the use of the ring_hash_experimental LB policy with an explicit hash + // header. Use a blocking dialer to control connection attempts. + const ringHashServiceConfig = `{"loadBalancingConfig": [ + {"ring_hash_experimental":{"requestHashHeader":"address_hash"}} + ]}` + r := manual.NewBuilderWithScheme("whatever") + blockingDialer := testutils.NewBlockingDialer() + dopts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithResolvers(r), + grpc.WithDefaultServiceConfig(ringHashServiceConfig), + grpc.WithConnectParams(fastConnectParams), + grpc.WithContextDialer(blockingDialer.DialContext), + } + cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...) + if err != nil { + t.Fatalf("Failed to dial local test server: %v", err) + } + defer cc.Close() + var endpoints []resolver.Endpoint + for _, backend := range backends { + endpoints = append(endpoints, resolver.Endpoint{ + Addresses: []resolver.Address{{Addr: backend}}, + }) + } + r.UpdateState(resolver.State{ + Endpoints: endpoints, + }) + client := testgrpc.NewTestServiceClient(cc) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Intercept all connection attempts to the backends. + var holds []*testutils.Hold + for i := 0; i < len(backends); i++ { + holds = append(holds, blockingDialer.Hold(backends[i])) + } + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + // Send 1 RPC and make sure this triggers at most 1 connection attempt. + _, err := client.EmptyCall(ctx, &testpb.Empty{}) + if err != nil { + t.Errorf("EmptyCall(): got %v, want success", err) + } + wg.Done() + }() + testutils.AwaitState(ctx, t, cc, connectivity.Connecting) + + // Check that only one connection attempt was started. + nConn := 0 + for _, hold := range holds { + if hold.IsStarted() { + nConn++ + } + } + if wantMaxConn := 1; nConn > wantMaxConn { + t.Fatalf("Got %d connection attempts, want at most %d", nConn, wantMaxConn) + } + + // Do a second RPC. Since there should already be a SubChannel in + // Connecting state, this should not trigger a connection attempt. + wg.Add(1) + go func() { + _, err := client.EmptyCall(ctx, &testpb.Empty{}) + if err != nil { + t.Errorf("EmptyCall(): got %v, want success", err) + } + wg.Done() + }() + + // Give extra time for more connections to be attempted. + time.Sleep(defaultTestShortTimeout) + + var firstConnectedBackend string + nConn = 0 + for i, hold := range holds { + if hold.IsStarted() { + // Unblock the connection attempt. The SubChannel (and hence the + // channel) should transition to Ready. RPCs should succeed and + // be routed to this backend. + hold.Resume() + holds[i] = nil + firstConnectedBackend = backends[i] + nConn++ + } + } + if wantMaxConn := 1; nConn > wantMaxConn { + t.Fatalf("Got %d connection attempts, want at most %d", nConn, wantMaxConn) + } + testutils.AwaitState(ctx, t, cc, connectivity.Ready) + wg.Wait() // Make sure we're done with the 2 previous RPCs. + + // Now send RPCs until we have at least one more connection attempt, that + // is, the random hash did not land on the same backend on every pick (the + // chances are low, but we don't want this to be flaky). Make sure no RPC + // fails and that we route all of them to the only subchannel in ready + // state. + nConn = 0 + for nConn == 0 { + p := peer.Peer{} + _, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&p)) + if status.Code(err) == codes.DeadlineExceeded { + t.Fatal("EmptyCall(): test timed out while waiting for more connection attempts") + } + if err != nil { + t.Fatalf("EmptyCall(): got %v, want success", err) + } + if p.Addr.String() != firstConnectedBackend { + t.Errorf("RPC sent to backend %q, want %q", p.Addr.String(), firstConnectedBackend) + } + for _, hold := range holds { + if hold != nil && hold.IsStarted() { + nConn++ + } + } + } +} diff --git a/xds/internal/balancer/ringhash/picker.go b/xds/internal/balancer/ringhash/picker.go index fc6bf67558cd..1488d7806237 100644 --- a/xds/internal/balancer/ringhash/picker.go +++ b/xds/internal/balancer/ringhash/picker.go @@ -20,46 +20,103 @@ package ringhash import ( "fmt" + "strings" + xxhash "github.com/cespare/xxhash/v2" "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/metadata" ) type picker struct { - ring *ring - logger *grpclog.PrefixLogger - // endpointStates is a cache of endpoint connectivity states and pickers. + ring *ring + + // endpointStates is a cache of endpoint states. // The ringhash balancer stores endpoint states in a `resolver.EndpointMap`, // with access guarded by `ringhashBalancer.mu`. The `endpointStates` cache // in the picker helps avoid locking the ringhash balancer's mutex when // reading the latest state at RPC time. - endpointStates map[string]balancer.State // endpointState.firstAddr -> balancer.State + endpointStates map[string]endpointState // endpointState.hashKey -> endpointState + + // requestHashHeader is the header key to look for the request hash. If it's + // empty, the request hash is expected to be set in the context via xDS. + // See gRFC A76. + requestHashHeader string + + // hasEndpointInConnectingState is true if any of the endpoints is in + // CONNECTING. + hasEndpointInConnectingState bool + + randUint64 func() uint64 } func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { - e := p.ring.pick(getRequestHash(info.Ctx)) + usingRandomHash := false + var requestHash uint64 + if p.requestHashHeader == "" { + var ok bool + if requestHash, ok = XDSRequestHash(info.Ctx); !ok { + return balancer.PickResult{}, fmt.Errorf("ringhash: expected xDS config selector to set the request hash") + } + } else { + md, ok := metadata.FromOutgoingContext(info.Ctx) + if !ok || len(md.Get(p.requestHashHeader)) == 0 { + requestHash = p.randUint64() + usingRandomHash = true + } else { + values := strings.Join(md.Get(p.requestHashHeader), ",") + requestHash = xxhash.Sum64String(values) + } + } + + e := p.ring.pick(requestHash) ringSize := len(p.ring.items) - // Per gRFC A61, because of sticky-TF with PickFirst's auto reconnect on TF, - // we ignore all TF subchannels and find the first ring entry in READY, - // CONNECTING or IDLE. If that entry is in IDLE, we need to initiate a - // connection. The idlePicker returned by the LazyLB or the new Pickfirst - // should do this automatically. - for i := 0; i < ringSize; i++ { - index := (e.idx + i) % ringSize - balState := p.balancerState(p.ring.items[index]) - switch balState.ConnectivityState { - case connectivity.Ready, connectivity.Connecting, connectivity.Idle: - return balState.Picker.Pick(info) - case connectivity.TransientFailure: - default: - panic(fmt.Sprintf("Found child balancer in unknown state: %v", balState.ConnectivityState)) + if !usingRandomHash { + // Per gRFC A61, because of sticky-TF with PickFirst's auto reconnect on TF, + // we ignore all TF subchannels and find the first ring entry in READY, + // CONNECTING or IDLE. If that entry is in IDLE, we need to initiate a + // connection. The idlePicker returned by the LazyLB or the new Pickfirst + // should do this automatically. + for i := 0; i < ringSize; i++ { + index := (e.idx + i) % ringSize + es := p.endpointState(p.ring.items[index]) + switch es.state.ConnectivityState { + case connectivity.Ready, connectivity.Connecting, connectivity.Idle: + return es.state.Picker.Pick(info) + case connectivity.TransientFailure: + default: + panic(fmt.Sprintf("Found child balancer in unknown state: %v", es.state.ConnectivityState)) + } + } + } else { + // If the picker has generated a random hash, it will walk the ring from + // this hash, and pick the first READY endpoint. If no endpoint is + // currently in CONNECTING state, it will trigger a connection attempt + // on at most one endpoint that is in IDLE state along the way. - A76 + requestedConnection := p.hasEndpointInConnectingState + for i := 0; i < ringSize; i++ { + index := (e.idx + i) % ringSize + es := p.endpointState(p.ring.items[index]) + if es.state.ConnectivityState == connectivity.Ready { + return es.state.Picker.Pick(info) + } + if !requestedConnection && es.state.ConnectivityState == connectivity.Idle { + requestedConnection = true + // If the SubChannel is in idle state, initiate a connection but + // continue to check other pickers to see if there is one in + // ready state. + es.balancer.ExitIdle() + } + } + if requestedConnection { + return balancer.PickResult{}, balancer.ErrNoSubConnAvailable } } + // All children are in transient failure. Return the first failure. - return p.balancerState(e).Picker.Pick(info) + return p.endpointState(e).state.Picker.Pick(info) } -func (p *picker) balancerState(e *ringEntry) balancer.State { - return p.endpointStates[e.firstAddr] +func (p *picker) endpointState(e *ringEntry) endpointState { + return p.endpointStates[e.hashKey] } diff --git a/xds/internal/balancer/ringhash/picker_test.go b/xds/internal/balancer/ringhash/picker_test.go index ff5b60589af9..2c079c19996e 100644 --- a/xds/internal/balancer/ringhash/picker_test.go +++ b/xds/internal/balancer/ringhash/picker_test.go @@ -20,18 +20,22 @@ package ringhash import ( "context" + "errors" "fmt" + "math" "testing" "time" "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/internal/testutils" - - internalgrpclog "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/metadata" ) -var testSubConns []*testutils.TestSubConn +var ( + testSubConns []*testutils.TestSubConn + errPicker = errors.New("picker in TransientFailure") +) func init() { for i := 0; i < 8; i++ { @@ -60,22 +64,35 @@ func (p *fakeChildPicker) Pick(balancer.PickInfo) (balancer.PickResult, error) { } } -func testRingAndEndpointStates(states []connectivity.State) (*ring, map[string]balancer.State) { +type fakeExitIdler struct { + sc *testutils.TestSubConn +} + +func (ei *fakeExitIdler) ExitIdle() { + ei.sc.Connect() +} + +func testRingAndEndpointStates(states []connectivity.State) (*ring, map[string]endpointState) { var items []*ringEntry - epStates := map[string]balancer.State{} + epStates := map[string]endpointState{} for i, st := range states { testSC := testSubConns[i] items = append(items, &ringEntry{ - idx: i, - hash: uint64((i + 1) * 10), - firstAddr: testSC.String(), + idx: i, + hash: math.MaxUint64 / uint64(len(states)) * uint64(i), + hashKey: testSC.String(), }) - epState := balancer.State{ - ConnectivityState: st, - Picker: &fakeChildPicker{ - connectivityState: st, - tfError: fmt.Errorf("%d", i), - subConn: testSC, + epState := endpointState{ + state: balancer.State{ + ConnectivityState: st, + Picker: &fakeChildPicker{ + connectivityState: st, + tfError: fmt.Errorf("%d: %w", i, errPicker), + subConn: testSC, + }, + }, + balancer: &fakeExitIdler{ + sc: testSC, }, } epStates[testSC.String()] = epState @@ -87,7 +104,6 @@ func (s) TestPickerPickFirstTwo(t *testing.T) { tests := []struct { name string connectivityStates []connectivity.State - hash uint64 wantSC balancer.SubConn wantErr error wantSCToConnect balancer.SubConn @@ -95,41 +111,40 @@ func (s) TestPickerPickFirstTwo(t *testing.T) { { name: "picked is Ready", connectivityStates: []connectivity.State{connectivity.Ready, connectivity.Idle}, - hash: 5, wantSC: testSubConns[0], }, { name: "picked is connecting, queue", connectivityStates: []connectivity.State{connectivity.Connecting, connectivity.Idle}, - hash: 5, wantErr: balancer.ErrNoSubConnAvailable, }, { name: "picked is Idle, connect and queue", connectivityStates: []connectivity.State{connectivity.Idle, connectivity.Idle}, - hash: 5, wantErr: balancer.ErrNoSubConnAvailable, wantSCToConnect: testSubConns[0], }, { name: "picked is TransientFailure, next is ready, return", connectivityStates: []connectivity.State{connectivity.TransientFailure, connectivity.Ready}, - hash: 5, wantSC: testSubConns[1], }, { name: "picked is TransientFailure, next is connecting, queue", connectivityStates: []connectivity.State{connectivity.TransientFailure, connectivity.Connecting}, - hash: 5, wantErr: balancer.ErrNoSubConnAvailable, }, { name: "picked is TransientFailure, next is Idle, connect and queue", connectivityStates: []connectivity.State{connectivity.TransientFailure, connectivity.Idle}, - hash: 5, wantErr: balancer.ErrNoSubConnAvailable, wantSCToConnect: testSubConns[1], }, + { + name: "all are in TransientFailure, return picked failure", + connectivityStates: []connectivity.State{connectivity.TransientFailure, connectivity.TransientFailure}, + wantErr: errPicker, + }, } ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -138,13 +153,12 @@ func (s) TestPickerPickFirstTwo(t *testing.T) { ring, epStates := testRingAndEndpointStates(tt.connectivityStates) p := &picker{ ring: ring, - logger: internalgrpclog.NewPrefixLogger(logger, "test-ringhash-picker"), endpointStates: epStates, } got, err := p.Pick(balancer.PickInfo{ - Ctx: SetRequestHash(ctx, tt.hash), + Ctx: SetXDSRequestHash(ctx, 0), // always pick the first endpoint on the ring. }) - if err != tt.wantErr { + if (err != nil || tt.wantErr != nil) && !errors.Is(err, tt.wantErr) { t.Errorf("Pick() error = %v, wantErr %v", err, tt.wantErr) return } @@ -161,3 +175,136 @@ func (s) TestPickerPickFirstTwo(t *testing.T) { }) } } + +func (s) TestPickerNoRequestHash(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + ring, epStates := testRingAndEndpointStates([]connectivity.State{connectivity.Ready}) + p := &picker{ + ring: ring, + endpointStates: epStates, + } + if _, err := p.Pick(balancer.PickInfo{Ctx: ctx}); err == nil { + t.Errorf("Pick() should have failed with no request hash") + } +} + +func (s) TestPickerRequestHashKey(t *testing.T) { + tests := []struct { + name string + headerValues []string + expectedPick int + }{ + { + name: "header not set", + expectedPick: 0, // Random hash set to 0, which is within (MaxUint64 / 3 * 2, 0] + }, + { + name: "header empty", + headerValues: []string{""}, + expectedPick: 0, // xxhash.Sum64String("value1,value2") is within (MaxUint64 / 3 * 2, 0] + }, + { + name: "header set to one value", + headerValues: []string{"some-value"}, + expectedPick: 1, // xxhash.Sum64String("some-value") is within (0, MaxUint64 / 3] + }, + { + name: "header set to multiple values", + headerValues: []string{"value1", "value2"}, + expectedPick: 2, // xxhash.Sum64String("value1,value2") is within (MaxUint64 / 3, MaxUint64 / 3 * 2] + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + ring, epStates := testRingAndEndpointStates( + []connectivity.State{ + connectivity.Ready, + connectivity.Ready, + connectivity.Ready, + }) + headerName := "some-header" + p := &picker{ + ring: ring, + endpointStates: epStates, + requestHashHeader: headerName, + randUint64: func() uint64 { return 0 }, + } + for _, v := range tt.headerValues { + ctx = metadata.AppendToOutgoingContext(ctx, headerName, v) + } + if res, err := p.Pick(balancer.PickInfo{Ctx: ctx}); err != nil { + t.Errorf("Pick() failed: %v", err) + } else if res.SubConn != testSubConns[tt.expectedPick] { + t.Errorf("Pick() got = %v, want SubConn: %v", res.SubConn, testSubConns[tt.expectedPick]) + } + }) + } +} + +func (s) TestPickerRandomHash(t *testing.T) { + tests := []struct { + name string + hash uint64 + connectivityStates []connectivity.State + wantSC balancer.SubConn + wantErr error + wantSCToConnect balancer.SubConn + hasEndpointInConnectingState bool + }{ + { + name: "header not set, picked is Ready", + connectivityStates: []connectivity.State{connectivity.Ready, connectivity.Idle}, + wantSC: testSubConns[0], + }, + { + name: "header not set, picked is Idle, another is Ready. Connect and pick Ready", + connectivityStates: []connectivity.State{connectivity.Idle, connectivity.Ready}, + wantSC: testSubConns[1], + wantSCToConnect: testSubConns[0], + }, + { + name: "header not set, picked is Idle, there is at least one Connecting", + connectivityStates: []connectivity.State{connectivity.Connecting, connectivity.Idle}, + wantErr: balancer.ErrNoSubConnAvailable, + hasEndpointInConnectingState: true, + }, + { + name: "header not set, all Idle or TransientFailure, connect", + connectivityStates: []connectivity.State{connectivity.TransientFailure, connectivity.Idle}, + wantErr: balancer.ErrNoSubConnAvailable, + wantSCToConnect: testSubConns[1], + }, + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ring, epStates := testRingAndEndpointStates(tt.connectivityStates) + p := &picker{ + ring: ring, + endpointStates: epStates, + requestHashHeader: "some-header", + hasEndpointInConnectingState: tt.hasEndpointInConnectingState, + randUint64: func() uint64 { return 0 }, // always return the first endpoint on the ring. + } + if got, err := p.Pick(balancer.PickInfo{Ctx: ctx}); err != tt.wantErr { + t.Errorf("Pick() error = %v, wantErr %v", err, tt.wantErr) + return + } else if got.SubConn != tt.wantSC { + t.Errorf("Pick() got = %v, want picked SubConn: %v", got, tt.wantSC) + } + if sc := tt.wantSCToConnect; sc != nil { + select { + case <-sc.(*testutils.TestSubConn).ConnectCh: + case <-time.After(defaultTestShortTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", sc) + } + } + }) + } +} diff --git a/xds/internal/balancer/ringhash/ring.go b/xds/internal/balancer/ringhash/ring.go index 978facf14333..d2e003494a97 100644 --- a/xds/internal/balancer/ringhash/ring.go +++ b/xds/internal/balancer/ringhash/ring.go @@ -33,23 +33,23 @@ type ring struct { } type endpointInfo struct { - firstAddr string + hashKey string scaledWeight float64 originalWeight uint32 } type ringEntry struct { - idx int - hash uint64 - firstAddr string - weight uint32 + idx int + hash uint64 + hashKey string + weight uint32 } // newRing creates a ring from the endpoints stored in the EndpointMap. The ring // size is limited by the passed in max/min. // // ring entries will be created for each endpoint, and endpoints with high -// weight (specified by the address) may have multiple entries. +// weight (specified by the endpoint) may have multiple entries. // // For example, for endpoints with weights {a:3, b:3, c:4}, a generated ring of // size 10 could be: @@ -109,8 +109,8 @@ func newRing(endpoints *resolver.EndpointMap[*endpointState], minRingSize, maxRi // updates. idx := 0 for currentHashes < targetHashes { - h := xxhash.Sum64String(epInfo.firstAddr + "_" + strconv.Itoa(idx)) - items = append(items, &ringEntry{hash: h, firstAddr: epInfo.firstAddr, weight: epInfo.originalWeight}) + h := xxhash.Sum64String(epInfo.hashKey + "_" + strconv.Itoa(idx)) + items = append(items, &ringEntry{hash: h, hashKey: epInfo.hashKey, weight: epInfo.originalWeight}) idx++ currentHashes++ } @@ -153,7 +153,7 @@ func normalizeWeights(endpoints *resolver.EndpointMap[*endpointState]) ([]endpoi // non-zero. So, we need not worry about divide by zero error here. nw := float64(epState.weight) / float64(weightSum) ret = append(ret, endpointInfo{ - firstAddr: epState.firstAddr, + hashKey: epState.hashKey, scaledWeight: nw, originalWeight: epState.weight, }) @@ -166,7 +166,7 @@ func normalizeWeights(endpoints *resolver.EndpointMap[*endpointState]) ([]endpoi // where an endpoint is added and then removed, the RPCs will still pick the // same old endpoint. sort.Slice(ret, func(i, j int) bool { - return ret[i].firstAddr < ret[j].firstAddr + return ret[i].hashKey < ret[j].hashKey }) return ret, min } diff --git a/xds/internal/balancer/ringhash/ring_test.go b/xds/internal/balancer/ringhash/ring_test.go index 1d28bccc4bd8..c069a7d72731 100644 --- a/xds/internal/balancer/ringhash/ring_test.go +++ b/xds/internal/balancer/ringhash/ring_test.go @@ -39,9 +39,9 @@ func init() { testEndpoint("c", 4), } testEndpointStateMap = resolver.NewEndpointMap[*endpointState]() - testEndpointStateMap.Set(testEndpoints[0], &endpointState{firstAddr: "a", weight: 3}) - testEndpointStateMap.Set(testEndpoints[1], &endpointState{firstAddr: "b", weight: 3}) - testEndpointStateMap.Set(testEndpoints[2], &endpointState{firstAddr: "c", weight: 4}) + testEndpointStateMap.Set(testEndpoints[0], &endpointState{hashKey: "a", weight: 3}) + testEndpointStateMap.Set(testEndpoints[1], &endpointState{hashKey: "b", weight: 3}) + testEndpointStateMap.Set(testEndpoints[2], &endpointState{hashKey: "c", weight: 4}) } func testEndpoint(addr string, endpointWeight uint32) resolver.Endpoint { @@ -62,7 +62,7 @@ func (s) TestRingNew(t *testing.T) { for _, e := range testEndpoints { var count int for _, ii := range r.items { - if ii.firstAddr == e.Addresses[0].Addr { + if ii.hashKey == hashKey(e) { count++ } } diff --git a/xds/internal/balancer/ringhash/ringhash.go b/xds/internal/balancer/ringhash/ringhash.go index 26623378d4b9..87d26db3c98d 100644 --- a/xds/internal/balancer/ringhash/ringhash.go +++ b/xds/internal/balancer/ringhash/ringhash.go @@ -23,6 +23,7 @@ import ( "encoding/json" "errors" "fmt" + "math/rand/v2" "sort" "sync" @@ -36,6 +37,7 @@ import ( "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/ringhash" "google.golang.org/grpc/serviceconfig" ) @@ -94,6 +96,18 @@ type ringhashBalancer struct { ring *ring } +// hashKey returns the hash key to use for an endpoint. Per gRFC A61, each entry +// in the ring is a hash of the endpoint's hash key concatenated with a +// per-entry unique suffix. +func hashKey(endpoint resolver.Endpoint) string { + if hk := ringhash.HashKey(endpoint); hk != "" { + return hk + } + // If no hash key is set, use the endpoint's first address as the hash key. + // This is the default behavior when no hash key is set. + return endpoint.Addresses[0].Addr +} + // UpdateState intercepts child balancer state updates. It updates the // per-endpoint state stored in the ring, and also the aggregated state based on // the child picker. It also reconciles the endpoint list. It sets @@ -114,31 +128,29 @@ func (b *ringhashBalancer) UpdateState(state balancer.State) { endpoint := childState.Endpoint endpointsSet.Set(endpoint, true) newWeight := getWeightAttribute(endpoint) + hk := hashKey(endpoint) es, ok := b.endpointStates.Get(endpoint) if !ok { - es = &endpointState{ - balancer: childState.Balancer, - weight: newWeight, - firstAddr: endpoint.Addresses[0].Addr, - state: childState.State, + es := &endpointState{ + balancer: childState.Balancer, + hashKey: hk, + weight: newWeight, + state: childState.State, } b.endpointStates.Set(endpoint, es) b.shouldRegenerateRing = true } else { // We have seen this endpoint before and created a `endpointState` - // object for it. If the weight or the first address of the endpoint - // has changed, update the endpoint state map with the new weight. - // This will be used when a new ring is created. + // object for it. If the weight or the hash key of the endpoint has + // changed, update the endpoint state map with the new weight or + // hash key. This will be used when a new ring is created. if oldWeight := es.weight; oldWeight != newWeight { b.shouldRegenerateRing = true es.weight = newWeight } - if es.firstAddr != endpoint.Addresses[0].Addr { - // If the order of the addresses for a given endpoint change, - // that will change the position of the endpoint in the ring. - // -A61 + if es.hashKey != hk { b.shouldRegenerateRing = true - es.firstAddr = endpoint.Addresses[0].Addr + es.hashKey = hk } es.state = childState.State } @@ -244,7 +256,7 @@ func (b *ringhashBalancer) updatePickerLocked() { endpointStates[i] = s } sort.Slice(endpointStates, func(i, j int) bool { - return endpointStates[i].firstAddr < endpointStates[j].firstAddr + return endpointStates[i].hashKey < endpointStates[j].hashKey }) var idleBalancer balancer.ExitIdler for _, es := range endpointStates { @@ -278,7 +290,6 @@ func (b *ringhashBalancer) updatePickerLocked() { } else { newPicker = b.newPickerLocked() } - b.logger.Infof("Pushing new state %v and picker %p", state, newPicker) b.ClientConn.UpdateState(balancer.State{ ConnectivityState: state, Picker: newPicker, @@ -299,11 +310,23 @@ func (b *ringhashBalancer) ExitIdle() { // over to avoid locking the mutex at RPC time. The picker should be // re-generated every time an endpoint state is updated. func (b *ringhashBalancer) newPickerLocked() *picker { - states := make(map[string]balancer.State) + states := make(map[string]endpointState) + hasEndpointConnecting := false for _, epState := range b.endpointStates.Values() { - states[epState.firstAddr] = epState.state + // Copy the endpoint state to avoid races, since ring hash + // mutates the state, weight and hash key in place. + states[epState.hashKey] = *epState + if epState.state.ConnectivityState == connectivity.Connecting { + hasEndpointConnecting = true + } + } + return &picker{ + ring: b.ring, + endpointStates: states, + requestHashHeader: b.config.RequestHashHeader, + hasEndpointInConnectingState: hasEndpointConnecting, + randUint64: rand.Uint64, } - return &picker{ring: b.ring, logger: b.logger, endpointStates: states} } // aggregatedStateLocked returns the aggregated child balancers state @@ -346,8 +369,7 @@ func (b *ringhashBalancer) aggregatedStateLocked() connectivity.State { } // getWeightAttribute is a convenience function which returns the value of the -// weight attribute stored in the BalancerAttributes field of addr, using the -// weightedroundrobin package. +// weight endpoint Attribute. // // When used in the xDS context, the weight attribute is guaranteed to be // non-zero. But, when used in a non-xDS context, the weight attribute could be @@ -361,12 +383,13 @@ func getWeightAttribute(e resolver.Endpoint) uint32 { } type endpointState struct { - // firstAddr is the first address in the endpoint. Per gRFC A61, each entry - // in the ring is an endpoint, positioned based on the hash of the - // endpoint's first address. - firstAddr string - weight uint32 - balancer balancer.ExitIdler + // hashKey is the hash key of the endpoint. Per gRFC A61, each entry in the + // ring is an endpoint, positioned based on the hash of the endpoint's first + // address by default. Per gRFC A76, the hash key of an endpoint may be + // overridden, for example based on EDS endpoint metadata. + hashKey string + weight uint32 + balancer balancer.ExitIdler // state is updated by the balancer while receiving resolver updates from // the channel and picker updates from its children. Access to it is guarded diff --git a/xds/internal/balancer/ringhash/ringhash_test.go b/xds/internal/balancer/ringhash/ringhash_test.go index 5ee45018ca9e..8331c1b630d2 100644 --- a/xds/internal/balancer/ringhash/ringhash_test.go +++ b/xds/internal/balancer/ringhash/ringhash_test.go @@ -83,7 +83,7 @@ func setupTest(t *testing.T, endpoints []resolver.Endpoint) (*testutils.Balancer t.Errorf("Number of child balancers = %d, want = %d", got, want) } for firstAddr, bs := range ringHashPicker.endpointStates { - if got, want := bs.ConnectivityState, connectivity.Idle; got != want { + if got, want := bs.state.ConnectivityState, connectivity.Idle; got != want { t.Errorf("Child balancer connectivity state for address %q = %v, want = %v", firstAddr, got, want) } } @@ -144,7 +144,7 @@ func (s) TestOneEndpoint(t *testing.T) { // only Endpoint which has a single address. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - if _, err := p0.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable { + if _, err := p0.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable { t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) } var sc0 *testutils.TestSubConn @@ -172,7 +172,7 @@ func (s) TestOneEndpoint(t *testing.T) { // Test pick with one backend. p1 := <-cc.NewPickerCh for i := 0; i < 5; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)}) + gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash)}) if gotSCSt.SubConn != sc0 { t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) } @@ -205,7 +205,7 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) { // SubConn. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - if _, err := p0.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable { + if _, err := p0.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable { t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) } @@ -216,7 +216,7 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) { t.Fatalf("Timed out waiting for SubConn creation.") case subConns[1] = <-cc.NewSubConnCh: } - if got, want := subConns[1].Addresses[0].Addr, ring.items[1].firstAddr; got != want { + if got, want := subConns[1].Addresses[0].Addr, ring.items[1].hashKey; got != want { t.Fatalf("SubConn.Address = %v, want = %v", got, want) } select { @@ -224,7 +224,7 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) { case <-time.After(defaultTestTimeout): t.Errorf("timeout waiting for Connect() from SubConn %v", subConns[1]) } - delete(remainingAddrs, ring.items[1].firstAddr) + delete(remainingAddrs, ring.items[1].hashKey) // Turn down the subConn in use. subConns[1].UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting}) @@ -248,9 +248,9 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) { case <-time.After(defaultTestTimeout): t.Errorf("timeout waiting for Connect() from SubConn %v", subConns[1]) } - if scAddr == ring.items[0].firstAddr { + if scAddr == ring.items[0].hashKey { subConns[0] = sc - } else if scAddr == ring.items[2].firstAddr { + } else if scAddr == ring.items[2].hashKey { subConns[2] = sc } @@ -273,9 +273,9 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) { case <-time.After(defaultTestTimeout): t.Errorf("timeout waiting for Connect() from SubConn %v", subConns[1]) } - if scAddr == ring.items[0].firstAddr { + if scAddr == ring.items[0].hashKey { subConns[0] = sc - } else if scAddr == ring.items[2].firstAddr { + } else if scAddr == ring.items[2].hashKey { subConns[2] = sc } sc.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting}) @@ -292,7 +292,7 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) { } p1 := <-cc.NewPickerCh for i := 0; i < 5; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)}) + gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash)}) if gotSCSt.SubConn != subConns[0] { t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, subConns[0]) } @@ -305,7 +305,7 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) { subConns[2].UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready}) p2 := <-cc.NewPickerCh for i := 0; i < 5; i++ { - gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)}) + gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash)}) if gotSCSt.SubConn != subConns[2] { t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, subConns[2]) } @@ -318,7 +318,7 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) { subConns[1].UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready}) p3 := <-cc.NewPickerCh for i := 0; i < 5; i++ { - gotSCSt, _ := p3.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)}) + gotSCSt, _ := p3.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash)}) if gotSCSt.SubConn != subConns[1] { t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, subConns[1]) } @@ -346,7 +346,7 @@ func (s) TestThreeBackendsAffinityMultiple(t *testing.T) { // SubConn. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - if _, err := p0.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable { + if _, err := p0.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable { t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) } // The picked SubConn should be the second in the ring. @@ -356,7 +356,7 @@ func (s) TestThreeBackendsAffinityMultiple(t *testing.T) { t.Fatalf("Timed out waiting for SubConn creation.") case sc0 = <-cc.NewSubConnCh: } - if got, want := sc0.Addresses[0].Addr, ring0.items[1].firstAddr; got != want { + if got, want := sc0.Addresses[0].Addr, ring0.items[1].hashKey; got != want { t.Fatalf("SubConn.Address = %v, want = %v", got, want) } select { @@ -375,7 +375,7 @@ func (s) TestThreeBackendsAffinityMultiple(t *testing.T) { // First hash should always pick sc0. p1 := <-cc.NewPickerCh for i := 0; i < 5; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)}) + gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash)}) if gotSCSt.SubConn != sc0 { t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) } @@ -384,7 +384,7 @@ func (s) TestThreeBackendsAffinityMultiple(t *testing.T) { secondHash := ring0.items[1].hash // secondHash+1 will pick the third SubConn from the ring. testHash2 := secondHash + 1 - if _, err := p0.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash2)}); err != balancer.ErrNoSubConnAvailable { + if _, err := p0.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash2)}); err != balancer.ErrNoSubConnAvailable { t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) } var sc1 *testutils.TestSubConn @@ -393,7 +393,7 @@ func (s) TestThreeBackendsAffinityMultiple(t *testing.T) { t.Fatalf("Timed out waiting for SubConn creation.") case sc1 = <-cc.NewSubConnCh: } - if got, want := sc1.Addresses[0].Addr, ring0.items[2].firstAddr; got != want { + if got, want := sc1.Addresses[0].Addr, ring0.items[2].hashKey; got != want { t.Fatalf("SubConn.Address = %v, want = %v", got, want) } select { @@ -407,14 +407,14 @@ func (s) TestThreeBackendsAffinityMultiple(t *testing.T) { // With the new generated picker, hash2 always picks sc1. p2 := <-cc.NewPickerCh for i := 0; i < 5; i++ { - gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash2)}) + gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash2)}) if gotSCSt.SubConn != sc1 { t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) } } // But the first hash still picks sc0. for i := 0; i < 5; i++ { - gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)}) + gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash)}) if gotSCSt.SubConn != sc0 { t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) } @@ -504,14 +504,14 @@ func (s) TestAddrWeightChange(t *testing.T) { t.Fatalf("new picker after changing address weight has %d entries, want 3", len(p3.(*picker).ring.items)) } for _, i := range p3.(*picker).ring.items { - if i.firstAddr == testBackendAddrStrs[0] { + if i.hashKey == testBackendAddrStrs[0] { if i.weight != 1 { - t.Fatalf("new picker after changing address weight has weight %d for %v, want 1", i.weight, i.firstAddr) + t.Fatalf("new picker after changing address weight has weight %d for %v, want 1", i.weight, i.hashKey) } } - if i.firstAddr == testBackendAddrStrs[1] { + if i.hashKey == testBackendAddrStrs[1] { if i.weight != 2 { - t.Fatalf("new picker after changing address weight has weight %d for %v, want 2", i.weight, i.firstAddr) + t.Fatalf("new picker after changing address weight has weight %d for %v, want 2", i.weight, i.hashKey) } } } @@ -532,6 +532,7 @@ func (s) TestAutoConnectEndpointOnTransientFailure(t *testing.T) { // ringhash won't tell SCs to connect until there is an RPC, so simulate // one now. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + ctx = SetXDSRequestHash(ctx, 0) defer cancel() p0.Pick(balancer.PickInfo{Ctx: ctx}) @@ -690,7 +691,7 @@ func (s) TestAddrBalancerAttributesChange(t *testing.T) { // only Endpoint which has a single address. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - if _, err := p0.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, firstHash)}); err != balancer.ErrNoSubConnAvailable { + if _, err := p0.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, firstHash)}); err != balancer.ErrNoSubConnAvailable { t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) } select { diff --git a/xds/internal/balancer/ringhash/util.go b/xds/internal/balancer/ringhash/util.go index 92bb3ae5b791..371c3c3e3558 100644 --- a/xds/internal/balancer/ringhash/util.go +++ b/xds/internal/balancer/ringhash/util.go @@ -18,23 +18,25 @@ package ringhash -import "context" +import ( + "context" +) -type clusterKey struct{} +type xdsHashKey struct{} -func getRequestHash(ctx context.Context) uint64 { - requestHash, _ := ctx.Value(clusterKey{}).(uint64) - return requestHash +// XDSRequestHash returns the request hash in the context and true if it was set +// from the xDS config selector. If the xDS config selector has not set the hash, +// it returns 0 and false. +func XDSRequestHash(ctx context.Context) (uint64, bool) { + requestHash := ctx.Value(xdsHashKey{}) + if requestHash == nil { + return 0, false + } + return requestHash.(uint64), true } -// GetRequestHashForTesting returns the request hash in the context; to be used -// for testing only. -func GetRequestHashForTesting(ctx context.Context) uint64 { - return getRequestHash(ctx) -} - -// SetRequestHash adds the request hash to the context for use in Ring Hash Load -// Balancing. -func SetRequestHash(ctx context.Context, requestHash uint64) context.Context { - return context.WithValue(ctx, clusterKey{}, requestHash) +// SetXDSRequestHash adds the request hash to the context for use in Ring Hash +// Load Balancing using xDS route hash_policy. +func SetXDSRequestHash(ctx context.Context, requestHash uint64) context.Context { + return context.WithValue(ctx, xdsHashKey{}, requestHash) } diff --git a/xds/internal/resolver/serviceconfig.go b/xds/internal/resolver/serviceconfig.go index 02e6a73eccc7..dbefe9801cfd 100644 --- a/xds/internal/resolver/serviceconfig.go +++ b/xds/internal/resolver/serviceconfig.go @@ -203,7 +203,7 @@ func (cs *configSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*iresolver.RP } lbCtx := clustermanager.SetPickedCluster(rpcInfo.Context, cluster.name) - lbCtx = ringhash.SetRequestHash(lbCtx, cs.generateHash(rpcInfo, rt.hashPolicies)) + lbCtx = ringhash.SetXDSRequestHash(lbCtx, cs.generateHash(rpcInfo, rt.hashPolicies)) config := &iresolver.RPCConfig{ // Communicate to the LB policy the chosen cluster and request hash, if Ring Hash LB policy. diff --git a/xds/internal/resolver/xds_resolver_test.go b/xds/internal/resolver/xds_resolver_test.go index 49ddf8140008..00107e58196e 100644 --- a/xds/internal/resolver/xds_resolver_test.go +++ b/xds/internal/resolver/xds_resolver_test.go @@ -495,8 +495,11 @@ func (s) TestResolverRequestHash(t *testing.T) { if err != nil { t.Fatalf("cs.SelectConfig(): %v", err) } - gotHash := ringhash.GetRequestHashForTesting(res.Context) wantHash := xxhash.Sum64String("/products") + gotHash, ok := ringhash.XDSRequestHash(res.Context) + if !ok { + t.Fatalf("Got no request hash, want: %v", wantHash) + } if gotHash != wantHash { t.Fatalf("Got request hash: %v, want: %v", gotHash, wantHash) } diff --git a/xds/internal/xdsclient/xdsresource/type_eds.go b/xds/internal/xdsclient/xdsresource/type_eds.go index f94a17e7c66a..a7eab2361d31 100644 --- a/xds/internal/xdsclient/xdsresource/type_eds.go +++ b/xds/internal/xdsclient/xdsresource/type_eds.go @@ -52,6 +52,7 @@ type Endpoint struct { Addresses []string HealthStatus EndpointHealthStatus Weight uint32 + HashKey string } // Locality contains information of a locality. diff --git a/xds/internal/xdsclient/xdsresource/unmarshal_eds.go b/xds/internal/xdsclient/xdsresource/unmarshal_eds.go index fd780d6632d2..26e16ce47abe 100644 --- a/xds/internal/xdsclient/xdsresource/unmarshal_eds.go +++ b/xds/internal/xdsclient/xdsresource/unmarshal_eds.go @@ -111,11 +111,31 @@ func parseEndpoints(lbEndpoints []*v3endpointpb.LbEndpoint, uniqueEndpointAddrs HealthStatus: EndpointHealthStatus(lbEndpoint.GetHealthStatus()), Addresses: addrs, Weight: weight, + HashKey: hashKey(lbEndpoint), }) } return endpoints, nil } +// hashKey extracts and returns the hash key from the given LbEndpoint. If no +// hash key is found, it returns an empty string. +func hashKey(lbEndpoint *v3endpointpb.LbEndpoint) string { + // "The xDS resolver, described in A74, will be changed to set the hash_key + // endpoint attribute to the value of LbEndpoint.Metadata envoy.lb hash_key + // field, as described in Envoy's documentation for the ring hash load + // balancer." - A76 + if envconfig.XDSEndpointHashKeyBackwardCompat { + return "" + } + envoyLB := lbEndpoint.GetMetadata().GetFilterMetadata()["envoy.lb"] + if envoyLB != nil { + if h := envoyLB.GetFields()["hash_key"]; h != nil { + return h.GetStringValue() + } + } + return "" +} + func parseEDSRespProto(m *v3endpointpb.ClusterLoadAssignment) (EndpointsUpdate, error) { ret := EndpointsUpdate{} for _, dropPolicy := range m.GetPolicy().GetDropOverloads() { diff --git a/xds/internal/xdsclient/xdsresource/unmarshal_eds_test.go b/xds/internal/xdsclient/xdsresource/unmarshal_eds_test.go index e8df0c3c3593..ae4f639d3ccc 100644 --- a/xds/internal/xdsclient/xdsresource/unmarshal_eds_test.go +++ b/xds/internal/xdsclient/xdsresource/unmarshal_eds_test.go @@ -34,7 +34,9 @@ import ( "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/xds/internal" "google.golang.org/grpc/xds/internal/xdsclient/xdsresource/version" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -333,6 +335,135 @@ func (s) TestEDSParseRespProtoAdditionalAddrs(t *testing.T) { } } +func (s) TestUnmarshalEndpointHashKey(t *testing.T) { + baseCLA := &v3endpointpb.ClusterLoadAssignment{ + Endpoints: []*v3endpointpb.LocalityLbEndpoints{ + { + Locality: &v3corepb.Locality{Region: "r"}, + LbEndpoints: []*v3endpointpb.LbEndpoint{ + { + HostIdentifier: &v3endpointpb.LbEndpoint_Endpoint{ + Endpoint: &v3endpointpb.Endpoint{ + Address: &v3corepb.Address{ + Address: &v3corepb.Address_SocketAddress{ + SocketAddress: &v3corepb.SocketAddress{ + Address: "test-address", + PortSpecifier: &v3corepb.SocketAddress_PortValue{ + PortValue: 8080, + }, + }, + }, + }, + }, + }, + }, + }, + LoadBalancingWeight: &wrapperspb.UInt32Value{Value: 1}, + }, + }, + } + + tests := []struct { + name string + metadata *v3corepb.Metadata + wantHashKey string + compatEnvVar bool + }{ + { + name: "no metadata", + metadata: nil, + wantHashKey: "", + }, + { + name: "empty metadata", + metadata: &v3corepb.Metadata{}, + wantHashKey: "", + }, + { + name: "filter metadata without envoy.lb", + metadata: &v3corepb.Metadata{ + FilterMetadata: map[string]*structpb.Struct{ + "test-filter": {}, + }, + }, + wantHashKey: "", + }, + { + name: "nil envoy.lb", + metadata: &v3corepb.Metadata{ + FilterMetadata: map[string]*structpb.Struct{ + "envoy.lb": nil, + }, + }, + wantHashKey: "", + }, + { + name: "envoy.lb without hash key", + metadata: &v3corepb.Metadata{ + FilterMetadata: map[string]*structpb.Struct{ + "envoy.lb": { + Fields: map[string]*structpb.Value{ + "hash_key": { + Kind: &structpb.Value_NumberValue{NumberValue: 123.0}, + }, + }, + }, + }, + }, + wantHashKey: "", + }, + { + name: "envoy.lb with hash key, compat mode off", + metadata: &v3corepb.Metadata{ + FilterMetadata: map[string]*structpb.Struct{ + "envoy.lb": { + Fields: map[string]*structpb.Value{ + "hash_key": { + Kind: &structpb.Value_StringValue{StringValue: "test-hash-key"}, + }, + }, + }, + }, + }, + wantHashKey: "test-hash-key", + }, + { + name: "envoy.lb with hash key, compat mode on", + metadata: &v3corepb.Metadata{ + FilterMetadata: map[string]*structpb.Struct{ + "envoy.lb": { + Fields: map[string]*structpb.Value{ + "hash_key": { + Kind: &structpb.Value_StringValue{StringValue: "test-hash-key"}, + }, + }, + }, + }, + }, + wantHashKey: "", + compatEnvVar: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testutils.SetEnvConfig(t, &envconfig.XDSEndpointHashKeyBackwardCompat, test.compatEnvVar) + + cla := proto.Clone(baseCLA).(*v3endpointpb.ClusterLoadAssignment) + cla.Endpoints[0].LbEndpoints[0].Metadata = test.metadata + marshalledCLA := testutils.MarshalAny(t, cla) + _, update, err := unmarshalEndpointsResource(marshalledCLA) + if err != nil { + t.Fatalf("unmarshalEndpointsResource() got error = %v, want success", err) + } + got := update.Localities[0].Endpoints[0].HashKey + if got != test.wantHashKey { + t.Errorf("unmarshalEndpointResource() endpoint hash key: got %s, want %s", got, test.wantHashKey) + } + }) + } +} + func (s) TestUnmarshalEndpoints(t *testing.T) { var v3EndpointsAny = testutils.MarshalAny(t, func() *v3endpointpb.ClusterLoadAssignment { clab0 := newClaBuilder("test", nil)