diff --git a/build.sbt b/build.sbt index 80c4f7d74e..210ed0eeb7 100644 --- a/build.sbt +++ b/build.sbt @@ -115,7 +115,7 @@ val scalaCollectionCompatVersion = "2.14.0" val scalaMacrosVersion = "2.1.1" val scalatestVersion = "3.2.19" val shapelessVersion = "2.3.13" -val sparkeyVersion = "3.5.1" +val sparkeyVersion = "3.7.0" val tensorFlowVersion = "1.1.0" val tensorFlowMetadataVersion = "1.16.1" val testContainersVersion = "0.44.1" @@ -330,6 +330,17 @@ ThisBuild / mimaBinaryIssueFilters ++= Seq( ), ProblemFilters.exclude[MissingClassProblem]( "org.apache.beam.sdk.extensions.sorter.BufferedExternalSorter$Options" + ), + // Private method sparkeySideInput(String, Function1) replaced by sparkeySideInputWithConfig + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "com.spotify.scio.extra.sparkey.package#SparkeyScioContext.sparkeySideInput$extension" + ), + // Scala 2.12 generates different extension method names + ProblemFilters.exclude[DirectMissingMethodProblem]( + "com.spotify.scio.extra.sparkey.package#SparkeyScioContext.sparkeySideInput$extension0" + ), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "com.spotify.scio.extra.sparkey.package#SparkeyScioContext.sparkeySideInput$extension1" ) ) diff --git a/scio-extra/src/main/scala/com/spotify/scio/extra/sparkey/SparkeyReadConfig.scala b/scio-extra/src/main/scala/com/spotify/scio/extra/sparkey/SparkeyReadConfig.scala new file mode 100644 index 0000000000..9809cec10d --- /dev/null +++ b/scio-extra/src/main/scala/com/spotify/scio/extra/sparkey/SparkeyReadConfig.scala @@ -0,0 +1,39 @@ +/* + * Copyright 2026 Spotify AB. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.spotify.scio.extra.sparkey + +import com.spotify.sparkey.LoadMode + +/** + * Configuration for reading Sparkey side inputs. + * + * @param loadMode + * page cache prefetch mode for mmap-backed shards. Ignored for heap-backed shards. + * @param heapBudgetBytes + * maximum bytes to read into JVM heap across all shards in this side input. Shards are loaded + * largest-first to maximize heap utilization. Shards that don't fit within the remaining budget + * fall back to memory-mapped files. Default is 0 (all mmap, matching current behavior). + */ +case class SparkeyReadConfig( + loadMode: LoadMode = LoadMode.NONE, + heapBudgetBytes: Long = 0 +) + +object SparkeyReadConfig { + val Default: SparkeyReadConfig = SparkeyReadConfig() +} diff --git a/scio-extra/src/main/scala/com/spotify/scio/extra/sparkey/SparkeyUri.scala b/scio-extra/src/main/scala/com/spotify/scio/extra/sparkey/SparkeyUri.scala index f0d4ed4e0d..697c2b735f 100644 --- a/scio-extra/src/main/scala/com/spotify/scio/extra/sparkey/SparkeyUri.scala +++ b/scio-extra/src/main/scala/com/spotify/scio/extra/sparkey/SparkeyUri.scala @@ -29,6 +29,8 @@ import org.apache.beam.sdk.options.PipelineOptions import java.nio.file.Path import java.util.UUID +import org.slf4j.LoggerFactory + import scala.collection.mutable import scala.jdk.CollectionConverters._ @@ -76,11 +78,16 @@ case class SparkeyUri(path: String) { rfu.download(downloadPaths.asJava).asScala } - def getReader(rfu: RemoteFileUtil): SparkeyReader = { + def getReader(rfu: RemoteFileUtil): SparkeyReader = + getReader(rfu, SparkeyReadConfig.Default) + + def getReader(rfu: RemoteFileUtil, config: SparkeyReadConfig): SparkeyReader = { if (!isSharded) { - val path = + val file = if (isLocal) new File(basePath) else downloadRemoteUris(Seq(basePath), rfu).head.toFile - Sparkey.open(path) + val totalSize = Sparkey.getIndexFile(file).length() + Sparkey.getLogFile(file).length() + val useHeap = config.heapBudgetBytes > 0 && totalSize <= config.heapBudgetBytes + Sparkey.reader().file(file).useHeap(useHeap).open() } else { val (basePaths, numShards) = ShardedSparkeyUri.basePathsAndCount(EmptyMatchTreatment.DISALLOW, globExpression) @@ -91,7 +98,10 @@ case class SparkeyUri(path: String) { .map(_.toAbsolutePath.toString.replaceAll("\\.sp[il]$", "")) .toSet } - new ShardedSparkeyReader(ShardedSparkeyUri.localReadersByShard(paths), numShards) + new ShardedSparkeyReader( + ShardedSparkeyUri.localReadersByShard(paths, config.heapBudgetBytes), + numShards + ) } } @@ -129,11 +139,75 @@ private[sparkey] object ShardedSparkeyUri { private[sparkey] def localReadersByShard( localBasePaths: Iterable[String] ): Map[Short, SparkeyReader] = - localBasePaths.iterator.map { path => - val (shardIndex, _) = shardsFromPath(path) - val reader = Sparkey.open(new File(path + ".spi")) - (shardIndex, reader) - }.toMap + localReadersByShard(localBasePaths, 0) + + private[sparkey] def localReadersByShard( + localBasePaths: Iterable[String], + heapBudgetBytes: Long + ): Map[Short, SparkeyReader] = { + if (heapBudgetBytes <= 0) { + // No heap budget — all mmap (current behavior) + localBasePaths.iterator.map { path => + val (shardIndex, _) = shardsFromPath(path) + val reader = Sparkey.open(new File(path + ".spi")) + (shardIndex, reader) + }.toMap + } else { + // Sort shards largest-first for greedy budget allocation + val shardsWithSize = localBasePaths + .map { path => + val (shardIndex, _) = shardsFromPath(path) + val file = new File(path + ".spi") + val size = Sparkey.getIndexFile(file).length() + Sparkey.getLogFile(file).length() + (shardIndex, path, size) + } + .toSeq + .sortBy(-_._3) + + case class Acc( + readers: List[(Short, SparkeyReader)] = Nil, + remainingBudget: Long = heapBudgetBytes, + heapShards: Int = 0, + mmapShards: Int = 0 + ) + + val result = shardsWithSize.foldLeft(Acc()) { case (acc, (shardIndex, path, size)) => + val file = new File(path + ".spi") + val useHeap = size <= acc.remainingBudget + val reader = Sparkey.reader().file(file).useHeap(useHeap).open() + if (useHeap) { + acc.copy( + readers = (shardIndex, reader) :: acc.readers, + remainingBudget = acc.remainingBudget - size, + heapShards = acc.heapShards + 1 + ) + } else { + acc.copy( + readers = (shardIndex, reader) :: acc.readers, + mmapShards = acc.mmapShards + 1 + ) + } + } + + val logger = LoggerFactory.getLogger(classOf[ShardedSparkeyReader]) + val totalSize = shardsWithSize.map(_._3).sum + val heapBytes = heapBudgetBytes - result.remainingBudget + logger.info( + "Opened {} shards: {} on heap ({} bytes), {} mmap. " + + "Total data: {} bytes, heap budget: {} bytes", + Array[AnyRef]( + Integer.valueOf(result.heapShards + result.mmapShards), + Integer.valueOf(result.heapShards), + java.lang.Long.valueOf(heapBytes), + Integer.valueOf(result.mmapShards), + java.lang.Long.valueOf(totalSize), + java.lang.Long.valueOf(heapBudgetBytes) + ): _* + ) + + result.readers.toMap + } + } def basePathsAndCount( emptyMatchTreatment: EmptyMatchTreatment, diff --git a/scio-extra/src/main/scala/com/spotify/scio/extra/sparkey/instances/ShardedSparkeyReader.scala b/scio-extra/src/main/scala/com/spotify/scio/extra/sparkey/instances/ShardedSparkeyReader.scala index af372b2c2d..30b3cf5565 100644 --- a/scio-extra/src/main/scala/com/spotify/scio/extra/sparkey/instances/ShardedSparkeyReader.scala +++ b/scio-extra/src/main/scala/com/spotify/scio/extra/sparkey/instances/ShardedSparkeyReader.scala @@ -18,7 +18,7 @@ package com.spotify.scio.extra.sparkey.instances import java.util -import com.spotify.sparkey.{IndexHeader, LogHeader, SparkeyReader} +import com.spotify.sparkey.{IndexHeader, LoadMode, LoadResult, LogHeader, SparkeyReader} import scala.util.hashing.MurmurHash3 import scala.jdk.CollectionConverters._ @@ -81,6 +81,9 @@ class ShardedSparkeyReader(val sparkeys: Map[Short, SparkeyReader], val numShard override def iterator(): util.Iterator[SparkeyReader.Entry] = sparkeys.values.map(_.iterator.asScala).reduce(_ ++ _).asJava + override def load(mode: LoadMode, executor: java.util.concurrent.Executor): LoadResult = + LoadResult.combine(sparkeys.values.map(_.load(mode, executor)).toArray: _*) + override def getLoadedBytes: Long = sparkeys.valuesIterator.map(_.getLoadedBytes).sum override def getTotalBytes: Long = sparkeys.valuesIterator.map(_.getTotalBytes).sum diff --git a/scio-extra/src/main/scala/com/spotify/scio/extra/sparkey/package.scala b/scio-extra/src/main/scala/com/spotify/scio/extra/sparkey/package.scala index f14b27dcd5..c2d54f7406 100644 --- a/scio-extra/src/main/scala/com/spotify/scio/extra/sparkey/package.scala +++ b/scio-extra/src/main/scala/com/spotify/scio/extra/sparkey/package.scala @@ -113,7 +113,11 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders { /** Enhanced version of [[ScioContext]] with Sparkey methods. */ implicit class SparkeyScioContext(private val self: ScioContext) extends AnyVal { - private def sparkeySideInput[T](basePath: String, mapFn: SparkeyReader => T): SideInput[T] = { + private def sparkeySideInputWithConfig[T]( + basePath: String, + mapFn: SparkeyReader => T, + config: SparkeyReadConfig + ): SideInput[T] = { if (self.isTest) { val id = self.testId.get val view = TestDataManager @@ -126,7 +130,7 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders { val view: PCollectionView[SparkeyUri] = self .parallelize(paths) .applyInternal(View.asSingleton()) - new SparkeySideInput(view, mapFn) + new SparkeySideInput(view, mapFn, config) } } @@ -138,7 +142,21 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders { */ @experimental def sparkeySideInput(basePath: String): SideInput[SparkeyReader] = - sparkeySideInput(basePath, identity) + sparkeySideInputWithConfig(basePath, identity, SparkeyReadConfig.Default) + + /** + * Create a SideInput of `SparkeyReader` from a [[SparkeyUri]] base path, to be used with + * [[com.spotify.scio.values.SCollection.withSideInputs SCollection.withSideInputs]]. If the + * provided base path ends with "*", it will be treated as a sharded collection of Sparkey + * files. + * + * @param config + * read configuration including page cache prefetch mode and heap budget for loading shards + * into JVM heap memory instead of memory-mapped files. + */ + @experimental + def sparkeySideInput(basePath: String, config: SparkeyReadConfig): SideInput[SparkeyReader] = + sparkeySideInputWithConfig(basePath, identity, config) /** * Create a SideInput of `TypedSparkeyReader` from a [[SparkeyUri]] base path, to be used with @@ -152,14 +170,15 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders { decoder: Array[Byte] => T, cache: Cache[String, T] = null ): SideInput[TypedSparkeyReader[T]] = - sparkeySideInput( + sparkeySideInputWithConfig( basePath, reader => new TypedSparkeyReader[T]( reader, decoder, Option(cache).getOrElse(Cache.noOp[String, T]) - ) + ), + SparkeyReadConfig.Default ) /** @@ -171,7 +190,11 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders { basePath: String, cache: Cache[String, String] ): SideInput[CachedStringSparkeyReader] = - sparkeySideInput(basePath, reader => new CachedStringSparkeyReader(reader, cache)) + sparkeySideInputWithConfig( + basePath, + reader => new CachedStringSparkeyReader(reader, cache), + SparkeyReadConfig.Default + ) } /** Enhanced version of [[com.spotify.scio.values.SCollection SCollection]] with Sparkey methods. */ @@ -476,7 +499,11 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders { */ @experimental def asSparkeySideInput: SideInput[SparkeyReader] = - new SparkeySideInput(self.applyInternal(View.asSingleton()), identity) + new SparkeySideInput( + self.applyInternal(View.asSingleton()), + identity, + SparkeyReadConfig.Default + ) /** * Convert this SCollection to a SideInput of `SparkeyReader`, to be used with @@ -490,7 +517,8 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders { ): SideInput[TypedSparkeyReader[T]] = { new SparkeySideInput( self.applyInternal(View.asSingleton()), - reader => new TypedSparkeyReader[T](reader, decoder, cache) + reader => new TypedSparkeyReader[T](reader, decoder, cache), + SparkeyReadConfig.Default ) } @@ -504,7 +532,8 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders { def asTypedSparkeySideInput[T](decoder: Array[Byte] => T): SideInput[TypedSparkeyReader[T]] = new SparkeySideInput( self.applyInternal(View.asSingleton()), - reader => new TypedSparkeyReader[T](reader, decoder, Cache.noOp) + reader => new TypedSparkeyReader[T](reader, decoder, Cache.noOp), + SparkeyReadConfig.Default ) /** @@ -517,7 +546,8 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders { ): SideInput[CachedStringSparkeyReader] = new SparkeySideInput( self.applyInternal(View.asSingleton()), - reader => new CachedStringSparkeyReader(reader, cache) + reader => new CachedStringSparkeyReader(reader, cache), + SparkeyReadConfig.Default ) } @@ -542,15 +572,19 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders { private class SparkeySideInput[T]( val view: PCollectionView[SparkeyUri], - mapFn: SparkeyReader => T + mapFn: SparkeyReader => T, + config: SparkeyReadConfig ) extends SideInput[T] { + // Binary-compatible constructor (pre-3.6.1) + def this(view: PCollectionView[SparkeyUri], mapFn: SparkeyReader => T) = + this(view, mapFn, SparkeyReadConfig.Default) override def updateCacheOnGlobalWindow: Boolean = false - override def get[I, O](context: DoFn[I, O]#ProcessContext): T = - mapFn( - SparkeySideInput.checkMemory( - context.sideInput(view).getReader(RemoteFileUtil.create(context.getPipelineOptions)) - ) - ) + override def get[I, O](context: DoFn[I, O]#ProcessContext): T = { + val uri = context.sideInput(view) + val rfu = RemoteFileUtil.create(context.getPipelineOptions) + val reader = SparkeySideInput.getOrCreateReader(uri, rfu, config) + mapFn(reader) + } } /** @@ -585,6 +619,40 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders { private object SparkeySideInput { private val logger = LoggerFactory.getLogger(this.getClass) + + private val readerCache = + new java.util.concurrent.ConcurrentHashMap[ + String, + java.util.concurrent.CompletableFuture[SparkeyReader] + ]() + + def getOrCreateReader( + uri: SparkeyUri, + rfu: RemoteFileUtil, + config: SparkeyReadConfig + ): SparkeyReader = { + val future = readerCache.computeIfAbsent( + uri.path, + _ => + // supplyAsync so computeIfAbsent returns quickly (it holds a bucket lock) + java.util.concurrent.CompletableFuture.supplyAsync { () => + logger.info("Loading sparkey reader for {}", uri.path) + val reader = uri.getReader(rfu, config) + reader.load(config.loadMode) + checkMemory(reader) + reader + } + ) + try { + future.get() + } catch { + case e: java.util.concurrent.ExecutionException => + // Remove failed entry so next attempt can retry + readerCache.remove(uri.path, future) + throw e.getCause + } + } + def checkMemory(reader: SparkeyReader): SparkeyReader = { val memoryBytes = java.lang.management.ManagementFactory.getOperatingSystemMXBean .asInstanceOf[com.sun.management.OperatingSystemMXBean] diff --git a/scio-extra/src/test/scala/com/spotify/scio/extra/sparkey/SparkeyTest.scala b/scio-extra/src/test/scala/com/spotify/scio/extra/sparkey/SparkeyTest.scala index a8a128f3dd..1c674534bb 100644 --- a/scio-extra/src/test/scala/com/spotify/scio/extra/sparkey/SparkeyTest.scala +++ b/scio-extra/src/test/scala/com/spotify/scio/extra/sparkey/SparkeyTest.scala @@ -639,4 +639,51 @@ class SparkeyTest extends PipelineSpec { result should contain theSameElementsAs expectedOutput } + it should "support heap budget for sharded sparkey" in { + val (_, sparkeyUris) = runWithLocalOutput(_.parallelize(bigSideData).asSparkey(numShards = 4)) + val sparkeyUri = sparkeyUris.head + val rfu = RemoteFileUtil.create(PipelineOptionsFactory.create()) + + // With no heap budget (default) — all mmap + val defaultReader = sparkeyUri.getReader(rfu, SparkeyReadConfig()) + defaultReader.toMap shouldBe bigSideData.toMap + defaultReader.close() + + // With large heap budget — all shards on heap + val heapReader = sparkeyUri.getReader(rfu, SparkeyReadConfig(heapBudgetBytes = Long.MaxValue)) + heapReader.toMap shouldBe bigSideData.toMap + bigSideData.foreach { case (k, v) => + heapReader.getAsString(k) shouldBe v + } + heapReader.close() + + // With tiny heap budget — only smallest shards fit + val tinyReader = sparkeyUri.getReader(rfu, SparkeyReadConfig(heapBudgetBytes = 1)) + tinyReader.toMap shouldBe bigSideData.toMap + tinyReader.close() + + FileUtils.deleteDirectory(new File(sparkeyUri.basePath)) + } + + it should "support heap budget for unsharded sparkey" in { + val (_, sparkeyUris) = runWithLocalOutput(_.parallelize(sideData).asSparkey) + val basePath = sparkeyUris.head.basePath + val sparkeyUri = SparkeyUri(basePath) + val rfu = RemoteFileUtil.create(PipelineOptionsFactory.create()) + + // Large budget — on heap + val heapReader = sparkeyUri.getReader(rfu, SparkeyReadConfig(heapBudgetBytes = Long.MaxValue)) + heapReader.toMap shouldBe sideData.toMap + heapReader.close() + + // Tiny budget — mmap + val mmapReader = sparkeyUri.getReader(rfu, SparkeyReadConfig(heapBudgetBytes = 1)) + mmapReader.toMap shouldBe sideData.toMap + mmapReader.close() + + for (ext <- Seq(".spi", ".spl")) { + new File(basePath + ext).delete() + } + } + }