From b667ac0b87c759e5fd1ac17aad5be3cf777c1ac9 Mon Sep 17 00:00:00 2001 From: Deeven Seru Date: Tue, 21 Apr 2026 16:24:19 +0000 Subject: [PATCH] fix(tool/bigquery): handle omitted optional parameters with typed NULLs --- .../tools/bigquery/bigquerycommon/util.go | 4 +- .../tools/bigquery/bigquerysql/bigquerysql.go | 202 ++++++++++------ .../bigquerysql/bigquerysql_invoke_test.go | 225 ++++++++++++++++++ 3 files changed, 352 insertions(+), 79 deletions(-) create mode 100644 internal/tools/bigquery/bigquerysql/bigquerysql_invoke_test.go diff --git a/internal/tools/bigquery/bigquerycommon/util.go b/internal/tools/bigquery/bigquerycommon/util.go index 7e3a3baa772a..2e84a941a7bb 100644 --- a/internal/tools/bigquery/bigquerycommon/util.go +++ b/internal/tools/bigquery/bigquerycommon/util.go @@ -67,8 +67,10 @@ func BQTypeStringFromToolType(toolType string) (string, error) { return "INT64", nil case "float": return "FLOAT64", nil - case "boolean": + case parameters.TypeBool: return "BOOL", nil + case parameters.TypeMap: + return "STRUCT", nil default: return "", fmt.Errorf("unsupported tool parameter type for BigQuery: %s", toolType) } diff --git a/internal/tools/bigquery/bigquerysql/bigquerysql.go b/internal/tools/bigquery/bigquerysql/bigquerysql.go index add7a351da37..7c8300010d7a 100644 --- a/internal/tools/bigquery/bigquerysql/bigquerysql.go +++ b/internal/tools/bigquery/bigquerysql/bigquerysql.go @@ -19,8 +19,8 @@ import ( "fmt" "net/http" "reflect" + "regexp" "strconv" - "strings" bigqueryapi "cloud.google.com/go/bigquery" yaml "github.com/goccy/go-yaml" @@ -117,43 +117,115 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } - highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters)) - lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters)) - paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { return nil, util.NewAgentError("unable to extract template params", err) } - for _, p := range t.Parameters { + highLevelParams, lowLevelParams, err := buildQueryParameters(t.Parameters, paramsMap, newStatement) + if err != nil { + return nil, util.NewAgentError("unable to build query parameters", err) + } + + connProps := []*bigqueryapi.ConnectionProperty{} + if source.BigQuerySession() != nil { + session, err := source.BigQuerySession()(ctx) + if err != nil { + return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err) + } + if session != nil { + // Add session ID to the connection properties for subsequent calls. + connProps = append(connProps, &bigqueryapi.ConnectionProperty{Key: "session_id", Value: session.ID}) + } + } + + bqClient, restService, err := source.RetrieveClientAndService(accessToken) + if err != nil { + return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err) + } + + dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, connProps, source.GetMaximumBytesBilled()) + if err != nil { + return nil, util.ProcessGcpError(err) + } + + statementType := dryRunJob.Statistics.Query.StatementType + resp, err := source.RunSQL(ctx, bqClient, newStatement, statementType, highLevelParams, connProps) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil +} + +func buildQueryParameters(paramsMetadata parameters.Parameters, paramsMap map[string]any, statement string) ([]bigqueryapi.QueryParameter, []*bigqueryrestapi.QueryParameter, error) { + highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(paramsMetadata)) + lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(paramsMetadata)) + + for _, p := range paramsMetadata { name := p.GetName() value := paramsMap[name] - // This block for converting []any to typed slices is still necessary and correct. - if arrayParam, ok := p.(*parameters.ArrayParameter); ok { - arrayParamValue, ok := value.([]any) - if !ok { - return nil, util.NewAgentError(fmt.Sprintf("unable to convert parameter `%s` to []any", name), nil) - } - itemType := arrayParam.GetItems().GetType() - var err error - value, err = parameters.ConvertAnySliceToTyped(arrayParamValue, itemType) - if err != nil { - return nil, util.NewAgentError(fmt.Sprintf("unable to convert parameter `%s` from []any to typed slice", name), err) + // Handle array types: convert []any to typed slices if necessary. + if arrayParam, ok := p.(*parameters.ArrayParameter); ok && value != nil { + if arrayParamValue, ok := value.([]any); ok { + itemType := arrayParam.GetItems().GetType() + var err error + value, err = parameters.ConvertAnySliceToTyped(arrayParamValue, itemType) + if err != nil { + return nil, nil, fmt.Errorf("unable to convert parameter `%s` from []any to typed slice: %w", name, err) + } } } // Determine if the parameter is named or positional for the high-level client. var paramNameForHighLevel string - if strings.Contains(newStatement, "@"+name) { + isNamed, _ := regexp.MatchString("@"+name+"\\b", statement) + if isNamed { paramNameForHighLevel = name } + // Handle nil values for optional parameters by providing typed NULLs. + // BigQuery high-level client requires objects like NullString for NULLs. + // BigQuery low-level REST client requires setting the Null fields. + finalValue := value + isNull := value == nil + + if isNull { + switch p.GetType() { + case parameters.TypeString: + finalValue = bigqueryapi.NullString{Valid: false} + case parameters.TypeInt: + finalValue = bigqueryapi.NullInt64{Valid: false} + case parameters.TypeFloat: + finalValue = bigqueryapi.NullFloat64{Valid: false} + case parameters.TypeBool: + finalValue = bigqueryapi.NullBool{Valid: false} + case parameters.TypeArray: + // For arrays, provide a typed nil slice based on items type. + if arrayParam, ok := p.(*parameters.ArrayParameter); ok { + switch arrayParam.GetItems().GetType() { + case parameters.TypeString: + finalValue = []string(nil) + case parameters.TypeInt: + finalValue = []int64(nil) + case parameters.TypeFloat: + finalValue = []float64(nil) + case parameters.TypeBool: + finalValue = []bool(nil) + default: + finalValue = []any(nil) + } + } + case parameters.TypeMap: + finalValue = map[string]any(nil) + } + } + // 1. Create the high-level parameter for the final query execution. highLevelParams = append(highLevelParams, bigqueryapi.QueryParameter{ Name: paramNameForHighLevel, - Value: value, + Value: finalValue, }) // 2. Create the low-level parameter for the dry run. @@ -163,80 +235,54 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para ParameterValue: &bigqueryrestapi.QueryParameterValue{}, } - rv := reflect.ValueOf(value) - if rv.Kind() == reflect.Slice && rv.Type().Elem().Kind() != reflect.Uint8 { - lowLevelParam.ParameterType.Type = "ARRAY" + if isNull { + lowLevelParam.ParameterValue.NullFields = []string{"Value"} + } - // Default item type to FLOAT64 for embeddings, or use config if available. - itemType := "FLOAT64" - if arrayParam, ok := p.(*parameters.ArrayParameter); ok { - if bqType, err := bqutil.BQTypeStringFromToolType(arrayParam.GetItems().GetType()); err == nil { - itemType = bqType - } + if arrayParam, ok := p.(*parameters.ArrayParameter); ok { + lowLevelParam.ParameterType.Type = "ARRAY" + itemType, err := bqutil.BQTypeStringFromToolType(arrayParam.GetItems().GetType()) + if err != nil { + return nil, nil, fmt.Errorf("unable to get BigQuery type for parameter %q: %w", name, err) } lowLevelParam.ParameterType.ArrayType = &bigqueryrestapi.QueryParameterType{Type: itemType} - // Build the array values. - arrayValues := make([]*bigqueryrestapi.QueryParameterValue, rv.Len()) - for i := 0; i < rv.Len(); i++ { - val := rv.Index(i).Interface() - - // Prevent precision loss and scientific notation issues - var valStr string - switch v := val.(type) { - case float64: - valStr = strconv.FormatFloat(v, 'f', -1, 64) - case float32: - valStr = strconv.FormatFloat(float64(v), 'f', -1, 32) - default: - valStr = fmt.Sprintf("%v", val) - } - - arrayValues[i] = &bigqueryrestapi.QueryParameterValue{ - Value: valStr, + if !isNull { + sliceVal := reflect.ValueOf(value) + arrayValues := make([]*bigqueryrestapi.QueryParameterValue, sliceVal.Len()) + for i := 0; i < sliceVal.Len(); i++ { + val := sliceVal.Index(i).Interface() + + // Prevent precision loss and scientific notation issues + var valStr string + switch v := val.(type) { + case float64: + valStr = strconv.FormatFloat(v, 'f', -1, 64) + case float32: + valStr = strconv.FormatFloat(float64(v), 'f', -1, 32) + default: + valStr = fmt.Sprintf("%v", val) + } + + arrayValues[i] = &bigqueryrestapi.QueryParameterValue{ + Value: valStr, + } } + lowLevelParam.ParameterValue.ArrayValues = arrayValues } - lowLevelParam.ParameterValue.ArrayValues = arrayValues } else { - // Handle scalar types based on their defined type. bqType, err := bqutil.BQTypeStringFromToolType(p.GetType()) if err != nil { - return nil, util.NewAgentError("unable to get BigQuery type from tool parameter type", err) + return nil, nil, fmt.Errorf("unable to get BigQuery type for parameter %q: %w", name, err) } lowLevelParam.ParameterType.Type = bqType - lowLevelParam.ParameterValue.Value = fmt.Sprintf("%v", value) + if !isNull { + lowLevelParam.ParameterValue.Value = fmt.Sprintf("%v", value) + } } lowLevelParams = append(lowLevelParams, lowLevelParam) } - - connProps := []*bigqueryapi.ConnectionProperty{} - if source.BigQuerySession() != nil { - session, err := source.BigQuerySession()(ctx) - if err != nil { - return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err) - } - if session != nil { - // Add session ID to the connection properties for subsequent calls. - connProps = append(connProps, &bigqueryapi.ConnectionProperty{Key: "session_id", Value: session.ID}) - } - } - - bqClient, restService, err := source.RetrieveClientAndService(accessToken) - if err != nil { - return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err) - } - - dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, connProps, source.GetMaximumBytesBilled()) - if err != nil { - return nil, util.ProcessGcpError(err) - } - - statementType := dryRunJob.Statistics.Query.StatementType - resp, err := source.RunSQL(ctx, bqClient, newStatement, statementType, highLevelParams, connProps) - if err != nil { - return nil, util.ProcessGcpError(err) - } - return resp, nil + return highLevelParams, lowLevelParams, nil } func formatVectorForBigQuery(vectorFloats []float32) any { diff --git a/internal/tools/bigquery/bigquerysql/bigquerysql_invoke_test.go b/internal/tools/bigquery/bigquerysql/bigquerysql_invoke_test.go new file mode 100644 index 000000000000..6169588e2321 --- /dev/null +++ b/internal/tools/bigquery/bigquerysql/bigquerysql_invoke_test.go @@ -0,0 +1,225 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bigquerysql + +import ( + "reflect" + "testing" + + bigqueryapi "cloud.google.com/go/bigquery" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/mcp-toolbox/internal/util/parameters" +) + +func TestBuildQueryParameters(t *testing.T) { + required := false + paramsMetadata := parameters.Parameters{ + ¶meters.StringParameter{ + CommonParameter: parameters.CommonParameter{ + Name: "opt_string", + Type: parameters.TypeString, + Required: &required, + }, + }, + ¶meters.IntParameter{ + CommonParameter: parameters.CommonParameter{ + Name: "opt_int", + Type: parameters.TypeInt, + Required: &required, + }, + }, + ¶meters.FloatParameter{ + CommonParameter: parameters.CommonParameter{ + Name: "opt_float", + Type: parameters.TypeFloat, + Required: &required, + }, + }, + ¶meters.BooleanParameter{ + CommonParameter: parameters.CommonParameter{ + Name: "opt_bool", + Type: parameters.TypeBool, + Required: &required, + }, + }, + ¶meters.ArrayParameter{ + CommonParameter: parameters.CommonParameter{ + Name: "opt_array", + Type: parameters.TypeArray, + Required: &required, + }, + Items: parameters.NewStringParameter("item", ""), + }, + } + + paramsMap := map[string]any{ + // All are omitted + } + statement := "SELECT @opt_string, @opt_int, @opt_float, @opt_bool, @opt_array" + + gotHigh, gotLow, err := buildQueryParameters(paramsMetadata, paramsMap, statement) + if err != nil { + t.Fatalf("buildQueryParameters failed: %v", err) + } + + wantHigh := []bigqueryapi.QueryParameter{ + {Name: "opt_string", Value: bigqueryapi.NullString{Valid: false}}, + {Name: "opt_int", Value: bigqueryapi.NullInt64{Valid: false}}, + {Name: "opt_float", Value: bigqueryapi.NullFloat64{Valid: false}}, + {Name: "opt_bool", Value: bigqueryapi.NullBool{Valid: false}}, + {Name: "opt_array", Value: []string(nil)}, + } + + if diff := cmp.Diff(wantHigh, gotHigh); diff != "" { + t.Errorf("High-level parameters mismatch (-want +got):\n%s", diff) + } + + // For low-level, we check the NullFields slice + for i, p := range gotLow { + foundNull := false + for _, field := range p.ParameterValue.NullFields { + if field == "Value" { + foundNull = true + break + } + } + if !foundNull { + t.Errorf("Low-level parameter %d (%s) NullFields does not contain 'Value', want true", i, p.Name) + } + } + + // Verify one non-null case + paramsMapFull := map[string]any{ + "opt_string": "hello", + } + gotHighFull, gotLowFull, _ := buildQueryParameters(paramsMetadata, paramsMapFull, statement) + + if gotHighFull[0].Value != "hello" { + t.Errorf("Expected string value 'hello', got %v", gotHighFull[0].Value) + } + if len(gotLowFull[0].ParameterValue.NullFields) > 0 { + t.Error("Expected low-level NullFields to be empty for non-null value") + } + if gotLowFull[0].ParameterValue.Value != "hello" { + t.Errorf("Expected low-level string value 'hello', got %s", gotLowFull[0].ParameterValue.Value) + } +} + +func TestBuildQueryParameters_Types(t *testing.T) { + // Mixed cases + required := false + paramsMetadata := parameters.Parameters{ + ¶meters.StringParameter{CommonParameter: parameters.CommonParameter{Name: "s", Type: "string", Required: &required}}, + ¶meters.IntParameter{CommonParameter: parameters.CommonParameter{Name: "i", Type: "integer", Required: &required}}, + } + paramsMap := map[string]any{ + "s": "val", + // i is omitted + } + statement := "SELECT @s, @i" + + gotHigh, gotLow, _ := buildQueryParameters(paramsMetadata, paramsMap, statement) + + expectedHigh := []bigqueryapi.QueryParameter{ + {Name: "s", Value: "val"}, + {Name: "i", Value: bigqueryapi.NullInt64{Valid: false}}, + } + + if diff := cmp.Diff(expectedHigh, gotHigh, cmp.AllowUnexported(bigqueryapi.NullInt64{})); diff != "" { + t.Errorf("High-level parameters mismatch (-want +got):\n%s", diff) + } + + if len(gotLow[0].ParameterValue.NullFields) > 0 { + t.Error("Expected low-level NullFields to be empty for 's'") + } + foundNull := false + for _, field := range gotLow[1].ParameterValue.NullFields { + if field == "Value" { + foundNull = true + break + } + } + if !foundNull { + t.Error("Expected low-level NullFields to contain 'Value' for 'i'") + } +} + +func TestBuildQueryParameters_EdgeCases(t *testing.T) { + required := false + paramsMetadata := parameters.Parameters{ + ¶meters.StringParameter{ + CommonParameter: parameters.CommonParameter{ + Name: "user", + Type: parameters.TypeString, + Required: &required, + }, + }, + ¶meters.StringParameter{ + CommonParameter: parameters.CommonParameter{ + Name: "user_id", + Type: parameters.TypeString, + Required: &required, + }, + }, + ¶meters.MapParameter{ + CommonParameter: parameters.CommonParameter{ + Name: "opt_map", + Type: parameters.TypeMap, + Required: &required, + }, + }, + } + + paramsMap := map[string]any{ + "user_id": "123", + // user is omitted, and opt_map is omitted + } + // "user" should NOT be identified as named because it's only a prefix of "user_id". + statement := "SELECT @user_id, @opt_map" + + gotHigh, gotLow, err := buildQueryParameters(paramsMetadata, paramsMap, statement) + if err != nil { + t.Fatalf("buildQueryParameters failed: %v", err) + } + + // 1. Check named parameter isolation + // gotHigh[0] is "user" + if gotHigh[0].Name != "" { + t.Errorf("Expected 'user' to be positional (empty name), got %q", gotHigh[0].Name) + } + // gotHigh[1] is "user_id" + if gotHigh[1].Name != "user_id" { + t.Errorf("Expected 'user_id' to be named, got %q", gotHigh[1].Name) + } + + // 2. Check TypeMap NULL handling + // gotHigh[2] is "opt_map" + if gotHigh[2].Value == nil || !reflect.ValueOf(gotHigh[2].Value).IsNil() { + t.Errorf("Expected 'opt_map' Value to be a nil map, got %v", gotHigh[2].Value) + } + if gotLow[2].ParameterType.Type != "STRUCT" { + t.Errorf("Expected low-level 'opt_map' type to be STRUCT, got %q", gotLow[2].ParameterType.Type) + } + foundNull := false + for _, field := range gotLow[2].ParameterValue.NullFields { + if field == "Value" { + foundNull = true + break + } + } + if !foundNull { + t.Error("Expected low-level 'opt_map' NullFields to contain 'Value'") + } +}