Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion internal/tools/bigquery/bigquerycommon/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ func BQTypeStringFromToolType(toolType string) (string, error) {
return "INT64", nil
case "float":
return "FLOAT64", nil
case "boolean":
case parameters.TypeBool:
return "BOOL", nil
case parameters.TypeMap:
return "STRUCT", nil
default:
return "", fmt.Errorf("unsupported tool parameter type for BigQuery: %s", toolType)
}
Expand Down
202 changes: 124 additions & 78 deletions internal/tools/bigquery/bigquerysql/bigquerysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import (
"fmt"
"net/http"
"reflect"
"regexp"
"strconv"
"strings"

bigqueryapi "cloud.google.com/go/bigquery"
yaml "github.com/goccy/go-yaml"
Expand Down Expand Up @@ -117,43 +117,115 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
}

highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters))

paramsMap := params.AsMap()
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
return nil, util.NewAgentError("unable to extract template params", err)
}

for _, p := range t.Parameters {
highLevelParams, lowLevelParams, err := buildQueryParameters(t.Parameters, paramsMap, newStatement)
if err != nil {
return nil, util.NewAgentError("unable to build query parameters", err)
}

connProps := []*bigqueryapi.ConnectionProperty{}
if source.BigQuerySession() != nil {
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err)
}
if session != nil {
// Add session ID to the connection properties for subsequent calls.
connProps = append(connProps, &bigqueryapi.ConnectionProperty{Key: "session_id", Value: session.ID})
}
}

bqClient, restService, err := source.RetrieveClientAndService(accessToken)
if err != nil {
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
}

dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, connProps, source.GetMaximumBytesBilled())
if err != nil {
return nil, util.ProcessGcpError(err)
}

statementType := dryRunJob.Statistics.Query.StatementType
resp, err := source.RunSQL(ctx, bqClient, newStatement, statementType, highLevelParams, connProps)
if err != nil {
return nil, util.ProcessGcpError(err)
}
return resp, nil
}

func buildQueryParameters(paramsMetadata parameters.Parameters, paramsMap map[string]any, statement string) ([]bigqueryapi.QueryParameter, []*bigqueryrestapi.QueryParameter, error) {
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(paramsMetadata))
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(paramsMetadata))

