Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,8 @@ jobs:
- name: Run Merkle Tree --until filter tests
run: go test -count=1 -v ./tests/integration -run 'TestMerkleTreeUntilFilter'

- name: Run native PG (no spock) tests
run: go test -count=1 -v ./tests/integration -run 'TestNativePG'

- name: Run timestamp comparison tests
run: go test -count=1 -v ./tests/integration -run 'TestCompareTimestampsExact|TestPostgreSQLMicrosecondPrecision|TestOldVsNewComparison'
110 changes: 110 additions & 0 deletions db/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,116 @@ func GetSpockSlotLSNForNode(ctx context.Context, db DBQuerier, failedNode string
return lsn, nil
}

func GetNativeOriginLSNForNode(ctx context.Context, db DBQuerier, originNodeName string) (*string, error) {
sql, err := RenderSQL(SQLTemplates.GetNativeOriginLSNForNode, nil)
if err != nil {
return nil, err
}
var lsn *string
if err := db.QueryRow(ctx, sql, originNodeName).Scan(&lsn); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("failed to fetch native origin lsn: %w", err)
}
return lsn, nil
}

func GetNativeSlotLSNForNode(ctx context.Context, db DBQuerier, failedNode string) (*string, error) {
sql, err := RenderSQL(SQLTemplates.GetNativeSlotLSNForNode, nil)
if err != nil {
return nil, err
}
var lsn *string
if err := db.QueryRow(ctx, sql, failedNode).Scan(&lsn); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("failed to fetch native slot lsn: %w", err)
}
return lsn, nil
}

func GetReplicationOriginNames(ctx context.Context, db DBQuerier) (map[string]string, error) {
sql, err := RenderSQL(SQLTemplates.GetReplicationOriginNames, nil)
if err != nil {
return nil, err
}

rows, err := db.Query(ctx, sql)
if err != nil {
return nil, err
}
defer rows.Close()

names := make(map[string]string)
for rows.Next() {
var id, name string
if err := rows.Scan(&id, &name); err != nil {
return nil, err
}
names[id] = name
}

if err := rows.Err(); err != nil {
return nil, err
}

return names, nil
}

// GetNativeNodeOriginNames maps replication origin IDs to subscription names
// for native PG logical replication (no spock). This is the native PG
// equivalent of GetSpockNodeNames.
func GetNativeNodeOriginNames(ctx context.Context, db DBQuerier) (map[string]string, error) {
sql, err := RenderSQL(SQLTemplates.GetNativeNodeOriginNames, nil)
if err != nil {
return nil, err
}

rows, err := db.Query(ctx, sql)
if err != nil {
return nil, err
}
defer rows.Close()

names := make(map[string]string)
for rows.Next() {
var id, name string
if err := rows.Scan(&id, &name); err != nil {
return nil, err
}
names[id] = name
}

if err := rows.Err(); err != nil {
return nil, err
}

return names, nil
}

func GetNodeOriginNames(ctx context.Context, db DBQuerier) (map[string]string, error) {
var spockAvailable bool
err := db.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'spock')").Scan(&spockAvailable)
if err != nil {
return nil, fmt.Errorf("detecting spock extension: %w", err)
}
if spockAvailable {
return GetSpockNodeNames(ctx, db)
}
return GetNativeNodeOriginNames(ctx, db)
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

func CheckSpockInstalled(ctx context.Context, db DBQuerier) (bool, error) {
var exists bool
err := db.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'spock')").Scan(&exists)
if err != nil {
return false, fmt.Errorf("detecting spock extension: %w", err)
}
return exists, nil
}

