diff --git a/src/exo/api/main.py b/src/exo/api/main.py index 8fe0cfbecb..f2209551d9 100644 --- a/src/exo/api/main.py +++ b/src/exo/api/main.py @@ -702,19 +702,19 @@ async def cancel_command(self, command_id: CommandId) -> CancelCommandResponse: ) async def _token_chunk_stream( - self, command_id: CommandId + self, + command_id: CommandId, + recv: Receiver[TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk], ) -> AsyncGenerator[ TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk, None ]: """Yield chunks for a given command until completion. This is the internal low-level stream used by all API adapters. + The caller must register the token queue and pass the receive end + here before dispatching the command, so no tokens are dropped. """ try: - self._text_generation_queues[command_id], recv = channel[ - TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk - ]() - with recv as token_chunks: async for chunk in token_chunks: yield chunk @@ -736,7 +736,11 @@ async def _token_chunk_stream( del self._text_generation_queues[command_id] async def _collect_text_generation_with_stats( - self, command_id: CommandId + self, + command_id: CommandId, + token_stream: AsyncGenerator[ + TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk, None + ], ) -> BenchChatCompletionResponse: sampler = PowerSampler(get_node_system=lambda: self.state.node_system) text_parts: list[str] = [] @@ -749,7 +753,7 @@ async def _collect_text_generation_with_stats( async with anyio.create_task_group() as tg: tg.start_soon(sampler.run) - async for chunk in self._token_chunk_stream(command_id): + async for chunk in token_stream: if isinstance(chunk, PrefillProgressChunk): continue @@ -811,13 +815,29 @@ async def _trigger_notify_user_to_download_model(self, model_id: ModelId) -> Non async def _send_text_generation_with_images( self, task_params: TextGenerationTaskParams - ) -> TextGeneration: + ) -> tuple[ + TextGeneration, + AsyncGenerator[ + TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk, None + ], + ]: + """Build a TextGeneration command, register its token queue, then dispatch it. + + The token queue is registered before the command is sent so that tokens + emitted by workers before the HTTP consumer starts iterating are never + dropped. All callers must use the returned stream and must not call + _token_chunk_stream(command.command_id) separately. + """ task_params = task_params.with_card_sampling_defaults() images = task_params.images if not images: command = TextGeneration(task_params=task_params) + self._text_generation_queues[command.command_id], recv = channel[ + TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk + ]() + token_stream = self._token_chunk_stream(command.command_id, recv) await self._send(command) - return command + return command, token_stream hashes = [hashlib.sha256(img.encode("ascii")).hexdigest() for img in images] all_hashes = {idx: Base64ImageHash(h) for idx, h in enumerate(hashes)} @@ -825,6 +845,10 @@ async def _send_text_generation_with_images( update={"images": [], "image_hashes": all_hashes} ) command = TextGeneration(task_params=task_params) + self._text_generation_queues[command.command_id], recv = channel[ + TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk + ]() + token_stream = self._token_chunk_stream(command.command_id, recv) new_images: list[tuple[int, str]] = [] for idx, (img, h) in enumerate(zip(images, hashes, strict=True)): @@ -834,7 +858,7 @@ async def _send_text_generation_with_images( if not new_images: await self._send(command) - return command + return command, token_stream all_chunks: list[tuple[int, str]] = [] for img_idx, img_data in new_images: @@ -856,7 +880,7 @@ async def _send_text_generation_with_images( ) await self._send(command) - return command + return command, token_stream async def chat_completions( self, payload: ChatCompletionRequest @@ -868,14 +892,16 @@ async def chat_completions( ) task_params = task_params.model_copy(update={"model": resolved_model}) - command = await self._send_text_generation_with_images(task_params) + command, token_stream = await self._send_text_generation_with_images( + task_params + ) if payload.stream: return StreamingResponse( with_sse_keepalive( generate_chat_stream( command.command_id, - self._token_chunk_stream(command.command_id), + token_stream, ), ), media_type="text/event-stream", @@ -889,7 +915,7 @@ async def chat_completions( return StreamingResponse( collect_chat_response( command.command_id, - self._token_chunk_stream(command.command_id), + token_stream, ), media_type="application/json", ) @@ -911,14 +937,16 @@ async def bench_chat_completions( } ) - command = await self._send_text_generation_with_images(task_params) + command, token_stream = await self._send_text_generation_with_images( + task_params + ) if payload.stream: return StreamingResponse( with_sse_keepalive( generate_chat_stream( command.command_id, - self._token_chunk_stream(command.command_id), + token_stream, ), ), media_type="text/event-stream", @@ -929,7 +957,9 @@ async def bench_chat_completions( }, ) - return await self._collect_text_generation_with_stats(command.command_id) + return await self._collect_text_generation_with_stats( + command.command_id, token_stream + ) async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId: """Validate a text model exists and return the resolved model ID. @@ -1489,7 +1519,9 @@ async def claude_messages( ) task_params = task_params.model_copy(update={"model": resolved_model}) - command = await self._send_text_generation_with_images(task_params) + command, token_stream = await self._send_text_generation_with_images( + task_params + ) if payload.stream: return StreamingResponse( @@ -1497,7 +1529,7 @@ async def claude_messages( generate_claude_stream( command.command_id, payload.model, - self._token_chunk_stream(command.command_id), + token_stream, ), ), media_type="text/event-stream", @@ -1512,7 +1544,7 @@ async def claude_messages( collect_claude_response( command.command_id, payload.model, - self._token_chunk_stream(command.command_id), + token_stream, ), media_type="application/json", ) @@ -1525,7 +1557,9 @@ async def openai_responses( resolved_model = await self._resolve_and_validate_text_model(task_params.model) task_params = task_params.model_copy(update={"model": resolved_model}) - command = await self._send_text_generation_with_images(task_params) + command, token_stream = await self._send_text_generation_with_images( + task_params + ) if payload.stream: return StreamingResponse( @@ -1533,7 +1567,7 @@ async def openai_responses( generate_responses_stream( command.command_id, payload.model, - self._token_chunk_stream(command.command_id), + token_stream, ), ), media_type="text/event-stream", @@ -1549,7 +1583,7 @@ async def openai_responses( collect_responses_response( command.command_id, payload.model, - self._token_chunk_stream(command.command_id), + token_stream, ), media_type="application/json", ) @@ -1570,13 +1604,15 @@ async def ollama_chat( ) task_params = task_params.model_copy(update={"model": resolved_model}) - command = await self._send_text_generation_with_images(task_params) + command, token_stream = await self._send_text_generation_with_images( + task_params + ) if payload.stream: return StreamingResponse( generate_ollama_chat_stream( command.command_id, - self._token_chunk_stream(command.command_id), + token_stream, ), media_type="application/x-ndjson", headers={ @@ -1589,7 +1625,7 @@ async def ollama_chat( return StreamingResponse( collect_ollama_chat_response( command.command_id, - self._token_chunk_stream(command.command_id), + token_stream, ), media_type="application/json", ) @@ -1606,13 +1642,15 @@ async def ollama_generate( ) task_params = task_params.model_copy(update={"model": resolved_model}) - command = await self._send_text_generation_with_images(task_params) + command, token_stream = await self._send_text_generation_with_images( + task_params + ) if payload.stream: return StreamingResponse( generate_ollama_generate_stream( command.command_id, - self._token_chunk_stream(command.command_id), + token_stream, ), media_type="application/x-ndjson", headers={ @@ -1625,7 +1663,7 @@ async def ollama_generate( return StreamingResponse( collect_ollama_generate_response( command.command_id, - self._token_chunk_stream(command.command_id), + token_stream, ), media_type="application/json", )