Skip to content

Commit efdd017

Browse files
spkrkaclaude
andcommitted
feat(sparkey): Deduplicate SparkeyReader instances across DoFn clones
Beam creates one DoFn clone per vCPU thread (e.g. 80 on n4-standard-80). Previously, each clone independently called uri.getReader() which downloads files from GCS and opens new SparkeyReader instances — duplicating work and wasting file descriptors and mmap regions. This adds a static ConcurrentHashMap<String, CompletableFuture<SparkeyReader>> cache so the first thread loads the reader and all others wait on the same future. The reader is reused for the lifetime of the JVM, which is safe for Dataflow batch (one pipeline per JVM). The cache is used by SparkeySideInput, LargeMapSideInput, and LargeSetSideInput. No API changes, no new dependencies. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 1480594 commit efdd017

File tree

1 file changed

+38
-10
lines changed
  • scio-extra/src/main/scala/com/spotify/scio/extra/sparkey

1 file changed

+38
-10
lines changed

scio-extra/src/main/scala/com/spotify/scio/extra/sparkey/package.scala

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.beam.sdk.util.CoderUtils
3232
import org.apache.beam.sdk.values.PCollectionView
3333
import org.slf4j.LoggerFactory
3434

35+
import java.util.concurrent.{CompletableFuture, ConcurrentHashMap}
3536
import scala.util.hashing.MurmurHash3
3637

3738
/**
@@ -545,12 +546,11 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders {
545546
mapFn: SparkeyReader => T
546547
) extends SideInput[T] {
547548
override def updateCacheOnGlobalWindow: Boolean = false
548-
override def get[I, O](context: DoFn[I, O]#ProcessContext): T =
549-
mapFn(
550-
SparkeySideInput.checkMemory(
551-
context.sideInput(view).getReader(RemoteFileUtil.create(context.getPipelineOptions))
552-
)
553-
)
549+
override def get[I, O](context: DoFn[I, O]#ProcessContext): T = {
550+
val uri = context.sideInput(view)
551+
val rfu = RemoteFileUtil.create(context.getPipelineOptions)
552+
mapFn(SparkeySideInput.getOrCreateReader(uri, rfu))
553+
}
554554
}
555555

556556
/**
@@ -561,8 +561,10 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders {
561561
extends SideInput[SparkeyMap[K, V]] {
562562
override def updateCacheOnGlobalWindow: Boolean = false
563563
override def get[I, O](context: DoFn[I, O]#ProcessContext): SparkeyMap[K, V] = {
564+
val uri = context.sideInput(view)
565+
val rfu = RemoteFileUtil.create(context.getPipelineOptions)
564566
new SparkeyMap(
565-
context.sideInput(view).getReader(RemoteFileUtil.create(context.getPipelineOptions)),
567+
SparkeySideInput.getOrCreateReader(uri, rfu),
566568
CoderMaterializer.beam(context.getPipelineOptions, Coder[K]),
567569
CoderMaterializer.beam(context.getPipelineOptions, Coder[V])
568570
)
@@ -576,16 +578,42 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders {
576578
private class LargeSetSideInput[K: Coder](val view: PCollectionView[SparkeyUri])
577579
extends SideInput[SparkeySet[K]] {
578580
override def updateCacheOnGlobalWindow: Boolean = false
579-
override def get[I, O](context: DoFn[I, O]#ProcessContext): SparkeySet[K] =
581+
override def get[I, O](context: DoFn[I, O]#ProcessContext): SparkeySet[K] = {
582+
val uri = context.sideInput(view)
583+
val rfu = RemoteFileUtil.create(context.getPipelineOptions)
580584
new SparkeySet(
581-
context.sideInput(view).getReader(RemoteFileUtil.create(context.getPipelineOptions)),
585+
SparkeySideInput.getOrCreateReader(uri, rfu),
582586
CoderMaterializer.beam(context.getPipelineOptions, Coder[K])
583587
)
588+
}
584589
}
585590

591+
// Readers are cached for the lifetime of the JVM and never closed. This is intentional:
592+
// Beam side inputs have no close/teardown lifecycle, and in batch pipelines the JVM exits
593+
// when the pipeline finishes.
594+
// Note: the cache is keyed by URI path only. If the same path is rewritten with different
595+
// data and a new pipeline is run in the same JVM (e.g. DirectRunner, REPL), stale readers
596+
// will be returned. This is acceptable for Dataflow batch (one pipeline per JVM).
586597
private object SparkeySideInput {
587598
private val logger = LoggerFactory.getLogger(this.getClass)
588-
def checkMemory(reader: SparkeyReader): SparkeyReader = {
599+
600+
private val readerCache =
601+
new ConcurrentHashMap[String, CompletableFuture[SparkeyReader]]()
602+
603+
def getOrCreateReader(uri: SparkeyUri, rfu: RemoteFileUtil): SparkeyReader =
604+
readerCache
605+
.computeIfAbsent(
606+
uri.path,
607+
_ =>
608+
CompletableFuture.supplyAsync { () =>
609+
val reader = uri.getReader(rfu)
610+
checkMemory(reader)
611+
reader
612+
}
613+
)
614+
.join()
615+
616+
private def checkMemory(reader: SparkeyReader): SparkeyReader = {
589617
val memoryBytes = java.lang.management.ManagementFactory.getOperatingSystemMXBean
590618
.asInstanceOf[com.sun.management.OperatingSystemMXBean]
591619
.getTotalPhysicalMemorySize

0 commit comments

Comments
 (0)