From a46162319edfd827498b0f2da609c93390635fc8 Mon Sep 17 00:00:00 2001 From: BesikiML Date: Thu, 11 Jun 2026 16:33:41 -0400 Subject: [PATCH] refactor(reconcile): migrate internal config usage to source/target connection configs Replace DatabaseConfig usage across reconcile runtime paths with SourceConnectionConfig and TargetConnectionConfig, update service wiring and integration tests, and keep ReconCapture backward-compatible for legacy constructor callers during migration. --- .../lakebridge/reconcile/recon_capture.py | 81 ++++++++-- .../lakebridge/reconcile/reconciliation.py | 71 +++++---- .../trigger_recon_aggregate_service.py | 7 +- .../reconcile/trigger_recon_service.py | 26 ++-- .../reconcile/query_builder/test_execute.py | 59 +++---- .../test_aggregates_recon_capture.py | 17 +- .../reconcile/test_aggregates_reconcile.py | 28 ++-- .../reconcile/test_recon_capture.py | 147 +++++++----------- 8 files changed, 230 insertions(+), 206 deletions(-) diff --git a/src/databricks/labs/lakebridge/reconcile/recon_capture.py b/src/databricks/labs/lakebridge/reconcile/recon_capture.py index 165de2deb2..f0832f588b 100644 --- a/src/databricks/labs/lakebridge/reconcile/recon_capture.py +++ b/src/databricks/labs/lakebridge/reconcile/recon_capture.py @@ -11,7 +11,13 @@ from pyspark.errors import PySparkException from sqlglot import Dialect -from databricks.labs.lakebridge.config import DatabaseConfig, Table, ReconcileMetadataConfig +from databricks.labs.lakebridge.config import ( + DatabaseConfig, + SourceConnectionConfig, + TargetConnectionConfig, + Table, + ReconcileMetadataConfig, +) from databricks.labs.lakebridge.reconcile.recon_config import TableThresholds from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_key_from_dialect from databricks.labs.lakebridge.reconcile.exception import ( @@ -252,15 +258,58 @@ class ReconCapture: def __init__( self, - database_config: DatabaseConfig, - recon_id: str, - report_type: str, - source_dialect: Dialect, - ws: WorkspaceClient, - spark: SparkSession, + source_connection: SourceConnectionConfig | DatabaseConfig, + target_connection: TargetConnectionConfig | str, + recon_id: str | None = None, + report_type: str | None = None, + source_dialect: Dialect | None = None, + ws: WorkspaceClient | None = None, + spark: SparkSession | None = None, metadata_config: ReconcileMetadataConfig = ReconcileMetadataConfig(), ): - self.database_config = database_config + if isinstance(source_connection, DatabaseConfig): + # Backward-compatible constructor support for tests still passing DatabaseConfig: + # ReconCapture(database_config, recon_id, report_type, source_dialect, ws, spark, ...) + legacy_db = source_connection + if not isinstance(target_connection, str): + raise ValueError("Expected recon_id as second argument when using DatabaseConfig") + legacy_recon_id = target_connection + legacy_report_type = recon_id + legacy_source_dialect = report_type + legacy_ws = source_dialect + legacy_spark = ws + if ( + not isinstance(legacy_report_type, str) + or legacy_source_dialect is None + or legacy_ws is None + or legacy_spark is None + ): + raise ValueError("Invalid legacy ReconCapture constructor arguments") + source_connection = SourceConnectionConfig( + dialect=get_key_from_dialect(legacy_source_dialect), + catalog=legacy_db.source_catalog, + schema=legacy_db.source_schema, + uc_connection_name=( + "remorph_connection" if get_key_from_dialect(legacy_source_dialect) != "databricks" else None + ), + ) + target_connection = TargetConnectionConfig( + catalog=legacy_db.target_catalog, + schema=legacy_db.target_schema, + ) + recon_id = legacy_recon_id + report_type = legacy_report_type + source_dialect = legacy_source_dialect + ws = legacy_ws + spark = legacy_spark + + if recon_id is None or report_type is None or source_dialect is None or ws is None or spark is None: + raise ValueError("ReconCapture requires recon_id, report_type, source_dialect, ws, and spark") + if not isinstance(target_connection, TargetConnectionConfig): + raise ValueError("ReconCapture requires TargetConnectionConfig for target_connection") + + self.source_connection = source_connection + self.target_connection = target_connection self.recon_id = recon_id self.report_type = report_type self.source_dialect = source_dialect @@ -273,12 +322,12 @@ def _generate_recon_main_id( table_conf: Table, ) -> int: full_source_table = ( - f"{self.database_config.source_schema}.{table_conf.source_name}" - if self.database_config.source_catalog is None - else f"{self.database_config.source_catalog}.{self.database_config.source_schema}.{table_conf.source_name}" + f"{self.source_connection.schema}.{table_conf.source_name}" + if self.source_connection.catalog is None + else f"{self.source_connection.catalog}.{self.source_connection.schema}.{table_conf.source_name}" ) full_target_table = ( - f"{self.database_config.target_catalog}.{self.database_config.target_schema}.{table_conf.target_name}" + f"{self.target_connection.catalog}.{self.target_connection.schema}.{table_conf.target_name}" ) return hash(f"{self.recon_id}{full_source_table}{full_target_table}") @@ -300,13 +349,13 @@ def _insert_into_main_table( else '{source_dialect_key}' end as source_type, named_struct( - 'catalog', case when '{self.database_config.source_catalog}' = 'None' then null else '{self.database_config.source_catalog}' end, - 'schema', '{self.database_config.source_schema}', + 'catalog', case when '{self.source_connection.catalog}' = 'None' then null else '{self.source_connection.catalog}' end, + 'schema', '{self.source_connection.schema}', 'table_name', '{table_conf.source_name}' ) as source_table, named_struct( - 'catalog', '{self.database_config.target_catalog}', - 'schema', '{self.database_config.target_schema}', + 'catalog', '{self.target_connection.catalog}', + 'schema', '{self.target_connection.schema}', 'table_name', '{table_conf.target_name}' ) as target_table, '{self.report_type}' as report_type, diff --git a/src/databricks/labs/lakebridge/reconcile/reconciliation.py b/src/databricks/labs/lakebridge/reconcile/reconciliation.py index 7dd104964e..390cff0bea 100644 --- a/src/databricks/labs/lakebridge/reconcile/reconciliation.py +++ b/src/databricks/labs/lakebridge/reconcile/reconciliation.py @@ -1,12 +1,14 @@ import logging +from typing import Any, cast from pyspark.sql import DataFrame, SparkSession from sqlglot import Dialect from databricks.labs.lakebridge.config import ( - DatabaseConfig, HashExpressionOverrides, ReconcileMetadataConfig, + SourceConnectionConfig, + TargetConnectionConfig, ) from databricks.labs.lakebridge.reconcile.compare import ( capture_mismatch_data_and_columns, @@ -56,7 +58,8 @@ def __init__( self, source: DataSource, target: DataSource, - database_config: DatabaseConfig, + source_connection: SourceConnectionConfig, + target_connection: TargetConnectionConfig, report_type: str, schema_comparator: SchemaCompare, source_engine: Dialect, @@ -68,7 +71,8 @@ def __init__( self._source = source self._target = target self._report_type = report_type - self._database_config = database_config + self._source_connection = source_connection + self._target_connection = target_connection self._schema_comparator = schema_comparator self._target_engine = get_dialect("databricks") self._source_engine = source_engine @@ -150,15 +154,15 @@ def _get_reconcile_output( ), ).build_query(report_type=self._report_type) src_data = self._source.read_data( - catalog=self._database_config.source_catalog, - schema=self._database_config.source_schema, + catalog=self._source_connection.catalog, + schema=self._source_connection.schema, table=table_conf.source_name, query=src_hash_query, options=table_conf.jdbc_reader_options, ) tgt_data = self._target.read_data( - catalog=self._database_config.target_catalog, - schema=self._database_config.target_schema, + catalog=self._target_connection.catalog, + schema=self._target_connection.schema, table=table_conf.target_name, query=tgt_hash_query, options=table_conf.jdbc_reader_options, @@ -265,15 +269,15 @@ def _get_reconcile_aggregate_output( data_source_exception = None try: src_data = self._source.read_data( - catalog=self._database_config.source_catalog, - schema=self._database_config.source_schema, + catalog=self._source_connection.catalog, + schema=self._source_connection.schema, table=table_conf.source_name, query=src_query_with_rules.query, options=table_conf.jdbc_reader_options, ) tgt_data = self._target.read_data( - catalog=self._database_config.target_catalog, - schema=self._database_config.target_schema, + catalog=self._target_connection.catalog, + schema=self._target_connection.schema, table=table_conf.target_name, query=tgt_query_with_rules.query, options=table_conf.jdbc_reader_options, @@ -293,6 +297,9 @@ def _get_reconcile_aggregate_output( if data_source_exception: rule_reconcile_output = DataReconcileOutput(exception=str(data_source_exception)) else: + assert joined_df is not None + assert src_data is not None + assert tgt_data is not None rule_reconcile_output = reconcile_agg_data_per_rule( joined_df, src_data.columns, tgt_data.columns, rule ) @@ -338,8 +345,8 @@ def _get_sample_data( self._target, tgt_sampler, reconcile_output.missing_in_src, - self._database_config.target_catalog, - self._database_config.target_schema, + self._target_connection.catalog, + self._target_connection.schema, table_conf.target_name, ) @@ -348,13 +355,13 @@ def _get_sample_data( self._source, src_sampler, reconcile_output.missing_in_tgt, - self._database_config.source_catalog, - self._database_config.source_schema, + self._source_connection.catalog, + self._source_connection.schema, table_conf.source_name, ) return DataReconcileOutput( - mismatch=mismatch, + mismatch=cast(Any, mismatch), mismatch_count=reconcile_output.mismatch_count, missing_in_src_count=reconcile_output.missing_in_src_count, missing_in_tgt_count=reconcile_output.missing_in_tgt_count, @@ -377,8 +384,8 @@ def _get_mismatch_data( tgt_sampling_query = tgt_sampler.build_query_with_alias() sampling_model_target = self._target.read_data( - catalog=self._database_config.target_catalog, - schema=self._database_config.target_schema, + catalog=self._target_connection.catalog, + schema=self._target_connection.schema, table=tgt_table, query=tgt_sampling_query, options=None, @@ -396,15 +403,15 @@ def _get_mismatch_data( tgt_mismatch_sample_query = tgt_sampler.build_query(df) src_data = self._source.read_data( - catalog=self._database_config.source_catalog, - schema=self._database_config.source_schema, + catalog=self._source_connection.catalog, + schema=self._source_connection.schema, table=src_table, query=src_mismatch_sample_query, options=None, ) tgt_data = self._target.read_data( - catalog=self._database_config.target_catalog, - schema=self._database_config.target_schema, + catalog=self._target_connection.catalog, + schema=self._target_connection.schema, table=tgt_table, query=tgt_mismatch_sample_query, options=None, @@ -443,15 +450,15 @@ def _get_threshold_data( ).build_threshold_query() src_data = self._source.read_data( - catalog=self._database_config.source_catalog, - schema=self._database_config.source_schema, + catalog=self._source_connection.catalog, + schema=self._source_connection.schema, table=table_conf.source_name, query=src_threshold_query, options=table_conf.jdbc_reader_options, ) tgt_data = self._target.read_data( - catalog=self._database_config.target_catalog, - schema=self._database_config.target_schema, + catalog=self._target_connection.catalog, + schema=self._target_connection.schema, table=table_conf.target_name, query=tgt_threshold_query, options=table_conf.jdbc_reader_options, @@ -465,8 +472,8 @@ def _compute_threshold_comparison(self, table_conf: Table, src_schema: list[Sche ).build_comparison_query() threshold_result = self._target.read_data( - catalog=self._database_config.target_catalog, - schema=self._database_config.target_schema, + catalog=self._target_connection.catalog, + schema=self._target_connection.schema, table=table_conf.target_name, query=threshold_comparison_query, options=table_conf.jdbc_reader_options, @@ -489,15 +496,15 @@ def get_record_count(self, table_conf: Table, report_type: str) -> ReconcileReco source_count_query = CountQueryBuilder(table_conf, "source", self._source_engine).build_query() target_count_query = CountQueryBuilder(table_conf, "target", self._target_engine).build_query() source_count_row = self._source.read_data( - catalog=self._database_config.source_catalog, - schema=self._database_config.source_schema, + catalog=self._source_connection.catalog, + schema=self._source_connection.schema, table=table_conf.source_name, query=source_count_query, options=None, ).first() target_count_row = self._target.read_data( - catalog=self._database_config.target_catalog, - schema=self._database_config.target_schema, + catalog=self._target_connection.catalog, + schema=self._target_connection.schema, table=table_conf.target_name, query=target_count_query, options=None, diff --git a/src/databricks/labs/lakebridge/reconcile/trigger_recon_aggregate_service.py b/src/databricks/labs/lakebridge/reconcile/trigger_recon_aggregate_service.py index 94682512b5..ddb656882a 100644 --- a/src/databricks/labs/lakebridge/reconcile/trigger_recon_aggregate_service.py +++ b/src/databricks/labs/lakebridge/reconcile/trigger_recon_aggregate_service.py @@ -67,7 +67,12 @@ def recon_aggregate_one( recon_process_duration = ReconcileProcessDuration(start_ts=str(datetime.now(tz=timezone.utc)), end_ts=None) try: src_schema, tgt_schema = TriggerReconService.get_schemas( - reconciler.source, reconciler.target, normalized_table_conf, reconcile_config.database_config, True + reconciler.source, + reconciler.target, + normalized_table_conf, + reconcile_config.source, + reconcile_config.target, + True, ) table_reconcile_agg_output_list = reconciler.reconcile_aggregates( diff --git a/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py b/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py index 9e6c48986b..4cdf524cc6 100644 --- a/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py +++ b/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py @@ -7,7 +7,12 @@ from databricks.sdk import WorkspaceClient -from databricks.labs.lakebridge.config import ReconcileConfig, TableRecon, DatabaseConfig +from databricks.labs.lakebridge.config import ( + ReconcileConfig, + TableRecon, + SourceConnectionConfig, + TargetConnectionConfig, +) from databricks.labs.lakebridge.reconcile import utils from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException, ReconciliationException @@ -87,7 +92,8 @@ def create_recon_dependencies( reconciler = Reconciliation( source, target, - reconcile_config.database_config, + reconcile_config.source, + reconcile_config.target, report_type, SchemaCompare(spark=spark), get_dialect(source_dialect), @@ -98,7 +104,8 @@ def create_recon_dependencies( ) recon_capture = ReconCapture( - database_config=reconcile_config.database_config, + source_connection=reconcile_config.source, + target_connection=reconcile_config.target, recon_id=recon_id, report_type=report_type, source_dialect=get_dialect(source_dialect), @@ -142,7 +149,7 @@ def _do_recon_one(reconciler: Reconciliation, reconcile_config: ReconcileConfig, try: src_schema, tgt_schema = TriggerReconService.get_schemas( - reconciler.source, reconciler.target, table_conf, reconcile_config.database_config, True + reconciler.source, reconciler.target, table_conf, reconcile_config.source, reconcile_config.target, True ) except DataSourceRuntimeException as e: schema_reconcile_output = SchemaReconcileOutput(is_valid=False, exception=str(e)) @@ -173,19 +180,20 @@ def get_schemas( source: DataSource, target: DataSource, table_conf: Table, - database_config: DatabaseConfig, + source_connection: SourceConnectionConfig, + target_connection: TargetConnectionConfig, normalize: bool, ) -> tuple[list[Schema], list[Schema]]: src_schema = source.get_schema( - catalog=database_config.source_catalog, - schema=database_config.source_schema, + catalog=source_connection.catalog, + schema=source_connection.schema, table=table_conf.source_name, normalize=normalize, ) tgt_schema = target.get_schema( - catalog=database_config.target_catalog, - schema=database_config.target_schema, + catalog=target_connection.catalog, + schema=target_connection.schema, table=table_conf.target_name, normalize=normalize, ) diff --git a/tests/integration/reconcile/query_builder/test_execute.py b/tests/integration/reconcile/query_builder/test_execute.py index d9169a73c9..09642926d0 100644 --- a/tests/integration/reconcile/query_builder/test_execute.py +++ b/tests/integration/reconcile/query_builder/test_execute.py @@ -10,7 +10,6 @@ from pyspark.testing import assertDataFrameEqual from databricks.labs.lakebridge.config import ( - DatabaseConfig, TableRecon, ReconcileMetadataConfig, ReconcileConfig, @@ -231,19 +230,16 @@ def test_reconcile_data_with_mismatches_and_missing( ), } target_schema_repository = {(CATALOG, SCHEMA, TGT_TABLE): tgt_schema} - database_config = DatabaseConfig( - source_catalog=CATALOG, - source_schema=SCHEMA, - target_catalog=CATALOG, - target_schema=SCHEMA, - ) + source_connection = SourceConnectionConfig(dialect="databricks", catalog=CATALOG, schema=SCHEMA) + target_connection = TargetConnectionConfig(catalog=CATALOG, schema=SCHEMA) schema_comparator = SchemaCompare(spark) source = MockDataSource(source_dataframe_repository, source_schema_repository) target = MockDataSource(target_dataframe_repository, target_schema_repository) actual_data_reconcile = Reconciliation( source, target, - database_config, + source_connection, + target_connection, "data", schema_comparator, get_dialect("databricks"), @@ -313,7 +309,8 @@ def test_reconcile_data_with_mismatches_and_missing( actual_schema_reconcile = Reconciliation( source, target, - database_config, + source_connection, + target_connection, "data", schema_comparator, get_dialect("databricks"), @@ -439,19 +436,16 @@ def test_reconcile_data_without_mismatches_and_missing( ), } target_schema_repository = {(CATALOG, SCHEMA, TGT_TABLE): tgt_schema} - database_config = DatabaseConfig( - source_catalog=CATALOG, - source_schema=SCHEMA, - target_catalog=CATALOG, - target_schema=SCHEMA, - ) + source_connection = SourceConnectionConfig(dialect="databricks", catalog=CATALOG, schema=SCHEMA) + target_connection = TargetConnectionConfig(catalog=CATALOG, schema=SCHEMA) schema_comparator = SchemaCompare(spark) source = MockDataSource(source_dataframe_repository, source_schema_repository) target = MockDataSource(target_dataframe_repository, target_schema_repository) actual = Reconciliation( source, target, - database_config, + source_connection, + target_connection, "data", schema_comparator, get_dialect("databricks"), @@ -522,19 +516,16 @@ def test_reconcile_data_with_mismatch_and_no_missing( ), } target_schema_repository = {(CATALOG, SCHEMA, TGT_TABLE): tgt_schema} - database_config = DatabaseConfig( - source_catalog=CATALOG, - source_schema=SCHEMA, - target_catalog=CATALOG, - target_schema=SCHEMA, - ) + source_connection = SourceConnectionConfig(dialect="databricks", catalog=CATALOG, schema=SCHEMA) + target_connection = TargetConnectionConfig(catalog=CATALOG, schema=SCHEMA) schema_comparator = SchemaCompare(spark) source = MockDataSource(source_dataframe_repository, source_schema_repository) target = MockDataSource(target_dataframe_repository, target_schema_repository) actual = Reconciliation( source, target, - database_config, + source_connection, + target_connection, "data", schema_comparator, get_dialect("databricks"), @@ -625,19 +616,16 @@ def test_reconcile_data_missing_and_no_mismatch( ), } target_schema_repository = {(CATALOG, SCHEMA, TGT_TABLE): tgt_schema} - database_config = DatabaseConfig( - source_catalog=CATALOG, - source_schema=SCHEMA, - target_catalog=CATALOG, - target_schema=SCHEMA, - ) + source_connection = SourceConnectionConfig(dialect="databricks", catalog=CATALOG, schema=SCHEMA) + target_connection = TargetConnectionConfig(catalog=CATALOG, schema=SCHEMA) schema_comparator = SchemaCompare(spark) source = MockDataSource(source_dataframe_repository, source_schema_repository) target = MockDataSource(target_dataframe_repository, target_schema_repository) actual = Reconciliation( source, target, - database_config, + source_connection, + target_connection, "data", schema_comparator, get_dialect("databricks"), @@ -1972,12 +1960,8 @@ def test_reconcile_data_with_threshold_and_row_report_type( } target_schema_repository = {(CATALOG, SCHEMA, TGT_TABLE): tgt_schema} - database_config = DatabaseConfig( - source_catalog=CATALOG, - source_schema=SCHEMA, - target_catalog=CATALOG, - target_schema=SCHEMA, - ) + source_connection = SourceConnectionConfig(dialect="databricks", catalog=CATALOG, schema=SCHEMA) + target_connection = TargetConnectionConfig(catalog=CATALOG, schema=SCHEMA) schema_comparator = SchemaCompare(spark) source = MockDataSource(source_dataframe_repository, source_schema_repository) target = MockDataSource(target_dataframe_repository, target_schema_repository) @@ -1985,7 +1969,8 @@ def test_reconcile_data_with_threshold_and_row_report_type( actual = Reconciliation( source, target, - database_config, + source_connection, + target_connection, "row", schema_comparator, get_dialect("databricks"), diff --git a/tests/integration/reconcile/test_aggregates_recon_capture.py b/tests/integration/reconcile/test_aggregates_recon_capture.py index 2b743bdbf6..d33d189a10 100644 --- a/tests/integration/reconcile/test_aggregates_recon_capture.py +++ b/tests/integration/reconcile/test_aggregates_recon_capture.py @@ -2,7 +2,11 @@ from pyspark.sql import Row, SparkSession -from databricks.labs.lakebridge.config import DatabaseConfig, ReconcileMetadataConfig +from databricks.labs.lakebridge.config import ( + ReconcileMetadataConfig, + SourceConnectionConfig, + TargetConnectionConfig, +) from databricks.labs.lakebridge.reconcile.recon_capture import ( ReconCapture, ) @@ -37,9 +41,13 @@ def agg_data_prep(spark: SparkSession): def test_aggregates_reconcile_store_aggregate_metrics( mock_workspace_client, spark, recon_metadata: ReconcileMetadataConfig ): - database_config = DatabaseConfig( - "source_test_schema", "target_test_catalog", "target_test_schema", "source_test_catalog" + source_connection = SourceConnectionConfig( + dialect="snowflake", + catalog="source_test_schema", + schema="target_test_catalog", + uc_connection_name="remorph_snowflake", ) + target_connection = TargetConnectionConfig(catalog="target_test_schema", schema="source_test_catalog") source_type = get_dialect("snowflake") agg_reconcile_output, table_conf, reconcile_process_duration = agg_data_prep(spark) @@ -47,7 +55,8 @@ def test_aggregates_reconcile_store_aggregate_metrics( recon_id = "999fygdrs-dbb7-489f-bad1-6a7e8f4821b1" recon_capture = ReconCapture( - database_config, + source_connection, + target_connection, recon_id, "", source_type, diff --git a/tests/integration/reconcile/test_aggregates_reconcile.py b/tests/integration/reconcile/test_aggregates_reconcile.py index ffcfdb39d4..fcbd2696a1 100644 --- a/tests/integration/reconcile/test_aggregates_reconcile.py +++ b/tests/integration/reconcile/test_aggregates_reconcile.py @@ -10,7 +10,11 @@ from tests.integration.reconcile.conftest import FakeReconIntermediatePersist from tests.conftest import ansi_schema_fixture_factory -from databricks.labs.lakebridge.config import DatabaseConfig, ReconcileMetadataConfig +from databricks.labs.lakebridge.config import ( + ReconcileMetadataConfig, + SourceConnectionConfig, + TargetConnectionConfig, +) from databricks.labs.lakebridge.reconcile.reconciliation import Reconciliation from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect from databricks.labs.lakebridge.reconcile.connectors.data_source import MockDataSource @@ -107,18 +111,15 @@ def test_reconcile_aggregate_data_missing_records( } target_schema_repository = {(CATALOG, SCHEMA, TGT_TABLE): tgt_schema} - database_config = DatabaseConfig( - source_catalog=CATALOG, - source_schema=SCHEMA, - target_catalog=CATALOG, - target_schema=SCHEMA, - ) + source_connection = SourceConnectionConfig(dialect="databricks", catalog=CATALOG, schema=SCHEMA) + target_connection = TargetConnectionConfig(catalog=CATALOG, schema=SCHEMA) source = MockDataSource(source_dataframe_repository, source_schema_repository) target = MockDataSource(target_dataframe_repository, target_schema_repository) actual: list[AggregateQueryOutput] = Reconciliation( source, target, - database_config, + source_connection, + target_connection, "", SchemaCompare(spark), get_dialect("databricks"), @@ -353,18 +354,15 @@ def test_reconcile_aggregate_data_mismatch_and_missing_records( } target_schema_repository = {(CATALOG, SCHEMA, TGT_TABLE): tgt_schema} - db_config = DatabaseConfig( - source_catalog=CATALOG, - source_schema=SCHEMA, - target_catalog=CATALOG, - target_schema=SCHEMA, - ) + source_connection = SourceConnectionConfig(dialect="snowflake", catalog=CATALOG, schema=SCHEMA) + target_connection = TargetConnectionConfig(catalog=CATALOG, schema=SCHEMA) source = MockDataSource(source_dataframe_repository, source_schema_repository, delimiter='"') target = MockDataSource(target_dataframe_repository, target_schema_repository) actual_list: list[AggregateQueryOutput] = Reconciliation( source, target, - db_config, + source_connection, + target_connection, "", SchemaCompare(spark), get_dialect("snowflake"), diff --git a/tests/integration/reconcile/test_recon_capture.py b/tests/integration/reconcile/test_recon_capture.py index 492cf4d362..01368fdc6c 100644 --- a/tests/integration/reconcile/test_recon_capture.py +++ b/tests/integration/reconcile/test_recon_capture.py @@ -7,7 +7,11 @@ from pyspark.sql.functions import countDistinct from pyspark.sql.types import BooleanType, StringType, StructField, StructType -from databricks.labs.lakebridge.config import DatabaseConfig, ReconcileMetadataConfig +from databricks.labs.lakebridge.config import ( + ReconcileMetadataConfig, + SourceConnectionConfig, + TargetConnectionConfig, +) from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect from databricks.labs.lakebridge.reconcile.exception import WriteToTableException from databricks.labs.lakebridge.reconcile.recon_capture import ( @@ -120,15 +124,25 @@ def data_prep(spark: SparkSession): return reconcile_output, schema_output, table_conf, reconcile_process, row_count -def test_recon_capture_start_snowflake_all(mock_workspace_client, spark, recon_metadata): - database_config = DatabaseConfig( - "source_test_catalog", "source_test_schema", "target_test_catalog", "target_test_schema" +def _connection_configs_for(dialect: str) -> tuple[SourceConnectionConfig, TargetConnectionConfig]: + return ( + SourceConnectionConfig( + dialect=dialect, + catalog="source_test_catalog", + schema="source_test_schema", + uc_connection_name="remorph_connection" if dialect != "databricks" else None, + ), + TargetConnectionConfig(catalog="target_test_catalog", schema="target_test_schema"), ) + + +def test_recon_capture_start_snowflake_all(mock_workspace_client, spark, recon_metadata): ws = mock_workspace_client source_type = get_dialect("snowflake") + connection_configs = _connection_configs_for("snowflake") reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "all", source_type, @@ -206,13 +220,11 @@ def test_recon_capture_start_snowflake_all(mock_workspace_client, spark, recon_m def test_test_recon_capture_start_databricks_data(mock_workspace_client, spark, recon_metadata): - database_config = DatabaseConfig( - "source_test_catalog", "source_test_schema", "target_test_catalog", "target_test_schema" - ) ws = mock_workspace_client source_type = get_dialect("databricks") + connection_configs = _connection_configs_for("databricks") recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "data", source_type, @@ -252,13 +264,11 @@ def test_test_recon_capture_start_databricks_data(mock_workspace_client, spark, def test_test_recon_capture_start_databricks_row(mock_workspace_client, spark, recon_metadata): - database_config = DatabaseConfig( - "source_test_catalog", "source_test_schema", "target_test_catalog", "target_test_schema" - ) ws = mock_workspace_client source_type = get_dialect("databricks") + connection_configs = _connection_configs_for("databricks") recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "row", source_type, @@ -301,13 +311,11 @@ def test_test_recon_capture_start_databricks_row(mock_workspace_client, spark, r def test_recon_capture_start_oracle_schema(mock_workspace_client, spark, recon_metadata): - database_config = DatabaseConfig( - "source_test_catalog", "source_test_schema", "target_test_catalog", "target_test_schema" - ) ws = mock_workspace_client source_type = get_dialect("oracle") + connection_configs = _connection_configs_for("oracle") recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "schema", source_type, @@ -352,13 +360,11 @@ def test_recon_capture_start_oracle_schema(mock_workspace_client, spark, recon_m def test_recon_capture_start_oracle_with_exception(mock_workspace_client, spark, recon_metadata): - database_config = DatabaseConfig( - "source_test_catalog", "source_test_schema", "target_test_catalog", "target_test_schema" - ) ws = mock_workspace_client source_type = get_dialect("oracle") + connection_configs = _connection_configs_for("oracle") recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "all", source_type, @@ -398,13 +404,11 @@ def test_recon_capture_start_oracle_with_exception(mock_workspace_client, spark, def test_recon_capture_start_with_exception(mock_workspace_client, spark): - database_config = DatabaseConfig( - "source_test_catalog", "source_test_schema", "target_test_catalog", "target_test_schema" - ) ws = mock_workspace_client source_type = get_dialect("snowflake") + connection_configs = _connection_configs_for("snowflake") recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "all", source_type, @@ -423,16 +427,11 @@ def test_recon_capture_start_with_exception(mock_workspace_client, spark): def test_generate_final_reconcile_output_row(mock_workspace_client, spark, recon_metadata): - database_config = DatabaseConfig( - "source_test_catalog", - "source_test_schema", - "target_test_catalog", - "target_test_schema", - ) ws = mock_workspace_client source_type = get_dialect("databricks") + connection_configs = _connection_configs_for("databricks") recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "row", source_type, @@ -469,16 +468,11 @@ def test_generate_final_reconcile_output_row(mock_workspace_client, spark, recon def test_generate_final_reconcile_output_data(mock_workspace_client, spark, recon_metadata): - database_config = DatabaseConfig( - "source_test_catalog", - "source_test_schema", - "target_test_catalog", - "target_test_schema", - ) ws = mock_workspace_client source_type = get_dialect("databricks") + connection_configs = _connection_configs_for("databricks") recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "data", source_type, @@ -515,16 +509,11 @@ def test_generate_final_reconcile_output_data(mock_workspace_client, spark, reco def test_generate_final_reconcile_output_schema(mock_workspace_client, spark, recon_metadata): - database_config = DatabaseConfig( - "source_test_catalog", - "source_test_schema", - "target_test_catalog", - "target_test_schema", - ) ws = mock_workspace_client source_type = get_dialect("databricks") + connection_configs = _connection_configs_for("databricks") recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "schema", source_type, @@ -561,16 +550,11 @@ def test_generate_final_reconcile_output_schema(mock_workspace_client, spark, re def test_generate_final_reconcile_output_all(mock_workspace_client, spark, recon_metadata): - database_config = DatabaseConfig( - "source_test_catalog", - "source_test_schema", - "target_test_catalog", - "target_test_schema", - ) ws = mock_workspace_client source_type = get_dialect("databricks") + connection_configs = _connection_configs_for("databricks") recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "all", source_type, @@ -608,16 +592,11 @@ def test_generate_final_reconcile_output_all(mock_workspace_client, spark, recon def test_generate_final_reconcile_output_exception(mock_workspace_client, spark, recon_metadata): - database_config = DatabaseConfig( - "source_test_catalog", - "source_test_schema", - "target_test_catalog", - "target_test_schema", - ) ws = mock_workspace_client source_type = get_dialect("databricks") + connection_configs = _connection_configs_for("databricks") recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "all", source_type, @@ -656,11 +635,9 @@ def test_generate_final_reconcile_output_exception(mock_workspace_client, spark, def test_apply_threshold_for_mismatch_with_true_absolute(mock_workspace_client, spark, recon_metadata): - database_config = DatabaseConfig( - "source_test_catalog", "source_test_schema", "target_test_catalog", "target_test_schema" - ) ws = mock_workspace_client source_type = get_dialect("snowflake") + connection_configs = _connection_configs_for("snowflake") reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) reconcile_output.missing_in_src_count = 0 reconcile_output.missing_in_tgt_count = 0 @@ -670,7 +647,7 @@ def test_apply_threshold_for_mismatch_with_true_absolute(mock_workspace_client, TableThresholds(lower_bound="0", upper_bound="4", model="mismatch"), ] recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "all", source_type, @@ -693,17 +670,15 @@ def test_apply_threshold_for_mismatch_with_true_absolute(mock_workspace_client, def test_apply_threshold_for_mismatch_with_missing(mock_workspace_client, spark, recon_metadata): - database_config = DatabaseConfig( - "source_test_catalog", "source_test_schema", "target_test_catalog", "target_test_schema" - ) ws = mock_workspace_client source_type = get_dialect("snowflake") + connection_configs = _connection_configs_for("snowflake") reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) table_conf.table_thresholds = [ TableThresholds(lower_bound="0", upper_bound="4", model="mismatch"), ] recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "all", source_type, @@ -726,17 +701,15 @@ def test_apply_threshold_for_mismatch_with_missing(mock_workspace_client, spark, def test_apply_threshold_for_mismatch_with_schema_fail(mock_workspace_client, spark, recon_metadata): - database_config = DatabaseConfig( - "source_test_catalog", "source_test_schema", "target_test_catalog", "target_test_schema" - ) ws = mock_workspace_client source_type = get_dialect("snowflake") + connection_configs = _connection_configs_for("snowflake") reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) table_conf.table_thresholds = [ TableThresholds(lower_bound="0", upper_bound="4", model="mismatch"), ] recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "all", source_type, @@ -763,11 +736,9 @@ def test_apply_threshold_for_mismatch_with_schema_fail(mock_workspace_client, sp def test_apply_threshold_for_mismatch_with_wrong_absolute_bound(mock_workspace_client, spark, recon_metadata): - database_config = DatabaseConfig( - "source_test_catalog", "source_test_schema", "target_test_catalog", "target_test_schema" - ) ws = mock_workspace_client source_type = get_dialect("snowflake") + connection_configs = _connection_configs_for("snowflake") reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) table_conf.table_thresholds = [ TableThresholds(lower_bound="0", upper_bound="1", model="mismatch"), @@ -778,7 +749,7 @@ def test_apply_threshold_for_mismatch_with_wrong_absolute_bound(mock_workspace_c reconcile_output.missing_in_src = None reconcile_output.missing_in_tgt = None recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "all", source_type, @@ -801,11 +772,9 @@ def test_apply_threshold_for_mismatch_with_wrong_absolute_bound(mock_workspace_c def test_apply_threshold_for_mismatch_with_wrong_percentage_bound(mock_workspace_client, spark, recon_metadata): - database_config = DatabaseConfig( - "source_test_catalog", "source_test_schema", "target_test_catalog", "target_test_schema" - ) ws = mock_workspace_client source_type = get_dialect("snowflake") + connection_configs = _connection_configs_for("snowflake") reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) table_conf.table_thresholds = [ TableThresholds(lower_bound="0%", upper_bound="20%", model="mismatch"), @@ -816,7 +785,7 @@ def test_apply_threshold_for_mismatch_with_wrong_percentage_bound(mock_workspace reconcile_output.missing_in_src = None reconcile_output.missing_in_tgt = None recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "all", source_type, @@ -839,11 +808,9 @@ def test_apply_threshold_for_mismatch_with_wrong_percentage_bound(mock_workspace def test_apply_threshold_for_mismatch_with_true_percentage_bound(mock_workspace_client, spark, recon_metadata): - database_config = DatabaseConfig( - "source_test_catalog", "source_test_schema", "target_test_catalog", "target_test_schema" - ) ws = mock_workspace_client source_type = get_dialect("snowflake") + connection_configs = _connection_configs_for("snowflake") reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) table_conf.table_thresholds = [ TableThresholds(lower_bound="0%", upper_bound="90%", model="mismatch"), @@ -853,7 +820,7 @@ def test_apply_threshold_for_mismatch_with_true_percentage_bound(mock_workspace_ reconcile_output.missing_in_src = None reconcile_output.missing_in_tgt = None recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "all", source_type, @@ -876,11 +843,9 @@ def test_apply_threshold_for_mismatch_with_true_percentage_bound(mock_workspace_ def test_apply_threshold_for_mismatch_with_invalid_bounds(mock_workspace_client, spark): - database_config = DatabaseConfig( - "source_test_catalog", "source_test_schema", "target_test_catalog", "target_test_schema" - ) ws = mock_workspace_client source_type = get_dialect("snowflake") + connection_configs = _connection_configs_for("snowflake") reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) reconcile_output.missing_in_src_count = 0 reconcile_output.missing_in_tgt_count = 0 @@ -888,7 +853,7 @@ def test_apply_threshold_for_mismatch_with_invalid_bounds(mock_workspace_client, reconcile_output.missing_in_src = None reconcile_output.missing_in_tgt = None recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "all", source_type, @@ -922,11 +887,9 @@ def test_apply_threshold_for_mismatch_with_invalid_bounds(mock_workspace_client, def test_apply_threshold_for_only_threshold_mismatch_with_true_absolute(mock_workspace_client, spark, recon_metadata): - database_config = DatabaseConfig( - "source_test_catalog", "source_test_schema", "target_test_catalog", "target_test_schema" - ) ws = mock_workspace_client source_type = get_dialect("snowflake") + connection_configs = _connection_configs_for("snowflake") reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) reconcile_output.mismatch_count = 0 reconcile_output.missing_in_src_count = 0 @@ -937,7 +900,7 @@ def test_apply_threshold_for_only_threshold_mismatch_with_true_absolute(mock_wor TableThresholds(lower_bound="0", upper_bound="2", model="mismatch"), ] recon_capture = ReconCapture( - database_config, + *connection_configs, "73b44582-dbb7-489f-bad1-6a7e8f4821b1", "all", source_type,