diff --git a/cmd/internal/config_test.go b/cmd/internal/config_test.go index 9f9ac7438ccc..8797373220c8 100644 --- a/cmd/internal/config_test.go +++ b/cmd/internal/config_test.go @@ -1638,7 +1638,7 @@ func TestPrebuiltTools(t *testing.T) { wantToolset: server.ToolsetConfigs{ "cloud_sql_postgres_admin_tools": tools.ToolsetConfig{ Name: "cloud_sql_postgres_admin_tools", - ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup", "restore_backup"}, + ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup", "restore_backup", "execute_sql_many"}, }, }, }, diff --git a/cmd/internal/imports.go b/cmd/internal/imports.go index 698b74b567fd..dd440178040e 100644 --- a/cmd/internal/imports.go +++ b/cmd/internal/imports.go @@ -198,6 +198,7 @@ import ( _ "github.com/googleapis/mcp-toolbox/internal/tools/oracle/oraclesql" _ "github.com/googleapis/mcp-toolbox/internal/tools/postgres/postgresdatabaseoverview" _ "github.com/googleapis/mcp-toolbox/internal/tools/postgres/postgresexecutesql" + _ "github.com/googleapis/mcp-toolbox/internal/tools/postgres/postgresexecutesqlmany" _ "github.com/googleapis/mcp-toolbox/internal/tools/postgres/postgresgetcolumncardinality" _ "github.com/googleapis/mcp-toolbox/internal/tools/postgres/postgreslistactivequeries" _ "github.com/googleapis/mcp-toolbox/internal/tools/postgres/postgreslistavailableextensions" @@ -220,6 +221,7 @@ import ( _ "github.com/googleapis/mcp-toolbox/internal/tools/postgres/postgreslongrunningtransactions" _ "github.com/googleapis/mcp-toolbox/internal/tools/postgres/postgresreplicationstats" _ "github.com/googleapis/mcp-toolbox/internal/tools/postgres/postgressql" + _ "github.com/googleapis/mcp-toolbox/internal/tools/postgres/postgressqlmany" _ "github.com/googleapis/mcp-toolbox/internal/tools/redis" _ "github.com/googleapis/mcp-toolbox/internal/tools/serverlessspark/serverlesssparkcancelbatch" _ "github.com/googleapis/mcp-toolbox/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch" diff --git a/docs/en/integrations/postgres/tools/postgres-execute-sql-many.md b/docs/en/integrations/postgres/tools/postgres-execute-sql-many.md new file mode 100644 index 000000000000..5a20906387bb --- /dev/null +++ b/docs/en/integrations/postgres/tools/postgres-execute-sql-many.md @@ -0,0 +1,48 @@ +--- +title: "postgres-execute-sql-many" +type: docs +weight: 1 +description: > + A "postgres-execute-sql-many" tool executes a SQL statement against a specific Cloud SQL Postgres instance provided at runtime. +--- + +## About + +A `postgres-execute-sql-many` tool executes a SQL statement against a specific Cloud SQL Postgres instance identified by project, instance, and database parameters provided at runtime. + +This tool is useful for executing arbitrary SQL queries across multiple database instances without needing to configure a separate tool for each instance. + +> **Note:** This tool is intended for developer assistant workflows with human-in-the-loop and shouldn't be used for production agents. + +## Compatible Sources + +{{< compatible-sources others="integrations/cloud-sql-admin" >}} + +## Parameters + +The following parameters are required at runtime when invoking the tool: + +| **Parameter** | **Type** | **Description** | +| :------------ | :------- | :---------------------------- | +| `project` | string | The GCP project ID. | +| `instance` | string | The Cloud SQL instance ID. | +| `database` | string | The database name. | +| `sql` | string | The SQL statement to execute. | + +## Example + +```yaml +kind: tool +name: execute_sql_many_tool +type: postgres-execute-sql-many +source: my-cloud-sql-admin-source +description: Use this tool to execute sql statement on a specific instance. +``` + +## Reference + +| **field** | **type** | **required** | **description** | +| :---------- | :------- | :----------- | :------------------------------------------------- | +| type | string | true | Must be "postgres-execute-sql-many". | +| source | string | true | Name of the `cloud-sql-admin` source. | +| description | string | true | Description of the tool that is passed to the LLM. | diff --git a/docs/en/integrations/postgres/tools/postgres-sql-many.md b/docs/en/integrations/postgres/tools/postgres-sql-many.md new file mode 100644 index 000000000000..a4e1be933636 --- /dev/null +++ b/docs/en/integrations/postgres/tools/postgres-sql-many.md @@ -0,0 +1,56 @@ +--- +title: "postgres-sql-many" +type: docs +weight: 1 +description: > + A "postgres-sql-many" tool executes a predefined SQL statement against a specific Cloud SQL Postgres instance provided at runtime. +--- + +## About + +A `postgres-sql-many` tool executes a predefined SQL statement against a specific Cloud SQL Postgres instance identified by project, instance, and database parameters provided at runtime. + +It supports `templateParameters` to allow dynamic values to be injected into the query at runtime. + +> **Note:** This tool is intended for developer assistant workflows with human-in-the-loop and shouldn't be used for production agents. + +## Compatible Sources + +{{< compatible-sources others="integrations/cloud-sql-admin" >}} + +## Parameters + +The following parameters are required at runtime when invoking the tool: + +| **Parameter** | **Type** | **Description** | +| :------------ | :------- | :------------------------- | +| `project` | string | The GCP project ID. | +| `instance` | string | The Cloud SQL instance ID. | +| `database` | string | The database name. | + +Additional parameters may be required based on the `templateParameters` configured in the tool definition. + +## Example + +```yaml +kind: tool +name: get_user_many_tool +type: postgres-sql-many +source: my-cloud-sql-admin-source +description: Use this tool to get user details from a specific instance. +statement: SELECT * FROM users WHERE id = {{.user_id}} +templateParameters: + - name: user_id + type: string + description: The ID of the user. +``` + +## Reference + +| **field** | **type** | **required** | **description** | +| :----------------- | :------- | :----------- | :------------------------------------------------- | +| type | string | true | Must be "postgres-sql-many". | +| source | string | true | Name of the `cloud-sql-admin` source. | +| description | string | true | Description of the tool that is passed to the LLM. | +| statement | string | true | The SQL statement template to execute. | +| templateParameters | list | false | List of parameters used in the statement template. | diff --git a/internal/prebuiltconfigs/tools/cloud-sql-postgres-admin.yaml b/internal/prebuiltconfigs/tools/cloud-sql-postgres-admin.yaml index c64cb803da7d..338ff5fcab32 100644 --- a/internal/prebuiltconfigs/tools/cloud-sql-postgres-admin.yaml +++ b/internal/prebuiltconfigs/tools/cloud-sql-postgres-admin.yaml @@ -52,6 +52,9 @@ tools: restore_backup: kind: cloud-sql-restore-backup source: cloud-sql-admin-source + execute_sql_many: + kind: postgres-execute-sql-many + source: cloud-sql-admin-source toolsets: cloud_sql_postgres_admin_tools: @@ -66,3 +69,4 @@ toolsets: - clone_instance - create_backup - restore_backup + - execute_sql_many diff --git a/internal/sources/cloudsqladmin/cloud_sql_admin.go b/internal/sources/cloudsqladmin/cloud_sql_admin.go index 08bff1af38a2..68ff035706c7 100644 --- a/internal/sources/cloudsqladmin/cloud_sql_admin.go +++ b/internal/sources/cloudsqladmin/cloud_sql_admin.go @@ -126,6 +126,8 @@ func (s *Source) GetDefaultProject() string { return s.DefaultProject } +// GetService returns a new Cloud SQL Admin service for the given access token. + func (s *Source) GetService(ctx context.Context, accessToken string) (*sqladmin.Service, error) { if s.UseClientOAuth { token := &oauth2.Token{AccessToken: accessToken} @@ -294,6 +296,24 @@ func (s *Source) ListInstance(ctx context.Context, project, accessToken string) return instances, nil } +func (s *Source) ExecuteSql(ctx context.Context, project, instance, database, sql string, accessToken string) (any, error) { + service, err := s.GetService(ctx, accessToken) + if err != nil { + return nil, err + } + + req := &sqladmin.ExecuteSqlPayload{ + Database: database, + SqlStatement: sql, + } + + resp, err := service.Instances.ExecuteSql(project, instance, req).Do() + if err != nil { + return nil, fmt.Errorf("error executing sql: %w", err) + } + return resp, nil +} + func (s *Source) CreateInstance(ctx context.Context, project, name, dbVersion, rootPassword string, settings sqladmin.Settings, accessToken string) (any, error) { instance := sqladmin.DatabaseInstance{ Name: name, diff --git a/internal/tools/postgres/postgresexecutesqlmany/postgresexecutesqlmany.go b/internal/tools/postgres/postgresexecutesqlmany/postgresexecutesqlmany.go new file mode 100644 index 000000000000..5d72cf80c1d2 --- /dev/null +++ b/internal/tools/postgres/postgresexecutesqlmany/postgresexecutesqlmany.go @@ -0,0 +1,164 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgresexecutesqlmany + +import ( + "context" + "fmt" + "net/http" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/mcp-toolbox/internal/embeddingmodels" + "github.com/googleapis/mcp-toolbox/internal/sources" + "github.com/googleapis/mcp-toolbox/internal/tools" + "github.com/googleapis/mcp-toolbox/internal/util" + "github.com/googleapis/mcp-toolbox/internal/util/parameters" +) + +const resourceType string = "postgres-execute-sql-many" + +func init() { + if !tools.Register(resourceType, newConfig) { + panic(fmt.Sprintf("tool type %q already registered", resourceType)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + ExecuteSql(ctx context.Context, project, instance, database, sql string, accessToken string) (any, error) + UseClientAuthorization() bool +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description"` + AuthRequired []string `yaml:"authRequired"` + Annotations *tools.ToolAnnotations `yaml:"annotations,omitempty"` +} + +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigType() string { + return resourceType +} + +// Initialize creates a new Postgres ExecuteSqlMany tool. +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + params := parameters.Parameters{ + parameters.NewStringParameter("project", "The GCP project ID."), + parameters.NewStringParameter("instance", "The Cloud SQL instance ID."), + parameters.NewStringParameter("database", "The database name."), + parameters.NewStringParameter("sql", "The SQL statement to execute."), + } + + description := cfg.Description + if description == "" { + description = "Executes multiple SQL statements on a Postgres database." + } + + annotations := tools.GetAnnotationsOrDefault(cfg.Annotations, tools.NewDestructiveAnnotations) + mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, params, annotations) + + t := Tool{ + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +var _ tools.Tool = Tool{} + +type Tool struct { + Config + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +// Invoke executes the SQL statement on the given database. +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { + // Check source compatibility + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) + } + + // Extract parameters from the parameter values map. + paramsMap := params.AsMap() + project, _ := paramsMap["project"].(string) + instance, _ := paramsMap["instance"].(string) + database, _ := paramsMap["database"].(string) + sql, _ := paramsMap["sql"].(string) + + logger, err := util.LoggerFromContext(ctx) + if err != nil { + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) + } + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query on %s/%s/%s", resourceType, project, instance, database)) + + // Execute the SQL statement on the given database. + resp, err := source.ExecuteSql(ctx, project, instance, database, sql, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil +} + +func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return parameters.EmbedParams(ctx, t.Parameters, paramValues, embeddingModelsMap, nil) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.Parameters +} diff --git a/internal/tools/postgres/postgresexecutesqlmany/postgresexecutesqlmany_test.go b/internal/tools/postgres/postgresexecutesqlmany/postgresexecutesqlmany_test.go new file mode 100644 index 000000000000..c8b3ae8c9230 --- /dev/null +++ b/internal/tools/postgres/postgresexecutesqlmany/postgresexecutesqlmany_test.go @@ -0,0 +1,69 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgresexecutesqlmany_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/mcp-toolbox/internal/server" + "github.com/googleapis/mcp-toolbox/internal/testutils" + "github.com/googleapis/mcp-toolbox/internal/tools/postgres/postgresexecutesqlmany" +) + +func TestParseFromYamlExecuteSqlMany(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tool + name: example_tool + type: postgres-execute-sql-many + source: my-instance + description: some description + authRequired: + - my-google-auth-service + `, + want: server.ToolConfigs{ + "example_tool": postgresexecutesqlmany.Config{ + Name: "example_tool", + Type: "postgres-execute-sql-many", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{"my-google-auth-service"}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/internal/tools/postgres/postgressqlmany/postgressqlmany.go b/internal/tools/postgres/postgressqlmany/postgressqlmany.go new file mode 100644 index 000000000000..26f69af98e06 --- /dev/null +++ b/internal/tools/postgres/postgressqlmany/postgressqlmany.go @@ -0,0 +1,171 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgressqlmany + +import ( + "context" + "fmt" + "net/http" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/mcp-toolbox/internal/embeddingmodels" + "github.com/googleapis/mcp-toolbox/internal/sources" + "github.com/googleapis/mcp-toolbox/internal/tools" + "github.com/googleapis/mcp-toolbox/internal/util" + "github.com/googleapis/mcp-toolbox/internal/util/parameters" +) + +const resourceType string = "postgres-sql-many" + +func init() { + if !tools.Register(resourceType, newConfig) { + panic(fmt.Sprintf("tool type %q already registered", resourceType)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + ExecuteSql(ctx context.Context, project, instance, database, sql string, accessToken string) (any, error) + UseClientAuthorization() bool +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + Statement string `yaml:"statement" validate:"required"` + AuthRequired []string `yaml:"authRequired"` + Parameters parameters.Parameters `yaml:"parameters"` + TemplateParameters parameters.Parameters `yaml:"templateParameters"` + Annotations *tools.ToolAnnotations `yaml:"annotations,omitempty"` +} + +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigType() string { + return resourceType +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + infraParams := parameters.Parameters{ + parameters.NewStringParameter("project", "The GCP project ID."), + parameters.NewStringParameter("instance", "The Cloud SQL instance ID."), + parameters.NewStringParameter("database", "The database name."), + } + + allParams, _, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) + if err != nil { + return nil, err + } + + finalParams := append(infraParams, allParams...) + paramManifest := finalParams.Manifest() + + annotations := tools.GetAnnotationsOrDefault(cfg.Annotations, tools.NewDestructiveAnnotations) + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, finalParams, annotations) + + t := Tool{ + Config: cfg, + allParams: finalParams, + manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +var _ tools.Tool = Tool{} + +type Tool struct { + Config + allParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { + // Check source compatibility + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) + } + + // Extract parameters from the parameter values map. + paramsMap := params.AsMap() + project, _ := paramsMap["project"].(string) + instance, _ := paramsMap["instance"].(string) + database, _ := paramsMap["database"].(string) + + newStatement, err := parameters.ResolveTemplateParams(t.allParams, t.Statement, paramsMap) + if err != nil { + return nil, util.NewAgentError("unable to extract template params", err) + } + + logger, err := util.LoggerFromContext(ctx) + if err != nil { + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) + } + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query on %s/%s/%s", resourceType, project, instance, database)) + + // Execute the SQL statement on the given database. + resp, err := source.ExecuteSql(ctx, project, instance, database, newStatement, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil +} + +func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return parameters.EmbedParams(ctx, t.allParams, paramValues, embeddingModelsMap, nil) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.allParams +} diff --git a/internal/tools/postgres/postgressqlmany/postgressqlmany_test.go b/internal/tools/postgres/postgressqlmany/postgressqlmany_test.go new file mode 100644 index 000000000000..73fb96eb55da --- /dev/null +++ b/internal/tools/postgres/postgressqlmany/postgressqlmany_test.go @@ -0,0 +1,107 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgressqlmany_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/mcp-toolbox/internal/server" + "github.com/googleapis/mcp-toolbox/internal/testutils" + "github.com/googleapis/mcp-toolbox/internal/tools/postgres/postgressqlmany" + "github.com/googleapis/mcp-toolbox/internal/util/parameters" +) + +func TestParseFromYamlSqlMany(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tool + name: example_tool + type: postgres-sql-many + source: my-instance + description: some description + statement: "SELECT * FROM users WHERE id = {{.id}}" + authRequired: + - my-google-auth-service + `, + want: server.ToolConfigs{ + "example_tool": postgressqlmany.Config{ + Name: "example_tool", + Type: "postgres-sql-many", + Source: "my-instance", + Description: "some description", + Statement: "SELECT * FROM users WHERE id = {{.id}}", + AuthRequired: []string{"my-google-auth-service"}, + }, + }, + }, + { + desc: "with parameters and templateParameters", + in: ` + kind: tool + name: example_tool_params + type: postgres-sql-many + source: my-instance + description: some description + statement: "SELECT * FROM users WHERE id = {{.id}} AND status = {{.status}}" + parameters: + - name: status + type: string + description: User status + templateParameters: + - name: id + type: string + description: User ID + `, + want: server.ToolConfigs{ + "example_tool_params": postgressqlmany.Config{ + Name: "example_tool_params", + Type: "postgres-sql-many", + Source: "my-instance", + Description: "some description", + Statement: "SELECT * FROM users WHERE id = {{.id}} AND status = {{.status}}", + AuthRequired: []string{}, + Parameters: parameters.Parameters{ + parameters.NewStringParameter("status", "User status"), + }, + TemplateParameters: parameters.Parameters{ + parameters.NewStringParameter("id", "User ID"), + }, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/tests/cloudsql/cloud_sql_execute_sql_test.go b/tests/cloudsql/cloud_sql_execute_sql_test.go new file mode 100644 index 000000000000..0d746788c53e --- /dev/null +++ b/tests/cloudsql/cloud_sql_execute_sql_test.go @@ -0,0 +1,267 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudsql + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "regexp" + "strings" + "testing" + "time" + + "github.com/googleapis/mcp-toolbox/internal/testutils" + "github.com/googleapis/mcp-toolbox/tests" +) + +var ( + executeSqlManyToolType = "postgres-execute-sql-many" + sqlManyToolType = "postgres-sql-many" +) + +type executeSqlTransport struct { + transport http.RoundTripper + url *url.URL +} + +func (t *executeSqlTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if strings.HasPrefix(req.URL.String(), "https://sqladmin.googleapis.com") { + req.URL.Scheme = t.url.Scheme + req.URL.Host = t.url.Host + } + return t.transport.RoundTrip(req) +} + +type masterExecuteSqlHandler struct { + t *testing.T +} + +func (h *masterExecuteSqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.UserAgent(), "genai-toolbox/") { + h.t.Errorf("User-Agent header not found") + } + + // Verify it's an executeSql request + if !strings.Contains(r.URL.Path, "/executeSql") { + h.t.Errorf("unexpected URL path: %s", r.URL.Path) + } + + // Read request body to verify payload if needed + bodyBytes, _ := io.ReadAll(r.Body) + var payload map[string]any + if err := json.Unmarshal(bodyBytes, &payload); err != nil { + h.t.Errorf("failed to unmarshal request body: %v", err) + } + + // Mock response + response := map[string]any{ + "results": []map[string]any{ + { + "columns": []map[string]any{ + { + "name": "result", + "type": "STRING", + }, + }, + "rows": []map[string]any{ + { + "values": []map[string]any{ + { + "value": "success", + }, + }, + }, + }, + }, + }, + } + statusCode := http.StatusOK + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func TestExecuteSqlManyToolEndpoints(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + handler := &masterExecuteSqlHandler{t: t} + server := httptest.NewServer(handler) + defer server.Close() + + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("failed to parse server URL: %v", err) + } + + originalTransport := http.DefaultClient.Transport + if originalTransport == nil { + originalTransport = http.DefaultTransport + } + http.DefaultClient.Transport = &executeSqlTransport{ + transport: originalTransport, + url: serverURL, + } + t.Cleanup(func() { + http.DefaultClient.Transport = originalTransport + }) + + args := []string{"--enable-api"} + toolsFile := getExecuteSqlToolsConfig() + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + tcs := []struct { + name string + toolName string + body string + want string + expectError bool + errorStatus int + }{ + { + name: "successful execute-sql-many", + toolName: "execute-sql-many", + body: `{"project": "p1", "instance": "i1", "database": "db1", "sql": "SELECT 1"}`, + want: `{"results":[{"columns":[{"name":"result","type":"STRING"}],"rows":[{"values":[{"value":"success"}]}]}]}`, + }, + { + name: "successful sql-many", + toolName: "sql-many", + body: `{"project": "p1", "instance": "i1", "database": "db1", "user_id": "123"}`, + want: `{"results":[{"columns":[{"name":"result","type":"STRING"}],"rows":[{"values":[{"value":"success"}]}]}]}`, + }, + { + name: "missing required param in execute-sql-many", + toolName: "execute-sql-many", + body: `{"project": "p1", "instance": "i1", "database": "db1"}`, + want: `{"error":"parameter \"sql\" is required"}`, + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + api := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", tc.toolName) + req, err := http.NewRequest(http.MethodPost, api, bytes.NewBufferString(tc.body)) + if err != nil { + t.Fatalf("unable to create request: %s", err) + } + req.Header.Add("Content-type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to send request: %s", err) + } + defer resp.Body.Close() + + if tc.expectError { + if resp.StatusCode != tc.errorStatus { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("expected status %d but got %d: %s", tc.errorStatus, resp.StatusCode, string(bodyBytes)) + } + return + } + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var result struct { + Result string `json:"result"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if strings.Contains(result.Result, `"error":`) { + var gotMap, wantMap map[string]any + if err := json.Unmarshal([]byte(result.Result), &gotMap); err != nil { + t.Fatalf("failed to unmarshal result error object: %v", err) + } + if err := json.Unmarshal([]byte(tc.want), &wantMap); err != nil { + t.Fatalf("failed to unmarshal want error object: %v", err) + } + if !reflect.DeepEqual(gotMap, wantMap) { + t.Fatalf("unexpected error result: got %+v, want %+v", gotMap, wantMap) + } + return + } + + var got, want map[string]any + if err := json.Unmarshal([]byte(result.Result), &got); err != nil { + t.Fatalf("failed to unmarshal result object: %v. Result was: %s", err, result.Result) + } + if err := json.Unmarshal([]byte(tc.want), &want); err != nil { + t.Fatalf("failed to unmarshal want object: %v", err) + } + + if !reflect.DeepEqual(got, want) { + t.Fatalf("unexpected result: got %+v, want %+v", got, want) + } + }) + } +} + +func getExecuteSqlToolsConfig() map[string]any { + return map[string]any{ + "sources": map[string]any{ + "my-cloud-sql-source": map[string]any{ + "type": "cloud-sql-admin", + }, + }, + "tools": map[string]any{ + "execute-sql-many": map[string]any{ + "type": executeSqlManyToolType, + "source": "my-cloud-sql-source", + "description": "Use this tool to execute sql statement on a specific instance.", + }, + "sql-many": map[string]any{ + "type": sqlManyToolType, + "source": "my-cloud-sql-source", + "description": "Use this tool to get user details from a specific instance.", + "statement": "SELECT * FROM users WHERE id = {{.user_id}}", + "templateParameters": []map[string]any{ + { + "name": "user_id", + "type": "string", + "description": "The ID of the user.", + }, + }, + }, + }, + } +}