diff --git a/src/databricks/labs/lakebridge/reconcile/normalize_recon_config_service.py b/src/databricks/labs/lakebridge/reconcile/normalize_recon_config_service.py index 1d17b94f76..c28995c601 100644 --- a/src/databricks/labs/lakebridge/reconcile/normalize_recon_config_service.py +++ b/src/databricks/labs/lakebridge/reconcile/normalize_recon_config_service.py @@ -1,6 +1,7 @@ import dataclasses from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils from databricks.labs.lakebridge.reconcile.recon_config import ( Table, Aggregate, @@ -48,7 +49,22 @@ def _normalize_aggs(self, table: Table): def _normalize_agg(self, agg: Aggregate) -> Aggregate: normalized = dataclasses.replace(agg) - normalized.agg_columns = [self.source.normalize_identifier(c).ansi_normalized for c in normalized.agg_columns] + + # `*` is not an identifier — it's the SQL star, only valid in COUNT(*). Skip + # identifier normalization (otherwise it would be wrapped in backticks and + # produce invalid SQL like `count(`*`)`). Accept both the raw "*" and the + # ansi-normalized "`*`" form on input, and store as the raw "*" downstream. + def _is_star(col: str) -> bool: + return DialectUtils.unnormalize_identifier(col) == "*" + + for col in normalized.agg_columns: + if _is_star(col) and normalized.type.lower() != "count": + raise ValueError( + f"Invalid aggregate: '*' is only supported with type='count', got type='{normalized.type}'." + ) + normalized.agg_columns = [ + "*" if _is_star(c) else self.source.normalize_identifier(c).ansi_normalized for c in normalized.agg_columns + ] normalized.group_by_columns = ( [self.source.normalize_identifier(c).ansi_normalized for c in normalized.group_by_columns] if normalized.group_by_columns diff --git a/src/databricks/labs/lakebridge/reconcile/query_builder/aggregate_query.py b/src/databricks/labs/lakebridge/reconcile/query_builder/aggregate_query.py index b4a802f833..af62608645 100644 --- a/src/databricks/labs/lakebridge/reconcile/query_builder/aggregate_query.py +++ b/src/databricks/labs/lakebridge/reconcile/query_builder/aggregate_query.py @@ -58,6 +58,18 @@ def _get_mapping_cols_with_alias(self, cols_list: list[str], agg_type: str): """ cols_with_mapping: list[exp.Expression] = [] for col in cols_list: + # The literal `*` (COUNT(*)) is preserved by NormalizeReconConfigService + # rather than wrapped in backticks, so a single equality check is + # sufficient here. Emit a Star expression so sqlglot renders count(*) + # instead of count(`*`). + if col == "*": + cols_with_mapping.append( + exp.Alias( + this=exp.Star(), + alias=exp.Identifier(this=f"{agg_type.lower()}<#>*", quoted=False), + ) + ) + continue column_expr = build_column( this=( self._build_column_name_source_normalized(self._get_mapping_col(col)) diff --git a/tests/unit/reconcile/query_builder/test_aggregate_query.py b/tests/unit/reconcile/query_builder/test_aggregate_query.py new file mode 100644 index 0000000000..6e0a5df4d7 --- /dev/null +++ b/tests/unit/reconcile/query_builder/test_aggregate_query.py @@ -0,0 +1,65 @@ +import pytest + +from databricks.labs.lakebridge.reconcile.query_builder.aggregate_query import AggregateQueryBuilder +from databricks.labs.lakebridge.reconcile.recon_config import Aggregate, Table +from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect + + +def _build_table(aggregates: list[Aggregate]) -> Table: + return Table(source_name="supplier", target_name="target_supplier", aggregates=aggregates) + + +def test_count_star_emits_unquoted_star(fake_databricks_datasource, normalize_config_service): + """Aggregate(agg_columns=["*"], type="count") must produce COUNT(*), not COUNT(`*`).""" + table_conf = _build_table([Aggregate(agg_columns=["*"], type="count")]) + normalized = normalize_config_service.normalize_recon_table_config(table_conf) + + rules = AggregateQueryBuilder( + normalized, [], "source", get_dialect("databricks"), fake_databricks_datasource + ).build_queries() + + assert len(rules) == 1 + sql = rules[0].query + assert "count(*)" in sql.lower() + # Regression: the column-name normalizer must not emit COUNT(`*`) + assert "count(`*`)" not in sql.lower() + # Alias of the aggregate column survives normalization + assert "source_count_*" in sql.lower() + + +def test_count_star_normalized_input_emits_unquoted_star(fake_databricks_datasource, normalize_config_service): + """The same fix must hold when agg_columns is already in ansi-normalized form (`*`).""" + table_conf = _build_table([Aggregate(agg_columns=["`*`"], type="count")]) + normalized = normalize_config_service.normalize_recon_table_config(table_conf) + + rules = AggregateQueryBuilder( + normalized, [], "target", get_dialect("databricks"), fake_databricks_datasource + ).build_queries() + + assert len(rules) == 1 + sql = rules[0].query + assert "count(*)" in sql.lower() + assert "count(`*`)" not in sql.lower() + + +def test_count_star_alongside_named_column(fake_databricks_datasource, normalize_config_service): + """COUNT(*) and COUNT() must coexist in a single aggregate query.""" + table_conf = _build_table([Aggregate(agg_columns=["*", "s_acctbal"], type="count")]) + normalized = normalize_config_service.normalize_recon_table_config(table_conf) + + rules = AggregateQueryBuilder( + normalized, [], "source", get_dialect("databricks"), fake_databricks_datasource + ).build_queries() + + sql = rules[0].query.lower() + assert "count(*)" in sql + assert "count(`s_acctbal`)" in sql + # The * branch must not pollute the named-column branch with backticks around * + assert "count(`*`)" not in sql + + +def test_star_with_non_count_aggregate_raises(normalize_config_service): + """`*` only makes sense with type='count'. Reject other types early at normalize time.""" + table_conf = _build_table([Aggregate(agg_columns=["*"], type="sum")]) + with pytest.raises(ValueError, match=r"'\*' is only supported with type='count'"): + normalize_config_service.normalize_recon_table_config(table_conf)