Skip to content

Commit f775130

Browse files
Add take_events()
1 parent 948dd05 commit f775130

1 file changed

Lines changed: 136 additions & 32 deletions

File tree

faust/streams.py

Lines changed: 136 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ def _iter_ll(self, dir_: _LinkedListDirection) -> Iterator[StreamT]:
206206
seen: Set[StreamT] = set()
207207
while node:
208208
if node in seen:
209-
raise RuntimeError(f"Loop in Stream.{dir_.attr}: Call support!")
209+
raise RuntimeError(
210+
f"Loop in Stream.{dir_.attr}: Call support!")
210211
seen.add(node)
211212
yield node
212213
node = dir_.getter(node)
@@ -270,9 +271,7 @@ def _chain(self, **kwargs: Any) -> StreamT:
270271

271272
def noack(self) -> "StreamT":
272273
"""Create new stream where acks are manual."""
273-
self._next = new_stream = self.clone(
274-
enable_acks=False,
275-
)
274+
self._next = new_stream = self.clone(enable_acks=False, )
276275
return new_stream
277276

278277
async def items(self) -> AsyncIterator[Tuple[K, T_co]]:
@@ -299,7 +298,8 @@ async def events(self) -> AsyncIterable[EventT]:
299298
if self.current_event is not None:
300299
yield self.current_event
301300

