From 16d3d20a214eb7e07f6d4d8cb7eac02448222886 Mon Sep 17 00:00:00 2001 From: Tudor Golubenco Date: Sun, 22 Feb 2026 11:06:42 +0100 Subject: [PATCH 1/2] Granular object type filtering for pg to pg snapshot and streaming --- cmd/config/config_env.go | 18 +- cmd/config/config_yaml.go | 34 +- config_template.yaml | 9 + .../pgdumprestore/object_type_filter.go | 193 ++++++ .../pgdumprestore/object_type_filter_test.go | 611 ++++++++++++++++++ .../snapshot_pg_dump_restore_generator.go | 66 +- pkg/stream/integration/helper_test.go | 14 + ..._pg_object_type_filter_integration_test.go | 255 ++++++++ .../testdata/object_types_fixture.sql | 414 ++++++++++++ pkg/wal/processor/postgres/config.go | 8 + .../postgres/ddl_object_type_filter.go | 175 +++++ .../postgres/ddl_object_type_filter_test.go | 209 ++++++ .../postgres/postgres_wal_adapter.go | 38 +- pkg/wal/processor/postgres/postgres_writer.go | 2 +- 14 files changed, 2001 insertions(+), 45 deletions(-) create mode 100644 pkg/snapshot/generator/postgres/schema/pgdumprestore/object_type_filter.go create mode 100644 pkg/snapshot/generator/postgres/schema/pgdumprestore/object_type_filter_test.go create mode 100644 pkg/stream/integration/pg_pg_object_type_filter_integration_test.go create mode 100644 pkg/stream/integration/testdata/object_types_fixture.sql create mode 100644 pkg/wal/processor/postgres/ddl_object_type_filter.go create mode 100644 pkg/wal/processor/postgres/ddl_object_type_filter_test.go diff --git a/cmd/config/config_env.go b/cmd/config/config_env.go index 6f3316c3..c1f868f8 100644 --- a/cmd/config/config_env.go +++ b/cmd/config/config_env.go @@ -64,6 +64,8 @@ func init() { viper.BindEnv("PGSTREAM_POSTGRES_SNAPSHOT_NO_OWNER") viper.BindEnv("PGSTREAM_POSTGRES_SNAPSHOT_NO_PRIVILEGES") viper.BindEnv("PGSTREAM_POSTGRES_SNAPSHOT_EXCLUDED_SECURITY_LABELS") + viper.BindEnv("PGSTREAM_POSTGRES_SNAPSHOT_INCLUDE_OBJECT_TYPES") + viper.BindEnv("PGSTREAM_POSTGRES_SNAPSHOT_EXCLUDE_OBJECT_TYPES") viper.BindEnv("PGSTREAM_POSTGRES_SNAPSHOT_DISABLE_PROGRESS_TRACKING") viper.BindEnv("PGSTREAM_POSTGRES_WRITER_TARGET_URL") @@ -86,6 +88,8 @@ func init() { viper.BindEnv("PGSTREAM_POSTGRES_WRITER_BACKOFF_MAX_RETRIES") viper.BindEnv("PGSTREAM_POSTGRES_WRITER_DISABLE_RETRIES") viper.BindEnv("PGSTREAM_POSTGRES_WRITER_IGNORE_DDL") + viper.BindEnv("PGSTREAM_POSTGRES_WRITER_INCLUDE_DDL_OBJECT_TYPES") + viper.BindEnv("PGSTREAM_POSTGRES_WRITER_EXCLUDE_DDL_OBJECT_TYPES") viper.BindEnv("PGSTREAM_KAFKA_READER_SERVERS") viper.BindEnv("PGSTREAM_KAFKA_WRITER_SERVERS") @@ -313,6 +317,8 @@ func parseSchemaSnapshotConfig(pgurl string) (*snapshotbuilder.SchemaSnapshotCon NoOwner: viper.GetBool("PGSTREAM_POSTGRES_SNAPSHOT_NO_OWNER"), NoPrivileges: viper.GetBool("PGSTREAM_POSTGRES_SNAPSHOT_NO_PRIVILEGES"), ExcludedSecurityLabels: viper.GetStringSlice("PGSTREAM_POSTGRES_SNAPSHOT_EXCLUDED_SECURITY_LABELS"), + IncludeObjectTypes: viper.GetStringSlice("PGSTREAM_POSTGRES_SNAPSHOT_INCLUDE_OBJECT_TYPES"), + ExcludeObjectTypes: viper.GetStringSlice("PGSTREAM_POSTGRES_SNAPSHOT_EXCLUDE_OBJECT_TYPES"), }, }, nil } @@ -479,11 +485,13 @@ func parsePostgresProcessorConfig() *stream.PostgresProcessorConfig { ConvergenceThreshold: viper.GetFloat64("PGSTREAM_POSTGRES_WRITER_BATCH_AUTO_TUNE_CONVERGENCE_THRESHOLD"), }, }, - DisableTriggers: viper.GetBool("PGSTREAM_POSTGRES_WRITER_DISABLE_TRIGGERS"), - OnConflictAction: viper.GetString("PGSTREAM_POSTGRES_WRITER_ON_CONFLICT_ACTION"), - BulkIngestEnabled: bulkIngestEnabled, - RetryPolicy: parseBackoffConfig("PGSTREAM_POSTGRES_WRITER"), - IgnoreDDL: viper.GetBool("PGSTREAM_POSTGRES_WRITER_IGNORE_DDL"), + DisableTriggers: viper.GetBool("PGSTREAM_POSTGRES_WRITER_DISABLE_TRIGGERS"), + OnConflictAction: viper.GetString("PGSTREAM_POSTGRES_WRITER_ON_CONFLICT_ACTION"), + BulkIngestEnabled: bulkIngestEnabled, + RetryPolicy: parseBackoffConfig("PGSTREAM_POSTGRES_WRITER"), + IgnoreDDL: viper.GetBool("PGSTREAM_POSTGRES_WRITER_IGNORE_DDL"), + IncludeDDLObjectTypes: viper.GetStringSlice("PGSTREAM_POSTGRES_WRITER_INCLUDE_DDL_OBJECT_TYPES"), + ExcludeDDLObjectTypes: viper.GetStringSlice("PGSTREAM_POSTGRES_WRITER_EXCLUDE_DDL_OBJECT_TYPES"), }, } diff --git a/cmd/config/config_yaml.go b/cmd/config/config_yaml.go index 6d2e709c..62683521 100644 --- a/cmd/config/config_yaml.go +++ b/cmd/config/config_yaml.go @@ -108,6 +108,8 @@ type PgDumpPgRestoreConfig struct { NoPrivileges bool `mapstructure:"no_privileges" yaml:"no_privileges"` DumpFile string `mapstructure:"dump_file" yaml:"dump_file"` ExcludedSecurityLabels []string `mapstructure:"excluded_security_labels" yaml:"excluded_security_labels"` + IncludeObjectTypes []string `mapstructure:"include_object_types" yaml:"include_object_types"` + ExcludeObjectTypes []string `mapstructure:"exclude_object_types" yaml:"exclude_object_types"` } type ReplicationConfig struct { @@ -163,13 +165,15 @@ type ConstantBackoffConfig struct { } type PostgresTargetConfig struct { - URL string `mapstructure:"url" yaml:"url"` - Batch *BatchConfig `mapstructure:"batch" yaml:"batch"` - BulkIngest *BulkIngestConfig `mapstructure:"bulk_ingest" yaml:"bulk_ingest"` - DisableTriggers bool `mapstructure:"disable_triggers" yaml:"disable_triggers"` - OnConflictAction string `mapstructure:"on_conflict_action" yaml:"on_conflict_action"` - RetryPolicy BackoffConfig `mapstructure:"retry_policy" yaml:"retry_policy"` - IgnoreDDL bool `mapstructure:"ignore_ddl" yaml:"ignore_ddl"` + URL string `mapstructure:"url" yaml:"url"` + Batch *BatchConfig `mapstructure:"batch" yaml:"batch"` + BulkIngest *BulkIngestConfig `mapstructure:"bulk_ingest" yaml:"bulk_ingest"` + DisableTriggers bool `mapstructure:"disable_triggers" yaml:"disable_triggers"` + OnConflictAction string `mapstructure:"on_conflict_action" yaml:"on_conflict_action"` + RetryPolicy BackoffConfig `mapstructure:"retry_policy" yaml:"retry_policy"` + IgnoreDDL bool `mapstructure:"ignore_ddl" yaml:"ignore_ddl"` + IncludeDDLObjectTypes []string `mapstructure:"include_ddl_object_types" yaml:"include_ddl_object_types"` + ExcludeDDLObjectTypes []string `mapstructure:"exclude_ddl_object_types" yaml:"exclude_ddl_object_types"` } type KafkaTargetConfig struct { @@ -551,6 +555,8 @@ func (c *YAMLConfig) parseSchemaSnapshotConfig() (*snapshotbuilder.SchemaSnapsho streamSchemaCfg.DumpRestore.NoPrivileges = schemaSnapshotCfg.PgDumpPgRestore.NoPrivileges streamSchemaCfg.DumpRestore.DumpDebugFile = schemaSnapshotCfg.PgDumpPgRestore.DumpFile streamSchemaCfg.DumpRestore.ExcludedSecurityLabels = schemaSnapshotCfg.PgDumpPgRestore.ExcludedSecurityLabels + streamSchemaCfg.DumpRestore.IncludeObjectTypes = schemaSnapshotCfg.PgDumpPgRestore.IncludeObjectTypes + streamSchemaCfg.DumpRestore.ExcludeObjectTypes = schemaSnapshotCfg.PgDumpPgRestore.ExcludeObjectTypes var err error streamSchemaCfg.DumpRestore.RolesSnapshotMode, err = getRolesSnapshotMode(schemaSnapshotCfg.PgDumpPgRestore.RolesSnapshotMode) @@ -609,12 +615,14 @@ func (c *YAMLConfig) parsePostgresProcessorConfig() *stream.PostgresProcessorCon cfg := &stream.PostgresProcessorConfig{ BatchWriter: postgres.Config{ - URL: c.Target.Postgres.URL, - BatchConfig: c.Target.Postgres.Batch.parseBatchConfig(), - DisableTriggers: c.Target.Postgres.DisableTriggers, - OnConflictAction: c.Target.Postgres.OnConflictAction, - RetryPolicy: c.Target.Postgres.RetryPolicy.parseBackoffConfig(), - IgnoreDDL: c.Target.Postgres.IgnoreDDL, + URL: c.Target.Postgres.URL, + BatchConfig: c.Target.Postgres.Batch.parseBatchConfig(), + DisableTriggers: c.Target.Postgres.DisableTriggers, + OnConflictAction: c.Target.Postgres.OnConflictAction, + RetryPolicy: c.Target.Postgres.RetryPolicy.parseBackoffConfig(), + IgnoreDDL: c.Target.Postgres.IgnoreDDL, + IncludeDDLObjectTypes: c.Target.Postgres.IncludeDDLObjectTypes, + ExcludeDDLObjectTypes: c.Target.Postgres.ExcludeDDLObjectTypes, }, } diff --git a/config_template.yaml b/config_template.yaml index 84445f78..a6b404fb 100644 --- a/config_template.yaml +++ b/config_template.yaml @@ -34,6 +34,10 @@ source: roles_snapshot_mode: # enabled by default. Can be set to disabled to disable roles snapshotting, or can be set to no_passwords to exclude role passwords exclude_security_labels: ["anon"] # list of providers whose security labels will be excluded from the snapshot. Wildcard supported. dump_file: pg_dump.sql # name of the file where the contents of the schema pg_dump command and output will be written for debugging purposes. + # Granular object type filtering for schema snapshots. Only one of include_object_types or exclude_object_types can be set. + # Available categories: tables, sequences, types, indexes, constraints, functions, views, materialized_views, triggers, event_triggers, policies, rules, comments, extensions, collations, text_search + # include_object_types: ["tables", "sequences", "types"] # only include these object types in the schema snapshot + # exclude_object_types: ["functions", "views", "triggers"] # exclude these object types from the schema snapshot replication: # when mode is replication or snapshot_and_replication replication_slot: "pgstream_mydatabase_slot" plugin: @@ -75,6 +79,11 @@ target: schema_log_store_url: "postgresql://user:password@localhost:5432/mydatabase" # url to the postgres database where the schema log is stored to be used when performing schema change diffs disable_triggers: false # whether to disable triggers on the target database. Defaults to false on_conflict_action: "nothing" # options are update, nothing or error. Defaults to error + ignore_ddl: false # whether to skip all DDL replication. Defaults to false + # Selective DDL object type filtering for replication. Only one of include_ddl_object_types or exclude_ddl_object_types can be set. Ignored if ignore_ddl is true. + # Available categories: tables, sequences, types, indexes, constraints, functions, views, materialized_views, triggers, event_triggers, policies, rules, extensions, collations, text_search + # include_ddl_object_types: ["tables", "sequences", "types"] # only replicate DDL for these object types + # exclude_ddl_object_types: ["functions", "views", "triggers"] # skip DDL replication for these object types bulk_ingest: enabled: true # whether to enable bulk ingest on the target postgres, using COPY FROM (supported for insert only workloads) kafka: diff --git a/pkg/snapshot/generator/postgres/schema/pgdumprestore/object_type_filter.go b/pkg/snapshot/generator/postgres/schema/pgdumprestore/object_type_filter.go new file mode 100644 index 00000000..f48f44e4 --- /dev/null +++ b/pkg/snapshot/generator/postgres/schema/pgdumprestore/object_type_filter.go @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: Apache-2.0 + +package pgdumprestore + +import ( + "bufio" + "bytes" + "fmt" + "regexp" + "strings" +) + +// objectTypeCategory maps user-facing category names to pg_dump TOC Type values. +var objectTypeCategories = map[string][]string{ + "tables": {"TABLE", "DEFAULT"}, + "sequences": {"SEQUENCE", "SEQUENCE OWNED BY"}, + "types": {"TYPE", "DOMAIN"}, + "indexes": {"INDEX"}, + "constraints": {"CONSTRAINT", "FK CONSTRAINT"}, + "functions": {"FUNCTION", "AGGREGATE", "PROCEDURE"}, + "views": {"VIEW"}, + "materialized_views": {"MATERIALIZED VIEW"}, + "triggers": {"TRIGGER"}, + "event_triggers": {"EVENT TRIGGER"}, + "policies": {"POLICY", "ROW SECURITY"}, + "rules": {"RULE"}, + "comments": {"COMMENT"}, + "extensions": {"EXTENSION"}, + "collations": {"COLLATION"}, + "text_search": {"TEXT SEARCH CONFIGURATION", "TEXT SEARCH DICTIONARY", "TEXT SEARCH PARSER", "TEXT SEARCH TEMPLATE"}, +} + +// tocHeaderRegex matches pg_dump TOC comment headers like: +// -- Name: my_func; Type: FUNCTION; Schema: public; Owner: postgres +var tocHeaderRegex = regexp.MustCompile(`^--\s*Name:.*;\s*Type:\s*([^;]+)\s*;`) + +// parseTOCHeader extracts the object Type value from a pg_dump TOC comment header line. +// Returns the type string and true if the line is a TOC header, or ("", false) otherwise. +func parseTOCHeader(line string) (string, bool) { + matches := tocHeaderRegex.FindStringSubmatch(line) + if len(matches) < 2 { + return "", false + } + return strings.TrimSpace(matches[1]), true +} + +// objectTypeFilter determines which pg_dump object types should be excluded +// based on user-specified include or exclude category lists. +type objectTypeFilter struct { + excludedTypes map[string]struct{} + // categories tracks which user-facing categories are excluded for + // higher-level checks (e.g., skipping sequence dump step). + excludedCategories map[string]struct{} +} + +// newObjectTypeFilter creates an objectTypeFilter from include/exclude category lists. +// Only one of include or exclude can be set (not both). +// Returns nil if neither is set (no filtering). +func newObjectTypeFilter(include, exclude []string) (*objectTypeFilter, error) { + if len(include) > 0 && len(exclude) > 0 { + return nil, fmt.Errorf("include_object_types and exclude_object_types cannot both be set") + } + + if len(include) == 0 && len(exclude) == 0 { + return nil, nil + } + + f := &objectTypeFilter{ + excludedTypes: make(map[string]struct{}), + excludedCategories: make(map[string]struct{}), + } + + if len(include) > 0 { + // Validate all included categories + includedSet := make(map[string]struct{}, len(include)) + for _, cat := range include { + if _, ok := objectTypeCategories[cat]; !ok { + return nil, fmt.Errorf("unknown object type category: %q", cat) + } + includedSet[cat] = struct{}{} + } + // Exclude everything NOT in the include list + for cat, types := range objectTypeCategories { + if _, included := includedSet[cat]; !included { + f.excludedCategories[cat] = struct{}{} + for _, t := range types { + f.excludedTypes[t] = struct{}{} + } + } + } + } else { + // Validate and exclude the specified categories + for _, cat := range exclude { + types, ok := objectTypeCategories[cat] + if !ok { + return nil, fmt.Errorf("unknown object type category: %q", cat) + } + f.excludedCategories[cat] = struct{}{} + for _, t := range types { + f.excludedTypes[t] = struct{}{} + } + } + } + + return f, nil +} + +// isExcluded returns true if the given pg_dump Type value should be excluded. +// SCHEMA type is never excluded (required for namespace resolution). +func (f *objectTypeFilter) isExcluded(pgdumpType string) bool { + if f == nil { + return false + } + // SCHEMA is always included + if pgdumpType == "SCHEMA" { + return false + } + _, excluded := f.excludedTypes[pgdumpType] + return excluded +} + +// isCategoryExcluded returns true if the given user-facing category is excluded. +func (f *objectTypeFilter) isCategoryExcluded(category string) bool { + if f == nil { + return false + } + _, excluded := f.excludedCategories[category] + return excluded +} + +// cleanupStatementPrefixes maps SQL cleanup statement prefixes (from pg_dump +// --clean --if-exists output) to the object type category they belong to. +var cleanupStatementPrefixes = map[string]string{ + "DROP POLICY": "policies", + "DROP TRIGGER": "triggers", + "DROP RULE": "rules", + "DROP INDEX": "indexes", + "DROP FUNCTION": "functions", + "DROP AGGREGATE": "functions", + "DROP PROCEDURE": "functions", + "DROP VIEW": "views", + "DROP MATERIALIZED VIEW": "materialized_views", + "DROP TEXT SEARCH": "text_search", + "DROP COLLATION": "collations", + "DROP EXTENSION": "extensions", + "DROP EVENT TRIGGER": "event_triggers", + "DROP SEQUENCE": "sequences", + "DROP TABLE": "tables", + "DROP TYPE": "types", + "DROP DOMAIN": "types", + "DROP SCHEMA": "schemas", + "COMMENT ON": "comments", +} + +// filterCleanupDump removes lines from a cleanup dump that belong to excluded +// object type categories. This prevents errors like "relation does not exist" +// when DROP POLICY/TRIGGER/RULE statements reference tables that don't yet +// exist on the target. +func (f *objectTypeFilter) filterCleanupDump(cleanupDump []byte) []byte { + if f == nil { + return cleanupDump + } + + scanner := bufio.NewScanner(bytes.NewReader(cleanupDump)) + var filtered strings.Builder + for scanner.Scan() { + line := scanner.Text() + if f.shouldSkipCleanupLine(line) { + continue + } + filtered.WriteString(line) + filtered.WriteString("\n") + } + return []byte(filtered.String()) +} + +// shouldSkipCleanupLine returns true if a cleanup dump line should be skipped +// because it belongs to an excluded object type category. +func (f *objectTypeFilter) shouldSkipCleanupLine(line string) bool { + if f == nil { + return false + } + for prefix, cat := range cleanupStatementPrefixes { + if strings.HasPrefix(line, prefix) { + // SCHEMA is never excluded + if cat == "schemas" { + return false + } + return f.isCategoryExcluded(cat) + } + } + return false +} diff --git a/pkg/snapshot/generator/postgres/schema/pgdumprestore/object_type_filter_test.go b/pkg/snapshot/generator/postgres/schema/pgdumprestore/object_type_filter_test.go new file mode 100644 index 00000000..c668e01a --- /dev/null +++ b/pkg/snapshot/generator/postgres/schema/pgdumprestore/object_type_filter_test.go @@ -0,0 +1,611 @@ +// SPDX-License-Identifier: Apache-2.0 + +package pgdumprestore + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseTOCHeader(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + line string + wantType string + wantOK bool + }{ + { + name: "function header", + line: "-- Name: my_func; Type: FUNCTION; Schema: public; Owner: postgres", + wantType: "FUNCTION", + wantOK: true, + }, + { + name: "table header", + line: "-- Name: users; Type: TABLE; Schema: public; Owner: postgres", + wantType: "TABLE", + wantOK: true, + }, + { + name: "index header", + line: "-- Name: area_alias_idx_txt; Type: INDEX; Schema: musicbrainz; Owner: postgres", + wantType: "INDEX", + wantOK: true, + }, + { + name: "schema header", + line: "-- Name: musicbrainz; Type: SCHEMA; Schema: -; Owner: postgres", + wantType: "SCHEMA", + wantOK: true, + }, + { + name: "fk constraint header", + line: "-- Name: alternative_medium alternative_medium_fk_medium; Type: FK CONSTRAINT; Schema: musicbrainz; Owner: postgres", + wantType: "FK CONSTRAINT", + wantOK: true, + }, + { + name: "constraint header", + line: "-- Name: alternative_medium alternative_medium_pkey; Type: CONSTRAINT; Schema: musicbrainz; Owner: postgres", + wantType: "CONSTRAINT", + wantOK: true, + }, + { + name: "trigger header", + line: "-- Name: alternative_medium_track a_del_alternative_medium_track; Type: TRIGGER; Schema: musicbrainz; Owner: postgres", + wantType: "TRIGGER", + wantOK: true, + }, + { + name: "aggregate header", + line: "-- Name: median(integer); Type: AGGREGATE; Schema: musicbrainz; Owner: postgres", + wantType: "AGGREGATE", + wantOK: true, + }, + { + name: "text search configuration header", + line: "-- Name: mb_simple; Type: TEXT SEARCH CONFIGURATION; Schema: musicbrainz; Owner: postgres", + wantType: "TEXT SEARCH CONFIGURATION", + wantOK: true, + }, + { + name: "sequence header", + line: "-- Name: alternative_medium_id_seq; Type: SEQUENCE; Schema: musicbrainz; Owner: postgres", + wantType: "SEQUENCE", + wantOK: true, + }, + { + name: "default header", + line: "-- Name: alternative_medium id; Type: DEFAULT; Schema: musicbrainz; Owner: postgres", + wantType: "DEFAULT", + wantOK: true, + }, + { + name: "extension header", + line: "-- Name: cube; Type: EXTENSION; Schema: -; Owner: -", + wantType: "EXTENSION", + wantOK: true, + }, + { + name: "comment header", + line: "-- Name: EXTENSION cube; Type: COMMENT; Schema: -; Owner:", + wantType: "COMMENT", + wantOK: true, + }, + { + name: "collation header", + line: "-- Name: musicbrainz; Type: COLLATION; Schema: musicbrainz; Owner: postgres", + wantType: "COLLATION", + wantOK: true, + }, + { + name: "type header", + line: "-- Name: cover_art_presence; Type: TYPE; Schema: musicbrainz; Owner: postgres", + wantType: "TYPE", + wantOK: true, + }, + { + name: "regular comment line", + line: "-- PostgreSQL database dump", + wantOK: false, + }, + { + name: "empty line", + line: "", + wantOK: false, + }, + { + name: "SQL statement", + line: "CREATE TABLE users (id integer);", + wantOK: false, + }, + { + name: "separator comment", + line: "--", + wantOK: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + gotType, gotOK := parseTOCHeader(tc.line) + require.Equal(t, tc.wantOK, gotOK) + if gotOK { + require.Equal(t, tc.wantType, gotType) + } + }) + } +} + +func TestNewObjectTypeFilter(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + include []string + exclude []string + + wantNil bool + wantErr bool + }{ + { + name: "no filter configured", + wantNil: true, + }, + { + name: "include only", + include: []string{"tables", "sequences", "types"}, + }, + { + name: "exclude only", + exclude: []string{"functions", "views"}, + }, + { + name: "both set - error", + include: []string{"tables"}, + exclude: []string{"functions"}, + wantErr: true, + }, + { + name: "unknown include category", + include: []string{"tables", "unknown_category"}, + wantErr: true, + }, + { + name: "unknown exclude category", + exclude: []string{"unknown_category"}, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + f, err := newObjectTypeFilter(tc.include, tc.exclude) + if tc.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + + if tc.wantNil { + require.Nil(t, f) + return + } + + require.NotNil(t, f) + }) + } +} + +func TestObjectTypeFilter_IsExcluded(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + include []string + exclude []string + pgdumpType string + + wantExcluded bool + }{ + { + name: "nil filter - never excluded", + pgdumpType: "FUNCTION", + wantExcluded: false, + }, + { + name: "include tables only - function excluded", + include: []string{"tables", "sequences", "types"}, + pgdumpType: "FUNCTION", + wantExcluded: true, + }, + { + name: "include tables only - table included", + include: []string{"tables", "sequences", "types"}, + pgdumpType: "TABLE", + wantExcluded: false, + }, + { + name: "include tables only - DEFAULT included (part of tables)", + include: []string{"tables", "sequences", "types"}, + pgdumpType: "DEFAULT", + wantExcluded: false, + }, + { + name: "include tables only - SEQUENCE included", + include: []string{"tables", "sequences", "types"}, + pgdumpType: "SEQUENCE", + wantExcluded: false, + }, + { + name: "include tables only - INDEX excluded", + include: []string{"tables", "sequences", "types"}, + pgdumpType: "INDEX", + wantExcluded: true, + }, + { + name: "exclude functions - function excluded", + exclude: []string{"functions"}, + pgdumpType: "FUNCTION", + wantExcluded: true, + }, + { + name: "exclude functions - aggregate excluded", + exclude: []string{"functions"}, + pgdumpType: "AGGREGATE", + wantExcluded: true, + }, + { + name: "exclude functions - table not excluded", + exclude: []string{"functions"}, + pgdumpType: "TABLE", + wantExcluded: false, + }, + { + name: "SCHEMA is never excluded", + include: []string{"tables"}, + pgdumpType: "SCHEMA", + wantExcluded: false, + }, + { + name: "exclude views - VIEW excluded", + exclude: []string{"views"}, + pgdumpType: "VIEW", + wantExcluded: true, + }, + { + name: "exclude text_search - TEXT SEARCH CONFIGURATION excluded", + exclude: []string{"text_search"}, + pgdumpType: "TEXT SEARCH CONFIGURATION", + wantExcluded: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var f *objectTypeFilter + if tc.include != nil || tc.exclude != nil { + var err error + f, err = newObjectTypeFilter(tc.include, tc.exclude) + require.NoError(t, err) + } + + got := f.isExcluded(tc.pgdumpType) + require.Equal(t, tc.wantExcluded, got) + }) + } +} + +func TestObjectTypeFilter_IsCategoryExcluded(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + include []string + exclude []string + category string + + wantExcluded bool + }{ + { + name: "nil filter", + category: "functions", + wantExcluded: false, + }, + { + name: "include mode - included category", + include: []string{"tables", "sequences"}, + category: "tables", + wantExcluded: false, + }, + { + name: "include mode - excluded category", + include: []string{"tables", "sequences"}, + category: "functions", + wantExcluded: true, + }, + { + name: "exclude mode - excluded category", + exclude: []string{"functions", "views"}, + category: "functions", + wantExcluded: true, + }, + { + name: "exclude mode - not excluded category", + exclude: []string{"functions", "views"}, + category: "tables", + wantExcluded: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var f *objectTypeFilter + if tc.include != nil || tc.exclude != nil { + var err error + f, err = newObjectTypeFilter(tc.include, tc.exclude) + require.NoError(t, err) + } + + got := f.isCategoryExcluded(tc.category) + require.Equal(t, tc.wantExcluded, got) + }) + } +} + +func TestObjectTypeFilter_FilterCleanupDump(t *testing.T) { + t.Parallel() + + cleanupDump := `DROP POLICY IF EXISTS posts_select_published ON app.posts; +DROP POLICY IF EXISTS posts_modify_own ON app.posts; +ALTER TABLE IF EXISTS ONLY app.posts DROP CONSTRAINT IF EXISTS posts_category_id_fkey; +ALTER TABLE IF EXISTS ONLY app.posts DROP CONSTRAINT IF EXISTS posts_author_id_fkey; +DROP TRIGGER IF EXISTS trg_users_updated_at ON app.users; +DROP TRIGGER IF EXISTS trg_posts_updated_at ON app.posts; +DROP RULE IF EXISTS protect_audit_update ON app.audit_log; +DROP INDEX IF EXISTS app.idx_users_email; +DROP INDEX IF EXISTS app.idx_posts_slug; +ALTER TABLE IF EXISTS ONLY app.users DROP CONSTRAINT IF EXISTS users_pkey; +ALTER TABLE IF EXISTS app.posts ALTER COLUMN id DROP DEFAULT; +DROP SEQUENCE IF EXISTS app.users_id_seq; +DROP VIEW IF EXISTS app.user_stats; +DROP TABLE IF EXISTS app.users; +DROP TABLE IF EXISTS app.posts; +DROP FUNCTION IF EXISTS app.my_func(); +DROP AGGREGATE IF EXISTS app.median(integer); +DROP TYPE IF EXISTS app.status; +DROP DOMAIN IF EXISTS app.email; +DROP COLLATION IF EXISTS app.case_insensitive; +DROP MATERIALIZED VIEW IF EXISTS analytics.top_posts; +DROP TEXT SEARCH CONFIGURATION IF EXISTS app.english_unaccent; +DROP SCHEMA IF EXISTS app; +` + + tests := []struct { + name string + include []string + wantContains []string + wantNotContains []string + }{ + { + name: "include tables sequences types only", + include: []string{"tables", "sequences", "types"}, + wantContains: []string{ + "DROP TABLE IF EXISTS app.users", + "DROP TABLE IF EXISTS app.posts", + "DROP SEQUENCE IF EXISTS app.users_id_seq", + "DROP TYPE IF EXISTS app.status", + "DROP DOMAIN IF EXISTS app.email", + "DROP SCHEMA IF EXISTS app", + // ALTER TABLE lines are not categorized as any specific type, + // so they pass through + "ALTER TABLE IF EXISTS ONLY app.posts DROP CONSTRAINT", + "ALTER TABLE IF EXISTS app.posts ALTER COLUMN id DROP DEFAULT", + }, + wantNotContains: []string{ + "DROP POLICY", + "DROP TRIGGER", + "DROP RULE", + "DROP INDEX", + "DROP VIEW", + "DROP FUNCTION", + "DROP AGGREGATE", + "DROP COLLATION", + "DROP MATERIALIZED VIEW", + "DROP TEXT SEARCH", + }, + }, + { + name: "include everything", + include: []string{"tables", "sequences", "types", "indexes", "constraints", "functions", "views", "materialized_views", "triggers", "event_triggers", "policies", "rules", "comments", "extensions", "collations", "text_search"}, + wantContains: []string{ + "DROP TABLE", + "DROP POLICY", + "DROP TRIGGER", + "DROP RULE", + "DROP INDEX", + "DROP VIEW", + "DROP FUNCTION", + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + f, err := newObjectTypeFilter(tc.include, nil) + require.NoError(t, err) + + result := string(f.filterCleanupDump([]byte(cleanupDump))) + + for _, s := range tc.wantContains { + require.Contains(t, result, s, "expected cleanup to contain: %s", s) + } + for _, s := range tc.wantNotContains { + require.NotContains(t, result, s, "expected cleanup NOT to contain: %s", s) + } + }) + } +} + +func TestObjectTypeFilter_FilterCleanupDump_NilFilter(t *testing.T) { + t.Parallel() + + input := []byte("DROP FUNCTION IF EXISTS app.my_func();\nDROP TABLE IF EXISTS app.users;\n") + var f *objectTypeFilter + result := f.filterCleanupDump(input) + require.Equal(t, input, result) +} + +func TestParseDump_WithObjectTypeFilter(t *testing.T) { + t.Parallel() + + // Test that parseDump excludes functions and aggregates when filtered + dumpInput := `-- +-- PostgreSQL database dump +-- + +SET statement_timeout = 0; + +-- +-- Name: public; Type: SCHEMA; Schema: -; Owner: postgres +-- + +CREATE SCHEMA public; + +-- +-- Name: my_func(); Type: FUNCTION; Schema: public; Owner: postgres +-- + +CREATE FUNCTION public.my_func() RETURNS void + LANGUAGE sql + AS $$ SELECT 1; $$; + +ALTER FUNCTION public.my_func() OWNER TO postgres; + +-- +-- Name: my_agg(integer); Type: AGGREGATE; Schema: public; Owner: postgres +-- + +CREATE AGGREGATE public.my_agg(integer) ( + SFUNC = array_append, + STYPE = integer[] +); + +-- +-- Name: users; Type: TABLE; Schema: public; Owner: postgres +-- + +CREATE TABLE public.users ( + id integer NOT NULL, + name text +); + +-- +-- Name: users_id_seq; Type: SEQUENCE; Schema: public; Owner: postgres +-- + +CREATE SEQUENCE public.users_id_seq + START WITH 1 + INCREMENT BY 1; + +-- +-- Name: users id; Type: DEFAULT; Schema: public; Owner: postgres +-- + +ALTER TABLE ONLY public.users ALTER COLUMN id SET DEFAULT nextval('public.users_id_seq'::regclass); + +-- +-- Name: users users_pkey; Type: CONSTRAINT; Schema: public; Owner: postgres +-- + +ALTER TABLE ONLY public.users + ADD CONSTRAINT users_pkey PRIMARY KEY (id); + +-- +-- Name: users_name_idx; Type: INDEX; Schema: public; Owner: postgres +-- + +CREATE INDEX users_name_idx ON public.users USING btree (name); +` + + t.Run("exclude functions", func(t *testing.T) { + t.Parallel() + + filter, err := newObjectTypeFilter(nil, []string{"functions"}) + require.NoError(t, err) + + sg := &SnapshotGenerator{ + objectTypeFilter: filter, + } + + d := sg.parseDump([]byte(dumpInput)) + + // Functions and aggregates should be excluded from the filtered dump + filteredStr := string(d.filtered) + require.NotContains(t, filteredStr, "CREATE FUNCTION") + require.NotContains(t, filteredStr, "CREATE AGGREGATE") + require.NotContains(t, filteredStr, "ALTER FUNCTION") + + // Tables and sequences should still be present + require.Contains(t, filteredStr, "CREATE TABLE public.users") + require.Contains(t, filteredStr, "CREATE SCHEMA public") + }) + + t.Run("include only tables and sequences", func(t *testing.T) { + t.Parallel() + + filter, err := newObjectTypeFilter([]string{"tables", "sequences"}, nil) + require.NoError(t, err) + + sg := &SnapshotGenerator{ + objectTypeFilter: filter, + } + + d := sg.parseDump([]byte(dumpInput)) + + filteredStr := string(d.filtered) + + // Functions, aggregates, indexes, constraints should be excluded + require.NotContains(t, filteredStr, "CREATE FUNCTION") + require.NotContains(t, filteredStr, "CREATE AGGREGATE") + + // Tables and sequences should be present + require.Contains(t, filteredStr, "CREATE TABLE public.users") + require.Contains(t, filteredStr, "CREATE SEQUENCE public.users_id_seq") + + // SCHEMA is always included + require.Contains(t, filteredStr, "CREATE SCHEMA public") + + // Preamble (SET statements) should be included + require.Contains(t, filteredStr, "SET statement_timeout = 0") + }) + + t.Run("no filter - everything included", func(t *testing.T) { + t.Parallel() + + sg := &SnapshotGenerator{} + + d := sg.parseDump([]byte(dumpInput)) + + filteredStr := string(d.filtered) + + // Everything should be present + require.Contains(t, filteredStr, "CREATE FUNCTION") + require.Contains(t, filteredStr, "CREATE AGGREGATE") + require.Contains(t, filteredStr, "CREATE TABLE public.users") + require.Contains(t, filteredStr, "CREATE SEQUENCE public.users_id_seq") + }) +} diff --git a/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go b/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go index d574e70f..b1b46e1d 100644 --- a/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go +++ b/pkg/snapshot/generator/postgres/schema/pgdumprestore/snapshot_pg_dump_restore_generator.go @@ -39,6 +39,7 @@ type SnapshotGenerator struct { roleSQLParser *roleSQLParser optionGenerator *optionGenerator snapshotTracker snapshotProgressTracker + objectTypeFilter *objectTypeFilter } type snapshotProgressTracker interface { @@ -67,6 +68,14 @@ type Config struct { DumpDebugFile string // if set, security label providers that will be excluded from the dump ExcludedSecurityLabels []string + // IncludeObjectTypes is a list of object type categories to include in the + // schema snapshot. Only one of IncludeObjectTypes or ExcludeObjectTypes + // can be set. + IncludeObjectTypes []string + // ExcludeObjectTypes is a list of object type categories to exclude from + // the schema snapshot. Only one of IncludeObjectTypes or + // ExcludeObjectTypes can be set. + ExcludeObjectTypes []string } type Option func(s *SnapshotGenerator) @@ -94,6 +103,11 @@ func NewSnapshotGenerator(ctx context.Context, c *Config, opts ...Option) (*Snap return nil, err } + objTypeFilter, err := newObjectTypeFilter(c.IncludeObjectTypes, c.ExcludeObjectTypes) + if err != nil { + return nil, fmt.Errorf("invalid object type filter config: %w", err) + } + sg := &SnapshotGenerator{ sourceURL: c.SourcePGURL, targetURL: c.TargetPGURL, @@ -106,6 +120,7 @@ func NewSnapshotGenerator(ctx context.Context, c *Config, opts ...Option) (*Snap roleSQLParser: &roleSQLParser{}, sourceQuerier: sourceConnPool, optionGenerator: newOptionGenerator(sourceConnPool, c), + objectTypeFilter: objTypeFilter, } for _, opt := range opts { @@ -187,9 +202,12 @@ func (s *SnapshotGenerator) CreateSnapshot(ctx context.Context, ss *snapshot.Sna // queries since that's considered data and it's a schema only dump. Produce // the data only dump for the sequences only and restore it along with the // schema. - sequenceDump, err := s.dumpSequenceValues(ctx, dump.sequences) - if err != nil { - return err + var sequenceDump []byte + if !s.objectTypeFilter.isCategoryExcluded("sequences") { + sequenceDump, err = s.dumpSequenceValues(ctx, dump.sequences) + if err != nil { + return err + } } // the schema dump will not include the roles, so we need to dump them @@ -233,16 +251,21 @@ func (s *SnapshotGenerator) CreateSnapshot(ctx context.Context, ss *snapshot.Sna } // apply the sequences, indices and constraints when the wrapped generator has finished - s.logger.Info("restoring sequence data", loglib.Fields{"schemaTables": ss.SchemaTables}) - if err := s.restoreDump(ctx, sequenceDump); err != nil { - return err + if !s.objectTypeFilter.isCategoryExcluded("sequences") { + s.logger.Info("restoring sequence data", loglib.Fields{"schemaTables": ss.SchemaTables}) + if err := s.restoreDump(ctx, sequenceDump); err != nil { + return err + } } - s.logger.Info("restoring schema indices and constraints", loglib.Fields{"schemaTables": ss.SchemaTables}) - if s.snapshotTracker != nil { - return s.restoreIndicesWithTracking(ctx, dump.indicesAndConstraints) + if len(dump.indicesAndConstraints) > 0 { + s.logger.Info("restoring schema indices and constraints", loglib.Fields{"schemaTables": ss.SchemaTables}) + if s.snapshotTracker != nil { + return s.restoreIndicesWithTracking(ctx, dump.indicesAndConstraints) + } + return s.restoreDump(ctx, dump.indicesAndConstraints) } - return s.restoreDump(ctx, dump.indicesAndConstraints) + return nil } func (s *SnapshotGenerator) Close() error { @@ -302,7 +325,7 @@ func (s *SnapshotGenerator) dumpSchema(ctx context.Context, schemaTables map[str s.logger.Error(err, "pg_dump for schema failed", loglib.Fields{"pgdumpOptions": pgdumpOpts.ToArgs()}) return nil, fmt.Errorf("dumping schema: %w", err) } - parsedDump.cleanupPart = getDumpsDiff(dumpWithCleanUp, d) + parsedDump.cleanupPart = s.objectTypeFilter.filterCleanupDump(getDumpsDiff(dumpWithCleanUp, d)) s.dumpToFile(s.getDumpFileName("-cleanup"), pgdumpOpts, parsedDump.cleanupPart) } @@ -410,8 +433,29 @@ func (s *SnapshotGenerator) parseDump(d []byte) *dump { dumpRoles := make(map[string]role) alterTable := "" createEventTrigger := "" + + // Object type filtering state: lines before the first TOC header (the + // preamble) are always included. Once a TOC header is encountered, the + // section is included/excluded based on the filter. + inPreamble := true + skipCurrentSection := false + for scanner.Scan() { line := scanner.Text() + + // Check for TOC header to track object type sections + if tocType, ok := parseTOCHeader(line); ok { + inPreamble = false + skipCurrentSection = s.objectTypeFilter.isExcluded(tocType) + } + + // If the current section is excluded, skip the line. We still + // need to extract role information from non-excluded sections + // for role dependency tracking. + if !inPreamble && skipCurrentSection { + continue + } + switch { case strings.HasPrefix(line, "SECURITY LABEL") && isSecurityLabelForExcludedProvider(line, s.excludedSecurityLabels): diff --git a/pkg/stream/integration/helper_test.go b/pkg/stream/integration/helper_test.go index 5c11f1f9..4ce0b92e 100644 --- a/pkg/stream/integration/helper_test.go +++ b/pkg/stream/integration/helper_test.go @@ -231,6 +231,20 @@ func withBulkIngestionEnabled() option { } } +func withDDLObjectTypeFilter(include []string) option { + return func(cfg *stream.ProcessorConfig) { + if cfg.Postgres != nil { + cfg.Postgres.BatchWriter.IncludeDDLObjectTypes = include + } + } +} + +func testPostgresListenerCfgWithSnapshotAndFilter(sourceURL, targetURL string, tables []string, includeObjectTypes []string) stream.ListenerConfig { + cfg := testPostgresListenerCfgWithSnapshot(sourceURL, targetURL, tables) + cfg.Postgres.Snapshot.Schema.DumpRestore.IncludeObjectTypes = includeObjectTypes + return cfg +} + func testPostgresProcessorCfg(opts ...option) stream.ProcessorConfig { cfg := stream.ProcessorConfig{ Postgres: &stream.PostgresProcessorConfig{ diff --git a/pkg/stream/integration/pg_pg_object_type_filter_integration_test.go b/pkg/stream/integration/pg_pg_object_type_filter_integration_test.go new file mode 100644 index 00000000..98893ce4 --- /dev/null +++ b/pkg/stream/integration/pg_pg_object_type_filter_integration_test.go @@ -0,0 +1,255 @@ +// SPDX-License-Identifier: Apache-2.0 + +package integration + +import ( + "context" + "os" + "testing" + "time" + + "github.com/stretchr/testify/require" + pglib "github.com/xataio/pgstream/internal/postgres" + "github.com/xataio/pgstream/internal/testcontainers" + "github.com/xataio/pgstream/pkg/stream" +) + +func Test_PostgresToPostgres_ObjectTypeFilter(t *testing.T) { + if os.Getenv("PGSTREAM_INTEGRATION_TESTS") == "" { + t.Skip("skipping integration test...") + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Spin up a separate source Postgres container + var sourcePGURL string + pgcleanup, err := testcontainers.SetupPostgresContainer(ctx, &sourcePGURL, testcontainers.Postgres14, "config/postgresql.conf") + require.NoError(t, err) + defer pgcleanup() + + // Load fixture SQL into the source database + fixtureSQL, err := os.ReadFile("testdata/object_types_fixture.sql") + require.NoError(t, err) + execQueryWithURL(t, ctx, sourcePGURL, string(fixtureSQL)) + + // Configure pgstream with object type filtering for both snapshot and DDL replication + includeTypes := []string{"tables", "sequences", "types"} + tables := []string{"app.*", "analytics.*"} + + cfg := &stream.Config{ + Listener: testPostgresListenerCfgWithSnapshotAndFilter(sourcePGURL, targetPGURL, tables, includeTypes), + Processor: testPostgresProcessorCfg(withDDLObjectTypeFilter(includeTypes)), + } + + initStream(t, ctx, sourcePGURL) + runStream(t, ctx, cfg) + + targetConn, err := pglib.NewConn(ctx, targetPGURL) + require.NoError(t, err) + defer targetConn.Close(ctx) + + sourceConn, err := pglib.NewConn(ctx, sourcePGURL) + require.NoError(t, err) + defer sourceConn.Close(ctx) + + // ========================================================================= + // Part 1: Snapshot assertions + // ========================================================================= + + // --- Should exist on target (poll with timeout) --- + + // Tables + require.Eventually(t, func() bool { + return tableExists(ctx, targetConn, "app", "users") && + tableExists(ctx, targetConn, "app", "posts") && + tableExists(ctx, targetConn, "app", "categories") && + tableExists(ctx, targetConn, "analytics", "page_views") && + tableExists(ctx, targetConn, "analytics", "daily_stats") + }, 30*time.Second, time.Second, "expected tables to exist on target after snapshot") + + // Types + require.Eventually(t, func() bool { + return pgTypeExists(ctx, targetConn, "app", "status") && + pgTypeExists(ctx, targetConn, "app", "address") && + pgTypeExists(ctx, targetConn, "app", "email") && + pgTypeExists(ctx, targetConn, "app", "positive_int") + }, 20*time.Second, time.Second, "expected types to exist on target after snapshot") + + // Sequences + require.Eventually(t, func() bool { + return sequenceExists(ctx, targetConn, "app", "invoice_number_seq") + }, 20*time.Second, time.Second, "expected sequence to exist on target after snapshot") + + // Data + require.Eventually(t, func() bool { + count, err := rowCount(ctx, targetConn, "app.users") + return err == nil && count == 5 + }, 20*time.Second, time.Second, "expected 5 rows in app.users") + + require.Eventually(t, func() bool { + count, err := rowCount(ctx, targetConn, "app.posts") + return err == nil && count == 6 + }, 20*time.Second, time.Second, "expected 6 rows in app.posts") + + // --- Should NOT exist on target --- + + // Functions + require.Never(t, func() bool { + return functionExists(ctx, targetConn, "app", "slugify") || + functionExists(ctx, targetConn, "app", "get_post_comment_count") + }, 5*time.Second, time.Second, "functions should not exist on target") + + // Views + require.Never(t, func() bool { + return viewExists(ctx, targetConn, "app", "published_posts") || + viewExists(ctx, targetConn, "app", "user_stats") + }, 5*time.Second, time.Second, "views should not exist on target") + + // Non-PK indexes + require.Never(t, func() bool { + return indexExists(ctx, targetConn, "app", "idx_users_email") || + indexExists(ctx, targetConn, "app", "idx_posts_slug") + }, 5*time.Second, time.Second, "non-PK indexes should not exist on target") + + // Materialized views + require.Never(t, func() bool { + return matviewExists(ctx, targetConn, "analytics", "top_posts") + }, 5*time.Second, time.Second, "materialized views should not exist on target") + + // ========================================================================= + // Part 2: DDL replication assertions + // ========================================================================= + + // --- Should replicate --- + + // CREATE TABLE + execQueryWithURL(t, ctx, sourcePGURL, "CREATE TABLE app.filter_test(id serial PRIMARY KEY, val text)") + require.Eventually(t, func() bool { + return tableExists(ctx, targetConn, "app", "filter_test") + }, 20*time.Second, time.Second, "expected app.filter_test table to replicate") + + // ALTER TABLE ADD COLUMN + execQueryWithURL(t, ctx, sourcePGURL, "ALTER TABLE app.filter_test ADD COLUMN extra int") + require.Eventually(t, func() bool { + return columnExists(ctx, targetConn, "app", "filter_test", "extra") + }, 20*time.Second, time.Second, "expected extra column to replicate") + + // INSERT DATA + execQueryWithURL(t, ctx, sourcePGURL, "INSERT INTO app.filter_test(val) VALUES('hello')") + require.Eventually(t, func() bool { + count, err := rowCount(ctx, targetConn, "app.filter_test") + return err == nil && count == 1 + }, 20*time.Second, time.Second, "expected data to replicate") + + // CREATE TYPE + execQueryWithURL(t, ctx, sourcePGURL, "CREATE TYPE app.priority AS ENUM ('low','medium','high')") + require.Eventually(t, func() bool { + return pgTypeExists(ctx, targetConn, "app", "priority") + }, 20*time.Second, time.Second, "expected app.priority type to replicate") + + // --- Should NOT replicate --- + + // CREATE FUNCTION + execQueryWithURL(t, ctx, sourcePGURL, `CREATE FUNCTION app.filter_test_fn() RETURNS text AS $$ SELECT 'test'::text; $$ LANGUAGE sql`) + require.Never(t, func() bool { + return functionExists(ctx, targetConn, "app", "filter_test_fn") + }, 5*time.Second, time.Second, "function should not replicate") + + // CREATE VIEW + execQueryWithURL(t, ctx, sourcePGURL, "CREATE VIEW app.filter_test_view AS SELECT id, val FROM app.filter_test") + require.Never(t, func() bool { + return viewExists(ctx, targetConn, "app", "filter_test_view") + }, 5*time.Second, time.Second, "view should not replicate") + + // CREATE INDEX + execQueryWithURL(t, ctx, sourcePGURL, "CREATE INDEX filter_test_idx ON app.filter_test(val)") + require.Never(t, func() bool { + return indexExists(ctx, targetConn, "app", "filter_test_idx") + }, 5*time.Second, time.Second, "index should not replicate") +} + +// --- Catalog query helpers --- + +func tableExists(ctx context.Context, conn pglib.Querier, schema, table string) bool { + var exists bool + err := conn.QueryRow(ctx, []any{&exists}, + `SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2)`, + schema, table) + return err == nil && exists +} + +func pgTypeExists(ctx context.Context, conn pglib.Querier, schema, typeName string) bool { + var exists bool + err := conn.QueryRow(ctx, []any{&exists}, + `SELECT EXISTS( + SELECT 1 FROM pg_type t + JOIN pg_namespace n ON n.oid = t.typnamespace + WHERE n.nspname = $1 AND t.typname = $2 + )`, + schema, typeName) + return err == nil && exists +} + +func sequenceExists(ctx context.Context, conn pglib.Querier, schema, seqName string) bool { + var exists bool + err := conn.QueryRow(ctx, []any{&exists}, + `SELECT EXISTS(SELECT 1 FROM information_schema.sequences WHERE sequence_schema = $1 AND sequence_name = $2)`, + schema, seqName) + return err == nil && exists +} + +func functionExists(ctx context.Context, conn pglib.Querier, schema, funcName string) bool { + var exists bool + err := conn.QueryRow(ctx, []any{&exists}, + `SELECT EXISTS( + SELECT 1 FROM pg_proc p + JOIN pg_namespace n ON n.oid = p.pronamespace + WHERE n.nspname = $1 AND p.proname = $2 + )`, + schema, funcName) + return err == nil && exists +} + +func viewExists(ctx context.Context, conn pglib.Querier, schema, viewName string) bool { + var exists bool + err := conn.QueryRow(ctx, []any{&exists}, + `SELECT EXISTS(SELECT 1 FROM information_schema.views WHERE table_schema = $1 AND table_name = $2)`, + schema, viewName) + return err == nil && exists +} + +func indexExists(ctx context.Context, conn pglib.Querier, schema, indexName string) bool { + var exists bool + err := conn.QueryRow(ctx, []any{&exists}, + `SELECT EXISTS(SELECT 1 FROM pg_indexes WHERE schemaname = $1 AND indexname = $2)`, + schema, indexName) + return err == nil && exists +} + +func matviewExists(ctx context.Context, conn pglib.Querier, schema, matviewName string) bool { + var exists bool + err := conn.QueryRow(ctx, []any{&exists}, + `SELECT EXISTS( + SELECT 1 FROM pg_matviews WHERE schemaname = $1 AND matviewname = $2 + )`, + schema, matviewName) + return err == nil && exists +} + +func columnExists(ctx context.Context, conn pglib.Querier, schema, table, column string) bool { + var exists bool + err := conn.QueryRow(ctx, []any{&exists}, + `SELECT EXISTS( + SELECT 1 FROM information_schema.columns + WHERE table_schema = $1 AND table_name = $2 AND column_name = $3 + )`, + schema, table, column) + return err == nil && exists +} + +func rowCount(ctx context.Context, conn pglib.Querier, qualifiedTable string) (int, error) { + var count int + err := conn.QueryRow(ctx, []any{&count}, "SELECT count(*) FROM "+qualifiedTable) + return count, err +} diff --git a/pkg/stream/integration/testdata/object_types_fixture.sql b/pkg/stream/integration/testdata/object_types_fixture.sql new file mode 100644 index 00000000..77ffe75b --- /dev/null +++ b/pkg/stream/integration/testdata/object_types_fixture.sql @@ -0,0 +1,414 @@ +-- Fixture for object type filtering integration test +-- Creates objects across all filterable categories +-- Uses gen_random_uuid() instead of uuid_generate_v4() to avoid uuid-ossp dependency + +BEGIN; + +-- ============================================================================= +-- SCHEMA +-- ============================================================================= +CREATE SCHEMA IF NOT EXISTS app; +CREATE SCHEMA IF NOT EXISTS analytics; + +-- ============================================================================= +-- EXTENSIONS +-- ============================================================================= +CREATE EXTENSION IF NOT EXISTS pg_trgm WITH SCHEMA public; + +-- ============================================================================= +-- TYPES (TYPE, DOMAIN) +-- ============================================================================= +CREATE TYPE app.status AS ENUM ('active', 'inactive', 'suspended', 'deleted'); + +CREATE TYPE app.address AS ( + street TEXT, + city TEXT, + state TEXT, + zip TEXT +); + +CREATE DOMAIN app.email AS TEXT + CHECK (VALUE ~ '^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'); + +CREATE DOMAIN app.positive_int AS INTEGER + CHECK (VALUE > 0); + +-- ============================================================================= +-- TABLES (TABLE, DEFAULT) +-- ============================================================================= +CREATE TABLE app.users ( + id SERIAL PRIMARY KEY, + uuid UUID DEFAULT gen_random_uuid() NOT NULL, + username VARCHAR(100) NOT NULL, + email app.email NOT NULL, + status app.status DEFAULT 'active' NOT NULL, + address app.address, + score app.positive_int DEFAULT 1, + bio TEXT, + created_at TIMESTAMPTZ DEFAULT now() NOT NULL, + updated_at TIMESTAMPTZ DEFAULT now() NOT NULL, + CONSTRAINT users_username_length CHECK (char_length(username) >= 2) +); + +CREATE TABLE app.categories ( + id SERIAL PRIMARY KEY, + name VARCHAR(200) NOT NULL, + slug VARCHAR(200) NOT NULL, + parent_id INTEGER REFERENCES app.categories(id), + sort_order INTEGER DEFAULT 0, + created_at TIMESTAMPTZ DEFAULT now() NOT NULL +); + +CREATE TABLE app.posts ( + id SERIAL PRIMARY KEY, + author_id INTEGER NOT NULL REFERENCES app.users(id), + category_id INTEGER REFERENCES app.categories(id), + title VARCHAR(500) NOT NULL, + slug VARCHAR(500) NOT NULL, + body TEXT NOT NULL, + is_draft BOOLEAN DEFAULT true NOT NULL, + view_count INTEGER DEFAULT 0 NOT NULL, + published_at TIMESTAMPTZ, + created_at TIMESTAMPTZ DEFAULT now() NOT NULL, + updated_at TIMESTAMPTZ DEFAULT now() NOT NULL +); + +CREATE TABLE app.tags ( + id SERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL UNIQUE +); + +CREATE TABLE app.post_tags ( + post_id INTEGER NOT NULL REFERENCES app.posts(id) ON DELETE CASCADE, + tag_id INTEGER NOT NULL REFERENCES app.tags(id) ON DELETE CASCADE, + PRIMARY KEY (post_id, tag_id) +); + +CREATE TABLE app.comments ( + id SERIAL PRIMARY KEY, + post_id INTEGER NOT NULL REFERENCES app.posts(id) ON DELETE CASCADE, + user_id INTEGER NOT NULL REFERENCES app.users(id), + parent_id INTEGER REFERENCES app.comments(id), + body TEXT NOT NULL, + created_at TIMESTAMPTZ DEFAULT now() NOT NULL +); + +CREATE TABLE app.audit_log ( + id BIGSERIAL PRIMARY KEY, + table_name TEXT NOT NULL, + record_id INTEGER NOT NULL, + action TEXT NOT NULL, + old_data JSONB, + new_data JSONB, + changed_by INTEGER REFERENCES app.users(id), + changed_at TIMESTAMPTZ DEFAULT now() NOT NULL +); + +CREATE TABLE analytics.page_views ( + id BIGSERIAL PRIMARY KEY, + path TEXT NOT NULL, + user_id INTEGER, + referrer TEXT, + user_agent TEXT, + viewed_at TIMESTAMPTZ DEFAULT now() NOT NULL +); + +CREATE TABLE analytics.daily_stats ( + stat_date DATE NOT NULL, + total_views INTEGER DEFAULT 0, + total_users INTEGER DEFAULT 0, + total_posts INTEGER DEFAULT 0, + PRIMARY KEY (stat_date) +); + +-- ============================================================================= +-- SEQUENCES (beyond the implicit SERIAL ones) +-- ============================================================================= +CREATE SEQUENCE app.invoice_number_seq + START WITH 1000 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 10; + +-- ============================================================================= +-- FUNCTIONS +-- ============================================================================= +CREATE OR REPLACE FUNCTION app.update_updated_at() +RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at = now(); + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE OR REPLACE FUNCTION app.audit_changes() +RETURNS TRIGGER AS $$ +BEGIN + IF TG_OP = 'INSERT' THEN + INSERT INTO app.audit_log (table_name, record_id, action, new_data, changed_by) + VALUES (TG_TABLE_NAME, NEW.id, 'INSERT', to_jsonb(NEW), NEW.id); + RETURN NEW; + ELSIF TG_OP = 'UPDATE' THEN + INSERT INTO app.audit_log (table_name, record_id, action, old_data, new_data, changed_by) + VALUES (TG_TABLE_NAME, NEW.id, 'UPDATE', to_jsonb(OLD), to_jsonb(NEW), NEW.id); + RETURN NEW; + ELSIF TG_OP = 'DELETE' THEN + INSERT INTO app.audit_log (table_name, record_id, action, old_data) + VALUES (TG_TABLE_NAME, OLD.id, 'DELETE', to_jsonb(OLD)); + RETURN OLD; + END IF; + RETURN NULL; +END; +$$ LANGUAGE plpgsql; + +CREATE OR REPLACE FUNCTION app.slugify(input TEXT) +RETURNS TEXT AS $$ +BEGIN + RETURN lower(regexp_replace(regexp_replace(input, '[^a-zA-Z0-9\s-]', '', 'g'), '\s+', '-', 'g')); +END; +$$ LANGUAGE plpgsql IMMUTABLE; + +CREATE OR REPLACE FUNCTION app.get_post_comment_count(p_post_id INTEGER) +RETURNS INTEGER AS $$ + SELECT count(*)::integer FROM app.comments WHERE post_id = p_post_id; +$$ LANGUAGE sql STABLE; + +CREATE OR REPLACE FUNCTION analytics.increment_view(p_path TEXT, p_user_id INTEGER DEFAULT NULL) +RETURNS void AS $$ +BEGIN + INSERT INTO analytics.page_views (path, user_id) VALUES (p_path, p_user_id); + INSERT INTO analytics.daily_stats (stat_date, total_views) + VALUES (current_date, 1) + ON CONFLICT (stat_date) DO UPDATE SET total_views = analytics.daily_stats.total_views + 1; +END; +$$ LANGUAGE plpgsql; + +-- ============================================================================= +-- VIEWS +-- ============================================================================= +CREATE VIEW app.published_posts AS + SELECT p.id, p.title, p.slug, p.body, p.view_count, p.published_at, + u.username AS author, c.name AS category + FROM app.posts p + JOIN app.users u ON u.id = p.author_id + LEFT JOIN app.categories c ON c.id = p.category_id + WHERE p.is_draft = false AND p.published_at IS NOT NULL; + +CREATE VIEW app.user_stats AS + SELECT u.id, u.username, + count(DISTINCT p.id) AS post_count, + count(DISTINCT cm.id) AS comment_count, + coalesce(sum(p.view_count), 0) AS total_views + FROM app.users u + LEFT JOIN app.posts p ON p.author_id = u.id + LEFT JOIN app.comments cm ON cm.user_id = u.id + GROUP BY u.id, u.username; + +-- ============================================================================= +-- MATERIALIZED VIEWS +-- ============================================================================= +CREATE MATERIALIZED VIEW analytics.top_posts AS + SELECT p.id, p.title, p.slug, u.username AS author, + p.view_count, count(c.id) AS comment_count, + p.published_at + FROM app.posts p + JOIN app.users u ON u.id = p.author_id + LEFT JOIN app.comments c ON c.post_id = p.id + WHERE p.is_draft = false + GROUP BY p.id, p.title, p.slug, u.username, p.view_count, p.published_at + ORDER BY p.view_count DESC +WITH DATA; + +-- ============================================================================= +-- INDEXES +-- ============================================================================= +CREATE INDEX idx_users_username ON app.users (username); +CREATE INDEX idx_users_email ON app.users (email); +CREATE INDEX idx_users_status ON app.users (status); +CREATE INDEX idx_users_username_trgm ON app.users USING gin (username gin_trgm_ops); + +CREATE UNIQUE INDEX idx_categories_slug ON app.categories (slug); +CREATE INDEX idx_categories_parent ON app.categories (parent_id); + +CREATE UNIQUE INDEX idx_posts_slug ON app.posts (slug); +CREATE INDEX idx_posts_author ON app.posts (author_id); +CREATE INDEX idx_posts_category ON app.posts (category_id); +CREATE INDEX idx_posts_published ON app.posts (published_at) WHERE is_draft = false; +CREATE INDEX idx_posts_body_search ON app.posts USING gin (to_tsvector('english', body)); + +CREATE INDEX idx_comments_post ON app.comments (post_id); +CREATE INDEX idx_comments_user ON app.comments (user_id); +CREATE INDEX idx_comments_parent ON app.comments (parent_id); + +CREATE INDEX idx_audit_table_record ON app.audit_log (table_name, record_id); +CREATE INDEX idx_audit_changed_at ON app.audit_log (changed_at); + +CREATE INDEX idx_page_views_path ON analytics.page_views (path); +CREATE INDEX idx_page_views_time ON analytics.page_views (viewed_at); + +CREATE UNIQUE INDEX idx_top_posts_id ON analytics.top_posts (id); + +-- ============================================================================= +-- TRIGGERS +-- ============================================================================= +CREATE TRIGGER trg_users_updated_at + BEFORE UPDATE ON app.users + FOR EACH ROW EXECUTE FUNCTION app.update_updated_at(); + +CREATE TRIGGER trg_posts_updated_at + BEFORE UPDATE ON app.posts + FOR EACH ROW EXECUTE FUNCTION app.update_updated_at(); + +CREATE TRIGGER trg_users_audit + AFTER INSERT OR UPDATE OR DELETE ON app.users + FOR EACH ROW EXECUTE FUNCTION app.audit_changes(); + +-- ============================================================================= +-- RULES +-- ============================================================================= +CREATE RULE protect_audit_delete AS + ON DELETE TO app.audit_log + DO INSTEAD NOTHING; + +CREATE RULE protect_audit_update AS + ON UPDATE TO app.audit_log + DO INSTEAD NOTHING; + +-- ============================================================================= +-- POLICIES (Row Level Security) +-- ============================================================================= +ALTER TABLE app.posts ENABLE ROW LEVEL SECURITY; + +CREATE POLICY posts_select_published ON app.posts + FOR SELECT + USING (is_draft = false OR author_id = current_setting('app.current_user_id', true)::integer); + +CREATE POLICY posts_modify_own ON app.posts + FOR ALL + USING (author_id = current_setting('app.current_user_id', true)::integer) + WITH CHECK (author_id = current_setting('app.current_user_id', true)::integer); + +-- ============================================================================= +-- TEXT SEARCH +-- ============================================================================= +CREATE TEXT SEARCH CONFIGURATION app.english_unaccent (COPY = pg_catalog.english); + +-- ============================================================================= +-- COMMENTS (on objects) +-- ============================================================================= +COMMENT ON TABLE app.users IS 'Application users'; +COMMENT ON TABLE app.posts IS 'Blog posts written by users'; +COMMENT ON COLUMN app.users.score IS 'User reputation score, must be positive'; +COMMENT ON FUNCTION app.slugify(TEXT) IS 'Convert text to URL-friendly slug'; +COMMENT ON INDEX app.idx_posts_published IS 'Partial index for fast published post lookups'; +COMMENT ON VIEW app.published_posts IS 'Published posts with author and category info'; + +-- ============================================================================= +-- DATA +-- ============================================================================= + +-- Users +INSERT INTO app.users (username, email, status, bio, score) VALUES + ('alice', 'alice@example.com', 'active', 'Software engineer and blogger', 42), + ('bob', 'bob@example.com', 'active', 'DevOps enthusiast', 28), + ('charlie', 'charlie@example.com', 'active', 'Full-stack developer', 35), + ('diana', 'diana@example.com', 'inactive', 'Data scientist on sabbatical', 19), + ('eve', 'eve@example.com', 'suspended', NULL, 5); + +-- Categories +INSERT INTO app.categories (name, slug, parent_id, sort_order) VALUES + ('Technology', 'technology', NULL, 1), + ('Programming', 'programming', 1, 1), + ('Databases', 'databases', 1, 2), + ('DevOps', 'devops', 1, 3), + ('Lifestyle', 'lifestyle', NULL, 2), + ('Productivity', 'productivity', 5, 1); + +-- Tags +INSERT INTO app.tags (name) VALUES + ('postgresql'), ('go'), ('docker'), ('kubernetes'), ('tutorial'), + ('beginner'), ('advanced'), ('opinion'), ('howto'), ('announcement'); + +-- Posts +INSERT INTO app.posts (author_id, category_id, title, slug, body, is_draft, view_count, published_at) VALUES + (1, 2, 'Getting Started with Go', + 'getting-started-with-go', + 'Go is a statically typed, compiled language designed at Google.', + false, 1523, now() - interval '30 days'), + + (1, 3, 'PostgreSQL Change Data Capture with pgstream', + 'postgresql-cdc-pgstream', + 'pgstream is a powerful CDC tool for PostgreSQL.', + false, 2847, now() - interval '14 days'), + + (2, 4, 'Docker Compose for Local Development', + 'docker-compose-local-dev', + 'Docker Compose simplifies multi-container development.', + false, 956, now() - interval '7 days'), + + (3, 2, 'Understanding Interfaces in Go', + 'understanding-interfaces-go', + 'Interfaces are one of the most powerful features of Go.', + false, 1102, now() - interval '3 days'), + + (3, 3, 'Advanced PostgreSQL Indexing Strategies', + 'advanced-pg-indexing', + 'Proper indexing is crucial for database performance.', + true, 0, NULL), + + (1, 6, 'My Productivity Setup in 2025', + 'productivity-setup-2025', + 'After years of experimenting, here is my current productivity stack.', + false, 445, now() - interval '1 day'); + +-- Post tags +INSERT INTO app.post_tags (post_id, tag_id) VALUES + (1, 2), (1, 5), (1, 6), + (2, 1), (2, 5), (2, 9), + (3, 3), (3, 4), (3, 5), + (4, 2), (4, 7), + (5, 1), (5, 7), + (6, 8); + +-- Comments +INSERT INTO app.comments (post_id, user_id, parent_id, body) VALUES + (1, 2, NULL, 'Great introduction!'), + (1, 3, 1, 'Same here.'), + (1, 4, NULL, 'Could you do a follow-up?'), + (2, 3, NULL, 'pgstream looks promising.'), + (2, 1, 4, 'Schema changes are captured via event triggers.'), + (2, 2, NULL, 'We have been using this in production.'), + (3, 1, NULL, 'Nice setup!'), + (4, 1, NULL, 'The implicit interface satisfaction is elegant.'), + (4, 2, 8, 'Agreed.'); + +-- Page views +INSERT INTO analytics.page_views (path, user_id, viewed_at) VALUES + ('/posts/getting-started-with-go', 1, now() - interval '2 days'), + ('/posts/getting-started-with-go', 2, now() - interval '2 days'), + ('/posts/getting-started-with-go', NULL, now() - interval '1 day'), + ('/posts/postgresql-cdc-pgstream', 3, now() - interval '1 day'), + ('/posts/postgresql-cdc-pgstream', NULL, now() - interval '6 hours'), + ('/posts/docker-compose-local-dev', 2, now() - interval '3 hours'), + ('/posts/understanding-interfaces-go', 1, now() - interval '1 hour'), + ('/posts/productivity-setup-2025', 4, now() - interval '30 minutes'); + +-- Daily stats +INSERT INTO analytics.daily_stats (stat_date, total_views, total_users, total_posts) VALUES + (current_date - 7, 142, 45, 1), + (current_date - 6, 198, 62, 0), + (current_date - 5, 167, 51, 0), + (current_date - 4, 203, 68, 1), + (current_date - 3, 189, 59, 0), + (current_date - 2, 256, 82, 1), + (current_date - 1, 312, 95, 1), + (current_date, 87, 31, 0); + +-- Set a sequence value +SELECT setval('app.invoice_number_seq', 1042); + +-- Refresh the materialized view with the inserted data +REFRESH MATERIALIZED VIEW analytics.top_posts; + +COMMIT; diff --git a/pkg/wal/processor/postgres/config.go b/pkg/wal/processor/postgres/config.go index 2c5fcd0e..e7aa2afa 100644 --- a/pkg/wal/processor/postgres/config.go +++ b/pkg/wal/processor/postgres/config.go @@ -17,6 +17,14 @@ type Config struct { BulkIngestEnabled bool RetryPolicy backoff.Config IgnoreDDL bool + // IncludeDDLObjectTypes is a list of object type categories for which + // DDL should be replicated. Only one of IncludeDDLObjectTypes or + // ExcludeDDLObjectTypes can be set. Ignored if IgnoreDDL is true. + IncludeDDLObjectTypes []string + // ExcludeDDLObjectTypes is a list of object type categories for which + // DDL should be skipped. Only one of IncludeDDLObjectTypes or + // ExcludeDDLObjectTypes can be set. Ignored if IgnoreDDL is true. + ExcludeDDLObjectTypes []string } const ( diff --git a/pkg/wal/processor/postgres/ddl_object_type_filter.go b/pkg/wal/processor/postgres/ddl_object_type_filter.go new file mode 100644 index 00000000..1f34a4c2 --- /dev/null +++ b/pkg/wal/processor/postgres/ddl_object_type_filter.go @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "fmt" + "strings" + + "github.com/xataio/pgstream/pkg/wal" +) + +// ddlObjectTypeCategories maps user-facing category names to DDL object type +// values as reported by pg_event_trigger_ddl_commands() and +// pg_event_trigger_dropped_objects(). +var ddlObjectTypeCategories = map[string][]string{ + "tables": {"table", "table column"}, + "sequences": {"sequence"}, + "types": {"type", "domain"}, + "indexes": {"index"}, + "constraints": {"table constraint"}, + "functions": {"function", "aggregate", "procedure"}, + "views": {"view"}, + "materialized_views": {"materialized view"}, + "triggers": {"trigger"}, + "event_triggers": {"event trigger"}, + "policies": {"policy"}, + "rules": {"rule"}, + "extensions": {"extension"}, + "collations": {"collation"}, + "text_search": {"text search configuration", "text search dictionary", "text search parser", "text search template"}, +} + +// commandTagCategories maps command tag prefixes to categories for fallback +// when DDL events have no objects. +var commandTagCategories = map[string]string{ + "CREATE TABLE": "tables", + "ALTER TABLE": "tables", + "DROP TABLE": "tables", + "CREATE SEQUENCE": "sequences", + "ALTER SEQUENCE": "sequences", + "DROP SEQUENCE": "sequences", + "CREATE TYPE": "types", + "ALTER TYPE": "types", + "DROP TYPE": "types", + "CREATE DOMAIN": "types", + "ALTER DOMAIN": "types", + "DROP DOMAIN": "types", + "CREATE INDEX": "indexes", + "ALTER INDEX": "indexes", + "DROP INDEX": "indexes", + "CREATE FUNCTION": "functions", + "ALTER FUNCTION": "functions", + "DROP FUNCTION": "functions", + "CREATE AGGREGATE": "functions", + "DROP AGGREGATE": "functions", + "CREATE PROCEDURE": "functions", + "ALTER PROCEDURE": "functions", + "DROP PROCEDURE": "functions", + "CREATE VIEW": "views", + "ALTER VIEW": "views", + "DROP VIEW": "views", + "CREATE MATERIALIZED VIEW": "materialized_views", + "ALTER MATERIALIZED VIEW": "materialized_views", + "DROP MATERIALIZED VIEW": "materialized_views", + "CREATE TRIGGER": "triggers", + "ALTER TRIGGER": "triggers", + "DROP TRIGGER": "triggers", + "CREATE EVENT TRIGGER": "event_triggers", + "ALTER EVENT TRIGGER": "event_triggers", + "DROP EVENT TRIGGER": "event_triggers", + "CREATE POLICY": "policies", + "ALTER POLICY": "policies", + "DROP POLICY": "policies", + "CREATE RULE": "rules", + "ALTER RULE": "rules", + "DROP RULE": "rules", + "CREATE EXTENSION": "extensions", + "ALTER EXTENSION": "extensions", + "DROP EXTENSION": "extensions", + "CREATE COLLATION": "collations", + "ALTER COLLATION": "collations", + "DROP COLLATION": "collations", + "CREATE TEXT SEARCH CONFIGURATION": "text_search", + "ALTER TEXT SEARCH CONFIGURATION": "text_search", + "DROP TEXT SEARCH CONFIGURATION": "text_search", +} + +// ddlObjectTypeFilter determines which DDL events should be skipped based on +// user-specified include or exclude category lists. +type ddlObjectTypeFilter struct { + excludedDDLTypes map[string]struct{} + excludedCategories map[string]struct{} +} + +// newDDLObjectTypeFilter creates a ddlObjectTypeFilter from include/exclude +// category lists. Only one of include or exclude can be set (not both). +// Returns nil if neither is set (no filtering). +func newDDLObjectTypeFilter(include, exclude []string) (*ddlObjectTypeFilter, error) { + if len(include) > 0 && len(exclude) > 0 { + return nil, fmt.Errorf("include_ddl_object_types and exclude_ddl_object_types cannot both be set") + } + + if len(include) == 0 && len(exclude) == 0 { + return nil, nil + } + + f := &ddlObjectTypeFilter{ + excludedDDLTypes: make(map[string]struct{}), + excludedCategories: make(map[string]struct{}), + } + + if len(include) > 0 { + includedSet := make(map[string]struct{}, len(include)) + for _, cat := range include { + if _, ok := ddlObjectTypeCategories[cat]; !ok { + return nil, fmt.Errorf("unknown DDL object type category: %q", cat) + } + includedSet[cat] = struct{}{} + } + for cat, types := range ddlObjectTypeCategories { + if _, included := includedSet[cat]; !included { + f.excludedCategories[cat] = struct{}{} + for _, t := range types { + f.excludedDDLTypes[t] = struct{}{} + } + } + } + } else { + for _, cat := range exclude { + types, ok := ddlObjectTypeCategories[cat] + if !ok { + return nil, fmt.Errorf("unknown DDL object type category: %q", cat) + } + f.excludedCategories[cat] = struct{}{} + for _, t := range types { + f.excludedDDLTypes[t] = struct{}{} + } + } + } + + return f, nil +} + +// shouldSkipDDL returns true if the DDL event should be skipped based on the +// object types involved. +func (f *ddlObjectTypeFilter) shouldSkipDDL(ddlEvent *wal.DDLEvent) bool { + if f == nil || ddlEvent == nil { + return false + } + + // If there are objects, check their types. A single DDL statement (e.g., + // CREATE TABLE with PRIMARY KEY) can produce multiple objects of different + // types (table + index). Only skip if ALL objects are of excluded types; + // if any object is of an included type, the event should be executed. + if len(ddlEvent.Objects) > 0 { + for _, obj := range ddlEvent.Objects { + objType := strings.ToLower(obj.Type) + if _, excluded := f.excludedDDLTypes[objType]; !excluded { + return false + } + } + return true + } + + // Fallback: parse the command tag + tag := ddlEvent.CommandTag + for prefix, cat := range commandTagCategories { + if strings.HasPrefix(tag, prefix) { + _, excluded := f.excludedCategories[cat] + return excluded + } + } + + return false +} diff --git a/pkg/wal/processor/postgres/ddl_object_type_filter_test.go b/pkg/wal/processor/postgres/ddl_object_type_filter_test.go new file mode 100644 index 00000000..7e43ec35 --- /dev/null +++ b/pkg/wal/processor/postgres/ddl_object_type_filter_test.go @@ -0,0 +1,209 @@ +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/xataio/pgstream/pkg/wal" +) + +func TestNewDDLObjectTypeFilter(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + include []string + exclude []string + + wantNil bool + wantErr bool + }{ + { + name: "no filter configured", + wantNil: true, + }, + { + name: "include only", + include: []string{"tables", "sequences", "types"}, + }, + { + name: "exclude only", + exclude: []string{"functions", "views"}, + }, + { + name: "both set - error", + include: []string{"tables"}, + exclude: []string{"functions"}, + wantErr: true, + }, + { + name: "unknown include category", + include: []string{"unknown_category"}, + wantErr: true, + }, + { + name: "unknown exclude category", + exclude: []string{"unknown_category"}, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + f, err := newDDLObjectTypeFilter(tc.include, tc.exclude) + if tc.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + + if tc.wantNil { + require.Nil(t, f) + return + } + + require.NotNil(t, f) + }) + } +} + +func TestDDLObjectTypeFilter_ShouldSkipDDL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + include []string + exclude []string + ddlEvent *wal.DDLEvent + + wantSkip bool + }{ + { + name: "nil filter - no skip", + ddlEvent: &wal.DDLEvent{ + Objects: []wal.DDLObject{{Type: "function"}}, + }, + wantSkip: false, + }, + { + name: "nil ddl event - no skip", + exclude: []string{"functions"}, + wantSkip: false, + }, + { + name: "exclude functions - function object skipped", + exclude: []string{"functions"}, + ddlEvent: &wal.DDLEvent{ + CommandTag: "CREATE FUNCTION", + Objects: []wal.DDLObject{{Type: "function", Identity: "public.my_func", Schema: "public"}}, + }, + wantSkip: true, + }, + { + name: "exclude functions - aggregate object skipped", + exclude: []string{"functions"}, + ddlEvent: &wal.DDLEvent{ + CommandTag: "CREATE AGGREGATE", + Objects: []wal.DDLObject{{Type: "aggregate", Identity: "public.my_agg", Schema: "public"}}, + }, + wantSkip: true, + }, + { + name: "exclude functions - table object not skipped", + exclude: []string{"functions"}, + ddlEvent: &wal.DDLEvent{ + CommandTag: "CREATE TABLE", + Objects: []wal.DDLObject{{Type: "table", Identity: "public.users", Schema: "public"}}, + }, + wantSkip: false, + }, + { + name: "include tables only - function skipped", + include: []string{"tables", "sequences"}, + ddlEvent: &wal.DDLEvent{ + CommandTag: "CREATE FUNCTION", + Objects: []wal.DDLObject{{Type: "function", Identity: "public.my_func", Schema: "public"}}, + }, + wantSkip: true, + }, + { + name: "include tables only - table not skipped", + include: []string{"tables", "sequences"}, + ddlEvent: &wal.DDLEvent{ + CommandTag: "CREATE TABLE", + Objects: []wal.DDLObject{{Type: "table", Identity: "public.users", Schema: "public"}}, + }, + wantSkip: false, + }, + { + name: "include tables - mixed objects with table and index - not skipped", + include: []string{"tables", "sequences"}, + ddlEvent: &wal.DDLEvent{ + CommandTag: "CREATE TABLE", + Objects: []wal.DDLObject{ + {Type: "table", Identity: "public.users", Schema: "public"}, + {Type: "index", Identity: "public.users_pkey", Schema: "public"}, + }, + }, + wantSkip: false, + }, + { + name: "include tables - all objects excluded - skipped", + include: []string{"tables", "sequences"}, + ddlEvent: &wal.DDLEvent{ + CommandTag: "CREATE INDEX", + Objects: []wal.DDLObject{ + {Type: "index", Identity: "public.users_name_idx", Schema: "public"}, + }, + }, + wantSkip: true, + }, + { + name: "fallback to command tag - no objects - function skipped", + exclude: []string{"functions"}, + ddlEvent: &wal.DDLEvent{ + CommandTag: "CREATE FUNCTION", + Objects: nil, + }, + wantSkip: true, + }, + { + name: "fallback to command tag - no objects - table not skipped", + exclude: []string{"functions"}, + ddlEvent: &wal.DDLEvent{ + CommandTag: "CREATE TABLE", + Objects: nil, + }, + wantSkip: false, + }, + { + name: "fallback to command tag - unknown tag - not skipped", + exclude: []string{"functions"}, + ddlEvent: &wal.DDLEvent{ + CommandTag: "SOME_UNKNOWN_COMMAND", + Objects: nil, + }, + wantSkip: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var f *ddlObjectTypeFilter + if tc.include != nil || tc.exclude != nil { + var err error + f, err = newDDLObjectTypeFilter(tc.include, tc.exclude) + require.NoError(t, err) + } + + got := f.shouldSkipDDL(tc.ddlEvent) + require.Equal(t, tc.wantSkip, got) + }) + } +} diff --git a/pkg/wal/processor/postgres/postgres_wal_adapter.go b/pkg/wal/processor/postgres/postgres_wal_adapter.go index 1dd75452..8afc2238 100644 --- a/pkg/wal/processor/postgres/postgres_wal_adapter.go +++ b/pkg/wal/processor/postgres/postgres_wal_adapter.go @@ -4,6 +4,7 @@ package postgres import ( "context" + "fmt" loglib "github.com/xataio/pgstream/pkg/log" "github.com/xataio/pgstream/pkg/wal" @@ -36,38 +37,44 @@ type schemaInfo struct { } type adapter struct { - dmlAdapter dmlQueryAdapter - ddlAdapter ddlQueryAdapter - ddlEventAdapter ddlEventAdapter - - schemaObserver schemaObserver + dmlAdapter dmlQueryAdapter + ddlAdapter ddlQueryAdapter + ddlEventAdapter ddlEventAdapter + ddlObjectTypeFilter *ddlObjectTypeFilter + schemaObserver schemaObserver } type ( ddlEventAdapter func(*wal.Data) (*wal.DDLEvent, error) ) -func newAdapter(ctx context.Context, logger loglib.Logger, ignoreDDL bool, pgURL string, onConflictAction string, forCopy bool) (*adapter, error) { - schemaObserver, err := newPGSchemaObserver(ctx, pgURL, logger) +func newAdapter(ctx context.Context, logger loglib.Logger, config *Config, forCopy bool) (*adapter, error) { + schemaObserver, err := newPGSchemaObserver(ctx, config.URL, logger) if err != nil { return nil, err } - dmlAdapter, err := newDMLAdapter(onConflictAction, forCopy, logger) + dmlAdapter, err := newDMLAdapter(config.OnConflictAction, forCopy, logger) if err != nil { return nil, err } var ddl ddlQueryAdapter - if !ignoreDDL { + var ddlFilter *ddlObjectTypeFilter + if !config.IgnoreDDL { ddl = newDDLAdapter() + ddlFilter, err = newDDLObjectTypeFilter(config.IncludeDDLObjectTypes, config.ExcludeDDLObjectTypes) + if err != nil { + return nil, fmt.Errorf("invalid DDL object type filter config: %w", err) + } } return &adapter{ - dmlAdapter: dmlAdapter, - ddlAdapter: ddl, - schemaObserver: schemaObserver, - ddlEventAdapter: wal.WalDataToDDLEvent, + dmlAdapter: dmlAdapter, + ddlAdapter: ddl, + ddlObjectTypeFilter: ddlFilter, + schemaObserver: schemaObserver, + ddlEventAdapter: wal.WalDataToDDLEvent, }, nil } @@ -83,10 +90,11 @@ func (a *adapter) walEventToQueries(ctx context.Context, e *wal.Event) ([]*query if err != nil { return nil, err } + // always update the schema observer to keep internal cache correct a.schemaObserver.update(ddlEvent) - // there's no ddl adapter, the ddl query will not be processed - if a.ddlAdapter == nil { + // skip DDL execution if no adapter or if the DDL object type is filtered out + if a.ddlAdapter == nil || a.ddlObjectTypeFilter.shouldSkipDDL(ddlEvent) { return []*query{{}}, nil } diff --git a/pkg/wal/processor/postgres/postgres_writer.go b/pkg/wal/processor/postgres/postgres_writer.go index 8ef70694..69aa1c04 100644 --- a/pkg/wal/processor/postgres/postgres_writer.go +++ b/pkg/wal/processor/postgres/postgres_writer.go @@ -57,7 +57,7 @@ func newWriter(ctx context.Context, config *Config, writerType string, opts ...W forCopy := writerType == bulkIngestWriter - w.adapter, err = newAdapter(ctx, w.logger, config.IgnoreDDL, config.URL, config.OnConflictAction, forCopy) + w.adapter, err = newAdapter(ctx, w.logger, config, forCopy) if err != nil { return nil, err } From ab6b29408fcc861c8de74387ad12614e5785dc55 Mon Sep 17 00:00:00 2001 From: Tudor Golubenco Date: Sun, 22 Feb 2026 18:49:19 +0100 Subject: [PATCH 2/2] Tests fix --- .../pg_pg_object_type_filter_integration_test.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pkg/stream/integration/pg_pg_object_type_filter_integration_test.go b/pkg/stream/integration/pg_pg_object_type_filter_integration_test.go index 98893ce4..d0b9e048 100644 --- a/pkg/stream/integration/pg_pg_object_type_filter_integration_test.go +++ b/pkg/stream/integration/pg_pg_object_type_filter_integration_test.go @@ -19,15 +19,18 @@ func Test_PostgresToPostgres_ObjectTypeFilter(t *testing.T) { t.Skip("skipping integration test...") } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Spin up a separate source Postgres container + // Spin up a separate source Postgres container. + // Container setup must be deferred before context cancel so that + // cancel() runs first (LIFO), allowing the stream to shut down + // gracefully before the container is terminated. var sourcePGURL string - pgcleanup, err := testcontainers.SetupPostgresContainer(ctx, &sourcePGURL, testcontainers.Postgres14, "config/postgresql.conf") + pgcleanup, err := testcontainers.SetupPostgresContainer(context.Background(), &sourcePGURL, testcontainers.Postgres14, "config/postgresql.conf") require.NoError(t, err) defer pgcleanup() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Load fixture SQL into the source database fixtureSQL, err := os.ReadFile("testdata/object_types_fixture.sql") require.NoError(t, err)