Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions internal/sources/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ package postgres
import (
"context"
"fmt"
"net"
"net/url"
"strings"

"github.com/goccy/go-yaml"
"github.com/googleapis/mcp-toolbox/internal/sources"
Expand Down Expand Up @@ -144,15 +144,7 @@ func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name,
queryParams["application_name"] = userAgent
}

// urlExample := "postgres:dd//username:password@localhost:5432/database_name"
url := &url.URL{
Scheme: "postgres",
User: url.UserPassword(user, pass),
Host: fmt.Sprintf("%s:%s", host, port),
Path: dbname,
RawQuery: ConvertParamMapToRawQuery(queryParams),
}
config, err := pgxpool.ParseConfig(url.String())
config, err := pgxpool.ParseConfig(BuildPostgresURL(host, port, user, pass, dbname, queryParams))
if err != nil {
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
}
Expand All @@ -171,12 +163,26 @@ func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name,
return pool, nil
}

func ConvertParamMapToRawQuery(queryParams map[string]string) string {
queryArray := []string{}
for k, v := range queryParams {
queryArray = append(queryArray, fmt.Sprintf("%s=%s", k, v))
// BuildPostgresURL assembles a postgres connection URL from its components.
// It uses net.JoinHostPort so IPv6 host literals are wrapped in brackets as
// required by RFC 3986 (e.g. "[::1]:5432"); IPv4 addresses and hostnames are
// left unchanged. Query parameters are encoded with url.Values so special
// characters are escaped correctly and the output is deterministic.
func BuildPostgresURL(host, port, user, pass, dbname string, queryParams map[string]string) string {
u := &url.URL{
Scheme: "postgres",
User: url.UserPassword(user, pass),
Host: net.JoinHostPort(host, port),
Path: dbname,
}
if len(queryParams) > 0 {
q := url.Values{}
for k, v := range queryParams {
q.Set(k, v)
}
u.RawQuery = q.Encode()
}
return strings.Join(queryArray, "&")
return u.String()
}
Comment thread
kaldown marked this conversation as resolved.

func ParseQueryExecMode(queryExecMode string) (pgx.QueryExecMode, error) {
Expand Down
76 changes: 49 additions & 27 deletions internal/sources/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ package postgres_test

import (
"context"
"sort"
"strings"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -194,43 +192,67 @@ func TestFailParseFromYaml(t *testing.T) {
}
}

func TestConvertParamMapToRawQuery(t *testing.T) {
func TestBuildPostgresURL(t *testing.T) {
tcs := []struct {
desc string
in map[string]string
want string
desc string
host string
port string
queryParams map[string]string
want string
}{
{
desc: "nil param",
in: nil,
want: "",
desc: "hostname",
host: "db.example.com",
port: "5432",
want: "postgres://u:p@db.example.com:5432/mydb",
},
{
desc: "single query param",
in: map[string]string{
"foo": "bar",
},
want: "foo=bar",
desc: "ipv4",
host: "127.0.0.1",
port: "5432",
want: "postgres://u:p@127.0.0.1:5432/mydb",
},
{
desc: "more than one query param",
in: map[string]string{
"foo": "bar",
"hello": "world",
},
want: "foo=bar&hello=world",
desc: "ipv6 loopback",
host: "::1",
port: "5432",
want: "postgres://u:p@[::1]:5432/mydb",
},
{
desc: "ipv6 documentation",
host: "2001:db8::1",
port: "5432",
want: "postgres://u:p@[2001:db8::1]:5432/mydb",
},
{
desc: "ipv6 link-local with zone id",
host: "fe80::1%eth0",
port: "5432",
want: "postgres://u:p@[fe80::1%25eth0]:5432/mydb",
},
{
desc: "query params sorted and encoded",
host: "db.example.com",
port: "5432",
queryParams: map[string]string{"sslmode": "verify-full", "application_name": "my app"},
want: "postgres://u:p@db.example.com:5432/mydb?application_name=my+app&sslmode=verify-full",
},
{
desc: "query param value with special characters",
host: "db.example.com",
port: "5432",
queryParams: map[string]string{"options": "-c statement_timeout=5s&key=val"},
want: "postgres://u:p@db.example.com:5432/mydb?options=-c+statement_timeout%3D5s%26key%3Dval",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := postgres.ConvertParamMapToRawQuery(tc.in)
if strings.Contains(got, "&") {
splitGot := strings.Split(got, "&")
sort.Strings(splitGot)
got = strings.Join(splitGot, "&")
}
got := postgres.BuildPostgresURL(tc.host, tc.port, "u", "p", "mydb", tc.queryParams)
if got != tc.want {
t.Fatalf("incorrect conversion: got %s want %s", got, tc.want)
t.Fatalf("BuildPostgresURL(%q, %q, ...) = %q, want %q", tc.host, tc.port, got, tc.want)
}
if _, err := pgx.ParseConfig(got); err != nil {
t.Fatalf("pgx.ParseConfig(%q) returned error: %v", got, err)
}
})
}
Expand Down