Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ val scalaCollectionCompatVersion = "2.14.0"
val scalaMacrosVersion = "2.1.1"
val scalatestVersion = "3.2.19"
val shapelessVersion = "2.3.13"
val sparkeyVersion = "3.5.1"
val sparkeyVersion = "3.7.0"
val tensorFlowVersion = "1.1.0"
val tensorFlowMetadataVersion = "1.16.1"
val testContainersVersion = "0.44.1"
Expand Down Expand Up @@ -330,6 +330,17 @@ ThisBuild / mimaBinaryIssueFilters ++= Seq(
),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.beam.sdk.extensions.sorter.BufferedExternalSorter$Options"
),
// Private method sparkeySideInput(String, Function1) replaced by sparkeySideInputWithConfig
ProblemFilters.exclude[IncompatibleMethTypeProblem](
"com.spotify.scio.extra.sparkey.package#SparkeyScioContext.sparkeySideInput$extension"
),
// Scala 2.12 generates different extension method names
ProblemFilters.exclude[DirectMissingMethodProblem](
"com.spotify.scio.extra.sparkey.package#SparkeyScioContext.sparkeySideInput$extension0"
),
ProblemFilters.exclude[DirectMissingMethodProblem](
"com.spotify.scio.extra.sparkey.package#SparkeyScioContext.sparkeySideInput$extension1"
)
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2026 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package com.spotify.scio.extra.sparkey

import com.spotify.sparkey.LoadMode

/**
* Configuration for reading Sparkey side inputs.
*
* @param loadMode
* page cache prefetch mode for mmap-backed shards. Ignored for heap-backed shards.
* @param heapBudgetBytes
* maximum bytes to read into JVM heap across all shards in this side input. Shards are loaded
* largest-first to maximize heap utilization. Shards that don't fit within the remaining budget
* fall back to memory-mapped files. Default is 0 (all mmap, matching current behavior).
*/
case class SparkeyReadConfig(
loadMode: LoadMode = LoadMode.NONE,
heapBudgetBytes: Long = 0
)