302-
async def take(self, max_: int, within: Seconds) -> AsyncIterable[Sequence[T_co]]:
301+
async def take(self, max_: int,
302+
within: Seconds) -> AsyncIterable[Sequence[T_co]]:
303303
"""Buffer n values at a time and yield a list of buffered values.
304304
305305
Arguments:
@@ -340,7 +340,8 @@ async def add_to_buffer(value: T) -> T:
340340
buffer_add(cast(T_co, value))
341341
event = self.current_event
342342
if event is None:
343-
raise RuntimeError("Take buffer found current_event is None")
343+
raise RuntimeError(
344+
"Take buffer found current_event is None")
344345
event_add(event)
345346
if buffer_size() >= max_:
346347
# signal that the buffer is full and should be emptied.
@@ -391,9 +392,102 @@ async def add_to_buffer(value: T) -> T:
391392
self.enable_acks = stream_enable_acks
392393
self._processors.remove(add_to_buffer)
393394

395+
async def take_events(self, max_: int,
396+
within: Seconds) -> AsyncIterable[Sequence[EventT]]:
397+
"""Buffer n events at a time and yield a list of buffered events.
398+
Arguments:
399+
max_: Max number of messages to receive. When more than this
400+
number of messages are received within the specified number of
401+
seconds then we flush the buffer immediately.
402+
within: Timeout for when we give up waiting for another value,
403+
and process the values we have.
404+
Warning: If there's no timeout (i.e. `timeout=None`),
405+
the agent is likely to stall and block buffered events for an
406+
unreasonable length of time(!).
407+
"""
408+
buffer: List[T_co] = []
409+
events: List[EventT] = []
410+
buffer_add = buffer.append
411+
event_add = events.append
412+
buffer_size = buffer.__len__
413+
buffer_full = asyncio.Event()
414+
buffer_consumed = asyncio.Event()
415+
timeout = want_seconds(within) if within else None
416+
stream_enable_acks: bool = self.enable_acks
417+
418+
buffer_consuming: Optional[asyncio.Future] = None
419+
420+
channel_it = aiter(self.channel)
421+
422+
# We add this processor to populate the buffer, and the stream
423+
# is passively consumed in the background (enable_passive below).
424+
async def add_to_buffer(value: T) -> T:
425+
try:
426+
# buffer_consuming is set when consuming buffer after timeout.
427+
nonlocal buffer_consuming
428+
if buffer_consuming is not None:
429+
try:
430+
await buffer_consuming
431+
finally:
432+
buffer_consuming = None
433+
buffer_add(cast(T_co, value))
434+
event = self.current_event
435+
if event is None:
436+
raise RuntimeError(
437+
'Take buffer found current_event is None')
438+
event_add(event)
439+
if buffer_size() >= max_:
440+
# signal that the buffer is full and should be emptied.
441+
buffer_full.set()
442+
# strict wait for buffer to be consumed after buffer full.
443+
# If max is 1000, we are not allowed to return 1001 values.
444+
buffer_consumed.clear()
445+
await self.wait(buffer_consumed)
446+
except CancelledError: # pragma: no cover
447+
raise
448+
except Exception as exc:
449+
self.log.exception('Error adding to take buffer: %r', exc)
450+
await self.crash(exc)
451+
return value
452+
453+
# Disable acks to ensure this method acks manually
454+
# events only after they are consumed by the user
455+
self.enable_acks = False
456+
457+
self.add_processor(add_to_buffer)
458+
self._enable_passive(cast(ChannelT, channel_it))
459+
try:
460+
while not self.should_stop:
461+
# wait until buffer full, or timeout
462+
await self.wait_for_stopped(buffer_full, timeout=timeout)
463+
if buffer:
464+
# make sure background thread does not add new items to
465+
# buffer while we read.
466+
buffer_consuming = self.loop.create_future()
467+
try:
468+
yield list(events)
469+
finally:
470+
buffer.clear()
471+
for event in events:
472+
await self.ack(event)
473+
events.clear()
474+
# allow writing to buffer again
475+
notify(buffer_consuming)
476+
buffer_full.clear()
477+
buffer_consumed.set()
478+
else: # pragma: no cover
479+
pass
480+
else: # pragma: no cover
481+
pass
482+
483+
finally:
484+
# Restore last behaviour of "enable_acks"
485+
self.enable_acks = stream_enable_acks
486+
self._processors.remove(add_to_buffer)
487+
394488
async def take_with_timestamp(
395-
self, max_: int, within: Seconds, timestamp_field_name: str
396-
) -> AsyncIterable[Sequence[T_co]]:
489+
self, max_: int, within: Seconds,
490+
timestamp_field_name: str) -> AsyncIterable[Sequence[T_co]]:
397491
"""Buffer n values at a time and yield a list of buffered values with the
398492
timestamp when the message was added to kafka.
399493
@@ -439,7 +533,8 @@ async def add_to_buffer(value: T) -> T:
439533
value[timestamp_field_name] = event.message.timestamp
440534
buffer_add(value)
441535
if event is None:
442-
raise RuntimeError("Take buffer found current_event is None")
536+
raise RuntimeError(
537+
"Take buffer found current_event is None")
443538
event_add(event)
444539
if buffer_size() >= max_:
445540
# signal that the buffer is full and should be emptied.
@@ -498,9 +593,8 @@ def enumerate(self, start: int = 0) -> AsyncIterable[Tuple[int, T_co]]:
498593
"""
499594
return aenumerate(self, start)
500595

501-
async def noack_take(
502-
self, max_: int, within: Seconds
503-
) -> AsyncIterable[Sequence[T_co]]:
596+
async def noack_take(self, max_: int,
597+
within: Seconds) -> AsyncIterable[Sequence[T_co]]:
504598
"""
505599
Buffer n values at a time and yield a list of buffered values.
506600
:param max_: Max number of messages to receive. When more than this
@@ -543,7 +637,8 @@ async def add_to_buffer(value: T) -> T:
543637
event = self.current_event
544638
buffer_add(cast(T_co, event))
545639
if event is None:
546-
raise RuntimeError("Take buffer found current_event is None")
640+
raise RuntimeError(
641+
"Take buffer found current_event is None")
547642

548643
event_add(event)
549644
if buffer_size() >= max_:
@@ -628,8 +723,7 @@ async def mytask(stream):
628723
return self
629724
if self.concurrency_index is not None:
630725
raise ImproperlyConfigured(
631-
"Agent with concurrency>1 cannot use stream.through!"
632-
)
726+
"Agent with concurrency>1 cannot use stream.through!")
633727
# ridiculous mypy
634728
if isinstance(channel, str):
635729
channelchannel = cast(ChannelT, self.derive_topic(channel))
@@ -638,7 +732,8 @@ async def mytask(stream):
638732

639733
channel_it = aiter(channelchannel)
640734
if self._next is not None:
641-
raise ImproperlyConfigured("Stream is already using group_by/through")
735+
raise ImproperlyConfigured(
736+
"Stream is already using group_by/through")
642737
through = self._chain(channel=channel_it)
643738

644739
async def forward(value: T) -> T:
@@ -649,12 +744,17 @@ async def forward(value: T) -> T:
649744
self._enable_passive(cast(ChannelT, channel_it), declare=True)
650745
return through
651746

