From 4e59bcb0e70ecad84ba6ff4be2764764925a34f7 Mon Sep 17 00:00:00 2001 From: blurskye Date: Sun, 10 May 2026 01:14:59 +0600 Subject: [PATCH 1/4] Fix snapshot conflict updates with restored constraints --- .../snapshot_pg_dump_restore_generator.go | 61 ++++++++++---- ...snapshot_pg_dump_restore_generator_test.go | 82 +++++++++++++++++++ pkg/stream/snapshot_config_test.go | 76 +++++++++++++++++ pkg/stream/stream.go | 14 ++++ pkg/stream/stream_run.go | 1 + pkg/stream/stream_snapshot.go | 1 + 6 files changed, 218 insertions(+), 17 deletions(-) create mode 100644 pkg/stream/snapshot_config_test.go diff --git a/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go b/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go index 74443c55..4f681a17 100644 --- a/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go +++ b/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go @@ -39,6 +39,12 @@ type SnapshotGenerator struct { roleSQLParser *roleSQLParser optionGenerator *optionGenerator snapshotTracker snapshotProgressTracker + // restoreIndicesAndConstraintsBeforeData restores primary keys, indexes and + // constraints before the wrapped data snapshot generator runs. This is + // needed when the data writer uses INSERT ... ON CONFLICT DO UPDATE, since + // Postgres requires the conflict target to have a matching unique or primary + // key constraint before data is inserted. + restoreIndicesAndConstraintsBeforeData bool } type snapshotProgressTracker interface { @@ -67,6 +73,11 @@ type Config struct { DumpDebugFile string // if set, security label providers that will be excluded from the dump ExcludedSecurityLabels []string + // Restore primary keys, indexes and constraints before data is snapshotted. + // This is required when the data snapshot writer emits INSERT ... ON + // CONFLICT DO UPDATE, because the target table must already have a matching + // unique or primary key constraint. + RestoreIndicesAndConstraintsBeforeData bool } type Option func(s *SnapshotGenerator) @@ -96,17 +107,18 @@ func NewSnapshotGenerator(ctx context.Context, c *Config, opts ...Option) (*Snap } sg := &SnapshotGenerator{ - sourceURL: c.SourcePGURL, - targetURL: c.TargetPGURL, - pgDumpFn: pglib.RunPGDump, - pgDumpAllFn: pglib.RunPGDumpAll, - pgRestoreFn: pglib.RunPGRestore, - logger: loglib.NewNoopLogger(), - dumpDebugFile: c.DumpDebugFile, - excludedSecurityLabels: c.ExcludedSecurityLabels, - roleSQLParser: &roleSQLParser{}, - sourceQuerier: sourceConnPool, - optionGenerator: newOptionGenerator(sourceConnPool, c), + sourceURL: c.SourcePGURL, + targetURL: c.TargetPGURL, + pgDumpFn: pglib.RunPGDump, + pgDumpAllFn: pglib.RunPGDumpAll, + pgRestoreFn: pglib.RunPGRestore, + logger: loglib.NewNoopLogger(), + dumpDebugFile: c.DumpDebugFile, + excludedSecurityLabels: c.ExcludedSecurityLabels, + restoreIndicesAndConstraintsBeforeData: c.RestoreIndicesAndConstraintsBeforeData, + roleSQLParser: &roleSQLParser{}, + sourceQuerier: sourceConnPool, + optionGenerator: newOptionGenerator(sourceConnPool, c), } for _, opt := range opts { @@ -225,8 +237,18 @@ func (s *SnapshotGenerator) CreateSnapshot(ctx context.Context, ss *snapshot.Sna return err } + indicesAndConstraintsRestored := false + if s.generator != nil && s.restoreIndicesAndConstraintsBeforeData { + if err := s.restoreIndicesAndConstraints(ctx, dump.indicesAndConstraints, ss); err != nil { + return err + } + indicesAndConstraintsRestored = true + } + // call the wrapped snapshot generator if any before restoring sequences, - // indices and constraints to improve performance. + // indices and constraints to improve performance. When the data snapshot + // writer emits INSERT ... ON CONFLICT DO UPDATE, constraints are restored + // before data above so conflict targets are valid. if s.generator != nil { if err := s.generator.CreateSnapshot(ctx, ss); err != nil { return err @@ -239,19 +261,24 @@ func (s *SnapshotGenerator) CreateSnapshot(ctx context.Context, ss *snapshot.Sna return err } - s.logger.Info("restoring schema indices and constraints", loglib.Fields{"schemaTables": ss.SchemaTables}) - if s.snapshotTracker != nil { - if err := s.restoreIndicesWithTracking(ctx, dump.indicesAndConstraints); err != nil { + if !indicesAndConstraintsRestored { + if err := s.restoreIndicesAndConstraints(ctx, dump.indicesAndConstraints, ss); err != nil { return err } - } else if err := s.restoreDump(ctx, dump.indicesAndConstraints); err != nil { - return err } s.logger.Info("restoring views") return s.restoreDump(ctx, dump.views) } +func (s *SnapshotGenerator) restoreIndicesAndConstraints(ctx context.Context, dump []byte, ss *snapshot.Snapshot) error { + s.logger.Info("restoring schema indices and constraints", loglib.Fields{"schemaTables": ss.SchemaTables}) + if s.snapshotTracker != nil { + return s.restoreIndicesWithTracking(ctx, dump) + } + return s.restoreDump(ctx, dump) +} + func (s *SnapshotGenerator) Close() error { if s.generator != nil { if err := s.generator.Close(); err != nil { diff --git a/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator_test.go b/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator_test.go index 66bc68c6..b84701f7 100644 --- a/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator_test.go +++ b/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator_test.go @@ -1271,6 +1271,88 @@ func TestSnapshotGenerator_CreateSnapshot(t *testing.T) { } } +func TestSnapshotGenerator_RestoresConstraintsBeforeDataWhenConfigured(t *testing.T) { + t.Parallel() + + schemaDump := []byte(`CREATE TABLE public.test_table ( + id integer NOT NULL, + value text NOT NULL +); + +ALTER TABLE ONLY public.test_table + ADD CONSTRAINT test_table_pkey PRIMARY KEY (id); +`) + filteredDump := []byte(`CREATE TABLE public.test_table ( + id integer NOT NULL, + value text NOT NULL +); + +`) + constraintDump := []byte(`ALTER TABLE ONLY public.test_table + ADD CONSTRAINT test_table_pkey PRIMARY KEY (id); + +`) + + calls := []string{} + conn := &mocks.Querier{ + ExecFn: func(ctx context.Context, i uint, query string, args ...any) (pglib.CommandTag, error) { + return pglib.CommandTag{}, nil + }, + QueryFn: func(ctx context.Context, i uint, query string, args ...any) (pglib.Rows, error) { + return &mocks.Rows{ + CloseFn: func() {}, + NextFn: func(i uint) bool { return false }, + ErrFn: func() error { return nil }, + }, nil + }, + } + + sg := SnapshotGenerator{ + sourceURL: "source-url", + targetURL: "target-url", + sourceQuerier: conn, + pgDumpFn: newMockPgdump(func(_ context.Context, i uint, po pglib.PGDumpOptions) ([]byte, error) { + require.Equal(t, uint(1), i) + return schemaDump, nil + }), + pgRestoreFn: newMockPgrestore(func(_ context.Context, i uint, po pglib.PGRestoreOptions, dump []byte) (string, error) { + switch string(dump) { + case string(filteredDump): + calls = append(calls, "schema") + case string(constraintDump): + calls = append(calls, "constraints") + default: + require.Failf(t, "unexpected dump", "%q", string(dump)) + } + return "", nil + }), + logger: log.NewNoopLogger(), + generator: &generatormocks.Generator{ + CreateSnapshotFn: func(ctx context.Context, snapshot *snapshot.Snapshot) error { + calls = append(calls, "data") + return nil + }, + }, + roleSQLParser: &roleSQLParser{}, + optionGenerator: &optionGenerator{ + sourceURL: "source-url", + targetURL: "target-url", + noOwner: true, + rolesSnapshotMode: roleSnapshotDisabled, + querier: conn, + }, + restoreIndicesAndConstraintsBeforeData: true, + } + + err := sg.CreateSnapshot(context.Background(), &snapshot.Snapshot{ + SchemaTables: map[string][]string{ + publicSchema: {"test_table"}, + }, + }) + require.NoError(t, err) + require.Equal(t, []string{"schema", "constraints", "data"}, calls) +} + func TestSnapshotGenerator_parseDump(t *testing.T) { t.Parallel() diff --git a/pkg/stream/snapshot_config_test.go b/pkg/stream/snapshot_config_test.go new file mode 100644 index 00000000..43a06e25 --- /dev/null +++ b/pkg/stream/snapshot_config_test.go @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: Apache-2.0 + +package stream + +import ( + "testing" + + "github.com/stretchr/testify/require" + pgsnapshotgenerator "github.com/xataio/pgstream/pkg/snapshot/generator/postgres/data" + "github.com/xataio/pgstream/pkg/snapshot/generator/postgres/schema/pgdumprestore" + snapshotbuilder "github.com/xataio/pgstream/pkg/wal/listener/snapshot/builder" + pgwriter "github.com/xataio/pgstream/pkg/wal/processor/postgres" +) + +func TestPrepareSnapshotSchemaRestore(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + onConflictAction string + bulkIngest bool + + wantRestoreConstraintsBeforeData bool + }{ + { + name: "update with batch writer restores constraints before data", + onConflictAction: "update", + wantRestoreConstraintsBeforeData: true, + }, + { + name: "update with bulk ingest keeps default order", + onConflictAction: "update", + bulkIngest: true, + }, + { + name: "do nothing keeps default order", + onConflictAction: "nothing", + }, + { + name: "default error behavior keeps default order", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + config := &Config{ + Listener: ListenerConfig{ + Postgres: &PostgresListenerConfig{ + Snapshot: &snapshotbuilder.SnapshotListenerConfig{ + Data: &pgsnapshotgenerator.Config{}, + Schema: &snapshotbuilder.SchemaSnapshotConfig{ + DumpRestore: &pgdumprestore.Config{}, + }, + }, + }, + }, + Processor: ProcessorConfig{ + Postgres: &PostgresProcessorConfig{ + BatchWriter: pgwriter.Config{ + OnConflictAction: tc.onConflictAction, + BulkIngestEnabled: tc.bulkIngest, + }, + }, + }, + } + + prepareSnapshotSchemaRestore(config) + + require.Equal(t, + tc.wantRestoreConstraintsBeforeData, + config.Listener.Postgres.Snapshot.Schema.DumpRestore.RestoreIndicesAndConstraintsBeforeData) + }) + } +} diff --git a/pkg/stream/stream.go b/pkg/stream/stream.go index 357ed064..22a654e4 100644 --- a/pkg/stream/stream.go +++ b/pkg/stream/stream.go @@ -6,6 +6,7 @@ import ( "context" "errors" "fmt" + "strings" loglib "github.com/xataio/pgstream/pkg/log" "github.com/xataio/pgstream/pkg/otel" @@ -38,6 +39,19 @@ const ( processorTypeSnapshot ) +func prepareSnapshotSchemaRestore(config *Config) { + if config.Listener.Postgres == nil || config.Listener.Postgres.Snapshot == nil || + config.Listener.Postgres.Snapshot.Data == nil || config.Listener.Postgres.Snapshot.Schema == nil || + config.Listener.Postgres.Snapshot.Schema.DumpRestore == nil || config.Processor.Postgres == nil { + return + } + + batchWriterConfig := config.Processor.Postgres.BatchWriter + if !batchWriterConfig.BulkIngestEnabled && strings.EqualFold(batchWriterConfig.OnConflictAction, "update") { + config.Listener.Postgres.Snapshot.Schema.DumpRestore.RestoreIndicesAndConstraintsBeforeData = true + } +} + func buildProcessor(ctx context.Context, logger loglib.Logger, config *ProcessorConfig, checkpoint checkpointer.Checkpoint, processorType processorType, instrumentation *otel.Instrumentation) (processor.Processor, error) { var processor processor.Processor switch { diff --git a/pkg/stream/stream_run.go b/pkg/stream/stream_run.go index 0869bf04..149eb2f2 100644 --- a/pkg/stream/stream_run.go +++ b/pkg/stream/stream_run.go @@ -32,6 +32,7 @@ func Run(ctx context.Context, logger loglib.Logger, config *Config, init bool, i if err := config.IsValid(); err != nil { return fmt.Errorf("incompatible configuration: %w", err) } + prepareSnapshotSchemaRestore(config) if init { if err := Init(ctx, config.GetInitConfig()); err != nil { diff --git a/pkg/stream/stream_snapshot.go b/pkg/stream/stream_snapshot.go index 75604399..38173558 100644 --- a/pkg/stream/stream_snapshot.go +++ b/pkg/stream/stream_snapshot.go @@ -22,6 +22,7 @@ func Snapshot(ctx context.Context, logger loglib.Logger, config *Config, instrum if err := config.IsValid(); err != nil { return fmt.Errorf("incompatible configuration: %w", err) } + prepareSnapshotSchemaRestore(config) eg, ctx := errgroup.WithContext(ctx) From 5e791506d115a1f9e55e204aaa531490b027b24c Mon Sep 17 00:00:00 2001 From: blurskye Date: Sun, 10 May 2026 03:24:23 +0600 Subject: [PATCH 2/4] Refine snapshot constraint restore ordering --- .../snapshot_pg_dump_restore_generator.go | 77 ++++++++++++++----- ...snapshot_pg_dump_restore_generator_test.go | 27 +++++-- 2 files changed, 80 insertions(+), 24 deletions(-) diff --git a/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go b/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go index 4f681a17..47177c3f 100644 --- a/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go +++ b/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go @@ -39,11 +39,10 @@ type SnapshotGenerator struct { roleSQLParser *roleSQLParser optionGenerator *optionGenerator snapshotTracker snapshotProgressTracker - // restoreIndicesAndConstraintsBeforeData restores primary keys, indexes and - // constraints before the wrapped data snapshot generator runs. This is - // needed when the data writer uses INSERT ... ON CONFLICT DO UPDATE, since - // Postgres requires the conflict target to have a matching unique or primary - // key constraint before data is inserted. + // restoreIndicesAndConstraintsBeforeData restores constraints/indexes that can + // be used as INSERT ... ON CONFLICT targets before the wrapped data snapshot + // generator runs. Other indexes and constraints, such as foreign keys, are + // still restored after data is inserted. restoreIndicesAndConstraintsBeforeData bool } @@ -73,10 +72,10 @@ type Config struct { DumpDebugFile string // if set, security label providers that will be excluded from the dump ExcludedSecurityLabels []string - // Restore primary keys, indexes and constraints before data is snapshotted. - // This is required when the data snapshot writer emits INSERT ... ON - // CONFLICT DO UPDATE, because the target table must already have a matching - // unique or primary key constraint. + // Restore constraints/indexes that can be used as INSERT ... ON CONFLICT + // targets before data is snapshotted. This is required when the data snapshot + // writer emits INSERT ... ON CONFLICT DO UPDATE, because the target table must + // already have a matching unique or primary key constraint. RestoreIndicesAndConstraintsBeforeData bool } @@ -237,18 +236,19 @@ func (s *SnapshotGenerator) CreateSnapshot(ctx context.Context, ss *snapshot.Sna return err } - indicesAndConstraintsRestored := false + indicesAndConstraintsDump := dump.indicesAndConstraints if s.generator != nil && s.restoreIndicesAndConstraintsBeforeData { - if err := s.restoreIndicesAndConstraints(ctx, dump.indicesAndConstraints, ss); err != nil { + conflictTargets, remaining := splitConflictTargetConstraints(dump.indicesAndConstraints) + if err := s.restoreIndicesAndConstraints(ctx, conflictTargets, ss); err != nil { return err } - indicesAndConstraintsRestored = true + indicesAndConstraintsDump = remaining } // call the wrapped snapshot generator if any before restoring sequences, // indices and constraints to improve performance. When the data snapshot - // writer emits INSERT ... ON CONFLICT DO UPDATE, constraints are restored - // before data above so conflict targets are valid. + // writer emits INSERT ... ON CONFLICT DO UPDATE, the subset of constraints + // needed as conflict targets is restored before data above. if s.generator != nil { if err := s.generator.CreateSnapshot(ctx, ss); err != nil { return err @@ -261,16 +261,57 @@ func (s *SnapshotGenerator) CreateSnapshot(ctx context.Context, ss *snapshot.Sna return err } - if !indicesAndConstraintsRestored { - if err := s.restoreIndicesAndConstraints(ctx, dump.indicesAndConstraints, ss); err != nil { - return err - } + if err := s.restoreIndicesAndConstraints(ctx, indicesAndConstraintsDump, ss); err != nil { + return err } s.logger.Info("restoring views") return s.restoreDump(ctx, dump.views) } +func splitConflictTargetConstraints(d []byte) ([]byte, []byte) { + blocks := strings.Split(string(d), "\n\n") + connectBlocks := []string{} + conflictTargetBlocks := []string{} + remainingBlocks := []string{} + + for _, block := range blocks { + block = strings.TrimSpace(block) + if block == "" { + continue + } + switch { + case strings.Contains(block, `\connect`): + connectBlocks = append(connectBlocks, block) + case isConflictTargetConstraint(block): + conflictTargetBlocks = append(conflictTargetBlocks, block) + default: + remainingBlocks = append(remainingBlocks, block) + } + } + + return joinDumpBlocks(connectBlocks, conflictTargetBlocks), joinDumpBlocks(connectBlocks, remainingBlocks) +} + +func joinDumpBlocks(connectBlocks, blocks []string) []byte { + if len(blocks) == 0 { + return nil + } + allBlocks := make([]string, 0, len(connectBlocks)+len(blocks)) + allBlocks = append(allBlocks, connectBlocks...) + allBlocks = append(allBlocks, blocks...) + return []byte(strings.Join(allBlocks, "\n\n") + "\n\n") +} + +func isConflictTargetConstraint(block string) bool { + upperBlock := strings.ToUpper(block) + if strings.HasPrefix(upperBlock, "CREATE UNIQUE INDEX") { + return true + } + return strings.Contains(upperBlock, "ADD CONSTRAINT") && + (strings.Contains(upperBlock, " PRIMARY KEY ") || strings.Contains(upperBlock, " UNIQUE ")) +} + func (s *SnapshotGenerator) restoreIndicesAndConstraints(ctx context.Context, dump []byte, ss *snapshot.Snapshot) error { s.logger.Info("restoring schema indices and constraints", loglib.Fields{"schemaTables": ss.SchemaTables}) if s.snapshotTracker != nil { diff --git a/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator_test.go b/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator_test.go index b84701f7..a5a43128 100644 --- a/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator_test.go +++ b/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator_test.go @@ -1276,21 +1276,34 @@ func TestSnapshotGenerator_RestoresConstraintsBeforeDataWhenConfigured(t *testin schemaDump := []byte(`CREATE TABLE public.test_table ( id integer NOT NULL, + parent_id integer, value text NOT NULL ); ALTER TABLE ONLY public.test_table ADD CONSTRAINT test_table_pkey PRIMARY KEY (id); + +ALTER TABLE ONLY public.test_table + ADD CONSTRAINT test_table_parent_id_fkey FOREIGN KEY (parent_id) REFERENCES public.test_table(id); + +CREATE INDEX test_table_value_idx ON public.test_table USING btree (value); `) filteredDump := []byte(`CREATE TABLE public.test_table ( id integer NOT NULL, + parent_id integer, value text NOT NULL ); `) - constraintDump := []byte(`ALTER TABLE ONLY public.test_table + conflictTargetDump := []byte(`ALTER TABLE ONLY public.test_table ADD CONSTRAINT test_table_pkey PRIMARY KEY (id); +`) + remainingConstraintsDump := []byte(`ALTER TABLE ONLY public.test_table + ADD CONSTRAINT test_table_parent_id_fkey FOREIGN KEY (parent_id) REFERENCES public.test_table(id); + +CREATE INDEX test_table_value_idx ON public.test_table USING btree (value); + `) calls := []string{} @@ -1316,11 +1329,13 @@ ALTER TABLE ONLY public.test_table return schemaDump, nil }), pgRestoreFn: newMockPgrestore(func(_ context.Context, i uint, po pglib.PGRestoreOptions, dump []byte) (string, error) { - switch string(dump) { - case string(filteredDump): + switch strings.TrimSpace(string(dump)) { + case strings.TrimSpace(string(filteredDump)): calls = append(calls, "schema") - case string(constraintDump): - calls = append(calls, "constraints") + case strings.TrimSpace(string(conflictTargetDump)): + calls = append(calls, "conflict targets") + case strings.TrimSpace(string(remainingConstraintsDump)): + calls = append(calls, "remaining constraints") default: require.Failf(t, "unexpected dump", "%q", string(dump)) } @@ -1350,7 +1365,7 @@ ALTER TABLE ONLY public.test_table }, }) require.NoError(t, err) - require.Equal(t, []string{"schema", "constraints", "data"}, calls) + require.Equal(t, []string{"schema", "conflict targets", "data", "remaining constraints"}, calls) } func TestSnapshotGenerator_parseDump(t *testing.T) { From 720d7b44b9165313fbf88d0cc2f588f1552bd486 Mon Sep 17 00:00:00 2001 From: blurskye Date: Sun, 10 May 2026 04:11:45 +0600 Subject: [PATCH 3/4] Tighten conflict target constraint detection --- .../snapshot_pg_dump_restore_generator.go | 10 ++++++++-- .../snapshot_pg_dump_restore_generator_test.go | 10 ++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go b/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go index 47177c3f..46e128a4 100644 --- a/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go +++ b/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go @@ -308,8 +308,14 @@ func isConflictTargetConstraint(block string) bool { if strings.HasPrefix(upperBlock, "CREATE UNIQUE INDEX") { return true } - return strings.Contains(upperBlock, "ADD CONSTRAINT") && - (strings.Contains(upperBlock, " PRIMARY KEY ") || strings.Contains(upperBlock, " UNIQUE ")) + if !strings.Contains(upperBlock, "ADD CONSTRAINT") { + return false + } + return strings.Contains(upperBlock, " PRIMARY KEY (") || + strings.Contains(upperBlock, " PRIMARY KEY USING INDEX ") || + strings.Contains(upperBlock, " UNIQUE (") || + strings.Contains(upperBlock, " UNIQUE NULLS NOT DISTINCT (") || + strings.Contains(upperBlock, " UNIQUE USING INDEX ") } func (s *SnapshotGenerator) restoreIndicesAndConstraints(ctx context.Context, dump []byte, ss *snapshot.Snapshot) error { diff --git a/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator_test.go b/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator_test.go index a5a43128..391deede 100644 --- a/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator_test.go +++ b/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator_test.go @@ -1284,7 +1284,10 @@ ALTER TABLE ONLY public.test_table ADD CONSTRAINT test_table_pkey PRIMARY KEY (id); ALTER TABLE ONLY public.test_table - ADD CONSTRAINT test_table_parent_id_fkey FOREIGN KEY (parent_id) REFERENCES public.test_table(id); + ADD CONSTRAINT test_table_value_key UNIQUE (value); + +ALTER TABLE ONLY public.test_table + ADD CONSTRAINT "test_table UNIQUE named foreign key" FOREIGN KEY (parent_id) REFERENCES public.test_table(id); CREATE INDEX test_table_value_idx ON public.test_table USING btree (value); `) @@ -1298,9 +1301,12 @@ CREATE INDEX test_table_value_idx ON public.test_table USING btree (value); conflictTargetDump := []byte(`ALTER TABLE ONLY public.test_table ADD CONSTRAINT test_table_pkey PRIMARY KEY (id); +ALTER TABLE ONLY public.test_table + ADD CONSTRAINT test_table_value_key UNIQUE (value); + `) remainingConstraintsDump := []byte(`ALTER TABLE ONLY public.test_table - ADD CONSTRAINT test_table_parent_id_fkey FOREIGN KEY (parent_id) REFERENCES public.test_table(id); + ADD CONSTRAINT "test_table UNIQUE named foreign key" FOREIGN KEY (parent_id) REFERENCES public.test_table(id); CREATE INDEX test_table_value_idx ON public.test_table USING btree (value); From 27fab73724fc3b10d47027296fcf5b81d18dcc0e Mon Sep 17 00:00:00 2001 From: blurskye Date: Sun, 10 May 2026 04:15:22 +0600 Subject: [PATCH 4/4] Propagate snapshot processor close errors --- .../postgres/data/pg_snapshot_generator.go | 8 +- .../data/pg_snapshot_generator_test.go | 74 +++++++++++++++++-- pkg/wal/processor/mocks/mock_processor.go | 7 +- 3 files changed, 80 insertions(+), 9 deletions(-) diff --git a/pkg/snapshot/generator/postgres/data/pg_snapshot_generator.go b/pkg/snapshot/generator/postgres/data/pg_snapshot_generator.go index 0a4791c6..43e6cf92 100644 --- a/pkg/snapshot/generator/postgres/data/pg_snapshot_generator.go +++ b/pkg/snapshot/generator/postgres/data/pg_snapshot_generator.go @@ -136,7 +136,13 @@ func (sg *SnapshotGenerator) CreateSnapshot(ctx context.Context, ss *snapshot.Sn defer func() { // make sure we close the processor once the snapshot is completed. // It will wait until all rows are processed before returning. - sg.processor.Close() + if closeErr := sg.processor.Close(); closeErr != nil { + if err == nil { + err = closeErr + } else { + err = errors.Join(err, closeErr) + } + } }() // parallelise the snapshot creation for each schema as configured by the snapshot workers. diff --git a/pkg/snapshot/generator/postgres/data/pg_snapshot_generator_test.go b/pkg/snapshot/generator/postgres/data/pg_snapshot_generator_test.go index ed887348..24714971 100644 --- a/pkg/snapshot/generator/postgres/data/pg_snapshot_generator_test.go +++ b/pkg/snapshot/generator/postgres/data/pg_snapshot_generator_test.go @@ -108,11 +108,12 @@ func TestSnapshotGenerator_CreateSnapshot(t *testing.T) { errTest := errors.New("oh noes") tests := []struct { - name string - querier pglib.Querier - snapshot *snapshot.Snapshot - schemaWorkers uint - progressBar *progressmocks.Bar + name string + querier pglib.Querier + snapshot *snapshot.Snapshot + schemaWorkers uint + progressBar *progressmocks.Bar + processorCloseErr error wantEvents []*wal.Event wantErr error @@ -181,6 +182,66 @@ func TestSnapshotGenerator_CreateSnapshot(t *testing.T) { wantErr: nil, wantEvents: []*wal.Event{testEvent(testTable1, testColumns)}, }, + { + name: "error - closing processor", + querier: &pgmocks.Querier{ + ExecInTxWithOptionsFn: func(_ context.Context, i uint, f func(tx pglib.Tx) error, to pglib.TxOptions) error { + require.Equal(t, txOptions, to) + switch i { + case 1: + mockTx := pgmocks.Tx{ + QueryRowFn: func(_ context.Context, dest []any, query string, args ...any) error { + require.Equal(t, exportSnapshotQuery, query) + snapshotID, ok := dest[0].(*string) + require.True(t, ok) + *snapshotID = testSnapshotID + return nil + }, + } + return f(&mockTx) + case 2: + mockTx := pgmocks.Tx{ + ExecFn: func(ctx context.Context, _ uint, query string, args ...any) (pglib.CommandTag, error) { + require.Equal(t, fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s'", testSnapshotID), query) + return pglib.CommandTag{}, nil + }, + QueryRowFn: validTableInfoQueryRowFn, + } + return f(&mockTx) + case 3: + mockTx := pgmocks.Tx{ + ExecFn: func(ctx context.Context, _ uint, query string, args ...any) (pglib.CommandTag, error) { + require.Equal(t, fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s'", testSnapshotID), query) + return pglib.CommandTag{}, nil + }, + QueryFn: func(ctx context.Context, query string, args ...any) (pglib.Rows, error) { + require.Equal(t, fmt.Sprintf(pageRangeQuery, quotedSchemaTable1, 0, 1), query) + return &pgmocks.Rows{ + CloseFn: func() {}, + NextFn: func(i uint) bool { return i == 1 }, + FieldDescriptionsFn: func() []pgconn.FieldDescription { + return []pgconn.FieldDescription{ + {Name: "id", DataTypeOID: pgtype.UUIDOID}, + {Name: "name", DataTypeOID: pgtype.TextOID}, + } + }, + ValuesFn: func() ([]any, error) { + return []any{testUUID, "alice"}, nil + }, + ErrFn: func() error { return nil }, + }, nil + }, + } + return f(&mockTx) + default: + return fmt.Errorf("unexpected call to ExecInTxWithOptions: %d", i) + } + }, + }, + processorCloseErr: errTest, + wantErr: errTest, + wantEvents: []*wal.Event{testEvent(testTable1, testColumns)}, + }, { name: "ok - quoted identifiers in schema and table names", querier: &pgmocks.Querier{ @@ -1082,6 +1143,9 @@ func TestSnapshotGenerator_CreateSnapshot(t *testing.T) { eventChan <- e return nil }, + CloseFn: func() error { + return tc.processorCloseErr + }, }, schemaWorkers: 1, tableWorkers: 1, diff --git a/pkg/wal/processor/mocks/mock_processor.go b/pkg/wal/processor/mocks/mock_processor.go index 5013bb96..9ad530aa 100644 --- a/pkg/wal/processor/mocks/mock_processor.go +++ b/pkg/wal/processor/mocks/mock_processor.go @@ -4,6 +4,7 @@ package mocks import ( "context" + "sync/atomic" "github.com/xataio/pgstream/pkg/wal" ) @@ -11,16 +12,16 @@ import ( type Processor struct { ProcessWALEventFn func(ctx context.Context, walEvent *wal.Event) error CloseFn func() error - processCalls uint + processCalls atomic.Uint64 } func (m *Processor) ProcessWALEvent(ctx context.Context, walEvent *wal.Event) error { - m.processCalls++ + m.processCalls.Add(1) return m.ProcessWALEventFn(ctx, walEvent) } func (m *Processor) GetProcessCalls() uint { - return m.processCalls + return uint(m.processCalls.Load()) } func (m *Processor) Close() error {