Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions dali/python/nvidia/dali/external_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def call_and_feed(self, pipeline, batch_size):
try:
if self.batch:
callback_out = self.callback(*self.callback_args(None))
batch_size = len(callback_out[0]) if self.is_multioutput else len(callback_out)
else:
callback_out = [self.callback(*self.callback_args(i)) for i in range(batch_size)]
self.current_sample += batch_size
Expand All @@ -118,6 +119,8 @@ def call_and_feed(self, pipeline, batch_size):
for op in self.instances:
if self.batch:
data = callback_out[op._output_index]
if len(data) != batch_size:
raise RuntimeError("External source returned outputs with different batch sizes.")
else:
# extract a single output
data = [callback_out[i][op._output_index] for i in range(batch_size)]
Expand All @@ -126,6 +129,7 @@ def call_and_feed(self, pipeline, batch_size):
data = callback_out
op = self.instances[0]
pipeline.feed_input(op._name, data, op._layout, self._cuda_stream, self.use_copy_kernel)
return batch_size

def _is_generator_function(x):
"""Checks whether x is a generator function or a callable object
Expand Down
19 changes: 15 additions & 4 deletions dali/python/nvidia/dali/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,12 +465,16 @@ def _prepare_graph(self, define_graph = None):

def _setup_input_callbacks(self):
from nvidia.dali.external_source import _is_external_source_with_callback
groups = set()
groups_batch = set()
groups_sample = set()
for op in self._ops:
if _is_external_source_with_callback(op):
group = op._group
groups.add(group)
self._input_callbacks = list(groups)
if group.batch:
groups_batch.add(group)
else:
groups_sample.add(group)
self._input_callbacks = list(groups_batch) + list(groups_sample)

def build(self, define_graph = None):
"""Build the pipeline.
Expand Down Expand Up @@ -983,9 +987,16 @@ def _run_input_callbacks(self):
return

stop_iter = False
batch_size = self._max_batch_size
first = True
for group in self._input_callbacks:
try:
group.call_and_feed(self, self._max_batch_size)
actual_batch_size = group.call_and_feed(self, batch_size)
if first:
batch_size = actual_batch_size
first = False
elif actual_batch_size != batch_size:
raise RuntimeError("Batch size inconsistency between ExternalSource operators")
except StopIteration:
stop_iter = True
if stop_iter:
Expand Down
24 changes: 24 additions & 0 deletions dali/test/python/test_external_source_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from test_utils import check_output
import random
from collections import Iterable
import itertools
datapy = np

make_array = np.array
Expand Down Expand Up @@ -699,3 +700,26 @@ def test_iter_setup_zero_copy():
# make it -5 as -1 sometimes works, sometimes not due to being close to the limit
for additional_num_keep_samples in [-4, 0, 1]:
yield _test_iter_setup_zero_copy, use_fn_api, by_name, as_tensor, device, additional_num_keep_samples

def test_external_source_variable_batch_size():
batch_data = [
[[1,2,3],[4,5]],
[[7,8,9,10],[11],[12],[13],[14]],
[[15,16]]
]
batch_data = [[np.array(x) for x in b] for b in batch_data]
sample_data = list(itertools.chain(*batch_data))
pipe = Pipeline(5, 3, 0)
with pipe:
ext_batch = fn.external_source(batch_data, cycle="quiet")
ext_sample = fn.external_source(sample_data, cycle="quiet", batch=False)
pipe.set_outputs(ext_batch, ext_sample)
pipe.build()
for epoch in range(2):
for i in range(len(batch_data)):
batch, sample = pipe.run()
N = len(batch_data[i])
assert len(batch) == N
assert len(sample) == N
check_output((batch,), batch_data[i])
check_output((sample,), batch_data[i])