-
Notifications
You must be signed in to change notification settings - Fork 785
Expand file tree
/
Copy pathSeqeraTaskHandler.groovy
More file actions
426 lines (386 loc) · 16.2 KB
/
SeqeraTaskHandler.groovy
File metadata and controls
426 lines (386 loc) · 16.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
/*
* Copyright 2013-2026, Seqera Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.seqera.executor
import java.nio.file.Path
import groovy.transform.CompileStatic
import groovy.transform.PackageScope
import groovy.util.logging.Slf4j
import io.seqera.executor.Labels
import io.seqera.sched.api.schema.v1a1.AcceleratorType
import io.seqera.sched.api.schema.v1a1.GetTaskLogsResponse
import io.seqera.sched.api.schema.v1a1.NextflowTask
import io.seqera.sched.api.schema.v1a1.ResourceLimit
import io.seqera.sched.api.schema.v1a1.ResourceRequirement
import io.seqera.sched.api.schema.v1a1.Task
import io.seqera.sched.api.schema.v1a1.TaskState as SchedTaskState
import io.seqera.sched.api.schema.v1a1.TaskStatus as SchedTaskStatus
import io.seqera.sched.client.SchedClient
import io.seqera.util.SchemaMapperUtil
import nextflow.cloud.types.CloudMachineInfo
import nextflow.exception.ProcessException
import nextflow.exception.ProcessUnrecoverableException
import nextflow.util.Duration
import nextflow.util.MemoryUnit
import nextflow.fusion.FusionAwareTask
import nextflow.processor.TaskHandler
import nextflow.processor.TaskRun
import nextflow.processor.TaskStatus
import nextflow.trace.TraceRecord
/**
* Task handler for the Seqera scheduler executor.
*
* <p>Manages the lifecycle of a single task submitted to the Seqera scheduler,
* including submission via batch submitter, status polling, completion handling,
* and trace record enrichment with machine info and spot interruption metadata.
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
@Slf4j
@CompileStatic
class SeqeraTaskHandler extends TaskHandler implements FusionAwareTask {
private SchedClient client
private SeqeraExecutor executor
private Path exitFile
private Path outputFile
private Path errorFile
private volatile String taskId
/**
* Cached task state from last describeTask call, used for trace record metadata
*/
private volatile SchedTaskState cachedTaskState
/**
* Cached machine info extracted from task attempts
*/
private volatile CloudMachineInfo machineInfo
SeqeraTaskHandler(TaskRun task, SeqeraExecutor executor) {
super(task)
this.client = executor.getClient()
this.executor = executor
// those files are access via NF runtime, keep based on CloudStoragePath
this.outputFile = task.workDir.resolve(TaskRun.CMD_OUTFILE)
this.errorFile = task.workDir.resolve(TaskRun.CMD_ERRFILE)
this.exitFile = task.workDir.resolve(TaskRun.CMD_EXIT)
}
@Override
void prepareLauncher() {
assert fusionEnabled()
final launcher = fusionLauncher()
launcher.build()
}
@Override
void submit() {
executor.ensureRunCreated()
int cpuShares = (task.config.getCpus() ?: 1) * 1024
int memoryMiB = task.config.getMemory() ? (int) (task.config.getMemory().toBytes() / (1024 * 1024)) : 1024
final resourceReq = new ResourceRequirement()
.cpuShares(cpuShares)
.memoryMiB(memoryMiB)
// add accelerator settings if defined
final accelerator = task.config.getAccelerator()
if( accelerator ) {
// number of accelerators requested, fallback to limit if request is not specified
resourceReq.acceleratorCount(accelerator.request ?: accelerator.limit)
// accelerator type is GPU by default (most common in scientific computing)
resourceReq.acceleratorType(AcceleratorType.GPU)
// specific accelerator model name e.g. "nvidia-tesla-v100", "nvidia-a10g"
if( accelerator.type )
resourceReq.acceleratorName(accelerator.type)
}
// build machine requirement merging config settings with task arch, disk, and snapshot settings
final machineReq = SchemaMapperUtil.toMachineRequirement(
executor.getSeqeraConfig().machineRequirement,
task.getContainerPlatform(),
task.config.getDisk(),
fusionConfig().snapshotsEnabled()
)
// build resource limit from process resourceLimits directive (upper bound for OOM retry scaling)
final resourceLim = toResourceLimit()
// validate container - Seqera executor requires all processes to specify a container image
final container = task.getContainer()
if( !container )
throw new ProcessUnrecoverableException("Process `${task.lazyName()}` failed because the container image was not specified -- the Seqera executor requires all processes define a container image")
// build the scheduler task with all required attributes
final schedTask = new Task()
.name(task.lazyName()) // process name for identification
.image(container) // container image to run
.command(fusionSubmitCli()) // fusion-based command launcher
.environment(getTaskEnvironment()) // fusion + user-configured environment variables
.resourceRequirement(resourceReq) // cpu, memory, accelerators
.resourceLimit(resourceLim) // resource upper bounds for OOM retry
.machineRequirement(machineReq) // machine type and disk requirements
.nextflow(new NextflowTask()
.taskId(task.id?.intValue())
.hash(task.hash?.toString())
.workDir(task.getWorkDirStr()))
// attach per-task resource labels delta (over run-level baseline)
final taskLabels = Labels.toStringMap(task.config.getResourceLabels())
final delta = Labels.delta(taskLabels, executor.runResourceLabels)
if( delta )
schedTask.labels(delta)
log.debug "[SEQERA] Enqueueing task for batch submission: ${schedTask}"
// Enqueue for batch submission - status will be set by setBatchTaskId callback
executor.getBatchSubmitter().submit(this, schedTask)
}
/**
* Build the task environment by merging user-configured environment variables
* with Fusion environment variables. Fusion variables take precedence.
*/
protected Map<String, String> getTaskEnvironment() {
final configEnv = executor.getSeqeraConfig()?.taskEnvironment
final fusionEnv = fusionLauncher().fusionEnv()
if( !configEnv )
return fusionEnv
final result = new LinkedHashMap<String, String>(configEnv)
result.putAll(fusionEnv)
return result
}
/**
* Called by batch submitter after successful batch submission
*/
void setBatchTaskId(String taskId) {
this.taskId = taskId
this.status = TaskStatus.SUBMITTED
log.debug "[SEQERA] Process `${task.lazyName()}` submitted > taskId=$taskId; work-dir=${task.getWorkDirStr()}"
}
/**
* Called by batch submitter when batch submission fails
*/
void onBatchSubmitFailure(Exception cause) {
log.debug "[SEQERA] Batch submission failed for task ${task.lazyName()}: ${cause.message}"
task.error = cause
this.status = TaskStatus.COMPLETED
}
/**
* Build a {@link ResourceLimit} from the process {@code resourceLimits} directive.
* Returns {@code null} if no resource limits are defined.
*/
protected ResourceLimit toResourceLimit() {
final memoryLimit = task.config.getResourceLimit('memory') as MemoryUnit
final cpusLimit = task.config.getResourceLimit('cpus') as Integer
if( !memoryLimit && !cpusLimit )
return null
final result = new ResourceLimit()
if( memoryLimit )
result.memoryMiB((int)(memoryLimit.toBytes() / (1024 * 1024)))
if( cpusLimit )
result.cpuShares(cpusLimit * 1024)
return result
}
protected SchedTaskStatus schedTaskStatus() {
cachedTaskState = client.describeTask(taskId).getTaskState()
return cachedTaskState.getStatus()
}
@Override
boolean checkIfRunning() {
if (isSubmitted()) {
final schedStatus = schedTaskStatus()
log.debug "[SEQERA] checkIfRunning taskId=${taskId}; status=${schedStatus}"
if (isRunningOrTerminated(schedStatus)) {
status = TaskStatus.RUNNING
return true
}
}
return false
}
@Override
boolean checkIfCompleted() {
// Handle batch submission failure - task error was set but never reached RUNNING state
if (task.error && isCompleted()) {
return true
}
if (!isRunning())
return false
final schedStatus = schedTaskStatus()
log.debug "[SEQERA] checkIfCompleted status=${schedStatus}"
if (isTerminated(schedStatus)) {
log.debug "[SEQERA] Process `${task.lazyName()}` - terminated taskId=$taskId; status=$schedStatus"
// finalize the task
task.exitStatus = readExitFile()
if (isFailed(schedStatus)) {
// When no exit code available, get the error message from task state
if (task.exitStatus == Integer.MAX_VALUE) {
final errorMessage = cachedTaskState?.getErrorMessage() ?: "Task failed for unknown reason"
task.error = new ProcessException(errorMessage)
}
final logs = getTaskLogs(taskId)
task.stdout = logs?.stdout ?: outputFile
task.stderr = logs?.stderr ?: errorFile
} else {
task.stdout = outputFile
task.stderr = errorFile
}
status = TaskStatus.COMPLETED
return true
}
return false
}
protected boolean isRunningOrTerminated(SchedTaskStatus status) {
return status == SchedTaskStatus.RUNNING || isTerminated(status)
}
protected boolean isTerminated(SchedTaskStatus status) {
return status in [SchedTaskStatus.SUCCEEDED, SchedTaskStatus.FAILED, SchedTaskStatus.CANCELLED]
}
protected boolean isFailed(SchedTaskStatus status) {
return status == SchedTaskStatus.FAILED
}
protected GetTaskLogsResponse getTaskLogs(String taskId) {
return client.getTaskLogs(taskId)
}
@Override
protected void killTask() {
if( !taskId ) {
log.trace "[SEQERA] Skip cancel - taskId not yet assigned"
return
}
log.debug "[SEQERA] Cancel taskId=${taskId}"
try {
client.cancelTask(taskId)
}
catch (Throwable t) {
log.warn "[SEQERA] Failed to cancel task ${taskId}", t
}
}
@PackageScope
Integer readExitFile() {
try {
final result = exitFile.text as Integer
log.trace "[SEQERA] Read exit file for taskId $taskId; exit=${result}"
return result
}
catch (Exception e) {
log.debug "[SEQERA] Cannot read exit status for task: `${task.lazyName()}` - ${e.message}"
// return MAX_VALUE to signal it was unable to retrieve the exit code
return Integer.MAX_VALUE
}
}
/**
* Get machine info for the task execution from the last task attempt.
* The machine info is cached after first retrieval.
*
* @return CloudMachineInfo containing instance type, zone, and price model, or null if not available
*/
protected CloudMachineInfo getMachineInfo() {
if (machineInfo)
return machineInfo
if (!cachedTaskState)
return null
try {
final attempts = cachedTaskState.getAttempts()
if (!attempts || attempts.isEmpty())
return null
final lastAttempt = attempts.get(attempts.size() - 1)
final lastInfo = lastAttempt.getMachineInfo()
if (!lastInfo)
return null
// Convert Sched API MachineInfo to Nextflow CloudMachineInfo
machineInfo = new CloudMachineInfo(
type: lastInfo.getType(),
zone: lastInfo.getZone(),
priceModel: SchemaMapperUtil.toPriceModel(lastInfo.getPriceModel())
)
log.trace "[SEQERA] taskId=$taskId => machineInfo=$machineInfo"
return machineInfo
}
catch (Exception e) {
log.debug "[SEQERA] Unable to get machine info for taskId=$taskId - ${e.message}"
return null
}
}
/**
* Get the number of spot interruptions for this task.
* This is calculated server-side from task attempts with spot-related stop reasons.
*
* @return the count of spot interruptions, or null if not completed or not available
*/
protected Integer getNumSpotInterruptions() {
if (!taskId || !isCompleted())
return null
if (!cachedTaskState)
return null
return cachedTaskState.getNumSpotInterruptions()
}
/**
* Get the log stream identifier for this task.
*
* @return the log stream ID, or null if not available
*/
protected String getLogStreamId() {
return cachedTaskState?.getLogStreamId()
}
/**
* Get the native backend ID for this task (ECS task ARN or Docker container ID).
*
* @return the native ID from the last task attempt, or null if not available
*/
protected String getNativeId() {
return cachedTaskState?.getId()
}
/**
* Get the allocated resources for this task from the last task attempt.
* Falls back to the resource requirement from the task state if no attempts exist.
*
* @return a map of allocated resource fields, or null if not available
*/
protected Map<String,Object> getResourceAllocation() {
if (!cachedTaskState)
return null
def resources = null
final attempts = cachedTaskState.getAttempts()
if (attempts && !attempts.isEmpty()) {
resources = attempts.get(attempts.size() - 1).getResources()
}
if (!resources) {
resources = cachedTaskState.getResourceAllocation()
}
if (!resources)
return null
final result = new LinkedHashMap<String,Object>()
if (resources.getCpuShares() != null)
result.put('cpuShares', resources.getCpuShares())
if (resources.getMemoryMiB() != null)
result.put('memoryMiB', resources.getMemoryMiB())
if (resources.getAcceleratorCount() != null)
result.put('acceleratorCount', resources.getAcceleratorCount())
if (resources.getAcceleratorType() != null)
result.put('acceleratorType', resources.getAcceleratorType().toString())
if (resources.getAcceleratorName() != null)
result.put('acceleratorName', resources.getAcceleratorName())
if (resources.getTime() != null)
result.put('time', resources.getTime())
return result.isEmpty() ? null : result
}
protected Long getGrantedTime() {
final String time = cachedTaskState?.getResourceAllocation()?.getTime()
return time != null ? Duration.of(time).toMillis() : task.config.getTime()?.toMillis()
}
/**
* Get the trace record for this task, including machine info and spot interruptions metadata.
*
* @return the trace record with additional metadata fields
*/
@Override
TraceRecord getTraceRecord() {
final result = super.getTraceRecord()
result.put('native_id', getNativeId())
result.machineInfo = getMachineInfo()
result.numSpotInterruptions = getNumSpotInterruptions()
result.logStreamId = getLogStreamId()
result.resourceAllocation = getResourceAllocation()
// Override executor name to include cloud backend for cost tracking
result.executorName = "${SeqeraExecutor.SEQERA}/aws"
return result
}
}