652-
def _enable_passive(self, channel: ChannelT, *, declare: bool = False) -> None:
747+
def _enable_passive(self,
748+
channel: ChannelT,
749+
*,
750+
declare: bool = False) -> None:
653751
if not self._passive:
654752
self._passive = True
655753
self.add_future(self._passive_drainer(channel, declare))
656754

657-
async def _passive_drainer(self, channel: ChannelT, declare: bool = False) -> None:
755+
async def _passive_drainer(self,
756+
channel: ChannelT,
757+
declare: bool = False) -> None:
658758
try:
659759
if declare:
660760
await channel.maybe_declare()
@@ -760,13 +860,13 @@ def get_key(withdrawal):
760860
channel: ChannelT
761861
if self.concurrency_index is not None:
762862
raise ImproperlyConfigured(
763-
"Agent with concurrency>1 cannot use stream.group_by!"
764-
)
863+
"Agent with concurrency>1 cannot use stream.group_by!")
765864
if not name:
766865
if isinstance(key, FieldDescriptorT):
767866
name = key.ident
768867
else:
769-
raise TypeError("group_by with callback must set name=topic_suffix")
868+
raise TypeError(
869+
"group_by with callback must set name=topic_suffix")
770870
if topic is not None:
771871
channel = topic
772872
else:
@@ -776,9 +876,10 @@ def get_key(withdrawal):
776876
prefix = self.prefix + "-"
777877
suffix = f"-{name}-repartition"
778878
p = partitions if partitions else self.app.conf.topic_partitions
779-
channel = cast(ChannelT, self.channel).derive(
780-
prefix=prefix, suffix=suffix, partitions=p, internal=True
781-
)
879+
channel = cast(ChannelT, self.channel).derive(prefix=prefix,
880+
suffix=suffix,
881+
partitions=p,
882+
internal=True)
782883
format_key = self._format_key
783884

784885
channel_it = aiter(channel)
@@ -789,7 +890,8 @@ def get_key(withdrawal):
789890
async def repartition(value: T) -> T:
790891
event = self.current_event
791892
if event is None:
792-
raise RuntimeError("Cannot repartition stream with non-topic channel")
893+
raise RuntimeError(
894+
"Cannot repartition stream with non-topic channel")
793895
new_key = await format_key(key, value)
794896
await event.forward(channel, key=new_key)
795897
return value
@@ -1070,10 +1172,9 @@ async def _py_aiter(self) -> AsyncIterator[T_co]:
10701172
tp = message.tp
10711173
offset = message.offset
10721174

1073-
if (
1074-
not self.app.flow_control.is_active()
1075-
or message.generation_id != self.app.consumer_generation_id
1076-
):
1175+
if (not self.app.flow_control.is_active()
1176+
or message.generation_id !=
1177+
self.app.consumer_generation_id):
10771178
value = skipped_value
10781179
self.log.dev(
10791180
"Skipping message %r with generation_id %r because "
@@ -1092,7 +1193,8 @@ async def _py_aiter(self) -> AsyncIterator[T_co]:
10921193
# XXX ugh this should be in the consumer somehow
10931194

10941195
# call Sensors
1095-
sensor_state = on_stream_event_in(tp, offset, self, event)
1196+
sensor_state = on_stream_event_in(
1197+
tp, offset, self, event)
10961198

10971199
# set task-local current_event
10981200
_current_event_contextvar.set(create_ref(event))
@@ -1125,13 +1227,15 @@ async def _py_aiter(self) -> AsyncIterator[T_co]:
11251227
self.current_event = None
11261228
# We want to ack the filtered out message
11271229
# otherwise the lag would increase
1128-
if event is not None and (do_ack or value is skipped_value):
1230+
if event is not None and (do_ack
1231+
or value is skipped_value):
11291232
# This inlines self.ack
11301233
last_stream_to_ack = event.ack()
11311234
message = event.message
11321235
tp = event.message.tp
11331236
offset = event.message.offset
1134-
on_stream_event_out(tp, offset, self, event, sensor_state)
1237+
on_stream_event_out(tp, offset, self, event,
1238+
sensor_state)
11351239
if last_stream_to_ack:
11361240
on_message_out(tp, offset, message)
11371241

0 commit comments

Comments
 (0)