@@ -206,7 +206,8 @@ def _iter_ll(self, dir_: _LinkedListDirection) -> Iterator[StreamT]:
206206 seen : Set [StreamT ] = set ()
207207 while node :
208208 if node in seen :
209- raise RuntimeError (f"Loop in Stream.{ dir_ .attr } : Call support!" )
209+ raise RuntimeError (
210+ f"Loop in Stream.{ dir_ .attr } : Call support!" )
210211 seen .add (node )
211212 yield node
212213 node = dir_ .getter (node )
@@ -270,9 +271,7 @@ def _chain(self, **kwargs: Any) -> StreamT:
270271
271272 def noack (self ) -> "StreamT" :
272273 """Create new stream where acks are manual."""
273- self ._next = new_stream = self .clone (
274- enable_acks = False ,
275- )
274+ self ._next = new_stream = self .clone (enable_acks = False , )
276275 return new_stream
277276
278277 async def items (self ) -> AsyncIterator [Tuple [K , T_co ]]:
@@ -299,7 +298,8 @@ async def events(self) -> AsyncIterable[EventT]:
299298 if self .current_event is not None :
300299 yield self .current_event
301300
302- async def take (self , max_ : int , within : Seconds ) -> AsyncIterable [Sequence [T_co ]]:
301+ async def take (self , max_ : int ,
302+ within : Seconds ) -> AsyncIterable [Sequence [T_co ]]:
303303 """Buffer n values at a time and yield a list of buffered values.
304304
305305 Arguments:
@@ -340,7 +340,8 @@ async def add_to_buffer(value: T) -> T:
340340 buffer_add (cast (T_co , value ))
341341 event = self .current_event
342342 if event is None :
343- raise RuntimeError ("Take buffer found current_event is None" )
343+ raise RuntimeError (
344+ "Take buffer found current_event is None" )
344345 event_add (event )
345346 if buffer_size () >= max_ :
346347 # signal that the buffer is full and should be emptied.
@@ -391,9 +392,102 @@ async def add_to_buffer(value: T) -> T:
391392 self .enable_acks = stream_enable_acks
392393 self ._processors .remove (add_to_buffer )
393394
395+ async def take_events (self , max_ : int ,
396+ within : Seconds ) -> AsyncIterable [Sequence [EventT ]]:
397+ """Buffer n events at a time and yield a list of buffered events.
398+ Arguments:
399+ max_: Max number of messages to receive. When more than this
400+ number of messages are received within the specified number of
401+ seconds then we flush the buffer immediately.
402+ within: Timeout for when we give up waiting for another value,
403+ and process the values we have.
404+ Warning: If there's no timeout (i.e. `timeout=None`),
405+ the agent is likely to stall and block buffered events for an
406+ unreasonable length of time(!).
407+ """
408+ buffer : List [T_co ] = []
409+ events : List [EventT ] = []
410+ buffer_add = buffer .append
411+ event_add = events .append
412+ buffer_size = buffer .__len__
413+ buffer_full = asyncio .Event ()
414+ buffer_consumed = asyncio .Event ()
415+ timeout = want_seconds (within ) if within else None
416+ stream_enable_acks : bool = self .enable_acks
417+
418+ buffer_consuming : Optional [asyncio .Future ] = None
419+
420+ channel_it = aiter (self .channel )
421+
422+ # We add this processor to populate the buffer, and the stream
423+ # is passively consumed in the background (enable_passive below).
424+ async def add_to_buffer (value : T ) -> T :
425+ try :
426+ # buffer_consuming is set when consuming buffer after timeout.
427+ nonlocal buffer_consuming
428+ if buffer_consuming is not None :
429+ try :
430+ await buffer_consuming
431+ finally :
432+ buffer_consuming = None
433+ buffer_add (cast (T_co , value ))
434+ event = self .current_event
435+ if event is None :
436+ raise RuntimeError (
437+ 'Take buffer found current_event is None' )
438+ event_add (event )
439+ if buffer_size () >= max_ :
440+ # signal that the buffer is full and should be emptied.
441+ buffer_full .set ()
442+ # strict wait for buffer to be consumed after buffer full.
443+ # If max is 1000, we are not allowed to return 1001 values.
444+ buffer_consumed .clear ()
445+ await self .wait (buffer_consumed )
446+ except CancelledError : # pragma: no cover
447+ raise
448+ except Exception as exc :
449+ self .log .exception ('Error adding to take buffer: %r' , exc )
450+ await self .crash (exc )
451+ return value
452+
453+ # Disable acks to ensure this method acks manually
454+ # events only after they are consumed by the user
455+ self .enable_acks = False
456+
457+ self .add_processor (add_to_buffer )
458+ self ._enable_passive (cast (ChannelT , channel_it ))
459+ try :
460+ while not self .should_stop :
461+ # wait until buffer full, or timeout
462+ await self .wait_for_stopped (buffer_full , timeout = timeout )
463+ if buffer :
464+ # make sure background thread does not add new items to
465+ # buffer while we read.
466+ buffer_consuming = self .loop .create_future ()
467+ try :
468+ yield list (events )
469+ finally :
470+ buffer .clear ()
471+ for event in events :
472+ await self .ack (event )
473+ events .clear ()
474+ # allow writing to buffer again
475+ notify (buffer_consuming )
476+ buffer_full .clear ()
477+ buffer_consumed .set ()
478+ else : # pragma: no cover
479+ pass
480+ else : # pragma: no cover
481+ pass
482+
483+ finally :
484+ # Restore last behaviour of "enable_acks"
485+ self .enable_acks = stream_enable_acks
486+ self ._processors .remove (add_to_buffer )
487+
394488 async def take_with_timestamp (
395- self , max_ : int , within : Seconds , timestamp_field_name : str
396- ) -> AsyncIterable [Sequence [T_co ]]:
489+ self , max_ : int , within : Seconds ,
490+ timestamp_field_name : str ) -> AsyncIterable [Sequence [T_co ]]:
397491 """Buffer n values at a time and yield a list of buffered values with the
398492 timestamp when the message was added to kafka.
399493
@@ -439,7 +533,8 @@ async def add_to_buffer(value: T) -> T:
439533 value [timestamp_field_name ] = event .message .timestamp
440534 buffer_add (value )
441535 if event is None :
442- raise RuntimeError ("Take buffer found current_event is None" )
536+ raise RuntimeError (
537+ "Take buffer found current_event is None" )
443538 event_add (event )
444539 if buffer_size () >= max_ :
445540 # signal that the buffer is full and should be emptied.
@@ -498,9 +593,8 @@ def enumerate(self, start: int = 0) -> AsyncIterable[Tuple[int, T_co]]:
498593 """
499594 return aenumerate (self , start )
500595
501- async def noack_take (
502- self , max_ : int , within : Seconds
503- ) -> AsyncIterable [Sequence [T_co ]]:
596+ async def noack_take (self , max_ : int ,
597+ within : Seconds ) -> AsyncIterable [Sequence [T_co ]]:
504598 """
505599 Buffer n values at a time and yield a list of buffered values.
506600 :param max_: Max number of messages to receive. When more than this
@@ -543,7 +637,8 @@ async def add_to_buffer(value: T) -> T:
543637 event = self .current_event
544638 buffer_add (cast (T_co , event ))
545639 if event is None :
546- raise RuntimeError ("Take buffer found current_event is None" )
640+ raise RuntimeError (
641+ "Take buffer found current_event is None" )
547642
548643 event_add (event )
549644 if buffer_size () >= max_ :
@@ -628,8 +723,7 @@ async def mytask(stream):
628723 return self
629724 if self .concurrency_index is not None :
630725 raise ImproperlyConfigured (
631- "Agent with concurrency>1 cannot use stream.through!"
632- )
726+ "Agent with concurrency>1 cannot use stream.through!" )
633727 # ridiculous mypy
634728 if isinstance (channel , str ):
635729 channelchannel = cast (ChannelT , self .derive_topic (channel ))
@@ -638,7 +732,8 @@ async def mytask(stream):
638732
639733 channel_it = aiter (channelchannel )
640734 if self ._next is not None :
641- raise ImproperlyConfigured ("Stream is already using group_by/through" )
735+ raise ImproperlyConfigured (
736+ "Stream is already using group_by/through" )
642737 through = self ._chain (channel = channel_it )
643738
644739 async def forward (value : T ) -> T :
@@ -649,12 +744,17 @@ async def forward(value: T) -> T:
649744 self ._enable_passive (cast (ChannelT , channel_it ), declare = True )
650745 return through
651746
652- def _enable_passive (self , channel : ChannelT , * , declare : bool = False ) -> None :
747+ def _enable_passive (self ,
748+ channel : ChannelT ,
749+ * ,
750+ declare : bool = False ) -> None :
653751 if not self ._passive :
654752 self ._passive = True
655753 self .add_future (self ._passive_drainer (channel , declare ))
656754
657- async def _passive_drainer (self , channel : ChannelT , declare : bool = False ) -> None :
755+ async def _passive_drainer (self ,
756+ channel : ChannelT ,
757+ declare : bool = False ) -> None :
658758 try :
659759 if declare :
660760 await channel .maybe_declare ()
@@ -760,13 +860,13 @@ def get_key(withdrawal):
760860 channel : ChannelT
761861 if self .concurrency_index is not None :
762862 raise ImproperlyConfigured (
763- "Agent with concurrency>1 cannot use stream.group_by!"
764- )
863+ "Agent with concurrency>1 cannot use stream.group_by!" )
765864 if not name :
766865 if isinstance (key , FieldDescriptorT ):
767866 name = key .ident
768867 else :
769- raise TypeError ("group_by with callback must set name=topic_suffix" )
868+ raise TypeError (
869+ "group_by with callback must set name=topic_suffix" )
770870 if topic is not None :
771871 channel = topic
772872 else :
@@ -776,9 +876,10 @@ def get_key(withdrawal):
776876 prefix = self .prefix + "-"
777877 suffix = f"-{ name } -repartition"
778878 p = partitions if partitions else self .app .conf .topic_partitions
779- channel = cast (ChannelT , self .channel ).derive (
780- prefix = prefix , suffix = suffix , partitions = p , internal = True
781- )
879+ channel = cast (ChannelT , self .channel ).derive (prefix = prefix ,
880+ suffix = suffix ,
881+ partitions = p ,
882+ internal = True )
782883 format_key = self ._format_key
783884
784885 channel_it = aiter (channel )
@@ -789,7 +890,8 @@ def get_key(withdrawal):
789890 async def repartition (value : T ) -> T :
790891 event = self .current_event
791892 if event is None :
792- raise RuntimeError ("Cannot repartition stream with non-topic channel" )
893+ raise RuntimeError (
894+ "Cannot repartition stream with non-topic channel" )
793895 new_key = await format_key (key , value )
794896 await event .forward (channel , key = new_key )
795897 return value
@@ -1070,10 +1172,9 @@ async def _py_aiter(self) -> AsyncIterator[T_co]:
10701172 tp = message .tp
10711173 offset = message .offset
10721174
1073- if (
1074- not self .app .flow_control .is_active ()
1075- or message .generation_id != self .app .consumer_generation_id
1076- ):
1175+ if (not self .app .flow_control .is_active ()
1176+ or message .generation_id !=
1177+ self .app .consumer_generation_id ):
10771178 value = skipped_value
10781179 self .log .dev (
10791180 "Skipping message %r with generation_id %r because "
@@ -1092,7 +1193,8 @@ async def _py_aiter(self) -> AsyncIterator[T_co]:
10921193 # XXX ugh this should be in the consumer somehow
10931194
10941195 # call Sensors
1095- sensor_state = on_stream_event_in (tp , offset , self , event )
1196+ sensor_state = on_stream_event_in (
1197+ tp , offset , self , event )
10961198
10971199 # set task-local current_event
10981200 _current_event_contextvar .set (create_ref (event ))
@@ -1125,13 +1227,15 @@ async def _py_aiter(self) -> AsyncIterator[T_co]:
11251227 self .current_event = None
11261228 # We want to ack the filtered out message
11271229 # otherwise the lag would increase
1128- if event is not None and (do_ack or value is skipped_value ):
1230+ if event is not None and (do_ack
1231+ or value is skipped_value ):
11291232 # This inlines self.ack
11301233 last_stream_to_ack = event .ack ()
11311234 message = event .message
11321235 tp = event .message .tp
11331236 offset = event .message .offset
1134- on_stream_event_out (tp , offset , self , event , sensor_state )
1237+ on_stream_event_out (tp , offset , self , event ,
1238+ sensor_state )
11351239 if last_stream_to_ack :
11361240 on_message_out (tp , offset , message )
11371241
0 commit comments