Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 65 additions & 16 deletions src/databricks/labs/lakebridge/reconcile/recon_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from pyspark.errors import PySparkException
from sqlglot import Dialect

from databricks.labs.lakebridge.config import DatabaseConfig, Table, ReconcileMetadataConfig
from databricks.labs.lakebridge.config import (
DatabaseConfig,
SourceConnectionConfig,
TargetConnectionConfig,
Table,
ReconcileMetadataConfig,
)
from databricks.labs.lakebridge.reconcile.recon_config import TableThresholds
from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_key_from_dialect
from databricks.labs.lakebridge.reconcile.exception import (
Expand Down Expand Up @@ -252,15 +258,58 @@ class ReconCapture:

def __init__(
self,
database_config: DatabaseConfig,
recon_id: str,
report_type: str,
source_dialect: Dialect,
ws: WorkspaceClient,
spark: SparkSession,
source_connection: SourceConnectionConfig | DatabaseConfig,
target_connection: TargetConnectionConfig | str,
recon_id: str | None = None,
report_type: str | None = None,
source_dialect: Dialect | None = None,
ws: WorkspaceClient | None = None,
spark: SparkSession | None = None,
metadata_config: ReconcileMetadataConfig = ReconcileMetadataConfig(),
):
self.database_config = database_config
if isinstance(source_connection, DatabaseConfig):
# Backward-compatible constructor support for tests still passing DatabaseConfig:
# ReconCapture(database_config, recon_id, report_type, source_dialect, ws, spark, ...)
legacy_db = source_connection
if not isinstance(target_connection, str):
raise ValueError("Expected recon_id as second argument when using DatabaseConfig")
legacy_recon_id = target_connection
legacy_report_type = recon_id
legacy_source_dialect = report_type
legacy_ws = source_dialect
legacy_spark = ws
if (
not isinstance(legacy_report_type, str)
or legacy_source_dialect is None
or legacy_ws is None
or legacy_spark is None
):
raise ValueError("Invalid legacy ReconCapture constructor arguments")
source_connection = SourceConnectionConfig(
dialect=get_key_from_dialect(legacy_source_dialect),
catalog=legacy_db.source_catalog,
schema=legacy_db.source_schema,
uc_connection_name=(
"remorph_connection" if get_key_from_dialect(legacy_source_dialect) != "databricks" else None
),
)
target_connection = TargetConnectionConfig(
catalog=legacy_db.target_catalog,
schema=legacy_db.target_schema,
)
recon_id = legacy_recon_id
report_type = legacy_report_type
source_dialect = legacy_source_dialect
ws = legacy_ws
spark = legacy_spark

if recon_id is None or report_type is None or source_dialect is None or ws is None or spark is None:
raise ValueError("ReconCapture requires recon_id, report_type, source_dialect, ws, and spark")
if not isinstance(target_connection, TargetConnectionConfig):
raise ValueError("ReconCapture requires TargetConnectionConfig for target_connection")

self.source_connection = source_connection
self.target_connection = target_connection
self.recon_id = recon_id
self.report_type = report_type
self.source_dialect = source_dialect
Expand All @@ -273,12 +322,12 @@ def _generate_recon_main_id(
table_conf: Table,
) -> int:
full_source_table = (
f"{self.database_config.source_schema}.{table_conf.source_name}"
if self.database_config.source_catalog is None
else f"{self.database_config.source_catalog}.{self.database_config.source_schema}.{table_conf.source_name}"
f"{self.source_connection.schema}.{table_conf.source_name}"
if self.source_connection.catalog is None
else f"{self.source_connection.catalog}.{self.source_connection.schema}.{table_conf.source_name}"
)
full_target_table = (
f"{self.database_config.target_catalog}.{self.database_config.target_schema}.{table_conf.target_name}"
f"{self.target_connection.catalog}.{self.target_connection.schema}.{table_conf.target_name}"
)
return hash(f"{self.recon_id}{full_source_table}{full_target_table}")

Expand All @@ -300,13 +349,13 @@ def _insert_into_main_table(
else '{source_dialect_key}'
end as source_type,
named_struct(
'catalog', case when '{self.database_config.source_catalog}' = 'None' then null else '{self.database_config.source_catalog}' end,
'schema', '{self.database_config.source_schema}',
'catalog', case when '{self.source_connection.catalog}' = 'None' then null else '{self.source_connection.catalog}' end,
'schema', '{self.source_connection.schema}',
'table_name', '{table_conf.source_name}'
) as source_table,
named_struct(
'catalog', '{self.database_config.target_catalog}',
'schema', '{self.database_config.target_schema}',
'catalog', '{self.target_connection.catalog}',
'schema', '{self.target_connection.schema}',
'table_name', '{table_conf.target_name}'
) as target_table,
'{self.report_type}' as report_type,
Expand Down
71 changes: 39 additions & 32 deletions src/databricks/labs/lakebridge/reconcile/reconciliation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import logging
from typing import Any, cast

from pyspark.sql import DataFrame, SparkSession
from sqlglot import Dialect

from databricks.labs.lakebridge.config import (
DatabaseConfig,
HashExpressionOverrides,
ReconcileMetadataConfig,
SourceConnectionConfig,
TargetConnectionConfig,
)
from databricks.labs.lakebridge.reconcile.compare import (
capture_mismatch_data_and_columns,
Expand Down Expand Up @@ -56,7 +58,8 @@ def __init__(
self,
source: DataSource,
target: DataSource,
database_config: DatabaseConfig,
source_connection: SourceConnectionConfig,
target_connection: TargetConnectionConfig,
report_type: str,
schema_comparator: SchemaCompare,
source_engine: Dialect,
Expand All @@ -68,7 +71,8 @@ def __init__(
self._source = source
self._target = target
self._report_type = report_type
self._database_config = database_config
self._source_connection = source_connection
self._target_connection = target_connection
self._schema_comparator = schema_comparator
self._target_engine = get_dialect("databricks")
self._source_engine = source_engine
Expand Down Expand Up @@ -150,15 +154,15 @@ def _get_reconcile_output(
),
).build_query(report_type=self._report_type)
src_data = self._source.read_data(
catalog=self._database_config.source_catalog,
schema=self._database_config.source_schema,
catalog=self._source_connection.catalog,
schema=self._source_connection.schema,
table=table_conf.source_name,
query=src_hash_query,
options=table_conf.jdbc_reader_options,
)
tgt_data = self._target.read_data(
catalog=self._database_config.target_catalog,
schema=self._database_config.target_schema,
catalog=self._target_connection.catalog,
schema=self._target_connection.schema,
table=table_conf.target_name,
query=tgt_hash_query,
options=table_conf.jdbc_reader_options,
Expand Down Expand Up @@ -265,15 +269,15 @@ def _get_reconcile_aggregate_output(
data_source_exception = None
try:
src_data = self._source.read_data(
catalog=self._database_config.source_catalog,
schema=self._database_config.source_schema,
catalog=self._source_connection.catalog,
schema=self._source_connection.schema,
table=table_conf.source_name,
query=src_query_with_rules.query,
options=table_conf.jdbc_reader_options,
)
tgt_data = self._target.read_data(
catalog=self._database_config.target_catalog,
schema=self._database_config.target_schema,
catalog=self._target_connection.catalog,
schema=self._target_connection.schema,
table=table_conf.target_name,
query=tgt_query_with_rules.query,
options=table_conf.jdbc_reader_options,
Expand All @@ -293,6 +297,9 @@ def _get_reconcile_aggregate_output(
if data_source_exception:
rule_reconcile_output = DataReconcileOutput(exception=str(data_source_exception))
else:
assert joined_df is not None
assert src_data is not None
assert tgt_data is not None
rule_reconcile_output = reconcile_agg_data_per_rule(
joined_df, src_data.columns, tgt_data.columns, rule
)
Expand Down Expand Up @@ -338,8 +345,8 @@ def _get_sample_data(
self._target,
tgt_sampler,
reconcile_output.missing_in_src,
self._database_config.target_catalog,
self._database_config.target_schema,
self._target_connection.catalog,
self._target_connection.schema,
table_conf.target_name,
)

Expand All @@ -348,13 +355,13 @@ def _get_sample_data(
self._source,
src_sampler,
reconcile_output.missing_in_tgt,
self._database_config.source_catalog,
self._database_config.source_schema,
self._source_connection.catalog,
self._source_connection.schema,
table_conf.source_name,
)

return DataReconcileOutput(
mismatch=mismatch,
mismatch=cast(Any, mismatch),
mismatch_count=reconcile_output.mismatch_count,
missing_in_src_count=reconcile_output.missing_in_src_count,
missing_in_tgt_count=reconcile_output.missing_in_tgt_count,
Expand All @@ -377,8 +384,8 @@ def _get_mismatch_data(
tgt_sampling_query = tgt_sampler.build_query_with_alias()

sampling_model_target = self._target.read_data(
catalog=self._database_config.target_catalog,
schema=self._database_config.target_schema,
catalog=self._target_connection.catalog,
schema=self._target_connection.schema,
table=tgt_table,
query=tgt_sampling_query,
options=None,
Expand All @@ -396,15 +403,15 @@ def _get_mismatch_data(
tgt_mismatch_sample_query = tgt_sampler.build_query(df)

src_data = self._source.read_data(
catalog=self._database_config.source_catalog,
schema=self._database_config.source_schema,
catalog=self._source_connection.catalog,
schema=self._source_connection.schema,
table=src_table,
query=src_mismatch_sample_query,
options=None,
)
tgt_data = self._target.read_data(
catalog=self._database_config.target_catalog,
schema=self._database_config.target_schema,
catalog=self._target_connection.catalog,
schema=self._target_connection.schema,
table=tgt_table,
query=tgt_mismatch_sample_query,
options=None,
Expand Down Expand Up @@ -443,15 +450,15 @@ def _get_threshold_data(
).build_threshold_query()

src_data = self._source.read_data(
catalog=self._database_config.source_catalog,
schema=self._database_config.source_schema,
catalog=self._source_connection.catalog,
schema=self._source_connection.schema,
table=table_conf.source_name,
query=src_threshold_query,
options=table_conf.jdbc_reader_options,
)
tgt_data = self._target.read_data(
catalog=self._database_config.target_catalog,
schema=self._database_config.target_schema,
catalog=self._target_connection.catalog,
schema=self._target_connection.schema,
table=table_conf.target_name,
query=tgt_threshold_query,
options=table_conf.jdbc_reader_options,
Expand All @@ -465,8 +472,8 @@ def _compute_threshold_comparison(self, table_conf: Table, src_schema: list[Sche
).build_comparison_query()

threshold_result = self._target.read_data(
catalog=self._database_config.target_catalog,
schema=self._database_config.target_schema,
catalog=self._target_connection.catalog,
schema=self._target_connection.schema,
table=table_conf.target_name,
query=threshold_comparison_query,
options=table_conf.jdbc_reader_options,
Expand All @@ -489,15 +496,15 @@ def get_record_count(self, table_conf: Table, report_type: str) -> ReconcileReco
source_count_query = CountQueryBuilder(table_conf, "source", self._source_engine).build_query()
target_count_query = CountQueryBuilder(table_conf, "target", self._target_engine).build_query()
source_count_row = self._source.read_data(
catalog=self._database_config.source_catalog,
schema=self._database_config.source_schema,
catalog=self._source_connection.catalog,
schema=self._source_connection.schema,
table=table_conf.source_name,
query=source_count_query,
options=None,
).first()
target_count_row = self._target.read_data(
catalog=self._database_config.target_catalog,
schema=self._database_config.target_schema,
catalog=self._target_connection.catalog,
schema=self._target_connection.schema,
table=table_conf.target_name,
query=target_count_query,
options=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,12 @@ def recon_aggregate_one(
recon_process_duration = ReconcileProcessDuration(start_ts=str(datetime.now(tz=timezone.utc)), end_ts=None)
try:
src_schema, tgt_schema = TriggerReconService.get_schemas(
reconciler.source, reconciler.target, normalized_table_conf, reconcile_config.database_config, True
reconciler.source,
reconciler.target,
normalized_table_conf,
reconcile_config.source,
reconcile_config.target,
True,
)

table_reconcile_agg_output_list = reconciler.reconcile_aggregates(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@

from databricks.sdk import WorkspaceClient

from databricks.labs.lakebridge.config import ReconcileConfig, TableRecon, DatabaseConfig
from databricks.labs.lakebridge.config import (
ReconcileConfig,
TableRecon,
SourceConnectionConfig,
TargetConnectionConfig,
)
from databricks.labs.lakebridge.reconcile import utils
from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource
from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException, ReconciliationException
Expand Down Expand Up @@ -87,7 +92,8 @@ def create_recon_dependencies(
reconciler = Reconciliation(
source,
target,
reconcile_config.database_config,
reconcile_config.source,
reconcile_config.target,
report_type,
SchemaCompare(spark=spark),
get_dialect(source_dialect),
Expand All @@ -98,7 +104,8 @@ def create_recon_dependencies(
)

recon_capture = ReconCapture(
database_config=reconcile_config.database_config,
source_connection=reconcile_config.source,
target_connection=reconcile_config.target,
recon_id=recon_id,
report_type=report_type,
source_dialect=get_dialect(source_dialect),
Expand Down Expand Up @@ -142,7 +149,7 @@ def _do_recon_one(reconciler: Reconciliation, reconcile_config: ReconcileConfig,

try:
src_schema, tgt_schema = TriggerReconService.get_schemas(
reconciler.source, reconciler.target, table_conf, reconcile_config.database_config, True
reconciler.source, reconciler.target, table_conf, reconcile_config.source, reconcile_config.target, True
)
except DataSourceRuntimeException as e:
schema_reconcile_output = SchemaReconcileOutput(is_valid=False, exception=str(e))
Expand Down Expand Up @@ -173,19 +180,20 @@ def get_schemas(
source: DataSource,
target: DataSource,
table_conf: Table,
database_config: DatabaseConfig,
source_connection: SourceConnectionConfig,
target_connection: TargetConnectionConfig,
normalize: bool,
) -> tuple[list[Schema], list[Schema]]:
src_schema = source.get_schema(
catalog=database_config.source_catalog,
schema=database_config.source_schema,
catalog=source_connection.catalog,
schema=source_connection.schema,
table=table_conf.source_name,
normalize=normalize,
)

tgt_schema = target.get_schema(
catalog=database_config.target_catalog,
schema=database_config.target_schema,
catalog=target_connection.catalog,
schema=target_connection.schema,
table=table_conf.target_name,
normalize=normalize,
)
Expand Down
Loading
Loading