func GetSpockRepSetInfo(ctx context.Context, db DBQuerier) ([]types.SpockRepSetInfo, error) {
sql, err := RenderSQL(SQLTemplates.SpockRepSetInfo, nil)
if err != nil {
Expand Down
122 changes: 88 additions & 34 deletions db/queries/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,18 @@

package queries

import "text/template"
import (
"text/template"

"github.com/jackc/pgx/v5"
"github.com/pgedge/ace/pkg/config"
)

// aceTemplateFuncs provides the {{aceSchema}} function to SQL templates.
// The function is evaluated at render time (after config is loaded), not at parse time.
var aceTemplateFuncs = template.FuncMap{
"aceSchema": func() string { return pgx.Identifier{config.Get().MTree.Schema}.Sanitize() },
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

type Templates struct {
EstimateRowCount *template.Template
Expand Down Expand Up @@ -78,6 +89,7 @@ type Templates struct {
GetBlockCountSimple *template.Template
GetBlockSizeFromMetadata *template.Template
GetMaxNodeLevel *template.Template
CompareBlocksSQL *template.Template

DropXORFunction *template.Template
DropMetadataTable *template.Template
Expand Down Expand Up @@ -119,6 +131,10 @@ type Templates struct {
RemoveTableFromCDCMetadata *template.Template
GetSpockOriginLSNForNode *template.Template
GetSpockSlotLSNForNode *template.Template
GetNativeOriginLSNForNode *template.Template
GetNativeSlotLSNForNode *template.Template
GetReplicationOriginNames *template.Template
GetNativeNodeOriginNames *template.Template
EnsureHashVersionColumn *template.Template
GetHashVersion *template.Template
MarkAllLeavesDirty *template.Template
Expand All @@ -133,8 +149,8 @@ type Templates struct {

var SQLTemplates = Templates{
// A template isn't needed for this query; just keeping the struct uniform
CreateMetadataTable: template.Must(template.New("createMetadataTable").Parse(`
CREATE TABLE IF NOT EXISTS spock.ace_mtree_metadata (
CreateMetadataTable: template.Must(template.New("createMetadataTable").Funcs(aceTemplateFuncs).Parse(`
CREATE TABLE IF NOT EXISTS {{aceSchema}}.ace_mtree_metadata (
schema_name text,
table_name text,
total_rows bigint,
Expand All @@ -161,8 +177,8 @@ var SQLTemplates = Templates{
ALTER PUBLICATION {{.PublicationName}} DROP TABLE {{.TableName}}
`)),

RemoveTableFromCDCMetadata: template.Must(template.New("removeTableFromCDCMetadata").Parse(`
UPDATE spock.ace_cdc_metadata
RemoveTableFromCDCMetadata: template.Must(template.New("removeTableFromCDCMetadata").Funcs(aceTemplateFuncs).Parse(`
UPDATE {{aceSchema}}.ace_cdc_metadata
SET tables = array_remove(tables, $1)
WHERE publication_name = $2
`)),
Expand All @@ -179,9 +195,9 @@ var SQLTemplates = Templates{
)
`)),

UpdateCDCMetadata: template.Must(template.New("updateCdcMetadata").Parse(`
UpdateCDCMetadata: template.Must(template.New("updateCdcMetadata").Funcs(aceTemplateFuncs).Parse(`
INSERT INTO
spock.ace_cdc_metadata (
{{aceSchema}}.ace_cdc_metadata (
publication_name,
slot_name,
start_lsn,
Expand Down Expand Up @@ -220,17 +236,17 @@ var SQLTemplates = Templates{
CheckPIDExists: template.Must(template.New("checkPIDExists").Parse(`
SELECT pid FROM pg_stat_activity WHERE pid = $1
`)),
DropCDCMetadataTable: template.Must(template.New("dropCDCMetadataTable").Parse(`
DROP TABLE IF EXISTS spock.ace_cdc_metadata
DropCDCMetadataTable: template.Must(template.New("dropCDCMetadataTable").Funcs(aceTemplateFuncs).Parse(`
DROP TABLE IF EXISTS {{aceSchema}}.ace_cdc_metadata
`)),

GetCDCMetadata: template.Must(template.New("getCDCMetadata").Parse(`
GetCDCMetadata: template.Must(template.New("getCDCMetadata").Funcs(aceTemplateFuncs).Parse(`
SELECT
slot_name,
start_lsn,
tables
FROM
spock.ace_cdc_metadata
{{aceSchema}}.ace_cdc_metadata
WHERE
publication_name = $1
`)),
Expand Down Expand Up @@ -317,8 +333,8 @@ var SQLTemplates = Templates{
AND mt.node_position = b.node_position;
`)),

CreateCDCMetadataTable: template.Must(template.New("createCDCMetadataTable").Parse(`
CREATE TABLE IF NOT EXISTS spock.ace_cdc_metadata (
CreateCDCMetadataTable: template.Must(template.New("createCDCMetadataTable").Funcs(aceTemplateFuncs).Parse(`
CREATE TABLE IF NOT EXISTS {{aceSchema}}.ace_cdc_metadata (
publication_name text PRIMARY KEY,
slot_name text,
start_lsn text,
Expand Down Expand Up @@ -772,9 +788,9 @@ var SQLTemplates = Templates{
VALUES
(0, $1, {{.StartExpr}}, {{.EndExpr}});
`)),
CreateXORFunction: template.Must(template.New("createXORFunction").Parse(`
CreateXORFunction: template.Must(template.New("createXORFunction").Funcs(aceTemplateFuncs).Parse(`
CREATE
OR REPLACE FUNCTION spock.bytea_xor(a bytea, b bytea) RETURNS bytea AS $$
OR REPLACE FUNCTION {{aceSchema}}.bytea_xor(a bytea, b bytea) RETURNS bytea AS $$
DECLARE
result bytea;
len int;
Expand Down Expand Up @@ -805,7 +821,7 @@ var SQLTemplates = Templates{
CREATE OPERATOR # (
LEFTARG = bytea,
RIGHTARG = bytea,
PROCEDURE = spock.bytea_xor
PROCEDURE = {{aceSchema}}.bytea_xor
);
END IF;
END $$;
Expand Down Expand Up @@ -840,9 +856,9 @@ var SQLTemplates = Templates{
AND c.relname = $2
AND a.attname = $3
`)),
UpdateMetadata: template.Must(template.New("updateMetadata").Parse(`
UpdateMetadata: template.Must(template.New("updateMetadata").Funcs(aceTemplateFuncs).Parse(`
INSERT INTO
spock.ace_mtree_metadata (
{{aceSchema}}.ace_mtree_metadata (
schema_name,
table_name,
total_rows,
Expand Down Expand Up @@ -873,8 +889,8 @@ var SQLTemplates = Templates{
hash_version = EXCLUDED.hash_version,
last_updated = EXCLUDED.last_updated
`)),
DeleteMetadata: template.Must(template.New("deleteMetadata").Parse(`
DELETE FROM spock.ace_mtree_metadata WHERE schema_name = $1 AND table_name = $2
DeleteMetadata: template.Must(template.New("deleteMetadata").Funcs(aceTemplateFuncs).Parse(`
DELETE FROM {{aceSchema}}.ace_mtree_metadata WHERE schema_name = $1 AND table_name = $2
`)),
InsertBlockRanges: template.Must(template.New("insertBlockRanges").Parse(`
INSERT INTO
Expand Down Expand Up @@ -1055,11 +1071,11 @@ var SQLTemplates = Templates{
ORDER BY
node_position
`)),
GetRowCountEstimate: template.Must(template.New("getRowCountEstimate").Parse(`
GetRowCountEstimate: template.Must(template.New("getRowCountEstimate").Funcs(aceTemplateFuncs).Parse(`
SELECT
total_rows
FROM
spock.ace_mtree_metadata
{{aceSchema}}.ace_mtree_metadata
WHERE
schema_name = $1
AND table_name = $2
Expand Down Expand Up @@ -1294,11 +1310,11 @@ var SQLTemplates = Templates{
mt.range_start,
mt.range_end
`)),
GetBlockSizeFromMetadata: template.Must(template.New("getBlockSizeFromMetadata").Parse(`
GetBlockSizeFromMetadata: template.Must(template.New("getBlockSizeFromMetadata").Funcs(aceTemplateFuncs).Parse(`
SELECT
block_size
FROM
spock.ace_mtree_metadata
{{aceSchema}}.ace_mtree_metadata
WHERE
schema_name = $1
AND table_name = $2
Expand All @@ -1309,11 +1325,19 @@ var SQLTemplates = Templates{
FROM
{{.MtreeTable}}
`)),
DropXORFunction: template.Must(template.New("dropXORFunction").Parse(`
DROP FUNCTION IF EXISTS spock.bytea_xor(bytea, bytea) CASCADE
CompareBlocksSQL: template.Must(template.New("compareBlocksSQL").Parse(`
SELECT
*
FROM
{{.TableName}}
WHERE
{{.WhereClause}}
`)),
Comment thread
coderabbitai[bot] marked this conversation as resolved.
DropXORFunction: template.Must(template.New("dropXORFunction").Funcs(aceTemplateFuncs).Parse(`
DROP FUNCTION IF EXISTS {{aceSchema}}.bytea_xor(bytea, bytea) CASCADE
`)),
DropMetadataTable: template.Must(template.New("dropMetadataTable").Parse(`
DROP TABLE IF EXISTS spock.ace_mtree_metadata CASCADE
DropMetadataTable: template.Must(template.New("dropMetadataTable").Funcs(aceTemplateFuncs).Parse(`
DROP TABLE IF EXISTS {{aceSchema}}.ace_mtree_metadata CASCADE
`)),
DropMtreeTable: template.Must(template.New("dropMtreeTable").Parse(`
DROP TABLE IF EXISTS {{.MtreeTable}} CASCADE
Expand Down Expand Up @@ -1506,13 +1530,13 @@ var SQLTemplates = Templates{
ORDER BY rs.confirmed_flush_lsn DESC
LIMIT 1
`)),
EnsureHashVersionColumn: template.Must(template.New("ensureHashVersionColumn").Parse(`
ALTER TABLE spock.ace_mtree_metadata
EnsureHashVersionColumn: template.Must(template.New("ensureHashVersionColumn").Funcs(aceTemplateFuncs).Parse(`
ALTER TABLE {{aceSchema}}.ace_mtree_metadata
ADD COLUMN IF NOT EXISTS hash_version int NOT NULL DEFAULT 1
`)),
GetHashVersion: template.Must(template.New("getHashVersion").Parse(`
GetHashVersion: template.Must(template.New("getHashVersion").Funcs(aceTemplateFuncs).Parse(`
SELECT COALESCE(
(SELECT hash_version FROM spock.ace_mtree_metadata
(SELECT hash_version FROM {{aceSchema}}.ace_mtree_metadata
WHERE schema_name = $1 AND table_name = $2),
1
)
Expand All @@ -1522,11 +1546,41 @@ var SQLTemplates = Templates{
SET dirty = true
WHERE node_level = 0
`)),
UpdateHashVersion: template.Must(template.New("updateHashVersion").Parse(`
UPDATE spock.ace_mtree_metadata
UpdateHashVersion: template.Must(template.New("updateHashVersion").Funcs(aceTemplateFuncs).Parse(`
UPDATE {{aceSchema}}.ace_mtree_metadata
SET hash_version = $1, last_updated = current_timestamp
WHERE schema_name = $2 AND table_name = $3
`)),
GetNativeOriginLSNForNode: template.Must(template.New("getNativeOriginLSNForNode").Parse(`
SELECT ros.remote_lsn::text
FROM pg_catalog.pg_replication_origin_status ros
JOIN pg_catalog.pg_replication_origin ro ON ro.roident = ros.local_id
JOIN pg_catalog.pg_subscription s ON ro.roname LIKE 'pg_%' || s.oid::text
WHERE s.subname ~ ('\m' || $1 || '\M')
AND ros.remote_lsn IS NOT NULL
Comment on lines +1559 to +1560
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$1 (a user-supplied node name) is interpolated raw into a regex fragment. A node name containing |, ., or * silently
matches wrong subscriptions. The equivalent Spock templates use = $1 exact match. Fix: use = $1 here too.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically valid but low-risk in practice: node
names are simple identifiers (n1, postgres-n1). And the code already handles this gracefully at
table_repair.go:2940-2942

LIMIT 1
`)),
GetNativeSlotLSNForNode: template.Must(template.New("getNativeSlotLSNForNode").Parse(`
SELECT rs.confirmed_flush_lsn::text
FROM pg_catalog.pg_replication_slots rs
JOIN pg_catalog.pg_subscription s ON rs.slot_name = s.subslotname
WHERE s.subname ~ ('\m' || $1 || '\M')
AND rs.confirmed_flush_lsn IS NOT NULL
ORDER BY rs.confirmed_flush_lsn DESC
LIMIT 1
`)),
Comment thread
coderabbitai[bot] marked this conversation as resolved.
GetReplicationOriginNames: template.Must(template.New("getReplicationOriginNames").Parse(`
SELECT roident::text, roname FROM pg_replication_origin;
`)),
// GetNativeNodeOriginNames maps pg_replication_origin entries to their
// corresponding pg_subscription names. This provides the native PG
// equivalent of GetSpockNodeNames — mapping origin IDs (used by
// pg_xact_commit_timestamp_origin) to human-readable node identifiers.
GetNativeNodeOriginNames: template.Must(template.New("getNativeNodeOriginNames").Parse(`
SELECT ro.roident::text, s.subname
FROM pg_catalog.pg_replication_origin ro
JOIN pg_catalog.pg_subscription s ON ro.roname = 'pg_' || s.oid::text
`)),
GetReplicationOriginByName: template.Must(template.New("getReplicationOriginByName").Parse(`
SELECT roident FROM pg_replication_origin WHERE roname = $1
`)),
Expand Down
Loading
Loading