diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/catalog/AbstractDeltaCatalog.scala b/spark/src/main/scala/org/apache/spark/sql/delta/catalog/AbstractDeltaCatalog.scala index 12bcfaf8cce..e5ce325c9da 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/catalog/AbstractDeltaCatalog.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/catalog/AbstractDeltaCatalog.scala @@ -58,7 +58,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, QualifiedColTyp import org.apache.spark.sql.connector.catalog.{DelegatingCatalogExtension, Identifier, StagedTable, StagingTableCatalog, SupportsWrite, Table, TableCapability, TableCatalog, TableCatalogCapability, TableChange, V1Table} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.catalog.TableChange._ -import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Literal, NamedReference, Transform} +import org.apache.spark.sql.connector.expressions.{ClusterByTransform => SparkClusterByTransform, FieldReference, IdentityTransform, Literal, NamedReference, Transform} import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsTruncate, V1Write, WriteBuilder} import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} import org.apache.spark.sql.internal.SQLConf @@ -789,6 +789,15 @@ class AbstractDeltaCatalog extends DelegatingCatalogExtension } clusterBySpec = Some(ClusterBySpec(columnNames)) + // Spark 4.0+ DataFrameWriterV2.clusterBy() (PySpark / Scala API) passes Spark's real + // ClusterByTransform here instead of Delta's TempClusterByTransform, since OSS Spark now + // implements the API natively. + case SparkClusterByTransform(columnNames) => + if (clusterBySpec.nonEmpty) { + throw SparkException.internalError("Cannot have multiple cluster by transforms.") + } + clusterBySpec = Some(ClusterBySpec(columnNames)) + case transform => throw DeltaErrors.operationNotSupportedException(s"Partitioning by expressions") } diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/catalog/DeltaTableV2.scala b/spark/src/main/scala/org/apache/spark/sql/delta/catalog/DeltaTableV2.scala index 9508477ad28..a9abff0bb80 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/catalog/DeltaTableV2.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/catalog/DeltaTableV2.scala @@ -46,7 +46,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapabi import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.catalog.V1Table -import org.apache.spark.sql.connector.expressions._ +import org.apache.spark.sql.connector.expressions.{ClusterByTransform => SparkClusterByTransform, _} import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, V1Write, WriteBuilder} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -233,9 +233,19 @@ class DeltaTableV2 private( override def schema(): StructType = tableSchema override def partitioning(): Array[Transform] = { - initialSnapshot.metadata.partitionColumns.map { col => - new IdentityTransform(new FieldReference(Seq(col))) - }.toArray + val partitionTransforms = initialSnapshot.metadata.partitionColumns.map { col => + new IdentityTransform(new FieldReference(Seq(col))): Transform + } + // Expose clustering as a Spark `ClusterByTransform` so that callers that compare table + // partitioning to user-provided transforms (e.g. Spark's DataFrameWriter + // `checkPartitioningMatchesV2Table`) can see Delta's clustering columns. Without this, + // `df.write.format("delta").clusterBy(...).mode("append").saveAsTable(t)` fails on any + // existing clustered Delta table because Spark passes a `ClusterByTransform` while Delta + // would otherwise return an empty array. + val clusteringTransforms = clusterBySpec.toSeq.map { spec => + SparkClusterByTransform(spec.columnNames): Transform + } + (partitionTransforms ++ clusteringTransforms).toArray } override def properties(): ju.Map[String, String] = { diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/WriteIntoDelta.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/WriteIntoDelta.scala index 0f71bbb9687..556bc832114 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/WriteIntoDelta.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/WriteIntoDelta.scala @@ -87,7 +87,8 @@ case class WriteIntoDelta( override val data: DataFrame, val catalogTableOpt: Option[CatalogTable] = None, schemaInCatalog: Option[StructType] = None, - isInsertReplaceUsingByName: Boolean = false + isInsertReplaceUsingByName: Boolean = false, + clusterBySpec: Option[ClusterBySpec] = None ) extends LeafRunnableCommand with ImplicitMetadataOperation @@ -109,7 +110,7 @@ case class WriteIntoDelta( } val taggedCommitData = writeAndReturnCommitData( - txn, sparkSession + txn, sparkSession, clusterBySpec ) val operation = DeltaOperations.Write( mode = mode, diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaDataSource.scala b/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaDataSource.scala index 4e839fbb5a5..4aa9c721a0b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaDataSource.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaDataSource.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.delta.commands.{ import org.apache.spark.sql.delta.commands.cdc.CDCReader import org.apache.spark.sql.delta.logging.DeltaLogKeys import org.apache.spark.sql.delta.metering.DeltaLogging +import org.apache.spark.sql.delta.skipping.clustering.ClusteredTableUtils +import org.apache.spark.sql.delta.skipping.clustering.temp.ClusterBySpec import org.apache.spark.sql.delta.util.{PartitionUtils, Utils} import org.apache.hadoop.fs.Path import org.json4s.{Formats, NoTypeHints} @@ -44,6 +46,7 @@ import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, Literal} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{SupportsV1OverwriteWithSaveAsTable, Table, TableProvider} import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode @@ -252,6 +255,21 @@ class DeltaDataSource .map(DeltaDataSource.decodePartitioningColumns) .getOrElse(Nil) + // `df.write.format("delta").clusterBy(...).save(path)` encodes clustering columns into + // options under DataSourceUtils.CLUSTERING_COLUMNS_KEY. Decode them here so the + // resulting Delta table records the clustering columns (create) and so that appending + // to an existing clustered table with mismatching clusterBy fails fast. + val clusterBySpec = parameters.get(DataSourceUtils.CLUSTERING_COLUMNS_KEY) + .map(DataSourceUtils.decodePartitioningColumns) + .filter(_.nonEmpty) + .map(ClusterBySpec.fromColumnNames) + // When a new clustered table is created via `df.write.format("delta").clusterBy(...).save`, + // we must enable the clustering table feature in the protocol; otherwise commit fails with + // DELTA_DOMAIN_METADATA_NOT_SUPPORTED. + val clusteringFeatureProps = clusterBySpec + .map(_ => ClusteredTableUtils.getTableFeatureProperties(Map.empty)) + .getOrElse(Map.empty[String, String]) + val deltaLog = Utils.getDeltaLogFromTableOrPath( sqlContext.sparkSession, catalogTableOpt, new Path(path), parameters) val deltaOptions = new DeltaOptions(parameters, sqlContext.sparkSession.sessionState.conf) @@ -271,11 +289,12 @@ class DeltaDataSource options = deltaOptions, partitionColumns = partitionColumns, configuration = DeltaConfigs.validateConfigurations( - parameters.filterKeys(_.startsWith("delta.")).toMap), + parameters.filterKeys(_.startsWith("delta.")).toMap) ++ clusteringFeatureProps, data = data, // empty catalogTable is acceptable as the code path is only for path based writes // (df.write.save("path")) which does not need to use/update catalog - catalogTableOpt = None) + catalogTableOpt = None, + clusterBySpec = clusterBySpec) val finalWriteCmd = if (deltaOptions.isReplaceOnOrUsingDefined) { DeltaInsertReplaceOnOrUsingCommand.createCmdForSaveAndSaveAsTable( deltaTable = DeltaTableV2( diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableDDLSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableDDLSuite.scala index 0bfd7b98909..dcb61ff1676 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableDDLSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/skipping/clustering/ClusteredTableDDLSuite.scala @@ -1252,8 +1252,117 @@ trait ClusteredTableDDLWithV2Base trait ClusteredTableDDLWithV2 extends ClusteredTableDDLWithV2Base +trait ClusteredTableDataFrameWriterV2Suite + extends ClusteredTableCreateOrReplaceDDLSuite + with SharedSparkSession { + import testImplicits._ + + override protected def supportedClauses: Seq[String] = Seq("CREATE", "REPLACE") + + test("DataFrameWriterV2: clusterBy on create/replace") { + val testTable = this.testTable + withTable(testTable) { + Seq((1, "US", "F"), (2, "DE", "M")).toDF("id", "country", "gender") + .writeTo(testTable).using("delta").clusterBy("country", "gender").create() + verifyClusteringColumns(TableIdentifier(testTable), Seq("country", "gender")) + + Seq((3, "JP", "F")).toDF("id", "country", "gender") + .writeTo(testTable).using("delta").clusterBy("gender").replace() + verifyClusteringColumns(TableIdentifier(testTable), Seq("gender")) + + Seq((4, "FR", "M")).toDF("id", "country", "gender") + .writeTo(testTable).using("delta").clusterBy("country").createOrReplace() + verifyClusteringColumns(TableIdentifier(testTable), Seq("country")) + } + } + + test("DataFrameWriterV2: clusterBy with nested column") { + val testTable = this.testTable + withTable(testTable) { + spark.sql( + s"""CREATE TABLE $testTable (id INT, info STRUCT) USING delta""") + spark.table(testTable).limit(0) + .writeTo(testTable).using("delta").clusterBy("info.city").replace() + verifyClusteringColumns(TableIdentifier(testTable), Seq("info.city")) + } + } + + test("DataFrameWriter v1: saveAsTable create with clusterBy") { + val testTable = this.testTable + withTable(testTable) { + Seq((1, "US", "F"), (2, "DE", "M")).toDF("id", "country", "gender").write + .format("delta").clusterBy("country", "gender").saveAsTable(testTable) + verifyClusteringColumns(TableIdentifier(testTable), Seq("country", "gender")) + } + } + + test("DataFrameWriter v1: save(path) create with clusterBy") { + withTempDir { dir => + val path = dir.getCanonicalPath + Seq((1, "US", "F"), (2, "DE", "M")).toDF("id", "country", "gender").write + .format("delta").clusterBy("country", "gender").save(path) + val (_, snapshot) = + DeltaLog.forTableWithSnapshot(spark, new org.apache.hadoop.fs.Path(path)) + val cols = + org.apache.spark.sql.delta.skipping.clustering.ClusteringColumnInfo + .extractLogicalNames(snapshot) + assert(cols === Seq("country", "gender"), + s"expected clustering columns Seq(country, gender), got $cols") + } + } + + test("DataFrameWriter v1: append with matching clusterBy is allowed") { + val testTable = this.testTable + withTable(testTable) { + Seq((1, "US", "F")).toDF("id", "country", "gender").write + .format("delta").clusterBy("country", "gender").saveAsTable(testTable) + Seq((2, "DE", "M")).toDF("id", "country", "gender").write + .format("delta").mode("append").clusterBy("country", "gender").saveAsTable(testTable) + verifyClusteringColumns(TableIdentifier(testTable), Seq("country", "gender")) + assert(spark.table(testTable).count() === 2) + } + } + + test("DataFrameWriter v1: append with mismatching clusterBy throws") { + val testTable = this.testTable + withTable(testTable) { + Seq((1, "US", "F")).toDF("id", "country", "gender").write + .format("delta").clusterBy("country", "gender").saveAsTable(testTable) + val ex = intercept[Exception] { + Seq((2, "DE", "M")).toDF("id", "country", "gender").write + .format("delta").mode("append").clusterBy("gender").saveAsTable(testTable) + } + val msg = Option(ex.getMessage).getOrElse("") + assert( + msg.toLowerCase.contains("cluster") && + (msg.contains("do not match") || msg.contains("does not match") || + msg.contains("mismatch") || msg.contains("different")), + s"expected clustering mismatch error, got: $msg") + } + } + + test("DataFrameWriter v1: save(path) append with mismatching clusterBy throws") { + withTempDir { dir => + val path = dir.getCanonicalPath + Seq((1, "US", "F")).toDF("id", "country", "gender").write + .format("delta").clusterBy("country", "gender").save(path) + // Path-based append does NOT go through Spark's checkPartitioningMatchesV2Table; the + // mismatch must be caught by Delta's own validateClusteringColumnsInSnapshot. + val ex = intercept[Exception] { + Seq((2, "DE", "M")).toDF("id", "country", "gender").write + .format("delta").mode("append").clusterBy("gender").save(path) + } + val msg = Option(ex.getMessage).getOrElse("") + assert( + msg.toLowerCase.contains("cluster"), + s"expected clustering mismatch error, got: $msg") + } + } +} + trait ClusteredTableDDLDataSourceV2SuiteBase extends ClusteredTableDDLWithV2 + with ClusteredTableDataFrameWriterV2Suite with ClusteredTableDDLSuite { test("Create clustered table from external location, " + "location has clustered table, schema not specified, cluster by not specified") {