Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 67 additions & 29 deletions src/exo/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,19 +702,19 @@ async def cancel_command(self, command_id: CommandId) -> CancelCommandResponse:
)

async def _token_chunk_stream(
self, command_id: CommandId
self,
command_id: CommandId,
recv: Receiver[TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk],
) -> AsyncGenerator[
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk, None
]:
"""Yield chunks for a given command until completion.

This is the internal low-level stream used by all API adapters.
The caller must register the token queue and pass the receive end
here before dispatching the command, so no tokens are dropped.
"""
try:
self._text_generation_queues[command_id], recv = channel[
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk
]()

with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
Expand All @@ -736,7 +736,11 @@ async def _token_chunk_stream(
del self._text_generation_queues[command_id]

async def _collect_text_generation_with_stats(
self, command_id: CommandId
self,
command_id: CommandId,
token_stream: AsyncGenerator[
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk, None
],
) -> BenchChatCompletionResponse:
sampler = PowerSampler(get_node_system=lambda: self.state.node_system)
text_parts: list[str] = []
Expand All @@ -749,7 +753,7 @@ async def _collect_text_generation_with_stats(
async with anyio.create_task_group() as tg:
tg.start_soon(sampler.run)

async for chunk in self._token_chunk_stream(command_id):
async for chunk in token_stream:
if isinstance(chunk, PrefillProgressChunk):
continue

Expand Down Expand Up @@ -811,20 +815,40 @@ async def _trigger_notify_user_to_download_model(self, model_id: ModelId) -> Non

async def _send_text_generation_with_images(
self, task_params: TextGenerationTaskParams
) -> TextGeneration:
) -> tuple[
TextGeneration,
AsyncGenerator[
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk, None
],
]:
"""Build a TextGeneration command, register its token queue, then dispatch it.

The token queue is registered before the command is sent so that tokens
emitted by workers before the HTTP consumer starts iterating are never
dropped. All callers must use the returned stream and must not call
_token_chunk_stream(command.command_id) separately.
"""
task_params = task_params.with_card_sampling_defaults()
images = task_params.images
if not images:
command = TextGeneration(task_params=task_params)
self._text_generation_queues[command.command_id], recv = channel[
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk
]()
token_stream = self._token_chunk_stream(command.command_id, recv)
await self._send(command)
return command
return command, token_stream

hashes = [hashlib.sha256(img.encode("ascii")).hexdigest() for img in images]
all_hashes = {idx: Base64ImageHash(h) for idx, h in enumerate(hashes)}
task_params = task_params.model_copy(
update={"images": [], "image_hashes": all_hashes}
)
command = TextGeneration(task_params=task_params)
self._text_generation_queues[command.command_id], recv = channel[
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk
]()
token_stream = self._token_chunk_stream(command.command_id, recv)

new_images: list[tuple[int, str]] = []
for idx, (img, h) in enumerate(zip(images, hashes, strict=True)):
Expand All @@ -834,7 +858,7 @@ async def _send_text_generation_with_images(

if not new_images:
await self._send(command)
return command
return command, token_stream

all_chunks: list[tuple[int, str]] = []
for img_idx, img_data in new_images:
Expand All @@ -856,7 +880,7 @@ async def _send_text_generation_with_images(
)

await self._send(command)
return command
return command, token_stream

async def chat_completions(
self, payload: ChatCompletionRequest
Expand All @@ -868,14 +892,16 @@ async def chat_completions(
)
task_params = task_params.model_copy(update={"model": resolved_model})

command = await self._send_text_generation_with_images(task_params)
command, token_stream = await self._send_text_generation_with_images(
task_params
)

if payload.stream:
return StreamingResponse(
with_sse_keepalive(
generate_chat_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
token_stream,
),
),
media_type="text/event-stream",
Expand All @@ -889,7 +915,7 @@ async def chat_completions(
return StreamingResponse(
collect_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
token_stream,
),
media_type="application/json",
)
Expand All @@ -911,14 +937,16 @@ async def bench_chat_completions(
}
)

command = await self._send_text_generation_with_images(task_params)
command, token_stream = await self._send_text_generation_with_images(
task_params
)

if payload.stream:
return StreamingResponse(
with_sse_keepalive(
generate_chat_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
token_stream,
),
),
media_type="text/event-stream",
Expand All @@ -929,7 +957,9 @@ async def bench_chat_completions(
},
)

return await self._collect_text_generation_with_stats(command.command_id)
return await self._collect_text_generation_with_stats(
command.command_id, token_stream
)

async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
"""Validate a text model exists and return the resolved model ID.
Expand Down Expand Up @@ -1489,15 +1519,17 @@ async def claude_messages(
)
task_params = task_params.model_copy(update={"model": resolved_model})

command = await self._send_text_generation_with_images(task_params)
command, token_stream = await self._send_text_generation_with_images(
task_params
)

if payload.stream:
return StreamingResponse(
with_sse_keepalive(
generate_claude_stream(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
token_stream,
),
),
media_type="text/event-stream",
Expand All @@ -1512,7 +1544,7 @@ async def claude_messages(
collect_claude_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
token_stream,
),
media_type="application/json",
)
Expand All @@ -1525,15 +1557,17 @@ async def openai_responses(
resolved_model = await self._resolve_and_validate_text_model(task_params.model)
task_params = task_params.model_copy(update={"model": resolved_model})

command = await self._send_text_generation_with_images(task_params)
command, token_stream = await self._send_text_generation_with_images(
task_params
)

if payload.stream:
return StreamingResponse(
with_sse_keepalive(
generate_responses_stream(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
token_stream,
),
),
media_type="text/event-stream",
Expand All @@ -1549,7 +1583,7 @@ async def openai_responses(
collect_responses_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
token_stream,
),
media_type="application/json",
)
Expand All @@ -1570,13 +1604,15 @@ async def ollama_chat(
)
task_params = task_params.model_copy(update={"model": resolved_model})

command = await self._send_text_generation_with_images(task_params)
command, token_stream = await self._send_text_generation_with_images(
task_params
)

if payload.stream:
return StreamingResponse(
generate_ollama_chat_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
token_stream,
),
media_type="application/x-ndjson",
headers={
Expand All @@ -1589,7 +1625,7 @@ async def ollama_chat(
return StreamingResponse(
collect_ollama_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
token_stream,
),
media_type="application/json",
)
Expand All @@ -1606,13 +1642,15 @@ async def ollama_generate(
)
task_params = task_params.model_copy(update={"model": resolved_model})

command = await self._send_text_generation_with_images(task_params)
command, token_stream = await self._send_text_generation_with_images(
task_params
)

if payload.stream:
return StreamingResponse(
generate_ollama_generate_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
token_stream,
),
media_type="application/x-ndjson",
headers={
Expand All @@ -1625,7 +1663,7 @@ async def ollama_generate(
return StreamingResponse(
collect_ollama_generate_response(
command.command_id,
self._token_chunk_stream(command.command_id),
token_stream,
),
media_type="application/json",
)
Expand Down