diff --git a/internal/schema/tables.go b/internal/schema/tables.go index 38e60eb..0ed0bd2 100644 --- a/internal/schema/tables.go +++ b/internal/schema/tables.go @@ -12,6 +12,9 @@ type Table struct { OID int Schema string Name string + Type string + Parent sql.NullString + PartitionDef sql.NullString Comment sql.NullString Columns []*Column Dependencies []string @@ -37,6 +40,9 @@ func (t Table) DependsOn() []string { out = append(out, trig.ProcName) } } + if t.Parent.Valid { + out = append(out, t.Parent.String) + } return out } @@ -79,12 +85,28 @@ func (t Table) String() string { followUps += f.String() + "\n\n" } } + // This is the definition for a "regular" table. tableDef := fmt.Sprintf(query(`--sql CREATE TABLE %s ( %s -); +) `), pgtools.Identifier(t.Schema, t.Name), strings.Join(colDefs, ",\n ")) - constraintsByName := asMap[string](t.Constraints) + if !t.PartitionDef.Valid { + tableDef += ";" + } else { + if t.Parent.Valid { + // This is a "child" partitioned table: CREATE TABLE ... PARTITION OF ... FOR VALUE FROM ...; + tableDef = fmt.Sprintf( + "CREATE TABLE %s PARTITION OF %s %s;", + pgtools.Identifier(t.Schema, t.Name), + t.Parent.String, + t.PartitionDef.String, + ) + } else { + // This is a "regular" partitioned table definition: CREATE TABLE ... PARTITION BY ...; + tableDef += " PARTITION BY " + t.PartitionDef.String + ";" + } + } if t.Comment.Valid { tableDef += "\n\n" + fmt.Sprintf( @@ -104,6 +126,7 @@ CREATE TABLE %s ( } } + constraintsByName := asMap[string](t.Constraints) for _, index := range t.Indexes { if pkIndexes[index.SortKey()] { continue @@ -177,6 +200,9 @@ func LoadTables(config Config, db *sql.DB) ([]*Table, error) { &table.OID, &table.Schema, &table.Name, + &table.Type, + &table.Parent, + &table.PartitionDef, &table.Comment, &column.Number, &column.Name, @@ -210,7 +236,22 @@ with r as ( c.oid as oid, c.relname as name, n.nspname as schema, - c.relkind as relationtype + c.relkind as relationtype, + (SELECT + nmsp_parent.nspname || '.' || parent.relname as parent + FROM pg_inherits + JOIN pg_class parent ON pg_inherits.inhparent = parent.oid + JOIN pg_class child ON pg_inherits.inhrelid = child.oid + JOIN pg_namespace nmsp_parent ON nmsp_parent.oid = parent.relnamespace + JOIN pg_namespace nmsp_child ON nmsp_child.oid = child.relnamespace + where child.oid = c.oid) + as parent_table, + case when c.relpartbound is not null then + pg_get_expr(c.relpartbound, c.oid, true) + when c.relhassubclass is not null then + pg_catalog.pg_get_partkeydef(c.oid) + end + as partition_def from pg_catalog.pg_class c inner join pg_catalog.pg_namespace n @@ -222,6 +263,9 @@ select r.oid as "table_oid", r.schema as "table_schema", r.name as "table_name", + r.relationtype as "table_type", + r.parent_table as "table_parent", + r.partition_def as "table_partition_def", obj_description(r.oid) as "table_comment", a.attnum as "column_number", a.attname as "name", diff --git a/internal/schema/tables_test.go b/internal/schema/tables_test.go new file mode 100644 index 0000000..6197bc1 --- /dev/null +++ b/internal/schema/tables_test.go @@ -0,0 +1,98 @@ +package schema_test + +import ( + "context" + "database/sql" + "testing" + + "github.com/peterldowns/testy/assert" + + "github.com/peterldowns/pgmigrate/internal/schema" + "github.com/peterldowns/pgmigrate/internal/withdb" +) + +func TestDumpingTablesWithPartitions(t *testing.T) { + t.Parallel() + + config := schema.Config{Schema: "public"} + ctx := context.Background() + original := query(`--sql +create table events ( + created timestamp with time zone not null default now(), + event text +) partition by range (created); + +create table events_p20250101 partition of events for values from ('2025-01-01 00:00:00Z') to ('2025-02-01 00:00:00Z'); + + `) + expected := query(`--sql +CREATE TABLE public.events ( + created timestamp with time zone NOT NULL DEFAULT now(), + event text +) PARTITION BY RANGE (created); + +CREATE TABLE public.events_p20250101 PARTITION OF public.events FOR VALUES FROM ('2025-01-01 00:00:00+00') TO ('2025-02-01 00:00:00+00'); + `) + + var result *schema.Schema + // Check that the "original" parses correctly and results in the "expected" SQL. + err := withdb.WithDB(ctx, "pgx", func(db *sql.DB) error { + var err error + if _, err = db.ExecContext(ctx, original); err != nil { + return err + } + result, err = schema.Parse(config, db) + return err + }) + assert.Nil(t, err) + assert.NotEqual(t, nil, result) + assert.Equal(t, expected, result.String()) + // Check that the "expected" result perfectly roundtrips and results in itself. + err = withdb.WithDB(ctx, "pgx", func(db *sql.DB) error { + var err error + if _, err = db.ExecContext(ctx, expected); err != nil { + return err + } + result, err = schema.Parse(config, db) + return err + }) + assert.Nil(t, err) + assert.NotEqual(t, nil, result) + assert.Equal(t, expected, result.String()) +} + +func TestPartitionedTablesDependOnEachOther(t *testing.T) { + t.Parallel() + + config := schema.Config{Schema: "public"} + ctx := context.Background() + original := query(`--sql +create table events ( + created timestamp with time zone not null default now(), + event text +) partition by range (created); + +create table events_p20250101 partition of events for values from ('2025-01-01 00:00:00Z') to ('2025-02-01 00:00:00Z'); + + `) + var result *schema.Schema + // Check that the "original" parses correctly and results in the "expected" SQL. + err := withdb.WithDB(ctx, "pgx", func(db *sql.DB) error { + var err error + if _, err = db.ExecContext(ctx, original); err != nil { + return err + } + result, err = schema.Parse(config, db) + return err + }) + assert.Nil(t, err) + assert.NotEqual(t, nil, result) + + tables := asMap(result.Tables) + parent, ok := tables["public.events"] + assert.True(t, ok) + child, ok := tables["public.events_p20250101"] + assert.True(t, ok) + assert.Equal(t, []string{"public.events"}, child.DependsOn()) + assert.Equal(t, nil, parent.DependsOn()) +}