Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,13 @@ func (sg *SnapshotGenerator) CreateSnapshot(ctx context.Context, ss *snapshot.Sn
defer func() {
// make sure we close the processor once the snapshot is completed.
// It will wait until all rows are processed before returning.
sg.processor.Close()
if closeErr := sg.processor.Close(); closeErr != nil {
if err == nil {
err = closeErr
} else {
err = errors.Join(err, closeErr)
}
}
}()

// parallelise the snapshot creation for each schema as configured by the snapshot workers.
Expand Down
74 changes: 69 additions & 5 deletions pkg/snapshot/generator/postgres/data/pg_snapshot_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,12 @@ func TestSnapshotGenerator_CreateSnapshot(t *testing.T) {
errTest := errors.New("oh noes")

tests := []struct {
name string
querier pglib.Querier
snapshot *snapshot.Snapshot
schemaWorkers uint
progressBar *progressmocks.Bar
name string
querier pglib.Querier
snapshot *snapshot.Snapshot
schemaWorkers uint
progressBar *progressmocks.Bar
processorCloseErr error

wantEvents []*wal.Event
wantErr error
Expand Down Expand Up @@ -181,6 +182,66 @@ func TestSnapshotGenerator_CreateSnapshot(t *testing.T) {
wantErr: nil,
wantEvents: []*wal.Event{testEvent(testTable1, testColumns)},
},
{
name: "error - closing processor",
querier: &pgmocks.Querier{
ExecInTxWithOptionsFn: func(_ context.Context, i uint, f func(tx pglib.Tx) error, to pglib.TxOptions) error {
require.Equal(t, txOptions, to)
switch i {
case 1:
mockTx := pgmocks.Tx{
QueryRowFn: func(_ context.Context, dest []any, query string, args ...any) error {
require.Equal(t, exportSnapshotQuery, query)
snapshotID, ok := dest[0].(*string)
require.True(t, ok)
*snapshotID = testSnapshotID
return nil
},
}
return f(&mockTx)
case 2:
mockTx := pgmocks.Tx{
ExecFn: func(ctx context.Context, _ uint, query string, args ...any) (pglib.CommandTag, error) {
require.Equal(t, fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s'", testSnapshotID), query)
return pglib.CommandTag{}, nil
},
QueryRowFn: validTableInfoQueryRowFn,
}
return f(&mockTx)
case 3:
mockTx := pgmocks.Tx{
ExecFn: func(ctx context.Context, _ uint, query string, args ...any) (pglib.CommandTag, error) {
require.Equal(t, fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s'", testSnapshotID), query)
return pglib.CommandTag{}, nil
},
QueryFn: func(ctx context.Context, query string, args ...any) (pglib.Rows, error) {
require.Equal(t, fmt.Sprintf(pageRangeQuery, quotedSchemaTable1, 0, 1), query)
return &pgmocks.Rows{
CloseFn: func() {},
NextFn: func(i uint) bool { return i == 1 },
FieldDescriptionsFn: func() []pgconn.FieldDescription {
return []pgconn.FieldDescription{
{Name: "id", DataTypeOID: pgtype.UUIDOID},
{Name: "name", DataTypeOID: pgtype.TextOID},
}
},
ValuesFn: func() ([]any, error) {
return []any{testUUID, "alice"}, nil
},
ErrFn: func() error { return nil },
}, nil
},
}
return f(&mockTx)
default:
return fmt.Errorf("unexpected call to ExecInTxWithOptions: %d", i)
}
},
},
processorCloseErr: errTest,
wantErr: errTest,
wantEvents: []*wal.Event{testEvent(testTable1, testColumns)},
},
{
name: "ok - quoted identifiers in schema and table names",
querier: &pgmocks.Querier{
Expand Down Expand Up @@ -1082,6 +1143,9 @@ func TestSnapshotGenerator_CreateSnapshot(t *testing.T) {
eventChan <- e
return nil
},
CloseFn: func() error {
return tc.processorCloseErr
},
},
schemaWorkers: 1,
tableWorkers: 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ type SnapshotGenerator struct {
roleSQLParser *roleSQLParser
optionGenerator *optionGenerator
snapshotTracker snapshotProgressTracker
// restoreIndicesAndConstraintsBeforeData restores constraints/indexes that can
// be used as INSERT ... ON CONFLICT targets before the wrapped data snapshot
// generator runs. Other indexes and constraints, such as foreign keys, are
// still restored after data is inserted.
restoreIndicesAndConstraintsBeforeData bool
}

type snapshotProgressTracker interface {
Expand Down Expand Up @@ -67,6 +72,11 @@ type Config struct {
DumpDebugFile string
// if set, security label providers that will be excluded from the dump
ExcludedSecurityLabels []string
// Restore constraints/indexes that can be used as INSERT ... ON CONFLICT
// targets before data is snapshotted. This is required when the data snapshot
// writer emits INSERT ... ON CONFLICT DO UPDATE, because the target table must
// already have a matching unique or primary key constraint.
RestoreIndicesAndConstraintsBeforeData bool
}

type Option func(s *SnapshotGenerator)
Expand Down Expand Up @@ -96,17 +106,18 @@ func NewSnapshotGenerator(ctx context.Context, c *Config, opts ...Option) (*Snap
}

sg := &SnapshotGenerator{
sourceURL: c.SourcePGURL,
targetURL: c.TargetPGURL,
pgDumpFn: pglib.RunPGDump,
pgDumpAllFn: pglib.RunPGDumpAll,
pgRestoreFn: pglib.RunPGRestore,
logger: loglib.NewNoopLogger(),
dumpDebugFile: c.DumpDebugFile,
excludedSecurityLabels: c.ExcludedSecurityLabels,
roleSQLParser: &roleSQLParser{},
sourceQuerier: sourceConnPool,
optionGenerator: newOptionGenerator(sourceConnPool, c),
sourceURL: c.SourcePGURL,
targetURL: c.TargetPGURL,
pgDumpFn: pglib.RunPGDump,
pgDumpAllFn: pglib.RunPGDumpAll,
pgRestoreFn: pglib.RunPGRestore,
logger: loglib.NewNoopLogger(),
dumpDebugFile: c.DumpDebugFile,
excludedSecurityLabels: c.ExcludedSecurityLabels,
restoreIndicesAndConstraintsBeforeData: c.RestoreIndicesAndConstraintsBeforeData,
roleSQLParser: &roleSQLParser{},
sourceQuerier: sourceConnPool,
optionGenerator: newOptionGenerator(sourceConnPool, c),
}

for _, opt := range opts {
Expand Down Expand Up @@ -225,8 +236,19 @@ func (s *SnapshotGenerator) CreateSnapshot(ctx context.Context, ss *snapshot.Sna
return err
}

indicesAndConstraintsDump := dump.indicesAndConstraints
if s.generator != nil && s.restoreIndicesAndConstraintsBeforeData {
conflictTargets, remaining := splitConflictTargetConstraints(dump.indicesAndConstraints)
if err := s.restoreIndicesAndConstraints(ctx, conflictTargets, ss); err != nil {
return err
}
indicesAndConstraintsDump = remaining
}

// call the wrapped snapshot generator if any before restoring sequences,
// indices and constraints to improve performance.
// indices and constraints to improve performance. When the data snapshot
// writer emits INSERT ... ON CONFLICT DO UPDATE, the subset of constraints
// needed as conflict targets is restored before data above.
if s.generator != nil {
if err := s.generator.CreateSnapshot(ctx, ss); err != nil {
return err
Expand All @@ -239,19 +261,71 @@ func (s *SnapshotGenerator) CreateSnapshot(ctx context.Context, ss *snapshot.Sna
return err
}

s.logger.Info("restoring schema indices and constraints", loglib.Fields{"schemaTables": ss.SchemaTables})
if s.snapshotTracker != nil {
if err := s.restoreIndicesWithTracking(ctx, dump.indicesAndConstraints); err != nil {
return err
}
} else if err := s.restoreDump(ctx, dump.indicesAndConstraints); err != nil {
if err := s.restoreIndicesAndConstraints(ctx, indicesAndConstraintsDump, ss); err != nil {
return err
}

s.logger.Info("restoring views")
return s.restoreDump(ctx, dump.views)
}

func splitConflictTargetConstraints(d []byte) ([]byte, []byte) {
blocks := strings.Split(string(d), "\n\n")
connectBlocks := []string{}
conflictTargetBlocks := []string{}
remainingBlocks := []string{}

for _, block := range blocks {
block = strings.TrimSpace(block)
if block == "" {
continue
}
switch {
case strings.Contains(block, `\connect`):
connectBlocks = append(connectBlocks, block)
case isConflictTargetConstraint(block):
conflictTargetBlocks = append(conflictTargetBlocks, block)
default:
remainingBlocks = append(remainingBlocks, block)
}
}

return joinDumpBlocks(connectBlocks, conflictTargetBlocks), joinDumpBlocks(connectBlocks, remainingBlocks)
}

func joinDumpBlocks(connectBlocks, blocks []string) []byte {
if len(blocks) == 0 {
return nil
}
allBlocks := make([]string, 0, len(connectBlocks)+len(blocks))
allBlocks = append(allBlocks, connectBlocks...)
allBlocks = append(allBlocks, blocks...)
return []byte(strings.Join(allBlocks, "\n\n") + "\n\n")
}

func isConflictTargetConstraint(block string) bool {
upperBlock := strings.ToUpper(block)
if strings.HasPrefix(upperBlock, "CREATE UNIQUE INDEX") {
return true
}
if !strings.Contains(upperBlock, "ADD CONSTRAINT") {
return false
}
return strings.Contains(upperBlock, " PRIMARY KEY (") ||
strings.Contains(upperBlock, " PRIMARY KEY USING INDEX ") ||
strings.Contains(upperBlock, " UNIQUE (") ||
strings.Contains(upperBlock, " UNIQUE NULLS NOT DISTINCT (") ||
strings.Contains(upperBlock, " UNIQUE USING INDEX ")
}

func (s *SnapshotGenerator) restoreIndicesAndConstraints(ctx context.Context, dump []byte, ss *snapshot.Snapshot) error {
s.logger.Info("restoring schema indices and constraints", loglib.Fields{"schemaTables": ss.SchemaTables})
if s.snapshotTracker != nil {
return s.restoreIndicesWithTracking(ctx, dump)
}
return s.restoreDump(ctx, dump)
}

func (s *SnapshotGenerator) Close() error {
if s.generator != nil {
if err := s.generator.Close(); err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,109 @@ func TestSnapshotGenerator_CreateSnapshot(t *testing.T) {
}
}

func TestSnapshotGenerator_RestoresConstraintsBeforeDataWhenConfigured(t *testing.T) {
t.Parallel()

schemaDump := []byte(`CREATE TABLE public.test_table (
id integer NOT NULL,
parent_id integer,
value text NOT NULL
);

ALTER TABLE ONLY public.test_table
ADD CONSTRAINT test_table_pkey PRIMARY KEY (id);

ALTER TABLE ONLY public.test_table
ADD CONSTRAINT test_table_value_key UNIQUE (value);

ALTER TABLE ONLY public.test_table
ADD CONSTRAINT "test_table UNIQUE named foreign key" FOREIGN KEY (parent_id) REFERENCES public.test_table(id);

CREATE INDEX test_table_value_idx ON public.test_table USING btree (value);
`)
filteredDump := []byte(`CREATE TABLE public.test_table (
id integer NOT NULL,
parent_id integer,
value text NOT NULL
);

`)
conflictTargetDump := []byte(`ALTER TABLE ONLY public.test_table
ADD CONSTRAINT test_table_pkey PRIMARY KEY (id);

ALTER TABLE ONLY public.test_table
ADD CONSTRAINT test_table_value_key UNIQUE (value);

`)
remainingConstraintsDump := []byte(`ALTER TABLE ONLY public.test_table
ADD CONSTRAINT "test_table UNIQUE named foreign key" FOREIGN KEY (parent_id) REFERENCES public.test_table(id);

CREATE INDEX test_table_value_idx ON public.test_table USING btree (value);

`)

calls := []string{}
conn := &mocks.Querier{
ExecFn: func(ctx context.Context, i uint, query string, args ...any) (pglib.CommandTag, error) {
return pglib.CommandTag{}, nil
},
QueryFn: func(ctx context.Context, i uint, query string, args ...any) (pglib.Rows, error) {
return &mocks.Rows{
CloseFn: func() {},
NextFn: func(i uint) bool { return false },
ErrFn: func() error { return nil },
}, nil
},
}

sg := SnapshotGenerator{
sourceURL: "source-url",
targetURL: "target-url",
sourceQuerier: conn,
pgDumpFn: newMockPgdump(func(_ context.Context, i uint, po pglib.PGDumpOptions) ([]byte, error) {
require.Equal(t, uint(1), i)
return schemaDump, nil
}),
pgRestoreFn: newMockPgrestore(func(_ context.Context, i uint, po pglib.PGRestoreOptions, dump []byte) (string, error) {
switch strings.TrimSpace(string(dump)) {
case strings.TrimSpace(string(filteredDump)):
calls = append(calls, "schema")
case strings.TrimSpace(string(conflictTargetDump)):
calls = append(calls, "conflict targets")
case strings.TrimSpace(string(remainingConstraintsDump)):
calls = append(calls, "remaining constraints")
default:
require.Failf(t, "unexpected dump", "%q", string(dump))
}
return "", nil
}),
logger: log.NewNoopLogger(),
generator: &generatormocks.Generator{
CreateSnapshotFn: func(ctx context.Context, snapshot *snapshot.Snapshot) error {
calls = append(calls, "data")
return nil
},
},
roleSQLParser: &roleSQLParser{},
optionGenerator: &optionGenerator{
sourceURL: "source-url",
targetURL: "target-url",
noOwner: true,
rolesSnapshotMode: roleSnapshotDisabled,
querier: conn,
},
restoreIndicesAndConstraintsBeforeData: true,
}

err := sg.CreateSnapshot(context.Background(), &snapshot.Snapshot{
SchemaTables: map[string][]string{
publicSchema: {"test_table"},
},
})
require.NoError(t, err)
require.Equal(t, []string{"schema", "conflict targets", "data", "remaining constraints"}, calls)
}

func TestSnapshotGenerator_parseDump(t *testing.T) {
t.Parallel()

Expand Down
Loading
Loading