diff --git a/internal/envconfig/envconfig.go b/internal/envconfig/envconfig.go index 6414ee4bbe27..61d4e66e6220 100644 --- a/internal/envconfig/envconfig.go +++ b/internal/envconfig/envconfig.go @@ -82,6 +82,11 @@ var ( // This feature is defined in gRFC A81 and is enabled by setting the // environment variable GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE to "true". XDSAuthorityRewrite = boolFromEnv("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", false) + + // XDSLiteralAuthorityRewrite indicates whether xDS host_rewrite_literal field + // is honored. This feature is defined in gRFC A111 and is enabled by setting the + // environment variable GRPC_EXPERIMENTAL_XDS_LITERAL_AUTHORITY_REWRITE to "true". + XDSLiteralAuthorityRewrite = boolFromEnv("GRPC_EXPERIMENTAL_XDS_LITERAL_AUTHORITY_REWRITE", false) ) func boolFromEnv(envVar string, def bool) bool { diff --git a/internal/testutils/balancer.go b/internal/testutils/balancer.go index 369939b9187d..f7f96b4216ea 100644 --- a/internal/testutils/balancer.go +++ b/internal/testutils/balancer.go @@ -226,7 +226,7 @@ func (tcc *BalancerClientConn) WaitForErrPicker(ctx context.Context) error { case <-ctx.Done(): return errors.New("timeout when waiting for an error picker") case picker := <-tcc.NewPickerCh: - if _, perr := picker.Pick(balancer.PickInfo{}); perr == nil { + if _, perr := picker.Pick(balancer.PickInfo{Ctx: context.Background()}); perr == nil { return fmt.Errorf("balancer returned a picker which is not an error picker") } } @@ -244,7 +244,8 @@ func (tcc *BalancerClientConn) WaitForPickerWithErr(ctx context.Context, want er case <-ctx.Done(): return fmt.Errorf("timeout when waiting for an error picker with %v; last picker error: %v", want, lastErr) case picker := <-tcc.NewPickerCh: - if _, lastErr = picker.Pick(balancer.PickInfo{}); lastErr != nil && lastErr.Error() == want.Error() { + pi := balancer.PickInfo{Ctx: context.Background()} + if _, lastErr = picker.Pick(pi); lastErr != nil && lastErr.Error() == want.Error() { return nil } } @@ -292,7 +293,7 @@ func (tcc *BalancerClientConn) WaitForRoundRobinPicker(ctx context.Context, want } var pickerErr error if err := IsRoundRobin(want, func() balancer.SubConn { - sc, err := p.Pick(balancer.PickInfo{}) + sc, err := p.Pick(balancer.PickInfo{Ctx: context.Background()}) if err != nil { pickerErr = err } else if sc.Done != nil { @@ -390,7 +391,7 @@ func IsRoundRobin(want []balancer.SubConn, f func() balancer.SubConn) error { // Every invocation of the returned function results in a new pick. func SubConnFromPicker(p balancer.Picker) func() balancer.SubConn { return func() balancer.SubConn { - scst, _ := p.Pick(balancer.PickInfo{}) + scst, _ := p.Pick(balancer.PickInfo{Ctx: context.Background()}) return scst.SubConn } } diff --git a/internal/xds/balancer/clusterimpl/picker.go b/internal/xds/balancer/clusterimpl/picker.go index ddead1375d06..039079138e67 100644 --- a/internal/xds/balancer/clusterimpl/picker.go +++ b/internal/xds/balancer/clusterimpl/picker.go @@ -147,11 +147,18 @@ func (d *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { // be used. lID = scw.localityID - if scw.hostname != "" && autoHostRewriteEnabled(info.Ctx) { + authorityOverride := "" + if hostRewriteLiteral(info.Ctx) != "" { + authorityOverride = hostRewriteLiteral(info.Ctx) + } else if scw.hostname != "" && autoHostRewriteEnabled(info.Ctx) { + authorityOverride = scw.hostname + } + + if authorityOverride != "" { if pr.Metadata == nil { - pr.Metadata = metadata.Pairs(":authority", scw.hostname) + pr.Metadata = metadata.Pairs(":authority", authorityOverride) } else { - pr.Metadata.Set(":authority", scw.hostname) + pr.Metadata.Set(":authority", authorityOverride) } } } @@ -225,3 +232,19 @@ func AutoHostRewriteEnabledForTesting(ctx context.Context) bool { func EnableAutoHostRewrite(ctx context.Context) context.Context { return context.WithValue(ctx, autoHostRewriteKey{}, true) } + +// hostRewriteLiteralKey is the context key used to store the value of +// route's hostRewriteLiteral in the RPC context. +type hostRewriteLiteralKey struct{} + +// SetHostRewriteLiteral sets a hostRewriteLiteral value to the context for the +// xds_cluster_impl LB policy to pick. +func SetHostRewriteLiteral(ctx context.Context, hostRewriteLiteral string) context.Context { + return context.WithValue(ctx, hostRewriteLiteralKey{}, hostRewriteLiteral) +} + +// hostRewriteLiteral returns the value of hostRewriteLiteral set in the ctx. +func hostRewriteLiteral(ctx context.Context) string { + v, _ := ctx.Value(hostRewriteLiteralKey{}).(string) + return v +} diff --git a/internal/xds/balancer/clusterimpl/tests/balancer_test.go b/internal/xds/balancer/clusterimpl/tests/balancer_test.go index f0418aba3030..2a38cfd6dc2b 100644 --- a/internal/xds/balancer/clusterimpl/tests/balancer_test.go +++ b/internal/xds/balancer/clusterimpl/tests/balancer_test.go @@ -48,6 +48,7 @@ import ( "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/internal/testutils/xds/e2e" "google.golang.org/grpc/internal/testutils/xds/fakeserver" + "google.golang.org/grpc/internal/xds/bootstrap" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/resolver" @@ -1468,3 +1469,150 @@ func (s) TestAuthorityOverridingWithTLS(t *testing.T) { }) } } + +// Tests that if a client receives its configuration via xDS and it has the host +// rewrite literal value configured, the authority pseudo-header in the RPC call +// is set appropriately. Also verifies that CallAuthority call option takes +// precedence. +// +// Per gRFC A111, the host rewrite literal feature should only be active if two +// conditions are met: +// 1. The environment variable (XDSAuthorityLiteralRewrite) is enabled. +// 2. The xDS server is marked as "trusted_xds_server" in the bootstrap config. +func (s) TestHostRewriteLiteral(t *testing.T) { + const ( + dialTargetName = "original.target.example.com" // Used in xds:/// URI and matched by Listener & VirtualHost + hostRewriteValue = "rewritten.host.example.com" + userAuthorityOverride = "user-override.com" + ) + + for _, tt := range []struct { + name string + xdsAuthorityRewriteEnv bool + trustedXdsServer bool + wantAuthority string + }{ + { + name: "EnvDisabled_NonTrustedServer", + xdsAuthorityRewriteEnv: false, + trustedXdsServer: false, + wantAuthority: dialTargetName, + }, + { + name: "EnvDisabled_TrustedServer", + xdsAuthorityRewriteEnv: false, + trustedXdsServer: true, + wantAuthority: dialTargetName, + }, + { + name: "EnvEnabled_NonTrustedServer", + xdsAuthorityRewriteEnv: true, + trustedXdsServer: false, + wantAuthority: dialTargetName, + }, + { + name: "EnvEnabled_TrustedServer", + xdsAuthorityRewriteEnv: true, + trustedXdsServer: true, + wantAuthority: hostRewriteValue, + }, + } { + t.Run(tt.name, func(t *testing.T) { + testutils.SetEnvConfig(t, &envconfig.XDSLiteralAuthorityRewrite, tt.xdsAuthorityRewriteEnv) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + var gotAuthority string + stubServer := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + if md, ok := metadata.FromIncomingContext(ctx); ok { + if authVals := md.Get(":authority"); len(authVals) > 0 { + gotAuthority = authVals[0] + } + } + return &testpb.Empty{}, nil + }, + } + if err := stubServer.StartServer(); err != nil { + t.Fatalf("Failed to start stub server: %v", err) + } + defer stubServer.Stop() + t.Logf("Stub server listening at: %s", stubServer.Address) + + mgmtServer := e2e.StartManagementServer(t, e2e.ManagementServerOptions{AllowResourceSubset: true}) + defer mgmtServer.Stop() + + nodeID := uuid.New().String() + + // Build bootstrap configuration with or without trusted_xds_server + serverFeatures := "[]" + if tt.trustedXdsServer { + serverFeatures = `["trusted_xds_server"]` + } + bootstrapContents, err := bootstrap.NewContentsForTesting(bootstrap.ConfigOptionsForTesting{ + Servers: []byte(fmt.Sprintf(`[{ + "server_uri": "passthrough:///%s", + "channel_creds": [{"type": "insecure"}], + "server_features": %s + }]`, mgmtServer.Address, serverFeatures)), + Node: []byte(fmt.Sprintf(`{"id": "%s"}`, nodeID)), + }) + if err != nil { + t.Fatalf("Failed to create bootstrap configuration: %v", err) + } + + resolverBuilder, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bootstrapContents) + if err != nil { + t.Fatalf("Failed to create xDS resolver for testing: %v", err) + } + + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: dialTargetName, + NodeID: nodeID, + Host: "localhost", + Port: testutils.ParsePort(t, stubServer.Address), + }) + + // Modify the route to enable HostRewriteLiteral. + resources.Routes[0].VirtualHosts[0].Routes[0].GetRoute().HostRewriteSpecifier = &v3routepb.RouteAction_HostRewriteLiteral{ + HostRewriteLiteral: hostRewriteValue, + } + + if err := mgmtServer.Update(ctx, resources); err != nil { + t.Fatalf("xDS server update failed: %v", err) + } + + cc, err := grpc.NewClient(fmt.Sprintf("xds:///%s", dialTargetName), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithResolvers(resolverBuilder), + ) + if err != nil { + t.Fatalf("Failed to create gRPC client: %v", err) + } + defer cc.Close() + + client := testpb.NewTestServiceClient(cc) + + // Test RPC without CallAuthority override + _, err = client.EmptyCall(ctx, &testpb.Empty{}) + if err != nil { + t.Fatalf("EmptyCall failed: %v", err) + } + + if gotAuthority != tt.wantAuthority { + t.Errorf("Server received :authority header %q, want %q", gotAuthority, tt.wantAuthority) + } + + // Test RPC with CallAuthority override - should always use the override + _, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(userAuthorityOverride)) + if err != nil { + t.Fatalf("EmptyCall with CallAuthority failed: %v", err) + } + + if gotAuthority != userAuthorityOverride { + t.Errorf("Server received :authority header %q, want %q", gotAuthority, userAuthorityOverride) + } + }) + } +} diff --git a/internal/xds/resolver/serviceconfig.go b/internal/xds/resolver/serviceconfig.go index 40a423f1f1e2..7d367115af9e 100644 --- a/internal/xds/resolver/serviceconfig.go +++ b/internal/xds/resolver/serviceconfig.go @@ -109,13 +109,14 @@ type routeCluster struct { } type route struct { - m *xdsresource.CompositeMatcher // converted from route matchers - actionType xdsresource.RouteActionType // holds route action type - clusters wrr.WRR // holds *routeCluster entries - maxStreamDuration time.Duration - retryConfig *xdsresource.RetryConfig - hashPolicies []*xdsresource.HashPolicy - autoHostRewrite bool + m *xdsresource.CompositeMatcher // converted from route matchers + actionType xdsresource.RouteActionType // holds route action type + clusters wrr.WRR // holds *routeCluster entries + maxStreamDuration time.Duration + retryConfig *xdsresource.RetryConfig + hashPolicies []*xdsresource.HashPolicy + autoHostRewrite bool + hostRewriteLiteral string } func (r route) String() string { @@ -202,6 +203,9 @@ func (cs *configSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*iresolver.RP if rt.autoHostRewrite { lbCtx = clusterimpl.EnableAutoHostRewrite(lbCtx) } + if rt.hostRewriteLiteral != "" { + lbCtx = clusterimpl.SetHostRewriteLiteral(lbCtx, rt.hostRewriteLiteral) + } config := &iresolver.RPCConfig{ // Communicate to the LB policy the chosen cluster and request hash, if Ring Hash LB policy. diff --git a/internal/xds/resolver/xds_resolver.go b/internal/xds/resolver/xds_resolver.go index a38b292b006d..45cc1bb1035c 100644 --- a/internal/xds/resolver/xds_resolver.go +++ b/internal/xds/resolver/xds_resolver.go @@ -397,6 +397,7 @@ func (r *xdsResolver) newConfigSelector() (*configSelector, error) { cs.routes[i].retryConfig = rt.RetryConfig cs.routes[i].hashPolicies = rt.HashPolicies cs.routes[i].autoHostRewrite = rt.AutoHostRewrite + cs.routes[i].hostRewriteLiteral = rt.HostRewriteLiteral } // Account for this config selector's clusters. Do this after no further diff --git a/internal/xds/xdsclient/xdsresource/type_rds.go b/internal/xds/xdsclient/xdsresource/type_rds.go index 48e8051b3222..9da7d04567d3 100644 --- a/internal/xds/xdsclient/xdsresource/type_rds.go +++ b/internal/xds/xdsclient/xdsresource/type_rds.go @@ -150,9 +150,14 @@ type Route struct { // ClusterSpecifierPlugin is the name of the Cluster Specifier Plugin that // this Route is linked to, if specified by xDS. ClusterSpecifierPlugin string + // AutoHostRewrite indicates that the ":authority" header can be rewritten // to the hostname of the upstream endpoint. AutoHostRewrite bool + + // HostRewriteLiteral contains the Host override that is set by the Route + // Action host_rewrite_literal. + HostRewriteLiteral string } // WeightedCluster contains settings for an xds ActionType.WeightedCluster. diff --git a/internal/xds/xdsclient/xdsresource/unmarshal_rds.go b/internal/xds/xdsclient/xdsresource/unmarshal_rds.go index d988b4e77f9a..60eb5f799985 100644 --- a/internal/xds/xdsclient/xdsresource/unmarshal_rds.go +++ b/internal/xds/xdsclient/xdsresource/unmarshal_rds.go @@ -304,10 +304,13 @@ func routesProtoToSlice(routes []*v3routepb.Route, csps map[string]clusterspecif case *v3routepb.Route_Route: action := r.GetRoute() - if envconfig.XDSAuthorityRewrite { - if opts != nil && opts.ServerConfig != nil && opts.ServerConfig.SupportsServerFeature(xdsclient.ServerFeatureTrustedXDSServer) { + if opts != nil && opts.ServerConfig != nil && opts.ServerConfig.SupportsServerFeature(xdsclient.ServerFeatureTrustedXDSServer) { + if envconfig.XDSAuthorityRewrite { route.AutoHostRewrite = action.GetAutoHostRewrite().GetValue() } + if envconfig.XDSLiteralAuthorityRewrite { + route.HostRewriteLiteral = action.GetHostRewriteLiteral() + } } // Hash Policies are only applicable for a Ring Hash LB.