diff --git a/src/databricks/labs/lakebridge/reconcile/query_builder/aggregate_query.py b/src/databricks/labs/lakebridge/reconcile/query_builder/aggregate_query.py index b4a802f833..3f516e3fec 100644 --- a/src/databricks/labs/lakebridge/reconcile/query_builder/aggregate_query.py +++ b/src/databricks/labs/lakebridge/reconcile/query_builder/aggregate_query.py @@ -58,6 +58,22 @@ def _get_mapping_cols_with_alias(self, cols_list: list[str], agg_type: str): """ cols_with_mapping: list[exp.Expression] = [] for col in cols_list: + # Special-case the literal star for COUNT(*). + # The column-name normalizer would render "*" as the backtick-quoted + # identifier `*`, producing COUNT(`*`) which fails at SQL analysis. + # NormalizeReconConfigService also rewrites entries in agg_columns to + # the ansi-normalized form, so by the time we get here the value can + # be the literal "*" or the wrapped "`*`". Bypass identifier handling + # in either case and emit a Star expression so the downstream builder + # produces COUNT(*). + if DialectUtils.unnormalize_identifier(col) == "*" and agg_type.lower() == "count": + cols_with_mapping.append( + exp.Alias( + this=exp.Star(), + alias=exp.Identifier(this=f"{agg_type.lower()}<#>*", quoted=False), + ) + ) + continue column_expr = build_column( this=( self._build_column_name_source_normalized(self._get_mapping_col(col)) diff --git a/tests/unit/reconcile/query_builder/test_aggregate_query.py b/tests/unit/reconcile/query_builder/test_aggregate_query.py new file mode 100644 index 0000000000..b08810ce46 --- /dev/null +++ b/tests/unit/reconcile/query_builder/test_aggregate_query.py @@ -0,0 +1,71 @@ +from databricks.labs.lakebridge.reconcile.query_builder.aggregate_query import AggregateQueryBuilder +from databricks.labs.lakebridge.reconcile.recon_config import Aggregate, Table +from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect + + +def _build_table(aggregates: list[Aggregate]) -> Table: + return Table(source_name="supplier", target_name="target_supplier", aggregates=aggregates) + + +def test_count_star_emits_unquoted_star(fake_databricks_datasource, normalize_config_service): + """Aggregate(agg_columns=["*"], type="count") must produce COUNT(*), not COUNT(`*`).""" + table_conf = _build_table([Aggregate(agg_columns=["*"], type="count")]) + normalized = normalize_config_service.normalize_recon_table_config(table_conf) + + rules = AggregateQueryBuilder( + normalized, [], "source", get_dialect("databricks"), fake_databricks_datasource + ).build_queries() + + assert len(rules) == 1 + sql = rules[0].query + assert "count(*)" in sql.lower() + # Regression: the column-name normalizer must not emit COUNT(`*`) + assert "count(`*`)" not in sql.lower() + # Alias of the aggregate column survives normalization + assert "source_count_*" in sql.lower() + + +def test_count_star_normalized_input_emits_unquoted_star(fake_databricks_datasource, normalize_config_service): + """The same fix must hold when agg_columns is already in ansi-normalized form (`*`).""" + table_conf = _build_table([Aggregate(agg_columns=["`*`"], type="count")]) + normalized = normalize_config_service.normalize_recon_table_config(table_conf) + + rules = AggregateQueryBuilder( + normalized, [], "target", get_dialect("databricks"), fake_databricks_datasource + ).build_queries() + + assert len(rules) == 1 + sql = rules[0].query + assert "count(*)" in sql.lower() + assert "count(`*`)" not in sql.lower() + + +def test_count_star_alongside_named_column(fake_databricks_datasource, normalize_config_service): + """COUNT(*) and COUNT(