Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/exo/api/adapters/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ async def collect_claude_response(
thinking_parts: list[str] = []
tool_use_blocks: list[ClaudeToolUseBlock] = []
stop_reason: ClaudeStopReason | None = None
matched_stop_sequence: str | None = None
last_usage: Usage | None = None
error_message: str | None = None

Expand Down Expand Up @@ -303,12 +304,19 @@ async def collect_claude_response(
else:
text_parts.append(chunk.text)

if chunk.matched_stop_sequence is not None:
matched_stop_sequence = chunk.matched_stop_sequence
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)

if error_message is not None:
raise ValueError(error_message)

# A matched stop sequence is reported as "stop_sequence", not the generic
# "end_turn" that a bare "stop" finish_reason maps to.
if matched_stop_sequence is not None:
stop_reason = "stop_sequence"

combined_text = "".join(text_parts)
combined_thinking = "".join(thinking_parts)

Expand All @@ -333,6 +341,7 @@ async def collect_claude_response(
model=model,
content=content,
stop_reason=stop_reason,
stop_sequence=matched_stop_sequence,
usage=ClaudeUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
Expand Down Expand Up @@ -362,6 +371,7 @@ async def generate_claude_stream(

output_tokens = 0
stop_reason: ClaudeStopReason | None = None
matched_stop_sequence: str | None = None
last_usage: Usage | None = None
next_block_index = 0

Expand Down Expand Up @@ -454,9 +464,16 @@ async def generate_claude_stream(
)
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"

if chunk.matched_stop_sequence is not None:
matched_stop_sequence = chunk.matched_stop_sequence
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)

# A matched stop sequence is reported as "stop_sequence", not the generic
# "end_turn" that a bare "stop" finish_reason maps to.
if matched_stop_sequence is not None:
stop_reason = "stop_sequence"

# Use actual token count from usage if available
if last_usage is not None:
output_tokens = last_usage.completion_tokens
Expand All @@ -480,7 +497,9 @@ async def generate_claude_stream(

# message_delta
message_delta = ClaudeMessageDeltaEvent(
delta=ClaudeMessageDelta(stop_reason=stop_reason),
delta=ClaudeMessageDelta(
stop_reason=stop_reason, stop_sequence=matched_stop_sequence
),
usage=ClaudeMessageDeltaUsage(output_tokens=output_tokens),
)
yield f"event: message_delta\ndata: {message_delta.model_dump_json()}\n\n"
Expand Down
143 changes: 143 additions & 0 deletions src/exo/api/tests/test_claude_stop_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Tests for stop-sequence reporting in the Claude Messages API adapter.

When generation stops because a user-supplied stop sequence matched, the
response must report stop_reason="stop_sequence" and echo the matched sequence
in the stop_sequence field — not the generic "end_turn" used for a natural EOS.
"""

import json
from collections.abc import AsyncGenerator
from typing import Any, cast

from exo.api.adapters.claude import (
ClaudeMessagesResponse,
collect_claude_response,
generate_claude_stream,
)
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId, ModelId

MODEL = ModelId("test-model")
COMMAND_ID = CommandId("cmd_test123")


async def _chunks_to_stream(
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk],
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
for chunk in chunks:
yield chunk


async def _collect_response(
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk],
) -> ClaudeMessagesResponse:
parts: list[str] = []
async for part in collect_claude_response(
COMMAND_ID, MODEL, _chunks_to_stream(chunks)
):
parts.append(part)
return ClaudeMessagesResponse.model_validate_json("".join(parts))


def _parse_sse_events(events: list[str]) -> list[dict[str, Any]]:
parsed: list[dict[str, Any]] = []
for event_str in events:
for line in event_str.strip().split("\n"):
if line.startswith("data: "):
parsed.append(cast(dict[str, Any], json.loads(line[6:])))
return parsed


class TestCollectClaudeResponseStopSequence:
async def test_matched_stop_sequence_reports_stop_sequence(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
TokenChunk(model=MODEL, text="ABC", token_id=1, usage=None),
TokenChunk(
model=MODEL,
text="",
token_id=2,
usage=None,
finish_reason="stop",
matched_stop_sequence="END",
),
]
response = await _collect_response(chunks)

assert response.stop_reason == "stop_sequence"
assert response.stop_sequence == "END"
# The stop sequence itself is never part of the emitted text.
text_blocks = [b for b in response.content if b.type == "text"]
assert "END" not in "".join(b.text for b in text_blocks)

async def test_natural_eos_still_reports_end_turn(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
TokenChunk(model=MODEL, text="Hello", token_id=1, usage=None),
TokenChunk(
model=MODEL, text="", token_id=2, usage=None, finish_reason="stop"
),
]
response = await _collect_response(chunks)

assert response.stop_reason == "end_turn"
assert response.stop_sequence is None

async def test_length_limit_reports_max_tokens(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
TokenChunk(
model=MODEL,
text="Hello",
token_id=1,
usage=None,
finish_reason="length",
),
]
response = await _collect_response(chunks)

assert response.stop_reason == "max_tokens"
assert response.stop_sequence is None


class TestStreamingClaudeResponseStopSequence:
async def _stream_events(
self, chunks: list[ErrorChunk | ToolCallChunk | TokenChunk]
) -> list[dict[str, Any]]:
events: list[str] = []
async for event in generate_claude_stream(
COMMAND_ID, MODEL, _chunks_to_stream(chunks)
):
events.append(event)
return _parse_sse_events(events)

async def test_streaming_matched_stop_sequence(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
TokenChunk(model=MODEL, text="ABC", token_id=1, usage=None),
TokenChunk(
model=MODEL,
text="",
token_id=2,
usage=None,
finish_reason="stop",
matched_stop_sequence="END",
),
]
parsed = await self._stream_events(chunks)

message_deltas = [p for p in parsed if p.get("type") == "message_delta"]
assert len(message_deltas) == 1
delta = cast(dict[str, Any], message_deltas[0]["delta"])
assert delta["stop_reason"] == "stop_sequence"
assert delta["stop_sequence"] == "END"

async def test_streaming_natural_eos(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
TokenChunk(model=MODEL, text="Hi", token_id=1, usage=None),
TokenChunk(
model=MODEL, text="", token_id=2, usage=None, finish_reason="stop"
),
]
parsed = await self._stream_events(chunks)

message_deltas = [p for p in parsed if p.get("type") == "message_delta"]
delta = cast(dict[str, Any], message_deltas[0]["delta"])
assert delta["stop_reason"] == "end_turn"
assert delta["stop_sequence"] is None
3 changes: 3 additions & 0 deletions src/exo/shared/types/chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class TokenChunk(BaseChunk):
logprob: float | None = None
top_logprobs: list[TopLogprobItem] | None = None
is_thinking: bool = False
# Set to the matched stop sequence when generation stopped because a
# user-supplied stop sequence was hit (distinguishes it from a natural EOS).
matched_stop_sequence: str | None = None


class ErrorChunk(BaseChunk):
Expand Down
3 changes: 3 additions & 0 deletions src/exo/shared/types/worker/runner_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ class GenerationResponse(BaseRunnerResponse):
stats: GenerationStats | None = None
usage: Usage | None
is_thinking: bool = False
# The stop sequence that terminated generation, when ``finish_reason`` is
# "stop" because a user-supplied stop sequence matched (vs. a natural EOS).
matched_stop_sequence: str | None = None


class ImageGenerationResponse(BaseRunnerResponse):
Expand Down
48 changes: 22 additions & 26 deletions src/exo/worker/engines/mlx/generator/batch_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
prefill,
)
from exo.worker.engines.mlx.generator.remote_prefill import remote_prefill
from exo.worker.engines.mlx.generator.stop_sequences import scan_stop_sequences
from exo.worker.engines.mlx.patches.opt_batch_gen import (
set_needs_topk,
take_ready_topk,
Expand Down Expand Up @@ -369,28 +370,29 @@ def step(self) -> list[tuple[int, GenerationResponse]]:
f"[bench] uid={response.uid} tok#{state.completion_tokens} {text!r} t={delta:.4f}s"
)
state.generated_text_parts.append(text)
state.potential_stop_sequence_text += text

finish_reason: FinishReason | None = cast(
FinishReason | None, response.finish_reason
)
# Hold back any trailing partial stop-sequence match so a multi-token
# stop sequence never leaks its leading bytes into output.
model_finish_reason = cast(FinishReason | None, response.finish_reason)
task_params = state.task_params
stop_sequences = _stop_sequences(task_params)
max_stop_len = max((len(s) for s in stop_sequences), default=0)

if stop_sequences:
for stop_seq in stop_sequences:
if stop_seq in state.potential_stop_sequence_text:
stop_index = state.potential_stop_sequence_text.find(stop_seq)
text_before_stop = state.potential_stop_sequence_text[
:stop_index
]
chunk_start = len(state.potential_stop_sequence_text) - len(
text
)
text = text_before_stop[chunk_start:]
finish_reason = "stop"
break

state.potential_stop_sequence_text += text
text, matched_stop_sequence, state.potential_stop_sequence_text = (
scan_stop_sequences(state.potential_stop_sequence_text, stop_sequences)
)

finish_reason: FinishReason | None
if matched_stop_sequence is not None:
finish_reason = "stop"
elif model_finish_reason is not None:
# Natural EOS / length limit: flush any held-back text — it is
# real output that merely looked like a stop-sequence prefix.
text += state.potential_stop_sequence_text
state.potential_stop_sequence_text = ""
finish_reason = model_finish_reason
else:
finish_reason = None

is_done = finish_reason is not None

Expand Down Expand Up @@ -457,19 +459,13 @@ def step(self) -> list[tuple[int, GenerationResponse]]:
finish_reason=finish_reason,
stats=stats,
usage=usage,
matched_stop_sequence=matched_stop_sequence,
),
)
)

if is_done:
del self._active_tasks[response.uid]
elif (
max_stop_len > 0
and len(state.potential_stop_sequence_text) > max_stop_len
):
state.potential_stop_sequence_text = state.potential_stop_sequence_text[
-max_stop_len:
]

_step_elapsed = time.perf_counter() - _step_tic
_overhead = _step_elapsed - _next_elapsed
Expand Down
Loading