diff --git a/tests/alloydbpg/alloydb_pg_mcp_test.go b/tests/alloydbpg/alloydb_pg_mcp_test.go index 8e6c3778f9eb..acdbdfec8181 100644 --- a/tests/alloydbpg/alloydb_pg_mcp_test.go +++ b/tests/alloydbpg/alloydb_pg_mcp_test.go @@ -14,9 +14,6 @@ package alloydbpg -// TODO: We may want to add tests for custom tools defined in alloydb-postgres.yaml -// in the future, rather than just testing the prebuilt tools. - import ( "context" "fmt" @@ -149,7 +146,7 @@ func TestAlloyDBPgListTools(t *testing.T) { } func TestAlloyDBPgCallTool(t *testing.T) { - getAlloyDBPgVars(t) + sourceConfig := getAlloyDBPgVars(t) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() @@ -162,9 +159,40 @@ func TestAlloyDBPgCallTool(t *testing.T) { uniqueID := strings.ReplaceAll(uuid.New().String(), "-", "") - args := []string{"--prebuilt", "alloydb-postgres"} + t.Cleanup(func() { + tests.CleanupPostgresTables(t, context.Background(), pool, uniqueID) + }) - cmd, cleanup, err := tests.StartCmd(ctx, map[string]any{}, args...) + tableNameParam := "param_table_" + uniqueID + tableNameAuth := "auth_table_" + uniqueID + + // set up data for param tool + createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam) + teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams) + defer teardownTable1(t) + + // set up data for auth tool + createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetPostgresSQLAuthToolInfo(tableNameAuth) + teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) + defer teardownTable2(t) + + // Set up table for semantic search + vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool) + defer tearDownVectorTable(t) + + // Write config into a file and pass it to command + toolsFile := tests.GetToolsConfig(sourceConfig, AlloyDBPostgresToolType, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) + toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql") + tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement() + toolsFile = tests.AddTemplateParamConfig(t, toolsFile, AlloyDBPostgresToolType, tmplSelectCombined, tmplSelectFilterCombined, "") + + // Add semantic search tool config + insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName) + toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, AlloyDBPostgresToolType, insertStmt, searchStmt) + + toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile) + + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile) if err != nil { t.Fatalf("command initialization returned an error: %v", err) } @@ -178,6 +206,19 @@ func TestAlloyDBPgCallTool(t *testing.T) { t.Fatalf("toolbox didn't start successfully: %v", err) } + // Get configs for tests + select1Want, _, createTableStatement, _ := tests.GetPostgresWants() + + // Run custom tool tests via MCP + tests.RunMCPToolInvokeTest(t, ctx, select1Want) + + // Run execute-sql tool tests via MCP + tests.RunMCPExecuteSqlToolInvokeTest(t, ctx, createTableStatement, select1Want) + + // Run template parameters tool tests via MCP + tableNameTemplateParam := "template_param_table_" + uniqueID + tests.RunMCPToolInvokeWithTemplateParameters(t, ctx, tableNameTemplateParam) + // Run shared Postgres tests tests.RunMCPPostgresListViewsTest(t, ctx, pool) tests.RunMCPPostgresListSchemasTest(t, ctx, pool, AlloyDBPostgresUser, uniqueID) diff --git a/tests/cloudsqlpg/cloud_sql_pg_mcp_test.go b/tests/cloudsqlpg/cloud_sql_pg_mcp_test.go index 8c93fda8fa9a..c853c21174df 100644 --- a/tests/cloudsqlpg/cloud_sql_pg_mcp_test.go +++ b/tests/cloudsqlpg/cloud_sql_pg_mcp_test.go @@ -14,9 +14,6 @@ package cloudsqlpg -// TODO: We may want to add tests for custom tools defined in cloud-sql-postgres.yaml -// in the future, rather than just testing the prebuilt tools. - import ( "context" "fmt" @@ -136,7 +133,7 @@ func TestCloudSQLPgListTools(t *testing.T) { } func TestCloudSQLPgCallTool(t *testing.T) { - getCloudSQLPgVars(t) + sourceConfig := getCloudSQLPgVars(t) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() @@ -149,9 +146,40 @@ func TestCloudSQLPgCallTool(t *testing.T) { uniqueID := strings.ReplaceAll(uuid.New().String(), "-", "") - args := []string{"--prebuilt", "cloud-sql-postgres"} + t.Cleanup(func() { + tests.CleanupPostgresTables(t, context.Background(), pool, uniqueID) + }) - cmd, cleanup, err := tests.StartCmd(ctx, map[string]any{}, args...) + tableNameParam := "param_table_" + uniqueID + tableNameAuth := "auth_table_" + uniqueID + + // set up data for param tool + createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam) + teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams) + defer teardownTable1(t) + + // set up data for auth tool + createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetPostgresSQLAuthToolInfo(tableNameAuth) + teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) + defer teardownTable2(t) + + // Set up table for semantic search + vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool) + defer tearDownVectorTable(t) + + // Write config into a file and pass it to command + toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLPostgresToolType, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) + toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql") + tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement() + toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLPostgresToolType, tmplSelectCombined, tmplSelectFilterCombined, "") + + // Add semantic search tool config + insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName) + toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, CloudSQLPostgresToolType, insertStmt, searchStmt) + + toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile) + + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile) if err != nil { t.Fatalf("command initialization returned an error: %v", err) } @@ -165,6 +193,19 @@ func TestCloudSQLPgCallTool(t *testing.T) { t.Fatalf("toolbox didn't start successfully: %v", err) } + // Get configs for tests + select1Want, _, createTableStatement, _ := tests.GetPostgresWants() + + // Run custom tool tests via MCP + tests.RunMCPToolInvokeTest(t, ctx, select1Want) + + // Run execute-sql tool tests via MCP + tests.RunMCPExecuteSqlToolInvokeTest(t, ctx, createTableStatement, select1Want) + + // Run template parameters tool tests via MCP + tableNameTemplateParam := "template_param_table_" + uniqueID + tests.RunMCPToolInvokeWithTemplateParameters(t, ctx, tableNameTemplateParam) + // Run shared Postgres tests tests.RunMCPPostgresListViewsTest(t, ctx, pool) tests.RunMCPPostgresListSchemasTest(t, ctx, pool, CloudSQLPostgresUser, uniqueID) diff --git a/tests/mcp_tool.go b/tests/mcp_tool.go index 4affde143f14..1dcab3d05bad 100644 --- a/tests/mcp_tool.go +++ b/tests/mcp_tool.go @@ -172,7 +172,7 @@ func GetMCPResultText(t *testing.T, resp *MCPCallToolResponse) []any { } else { if slice, ok := item.([]any); ok { res = append(res, slice...) - } else { + } else if item != nil { res = append(res, item) } } @@ -316,7 +316,7 @@ func RunMCPToolInvokeTest(t *testing.T, ctx context.Context, select1Want string, myToolId3NameAliceWant: "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]", myToolById4Want: "[{\"id\":4,\"name\":null}]", myArrayToolWant: "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]", - nullWant: "null", + nullWant: "[]", supportOptionalNullParam: true, supportArrayParam: true, supportClientAuth: false, @@ -411,6 +411,283 @@ func RunMCPToolInvokeTest(t *testing.T, ctx context.Context, select1Want string, } } +// RunMCPExecuteSqlToolInvokeTest runs execute-sql tool invoke test cases via MCP. +func RunMCPExecuteSqlToolInvokeTest(t *testing.T, ctx context.Context, createTableStatement, select1Want string, options ...ExecuteSqlOption) { + configs := &ExecuteSqlTestConfig{ + select1Statement: `"SELECT 1"`, + createWant: "[]", + dropWant: "[]", + selectEmptyWant: "[]", + } + + for _, option := range options { + option(configs) + } + + idToken, err := GetGoogleIdToken(t) + if err != nil { + t.Fatalf("error getting Google ID token: %s", err) + } + + invokeTcs := []struct { + name string + toolName string + args map[string]any + headers map[string]string + want string + isErr bool + isAgentErr bool + }{ + { + name: "invoke my-exec-sql-tool", + toolName: "my-exec-sql-tool", + args: map[string]any{"sql": strings.Trim(configs.select1Statement, `"`)}, + want: select1Want, + }, + { + name: "invoke my-exec-sql-tool create table", + toolName: "my-exec-sql-tool", + args: map[string]any{"sql": strings.Trim(createTableStatement, `"`)}, + want: configs.createWant, + }, + { + name: "invoke my-exec-sql-tool select table", + toolName: "my-exec-sql-tool", + args: map[string]any{"sql": "SELECT * FROM t"}, + want: configs.selectEmptyWant, + }, + { + name: "invoke my-exec-sql-tool drop table", + toolName: "my-exec-sql-tool", + args: map[string]any{"sql": "DROP TABLE t"}, + want: configs.dropWant, + }, + { + name: "invoke my-exec-sql-tool without body", + toolName: "my-exec-sql-tool", + args: map[string]any{}, + isAgentErr: true, + }, + { + name: "Invoke my-auth-exec-sql-tool with auth token", + toolName: "my-auth-exec-sql-tool", + args: map[string]any{"sql": strings.Trim(configs.select1Statement, `"`)}, + headers: map[string]string{"my-google-auth_token": idToken}, + want: select1Want, + }, + { + name: "Invoke my-auth-exec-sql-tool with invalid auth token", + toolName: "my-auth-exec-sql-tool", + args: map[string]any{"sql": strings.Trim(configs.select1Statement, `"`)}, + headers: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, + isErr: true, + }, + { + name: "Invoke my-auth-exec-sql-tool without auth token", + toolName: "my-auth-exec-sql-tool", + args: map[string]any{"sql": strings.Trim(configs.select1Statement, `"`)}, + isErr: true, + }, + { + name: "invoke my-exec-sql-tool with invalid SELECT SQL", + toolName: "my-exec-sql-tool", + args: map[string]any{"sql": "SELECT * FROM non_existent_table"}, + isAgentErr: true, + }, + { + name: "invoke my-exec-sql-tool with invalid ALTER SQL", + toolName: "my-exec-sql-tool", + args: map[string]any{"sql": "ALTER TALE t ALTER COLUMN id DROP NOT NULL"}, + isAgentErr: true, + }, + } + + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, tc.toolName, tc.args, tc.headers) + + if tc.isErr { + if err == nil && mcpResp.Error == nil && !mcpResp.Result.IsError { + t.Fatalf("expected error but got none") + } + return + } + + if err != nil { + if tc.isAgentErr { + return + } + t.Fatalf("native error executing %s: %s", tc.toolName, err) + } + + if mcpResp.Result.IsError { + if tc.isAgentErr { + return + } + t.Fatalf("%s returned error result: %v", tc.toolName, mcpResp.Result) + } + + if statusCode != http.StatusOK { + t.Fatalf("wrong status code: got %d, want %d", statusCode, http.StatusOK) + } + + if tc.want == "" { + return + } + + got := GetMCPResultText(t, mcpResp) + gotBytes, _ := json.Marshal(got) + gotStr := string(gotBytes) + if !strings.Contains(gotStr, tc.want) { + t.Fatalf(`expected %q to contain %q`, gotStr, tc.want) + } + }) + } +} + +// RunMCPToolInvokeWithTemplateParameters runs tool invoke test cases with template parameters via MCP. +func RunMCPToolInvokeWithTemplateParameters(t *testing.T, ctx context.Context, tableName string, options ...TemplateParamOption) { + configs := &TemplateParameterTestConfig{ + ddlWant: "[]", + selectAllWant: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"},{\"age\":100,\"id\":2,\"name\":\"Alice\"}]", + selectId1Want: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]", + selectNameWant: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]", + selectEmptyWant: "[]", + insert1Want: "[]", + + nameFieldArray: `["name"]`, + nameColFilter: "name", + createColArray: `["id INT","name VARCHAR(20)","age INT"]`, + + supportDdl: true, + supportInsert: true, + } + + for _, option := range options { + option(configs) + } + + selectOnlyNamesWant := "[{\"name\":\"Alex\"},{\"name\":\"Alice\"}]" + + invokeTcs := []struct { + name string + toolName string + enabled bool + ddl bool + insert bool + args map[string]any + want string + isErr bool + }{ + { + name: "invoke create-table-templateParams-tool", + toolName: "create-table-templateParams-tool", + ddl: true, + enabled: configs.supportDdl, + args: map[string]any{"tableName": tableName, "columns": []string{"id INT", "name VARCHAR(20)", "age INT"}}, // Pass as slice, not string! + want: configs.ddlWant, + }, + { + name: "invoke insert-table-templateParams-tool", + toolName: "insert-table-templateParams-tool", + insert: true, + enabled: configs.supportInsert, + args: map[string]any{"tableName": tableName, "columns": []string{"id", "name", "age"}, "values": "1, 'Alex', 21"}, + want: configs.insert1Want, + }, + { + name: "invoke insert-table-templateParams-tool 2", + toolName: "insert-table-templateParams-tool", + insert: true, + enabled: configs.supportInsert, + args: map[string]any{"tableName": tableName, "columns": []string{"id", "name", "age"}, "values": "2, 'Alice', 100"}, + want: configs.insert1Want, + }, + { + name: "invoke select-templateParams-tool", + toolName: "select-templateParams-tool", + enabled: true, + args: map[string]any{"tableName": tableName}, + want: configs.selectAllWant, + }, + { + name: "invoke select-templateParams-combined-tool", + toolName: "select-templateParams-combined-tool", + enabled: true, + args: map[string]any{"id": 1, "tableName": tableName}, + want: configs.selectId1Want, + }, + { + name: "invoke select-templateParams-combined-tool with no results", + toolName: "select-templateParams-combined-tool", + enabled: true, + args: map[string]any{"id": 999, "tableName": tableName}, + want: configs.selectEmptyWant, + }, + { + name: "invoke select-fields-templateParams-tool", + toolName: "select-fields-templateParams-tool", + enabled: configs.supportSelectFields, + args: map[string]any{"tableName": tableName, "fields": []string{"name"}}, + want: selectOnlyNamesWant, + }, + { + name: "invoke select-filter-templateParams-combined-tool", + toolName: "select-filter-templateParams-combined-tool", + enabled: true, + args: map[string]any{"name": "Alex", "tableName": tableName, "columnFilter": configs.nameColFilter}, + want: configs.selectNameWant, + }, + { + name: "invoke drop-table-templateParams-tool", + toolName: "drop-table-templateParams-tool", + ddl: true, + enabled: configs.supportDdl, + args: map[string]any{"tableName": tableName}, + want: configs.ddlWant, + }, + } + + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + if !tc.enabled { + return + } + statusCode, mcpResp, err := InvokeMCPTool(t, ctx, tc.toolName, tc.args, nil) + + if tc.isErr { + if err == nil && mcpResp.Error == nil && !mcpResp.Result.IsError { + t.Fatalf("expected error but got none") + } + return + } + + if err != nil { + t.Fatalf("native error executing %s: %s", tc.toolName, err) + } + + if mcpResp.Result.IsError { + t.Fatalf("%s returned error result: %v", tc.toolName, mcpResp.Result) + } + + if statusCode != http.StatusOK { + t.Fatalf("wrong status code: got %d, want %d", statusCode, http.StatusOK) + } + + if tc.want == "" { + return + } + + got := GetMCPResultText(t, mcpResp) + gotBytes, _ := json.Marshal(got) + gotStr := string(gotBytes) + if !strings.Contains(gotStr, tc.want) { + t.Fatalf(`expected %q to contain %q`, gotStr, tc.want) + } + }) + } +} + // setUpPostgresViews creates a test view and returns a cleanup function. func setUpMCPPostgresViews(t *testing.T, ctx context.Context, pool *pgxpool.Pool, viewName string) func() { createView := fmt.Sprintf("CREATE VIEW %s AS SELECT 1 AS col", viewName)