Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions internal/schema/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ type Table struct {
OID int
Schema string
Name string
Type string
Parent sql.NullString
PartitionDef sql.NullString
Comment sql.NullString
Columns []*Column
Dependencies []string
Expand All @@ -37,6 +40,9 @@ func (t Table) DependsOn() []string {
out = append(out, trig.ProcName)
}
}
if t.Parent.Valid {
out = append(out, t.Parent.String)
}
return out
}

Expand Down Expand Up @@ -79,12 +85,28 @@ func (t Table) String() string {
followUps += f.String() + "\n\n"
}
}
// This is the definition for a "regular" table.
tableDef := fmt.Sprintf(query(`--sql
CREATE TABLE %s (
%s
);
)
`), pgtools.Identifier(t.Schema, t.Name), strings.Join(colDefs, ",\n "))
constraintsByName := asMap[string](t.Constraints)
if !t.PartitionDef.Valid {
tableDef += ";"
} else {
if t.Parent.Valid {
// This is a "child" partitioned table: CREATE TABLE ... PARTITION OF ... FOR VALUE FROM ...;
tableDef = fmt.Sprintf(
"CREATE TABLE %s PARTITION OF %s %s;",
pgtools.Identifier(t.Schema, t.Name),
t.Parent.String,
t.PartitionDef.String,
)
} else {
// This is a "regular" partitioned table definition: CREATE TABLE ... PARTITION BY ...;
tableDef += " PARTITION BY " + t.PartitionDef.String + ";"
}
}

if t.Comment.Valid {
tableDef += "\n\n" + fmt.Sprintf(
Expand All @@ -104,6 +126,7 @@ CREATE TABLE %s (
}
}

constraintsByName := asMap[string](t.Constraints)
for _, index := range t.Indexes {
if pkIndexes[index.SortKey()] {
continue
Expand Down Expand Up @@ -177,6 +200,9 @@ func LoadTables(config Config, db *sql.DB) ([]*Table, error) {
&table.OID,
&table.Schema,
&table.Name,
&table.Type,
&table.Parent,
&table.PartitionDef,
&table.Comment,
&column.Number,
&column.Name,
Expand Down Expand Up @@ -210,7 +236,22 @@ with r as (
c.oid as oid,
c.relname as name,
n.nspname as schema,
c.relkind as relationtype
c.relkind as relationtype,
(SELECT
nmsp_parent.nspname || '.' || parent.relname as parent
FROM pg_inherits
JOIN pg_class parent ON pg_inherits.inhparent = parent.oid
JOIN pg_class child ON pg_inherits.inhrelid = child.oid
JOIN pg_namespace nmsp_parent ON nmsp_parent.oid = parent.relnamespace
JOIN pg_namespace nmsp_child ON nmsp_child.oid = child.relnamespace
where child.oid = c.oid)
as parent_table,
case when c.relpartbound is not null then
pg_get_expr(c.relpartbound, c.oid, true)
when c.relhassubclass is not null then
pg_catalog.pg_get_partkeydef(c.oid)
end
as partition_def
from
pg_catalog.pg_class c
inner join pg_catalog.pg_namespace n
Expand All @@ -222,6 +263,9 @@ select
r.oid as "table_oid",
r.schema as "table_schema",
r.name as "table_name",
r.relationtype as "table_type",
r.parent_table as "table_parent",
r.partition_def as "table_partition_def",
obj_description(r.oid) as "table_comment",
a.attnum as "column_number",
a.attname as "name",
Expand Down
98 changes: 98 additions & 0 deletions internal/schema/tables_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package schema_test

import (
"context"
"database/sql"
"testing"

"github.com/peterldowns/testy/assert"

"github.com/peterldowns/pgmigrate/internal/schema"
"github.com/peterldowns/pgmigrate/internal/withdb"
)

func TestDumpingTablesWithPartitions(t *testing.T) {
t.Parallel()

config := schema.Config{Schema: "public"}
ctx := context.Background()
original := query(`--sql
create table events (
created timestamp with time zone not null default now(),
event text
) partition by range (created);

create table events_p20250101 partition of events for values from ('2025-01-01 00:00:00Z') to ('2025-02-01 00:00:00Z');

`)
expected := query(`--sql
CREATE TABLE public.events (
created timestamp with time zone NOT NULL DEFAULT now(),
event text
) PARTITION BY RANGE (created);

CREATE TABLE public.events_p20250101 PARTITION OF public.events FOR VALUES FROM ('2025-01-01 00:00:00+00') TO ('2025-02-01 00:00:00+00');
`)

var result *schema.Schema
// Check that the "original" parses correctly and results in the "expected" SQL.
err := withdb.WithDB(ctx, "pgx", func(db *sql.DB) error {
var err error
if _, err = db.ExecContext(ctx, original); err != nil {
return err
}
result, err = schema.Parse(config, db)
return err
})
assert.Nil(t, err)
assert.NotEqual(t, nil, result)
assert.Equal(t, expected, result.String())
// Check that the "expected" result perfectly roundtrips and results in itself.
err = withdb.WithDB(ctx, "pgx", func(db *sql.DB) error {
var err error
if _, err = db.ExecContext(ctx, expected); err != nil {
return err
}
result, err = schema.Parse(config, db)
return err
})
assert.Nil(t, err)
assert.NotEqual(t, nil, result)
assert.Equal(t, expected, result.String())
}

func TestPartitionedTablesDependOnEachOther(t *testing.T) {
t.Parallel()

config := schema.Config{Schema: "public"}
ctx := context.Background()
original := query(`--sql
create table events (
created timestamp with time zone not null default now(),
event text
) partition by range (created);

create table events_p20250101 partition of events for values from ('2025-01-01 00:00:00Z') to ('2025-02-01 00:00:00Z');

`)
var result *schema.Schema
// Check that the "original" parses correctly and results in the "expected" SQL.
err := withdb.WithDB(ctx, "pgx", func(db *sql.DB) error {
var err error
if _, err = db.ExecContext(ctx, original); err != nil {
return err
}
result, err = schema.Parse(config, db)
return err
})
assert.Nil(t, err)
assert.NotEqual(t, nil, result)

tables := asMap(result.Tables)
parent, ok := tables["public.events"]
assert.True(t, ok)
child, ok := tables["public.events_p20250101"]
assert.True(t, ok)
assert.Equal(t, []string{"public.events"}, child.DependsOn())
assert.Equal(t, nil, parent.DependsOn())
}
Loading