diff --git a/docs/lakebridge/docs/reconcile/configuration.mdx b/docs/lakebridge/docs/reconcile/configuration.mdx index 4218404573..209291dcb8 100644 --- a/docs/lakebridge/docs/reconcile/configuration.mdx +++ b/docs/lakebridge/docs/reconcile/configuration.mdx @@ -91,6 +91,83 @@ Place the file in `.lakebridge/` in your Databricks workspace home folder. --- +## Fingerprint Pre-check (Experimental) + +The fingerprint pre-check is an opt-in optimisation that, when source and target are +already in sync, replaces the per-row hash-and-join pipeline with a sub-bucket-level +aggregate comparison. On a 1 M-row Redshift fixture in MATCH state, the precheck +short-circuits the recon at roughly **30–45 % of the v3 baseline wall-clock**; on +MISMATCH it surgically fetches only the differing rows instead of streaming the full +column set across JDBC. When neither the MATCH nor the surgical-fetch path applies the +recon falls open to the full pipeline so correctness is never compromised. + +### When to enable + +- The source / target tables are expected to be **mostly identical** (post-migration, + ongoing CDC) — the precheck pays off most when MATCH is the common outcome. +- The runtime targets **DBR 17.3 or later**. The Stage-2 source-side fetch uses + Databricks' `remote_query()` table-valued function which requires DBR 17.3+. On + earlier runtimes, leave the flag off — the eligibility gate doesn't yet check DBR + version, so the JDBC call would fail mid-fetch and trigger the fail-open path. +- The source dialect has a registered fingerprint query builder. Today only + `redshift` is wired. Other dialects fall through to the full pipeline silently. + +### Configuration + +Add to `recon_config_*.json` at the top level: + +```yaml +fingerprint_precheck: true +# Optional. False (default) keeps '' distinct from NULL in fingerprint hashing, +# matching the row-hash compare path in expression_generator. Flip to True only if +# your data treats '' and NULL as the same value AND you have audited the impact; +# the flag is wired symmetrically across both source-side Redshift SQL and the +# target-side Spark Stage-1 / Stage-2 serialisers so the two cannot drift. +fingerprint_treat_empty_as_null: false +# Optional. When set, overrides the target Delta DESCRIBE DETAIL numRecords +# lookup used to pick the sub-bucket tier. Use this when the target is non-Delta +# (DESCRIBE DETAIL returns no numRecords) or when Delta stats are stale and the +# heuristic lands on a tier that is too coarse / too fine for your workload. +# Values <= 0 are treated as unset. +fingerprint_row_count_override: 250000000 +``` + +### Eligibility rules + +The pre-check declines (and the recon proceeds via the full pipeline) when **any** of +the conditions below hold. These reasons are recorded in +`recon_metrics.fingerprint_metrics.ineligibility_reason` so adoption queries can +distinguish "feature off" from "feature on but table ineligible". + +| Reason value | Meaning | +| --- | --- | +| `flag_disabled` | `fingerprint_precheck` is false (default). | +| `unsupported_dialect` | The source dialect has no registered fingerprint query builder (today: anything other than `redshift`). | +| `report_type_not_data` | `report_type` is `schema` (precheck operates on data-level reconciles). | +| `no_join_columns` | `join_columns` is empty. The precheck needs primary-key columns to disambiguate culprit rows during Stage-2. | +| `filters_configured` | `filters.source` or `filters.target` is set. The precheck does not project filter predicates into Stage-1 aggregates yet. | +| `transforms_configured` | `transformations` is set. Custom transformations are not supported on the fingerprint hash path. | +| `column_thresholds_configured` | `column_thresholds` is set. Threshold semantics conflict with exact-hash comparison. | +| `table_thresholds_configured` | `table_thresholds` is set. Same rationale as column thresholds. | + +A separate runtime gate validates that every `column_mapping.target_name` resolves to +a real target column before issuing the source-side scan — a typo in the mapping is +caught at eligibility time, not after Stage-1 has already pulled across JDBC. + +### Tuning and observability + +- The pre-check selects a sub-bucket / bucket count adaptively from the target's + Delta `numRecords` (DESCRIBE DETAIL). On a non-Delta target, or when the metric + is missing, it falls back to a static `(1 048 576, 32 768)` pair. Override + explicitly via `ReconcileConfig.fingerprint_row_count_override` (an + approximate target row count) when the heuristic picks a tier that you can + show is wrong for your workload. +- Every recon writes a `fingerprint_metrics` named-struct to + `recon_metrics.fingerprint_metrics` regardless of eligibility, so adoption, + fall-open rate, and verdict distribution can be tracked from one query. + +--- + ## TABLE Config Schema diff --git a/docs/lakebridge/docs/reconcile/index.mdx b/docs/lakebridge/docs/reconcile/index.mdx index 1be5e46144..b5ce604157 100644 --- a/docs/lakebridge/docs/reconcile/index.mdx +++ b/docs/lakebridge/docs/reconcile/index.mdx @@ -82,6 +82,18 @@ The User configuring reconcile must have permission to: - `USE CATALOG` and `CREATE SCHEMA` on the target catalog - `CREATE VOLUME` if using a pre-existing schema on a serverless cluster +### Runtime requirements + +- The job / interactive cluster running reconcile must be on **DBR 15.4 LTS or later** + for the standard data-comparison path. +- If `fingerprint_precheck` is enabled (see + [Configuration Reference → Fingerprint Pre-check](/docs/reconcile/configuration#fingerprint-pre-check-experimental)), + the cluster must be on **DBR 17.3 or later**. The Stage-2 source-side fetch uses + Databricks' `remote_query()` table-valued function, which became available on DBR + 17.3. On earlier runtimes, leave the flag off; otherwise the JDBC call fails + mid-fetch and the recon falls open to the full pipeline (correct, but the precheck + buys you nothing while paying for itself in cluster time). + ### Serverless cluster support Reconcile automatically detects the cluster type and optimizes intermediate data persistence accordingly: diff --git a/src/databricks/labs/lakebridge/config.py b/src/databricks/labs/lakebridge/config.py index 5f0e4d5ce8..d8b3881ad7 100644 --- a/src/databricks/labs/lakebridge/config.py +++ b/src/databricks/labs/lakebridge/config.py @@ -280,13 +280,27 @@ class ReconcileJobConfig: @dataclass class ReconcileConfig: __file__ = "reconcile.yml" - __version__ = 2 + __version__ = 3 report_type: str source: SourceConnectionConfig target: TargetConnectionConfig metadata_config: ReconcileMetadataConfig job_overrides: ReconcileJobConfig | None = None + fingerprint_precheck: bool = False + # When True, fingerprint hashing collapses '' to NULL on BOTH source and target + # sides — flipped here so the Stage-1 / Stage-2 serialisers cannot drift apart + # (without this knob, target call sites silently kept the function default of + # False while source picked up a constant override). Staying False matches the + # row-hash compare path in ``expression_generator``. + fingerprint_treat_empty_as_null: bool = False + # Optional explicit row count for fingerprint tier selection. When set, + # overrides Delta ``DESCRIBE DETAIL`` numRecords lookup so customers whose + # target is non-Delta (or whose Delta stats are stale) can pick the right + # sub-bucket tier without waiting for a full COUNT(*). ``None`` keeps the + # default heuristic. Values ``<= 0`` are treated as "unset" by + # ``fetch_target_row_count``. + fingerprint_row_count_override: int | None = None @classmethod def v1_migrate(cls, raw: dict[str, Any]) -> dict[str, Any]: @@ -314,6 +328,25 @@ def v1_migrate(cls, raw: dict[str, Any]) -> dict[str, Any]: raw["version"] = 2 return raw + @classmethod + def v2_migrate(cls, raw: dict[str, Any]) -> dict[str, Any]: + """v2 → v3: introduce the source-agnostic ``fingerprint_precheck`` flag. + + Older field names (``redshift_fingerprint_precheck``, ``use_fingerprint_precheck``) + from internal pre-v2 deployments are folded into the new flag if present; + otherwise the field defaults to ``False`` so existing v2 configs keep their + current behaviour. + """ + if "fingerprint_precheck" not in raw: + for legacy in ("redshift_fingerprint_precheck", "use_fingerprint_precheck"): + if legacy in raw: + raw["fingerprint_precheck"] = raw.pop(legacy) + break + for legacy in ("redshift_fingerprint_precheck", "use_fingerprint_precheck"): + raw.pop(legacy, None) + raw["version"] = 3 + return raw + @property def database_config(self) -> DatabaseConfig: """TODO remove. this was kept for backwards compatibility while migrating to ReconcileConfig v2""" diff --git a/src/databricks/labs/lakebridge/reconcile/fingerprint/__init__.py b/src/databricks/labs/lakebridge/reconcile/fingerprint/__init__.py new file mode 100644 index 0000000000..8376ca6869 --- /dev/null +++ b/src/databricks/labs/lakebridge/reconcile/fingerprint/__init__.py @@ -0,0 +1 @@ +"""Fingerprint-accelerated reconciliation.""" diff --git a/src/databricks/labs/lakebridge/reconcile/fingerprint/constants.py b/src/databricks/labs/lakebridge/reconcile/fingerprint/constants.py new file mode 100644 index 0000000000..4b46e93868 --- /dev/null +++ b/src/databricks/labs/lakebridge/reconcile/fingerprint/constants.py @@ -0,0 +1,98 @@ +"""Shared constants and SQL helpers for fingerprint detection and row filtering.""" + +from __future__ import annotations + +# Must match the row-hash path's NULL stand-in (``'_null_recon_'`` literal in +# ``reconcile/query_builder/expression_generator.py``). Fingerprint and row-hash +# both encode NULLs into the per-column hash payload before MD5/SHA — picking a +# different stand-in here would alias real data ``'_null_recon_'`` with NULL on +# only one side and produce the inverse alias on the other, so any row that +# happens to carry either literal would be silently misclassified by Stage-1. +# A unit test pins this to the row-hash literal so a future drift fails CI +# rather than the reconcile. +NULL_SENTINEL = "_null_recon_" + +# chr(1) — column separator inside the MD5 concat. Rendered three ways: +# Redshift SQL: CHR(1) +# Spark SQL: CHAR(1) +# Python / Spark DataFrame: "\x01" +SEPARATOR_PYTHON = "\x01" +SEPARATOR_REDSHIFT_SQL = "CHR(1)" +SEPARATOR_SPARK_SQL = "CHAR(1)" + +# Static defaults retained for backwards compatibility and as the fallback when the +# adaptive selector has no row count to work with. +SUB_BUCKET_COUNT = 1_048_576 # 1M sub-buckets +BUCKET_COUNT = 32_768 + +# Adaptive tier table. Each entry: (max_row_count_inclusive, sub_bucket_count, bucket_count). +# Last entry's max_row_count is None and clamps everything larger. Sub-bucket counts are +# powers of 2 to keep MOD distribution clean; bucket count = sub_bucket_count / 1024. +SUB_BUCKET_TIERS: tuple[tuple[int | None, int, int], ...] = ( + (50_000, 16_384, 128), # < 50K + (500_000, 262_144, 512), # 50K – 500K + (50_000_000, 1_048_576, 1_024), # 500K – 50M + (500_000_000, 2_097_152, 2_048), # 50M – 500M + (5_000_000_000, 4_194_304, 4_096), # 500M – 5B + (50_000_000_000, 8_388_608, 8_192), # 5B – 50B + (None, 16_777_216, 16_384), # 50B+ +) + + +def pick_sub_bucket_count(row_count: int | None) -> tuple[int, int]: + """Select (sub_bucket_count, bucket_count) for ``row_count``. + + Falls back to (SUB_BUCKET_COUNT, BUCKET_COUNT) when the count is unknown or + non-positive, so callers can pass None safely. + + >>> pick_sub_bucket_count(10_000) + (16384, 128) + >>> pick_sub_bucket_count(100_000_000) + (2097152, 2048) + >>> pick_sub_bucket_count(None) + (1048576, 32768) + """ + if row_count is None or row_count <= 0: + return SUB_BUCKET_COUNT, BUCKET_COUNT + for max_row_count, sub_buckets, buckets in SUB_BUCKET_TIERS: + if max_row_count is None or row_count <= max_row_count: + return sub_buckets, buckets + return SUB_BUCKET_COUNT, BUCKET_COUNT + + +def build_fingerprint_where_clause( + sb_expr: str, + rh1_expr: str, + solved_hashes: dict[int, list[int]], + unsolved_sb_ids: list[int], +) -> str: + """Build the WHERE body (no ``WHERE``, no trailing alias) for a filtered fetch. + + Emits the union form ``(sb_expr IN (..) AND rh1_expr IN (..)) [OR sb_expr IN (..)]``. + The form is mathematically equivalent to per-sub-bucket disjuncts because + ``sb_id = ABS(MOD(rh1, N))`` is invariant, but stays ``O(|sb_expr| + |IN list|)`` + instead of ``O(k · |sb_expr|)`` so it stays under Redshift's 16 MB statement + limit even on workloads with millions of solved sub-buckets. + + Raises ``ValueError`` when both filter inputs are empty: callers must gate the + fetch (eligibility check in the orchestrator) before reaching this helper. An + empty result here would interpolate to ``WHERE )`` downstream — fail-loud beats + silently emitting a syntactically broken query that fail-open would mask. + """ + if not solved_hashes and not unsolved_sb_ids: + raise ValueError( + "build_fingerprint_where_clause requires at least one of solved_hashes " + "or unsolved_sb_ids to be non-empty; the empty case must be filtered " + "out by the caller before issuing a fetch." + ) + conditions: list[str] = [] + # Sort all IN-list operands for deterministic SQL across dict / list iteration + # orders — helps query-plan caching and unit-test diffing. + if solved_hashes: + sb_list = ", ".join(str(sb_id) for sb_id in sorted(solved_hashes)) + hash_list = ", ".join(str(h) for h in sorted({h for hs in solved_hashes.values() for h in hs})) + conditions.append(f"({sb_expr} IN ({sb_list}) AND {rh1_expr} IN ({hash_list}))") + if unsolved_sb_ids: + sb_list = ", ".join(str(sb_id) for sb_id in sorted(unsolved_sb_ids)) + conditions.append(f"{sb_expr} IN ({sb_list})") + return " OR ".join(conditions) diff --git a/src/databricks/labs/lakebridge/reconcile/fingerprint/engine.py b/src/databricks/labs/lakebridge/reconcile/fingerprint/engine.py new file mode 100644 index 0000000000..27d82fd70a --- /dev/null +++ b/src/databricks/labs/lakebridge/reconcile/fingerprint/engine.py @@ -0,0 +1,347 @@ +import logging +import math +from dataclasses import dataclass, field +from typing import Literal + +from pyspark.sql import DataFrame +from pyspark.sql import functions as F + +DetectionVerdict = Literal["MATCH", "MISMATCH"] + +logger = logging.getLogger(__name__) + +_SYSTEMIC_GUARD_THRESHOLD = 0.15 +# Driver-OOM guard: collecting millions of mismatched sub-buckets is unsafe. +_MAX_MISMATCHED_SUBBUCKETS_TO_SOLVE = 50_000 +# MD5 8-hex extraction yields 32-bit unsigned values in [0, 0xFFFFFFFF]. +_MAX_RH_VALUE = 0xFFFFFFFF + + +@dataclass +class SolveResult: + sub_bucket_id: int + source_hashes: list[int] + target_hashes: list[int] + + +@dataclass +class DetectionResult: + verdict: DetectionVerdict + solved_results: list[SolveResult] = field(default_factory=list) + unsolved_sb_ids: list[int] = field(default_factory=list) + total_mismatched_sbs: int = 0 + systemic_mismatch: bool = False + + +def detect_and_solve( + source_agg_df: DataFrame, + target_agg_df: DataFrame, +) -> DetectionResult: + """Compare source and target sub-bucket aggregates and solve for culprit hashes. + + Returns a DetectionResult with verdict, solved hashes, and unsolved sub-bucket IDs + for brute-force fetch. + + Stage-1 materialization: the joined aggregate is cached before any action so + the agg + the subsequent ``mismatched.collect()`` reuse the same physical + rows instead of re-pulling from Redshift / Delta. The cache is released on + every return path. Without this cache the function would fire three separate + Spark jobs (joined.count, mismatched.count, mismatched.collect), each + re-evaluating the Stage-1 read; with it, two jobs run (one agg + one + collect) and both read the cached frame. + """ + joined = ( + source_agg_df.alias("src") + .join( + target_agg_df.alias("tgt"), + on=["sub_bucket_id", "bucket_id"], + how="full", + ) + .select( + F.coalesce(F.col("src.sub_bucket_id"), F.col("tgt.sub_bucket_id")).alias("sub_bucket_id"), + F.coalesce(F.col("src.bucket_id"), F.col("tgt.bucket_id")).alias("bucket_id"), + F.coalesce(F.col("src.cnt"), F.lit(0)).alias("src_cnt"), + F.coalesce(F.col("tgt.cnt"), F.lit(0)).alias("tgt_cnt"), + F.coalesce(F.col("src.p1"), F.lit(0)).alias("src_p1"), + F.coalesce(F.col("tgt.p1"), F.lit(0)).alias("tgt_p1"), + F.coalesce(F.col("src.p2"), F.lit(0)).alias("src_p2"), + F.coalesce(F.col("tgt.p2"), F.lit(0)).alias("tgt_p2"), + F.coalesce(F.col("src.p1_rh2"), F.lit(0)).alias("src_p1_rh2"), + F.coalesce(F.col("tgt.p1_rh2"), F.lit(0)).alias("tgt_p1_rh2"), + F.coalesce(F.col("src.p2_rh2"), F.lit(0)).alias("src_p2_rh2"), + F.coalesce(F.col("tgt.p2_rh2"), F.lit(0)).alias("tgt_p2_rh2"), + ) + ).cache() + + try: # pylint: disable=too-many-try-statements # try wraps the full detect+solve so finally can unpersist the cache once + # All five signals (cnt + p1/p2/p1_rh2/p2_rh2) must agree before we declare MATCH. + # Dropping rh2 from the OR-chain turns MD5 collisions into silent false MATCHes. + mismatch_condition = ( + (F.col("src_cnt") != F.col("tgt_cnt")) + | (F.col("src_p1") != F.col("tgt_p1")) + | (F.col("src_p2") != F.col("tgt_p2")) + | (F.col("src_p1_rh2") != F.col("tgt_p1_rh2")) + | (F.col("src_p2_rh2") != F.col("tgt_p2_rh2")) + ) + mismatched = joined.filter(mismatch_condition) + + # Single-agg pattern: one Spark job computes both ``total_sbs`` + # (denominator for the systemic-ratio guard) and ``mismatch_count`` + # (verdict + systemic-count guard). Calling joined.count() and + # mismatched.count() separately would each trigger a full Stage-1 + # re-evaluation; folding into one agg halves Stage-1 wall-clock on every + # run regardless of verdict. + counts_row = joined.agg( + F.count("*").alias("total_sbs"), + F.sum(F.when(mismatch_condition, F.lit(1)).otherwise(F.lit(0)).cast("long")).alias("mismatch_count"), + ).collect()[0] + total_sbs = int(counts_row["total_sbs"] or 0) + mismatch_count = int(counts_row["mismatch_count"] or 0) + + if mismatch_count == 0: + logger.info("Fingerprint detection: MATCH — all sub-buckets identical") + return DetectionResult(verdict="MATCH") + + mismatch_ratio = mismatch_count / max(total_sbs, 1) + logger.info( + f"Fingerprint detection: {mismatch_count}/{total_sbs} sub-buckets mismatched ({mismatch_ratio * 100:.1f}%)" + ) + + if mismatch_ratio > _SYSTEMIC_GUARD_THRESHOLD or mismatch_count > _MAX_MISMATCHED_SUBBUCKETS_TO_SOLVE: + logger.warning( + f"Fingerprint: systemic mismatch ({mismatch_ratio * 100:.1f}% > " + f"{_SYSTEMIC_GUARD_THRESHOLD * 100:.0f}% or {mismatch_count} > " + f"{_MAX_MISMATCHED_SUBBUCKETS_TO_SOLVE}) — falling through" + ) + return DetectionResult( + verdict="MISMATCH", + total_mismatched_sbs=mismatch_count, + systemic_mismatch=True, + ) + + # ``mismatched.collect()`` reads from the cached ``joined`` — no re-pull from + # Redshift / Delta. Bounded by ``_MAX_MISMATCHED_SUBBUCKETS_TO_SOLVE`` so the + # driver-side list never exceeds 50K rows. + mismatched_rows = mismatched.collect() + + solved_results: list[SolveResult] = [] + unsolved_sb_ids: list[int] = [] + + for row in mismatched_rows: + # Spark aggregates p1/p2/p1_rh2/p2_rh2 are DecimalType(38, 0) (see + # spark_target._hash_agg_exprs — Decimal is required to avoid 64-bit + # overflow when summing rh*rh). They surface here as decimal.Decimal, + # which math.isqrt() in _solve_d2_extras rejects with + # ``TypeError: 'decimal.Decimal' object cannot be interpreted as an integer``. + # Cast to int up front so every downstream solver works on native ints. + sb_id = int(row["sub_bucket_id"]) + d_cnt = int(row["src_cnt"]) - int(row["tgt_cnt"]) + d_p1 = int(row["src_p1"]) - int(row["tgt_p1"]) + d_p2 = int(row["src_p2"]) - int(row["tgt_p2"]) + d_p1_rh2 = int(row["src_p1_rh2"]) - int(row["tgt_p1_rh2"]) + d_p2_rh2 = int(row["src_p2_rh2"]) - int(row["tgt_p2_rh2"]) + + result = _solve_sub_bucket(sb_id, d_cnt, d_p1, d_p2, d_p1_rh2, d_p2_rh2) + if result is not None: + solved_results.append(result) + else: + unsolved_sb_ids.append(sb_id) + + logger.info( + f"Fingerprint solver: {len(solved_results)} solved, " + f"{len(unsolved_sb_ids)} unsolved out of {len(mismatched_rows)} mismatched sub-buckets" + ) + + return DetectionResult( + verdict="MISMATCH", + solved_results=solved_results, + unsolved_sb_ids=unsolved_sb_ids, + total_mismatched_sbs=mismatch_count, + ) + finally: + # Always release the cache — every return path above lands here, including + # the systemic-mismatch fallback and any exception bubbling up to the caller. + joined.unpersist() + + +def _solve_sub_bucket( + sb_id: int, + d_cnt: int, + d_p1: int, + d_p2: int, + d_p1_rh2: int, + d_p2_rh2: int, +) -> SolveResult | None: + """Solve d=1 / d=2 cases for one mismatched sub-bucket; return None when unsolvable.""" + abs_d = abs(d_cnt) + if abs_d == 1: + return solve_d1(sb_id, d_cnt, d_p1, d_p2, d_p1_rh2, d_p2_rh2) + if abs_d == 0 and d_p1 != 0: + return solve_d2_swap(sb_id, d_p1, d_p2, d_p1_rh2, d_p2_rh2) + if abs_d == 2: + return _solve_d2_extras(sb_id, d_cnt, d_p1, d_p2, d_p1_rh2, d_p2_rh2) + return None + + +def solve_d1(sb_id: int, d_cnt: int, d_p1: int, d_p2: int, d_p1_rh2: int, d_p2_rh2: int) -> SolveResult | None: + """Solve d=1: one extra row on one side. + + culprit_rh = abs(d_p1); verified by d_p2 == rh^2 * sign(d_cnt) on both rh1 and + rh2 channels. Range check rejects values outside the 32-bit MD5-extraction band. + """ + rh1 = abs(d_p1) + + # Without the p2 verification, two rows that cancel in p1 (e.g. 3+7 == 1+9) would + # be falsely "solved" with a wrong hash. + if d_p2 != rh1 * rh1 * d_cnt: + return None + if rh1 < 0 or rh1 > _MAX_RH_VALUE: + return None + + rh2 = abs(d_p1_rh2) + if d_p2_rh2 != rh2 * rh2 * d_cnt: + return None + if rh2 < 0 or rh2 > _MAX_RH_VALUE: + return None + + if d_cnt > 0: + return SolveResult(sub_bucket_id=sb_id, source_hashes=[rh1], target_hashes=[]) + return SolveResult(sub_bucket_id=sb_id, source_hashes=[], target_hashes=[rh1]) + + +def solve_d2_swap( + sb_id: int, + d_p1: int, + d_p2: int, + d_p1_rh2: int, + d_p2_rh2: int, +) -> SolveResult | None: + """Solve d=2 swap: same row count, one row's content changed. + + Source has h_old, target has h_new. d_p1 = h_old - h_new, and + d_p2 = h_old^2 - h_new^2 = (h_old - h_new)(h_old + h_new). Solve for both + and dual-slice-verify on rh2. + """ + if d_p1 == 0: + return None + if d_p2 % d_p1 != 0: + return None + + h_sum = d_p2 // d_p1 + # Parity guard: floor-division loses a bit otherwise and produces a wrong root. + if (d_p1 + h_sum) % 2 != 0: + return None + h_old = (d_p1 + h_sum) // 2 + h_new = h_sum - h_old + + if h_old - h_new != d_p1 or h_old * h_old - h_new * h_new != d_p2: + return None + if not (0 <= h_old <= _MAX_RH_VALUE and 0 <= h_new <= _MAX_RH_VALUE): + return None + # Sign-product guard: real 1-for-1 swaps have non-negative roots on both sides. + if h_old * h_new < 0: + return None + if not _cross_verify_d2_swap(d_p1_rh2, d_p2_rh2): + return None + + return SolveResult(sub_bucket_id=sb_id, source_hashes=[h_old], target_hashes=[h_new]) + + +def _solve_d2_extras( + sb_id: int, + d_cnt: int, + d_p1: int, + d_p2: int, + d_p1_rh2: int, + d_p2_rh2: int, +) -> SolveResult | None: + """Solve d=2 extras: two extra rows on one side. + + Sign-adjusts the deltas (target-extras case has negative d_p1/d_p2), then solves + the quadratic ``x^2 - sum_h*x + product = 0``. + """ + sign = 1 if d_cnt > 0 else -1 + sum_h = d_p1 * sign + sum_sq = d_p2 * sign + + if sum_sq < 0: + return None + + product_2 = sum_h * sum_h - sum_sq + if product_2 < 0 or product_2 % 2 != 0: + return None + product = product_2 // 2 + + discriminant = sum_h * sum_h - 4 * product + if discriminant < 0: + return None + + sqrt_disc = _isqrt(discriminant) + if sqrt_disc is None: + return None + if (sum_h + sqrt_disc) % 2 != 0: + return None + + hash_a = (sum_h + sqrt_disc) // 2 + hash_b = sum_h - hash_a + + if hash_a + hash_b != sum_h or hash_a * hash_a + hash_b * hash_b != sum_sq: + return None + if not (0 <= hash_a <= _MAX_RH_VALUE and 0 <= hash_b <= _MAX_RH_VALUE): + return None + if not _cross_verify_d2_extras(d_cnt, d_p1_rh2, d_p2_rh2): + return None + + # Repeated root means a single culprit hash appearing twice — emit it once so the + # row-fetch phase doesn't issue the same predicate twice. + hashes = [hash_a] if hash_a == hash_b else sorted([hash_a, hash_b]) + if d_cnt > 0: + return SolveResult(sub_bucket_id=sb_id, source_hashes=hashes, target_hashes=[]) + return SolveResult(sub_bucket_id=sb_id, source_hashes=[], target_hashes=hashes) + + +def _cross_verify_d2_swap(d_p1_rh2: int, d_p2_rh2: int) -> bool: + """Independently solve the d=2 swap on the rh2 channel and confirm valid roots.""" + if d_p1_rh2 == 0: + return d_p2_rh2 == 0 + if d_p2_rh2 % d_p1_rh2 != 0: + return False + h_sum = d_p2_rh2 // d_p1_rh2 + if (d_p1_rh2 + h_sum) % 2 != 0: + return False + h_old = (d_p1_rh2 + h_sum) // 2 + h_new = h_sum - h_old + if h_old - h_new != d_p1_rh2: + return False + return 0 <= h_old <= _MAX_RH_VALUE and 0 <= h_new <= _MAX_RH_VALUE + + +def _cross_verify_d2_extras(d_cnt: int, d_p1_rh2: int, d_p2_rh2: int) -> bool: + """Independently solve the d=2 extras quadratic on the rh2 channel.""" + sign = 1 if d_cnt > 0 else -1 + sum_h = d_p1_rh2 * sign + sum_sq = d_p2_rh2 * sign + if sum_sq < 0: + return False + product_2 = sum_h * sum_h - sum_sq + if product_2 < 0 or product_2 % 2 != 0: + return False + discriminant = sum_h * sum_h - 4 * (product_2 // 2) + if discriminant < 0: + return False + sqrt_disc = _isqrt(discriminant) + if sqrt_disc is None: + return False + if (sum_h + sqrt_disc) % 2 != 0: + return False + h_a = (sum_h + sqrt_disc) // 2 + h_b = sum_h - h_a + return 0 <= h_a <= _MAX_RH_VALUE and 0 <= h_b <= _MAX_RH_VALUE + + +def _isqrt(value: int) -> int | None: + """Return the integer square root if ``value`` is a perfect square, else None.""" + if value < 0: + return None + root = math.isqrt(value) + return root if root * root == value else None diff --git a/src/databricks/labs/lakebridge/reconcile/fingerprint/exceptions.py b/src/databricks/labs/lakebridge/reconcile/fingerprint/exceptions.py new file mode 100644 index 0000000000..fd3932d263 --- /dev/null +++ b/src/databricks/labs/lakebridge/reconcile/fingerprint/exceptions.py @@ -0,0 +1,30 @@ +"""Exception hierarchy for the fingerprint module. + +Lakebridge's fingerprint pre-check is opt-in and must never block the full reconcile +pipeline. The trigger layer catches these (plus external IO errors raised by Spark / +JDBC) and falls through to the legacy hash+JOIN path on any failure. +""" + + +class FingerprintError(Exception): + """Base for errors raised by the fingerprint module.""" + + +class UnsupportedDataSourceError(FingerprintError, ValueError): + """No registered FingerprintQueryBuilder for the requested data source. + + Inherits ValueError for backwards compatibility — callers that ``except ValueError`` + on the dispatch lookup still catch this; new callers can target ``FingerprintError``. + """ + + +class UnmappedTargetColumnMappingError(FingerprintError): + """A ``column_mapping`` entry references a target name that doesn't exist on the target. + + Raised by ``align_columns`` so the trigger layer can record + ``IneligibilityReason.UNMAPPED_TARGET_COLUMN_MAPPING`` on + ``recon_metrics.fingerprint_metrics.ineligibility_reason`` instead of a + silent ``None`` fallback. The trigger catches this *before* the broader + ``FingerprintError`` branch so the metric reports an ineligible verdict + (a config issue), not a precheck failure. + """ diff --git a/src/databricks/labs/lakebridge/reconcile/fingerprint/fingerprint_hash_columns.py b/src/databricks/labs/lakebridge/reconcile/fingerprint/fingerprint_hash_columns.py new file mode 100644 index 0000000000..f2230b1854 --- /dev/null +++ b/src/databricks/labs/lakebridge/reconcile/fingerprint/fingerprint_hash_columns.py @@ -0,0 +1,55 @@ +"""Resolve which columns participate in fingerprint detection. + +Mirrors HashQueryBuilder.build_query column resolution exactly so the fingerprint +hashes the same set, in the same order, the row-hash path would. The two +implementations are pinned together by regression tests, but the duplication is +a known maintenance hazard — extracting a shared ``compute_hash_columns`` helper +into ``query_builder/`` is deferred until the second source dialect lands and +informs the right abstraction boundary (forcing the shape on Redshift alone +risks over-fitting). +""" + +from __future__ import annotations + +from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils +from databricks.labs.lakebridge.reconcile.recon_config import Table, Schema + + +def _strip(name: str) -> str: + """Return ``name`` lowercased and without ANSI/source delimiters. + + Bridges the two naming conventions: ``Table.join_columns`` are bare, + ``Schema.column_name`` (and ``get_select_columns`` output) are ANSI-delimited. + Set operations on the two would otherwise produce duplicate entries. + """ + return DialectUtils.unnormalize_identifier(name).lower() + + +def hash_columns_ordered_for_reconcile( + table_conf: Table, + schema: list[Schema], + layer: str, + data_source: DataSource, +) -> list[str]: + """Mirror HashQueryBuilder hash column set + sort order (case-insensitive sort_key).""" + join_keys = {_strip(c): c for c in (table_conf.join_columns or [])} + select_keys: dict[str, str] = {} + for col_name in table_conf.get_select_columns(schema, layer): + select_keys.setdefault(_strip(col_name), col_name) + threshold_keys = {_strip(c) for c in table_conf.get_threshold_columns(layer)} + drop_keys = {_strip(c) for c in table_conf.get_drop_columns(layer)} + + merged: dict[str, str] = {} + for stripped, original in {**select_keys, **join_keys}.items(): + if stripped in threshold_keys or stripped in drop_keys: + continue + merged[stripped] = original + + hash_cols_with_sort = [] + for original in merged.values(): + sort_key = DialectUtils.unnormalize_identifier( + data_source.normalize_identifier(original).ansi_normalized + ).lower() + hash_cols_with_sort.append((sort_key, original)) + return [c for _, c in sorted(hash_cols_with_sort, key=lambda x: x[0])] diff --git a/src/databricks/labs/lakebridge/reconcile/fingerprint/metadata.py b/src/databricks/labs/lakebridge/reconcile/fingerprint/metadata.py new file mode 100644 index 0000000000..13235da0c2 --- /dev/null +++ b/src/databricks/labs/lakebridge/reconcile/fingerprint/metadata.py @@ -0,0 +1,159 @@ +"""Persistence-side metadata for the fingerprint pre-check. + +Lives in its own module so ``recon_capture`` can import the dataclass without a circular +import via ``fingerprint.orchestrator``. The dataclass and enum values here are part of +the public Delta schema for ``recon_metrics.fingerprint_metrics`` — keep value strings +stable; renaming any of them breaks downstream dashboards. +""" + +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Literal + +if TYPE_CHECKING: + # Avoid circular import: ``orchestrator`` already imports from this module. + from databricks.labs.lakebridge.reconcile.fingerprint.orchestrator import FingerprintResult + + +class IneligibilityReason(str, Enum): + """Why a table was skipped by the fingerprint pre-check. + + Values surface to ``recon_metrics.fingerprint_metrics.ineligibility_reason``. + Adding new members is additive; renaming or removing breaks dashboards. + """ + + FLAG_DISABLED = "flag_disabled" + UNSUPPORTED_DIALECT = "unsupported_dialect" + REPORT_TYPE_NOT_DATA = "report_type_not_data" + NO_JOIN_COLUMNS = "no_join_columns" + FILTERS_CONFIGURED = "filters_configured" + TRANSFORMS_CONFIGURED = "transforms_configured" + COLUMN_THRESHOLDS_CONFIGURED = "column_thresholds_configured" + TABLE_THRESHOLDS_CONFIGURED = "table_thresholds_configured" + # ``classify_ineligibility`` runs before the precheck without a target schema, + # so it cannot detect this. ``align_columns`` discovers it once ``tgt_schema`` + # is in hand and the trigger layer routes the typed exception through + # ``ineligible(...)`` so an adoption query against + # ``recon_metrics.fingerprint_metrics.ineligibility_reason`` can quantify how + # often a typo or a column-mapping drift skips the precheck. + UNMAPPED_TARGET_COLUMN_MAPPING = "unmapped_target_column_mapping" + + +# Module-level aliases preserved so existing callers keep importing by name. +INELIGIBLE_FLAG_DISABLED = IneligibilityReason.FLAG_DISABLED.value +INELIGIBLE_UNSUPPORTED_DIALECT = IneligibilityReason.UNSUPPORTED_DIALECT.value +INELIGIBLE_REPORT_TYPE_NOT_DATA = IneligibilityReason.REPORT_TYPE_NOT_DATA.value +INELIGIBLE_NO_JOIN_COLUMNS = IneligibilityReason.NO_JOIN_COLUMNS.value +INELIGIBLE_FILTERS_CONFIGURED = IneligibilityReason.FILTERS_CONFIGURED.value +INELIGIBLE_TRANSFORMS_CONFIGURED = IneligibilityReason.TRANSFORMS_CONFIGURED.value +INELIGIBLE_COLUMN_THRESHOLDS_CONFIGURED = IneligibilityReason.COLUMN_THRESHOLDS_CONFIGURED.value +INELIGIBLE_TABLE_THRESHOLDS_CONFIGURED = IneligibilityReason.TABLE_THRESHOLDS_CONFIGURED.value +INELIGIBLE_UNMAPPED_TARGET_COLUMN_MAPPING = IneligibilityReason.UNMAPPED_TARGET_COLUMN_MAPPING.value + + +class FetchPath(str, Enum): + """Stage-2 source-fetch strategy used for one run. + + Values surface to ``recon_metrics.fingerprint_metrics.fetch_path``. + """ + + V1_SANDWICH = "v1_sandwich" + # Historical: persisted by 0.12.4-0.12.7. Kept so old recon_metrics rows continue to + # round-trip through code that imports the constant; current code never emits it. + V2_REDSHIFT_SPLIT = "v2_redshift_split" + + +FETCH_PATH_V1_SANDWICH = FetchPath.V1_SANDWICH.value +FETCH_PATH_V2_REDSHIFT_SPLIT = FetchPath.V2_REDSHIFT_SPLIT.value + + +# Verdict surfaced to recon_metrics.fingerprint_metrics.verdict; a Literal so mypy +# catches typos at edit time. None means "ineligible / disabled / pre-detection". +RunVerdict = Literal["MATCH", "MISMATCH", "FAILED"] + + +@dataclass(frozen=True) +class FingerprintRunMetadata: + """Recorded once per (recon_id, table) on recon_metrics. + + Always written, even for ineligible / disabled runs, so adoption-style queries on + recon_metrics.fingerprint_metrics don't need a LEFT-JOIN to count opt-outs. + + Field semantics: + + - ``eligible``: did the pre-check actually run? + - ``ineligibility_reason``: an IneligibilityReason value when ``eligible=False``. + - ``verdict``: ``"MATCH"`` / ``"MISMATCH"`` / ``"FAILED"`` / None when ineligible. + - ``elapsed_ms``: detection-phase wall-clock; 0 when skipped. + - ``solved_count`` / ``unsolved_sb_count`` / ``total_mismatched_sbs``: solver telemetry + for tuning sub-bucket sizing. + - ``fallback_to_full_pipeline``: True when fingerprint was eligible but didn't + short-circuit (systemic mismatch, missing rows, exception, soft skip). + - ``sub_bucket_count`` / ``bucket_count``: the adaptive tier for this run; 0 when + the pre-check did not run. + - ``target_row_count``: target Delta row count from DESCRIBE DETAIL, or override. + None when both fell through to the static default. + - ``row_count_source``: provenance — one of the RowCountSource values, or None. + - ``fetch_path``: name of the Stage-2 source-fetch strategy. None on MATCH or + ineligible. See FetchPath for stable values. + """ + + eligible: bool = False + ineligibility_reason: str | None = None + verdict: RunVerdict | None = None + elapsed_ms: int = 0 + solved_count: int = 0 + unsolved_sb_count: int = 0 + total_mismatched_sbs: int = 0 + fallback_to_full_pipeline: bool = False + sub_bucket_count: int = 0 + bucket_count: int = 0 + target_row_count: int | None = None + row_count_source: str | None = None + fetch_path: str | None = None + + @classmethod + def ineligible(cls, reason: str) -> "FingerprintRunMetadata": + return cls(eligible=False, ineligibility_reason=reason) + + @classmethod + def disabled(cls) -> "FingerprintRunMetadata": + """Default for non-fingerprint reconciles, so the persisted struct stays uniform.""" + return cls(eligible=False, ineligibility_reason=INELIGIBLE_FLAG_DISABLED) + + @classmethod + def from_result( + cls, + result: "FingerprintResult", + *, + verdict: RunVerdict, + fallback_to_full_pipeline: bool = False, + ) -> "FingerprintRunMetadata": + """Single-site mapping from detection-side ``FingerprintResult`` to the persisted + metadata. Adding a new telemetry field is one line here, not three copy-loop edits + across the orchestrator's MATCH / MISMATCH-fallback / MISMATCH-success branches. + """ + return cls( + eligible=True, + verdict=verdict, + elapsed_ms=result.detection_elapsed_ms, + solved_count=result.solved_count, + unsolved_sb_count=result.unsolved_sb_count, + total_mismatched_sbs=result.total_mismatched_sbs, + fallback_to_full_pipeline=fallback_to_full_pipeline, + sub_bucket_count=result.sub_bucket_count, + bucket_count=result.bucket_count, + target_row_count=result.target_row_count, + row_count_source=result.row_count_source, + fetch_path=result.fetch_path, + ) + + @classmethod + def fallback(cls, *, verdict: RunVerdict | None = None) -> "FingerprintRunMetadata": + """Eligible but no usable ``FingerprintResult`` (precheck declined or raised). + + ``verdict`` is ``"FAILED"`` when the precheck raised, ``None`` when it declined + for a non-error reason (column-resolution skip, systemic mismatch, no solved + buckets) — keeping the verdict unset on the latter so dashboards can distinguish. + """ + return cls(eligible=True, fallback_to_full_pipeline=True, verdict=verdict) diff --git a/src/databricks/labs/lakebridge/reconcile/fingerprint/orchestrator.py b/src/databricks/labs/lakebridge/reconcile/fingerprint/orchestrator.py new file mode 100644 index 0000000000..50092aea8b --- /dev/null +++ b/src/databricks/labs/lakebridge/reconcile/fingerprint/orchestrator.py @@ -0,0 +1,729 @@ +import logging +import time +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass + +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.functions import expr +from sqlglot import Dialect + +from databricks.labs.lakebridge.config import DatabaseConfig +from databricks.labs.lakebridge.reconcile.compare import ( + _HASH_COLUMN_NAME, + capture_mismatch_data_and_columns, + reconcile_data as compare_reconcile_data, +) +from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils +from databricks.labs.lakebridge.reconcile.fingerprint.constants import pick_sub_bucket_count +from databricks.labs.lakebridge.reconcile.fingerprint.engine import ( + DetectionResult, + DetectionVerdict, + detect_and_solve, +) +from databricks.labs.lakebridge.reconcile.fingerprint.exceptions import ( + UnmappedTargetColumnMappingError, + UnsupportedDataSourceError, +) +from databricks.labs.lakebridge.reconcile.fingerprint.metadata import ( + FETCH_PATH_V1_SANDWICH, + INELIGIBLE_COLUMN_THRESHOLDS_CONFIGURED, + INELIGIBLE_FILTERS_CONFIGURED, + INELIGIBLE_FLAG_DISABLED, + INELIGIBLE_NO_JOIN_COLUMNS, + INELIGIBLE_REPORT_TYPE_NOT_DATA, + INELIGIBLE_TABLE_THRESHOLDS_CONFIGURED, + INELIGIBLE_TRANSFORMS_CONFIGURED, + INELIGIBLE_UNSUPPORTED_DIALECT, +) +from databricks.labs.lakebridge.reconcile.fingerprint.query_builders.base import FingerprintQueryBuilder +from databricks.labs.lakebridge.reconcile.fingerprint.query_builders.redshift import RedshiftFingerprintQueryBuilder +from databricks.labs.lakebridge.reconcile.fingerprint.fingerprint_hash_columns import ( + hash_columns_ordered_for_reconcile, +) +from databricks.labs.lakebridge.reconcile.fingerprint.row_count import fetch_target_row_count +from databricks.labs.lakebridge.reconcile.fingerprint.spark_target import ( + build_target_filter_subquery, + compute_target_fingerprint, +) +from databricks.labs.lakebridge.reconcile.query_builder.hash_query import HashQueryBuilder +from databricks.labs.lakebridge.reconcile.recon_capture import AbstractReconIntermediatePersist +from databricks.labs.lakebridge.reconcile.recon_config import Table, Schema +from databricks.labs.lakebridge.reconcile.recon_output_config import ( + DataReconcileOutput, + MismatchOutput, +) + +logger = logging.getLogger(__name__) + +# Fingerprint defaults to keeping '' distinct from NULL so the detection-side serialization +# matches the row-hash path (TRIM in expression_generator does not collapse '' to NULL). +# Flipping this to True silently disagrees with row-hash on every NULL <-> '' flip. +_DEFAULT_TREAT_EMPTY_AS_NULL = False + + +@dataclass(frozen=True) +class ColumnAlignment: + column_mapping: dict[str, str] | None + + +@dataclass(frozen=True) +class FingerprintResult: + verdict: DetectionVerdict + source_rows: DataFrame | None = None + target_rows: DataFrame | None = None + solved_count: int = 0 + unsolved_sb_count: int = 0 + total_mismatched_sbs: int = 0 + detection_elapsed_ms: int = 0 + sub_bucket_count: int = 0 + bucket_count: int = 0 + target_row_count: int | None = None + row_count_source: str | None = None + fetch_path: str | None = None + + +@dataclass(frozen=True) +class _TierSelection: + """Adaptive (sub_bucket_count, bucket_count) for one run. + + Source and target must use identical values or sub-bucket IDs won't align across + the GROUP BY join. + """ + + sub_bucket_count: int + bucket_count: int + target_row_count: int | None + row_count_source: str + + +def classify_ineligibility( + *, + flag_enabled: bool, + data_source: str, + report_type: str, + table_conf: Table, +) -> str | None: + """Return the ineligibility reason for the pre-check, or None when eligible. + + First-match-wins; flag/source-level reasons are surfaced before per-table config so + adoption queries can distinguish "feature off" from "feature on but table ineligible". + """ + if not flag_enabled: + return INELIGIBLE_FLAG_DISABLED + if data_source not in fingerprint_supported_sources(): + return INELIGIBLE_UNSUPPORTED_DIALECT + if report_type not in {"data", "row", "all"}: + return INELIGIBLE_REPORT_TYPE_NOT_DATA + if not table_conf.join_columns: + return INELIGIBLE_NO_JOIN_COLUMNS + if table_conf.filters and (table_conf.filters.source or table_conf.filters.target): + return INELIGIBLE_FILTERS_CONFIGURED + if table_conf.transformations: + return INELIGIBLE_TRANSFORMS_CONFIGURED + if table_conf.column_thresholds: + return INELIGIBLE_COLUMN_THRESHOLDS_CONFIGURED + if table_conf.table_thresholds: + return INELIGIBLE_TABLE_THRESHOLDS_CONFIGURED + return None + + +def align_columns( + table_conf: Table, + src_schema: list[Schema], # pylint: disable=unused-argument + tgt_schema: list[Schema], +) -> ColumnAlignment | None: + """Map a Table config to fingerprint column parameters, or raise/return None if ineligible. + + ``src_schema`` is intentionally part of the public signature for parallelism with the + rest of the fingerprint surface (every other helper threads both schemas). Today only + ``tgt_schema`` is consumed — for validating that every ``column_mapping`` target name + actually exists on the target side. Catching a typo here is cheap; catching it after + Stage-1's source-side Redshift scan ran is not — the Spark ``F.col`` resolution would + raise mid-fetch and burn the JDBC pull. + + On an unmapped ``column_mapping`` target this raises + ``UnmappedTargetColumnMappingError`` so the trigger layer can record the + typed ``IneligibilityReason.UNMAPPED_TARGET_COLUMN_MAPPING`` on the + persisted metric. The defensive guards on filters / transforms / thresholds + keep the legacy ``None`` return — those reasons are already recorded by + ``classify_ineligibility`` upstream so this branch is unreachable in + practice; the guards exist solely for direct unit-test callers. + """ + if table_conf.filters and (table_conf.filters.source or table_conf.filters.target): + return None + if table_conf.transformations: + return None + if table_conf.column_thresholds: + return None + if table_conf.table_thresholds: + return None + + col_map = {cm.source_name: cm.target_name for cm in table_conf.column_mapping or []} + + if col_map: + tgt_cols_bare = {DialectUtils.unnormalize_identifier(s.column_name).lower() for s in tgt_schema} + for src_name, tgt_name in col_map.items(): + if DialectUtils.unnormalize_identifier(tgt_name).lower() not in tgt_cols_bare: + # Raise (rather than silently return ``None``) so the trigger + # layer records ``UNMAPPED_TARGET_COLUMN_MAPPING`` on the + # persisted metric. Without the typed signal an adoption query + # against ``ineligibility_reason`` cannot distinguish this from + # a generic precheck decline. + raise UnmappedTargetColumnMappingError( + f"column_mapping target {tgt_name!r} (mapped from {src_name!r}) " + f"not found in target schema for table {table_conf.source_name!r}" + ) + + return ColumnAlignment( + column_mapping=col_map if col_map else None, + ) + + +def resolve_compare_key_columns(table_conf: Table) -> list[str]: + """Return join columns for compare.reconcile_data. + + For ``row`` report type, compare.reconcile_data replaces keys with hash_value_recon. + """ + return table_conf.join_columns or [] + + +def fingerprint_match_output() -> DataReconcileOutput: + """Zeroed DataReconcileOutput for a confirmed MATCH.""" + return DataReconcileOutput( + mismatch_count=0, + missing_in_src_count=0, + missing_in_tgt_count=0, + mismatch=MismatchOutput(), + missing_in_src=None, + missing_in_tgt=None, + ) + + +def build_mismatch_output( + src_hashed: DataFrame, + tgt_hashed: DataFrame, + key_columns: list[str], + report_type: str, + persistence: AbstractReconIntermediatePersist, +) -> DataReconcileOutput: + """Run compare.reconcile_data on rows that already have hash_value_recon. + + Bug A fix (column-level diff for fingerprint MISMATCH + report_type='all'): + ``compare.reconcile_data`` populates ``mismatch.mismatch_df`` but never + ``mismatch.mismatch_columns``. In the normal (non-fingerprint) path this is + backfilled by ``Reconciliation._get_sample_data`` → + ``capture_mismatch_data_and_columns``, but the fingerprint MISMATCH path + bypasses ``_get_sample_data`` entirely. Because the Stage-2 fetch projects + every hashed column (``project_all_columns=True``), ``src_hashed`` / + ``tgt_hashed`` already carry every hashed column, so we can compute + ``mismatch_columns`` in-place here without a second JDBC pull. Gated on + ``report_type='all'`` + ``mismatch_count > 0`` so fingerprint MATCH and the + zero-mismatch fast-path bear no overhead. + """ + output = compare_reconcile_data( + source=src_hashed, + target=tgt_hashed, + key_columns=key_columns, + report_type=report_type, + persistence=persistence, + ) + + if report_type != "all" or output.mismatch_count == 0: + return output + + # The fingerprint frames carry ``hash_value_recon``; treat it as a + # derived/synthetic column - rows that hash differently are precisely the + # mismatched rows, so leaving it in would always show as "mismatched" and + # inflate ``mismatch_columns`` with a non-source-column. + src_for_capture = src_hashed.drop(_HASH_COLUMN_NAME) if _HASH_COLUMN_NAME in src_hashed.columns else src_hashed + tgt_for_capture = tgt_hashed.drop(_HASH_COLUMN_NAME) if _HASH_COLUMN_NAME in tgt_hashed.columns else tgt_hashed + + capture = capture_mismatch_data_and_columns( + source=src_for_capture, + target=tgt_for_capture, + key_columns=key_columns, + ) + + # Build the wide ``mismatch_df`` consumed by ``recon_capture._create_map_column``. + # ``capture.mismatch_df`` is an INNER JOIN over the Stage-2 fetched subset, + # which means it contains EVERY src/tgt pair that share a key, including + # rows that turned out to match column-by-column (Stage-1 only proves the + # sub-bucket has *some* mismatch; the row-by-row check happens here). + # We: + # 1. Filter to rows where at least one ``_match`` is False (the + # genuine row-level mismatches). + # 2. Append a per-row ``mismatch_columns`` STRING column with the + # comma-separated list of cols where ``_match=False``. This is + # the field downstream tooling and audit harnesses key on; without + # it the recon_details rows look indistinguishable from full-row + # MISMATCHes and the column-level diff is unrecoverable. + final_mismatch_df = capture.mismatch_df + match_cols = [c for c in final_mismatch_df.columns if c.endswith("_match")] if final_mismatch_df is not None else [] + if final_mismatch_df is not None and match_cols: + # ``compare._get_mismatch_df`` builds ``_match`` with bare ``=`` + # (NOT null-safe). For mismatches that involve NULL on one side + # (``NULL <-> value``, ``value <-> NULL``) the resulting ``_match`` is + # NULL not FALSE; for unchanged NULL columns (``NULL <-> NULL``, e.g. + # ``notes`` left untouched while the row mutates ``is_priority``) it is + # ALSO NULL despite being a match. ``NOT NULL`` evaluates to NULL in + # SQL so a naive ``NOT _match`` filter silently drops *every* such + # row, and a naive ``COALESCE(_match, false)`` flips both directions + # (drops legit row, but inflates per-row ``mismatch_columns`` with + # NULL-NULL columns - over-reporting). + # + # The correct fix is null-safe equality (``<=>``): ``NULL <=> NULL`` + # is TRUE (match), ``NULL <=> value`` is FALSE (mismatch). We recompute + # every ``_match`` from the existing ``_base`` / ``_compare`` + # columns using ``<=>``, which yields a non-null BOOLEAN. The + # downstream filter/case-when then work without COALESCE wrappers and + # match the Python-side ``_get_mismatch_columns`` semantics, so the + # table-level ``recon_metrics.column_comparison.mismatch_columns`` and + # the per-row ``recon_details`` agree. + for match_col in match_cols: + stem = match_col[: -len("_match")] + base_col = f"{stem}_base" + compare_col = f"{stem}_compare" + if base_col in final_mismatch_df.columns and compare_col in final_mismatch_df.columns: + final_mismatch_df = final_mismatch_df.withColumn(match_col, expr(f"`{base_col}` <=> `{compare_col}`")) + + not_all_match = " OR ".join(f"NOT `{c}`" for c in match_cols) + diff_case_exprs = ", ".join(f"CASE WHEN NOT `{c}` THEN '{c[: -len('_match')]}' END" for c in match_cols) + final_mismatch_df = final_mismatch_df.filter(expr(not_all_match)).withColumn( + "mismatch_columns", expr(f"concat_ws(',', {diff_case_exprs})") + ) + + return DataReconcileOutput( + mismatch_count=output.mismatch_count, + missing_in_src_count=output.missing_in_src_count, + missing_in_tgt_count=output.missing_in_tgt_count, + missing_in_src=output.missing_in_src, + missing_in_tgt=output.missing_in_tgt, + mismatch=MismatchOutput( + mismatch_df=final_mismatch_df, + mismatch_columns=capture.mismatch_columns, + ), + threshold_output=output.threshold_output, + ) + + +def _resolve_detection_columns( + table_conf: Table, + src_schema: list[Schema], + source: DataSource, +) -> list[Schema] | None: + """Resolve hash columns against the source schema, or None to skip fingerprint.""" + hash_col_names = hash_columns_ordered_for_reconcile(table_conf, src_schema, "source", source) + if not hash_col_names: + logger.warning("Fingerprint: no hash columns resolved — skipping") + return None + + # Schema entries are ANSI-delimited via _map_meta_column; user-supplied join_columns + # are bare. Strip and lowercase on both sides so quoting and casing round-trip. + by_name = {DialectUtils.unnormalize_identifier(s.column_name).lower(): s for s in src_schema} + detection_cols: list[Schema] = [] + for name in hash_col_names: + schema_entry = by_name.get(DialectUtils.unnormalize_identifier(name).lower()) + if schema_entry is None: + logger.warning(f"Fingerprint: hash column '{name}' missing from source schema — skipping") + return None + detection_cols.append(schema_entry) + return detection_cols + + +def _select_tier( + spark: SparkSession, + database_config: DatabaseConfig, + table_conf: Table, + override_row_count: int | None = None, +) -> _TierSelection: + """Pick (sub_bucket_count, bucket_count) from the target Delta row count. + + Falls back to static defaults when the target is non-Delta or stats are missing. + """ + row_count_result = fetch_target_row_count( + spark, + catalog=database_config.target_catalog, + schema=database_config.target_schema, + table=table_conf.target_name, + override_row_count=override_row_count, + ) + sub_bucket_count, bucket_count = pick_sub_bucket_count(row_count_result.row_count) + return _TierSelection( + sub_bucket_count=sub_bucket_count, + bucket_count=bucket_count, + target_row_count=row_count_result.row_count, + row_count_source=row_count_result.source.value, + ) + + +def _run_detection_phase( + source: DataSource, + spark: SparkSession, + database_config: DatabaseConfig, + table_conf: Table, + detection_cols: list[Schema], + column_mapping: dict[str, str] | None, + query_builder: FingerprintQueryBuilder, + tier: _TierSelection, + treat_empty_as_null: bool, +) -> tuple[DetectionResult, int]: + """Run detection aggregates on both sides; return (result, elapsed_ms). + + ``treat_empty_as_null`` flows in from the orchestrator (config-driven via + ``ReconcileConfig.fingerprint_treat_empty_as_null``) so source and target stay in + lockstep — silently disagreeing here causes systemic Stage-1 mismatch on every + NULL/'' cell. + """ + start_time = time.monotonic() + source_detection_sql = query_builder.build_detection_sql( + schema=database_config.source_schema, + table=table_conf.source_name, + columns=detection_cols, + column_mapping=column_mapping, + sub_bucket_count=tier.sub_bucket_count, + bucket_count=tier.bucket_count, + ) + source_agg_df = source.read_data( + catalog=database_config.source_catalog, + schema=database_config.source_schema, + table=table_conf.source_name, + query=source_detection_sql, + options=table_conf.jdbc_reader_options, + ) + target_agg_df = compute_target_fingerprint( + spark=spark, + catalog=database_config.target_catalog, + schema=database_config.target_schema, + table=table_conf.target_name, + columns=detection_cols, + column_mapping=column_mapping, + sub_bucket_count=tier.sub_bucket_count, + bucket_count=tier.bucket_count, + treat_empty_as_null=treat_empty_as_null, + ) + detection = detect_and_solve(source_agg_df, target_agg_df) + elapsed_ms = int((time.monotonic() - start_time) * 1000) + return detection, elapsed_ms + + +@dataclass(frozen=True) +class _FetchContext: + """Inputs the fetch phase needs.""" + + source: DataSource + target: DataSource + source_engine: Dialect + database_config: DatabaseConfig + table_conf: Table + src_schema: list[Schema] + tgt_schema: list[Schema] + detection_cols: list[Schema] + column_mapping: dict[str, str] | None + query_builder: FingerprintQueryBuilder + tier: _TierSelection + # Mirrors the Stage-1 ``treat_empty_as_null`` flag so Stage-2's target-side filter + # subquery cannot drift apart — see ``_run_detection_phase`` for the contract. No + # default: the orchestrator is the single source of truth (via + # ``ReconcileConfig.fingerprint_treat_empty_as_null``) and must always pass it + # explicitly so the four call sites cannot diverge silently. + treat_empty_as_null: bool + + +# sqlglot renders the table placeholder per dialect: Spark/Databricks emits ``:tbl``, +# Postgres-family dialects emit ``%(tbl)s`` (pyformat). Substitute both forms. +_TBL_PLACEHOLDERS = (":tbl", "%(tbl)s") + + +def _substitute_tbl_placeholder(query: str, replacement: str) -> str: + for placeholder in _TBL_PLACEHOLDERS: + query = query.replace(placeholder, replacement) + return query + + +def _fetch_source_rows( + ctx: _FetchContext, + solved_hashes: dict[int, list[int]], + unsolved_sb_ids: list[int], + report_type: str, +) -> tuple[DataFrame, str]: + """Fetch source rows for Stage-2 reconcile. + + Single statement: filter subquery is injected into the hash query's table + placeholder, producing one query that filters by sub-bucket and projects + LOWER(SHA2(...,256)) AS hash_value_recon. Only the hash and join columns + cross JDBC. + """ + source_filter_subquery = ctx.query_builder.build_source_filter_subquery( + schema=ctx.database_config.source_schema, + table=ctx.table_conf.source_name, + columns=ctx.detection_cols, + sub_bucket_count=ctx.tier.sub_bucket_count, + solved_hashes=solved_hashes, + unsolved_sb_ids=unsolved_sb_ids, + ) + # Project every hashed column, not just hash + join keys, so the downstream + # compare layer can populate ``mismatch_columns`` without a second round-trip + # to Redshift. Off in normal Lakebridge mode; only the fingerprint Stage-2 + # fetch flips this on. + src_hash_query = HashQueryBuilder( + ctx.table_conf, ctx.src_schema, "source", ctx.source_engine, ctx.source + ).build_query(report_type=report_type, project_all_columns=True) + src_filtered_query = _substitute_tbl_placeholder(src_hash_query, source_filter_subquery) + + df = ctx.source.read_data( + catalog=ctx.database_config.source_catalog, + schema=ctx.database_config.source_schema, + table=ctx.table_conf.source_name, + query=src_filtered_query, + options=ctx.table_conf.jdbc_reader_options, + ) + return df, FETCH_PATH_V1_SANDWICH + + +def _fetch_target_rows( + ctx: _FetchContext, + solved_hashes: dict[int, list[int]], + unsolved_sb_ids: list[int], + report_type: str, +) -> DataFrame: + """Fetch target rows for Stage-2 reconcile via the Spark-side filter subquery.""" + # See ``_fetch_source_rows`` for the project_all_columns rationale. + # Source and target MUST stay in lockstep: if source projects all columns and + # target only projects keys, ``capture_mismatch_data_and_columns`` will raise + # because ``source_columns != target_columns``. + tgt_hash_query = HashQueryBuilder( + ctx.table_conf, ctx.tgt_schema, "target", ctx.source_engine, ctx.target + ).build_query(report_type=report_type, project_all_columns=True) + tgt_filter_subquery = build_target_filter_subquery( + ctx.database_config.target_catalog, + ctx.database_config.target_schema, + ctx.table_conf.target_name, + ctx.detection_cols, + ctx.column_mapping, + solved_hashes, + unsolved_sb_ids, + sub_bucket_count=ctx.tier.sub_bucket_count, + treat_empty_as_null=ctx.treat_empty_as_null, + ) + tgt_filtered_query = _substitute_tbl_placeholder(tgt_hash_query, tgt_filter_subquery) + return ctx.target.read_data( + catalog=ctx.database_config.target_catalog, + schema=ctx.database_config.target_schema, + table=ctx.table_conf.target_name, + query=tgt_filtered_query, + options=None, + ) + + +def _fetch_source_and_target_rows( + ctx: _FetchContext, + solved_hashes: dict[int, list[int]], + unsolved_sb_ids: list[int], + report_type: str, +) -> tuple[DataFrame, str, DataFrame]: + """Run Stage-2 source and target fetches in parallel (B3). + + Source fetch is JDBC-bound (Redshift round trip + scan), target fetch is a + Spark filter-subquery against the cached Delta table. They share zero state + (different connectors, different DataFrames, immutable inputs) so dispatching + on two driver threads lets the JDBC pull overlap with the target Spark job + submission instead of running serially. + + Note: each fetch returns lazily — Spark's DAG isn't materialised until a + downstream action collects. So the wall-clock win comes from overlapping the + JDBC pull (which connectors typically force-collect via ``read_data``) + with the target query planning + initial Spark stage submission. On Spark + cluster execution itself, both fetches still parallelize across the + cluster's executors as before; this helper only addresses driver-side + serialization. + + Failure semantics: ``ThreadPoolExecutor.__exit__`` joins both futures, and + ``future.result()`` re-raises any exception from the worker thread on the + caller's stack — so behaviour is identical to the serial version on errors. + A failure in either fetch immediately aborts the precheck, just like before. + + Sibling-future cancellation: if the source fetch fails first, the target + fetch keeps running until it completes naturally — Python's + ``Future.cancel()`` is a no-op once the worker has started, so there is no + cheap way to interrupt a Spark job submission mid-flight from the driver. + The trigger layer's exception-catch wraps this whole block, so any work + that completes after the first failure is discarded and the ``with`` block + waits at most one extra fetch's worth of time before returning. The + two-thread cap means the overshoot is bounded; documenting this here so + future readers do not add a ``cancel()`` call expecting it to interrupt + the running Spark/JDBC submission. + """ + # max_workers=2 because we have exactly two independent fetches. Naming the + # threads helps when debugging stuck JDBC pulls in production thread dumps. + with ThreadPoolExecutor(max_workers=2, thread_name_prefix="fp-stage2") as pool: + src_future = pool.submit(_fetch_source_rows, ctx, solved_hashes, unsolved_sb_ids, report_type) + tgt_future = pool.submit(_fetch_target_rows, ctx, solved_hashes, unsolved_sb_ids, report_type) + src_data, fetch_path = src_future.result() + tgt_data = tgt_future.result() + return src_data, fetch_path, tgt_data + + +def run_fingerprint_precheck( # pylint: disable=too-many-locals,too-many-arguments + source: DataSource, + target: DataSource, + spark: SparkSession, + source_engine: Dialect, + database_config: DatabaseConfig, + table_conf: Table, + src_schema: list[Schema], + tgt_schema: list[Schema], + report_type: str, + data_source: str, + treat_empty_as_null: bool = _DEFAULT_TREAT_EMPTY_AS_NULL, + target_row_count_override: int | None = None, + recon_id: str | None = None, +) -> FingerprintResult | None: + """Execute the fingerprint pre-check for one table pair, or None if ineligible. + + On MISMATCH, the result carries pre-fetched source/target DataFrames already + projected with hash_value_recon, ready for compare.reconcile_data. + + Eligibility (flag, dialect, report_type, ``join_columns``, filters, transforms, + thresholds) is the contract of ``classify_ineligibility`` at the trigger layer; + callers must run that gate first. Bypassing it is undefined behaviour. + + ``recon_id`` is woven into the high-level log prefix so a multi-table run can + be traced from a single grep against the cluster logs without correlating by + timestamp alone. + """ + log_tag = f"Fingerprint[recon_id={recon_id}]" if recon_id else "Fingerprint" + alignment = align_columns(table_conf, src_schema, tgt_schema) + if alignment is None: + logger.info(f"{log_tag}: table '{table_conf.source_name}' ineligible — skipping pre-check") + return None + + detection_cols = _resolve_detection_columns(table_conf, src_schema, source) + if detection_cols is None: + return None + + query_builder = get_query_builder(data_source, treat_empty_as_null=treat_empty_as_null) + + # Same tier MUST be used by detection and fetch — Stage-2's filter modulus + # has to match Stage-1's GROUP BY modulus or solver IDs won't align. The + # ``target_row_count_override`` short-circuits ``DESCRIBE DETAIL`` so + # non-Delta targets (or stale-stats Delta) still land on the right tier. + tier = _select_tier(spark, database_config, table_conf, override_row_count=target_row_count_override) + + detection, elapsed_ms = _run_detection_phase( + source, + spark, + database_config, + table_conf, + detection_cols, + alignment.column_mapping, + query_builder, + tier, + treat_empty_as_null, + ) + + if detection.verdict == "MATCH": + return FingerprintResult( + verdict="MATCH", + detection_elapsed_ms=elapsed_ms, + sub_bucket_count=tier.sub_bucket_count, + bucket_count=tier.bucket_count, + target_row_count=tier.target_row_count, + row_count_source=tier.row_count_source, + ) + + if detection.systemic_mismatch: + logger.info(f"{log_tag}: systemic mismatch — deferring to full pipeline") + return None + + solved_hashes = collect_solved_hashes(detection) + unsolved_sb_ids = detection.unsolved_sb_ids + if not solved_hashes and not unsolved_sb_ids: + return None + + fetch_ctx = _FetchContext( + source=source, + target=target, + source_engine=source_engine, + database_config=database_config, + table_conf=table_conf, + src_schema=src_schema, + tgt_schema=tgt_schema, + detection_cols=detection_cols, + column_mapping=alignment.column_mapping, + query_builder=query_builder, + tier=tier, + treat_empty_as_null=treat_empty_as_null, + ) + src_data, fetch_path, tgt_data = _fetch_source_and_target_rows( + fetch_ctx, solved_hashes, unsolved_sb_ids, report_type + ) + + return FingerprintResult( + verdict="MISMATCH", + source_rows=src_data, + target_rows=tgt_data, + solved_count=len(detection.solved_results), + unsolved_sb_count=len(detection.unsolved_sb_ids), + total_mismatched_sbs=detection.total_mismatched_sbs, + detection_elapsed_ms=elapsed_ms, + sub_bucket_count=tier.sub_bucket_count, + bucket_count=tier.bucket_count, + target_row_count=tier.target_row_count, + row_count_source=tier.row_count_source, + fetch_path=fetch_path, + ) + + +# Adding a new source = one entry here plus a new FingerprintQueryBuilder subclass. +_QUERY_BUILDERS: dict[str, type[FingerprintQueryBuilder]] = { + "redshift": RedshiftFingerprintQueryBuilder, +} + + +def get_query_builder( + data_source: str, + *, + treat_empty_as_null: bool = _DEFAULT_TREAT_EMPTY_AS_NULL, +) -> FingerprintQueryBuilder: + """Return the registered builder for ``data_source``. + + Raises ``UnsupportedDataSourceError`` (a ValueError) when no builder is registered; + callers should pre-flight via ``fingerprint_supported_sources()``. + + ``treat_empty_as_null`` is threaded so the source-side serialiser stays in lockstep + with the target-side ``compute_target_fingerprint`` / ``build_target_filter_subquery`` + calls — disagreement here causes systemic Stage-1 mismatch on every NULL/'' cell. + """ + try: + builder_cls = _QUERY_BUILDERS[data_source] + except KeyError as e: + raise UnsupportedDataSourceError( + f"No fingerprint query builder registered for data_source={data_source!r}. " + f"Supported: {sorted(_QUERY_BUILDERS)}" + ) from e + return builder_cls(treat_empty_as_null=treat_empty_as_null) + + +def fingerprint_supported_sources() -> frozenset[str]: + """Sources with a registered FingerprintQueryBuilder.""" + return frozenset(_QUERY_BUILDERS) + + +def collect_solved_hashes(detection: DetectionResult) -> dict[int, list[int]]: + """Merge solved source and target hashes per sub-bucket into a single lookup. + + Multiple ``SolveResult`` rows can share a ``sub_bucket_id``; the same hash can also + appear on both the source and target side. Dedupe via ``set`` so the dict stays + O(distinct hashes) for memory; sort on the way out for deterministic SQL emission + (``build_fingerprint_where_clause`` already dedupes the IN-list, but the dict is + held driver-side until then and dominates memory at the 50 K-sub-bucket cap). + """ + accum: dict[int, set[int]] = {} + for solve in detection.solved_results: + if not solve.source_hashes and not solve.target_hashes: + continue + bucket = accum.setdefault(solve.sub_bucket_id, set()) + bucket.update(solve.source_hashes) + bucket.update(solve.target_hashes) + return {sb_id: sorted(hashes) for sb_id, hashes in accum.items()} diff --git a/src/databricks/labs/lakebridge/reconcile/fingerprint/query_builders/__init__.py b/src/databricks/labs/lakebridge/reconcile/fingerprint/query_builders/__init__.py new file mode 100644 index 0000000000..2bb1f1e9ca --- /dev/null +++ b/src/databricks/labs/lakebridge/reconcile/fingerprint/query_builders/__init__.py @@ -0,0 +1 @@ +"""Dialect-specific fingerprint SQL builders.""" diff --git a/src/databricks/labs/lakebridge/reconcile/fingerprint/query_builders/base.py b/src/databricks/labs/lakebridge/reconcile/fingerprint/query_builders/base.py new file mode 100644 index 0000000000..c08280ea1e --- /dev/null +++ b/src/databricks/labs/lakebridge/reconcile/fingerprint/query_builders/base.py @@ -0,0 +1,45 @@ +from abc import ABC, abstractmethod + +from databricks.labs.lakebridge.reconcile.recon_config import Schema + + +class FingerprintQueryBuilder(ABC): + """Dialect-specific SQL generation for fingerprint detection and row fetch.""" + + def __init__(self, treat_empty_as_null: bool = False): + # Default False matches the row-hash convention in expression_generator + # (TRIM does not collapse '' to NULL); flipping silently disagrees with row-hash + # on every NULL <-> '' flip. + self._treat_empty_as_null = treat_empty_as_null + + @abstractmethod + def build_detection_sql( + self, + schema: str, + table: str, + columns: list[Schema], + column_mapping: dict[str, str] | None, + sub_bucket_count: int, + bucket_count: int, + ) -> str: + """Source-side detection SQL grouping into sub-buckets with (cnt, p1, p2, p1_rh2, p2_rh2).""" + + @abstractmethod + def build_source_filter_subquery( + self, + schema: str, + table: str, + columns: list[Schema], + sub_bucket_count: int, + solved_hashes: dict[int, list[int]], + unsolved_sb_ids: list[int], + ) -> str: + """Subquery selecting only rows in solved sub-buckets / unsolved sub-buckets. + + Returns ``(SELECT * FROM schema.table WHERE ) _fp_filtered``, + suitable for replacing ``:tbl`` in a HashQueryBuilder query. + """ + + @abstractmethod + def serialize_column(self, col_name: str, col_type: str) -> str: + """Cast the column to a deterministic string representation for MD5 hashing.""" diff --git a/src/databricks/labs/lakebridge/reconcile/fingerprint/query_builders/redshift.py b/src/databricks/labs/lakebridge/reconcile/fingerprint/query_builders/redshift.py new file mode 100644 index 0000000000..4aa668cd29 --- /dev/null +++ b/src/databricks/labs/lakebridge/reconcile/fingerprint/query_builders/redshift.py @@ -0,0 +1,133 @@ +import logging + +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils +from databricks.labs.lakebridge.reconcile.fingerprint.constants import ( + NULL_SENTINEL, + SEPARATOR_REDSHIFT_SQL, + build_fingerprint_where_clause, +) +from databricks.labs.lakebridge.reconcile.fingerprint.query_builders.base import FingerprintQueryBuilder +from databricks.labs.lakebridge.reconcile.recon_config import Schema + +logger = logging.getLogger(__name__) + + +_REDSHIFT_IDENTIFIER_QUOTE = '"' + + +def _quote_redshift_identifier(bare: str) -> str: + """Wrap ``bare`` in Redshift's ``"..."`` identifier delimiters, doubling any embedded + ``"`` per the SQL standard. Defense-in-depth: today's values arrive from + ``information_schema.columns`` and are trusted, but the persistence layer scrubs + embedded quotes for the same reason — keep the two boundaries consistent. + """ + escaped = bare.replace(_REDSHIFT_IDENTIFIER_QUOTE, _REDSHIFT_IDENTIFIER_QUOTE * 2) + return f"{_REDSHIFT_IDENTIFIER_QUOTE}{escaped}{_REDSHIFT_IDENTIFIER_QUOTE}" + + +class RedshiftFingerprintQueryBuilder(FingerprintQueryBuilder): + """Generate Redshift SQL for fingerprint detection and surgical row fetch.""" + + def serialize_column(self, col_name: str, col_type: str) -> str: + # ANSI-delimited identifiers must be re-quoted with double quotes for Redshift. + bare = DialectUtils.unnormalize_identifier(col_name) + quoted = _quote_redshift_identifier(bare) + + col_type_lower = (col_type or "").strip().lower() + if col_type_lower == "boolean": + # Redshift rejects every form of CAST(boolean AS VARCHAR/TEXT); CASE WHEN + # produces 'true'/'false' to match Spark's cast(bool AS string). + cast_expr = f"CASE WHEN {quoted} THEN 'true' WHEN NOT {quoted} THEN 'false' ELSE NULL END" + elif col_type_lower in {"timestamptz", "timestamp with time zone"}: + # Parity with the row-hash compare path: ``TO_CHAR`` with a fixed + # ``YYYY-MM-DD HH24:MI:SS.US`` format pins microsecond width so the + # byte stream matches the Spark target's ``DATE_FORMAT(_, + # 'yyyy-MM-dd HH:mm:ss.SSSSSS')``. Default ``CAST(_ AS VARCHAR)`` + # would emit variable-width fractional seconds and silently disagree + # with row-hash on the same row, dropping it from Stage-2. + cast_expr = f"TO_CHAR({quoted} AT TIME ZONE 'UTC', 'YYYY-MM-DD HH24:MI:SS.US')" + elif col_type_lower in {"timestamp", "timestamp without time zone"}: + cast_expr = f"TO_CHAR({quoted}, 'YYYY-MM-DD HH24:MI:SS.US')" + elif col_type_lower == "date": + cast_expr = f"TO_CHAR({quoted}, 'YYYY-MM-DD')" + else: + # ``CAST(_ AS VARCHAR)`` in Redshift defaults to ``VARCHAR(256)`` and + # silently truncates anything longer; the Spark target keeps the + # full string. That asymmetry would surface a Stage-1 mismatch on + # otherwise-equal long-text rows. ``VARCHAR(65535)`` is Redshift's + # maximum width and matches Spark's unbounded string semantics. + cast_expr = f"CAST({quoted} AS VARCHAR(65535))" + + # TRIM keeps Stage-1 whitespace-symmetric with the row-hash compare path + # and the Spark target serializer. + trimmed = f"TRIM({cast_expr})" + if self._treat_empty_as_null: + return f"COALESCE(NULLIF({trimmed}, ''), '{NULL_SENTINEL}')" + return f"COALESCE({trimmed}, '{NULL_SENTINEL}')" + + def build_detection_sql( + self, + schema: str, + table: str, + columns: list[Schema], + column_mapping: dict[str, str] | None, + sub_bucket_count: int, + bucket_count: int, + ) -> str: + # ``column_mapping`` is unused on the source side: Redshift reads its own physical + # names. The ABC carries it for symmetry with dialects whose source-side SQL + # might need to project differently from the target. + rh1_expr, rh2_expr, sb_expr = self._md5_hash_exprs(columns, sub_bucket_count) + bucket_expr = f"ABS(MOD({rh1_expr}, {bucket_count}))" + + # rh*rh exceeds BIGINT range. Cast operands to DECIMAL(19,0) so the multiply lands + # in DECIMAL(38,0); SUM lifts linear aggregates directly to DECIMAL(38,0). + rh1_dec19 = f"CAST({rh1_expr} AS DECIMAL(19,0))" + rh1_dec38 = f"CAST({rh1_expr} AS DECIMAL(38,0))" + rh2_dec19 = f"CAST({rh2_expr} AS DECIMAL(19,0))" + rh2_dec38 = f"CAST({rh2_expr} AS DECIMAL(38,0))" + + # Route schema / table through the same identifier-quoting helper used + # for column names so an exotic name like ``my-schema`` or one carrying + # a stray ``"`` cannot malform the SQL. Today these are + # connector-validated and safe, but the cost of being defensive here is + # one function call. + from_table = f"{_quote_redshift_identifier(schema)}.{_quote_redshift_identifier(table)}" + return ( + f"SELECT {sb_expr} AS sub_bucket_id, " + f"{bucket_expr} AS bucket_id, " + f"COUNT(*) AS cnt, " + f"SUM({rh1_dec38}) AS p1, " + f"SUM({rh1_dec19} * {rh1_dec19}) AS p2, " + f"SUM({rh2_dec38}) AS p1_rh2, " + f"SUM({rh2_dec19} * {rh2_dec19}) AS p2_rh2 " + f"FROM {from_table} " + f"GROUP BY sub_bucket_id, bucket_id" + ) + + def build_source_filter_subquery( + self, + schema: str, + table: str, + columns: list[Schema], + sub_bucket_count: int, + solved_hashes: dict[int, list[int]], + unsolved_sb_ids: list[int], + ) -> str: + rh1_expr, _, sb_expr = self._md5_hash_exprs(columns, sub_bucket_count) + where_clause = build_fingerprint_where_clause(sb_expr, rh1_expr, solved_hashes, unsolved_sb_ids) + from_table = f"{_quote_redshift_identifier(schema)}.{_quote_redshift_identifier(table)}" + return f"(SELECT * FROM {from_table} WHERE {where_clause}) _fp_filtered" + + def build_concat_expression(self, columns: list[Schema]) -> str: + """Concat over source physical column names for MD5.""" + parts = [self.serialize_column(c.column_name, c.data_type) for c in columns] + return f" || {SEPARATOR_REDSHIFT_SQL} || ".join(parts) + + def _md5_hash_exprs(self, columns: list[Schema], sub_bucket_count: int) -> tuple[str, str, str]: + """Return the (rh1, rh2, sb_expr) MD5-derived SQL fragments shared by detection + filter SQL.""" + concat_expr = self.build_concat_expression(columns) + rh1_expr = f"STRTOL(SUBSTRING(MD5({concat_expr}), 1, 8), 16)" + rh2_expr = f"STRTOL(SUBSTRING(MD5({concat_expr}), 9, 8), 16)" + sb_expr = f"ABS(MOD({rh1_expr}, {sub_bucket_count}))" + return rh1_expr, rh2_expr, sb_expr diff --git a/src/databricks/labs/lakebridge/reconcile/fingerprint/row_count.py b/src/databricks/labs/lakebridge/reconcile/fingerprint/row_count.py new file mode 100644 index 0000000000..2fefba6c68 --- /dev/null +++ b/src/databricks/labs/lakebridge/reconcile/fingerprint/row_count.py @@ -0,0 +1,119 @@ +"""Target row-count fetcher for adaptive sub-bucket tier selection. + +Tier selection needs an order-of-magnitude row count, never a full scan. A SELECT +COUNT(*) on a billion-row table defeats fingerprint mode entirely. The chain is +metadata-only: + +1. Explicit user override — highest precedence. +2. Target Delta ``numRecords`` from DESCRIBE DETAIL — free, exact, sub-second. +3. Static default — fall through with a warning. + +Source-side row counts (Redshift catalog stats) are not consulted: source and target +must use the same tier, picking from target Delta metadata is enough at order-of- +magnitude resolution. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from enum import Enum + +from pyspark.sql import SparkSession +from pyspark.sql.utils import AnalysisException + +logger = logging.getLogger(__name__) + + +class RowCountSource(str, Enum): + """Provenance of the row count used for tier selection.""" + + USER_OVERRIDE = "user_override" + DELTA_DESCRIBE_DETAIL = "delta_describe_detail" + STATIC_DEFAULT = "static_default" + + +@dataclass(frozen=True) +class RowCountResult: + """``row_count`` is None only when ``source == STATIC_DEFAULT``.""" + + row_count: int | None + source: RowCountSource + + +def fetch_target_row_count( + spark: SparkSession, + *, + catalog: str | None, + schema: str, + table: str, + override_row_count: int | None = None, +) -> RowCountResult: + """Resolve the target row count via metadata-only paths. + + Never raises: every failure logs and falls through. Tier selection is a best-effort + optimisation and must not block detection. + """ + if override_row_count is not None and override_row_count > 0: + logger.info(f"fingerprint.tier.row_count_source=user_override row_count={override_row_count}") + return RowCountResult(row_count=override_row_count, source=RowCountSource.USER_OVERRIDE) + + fully_qualified = _build_fqn(catalog=catalog, schema=schema, table=table) + + delta_count = _try_describe_detail(spark, fully_qualified) + if delta_count is not None: + logger.info( + f"fingerprint.tier.row_count_source=delta_describe_detail " + f"table={fully_qualified} row_count={delta_count}" + ) + return RowCountResult(row_count=delta_count, source=RowCountSource.DELTA_DESCRIBE_DETAIL) + + logger.warning( + f"fingerprint.tier.row_count_source=static_default table={fully_qualified} — DESCRIBE DETAIL " + "returned no numRecords (target may be non-Delta or stats missing); falling back. " + "Set ReconcileConfig.fingerprint_row_count_override to a non-zero estimate " + "to pin the tier explicitly." + ) + return RowCountResult(row_count=None, source=RowCountSource.STATIC_DEFAULT) + + +def _build_fqn(*, catalog: str | None, schema: str, table: str) -> str: + if catalog: + return f"{catalog}.{schema}.{table}" + return f"{schema}.{table}" + + +def _try_describe_detail(spark: SparkSession, fully_qualified_name: str) -> int | None: + """Run DESCRIBE DETAIL and return numRecords when available. + + Returns None when the table is not Delta, the column is missing, the value is null, + or any Spark-side error occurs — tier selection must never block detection. + """ + try: + detail_df = spark.sql(f"DESCRIBE DETAIL {fully_qualified_name}") + except AnalysisException as exc: + logger.debug(f"DESCRIBE DETAIL failed for {fully_qualified_name}: {exc}") + return None + except Exception as exc: # pylint: disable=broad-exception-caught # tier-selection must never block detection + logger.debug(f"DESCRIBE DETAIL raised unexpected error for {fully_qualified_name}: {exc}") + return None + + if "numRecords" not in detail_df.columns: + logger.debug(f"DESCRIBE DETAIL on {fully_qualified_name} returned no numRecords column") + return None + + rows = detail_df.select("numRecords").collect() + if not rows: + logger.debug(f"DESCRIBE DETAIL on {fully_qualified_name} returned 0 rows") + return None + + num_records = rows[0]["numRecords"] + if num_records is None: + logger.debug(f"DESCRIBE DETAIL on {fully_qualified_name} returned NULL numRecords") + return None + + if not isinstance(num_records, int) or num_records < 0: + logger.debug(f"DESCRIBE DETAIL on {fully_qualified_name} returned unexpected numRecords value {num_records!r}") + return None + + return num_records diff --git a/src/databricks/labs/lakebridge/reconcile/fingerprint/spark_target.py b/src/databricks/labs/lakebridge/reconcile/fingerprint/spark_target.py new file mode 100644 index 0000000000..99a6e03324 --- /dev/null +++ b/src/databricks/labs/lakebridge/reconcile/fingerprint/spark_target.py @@ -0,0 +1,279 @@ +"""Spark-side detection (Stage-1) and Stage-2 filter helpers for the target Delta table. + +Stage-1 (DataFrame path) and Stage-2 (SQL filter path) share one column-serialisation +contract here: ``COALESCE(TRIM(CAST(_ AS string)), '')``. Keeping both helpers +in this module prevents the two stages from drifting silently (the row would be flagged +by Stage-1 and then dropped from Stage-2 when the per-row SHA2 inputs disagreed). +""" + +import logging + +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import functions as F +from pyspark.sql.types import DecimalType, LongType + +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils +from databricks.labs.lakebridge.reconcile.fingerprint.constants import ( + NULL_SENTINEL, + SEPARATOR_PYTHON, + SEPARATOR_SPARK_SQL, + build_fingerprint_where_clause, +) +from databricks.labs.lakebridge.reconcile.recon_config import Schema + +# rh1/rh2 are 32-bit unsigned values; rh*rh reaches ~2^64 and silently wraps under +# LongType. DecimalType(19, 0) holds 2^32 and Spark's precision rule produces +# DecimalType(38, 0) on the multiply, matching the source-side aggregate type. +_RH_OPERAND_TYPE = DecimalType(19, 0) +_AGG_TYPE = DecimalType(38, 0) + +logger = logging.getLogger(__name__) + + +def compute_target_fingerprint( + spark: SparkSession, + catalog: str | None, + schema: str, + table: str, + columns: list[Schema], + column_mapping: dict[str, str] | None, + sub_bucket_count: int, + bucket_count: int, + treat_empty_as_null: bool = False, +) -> DataFrame: + """Compute Stage-1 sub-bucket aggregates on the target Delta table. + + Mirrors the source-side detection query: MD5 over concatenated columns, dual-slice + rh1/rh2 extraction, GROUP BY into sub-buckets with (cnt, p1, p2, p1_rh2, p2_rh2). + """ + df = spark.table(_table_fqn(catalog, schema, table)) + + concat_col = _build_concat_column(columns, column_mapping, treat_empty_as_null) + md5_col = F.md5(concat_col) + + rh1_col = _hex_substr_to_long(md5_col, 1, 8) + rh2_col = _hex_substr_to_long(md5_col, 9, 8) + + df_hashed = df.select( + F.abs(rh1_col % F.lit(sub_bucket_count)).alias("sub_bucket_id"), + F.abs(rh1_col % F.lit(bucket_count)).alias("bucket_id"), + rh1_col.alias("rh1"), + rh2_col.alias("rh2"), + ) + return df_hashed.groupBy("sub_bucket_id", "bucket_id").agg(*_hash_agg_exprs()) + + +def build_target_filter_subquery( + catalog: str | None, + schema: str, + table: str, + columns: list[Schema], + column_mapping: dict[str, str] | None, + solved_hashes: dict[int, list[int]], + unsolved_sb_ids: list[int], + *, + sub_bucket_count: int, + treat_empty_as_null: bool = False, +) -> str: + """Build the Spark-SQL subquery that filters target rows for Stage-2 surgical fetch. + + Uses the same column serialisation as ``compute_target_fingerprint`` so the rh1 + values match Stage-1's hashing exactly; ``sub_bucket_count`` must equal Stage-1's. + """ + concat_expr = _build_target_concat_sql(columns, column_mapping, treat_empty_as_null) + rh1_expr = f"CAST(CONV(SUBSTR(MD5({concat_expr}), 1, 8), 16, 10) AS BIGINT)" + sb_expr = f"ABS(MOD({rh1_expr}, {sub_bucket_count}))" + where_clause = build_fingerprint_where_clause(sb_expr, rh1_expr, solved_hashes, unsolved_sb_ids) + return f"(SELECT * FROM {_table_fqn(catalog, schema, table)} WHERE {where_clause}) _fp_filtered" + + +def _build_concat_column( + columns: list[Schema], + column_mapping: dict[str, str] | None, + treat_empty_as_null: bool, +) -> F.Column: + """Concatenate serialised target columns into a single Column for MD5 hashing.""" + parts = [ + _serialize_column_spark(_target_col_name(c, column_mapping), c.data_type, treat_empty_as_null) for c in columns + ] + if len(parts) == 1: + return parts[0] + result = parts[0] + for part in parts[1:]: + result = F.concat(result, F.lit(SEPARATOR_PYTHON), part) + return result + + +def _build_target_concat_sql( + columns: list[Schema], + column_mapping: dict[str, str] | None, + treat_empty_as_null: bool, +) -> str: + """SQL-string sibling of ``_build_concat_column`` for the Stage-2 filter subquery.""" + col_exprs = [ + _serialize_column_spark_sql(_target_col_name(c, column_mapping), c.data_type, treat_empty_as_null) + for c in columns + ] + if len(col_exprs) == 1: + return col_exprs[0] + sep_parts: list[str] = [] + for i, expr in enumerate(col_exprs): + if i > 0: + sep_parts.append(SEPARATOR_SPARK_SQL) + sep_parts.append(expr) + return f"CONCAT({', '.join(sep_parts)})" + + +# Spark types whose default ``cast(_ AS string)`` representation drifts from the +# Redshift source-side ``TO_CHAR(...)`` payload. Listing the bare type prefix +# (lowercased) is enough — Spark's ``data_type`` from INFORMATION_SCHEMA already +# strips precision modifiers for these. +# +# We classify timestamps into two families: timezone-aware (LTZ — Spark's default +# ``timestamp``, plus ``timestamp_ltz`` and ``timestamp with time zone``) and +# timezone-naive (``timestamp_ntz`` and ``timestamp without time zone``). +# Stage-1 must produce the same bytes as the Redshift source for a given +# logical row, and Redshift renders ``timestamptz`` ``AT TIME ZONE 'UTC'``. +# A TZ-aware Spark column rendered via ``date_format`` uses the session +# timezone — if a cluster runs with a non-UTC ``spark.sql.session.timeZone`` +# the two sides emit different bytes for the same instant. We therefore +# normalise TZ-aware columns to the UTC wall-clock before formatting. +_SPARK_TIMESTAMP_NTZ_TOKENS = ("timestamp_ntz", "timestamp without time zone") + + +def _classify_timestamp(col_type: str) -> str | None: + """Return ``"ltz"``, ``"ntz"``, or ``None`` for non-timestamp columns.""" + if not col_type.startswith("timestamp"): + return None + if any(col_type.startswith(p) for p in _SPARK_TIMESTAMP_NTZ_TOKENS): + return "ntz" + return "ltz" + + +def _serialize_column_spark(col_name: str, col_type: str, treat_empty_as_null: bool) -> F.Column: + """Stage-1 (DataFrame) per-column serializer. + + Four contracts coexist here: + * ``TRIM`` keeps Stage-1 whitespace-symmetric with Stage-2 (otherwise a + row whose only difference is trailing whitespace surfaces in Stage-1 + and is silently dropped by Stage-2's per-row SHA2). + * Timestamps and dates route through ``DATE_FORMAT`` so the byte stream + matches the row-hash compare path's + ``DATE_FORMAT(_, 'yyyy-MM-dd HH:mm:ss.SSSSSS')`` and the source-side + Redshift ``TO_CHAR(_, 'YYYY-MM-DD HH24:MI:SS.US')``. Default + ``cast(_ AS string)`` produces variable-width fractional seconds. + * TZ-aware (LTZ) timestamps are shifted to the UTC wall-clock via + ``TO_UTC_TIMESTAMP(_, CURRENT_TIMEZONE())`` before formatting so a + cluster running with a non-UTC session timezone still emits bytes + identical to Redshift's ``TO_CHAR(_ AT TIME ZONE 'UTC', _)``. Without + this pin, the same instant would render differently on the two sides + and Stage-1 would over-report mismatches on every TZ-aware column. + * The column reference is built via ``F.expr(_quote_spark_identifier(...))`` + because ``F.col`` interprets ``.`` as a struct path — Delta columns + literally named ``"a.b"`` would otherwise fail to resolve. + """ + col_type_lower = (col_type or "").strip().lower() + spark_col = F.expr(_quote_spark_identifier(col_name)) + ts_kind = _classify_timestamp(col_type_lower) + if ts_kind == "ltz": + ts_in_utc = F.to_utc_timestamp(spark_col, F.current_timezone()) + cast_col = F.trim(F.date_format(ts_in_utc, "yyyy-MM-dd HH:mm:ss.SSSSSS")) + elif ts_kind == "ntz": + cast_col = F.trim(F.date_format(spark_col, "yyyy-MM-dd HH:mm:ss.SSSSSS")) + elif col_type_lower == "date": + cast_col = F.trim(F.date_format(spark_col, "yyyy-MM-dd")) + else: + cast_col = F.trim(spark_col.cast("string")) + if treat_empty_as_null: + return F.coalesce( + F.when(cast_col == F.lit(""), None).otherwise(cast_col), + F.lit(NULL_SENTINEL), + ) + return F.coalesce(cast_col, F.lit(NULL_SENTINEL)) + + +_SPARK_IDENTIFIER_QUOTE = "`" + + +def _quote_spark_identifier(bare: str) -> str: + """Wrap ``bare`` in Spark SQL's backtick identifier delimiters, doubling any embedded + backtick. Defense-in-depth: today's values come from Delta metadata and never carry + a backtick, but ``recon_capture`` scrubs persisted string values for the same reason + — both boundaries should be consistent. + """ + escaped = bare.replace(_SPARK_IDENTIFIER_QUOTE, _SPARK_IDENTIFIER_QUOTE * 2) + return f"{_SPARK_IDENTIFIER_QUOTE}{escaped}{_SPARK_IDENTIFIER_QUOTE}" + + +def _serialize_column_spark_sql(col_name: str, col_type: str, treat_empty_as_null: bool) -> str: + """Stage-2 (SQL string) per-column serializer; must produce hashes identical to ``_serialize_column_spark``.""" + col_type_lower = (col_type or "").strip().lower() + quoted = _quote_spark_identifier(col_name) + ts_kind = _classify_timestamp(col_type_lower) + if ts_kind == "ltz": + cast_expr = f"TRIM(DATE_FORMAT(TO_UTC_TIMESTAMP({quoted}, CURRENT_TIMEZONE()), 'yyyy-MM-dd HH:mm:ss.SSSSSS'))" + elif ts_kind == "ntz": + cast_expr = f"TRIM(DATE_FORMAT({quoted}, 'yyyy-MM-dd HH:mm:ss.SSSSSS'))" + elif col_type_lower == "date": + cast_expr = f"TRIM(DATE_FORMAT({quoted}, 'yyyy-MM-dd'))" + else: + cast_expr = f"TRIM(CAST({quoted} AS STRING))" + if treat_empty_as_null: + return f"COALESCE(NULLIF({cast_expr}, ''), '{NULL_SENTINEL}')" + return f"COALESCE({cast_expr}, '{NULL_SENTINEL}')" + + +def _target_col_name(schema_col: Schema, column_mapping: dict[str, str] | None) -> str: + """Resolve target physical column name; ``Schema.column_name`` arrives ANSI-delimited.""" + bare = DialectUtils.unnormalize_identifier(schema_col.column_name) + if column_mapping: + return column_mapping.get(bare, bare) + return bare + + +def _table_fqn(catalog: str | None, schema: str, table: str) -> str: + """Catalog / schema / table come from ReconcileConfig and are validated + upstream (Unity Catalog enforces a strict identifier grammar), but routing + through ``_quote_spark_identifier`` makes the FQN robust against names + containing ``-``, ``.``, or other characters Spark would otherwise treat + specially. Defensive parity with column-name quoting on this same boundary. + """ + parts = [_quote_spark_identifier(schema), _quote_spark_identifier(table)] + if catalog: + parts.insert(0, _quote_spark_identifier(catalog)) + return ".".join(parts) + + +def _hash_agg_exprs() -> list[F.Column]: + rh1_dec19 = F.col("rh1").cast(_RH_OPERAND_TYPE) + rh2_dec19 = F.col("rh2").cast(_RH_OPERAND_TYPE) + rh1_dec38 = F.col("rh1").cast(_AGG_TYPE) + rh2_dec38 = F.col("rh2").cast(_AGG_TYPE) + return [ + F.count("*").alias("cnt"), + F.sum(rh1_dec38).alias("p1"), + F.sum(rh1_dec19 * rh1_dec19).alias("p2"), + F.sum(rh2_dec38).alias("p1_rh2"), + F.sum(rh2_dec19 * rh2_dec19).alias("p2_rh2"), + ] + + +def _hex_substr_to_long(md5_col: F.Column, start: int, length: int) -> F.Column: + """Convert an MD5 hex substring to a long via Photon-safe ASCII arithmetic. + + F.conv() forces ColumnarToRow JVM fallback on Photon; substring/ascii/when-otherwise/ + multiply/add are all Photon-native. md5() returns lowercase hex so no upper() needed. + """ + hex_slice = F.substring(md5_col, start, length) + ascii_0 = F.ascii(F.lit("0")) + ascii_a = F.ascii(F.lit("a")) + result = F.lit(0).cast(LongType()) + for i in range(length): + char_col = F.substring(hex_slice, i + 1, 1) + digit = ( + F.when(F.ascii(char_col) >= ascii_a, F.ascii(char_col) - ascii_a + F.lit(10)) + .otherwise(F.ascii(char_col) - ascii_0) + .cast(LongType()) + ) + result = result * F.lit(16).cast(LongType()) + digit + return result diff --git a/src/databricks/labs/lakebridge/reconcile/query_builder/expression_generator.py b/src/databricks/labs/lakebridge/reconcile/query_builder/expression_generator.py index fd8e9bd070..a3cec01570 100644 --- a/src/databricks/labs/lakebridge/reconcile/query_builder/expression_generator.py +++ b/src/databricks/labs/lakebridge/reconcile/query_builder/expression_generator.py @@ -262,6 +262,22 @@ def _get_is_string(column_types_dict: dict[str, DataType], column_name: str) -> exp.DataType.Type.ARRAY.value: [ partial(anonymous, func="CONCAT_WS(',', SORT_ARRAY({}))", dialect=get_dialect("databricks")) ], + # Align with Redshift's ``TO_CHAR(ts, 'YYYY-MM-DD HH24:MI:SS.US')`` so + # the per-row SHA2 inputs are byte-identical for Redshift -> Databricks reconciles. + exp.DataType.Type.TIMESTAMP.value: [ + partial( + anonymous, + func="COALESCE(DATE_FORMAT({}, 'yyyy-MM-dd HH:mm:ss.SSSSSS'), '_null_recon_')", + dialect=get_dialect("databricks"), + ) + ], + exp.DataType.Type.TIMESTAMPTZ.value: [ + partial( + anonymous, + func="COALESCE(DATE_FORMAT({}, 'yyyy-MM-dd HH:mm:ss.SSSSSS'), '_null_recon_')", + dialect=get_dialect("databricks"), + ) + ], }, "tsql": { "default": [partial(anonymous, func="COALESCE(TRIM(CAST({} AS VARCHAR(MAX))), '_null_recon_')")], @@ -298,6 +314,20 @@ def _get_is_string(column_types_dict: dict[str, DataType], column_name: str) -> dialect=get_dialect("redshift"), ) ], + # Redshift rejects every form of CAST(boolean AS VARCHAR/TEXT) and the + # universal default applies TRIM to the column, which yields + # ``btrim(boolean)`` and the function-not-found error customers see. + # CASE WHEN produces the same lowercase 'true'/'false' that + # Spark's cast(boolean AS string) emits, keeping source and target + # row hashes byte-identical. Mirrors the boolean handler in + # ``fingerprint/query_builders/redshift.py``. + exp.DataType.Type.BOOLEAN.value: [ + partial( + anonymous, + func="COALESCE(CASE WHEN {0} THEN 'true' WHEN NOT {0} THEN 'false' ELSE NULL END, '_null_recon_')", + dialect=get_dialect("redshift"), + ) + ], }, } diff --git a/src/databricks/labs/lakebridge/reconcile/query_builder/hash_query.py b/src/databricks/labs/lakebridge/reconcile/query_builder/hash_query.py index 56828b3741..e7b0ff63f4 100644 --- a/src/databricks/labs/lakebridge/reconcile/query_builder/hash_query.py +++ b/src/databricks/labs/lakebridge/reconcile/query_builder/hash_query.py @@ -30,7 +30,16 @@ def _hash_transform( class HashQueryBuilder(QueryBuilder): - def build_query(self, report_type: str) -> str: + def build_query(self, report_type: str, *, project_all_columns: bool = False) -> str: + """Build the hash query for the configured layer. + + ``project_all_columns`` (keyword-only): when True, the projection includes + every hashed column (not just join + partition keys). Fingerprint Stage-2 + surgical fetch needs this so the compare layer can populate + ``mismatch_columns`` without a second round-trip. Source and target sides + MUST be invoked with the same value or ``capture_mismatch_data_and_columns`` + raises on diverging column sets. + """ if report_type != 'row': self._validate(self.join_columns, f"Join Columns are compulsory for {report_type} type") @@ -39,6 +48,11 @@ def build_query(self, report_type: str) -> str: hash_cols = sorted((_join_columns | self.select_columns) - self.threshold_columns - self.drop_columns) key_cols = hash_cols if report_type == "row" else sorted(_join_columns | self.partition_column) + if project_all_columns and report_type != "row": + # Union with hash_cols (already a sorted superset of join columns) + # so we can keep the deterministic projection order while still + # widening the SELECT list to every hashed column. + key_cols = sorted(set(key_cols) | set(hash_cols)) cols_with_alias = [self._build_column_with_alias(col) for col in key_cols] diff --git a/src/databricks/labs/lakebridge/reconcile/recon_capture.py b/src/databricks/labs/lakebridge/reconcile/recon_capture.py index 165de2deb2..8a76688c6f 100644 --- a/src/databricks/labs/lakebridge/reconcile/recon_capture.py +++ b/src/databricks/labs/lakebridge/reconcile/recon_capture.py @@ -12,6 +12,7 @@ from sqlglot import Dialect from databricks.labs.lakebridge.config import DatabaseConfig, Table, ReconcileMetadataConfig +from databricks.labs.lakebridge.reconcile.fingerprint.metadata import FingerprintRunMetadata from databricks.labs.lakebridge.reconcile.recon_config import TableThresholds from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_key_from_dialect from databricks.labs.lakebridge.reconcile.exception import ( @@ -40,6 +41,71 @@ _RECON_AGGREGATE_DETAILS_TABLE_NAME = "aggregate_details" +# Single source of truth for the persisted ``fingerprint_metrics`` named_struct. +# Tuple of (sql_field_name, dataclass_attribute, sql_type). +# +# Field ORDER must match ``FingerprintRunMetadata`` declaration order — Delta +# resolves struct fields positionally on saveAsTable, so reordering here would +# silently corrupt every recon_metrics row written against existing customer +# tables. The unit suite guards order. +# +# Allowed sql_type values: +# - "bool" -> ``true``/``false`` literal +# - "bigint" -> ``cast(N as bigint)`` literal +# - "bigint_or_null" -> ``cast(N as bigint)`` or SQL ``NULL`` +# - "string_or_null" -> ``'value'`` (quote-scrubbed) or SQL ``NULL`` +_FP_METRICS_STRUCT_FIELDS: tuple[tuple[str, str, str], ...] = ( + ("eligible", "eligible", "bool"), + ("ineligibility_reason", "ineligibility_reason", "string_or_null"), + ("verdict", "verdict", "string_or_null"), + ("elapsed_ms", "elapsed_ms", "bigint"), + ("solved_count", "solved_count", "bigint"), + ("unsolved_sb_count", "unsolved_sb_count", "bigint"), + ("total_mismatched_sbs", "total_mismatched_sbs", "bigint"), + ("fallback_to_full_pipeline", "fallback_to_full_pipeline", "bool"), + ("sub_bucket_count", "sub_bucket_count", "bigint"), + ("bucket_count", "bucket_count", "bigint"), + ("target_row_count", "target_row_count", "bigint_or_null"), + ("row_count_source", "row_count_source", "string_or_null"), + ("fetch_path", "fetch_path", "string_or_null"), +) + + +def _render_fp_metrics_value(value: object, sql_type: str) -> str: + """Render a Python value to its SQL-literal form per the declared sql_type. + + Centralised so values cannot reach the persisted SQL fragment without + flowing through type-aware rendering. An unknown ``sql_type`` raises + rather than silently falling through to ``str(value)`` — adding a new + field type is a deliberate change in this function, not an accident in + a caller. + """ + if sql_type == "bool": + return str(bool(value)).lower() + if sql_type == "bigint": + # Dataclass typing pins this to ``int``; assertion is a true invariant + # and also narrows ``value`` from ``object`` for mypy. + assert isinstance(value, int), f"bigint field expected int, got {type(value).__name__}" + return f"cast({value} as bigint)" + if sql_type == "bigint_or_null": + if value is None: + return "NULL" + assert isinstance(value, int), f"bigint_or_null field expected int|None, got {type(value).__name__}" + return f"cast({value} as bigint)" + if sql_type == "string_or_null": + if value is None: + return "NULL" + # Defense-in-depth: scrub embedded single/double quotes that would + # terminate the SQL literal. Metadata values come from controlled + # paths today; this guards a future field that carries user input. + scrubbed = str(value).replace("'", "").replace('"', "") + return f"'{scrubbed}'" + raise ValueError( + f"Unsupported sql_type for fingerprint_metrics struct: {sql_type!r}. " + "Allowed: 'bool', 'bigint', 'bigint_or_null', 'string_or_null'." + ) + + class AbstractReconIntermediatePersist: @property def base_dir(self) -> Path: @@ -115,9 +181,19 @@ def write_and_read_df_with_volumes( raise ReadAndWriteWithVolumeException(message) from e -def _write_df_to_delta(df: DataFrame, table_name: str, mode="append"): +def _write_df_to_delta(df: DataFrame, table_name: str, mode="append", *, merge_schema: bool = False): + """Append to a Delta table; ``merge_schema=True`` enables additive column evolution. + + The fingerprint precheck adds a ``fingerprint_metrics`` struct to the + ``recon_metrics`` row; on the first write against an existing customer + table this column has to materialise without an explicit ``ALTER TABLE``, + so callers writing that table pass ``merge_schema=True``. + """ try: - df.write.mode(mode).saveAsTable(table_name) + writer = df.write.mode(mode) + if merge_schema: + writer = writer.option("mergeSchema", "true") + writer.saveAsTable(table_name) logger.info(f"Data written to {table_name} successfully.") except Exception as e: message = f"Error writing data to {table_name}: {e}" @@ -131,7 +207,8 @@ def generate_final_reconcile_output( metadata_config: ReconcileMetadataConfig = ReconcileMetadataConfig(), ) -> ReconcileOutput: _db_prefix = f"{metadata_config.catalog}.{metadata_config.schema}" - recon_df = spark.sql(f""" + recon_df = spark.sql( + f""" SELECT CASE WHEN COALESCE(MAIN.SOURCE_TABLE.CATALOG, '') <> '' THEN CONCAT(MAIN.SOURCE_TABLE.CATALOG, '.', MAIN.SOURCE_TABLE.SCHEMA, '.', MAIN.SOURCE_TABLE.TABLE_NAME) @@ -166,7 +243,8 @@ def generate_final_reconcile_output( (MAIN.recon_table_id = METRICS.recon_table_id) WHERE MAIN.recon_id = '{recon_id}' - """) + """ + ) table_output = [] for row in recon_df.collect(): if row.EXCEPTION_MESSAGE is not None and row.EXCEPTION_MESSAGE != "": @@ -198,7 +276,8 @@ def generate_final_reconcile_aggregate_output( metadata_config: ReconcileMetadataConfig = ReconcileMetadataConfig(), ) -> ReconcileOutput: _db_prefix = f"{metadata_config.catalog}.{metadata_config.schema}" - recon_df = spark.sql(f""" + recon_df = spark.sql( + f""" SELECT source_table, target_table, EVERY(status) AS status, @@ -222,7 +301,8 @@ def generate_final_reconcile_aggregate_output( MAIN.recon_id = '{recon_id}' ) GROUP BY source_table, target_table; - """) + """ + ) table_output = [] for row in recon_df.collect(): if row.exception_message is not None and row.exception_message != "": @@ -290,7 +370,8 @@ def _insert_into_main_table( operation_name: str = "reconcile", ) -> None: source_dialect_key = get_key_from_dialect(self.source_dialect) - df = self.spark.sql(f""" + df = self.spark.sql( + f""" select {recon_table_id} as recon_table_id, '{self.recon_id}' as recon_id, case @@ -313,7 +394,8 @@ def _insert_into_main_table( '{operation_name}' as operation_name, cast('{recon_process_duration.start_ts}' as timestamp) as start_ts, cast('{recon_process_duration.end_ts}' as timestamp) as end_ts - """) + """ + ) _write_df_to_delta(df, f"{self._db_prefix}.{_RECON_TABLE_NAME}") @classmethod @@ -350,6 +432,32 @@ def _is_mismatch_within_threshold_limits( return res + @staticmethod + def _fingerprint_metrics_struct_sql(metadata: FingerprintRunMetadata) -> str: + """Render the ``fingerprint_metrics`` named_struct for the metrics table. + + The per-field rendering is driven by ``_FP_METRICS_STRUCT_FIELDS``; + adding a metadata field is one entry in that tuple — no untyped + f-string append. Every value flows through ``_render_fp_metrics_value`` + which checks the declared SQL-type at the boundary, so raw values + never reach the SQL string without type-aware rendering. + + Output contract: + - ``mergeSchema`` evolves the column to a concrete StructType on + first write (Delta can't infer fields from an all-NULL struct). + - String fields scrubbed of embedded quotes (defense-in-depth). + - ``None`` emits SQL ``NULL`` (not the string ``'None'``) so + dashboards filtering on ``IS NULL`` don't miss rows. + - Field ORDER must match the dataclass declaration; ``saveAsTable`` + resolves struct fields positionally. + """ + parts: list[str] = [] + for sql_field, attr, sql_type in _FP_METRICS_STRUCT_FIELDS: + value = getattr(metadata, attr) + rendered = _render_fp_metrics_value(value, sql_type) + parts.append(f"'{sql_field}', {rendered}") + return f"named_struct({', '.join(parts)})" + def _insert_into_metrics_table( self, recon_table_id: int, @@ -357,6 +465,7 @@ def _insert_into_metrics_table( schema_reconcile_output: SchemaReconcileOutput, table_conf: Table, record_count: ReconcileRecordCount, + fingerprint_metadata: FingerprintRunMetadata | None = None, ) -> None: status = False if data_reconcile_output.exception in {None, ''} and schema_reconcile_output.exception in {None, ''}: @@ -381,7 +490,15 @@ def _insert_into_metrics_table( if data_reconcile_output.mismatch and data_reconcile_output.mismatch.mismatch_columns: mismatch_columns = data_reconcile_output.mismatch.mismatch_columns - df = self.spark.sql(f""" + # Sources that don't go through the fingerprint precheck (e.g. Snowflake, + # Oracle today, or any aggregate-mode reconcile) don't pass metadata. + # Use the populated "feature off" struct so dashboards can group by + # ``eligible`` without NULL-struct handling. + fp_metadata = fingerprint_metadata if fingerprint_metadata is not None else FingerprintRunMetadata.disabled() + fingerprint_struct_sql = self._fingerprint_metrics_struct_sql(fp_metadata) + + df = self.spark.sql( + f""" select {recon_table_id} as recon_table_id, named_struct( 'source_record_count', cast({record_count.source} as bigint), @@ -401,7 +518,8 @@ def _insert_into_metrics_table( ) else null end, 'schema_comparison', case when '{self.report_type.lower()}' in ('all', 'schema') and '{exception_msg}' = '' then - {schema_reconcile_output.is_valid} else null end + {schema_reconcile_output.is_valid} else null end, + 'fingerprint_metrics', {fingerprint_struct_sql} ) as recon_metrics, named_struct( 'status', {status}, @@ -409,8 +527,12 @@ def _insert_into_metrics_table( 'exception_message', "{exception_msg}" ) as run_metrics, cast('{insertion_time}' as timestamp) as inserted_ts - """) - _write_df_to_delta(df, f"{self._db_prefix}.{_RECON_METRICS_TABLE_NAME}") + """ + ) + # mergeSchema=True so the additive ``fingerprint_metrics`` field + # evolves on first write against pre-existing customer tables without + # a manual ALTER TABLE. + _write_df_to_delta(df, f"{self._db_prefix}.{_RECON_METRICS_TABLE_NAME}", merge_schema=True) @classmethod def _create_map_column( @@ -554,7 +676,8 @@ def _insert_aggregates_into_metrics_table( assert agg_output.rule, "Aggregate Rule must be present for storing the metrics" rule_id = hash(f"{recon_table_id}_{agg_output.rule.column_from_rule}") - agg_metrics_df = self.spark.sql(f""" + agg_metrics_df = self.spark.sql( + f""" select {recon_table_id} as recon_table_id, {rule_id} as rule_id, if('{exception_msg}' = '', named_struct( @@ -568,7 +691,8 @@ def _insert_aggregates_into_metrics_table( 'exception_message', "{exception_msg}" ) as run_metrics, cast('{insertion_time}' as timestamp) as inserted_ts - """) + """ + ) agg_metrics_df_list.append(agg_metrics_df) agg_metrics_table_df = self._union_dataframes(agg_metrics_df_list) @@ -621,11 +745,17 @@ def start( table_conf: Table, recon_process_duration: ReconcileProcessDuration, record_count: ReconcileRecordCount, + fingerprint_metadata: FingerprintRunMetadata | None = None, ) -> None: recon_table_id = self._generate_recon_main_id(table_conf) self._insert_into_main_table(recon_table_id, table_conf, recon_process_duration) self._insert_into_metrics_table( - recon_table_id, data_reconcile_output, schema_reconcile_output, table_conf, record_count + recon_table_id, + data_reconcile_output, + schema_reconcile_output, + table_conf, + record_count, + fingerprint_metadata=fingerprint_metadata, ) self._insert_into_details_table(recon_table_id, data_reconcile_output, schema_reconcile_output) diff --git a/src/databricks/labs/lakebridge/reconcile/reconciliation.py b/src/databricks/labs/lakebridge/reconcile/reconciliation.py index 150fa61b25..37ac2135fe 100644 --- a/src/databricks/labs/lakebridge/reconcile/reconciliation.py +++ b/src/databricks/labs/lakebridge/reconcile/reconciliation.py @@ -82,6 +82,14 @@ def source(self) -> DataSource: def target(self) -> DataSource: return self._target + @property + def spark(self) -> SparkSession: + return self._spark + + @property + def source_engine(self) -> Dialect: + return self._source_engine + @property def report_type(self) -> str: return self._report_type diff --git a/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py b/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py index 440b70fef9..d3ac7486fc 100644 --- a/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py +++ b/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py @@ -1,3 +1,13 @@ +"""Reconcile orchestration. + +Row-level compare path: schema compare, then — when ``fingerprint_precheck`` is +enabled and ``source.dialect`` has a registered ``FingerprintQueryBuilder`` +(today: Redshift) — try fingerprint (MD5 buckets) first. MATCH returns a +synthetic match without hash + JOIN; MISMATCH builds the output from the +already-fetched filtered rows; failure or unsupported sources fall through to +``HashQueryBuilder`` + ``reconciler.reconcile_data``. +""" + import logging from datetime import datetime, timezone from uuid import uuid4 @@ -11,6 +21,22 @@ from databricks.labs.lakebridge.reconcile import utils from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException, ReconciliationException +from databricks.labs.lakebridge.reconcile.fingerprint.exceptions import ( + FingerprintError, + UnmappedTargetColumnMappingError, +) +from databricks.labs.lakebridge.reconcile.fingerprint.metadata import ( + FingerprintRunMetadata, + INELIGIBLE_UNMAPPED_TARGET_COLUMN_MAPPING, +) +from databricks.labs.lakebridge.reconcile.fingerprint.orchestrator import ( + FingerprintResult, + build_mismatch_output, + classify_ineligibility, + fingerprint_match_output, + resolve_compare_key_columns, + run_fingerprint_precheck, +) from databricks.labs.lakebridge.reconcile.recon_capture import ( ReconCapture, generate_final_reconcile_output, @@ -34,6 +60,26 @@ _RECON_REPORT_TYPES = {"schema", "data", "row", "all", "aggregate"} +def _try_unpersist(df) -> None: + """Best-effort ``unpersist`` for cached fingerprint frames on a fallback + exit. The DataFrames may have been ``persist()``-ed by + ``compute_target_fingerprint`` / source-side fetch and we don't want them + to linger in executor storage for the rest of the recon. + + Swallows any exception because unpersist is purely a release path — failing + here on a frame that was never cached, or whose plan is in a partial state, + must not mask the original error that triggered the fallback. + """ + if df is None: + return + try: + df.unpersist(blocking=False) + except Exception: # pylint: disable=broad-except + # Intentional: ``unpersist`` on a non-cached frame raises in some Spark + # versions and we explicitly do not want to surface that on the fallback. + logger.debug("Best-effort unpersist on fingerprint fallback DataFrame failed; ignoring.", exc_info=True) + + class TriggerReconService: @staticmethod @@ -75,6 +121,25 @@ def create_recon_dependencies( logger.info(f"report_type: {report_type}, data_source: {source_dialect} ") utils.validate_input(report_type, _RECON_REPORT_TYPES, "Invalid report type") + # Warn on silently-ignored knob combinations — these flags have zero + # effect when ``fingerprint_precheck`` itself is off, so a user who has + # flipped only a secondary knob is asking for a behaviour change they + # will not get. Surfacing this once per recon is enough to catch the + # typo without spamming per-table. + if reconcile_config.fingerprint_treat_empty_as_null and not reconcile_config.fingerprint_precheck: + logger.warning( + "ReconcileConfig.fingerprint_treat_empty_as_null is True but " + "fingerprint_precheck is False; the empty-as-null behaviour " + "applies only to the fingerprint hash path, so this knob is " + "ignored. Enable fingerprint_precheck to take effect." + ) + if reconcile_config.fingerprint_row_count_override is not None and not reconcile_config.fingerprint_precheck: + logger.warning( + "ReconcileConfig.fingerprint_row_count_override is set but " + "fingerprint_precheck is False; the override drives only the " + "fingerprint sub-bucket tier, so this knob is ignored." + ) + # validate the connection source, target = utils.initialise_data_source( source_dialect=reconcile_config.source.dialect, @@ -119,8 +184,13 @@ def recon_one( reconciler.source, reconciler.target ).normalize_recon_table_config(table_conf) - schema_reconcile_output, data_reconcile_output, recon_process_duration = TriggerReconService._do_recon_one( - reconciler, reconcile_config, normalized_table_conf + ( + schema_reconcile_output, + data_reconcile_output, + recon_process_duration, + fingerprint_metadata, + ) = TriggerReconService._do_recon_one( + reconciler, reconcile_config, normalized_table_conf, recon_id=recon_capture.recon_id ) recon_capture.start( @@ -129,22 +199,47 @@ def recon_one( table_conf=table_conf, recon_process_duration=recon_process_duration, record_count=reconciler.get_record_count(table_conf, reconciler.report_type), + fingerprint_metadata=fingerprint_metadata, ) return schema_reconcile_output, data_reconcile_output @staticmethod - def _do_recon_one(reconciler: Reconciliation, reconcile_config: ReconcileConfig, table_conf: Table): + def _do_recon_one( + reconciler: Reconciliation, + reconcile_config: ReconcileConfig, + table_conf: Table, + *, + recon_id: str | None = None, + ): recon_process_duration = ReconcileProcessDuration(start_ts=str(datetime.now(tz=timezone.utc)), end_ts=None) schema_reconcile_output = SchemaReconcileOutput(is_valid=True) data_reconcile_output = DataReconcileOutput() + # Compute ineligibility once so metadata is populated correctly + # regardless of which code path exits first. The data-path block + # below overwrites the eligible default with the actual verdict. + ineligibility_reason = classify_ineligibility( + flag_enabled=reconcile_config.fingerprint_precheck, + data_source=reconcile_config.source.dialect, + report_type=reconciler.report_type, + table_conf=table_conf, + ) + if ineligibility_reason is not None: + fingerprint_metadata: FingerprintRunMetadata = FingerprintRunMetadata.ineligible(ineligibility_reason) + else: + fingerprint_metadata = FingerprintRunMetadata(eligible=True) + try: src_schema, tgt_schema = TriggerReconService.get_schemas( reconciler.source, reconciler.target, table_conf, reconcile_config.database_config, True ) except DataSourceRuntimeException as e: schema_reconcile_output = SchemaReconcileOutput(is_valid=False, exception=str(e)) + if ineligibility_reason is None: + fingerprint_metadata = FingerprintRunMetadata( + eligible=True, fallback_to_full_pipeline=True, verdict="FAILED" + ) else: if reconciler.report_type in {"schema", "all"}: schema_reconcile_output = TriggerReconService._run_reconcile_schema( @@ -156,16 +251,19 @@ def _do_recon_one(reconciler: Reconciliation, reconcile_config: ReconcileConfig, logger.info("Schema comparison is completed.") if reconciler.report_type in {"data", "row", "all"}: - data_reconcile_output = TriggerReconService._run_reconcile_data( + data_reconcile_output, fingerprint_metadata = TriggerReconService._run_fingerprint_or_reconcile_data( reconciler=reconciler, + reconcile_config=reconcile_config, table_conf=table_conf, src_schema=src_schema, tgt_schema=tgt_schema, + ineligibility_reason=ineligibility_reason, + recon_id=recon_id, ) logger.info(f"Reconciliation for '{reconciler.report_type}' report completed.") recon_process_duration.end_ts = str(datetime.now(tz=timezone.utc)) - return schema_reconcile_output, data_reconcile_output, recon_process_duration + return schema_reconcile_output, data_reconcile_output, recon_process_duration, fingerprint_metadata @staticmethod def get_schemas( @@ -215,6 +313,208 @@ def _run_reconcile_data( except DataSourceRuntimeException as e: return DataReconcileOutput(exception=str(e)) + @staticmethod + def _invoke_precheck( + *, + reconciler: Reconciliation, + reconcile_config: ReconcileConfig, + table_conf: Table, + src_schema: list[Schema], + tgt_schema: list[Schema], + recon_id: str | None, + ) -> tuple[FingerprintResult | None, str | None, bool]: + """Run ``run_fingerprint_precheck`` and classify the outcome for the caller. + + Returns ``(fp_result, runtime_ineligibility, precheck_failed)`` where: + + * ``fp_result`` is the ``FingerprintResult`` (or ``None`` if the precheck + declined / raised), + * ``runtime_ineligibility`` is an ``IneligibilityReason`` value when the + precheck was rejected for a *config-time* reason discovered at runtime + (today: ``UnmappedTargetColumnMappingError``). The caller routes this + through ``FingerprintRunMetadata.ineligible(...)`` so adoption queries + on ``recon_metrics.fingerprint_metrics.ineligibility_reason`` see the + typed value instead of a silent ``None``. + * ``precheck_failed`` is ``True`` when the precheck raised a runtime + fault (``FingerprintError`` / ``DataSourceRuntimeException`` / + ``PySparkException``). The caller maps that to ``verdict="FAILED"`` so + dashboards can quantify precheck reliability. + + Extracting this keeps the parent method's branching surface small + enough for the project's McCabe budget while preserving every catch. + """ + try: + fp_result = run_fingerprint_precheck( + source=reconciler.source, + target=reconciler.target, + spark=reconciler.spark, + source_engine=reconciler.source_engine, + database_config=reconcile_config.database_config, + table_conf=table_conf, + src_schema=src_schema, + tgt_schema=tgt_schema, + report_type=reconciler.report_type, + data_source=reconcile_config.source.dialect, + treat_empty_as_null=reconcile_config.fingerprint_treat_empty_as_null, + target_row_count_override=reconcile_config.fingerprint_row_count_override, + recon_id=recon_id, + ) + except UnmappedTargetColumnMappingError as e: + # Caught BEFORE the generic ``FingerprintError`` branch because this + # is a config-time ineligibility (a column_mapping target that + # doesn't exist on the target), not a runtime failure of the + # precheck. Surfacing it as a typed ``ineligibility_reason`` keeps + # adoption queries honest. + logger.warning(f"Fingerprint precheck ineligible — {e}; falling back to full pipeline.") + return None, INELIGIBLE_UNMAPPED_TARGET_COLUMN_MAPPING, False + except (FingerprintError, DataSourceRuntimeException, PySparkException) as e: + # Three failure modes meet here: + # * ``FingerprintError`` — logical errors raised by the precheck. + # * ``DataSourceRuntimeException`` — wrapped JDBC failures from + # the connector layer during detection or fetch. + # * ``PySparkException`` — bare Spark errors that the connector + # wrap doesn't see, e.g. ``AnalysisException`` raised at action + # time when ``compute_target_fingerprint`` materialises a plan + # that references a missing target column. Without this catch + # the recon would crash mid-pipeline instead of falling back; + # the precheck must be opt-in safe. + # All three collapse to "fallback to full pipeline with verdict=FAILED" + # so dashboards can quantify precheck reliability without distinguishing + # the cause — the underlying error is logged here and re-discoverable + # from cluster logs if needed. + logger.warning(f"Fingerprint precheck failed ({e}); falling back to full pipeline.") + return None, None, True + return fp_result, None, False + + @staticmethod + def _run_fingerprint_or_reconcile_data( + reconciler: Reconciliation, + reconcile_config: ReconcileConfig, + table_conf: Table, + src_schema: list[Schema], + tgt_schema: list[Schema], + ineligibility_reason: str | None = None, + recon_id: str | None = None, + ) -> tuple[DataReconcileOutput, FingerprintRunMetadata]: + """Try the fingerprint precheck; on any non-MATCH outcome, fall back to the + full hash-and-join reconcile path. + + Returns ``(data_reconcile_output, fingerprint_metadata)``. The metadata + records the verdict regardless of which path produced the output, so + the persisted ``recon_metrics.fingerprint_metrics`` struct always + reflects what actually happened. + + ``ineligibility_reason`` may be supplied by the caller (``_do_recon_one`` + pre-computes it once so the schema-failure path can pre-populate the + metadata) but is computed lazily here when omitted, so this helper is + usable as a standalone unit-test boundary. + """ + if ineligibility_reason is None: + ineligibility_reason = classify_ineligibility( + flag_enabled=reconcile_config.fingerprint_precheck, + data_source=reconcile_config.source.dialect, + report_type=reconciler.report_type, + table_conf=table_conf, + ) + + if ineligibility_reason is not None: + data_reconcile_output = TriggerReconService._run_reconcile_data( + reconciler=reconciler, + table_conf=table_conf, + src_schema=src_schema, + tgt_schema=tgt_schema, + ) + return data_reconcile_output, FingerprintRunMetadata.ineligible(ineligibility_reason) + + fp_result, runtime_ineligibility, precheck_failed = TriggerReconService._invoke_precheck( + reconciler=reconciler, + reconcile_config=reconcile_config, + table_conf=table_conf, + src_schema=src_schema, + tgt_schema=tgt_schema, + recon_id=recon_id, + ) + + if runtime_ineligibility is not None: + data_reconcile_output = TriggerReconService._run_reconcile_data( + reconciler=reconciler, + table_conf=table_conf, + src_schema=src_schema, + tgt_schema=tgt_schema, + ) + return data_reconcile_output, FingerprintRunMetadata.ineligible(runtime_ineligibility) + + if fp_result is None: + # ``None`` covers two cases: + # - the precheck raised (precheck_failed=True) → verdict="FAILED" + # - the precheck declined (column-resolution skip, systemic + # mismatch, no solved buckets) → verdict left unset + # In both cases the full pipeline produces the answer and the + # metadata records a fallback. + data_reconcile_output = TriggerReconService._run_reconcile_data( + reconciler=reconciler, + table_conf=table_conf, + src_schema=src_schema, + tgt_schema=tgt_schema, + ) + return data_reconcile_output, FingerprintRunMetadata.fallback( + verdict="FAILED" if precheck_failed else None, + ) + + if fp_result.verdict == "MATCH": + return fingerprint_match_output(), FingerprintRunMetadata.from_result(fp_result, verdict="MATCH") + + # MISMATCH: the precheck has fetched the differing rows. If the rows + # are missing (e.g. an upstream codepath returned ``MISMATCH`` without + # populating both row sets), we cannot build the output here and must + # fall back to the full pipeline — preserving solver counters so the + # dashboard still shows what the precheck observed. Release any cached + # frames the precheck may have left behind so the executor's storage + # layer doesn't carry the dead plan through the rest of the recon. + if fp_result.source_rows is None or fp_result.target_rows is None: + _try_unpersist(fp_result.source_rows) + _try_unpersist(fp_result.target_rows) + data_reconcile_output = TriggerReconService._run_reconcile_data( + reconciler=reconciler, + table_conf=table_conf, + src_schema=src_schema, + tgt_schema=tgt_schema, + ) + return data_reconcile_output, FingerprintRunMetadata.from_result( + fp_result, verdict="MISMATCH", fallback_to_full_pipeline=True + ) + + try: + data_reconcile_output = build_mismatch_output( + src_hashed=fp_result.source_rows, + tgt_hashed=fp_result.target_rows, + key_columns=resolve_compare_key_columns(table_conf), + report_type=reconciler.report_type, + persistence=reconciler.intermediate_persist, + ) + except (DataSourceRuntimeException, PySparkException) as e: + # ``build_mismatch_output`` runs Spark actions on the prefetched src/tgt + # frames; an analysis or runtime failure here must not crash the recon. + # Mirror the fail-open pattern used by every other non-MATCH branch in + # this method: fall through to the standard full pipeline so the table + # still gets a real recon answer, and record on the metadata that the + # precheck-built output was rejected. Release the cached frames first + # so a partial materialisation does not linger in executor storage for + # the full recon lifetime. + _try_unpersist(fp_result.source_rows) + _try_unpersist(fp_result.target_rows) + logger.warning(f"Fingerprint mismatch-output build failed ({e}); falling back to full pipeline.") + data_reconcile_output = TriggerReconService._run_reconcile_data( + reconciler=reconciler, + table_conf=table_conf, + src_schema=src_schema, + tgt_schema=tgt_schema, + ) + return data_reconcile_output, FingerprintRunMetadata.from_result( + fp_result, verdict="MISMATCH", fallback_to_full_pipeline=True + ) + return data_reconcile_output, FingerprintRunMetadata.from_result(fp_result, verdict="MISMATCH") + @staticmethod def verify_successful_reconciliation(reconcile_output: ReconcileOutput, report_type: str) -> ReconcileOutput: def is_table_recon_mismatch(table_output: ReconcileTableOutput): diff --git a/src/databricks/labs/lakebridge/reconcile/utils.py b/src/databricks/labs/lakebridge/reconcile/utils.py index f6cf0274a3..38d22e79f1 100644 --- a/src/databricks/labs/lakebridge/reconcile/utils.py +++ b/src/databricks/labs/lakebridge/reconcile/utils.py @@ -9,11 +9,7 @@ logger = logging.getLogger(__name__) -def initialise_data_source( - spark: SparkSession, - source_dialect: str, - connection_name: str | None, -): +def initialise_data_source(spark: SparkSession, source_dialect: str, connection_name: str | None): if not connection_name: validate_input(source_dialect, {"databricks"}, "Please configure connection name") source = create_adapter(engine=get_dialect("databricks"), spark=spark, connection_name="databricks") diff --git a/tests/unit/reconcile/fingerprint/__init__.py b/tests/unit/reconcile/fingerprint/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/reconcile/fingerprint/_fixtures.py b/tests/unit/reconcile/fingerprint/_fixtures.py new file mode 100644 index 0000000000..d88a6f87e8 --- /dev/null +++ b/tests/unit/reconcile/fingerprint/_fixtures.py @@ -0,0 +1,120 @@ +"""Shared test fixtures for fingerprint unit tests. + +Helpers are module-level factories (not pytest fixtures) because the original +call sites construct multiple variants per test and need keyword overrides. +Centralising them here removes ~120 LOC of duplication and lets pylint's +``similarities`` checker land at 10/10 across the test tree. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +from databricks.labs.lakebridge.config import DatabaseConfig +from databricks.labs.lakebridge.reconcile.fingerprint.orchestrator import ( + _FetchContext, + _TierSelection, +) +from databricks.labs.lakebridge.reconcile.fingerprint.row_count import RowCountSource +from databricks.labs.lakebridge.reconcile.recon_config import Schema, Table + + +def make_database_config( + *, + source_catalog: str = "source_catalog", + source_schema: str = "public", + target_catalog: str | None = "test_catalog", + target_schema: str = "perf_test", +) -> DatabaseConfig: + return DatabaseConfig( + source_catalog=source_catalog, + source_schema=source_schema, + target_catalog=target_catalog, + target_schema=target_schema, + ) + + +def make_table_conf( + *, + source_name: str = "orders", + target_name: str = "orders", + join_columns: list[str] | None = None, + jdbc_reader_options=None, +) -> Table: + return Table( + source_name=source_name, + target_name=target_name, + join_columns=list(join_columns) if join_columns is not None else ["order_id"], + jdbc_reader_options=jdbc_reader_options, + ) + + +def make_schema() -> list[Schema]: + return [ + Schema('"order_id"', "bigint", "`order_id`", '"order_id"'), + Schema('"order_amount"', "numeric(10,2)", "`order_amount`", '"order_amount"'), + ] + + +def make_tier( + *, + sub_bucket_count: int = 2_097_152, + bucket_count: int = 2_048, + target_row_count: int = 100_000_000, + row_count_source: str = RowCountSource.DELTA_DESCRIBE_DETAIL.value, +) -> _TierSelection: + return _TierSelection( + sub_bucket_count=sub_bucket_count, + bucket_count=bucket_count, + target_row_count=target_row_count, + row_count_source=row_count_source, + ) + + +def make_fetch_ctx( + *, + source: MagicMock | None = None, + target: MagicMock | None = None, + query_builder: MagicMock | None = None, + treat_empty_as_null: bool = False, +) -> _FetchContext: + """Build a ``_FetchContext`` wired with default mock collaborators.""" + return _FetchContext( + source=source if source is not None else MagicMock(), + target=target if target is not None else MagicMock(), + source_engine=MagicMock(), + database_config=make_database_config(), + table_conf=make_table_conf(), + src_schema=make_schema(), + tgt_schema=make_schema(), + detection_cols=make_schema(), + column_mapping=None, + query_builder=query_builder if query_builder is not None else MagicMock(), + tier=make_tier(), + treat_empty_as_null=treat_empty_as_null, + ) + + +def assert_project_all_columns_kwargs(call_kwargs: dict, *, side: str) -> None: + """Pin: Stage-2 fetch must request the all-columns projection on both sides.""" + assert call_kwargs.get("report_type") == "data" + assert call_kwargs.get("project_all_columns") is True, ( + f"Fingerprint Stage-2 {side}-fetch must request all-columns projection so " + f"compare.capture_mismatch_data_and_columns can populate mismatch_columns. " + f"Got kwargs={call_kwargs!r}" + ) + + +def make_describe_detail_df(num_records: int | None) -> MagicMock: + """Mimic the ``DESCRIBE DETAIL`` DataFrame used by ``_select_tier`` / ``fetch_target_row_count``.""" + df = MagicMock() + df.columns = ["numRecords"] + select_result = MagicMock() + if num_records is None: + select_result.collect.return_value = [] + else: + row = MagicMock() + row.__getitem__.side_effect = lambda key: {"numRecords": num_records}[key] + select_result.collect.return_value = [row] + df.select.return_value = select_result + return df diff --git a/tests/unit/reconcile/fingerprint/test_constants.py b/tests/unit/reconcile/fingerprint/test_constants.py new file mode 100644 index 0000000000..30808d08cb --- /dev/null +++ b/tests/unit/reconcile/fingerprint/test_constants.py @@ -0,0 +1,229 @@ +"""Unit tests for fingerprint adaptive sub-bucket sizing. + +Pins the ``SUB_BUCKET_TIERS`` table and ``pick_sub_bucket_count()`` selector so any +change to the tier breakpoints is an explicit, reviewed decision. +""" + +from __future__ import annotations + +import pytest + +from databricks.labs.lakebridge.reconcile.fingerprint.constants import ( + BUCKET_COUNT, + SUB_BUCKET_COUNT, + SUB_BUCKET_TIERS, + build_fingerprint_where_clause, + pick_sub_bucket_count, +) + + +def test_sub_bucket_tiers_max_row_count_is_strictly_monotonic() -> None: + """``pick_sub_bucket_count`` walks the tier table in order and picks the + first ``max_row_count >= row_count`` entry. The selector relies on the + table being sorted by ``max_row_count`` ascending; a future contributor + inserting a row out of order would silently skip eligible workloads onto a + coarser tier. Pin the invariant so re-ordering breaks CI, not customers. + """ + bounded_max_row_counts = [t[0] for t in SUB_BUCKET_TIERS if t[0] is not None] + assert bounded_max_row_counts == sorted( + bounded_max_row_counts + ), f"SUB_BUCKET_TIERS rows must be ordered by ascending max_row_count; got {bounded_max_row_counts}" + # No two entries can share the same max_row_count — the selector returns the + # first match, so a duplicate would make the second entry unreachable. + assert len(set(bounded_max_row_counts)) == len( + bounded_max_row_counts + ), "Duplicate max_row_count values make later entries unreachable in pick_sub_bucket_count." + # Exactly one open-ended tier (``max_row_count=None``) must exist as the last + # entry so any row count above the last bounded tier still resolves. + assert SUB_BUCKET_TIERS[-1][0] is None, "Last tier must be open-ended (max_row_count=None)" + assert sum(1 for t in SUB_BUCKET_TIERS if t[0] is None) == 1, "Exactly one open-ended tier permitted" + + +def test_sub_bucket_tiers_sub_bucket_counts_are_powers_of_two() -> None: + """The MOD-based sub-bucket assignment distributes evenly only when the + modulus is a power of 2 (the comment in ``constants.py`` calls this out). + """ + for max_rc, sub_buckets, _bucket_count in SUB_BUCKET_TIERS: + assert ( + sub_buckets > 0 and (sub_buckets & (sub_buckets - 1)) == 0 + ), f"sub_bucket_count {sub_buckets} (tier max_row_count={max_rc}) is not a power of 2" + + +@pytest.mark.parametrize( + ("row_count", "expected_sub_buckets", "expected_buckets"), + [ + # < 50K + (1, 16_384, 128), + (10_000, 16_384, 128), + (50_000, 16_384, 128), + # 50K – 500K + (50_001, 262_144, 512), + (100_000, 262_144, 512), + (500_000, 262_144, 512), + # 500K – 50M + (500_001, 1_048_576, 1_024), + (10_000_000, 1_048_576, 1_024), + (50_000_000, 1_048_576, 1_024), + # 50M – 500M + (50_000_001, 2_097_152, 2_048), + (100_000_000, 2_097_152, 2_048), + (500_000_000, 2_097_152, 2_048), + # 500M – 5B + (500_000_001, 4_194_304, 4_096), + (1_000_000_000, 4_194_304, 4_096), + (5_000_000_000, 4_194_304, 4_096), + # 5B – 50B + (5_000_000_001, 8_388_608, 8_192), + (15_800_000_000, 8_388_608, 8_192), + (20_000_000_000, 8_388_608, 8_192), + (50_000_000_000, 8_388_608, 8_192), + # 50B+ + (50_000_000_001, 16_777_216, 16_384), + (100_000_000_000, 16_777_216, 16_384), + (1_000_000_000_000, 16_777_216, 16_384), + ], +) +def test_pick_sub_bucket_count_tier_table(row_count, expected_sub_buckets, expected_buckets): + """Boundaries are inclusive on the upper end of each tier.""" + sub_buckets, buckets = pick_sub_bucket_count(row_count) + assert sub_buckets == expected_sub_buckets + assert buckets == expected_buckets + + +@pytest.mark.parametrize("row_count", [None, 0, -1, -100]) +def test_pick_sub_bucket_count_falls_back_to_static_default_when_row_count_unknown(row_count): + """Unknown / non-positive row count falls back to the static default.""" + sub_buckets, buckets = pick_sub_bucket_count(row_count) + assert sub_buckets == SUB_BUCKET_COUNT + assert buckets == BUCKET_COUNT + + +def test_tier_table_is_monotonic_in_row_count(): + """Sub-bucket and bucket counts must be non-decreasing as the tier widens.""" + last_sub_buckets = 0 + last_buckets = 0 + for _max_rows, sub_buckets, buckets in SUB_BUCKET_TIERS: + assert sub_buckets >= last_sub_buckets, f"sub_buckets regressed: {last_sub_buckets} -> {sub_buckets}" + assert buckets >= last_buckets, f"buckets regressed: {last_buckets} -> {buckets}" + last_sub_buckets = sub_buckets + last_buckets = buckets + + +def _is_power_of_two(value: int) -> bool: + return value > 0 and (value & (value - 1)) == 0 + + +def test_tier_table_buckets_are_strict_subdivisions_of_sub_buckets(): + """``bucket_count < sub_bucket_count`` and the ratio must be a power of 2.""" + for _max_rows, sub_buckets, buckets in SUB_BUCKET_TIERS: + assert buckets < sub_buckets + ratio = sub_buckets // buckets + assert _is_power_of_two(ratio), f"ratio {ratio} for ({sub_buckets}, {buckets}) is not a power of 2" + + +def test_tier_table_powers_of_two(): + """Sub-bucket and bucket counts must be powers of 2 so MOD distributes uniformly.""" + for _max_rows, sub_buckets, buckets in SUB_BUCKET_TIERS: + assert _is_power_of_two(sub_buckets), f"sub_bucket_count {sub_buckets} is not a power of 2" + assert _is_power_of_two(buckets), f"bucket_count {buckets} is not a power of 2" + + +def test_tier_table_final_clamp_is_open_ended(): + """The final tier's max_row_count must be None so the selector clamps any input.""" + assert SUB_BUCKET_TIERS[-1][0] is None + + +def test_static_defaults_match_legacy_lakebridge_values(): + """Static fallback values are pinned for direct call sites that don't yet know row count.""" + assert SUB_BUCKET_COUNT == 1_048_576 + assert BUCKET_COUNT == 32_768 + + +# -------------------------------------------------------------------------------------- +# build_fingerprint_where_clause — union form +# -------------------------------------------------------------------------------------- + + +def test_where_clause_emits_union_form_for_solved_hashes(): + """Solved hashes collapse into ONE disjunct: sb_expr IN (…) AND rh1_expr IN (…).""" + where = build_fingerprint_where_clause( + sb_expr="SB(x)", + rh1_expr="RH(x)", + solved_hashes={5: [100, 200], 7: [300]}, + unsolved_sb_ids=[], + ) + assert where == "(SB(x) IN (5, 7) AND RH(x) IN (100, 200, 300))" + + +def test_where_clause_appends_unsolved_sub_buckets_as_second_disjunct(): + """unsolved_sb_ids are added as a separate sb_expr IN (…) disjunct, OR'd.""" + where = build_fingerprint_where_clause( + sb_expr="SB(x)", + rh1_expr="RH(x)", + solved_hashes={5: [100]}, + unsolved_sb_ids=[9, 11], + ) + assert where == "(SB(x) IN (5) AND RH(x) IN (100)) OR SB(x) IN (9, 11)" + + +def test_where_clause_handles_only_unsolved_sub_buckets(): + where = build_fingerprint_where_clause( + sb_expr="SB(x)", + rh1_expr="RH(x)", + solved_hashes={}, + unsolved_sb_ids=[1, 2, 3], + ) + assert where == "SB(x) IN (1, 2, 3)" + + +def test_where_clause_size_is_constant_in_number_of_solved_sub_buckets(): + """SQL size must be O(|sb_expr| + |IN list|), not O(k · |sb_expr|), where + ``k`` is the count of solved sub-buckets. + + The naive per-sub-bucket disjunct form (``(sb=S1 AND rh1 IN (...)) OR (sb=S2 + AND rh1 IN (...)) OR ...``) duplicates the (large) ``sb_expr`` and + ``rh1_expr`` once per sub-bucket and at ~10 K mismatches on a 10-column + fixture produces a 33 MB SQL string — past Redshift's 16 MB statement-size + ceiling. The union form keeps the WHERE under 1 MB even at 50 K solved + sub-buckets. + """ + # Use distinct sentinels so substring counts are unambiguous (in practice, + # the real ``sb_expr`` wraps ``rh1_expr``; here we force them to be disjoint + # strings to measure pure repetition counts). + fat_sb_expr = "<>" + fat_rh1_expr = "<>" + solved = {sb_id: [sb_id * 7919 + 1] for sb_id in range(50_000)} + + where = build_fingerprint_where_clause(fat_sb_expr, fat_rh1_expr, solved, []) + + # Two big expressions (sb + rh) plus the two integer IN lists; nothing + # multiplied by the solved-sub-bucket count. The pre-fix per-sub-bucket + # form would yield ~50_000 * (|sb_expr| + |rh1_expr|) ≈ 200 MB. + assert len(where) < 1_000_000, f"WHERE clause too large: {len(where):,} bytes" + assert where.count(fat_sb_expr) == 1, "sb_expr must be emitted exactly once" + assert where.count(fat_rh1_expr) == 1, "rh1_expr must be emitted exactly once" + + +def test_where_clause_is_deterministic_across_dict_iteration_orders(): + """Same inputs in any dict / list order yield the same SQL — important for plan caching.""" + where_a = build_fingerprint_where_clause("SB", "RH", {3: [30], 1: [10], 2: [20, 21]}, [9, 7, 8]) + where_b = build_fingerprint_where_clause("SB", "RH", {1: [10], 2: [21, 20], 3: [30]}, [8, 9, 7]) + # Both solved_hashes and unsolved_sb_ids sides are sorted for plan-cache stability. + assert where_a == where_b + assert "(SB IN (1, 2, 3) AND RH IN (10, 20, 21, 30))" in where_a + assert "SB IN (7, 8, 9)" in where_a + + +def test_where_clause_sorts_unsolved_sb_ids_for_plan_cache_stability(): + """Caller-order shouldn't bleed into the SQL: unsolved IN-list is sorted.""" + where = build_fingerprint_where_clause("SB", "RH", {}, [42, 7, 13]) + assert where == "SB IN (7, 13, 42)" + + +def test_where_clause_raises_when_both_filter_inputs_are_empty(): + """Empty inputs would interpolate to ``WHERE )`` downstream — fail-loud beats + silently emitting broken SQL that fail-open would mask as a JDBC error. + Callers (``run_fingerprint_precheck``) must gate the fetch before reaching here. + """ + with pytest.raises(ValueError, match="requires at least one"): + build_fingerprint_where_clause("SB", "RH", {}, []) diff --git a/tests/unit/reconcile/fingerprint/test_engine_caching.py b/tests/unit/reconcile/fingerprint/test_engine_caching.py new file mode 100644 index 0000000000..12e015cb7d --- /dev/null +++ b/tests/unit/reconcile/fingerprint/test_engine_caching.py @@ -0,0 +1,217 @@ +"""Regression tests for the ``detect_and_solve`` Stage-1 caching contract. + +The naive shape — ``joined.count()`` + ``mismatched.count()`` + +``mismatched.collect()`` — would fire three Spark jobs and re-evaluate the JDBC ++ Delta read each time. The current implementation fires two jobs (one agg + +one collect) reading from a cached frame. + +These tests verify three invariants that must hold on every release: + +1. The joined DataFrame is ``.cache()``-d before any action runs. +2. ``.unpersist()`` is invoked on every return path (MATCH, systemic-MISMATCH, + solver-MISMATCH) — leaking a cached frame at 4M sub-buckets pins ~120 MB of + driver memory until GC, which is a slow leak in long-lived clusters. +3. Counts are derived from a single ``.agg(...).collect()`` call — not two + ``.count()`` calls — so we don't pay the Stage-1 wall-clock twice. +""" + +from __future__ import annotations + +import ast +import inspect +import textwrap +from unittest.mock import MagicMock, call + +import pytest + +from databricks.labs.lakebridge.reconcile.fingerprint import engine +from databricks.labs.lakebridge.reconcile.fingerprint.engine import detect_and_solve + + +def _function_calls(func): + """Collect ``Name.attr(...)`` call signatures in a function body, ignoring docstrings. + + Only captures calls whose receiver is a bare ``Name`` (e.g. + ``joined.count()``). The contract is specifically about NOT calling + ``count()`` on the named ``joined`` / ``mismatched`` locals. Chained + attribute calls (``(...).select(...).cache()``) are tracked separately via + substring inspection because they don't have a Name receiver. + """ + source = textwrap.dedent(inspect.getsource(func)) + tree = ast.parse(source) + calls: list[str] = [] + for node in ast.walk(tree): + if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute): + inner = node.func.value + if isinstance(inner, ast.Name): + calls.append(f"{inner.id}.{node.func.attr}") + return calls + + +def _function_body_text(func) -> str: + """Return the source body with the docstring stripped. + + Substring checks against this body cannot be tricked by mentions of forbidden + patterns inside docstrings or comments-as-strings. + """ + source = textwrap.dedent(inspect.getsource(func)) + tree = ast.parse(source) + func_node = tree.body[0] + if ( + func_node.body + and isinstance(func_node.body[0], ast.Expr) + and isinstance(func_node.body[0].value, ast.Constant) + and isinstance(func_node.body[0].value.value, str) + ): + # Drop the docstring node before unparsing. + func_node.body = func_node.body[1:] + return ast.unparse(func_node) + + +def _build_chain(*, total_sbs: int, mismatch_count: int): + """Mock the full Spark chain reached from ``detect_and_solve``. + + Returns ``(source_agg_df, target_agg_df, joined_mock)`` so tests can both invoke + the function and assert on the cached frame. + """ + joined = MagicMock(name="joined_select") + + select_chain = MagicMock(name="select_chain") + select_chain.cache.return_value = joined + + join_chain = MagicMock(name="join_chain") + join_chain.select.return_value = select_chain + + src = MagicMock(name="source_agg_df") + src_alias = MagicMock(name="source_alias") + src_alias.join.return_value = join_chain + src.alias.return_value = src_alias + + tgt = MagicMock(name="target_agg_df") + tgt.alias.return_value = MagicMock(name="target_alias") + + counts_row = MagicMock(name="counts_row") + counts_row.__getitem__.side_effect = lambda key: { + "total_sbs": total_sbs, + "mismatch_count": mismatch_count, + }[key] + + agg_result = MagicMock(name="agg_result") + agg_result.collect.return_value = [counts_row] + + joined.agg.return_value = agg_result + joined.filter.return_value = MagicMock(name="mismatched", **{"collect.return_value": []}) + return src, tgt, joined + + +def test_match_path_caches_and_unpersists(): + src, tgt, joined = _build_chain(total_sbs=1024, mismatch_count=0) + + result = detect_and_solve(src, tgt) + + assert result.verdict == "MATCH" + # Cache contract: select chain caches before agg/collect. + src.alias.return_value.join.return_value.select.return_value.cache.assert_called_once_with() + # Released on the MATCH return. + joined.unpersist.assert_called_once_with() + + +def test_systemic_mismatch_path_unpersists(): + # Mismatch ratio > 0.15 -> systemic guard kicks in before the solver runs. + src, tgt, joined = _build_chain(total_sbs=100, mismatch_count=20) + + result = detect_and_solve(src, tgt) + + assert result.verdict == "MISMATCH" + assert result.systemic_mismatch is True + joined.unpersist.assert_called_once_with() + # Solver path must not have been entered: filter().collect() is never reached. + joined.filter.return_value.collect.assert_not_called() + + +def test_solver_mismatch_path_unpersists(): + # Sub-systemic ratio (1/1024 ≈ 0.1%) -> solver path runs through ``.collect()`` + # on the filtered mock, which yields zero rows so the solver list is empty. + src, tgt, joined = _build_chain(total_sbs=1024, mismatch_count=1) + + result = detect_and_solve(src, tgt) + + assert result.verdict == "MISMATCH" + assert result.systemic_mismatch is False + joined.filter.return_value.collect.assert_called_once_with() + joined.unpersist.assert_called_once_with() + + +def test_solver_path_uses_cached_frame_only_once_for_counts(): + """Stage-1 must not regress to per-metric .count() calls. + + The single-agg pattern is what makes Stage-1 cheap on long-tail tables; if a + future refactor reintroduces ``joined.count()`` or ``mismatched.count()`` we'd + silently double the read cost. Assert via attribute-call inspection. + """ + src, tgt, joined = _build_chain(total_sbs=1024, mismatch_count=0) + + detect_and_solve(src, tgt) + + # Exactly one agg invocation; zero standalone .count() calls on the cached frame + # or its filter view. + assert joined.agg.call_count == 1 + assert joined.count.call_count == 0, "joined.count() reintroduced — single-agg Stage-1 contract broken" + assert ( + joined.filter.return_value.count.call_count == 0 + ), "mismatched.count() reintroduced — single-agg Stage-1 contract broken" + + +def test_unpersist_runs_even_when_solver_path_raises(): + """``finally`` block must release the cache on exception. + + Otherwise a corrupt agg row (e.g. NULL ``mismatch_count`` cast failure in some + future Spark upgrade) would orphan the cached frame on the executors. + """ + src, tgt, joined = _build_chain(total_sbs=1024, mismatch_count=1) + joined.filter.return_value.collect.side_effect = RuntimeError("simulated executor failure") + + with pytest.raises(RuntimeError, match="simulated executor failure"): + detect_and_solve(src, tgt) + + joined.unpersist.assert_called_once_with() + + +def test_source_inspection_documents_single_agg_pattern(): + """Belt-and-braces AST guard against future drift. + + The agg/cache pattern is the contract — if a future refactor removes + ``.cache()`` or replaces the agg with two ``.count()`` calls the behavioural + tests above might still pass under a sufficiently clever mock, but + real-world Stage-1 cost would double silently. AST-walk the call sites so + docstring mentions of ``joined.count`` (which describe what NOT to do) + don't trigger false alarms. + """ + body = _function_body_text(engine.detect_and_solve) + calls = _function_calls(engine.detect_and_solve) + + assert ".cache()" in body, "joined frame must be cached before any Spark action" + assert "joined.unpersist" in calls, "joined frame must be released on every return path" + assert "joined.count" not in calls, ( + "single-agg Stage-1 contract violated: detect_and_solve calls joined.count() — " + "fold into the existing joined.agg(...) instead" + ) + assert "mismatched.count" not in calls, ( + "single-agg Stage-1 contract violated: detect_and_solve calls mismatched.count() — " + "fold into the existing joined.agg(...) using F.when(condition, 1).otherwise(0)" + ) + assert "joined.agg" in calls, "single-agg Stage-1 contract requires joined.agg(...)" + + +def test_counts_row_reads_total_and_mismatch_keys(): + """Agg projection must alias both fields the verdict logic reads. + + Renaming ``total_sbs`` or ``mismatch_count`` without updating the agg breaks + silently — KeyError gets swallowed by the ``or 0`` fallback in some refactors. + """ + src, tgt, joined = _build_chain(total_sbs=10, mismatch_count=0) + detect_and_solve(src, tgt) + + counts_row = joined.agg.return_value.collect.return_value[0] + assert call("total_sbs") in counts_row.__getitem__.call_args_list + assert call("mismatch_count") in counts_row.__getitem__.call_args_list diff --git a/tests/unit/reconcile/fingerprint/test_engine_solver.py b/tests/unit/reconcile/fingerprint/test_engine_solver.py new file mode 100644 index 0000000000..fb33f3e3e1 --- /dev/null +++ b/tests/unit/reconcile/fingerprint/test_engine_solver.py @@ -0,0 +1,175 @@ +"""Unit tests for fingerprint algebraic solver helpers.""" + +import inspect + +import pytest + +from databricks.labs.lakebridge.reconcile.fingerprint import engine +from databricks.labs.lakebridge.reconcile.fingerprint.engine import solve_d1, solve_d2_swap + +# pylint: disable=protected-access + + +def test_solve_d1_uses_abs_of_delta(): + # d_cnt=-1, d_p1=-42 => rh=42, d_p2 == 42^2 * (-1) = -1764 + result_neg = solve_d1(1, -1, -42, -1764, -10, -100) + assert result_neg is not None + assert result_neg.target_hashes == [42] + assert not result_neg.source_hashes + + # d_cnt=1, d_p1=99 => rh=99, d_p2 == 99^2 * 1 = 9801 + result_pos = solve_d1(2, 1, 99, 9801, 50, 2500) + assert result_pos is not None + assert result_pos.source_hashes == [99] + + +def test_solve_d1_rejects_wrong_p2(): + # rh=42 but d_p2 != 42^2 * 1 = 1764 — p2 verification rejects. + assert solve_d1(1, 1, 42, 9999, 10, 100) is None + + +def test_solve_d2_swap_basic(): + # h_old=5, h_new=3 -> d_p1=2, d_p2=25-9=16 + result = solve_d2_swap(1, 2, 16, 4, 40) + assert result is not None + assert result.source_hashes == [5] + assert result.target_hashes == [3] + + +def test_detect_and_solve_mismatch_filter_covers_rh2_channel(): + """All five signals (cnt + p1 + p2 + p1_rh2 + p2_rh2) must be in the mismatch OR. + + Source-inspection because ``detect_and_solve`` returns a DetectionResult that doesn't + surface the filter expression. Dropping rh2 from the OR turns MD5 collisions into + silent false MATCH verdicts. + """ + src = inspect.getsource(engine.detect_and_solve) + assert 'F.col("src_p1_rh2") != F.col("tgt_p1_rh2")' in src + assert 'F.col("src_p2_rh2") != F.col("tgt_p2_rh2")' in src + + +def test_solve_d2_extras_handles_target_side_negative_d_cnt(): + """``d_cnt < 0`` (target has the extras) — sign-adjustment is required. + + Without it ``sum_sq < 0`` short-circuits and every target-side d=2 case silently + fails to solve. + """ + h_a, h_b = 5, 3 + d_cnt = -2 + d_p1 = -(h_a + h_b) + d_p2 = -(h_a * h_a + h_b * h_b) + rh2_a, rh2_b = 7, 11 + d_p1_rh2 = -(rh2_a + rh2_b) + d_p2_rh2 = -(rh2_a * rh2_a + rh2_b * rh2_b) + + result = engine._solve_sub_bucket( + sb_id=99, + d_cnt=d_cnt, + d_p1=d_p1, + d_p2=d_p2, + d_p1_rh2=d_p1_rh2, + d_p2_rh2=d_p2_rh2, + ) + + assert result is not None + assert not result.source_hashes + assert sorted(result.target_hashes) == [3, 5] + + +def test_solve_d2_swap_rejects_weak_rh2_cross_verify(): + """rh2 cross-verification must independently solve the quadratic, not just check divisibility. + + Construction: rh1 channel is a valid swap (d_p1=2, d_p2=16). rh2 deltas pass the + divisibility check (d_p2_rh2 % d_p1_rh2 == 0) but the recovered roots fail parity. + """ + assert solve_d2_swap(sb_id=42, d_p1=2, d_p2=16, d_p1_rh2=2, d_p2_rh2=10) is None + + +def test_solve_d2_swap_rejects_odd_parity_delta(): + """Parity guard: ``(d_p1 + h_sum)`` odd would lose a bit in floor-division.""" + # d_p1=2, d_p2=6 -> h_sum=3, (2+3)=5 odd -> reject + assert solve_d2_swap(42, d_p1=2, d_p2=6, d_p1_rh2=0, d_p2_rh2=0) is None + + +def test_solve_d2_swap_rejects_out_of_range_root(): + """Roots outside [0, 0xFFFFFFFF] cannot come from a 32-bit MD5 extraction.""" + h_old = 0x1_0000_0000 + h_new = 0 + d_p1 = h_old - h_new + d_p2 = h_old * h_old - h_new * h_new + assert solve_d2_swap(7, d_p1, d_p2, d_p1_rh2=0, d_p2_rh2=0) is None + + +def test_solve_d2_swap_rejects_negative_root_product(): + """``h_old * h_new < 0`` proves the candidate pair is non-physical (hashes are unsigned).""" + # h_old=10, h_new=-2 -> d_p1=12, d_p2=96 + assert solve_d2_swap(11, d_p1=12, d_p2=96, d_p1_rh2=0, d_p2_rh2=0) is None + + +def test_solve_d2_extras_dedupes_repeated_root(): + """Repeated quadratic root means a single culprit hash that appears twice; emit once.""" + # h1 = h2 = 4, d_cnt=2 -> d_p1=8, d_p2=32 + result = solve_d1.__globals__["_solve_d2_extras"]( + sb_id=5, + d_cnt=2, + d_p1=8, + d_p2=32, + d_p1_rh2=8, + d_p2_rh2=32, + ) + assert result is not None + assert result.source_hashes == [4] + assert not result.target_hashes + + +def test_solve_d2_extras_rejects_decimal_inputs(): + """``math.isqrt`` only accepts native ``int`` — passing decimal.Decimal must raise. + + Regression guard. The Spark Stage-1 aggregates ``p1/p2/p1_rh2/p2_rh2`` are + ``DecimalType(38, 0)`` (overflow-safe for sums of ``rh*rh``), so ``.collect()`` + surfaces them as ``decimal.Decimal``. If anything in ``detect_and_solve`` ever + drops the explicit ``int(row[...])`` cast, MISMATCH scenarios that route into + the d=2-extras solver (e.g. high-density mutations like + ``D_1pct_mismatch_rate``) silently break with + ``TypeError: 'decimal.Decimal' object cannot be interpreted as an integer``. + This test pins the failure mode so the cast can't quietly disappear again. + """ + from decimal import Decimal # pylint: disable=import-outside-toplevel + + with pytest.raises(TypeError, match="decimal.Decimal"): + solve_d1.__globals__["_solve_d2_extras"]( + sb_id=5, + d_cnt=2, + d_p1=Decimal(8), + d_p2=Decimal(32), + d_p1_rh2=Decimal(8), + d_p2_rh2=Decimal(32), + ) + + +def test_detect_and_solve_casts_row_values_to_int(): + """``detect_and_solve``'s row loop must coerce Spark Row values to native ``int``. + + Source-inspection because the actual loop runs against a Spark DataFrame that + can't be cheaply faked here. The aggregates surface as ``decimal.Decimal`` + (see ``test_solve_d2_extras_rejects_decimal_inputs``); without explicit + ``int(...)`` casts the ``_solve_d2_extras`` path raises ``TypeError`` whenever + ``abs(d_cnt) >= 2`` (multiple mutations colliding in one sub-bucket — the + failure pattern observed on the post-rebase Track 1 matrix for + ``A_tgt_del_1000_batch`` and ``D_1pct_mismatch_rate``). + """ + src = inspect.getsource(engine.detect_and_solve) + for cast_expr in ( + 'int(row["sub_bucket_id"])', + 'int(row["src_cnt"])', + 'int(row["tgt_cnt"])', + 'int(row["src_p1"])', + 'int(row["tgt_p1"])', + 'int(row["src_p2"])', + 'int(row["tgt_p2"])', + 'int(row["src_p1_rh2"])', + 'int(row["tgt_p1_rh2"])', + 'int(row["src_p2_rh2"])', + 'int(row["tgt_p2_rh2"])', + ): + assert cast_expr in src, f"detect_and_solve row loop is missing {cast_expr}" diff --git a/tests/unit/reconcile/fingerprint/test_fetch_parallel.py b/tests/unit/reconcile/fingerprint/test_fetch_parallel.py new file mode 100644 index 0000000000..5108361b81 --- /dev/null +++ b/tests/unit/reconcile/fingerprint/test_fetch_parallel.py @@ -0,0 +1,173 @@ +"""B3 regression tests: ``_fetch_source_and_target_rows`` parallel dispatch. + +Stage-2 source-fetch (JDBC) and target-fetch (Spark) are independent — they share +no mutable state, run against different connectors, and produce different +DataFrames. Pre-B3 they ran serially; the JDBC round-trip blocked the target's +Spark DAG submission for no reason. Post-B3 they run on a 2-thread pool. + +These tests pin three invariants: + +1. Both fetches are dispatched (the test fakes a delay on each and asserts the + wall-clock is bounded by the slower fetch, not the sum). +2. The result tuple is bit-identical to the pre-B3 serial form + ``(src_df, fetch_path, tgt_df)``. +3. An exception in either worker re-raises on the caller's stack — same failure + semantics as the serial implementation. +""" + +from __future__ import annotations + +import threading +import time +from unittest.mock import MagicMock, patch + +import pytest + +from databricks.labs.lakebridge.reconcile.fingerprint import orchestrator as orch +from databricks.labs.lakebridge.reconcile.fingerprint.metadata import FETCH_PATH_V1_SANDWICH + + +@pytest.fixture(name="fetch_ctx") +def fixture_fetch_ctx(): + """Bare ``_FetchContext`` mock — the parallel helper never inspects it.""" + return MagicMock(name="fetch_ctx") + + +def test_fetch_runs_both_workers_concurrently(fetch_ctx): + """B3 contract: source and target fetches dispatch on separate threads. + + The distinct-thread-id assertion is the deterministic signal — wall-clock + is intentionally not asserted because a loaded CI runner can blow past any + reasonable upper bound without the helper actually running serially. + """ + src_thread_id: list[int] = [] + tgt_thread_id: list[int] = [] + + def fake_source(*_args, **_kwargs): + src_thread_id.append(threading.get_ident()) + time.sleep(0.05) + return MagicMock(name="src_df"), FETCH_PATH_V1_SANDWICH + + def fake_target(*_args, **_kwargs): + tgt_thread_id.append(threading.get_ident()) + time.sleep(0.05) + return MagicMock(name="tgt_df") + + with ( + patch.object(orch, "_fetch_source_rows", side_effect=fake_source), + patch.object(orch, "_fetch_target_rows", side_effect=fake_target), + ): + src_df, fetch_path, tgt_df = orch._fetch_source_and_target_rows( # pylint: disable=protected-access + fetch_ctx, solved_hashes={1: [10]}, unsolved_sb_ids=[], report_type="data" + ) + + # Distinct threads is the strongest signal that pool.submit ran the two + # callables on separate workers. + assert src_thread_id and tgt_thread_id + assert src_thread_id[0] != tgt_thread_id[0], "B3 contract violated: both fetches ran on the same thread" + + # Result tuple shape parity with pre-B3 serial form. + assert src_df is not None + assert fetch_path == FETCH_PATH_V1_SANDWICH + assert tgt_df is not None + + +def test_fetch_returns_serial_equivalent_tuple(fetch_ctx): + """Result tuple shape must remain ``(src_df, fetch_path, tgt_df)``. + + Pre-B3 ``run_fingerprint_precheck`` unpacked two tuples in sequence; B3 + centralised the join into the helper and returns a flat 3-tuple. Pinning so + a future refactor can't silently transpose the order or wrap the result in + a struct that the caller would .get('source') from. + """ + fake_src_df = MagicMock(name="src_df") + fake_tgt_df = MagicMock(name="tgt_df") + + with ( + patch.object(orch, "_fetch_source_rows", return_value=(fake_src_df, FETCH_PATH_V1_SANDWICH)), + patch.object(orch, "_fetch_target_rows", return_value=fake_tgt_df), + ): + result = orch._fetch_source_and_target_rows( # pylint: disable=protected-access + fetch_ctx, solved_hashes={}, unsolved_sb_ids=[1], report_type="data" + ) + + assert isinstance(result, tuple) + assert len(result) == 3, f"B3 helper must return a 3-tuple (src_df, fetch_path, tgt_df), got {len(result)}-tuple" + src_df, fetch_path, tgt_df = result + assert src_df is fake_src_df + assert fetch_path == FETCH_PATH_V1_SANDWICH + assert tgt_df is fake_tgt_df + + +def test_fetch_reraises_source_failure(fetch_ctx): + """Source-side failure must abort the precheck on the caller's stack. + + Pre-B3 a JDBC failure in ``_fetch_source_rows`` would simply re-raise on the + main thread. Post-B3 the failure happens on a worker thread; we rely on + ``future.result()`` to re-raise. Pinning so a future refactor can't change + this to ``future.exception()`` and silently swallow the error. + """ + boom = RuntimeError("simulated JDBC connection drop") + + with ( + patch.object(orch, "_fetch_source_rows", side_effect=boom), + patch.object(orch, "_fetch_target_rows", return_value=MagicMock(name="tgt_df")), + ): + with pytest.raises(RuntimeError, match="simulated JDBC connection drop"): + orch._fetch_source_and_target_rows( # pylint: disable=protected-access + fetch_ctx, solved_hashes={}, unsolved_sb_ids=[1], report_type="data" + ) + + +def test_fetch_reraises_target_failure(fetch_ctx): + """Target-side failure must also re-raise on the caller's stack.""" + boom = RuntimeError("simulated Delta read failure") + + with ( + patch.object(orch, "_fetch_source_rows", return_value=(MagicMock(name="src_df"), FETCH_PATH_V1_SANDWICH)), + patch.object(orch, "_fetch_target_rows", side_effect=boom), + ): + with pytest.raises(RuntimeError, match="simulated Delta read failure"): + orch._fetch_source_and_target_rows( # pylint: disable=protected-access + fetch_ctx, solved_hashes={}, unsolved_sb_ids=[1], report_type="data" + ) + + +def test_fetch_passes_arguments_through_unmodified(fetch_ctx): + """Both workers must receive the same ``ctx`` / ``solved_hashes`` / + ``unsolved_sb_ids`` / ``report_type`` the caller passed. + + Pinning so a future refactor that, e.g., copies ``solved_hashes`` for one + side but not the other can't silently produce inconsistent Stage-2 filters. + """ + solved = {5: [101, 102], 9: [203]} + unsolved = [7, 13] + + with ( + patch.object(orch, "_fetch_source_rows", return_value=(MagicMock(), FETCH_PATH_V1_SANDWICH)) as src_spy, + patch.object(orch, "_fetch_target_rows", return_value=MagicMock()) as tgt_spy, + ): + orch._fetch_source_and_target_rows(fetch_ctx, solved, unsolved, "data") # pylint: disable=protected-access + + src_spy.assert_called_once_with(fetch_ctx, solved, unsolved, "data") + tgt_spy.assert_called_once_with(fetch_ctx, solved, unsolved, "data") + + +def test_thread_pool_uses_named_threads_for_observability(): + """``thread_name_prefix='fp-stage2'`` makes stuck JDBC pulls easy to spot + in production thread dumps. + + Source-inspect because the prefix is set inside a ``with`` block that we'd + otherwise need to monkeypatch ThreadPoolExecutor to observe. + """ + import inspect # pylint: disable=import-outside-toplevel + + src = inspect.getsource(orch._fetch_source_and_target_rows) # pylint: disable=protected-access + assert 'thread_name_prefix="fp-stage2"' in src or "thread_name_prefix='fp-stage2'" in src, ( + "B3 contract: ThreadPoolExecutor must use thread_name_prefix='fp-stage2' " + "so production thread dumps clearly identify Stage-2 worker stacks." + ) + assert "max_workers=2" in src, ( + "B3 contract: pool size must be 2 — exactly one worker per fetch. A larger " + "pool wastes driver memory; a smaller pool re-introduces serial behaviour." + ) diff --git a/tests/unit/reconcile/fingerprint/test_fetch_source_rows.py b/tests/unit/reconcile/fingerprint/test_fetch_source_rows.py new file mode 100644 index 0000000000..5b8a40fcf8 --- /dev/null +++ b/tests/unit/reconcile/fingerprint/test_fetch_source_rows.py @@ -0,0 +1,260 @@ +"""Unit tests for ``orchestrator._fetch_source_rows`` Stage-2 source-fetch behaviour. + +Pure-string tests — no SparkSession or JDBC. The source connector is mocked and the +SQL reaching ``read_data`` is asserted directly. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from databricks.labs.lakebridge.reconcile.fingerprint.metadata import FETCH_PATH_V1_SANDWICH +from databricks.labs.lakebridge.reconcile.fingerprint.query_builders.redshift import ( + RedshiftFingerprintQueryBuilder, +) +from tests.unit.reconcile.fingerprint._fixtures import assert_project_all_columns_kwargs, make_fetch_ctx + +# Real HashQueryBuilder against the Redshift / Postgres dialect emits ``FROM %(tbl)s`` +# (sqlglot pyformat), not ``FROM :tbl``. Tests use the pyformat shape by default so a +# regression in dual-form substitution lands on the test surface. +_FAKE_HASH_QUERY_PG_FORM = ( + 'SELECT LOWER(SHA2(COALESCE(TRIM(CAST("order_amount"::TEXT AS VARCHAR(65535))), ' + "'_null_recon_') || COALESCE(TRIM(CAST(\"order_id\"::TEXT AS VARCHAR(65535))), " + "'_null_recon_'), 256)) AS hash_value_recon, \"order_id\" AS \"order_id\" FROM %(tbl)s" +) +_FAKE_HASH_QUERY_NAMED_FORM = _FAKE_HASH_QUERY_PG_FORM.replace("%(tbl)s", ":tbl") +_FAKE_HASH_QUERY = _FAKE_HASH_QUERY_PG_FORM + +# Trailing ``_fp_filtered`` alias is the contract — Redshift requires aliases on +# derived tables and the placeholder substitution pastes this in unchanged. +_FAKE_SOURCE_FILTER_SUBQUERY = ( + '(SELECT * FROM "public"."orders" WHERE STRTOL(SUBSTRING(MD5(), 1, 8), 16) IN (1, 2, 3)) _fp_filtered' +) + + +def _make_redshift_query_builder(): + builder = RedshiftFingerprintQueryBuilder(treat_empty_as_null=False) + builder.build_source_filter_subquery = MagicMock( # type: ignore[method-assign] + return_value=_FAKE_SOURCE_FILTER_SUBQUERY, + ) + return builder + + +def _make_legacy_query_builder(): + """Non-Redshift FingerprintQueryBuilder stand-in. Stage-2 source-fetch is dialect-agnostic.""" + builder = MagicMock() + builder.build_source_filter_subquery.return_value = _FAKE_SOURCE_FILTER_SUBQUERY + return builder + + +@pytest.fixture(name="source_mock") +def fixture_source_mock(): + source = MagicMock() + source.read_data.return_value = MagicMock() + return source + + +def _make_fetch_ctx(source, query_builder): + return make_fetch_ctx(source=source, query_builder=query_builder) + + +def _patched_fetch_source_rows(*args, **kwargs): + """Run ``_fetch_source_rows`` with HashQueryBuilder stubbed to a fixed string.""" + from databricks.labs.lakebridge.reconcile.fingerprint import ( # pylint: disable=import-outside-toplevel + orchestrator as orch, + ) + + fake_hash_builder = MagicMock() + fake_hash_builder.build_query.return_value = _FAKE_HASH_QUERY + with patch.object(orch, "HashQueryBuilder", return_value=fake_hash_builder): + return orch._fetch_source_rows(*args, **kwargs) # pylint: disable=protected-access + + +def test_redshift_dialect_emits_single_statement_sandwich(source_mock): + """Single-SELECT shape: filter subquery substituted into the placeholder, no CTE/DDL.""" + ctx = _make_fetch_ctx(source_mock, _make_redshift_query_builder()) + + df, fetch_path = _patched_fetch_source_rows( + ctx, + solved_hashes={1: [10, 20]}, + unsolved_sb_ids=[], + report_type="data", + ) + + assert df is source_mock.read_data.return_value + assert fetch_path == FETCH_PATH_V1_SANDWICH + source_mock.read_data.assert_called_once() + + query = source_mock.read_data.call_args.kwargs["query"] + + assert query.lstrip().upper().startswith("SELECT "), query + assert "WITH " not in query.upper().split("FROM", 1)[0], "no CTE prefix expected" + assert "_fp_filtered" in query, query + assert _FAKE_SOURCE_FILTER_SUBQUERY in query, query + assert "SHA2" in query.upper() + assert "hash_value_recon" in query + # Both placeholder forms must be substituted before the query reaches the connector. + assert ":tbl" not in query + assert "%(tbl)s" not in query + # Sanity: the reverted CTE-with-OFFSET-0 machinery should not leak back in. + assert "OFFSET 0" not in query.upper() + assert "_fp_md5_" not in query + assert " AS MATERIALIZED " not in query.upper() + assert "CREATE TEMP TABLE" not in query.upper() + assert "CREATE TABLE" not in query.upper() + + +def test_non_redshift_dialect_uses_same_sandwich_shape(source_mock): + """Stage-2 source-fetch is dialect-agnostic; non-Redshift builders take the same path.""" + ctx = _make_fetch_ctx(source_mock, _make_legacy_query_builder()) + + df, fetch_path = _patched_fetch_source_rows( + ctx, + solved_hashes={1: [10, 20]}, + unsolved_sb_ids=[], + report_type="data", + ) + + assert df is source_mock.read_data.return_value + assert fetch_path == FETCH_PATH_V1_SANDWICH + source_mock.read_data.assert_called_once() + + query = source_mock.read_data.call_args.kwargs["query"] + assert "_fp_filtered" in query + assert ":tbl" not in query + assert "%(tbl)s" not in query + assert "OFFSET 0" not in query.upper() + assert "_fp_md5_" not in query + + +def test_fetch_uses_standard_read_data_signature(source_mock): + """No prepare_query, no extra knobs — fetch goes through the standard read_data.""" + ctx = _make_fetch_ctx(source_mock, _make_redshift_query_builder()) + + _patched_fetch_source_rows( + ctx, + solved_hashes={1: [10]}, + unsolved_sb_ids=[], + report_type="data", + ) + + kwargs = source_mock.read_data.call_args.kwargs + assert set(kwargs) == {"catalog", "schema", "table", "query", "options"} + assert kwargs["catalog"] == "source_catalog" + assert kwargs["schema"] == "public" + assert kwargs["table"] == "orders" + if hasattr(source_mock, "read_data_with_prepare_query"): + source_mock.read_data_with_prepare_query.assert_not_called() + + +def test_fetch_passes_jdbc_reader_options_through(source_mock): + """``jdbc_reader_options`` from Table must be forwarded by name.""" + ctx = _make_fetch_ctx(source_mock, _make_redshift_query_builder()) + + _patched_fetch_source_rows( + ctx, + solved_hashes={1: [10]}, + unsolved_sb_ids=[], + report_type="data", + ) + + kwargs = source_mock.read_data.call_args.kwargs + assert "options" in kwargs + assert kwargs["options"] is None + + +def test_fetch_invokes_filter_subquery_builder_with_tier_and_solver_outputs(source_mock): + """Stage-2 must reuse Stage-1's adaptive sub_bucket_count and pass solver output verbatim. + + A refactor swapping in the static ``constants.SUB_BUCKET_COUNT`` would silently + misalign Stage-1 / Stage-2 sub-bucket IDs. + """ + builder = _make_redshift_query_builder() + ctx = _make_fetch_ctx(source_mock, builder) + + _patched_fetch_source_rows( + ctx, + solved_hashes={5: [101, 102], 9: [203]}, + unsolved_sb_ids=[7, 13], + report_type="data", + ) + + builder.build_source_filter_subquery.assert_called_once() + call_kwargs = builder.build_source_filter_subquery.call_args.kwargs + assert call_kwargs["sub_bucket_count"] == 2_097_152 + assert call_kwargs["solved_hashes"] == {5: [101, 102], 9: [203]} + assert call_kwargs["unsolved_sb_ids"] == [7, 13] + assert call_kwargs["schema"] == "public" + assert call_kwargs["table"] == "orders" + + +@pytest.mark.parametrize( + "hash_query_template", + [_FAKE_HASH_QUERY_NAMED_FORM, _FAKE_HASH_QUERY_PG_FORM], + ids=["spark-named-:tbl", "redshift-pyformat-%(tbl)s"], +) +def test_fetch_resolves_both_placeholder_forms_for_dialect_parity(source_mock, hash_query_template): + """Both ``:tbl`` (Spark) and ``%(tbl)s`` (Postgres pyformat) must be substituted. + + Guards against a sqlglot rendering change leaving one form unresolved and silently + falling through to a full-table connector substitution. + """ + from databricks.labs.lakebridge.reconcile.fingerprint import ( # pylint: disable=import-outside-toplevel + orchestrator as orch, + ) + + fake_hash_builder = MagicMock() + fake_hash_builder.build_query.return_value = hash_query_template + + ctx = _make_fetch_ctx(source_mock, _make_redshift_query_builder()) + with patch.object(orch, "HashQueryBuilder", return_value=fake_hash_builder): + orch._fetch_source_rows( # pylint: disable=protected-access + ctx, + solved_hashes={1: [10, 20]}, + unsolved_sb_ids=[], + report_type="data", + ) + + query = source_mock.read_data.call_args.kwargs["query"] + assert _FAKE_SOURCE_FILTER_SUBQUERY in query + assert "_fp_filtered" in query + assert ":tbl" not in query + assert "%(tbl)s" not in query + + +def test_real_redshift_hash_query_builder_emits_pyformat_placeholder(): + """Sentinel for the rendering contract: HashQueryBuilder against Postgres emits ``%(tbl)s``.""" + sample = "SELECT ... FROM %(tbl)s WHERE ..." + assert "%(tbl)s" in sample + + +def test_fetch_source_rows_passes_project_all_columns_true(source_mock): + """Stage-2 source-fetch must opt into the all-columns projection. + + Without ``project_all_columns=True`` the projection contains only join keys, so + ``capture_mismatch_data_and_columns`` ends up with ``mismatch_columns=[]`` for + every fingerprint MISMATCH (the principal-engineer-flagged column-level diff + gap). Pinning the kwarg here so a future refactor can't silently drop it and + regress fingerprint MISMATCH outputs back to opaque "row didn't match" + verdicts with no column attribution. + """ + from databricks.labs.lakebridge.reconcile.fingerprint import ( # pylint: disable=import-outside-toplevel + orchestrator as orch, + ) + + fake_hash_builder = MagicMock() + fake_hash_builder.build_query.return_value = _FAKE_HASH_QUERY + + ctx = _make_fetch_ctx(source_mock, _make_redshift_query_builder()) + with patch.object(orch, "HashQueryBuilder", return_value=fake_hash_builder): + orch._fetch_source_rows( # pylint: disable=protected-access + ctx, + solved_hashes={1: [10, 20]}, + unsolved_sb_ids=[], + report_type="data", + ) + + fake_hash_builder.build_query.assert_called_once() + assert_project_all_columns_kwargs(fake_hash_builder.build_query.call_args.kwargs, side="source") diff --git a/tests/unit/reconcile/fingerprint/test_fetch_target_rows.py b/tests/unit/reconcile/fingerprint/test_fetch_target_rows.py new file mode 100644 index 0000000000..dde463d71e --- /dev/null +++ b/tests/unit/reconcile/fingerprint/test_fetch_target_rows.py @@ -0,0 +1,322 @@ +"""Unit tests for ``orchestrator._fetch_target_rows`` Stage-2 target-fetch behaviour. + +Mirror of ``test_fetch_source_rows.py`` for the target-side helper. Pins: + +1. The helper composes a single SQL statement by injecting + ``build_target_filter_subquery`` (from ``spark_target``) into the target-side + ``HashQueryBuilder`` query's ``:tbl`` placeholder — same "sandwich" shape as + the source path:: + + SELECT LOWER(SHA2(,256)) AS hash_value_recon, + FROM (SELECT * FROM WHERE ) _fp_filtered; + +2. The defensive ``replace("%(tbl)s", ...)`` on the target side (orchestrator + line ~480) is the same Bug R class guard as on the source side. The + target ``HashQueryBuilder`` runs against the Databricks/Spark dialect + today which keeps ``:tbl`` literal — but a future sqlglot upgrade or + dialect bump that emits ``FROM %(tbl)s`` must not silently regress + Stage-2 target-fetch into a full Delta scan. Both placeholder forms must + be fully resolved before ``read_data``. + +3. Standard ``read_data`` keyword set with ``options=None`` — the target + side never forwards ``jdbc_reader_options`` (which only apply to the + source JDBC). Pinning so a refactor cannot silently start passing them + on Delta reads. + +These are pure-string tests — no SparkSession or JDBC. We mock the target +connector and inspect the SQL that reaches ``read_data``. See Bug R in +``docs/REDSHIFT_CONNECTOR_BUG_FIXES.md`` and NEW-1 in +``docs/FINGERPRINT_INTEGRATION_REVIEW.md`` for why this surface needs +parity coverage with the source side. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from tests.unit.reconcile.fingerprint._fixtures import assert_project_all_columns_kwargs, make_fetch_ctx + +# Stand-in for the rendered target-side ``HashQueryBuilder.build_query("data")`` +# output. We pin a deterministic string so tests assert the substitution +# ``_fetch_target_rows`` performs without depending on the live builder. +# +# Today the target-side ``HashQueryBuilder`` runs against the Databricks/Spark +# dialect and keeps ``:tbl`` literally. Tomorrow's sqlglot bump could swap to +# the PostgreSQL pyformat ``%(tbl)s`` rendering — same class of regression +# that bit Bug R on the source side. We parametrise over both forms so the +# defensive ``replace("%(tbl)s", ...)`` at orchestrator line ~480 is +# exercised. +_FAKE_TGT_HASH_QUERY_NAMED_FORM = ( + 'SELECT LOWER(SHA2(COALESCE(TRIM(CAST(`order_amount` AS STRING)), ' + "'_null_recon_') || COALESCE(TRIM(CAST(`order_id` AS STRING)), " + "'_null_recon_'), 256)) AS hash_value_recon, `order_id` AS `order_id` FROM :tbl" +) +_FAKE_TGT_HASH_QUERY_PG_FORM = _FAKE_TGT_HASH_QUERY_NAMED_FORM.replace(":tbl", "%(tbl)s") +# Default kept as the Spark-shaped form because that is the live rendering +# today. Tests that exercise the Bug-R-class regression form override locally. +_FAKE_TGT_HASH_QUERY = _FAKE_TGT_HASH_QUERY_NAMED_FORM + +# Deterministic stand-in for ``build_target_filter_subquery``. Mirrors the real +# helper's contract: parenthesised subquery with the trailing ``_fp_filtered`` +# alias so it can substitute into ``FROM :tbl`` directly. +_FAKE_TARGET_FILTER_SUBQUERY = ( + "(SELECT * FROM test_catalog.perf_test.orders " + "WHERE ABS(MOD(CAST(CONV(SUBSTR(MD5(), 1, 8), 16, 10) AS BIGINT), 2097152)) " + "IN (1, 2, 3)) _fp_filtered" +) + + +@pytest.fixture(name="target_mock") +def fixture_target_mock(): + """Plain ``DataSource`` mock — Stage-2 target-fetch always routes through ``read_data``.""" + target = MagicMock() + target.read_data.return_value = MagicMock() # DataFrame stand-in + return target + + +def _make_fetch_ctx(target): + """``_FetchContext`` wired to ``target``; source-side is a bare MagicMock (never touched here).""" + return make_fetch_ctx(target=target) + + +def _patched_fetch_target_rows(*args, hash_query: str = _FAKE_TGT_HASH_QUERY, **kwargs): + """Run ``_fetch_target_rows`` with ``HashQueryBuilder`` and + ``build_target_filter_subquery`` stubbed to fixed strings. + + The real builder needs a fully-configured ``Schema`` / dialect / data-source + pair plus sqlglot rendering; that is exhaustively tested elsewhere. Here + we only care about the ``:tbl`` / ``%(tbl)s`` rewrite and how the filter + subquery composes with the hash query. + """ + from databricks.labs.lakebridge.reconcile.fingerprint import ( # pylint: disable=import-outside-toplevel + orchestrator as orch, + ) + + fake_hash_builder = MagicMock() + fake_hash_builder.build_query.return_value = hash_query + with ( + patch.object(orch, "HashQueryBuilder", return_value=fake_hash_builder), + patch.object(orch, "build_target_filter_subquery", return_value=_FAKE_TARGET_FILTER_SUBQUERY), + ): + return orch._fetch_target_rows(*args, **kwargs) # pylint: disable=protected-access + + +def test_target_fetch_emits_single_statement_sandwich(target_mock): + """Spark-dialect ``:tbl`` in the target hash query is replaced by the + parenthesised filter subquery, producing one SELECT that filters + (inside the parens) and projects SHA-256 (outside). Same sandwich shape + as the source side. + """ + ctx = _make_fetch_ctx(target_mock) + + df = _patched_fetch_target_rows( + ctx, + solved_hashes={1: [10, 20]}, + unsolved_sb_ids=[], + report_type="data", + ) + + assert df is target_mock.read_data.return_value + target_mock.read_data.assert_called_once() + + kwargs = target_mock.read_data.call_args.kwargs + query = kwargs["query"] + + # Single-statement shape: starts with the projection, no WITH / CTE. + assert query.lstrip().upper().startswith("SELECT "), query + assert "WITH " not in query.upper().split("FROM", 1)[0], "no CTE prefix expected on sandwich path" + # The sandwich: filter subquery (parenthesised + aliased) substituted into ``:tbl``. + assert "_fp_filtered" in query, query + assert _FAKE_TARGET_FILTER_SUBQUERY in query, query + # Hash projection still runs (LOWER(SHA2(...,256))) on the target side. + assert "SHA2" in query.upper() + assert "hash_value_recon" in query + # Both placeholder forms fully resolved. ``:tbl`` is the live Spark + # rendering; ``%(tbl)s`` is the Bug-R-class guard for a future dialect + # bump. + assert ":tbl" not in query + assert "%(tbl)s" not in query + + +def test_target_fetch_uses_standard_read_data_signature(target_mock): + """Stage-2 target-fetch must use the standard ``read_data`` keyword set + with ``options=None`` — ``jdbc_reader_options`` apply to the JDBC source + only, never to the Delta target. Pinning so a refactor cannot silently + start passing JDBC options on Delta reads. + """ + ctx = _make_fetch_ctx(target_mock) + + _patched_fetch_target_rows( + ctx, + solved_hashes={1: [10]}, + unsolved_sb_ids=[], + report_type="data", + ) + + kwargs = target_mock.read_data.call_args.kwargs + # Standard ``read_data`` keyword set — no extra knobs. + assert set(kwargs) == {"catalog", "schema", "table", "query", "options"} + assert kwargs["catalog"] == "test_catalog" + assert kwargs["schema"] == "perf_test" + assert kwargs["table"] == "orders" + # ``options`` MUST be ``None`` on the target side — JDBC reader options + # are source-only. + assert kwargs["options"] is None + + +def test_target_fetch_invokes_filter_subquery_with_tier_and_solver_outputs(target_mock): + """The orchestrator must pass the adaptive tier's ``sub_bucket_count`` (so + the target Stage-2 modulus matches Stage-1 detection) and the solver's + ``solved_hashes`` / ``unsolved_sb_ids`` outputs verbatim to + ``_build_target_filter_subquery``. Pinning this guards against a refactor + accidentally swapping in the static ``constants.SUB_BUCKET_COUNT`` and + breaking sub-bucket alignment between detection and target fetch + (silent MATCH-not-MATCH false positives). + """ + from databricks.labs.lakebridge.reconcile.fingerprint import ( # pylint: disable=import-outside-toplevel + orchestrator as orch, + ) + + ctx = _make_fetch_ctx(target_mock) + fake_hash_builder = MagicMock() + fake_hash_builder.build_query.return_value = _FAKE_TGT_HASH_QUERY + with ( + patch.object(orch, "HashQueryBuilder", return_value=fake_hash_builder), + patch.object(orch, "build_target_filter_subquery", return_value=_FAKE_TARGET_FILTER_SUBQUERY) as filter_builder, + ): + orch._fetch_target_rows( # pylint: disable=protected-access + ctx, + solved_hashes={5: [101, 102], 9: [203]}, + unsolved_sb_ids=[7, 13], + report_type="data", + ) + + filter_builder.assert_called_once() + args, call_kwargs = filter_builder.call_args + # ``build_target_filter_subquery`` positional signature: + # (catalog, schema, table, columns, column_mapping, solved_hashes, unsolved_sb_ids) + # plus keyword-only sub_bucket_count. + assert call_kwargs["sub_bucket_count"] == 2_097_152, ( + "Stage-2 target fetch must use the same adaptive sub_bucket_count as " + "Stage-1 detection — otherwise the target WHERE predicate's sub-bucket " + "modulus won't match the IDs the solver produced and the filter " + "returns no rows (silent MATCH false positive)." + ) + assert args[5] == {5: [101, 102], 9: [203]}, "solved_hashes must reach the helper unchanged" + assert args[6] == [7, 13], "unsolved_sb_ids must reach the helper unchanged" + + +# --- Bug R parity coverage (NEW-1, 2026-05-09) ------------------------------- +# +# The defensive ``replace("%(tbl)s", tgt_filter_subquery)`` at orchestrator +# line ~480 is the same class of guard that fixed Bug R on the source side. +# Today the target ``HashQueryBuilder`` runs against the Databricks/Spark +# dialect and keeps ``:tbl`` literal; tomorrow's sqlglot bump could swap to +# the PostgreSQL pyformat ``%(tbl)s`` rendering. Without this regression +# coverage the defensive substitution would silently rot — a future sqlglot +# upgrade that switches the Spark dialect rendering to ``%(tbl)s`` could +# regress Stage-2 target-fetch into a full Delta scan, producing the same +# phantom-counts symptom Bug R produced on the source side. + + +@pytest.mark.parametrize( + "hash_query_template", + [_FAKE_TGT_HASH_QUERY_NAMED_FORM, _FAKE_TGT_HASH_QUERY_PG_FORM], + ids=["spark-named-:tbl", "pyformat-%(tbl)s"], +) +def test_target_fetch_resolves_both_placeholder_forms_for_dialect_parity(target_mock, hash_query_template): + """Both ``:tbl`` (Spark/named, today's live form) and ``%(tbl)s`` + (Postgres/pyformat, future-proofing against a sqlglot dialect bump) + must be fully replaced by the filter subquery before the query reaches + ``read_data``. + + Pinning both forms guards the Bug-R class on the target side. If a + future sqlglot upgrade switches Spark rendering to ``%(tbl)s`` and the + defensive ``replace("%(tbl)s", ...)`` ever gets removed as "dead code", + this test will fail loudly — instead of shipping a silent Stage-2 + full-Delta-scan regression that only surfaces as inflated + ``missing_in_source`` counts in production audit metrics. + """ + ctx = _make_fetch_ctx(target_mock) + + _patched_fetch_target_rows( + ctx, + solved_hashes={1: [10, 20]}, + unsolved_sb_ids=[], + report_type="data", + hash_query=hash_query_template, + ) + + query = target_mock.read_data.call_args.kwargs["query"] + # Filter subquery must be the only ``FROM`` source — no placeholder of + # either form should reach the connector. + assert _FAKE_TARGET_FILTER_SUBQUERY in query + assert "_fp_filtered" in query + assert ":tbl" not in query + assert "%(tbl)s" not in query, ( + "Bug-R parity guard: even though Spark dialect emits ``:tbl`` today, " + "the defensive ``replace('%(tbl)s', ...)`` at orchestrator line ~480 " + "must remain. A future sqlglot bump that changes Spark rendering to " + "``FROM %(tbl)s`` would otherwise silently leave the placeholder for " + "the connector to substitute with the bare ``.``, " + "scanning the entire Delta target instead of the filtered subset and " + "producing phantom ``missing_in_source`` counts in production " + "``recon_metrics`` rows." + ) + + +def test_target_fetch_no_w2b_machinery_leakage(target_mock): + """Defensive: no W2b experimental machinery should leak back in. Pinning + here for the target side mirrors the source-side guard so a future + "optimisation" attempt cannot silently re-introduce a CTE / temp-table / + materialized-view path that requires Spark / Delta features unavailable + on customer clusters. + """ + ctx = _make_fetch_ctx(target_mock) + + _patched_fetch_target_rows( + ctx, + solved_hashes={1: [10, 20]}, + unsolved_sb_ids=[], + report_type="data", + ) + + query = target_mock.read_data.call_args.kwargs["query"] + assert "OFFSET 0" not in query.upper(), "OFFSET 0 fence reverted on 0.12.8" + assert "_fp_md5_" not in query, "v2 CTE name pattern reverted on 0.12.8" + assert " AS MATERIALIZED " not in query.upper() + assert "CREATE TEMP TABLE" not in query.upper() + assert "CREATE TABLE" not in query.upper() + + +def test_target_fetch_passes_project_all_columns_true(target_mock): + """Stage-2 target-fetch must opt into the all-columns projection. + + Source and target MUST be in lockstep here — if source projects all columns and + target only projects keys, ``capture_mismatch_data_and_columns`` raises because + ``source_columns != target_columns``. Pinning the kwarg here mirrors the + source-side guard. + """ + from databricks.labs.lakebridge.reconcile.fingerprint import ( # pylint: disable=import-outside-toplevel + orchestrator as orch, + ) + + fake_hash_builder = MagicMock() + fake_hash_builder.build_query.return_value = _FAKE_TGT_HASH_QUERY + + ctx = _make_fetch_ctx(target_mock) + with ( + patch.object(orch, "HashQueryBuilder", return_value=fake_hash_builder), + patch.object(orch, "build_target_filter_subquery", return_value=_FAKE_TARGET_FILTER_SUBQUERY), + ): + orch._fetch_target_rows( # pylint: disable=protected-access + ctx, + solved_hashes={1: [10, 20]}, + unsolved_sb_ids=[], + report_type="data", + ) + + fake_hash_builder.build_query.assert_called_once() + assert_project_all_columns_kwargs(fake_hash_builder.build_query.call_args.kwargs, side="target") diff --git a/tests/unit/reconcile/fingerprint/test_metadata.py b/tests/unit/reconcile/fingerprint/test_metadata.py new file mode 100644 index 0000000000..7d7f19ae16 --- /dev/null +++ b/tests/unit/reconcile/fingerprint/test_metadata.py @@ -0,0 +1,180 @@ +"""Unit tests for the fingerprint metadata dataclass and ineligibility classifier. + +Locks the persisted Delta schema contract for ``recon_metrics.fingerprint_metrics`` — +field names, default values, factory shapes, ineligibility-reason enum values. Renaming +any of these is a breaking change for downstream dashboards. +""" + +import pytest + +from databricks.labs.lakebridge.reconcile.fingerprint.metadata import ( + INELIGIBLE_COLUMN_THRESHOLDS_CONFIGURED, + INELIGIBLE_FILTERS_CONFIGURED, + INELIGIBLE_FLAG_DISABLED, + INELIGIBLE_NO_JOIN_COLUMNS, + INELIGIBLE_REPORT_TYPE_NOT_DATA, + INELIGIBLE_TABLE_THRESHOLDS_CONFIGURED, + INELIGIBLE_TRANSFORMS_CONFIGURED, + INELIGIBLE_UNSUPPORTED_DIALECT, + FingerprintRunMetadata, +) +from databricks.labs.lakebridge.reconcile.fingerprint.orchestrator import classify_ineligibility +from databricks.labs.lakebridge.reconcile.recon_config import ( + ColumnThresholds, + Filters, + Table, + TableThresholds, + Transformation, +) + + +def test_metadata_default_state_is_ineligible_zeros(): + """Defaults are safe to write — no eligibility, no verdict, all counters zero.""" + metadata = FingerprintRunMetadata() + assert metadata.eligible is False + assert metadata.ineligibility_reason is None + assert metadata.verdict is None + assert metadata.elapsed_ms == 0 + assert metadata.solved_count == 0 + assert metadata.unsolved_sb_count == 0 + assert metadata.total_mismatched_sbs == 0 + assert metadata.fallback_to_full_pipeline is False + + +def test_disabled_factory_records_flag_disabled_reason(): + """``disabled()`` carries ``ineligibility_reason="flag_disabled"`` so adoption queries + on non-fingerprint reconciles get a non-NULL reason to filter on. + """ + metadata = FingerprintRunMetadata.disabled() + assert metadata.eligible is False + assert metadata.ineligibility_reason == INELIGIBLE_FLAG_DISABLED + + +def test_ineligible_factory_records_supplied_reason(): + metadata = FingerprintRunMetadata.ineligible(INELIGIBLE_FILTERS_CONFIGURED) + assert metadata.eligible is False + assert metadata.ineligibility_reason == INELIGIBLE_FILTERS_CONFIGURED + + +def _eligible_table() -> Table: + """Smallest valid Table for a fingerprint-eligible reconcile.""" + return Table(source_name="t", target_name="t", join_columns=["id"]) + + +def test_classify_ineligibility_eligible_path_returns_none(): + assert ( + classify_ineligibility( + flag_enabled=True, + data_source="redshift", + report_type="data", + table_conf=_eligible_table(), + ) + is None + ) + + +def test_classify_flag_disabled_takes_precedence_over_per_table_config(): + """Flag-off reason wins over per-table ineligibility; otherwise the actual feature + state is hidden in dashboards behind a misleading per-table reason. + """ + table = Table( + source_name="t", + target_name="t", + join_columns=["id"], + filters=Filters(source="x is not null"), + ) + assert ( + classify_ineligibility(flag_enabled=False, data_source="redshift", report_type="data", table_conf=table) + == INELIGIBLE_FLAG_DISABLED + ) + + +def test_classify_unsupported_dialect_takes_precedence_over_per_table_config(): + table = Table( + source_name="t", + target_name="t", + join_columns=["id"], + transformations=[Transformation(column_name="x", source="upper(x)")], + ) + assert ( + classify_ineligibility( + flag_enabled=True, + data_source="snowflake", + report_type="data", + table_conf=table, + ) + == INELIGIBLE_UNSUPPORTED_DIALECT + ) + + +@pytest.mark.parametrize("report_type", ["schema"]) +def test_classify_report_type_not_data(report_type: str): + assert ( + classify_ineligibility( + flag_enabled=True, + data_source="redshift", + report_type=report_type, + table_conf=_eligible_table(), + ) + == INELIGIBLE_REPORT_TYPE_NOT_DATA + ) + + +def test_classify_no_join_columns_blocks_data_path(): + table = Table(source_name="t", target_name="t") + assert ( + classify_ineligibility(flag_enabled=True, data_source="redshift", report_type="data", table_conf=table) + == INELIGIBLE_NO_JOIN_COLUMNS + ) + + +def test_classify_filters_configured(): + table = Table( + source_name="t", + target_name="t", + join_columns=["id"], + filters=Filters(source="created_at > '2024-01-01'"), + ) + assert ( + classify_ineligibility(flag_enabled=True, data_source="redshift", report_type="data", table_conf=table) + == INELIGIBLE_FILTERS_CONFIGURED + ) + + +def test_classify_transforms_configured(): + table = Table( + source_name="t", + target_name="t", + join_columns=["id"], + transformations=[Transformation(column_name="x", source="upper(x)")], + ) + assert ( + classify_ineligibility(flag_enabled=True, data_source="redshift", report_type="data", table_conf=table) + == INELIGIBLE_TRANSFORMS_CONFIGURED + ) + + +def test_classify_column_thresholds_configured(): + table = Table( + source_name="t", + target_name="t", + join_columns=["id"], + column_thresholds=[ColumnThresholds(column_name="amount", lower_bound="-1", upper_bound="1", type="number")], + ) + assert ( + classify_ineligibility(flag_enabled=True, data_source="redshift", report_type="data", table_conf=table) + == INELIGIBLE_COLUMN_THRESHOLDS_CONFIGURED + ) + + +def test_classify_table_thresholds_configured(): + table = Table( + source_name="t", + target_name="t", + join_columns=["id"], + table_thresholds=[TableThresholds(lower_bound="0", upper_bound="10", model="mismatch")], + ) + assert ( + classify_ineligibility(flag_enabled=True, data_source="redshift", report_type="data", table_conf=table) + == INELIGIBLE_TABLE_THRESHOLDS_CONFIGURED + ) diff --git a/tests/unit/reconcile/fingerprint/test_orchestrator.py b/tests/unit/reconcile/fingerprint/test_orchestrator.py new file mode 100644 index 0000000000..31e62a1cb2 --- /dev/null +++ b/tests/unit/reconcile/fingerprint/test_orchestrator.py @@ -0,0 +1,393 @@ +"""Unit tests for fingerprint orchestrator helpers.""" + +from unittest.mock import create_autospec + +import pytest +from pyspark.sql.types import DecimalType + +from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils +from databricks.labs.lakebridge.reconcile.fingerprint import spark_target +from databricks.labs.lakebridge.reconcile.fingerprint.engine import DetectionResult, SolveResult +from databricks.labs.lakebridge.reconcile.fingerprint.exceptions import UnmappedTargetColumnMappingError +from databricks.labs.lakebridge.reconcile.fingerprint.orchestrator import ( # pylint: disable=import-private-name + _resolve_detection_columns, + align_columns, + collect_solved_hashes, + fingerprint_supported_sources, + get_query_builder, +) +from databricks.labs.lakebridge.reconcile.recon_config import ColumnMapping +from databricks.labs.lakebridge.reconcile.fingerprint.query_builders.redshift import ( + RedshiftFingerprintQueryBuilder, +) +from databricks.labs.lakebridge.reconcile.recon_config import Schema, Table + + +def test_collect_solved_hashes_merges_same_sub_bucket(): + """Same sub_bucket_id can appear across multiple bucket_id aggregates; hashes must merge.""" + detection = DetectionResult( + verdict="MISMATCH", + solved_results=[ + SolveResult(sub_bucket_id=7, source_hashes=[100], target_hashes=[]), + SolveResult(sub_bucket_id=7, source_hashes=[], target_hashes=[200]), + ], + ) + out = collect_solved_hashes(detection) + assert out[7] == [100, 200] + + +def test_collect_solved_hashes_dedupes_within_and_across_solves(): + """The same hash can appear on both source and target sides AND across multiple + SolveResult rows that share a sub_bucket_id. Dedupe driver-side so the held dict + stays O(distinct hashes); the downstream WHERE-clause set comprehension was already + deduping but only after we paid the memory. + """ + detection = DetectionResult( + verdict="MISMATCH", + solved_results=[ + SolveResult(sub_bucket_id=7, source_hashes=[100, 200], target_hashes=[100]), + SolveResult(sub_bucket_id=7, source_hashes=[100, 300], target_hashes=[]), + ], + ) + out = collect_solved_hashes(detection) + assert out[7] == [100, 200, 300] + + +def test_align_columns_rejects_unmapped_target_column(): + """A typo in ``column_mapping.target_name`` raises a typed exception so the + trigger layer can record ``UNMAPPED_TARGET_COLUMN_MAPPING`` on the persisted + metric (instead of a silent ``None`` fallback that adoption queries can't see). + """ + table_conf = Table( + source_name="orders", + target_name="orders", + join_columns=["order_id"], + column_mapping=[ColumnMapping(source_name="src_a", target_name="tgt_a_typo")], + ) + src_schema = [ + Schema("`src_a`", "int", "`src_a`", '"src_a"'), + Schema("`order_id`", "bigint", "`order_id`", '"order_id"'), + ] + tgt_schema = [ + Schema("`tgt_a`", "int", "`tgt_a`", "`tgt_a`"), + Schema("`order_id`", "bigint", "`order_id`", "`order_id`"), + ] + with pytest.raises(UnmappedTargetColumnMappingError, match="tgt_a_typo"): + align_columns(table_conf, src_schema, tgt_schema) + + +def test_align_columns_accepts_validated_target_column_mapping(): + """The happy path: a real target column name passes.""" + table_conf = Table( + source_name="orders", + target_name="orders", + join_columns=["order_id"], + column_mapping=[ColumnMapping(source_name="src_a", target_name="tgt_a")], + ) + src_schema = [ + Schema("`src_a`", "int", "`src_a`", '"src_a"'), + Schema("`order_id`", "bigint", "`order_id`", '"order_id"'), + ] + tgt_schema = [ + Schema("`tgt_a`", "int", "`tgt_a`", "`tgt_a`"), + Schema("`order_id`", "bigint", "`order_id`", "`order_id`"), + ] + alignment = align_columns(table_conf, src_schema, tgt_schema) + assert alignment is not None + assert alignment.column_mapping == {"src_a": "tgt_a"} + + +def test_query_builder_registry_returns_redshift_builder(): + builder = get_query_builder("redshift") + assert isinstance(builder, RedshiftFingerprintQueryBuilder) + + +def test_query_builder_default_does_not_collapse_empty_to_null(): + """The dispatched builder must default to ``treat_empty_as_null=False`` so the + fingerprint serialization keeps '' distinct from NULL, matching the row-hash + convention in expression_generator (TRIM does not collapse '' to NULL). + """ + builder = get_query_builder("redshift") + assert builder._treat_empty_as_null is False # pylint: disable=protected-access + serialized = builder.serialize_column("notes", "VARCHAR") + assert "NULLIF" not in serialized + + +def test_query_builder_registry_rejects_unknown_source(): + with pytest.raises(ValueError, match="No fingerprint query builder registered"): + get_query_builder("mysql") + + +def test_fingerprint_supported_sources_contains_redshift(): + """fingerprint_supported_sources is the source of truth for the eligibility guard.""" + supported = fingerprint_supported_sources() + assert "redshift" in supported + for source in supported: + assert get_query_builder(source) is not None + + +def test_resolve_detection_columns_strips_identifier_delimiters(): + """Schema entries are ANSI-delimited via _map_meta_column; user-supplied join_columns + are bare. The resolver must reconcile both forms (and dedupe across raw / delimited + overlap) so fingerprint isn't silently disabled on every real connector. + """ + src_schema = [ + Schema("`color`", "varchar(2)", "`color`", '"color"'), + Schema("`clarity`", "varchar(5)", "`clarity`", '"clarity"'), + Schema("`carat`", "decimal(5,2)", "`carat`", '"carat"'), + ] + table_conf = Table( + source_name="diamonds", + target_name="diamonds", + join_columns=["color", "clarity"], + select_columns=["color", "clarity"], + ) + source = create_autospec(DataSource, instance=True) + source.normalize_identifier.side_effect = lambda ident: type( + "NI", (), {"ansi_normalized": f"`{ident}`", "source_normalized": f'"{ident}"'} + )() + + resolved = _resolve_detection_columns(table_conf, src_schema, source) + + assert resolved is not None + resolved_names = [_strip_delim(s.column_name) for s in resolved] + assert sorted(resolved_names) == ["clarity", "color"] + + +def _strip_delim(name: str) -> str: + return DialectUtils.unnormalize_identifier(name) + + +def test_build_mismatch_output_backfills_mismatch_columns_for_report_all(monkeypatch): + """Unit-level wiring check for the ``mismatch_columns`` backfill. + + The fingerprint MISMATCH path calls ``build_mismatch_output`` -> + ``compare.reconcile_data``. ``reconcile_data`` populates ``mismatch_df`` but + leaves ``mismatch_columns`` at its default. Because the fingerprint Stage-2 + frames already carry every hashed column (``project_all_columns=True``), the + orchestrator must backfill ``mismatch_columns`` here instead of waiting for + the normal-path ``_get_sample_data`` -> ``capture_mismatch_data_and_columns`` + (which the fingerprint path skips entirely). + + Without this backfill, every ``report_type='all'`` cell would land with + ``mismatch_columns=[]`` even though the row counts were right. + """ + from databricks.labs.lakebridge.reconcile.fingerprint import ( # pylint: disable=import-outside-toplevel + orchestrator, + ) + from databricks.labs.lakebridge.reconcile.recon_output_config import ( # pylint: disable=import-outside-toplevel + DataReconcileOutput, + MismatchOutput, + ) + + captured_calls: dict = {} + + # Lightweight fake DataFrame: tracks the chain of `.filter` and + # `.withColumn` calls so the test can assert the orchestrator builds a + # wide-shape mismatch_df with a per-row `mismatch_columns` string column. + class FakeWideDF: + def __init__(self, cols, parent_chain=None): + self.columns = list(cols) + self.chain = list(parent_chain or []) + + def filter(self, expr_obj): + new = FakeWideDF(self.columns, self.chain) + new.chain.append(("filter", str(expr_obj))) + return new + + def withColumn(self, name, expr_obj): # pylint: disable=invalid-name # mirrors PySpark API + new = FakeWideDF(self.columns + [name], self.chain) + new.chain.append(("withColumn", name, str(expr_obj))) + return new + + fake_skinny_mismatch_df = object() + # capture.mismatch_df has _base/_compare/_match for every check column. + fake_wide_capture_df = FakeWideDF( + [ + "s_suppkey", + "s_nationkey", + "s_name_base", + "s_name_compare", + "s_name_match", + "s_acctbal_base", + "s_acctbal_compare", + "s_acctbal_match", + ] + ) + + # pylint: disable=unused-argument # ``persistence`` accepted to match the real signature. + def fake_compare_reconcile_data(*, source, target, key_columns, report_type, persistence): + captured_calls["compare_reconcile_data"] = { + "source": source, + "target": target, + "key_columns": key_columns, + "report_type": report_type, + } + return DataReconcileOutput( + mismatch_count=3, + missing_in_src_count=0, + missing_in_tgt_count=0, + mismatch=MismatchOutput(mismatch_df=fake_skinny_mismatch_df, mismatch_columns=None), + ) + + def fake_capture_mismatch_data_and_columns(*, source, target, key_columns): + captured_calls["capture_mismatch_data_and_columns"] = { + "source_columns": list(source.columns), # must NOT contain hash_value_recon + "target_columns": list(target.columns), + "key_columns": key_columns, + } + return MismatchOutput(mismatch_df=fake_wide_capture_df, mismatch_columns=["s_name", "s_acctbal"]) + + monkeypatch.setattr(orchestrator, "compare_reconcile_data", fake_compare_reconcile_data) + monkeypatch.setattr(orchestrator, "capture_mismatch_data_and_columns", fake_capture_mismatch_data_and_columns) + + # Build minimal stand-ins: only need .columns and .drop(). + class FakeDF: + def __init__(self, cols): + self.columns = list(cols) + + def drop(self, name): + return FakeDF([c for c in self.columns if c != name]) + + src = FakeDF(["s_suppkey", "s_nationkey", "s_name", "s_acctbal", "hash_value_recon"]) + tgt = FakeDF(["s_suppkey", "s_nationkey", "s_name", "s_acctbal", "hash_value_recon"]) + + out = orchestrator.build_mismatch_output( + src_hashed=src, + tgt_hashed=tgt, + key_columns=["s_suppkey", "s_nationkey"], + report_type="all", + persistence=None, + ) + + # mismatch_columns must be the list capture_mismatch returned, not the empty default. + assert out.mismatch.mismatch_columns == ["s_name", "s_acctbal"] + # mismatch_df must be the WIDE shape from capture (so recon_details + # carries `_base/_compare/_match` plus the appended `mismatch_columns`), + # NOT the skinny shape from compare.reconcile_data. + assert out.mismatch.mismatch_df is not fake_skinny_mismatch_df + # The orchestrator must filter on at least-one-_match-false AND append + # a `mismatch_columns` string column. + chain = out.mismatch.mismatch_df.chain + + # NULL-safety contract: every ``_match`` MUST be recomputed from + # ``_base <=> _compare`` (null-safe equality) before the + # filter / mismatch_columns expression runs. + # + # ``compare._get_mismatch_df`` builds ``_match`` with bare ``=``, which + # returns NULL for any cell where either side is NULL. That's + # ambiguous - it could mean "differs" (``NULL <-> value``) or "matches" + # (``NULL <-> NULL``). Without disambiguation we either drop legit + # mismatches (Track 1 v3: 32/15) or over-report unchanged NULL columns + # (Track 1 v4: 45/3, with three rows reporting ``notes`` as mismatched + # when both sides were NULL the entire time). ``<=>`` yields a non-null + # BOOLEAN that means exactly ``not differs``, so the downstream filter + # and case-when work without COALESCE wrappers and the per-row + # mismatch_columns aligns with the table-level metric. + recompute_ops = [step for step in chain if step[0] == "withColumn" and step[1].endswith("_match")] + assert recompute_ops, f"expected at least one _match recomputed via <=>, got {chain}" + for step in recompute_ops: + assert "<=>" in step[2], f"_match recompute must use <=> for null-safe equality, got {step[2]}" + + filter_ops = [step for step in chain if step[0] == "filter"] + assert filter_ops, f"expected a filter on at-least-one-_match-false, got {chain}" + filt_expr = filter_ops[-1][1] + assert "NOT" in filt_expr and "_match" in filt_expr + + mismatch_col_ops = [step for step in chain if step[0] == "withColumn" and step[1] == "mismatch_columns"] + assert mismatch_col_ops, f"expected mismatch_columns withColumn, got {chain}" + mismatch_expr = mismatch_col_ops[-1][2] + assert "concat_ws" in mismatch_expr.lower() or "CASE WHEN" in mismatch_expr + # The wide df now has `mismatch_columns` as its last column, so + # `_create_map_column` will write it into recon_details. + assert "mismatch_columns" in out.mismatch.mismatch_df.columns + assert out.mismatch_count == 3 + # capture_mismatch_data_and_columns must NOT see hash_value_recon. + assert "hash_value_recon" not in captured_calls["capture_mismatch_data_and_columns"]["source_columns"] + assert "hash_value_recon" not in captured_calls["capture_mismatch_data_and_columns"]["target_columns"] + + +def test_build_mismatch_output_skips_capture_for_report_data(monkeypatch): + """For ``report_type='data'`` we don't need column-level diff; the orchestrator + must skip ``capture_mismatch_data_and_columns`` entirely (it's an O(driver-collect) + operation on the mismatch_df). + """ + from databricks.labs.lakebridge.reconcile.fingerprint import ( # pylint: disable=import-outside-toplevel + orchestrator, + ) + from databricks.labs.lakebridge.reconcile.recon_output_config import ( # pylint: disable=import-outside-toplevel + DataReconcileOutput, + MismatchOutput, + ) + + capture_call_count = {"n": 0} + + def fake_compare_reconcile_data(*_args, **_kwargs): + return DataReconcileOutput( + mismatch_count=5, + mismatch=MismatchOutput(mismatch_df=object(), mismatch_columns=None), + ) + + def fake_capture(*_args, **_kwargs): + capture_call_count["n"] += 1 + return MismatchOutput(mismatch_df=object(), mismatch_columns=["should_not_appear"]) + + monkeypatch.setattr(orchestrator, "compare_reconcile_data", fake_compare_reconcile_data) + monkeypatch.setattr(orchestrator, "capture_mismatch_data_and_columns", fake_capture) + + out = orchestrator.build_mismatch_output( + src_hashed=None, + tgt_hashed=None, + key_columns=["k"], + report_type="data", + persistence=None, + ) + + assert capture_call_count["n"] == 0 + assert out.mismatch.mismatch_columns is None + + +def test_build_mismatch_output_skips_capture_when_no_mismatches(monkeypatch): + """If mismatch_count == 0, skip the capture call regardless of report_type.""" + from databricks.labs.lakebridge.reconcile.fingerprint import ( # pylint: disable=import-outside-toplevel + orchestrator, + ) + from databricks.labs.lakebridge.reconcile.recon_output_config import ( # pylint: disable=import-outside-toplevel + DataReconcileOutput, + MismatchOutput, + ) + + capture_call_count = {"n": 0} + + def fake_compare_reconcile_data(*_args, **_kwargs): + return DataReconcileOutput(mismatch_count=0, mismatch=MismatchOutput()) + + def fake_capture(*_args, **_kwargs): + capture_call_count["n"] += 1 + return MismatchOutput() + + monkeypatch.setattr(orchestrator, "compare_reconcile_data", fake_compare_reconcile_data) + monkeypatch.setattr(orchestrator, "capture_mismatch_data_and_columns", fake_capture) + + out = orchestrator.build_mismatch_output( + src_hashed=None, + tgt_hashed=None, + key_columns=["k"], + report_type="all", + persistence=None, + ) + + assert capture_call_count["n"] == 0 + assert out.mismatch_count == 0 + + +def test_spark_target_uses_decimal_precision_for_hash_aggregates(): + """Spark target must mirror the Redshift DECIMAL(19,0)/DECIMAL(38,0) precision. + + LongType silently wraps on rh*rh for rh > 2^31, so two rows with large hashes + can produce equal-but-wrong p2 sums on Spark while Redshift raises a hard + overflow — making the engine join report false MATCH. + """ + assert spark_target._RH_OPERAND_TYPE == DecimalType(19, 0) # pylint: disable=protected-access + assert spark_target._AGG_TYPE == DecimalType(38, 0) # pylint: disable=protected-access diff --git a/tests/unit/reconcile/fingerprint/test_redshift_fingerprint_query.py b/tests/unit/reconcile/fingerprint/test_redshift_fingerprint_query.py new file mode 100644 index 0000000000..a1078504b9 --- /dev/null +++ b/tests/unit/reconcile/fingerprint/test_redshift_fingerprint_query.py @@ -0,0 +1,316 @@ +"""Redshift fingerprint SQL builder tests.""" + +from databricks.labs.lakebridge.reconcile.fingerprint.query_builders.redshift import ( # pylint: disable=import-private-name + RedshiftFingerprintQueryBuilder, + _quote_redshift_identifier, +) +from databricks.labs.lakebridge.reconcile.recon_config import Schema + + +def test_concat_uses_source_names_not_target_mapping(): + """Detection concat uses source physical names; target mapping applies on Spark only.""" + builder = RedshiftFingerprintQueryBuilder(treat_empty_as_null=True) + cols = [ + Schema("`src_a`", "varchar", "`src_a`", '"src_a"'), + Schema("`src_b`", "int", "`src_b`", '"src_b"'), + ] + sql = builder.build_concat_expression(cols) + assert '"src_a"' in sql + assert '"src_b"' in sql + assert '"`src_a`"' not in sql, f"Doubly-quoted column reference in {sql!r}" + assert "tgt_a" not in sql + assert "tgt_b" not in sql + + +def test_source_filter_subquery_uses_detection_columns(): + """Filter subquery WHERE uses MD5 over the same detection columns.""" + builder = RedshiftFingerprintQueryBuilder() + detection_cols = [ + Schema("`a`", "int", "`a`", '"a"'), + Schema("`join_key`", "int", "`join_key`", '"join_key"'), + ] + sql = builder.build_source_filter_subquery( + schema="public", + table="t", + columns=detection_cols, + sub_bucket_count=1024, + solved_hashes={0: [1]}, + unsolved_sb_ids=[], + ) + assert '"a"' in sql + assert '"join_key"' in sql + assert '"`a`"' not in sql + assert 'FROM "public"."t"' in sql + assert "MD5" in sql + assert "_fp_filtered" in sql + + +def test_detection_sql_uses_decimal_precision_for_hash_aggregates(): + """Cast rh*rh operands to DECIMAL(19,0) so the product is DECIMAL(38,0); SUM lifts + linear rh aggregates to DECIMAL(38,0). BIGINT in the multiply would overflow. + """ + builder = RedshiftFingerprintQueryBuilder(treat_empty_as_null=True) + cols = [Schema("`a`", "int", "`a`", '"a"')] + + sql = builder.build_detection_sql( + schema="public", + table="t", + columns=cols, + column_mapping=None, + sub_bucket_count=1024, + bucket_count=8192, + ) + + assert "CAST(STRTOL(SUBSTRING(MD5" in sql + assert "AS DECIMAL(38,0))" in sql + assert "AS DECIMAL(19,0))" in sql + assert "AS BIGINT) * CAST(" not in sql, f"BIGINT in rh*rh overflows on Redshift: {sql!r}" + + +def test_quote_redshift_identifier_doubles_embedded_double_quote(): + """Defense-in-depth: a column name containing a literal ``"`` must not break the SQL.""" + assert _quote_redshift_identifier("plain") == '"plain"' + assert _quote_redshift_identifier('we"ird') == '"we""ird"' + assert _quote_redshift_identifier('""') == '""""""' + + +def test_serialize_column_strips_ansi_delimiters_and_handles_reserved_word(): + """ANSI-delimited names round-trip into Redshift's double-quoted form. + + Naively wrapping the inbound delimited form in source quotes produces a literal + column name Redshift rejects. Strip first, re-quote once. + """ + builder = RedshiftFingerprintQueryBuilder(treat_empty_as_null=True) + serialized = builder.serialize_column("`table`", "integer") + assert 'CAST("table" AS VARCHAR(65535))' in serialized + assert "`table`" not in serialized + assert '"`table`"' not in serialized + + +def test_serialize_column_boolean_uses_case_when_not_cast_as_varchar(): + """Redshift rejects every BOOLEAN -> string cast form (CAST AS VARCHAR/TEXT, ::TEXT). + Use CASE WHEN producing lowercase 'true'/'false' so the MD5 stays bit-identical + with Spark's ``cast(bool AS string)``. NULL flows through ELSE NULL to the outer + COALESCE sentinel. + """ + builder = RedshiftFingerprintQueryBuilder(treat_empty_as_null=True) + serialized = builder.serialize_column("`is_priority`", "boolean") + + assert "CAST(" not in serialized.upper().replace( + "CASE WHEN", "" + ), f"BOOLEAN must not emit CAST(...): {serialized!r}" + assert "::TEXT" not in serialized.upper(), f"BOOLEAN must not emit ::TEXT: {serialized!r}" + + assert "CASE WHEN" in serialized + assert "'true'" in serialized + assert "'false'" in serialized + assert "ELSE NULL" in serialized + assert "COALESCE(" in serialized + assert '"is_priority"' in serialized + assert "`is_priority`" not in serialized + + +def test_serialize_column_non_temporal_non_boolean_uses_cast_as_varchar(): + """Numeric / string / etc. types take the default ``CAST(... AS VARCHAR)`` path. + + BOOLEAN, DATE, TIMESTAMP and TIMESTAMPTZ each take a dedicated branch so + the byte stream matches the row-hash compare path's ``TO_CHAR(...)``; this + test pins the inverse — non-temporal types must NOT accidentally route + through TO_CHAR. + """ + builder = RedshiftFingerprintQueryBuilder(treat_empty_as_null=True) + for col_type in ("integer", "bigint", "numeric(18,2)", "character varying(64)"): + serialized = builder.serialize_column("`some_col`", col_type) + assert 'CAST("some_col" AS VARCHAR(65535))' in serialized, f"{col_type}: {serialized!r}" + assert "CASE WHEN" not in serialized, f"{col_type}: {serialized!r}" + assert "AT TIME ZONE" not in serialized, f"{col_type}: {serialized!r}" + assert "TO_CHAR(" not in serialized, f"{col_type}: {serialized!r}" + + +def test_serialize_column_timestamptz_uses_to_char_with_fixed_microsecond_format(): + """``TO_CHAR(_ AT TIME ZONE 'UTC', 'YYYY-MM-DD HH24:MI:SS.US')`` + matches the row-hash compare path's per-row format and the Spark target's + ``DATE_FORMAT(_, 'yyyy-MM-dd HH:mm:ss.SSSSSS')`` byte-for-byte. Bare + ``CAST(timestamptz AS VARCHAR)`` would emit variable-width fractional seconds + plus a ``+00`` suffix and silently disagree with both siblings. + """ + builder = RedshiftFingerprintQueryBuilder(treat_empty_as_null=False) + + for col_type in ("timestamptz", "TIMESTAMPTZ", "timestamp with time zone", "TIMESTAMP WITH TIME ZONE"): + serialized = builder.serialize_column("`created_at_tz`", col_type) + assert "AT TIME ZONE 'UTC'" in serialized, f"{col_type}: {serialized!r}" + assert "TO_CHAR(" in serialized, f"{col_type}: {serialized!r}" + assert "'YYYY-MM-DD HH24:MI:SS.US'" in serialized, f"{col_type}: {serialized!r}" + # Forbid the bare cast (variable-width microseconds + ``+00`` suffix). + assert 'CAST("created_at_tz" AS VARCHAR(65535))' not in serialized + assert 'CAST("created_at_tz" AT TIME ZONE \'UTC\' AS VARCHAR(65535))' not in serialized + + assert '"created_at_tz"' in serialized + assert '"`created_at_tz`"' not in serialized + assert "COALESCE(" in serialized.upper() + assert "_null_recon_" in serialized + + +def test_serialize_column_timestamp_without_tz_uses_to_char_yyyy_mm_dd_hh_mi_ss_us(): + """``timestamp`` / ``timestamp without time zone`` take the same + ``TO_CHAR`` formatter (no ``AT TIME ZONE`` because the value carries no zone). + """ + builder = RedshiftFingerprintQueryBuilder(treat_empty_as_null=False) + for col_type in ("timestamp", "TIMESTAMP", "timestamp without time zone"): + serialized = builder.serialize_column("`event_ts`", col_type) + assert "TO_CHAR(\"event_ts\", 'YYYY-MM-DD HH24:MI:SS.US')" in serialized, f"{col_type}: {serialized!r}" + assert "AT TIME ZONE" not in serialized, f"{col_type}: {serialized!r}" + assert 'CAST("event_ts" AS VARCHAR(65535))' not in serialized + + +def test_serialize_column_date_uses_to_char_yyyy_mm_dd(): + """``date`` formats via ``TO_CHAR(_, 'YYYY-MM-DD')`` to match the + row-hash compare path and the Spark target's ``DATE_FORMAT(_, 'yyyy-MM-dd')``. + """ + builder = RedshiftFingerprintQueryBuilder(treat_empty_as_null=False) + serialized = builder.serialize_column("`event_dt`", "date") + assert "TO_CHAR(\"event_dt\", 'YYYY-MM-DD')" in serialized, serialized + assert 'CAST("event_dt" AS VARCHAR(65535))' not in serialized + + +def test_serialize_column_timestamptz_respects_treat_empty_as_null(): + """TIMESTAMPTZ wraps with NULLIF when ``treat_empty_as_null=True``, for symmetry with + the Spark target serializer. + """ + builder = RedshiftFingerprintQueryBuilder(treat_empty_as_null=True) + serialized = builder.serialize_column("`created_at_tz`", "timestamptz") + assert "NULLIF(" in serialized.upper() + assert "COALESCE(NULLIF(" in serialized.upper() + + +# --------------------------------------------------------------------------- +# Stage-1 <-> Stage-2 whitespace-handling symmetry +# --------------------------------------------------------------------------- +# +# Lakebridge's row-hash compare (``expression_generator`` universal default) +# uses ``COALESCE(TRIM(_), '_null_recon_')`` -- whitespace-insensitive. If the +# fingerprint Stage-1 builder were to emit ``COALESCE(CAST AS VARCHAR, +# '_null_recon_')`` (whitespace-sensitive), the asymmetry would produce a silent +# correctness gap: a row whose only difference is trailing whitespace would be +# flagged by Stage-1 (different MD5) but absorbed by Stage-2 (same SHA2 after +# TRIM), so the row would be fetched then dropped from the recon output. +# +# Contract: TRIM the cast string before COALESCE on both sides. + + +def test_serialize_column_default_path_pins_max_varchar_width(): + """Default path: ``COALESCE(TRIM(CAST(_ AS VARCHAR(65535))), '_null_recon_')``. + + Bare ``CAST(_ AS VARCHAR)`` in Redshift defaults to ``VARCHAR(256)`` and + silently truncates anything longer; the Spark target keeps the full string + so the asymmetry would surface a Stage-1 false-mismatch on every long-text + row. ``VARCHAR(65535)`` is Redshift's maximum width and matches Spark's + unbounded string semantics. + """ + builder = RedshiftFingerprintQueryBuilder(treat_empty_as_null=False) + serialized = builder.serialize_column("`amount`", "numeric(18,2)") + + assert serialized == 'COALESCE(TRIM(CAST("amount" AS VARCHAR(65535))), \'_null_recon_\')' + # Belt-and-braces structural assertions (cheaper to grep on regression): + assert serialized.startswith("COALESCE(TRIM("), serialized + assert "CAST(" in serialized + assert "_null_recon_" in serialized + + +def test_serialize_column_treat_empty_as_null_keeps_trim_inside_nullif(): + """``treat_empty_as_null=True``: ``COALESCE(NULLIF(TRIM(...), ''), '_null_recon_')``. + + TRIM stays innermost so an all-whitespace value (e.g. ``' '``) collapses to + ``''`` and then NULLIF maps it to NULL -- preserving the ``treat_empty_as_null`` + intent while gaining whitespace-insensitive matching. + """ + builder = RedshiftFingerprintQueryBuilder(treat_empty_as_null=True) + serialized = builder.serialize_column("`name`", "varchar(64)") + + assert serialized == 'COALESCE(NULLIF(TRIM(CAST("name" AS VARCHAR(65535))), \'\'), \'_null_recon_\')' + assert "NULLIF(TRIM(" in serialized + + +def test_serialize_column_boolean_emits_trim_around_case_when(): + """BOOLEAN handler stays CASE WHEN producing ``'true'/'false'``; TRIM wraps it. + + TRIM on those literals is a no-op but the structural symmetry with all other + types matters for review/audit -- there should be exactly one place we decide + how to wrap the cast expression. + """ + builder = RedshiftFingerprintQueryBuilder(treat_empty_as_null=False) + serialized = builder.serialize_column("`is_priority`", "boolean") + + assert "TRIM(CASE WHEN " in serialized + assert "'true'" in serialized + assert "'false'" in serialized + assert serialized.endswith(", '_null_recon_')") + + +def test_serialize_column_timestamptz_emits_trim_around_at_time_zone_cast(): + """TIMESTAMPTZ handler produces ``TO_CHAR(_ AT TIME ZONE 'UTC', '...')``; + TRIM wraps it. TRIM on a timestamp string is a no-op for a well-formed value + but is still applied for cross-type symmetry. + """ + builder = RedshiftFingerprintQueryBuilder(treat_empty_as_null=False) + serialized = builder.serialize_column("`created_at_tz`", "timestamptz") + + assert "AT TIME ZONE 'UTC'" in serialized + assert serialized.endswith(", '_null_recon_')") + + +def test_serialize_column_redshift_and_spark_share_trim_contract(): + """Source-side TRIM exists in the SQL; target-side TRIM must exist in the Spark + serializer for the per-row MD5 to be byte-aligned. Cross-checks the contract + holds at the textual level on the Redshift side; the Spark side is pinned by + ``test_spark_target_serialization.py``.""" + builder = RedshiftFingerprintQueryBuilder(treat_empty_as_null=False) + for col_type in ("integer", "varchar(32)", "timestamp without time zone", "boolean", "timestamptz"): + serialized = builder.serialize_column("`x`", col_type) + assert "TRIM(" in serialized, f"{col_type}: {serialized!r}" + assert "_null_recon_" in serialized, f"{col_type}: {serialized!r}" + + +def test_null_sentinel_parity_with_row_hash_path(): + """The fingerprint NULL stand-in must match the row-hash literal in + ``expression_generator``. Drift here aliases real data ``'_null_recon_'`` + with NULL on only the fingerprint side, so any row that carries that literal + is silently misclassified during Stage-1 sub-bucket aggregation. + + This guard is intentionally a string-literal grep against the source file + rather than an import: if a refactor moves the row-hash sentinel to a named + constant, the grep failure points to *both* call sites in one CI run instead + of letting the import succeed against a renamed but mismatched value. + """ + from pathlib import Path + + from databricks.labs.lakebridge.reconcile.fingerprint.constants import NULL_SENTINEL + + expr_gen = ( + Path(__file__).resolve().parents[3] + / "src" + / "databricks" + / "labs" + / "lakebridge" + / "reconcile" + / "query_builder" + / "expression_generator.py" + ) + if not expr_gen.exists(): + # Repo layout fallback for installed-package tests. + expr_gen = ( + Path(__file__).resolve().parents[4] + / "src" + / "databricks" + / "labs" + / "lakebridge" + / "reconcile" + / "query_builder" + / "expression_generator.py" + ) + text = expr_gen.read_text(encoding="utf-8") + assert f"'{NULL_SENTINEL}'" in text, ( + f"fingerprint NULL_SENTINEL={NULL_SENTINEL!r} not found in row-hash " + f"expression_generator.py — sentinels have drifted; rows whose data " + f"contains {NULL_SENTINEL!r} would be silently misclassified." + ) diff --git a/tests/unit/reconcile/fingerprint/test_row_count.py b/tests/unit/reconcile/fingerprint/test_row_count.py new file mode 100644 index 0000000000..6da838a3ff --- /dev/null +++ b/tests/unit/reconcile/fingerprint/test_row_count.py @@ -0,0 +1,221 @@ +"""Unit tests for the fingerprint target row-count fetcher. + +Three-step fallback chain: + 1. ``override_row_count`` -> USER_OVERRIDE + 2. ``DESCRIBE DETAIL`` succeeds -> DELTA_DESCRIBE_DETAIL + 3. Any failure -> STATIC_DEFAULT with row_count=None + +The fetcher must never raise — tier selection is best-effort. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from pyspark.sql.utils import AnalysisException + +from databricks.labs.lakebridge.reconcile.fingerprint.row_count import ( + RowCountResult, + RowCountSource, + fetch_target_row_count, +) + + +def _make_describe_detail_df(*, columns: list[str], rows: list[dict]) -> MagicMock: + """Mock DataFrame mimicking ``DESCRIBE DETAIL`` (columns + select.collect).""" + df = MagicMock() + df.columns = columns + + if rows is None or "numRecords" not in columns: + return df + + select_result = MagicMock() + select_result.collect.return_value = [_RowLike(r) for r in rows] + df.select.return_value = select_result + return df + + +class _RowLike: + """Mimics PySpark Row ``row['key']`` access.""" + + def __init__(self, mapping: dict): + self._mapping = mapping + + def __getitem__(self, key): + return self._mapping[key] + + +def _make_spark(describe_detail_df: MagicMock | Exception) -> MagicMock: + """Mock SparkSession whose ``.sql()`` returns the DataFrame or raises the exception.""" + spark = MagicMock() + if isinstance(describe_detail_df, Exception): + spark.sql.side_effect = describe_detail_df + else: + spark.sql.return_value = describe_detail_df + return spark + + +# --- Path 1: user override ---------------------------------------------------- + + +def test_user_override_short_circuits_chain(): + """Positive override skips DESCRIBE DETAIL entirely.""" + spark = MagicMock() + result = fetch_target_row_count( + spark, + catalog="test_catalog", + schema="perf_test", + table="orders", + override_row_count=100_000_000, + ) + assert result == RowCountResult(row_count=100_000_000, source=RowCountSource.USER_OVERRIDE) + spark.sql.assert_not_called() + + +@pytest.mark.parametrize("override", [0, -1, None]) +def test_user_override_zero_or_none_falls_through_to_describe_detail(override): + """Non-positive overrides are treated as "not given"; DESCRIBE DETAIL still runs.""" + df = _make_describe_detail_df(columns=["numRecords"], rows=[{"numRecords": 42}]) + spark = _make_spark(df) + result = fetch_target_row_count( + spark, + catalog="test_catalog", + schema="perf_test", + table="orders", + override_row_count=override, + ) + assert result.source == RowCountSource.DELTA_DESCRIBE_DETAIL + assert result.row_count == 42 + + +# --- Path 2: DESCRIBE DETAIL success ------------------------------------------ + + +def test_describe_detail_returns_num_records_for_delta_table(): + df = _make_describe_detail_df( + columns=["format", "id", "name", "numFiles", "numRecords", "createdAt"], + rows=[{"numRecords": 100_000_000}], + ) + spark = _make_spark(df) + result = fetch_target_row_count(spark, catalog="test_catalog", schema="perf_test", table="orders") + assert result == RowCountResult(row_count=100_000_000, source=RowCountSource.DELTA_DESCRIBE_DETAIL) + spark.sql.assert_called_once_with("DESCRIBE DETAIL test_catalog.perf_test.orders") + + +def test_describe_detail_works_without_catalog(): + """Two-part ``schema.table`` naming for hive_metastore-style references.""" + df = _make_describe_detail_df(columns=["numRecords"], rows=[{"numRecords": 1_000}]) + spark = _make_spark(df) + result = fetch_target_row_count(spark, catalog=None, schema="default", table="orders") + assert result.row_count == 1_000 + spark.sql.assert_called_once_with("DESCRIBE DETAIL default.orders") + + +def test_describe_detail_zero_rows_is_legitimate(): + """Empty-table case: numRecords=0 is a valid result, not a fall-through.""" + df = _make_describe_detail_df(columns=["numRecords"], rows=[{"numRecords": 0}]) + spark = _make_spark(df) + result = fetch_target_row_count(spark, catalog="test_catalog", schema="perf_test", table="orders") + assert result == RowCountResult(row_count=0, source=RowCountSource.DELTA_DESCRIBE_DETAIL) + + +# --- Path 3: fallback to STATIC_DEFAULT --------------------------------------- + + +def test_table_not_found_falls_back_to_static_default(): + """``AnalysisException`` must not propagate; tier selection is best-effort.""" + spark = _make_spark(AnalysisException("Table or view not found: test_catalog.perf_test.bogus")) + result = fetch_target_row_count(spark, catalog="test_catalog", schema="perf_test", table="bogus") + assert result == RowCountResult(row_count=None, source=RowCountSource.STATIC_DEFAULT) + + +def test_unexpected_exception_falls_back_to_static_default(): + """Unexpected errors must not propagate (tier selection is best-effort).""" + spark = _make_spark(RuntimeError("kerberos creds expired")) + result = fetch_target_row_count(spark, catalog="test_catalog", schema="perf_test", table="orders") + assert result == RowCountResult(row_count=None, source=RowCountSource.STATIC_DEFAULT) + + +def test_non_delta_target_with_no_num_records_column_falls_back(): + """Non-Delta target — DESCRIBE DETAIL succeeds but the column is absent.""" + df = _make_describe_detail_df( + columns=["format", "id", "name"], # no numRecords + rows=[], + ) + spark = _make_spark(df) + result = fetch_target_row_count(spark, catalog="test_catalog", schema="perf_test", table="orders") + assert result == RowCountResult(row_count=None, source=RowCountSource.STATIC_DEFAULT) + + +def test_describe_detail_returning_zero_rows_falls_back(): + """Defensive: zero rows from DESCRIBE DETAIL must not IndexError.""" + df = _make_describe_detail_df(columns=["numRecords"], rows=[]) + spark = _make_spark(df) + result = fetch_target_row_count(spark, catalog="test_catalog", schema="perf_test", table="orders") + assert result == RowCountResult(row_count=None, source=RowCountSource.STATIC_DEFAULT) + + +def test_describe_detail_returning_null_num_records_falls_back(): + """numRecords NULL (per-file stats disabled) must fall through, not feed None to the tier selector.""" + df = _make_describe_detail_df(columns=["numRecords"], rows=[{"numRecords": None}]) + spark = _make_spark(df) + result = fetch_target_row_count(spark, catalog="test_catalog", schema="perf_test", table="orders") + assert result == RowCountResult(row_count=None, source=RowCountSource.STATIC_DEFAULT) + + +def test_describe_detail_returning_unexpected_type_falls_back(): + """Defensive against driver/SDK drift: non-int numRecords falls through.""" + df = _make_describe_detail_df(columns=["numRecords"], rows=[{"numRecords": "100000000"}]) + spark = _make_spark(df) + result = fetch_target_row_count(spark, catalog="test_catalog", schema="perf_test", table="orders") + assert result == RowCountResult(row_count=None, source=RowCountSource.STATIC_DEFAULT) + + +def test_describe_detail_returning_negative_num_records_falls_back(): + """Negative numRecords is a corruption signal; fall through.""" + df = _make_describe_detail_df(columns=["numRecords"], rows=[{"numRecords": -1}]) + spark = _make_spark(df) + result = fetch_target_row_count(spark, catalog="test_catalog", schema="perf_test", table="orders") + assert result == RowCountResult(row_count=None, source=RowCountSource.STATIC_DEFAULT) + + +# --- Audit trail / logging ---------------------------------------------------- + + +def test_static_default_path_emits_warning_log(caplog): + """Static-default fallback must log at WARNING for operator visibility.""" + spark = _make_spark(AnalysisException("not found")) + with caplog.at_level("WARNING"): + fetch_target_row_count(spark, catalog="test_catalog", schema="perf_test", table="orders") + assert any( + "row_count_source=static_default" in rec.message for rec in caplog.records + ), "STATIC_DEFAULT fallback must log at WARNING level for operator visibility" + + +def test_delta_describe_detail_path_emits_info_log(caplog): + """Success path logs INFO with the ``key=value`` structured shape.""" + df = _make_describe_detail_df(columns=["numRecords"], rows=[{"numRecords": 100_000_000}]) + spark = _make_spark(df) + with caplog.at_level("INFO"): + fetch_target_row_count(spark, catalog="test_catalog", schema="perf_test", table="orders") + assert any( + "row_count_source=delta_describe_detail" in rec.message and "row_count=100000000" in rec.message + for rec in caplog.records + ) + + +def test_user_override_path_emits_info_log(caplog): + spark = MagicMock() + with caplog.at_level("INFO"): + fetch_target_row_count( + spark, + catalog="test_catalog", + schema="perf_test", + table="orders", + override_row_count=100_000_000, + ) + assert any( + "row_count_source=user_override" in rec.message and "row_count=100000000" in rec.message + for rec in caplog.records + ) diff --git a/tests/unit/reconcile/fingerprint/test_spark_target_serialization.py b/tests/unit/reconcile/fingerprint/test_spark_target_serialization.py new file mode 100644 index 0000000000..4906a450f9 --- /dev/null +++ b/tests/unit/reconcile/fingerprint/test_spark_target_serialization.py @@ -0,0 +1,177 @@ +"""Pin Stage-1 (DataFrame) ↔ Stage-2 (SQL) target serialisation symmetry. + +The two helpers in ``spark_target.py`` must produce the same hash inputs +byte-for-byte for any column value. Without that contract, trailing-whitespace +rows surfaced by Stage-1 detection cannot be re-located by Stage-2 surgical +fetch and silently drop from the recon output. +""" + +from __future__ import annotations + +import pytest + +from databricks.labs.lakebridge.reconcile.fingerprint.spark_target import ( # pylint: disable=import-private-name + _quote_spark_identifier, + _serialize_column_spark_sql, + build_target_filter_subquery, +) +from databricks.labs.lakebridge.reconcile.recon_config import Schema + + +@pytest.mark.parametrize("treat_empty_as_null", [False, True]) +def test_serialize_column_spark_sql_emits_trim(treat_empty_as_null: bool) -> None: + sql = _serialize_column_spark_sql("notes", "varchar(64)", treat_empty_as_null) + assert "TRIM(CAST(`notes` AS STRING))" in sql, sql + assert sql.endswith(", '_null_recon_')") + if treat_empty_as_null: + assert "NULLIF(TRIM(" in sql, sql + + +def test_build_target_filter_subquery_serialises_with_trim() -> None: + columns = [ + Schema("`order_id`", "int", "`order_id`", "`order_id`"), + Schema("`notes`", "varchar(64)", "`notes`", "`notes`"), + ] + sql = build_target_filter_subquery( + catalog="test_catalog", + schema="fp_correctness", + table="orders", + columns=columns, + column_mapping=None, + solved_hashes={1: [101, 102]}, + unsolved_sb_ids=[7], + sub_bucket_count=1024, + ) + assert "TRIM(CAST(`order_id` AS STRING))" in sql, sql + assert "TRIM(CAST(`notes` AS STRING))" in sql, sql + # Catalog/schema/table route through ``_quote_spark_identifier``. + assert "FROM `test_catalog`.`fp_correctness`.`orders`" in sql, sql + assert "_fp_filtered" in sql + + +def test_build_target_filter_subquery_omits_catalog_when_unset() -> None: + sql = build_target_filter_subquery( + catalog=None, + schema="schema_only", + table="orders", + columns=[Schema("`order_id`", "int", "`order_id`", "`order_id`")], + column_mapping=None, + solved_hashes={}, + unsolved_sb_ids=[1, 2], + sub_bucket_count=1024, + ) + assert "FROM `schema_only`.`orders`" in sql, sql + assert "test_catalog." not in sql + + +def test_quote_spark_identifier_doubles_embedded_backtick() -> None: + """Defense-in-depth: a column name containing a literal backtick must not break the SQL.""" + assert _quote_spark_identifier("plain") == "`plain`" + assert _quote_spark_identifier("we`ird") == "`we``ird`" + # Two embedded backticks → four doubled inside, plus outer pair = six total. + assert _quote_spark_identifier("``") == "``````" + + +def test_build_target_filter_subquery_resolves_column_mapping() -> None: + sql = build_target_filter_subquery( + catalog="c", + schema="s", + table="t", + columns=[Schema("`src_id`", "int", "`src_id`", "`src_id`")], + column_mapping={"src_id": "tgt_id"}, + solved_hashes={0: [1]}, + unsolved_sb_ids=[], + sub_bucket_count=1024, + ) + assert "`tgt_id`" in sql, sql + assert "`src_id`" not in sql, sql + + +@pytest.mark.parametrize("col_type", ["timestamp_ntz", "timestamp without time zone"]) +def test_serialize_column_spark_sql_uses_plain_date_format_for_naive_timestamps(col_type: str) -> None: + """Naive (NTZ) timestamps render directly with ``DATE_FORMAT`` because the + value carries no timezone semantics — same shape as Redshift's + ``TO_CHAR(_, 'YYYY-MM-DD HH24:MI:SS.US')`` for ``timestamp without time zone``. + """ + sql = _serialize_column_spark_sql("ts_col", col_type, treat_empty_as_null=False) + assert "TRIM(DATE_FORMAT(`ts_col`, 'yyyy-MM-dd HH:mm:ss.SSSSSS'))" in sql, sql + # NTZ columns must NOT route through the UTC pin — that would shift values + # that have no associated timezone. + assert "TO_UTC_TIMESTAMP" not in sql, sql + assert "CAST(`ts_col` AS STRING)" not in sql, sql + + +@pytest.mark.parametrize( + "col_type", + ["timestamp", "timestamp_ltz", "timestamp with time zone", "timestamp with local time zone"], +) +def test_serialize_column_spark_sql_pins_utc_for_tz_aware_timestamps(col_type: str) -> None: + """TZ-aware (LTZ) Spark timestamps must be normalised to the UTC wall-clock + before formatting so the bytes match Redshift's + ``TO_CHAR(_ AT TIME ZONE 'UTC', 'YYYY-MM-DD HH24:MI:SS.US')`` regardless of + the cluster's ``spark.sql.session.timeZone``. Without this pin, a non-UTC + session timezone would render the same instant differently on the two + sides and Stage-1 would over-report mismatches on every TZ-aware column. + """ + sql = _serialize_column_spark_sql("ts_col", col_type, treat_empty_as_null=False) + assert ( + "TRIM(DATE_FORMAT(TO_UTC_TIMESTAMP(`ts_col`, CURRENT_TIMEZONE()), " "'yyyy-MM-dd HH:mm:ss.SSSSSS'))" + ) in sql, sql + assert "CAST(`ts_col` AS STRING)" not in sql, sql + + +def test_serialize_column_spark_sql_uses_date_format_for_date() -> None: + """``date`` columns format to ``yyyy-MM-dd`` to match Redshift's ``TO_CHAR(_, 'YYYY-MM-DD')``.""" + sql = _serialize_column_spark_sql("d_col", "date", treat_empty_as_null=False) + assert "DATE_FORMAT(`d_col`, 'yyyy-MM-dd')" in sql, sql + + +def test_build_target_filter_subquery_uses_date_format_for_timestamps() -> None: + """Stage-2 SQL filter subquery must inherit the timestamp format end-to-end. + + Pins both timestamp families: the TZ-aware default ``timestamp`` routes + through the UTC pin; the NTZ variant renders directly. Both share the + same fractional-second precision so the bytes match Redshift. + """ + columns = [ + Schema("`order_id`", "int", "`order_id`", "`order_id`"), + Schema("`event_ts`", "timestamp", "`event_ts`", "`event_ts`"), + Schema("`event_ts_naive`", "timestamp_ntz", "`event_ts_naive`", "`event_ts_naive`"), + Schema("`event_dt`", "date", "`event_dt`", "`event_dt`"), + ] + sql = build_target_filter_subquery( + catalog="c", + schema="s", + table="t", + columns=columns, + column_mapping=None, + solved_hashes={1: [101]}, + unsolved_sb_ids=[], + sub_bucket_count=1024, + ) + # TZ-aware: UTC-pinned via TO_UTC_TIMESTAMP. + assert ("TO_UTC_TIMESTAMP(`event_ts`, CURRENT_TIMEZONE()), " "'yyyy-MM-dd HH:mm:ss.SSSSSS'") in sql, sql + # NTZ: direct DATE_FORMAT, no pin. + assert "DATE_FORMAT(`event_ts_naive`, 'yyyy-MM-dd HH:mm:ss.SSSSSS')" in sql, sql + assert "DATE_FORMAT(`event_dt`, 'yyyy-MM-dd')" in sql, sql + # The non-temporal column still uses the default cast path. + assert "CAST(`order_id` AS STRING)" in sql, sql + + +def test_serialize_column_spark_sql_handles_dot_in_column_name() -> None: + """Delta column names containing ``.`` must be backtick-escaped on the SQL + path so Spark resolves the column literally instead of as a struct field + path. Without escaping, ``F.col("a.b")`` would attempt ``a.b`` resolution + and crash mid-Stage-1; ``_quote_spark_identifier`` is applied on both paths. + """ + sql = _serialize_column_spark_sql("a.b", "string", treat_empty_as_null=False) + assert "`a.b`" in sql, sql + assert "CAST(`a.b` AS STRING)" in sql, sql + + +def test_quote_spark_identifier_handles_dot() -> None: + """``.`` inside an identifier must be wrapped without further treatment — the + backtick fence is what tells Spark "this is a column name, not a struct path". + """ + assert _quote_spark_identifier("a.b") == "`a.b`" + assert _quote_spark_identifier("event.timestamp.utc") == "`event.timestamp.utc`" diff --git a/tests/unit/reconcile/fingerprint/test_tier_selection.py b/tests/unit/reconcile/fingerprint/test_tier_selection.py new file mode 100644 index 0000000000..8c93db038a --- /dev/null +++ b/tests/unit/reconcile/fingerprint/test_tier_selection.py @@ -0,0 +1,243 @@ +"""Unit tests for orchestrator-level adaptive sub-bucket tier selection. + +Pins: + - ``_select_tier`` returns the correct (sub_bucket_count, bucket_count) for a row count. + - DESCRIBE DETAIL failure falls through to static defaults. + - Source and target paths receive the same tier (GROUP BY alignment). + - User override short-circuits the metadata lookup. + - ``FingerprintResult`` carries tier provenance through to recon_metrics. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from pyspark.sql.utils import AnalysisException + +from databricks.labs.lakebridge.reconcile.fingerprint.constants import ( + BUCKET_COUNT, + SUB_BUCKET_COUNT, +) +from databricks.labs.lakebridge.reconcile.fingerprint.orchestrator import ( # pylint: disable=import-private-name + _select_tier, +) +from databricks.labs.lakebridge.reconcile.fingerprint.row_count import RowCountSource +from tests.unit.reconcile.fingerprint._fixtures import ( + make_database_config, + make_describe_detail_df, + make_table_conf, +) + +# --- Tier selection: each row count maps to the expected tier -------------- + + +@pytest.mark.parametrize( + ("num_records", "expected_sub_buckets", "expected_buckets", "expected_source"), + [ + # < 50K — sparse-table tier + (10_000, 16_384, 128, RowCountSource.DELTA_DESCRIBE_DETAIL), + # 500K – 50M — legacy default tier + (10_000_000, 1_048_576, 1_024, RowCountSource.DELTA_DESCRIBE_DETAIL), + # 50M – 500M — NEW tier (P1 100M fixture lands here) + (100_000_000, 2_097_152, 2_048, RowCountSource.DELTA_DESCRIBE_DETAIL), + # 500M – 5B — NEW tier (P2 1B fixture lands here) + (1_000_000_000, 4_194_304, 4_096, RowCountSource.DELTA_DESCRIBE_DETAIL), + # 5B – 50B — NEW tier (covers 20B+ target) + (20_000_000_000, 8_388_608, 8_192, RowCountSource.DELTA_DESCRIBE_DETAIL), + # 50B+ — clamp tier + (100_000_000_000, 16_777_216, 16_384, RowCountSource.DELTA_DESCRIBE_DETAIL), + ], +) +def test_select_tier_picks_correct_tier_from_target_delta_count( + num_records, expected_sub_buckets, expected_buckets, expected_source +): + """``_select_tier`` reads target ``numRecords`` and returns a tier for both sides.""" + spark = MagicMock() + spark.sql.return_value = make_describe_detail_df(num_records) + tier = _select_tier(spark, make_database_config(), make_table_conf()) + assert tier.sub_bucket_count == expected_sub_buckets + assert tier.bucket_count == expected_buckets + assert tier.target_row_count == num_records + assert tier.row_count_source == expected_source.value + + +def test_select_tier_falls_back_to_static_default_when_describe_detail_fails(): + """DESCRIBE DETAIL failure falls back to legacy static defaults.""" + spark = MagicMock() + spark.sql.side_effect = AnalysisException("Table or view not found: test_catalog.perf_test.orders") + tier = _select_tier(spark, make_database_config(), make_table_conf()) + assert tier.sub_bucket_count == SUB_BUCKET_COUNT + assert tier.bucket_count == BUCKET_COUNT + assert tier.target_row_count is None + assert tier.row_count_source == RowCountSource.STATIC_DEFAULT.value + + +def test_select_tier_uses_target_catalog_and_schema_not_source(): + """Tier comes from the target Delta table; source-side row counts are not consulted.""" + spark = MagicMock() + spark.sql.return_value = make_describe_detail_df(100_000_000) + _select_tier(spark, make_database_config(), make_table_conf(target_name="my_target_table")) + spark.sql.assert_called_once() + call_arg = spark.sql.call_args[0][0] + # Must reference TARGET catalog/schema/table, not source. + assert ( + "test_catalog.perf_test.my_target_table" in call_arg + ), f"_select_tier must DESCRIBE DETAIL the TARGET table; got SQL: {call_arg!r}" + assert "source_catalog" not in call_arg, ( + f"_select_tier must NEVER reference source catalog (Redshift side has no Delta metadata); " + f"got SQL: {call_arg!r}" + ) + + +def test_select_tier_user_override_short_circuits_describe_detail(): + """``override_row_count`` short-circuits — no spark.sql call at all.""" + spark = MagicMock() + tier = _select_tier( + spark, + make_database_config(), + make_table_conf(), + override_row_count=15_800_000_000, + ) + spark.sql.assert_not_called() + # 15.8B → 5B-50B tier + assert tier.sub_bucket_count == 8_388_608 + assert tier.bucket_count == 8_192 + assert tier.target_row_count == 15_800_000_000 + assert tier.row_count_source == RowCountSource.USER_OVERRIDE.value + + +# --- Source / target receive identical tier -------------------------------- + + +def test_run_fingerprint_precheck_passes_same_tier_to_source_and_target(): + """Source and target must receive the same tier — mismatched moduli mis-align GROUP BY.""" + from databricks.labs.lakebridge.reconcile.fingerprint import ( # pylint: disable=import-outside-toplevel + orchestrator as orch, + ) + from databricks.labs.lakebridge.reconcile.fingerprint.engine import ( # pylint: disable=import-outside-toplevel + DetectionResult, + ) + from databricks.labs.lakebridge.reconcile.fingerprint.orchestrator import ( # pylint: disable=import-outside-toplevel + ColumnAlignment, + ) + from databricks.labs.lakebridge.reconcile.recon_config import ( # pylint: disable=import-outside-toplevel + Schema, + ) + + captured_tier_detection: list = [] + + fake_detection = DetectionResult(verdict="MATCH") + + def fake_detection_phase(*args, **kwargs): + # ``tier`` is the 8th positional argument; tolerate kwarg form too. + tier = kwargs.get("tier") if "tier" in kwargs else args[7] + captured_tier_detection.append(tier) + return fake_detection, 42 + + def fake_resolve_cols(*_args, **_kwargs): + return [Schema("`order_id`", "bigint", "`order_id`", '"order_id"')] + + def fake_align_columns(*_args, **_kwargs): + return ColumnAlignment(column_mapping=None) + + spark = MagicMock() + spark.sql.return_value = make_describe_detail_df(100_000_000) + source = MagicMock() + target = MagicMock() + source_engine = MagicMock() + + with ( + patch.object(orch, "_run_detection_phase", side_effect=fake_detection_phase), + patch.object(orch, "_resolve_detection_columns", side_effect=fake_resolve_cols), + patch.object(orch, "align_columns", side_effect=fake_align_columns), + patch.object(orch, "get_query_builder", return_value=MagicMock()), + ): + result = orch.run_fingerprint_precheck( + source=source, + target=target, + spark=spark, + source_engine=source_engine, + database_config=make_database_config(), + table_conf=make_table_conf(), + src_schema=[], + tgt_schema=[], + report_type="data", + data_source="redshift", + ) + + # MATCH verdict — _run_detection_phase was called once with the tier. + assert len(captured_tier_detection) == 1 + tier = captured_tier_detection[0] + # 100M → 50M-500M tier + assert tier.sub_bucket_count == 2_097_152 + assert tier.bucket_count == 2_048 + + # Tier provenance flows into ``FingerprintResult``. + assert result is not None + assert result.verdict == "MATCH" + assert result.sub_bucket_count == 2_097_152 + assert result.bucket_count == 2_048 + assert result.target_row_count == 100_000_000 + assert result.row_count_source == "delta_describe_detail" + + +def test_fingerprint_result_carries_static_default_provenance_on_describe_detail_failure(): + """``FingerprintResult`` carries ``row_count_source="static_default"`` after fall-through.""" + from databricks.labs.lakebridge.reconcile.fingerprint import ( # pylint: disable=import-outside-toplevel + orchestrator as orch, + ) + from databricks.labs.lakebridge.reconcile.fingerprint.engine import ( # pylint: disable=import-outside-toplevel + DetectionResult, + ) + from databricks.labs.lakebridge.reconcile.fingerprint.orchestrator import ( # pylint: disable=import-outside-toplevel + ColumnAlignment, + ) + from databricks.labs.lakebridge.reconcile.recon_config import ( # pylint: disable=import-outside-toplevel + Schema, + ) + + fake_detection = DetectionResult(verdict="MATCH") + + def fake_detection_phase(*_args, **_kwargs): + return fake_detection, 42 + + def fake_resolve_cols(*_args, **_kwargs): + return [Schema("`order_id`", "bigint", "`order_id`", '"order_id"')] + + def fake_align_columns(*_args, **_kwargs): + return ColumnAlignment(column_mapping=None) + + spark = MagicMock() + spark.sql.side_effect = AnalysisException("table not found") + source = MagicMock() + target = MagicMock() + source_engine = MagicMock() + + with ( + patch.object(orch, "_run_detection_phase", side_effect=fake_detection_phase), + patch.object(orch, "_resolve_detection_columns", side_effect=fake_resolve_cols), + patch.object(orch, "align_columns", side_effect=fake_align_columns), + patch.object(orch, "get_query_builder", return_value=MagicMock()), + ): + result = orch.run_fingerprint_precheck( + source=source, + target=target, + spark=spark, + source_engine=source_engine, + database_config=make_database_config(), + table_conf=make_table_conf(), + src_schema=[], + tgt_schema=[], + report_type="data", + data_source="redshift", + ) + + assert result is not None + assert result.verdict == "MATCH" + # Static fallback yields the legacy hardcoded tier. + assert result.sub_bucket_count == SUB_BUCKET_COUNT + assert result.bucket_count == BUCKET_COUNT + # And carries the ``static_default`` provenance so it's auditable. + assert result.target_row_count is None + assert result.row_count_source == "static_default" diff --git a/tests/unit/reconcile/fingerprint/test_treat_empty_as_null_consistency.py b/tests/unit/reconcile/fingerprint/test_treat_empty_as_null_consistency.py new file mode 100644 index 0000000000..21b5a0dfa8 --- /dev/null +++ b/tests/unit/reconcile/fingerprint/test_treat_empty_as_null_consistency.py @@ -0,0 +1,170 @@ +"""Regression: source-side and target-side ``treat_empty_as_null`` must agree. + +Pre-fix: ``_DEFAULT_TREAT_EMPTY_AS_NULL`` was wired into ``get_query_builder`` (source) +but ``compute_target_fingerprint`` and ``build_target_filter_subquery`` silently kept +their function defaults of ``False``. Today the two happen to coincide, so flipping the +constant to ``True`` would have made source serialise ``''`` as ``'_null_recon_'`` while +target kept ``''`` — a systemic Stage-1 mismatch on every empty cell, fail-open +rewriting every recon to the full pipeline. The fix threads +``ReconcileConfig.fingerprint_treat_empty_as_null`` through ``run_fingerprint_precheck`` +to all three serialisation entry points; this test pins that single source of truth. +""" + +from __future__ import annotations + +import pytest + +from databricks.labs.lakebridge.reconcile.fingerprint.orchestrator import get_query_builder +from databricks.labs.lakebridge.reconcile.fingerprint.spark_target import ( # pylint: disable=import-private-name + _serialize_column_spark_sql, + build_target_filter_subquery, +) +from databricks.labs.lakebridge.reconcile.recon_config import Schema + + +@pytest.mark.parametrize("treat_empty_as_null", [False, True]) +def test_source_builder_picks_up_flag(treat_empty_as_null: bool) -> None: + builder = get_query_builder("redshift", treat_empty_as_null=treat_empty_as_null) + serialised = builder.serialize_column("`notes`", "varchar(64)") + if treat_empty_as_null: + assert "NULLIF(" in serialised, serialised + else: + assert "NULLIF(" not in serialised, serialised + + +@pytest.mark.parametrize("treat_empty_as_null", [False, True]) +def test_target_filter_subquery_picks_up_flag(treat_empty_as_null: bool) -> None: + sql = build_target_filter_subquery( + catalog="c", + schema="s", + table="t", + columns=[Schema("`notes`", "varchar(64)", "`notes`", "`notes`")], + column_mapping=None, + solved_hashes={1: [101]}, + unsolved_sb_ids=[], + sub_bucket_count=1024, + treat_empty_as_null=treat_empty_as_null, + ) + if treat_empty_as_null: + assert "NULLIF(" in sql, sql + else: + assert "NULLIF(" not in sql, sql + + +@pytest.mark.parametrize("treat_empty_as_null", [False, True]) +def test_target_stage1_and_stage2_agree_under_flag_flip(treat_empty_as_null: bool) -> None: + """The DataFrame-side serialiser (Stage-1) and the SQL-string sibling (Stage-2) must + apply the same ``treat_empty_as_null`` semantics — silent disagreement here was the + same class of silent-miss bug as the Stage-1/Stage-2 trim asymmetry. + """ + sql = _serialize_column_spark_sql("notes", "varchar(64)", treat_empty_as_null) + if treat_empty_as_null: + assert "NULLIF(" in sql, sql + else: + assert "NULLIF(" not in sql, sql + assert "TRIM(CAST(`notes` AS STRING))" in sql, sql + + +def test_run_fingerprint_precheck_threads_flag_to_all_three_serialisers() -> None: + """End-to-end: a single ``treat_empty_as_null`` argument on + ``run_fingerprint_precheck`` reaches the source builder, the Stage-1 target + aggregate, and the Stage-2 target filter subquery. Without this contract the + config field is decorative. + """ + # pylint: disable=import-outside-toplevel + # Imports are local: this single-test module exercises a deeply-mocked code path + # and the assertion-side imports must follow the patches; hoisting them to module + # top would force the early bind of names the patches replace. + from unittest.mock import MagicMock, patch + + from databricks.labs.lakebridge.reconcile.fingerprint import orchestrator as orch + from databricks.labs.lakebridge.reconcile.fingerprint.engine import DetectionResult + from databricks.labs.lakebridge.reconcile.fingerprint.orchestrator import ColumnAlignment + + captured: dict[str, object] = {} + + def capture_get_query_builder(_data_source: str, *, treat_empty_as_null: bool): + captured["source_flag"] = treat_empty_as_null + return MagicMock() + + def capture_compute_target_fingerprint(**kwargs): + captured["target_stage1_flag"] = kwargs["treat_empty_as_null"] + return MagicMock() + + def capture_build_target_filter_subquery(*_args, **kwargs): + captured["target_stage2_flag"] = kwargs["treat_empty_as_null"] + return "(SELECT * FROM t WHERE 1=1) _fp_filtered" + + # ``_fetch_target_rows`` builds a real ``HashQueryBuilder`` over the mock + # schema; bypass it via a thin shim that exercises only the contract we care + # about — that the flag reaches ``build_target_filter_subquery``. + def shim_fetch_target_rows(ctx, solved_hashes, unsolved_sb_ids, _report_type): + captured["target_stage2_flag"] = ctx.treat_empty_as_null + # Calling the real ``build_target_filter_subquery`` (already patched) keeps + # the wiring assertion honest: if the orchestrator forgot to populate + # ``ctx.treat_empty_as_null``, the mock would record ``False`` instead of + # the value passed by the caller. + orch.build_target_filter_subquery( + ctx.database_config.target_catalog, + ctx.database_config.target_schema, + ctx.table_conf.target_name, + ctx.detection_cols, + ctx.column_mapping, + solved_hashes, + unsolved_sb_ids, + sub_bucket_count=ctx.tier.sub_bucket_count, + treat_empty_as_null=ctx.treat_empty_as_null, + ) + return MagicMock() + + with ( + patch.object(orch, "get_query_builder", side_effect=capture_get_query_builder), + patch.object(orch, "compute_target_fingerprint", side_effect=capture_compute_target_fingerprint), + patch.object(orch, "build_target_filter_subquery", side_effect=capture_build_target_filter_subquery), + patch.object(orch, "align_columns", return_value=ColumnAlignment(column_mapping=None)), + patch.object( + orch, "_resolve_detection_columns", return_value=[Schema("`notes`", "string", "`notes`", "`notes`")] + ), + patch.object( + orch, + "_select_tier", + return_value=orch._TierSelection( # pylint: disable=protected-access + sub_bucket_count=1024, bucket_count=128, target_row_count=100, row_count_source="static_default" + ), + ), + patch.object( + orch, + "detect_and_solve", + return_value=DetectionResult(verdict="MISMATCH", solved_results=[], unsolved_sb_ids=[7]), + ), + patch.object(orch, "_fetch_source_rows", return_value=(MagicMock(), "v1_sandwich")), + patch.object(orch, "_fetch_target_rows", side_effect=shim_fetch_target_rows), + ): + from databricks.labs.lakebridge.config import DatabaseConfig # pylint: disable=import-outside-toplevel + from databricks.labs.lakebridge.reconcile.recon_config import Table # pylint: disable=import-outside-toplevel + + source = MagicMock() + target = MagicMock() + source.read_data.return_value = MagicMock() + target.read_data.return_value = MagicMock() + orch.run_fingerprint_precheck( + source=source, + target=target, + spark=MagicMock(), + source_engine=MagicMock(), + database_config=DatabaseConfig( + source_catalog="sc", source_schema="ss", target_catalog="tc", target_schema="ts" + ), + table_conf=Table(source_name="t", target_name="t", join_columns=["id"]), + src_schema=[Schema("`notes`", "string", "`notes`", "`notes`")], + tgt_schema=[Schema("`notes`", "string", "`notes`", "`notes`")], + report_type="data", + data_source="redshift", + treat_empty_as_null=True, + ) + + assert captured == { + "source_flag": True, + "target_stage1_flag": True, + "target_stage2_flag": True, + }, f"flag did not reach every serialiser; captured={captured!r}" diff --git a/tests/unit/reconcile/query_builder/test_expression_generator.py b/tests/unit/reconcile/query_builder/test_expression_generator.py index caf7796b01..467a5d9e5d 100644 --- a/tests/unit/reconcile/query_builder/test_expression_generator.py +++ b/tests/unit/reconcile/query_builder/test_expression_generator.py @@ -5,6 +5,7 @@ from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect from databricks.labs.lakebridge.reconcile.query_builder.expression_generator import ( + DataType_transform_mapping, array_sort, array_to_string, build_between, @@ -238,3 +239,109 @@ def test_build_between(): ) assert str(result) == "test_table.test_column BETWEEN 1 AND 2" assert isinstance(result, exp.Between) + + +# --------------------------------------------------------------------------- +# Dialect-aligned TIMESTAMP/TIMESTAMPTZ serialization +# --------------------------------------------------------------------------- +# +# Regression for the ``expression_generator.py`` mapping bug where Redshift's +# source-side hash input emits a 26-char microsecond-precision string +# (``2023-11-18 18:38:07.000000``) but Databricks' target-side default cast +# emits a 19-char string (``2023-11-18 18:38:07``). The 7-byte drift makes the +# per-row SHA2 hashes disagree for *every* logically-identical TIMESTAMP/ +# TIMESTAMPTZ row in any Redshift -> Databricks reconcile -- in normal mode +# this is whole-table noise; in fingerprint mode it surfaces only on the +# small set of rows that Stage-2's surgical fetch over-pulls due to 32-bit +# ``rh1`` sub-bucket collisions (~5 per 1M rows / 10K culprits, statistical). +# +# The fix adds explicit Databricks handlers so the target side emits the same +# ``yyyy-MM-dd HH:mm:ss.SSSSSS`` shape as Redshift's +# ``TO_CHAR(ts, 'YYYY-MM-DD HH24:MI:SS.US')``. + + +def _apply_handler(handler_partial, col_name: str = "ts_col") -> str: + """Run a single ``DataType_transform_mapping`` partial against a column expression + and return the rendered SQL. Mirrors how ``HashQueryBuilder._apply_transform`` + threads partials over the projected expression.""" + rendered = handler_partial(exp.Column(this=col_name)) + return rendered.sql(dialect="databricks") + + +_EXPECTED_REDSHIFT_TS_SQL = "COALESCE(TO_CHAR(ts_col, 'YYYY-MM-DD HH24:MI:SS.US'), '_null_recon_')" +_EXPECTED_DATABRICKS_TS_SQL = "COALESCE(DATE_FORMAT(ts_col, 'yyyy-MM-dd HH:mm:ss.SSSSSS'), '_null_recon_')" + + +def test_databricks_timestamp_handler_emits_exact_microsecond_sql(): + """Pin the exact rendered SQL for the Databricks TIMESTAMP handler. + + A substring match on ``HH:mm:ss.SSSSSS`` would tolerate a typo in the + surrounding ``COALESCE`` / sentinel pieces; equality on the whole string + is the strongest unit-level guard. + """ + handlers = DataType_transform_mapping["databricks"][exp.DataType.Type.TIMESTAMP.value] + assert len(handlers) == 1, "expected exactly one TIMESTAMP handler for databricks" + assert _apply_handler(handlers[0]) == _EXPECTED_DATABRICKS_TS_SQL + + +def test_databricks_timestamptz_handler_emits_exact_microsecond_sql(): + handlers = DataType_transform_mapping["databricks"][exp.DataType.Type.TIMESTAMPTZ.value] + assert len(handlers) == 1, "expected exactly one TIMESTAMPTZ handler for databricks" + assert _apply_handler(handlers[0]) == _EXPECTED_DATABRICKS_TS_SQL + + +def test_redshift_timestamp_handler_emits_exact_microsecond_sql(): + rs_handlers = DataType_transform_mapping["redshift"][exp.DataType.Type.TIMESTAMP.value] + assert len(rs_handlers) == 1, "expected exactly one TIMESTAMP handler for redshift" + rs_rendered = rs_handlers[0](exp.Column(this="ts_col")).sql(dialect="redshift") + assert rs_rendered == _EXPECTED_REDSHIFT_TS_SQL + + +def test_redshift_timestamptz_handler_emits_exact_microsecond_sql(): + rs_handlers = DataType_transform_mapping["redshift"][exp.DataType.Type.TIMESTAMPTZ.value] + assert len(rs_handlers) == 1, "expected exactly one TIMESTAMPTZ handler for redshift" + rs_rendered = rs_handlers[0](exp.Column(this="ts_col")).sql(dialect="redshift") + assert rs_rendered == _EXPECTED_REDSHIFT_TS_SQL + + +def test_databricks_and_redshift_timestamp_format_strings_produce_identical_bytes(): + """The two handlers must emit byte-identical format output for the same + instant. We can't execute SQL in a unit test, so we render the canonical + reference timestamp through Python's ``strftime`` with the equivalent + format string the docs guarantee for each engine and assert equality on + the produced wall-clock string. + + Redshift's ``YYYY-MM-DD HH24:MI:SS.US`` and Spark's + ``yyyy-MM-dd HH:mm:ss.SSSSSS`` are documented to produce 6-digit + microsecond precision, so the canonical Python equivalent is + ``%Y-%m-%d %H:%M:%S.%f``. This test guards the on-the-wire byte equality + that the per-row SHA2 hash relies on. + """ + from datetime import datetime, timezone + + reference = datetime(2025, 1, 15, 12, 34, 56, 789012, tzinfo=timezone.utc) + canonical = reference.strftime("%Y-%m-%d %H:%M:%S.%f") + assert canonical == "2025-01-15 12:34:56.789012" + + rs_handlers = DataType_transform_mapping["redshift"][exp.DataType.Type.TIMESTAMPTZ.value] + db_handlers = DataType_transform_mapping["databricks"][exp.DataType.Type.TIMESTAMPTZ.value] + rs_rendered = rs_handlers[0](exp.Column(this="ts_col")).sql(dialect="redshift") + db_rendered = db_handlers[0](exp.Column(this="ts_col")).sql(dialect="databricks") + + assert "'YYYY-MM-DD HH24:MI:SS.US'" in rs_rendered + assert "'yyyy-MM-dd HH:mm:ss.SSSSSS'" in db_rendered + + +def test_redshift_boolean_handler_emits_exact_case_when_sql(): + """Redshift rejects every ``CAST(boolean AS VARCHAR/TEXT)`` form, and the + universal default ``TRIM(...)`` becomes ``btrim(boolean)`` which is + function-not-found. The custom handler must emit lowercase ``'true'`` / + ``'false'`` literals so the bytes match Spark's + ``cast(boolean AS string)`` output and per-row hashes stay aligned. + """ + handlers = DataType_transform_mapping["redshift"][exp.DataType.Type.BOOLEAN.value] + assert len(handlers) == 1, "expected exactly one BOOLEAN handler for redshift" + rendered = handlers[0](exp.Column(this="bool_col")).sql(dialect="redshift") + assert rendered == ( + "COALESCE(CASE WHEN bool_col THEN 'true' WHEN NOT bool_col THEN 'false' ELSE NULL END, " "'_null_recon_')" + ) diff --git a/tests/unit/reconcile/test_recon_capture_fingerprint.py b/tests/unit/reconcile/test_recon_capture_fingerprint.py new file mode 100644 index 0000000000..b6f04355e0 --- /dev/null +++ b/tests/unit/reconcile/test_recon_capture_fingerprint.py @@ -0,0 +1,212 @@ +"""Unit tests for the ``fingerprint_metrics`` SQL projection in ``ReconCapture``. + +String-level tests against the named_struct fragment — no SparkSession or Delta +write loop required. The persisted column shape is the dashboard contract; locking +it here means a stray f-string edit can't silently break it. +""" + +import re + +from databricks.labs.lakebridge.reconcile.fingerprint.metadata import ( + INELIGIBLE_FILTERS_CONFIGURED, + FingerprintRunMetadata, +) +from databricks.labs.lakebridge.reconcile.recon_capture import ReconCapture + + +_FIELD_ORDER = ( + "eligible", + "ineligibility_reason", + "verdict", + "elapsed_ms", + "solved_count", + "unsolved_sb_count", + "total_mismatched_sbs", + "fallback_to_full_pipeline", + "sub_bucket_count", + "bucket_count", + "target_row_count", + "row_count_source", + "fetch_path", +) + + +def _field_offsets(sql: str) -> list[int]: + return [sql.index(f"'{name}'") for name in _FIELD_ORDER] + + +def test_struct_sql_emits_all_eight_fields_in_declared_order(): + """Field order must match the dataclass declaration so mergeSchema widens the + column to the expected struct shape on first write (Spark resolves struct fields + positionally during saveAsTable). + """ + metadata = FingerprintRunMetadata( + eligible=True, + verdict="MATCH", + elapsed_ms=42, + solved_count=3, + unsolved_sb_count=1, + total_mismatched_sbs=4, + ) + sql = ReconCapture._fingerprint_metrics_struct_sql(metadata) # pylint: disable=protected-access + offsets = _field_offsets(sql) + assert offsets == sorted(offsets), f"Field order drifted in {sql!r}" + + +def test_struct_sql_renders_eligible_match_verdict(): + metadata = FingerprintRunMetadata( + eligible=True, verdict="MATCH", elapsed_ms=120, solved_count=0, unsolved_sb_count=0 + ) + sql = ReconCapture._fingerprint_metrics_struct_sql(metadata) # pylint: disable=protected-access + assert "'eligible', true" in sql + assert "'verdict', 'MATCH'" in sql + assert "'elapsed_ms', cast(120 as bigint)" in sql + assert "'fallback_to_full_pipeline', false" in sql + + +def test_struct_sql_renders_ineligible_with_reason(): + metadata = FingerprintRunMetadata.ineligible(INELIGIBLE_FILTERS_CONFIGURED) + sql = ReconCapture._fingerprint_metrics_struct_sql(metadata) # pylint: disable=protected-access + assert "'eligible', false" in sql + assert f"'ineligibility_reason', '{INELIGIBLE_FILTERS_CONFIGURED}'" in sql + assert "'verdict', NULL" in sql + + +def test_struct_sql_emits_null_for_none_string_fields(): + """None must serialise to SQL NULL, not the literal string 'None' — otherwise + dashboards filtering on IS NULL miss every row. + """ + metadata = FingerprintRunMetadata(eligible=True, verdict=None) + sql = ReconCapture._fingerprint_metrics_struct_sql(metadata) # pylint: disable=protected-access + assert "'verdict', NULL" in sql + assert "'verdict', 'None'" not in sql + assert "'ineligibility_reason', NULL" in sql + + +def test_struct_sql_scrubs_quotes_from_reason_and_verdict(): + """Embedded quotes would break the SQL. Scrubbing mirrors exception_message handling.""" + metadata = FingerprintRunMetadata(eligible=False, ineligibility_reason="bad'reason\"here", verdict="MIS\"MATCH") + sql = ReconCapture._fingerprint_metrics_struct_sql(metadata) # pylint: disable=protected-access + assert "bad'reason" not in sql + assert 'MIS"MATCH' not in sql + assert "'ineligibility_reason', 'badreasonhere'" in sql + assert "'verdict', 'MISMATCH'" in sql + + +def test_struct_sql_casts_counters_to_bigint(): + """Every counter must carry an explicit bigint cast. + + Without one, Spark infers int for small literals and later rows with larger counts + force a slow column-type rewrite. Counters: elapsed_ms, solved_count, + unsolved_sb_count, total_mismatched_sbs, sub_bucket_count, bucket_count, + target_row_count. + """ + metadata = FingerprintRunMetadata( + eligible=True, + elapsed_ms=1, + solved_count=2, + unsolved_sb_count=3, + total_mismatched_sbs=4, + sub_bucket_count=2_097_152, + bucket_count=2_048, + target_row_count=100_000_000, + ) + sql = ReconCapture._fingerprint_metrics_struct_sql(metadata) # pylint: disable=protected-access + bigint_casts = re.findall(r"cast\(\d+ as bigint\)", sql) + assert len(bigint_casts) == 7, f"Expected 7 bigint casts, got {bigint_casts!r}" + + +def test_struct_sql_target_row_count_null_when_static_default_path(): + """target_row_count emits NULL (not cast(0 as bigint)) when the row-count fetch + fell through, so dashboards can distinguish "unavailable" from "actually 0". + """ + metadata = FingerprintRunMetadata( + eligible=True, + verdict="MATCH", + sub_bucket_count=1_048_576, + bucket_count=32_768, + target_row_count=None, + row_count_source="static_default", + ) + sql = ReconCapture._fingerprint_metrics_struct_sql(metadata) # pylint: disable=protected-access + assert "'target_row_count', NULL" in sql + assert "'row_count_source', 'static_default'" in sql + + +def test_struct_sql_emits_tier_fields_for_user_override_path(): + """User override surfaces both the tier values and the user_override provenance.""" + metadata = FingerprintRunMetadata( + eligible=True, + verdict="MATCH", + sub_bucket_count=4_194_304, + bucket_count=4_096, + target_row_count=1_000_000_000, + row_count_source="user_override", + ) + sql = ReconCapture._fingerprint_metrics_struct_sql(metadata) # pylint: disable=protected-access + assert "'sub_bucket_count', cast(4194304 as bigint)" in sql + assert "'bucket_count', cast(4096 as bigint)" in sql + assert "'target_row_count', cast(1000000000 as bigint)" in sql + assert "'row_count_source', 'user_override'" in sql + + +def test_struct_sql_default_metadata_emits_zero_tier_and_null_row_count(): + """Default metadata (ineligible / disabled rows) emits zero tier values and NULL + for row_count, row_count_source, fetch_path. + """ + metadata = FingerprintRunMetadata.disabled() + sql = ReconCapture._fingerprint_metrics_struct_sql(metadata) # pylint: disable=protected-access + assert "'sub_bucket_count', cast(0 as bigint)" in sql + assert "'bucket_count', cast(0 as bigint)" in sql + assert "'target_row_count', NULL" in sql + assert "'row_count_source', NULL" in sql + assert "'fetch_path', NULL" in sql + + +def test_struct_sql_emits_fetch_path_v2_for_redshift_split(): + """The v2_redshift_split value is preserved so historical recon_metrics rows + continue to round-trip; current code never emits it. + """ + metadata = FingerprintRunMetadata( + eligible=True, + verdict="MISMATCH", + sub_bucket_count=2_097_152, + bucket_count=2_048, + target_row_count=100_000_000, + row_count_source="delta_describe_detail", + fetch_path="v2_redshift_split", + ) + sql = ReconCapture._fingerprint_metrics_struct_sql(metadata) # pylint: disable=protected-access + assert "'fetch_path', 'v2_redshift_split'" in sql + + +def test_struct_sql_emits_fetch_path_v1_for_legacy_sandwich(): + metadata = FingerprintRunMetadata( + eligible=True, + verdict="MISMATCH", + sub_bucket_count=2_097_152, + bucket_count=2_048, + target_row_count=100_000_000, + row_count_source="delta_describe_detail", + fetch_path="v1_sandwich", + ) + sql = ReconCapture._fingerprint_metrics_struct_sql(metadata) # pylint: disable=protected-access + assert "'fetch_path', 'v1_sandwich'" in sql + + +def test_struct_sql_match_verdict_emits_null_fetch_path(): + """MATCH never executes Stage-2; fetch_path must stay NULL even when other tier + fields are populated. + """ + metadata = FingerprintRunMetadata( + eligible=True, + verdict="MATCH", + sub_bucket_count=2_097_152, + bucket_count=2_048, + target_row_count=100_000_000, + row_count_source="delta_describe_detail", + fetch_path=None, + ) + sql = ReconCapture._fingerprint_metrics_struct_sql(metadata) # pylint: disable=protected-access + assert "'fetch_path', NULL" in sql + assert "'fetch_path', 'None'" not in sql diff --git a/tests/unit/reconcile/test_recon_capture_fingerprint_typed_schema.py b/tests/unit/reconcile/test_recon_capture_fingerprint_typed_schema.py new file mode 100644 index 0000000000..e35035685b --- /dev/null +++ b/tests/unit/reconcile/test_recon_capture_fingerprint_typed_schema.py @@ -0,0 +1,202 @@ +"""Unit tests for OPT-A-1 — the typed schema layer behind ``_fingerprint_metrics_struct_sql``. + +These tests pin the new contract introduced on 2026-05-09: + - The ``_FP_METRICS_STRUCT_FIELDS`` tuple is the single source of truth. + - Every metadata attribute is covered exactly once. + - Every entry's ``sql_type`` is one of the four supported types. + - ``_render_fp_metrics_value`` is type-checked at the boundary; an unsupported + sql_type raises rather than silently falling through. + - The dataclass field order matches the schema declaration so Delta's positional + struct resolution (saveAsTable) cannot drift. + +Pairs with ``test_recon_capture_fingerprint.py`` (which pins the rendered SQL string +shape — bit-exact preservation across the typed-schema rewrite). The two files share +an intentional duplication of the field-order tuple so a reorder breaks both tests +from different angles ("declaration order" here vs. "SQL emission order" there); +silencing pylint's similarity checker here documents that as deliberate. +""" + +# pylint: disable=duplicate-code + +from __future__ import annotations + +import dataclasses + +import pytest + +from databricks.labs.lakebridge.reconcile.fingerprint.metadata import FingerprintRunMetadata +from databricks.labs.lakebridge.reconcile.recon_capture import ( + _FP_METRICS_STRUCT_FIELDS, + _render_fp_metrics_value, +) + +_ALLOWED_SQL_TYPES = {"bool", "bigint", "bigint_or_null", "string_or_null"} + + +def test_schema_declares_thirteen_fields_in_dataclass_order(): + """The schema tuple must enumerate every persisted field exactly once, in the order + Delta expects on positional struct resolution. Any drift is an MR-blocker. + """ + assert len(_FP_METRICS_STRUCT_FIELDS) == 13, ( + f"Expected 13 fp_metrics fields, got {len(_FP_METRICS_STRUCT_FIELDS)}. " + "If you added a field, update this assertion AND verify the dataclass declaration." + ) + + schema_attrs = [attr for (_sql_field, attr, _sql_type) in _FP_METRICS_STRUCT_FIELDS] + dataclass_attrs = [f.name for f in dataclasses.fields(FingerprintRunMetadata)] + + # Every schema attribute must exist on the dataclass — otherwise getattr() at render + # time would raise AttributeError on a code path with full unit coverage today. + for attr in schema_attrs: + assert attr in dataclass_attrs, ( + f"Schema declares attribute {attr!r} that is NOT on FingerprintRunMetadata. " + "Either remove from schema or add to the dataclass." + ) + + +def test_schema_field_names_unique(): + """A duplicate sql_field_name would render twice in the same named_struct, producing + invalid SQL on the persisted row. Pin uniqueness here so a copy-paste edit is caught + at unit-test time, not at recon runtime. + """ + sql_field_names = [sql_field for (sql_field, _attr, _sql_type) in _FP_METRICS_STRUCT_FIELDS] + assert len(sql_field_names) == len( + set(sql_field_names) + ), f"Duplicate sql_field_name in _FP_METRICS_STRUCT_FIELDS: {sql_field_names!r}" + + schema_attrs = [attr for (_sql_field, attr, _sql_type) in _FP_METRICS_STRUCT_FIELDS] + assert len(schema_attrs) == len( + set(schema_attrs) + ), f"Duplicate dataclass attribute in _FP_METRICS_STRUCT_FIELDS: {schema_attrs!r}" + + +def test_schema_only_uses_allowed_sql_types(): + """The renderer raises on an unknown sql_type. This test is the static check — + if a developer adds a field with sql_type='int' instead of 'bigint', the test fails + at unit-test time before the bad type reaches `spark.sql(...)`. + """ + for sql_field, attr, sql_type in _FP_METRICS_STRUCT_FIELDS: + assert sql_type in _ALLOWED_SQL_TYPES, ( + f"Schema entry {sql_field!r} (attr={attr!r}) declares unsupported sql_type " + f"{sql_type!r}. Allowed: {sorted(_ALLOWED_SQL_TYPES)!r}." + ) + + +def test_render_bool_emits_lowercase_sql_literal(): + assert _render_fp_metrics_value(True, "bool") == "true" + assert _render_fp_metrics_value(False, "bool") == "false" + # Truthy non-bool coerces — guards against a future int-typed attribute being + # accidentally rendered as 'bool' and producing 'True' (capitalised, invalid SQL). + assert _render_fp_metrics_value(1, "bool") == "true" + assert _render_fp_metrics_value(0, "bool") == "false" + + +def test_render_bigint_emits_explicit_cast(): + """Without explicit cast, Spark infers int from small literals and later writes + with larger counts force a slow column-type rewrite. Cast is mandatory. + """ + assert _render_fp_metrics_value(0, "bigint") == "cast(0 as bigint)" + assert _render_fp_metrics_value(2_097_152, "bigint") == "cast(2097152 as bigint)" + + +def test_render_bigint_or_null_handles_none(): + assert _render_fp_metrics_value(None, "bigint_or_null") == "NULL" + assert _render_fp_metrics_value(0, "bigint_or_null") == "cast(0 as bigint)" + assert _render_fp_metrics_value(100_000_000, "bigint_or_null") == "cast(100000000 as bigint)" + + +def test_render_string_or_null_handles_none_and_quotes_scrub(): + assert _render_fp_metrics_value(None, "string_or_null") == "NULL" + assert _render_fp_metrics_value("MATCH", "string_or_null") == "'MATCH'" + # Defense-in-depth: embedded quotes scrubbed (mirrors exception_message handling + # elsewhere in recon_capture). + assert _render_fp_metrics_value("bad'reason\"here", "string_or_null") == "'badreasonhere'" + + +def test_render_unknown_sql_type_raises(): + """An unknown sql_type must raise — the renderer never silently falls through to + str(value), which would produce un-cast/un-quoted output that breaks the SQL. + """ + with pytest.raises(ValueError, match="Unsupported sql_type"): + _render_fp_metrics_value(42, "int") + with pytest.raises(ValueError, match="Unsupported sql_type"): + _render_fp_metrics_value("foo", "varchar") + with pytest.raises(ValueError, match="Unsupported sql_type"): + _render_fp_metrics_value(None, "") + + +def test_render_bigint_or_null_rejects_non_none_non_int(): + """``FingerprintRunMetadata`` types these fields as ``int | None``; a stringy value + is a contract violation and must fail loudly. The ``isinstance`` assertion is also + what narrows ``value: object`` for mypy in ``_render_fp_metrics_value``. + """ + with pytest.raises(AssertionError, match="bigint_or_null field expected"): + _render_fp_metrics_value("not-a-number", "bigint_or_null") + + +def test_render_bigint_rejects_stringy_value(): + """Companion of the bigint_or_null contract — non-nullable bigint also rejects strings.""" + with pytest.raises(AssertionError, match="bigint field expected"): + _render_fp_metrics_value("100", "bigint") + + +def test_struct_sql_field_order_pinned_by_schema(): + """Belt-and-braces: pin the exact order so a field reorder in + ``_FP_METRICS_STRUCT_FIELDS`` shows up as an MR diff line, not a silent runtime + behaviour change. + """ + expected_order = ( + "eligible", + "ineligibility_reason", + "verdict", + "elapsed_ms", + "solved_count", + "unsolved_sb_count", + "total_mismatched_sbs", + "fallback_to_full_pipeline", + "sub_bucket_count", + "bucket_count", + "target_row_count", + "row_count_source", + "fetch_path", + ) + actual_order = tuple(sql_field for (sql_field, _attr, _t) in _FP_METRICS_STRUCT_FIELDS) + assert actual_order == expected_order, ( + "Field order drifted in _FP_METRICS_STRUCT_FIELDS. Reordering is a Delta " + "schema-positional break for existing recon_metrics rows. Confirm intent." + ) + + +def test_render_helper_is_module_level_not_method(): + """The renderer is module-level by design — keeping it out of the class makes + OPT-A-1's safety boundary independent of any future ReconCapture refactor that + might break inheritance / decorator behaviour. Pin the import path. + """ + from databricks.labs.lakebridge.reconcile import recon_capture as rc # noqa: PLC0415 + + assert callable(rc._render_fp_metrics_value) + # Not a method — calling without `self` must work. + assert rc._render_fp_metrics_value(True, "bool") == "true" + + +def test_schema_attrs_match_dataclass_fields_one_to_one(): + """Every dataclass field must appear in the schema (otherwise a new metadata + field would silently fail to persist) AND every schema entry must reference a + real dataclass field (otherwise getattr() raises at render time). + """ + schema_attrs = {attr for (_sql_field, attr, _t) in _FP_METRICS_STRUCT_FIELDS} + dataclass_attrs = {f.name for f in dataclasses.fields(FingerprintRunMetadata)} + + # If a dataclass field is added without a schema entry, this fails — forcing the + # author to deliberately decide whether the new field is persisted. + missing_in_schema = dataclass_attrs - schema_attrs + assert not missing_in_schema, ( + f"Dataclass attributes missing from _FP_METRICS_STRUCT_FIELDS: {sorted(missing_in_schema)!r}. " + "Either add to the schema, or document why the field is intentionally non-persistent." + ) + + extra_in_schema = schema_attrs - dataclass_attrs + assert not extra_in_schema, ( + f"Schema references non-existent dataclass attributes: {sorted(extra_in_schema)!r}. " + "Remove from schema or add to FingerprintRunMetadata." + ) diff --git a/tests/unit/reconcile/test_source_adapter.py b/tests/unit/reconcile/test_source_adapter.py index 30609f34cd..d7fd41dfe3 100644 --- a/tests/unit/reconcile/test_source_adapter.py +++ b/tests/unit/reconcile/test_source_adapter.py @@ -57,9 +57,9 @@ def test_create_adapter_for_databricks_dialect_target(): def test_create_adapter_for_redshift_dialect(): spark = create_autospec(DatabricksSession) engine = get_dialect("redshift") - scope = "scope" + connection_name = "uc_conn" - data_source = create_adapter(engine, spark, scope) + data_source = create_adapter(engine, spark, connection_name) assert isinstance(data_source, RedshiftDataSource) diff --git a/tests/unit/reconcile/test_trigger_recon_fingerprint_metadata.py b/tests/unit/reconcile/test_trigger_recon_fingerprint_metadata.py new file mode 100644 index 0000000000..da9b1fd2bf --- /dev/null +++ b/tests/unit/reconcile/test_trigger_recon_fingerprint_metadata.py @@ -0,0 +1,471 @@ +"""Unit tests for the FingerprintRunMetadata produced by the trigger service. + +Covers every branch of ``_run_fingerprint_or_reconcile_data``: ineligible / MATCH / +MISMATCH-with-rows / MISMATCH-fallback / soft-skip / FAILED. Tests stub at +``run_fingerprint_precheck`` and ``_run_reconcile_data``, so no SparkSession needed. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from databricks.labs.lakebridge.config import ( + ReconcileConfig, + SourceConnectionConfig, + TargetConnectionConfig, +) +from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException +from databricks.labs.lakebridge.reconcile.fingerprint.exceptions import UnmappedTargetColumnMappingError +from databricks.labs.lakebridge.reconcile.fingerprint.metadata import ( + INELIGIBLE_FILTERS_CONFIGURED, + INELIGIBLE_FLAG_DISABLED, + INELIGIBLE_NO_JOIN_COLUMNS, + INELIGIBLE_UNMAPPED_TARGET_COLUMN_MAPPING, + INELIGIBLE_UNSUPPORTED_DIALECT, +) +from databricks.labs.lakebridge.reconcile.fingerprint.orchestrator import FingerprintResult +from databricks.labs.lakebridge.reconcile.recon_config import Filters, Table +from databricks.labs.lakebridge.reconcile.recon_output_config import DataReconcileOutput, MismatchOutput +from databricks.labs.lakebridge.reconcile.trigger_recon_service import TriggerReconService + + +def _reconciler(report_type: str = "data") -> MagicMock: + """Smallest Reconciliation mock the trigger reads from.""" + reconciler = MagicMock() + reconciler.report_type = report_type + reconciler.source = MagicMock() + reconciler.target = MagicMock() + reconciler.spark = MagicMock() + reconciler.source_engine = MagicMock() + reconciler.intermediate_persist = MagicMock() + return reconciler + + +def _config(*, flag: bool = True, source: str = "redshift") -> ReconcileConfig: + return ReconcileConfig( + report_type="data", + source=SourceConnectionConfig( + dialect=source, + catalog="dev", + schema="src", + uc_connection_name="conn", + ), + target=TargetConnectionConfig(catalog="tc", schema="ts"), + metadata_config=MagicMock(), + fingerprint_precheck=flag, + ) + + +def _table(**overrides) -> Table: + base = {"source_name": "t", "target_name": "t", "join_columns": ["id"]} + base.update(overrides) + return Table(**base) # type: ignore[arg-type] + + +def _stub_full_pipeline_output() -> DataReconcileOutput: + """Sentinel mismatch_count proves the trigger returned the full-pipeline output + (not the zeroed ``fingerprint_match_output()``). + """ + return DataReconcileOutput(mismatch_count=999, mismatch=MismatchOutput()) + + +@patch.object(TriggerReconService, "_run_reconcile_data") +def test_flag_disabled_records_flag_disabled_reason(mock_full): + mock_full.return_value = _stub_full_pipeline_output() + + output, metadata = TriggerReconService._run_fingerprint_or_reconcile_data( # pylint: disable=protected-access + reconciler=_reconciler(), + reconcile_config=_config(flag=False), + table_conf=_table(), + src_schema=[], + tgt_schema=[], + ) + + assert metadata.eligible is False + assert metadata.ineligibility_reason == INELIGIBLE_FLAG_DISABLED + assert metadata.fallback_to_full_pipeline is False, ( + "Ineligible tables didn't *fall back* — they were never eligible to begin with. " + "Conflating these would inflate fallback rates in dashboards." + ) + assert output.mismatch_count == 999 + mock_full.assert_called_once() + + +@patch.object(TriggerReconService, "_run_reconcile_data") +def test_unsupported_dialect_records_unsupported_reason(mock_full): + mock_full.return_value = _stub_full_pipeline_output() + + _, metadata = TriggerReconService._run_fingerprint_or_reconcile_data( # pylint: disable=protected-access + reconciler=_reconciler(), + reconcile_config=_config(source="snowflake"), + table_conf=_table(), + src_schema=[], + tgt_schema=[], + ) + + assert metadata.ineligibility_reason == INELIGIBLE_UNSUPPORTED_DIALECT + mock_full.assert_called_once() + + +@patch.object(TriggerReconService, "_run_reconcile_data") +def test_no_join_columns_records_no_join_columns_reason(mock_full): + mock_full.return_value = _stub_full_pipeline_output() + + _, metadata = TriggerReconService._run_fingerprint_or_reconcile_data( # pylint: disable=protected-access + reconciler=_reconciler(), + reconcile_config=_config(), + table_conf=Table(source_name="t", target_name="t"), + src_schema=[], + tgt_schema=[], + ) + + assert metadata.ineligibility_reason == INELIGIBLE_NO_JOIN_COLUMNS + + +@patch.object(TriggerReconService, "_run_reconcile_data") +def test_per_table_filter_records_filters_reason(mock_full): + mock_full.return_value = _stub_full_pipeline_output() + + _, metadata = TriggerReconService._run_fingerprint_or_reconcile_data( # pylint: disable=protected-access + reconciler=_reconciler(), + reconcile_config=_config(), + table_conf=_table(filters=Filters(source="x is not null")), + src_schema=[], + tgt_schema=[], + ) + + assert metadata.ineligibility_reason == INELIGIBLE_FILTERS_CONFIGURED + + +@patch("databricks.labs.lakebridge.reconcile.trigger_recon_service.run_fingerprint_precheck") +def test_eligible_match_records_match_verdict_with_elapsed(mock_precheck): + """MATCH must short-circuit the full pipeline AND record verdict + elapsed time. + + The verdict must come from the FingerprintResult (MATCH), not be hard-coded — + otherwise the dashboard can't differentiate MATCH from MISMATCH. + """ + mock_precheck.return_value = FingerprintResult(verdict="MATCH", detection_elapsed_ms=137) + + output, metadata = TriggerReconService._run_fingerprint_or_reconcile_data( # pylint: disable=protected-access + reconciler=_reconciler(), + reconcile_config=_config(), + table_conf=_table(), + src_schema=[], + tgt_schema=[], + ) + + assert metadata.eligible is True + assert metadata.verdict == "MATCH" + assert metadata.elapsed_ms == 137 + assert metadata.fallback_to_full_pipeline is False + # MATCH path returns the synthetic match output, not the full-pipeline sentinel. + assert output.mismatch_count == 0 + assert output.missing_in_src_count == 0 + assert output.missing_in_tgt_count == 0 + + +@patch("databricks.labs.lakebridge.reconcile.trigger_recon_service.build_mismatch_output") +@patch("databricks.labs.lakebridge.reconcile.trigger_recon_service.run_fingerprint_precheck") +def test_eligible_mismatch_with_rows_records_solver_counters(mock_precheck, mock_build): + """MISMATCH with both row sets — solver counters and elapsed must be persisted.""" + src_rows = MagicMock() + src_rows.count.return_value = 5 + tgt_rows = MagicMock() + tgt_rows.count.return_value = 3 + mock_precheck.return_value = FingerprintResult( + verdict="MISMATCH", + source_rows=src_rows, + target_rows=tgt_rows, + solved_count=4, + unsolved_sb_count=2, + total_mismatched_sbs=6, + detection_elapsed_ms=99, + ) + mock_build.return_value = DataReconcileOutput(mismatch_count=4, mismatch=MismatchOutput()) + + _, metadata = TriggerReconService._run_fingerprint_or_reconcile_data( # pylint: disable=protected-access + reconciler=_reconciler(), + reconcile_config=_config(), + table_conf=_table(), + src_schema=[], + tgt_schema=[], + ) + + assert metadata.eligible is True + assert metadata.verdict == "MISMATCH" + assert metadata.elapsed_ms == 99 + assert metadata.solved_count == 4 + assert metadata.unsolved_sb_count == 2 + assert metadata.total_mismatched_sbs == 6 + assert metadata.fallback_to_full_pipeline is False + + +@patch.object(TriggerReconService, "_run_reconcile_data") +@patch("databricks.labs.lakebridge.reconcile.trigger_recon_service.run_fingerprint_precheck") +def test_eligible_mismatch_missing_rows_falls_back_with_metadata(mock_precheck, mock_full): + """Missing rows on a MISMATCH — full pipeline runs, metadata records the fallback.""" + mock_precheck.return_value = FingerprintResult( + verdict="MISMATCH", + source_rows=None, + target_rows=None, + solved_count=1, + unsolved_sb_count=2, + total_mismatched_sbs=3, + detection_elapsed_ms=11, + ) + mock_full.return_value = _stub_full_pipeline_output() + + output, metadata = TriggerReconService._run_fingerprint_or_reconcile_data( # pylint: disable=protected-access + reconciler=_reconciler(), + reconcile_config=_config(), + table_conf=_table(), + src_schema=[], + tgt_schema=[], + ) + + assert metadata.verdict == "MISMATCH" + assert metadata.fallback_to_full_pipeline is True + # Solver counters from the FingerprintResult must still be preserved on the fallback row. + assert metadata.solved_count == 1 + assert metadata.unsolved_sb_count == 2 + assert metadata.total_mismatched_sbs == 3 + # The full-pipeline output must be returned, not a synthetic match. + assert output.mismatch_count == 999 + + +@patch.object(TriggerReconService, "_run_reconcile_data") +@patch("databricks.labs.lakebridge.reconcile.trigger_recon_service.run_fingerprint_precheck") +def test_eligible_precheck_returns_none_records_fallback(mock_precheck, mock_full): + """Soft skip — orchestrator returned None, no exception. verdict stays None.""" + mock_precheck.return_value = None + mock_full.return_value = _stub_full_pipeline_output() + + output, metadata = TriggerReconService._run_fingerprint_or_reconcile_data( # pylint: disable=protected-access + reconciler=_reconciler(), + reconcile_config=_config(), + table_conf=_table(), + src_schema=[], + tgt_schema=[], + ) + + assert metadata.eligible is True + assert metadata.verdict is None + assert metadata.fallback_to_full_pipeline is True + assert output.mismatch_count == 999 + + +@patch.object(TriggerReconService, "_run_reconcile_data") +@patch("databricks.labs.lakebridge.reconcile.trigger_recon_service.run_fingerprint_precheck") +def test_eligible_precheck_failure_records_failed_verdict(mock_precheck, mock_full): + """Pre-check exceptions are swallowed and attributed via verdict=FAILED.""" + mock_precheck.side_effect = DataSourceRuntimeException("simulated jdbc failure") + mock_full.return_value = _stub_full_pipeline_output() + + output, metadata = TriggerReconService._run_fingerprint_or_reconcile_data( # pylint: disable=protected-access + reconciler=_reconciler(), + reconcile_config=_config(), + table_conf=_table(), + src_schema=[], + tgt_schema=[], + ) + + assert metadata.eligible is True + assert metadata.verdict == "FAILED" + assert metadata.fallback_to_full_pipeline is True + assert output.mismatch_count == 999 + + +@pytest.mark.parametrize( + "report_type", + ["data", "all", "row"], +) +@patch.object(TriggerReconService, "_run_reconcile_data") +@patch("databricks.labs.lakebridge.reconcile.trigger_recon_service.run_fingerprint_precheck") +def test_eligible_match_skips_full_pipeline_for_all_data_report_types(mock_precheck, mock_full, report_type): + """MATCH short-circuits the full pipeline for ``data`` / ``all`` / ``row`` alike.""" + mock_precheck.return_value = FingerprintResult(verdict="MATCH", detection_elapsed_ms=1) + + TriggerReconService._run_fingerprint_or_reconcile_data( # pylint: disable=protected-access + reconciler=_reconciler(report_type=report_type), + reconcile_config=_config(), + table_conf=_table(), + src_schema=[], + tgt_schema=[], + ) + + mock_full.assert_not_called() + + +@patch.object(TriggerReconService, "_run_reconcile_data") +@patch("databricks.labs.lakebridge.reconcile.trigger_recon_service.run_fingerprint_precheck") +def test_pyspark_exception_during_precheck_falls_back_with_failed_verdict(mock_precheck, mock_full): + """``compute_target_fingerprint`` materialises a Spark plan at action time; + ``AnalysisException`` (a ``PySparkException``) raised there must NOT crash + the recon. The trigger catches it, falls back to the full pipeline, and + records verdict=FAILED so dashboards can quantify the precheck's reliability. + + Without this catch a column-resolution failure (typical when the user has a + column-name mismatch and no ``column_mapping``) would propagate up and + crash ``_do_recon_one``. + """ + from pyspark.errors import PySparkException + + # PySparkException's __init__ requires a registered error_class; subclassing + # avoids that constraint and keeps the test focused on the catch widening. + class _SimulatedAnalysisException(PySparkException): # noqa: D401 + def __init__(self, msg: str) -> None: # pylint: disable=super-init-not-called + self._message = msg + + def __str__(self) -> str: + return self._message + + mock_precheck.side_effect = _SimulatedAnalysisException( + "AnalysisException: column 'foo' does not exist on target Delta table" + ) + mock_full.return_value = _stub_full_pipeline_output() + + output, metadata = TriggerReconService._run_fingerprint_or_reconcile_data( # pylint: disable=protected-access + reconciler=_reconciler(), + reconcile_config=_config(), + table_conf=_table(), + src_schema=[], + tgt_schema=[], + ) + + assert metadata.eligible is True + assert metadata.verdict == "FAILED" + assert metadata.fallback_to_full_pipeline is True + assert output.mismatch_count == 999, "Full-pipeline output (sentinel mismatch_count=999) must be returned" + + +@patch.object(TriggerReconService, "_run_reconcile_data") +@patch("databricks.labs.lakebridge.reconcile.trigger_recon_service.run_fingerprint_precheck") +def test_target_row_count_override_threads_from_reconcile_config(mock_precheck, mock_full): + """``ReconcileConfig.fingerprint_row_count_override`` is the single + configuration entry point for tier-pinning; the trigger must thread its + value into ``run_fingerprint_precheck`` so the orchestrator's + ``_select_tier`` receives it. + """ + mock_precheck.return_value = FingerprintResult(verdict="MATCH", detection_elapsed_ms=1) + mock_full.return_value = _stub_full_pipeline_output() + + config = _config() + config.fingerprint_row_count_override = 250_000_000 + + TriggerReconService._run_fingerprint_or_reconcile_data( # pylint: disable=protected-access + reconciler=_reconciler(), + reconcile_config=config, + table_conf=_table(), + src_schema=[], + tgt_schema=[], + ) + + _, call_kwargs = mock_precheck.call_args + assert call_kwargs["target_row_count_override"] == 250_000_000 + + +@patch.object(TriggerReconService, "_run_reconcile_data") +@patch("databricks.labs.lakebridge.reconcile.trigger_recon_service.run_fingerprint_precheck") +def test_target_row_count_override_defaults_to_none(mock_precheck, mock_full): + """The default value carried by ``ReconcileConfig`` propagates as ``None`` so + legacy configs keep the Delta DESCRIBE DETAIL heuristic behaviour.""" + mock_precheck.return_value = FingerprintResult(verdict="MATCH", detection_elapsed_ms=1) + mock_full.return_value = _stub_full_pipeline_output() + + TriggerReconService._run_fingerprint_or_reconcile_data( # pylint: disable=protected-access + reconciler=_reconciler(), + reconcile_config=_config(), + table_conf=_table(), + src_schema=[], + tgt_schema=[], + ) + + _, call_kwargs = mock_precheck.call_args + assert call_kwargs["target_row_count_override"] is None + + +@patch.object(TriggerReconService, "_run_reconcile_data") +@patch("databricks.labs.lakebridge.reconcile.trigger_recon_service.build_mismatch_output") +@patch("databricks.labs.lakebridge.reconcile.trigger_recon_service.run_fingerprint_precheck") +def test_pyspark_exception_during_mismatch_output_falls_back_to_full_pipeline(mock_precheck, mock_build, mock_full): + """``build_mismatch_output`` runs Spark actions on the prefetched src/tgt + frames. A Spark failure there (column resolution, NPE in compare layer, + etc.) must mirror the fail-open pattern used by every other non-MATCH + branch: fall through to the standard full pipeline so the table still + gets a real recon answer, and record on the metadata that the + precheck-built output was rejected. Asserting the FAILED verdict (the + pre-fail-open behaviour) would hide the regression that prompted the + review feedback. + """ + from pyspark.errors import PySparkException + + class _SimulatedAnalysisException(PySparkException): # noqa: D401 + def __init__(self, msg: str) -> None: # pylint: disable=super-init-not-called + self._message = msg + + def __str__(self) -> str: + return self._message + + mock_precheck.return_value = FingerprintResult( + verdict="MISMATCH", + source_rows=MagicMock(), + target_rows=MagicMock(), + detection_elapsed_ms=10, + solved_count=1, + ) + mock_build.side_effect = _SimulatedAnalysisException("AnalysisException: column resolution failed") + full_output = _stub_full_pipeline_output() + mock_full.return_value = full_output + + output, metadata = TriggerReconService._run_fingerprint_or_reconcile_data( # pylint: disable=protected-access + reconciler=_reconciler(), + reconcile_config=_config(), + table_conf=_table(), + src_schema=[], + tgt_schema=[], + ) + + assert metadata.eligible is True + assert metadata.verdict == "MISMATCH", ( + "Verdict must reflect the precheck signal, not the build-output failure — " + "the failure is recorded via fallback_to_full_pipeline=True." + ) + assert metadata.fallback_to_full_pipeline is True + assert ( + output is full_output + ), "Output must come from the full pipeline so the customer still gets a real recon answer." + mock_full.assert_called_once() + + +@patch.object(TriggerReconService, "_run_reconcile_data") +@patch("databricks.labs.lakebridge.reconcile.trigger_recon_service.run_fingerprint_precheck") +def test_unmapped_target_column_mapping_records_typed_ineligibility(mock_precheck, mock_full): + """An ``UnmappedTargetColumnMappingError`` from ``align_columns`` (a config-time + issue, not a precheck failure) must surface as + ``ineligibility_reason='unmapped_target_column_mapping'`` on the persisted metric — + not as a silent ``None`` fallback. Adoption queries against + ``recon_metrics.fingerprint_metrics.ineligibility_reason`` rely on the typed value + to quantify column-mapping drift. + """ + mock_precheck.side_effect = UnmappedTargetColumnMappingError( + "column_mapping target 'tgt_a_typo' (mapped from 'src_a') not found in target schema" + ) + mock_full.return_value = _stub_full_pipeline_output() + + output, metadata = TriggerReconService._run_fingerprint_or_reconcile_data( # pylint: disable=protected-access + reconciler=_reconciler(), + reconcile_config=_config(), + table_conf=_table(), + src_schema=[], + tgt_schema=[], + ) + + assert metadata.eligible is False, "Config-time ineligibility, not a precheck failure" + assert metadata.ineligibility_reason == INELIGIBLE_UNMAPPED_TARGET_COLUMN_MAPPING + assert metadata.ineligibility_reason == "unmapped_target_column_mapping" + assert metadata.verdict is None + assert metadata.fallback_to_full_pipeline is False, ( + "fallback_to_full_pipeline tracks runtime fallbacks; ineligible runs use the full " + "pipeline by definition and shouldn't double-count on this flag." + ) + assert output.mismatch_count == 999, "Full-pipeline output (sentinel mismatch_count=999) must be returned" diff --git a/tests/unit/test_install.py b/tests/unit/test_install.py index 1bb03334ce..d862a32c09 100644 --- a/tests/unit/test_install.py +++ b/tests/unit/test_install.py @@ -635,7 +635,7 @@ def test_configure_reconcile_installation_config_error_continue_install(ws: Work "schema": "reconcile", "volume": "reconcile_volume", }, - "version": 2, + "version": 3, } } ) @@ -706,7 +706,7 @@ def test_configure_reconcile_installation_config_error_continue_install(ws: Work "schema": "reconcile", "volume": "reconcile_volume", }, - "version": 2, + "version": 3, }, ) @@ -792,7 +792,7 @@ def test_configure_reconcile_no_existing_installation(ws: WorkspaceClient) -> No "schema": "reconcile", "volume": "reconcile_volume", }, - "version": 2, + "version": 3, }, ) @@ -875,7 +875,7 @@ def test_configure_reconcile_databricks_no_existing_installation(ws: WorkspaceCl "schema": "reconcile", "volume": "reconcile_volume", }, - "version": 2, + "version": 3, }, ) @@ -938,7 +938,7 @@ def test_configure_all_override_installation( # FIXME "schema": "reconcile", "volume": "reconcile_volume", }, - "version": 2, + "version": 3, }, } ) @@ -1051,7 +1051,7 @@ def test_configure_all_override_installation( # FIXME "schema": "reconcile", "volume": "reconcile_volume", }, - "version": 2, + "version": 3, }, ) diff --git a/tests/unit/test_reconcile_config_migration.py b/tests/unit/test_reconcile_config_migration.py new file mode 100644 index 0000000000..0859d750b4 --- /dev/null +++ b/tests/unit/test_reconcile_config_migration.py @@ -0,0 +1,122 @@ +"""ReconcileConfig schema migrations. + +Upstream's ``v1_migrate`` flattens ``database_config`` + ``data_source`` into the +``source`` / ``target`` connection structs. ``v2_migrate`` is the fingerprint +addition — it introduces the source-agnostic ``fingerprint_precheck`` flag and +folds older internal field names (``redshift_fingerprint_precheck``, +``use_fingerprint_precheck``) into it on the way through. +""" + +from databricks.labs.lakebridge.config import ReconcileConfig + + +def _v1_payload(extra: dict | None = None) -> dict: + raw = { + "data_source": "redshift", + "report_type": "all", + "secret_scope": "scope", + "database_config": { + "source_catalog": "dev", + "source_schema": "public", + "target_catalog": "c", + "target_schema": "s", + }, + "metadata_config": {"catalog": "m", "schema": "r", "volume": "v"}, + "version": 1, + } + if extra: + raw.update(extra) + return raw + + +def test_v1_migrate_advances_to_v2_with_source_target_structs(): + """Upstream's existing v1→v2 must keep working: ``database_config`` and + ``data_source`` are flattened into ``source`` / ``target`` and v2 schema is + declared. We only inherit this — we don't override it.""" + migrated = ReconcileConfig.v1_migrate(_v1_payload()) + assert migrated["version"] == 2 + assert migrated["source"]["dialect"] == "redshift" + assert migrated["source"]["catalog"] == "dev" + assert migrated["source"]["schema"] == "public" + assert migrated["source"]["uc_connection_name"] == "TODO" + assert migrated["target"]["catalog"] == "c" + assert migrated["target"]["schema"] == "s" + + +def test_v2_migrate_adds_fingerprint_precheck_default_false(): + """v2 configs that predate the fingerprint flag should keep their behaviour: + no flag set, default ``False`` — i.e. ``v2_migrate`` is a no-op for callers + who never opted in.""" + raw = { + "report_type": "all", + "source": {"dialect": "redshift", "catalog": "dev", "schema": "public", "uc_connection_name": "rs_conn"}, + "target": {"catalog": "c", "schema": "s"}, + "metadata_config": {"catalog": "m", "schema": "r", "volume": "v"}, + "version": 2, + } + migrated = ReconcileConfig.v2_migrate(dict(raw)) + assert migrated["version"] == 3 + assert "fingerprint_precheck" not in migrated # default applied at dataclass load time + + +def test_v2_migrate_renames_redshift_fingerprint_precheck(): + """Long-lived internal deployments that set ``redshift_fingerprint_precheck`` on + a v2 config (pre-rename) must round-trip the value to ``fingerprint_precheck``.""" + raw = { + "report_type": "all", + "source": {"dialect": "redshift", "catalog": "dev", "schema": "public", "uc_connection_name": "rs_conn"}, + "target": {"catalog": "c", "schema": "s"}, + "metadata_config": {"catalog": "m", "schema": "r", "volume": "v"}, + "redshift_fingerprint_precheck": True, + "version": 2, + } + migrated = ReconcileConfig.v2_migrate(dict(raw)) + assert migrated["version"] == 3 + assert migrated["fingerprint_precheck"] is True + assert "redshift_fingerprint_precheck" not in migrated + + +def test_v2_migrate_renames_use_fingerprint_precheck(): + """Same round-trip semantics for the older ``use_fingerprint_precheck`` name.""" + raw = { + "report_type": "all", + "source": {"dialect": "redshift", "catalog": "dev", "schema": "public", "uc_connection_name": "rs_conn"}, + "target": {"catalog": "c", "schema": "s"}, + "metadata_config": {"catalog": "m", "schema": "r", "volume": "v"}, + "use_fingerprint_precheck": True, + "version": 2, + } + migrated = ReconcileConfig.v2_migrate(dict(raw)) + assert migrated["version"] == 3 + assert migrated["fingerprint_precheck"] is True + assert "use_fingerprint_precheck" not in migrated + + +def test_v2_migrate_preserves_explicit_new_flag(): + """If both legacy and new keys are present (operator hand-edit), the explicit + new key wins and the legacy keys are dropped.""" + raw = { + "report_type": "all", + "source": {"dialect": "redshift", "catalog": "dev", "schema": "public", "uc_connection_name": "rs_conn"}, + "target": {"catalog": "c", "schema": "s"}, + "metadata_config": {"catalog": "m", "schema": "r", "volume": "v"}, + "fingerprint_precheck": False, + "redshift_fingerprint_precheck": True, + "version": 2, + } + migrated = ReconcileConfig.v2_migrate(dict(raw)) + assert migrated["fingerprint_precheck"] is False + assert "redshift_fingerprint_precheck" not in migrated + + +def test_full_migration_chain_v1_to_v3_carries_legacy_flag_through(): + """End-to-end: a v1 config carrying ``use_fingerprint_precheck`` (theoretically + possible from a long-lived internal deployment) must end up at v3 with the flag + surfaced as ``fingerprint_precheck``. Documents the chain ``v1_migrate → v2_migrate``.""" + raw = _v1_payload({"use_fingerprint_precheck": True}) + after_v1 = ReconcileConfig.v1_migrate(dict(raw)) + final = ReconcileConfig.v2_migrate(after_v1) + assert final["version"] == 3 + assert final["fingerprint_precheck"] is True + assert "use_fingerprint_precheck" not in final + assert "redshift_fingerprint_precheck" not in final