Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses

from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
from databricks.labs.lakebridge.reconcile.recon_config import (
Table,
Aggregate,
Expand Down Expand Up @@ -48,7 +49,22 @@ def _normalize_aggs(self, table: Table):

def _normalize_agg(self, agg: Aggregate) -> Aggregate:
normalized = dataclasses.replace(agg)
normalized.agg_columns = [self.source.normalize_identifier(c).ansi_normalized for c in normalized.agg_columns]

# `*` is not an identifier — it's the SQL star, only valid in COUNT(*). Skip
# identifier normalization (otherwise it would be wrapped in backticks and
# produce invalid SQL like `count(`*`)`). Accept both the raw "*" and the
# ansi-normalized "`*`" form on input, and store as the raw "*" downstream.
def _is_star(col: str) -> bool:
return DialectUtils.unnormalize_identifier(col) == "*"

for col in normalized.agg_columns:
if _is_star(col) and normalized.type.lower() != "count":
raise ValueError(
f"Invalid aggregate: '*' is only supported with type='count', got type='{normalized.type}'."
)
normalized.agg_columns = [
"*" if _is_star(c) else self.source.normalize_identifier(c).ansi_normalized for c in normalized.agg_columns
]
normalized.group_by_columns = (
[self.source.normalize_identifier(c).ansi_normalized for c in normalized.group_by_columns]
if normalized.group_by_columns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ def _get_mapping_cols_with_alias(self, cols_list: list[str], agg_type: str):
"""
cols_with_mapping: list[exp.Expression] = []
for col in cols_list:
# The literal `*` (COUNT(*)) is preserved by NormalizeReconConfigService
# rather than wrapped in backticks, so a single equality check is
# sufficient here. Emit a Star expression so sqlglot renders count(*)
# instead of count(`*`).
if col == "*":
cols_with_mapping.append(
exp.Alias(
this=exp.Star(),
alias=exp.Identifier(this=f"{agg_type.lower()}<#>*", quoted=False),
)
)
continue
column_expr = build_column(
this=(
self._build_column_name_source_normalized(self._get_mapping_col(col))
Expand Down
65 changes: 65 additions & 0 deletions tests/unit/reconcile/query_builder/test_aggregate_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest

from databricks.labs.lakebridge.reconcile.query_builder.aggregate_query import AggregateQueryBuilder
from databricks.labs.lakebridge.reconcile.recon_config import Aggregate, Table
from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect


def _build_table(aggregates: list[Aggregate]) -> Table:
return Table(source_name="supplier", target_name="target_supplier", aggregates=aggregates)


def test_count_star_emits_unquoted_star(fake_databricks_datasource, normalize_config_service):
"""Aggregate(agg_columns=["*"], type="count") must produce COUNT(*), not COUNT(`*`)."""
table_conf = _build_table([Aggregate(agg_columns=["*"], type="count")])
normalized = normalize_config_service.normalize_recon_table_config(table_conf)

rules = AggregateQueryBuilder(
normalized, [], "source", get_dialect("databricks"), fake_databricks_datasource
).build_queries()

assert len(rules) == 1
sql = rules[0].query
assert "count(*)" in sql.lower()
# Regression: the column-name normalizer must not emit COUNT(`*`)
assert "count(`*`)" not in sql.lower()
# Alias of the aggregate column survives normalization
assert "source_count_*" in sql.lower()


def test_count_star_normalized_input_emits_unquoted_star(fake_databricks_datasource, normalize_config_service):
"""The same fix must hold when agg_columns is already in ansi-normalized form (`*`)."""
table_conf = _build_table([Aggregate(agg_columns=["`*`"], type="count")])
normalized = normalize_config_service.normalize_recon_table_config(table_conf)

rules = AggregateQueryBuilder(
normalized, [], "target", get_dialect("databricks"), fake_databricks_datasource
).build_queries()

assert len(rules) == 1
sql = rules[0].query
assert "count(*)" in sql.lower()
assert "count(`*`)" not in sql.lower()


def test_count_star_alongside_named_column(fake_databricks_datasource, normalize_config_service):
"""COUNT(*) and COUNT(<col>) must coexist in a single aggregate query."""
table_conf = _build_table([Aggregate(agg_columns=["*", "s_acctbal"], type="count")])
normalized = normalize_config_service.normalize_recon_table_config(table_conf)

rules = AggregateQueryBuilder(
normalized, [], "source", get_dialect("databricks"), fake_databricks_datasource
).build_queries()

sql = rules[0].query.lower()
assert "count(*)" in sql
assert "count(`s_acctbal`)" in sql
# The * branch must not pollute the named-column branch with backticks around *
assert "count(`*`)" not in sql


def test_star_with_non_count_aggregate_raises(normalize_config_service):
"""`*` only makes sense with type='count'. Reject other types early at normalize time."""
table_conf = _build_table([Aggregate(agg_columns=["*"], type="sum")])
with pytest.raises(ValueError, match=r"'\*' is only supported with type='count'"):
normalize_config_service.normalize_recon_table_config(table_conf)
Loading