diff --git a/pkg/snapshot/generator/postgres/data/pg_snapshot_generator.go b/pkg/snapshot/generator/postgres/data/pg_snapshot_generator.go index 0a4791c6..43e6cf92 100644 --- a/pkg/snapshot/generator/postgres/data/pg_snapshot_generator.go +++ b/pkg/snapshot/generator/postgres/data/pg_snapshot_generator.go @@ -136,7 +136,13 @@ func (sg *SnapshotGenerator) CreateSnapshot(ctx context.Context, ss *snapshot.Sn defer func() { // make sure we close the processor once the snapshot is completed. // It will wait until all rows are processed before returning. - sg.processor.Close() + if closeErr := sg.processor.Close(); closeErr != nil { + if err == nil { + err = closeErr + } else { + err = errors.Join(err, closeErr) + } + } }() // parallelise the snapshot creation for each schema as configured by the snapshot workers. diff --git a/pkg/snapshot/generator/postgres/data/pg_snapshot_generator_test.go b/pkg/snapshot/generator/postgres/data/pg_snapshot_generator_test.go index ed887348..24714971 100644 --- a/pkg/snapshot/generator/postgres/data/pg_snapshot_generator_test.go +++ b/pkg/snapshot/generator/postgres/data/pg_snapshot_generator_test.go @@ -108,11 +108,12 @@ func TestSnapshotGenerator_CreateSnapshot(t *testing.T) { errTest := errors.New("oh noes") tests := []struct { - name string - querier pglib.Querier - snapshot *snapshot.Snapshot - schemaWorkers uint - progressBar *progressmocks.Bar + name string + querier pglib.Querier + snapshot *snapshot.Snapshot + schemaWorkers uint + progressBar *progressmocks.Bar + processorCloseErr error wantEvents []*wal.Event wantErr error @@ -181,6 +182,66 @@ func TestSnapshotGenerator_CreateSnapshot(t *testing.T) { wantErr: nil, wantEvents: []*wal.Event{testEvent(testTable1, testColumns)}, }, + { + name: "error - closing processor", + querier: &pgmocks.Querier{ + ExecInTxWithOptionsFn: func(_ context.Context, i uint, f func(tx pglib.Tx) error, to pglib.TxOptions) error { + require.Equal(t, txOptions, to) + switch i { + case 1: + mockTx := pgmocks.Tx{ + QueryRowFn: func(_ context.Context, dest []any, query string, args ...any) error { + require.Equal(t, exportSnapshotQuery, query) + snapshotID, ok := dest[0].(*string) + require.True(t, ok) + *snapshotID = testSnapshotID + return nil + }, + } + return f(&mockTx) + case 2: + mockTx := pgmocks.Tx{ + ExecFn: func(ctx context.Context, _ uint, query string, args ...any) (pglib.CommandTag, error) { + require.Equal(t, fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s'", testSnapshotID), query) + return pglib.CommandTag{}, nil + }, + QueryRowFn: validTableInfoQueryRowFn, + } + return f(&mockTx) + case 3: + mockTx := pgmocks.Tx{ + ExecFn: func(ctx context.Context, _ uint, query string, args ...any) (pglib.CommandTag, error) { + require.Equal(t, fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s'", testSnapshotID), query) + return pglib.CommandTag{}, nil + }, + QueryFn: func(ctx context.Context, query string, args ...any) (pglib.Rows, error) { + require.Equal(t, fmt.Sprintf(pageRangeQuery, quotedSchemaTable1, 0, 1), query) + return &pgmocks.Rows{ + CloseFn: func() {}, + NextFn: func(i uint) bool { return i == 1 }, + FieldDescriptionsFn: func() []pgconn.FieldDescription { + return []pgconn.FieldDescription{ + {Name: "id", DataTypeOID: pgtype.UUIDOID}, + {Name: "name", DataTypeOID: pgtype.TextOID}, + } + }, + ValuesFn: func() ([]any, error) { + return []any{testUUID, "alice"}, nil + }, + ErrFn: func() error { return nil }, + }, nil + }, + } + return f(&mockTx) + default: + return fmt.Errorf("unexpected call to ExecInTxWithOptions: %d", i) + } + }, + }, + processorCloseErr: errTest, + wantErr: errTest, + wantEvents: []*wal.Event{testEvent(testTable1, testColumns)}, + }, { name: "ok - quoted identifiers in schema and table names", querier: &pgmocks.Querier{ @@ -1082,6 +1143,9 @@ func TestSnapshotGenerator_CreateSnapshot(t *testing.T) { eventChan <- e return nil }, + CloseFn: func() error { + return tc.processorCloseErr + }, }, schemaWorkers: 1, tableWorkers: 1, diff --git a/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go b/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go index 74443c55..46e128a4 100644 --- a/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go +++ b/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go @@ -39,6 +39,11 @@ type SnapshotGenerator struct { roleSQLParser *roleSQLParser optionGenerator *optionGenerator snapshotTracker snapshotProgressTracker + // restoreIndicesAndConstraintsBeforeData restores constraints/indexes that can + // be used as INSERT ... ON CONFLICT targets before the wrapped data snapshot + // generator runs. Other indexes and constraints, such as foreign keys, are + // still restored after data is inserted. + restoreIndicesAndConstraintsBeforeData bool } type snapshotProgressTracker interface { @@ -67,6 +72,11 @@ type Config struct { DumpDebugFile string // if set, security label providers that will be excluded from the dump ExcludedSecurityLabels []string + // Restore constraints/indexes that can be used as INSERT ... ON CONFLICT + // targets before data is snapshotted. This is required when the data snapshot + // writer emits INSERT ... ON CONFLICT DO UPDATE, because the target table must + // already have a matching unique or primary key constraint. + RestoreIndicesAndConstraintsBeforeData bool } type Option func(s *SnapshotGenerator) @@ -96,17 +106,18 @@ func NewSnapshotGenerator(ctx context.Context, c *Config, opts ...Option) (*Snap } sg := &SnapshotGenerator{ - sourceURL: c.SourcePGURL, - targetURL: c.TargetPGURL, - pgDumpFn: pglib.RunPGDump, - pgDumpAllFn: pglib.RunPGDumpAll, - pgRestoreFn: pglib.RunPGRestore, - logger: loglib.NewNoopLogger(), - dumpDebugFile: c.DumpDebugFile, - excludedSecurityLabels: c.ExcludedSecurityLabels, - roleSQLParser: &roleSQLParser{}, - sourceQuerier: sourceConnPool, - optionGenerator: newOptionGenerator(sourceConnPool, c), + sourceURL: c.SourcePGURL, + targetURL: c.TargetPGURL, + pgDumpFn: pglib.RunPGDump, + pgDumpAllFn: pglib.RunPGDumpAll, + pgRestoreFn: pglib.RunPGRestore, + logger: loglib.NewNoopLogger(), + dumpDebugFile: c.DumpDebugFile, + excludedSecurityLabels: c.ExcludedSecurityLabels, + restoreIndicesAndConstraintsBeforeData: c.RestoreIndicesAndConstraintsBeforeData, + roleSQLParser: &roleSQLParser{}, + sourceQuerier: sourceConnPool, + optionGenerator: newOptionGenerator(sourceConnPool, c), } for _, opt := range opts { @@ -225,8 +236,19 @@ func (s *SnapshotGenerator) CreateSnapshot(ctx context.Context, ss *snapshot.Sna return err } + indicesAndConstraintsDump := dump.indicesAndConstraints + if s.generator != nil && s.restoreIndicesAndConstraintsBeforeData { + conflictTargets, remaining := splitConflictTargetConstraints(dump.indicesAndConstraints) + if err := s.restoreIndicesAndConstraints(ctx, conflictTargets, ss); err != nil { + return err + } + indicesAndConstraintsDump = remaining + } + // call the wrapped snapshot generator if any before restoring sequences, - // indices and constraints to improve performance. + // indices and constraints to improve performance. When the data snapshot + // writer emits INSERT ... ON CONFLICT DO UPDATE, the subset of constraints + // needed as conflict targets is restored before data above. if s.generator != nil { if err := s.generator.CreateSnapshot(ctx, ss); err != nil { return err @@ -239,12 +261,7 @@ func (s *SnapshotGenerator) CreateSnapshot(ctx context.Context, ss *snapshot.Sna return err } - s.logger.Info("restoring schema indices and constraints", loglib.Fields{"schemaTables": ss.SchemaTables}) - if s.snapshotTracker != nil { - if err := s.restoreIndicesWithTracking(ctx, dump.indicesAndConstraints); err != nil { - return err - } - } else if err := s.restoreDump(ctx, dump.indicesAndConstraints); err != nil { + if err := s.restoreIndicesAndConstraints(ctx, indicesAndConstraintsDump, ss); err != nil { return err } @@ -252,6 +269,63 @@ func (s *SnapshotGenerator) CreateSnapshot(ctx context.Context, ss *snapshot.Sna return s.restoreDump(ctx, dump.views) } +func splitConflictTargetConstraints(d []byte) ([]byte, []byte) { + blocks := strings.Split(string(d), "\n\n") + connectBlocks := []string{} + conflictTargetBlocks := []string{} + remainingBlocks := []string{} + + for _, block := range blocks { + block = strings.TrimSpace(block) + if block == "" { + continue + } + switch { + case strings.Contains(block, `\connect`): + connectBlocks = append(connectBlocks, block) + case isConflictTargetConstraint(block): + conflictTargetBlocks = append(conflictTargetBlocks, block) + default: + remainingBlocks = append(remainingBlocks, block) + } + } + + return joinDumpBlocks(connectBlocks, conflictTargetBlocks), joinDumpBlocks(connectBlocks, remainingBlocks) +} + +func joinDumpBlocks(connectBlocks, blocks []string) []byte { + if len(blocks) == 0 { + return nil + } + allBlocks := make([]string, 0, len(connectBlocks)+len(blocks)) + allBlocks = append(allBlocks, connectBlocks...) + allBlocks = append(allBlocks, blocks...) + return []byte(strings.Join(allBlocks, "\n\n") + "\n\n") +} + +func isConflictTargetConstraint(block string) bool { + upperBlock := strings.ToUpper(block) + if strings.HasPrefix(upperBlock, "CREATE UNIQUE INDEX") { + return true + } + if !strings.Contains(upperBlock, "ADD CONSTRAINT") { + return false + } + return strings.Contains(upperBlock, " PRIMARY KEY (") || + strings.Contains(upperBlock, " PRIMARY KEY USING INDEX ") || + strings.Contains(upperBlock, " UNIQUE (") || + strings.Contains(upperBlock, " UNIQUE NULLS NOT DISTINCT (") || + strings.Contains(upperBlock, " UNIQUE USING INDEX ") +} + +func (s *SnapshotGenerator) restoreIndicesAndConstraints(ctx context.Context, dump []byte, ss *snapshot.Snapshot) error { + s.logger.Info("restoring schema indices and constraints", loglib.Fields{"schemaTables": ss.SchemaTables}) + if s.snapshotTracker != nil { + return s.restoreIndicesWithTracking(ctx, dump) + } + return s.restoreDump(ctx, dump) +} + func (s *SnapshotGenerator) Close() error { if s.generator != nil { if err := s.generator.Close(); err != nil { diff --git a/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator_test.go b/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator_test.go index 66bc68c6..391deede 100644 --- a/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator_test.go +++ b/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator_test.go @@ -1271,6 +1271,109 @@ func TestSnapshotGenerator_CreateSnapshot(t *testing.T) { } } +func TestSnapshotGenerator_RestoresConstraintsBeforeDataWhenConfigured(t *testing.T) { + t.Parallel() + + schemaDump := []byte(`CREATE TABLE public.test_table ( + id integer NOT NULL, + parent_id integer, + value text NOT NULL +); + +ALTER TABLE ONLY public.test_table + ADD CONSTRAINT test_table_pkey PRIMARY KEY (id); + +ALTER TABLE ONLY public.test_table + ADD CONSTRAINT test_table_value_key UNIQUE (value); + +ALTER TABLE ONLY public.test_table + ADD CONSTRAINT "test_table UNIQUE named foreign key" FOREIGN KEY (parent_id) REFERENCES public.test_table(id); + +CREATE INDEX test_table_value_idx ON public.test_table USING btree (value); +`) + filteredDump := []byte(`CREATE TABLE public.test_table ( + id integer NOT NULL, + parent_id integer, + value text NOT NULL +); + +`) + conflictTargetDump := []byte(`ALTER TABLE ONLY public.test_table + ADD CONSTRAINT test_table_pkey PRIMARY KEY (id); + +ALTER TABLE ONLY public.test_table + ADD CONSTRAINT test_table_value_key UNIQUE (value); + +`) + remainingConstraintsDump := []byte(`ALTER TABLE ONLY public.test_table + ADD CONSTRAINT "test_table UNIQUE named foreign key" FOREIGN KEY (parent_id) REFERENCES public.test_table(id); + +CREATE INDEX test_table_value_idx ON public.test_table USING btree (value); + +`) + + calls := []string{} + conn := &mocks.Querier{ + ExecFn: func(ctx context.Context, i uint, query string, args ...any) (pglib.CommandTag, error) { + return pglib.CommandTag{}, nil + }, + QueryFn: func(ctx context.Context, i uint, query string, args ...any) (pglib.Rows, error) { + return &mocks.Rows{ + CloseFn: func() {}, + NextFn: func(i uint) bool { return false }, + ErrFn: func() error { return nil }, + }, nil + }, + } + + sg := SnapshotGenerator{ + sourceURL: "source-url", + targetURL: "target-url", + sourceQuerier: conn, + pgDumpFn: newMockPgdump(func(_ context.Context, i uint, po pglib.PGDumpOptions) ([]byte, error) { + require.Equal(t, uint(1), i) + return schemaDump, nil + }), + pgRestoreFn: newMockPgrestore(func(_ context.Context, i uint, po pglib.PGRestoreOptions, dump []byte) (string, error) { + switch strings.TrimSpace(string(dump)) { + case strings.TrimSpace(string(filteredDump)): + calls = append(calls, "schema") + case strings.TrimSpace(string(conflictTargetDump)): + calls = append(calls, "conflict targets") + case strings.TrimSpace(string(remainingConstraintsDump)): + calls = append(calls, "remaining constraints") + default: + require.Failf(t, "unexpected dump", "%q", string(dump)) + } + return "", nil + }), + logger: log.NewNoopLogger(), + generator: &generatormocks.Generator{ + CreateSnapshotFn: func(ctx context.Context, snapshot *snapshot.Snapshot) error { + calls = append(calls, "data") + return nil + }, + }, + roleSQLParser: &roleSQLParser{}, + optionGenerator: &optionGenerator{ + sourceURL: "source-url", + targetURL: "target-url", + noOwner: true, + rolesSnapshotMode: roleSnapshotDisabled, + querier: conn, + }, + restoreIndicesAndConstraintsBeforeData: true, + } + + err := sg.CreateSnapshot(context.Background(), &snapshot.Snapshot{ + SchemaTables: map[string][]string{ + publicSchema: {"test_table"}, + }, + }) + require.NoError(t, err) + require.Equal(t, []string{"schema", "conflict targets", "data", "remaining constraints"}, calls) +} + func TestSnapshotGenerator_parseDump(t *testing.T) { t.Parallel() diff --git a/pkg/stream/snapshot_config_test.go b/pkg/stream/snapshot_config_test.go new file mode 100644 index 00000000..43a06e25 --- /dev/null +++ b/pkg/stream/snapshot_config_test.go @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: Apache-2.0 + +package stream + +import ( + "testing" + + "github.com/stretchr/testify/require" + pgsnapshotgenerator "github.com/xataio/pgstream/pkg/snapshot/generator/postgres/data" + "github.com/xataio/pgstream/pkg/snapshot/generator/postgres/schema/pgdumprestore" + snapshotbuilder "github.com/xataio/pgstream/pkg/wal/listener/snapshot/builder" + pgwriter "github.com/xataio/pgstream/pkg/wal/processor/postgres" +) + +func TestPrepareSnapshotSchemaRestore(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + onConflictAction string + bulkIngest bool + + wantRestoreConstraintsBeforeData bool + }{ + { + name: "update with batch writer restores constraints before data", + onConflictAction: "update", + wantRestoreConstraintsBeforeData: true, + }, + { + name: "update with bulk ingest keeps default order", + onConflictAction: "update", + bulkIngest: true, + }, + { + name: "do nothing keeps default order", + onConflictAction: "nothing", + }, + { + name: "default error behavior keeps default order", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + config := &Config{ + Listener: ListenerConfig{ + Postgres: &PostgresListenerConfig{ + Snapshot: &snapshotbuilder.SnapshotListenerConfig{ + Data: &pgsnapshotgenerator.Config{}, + Schema: &snapshotbuilder.SchemaSnapshotConfig{ + DumpRestore: &pgdumprestore.Config{}, + }, + }, + }, + }, + Processor: ProcessorConfig{ + Postgres: &PostgresProcessorConfig{ + BatchWriter: pgwriter.Config{ + OnConflictAction: tc.onConflictAction, + BulkIngestEnabled: tc.bulkIngest, + }, + }, + }, + } + + prepareSnapshotSchemaRestore(config) + + require.Equal(t, + tc.wantRestoreConstraintsBeforeData, + config.Listener.Postgres.Snapshot.Schema.DumpRestore.RestoreIndicesAndConstraintsBeforeData) + }) + } +} diff --git a/pkg/stream/stream.go b/pkg/stream/stream.go index 357ed064..22a654e4 100644 --- a/pkg/stream/stream.go +++ b/pkg/stream/stream.go @@ -6,6 +6,7 @@ import ( "context" "errors" "fmt" + "strings" loglib "github.com/xataio/pgstream/pkg/log" "github.com/xataio/pgstream/pkg/otel" @@ -38,6 +39,19 @@ const ( processorTypeSnapshot ) +func prepareSnapshotSchemaRestore(config *Config) { + if config.Listener.Postgres == nil || config.Listener.Postgres.Snapshot == nil || + config.Listener.Postgres.Snapshot.Data == nil || config.Listener.Postgres.Snapshot.Schema == nil || + config.Listener.Postgres.Snapshot.Schema.DumpRestore == nil || config.Processor.Postgres == nil { + return + } + + batchWriterConfig := config.Processor.Postgres.BatchWriter + if !batchWriterConfig.BulkIngestEnabled && strings.EqualFold(batchWriterConfig.OnConflictAction, "update") { + config.Listener.Postgres.Snapshot.Schema.DumpRestore.RestoreIndicesAndConstraintsBeforeData = true + } +} + func buildProcessor(ctx context.Context, logger loglib.Logger, config *ProcessorConfig, checkpoint checkpointer.Checkpoint, processorType processorType, instrumentation *otel.Instrumentation) (processor.Processor, error) { var processor processor.Processor switch { diff --git a/pkg/stream/stream_run.go b/pkg/stream/stream_run.go index 0869bf04..149eb2f2 100644 --- a/pkg/stream/stream_run.go +++ b/pkg/stream/stream_run.go @@ -32,6 +32,7 @@ func Run(ctx context.Context, logger loglib.Logger, config *Config, init bool, i if err := config.IsValid(); err != nil { return fmt.Errorf("incompatible configuration: %w", err) } + prepareSnapshotSchemaRestore(config) if init { if err := Init(ctx, config.GetInitConfig()); err != nil { diff --git a/pkg/stream/stream_snapshot.go b/pkg/stream/stream_snapshot.go index 75604399..38173558 100644 --- a/pkg/stream/stream_snapshot.go +++ b/pkg/stream/stream_snapshot.go @@ -22,6 +22,7 @@ func Snapshot(ctx context.Context, logger loglib.Logger, config *Config, instrum if err := config.IsValid(); err != nil { return fmt.Errorf("incompatible configuration: %w", err) } + prepareSnapshotSchemaRestore(config) eg, ctx := errgroup.WithContext(ctx) diff --git a/pkg/wal/processor/mocks/mock_processor.go b/pkg/wal/processor/mocks/mock_processor.go index 5013bb96..9ad530aa 100644 --- a/pkg/wal/processor/mocks/mock_processor.go +++ b/pkg/wal/processor/mocks/mock_processor.go @@ -4,6 +4,7 @@ package mocks import ( "context" + "sync/atomic" "github.com/xataio/pgstream/pkg/wal" ) @@ -11,16 +12,16 @@ import ( type Processor struct { ProcessWALEventFn func(ctx context.Context, walEvent *wal.Event) error CloseFn func() error - processCalls uint + processCalls atomic.Uint64 } func (m *Processor) ProcessWALEvent(ctx context.Context, walEvent *wal.Event) error { - m.processCalls++ + m.processCalls.Add(1) return m.ProcessWALEventFn(ctx, walEvent) } func (m *Processor) GetProcessCalls() uint { - return m.processCalls + return uint(m.processCalls.Load()) } func (m *Processor) Close() error {