for _, p := range paramsMetadata {
name := p.GetName()
value := paramsMap[name]

// This block for converting []any to typed slices is still necessary and correct.
if arrayParam, ok := p.(*parameters.ArrayParameter); ok {
arrayParamValue, ok := value.([]any)
if !ok {
return nil, util.NewAgentError(fmt.Sprintf("unable to convert parameter `%s` to []any", name), nil)
}
itemType := arrayParam.GetItems().GetType()
var err error
value, err = parameters.ConvertAnySliceToTyped(arrayParamValue, itemType)
if err != nil {
return nil, util.NewAgentError(fmt.Sprintf("unable to convert parameter `%s` from []any to typed slice", name), err)
// Handle array types: convert []any to typed slices if necessary.
if arrayParam, ok := p.(*parameters.ArrayParameter); ok && value != nil {
if arrayParamValue, ok := value.([]any); ok {
itemType := arrayParam.GetItems().GetType()
var err error
value, err = parameters.ConvertAnySliceToTyped(arrayParamValue, itemType)
if err != nil {
return nil, nil, fmt.Errorf("unable to convert parameter `%s` from []any to typed slice: %w", name, err)
}
}
}

// Determine if the parameter is named or positional for the high-level client.
var paramNameForHighLevel string
if strings.Contains(newStatement, "@"+name) {
isNamed, _ := regexp.MatchString("@"+name+"\\b", statement)
if isNamed {
paramNameForHighLevel = name
}

// Handle nil values for optional parameters by providing typed NULLs.
// BigQuery high-level client requires objects like NullString for NULLs.
// BigQuery low-level REST client requires setting the Null fields.
finalValue := value
isNull := value == nil

if isNull {
switch p.GetType() {
case parameters.TypeString:
finalValue = bigqueryapi.NullString{Valid: false}
case parameters.TypeInt:
finalValue = bigqueryapi.NullInt64{Valid: false}
case parameters.TypeFloat:
finalValue = bigqueryapi.NullFloat64{Valid: false}
case parameters.TypeBool:
finalValue = bigqueryapi.NullBool{Valid: false}
case parameters.TypeArray:
// For arrays, provide a typed nil slice based on items type.
if arrayParam, ok := p.(*parameters.ArrayParameter); ok {
switch arrayParam.GetItems().GetType() {
case parameters.TypeString:
finalValue = []string(nil)
case parameters.TypeInt:
finalValue = []int64(nil)
case parameters.TypeFloat:
finalValue = []float64(nil)
case parameters.TypeBool:
finalValue = []bool(nil)
default:
finalValue = []any(nil)
}
}
case parameters.TypeMap:
finalValue = map[string]any(nil)
}
Comment thread
Deeven-Seru marked this conversation as resolved.
}

// 1. Create the high-level parameter for the final query execution.
highLevelParams = append(highLevelParams, bigqueryapi.QueryParameter{
Name: paramNameForHighLevel,
Value: value,
Value: finalValue,
})

// 2. Create the low-level parameter for the dry run.
Expand All @@ -163,80 +235,54 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
ParameterValue: &bigqueryrestapi.QueryParameterValue{},
}

rv := reflect.ValueOf(value)
if rv.Kind() == reflect.Slice && rv.Type().Elem().Kind() != reflect.Uint8 {
lowLevelParam.ParameterType.Type = "ARRAY"
if isNull {
lowLevelParam.ParameterValue.NullFields = []string{"Value"}
}

// Default item type to FLOAT64 for embeddings, or use config if available.
itemType := "FLOAT64"
if arrayParam, ok := p.(*parameters.ArrayParameter); ok {
if bqType, err := bqutil.BQTypeStringFromToolType(arrayParam.GetItems().GetType()); err == nil {
itemType = bqType
}
if arrayParam, ok := p.(*parameters.ArrayParameter); ok {
lowLevelParam.ParameterType.Type = "ARRAY"
itemType, err := bqutil.BQTypeStringFromToolType(arrayParam.GetItems().GetType())
if err != nil {
return nil, nil, fmt.Errorf("unable to get BigQuery type for parameter %q: %w", name, err)
}
lowLevelParam.ParameterType.ArrayType = &bigqueryrestapi.QueryParameterType{Type: itemType}

// Build the array values.
arrayValues := make([]*bigqueryrestapi.QueryParameterValue, rv.Len())
for i := 0; i < rv.Len(); i++ {
val := rv.Index(i).Interface()

// Prevent precision loss and scientific notation issues
var valStr string
switch v := val.(type) {
case float64:
valStr = strconv.FormatFloat(v, 'f', -1, 64)
case float32:
valStr = strconv.FormatFloat(float64(v), 'f', -1, 32)
default:
valStr = fmt.Sprintf("%v", val)
}

arrayValues[i] = &bigqueryrestapi.QueryParameterValue{
Value: valStr,
if !isNull {
sliceVal := reflect.ValueOf(value)
arrayValues := make([]*bigqueryrestapi.QueryParameterValue, sliceVal.Len())
for i := 0; i < sliceVal.Len(); i++ {
val := sliceVal.Index(i).Interface()

// Prevent precision loss and scientific notation issues
var valStr string
switch v := val.(type) {
case float64:
valStr = strconv.FormatFloat(v, 'f', -1, 64)
case float32:
valStr = strconv.FormatFloat(float64(v), 'f', -1, 32)
default:
valStr = fmt.Sprintf("%v", val)
}

arrayValues[i] = &bigqueryrestapi.QueryParameterValue{
Value: valStr,
}
}
lowLevelParam.ParameterValue.ArrayValues = arrayValues
}
lowLevelParam.ParameterValue.ArrayValues = arrayValues
} else {
// Handle scalar types based on their defined type.
bqType, err := bqutil.BQTypeStringFromToolType(p.GetType())
if err != nil {
return nil, util.NewAgentError("unable to get BigQuery type from tool parameter type", err)
return nil, nil, fmt.Errorf("unable to get BigQuery type for parameter %q: %w", name, err)
}
lowLevelParam.ParameterType.Type = bqType
lowLevelParam.ParameterValue.Value = fmt.Sprintf("%v", value)
if !isNull {
lowLevelParam.ParameterValue.Value = fmt.Sprintf("%v", value)
}
}
lowLevelParams = append(lowLevelParams, lowLevelParam)
}

connProps := []*bigqueryapi.ConnectionProperty{}
if source.BigQuerySession() != nil {
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err)
}
if session != nil {
// Add session ID to the connection properties for subsequent calls.
connProps = append(connProps, &bigqueryapi.ConnectionProperty{Key: "session_id", Value: session.ID})
}
}

bqClient, restService, err := source.RetrieveClientAndService(accessToken)
if err != nil {
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
}

dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, connProps, source.GetMaximumBytesBilled())
if err != nil {
return nil, util.ProcessGcpError(err)
}

statementType := dryRunJob.Statistics.Query.StatementType
resp, err := source.RunSQL(ctx, bqClient, newStatement, statementType, highLevelParams, connProps)
if err != nil {
return nil, util.ProcessGcpError(err)
}
return resp, nil
return highLevelParams, lowLevelParams, nil
}

func formatVectorForBigQuery(vectorFloats []float32) any {
Expand Down
Loading