diff --git a/internal/sources/postgres/postgres.go b/internal/sources/postgres/postgres.go index 86c7235ac53f..a27bb4c8f90b 100644 --- a/internal/sources/postgres/postgres.go +++ b/internal/sources/postgres/postgres.go @@ -17,8 +17,8 @@ package postgres import ( "context" "fmt" + "net" "net/url" - "strings" "github.com/goccy/go-yaml" "github.com/googleapis/mcp-toolbox/internal/sources" @@ -144,15 +144,7 @@ func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, queryParams["application_name"] = userAgent } - // urlExample := "postgres:dd//username:password@localhost:5432/database_name" - url := &url.URL{ - Scheme: "postgres", - User: url.UserPassword(user, pass), - Host: fmt.Sprintf("%s:%s", host, port), - Path: dbname, - RawQuery: ConvertParamMapToRawQuery(queryParams), - } - config, err := pgxpool.ParseConfig(url.String()) + config, err := pgxpool.ParseConfig(BuildPostgresURL(host, port, user, pass, dbname, queryParams)) if err != nil { return nil, fmt.Errorf("unable to parse connection uri: %w", err) } @@ -171,12 +163,26 @@ func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, return pool, nil } -func ConvertParamMapToRawQuery(queryParams map[string]string) string { - queryArray := []string{} - for k, v := range queryParams { - queryArray = append(queryArray, fmt.Sprintf("%s=%s", k, v)) +// BuildPostgresURL assembles a postgres connection URL from its components. +// It uses net.JoinHostPort so IPv6 host literals are wrapped in brackets as +// required by RFC 3986 (e.g. "[::1]:5432"); IPv4 addresses and hostnames are +// left unchanged. Query parameters are encoded with url.Values so special +// characters are escaped correctly and the output is deterministic. +func BuildPostgresURL(host, port, user, pass, dbname string, queryParams map[string]string) string { + u := &url.URL{ + Scheme: "postgres", + User: url.UserPassword(user, pass), + Host: net.JoinHostPort(host, port), + Path: dbname, + } + if len(queryParams) > 0 { + q := url.Values{} + for k, v := range queryParams { + q.Set(k, v) + } + u.RawQuery = q.Encode() } - return strings.Join(queryArray, "&") + return u.String() } func ParseQueryExecMode(queryExecMode string) (pgx.QueryExecMode, error) { diff --git a/internal/sources/postgres/postgres_test.go b/internal/sources/postgres/postgres_test.go index 210edfecc10e..2d6bd96fc5cc 100644 --- a/internal/sources/postgres/postgres_test.go +++ b/internal/sources/postgres/postgres_test.go @@ -16,8 +16,6 @@ package postgres_test import ( "context" - "sort" - "strings" "testing" "github.com/google/go-cmp/cmp" @@ -194,43 +192,67 @@ func TestFailParseFromYaml(t *testing.T) { } } -func TestConvertParamMapToRawQuery(t *testing.T) { +func TestBuildPostgresURL(t *testing.T) { tcs := []struct { - desc string - in map[string]string - want string + desc string + host string + port string + queryParams map[string]string + want string }{ { - desc: "nil param", - in: nil, - want: "", + desc: "hostname", + host: "db.example.com", + port: "5432", + want: "postgres://u:p@db.example.com:5432/mydb", }, { - desc: "single query param", - in: map[string]string{ - "foo": "bar", - }, - want: "foo=bar", + desc: "ipv4", + host: "127.0.0.1", + port: "5432", + want: "postgres://u:p@127.0.0.1:5432/mydb", }, { - desc: "more than one query param", - in: map[string]string{ - "foo": "bar", - "hello": "world", - }, - want: "foo=bar&hello=world", + desc: "ipv6 loopback", + host: "::1", + port: "5432", + want: "postgres://u:p@[::1]:5432/mydb", + }, + { + desc: "ipv6 documentation", + host: "2001:db8::1", + port: "5432", + want: "postgres://u:p@[2001:db8::1]:5432/mydb", + }, + { + desc: "ipv6 link-local with zone id", + host: "fe80::1%eth0", + port: "5432", + want: "postgres://u:p@[fe80::1%25eth0]:5432/mydb", + }, + { + desc: "query params sorted and encoded", + host: "db.example.com", + port: "5432", + queryParams: map[string]string{"sslmode": "verify-full", "application_name": "my app"}, + want: "postgres://u:p@db.example.com:5432/mydb?application_name=my+app&sslmode=verify-full", + }, + { + desc: "query param value with special characters", + host: "db.example.com", + port: "5432", + queryParams: map[string]string{"options": "-c statement_timeout=5s&key=val"}, + want: "postgres://u:p@db.example.com:5432/mydb?options=-c+statement_timeout%3D5s%26key%3Dval", }, } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - got := postgres.ConvertParamMapToRawQuery(tc.in) - if strings.Contains(got, "&") { - splitGot := strings.Split(got, "&") - sort.Strings(splitGot) - got = strings.Join(splitGot, "&") - } + got := postgres.BuildPostgresURL(tc.host, tc.port, "u", "p", "mydb", tc.queryParams) if got != tc.want { - t.Fatalf("incorrect conversion: got %s want %s", got, tc.want) + t.Fatalf("BuildPostgresURL(%q, %q, ...) = %q, want %q", tc.host, tc.port, got, tc.want) + } + if _, err := pgx.ParseConfig(got); err != nil { + t.Fatalf("pgx.ParseConfig(%q) returned error: %v", got, err) } }) }