diff --git a/src/exo/api/adapters/claude.py b/src/exo/api/adapters/claude.py index a54de764c6..c0f179b6d1 100644 --- a/src/exo/api/adapters/claude.py +++ b/src/exo/api/adapters/claude.py @@ -273,6 +273,7 @@ async def collect_claude_response( thinking_parts: list[str] = [] tool_use_blocks: list[ClaudeToolUseBlock] = [] stop_reason: ClaudeStopReason | None = None + matched_stop_sequence: str | None = None last_usage: Usage | None = None error_message: str | None = None @@ -303,12 +304,19 @@ async def collect_claude_response( else: text_parts.append(chunk.text) + if chunk.matched_stop_sequence is not None: + matched_stop_sequence = chunk.matched_stop_sequence if chunk.finish_reason is not None: stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason) if error_message is not None: raise ValueError(error_message) + # A matched stop sequence is reported as "stop_sequence", not the generic + # "end_turn" that a bare "stop" finish_reason maps to. + if matched_stop_sequence is not None: + stop_reason = "stop_sequence" + combined_text = "".join(text_parts) combined_thinking = "".join(thinking_parts) @@ -333,6 +341,7 @@ async def collect_claude_response( model=model, content=content, stop_reason=stop_reason, + stop_sequence=matched_stop_sequence, usage=ClaudeUsage( input_tokens=input_tokens, output_tokens=output_tokens, @@ -362,6 +371,7 @@ async def generate_claude_stream( output_tokens = 0 stop_reason: ClaudeStopReason | None = None + matched_stop_sequence: str | None = None last_usage: Usage | None = None next_block_index = 0 @@ -454,9 +464,16 @@ async def generate_claude_stream( ) yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n" + if chunk.matched_stop_sequence is not None: + matched_stop_sequence = chunk.matched_stop_sequence if chunk.finish_reason is not None: stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason) + # A matched stop sequence is reported as "stop_sequence", not the generic + # "end_turn" that a bare "stop" finish_reason maps to. + if matched_stop_sequence is not None: + stop_reason = "stop_sequence" + # Use actual token count from usage if available if last_usage is not None: output_tokens = last_usage.completion_tokens @@ -480,7 +497,9 @@ async def generate_claude_stream( # message_delta message_delta = ClaudeMessageDeltaEvent( - delta=ClaudeMessageDelta(stop_reason=stop_reason), + delta=ClaudeMessageDelta( + stop_reason=stop_reason, stop_sequence=matched_stop_sequence + ), usage=ClaudeMessageDeltaUsage(output_tokens=output_tokens), ) yield f"event: message_delta\ndata: {message_delta.model_dump_json()}\n\n" diff --git a/src/exo/api/tests/test_claude_stop_sequence.py b/src/exo/api/tests/test_claude_stop_sequence.py new file mode 100644 index 0000000000..556cbb55e8 --- /dev/null +++ b/src/exo/api/tests/test_claude_stop_sequence.py @@ -0,0 +1,143 @@ +"""Tests for stop-sequence reporting in the Claude Messages API adapter. + +When generation stops because a user-supplied stop sequence matched, the +response must report stop_reason="stop_sequence" and echo the matched sequence +in the stop_sequence field — not the generic "end_turn" used for a natural EOS. +""" + +import json +from collections.abc import AsyncGenerator +from typing import Any, cast + +from exo.api.adapters.claude import ( + ClaudeMessagesResponse, + collect_claude_response, + generate_claude_stream, +) +from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk +from exo.shared.types.common import CommandId, ModelId + +MODEL = ModelId("test-model") +COMMAND_ID = CommandId("cmd_test123") + + +async def _chunks_to_stream( + chunks: list[ErrorChunk | ToolCallChunk | TokenChunk], +) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]: + for chunk in chunks: + yield chunk + + +async def _collect_response( + chunks: list[ErrorChunk | ToolCallChunk | TokenChunk], +) -> ClaudeMessagesResponse: + parts: list[str] = [] + async for part in collect_claude_response( + COMMAND_ID, MODEL, _chunks_to_stream(chunks) + ): + parts.append(part) + return ClaudeMessagesResponse.model_validate_json("".join(parts)) + + +def _parse_sse_events(events: list[str]) -> list[dict[str, Any]]: + parsed: list[dict[str, Any]] = [] + for event_str in events: + for line in event_str.strip().split("\n"): + if line.startswith("data: "): + parsed.append(cast(dict[str, Any], json.loads(line[6:]))) + return parsed + + +class TestCollectClaudeResponseStopSequence: + async def test_matched_stop_sequence_reports_stop_sequence(self): + chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [ + TokenChunk(model=MODEL, text="ABC", token_id=1, usage=None), + TokenChunk( + model=MODEL, + text="", + token_id=2, + usage=None, + finish_reason="stop", + matched_stop_sequence="END", + ), + ] + response = await _collect_response(chunks) + + assert response.stop_reason == "stop_sequence" + assert response.stop_sequence == "END" + # The stop sequence itself is never part of the emitted text. + text_blocks = [b for b in response.content if b.type == "text"] + assert "END" not in "".join(b.text for b in text_blocks) + + async def test_natural_eos_still_reports_end_turn(self): + chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [ + TokenChunk(model=MODEL, text="Hello", token_id=1, usage=None), + TokenChunk( + model=MODEL, text="", token_id=2, usage=None, finish_reason="stop" + ), + ] + response = await _collect_response(chunks) + + assert response.stop_reason == "end_turn" + assert response.stop_sequence is None + + async def test_length_limit_reports_max_tokens(self): + chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [ + TokenChunk( + model=MODEL, + text="Hello", + token_id=1, + usage=None, + finish_reason="length", + ), + ] + response = await _collect_response(chunks) + + assert response.stop_reason == "max_tokens" + assert response.stop_sequence is None + + +class TestStreamingClaudeResponseStopSequence: + async def _stream_events( + self, chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] + ) -> list[dict[str, Any]]: + events: list[str] = [] + async for event in generate_claude_stream( + COMMAND_ID, MODEL, _chunks_to_stream(chunks) + ): + events.append(event) + return _parse_sse_events(events) + + async def test_streaming_matched_stop_sequence(self): + chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [ + TokenChunk(model=MODEL, text="ABC", token_id=1, usage=None), + TokenChunk( + model=MODEL, + text="", + token_id=2, + usage=None, + finish_reason="stop", + matched_stop_sequence="END", + ), + ] + parsed = await self._stream_events(chunks) + + message_deltas = [p for p in parsed if p.get("type") == "message_delta"] + assert len(message_deltas) == 1 + delta = cast(dict[str, Any], message_deltas[0]["delta"]) + assert delta["stop_reason"] == "stop_sequence" + assert delta["stop_sequence"] == "END" + + async def test_streaming_natural_eos(self): + chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [ + TokenChunk(model=MODEL, text="Hi", token_id=1, usage=None), + TokenChunk( + model=MODEL, text="", token_id=2, usage=None, finish_reason="stop" + ), + ] + parsed = await self._stream_events(chunks) + + message_deltas = [p for p in parsed if p.get("type") == "message_delta"] + delta = cast(dict[str, Any], message_deltas[0]["delta"]) + assert delta["stop_reason"] == "end_turn" + assert delta["stop_sequence"] is None diff --git a/src/exo/shared/types/chunks.py b/src/exo/shared/types/chunks.py index 82425d9f9c..2d6757c63d 100644 --- a/src/exo/shared/types/chunks.py +++ b/src/exo/shared/types/chunks.py @@ -29,6 +29,9 @@ class TokenChunk(BaseChunk): logprob: float | None = None top_logprobs: list[TopLogprobItem] | None = None is_thinking: bool = False + # Set to the matched stop sequence when generation stopped because a + # user-supplied stop sequence was hit (distinguishes it from a natural EOS). + matched_stop_sequence: str | None = None class ErrorChunk(BaseChunk): diff --git a/src/exo/shared/types/worker/runner_response.py b/src/exo/shared/types/worker/runner_response.py index 9fb3301904..e25755740c 100644 --- a/src/exo/shared/types/worker/runner_response.py +++ b/src/exo/shared/types/worker/runner_response.py @@ -25,6 +25,9 @@ class GenerationResponse(BaseRunnerResponse): stats: GenerationStats | None = None usage: Usage | None is_thinking: bool = False + # The stop sequence that terminated generation, when ``finish_reason`` is + # "stop" because a user-supplied stop sequence matched (vs. a natural EOS). + matched_stop_sequence: str | None = None class ImageGenerationResponse(BaseRunnerResponse): diff --git a/src/exo/worker/engines/mlx/generator/batch_generate.py b/src/exo/worker/engines/mlx/generator/batch_generate.py index 5c7394ddec..8fa14ca106 100644 --- a/src/exo/worker/engines/mlx/generator/batch_generate.py +++ b/src/exo/worker/engines/mlx/generator/batch_generate.py @@ -41,6 +41,7 @@ prefill, ) from exo.worker.engines.mlx.generator.remote_prefill import remote_prefill +from exo.worker.engines.mlx.generator.stop_sequences import scan_stop_sequences from exo.worker.engines.mlx.patches.opt_batch_gen import ( set_needs_topk, take_ready_topk, @@ -369,28 +370,29 @@ def step(self) -> list[tuple[int, GenerationResponse]]: f"[bench] uid={response.uid} tok#{state.completion_tokens} {text!r} t={delta:.4f}s" ) state.generated_text_parts.append(text) - state.potential_stop_sequence_text += text - finish_reason: FinishReason | None = cast( - FinishReason | None, response.finish_reason - ) + # Hold back any trailing partial stop-sequence match so a multi-token + # stop sequence never leaks its leading bytes into output. + model_finish_reason = cast(FinishReason | None, response.finish_reason) task_params = state.task_params stop_sequences = _stop_sequences(task_params) - max_stop_len = max((len(s) for s in stop_sequences), default=0) - - if stop_sequences: - for stop_seq in stop_sequences: - if stop_seq in state.potential_stop_sequence_text: - stop_index = state.potential_stop_sequence_text.find(stop_seq) - text_before_stop = state.potential_stop_sequence_text[ - :stop_index - ] - chunk_start = len(state.potential_stop_sequence_text) - len( - text - ) - text = text_before_stop[chunk_start:] - finish_reason = "stop" - break + + state.potential_stop_sequence_text += text + text, matched_stop_sequence, state.potential_stop_sequence_text = ( + scan_stop_sequences(state.potential_stop_sequence_text, stop_sequences) + ) + + finish_reason: FinishReason | None + if matched_stop_sequence is not None: + finish_reason = "stop" + elif model_finish_reason is not None: + # Natural EOS / length limit: flush any held-back text — it is + # real output that merely looked like a stop-sequence prefix. + text += state.potential_stop_sequence_text + state.potential_stop_sequence_text = "" + finish_reason = model_finish_reason + else: + finish_reason = None is_done = finish_reason is not None @@ -457,19 +459,13 @@ def step(self) -> list[tuple[int, GenerationResponse]]: finish_reason=finish_reason, stats=stats, usage=usage, + matched_stop_sequence=matched_stop_sequence, ), ) ) if is_done: del self._active_tasks[response.uid] - elif ( - max_stop_len > 0 - and len(state.potential_stop_sequence_text) > max_stop_len - ): - state.potential_stop_sequence_text = state.potential_stop_sequence_text[ - -max_stop_len: - ] _step_elapsed = time.perf_counter() - _step_tic _overhead = _step_elapsed - _next_elapsed diff --git a/src/exo/worker/engines/mlx/generator/generate.py b/src/exo/worker/engines/mlx/generator/generate.py index 2e3d051251..10e23382c0 100644 --- a/src/exo/worker/engines/mlx/generator/generate.py +++ b/src/exo/worker/engines/mlx/generator/generate.py @@ -56,6 +56,7 @@ MAX_TOKENS, ) from exo.worker.engines.mlx.generator.remote_prefill import remote_prefill +from exo.worker.engines.mlx.generator.stop_sequences import scan_stop_sequences from exo.worker.engines.mlx.types import KVCacheType, Model from exo.worker.engines.mlx.utils_mlx import ( apply_chat_template, @@ -624,7 +625,6 @@ def mlx_generate( if task.stop is not None else [] ) - max_stop_len = max((len(s) for s in stop_sequences), default=0) maybe_vision_ctx = ( patch_embed_tokens( @@ -710,7 +710,9 @@ def mlx_generate( last_token = prompt_tokens[-2:] max_tokens = task.max_output_tokens or MAX_TOKENS - accumulated_text = "" + # Text decoded but not yet emitted because it could be the start of a stop + # sequence spanning multiple tokens. See scan_stop_sequences. + pending_stop_text = "" generated_text_parts: list[str] = [] generation_start_time = time.perf_counter() usage: Usage | None = None @@ -733,26 +735,26 @@ def mlx_generate( start=1, ): generated_text_parts.append(out.text) - accumulated_text += out.text - # Check for stop sequences - text = out.text - finish_reason: FinishReason | None = cast( - FinishReason | None, out.finish_reason + # Check for stop sequences, holding back any trailing partial match so a + # multi-token stop sequence never leaks its leading bytes into output. + model_finish_reason = cast(FinishReason | None, out.finish_reason) + pending_stop_text += out.text + text, matched_stop_sequence, pending_stop_text = scan_stop_sequences( + pending_stop_text, stop_sequences ) - stop_matched = False - - if stop_sequences: - for stop_seq in stop_sequences: - if stop_seq in accumulated_text: - # Trim text to just before the stop sequence - stop_index = accumulated_text.find(stop_seq) - text_before_stop = accumulated_text[:stop_index] - chunk_start = len(accumulated_text) - len(out.text) - text = text_before_stop[chunk_start:] - finish_reason = "stop" - stop_matched = True - break + + finish_reason: FinishReason | None + if matched_stop_sequence is not None: + finish_reason = "stop" + elif model_finish_reason is not None: + # Natural EOS / length limit: flush any held-back text — it is real + # output that merely looked like the start of a stop sequence. + text += pending_stop_text + pending_stop_text = "" + finish_reason = model_finish_reason + else: + finish_reason = None is_done = finish_reason is not None @@ -765,7 +767,9 @@ def mlx_generate( generation_tokens=int(out.generation_tokens), peak_memory_usage=Memory.from_gb(out.peak_memory), ) - if not stop_matched and out.finish_reason not in get_args(FinishReason): + if matched_stop_sequence is None and out.finish_reason not in get_args( + FinishReason + ): logger.warning( f"Model generated unexpected finish_reason: {out.finish_reason}" ) @@ -816,12 +820,9 @@ def mlx_generate( finish_reason=finish_reason, stats=stats, usage=usage, + matched_stop_sequence=matched_stop_sequence, ) if is_done: mx_barrier(group) break - - # Limit accumulated_text to what's needed for stop sequence detection - if max_stop_len > 0 and len(accumulated_text) > max_stop_len: - accumulated_text = accumulated_text[-max_stop_len:] diff --git a/src/exo/worker/engines/mlx/generator/stop_sequences.py b/src/exo/worker/engines/mlx/generator/stop_sequences.py new file mode 100644 index 0000000000..f7c36b012d --- /dev/null +++ b/src/exo/worker/engines/mlx/generator/stop_sequences.py @@ -0,0 +1,60 @@ +"""Streaming-safe stop-sequence scanning. + +Stop sequences can span multiple decoded tokens, so a naive substring check on +each token's text leaks the leading bytes of a sequence before the whole +sequence has been generated (e.g. ``"END"`` arriving as ``"E"`` then ``"ND"`` +would stream the ``"E"`` before the match is recognised). The scanner below +holds back any trailing text that might still grow into a stop sequence until it +is known to be safe to emit. +""" + + +def scan_stop_sequences( + buffer: str, + stop_sequences: list[str], +) -> tuple[str, str | None, str]: + """Split a running text buffer around stop sequences. + + The caller appends each newly decoded chunk of text to the ``pending`` value + returned by the previous call and passes the result back in as ``buffer``. + + Returns ``(emit, matched, pending)``: + + - ``emit``: text that is safe to stream now. It contains everything before + the first fully matched stop sequence and excludes any trailing partial + match that could still grow into a stop sequence on the next token. + - ``matched``: the stop sequence that was fully matched, or ``None``. + - ``pending``: a trailing partial match to carry into the next call. Empty + when a sequence matched or when nothing needs to be held back. Bounded by + the length of the longest stop sequence minus one. + """ + # Ignore empty stop sequences — they would match everywhere and stop + # generation immediately with no output. + active = [stop_sequence for stop_sequence in stop_sequences if stop_sequence] + if not active: + return buffer, None, "" + + # Earliest fully matched stop sequence wins: generation stops there. + earliest_index = len(buffer) + matched: str | None = None + for stop_sequence in active: + index = buffer.find(stop_sequence) + if index != -1 and index < earliest_index: + earliest_index = index + matched = stop_sequence + if matched is not None: + return buffer[:earliest_index], matched, "" + + # No full match: hold back the longest suffix of the buffer that is a proper + # prefix of any stop sequence, since the next token might complete it. + hold = 0 + for stop_sequence in active: + max_partial = min(len(stop_sequence) - 1, len(buffer)) + for length in range(max_partial, hold, -1): + if stop_sequence.startswith(buffer[-length:]): + hold = length + break + if hold == 0: + return buffer, None, "" + split = len(buffer) - hold + return buffer[:split], None, buffer[split:] diff --git a/src/exo/worker/engines/mlx/generator/tests/test_stop_sequences.py b/src/exo/worker/engines/mlx/generator/tests/test_stop_sequences.py new file mode 100644 index 0000000000..06b5a1d4bf --- /dev/null +++ b/src/exo/worker/engines/mlx/generator/tests/test_stop_sequences.py @@ -0,0 +1,98 @@ +"""Tests for streaming-safe stop-sequence scanning. + +These cover the multi-token leak bug: a stop sequence spanning more than one +decoded token must not leak its leading bytes into the output before the whole +sequence is recognised. +""" + +from exo.worker.engines.mlx.generator.stop_sequences import scan_stop_sequences + + +def _run_stream(tokens: list[str], stop_sequences: list[str]) -> tuple[str, str | None]: + """Mirror the generator loop's use of scan_stop_sequences. + + Feeds ``tokens`` one at a time, accumulating the emitted text exactly as + generate.py / batch_generate.py do, and flushing any held-back text on a + natural end-of-stream. Returns ``(emitted_text, matched_stop_sequence)``. + """ + pending = "" + emitted: list[str] = [] + matched: str | None = None + for token in tokens: + pending += token + emit, hit, pending = scan_stop_sequences(pending, stop_sequences) + emitted.append(emit) + if hit is not None: + matched = hit + break + else: + # Natural EOS: flush whatever was held back — it is real output. + emitted.append(pending) + return "".join(emitted), matched + + +class TestScanStopSequencesContract: + def test_no_stop_sequences_passes_through(self): + assert scan_stop_sequences("hello", []) == ("hello", None, "") + + def test_empty_stop_sequence_is_ignored(self): + assert scan_stop_sequences("hello", [""]) == ("hello", None, "") + + def test_full_match_trims_and_reports(self): + assert scan_stop_sequences("abEND", ["END"]) == ("ab", "END", "") + + def test_match_inside_single_buffer(self): + assert scan_stop_sequences("Hello END world", ["END"]) == ( + "Hello ", + "END", + "", + ) + + def test_partial_suffix_is_held_back(self): + # "EN" could still grow into "END" on the next token, so hold it. + assert scan_stop_sequences("abEN", ["END"]) == ("ab", None, "EN") + + def test_non_matching_tail_is_not_held(self): + assert scan_stop_sequences("abXY", ["END"]) == ("abXY", None, "") + + def test_earliest_full_match_wins(self): + emit, matched, pending = scan_stop_sequences("aBBBcQQQ", ["QQQ", "BBB"]) + assert (emit, matched, pending) == ("a", "BBB", "") + + def test_longest_partial_across_sequences_is_held(self): + # "XY" is a prefix of "XYZ"; "Q" is unrelated. Hold "XY". + assert scan_stop_sequences("aXY", ["XYZ", "Q"]) == ("a", None, "XY") + + +class TestScanStopSequencesStreaming: + def test_multi_token_stop_does_not_leak(self): + # "END" arriving as "E" then "ND" must not emit the "E". + assert _run_stream(["E", "ND"], ["END"]) == ("", "END") + + def test_text_then_multi_token_stop(self): + assert _run_stream(["ABC", "DEF", "E", "ND"], ["END"]) == ("ABCDEF", "END") + + def test_single_token_stop(self): + assert _run_stream(["END"], ["END"]) == ("", "END") + + def test_false_partial_is_released_on_continuation(self): + # "E" is held, then "X" arrives proving it was not a stop sequence. + assert _run_stream(["E", "X", "tra"], ["END"]) == ("EXtra", None) + + def test_held_partial_is_flushed_on_natural_eos(self): + # Generation ends while a partial match is buffered: it is real output. + assert _run_stream(["partialEN"], ["END"]) == ("partialEN", None) + + def test_no_stop_present(self): + assert _run_stream(["foo", "bar"], ["END"]) == ("foobar", None) + + def test_stop_split_across_three_tokens(self): + assert _run_stream(["12", "34", "56"], ["3456"]) == ("12", "3456") + + def test_multiple_stops_second_matches(self): + assert _run_stream(["A", "B"], ["XYZ", "B"]) == ("A", "B") + + def test_pending_is_bounded_by_longest_stop(self): + # However long the run, never hold back more than len(stop) - 1. + _emit, _matched, pending = scan_stop_sequences("x" * 100 + "EN", ["END"]) + assert pending == "EN" diff --git a/src/exo/worker/runner/llm_inference/model_output_parsers.py b/src/exo/worker/runner/llm_inference/model_output_parsers.py index 4952688dce..3329c95079 100644 --- a/src/exo/worker/runner/llm_inference/model_output_parsers.py +++ b/src/exo/worker/runner/llm_inference/model_output_parsers.py @@ -142,6 +142,7 @@ def map_responses_to_chunks( logprob=response.logprob, top_logprobs=response.top_logprobs, is_thinking=response.is_thinking, + matched_stop_sequence=response.matched_stop_sequence, ) case ToolCallResponse(): return ToolCallChunk(