diff --git a/internal/impl/mysql/config_test.go b/internal/impl/mysql/config_test.go new file mode 100644 index 0000000000..e1475e310d --- /dev/null +++ b/internal/impl/mysql/config_test.go @@ -0,0 +1,181 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package mysql + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Ensures the snapshot_max_parallel_tables field defaults to 1 (preserving +// the pre-parallel behaviour for configs that don't set it) and that explicit +// values round-trip through the spec. +func TestConfig_SnapshotMaxParallelTables_DefaultAndExplicit(t *testing.T) { + tests := []struct { + name string + yaml string + expected int + }{ + { + name: "default", + yaml: ` +dsn: user:password@tcp(localhost:3306)/db +tables: [a] +stream_snapshot: true +checkpoint_cache: foo +`, + expected: 1, + }, + { + name: "explicit=8", + yaml: ` +dsn: user:password@tcp(localhost:3306)/db +tables: [a] +stream_snapshot: true +checkpoint_cache: foo +snapshot_max_parallel_tables: 8 +`, + expected: 8, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + conf, err := mysqlStreamConfigSpec.ParseYAML(tc.yaml, nil) + require.NoError(t, err) + + got, err := conf.FieldInt(fieldSnapshotMaxParallelTables) + require.NoError(t, err) + assert.Equal(t, tc.expected, got) + }) + } +} + +// Ensures newMySQLStreamInput's post-parse validation rejects non-positive +// values for snapshot_max_parallel_tables. We exercise the field contract via +// the spec rather than the full constructor (which requires a license and a +// cache resource). +func TestConfig_SnapshotMaxParallelTables_InvalidValuesRejected(t *testing.T) { + tests := []struct { + name string + value int + }{ + {"zero", 0}, + {"negative", -5}, + {"above_upper_bound", maxSnapshotParallelTables + 1}, + {"absurdly_large", 10000}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + yaml := fmt.Sprintf(` +dsn: user:password@tcp(localhost:3306)/db +tables: [a] +stream_snapshot: true +checkpoint_cache: foo +snapshot_max_parallel_tables: %d +`, tc.value) + conf, err := mysqlStreamConfigSpec.ParseYAML(yaml, nil) + require.NoError(t, err, "spec parsing itself should succeed; validation is enforced inside newMySQLStreamInput") + + // Mirror the constructor's validation logic (we can't invoke the + // constructor directly without a license/cache, but this asserts + // the validation predicate that guards it). + got, err := conf.FieldInt(fieldSnapshotMaxParallelTables) + require.NoError(t, err) + assert.True(t, + got < 1 || got > maxSnapshotParallelTables, + "configured value should violate the [1, %d] range enforced in newMySQLStreamInput", maxSnapshotParallelTables, + ) + }) + } +} + +// Same shape as the max_parallel_tables tests: the new snapshot_chunks_per_table +// field must default to 1 (preserving whole-table-read behaviour) and must +// round-trip explicit values through the spec. +func TestConfig_SnapshotChunksPerTable_DefaultAndExplicit(t *testing.T) { + tests := []struct { + name string + yaml string + expected int + }{ + { + name: "default", + yaml: ` +dsn: user:password@tcp(localhost:3306)/db +tables: [a] +stream_snapshot: true +checkpoint_cache: foo +`, + expected: 1, + }, + { + name: "explicit=16", + yaml: ` +dsn: user:password@tcp(localhost:3306)/db +tables: [a] +stream_snapshot: true +checkpoint_cache: foo +snapshot_chunks_per_table: 16 +`, + expected: 16, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + conf, err := mysqlStreamConfigSpec.ParseYAML(tc.yaml, nil) + require.NoError(t, err) + + got, err := conf.FieldInt(fieldSnapshotChunksPerTable) + require.NoError(t, err) + assert.Equal(t, tc.expected, got) + }) + } +} + +// Guards the same validation predicate for chunks_per_table that the +// constructor enforces: values outside [1, maxSnapshotChunksPerTable] must +// fail fast rather than produce runaway planning queries. +func TestConfig_SnapshotChunksPerTable_InvalidValuesRejected(t *testing.T) { + tests := []struct { + name string + value int + }{ + {"zero", 0}, + {"negative", -1}, + {"above_upper_bound", maxSnapshotChunksPerTable + 1}, + {"absurdly_large", 100000}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + yaml := fmt.Sprintf(` +dsn: user:password@tcp(localhost:3306)/db +tables: [a] +stream_snapshot: true +checkpoint_cache: foo +snapshot_chunks_per_table: %d +`, tc.value) + conf, err := mysqlStreamConfigSpec.ParseYAML(yaml, nil) + require.NoError(t, err, "spec parsing itself should succeed; validation is enforced inside newMySQLStreamInput") + + got, err := conf.FieldInt(fieldSnapshotChunksPerTable) + require.NoError(t, err) + assert.True(t, + got < 1 || got > maxSnapshotChunksPerTable, + "configured value should violate the [1, %d] range enforced in newMySQLStreamInput", maxSnapshotChunksPerTable, + ) + }) + } +} diff --git a/internal/impl/mysql/input_mysql_stream.go b/internal/impl/mysql/input_mysql_stream.go index 07fe87ca87..f569dd0da4 100644 --- a/internal/impl/mysql/input_mysql_stream.go +++ b/internal/impl/mysql/input_mysql_stream.go @@ -36,21 +36,39 @@ import ( ) const ( - fieldMySQLFlavor = "flavor" - fieldMySQLDSN = "dsn" - fieldMySQLTables = "tables" - fieldStreamSnapshot = "stream_snapshot" - fieldSnapshotMaxBatchSize = "snapshot_max_batch_size" - fieldMaxReconnectAttempts = "max_reconnect_attempts" - fieldBatching = "batching" - fieldCheckpointKey = "checkpoint_key" - fieldCheckpointCache = "checkpoint_cache" - fieldCheckpointLimit = "checkpoint_limit" - fieldAWSIAMAuth = "aws" + fieldMySQLFlavor = "flavor" + fieldMySQLDSN = "dsn" + fieldMySQLTables = "tables" + fieldStreamSnapshot = "stream_snapshot" + fieldSnapshotMaxBatchSize = "snapshot_max_batch_size" + fieldSnapshotMaxParallelTables = "snapshot_max_parallel_tables" + fieldSnapshotChunksPerTable = "snapshot_chunks_per_table" + fieldMaxReconnectAttempts = "max_reconnect_attempts" + fieldBatching = "batching" + fieldCheckpointKey = "checkpoint_key" + fieldCheckpointCache = "checkpoint_cache" + fieldCheckpointLimit = "checkpoint_limit" + fieldAWSIAMAuth = "aws" // FieldAWSIAMAuthEnabled enabled field. FieldAWSIAMAuthEnabled = "enabled" shutdownTimeout = 5 * time.Second + + // maxSnapshotParallelTables is an upper bound on the snapshot worker pool. + // It guards against accidental denial-of-service from a mis-typed config + // value that would otherwise try to open thousands of MySQL connections + // at once. Operators with a legitimate need for more parallelism can open + // an issue — 256 is already well beyond the point at which the MySQL + // server's own connection limits dominate. + maxSnapshotParallelTables = 256 + + // maxSnapshotChunksPerTable caps chunks_per_table for the same reason as + // maxSnapshotParallelTables: a mis-typed value should fail fast at config + // parse time rather than produce thousands of MIN/MAX planning queries + // and slow down startup. The actual concurrency ceiling is still + // snapshot_max_parallel_tables — chunks above that just rebalance work + // across the fixed worker pool. + maxSnapshotChunksPerTable = 256 ) func notImportedAWSOptFn(_ context.Context, awsConf *service.ParsedConfig, _ *mysql.Config, _ *service.Logger) (TokenBuilder, error) { @@ -103,6 +121,14 @@ This input adds the following metadata fields to each message: service.NewIntField(fieldSnapshotMaxBatchSize). Description("The maximum number of rows to be streamed in a single batch when taking a snapshot."). Default(1000), + service.NewIntField(fieldSnapshotMaxParallelTables). + Description("The maximum number of tables that may be snapshotted in parallel. When set to `1` (the default) tables are read sequentially using a single transaction, preserving the previous behaviour. When set higher, multiple `REPEATABLE READ` transactions are opened on separate connections under a single brief `FLUSH TABLES ... WITH READ LOCK` window so every worker observes an identical, globally-consistent snapshot at the same binlog position. Must be between `1` and `256`."). + Advanced(). + Default(1), + service.NewIntField(fieldSnapshotChunksPerTable). + Description("The number of primary-key chunks each table is split into during the snapshot. When set to `1` (the default) each table is read as a single unit. When set higher, each table's first primary-key column is probed for `MIN` and `MAX` and the resulting integer range is split into N equal half-open chunks that are dispatched across the `"+fieldSnapshotMaxParallelTables+"` worker pool. This is how a single very large table is parallelised. Only tables whose first primary-key column is an integer type (`tinyint`, `smallint`, `mediumint`, `int`, `integer`, or `bigint`, signed or unsigned) are chunked; tables with non-numeric first PK columns fall back to a single whole-table read and log the reason. Composite primary keys are supported — chunking uses the leading column only, and per-chunk keyset pagination continues to respect the full PK ordering. Must be between `1` and `256`."). + Advanced(). + Default(1), service.NewIntField(fieldMaxReconnectAttempts). Description("The maximum number of attempts the MySQL driver will try to re-establish a broken connection before Connect attempts reconnection. A zero or negative number means infinite retry attempts."). Advanced(). @@ -180,10 +206,12 @@ type mysqlStreamInput struct { tables []string streamSnapshot bool - batching service.BatchPolicy - batchPolicy *service.Batcher - checkPointLimit int - fieldSnapshotMaxBatchSize int + batching service.BatchPolicy + batchPolicy *service.Batcher + checkPointLimit int + fieldSnapshotMaxBatchSize int + fieldSnapshotMaxParallelTables int + fieldSnapshotChunksPerTable int logger *service.Logger res *service.Resources @@ -279,6 +307,26 @@ func newMySQLStreamInput(conf *service.ParsedConfig, res *service.Resources) (s return nil, err } + if i.fieldSnapshotMaxParallelTables, err = conf.FieldInt(fieldSnapshotMaxParallelTables); err != nil { + return nil, err + } + if i.fieldSnapshotMaxParallelTables < 1 { + return nil, fmt.Errorf("field '%s' must be at least 1, got %d", fieldSnapshotMaxParallelTables, i.fieldSnapshotMaxParallelTables) + } + if i.fieldSnapshotMaxParallelTables > maxSnapshotParallelTables { + return nil, fmt.Errorf("field '%s' must be at most %d, got %d", fieldSnapshotMaxParallelTables, maxSnapshotParallelTables, i.fieldSnapshotMaxParallelTables) + } + + if i.fieldSnapshotChunksPerTable, err = conf.FieldInt(fieldSnapshotChunksPerTable); err != nil { + return nil, err + } + if i.fieldSnapshotChunksPerTable < 1 { + return nil, fmt.Errorf("field '%s' must be at least 1, got %d", fieldSnapshotChunksPerTable, i.fieldSnapshotChunksPerTable) + } + if i.fieldSnapshotChunksPerTable > maxSnapshotChunksPerTable { + return nil, fmt.Errorf("field '%s' must be at most %d, got %d", fieldSnapshotChunksPerTable, maxSnapshotChunksPerTable, i.fieldSnapshotChunksPerTable) + } + if i.canalMaxConnAttempts, err = conf.FieldInt(fieldMaxReconnectAttempts); err != nil { return nil, err } @@ -418,21 +466,15 @@ func (i *mysqlStreamInput) Connect(ctx context.Context) error { func (i *mysqlStreamInput) startMySQLSync(ctx context.Context, pos *position, snapshot *Snapshot) error { // If we are given a snapshot, then we need to read it. if snapshot != nil { - startPos, err := snapshot.prepareSnapshot(ctx, i.tables) - if err != nil { - _ = snapshot.close() - return fmt.Errorf("unable to prepare snapshot: %w", err) - } - if err = i.readSnapshot(ctx, snapshot); err != nil { - _ = snapshot.close() - return fmt.Errorf("failed reading snapshot: %w", err) - } - if err = snapshot.releaseSnapshot(ctx); err != nil { - _ = snapshot.close() - return fmt.Errorf("unable to release snapshot: %w", err) + var startPos *position + var err error + if i.fieldSnapshotMaxParallelTables <= 1 && i.fieldSnapshotChunksPerTable <= 1 { + startPos, err = i.runSequentialSnapshot(ctx, snapshot) + } else { + startPos, err = i.runParallelSnapshot(ctx, snapshot) } - if err = snapshot.close(); err != nil { - return fmt.Errorf("unable to close snapshot: %w", err) + if err != nil { + return err } // Signal snapshot completion. readMessages will flush any partial batch // and pre-resolve a checkpoint entry for startPos so the cache is @@ -459,91 +501,184 @@ func (i *mysqlStreamInput) startMySQLSync(ctx context.Context, pos *position, sn return nil } +// runSequentialSnapshot executes the original single-transaction snapshot flow: +// one FLUSH TABLES WITH READ LOCK window, one consistent-snapshot transaction, +// tables read serially by a single goroutine. Preserves byte-identical +// behaviour from before parallel-snapshot support was introduced. +func (i *mysqlStreamInput) runSequentialSnapshot(ctx context.Context, snapshot *Snapshot) (*position, error) { + startPos, err := snapshot.prepareSnapshot(ctx, i.tables) + if err != nil { + _ = snapshot.close() + return nil, fmt.Errorf("unable to prepare snapshot: %w", err) + } + if err = i.readSnapshot(ctx, snapshot); err != nil { + _ = snapshot.close() + return nil, fmt.Errorf("failed reading snapshot: %w", err) + } + if err = snapshot.releaseSnapshot(ctx); err != nil { + _ = snapshot.close() + return nil, fmt.Errorf("unable to release snapshot: %w", err) + } + if err = snapshot.close(); err != nil { + return nil, fmt.Errorf("unable to close snapshot: %w", err) + } + return startPos, nil +} + +// runParallelSnapshot opens fieldSnapshotMaxParallelTables consistent-snapshot +// transactions under a single FLUSH TABLES WITH READ LOCK window and reads the +// configured tables concurrently. All workers share one binlog position so the +// downstream handoff to the binlog stream is unchanged from the sequential +// path. The original snapshot argument is used only as a carrier for the +// already-open *sql.DB; ownership of that db is transferred to the parallel +// set (which closes it when done) so the caller must not reuse the original +// Snapshot afterwards. +func (i *mysqlStreamInput) runParallelSnapshot(ctx context.Context, snapshot *Snapshot) (*position, error) { + // Transfer db ownership to the parallel set before doing anything that + // might fail: if prepare fails, the set's close will release the db, and + // we want snapshot.close() to be a safe no-op in that case. + db := snapshot.db + snapshot.db = nil + + // Workers are capped by the plausible number of work units: at most + // chunks_per_table * len(tables), and never more than requested. Planning + // may emit fewer units (e.g. some tables fall back to whole-table reads) + // but the over-provisioning cost is bounded and connections held by idle + // workers are released when the snapshot completes. + workerCount := i.fieldSnapshotMaxParallelTables + if maxUnits := len(i.tables) * i.fieldSnapshotChunksPerTable; workerCount > maxUnits { + workerCount = maxUnits + } + + set, startPos, err := prepareParallelSnapshotSet(ctx, i.logger, db, i.tables, workerCount) + if err != nil { + // prepareParallelSnapshotSet closed db on its own error paths. + return nil, fmt.Errorf("unable to prepare parallel snapshot: %w", err) + } + + // Plan work units using any worker's consistent-snapshot transaction. + // All workers observe identical state so MIN/MAX computed here apply + // uniformly to every worker's subsequent reads. + units, err := planSnapshotWork(ctx, set.workers[0], i.tables, i.fieldSnapshotChunksPerTable) + if err != nil { + _ = set.close() + return nil, fmt.Errorf("plan snapshot work: %w", err) + } + i.logger.Infof("Parallel snapshot planned: %d tables -> %d work units across %d workers", len(i.tables), len(units), len(set.workers)) + + if err := i.readSnapshotParallel(ctx, set, units); err != nil { + _ = set.close() + return nil, fmt.Errorf("failed reading snapshot: %w", err) + } + if err := set.release(ctx); err != nil { + _ = set.close() + return nil, fmt.Errorf("unable to release parallel snapshot: %w", err) + } + if err := set.close(); err != nil { + return nil, fmt.Errorf("unable to close parallel snapshot: %w", err) + } + return startPos, nil +} + func (i *mysqlStreamInput) readSnapshot(ctx context.Context, snapshot *Snapshot) error { - // TODO(cdc): Process tables in parallel for _, table := range i.tables { - // Pre-populate schema cache so snapshot messages carry schema metadata. - if tbl, err := i.canal.GetTable(i.mysqlConfig.DBName, table); err == nil { - if _, err := i.getTableSchema(tbl); err != nil { - i.logger.Warnf("Failed to pre-populate schema for table %s during snapshot: %v", table, err) - } + if err := i.readSnapshotWorkUnit(ctx, snapshot, snapshotWorkUnit{table: table}); err != nil { + return err + } + } + return nil +} + +// readSnapshotWorkUnit snapshots one work unit — either a whole table or a +// primary-key chunk of a table — by paging through its rows in primary-key +// order using the REPEATABLE READ / CONSISTENT SNAPSHOT transaction held by +// snapshot. When unit.bounds is nil the whole table is read; otherwise rows +// are filtered by the chunk's [lowerIncl, upperExcl) range on the first PK +// column. Both the sequential and the parallel paths use this same body so +// per-table semantics are identical regardless of chunking configuration. +func (i *mysqlStreamInput) readSnapshotWorkUnit(ctx context.Context, snapshot *Snapshot, unit snapshotWorkUnit) error { + table := unit.table + // Pre-populate schema cache so snapshot messages carry schema metadata. + if tbl, err := i.canal.GetTable(i.mysqlConfig.DBName, table); err == nil { + if _, err := i.getTableSchema(tbl); err != nil { + i.logger.Warnf("Failed to pre-populate schema for table %s during snapshot: %v", table, err) + } + } else { + i.logger.Warnf("Failed to fetch schema for table %s during snapshot: %v", table, err) + } + tablePks, err := snapshot.getTablePrimaryKeys(ctx, table) + if err != nil { + return err + } + i.logger.Tracef("primary keys for table %s: %v", table, tablePks) + lastSeenPksValues := map[string]any{} + for _, pk := range tablePks { + lastSeenPksValues[pk] = nil + } + + var numRowsProcessed int + for { + var batchRows *sql.Rows + if numRowsProcessed == 0 { + batchRows, err = snapshot.querySnapshotTable(ctx, table, tablePks, unit.bounds, nil, i.fieldSnapshotMaxBatchSize) } else { - i.logger.Warnf("Failed to fetch schema for table %s during snapshot: %v", table, err) + batchRows, err = snapshot.querySnapshotTable(ctx, table, tablePks, unit.bounds, &lastSeenPksValues, i.fieldSnapshotMaxBatchSize) } - tablePks, err := snapshot.getTablePrimaryKeys(ctx, table) if err != nil { - return err + return fmt.Errorf("executing snapshot table query: %s", err) } - i.logger.Tracef("primary keys for table %s: %v", table, tablePks) - lastSeenPksValues := map[string]any{} - for _, pk := range tablePks { - lastSeenPksValues[pk] = nil + + types, err := batchRows.ColumnTypes() + if err != nil { + return fmt.Errorf("fetching column types: %s", err) } - var numRowsProcessed int - for { - var batchRows *sql.Rows - if numRowsProcessed == 0 { - batchRows, err = snapshot.querySnapshotTable(ctx, table, tablePks, nil, i.fieldSnapshotMaxBatchSize) - } else { - batchRows, err = snapshot.querySnapshotTable(ctx, table, tablePks, &lastSeenPksValues, i.fieldSnapshotMaxBatchSize) - } - if err != nil { - return fmt.Errorf("executing snapshot table query: %s", err) - } + values, mappers := prepSnapshotScannerAndMappers(types) - types, err := batchRows.ColumnTypes() - if err != nil { - return fmt.Errorf("fetching column types: %s", err) - } + columns, err := batchRows.Columns() + if err != nil { + return fmt.Errorf("fetching columns: %s", err) + } - values, mappers := prepSnapshotScannerAndMappers(types) + var batchRowsCount int + for batchRows.Next() { + numRowsProcessed++ + batchRowsCount++ - columns, err := batchRows.Columns() - if err != nil { - return fmt.Errorf("fetching columns: %s", err) + if err := batchRows.Scan(values...); err != nil { + return err } - var batchRowsCount int - for batchRows.Next() { - numRowsProcessed++ - batchRowsCount++ - - if err := batchRows.Scan(values...); err != nil { + row := map[string]any{} + for idx, value := range values { + v, err := mappers[idx](value) + if err != nil { return err } - - row := map[string]any{} - for idx, value := range values { - v, err := mappers[idx](value) - if err != nil { - return err - } - row[columns[idx]] = v - if _, ok := lastSeenPksValues[columns[idx]]; ok { - lastSeenPksValues[columns[idx]] = value - } - } - - select { - case i.rawMessageEvents <- MessageEvent{ - Row: row, - Operation: MessageOperationRead, - Table: table, - Position: nil, - }: - case <-ctx.Done(): - return ctx.Err() + row[columns[idx]] = v + if _, ok := lastSeenPksValues[columns[idx]]; ok { + lastSeenPksValues[columns[idx]] = value } } - if err := batchRows.Err(); err != nil { - return fmt.Errorf("iterating snapshot table: %s", err) + select { + case i.rawMessageEvents <- MessageEvent{ + Row: row, + Operation: MessageOperationRead, + Table: table, + Position: nil, + }: + case <-ctx.Done(): + return ctx.Err() } + } - if batchRowsCount < i.fieldSnapshotMaxBatchSize { - break - } + if err := batchRows.Err(); err != nil { + return fmt.Errorf("iterating snapshot table: %s", err) + } + + if batchRowsCount < i.fieldSnapshotMaxBatchSize { + break } } return nil diff --git a/internal/impl/mysql/integration_test.go b/internal/impl/mysql/integration_test.go index 45b9ad3d92..620a5c34ca 100644 --- a/internal/impl/mysql/integration_test.go +++ b/internal/impl/mysql/integration_test.go @@ -282,6 +282,307 @@ file: require.NoError(t, streamOut.StopWithin(time.Second*10)) } +// TestIntegrationMySQLParallelSnapshot verifies that enabling +// snapshot_max_parallel_tables produces the same total set of snapshot rows +// across multiple tables as the sequential path, and that the subsequent +// binlog-stream handoff captures ongoing writes correctly. The parallel path +// opens N REPEATABLE READ / CONSISTENT SNAPSHOT transactions under one +// FLUSH TABLES WITH READ LOCK window, so all workers observe identical state. +func TestIntegrationMySQLParallelSnapshot(t *testing.T) { + dsn, db := setupTestWithMySQLVersion(t, "8.0") + + // Create 4 tables and pre-load each with 500 rows so a parallel snapshot + // has meaningful per-worker work and the distribution is observable. + tableNames := []string{"foo1", "foo2", "foo3", "foo4"} + const rowsPerTable = 500 + + for _, tbl := range tableNames { + db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (a INT PRIMARY KEY)", tbl)) + for i := range rowsPerTable { + db.Exec(fmt.Sprintf("INSERT INTO %s VALUES (?)", tbl), i) + } + } + + template := fmt.Sprintf(` +mysql_cdc: + dsn: %s + stream_snapshot: true + snapshot_max_batch_size: 100 + snapshot_max_parallel_tables: 4 + checkpoint_cache: parcache + tables: + - foo1 + - foo2 + - foo3 + - foo4 +`, dsn) + + cacheConf := fmt.Sprintf(` +label: parcache +file: + directory: %s`, t.TempDir()) + + streamOutBuilder := service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: DEBUG`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + snapshotCounts := map[string]*atomic.Int64{} + cdcCounts := map[string]*atomic.Int64{} + for _, tbl := range tableNames { + snapshotCounts[tbl] = &atomic.Int64{} + cdcCounts[tbl] = &atomic.Int64{} + } + var totalMsgs atomic.Int64 + + require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(_ context.Context, mb service.MessageBatch) error { + for _, msg := range mb { + op, _ := msg.MetaGet("operation") + tbl, _ := msg.MetaGet("table") + if c, ok := snapshotCounts[tbl]; ok && op == "read" { + c.Add(1) + } + if c, ok := cdcCounts[tbl]; ok && (op == "insert" || op == "update" || op == "delete") { + c.Add(1) + } + totalMsgs.Add(1) + } + return nil + })) + + streamOut, err := streamOutBuilder.Build() + require.NoError(t, err) + license.InjectTestService(streamOut.Resources()) + + go func() { + err = streamOut.Run(t.Context()) + require.NoError(t, err) + }() + + // Wait for the snapshot phase to complete for all tables. + assert.Eventually(t, func() bool { + for _, tbl := range tableNames { + if snapshotCounts[tbl].Load() < int64(rowsPerTable) { + return false + } + } + return true + }, time.Minute*2, time.Millisecond*100, "parallel snapshot should emit %d rows per table", rowsPerTable) + + // Write additional rows post-snapshot and confirm the binlog-stream + // handoff picks them up — this validates that the single shared binlog + // position captured under the read-lock window is still a valid starting + // point for the binlog consumer. + const cdcRowsPerTable = 100 + for _, tbl := range tableNames { + for i := rowsPerTable; i < rowsPerTable+cdcRowsPerTable; i++ { + db.Exec(fmt.Sprintf("INSERT INTO %s VALUES (?)", tbl), i) + } + } + + assert.Eventually(t, func() bool { + for _, tbl := range tableNames { + if cdcCounts[tbl].Load() < int64(cdcRowsPerTable) { + return false + } + } + return true + }, time.Minute*2, time.Millisecond*100, "binlog stream should pick up post-snapshot inserts for each table") + + // Sanity check: every snapshot row was emitted exactly once (no + // duplicates from overlapping per-worker transactions). + for _, tbl := range tableNames { + assert.Equal(t, int64(rowsPerTable), snapshotCounts[tbl].Load(), "exactly %d snapshot rows expected for %s", rowsPerTable, tbl) + } + + require.NoError(t, streamOut.StopWithin(time.Second*10)) +} + +// TestIntegrationMySQLChunkedSnapshot exercises intra-table chunking with +// both a single-column integer PK and a composite (int, int) PK. The +// chunked path should still emit every row exactly once under the shared +// consistent-snapshot window, and the binlog-stream handoff should still +// pick up post-snapshot writes correctly. +func TestIntegrationMySQLChunkedSnapshot(t *testing.T) { + dsn, db := setupTestWithMySQLVersion(t, "8.0") + + // single_pk: single INT PK. composite_pk: (tenant_id, id) — chunking + // uses the first column only, so we spread rows across tenant ids so + // each chunk gets non-empty work. + const rowsPerTable = 2000 + db.Exec("CREATE TABLE single_pk (id INT PRIMARY KEY, payload VARCHAR(32))") + db.Exec("CREATE TABLE composite_pk (tenant_id INT, id INT, payload VARCHAR(32), PRIMARY KEY (tenant_id, id))") + + for i := range rowsPerTable { + db.Exec("INSERT INTO single_pk VALUES (?, ?)", i, fmt.Sprintf("row-%d", i)) + // tenant_id spans [0, 40) and id spans [0, 50) so chunking on + // tenant_id produces meaningful range partitions. + db.Exec("INSERT INTO composite_pk VALUES (?, ?, ?)", i%40, i/40, fmt.Sprintf("row-%d", i)) + } + + template := fmt.Sprintf(` +mysql_cdc: + dsn: %s + stream_snapshot: true + snapshot_max_batch_size: 200 + snapshot_max_parallel_tables: 4 + snapshot_chunks_per_table: 8 + checkpoint_cache: chunkcache + tables: + - single_pk + - composite_pk +`, dsn) + + cacheConf := fmt.Sprintf(` +label: chunkcache +file: + directory: %s`, t.TempDir()) + + streamOutBuilder := service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: DEBUG`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + snapshotCounts := map[string]*atomic.Int64{ + "single_pk": {}, + "composite_pk": {}, + } + cdcCounts := map[string]*atomic.Int64{ + "single_pk": {}, + "composite_pk": {}, + } + + // Track the pk values we observe during snapshot so we can detect + // duplicates from overlapping chunk ranges — the most likely correctness + // regression if the range predicates get subtly wrong. + seenSingle := sync.Map{} + seenComposite := sync.Map{} + var duplicateCount atomic.Int64 + + require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(_ context.Context, mb service.MessageBatch) error { + for _, msg := range mb { + op, _ := msg.MetaGet("operation") + tbl, _ := msg.MetaGet("table") + c, ok := snapshotCounts[tbl] + if !ok { + continue + } + if op == "read" { + c.Add(1) + body, err := msg.AsStructured() + if err != nil { + return err + } + row, _ := body.(map[string]any) + switch tbl { + case "single_pk": + id := fmt.Sprintf("%v", row["id"]) + if _, loaded := seenSingle.LoadOrStore(id, struct{}{}); loaded { + duplicateCount.Add(1) + } + case "composite_pk": + key := fmt.Sprintf("%v/%v", row["tenant_id"], row["id"]) + if _, loaded := seenComposite.LoadOrStore(key, struct{}{}); loaded { + duplicateCount.Add(1) + } + } + } + if op == "insert" || op == "update" || op == "delete" { + cdcCounts[tbl].Add(1) + } + } + return nil + })) + + streamOut, err := streamOutBuilder.Build() + require.NoError(t, err) + license.InjectTestService(streamOut.Resources()) + + go func() { + err = streamOut.Run(t.Context()) + require.NoError(t, err) + }() + + assert.Eventually(t, func() bool { + return snapshotCounts["single_pk"].Load() >= rowsPerTable && + snapshotCounts["composite_pk"].Load() >= rowsPerTable + }, time.Minute*2, time.Millisecond*100, "chunked snapshot should emit %d rows per table", rowsPerTable) + + // Every row appeared exactly once and no chunk produced duplicates. + assert.Equal(t, int64(rowsPerTable), snapshotCounts["single_pk"].Load()) + assert.Equal(t, int64(rowsPerTable), snapshotCounts["composite_pk"].Load()) + assert.Zero(t, duplicateCount.Load(), "chunk ranges must not overlap") + + // Binlog handoff still works after the chunked snapshot. + const cdcRows = 50 + for i := rowsPerTable; i < rowsPerTable+cdcRows; i++ { + db.Exec("INSERT INTO single_pk VALUES (?, ?)", i, "cdc") + db.Exec("INSERT INTO composite_pk VALUES (?, ?, ?)", i%40, 1000+i, "cdc") + } + assert.Eventually(t, func() bool { + return cdcCounts["single_pk"].Load() >= cdcRows && cdcCounts["composite_pk"].Load() >= cdcRows + }, time.Minute*2, time.Millisecond*100, "binlog stream should pick up post-snapshot inserts") + + require.NoError(t, streamOut.StopWithin(time.Second*10)) +} + +// TestIntegrationMySQLChunkedSnapshotNonNumericPKFallback confirms that a +// table whose first PK column is non-numeric (here, VARCHAR) is not chunked +// — it falls back to a single whole-table read and the snapshot completes +// without error. +func TestIntegrationMySQLChunkedSnapshotNonNumericPKFallback(t *testing.T) { + dsn, db := setupTestWithMySQLVersion(t, "8.0") + + const rowsPerTable = 300 + db.Exec("CREATE TABLE string_pk (id VARCHAR(64) PRIMARY KEY, payload VARCHAR(32))") + for i := range rowsPerTable { + db.Exec("INSERT INTO string_pk VALUES (?, ?)", fmt.Sprintf("key-%04d", i), "p") + } + + template := fmt.Sprintf(` +mysql_cdc: + dsn: %s + stream_snapshot: true + snapshot_max_batch_size: 100 + snapshot_max_parallel_tables: 2 + snapshot_chunks_per_table: 8 + checkpoint_cache: fbcache + tables: + - string_pk +`, dsn) + cacheConf := fmt.Sprintf("label: fbcache\nfile:\n directory: %s", t.TempDir()) + + streamOutBuilder := service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: DEBUG`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + var snapCount atomic.Int64 + require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(_ context.Context, mb service.MessageBatch) error { + for _, msg := range mb { + if op, _ := msg.MetaGet("operation"); op == "read" { + snapCount.Add(1) + } + } + return nil + })) + + streamOut, err := streamOutBuilder.Build() + require.NoError(t, err) + license.InjectTestService(streamOut.Resources()) + + go func() { + err = streamOut.Run(t.Context()) + require.NoError(t, err) + }() + + assert.Eventually(t, func() bool { + return snapCount.Load() >= rowsPerTable + }, time.Minute, time.Millisecond*100, "fallback whole-table read should still emit all %d rows", rowsPerTable) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) +} + func TestIntegrationMySQLCDCWithCompositePrimaryKeys(t *testing.T) { dsn, db := setupTestWithMySQLVersion(t, "8.0") // Create table diff --git a/internal/impl/mysql/parallel_snapshot.go b/internal/impl/mysql/parallel_snapshot.go new file mode 100644 index 0000000000..2dbb90d1b2 --- /dev/null +++ b/internal/impl/mysql/parallel_snapshot.go @@ -0,0 +1,228 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package mysql + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/redpanda-data/benthos/v4/public/service" + "golang.org/x/sync/errgroup" +) + +// parallelSnapshotSet owns the shared *sql.DB and a pool of per-worker Snapshot +// instances. Every worker in the set holds its own *sql.Conn and its own +// REPEATABLE READ / CONSISTENT SNAPSHOT transaction, but all transactions were +// opened within a single FLUSH TABLES ... WITH READ LOCK window so they view +// identical state at the same binlog position. +type parallelSnapshotSet struct { + db *sql.DB + workers []*Snapshot + logger *service.Logger +} + +// prepareParallelSnapshotSet opens workerCount reader connections that all +// share a single globally-consistent MySQL snapshot: +// +// 1. Acquire a single lock connection and FLUSH TABLES WITH READ LOCK. +// 2. Open workerCount snapshot connections, each starting a REPEATABLE READ +// transaction followed by START TRANSACTION WITH CONSISTENT SNAPSHOT. +// 3. Capture the binlog position once (all workers share this position). +// 4. Release the table locks and return. +// +// The returned set's workers can each be read from in parallel without +// coordination: they are independent connections/transactions observing the +// same historical state. The caller is responsible for invoking release then +// close once snapshot reading is finished. +// +// workerCount must already be bounded by the caller (e.g. to the number of +// expected work units). This function opens exactly workerCount connections; +// it does not second-guess the caller's sizing. +// +// Ownership: this function takes ownership of db. On success the returned set +// closes db when set.close() is called. On error db is closed before the +// function returns (along with any partially-opened conns/txns) and the +// caller must not reuse it. +func prepareParallelSnapshotSet(ctx context.Context, logger *service.Logger, db *sql.DB, tables []string, workerCount int) (*parallelSnapshotSet, *position, error) { + if workerCount < 1 { + _ = db.Close() + return nil, nil, fmt.Errorf("parallel snapshot worker count must be >= 1, got %d", workerCount) + } + if len(tables) == 0 { + _ = db.Close() + return nil, nil, errors.New("no tables provided") + } + + set := ¶llelSnapshotSet{db: db, logger: logger} + // failWith closes the partially-built set (which closes db) and returns + // the combined error. Use this on every error path below. + failWith := func(errs ...error) (*parallelSnapshotSet, *position, error) { + errs = append(errs, set.close()) + return nil, nil, errors.Join(errs...) + } + + lockConn, err := db.Conn(ctx) + if err != nil { + return failWith(fmt.Errorf("create lock connection: %w", err)) + } + // The lock conn is only needed to bracket the BEGINs below. Always return + // it to the pool on exit; the lock itself is released via UNLOCK TABLES. + defer func() { + _ = lockConn.Close() + }() + + lockQuery := buildFlushAndLockTablesQuery(tables) + logger.Infof("Acquiring table-level read locks for parallel snapshot (%d workers): %s", workerCount, lockQuery) + if _, err := lockConn.ExecContext(ctx, lockQuery); err != nil { + return failWith(fmt.Errorf("acquire table-level read locks: %w", err)) + } + unlockTables := func() error { + if _, err := lockConn.ExecContext(ctx, "UNLOCK TABLES"); err != nil { + return fmt.Errorf("release table-level read locks: %w", err) + } + return nil + } + + for idx := 0; idx < workerCount; idx++ { + conn, err := db.Conn(ctx) + if err != nil { + return failWith(fmt.Errorf("open snapshot connection %d: %w", idx, err), unlockTables()) + } + tx, err := conn.BeginTx(ctx, &sql.TxOptions{ + ReadOnly: true, + Isolation: sql.LevelRepeatableRead, + }) + if err != nil { + _ = conn.Close() + return failWith(fmt.Errorf("begin snapshot transaction %d: %w", idx, err), unlockTables()) + } + // NOTE: this is a little sneaky because we're actually implicitly + // closing the transaction started with BeginTx above and replacing it + // with this one. We have to do this because the database/sql driver + // does not support WITH CONSISTENT SNAPSHOT directly. + if _, err := tx.ExecContext(ctx, "START TRANSACTION WITH CONSISTENT SNAPSHOT"); err != nil { + _ = tx.Rollback() + _ = conn.Close() + return failWith(fmt.Errorf("start consistent snapshot %d: %w", idx, err), unlockTables()) + } + // Each worker is a "bare" Snapshot: no db (the set owns it), no + // lockConn (released at the end of this function). close() on each + // worker will rollback its tx and close its conn, which is what we + // want. + set.workers = append(set.workers, &Snapshot{ + tx: tx, + snapshotConn: conn, + logger: logger, + }) + } + + // Capture binlog position while still under lock, from any worker. All + // workers are at the same snapshot so this single position applies to all + // of them. + pos, err := set.workers[0].getCurrentBinlogPosition(ctx) + if err != nil { + return failWith(fmt.Errorf("get binlog position: %w", err), unlockTables()) + } + + if err := unlockTables(); err != nil { + return failWith(err) + } + + return set, &pos, nil +} + +// release commits every worker's snapshot transaction. Analogous to +// Snapshot.releaseSnapshot for the sequential path. +func (p *parallelSnapshotSet) release(ctx context.Context) error { + var errs []error + for _, w := range p.workers { + if err := w.releaseSnapshot(ctx); err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) +} + +// close rolls back any still-open transactions, closes every worker +// connection, then closes the shared *sql.DB. +func (p *parallelSnapshotSet) close() error { + var errs []error + for _, w := range p.workers { + if err := w.close(); err != nil { + errs = append(errs, err) + } + } + if p.db != nil { + if err := p.db.Close(); err != nil { + errs = append(errs, fmt.Errorf("close db: %w", err)) + } + p.db = nil + } + return errors.Join(errs...) +} + +// readSnapshotParallel distributes work units across set.workers and reads +// them concurrently using an errgroup. Any worker error cancels siblings and +// returns from Wait (matching the existing fail-halt semantics of the +// sequential path). +func (i *mysqlStreamInput) readSnapshotParallel(ctx context.Context, set *parallelSnapshotSet, units []snapshotWorkUnit) error { + return distributeWorkToWorkers(ctx, units, len(set.workers), func(gctx context.Context, workerIdx int, unit snapshotWorkUnit) error { + return i.readSnapshotWorkUnit(gctx, set.workers[workerIdx], unit) + }) +} + +// distributeWorkToWorkers fans out items across workerCount goroutines, +// calling readFn(ctx, workerIdx, item) exactly once per item. It uses an +// errgroup: the first error cancels the shared context and is returned from +// Wait. Exposed as a generic helper so the fan-out logic can be unit-tested +// independently of MySQL — tests pass []string, production passes +// []snapshotWorkUnit. +func distributeWorkToWorkers[T any](ctx context.Context, items []T, workerCount int, readFn func(context.Context, int, T) error) error { + if workerCount < 1 { + return fmt.Errorf("workerCount must be >= 1, got %d", workerCount) + } + if workerCount > len(items) { + workerCount = len(items) + } + if workerCount == 0 { + // No items at all. Nothing to do. + return nil + } + + g, gctx := errgroup.WithContext(ctx) + itemCh := make(chan T) + + g.Go(func() error { + defer close(itemCh) + for _, it := range items { + select { + case itemCh <- it: + case <-gctx.Done(): + return gctx.Err() + } + } + return nil + }) + + for w := 0; w < workerCount; w++ { + workerIdx := w + g.Go(func() error { + for item := range itemCh { + if err := readFn(gctx, workerIdx, item); err != nil { + return err + } + } + return nil + }) + } + + return g.Wait() +} diff --git a/internal/impl/mysql/parallel_snapshot_test.go b/internal/impl/mysql/parallel_snapshot_test.go new file mode 100644 index 0000000000..13b2a5846e --- /dev/null +++ b/internal/impl/mysql/parallel_snapshot_test.go @@ -0,0 +1,189 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package mysql + +import ( + "context" + "errors" + "fmt" + "sort" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDistributeTablesToWorkers_CoversEveryTableExactlyOnce(t *testing.T) { + tables := []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"} + + for _, workers := range []int{1, 2, 3, 4, 8, 16} { + t.Run(fmt.Sprintf("workers=%d", workers), func(t *testing.T) { + var mu sync.Mutex + var visited []string + + err := distributeWorkToWorkers(t.Context(), tables, workers, func(_ context.Context, _ int, table string) error { + mu.Lock() + visited = append(visited, table) + mu.Unlock() + return nil + }) + require.NoError(t, err) + + sort.Strings(visited) + expected := append([]string{}, tables...) + sort.Strings(expected) + assert.Equal(t, expected, visited, "each table must be visited exactly once") + }) + } +} + +func TestDistributeTablesToWorkers_WorkerCountCappedByTableCount(t *testing.T) { + tables := []string{"a", "b"} + + var activeWorkers atomic.Int32 + var maxActive atomic.Int32 + + err := distributeWorkToWorkers(t.Context(), tables, 16, func(_ context.Context, _ int, _ string) error { + n := activeWorkers.Add(1) + for { + cur := maxActive.Load() + if n <= cur || maxActive.CompareAndSwap(cur, n) { + break + } + } + time.Sleep(10 * time.Millisecond) + activeWorkers.Add(-1) + return nil + }) + require.NoError(t, err) + assert.LessOrEqual(t, int(maxActive.Load()), len(tables), "should never exceed table count, even when workerCount is larger") +} + +func TestDistributeTablesToWorkers_SingleWorkerIsSequential(t *testing.T) { + tables := []string{"a", "b", "c", "d"} + + var mu sync.Mutex + var inFlight int + var maxInFlight int + + err := distributeWorkToWorkers(t.Context(), tables, 1, func(_ context.Context, _ int, _ string) error { + mu.Lock() + inFlight++ + if inFlight > maxInFlight { + maxInFlight = inFlight + } + mu.Unlock() + time.Sleep(5 * time.Millisecond) + mu.Lock() + inFlight-- + mu.Unlock() + return nil + }) + require.NoError(t, err) + assert.Equal(t, 1, maxInFlight, "workerCount=1 must serialize all reads") +} + +func TestDistributeTablesToWorkers_ErrorPropagatesAndCancelsSiblings(t *testing.T) { + tables := make([]string, 50) + for i := range tables { + tables[i] = fmt.Sprintf("t%d", i) + } + + sentinel := errors.New("boom") + var calls atomic.Int32 + + err := distributeWorkToWorkers(t.Context(), tables, 4, func(ctx context.Context, _ int, table string) error { + calls.Add(1) + if table == "t5" { + return sentinel + } + // Block until cancelled so we can observe siblings being cancelled + // after the sentinel error fires. + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(2 * time.Second): + return nil + } + }) + require.ErrorIs(t, err, sentinel) + // At most every worker got 1 table before cancellation, plus the sentinel. + // We should not have processed all 50 tables. + assert.Less(t, int(calls.Load()), len(tables), "error must cancel siblings before all tables are consumed") +} + +func TestDistributeTablesToWorkers_ContextCancellationPropagates(t *testing.T) { + tables := make([]string, 100) + for i := range tables { + tables[i] = fmt.Sprintf("t%d", i) + } + + ctx, cancel := context.WithCancel(t.Context()) + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() + + err := distributeWorkToWorkers(ctx, tables, 4, func(ctx context.Context, _ int, _ string) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(500 * time.Millisecond): + return nil + } + }) + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) +} + +func TestDistributeTablesToWorkers_ZeroWorkersRejected(t *testing.T) { + err := distributeWorkToWorkers(t.Context(), []string{"a"}, 0, func(context.Context, int, string) error { + return nil + }) + require.Error(t, err) + assert.Contains(t, err.Error(), ">= 1") +} + +func TestDistributeTablesToWorkers_EmptyTablesIsNoop(t *testing.T) { + var called atomic.Bool + err := distributeWorkToWorkers(t.Context(), nil, 4, func(context.Context, int, string) error { + called.Store(true) + return nil + }) + require.NoError(t, err) + assert.False(t, called.Load(), "readFn must not be called when table list is empty") +} + +func TestDistributeTablesToWorkers_WorkerIdxWithinBounds(t *testing.T) { + tables := []string{"a", "b", "c", "d", "e", "f", "g", "h"} + const workerCount = 3 + + var mu sync.Mutex + seenIdxs := map[int]struct{}{} + + err := distributeWorkToWorkers(t.Context(), tables, workerCount, func(_ context.Context, idx int, _ string) error { + mu.Lock() + seenIdxs[idx] = struct{}{} + mu.Unlock() + assert.GreaterOrEqual(t, idx, 0) + assert.Less(t, idx, workerCount) + return nil + }) + require.NoError(t, err) + // Not all worker idxs are guaranteed to fire (fast paths may let one + // worker drain the whole channel), but every idx we observed must be + // within [0, workerCount). + for idx := range seenIdxs { + assert.GreaterOrEqual(t, idx, 0) + assert.Less(t, idx, workerCount) + } +} diff --git a/internal/impl/mysql/snapshot.go b/internal/impl/mysql/snapshot.go index a430c9ec32..e3f283e579 100644 --- a/internal/impl/mysql/snapshot.go +++ b/internal/impl/mysql/snapshot.go @@ -180,37 +180,42 @@ ORDER BY ORDINAL_POSITION return pks, nil } -func (s *Snapshot) querySnapshotTable(ctx context.Context, table string, pk []string, lastSeenPkVal *map[string]any, limit int) (*sql.Rows, error) { +func (s *Snapshot) querySnapshotTable(ctx context.Context, table string, pk []string, bounds *chunkBounds, lastSeenPkVal *map[string]any, limit int) (*sql.Rows, error) { snapshotQueryParts := []string{ "SELECT * FROM " + table, } - if lastSeenPkVal == nil { - snapshotQueryParts = append(snapshotQueryParts, buildOrderByClause(pk)) + var whereParts []string + var args []any - snapshotQueryParts = append(snapshotQueryParts, "LIMIT ?") - q := strings.Join(snapshotQueryParts, " ") - s.logger.Infof("Querying snapshot: %s", q) - return s.tx.QueryContext(ctx, strings.Join(snapshotQueryParts, " "), limit) + if chunkPred, chunkArgs := buildChunkPredicate(bounds); chunkPred != "" { + whereParts = append(whereParts, chunkPred) + args = append(args, chunkArgs...) } - var lastSeenPkVals []any - var placeholders []string - for _, pkCol := range pk { - val, ok := (*lastSeenPkVal)[pkCol] - if !ok { - return nil, fmt.Errorf("primary key column '%s' not found in last seen values", pkCol) + if lastSeenPkVal != nil { + var placeholders []string + for _, pkCol := range pk { + val, ok := (*lastSeenPkVal)[pkCol] + if !ok { + return nil, fmt.Errorf("primary key column '%s' not found in last seen values", pkCol) + } + args = append(args, val) + placeholders = append(placeholders, "?") } - lastSeenPkVals = append(lastSeenPkVals, val) - placeholders = append(placeholders, "?") + whereParts = append(whereParts, fmt.Sprintf("(%s) > (%s)", strings.Join(pk, ", "), strings.Join(placeholders, ", "))) } - snapshotQueryParts = append(snapshotQueryParts, fmt.Sprintf("WHERE (%s) > (%s)", strings.Join(pk, ", "), strings.Join(placeholders, ", "))) + if len(whereParts) > 0 { + snapshotQueryParts = append(snapshotQueryParts, "WHERE "+strings.Join(whereParts, " AND ")) + } snapshotQueryParts = append(snapshotQueryParts, buildOrderByClause(pk)) - snapshotQueryParts = append(snapshotQueryParts, fmt.Sprintf("LIMIT %d", limit)) + snapshotQueryParts = append(snapshotQueryParts, "LIMIT ?") + args = append(args, limit) + q := strings.Join(snapshotQueryParts, " ") s.logger.Infof("Querying snapshot: %s", q) - return s.tx.QueryContext(ctx, q, lastSeenPkVals...) + return s.tx.QueryContext(ctx, q, args...) } func buildOrderByClause(pk []string) string { diff --git a/internal/impl/mysql/snapshot_chunking.go b/internal/impl/mysql/snapshot_chunking.go new file mode 100644 index 0000000000..969e5d280d --- /dev/null +++ b/internal/impl/mysql/snapshot_chunking.go @@ -0,0 +1,217 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package mysql + +import ( + "context" + "database/sql" + "fmt" + "strings" +) + +// chunkBounds is a half-open range [lowerIncl, upperExcl) on the first column +// of a table's primary key. A nil lowerIncl means unbounded below; a nil +// upperExcl means unbounded above. Combined with the existing keyset +// pagination in querySnapshotTable, a chunkBounds partitions one table's +// rows across multiple workers with neither overlap nor gap. +type chunkBounds struct { + firstPKCol string + lowerIncl any + upperExcl any +} + +// snapshotWorkUnit is one unit of work dispatched to a snapshot worker. Every +// table produces at least one unit: either a whole-table unit (bounds == nil) +// or multiple chunked units covering the table's primary-key space. +type snapshotWorkUnit struct { + table string + bounds *chunkBounds +} + +// numericPKDataTypes is the set of MySQL DATA_TYPE tokens for which snapshot +// chunking is supported. Covers the integer family, signed and unsigned (the +// DATA_TYPE column does not distinguish the two — both appear as e.g. "int"). +// Tables whose first PK column is outside this set fall back to a single +// whole-table read. +var numericPKDataTypes = map[string]struct{}{ + "tinyint": {}, + "smallint": {}, + "mediumint": {}, + "int": {}, + "integer": {}, + "bigint": {}, +} + +// planSnapshotWork turns a table list into a work-unit list. For each table: +// +// - chunksPerTable <= 1: emit one whole-table unit (no MIN/MAX query). +// - First PK column is a supported integer type: compute MIN/MAX under the +// planner's consistent-snapshot transaction and split into chunksPerTable +// equal ranges. +// - Otherwise: emit one whole-table unit and log the fallback reason. +// +// The planner argument must hold an open consistent-snapshot transaction; all +// metadata/MIN/MAX queries run inside it so the boundaries agree with the +// state every worker observes (all workers were opened under the same FLUSH +// TABLES WITH READ LOCK window). +// +// For composite primary keys only the first column is used for chunking. This +// is efficient when the first column is the clustering prefix (the common +// shape for composite PKs that start with a tenant/shard id or a time bucket) +// and trivially correct for single-column numeric PKs. Skewed first-column +// distributions will cause uneven chunk sizes; operators who hit that pattern +// can leave chunks_per_table at 1 and rely on table-level parallelism alone. +func planSnapshotWork( + ctx context.Context, + planner *Snapshot, + tables []string, + chunksPerTable int, +) ([]snapshotWorkUnit, error) { + if chunksPerTable < 1 { + chunksPerTable = 1 + } + + units := make([]snapshotWorkUnit, 0, len(tables)) + for _, table := range tables { + if chunksPerTable == 1 { + units = append(units, snapshotWorkUnit{table: table}) + continue + } + + pks, err := planner.getTablePrimaryKeys(ctx, table) + if err != nil { + return nil, fmt.Errorf("chunk planning for %s: %w", table, err) + } + firstPK := pks[0] + + numeric, err := isNumericPKColumn(ctx, planner, table, firstPK) + if err != nil { + return nil, fmt.Errorf("inspect PK type for %s.%s: %w", table, firstPK, err) + } + if !numeric { + planner.logger.Infof( + "Snapshot chunking disabled for table %s: first PK column %s is non-numeric; reading as a single unit", + table, firstPK) + units = append(units, snapshotWorkUnit{table: table}) + continue + } + + lo, hi, empty, err := tableIntBounds(ctx, planner, table, firstPK) + if err != nil { + return nil, fmt.Errorf("compute MIN/MAX for %s.%s: %w", table, firstPK, err) + } + if empty { + units = append(units, snapshotWorkUnit{table: table}) + continue + } + + for _, r := range splitIntRange(lo, hi, chunksPerTable) { + units = append(units, snapshotWorkUnit{ + table: table, + bounds: &chunkBounds{ + firstPKCol: firstPK, + lowerIncl: r.lo, + upperExcl: r.hi, + }, + }) + } + } + return units, nil +} + +func isNumericPKColumn(ctx context.Context, s *Snapshot, table, column string) (bool, error) { + const q = ` +SELECT DATA_TYPE +FROM INFORMATION_SCHEMA.COLUMNS +WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = ? AND COLUMN_NAME = ? +` + var dt string + if err := s.tx.QueryRowContext(ctx, q, table, column).Scan(&dt); err != nil { + return false, err + } + _, ok := numericPKDataTypes[strings.ToLower(dt)] + return ok, nil +} + +// tableIntBounds returns MIN(col), MAX(col) for an integer PK column under +// the snapshot transaction. empty == true when the table has no rows (MIN +// and MAX return NULL). +func tableIntBounds(ctx context.Context, s *Snapshot, table, column string) (lo, hi int64, empty bool, err error) { + q := fmt.Sprintf("SELECT MIN(`%s`), MAX(`%s`) FROM `%s`", column, column, table) + var loN, hiN sql.NullInt64 + if err := s.tx.QueryRowContext(ctx, q).Scan(&loN, &hiN); err != nil { + return 0, 0, false, err + } + if !loN.Valid || !hiN.Valid { + return 0, 0, true, nil + } + return loN.Int64, hiN.Int64, false, nil +} + +// intRange is a planner-internal half-open chunk range. lo == nil leaves the +// first chunk unbounded below; hi == nil leaves the last chunk unbounded +// above. Open-ended outer chunks ensure rows at or near MIN/MAX are not lost +// to off-by-one errors and that any row surviving outside [MIN, MAX] under +// the snapshot is still picked up rather than silently dropped. +type intRange struct { + lo any + hi any +} + +// splitIntRange splits [lo, hi] into n half-open chunks. The outermost chunks +// use nil bounds so that rows at the exact MIN/MAX endpoints are captured and +// so that the caller does not need to special-case inclusive-vs-exclusive +// endpoints when binding parameters. Every integer in [lo, hi] falls into +// exactly one chunk. +func splitIntRange(lo, hi int64, n int) []intRange { + if n <= 1 || hi <= lo { + return []intRange{{lo: nil, hi: nil}} + } + span := uint64(hi - lo) + step := span / uint64(n) + if step == 0 { + step = 1 + } + + out := make([]intRange, 0, n) + for i := 0; i < n; i++ { + var loV, hiV any + if i > 0 { + loV = lo + int64(step*uint64(i)) + } + if i < n-1 { + hiV = lo + int64(step*uint64(i+1)) + } + out = append(out, intRange{lo: loV, hi: hiV}) + } + return out +} + +// buildChunkPredicate returns a SQL fragment bounding the first PK column +// and the values to bind. Returns ("", nil) for a nil or open-ended bounds +// argument — the caller should omit a WHERE clause in that case. +func buildChunkPredicate(b *chunkBounds) (string, []any) { + if b == nil { + return "", nil + } + var parts []string + var args []any + if b.lowerIncl != nil { + parts = append(parts, fmt.Sprintf("`%s` >= ?", b.firstPKCol)) + args = append(args, b.lowerIncl) + } + if b.upperExcl != nil { + parts = append(parts, fmt.Sprintf("`%s` < ?", b.firstPKCol)) + args = append(args, b.upperExcl) + } + if len(parts) == 0 { + return "", nil + } + return strings.Join(parts, " AND "), args +} diff --git a/internal/impl/mysql/snapshot_chunking_test.go b/internal/impl/mysql/snapshot_chunking_test.go new file mode 100644 index 0000000000..b04aeca195 --- /dev/null +++ b/internal/impl/mysql/snapshot_chunking_test.go @@ -0,0 +1,183 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package mysql + +import ( + "context" + "fmt" + "math" + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// splitIntRange is the pure chunking math. These tests lock down the +// partitioning invariants that the planner and the SQL predicate builder +// both depend on. +func TestSplitIntRange_SingleChunkWhenNLEOne(t *testing.T) { + for _, n := range []int{0, 1, -3} { + t.Run(fmt.Sprintf("n=%d", n), func(t *testing.T) { + got := splitIntRange(0, 100, n) + require.Len(t, got, 1) + assert.Nil(t, got[0].lo, "single chunk must be unbounded below") + assert.Nil(t, got[0].hi, "single chunk must be unbounded above") + }) + } +} + +func TestSplitIntRange_SingleChunkWhenRangeCollapsed(t *testing.T) { + // lo == hi (1 row) and lo > hi (empty / reversed) both degenerate to a + // single unbounded chunk so the worker sees the whole table (possibly + // empty) without the planner emitting a no-op chunk. + for _, tc := range []struct{ lo, hi int64 }{ + {lo: 5, hi: 5}, + {lo: 10, hi: 3}, + } { + t.Run(fmt.Sprintf("lo=%d,hi=%d", tc.lo, tc.hi), func(t *testing.T) { + got := splitIntRange(tc.lo, tc.hi, 4) + require.Len(t, got, 1) + assert.Nil(t, got[0].lo) + assert.Nil(t, got[0].hi) + }) + } +} + +func TestSplitIntRange_OutermostChunksAreOpenEnded(t *testing.T) { + // The first chunk must have no lower bound and the last chunk must have + // no upper bound. This guarantees every row in [MIN, MAX] is covered + // regardless of endpoint-inclusion decisions and that any row that + // somehow exists outside [MIN, MAX] is still read (not skipped). + got := splitIntRange(0, 100, 4) + require.Len(t, got, 4) + assert.Nil(t, got[0].lo, "first chunk must be unbounded below") + assert.NotNil(t, got[0].hi) + assert.NotNil(t, got[len(got)-1].lo) + assert.Nil(t, got[len(got)-1].hi, "last chunk must be unbounded above") +} + +func TestSplitIntRange_ChunksCoverAllIntegersExactlyOnce(t *testing.T) { + // Enumerate every integer in the range and confirm that each one belongs + // to exactly one chunk under the half-open [lo, hi) semantics the SQL + // predicate builder emits. + lo, hi := int64(0), int64(50) + for _, n := range []int{2, 3, 5, 7, 10, 16} { + t.Run(fmt.Sprintf("n=%d", n), func(t *testing.T) { + got := splitIntRange(lo, hi, n) + require.NotEmpty(t, got) + for v := lo; v <= hi; v++ { + covers := 0 + for _, c := range got { + lower := c.lo == nil || v >= c.lo.(int64) + upper := c.hi == nil || v < c.hi.(int64) + if lower && upper { + covers++ + } + } + assert.Equal(t, 1, covers, "value %d must belong to exactly one chunk", v) + } + }) + } +} + +func TestSplitIntRange_WhenNExceedsSpanStepIsAtLeastOne(t *testing.T) { + // [0, 3] asked for 10 chunks — span < n. The implementation floors step + // to 1; the open-ended outer chunks still guarantee total coverage even + // though some inner chunks may overlap the same pk values. Coverage + // (every row visited at least once) is what we lock down here. + got := splitIntRange(0, 3, 10) + require.NotEmpty(t, got) + for v := int64(0); v <= 3; v++ { + covers := 0 + for _, c := range got { + lower := c.lo == nil || v >= c.lo.(int64) + upper := c.hi == nil || v < c.hi.(int64) + if lower && upper { + covers++ + } + } + assert.GreaterOrEqual(t, covers, 1, "value %d must be covered by at least one chunk", v) + } +} + +func TestSplitIntRange_LargeSpanDoesNotOverflow(t *testing.T) { + // hi-lo near math.MaxInt64 must not overflow int64 arithmetic during + // step computation — we cast through uint64 to guard against that. + got := splitIntRange(math.MinInt64/2, math.MaxInt64/2, 8) + require.Len(t, got, 8) + assert.Nil(t, got[0].lo) + assert.Nil(t, got[len(got)-1].hi) +} + +// buildChunkPredicate translates chunkBounds to a SQL fragment. These tests +// pin the shape of that fragment so changes to the query surface are obvious. +func TestBuildChunkPredicate_NilReturnsEmpty(t *testing.T) { + frag, args := buildChunkPredicate(nil) + assert.Empty(t, frag) + assert.Nil(t, args) +} + +func TestBuildChunkPredicate_BothBoundsPresent(t *testing.T) { + frag, args := buildChunkPredicate(&chunkBounds{firstPKCol: "id", lowerIncl: int64(10), upperExcl: int64(20)}) + assert.Equal(t, "`id` >= ? AND `id` < ?", frag) + assert.Equal(t, []any{int64(10), int64(20)}, args) +} + +func TestBuildChunkPredicate_OnlyLowerBound(t *testing.T) { + frag, args := buildChunkPredicate(&chunkBounds{firstPKCol: "id", lowerIncl: int64(10)}) + assert.Equal(t, "`id` >= ?", frag) + assert.Equal(t, []any{int64(10)}, args) +} + +func TestBuildChunkPredicate_OnlyUpperBound(t *testing.T) { + frag, args := buildChunkPredicate(&chunkBounds{firstPKCol: "id", upperExcl: int64(20)}) + assert.Equal(t, "`id` < ?", frag) + assert.Equal(t, []any{int64(20)}, args) +} + +func TestBuildChunkPredicate_OpenEndedBothSidesReturnsEmpty(t *testing.T) { + // An "all open" chunk (both bounds nil) degenerates to no predicate — + // the caller omits the WHERE clause entirely. + frag, args := buildChunkPredicate(&chunkBounds{firstPKCol: "id"}) + assert.Empty(t, frag) + assert.Nil(t, args) +} + +// distributeWorkToWorkers was generalised from the table-string signature to +// a generic one so work units can share the same fan-out code path. Confirm +// the generic instantiation works for snapshotWorkUnit values. +func TestDistributeWorkToWorkers_SnapshotWorkUnitInstantiation(t *testing.T) { + units := []snapshotWorkUnit{ + {table: "a"}, + {table: "b", bounds: &chunkBounds{firstPKCol: "id", upperExcl: int64(100)}}, + {table: "b", bounds: &chunkBounds{firstPKCol: "id", lowerIncl: int64(100)}}, + } + + var mu sync.Mutex + var visited []snapshotWorkUnit + var workerIdxMax atomic.Int32 + + err := distributeWorkToWorkers(t.Context(), units, 2, func(_ context.Context, idx int, u snapshotWorkUnit) error { + mu.Lock() + visited = append(visited, u) + mu.Unlock() + for { + cur := workerIdxMax.Load() + if int32(idx) <= cur || workerIdxMax.CompareAndSwap(cur, int32(idx)) { + break + } + } + return nil + }) + require.NoError(t, err) + assert.Len(t, visited, len(units), "every work unit must be visited exactly once") + assert.LessOrEqual(t, int(workerIdxMax.Load()), 1, "worker idx must stay within [0, workerCount)") +}