Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ def _get_mapping_cols_with_alias(self, cols_list: list[str], agg_type: str):
"""
cols_with_mapping: list[exp.Expression] = []
for col in cols_list:
# Special-case the literal star for COUNT(*).
# The column-name normalizer would render "*" as the backtick-quoted
# identifier `*`, producing COUNT(`*`) which fails at SQL analysis.
# NormalizeReconConfigService also rewrites entries in agg_columns to
# the ansi-normalized form, so by the time we get here the value can
# be the literal "*" or the wrapped "`*`". Bypass identifier handling
# in either case and emit a Star expression so the downstream builder
# produces COUNT(*).
if DialectUtils.unnormalize_identifier(col) == "*" and agg_type.lower() == "count":
cols_with_mapping.append(
exp.Alias(
this=exp.Star(),
alias=exp.Identifier(this=f"{agg_type.lower()}<#>*", quoted=False),
)
)
continue

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed this in the original PR

column_expr = build_column(
this=(
self._build_column_name_source_normalized(self._get_mapping_col(col))
Expand Down
71 changes: 71 additions & 0 deletions tests/unit/reconcile/query_builder/test_aggregate_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from databricks.labs.lakebridge.reconcile.query_builder.aggregate_query import AggregateQueryBuilder
from databricks.labs.lakebridge.reconcile.recon_config import Aggregate, Table
from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect


def _build_table(aggregates: list[Aggregate]) -> Table:
return Table(source_name="supplier", target_name="target_supplier", aggregates=aggregates)


def test_count_star_emits_unquoted_star(fake_databricks_datasource, normalize_config_service):
"""Aggregate(agg_columns=["*"], type="count") must produce COUNT(*), not COUNT(`*`)."""
table_conf = _build_table([Aggregate(agg_columns=["*"], type="count")])
normalized = normalize_config_service.normalize_recon_table_config(table_conf)

rules = AggregateQueryBuilder(
normalized, [], "source", get_dialect("databricks"), fake_databricks_datasource
).build_queries()

assert len(rules) == 1
sql = rules[0].query
assert "count(*)" in sql.lower()
# Regression: the column-name normalizer must not emit COUNT(`*`)
assert "count(`*`)" not in sql.lower()
# Alias of the aggregate column survives normalization
assert "source_count_*" in sql.lower()


def test_count_star_normalized_input_emits_unquoted_star(fake_databricks_datasource, normalize_config_service):
"""The same fix must hold when agg_columns is already in ansi-normalized form (`*`)."""
table_conf = _build_table([Aggregate(agg_columns=["`*`"], type="count")])
normalized = normalize_config_service.normalize_recon_table_config(table_conf)

rules = AggregateQueryBuilder(
normalized, [], "target", get_dialect("databricks"), fake_databricks_datasource
).build_queries()

assert len(rules) == 1
sql = rules[0].query
assert "count(*)" in sql.lower()
assert "count(`*`)" not in sql.lower()


def test_count_star_alongside_named_column(fake_databricks_datasource, normalize_config_service):
"""COUNT(*) and COUNT(<col>) must coexist in a single aggregate query."""
table_conf = _build_table([Aggregate(agg_columns=["*", "s_acctbal"], type="count")])
normalized = normalize_config_service.normalize_recon_table_config(table_conf)

rules = AggregateQueryBuilder(
normalized, [], "source", get_dialect("databricks"), fake_databricks_datasource
).build_queries()

sql = rules[0].query.lower()
assert "count(*)" in sql
assert "count(`s_acctbal`)" in sql
# The * branch must not pollute the named-column branch with backticks around *
assert "count(`*`)" not in sql


def test_star_with_non_count_aggregate_is_unchanged(fake_databricks_datasource, normalize_config_service):
"""The fast-path must only apply to type='count'. Other aggregates keep the existing behavior."""
table_conf = _build_table([Aggregate(agg_columns=["*"], type="sum")])
normalized = normalize_config_service.normalize_recon_table_config(table_conf)

rules = AggregateQueryBuilder(
normalized, [], "source", get_dialect("databricks"), fake_databricks_datasource
).build_queries()

sql = rules[0].query.lower()
# Existing path quotes the identifier; we must not silently turn this into sum(*)
assert "sum(*)" not in sql
assert "sum(`*`)" in sql
Loading