object SparkeyReadConfig {
val Default: SparkeyReadConfig = SparkeyReadConfig()
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import org.apache.beam.sdk.options.PipelineOptions

import java.nio.file.Path
import java.util.UUID
import org.slf4j.LoggerFactory

import scala.collection.mutable
import scala.jdk.CollectionConverters._

Expand Down Expand Up @@ -76,11 +78,16 @@ case class SparkeyUri(path: String) {
rfu.download(downloadPaths.asJava).asScala
}

def getReader(rfu: RemoteFileUtil): SparkeyReader = {
def getReader(rfu: RemoteFileUtil): SparkeyReader =
getReader(rfu, SparkeyReadConfig.Default)

def getReader(rfu: RemoteFileUtil, config: SparkeyReadConfig): SparkeyReader = {
if (!isSharded) {
val path =
val file =
if (isLocal) new File(basePath) else downloadRemoteUris(Seq(basePath), rfu).head.toFile
Sparkey.open(path)
val totalSize = Sparkey.getIndexFile(file).length() + Sparkey.getLogFile(file).length()
val useHeap = config.heapBudgetBytes > 0 && totalSize <= config.heapBudgetBytes
Sparkey.reader().file(file).useHeap(useHeap).open()
} else {
val (basePaths, numShards) =
ShardedSparkeyUri.basePathsAndCount(EmptyMatchTreatment.DISALLOW, globExpression)
Expand All @@ -91,7 +98,10 @@ case class SparkeyUri(path: String) {
.map(_.toAbsolutePath.toString.replaceAll("\\.sp[il]$", ""))
.toSet
}
new ShardedSparkeyReader(ShardedSparkeyUri.localReadersByShard(paths), numShards)
new ShardedSparkeyReader(
ShardedSparkeyUri.localReadersByShard(paths, config.heapBudgetBytes),
numShards
)
}
}

Expand Down Expand Up @@ -129,11 +139,75 @@ private[sparkey] object ShardedSparkeyUri {
private[sparkey] def localReadersByShard(
localBasePaths: Iterable[String]
): Map[Short, SparkeyReader] =
localBasePaths.iterator.map { path =>
val (shardIndex, _) = shardsFromPath(path)
val reader = Sparkey.open(new File(path + ".spi"))
(shardIndex, reader)
}.toMap
localReadersByShard(localBasePaths, 0)

private[sparkey] def localReadersByShard(
localBasePaths: Iterable[String],
heapBudgetBytes: Long
): Map[Short, SparkeyReader] = {
if (heapBudgetBytes <= 0) {
// No heap budget — all mmap (current behavior)
localBasePaths.iterator.map { path =>
val (shardIndex, _) = shardsFromPath(path)
val reader = Sparkey.open(new File(path + ".spi"))
(shardIndex, reader)
}.toMap
} else {
// Sort shards largest-first for greedy budget allocation
val shardsWithSize = localBasePaths
.map { path =>
val (shardIndex, _) = shardsFromPath(path)
val file = new File(path + ".spi")
val size = Sparkey.getIndexFile(file).length() + Sparkey.getLogFile(file).length()
(shardIndex, path, size)
}
.toSeq
.sortBy(-_._3)

case class Acc(
readers: List[(Short, SparkeyReader)] = Nil,
remainingBudget: Long = heapBudgetBytes,
heapShards: Int = 0,
mmapShards: Int = 0
)

val result = shardsWithSize.foldLeft(Acc()) { case (acc, (shardIndex, path, size)) =>
val file = new File(path + ".spi")
val useHeap = size <= acc.remainingBudget
val reader = Sparkey.reader().file(file).useHeap(useHeap).open()
if (useHeap) {
acc.copy(
readers = (shardIndex, reader) :: acc.readers,
remainingBudget = acc.remainingBudget - size,
heapShards = acc.heapShards + 1
)
} else {
acc.copy(
readers = (shardIndex, reader) :: acc.readers,
mmapShards = acc.mmapShards + 1
)
}
}

val logger = LoggerFactory.getLogger(classOf[ShardedSparkeyReader])
val totalSize = shardsWithSize.map(_._3).sum
val heapBytes = heapBudgetBytes - result.remainingBudget
logger.info(
"Opened {} shards: {} on heap ({} bytes), {} mmap. " +
"Total data: {} bytes, heap budget: {} bytes",
Array[AnyRef](
Integer.valueOf(result.heapShards + result.mmapShards),
Integer.valueOf(result.heapShards),
java.lang.Long.valueOf(heapBytes),
Integer.valueOf(result.mmapShards),
java.lang.Long.valueOf(totalSize),
java.lang.Long.valueOf(heapBudgetBytes)
): _*
)

result.readers.toMap
}
}

def basePathsAndCount(
emptyMatchTreatment: EmptyMatchTreatment,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.spotify.scio.extra.sparkey.instances

import java.util

import com.spotify.sparkey.{IndexHeader, LogHeader, SparkeyReader}
import com.spotify.sparkey.{IndexHeader, LoadMode, LoadResult, LogHeader, SparkeyReader}

import scala.util.hashing.MurmurHash3
import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -81,6 +81,9 @@ class ShardedSparkeyReader(val sparkeys: Map[Short, SparkeyReader], val numShard
override def iterator(): util.Iterator[SparkeyReader.Entry] =
sparkeys.values.map(_.iterator.asScala).reduce(_ ++ _).asJava

override def load(mode: LoadMode, executor: java.util.concurrent.Executor): LoadResult =
LoadResult.combine(sparkeys.values.map(_.load(mode, executor)).toArray: _*)

override def getLoadedBytes: Long = sparkeys.valuesIterator.map(_.getLoadedBytes).sum

override def getTotalBytes: Long = sparkeys.valuesIterator.map(_.getTotalBytes).sum
Expand Down
102 changes: 85 additions & 17 deletions scio-extra/src/main/scala/com/spotify/scio/extra/sparkey/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders {
/** Enhanced version of [[ScioContext]] with Sparkey methods. */
implicit class SparkeyScioContext(private val self: ScioContext) extends AnyVal {

private def sparkeySideInput[T](basePath: String, mapFn: SparkeyReader => T): SideInput[T] = {
private def sparkeySideInputWithConfig[T](
basePath: String,
mapFn: SparkeyReader => T,
config: SparkeyReadConfig
): SideInput[T] = {
if (self.isTest) {
val id = self.testId.get
val view = TestDataManager
Expand All @@ -126,7 +130,7 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders {
val view: PCollectionView[SparkeyUri] = self
.parallelize(paths)
.applyInternal(View.asSingleton())
new SparkeySideInput(view, mapFn)
new SparkeySideInput(view, mapFn, config)
}
}

Expand All @@ -138,7 +142,21 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders {
*/
@experimental
def sparkeySideInput(basePath: String): SideInput[SparkeyReader] =
sparkeySideInput(basePath, identity)
sparkeySideInputWithConfig(basePath, identity, SparkeyReadConfig.Default)

/**
* Create a SideInput of `SparkeyReader` from a [[SparkeyUri]] base path, to be used with
* [[com.spotify.scio.values.SCollection.withSideInputs SCollection.withSideInputs]]. If the
* provided base path ends with "*", it will be treated as a sharded collection of Sparkey
* files.
*
* @param config
* read configuration including page cache prefetch mode and heap budget for loading shards
* into JVM heap memory instead of memory-mapped files.
*/
@experimental
def sparkeySideInput(basePath: String, config: SparkeyReadConfig): SideInput[SparkeyReader] =
sparkeySideInputWithConfig(basePath, identity, config)

/**
* Create a SideInput of `TypedSparkeyReader` from a [[SparkeyUri]] base path, to be used with
Expand All @@ -152,14 +170,15 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders {
decoder: Array[Byte] => T,
cache: Cache[String, T] = null
): SideInput[TypedSparkeyReader[T]] =
sparkeySideInput(
sparkeySideInputWithConfig(
basePath,
reader =>
new TypedSparkeyReader[T](
reader,
decoder,
Option(cache).getOrElse(Cache.noOp[String, T])
)
),
SparkeyReadConfig.Default
)

/**
Expand All @@ -171,7 +190,11 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders {
basePath: String,
cache: Cache[String, String]
): SideInput[CachedStringSparkeyReader] =
sparkeySideInput(basePath, reader => new CachedStringSparkeyReader(reader, cache))
sparkeySideInputWithConfig(
basePath,
reader => new CachedStringSparkeyReader(reader, cache),
SparkeyReadConfig.Default
)
}

/** Enhanced version of [[com.spotify.scio.values.SCollection SCollection]] with Sparkey methods. */
Expand Down Expand Up @@ -476,7 +499,11 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders {
*/
@experimental
def asSparkeySideInput: SideInput[SparkeyReader] =
new SparkeySideInput(self.applyInternal(View.asSingleton()), identity)
new SparkeySideInput(
self.applyInternal(View.asSingleton()),
identity,
SparkeyReadConfig.Default
)

/**
* Convert this SCollection to a SideInput of `SparkeyReader`, to be used with
Expand All @@ -490,7 +517,8 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders {
): SideInput[TypedSparkeyReader[T]] = {
new SparkeySideInput(
self.applyInternal(View.asSingleton()),
reader => new TypedSparkeyReader[T](reader, decoder, cache)
reader => new TypedSparkeyReader[T](reader, decoder, cache),
SparkeyReadConfig.Default
)
}

Expand All @@ -504,7 +532,8 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders {
def asTypedSparkeySideInput[T](decoder: Array[Byte] => T): SideInput[TypedSparkeyReader[T]] =
new SparkeySideInput(
self.applyInternal(View.asSingleton()),
reader => new TypedSparkeyReader[T](reader, decoder, Cache.noOp)
reader => new TypedSparkeyReader[T](reader, decoder, Cache.noOp),
SparkeyReadConfig.Default
)

/**
Expand All @@ -517,7 +546,8 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders {
): SideInput[CachedStringSparkeyReader] =
new SparkeySideInput(
self.applyInternal(View.asSingleton()),
reader => new CachedStringSparkeyReader(reader, cache)
reader => new CachedStringSparkeyReader(reader, cache),
SparkeyReadConfig.Default
)
}

Expand All @@ -542,15 +572,19 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders {

private class SparkeySideInput[T](
val view: PCollectionView[SparkeyUri],
mapFn: SparkeyReader => T
mapFn: SparkeyReader => T,
config: SparkeyReadConfig
) extends SideInput[T] {
// Binary-compatible constructor (pre-3.6.1)
def this(view: PCollectionView[SparkeyUri], mapFn: SparkeyReader => T) =
this(view, mapFn, SparkeyReadConfig.Default)
override def updateCacheOnGlobalWindow: Boolean = false
override def get[I, O](context: DoFn[I, O]#ProcessContext): T =
mapFn(
SparkeySideInput.checkMemory(
context.sideInput(view).getReader(RemoteFileUtil.create(context.getPipelineOptions))
)
)
override def get[I, O](context: DoFn[I, O]#ProcessContext): T = {
val uri = context.sideInput(view)
val rfu = RemoteFileUtil.create(context.getPipelineOptions)
val reader = SparkeySideInput.getOrCreateReader(uri, rfu, config)
mapFn(reader)
}
}

/**
Expand Down Expand Up @@ -585,6 +619,40 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders {

private object SparkeySideInput {
private val logger = LoggerFactory.getLogger(this.getClass)

private val readerCache =
new java.util.concurrent.ConcurrentHashMap[
String,
java.util.concurrent.CompletableFuture[SparkeyReader]
]()

def getOrCreateReader(
uri: SparkeyUri,
rfu: RemoteFileUtil,
config: SparkeyReadConfig
): SparkeyReader = {
val future = readerCache.computeIfAbsent(
uri.path,
_ =>
// supplyAsync so computeIfAbsent returns quickly (it holds a bucket lock)
java.util.concurrent.CompletableFuture.supplyAsync { () =>
logger.info("Loading sparkey reader for {}", uri.path)
val reader = uri.getReader(rfu, config)
reader.load(config.loadMode)
checkMemory(reader)
reader
}
)
try {
future.get()
} catch {
case e: java.util.concurrent.ExecutionException =>
// Remove failed entry so next attempt can retry
readerCache.remove(uri.path, future)
throw e.getCause
}
}

def checkMemory(reader: SparkeyReader): SparkeyReader = {
val memoryBytes = java.lang.management.ManagementFactory.getOperatingSystemMXBean
.asInstanceOf[com.sun.management.OperatingSystemMXBean]
Expand Down
Loading
Loading