diff --git a/cmd/internal/flags.go b/cmd/internal/flags.go index 0ccc14ae28e6..82e713fc7e74 100644 --- a/cmd/internal/flags.go +++ b/cmd/internal/flags.go @@ -34,6 +34,7 @@ func PersistentFlags(parentCmd *cobra.Command, opts *ToolboxOptions) { persistentFlags.BoolVar(&opts.Cfg.TelemetryGCP, "telemetry-gcp", false, "Enable exporting directly to Google Cloud Monitoring.") persistentFlags.StringVar(&opts.Cfg.TelemetryOTLP, "telemetry-otlp", "", "Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318')") persistentFlags.StringVar(&opts.Cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.") + persistentFlags.BoolVar(&opts.Cfg.SQLCommenter, "sql-commenter", false, "Enable appending SQLCommenter-format comments to SQL statements.") persistentFlags.StringSliceVar(&opts.Cfg.UserAgentMetadata, "user-agent-metadata", []string{}, "Appends additional metadata to the User-Agent.") } diff --git a/docs/en/reference/cli.md b/docs/en/reference/cli.md index 4a761bc634b7..010857278b15 100644 --- a/docs/en/reference/cli.md +++ b/docs/en/reference/cli.md @@ -22,6 +22,7 @@ description: > | | `--telemetry-gcp` | Enable exporting directly to Google Cloud Monitoring. | | | | `--telemetry-otlp` | Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318') | | | | `--telemetry-service-name` | Sets the value of the service.name resource attribute for telemetry data. | `toolbox` | +| | `--sql-commenter` | Enable appending SQLCommenter-format comments to SQL statements. | `false` | | | `--config` | File path specifying the tool configuration. Cannot be used with --configs or --config-folder. | | | | `--configs` | Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --config or --config-folder. | | | | `--config-folder` | Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --config or --configs. | | diff --git a/internal/server/config.go b/internal/server/config.go index efa947cefeac..0af257bb95d3 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -64,6 +64,8 @@ type ServerConfig struct { TelemetryOTLP string // TelemetryServiceName defines the value of service.name resource attribute. TelemetryServiceName string + // SQLCommenter enables appending SQLCommenter-format comments to SQL statements. + SQLCommenter bool // Stdio indicates if Toolbox is listening via MCP stdio. Stdio bool // DisableReload indicates if the user has disabled dynamic reloading for Toolbox. diff --git a/internal/server/mcp.go b/internal/server/mcp.go index 24130cab4f9d..4c154319c05a 100644 --- a/internal/server/mcp.go +++ b/internal/server/mcp.go @@ -145,14 +145,15 @@ func (c traceContextCarrier) Keys() []string { return keys } -// extractTraceContext extracts W3C Trace Context from params._meta -func extractTraceContext(ctx context.Context, body []byte) context.Context { - // Try to parse the request to extract _meta +// extractMeta parses params._meta from the request body in a single pass, +// extracting both W3C Trace Context and client telemetry attributes. +func extractMeta(ctx context.Context, body []byte) context.Context { var req struct { Params struct { Meta struct { - Traceparent string `json:"traceparent,omitempty"` - Tracestate string `json:"tracestate,omitempty"` + Traceparent string `json:"traceparent,omitempty"` + Tracestate string `json:"tracestate,omitempty"` + TelemetryAttrs map[string]string `json:"dev.mcp-toolbox/telemetry,omitempty"` } `json:"_meta,omitempty"` } `json:"params,omitempty"` } @@ -161,7 +162,7 @@ func extractTraceContext(ctx context.Context, body []byte) context.Context { return ctx } - // If traceparent is present, extract the context + // Extract W3C Trace Context if req.Params.Meta.Traceparent != "" { carrier := traceContextCarrier{ "traceparent": req.Params.Meta.Traceparent, @@ -169,7 +170,19 @@ func extractTraceContext(ctx context.Context, body []byte) context.Context { if req.Params.Meta.Tracestate != "" { carrier["tracestate"] = req.Params.Meta.Tracestate } - return otel.GetTextMapPropagator().Extract(ctx, carrier) + ctx = otel.GetTextMapPropagator().Extract(ctx, carrier) + } + + // Extract client telemetry attributes + if attrs := req.Params.Meta.TelemetryAttrs; len(attrs) > 0 { + ta := &util.TelemetryAttributes{ + ClientName: attrs["client.name"], + ClientVersion: attrs["client.version"], + ClientModel: attrs["client.model"], + ClientUserID: attrs["client.user.id"], + ClientAgentID: attrs["client.agent.id"], + } + ctx = util.WithTelemetryAttributes(ctx, ta) } return ctx @@ -191,6 +204,8 @@ func (s *stdioSession) Start(ctx context.Context) error { // readInputStream reads requests/notifications from MCP clients through stdin func (s *stdioSession) readInputStream(ctx context.Context) error { sessionStart := time.Now() + ctx = util.WithUserAgent(ctx, s.server.version) + ctx = util.WithSQLCommenterEnabled(ctx, s.server.sqlCommenterEnabled) // Define attributes for session metrics // Note: mcp.protocol.version is added dynamically after protocol negotiation @@ -238,7 +253,7 @@ func (s *stdioSession) readInputStream(ctx context.Context) error { if err := func() error { // This ensures the transport span becomes a child of the client span - msgCtx := extractTraceContext(ctx, []byte(line)) + msgCtx := extractMeta(ctx, []byte(line)) // Create span for STDIO transport msgCtx, span := s.server.instrumentation.Tracer.Start(msgCtx, "toolbox/server/mcp/stdio", @@ -463,6 +478,8 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) { ctx := r.Context() ctx = util.WithLogger(ctx, s.logger) + ctx = util.WithUserAgent(ctx, s.version) + ctx = util.WithSQLCommenterEnabled(ctx, s.sqlCommenterEnabled) // Read body first so we can extract trace context body, err := io.ReadAll(r.Body) @@ -475,7 +492,7 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) { } // This ensures the transport span becomes a child of the client span - ctx = extractTraceContext(ctx, body) + ctx = extractMeta(ctx, body) // Create span for HTTP transport ctx, span := s.instrumentation.Tracer.Start(ctx, "toolbox/server/mcp/http", diff --git a/internal/server/server.go b/internal/server/server.go index d42c25b9cfcd..5eff4cb034e6 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -49,16 +49,17 @@ import ( // Server contains info for running an instance of Toolbox. Should be instantiated with NewServer(). type Server struct { - version string - toolboxUrl string - srv *http.Server - listener net.Listener - root chi.Router - logger log.Logger - instrumentation *telemetry.Instrumentation - sseManager *sseManager - ResourceMgr *resources.ResourceManager - mcpPrmFile string + version string + sqlCommenterEnabled bool + toolboxUrl string + srv *http.Server + listener net.Listener + root chi.Router + logger log.Logger + instrumentation *telemetry.Instrumentation + sseManager *sseManager + ResourceMgr *resources.ResourceManager + mcpPrmFile string } func InitializeConfigs(ctx context.Context, cfg ServerConfig) ( @@ -378,15 +379,16 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) { resourceManager := resources.NewResourceManager(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap) s := &Server{ - version: cfg.Version, - srv: srv, - root: r, - logger: l, - instrumentation: instrumentation, - sseManager: sseManager, - ResourceMgr: resourceManager, - toolboxUrl: cfg.ToolboxUrl, - mcpPrmFile: cfg.McpPrmFile, + version: cfg.Version, + sqlCommenterEnabled: cfg.SQLCommenter, + srv: srv, + root: r, + logger: l, + instrumentation: instrumentation, + sseManager: sseManager, + ResourceMgr: resourceManager, + toolboxUrl: cfg.ToolboxUrl, + mcpPrmFile: cfg.McpPrmFile, } // cors diff --git a/internal/sources/alloydbpg/alloydb_pg.go b/internal/sources/alloydbpg/alloydb_pg.go index c7617ae824d7..de3a7f77e1f7 100644 --- a/internal/sources/alloydbpg/alloydb_pg.go +++ b/internal/sources/alloydbpg/alloydb_pg.go @@ -23,6 +23,7 @@ import ( "cloud.google.com/go/alloydbconn" "github.com/goccy/go-yaml" "github.com/googleapis/mcp-toolbox/internal/sources" + "github.com/googleapis/mcp-toolbox/internal/sources/sqlcommenter" "github.com/googleapis/mcp-toolbox/internal/util" "github.com/googleapis/mcp-toolbox/internal/util/orderedmap" "github.com/jackc/pgx/v5/pgxpool" @@ -103,6 +104,7 @@ func (s *Source) PostgresPool() *pgxpool.Pool { } func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + statement = sqlcommenter.AppendComment(ctx, statement, SourceType) results, err := s.Pool.Query(ctx, statement, params...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) diff --git a/internal/sources/cloudsqlmssql/cloud_sql_mssql.go b/internal/sources/cloudsqlmssql/cloud_sql_mssql.go index 8a3605265acf..718b77083f3a 100644 --- a/internal/sources/cloudsqlmssql/cloud_sql_mssql.go +++ b/internal/sources/cloudsqlmssql/cloud_sql_mssql.go @@ -24,6 +24,7 @@ import ( "cloud.google.com/go/cloudsqlconn/sqlserver/mssql" "github.com/goccy/go-yaml" "github.com/googleapis/mcp-toolbox/internal/sources" + "github.com/googleapis/mcp-toolbox/internal/sources/sqlcommenter" "github.com/googleapis/mcp-toolbox/internal/util" "github.com/googleapis/mcp-toolbox/internal/util/orderedmap" "go.opentelemetry.io/otel/trace" @@ -108,6 +109,7 @@ func (s *Source) MSSQLDB() *sql.DB { } func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + statement = sqlcommenter.AppendComment(ctx, statement, SourceType) results, err := s.MSSQLDB().QueryContext(ctx, statement, params...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) diff --git a/internal/sources/cloudsqlmysql/cloud_sql_mysql.go b/internal/sources/cloudsqlmysql/cloud_sql_mysql.go index cce65db37497..b7724563c20c 100644 --- a/internal/sources/cloudsqlmysql/cloud_sql_mysql.go +++ b/internal/sources/cloudsqlmysql/cloud_sql_mysql.go @@ -24,6 +24,7 @@ import ( "cloud.google.com/go/cloudsqlconn/mysql/mysql" "github.com/goccy/go-yaml" "github.com/googleapis/mcp-toolbox/internal/sources" + "github.com/googleapis/mcp-toolbox/internal/sources/sqlcommenter" "github.com/googleapis/mcp-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/mcp-toolbox/internal/util" "github.com/googleapis/mcp-toolbox/internal/util/orderedmap" @@ -107,6 +108,7 @@ func (s *Source) MySQLDatabase() string { } func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + statement = sqlcommenter.AppendComment(ctx, statement, SourceType) results, err := s.MySQLPool().QueryContext(ctx, statement, params...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) diff --git a/internal/sources/cloudsqlpg/cloud_sql_pg.go b/internal/sources/cloudsqlpg/cloud_sql_pg.go index 818e2271cb29..433a03206568 100644 --- a/internal/sources/cloudsqlpg/cloud_sql_pg.go +++ b/internal/sources/cloudsqlpg/cloud_sql_pg.go @@ -22,6 +22,7 @@ import ( "cloud.google.com/go/cloudsqlconn" "github.com/goccy/go-yaml" "github.com/googleapis/mcp-toolbox/internal/sources" + "github.com/googleapis/mcp-toolbox/internal/sources/sqlcommenter" "github.com/googleapis/mcp-toolbox/internal/util" "github.com/googleapis/mcp-toolbox/internal/util/orderedmap" "github.com/jackc/pgx/v5/pgxpool" @@ -109,6 +110,7 @@ func (s *Source) PostgresPool() *pgxpool.Pool { } func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + statement = sqlcommenter.AppendComment(ctx, statement, SourceType) results, err := s.PostgresPool().Query(ctx, statement, params...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) diff --git a/internal/sources/mssql/mssql.go b/internal/sources/mssql/mssql.go index 06c33af691fb..6952c04a9521 100644 --- a/internal/sources/mssql/mssql.go +++ b/internal/sources/mssql/mssql.go @@ -22,6 +22,7 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/mcp-toolbox/internal/sources" + "github.com/googleapis/mcp-toolbox/internal/sources/sqlcommenter" "github.com/googleapis/mcp-toolbox/internal/util" "github.com/googleapis/mcp-toolbox/internal/util/orderedmap" _ "github.com/microsoft/go-mssqldb" @@ -106,6 +107,7 @@ func (s *Source) MSSQLDB() *sql.DB { } func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + statement = sqlcommenter.AppendComment(ctx, statement, SourceType) results, err := s.MSSQLDB().QueryContext(ctx, statement, params...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) diff --git a/internal/sources/mysql/mysql.go b/internal/sources/mysql/mysql.go index 477c9a982826..3c13a01ecc26 100644 --- a/internal/sources/mysql/mysql.go +++ b/internal/sources/mysql/mysql.go @@ -23,6 +23,7 @@ import ( driver "github.com/go-sql-driver/mysql" "github.com/goccy/go-yaml" "github.com/googleapis/mcp-toolbox/internal/sources" + "github.com/googleapis/mcp-toolbox/internal/sources/sqlcommenter" "github.com/googleapis/mcp-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/mcp-toolbox/internal/util" "github.com/googleapis/mcp-toolbox/internal/util/orderedmap" @@ -106,6 +107,7 @@ func (s *Source) MySQLDatabase() string { } func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + statement = sqlcommenter.AppendComment(ctx, statement, SourceType) results, err := s.MySQLPool().QueryContext(ctx, statement, params...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) diff --git a/internal/sources/postgres/postgres.go b/internal/sources/postgres/postgres.go index 3383d482f4f9..2c721d50715f 100644 --- a/internal/sources/postgres/postgres.go +++ b/internal/sources/postgres/postgres.go @@ -21,6 +21,7 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/mcp-toolbox/internal/sources" + "github.com/googleapis/mcp-toolbox/internal/sources/sqlcommenter" "github.com/googleapis/mcp-toolbox/internal/util" "github.com/googleapis/mcp-toolbox/internal/util/orderedmap" "github.com/jackc/pgx/v5" @@ -101,6 +102,7 @@ func (s *Source) PostgresPool() *pgxpool.Pool { } func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + statement = sqlcommenter.AppendComment(ctx, statement, SourceType) results, err := s.PostgresPool().Query(ctx, statement, params...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) diff --git a/internal/sources/sqlcommenter/sqlcommenter.go b/internal/sources/sqlcommenter/sqlcommenter.go new file mode 100644 index 000000000000..d281a3600fb5 --- /dev/null +++ b/internal/sources/sqlcommenter/sqlcommenter.go @@ -0,0 +1,117 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlcommenter + +import ( + "context" + "fmt" + "net/url" + "sort" + "strings" + + "github.com/googleapis/mcp-toolbox/internal/util" + "go.opentelemetry.io/otel/trace" +) + +// AppendComment appends a SQLCommenter-format comment to the given SQL statement. +// It gathers attributes from the context (trace, server, client, tool metadata) +// and the provided dbSystemName, then appends them as key='value' pairs sorted +// alphabetically. +func AppendComment(ctx context.Context, statement string, dbSystemName string) string { + // Only append SQL comments when sql-commenter is enabled + if !util.SQLCommenterEnabledFromContext(ctx) { + return statement + } + + pairs := collectAttributes(ctx, dbSystemName) + if len(pairs) == 0 { + return statement + } + + // Sort keys alphabetically + keys := make([]string, 0, len(pairs)) + for k := range pairs { + keys = append(keys, k) + } + sort.Strings(keys) + + // Build comment in SQLCommenter format: key='url_encoded_value' + parts := make([]string, 0, len(keys)) + for _, k := range keys { + encodedKey := url.QueryEscape(k) + encodedVal := url.QueryEscape(pairs[k]) + parts = append(parts, fmt.Sprintf("%s='%s'", encodedKey, encodedVal)) + } + + comment := strings.Join(parts, ",") + return "/*" + comment + "*/ " + statement +} + +// collectAttributes gathers all available SQLCommenter attributes from context. +func collectAttributes(ctx context.Context, dbSystemName string) map[string]string { + attrs := make(map[string]string) + + // traceparent from OTel span context + spanCtx := trace.SpanFromContext(ctx).SpanContext() + if spanCtx.IsValid() { + traceparent := fmt.Sprintf("00-%s-%s-%s", + spanCtx.TraceID().String(), + spanCtx.SpanID().String(), + spanCtx.TraceFlags().String(), + ) + attrs["traceparent"] = traceparent + } + + // server from UserAgent context + if ua, err := util.UserAgentFromContext(ctx); err == nil && ua != "" { + attrs["server"] = ua + } + + // db.system.name from parameter + if dbSystemName != "" { + attrs["db.system.name"] = dbSystemName + } + + // tool.name from GenAIMetricAttrs + if genAI := util.GenAIMetricAttrsFromContext(ctx); genAI != nil { + if genAI.ToolName != "" { + attrs["tool.name"] = genAI.ToolName + } + } + + // Client attributes from TelemetryAttributes + if ta := util.TelemetryAttributesFromContext(ctx); ta != nil { + // Combined client = name/version + if ta.ClientName != "" && ta.ClientVersion != "" { + attrs["client"] = ta.ClientName + "/" + ta.ClientVersion + } else if ta.ClientName != "" { + attrs["client"] = ta.ClientName + } else if ta.ClientVersion != "" { + attrs["client"] = ta.ClientVersion + } + + if ta.ClientModel != "" { + attrs["client.model"] = ta.ClientModel + } + if ta.ClientUserID != "" { + attrs["client.user.id"] = ta.ClientUserID + } + if ta.ClientAgentID != "" { + attrs["client.agent.id"] = ta.ClientAgentID + } + } + + return attrs +} diff --git a/internal/sources/sqlcommenter/sqlcommenter_test.go b/internal/sources/sqlcommenter/sqlcommenter_test.go new file mode 100644 index 000000000000..4d239a2c63f9 --- /dev/null +++ b/internal/sources/sqlcommenter/sqlcommenter_test.go @@ -0,0 +1,208 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlcommenter + +import ( + "context" + "net/url" + "strings" + "testing" + + "github.com/googleapis/mcp-toolbox/internal/util" +) + +// sqlCommenterCtx returns a context with sql-commenter enabled. +func sqlCommenterCtx() context.Context { + return util.WithSQLCommenterEnabled(context.Background(), true) +} + +func TestAppendComment_SQLCommenterDisabled(t *testing.T) { + // SQL commenter not enabled in context — statement should be unchanged + ctx := context.Background() + ctx = util.WithUserAgent(ctx, "1.1.0") + ctx = util.WithGenAIMetricAttrs(ctx, &util.GenAIMetricAttrs{ + ToolName: "search_hotels", + }) + + stmt := "SELECT * FROM users" + result := AppendComment(ctx, stmt, "postgresql") + + if result != stmt { + t.Errorf("expected unchanged statement when sql-commenter disabled, got: %s", result) + } +} + +func TestAppendComment_EmptyContext(t *testing.T) { + ctx := sqlCommenterCtx() + stmt := "SELECT * FROM users" + result := AppendComment(ctx, stmt, "") + + // No attributes available, statement should be unchanged + if result != stmt { + t.Errorf("expected unchanged statement, got: %s", result) + } +} + +func TestAppendComment_OnlyDbSystemName(t *testing.T) { + ctx := sqlCommenterCtx() + stmt := "SELECT * FROM users" + result := AppendComment(ctx, stmt, "postgresql") + + expected := "/*db.system.name='postgresql'*/ SELECT * FROM users" + if result != expected { + t.Errorf("expected %s, got: %s", expected, result) + } +} + +func TestAppendComment_ServerSideAttributes(t *testing.T) { + ctx := sqlCommenterCtx() + ctx = util.WithUserAgent(ctx, "1.1.0") + ctx = util.WithGenAIMetricAttrs(ctx, &util.GenAIMetricAttrs{ + ToolName: "search_hotels", + }) + + stmt := "SELECT * FROM hotels" + result := AppendComment(ctx, stmt, "postgresql") + + // Should contain server, tool.name, db.system.name + if !strings.Contains(result, "/*") || !strings.Contains(result, "*/") { + t.Errorf("expected SQL comment, got: %s", result) + } + if !strings.Contains(result, "db.system.name='postgresql'") { + t.Errorf("missing db.system.name, got: %s", result) + } + if !strings.Contains(result, "server='"+url.QueryEscape("genai-toolbox/1.1.0")+"'") { + t.Errorf("missing server, got: %s", result) + } + if !strings.Contains(result, "tool.name='search_hotels'") { + t.Errorf("missing tool.name, got: %s", result) + } + // Comment should be prepended + if !strings.HasPrefix(result, "/*") { + t.Errorf("expected comment prepended to statement, got: %s", result) + } +} + +func TestAppendComment_FullAttributes(t *testing.T) { + ctx := sqlCommenterCtx() + ctx = util.WithUserAgent(ctx, "1.1.0") + ctx = util.WithGenAIMetricAttrs(ctx, &util.GenAIMetricAttrs{ + ToolName: "search_user", + }) + ctx = util.WithTelemetryAttributes(ctx, &util.TelemetryAttributes{ + ClientName: "toolbox-langchain-python", + ClientVersion: "v0.1.0", + ClientModel: "gemini-2.5-flash", + ClientUserID: "user-123", + ClientAgentID: "agent-456", + }) + + stmt := "SELECT * FROM users" + result := AppendComment(ctx, stmt, "postgresql") + + // Verify all expected key='value' pairs are present + expectedPairs := []string{ + "client='" + url.QueryEscape("toolbox-langchain-python/v0.1.0") + "'", + "client.agent.id='agent-456'", + "client.model='gemini-2.5-flash'", + "client.user.id='user-123'", + "db.system.name='postgresql'", + "server='" + url.QueryEscape("genai-toolbox/1.1.0") + "'", + "tool.name='search_user'", + } + for _, pair := range expectedPairs { + if !strings.Contains(result, pair) { + t.Errorf("missing pair %q in: %s", pair, result) + } + } +} + +func TestAppendComment_AlphabeticalOrder(t *testing.T) { + ctx := sqlCommenterCtx() + ctx = util.WithUserAgent(ctx, "1.0.0") + ctx = util.WithGenAIMetricAttrs(ctx, &util.GenAIMetricAttrs{ + ToolName: "my_tool", + }) + ctx = util.WithTelemetryAttributes(ctx, &util.TelemetryAttributes{ + ClientName: "test-client", + ClientVersion: "v1", + ClientModel: "model-x", + }) + + stmt := "SELECT 1" + result := AppendComment(ctx, stmt, "postgresql") + + // Extract the comment part + commentStart := strings.Index(result, "/*") + commentEnd := strings.Index(result, "*/") + if commentStart == -1 || commentEnd == -1 { + t.Fatalf("no comment found in: %s", result) + } + comment := result[commentStart+2 : commentEnd] + parts := strings.Split(comment, ",") + + // Verify keys are sorted + for i := 1; i < len(parts); i++ { + prevKey := strings.SplitN(parts[i-1], "=", 2)[0] + currKey := strings.SplitN(parts[i], "=", 2)[0] + if prevKey > currKey { + t.Errorf("keys not sorted: %s comes before %s", prevKey, currKey) + } + } +} + +func TestAppendComment_URLEncoding(t *testing.T) { + ctx := sqlCommenterCtx() + ctx = util.WithTelemetryAttributes(ctx, &util.TelemetryAttributes{ + ClientName: "my client/special", + ClientVersion: "v1.0", + }) + + stmt := "SELECT 1" + result := AppendComment(ctx, stmt, "") + + // The client value "my client/special/v1.0" should be URL-encoded + if !strings.Contains(result, "client='"+url.QueryEscape("my client/special/v1.0")+"'") { + t.Errorf("expected URL-encoded client, got: %s", result) + } +} + +func TestAppendComment_PartialClientAttributes(t *testing.T) { + ctx := sqlCommenterCtx() + ctx = util.WithTelemetryAttributes(ctx, &util.TelemetryAttributes{ + ClientName: "test-client", + // No version + }) + + stmt := "SELECT 1" + result := AppendComment(ctx, stmt, "") + + if !strings.Contains(result, "client='test-client'") { + t.Errorf("expected client with name only, got: %s", result) + } +} + +func TestAppendComment_EmptyTelemetryAttributes(t *testing.T) { + ctx := sqlCommenterCtx() + ctx = util.WithTelemetryAttributes(ctx, &util.TelemetryAttributes{}) + + stmt := "SELECT 1" + result := AppendComment(ctx, stmt, "postgresql") + + // Should only have db.system.name since all telemetry attrs are empty + if !strings.Contains(result, "db.system.name='postgresql'") { + t.Errorf("expected db.system.name, got: %s", result) + } +} diff --git a/internal/util/util.go b/internal/util/util.go index d651aea0b2da..73d4618c44c1 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -214,3 +214,42 @@ func GenAIMetricAttrsFromContext(ctx context.Context) *GenAIMetricAttrs { } return nil } + +// TelemetryAttributes holds client-provided telemetry metadata from _meta["dev.mcp-toolbox/telemetry"]. +type TelemetryAttributes struct { + ClientName string + ClientVersion string + ClientModel string + ClientUserID string + ClientAgentID string +} + +const telemetryAttrsKey contextKey = "telemetryAttrs" + +// WithTelemetryAttributes adds TelemetryAttributes to the context +func WithTelemetryAttributes(ctx context.Context, attrs *TelemetryAttributes) context.Context { + return context.WithValue(ctx, telemetryAttrsKey, attrs) +} + +// TelemetryAttributesFromContext retrieves TelemetryAttributes from context +func TelemetryAttributesFromContext(ctx context.Context) *TelemetryAttributes { + if attrs, ok := ctx.Value(telemetryAttrsKey).(*TelemetryAttributes); ok { + return attrs + } + return nil +} + +const sqlCommenterEnabledKey contextKey = "sqlCommenterEnabled" + +// WithSQLCommenterEnabled adds the sql-commenter-enabled flag to the context +func WithSQLCommenterEnabled(ctx context.Context, enabled bool) context.Context { + return context.WithValue(ctx, sqlCommenterEnabledKey, enabled) +} + +// SQLCommenterEnabledFromContext retrieves the sql-commenter-enabled flag from context +func SQLCommenterEnabledFromContext(ctx context.Context) bool { + if enabled, ok := ctx.Value(sqlCommenterEnabledKey).(bool); ok { + return enabled + } + return false +}