Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions dali/pipeline/executor/async_separated_pipelined_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "dali/pipeline/executor/async_separated_pipelined_executor.h"

#include <algorithm>

namespace dali {

void AsyncSeparatedPipelinedExecutor::RunCPU() {
Expand All @@ -39,14 +41,16 @@ void AsyncSeparatedPipelinedExecutor::Prefetch() {
RunGPU();
}

for (int i = 0; i < queue_sizes_.cpu_size; i++) {
int cpu_only_prefetch_count =
std::max(0, queue_sizes_.cpu_size - queue_sizes_.gpu_size);
for (int i = 0; i < cpu_only_prefetch_count; i++) {
RunCPU();
}
}

int AsyncSeparatedPipelinedExecutor::InputFeedCount(std::string_view op_name) {
(void)graph_->Node(op_name);
return queue_sizes_.cpu_size + queue_sizes_.gpu_size;
return std::max(queue_sizes_.cpu_size, queue_sizes_.gpu_size);
}

} // namespace dali
88 changes: 40 additions & 48 deletions dali/python/nvidia/dali/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2017-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1591,14 +1591,18 @@ def _prefetch(self):
raise RuntimeError("The pipeline was destroyed.")
self._schedule_py_workers()

# We probably need some benchmarking before we remove this code path
if not self._exec_separated:
# A larger separated CPU queue leaves CPU-only iterations after backend
# Prefetch. If a Python source reaches end of epoch, those iterations
# cannot be advanced through Mixed/GPU without feeding more CPU work.
cpu_queue_is_longer = self._cpu_queue_size > self._gpu_queue_size
if not self._exec_separated or cpu_queue_is_longer:
self._legacy_interleaved_prefetch()
return

# The new way: try to run the inputs and then feed them, finally call _pipe.Prefetch()
# If this fails, we just run `_pipe.Run()` a bunch of times. This will likely blow up for
# separated queues, which are not properly supported anyway.
# The new way: try to run the inputs and then feed them, finally call
# _pipe.Prefetch(). If this fails, we just run `_pipe.Run()` a bunch of
# times. This will likely blow up for separated queues, which are not
# properly supported anyway.
iters_fed = 0
self._first_iter = False
iters_fed, success = self._prefetch_inputs()
Expand All @@ -1613,7 +1617,12 @@ def _prefetch(self):
# Running all callbacks at once, then feeding, then running - may affect the performance
# of the 1st iteration.
def _legacy_interleaved_prefetch(self):
for _ in range(self._cpu_queue_size):
prefetch_count = (
max(self._cpu_queue_size, self._gpu_queue_size)
if self._exec_separated
else self._cpu_queue_size
)
Comment thread
JanuszL marked this conversation as resolved.
for _ in range(prefetch_count):
try:
self._first_iter = False
self._iter_setup()
Expand All @@ -1631,7 +1640,7 @@ def _prefetch_inputs(self):

Comment thread
JanuszL marked this conversation as resolved.
Outdated
if success:
if self._exec_separated:
prefetch_count = self._cpu_queue_size + self._gpu_queue_size
prefetch_count = max(self._cpu_queue_size, self._gpu_queue_size)
else:
prefetch_count = self._cpu_queue_size

Expand Down Expand Up @@ -1955,53 +1964,36 @@ def _iter_setup(self):
if iters == 0:
self.iter_setup()

def _run_input_callbacks(self, is_prefetch=False):
def _run_input_callbacks(self):
if self._input_callbacks is None:
return 0, True

done = False
stop_iter = False
iter = 0
while not done and not stop_iter:
done = True
batches = [] # data from external source callbacks is gathered here
for i, group in enumerate(self._parallel_input_callbacks):
try:
count = group.feed_count(self) if is_prefetch else 1
if iter < count:
batches.append(
group.schedule_and_receive(
self, self._py_pool, i, self._max_batch_size, self._epoch_idx
)
)
if iter + 1 < count:
done = False
except StopIteration:
stop_iter = True
for group in self._seq_input_callbacks:
try:
count = group.feed_count(self) if is_prefetch else 1
if iter < count:
batches.append(group.get_batch(self, self._max_batch_size, self._epoch_idx))
if iter + 1 < count:
done = False
except StopIteration:
stop_iter = True

if stop_iter:
return iter, False

batches = [] # data from external source callbacks is gathered here
for i, group in enumerate(self._parallel_input_callbacks):
try:
self.iter_setup()
batches.append(
group.schedule_and_receive(
self, self._py_pool, i, self._max_batch_size, self._epoch_idx
)
)
except StopIteration:
return 0, False
for group in self._seq_input_callbacks:
try:
batches.append(group.get_batch(self, self._max_batch_size, self._epoch_idx))
except StopIteration:
return iter, False
return 0, False

try:
self.iter_setup()
except StopIteration:
return 0, False

# we only fill external source queues when we know that all callbacks succeeded
for batch in batches:
batch.feed()
# we only fill external source queues when we know that all callbacks succeeded
for batch in batches:
batch.feed()

iter += 1
return iter, True
return 1, True

def iter_setup(self):
"""A deprecated method of providing the pipeline with external inputs.
Expand Down
50 changes: 49 additions & 1 deletion dali/test/python/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2017-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,6 +28,7 @@
import weakref
import gc
from webdataset_base import generate_temp_index_file as generate_temp_wds_index
from nose2.tools import params

from test_utils import (
check_batch,
Expand Down Expand Up @@ -1893,6 +1894,53 @@ def my_pipe():
my_pipe(device_id=0, seed=1234, num_threads=3, set_affinity=True, py_num_workers=3)


@params((2, 2), (3, 2), (2, 3))
def test_separated_queue_external_source_drains_prefetched_batches(cpu_size, gpu_size):
batch_size = 4
num_batches = 10
image_pattern = os.path.join(jpeg_folder, "*", "*.jpg")
paths = sorted(glob.glob(image_pattern))[: batch_size * num_batches]
assert len(paths) == batch_size * num_batches

def batches():
for i in range(num_batches):
batch_paths = paths[i * batch_size : (i + 1) * batch_size]
yield [np.fromfile(path, dtype=np.uint8) for path in batch_paths]

@dali.pipeline_def(
batch_size=batch_size,
num_threads=4,
device_id=0,
prefetch_queue_depth={"cpu_size": cpu_size, "gpu_size": gpu_size},
)
def pipe():
encoded = fn.external_source(
source=batches,
batch=True,
cycle="raise",
)
decoded = fn.decoders.image(
encoded,
device="mixed",
output_type=types.RGB,
)
return decoded

p = pipe()
p.build()
for _ in range(num_batches):
out = p.run()[0]
assert len(out) == batch_size
decoded = out.as_cpu()
for sample_idx in range(batch_size):
sample = decoded.at(sample_idx)
assert sample.ndim == 3
assert sample.shape[-1] == 3
assert np.any(sample)
with assert_raises(StopIteration):
p.run()


def test_not_iterable():
import nvidia.dali._utils.hacks as hacks
import collections.abc
Expand Down