diff --git a/pyproject.toml b/pyproject.toml index 787073d2c..9a0c84d03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ dependencies = [ "loguru>=0.7.0", "tomli-w>=1.0.0", "renderers>=0.1.8.dev40", + "google-genai>=2.8.0", ] [dependency-groups] diff --git a/tests/v1/test_clients.py b/tests/v1/test_clients.py new file mode 100644 index 000000000..1b208e2bf --- /dev/null +++ b/tests/v1/test_clients.py @@ -0,0 +1,983 @@ +from unittest.mock import AsyncMock, Mock + +import pytest +import verifiers.v1 as vf +from anthropic.types import Message as AnthropicMessage +from google.genai import types as google_types +from openai.types.chat import ChatCompletion +from openai.types.responses import Response as OpenAIResponse +from pydantic import TypeAdapter + +from verifiers.v1 import graph +from verifiers.v1.clients.anthropic import ( + AnthropicMessagesClient, + content_to_wire as anthropic_content, + messages_to_wire as anthropic_messages, +) +from verifiers.v1.clients.anthropic import ( + response_from_wire as anthropic_response, +) +from verifiers.v1.clients.google import ( + GoogleResponsesClient, + content_to_wire as google_content, +) +from verifiers.v1.clients.google import messages_to_wire as google_messages +from verifiers.v1.clients.google import response_from_wire as google_response +from verifiers.v1.clients.openai import ( + OpenAIChatCompletionsClient, + content_to_wire as chat_content, + message_to_wire as chat_message, +) +from verifiers.v1.clients.openai_responses import ( + OpenAIResponsesClient, + content_to_wire as responses_content, +) +from verifiers.v1.clients.openai_responses import message_to_wire as responses_message +from verifiers.v1.clients.openai_responses import ( + response_from_wire as responses_response, +) +from verifiers.v1.interception.server import parse_message, serialize_completion +from verifiers.v1.types import content_to_parts + + +class AsyncItems: + def __init__(self, items=(), final=None): + self.items = iter(items) + self.final = final + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.items) + except StopIteration: + raise StopAsyncIteration + + async def __aenter__(self): + return self + + async def __aexit__(self, *exc): + return None + + async def get_final_message(self): + return self.final + + async def get_final_completion(self): + return self.final + + +def round_trip(response: vf.Response, prompt: vf.Messages) -> vf.AssistantMessage: + trace = vf.Trace(task=vf.Task(idx=0, instruction="use a tool")) + graph.add_turn(trace, prompt, response) + return vf.Trace.model_validate(trace.to_wire()).assistant_messages[0] + + +def image(url: str, detail: str | None = None) -> vf.ImageUrlContentPart: + return vf.ImageUrlContentPart( + image_url=vf.ImageUrlSource.model_validate({"url": url, "detail": detail}) + ) + + +def test_image_detail_round_trips_to_openai_clients(): + content = content_to_parts( + [ + { + "type": "image_url", + "image_url": {"url": "https://example.com/image.png", "detail": "high"}, + } + ] + ) + + assert isinstance(content, list) + assert chat_content(content)[0]["image_url"]["detail"] == "high" + assert responses_content(content)[0]["detail"] == "high" + assert chat_content([image("data:image/png;base64,aW1hZ2U=")])[0]["image_url"][ + "url" + ].startswith("data:image/png;base64,") + assert responses_content([image("data:image/png;base64,aW1hZ2U=")])[0][ + "image_url" + ].startswith("data:image/png;base64,") + assert ( + responses_content([image("https://example.com/image.png", "original")])[0][ + "detail" + ] + == "original" + ) + with pytest.raises(ValueError, match="does not support image detail"): + chat_content([image("https://example.com/image.png", "original")]) + + +def test_openai_system_image_support_matches_native_apis(): + system = vf.SystemMessage(content=[vf.TextContentPart(text="System")]) + assert chat_message(system)["content"] == [{"type": "text", "text": "System"}] + + system = vf.SystemMessage(content=[image("data:image/png;base64,aW1hZ2U=")]) + with pytest.raises(ValueError, match="system messages do not support images"): + chat_message(system) + assert responses_message(system)[0]["content"][0]["type"] == "input_image" + + +def test_anthropic_uses_native_url_and_base64_image_sources(): + content = anthropic_content( + [ + image("https://example.com/image.png"), + image("data:IMAGE/PNG;charset=utf-8;BASE64,aW1hZ2U="), + ] + ) + + assert isinstance(content, list) + assert content[0]["source"] == { + "type": "url", + "url": "https://example.com/image.png", + } + assert content[1]["source"] == { + "type": "base64", + "media_type": "image/png", + "data": "aW1hZ2U=", + } + with pytest.raises(ValueError, match="must be base64 encoded"): + anthropic_content([image("data:image/png,image")]) + + +def test_google_uses_inline_images(): + parts = google_content( + [ + image( + "data:IMAGE/PNG;charset=utf-8;BASE64,aW1hZ2U=", + "high", + ) + ] + ) + + assert parts[0].inline_data is not None + assert parts[0].inline_data.mime_type == "image/png" + assert parts[0].inline_data.data == b"image" + assert parts[0].media_resolution is not None + assert parts[0].media_resolution.level == ( + google_types.PartMediaResolutionLevel.MEDIA_RESOLUTION_HIGH + ) + with pytest.raises(ValueError, match="must use data URIs"): + google_content([image("https://example.com/image.png")]) + with pytest.raises(ValueError, match="must be base64 encoded"): + google_content([image("data:image/png,image")]) + + +def test_openai_responses_preserves_native_output(): + output = [ + { + "id": "reasoning_1", + "type": "reasoning", + "summary": [{"type": "summary_text", "text": "Need weather."}], + }, + { + "type": "function_call", + "call_id": "call_1", + "name": "weather", + "arguments": '{"city":"Berlin"}', + }, + ] + response = responses_response( + OpenAIResponse.model_validate( + { + "id": "resp_1", + "created_at": 0, + "model": "gpt-test", + "object": "response", + "status": "completed", + "output": output, + "parallel_tool_calls": False, + "tool_choice": "auto", + "tools": [], + "usage": { + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0}, + }, + } + ) + ) + + assistant = round_trip(response, [vf.UserMessage(content="Weather?")]) + + assert responses_message(assistant) == output + assert response.finish_reason == "tool_calls" + + +def test_openai_responses_accepts_reasoning_only(): + output = [ + { + "id": "reasoning_1", + "type": "reasoning", + "status": "incomplete", + "summary": [], + } + ] + response = responses_response( + OpenAIResponse.model_validate( + { + "id": "resp_1", + "created_at": 0, + "model": "gpt-test", + "object": "response", + "status": "incomplete", + "incomplete_details": {"reason": "max_output_tokens"}, + "output": output, + "parallel_tool_calls": False, + "tool_choice": "auto", + "tools": [], + "usage": { + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 5}, + }, + } + ) + ) + + assert response.message.content is None + assert response.message.reasoning_content is None + assert response.message.tool_calls is None + assert response.message.provider_state == output + assert response.finish_reason == "length" + + +def test_interception_preserves_reasoning_state(): + response = vf.Response( + id="response_1", + created=0, + model="model", + message=vf.AssistantMessage( + reasoning_content="Still working.", + provider_state=[{"type": "reasoning", "id": "reasoning_1"}], + ), + finish_reason="length", + ) + + completion = ChatCompletion.model_validate(serialize_completion(response, "model")) + raw = completion.choices[0].message.model_dump(mode="json", exclude_none=True) + message = parse_message(raw) + + assert isinstance(message, vf.AssistantMessage) + assert message.content is None + assert message.reasoning_content == "Still working." + assert message.provider_state == [{"type": "reasoning", "id": "reasoning_1"}] + assert message.tool_calls is None + + +def test_interception_preserves_provider_state_for_tool_continuation(): + provider_state = [ + {"type": "reasoning", "id": "reasoning_1", "summary": []}, + { + "type": "function_call", + "call_id": "call_1", + "name": "weather", + "arguments": '{"city":"Berlin"}', + }, + ] + response = vf.Response( + id="response_1", + created=0, + model="model", + message=vf.AssistantMessage( + provider_state=provider_state, + tool_calls=[ + vf.ToolCall( + id="call_1", + name="weather", + arguments='{"city":"Berlin"}', + ) + ], + ), + finish_reason="tool_calls", + ) + + completion = ChatCompletion.model_validate(serialize_completion(response, "model")) + raw = completion.choices[0].message.model_dump(mode="json", exclude_none=True) + message = parse_message(raw) + + assert isinstance(message, vf.AssistantMessage) + assert message.provider_state == provider_state + assert message.tool_calls == response.message.tool_calls + assert responses_message(message) == provider_state + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("prompt_token_ids", "expected_prompt_ids"), + [([1, 2], [1, 2]), (None, [])], +) +async def test_openai_chat_preserves_tokens_and_logprobs( + prompt_token_ids, expected_prompt_ids +): + completion = ChatCompletion.model_validate( + { + "id": "chatcmpl_1", + "created": 0, + "model": "vllm-test", + "object": "chat.completion", + "prompt_token_ids": prompt_token_ids, + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "token_ids": [3, 4], + "message": { + "role": "assistant", + "content": "Hi", + "reasoning": "Think", + "reasoning_details": [ + { + "type": "reasoning.text", + "text": "Think", + "format": "anthropic-claude-v1", + "index": 0, + } + ], + }, + "logprobs": { + "content": [ + { + "token": "H", + "bytes": [72], + "logprob": -0.1, + "top_logprobs": [], + }, + { + "token": "i", + "bytes": [105], + "logprob": -0.2, + "top_logprobs": [], + }, + ] + }, + } + ], + "usage": { + "prompt_tokens": 2, + "completion_tokens": 2, + "total_tokens": 4, + }, + } + ) + openai = Mock() + openai.chat.completions.create = AsyncMock(return_value=completion) + client = OpenAIChatCompletionsClient(openai) + + response = await client.get_response( + [vf.UserMessage(content="Hello")], + "vllm-test", + vf.SamplingConfig.model_validate( + { + "logprobs": True, + "extra_body": {"return_token_ids": True}, + } + ), + ) + + await_args = openai.chat.completions.create.await_args + assert await_args is not None + assert await_args.kwargs["logprobs"] is True + assert await_args.kwargs["extra_body"] == {"return_token_ids": True} + assert response.message.reasoning_content == "Think" + assert response.message.provider_state is not None + assert chat_message(response.message)["reasoning_details"] == ( + response.message.provider_state + ) + assert response.tokens is not None + assert response.tokens.prompt_ids == expected_prompt_ids + assert response.tokens.completion_ids == [3, 4] + assert response.tokens.completion_logprobs == [-0.1, -0.2] + + +@pytest.mark.asyncio +async def test_openai_chat_passes_reasoning_options_through(): + completion = ChatCompletion.model_validate( + { + "id": "chatcmpl_1", + "created": 0, + "model": "anthropic/claude-haiku-4.5", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": {"role": "assistant", "content": "Hi"}, + } + ], + } + ) + openai = Mock() + openai.chat.completions.create = AsyncMock(return_value=completion) + client = OpenAIChatCompletionsClient(openai) + + await client.get_response( + [vf.UserMessage(content="Hello")], + "anthropic/claude-opus-4.6", + vf.SamplingConfig.model_validate( + { + "reasoning_effort": "high", + "verbosity": "max", + "extra_body": {"reasoning": {"enabled": True}}, + } + ), + ) + + request = openai.chat.completions.create.await_args.kwargs + assert request["reasoning_effort"] == "high" + assert request["verbosity"] == "max" + assert request["extra_body"] == {"reasoning": {"enabled": True}} + + +@pytest.mark.asyncio +async def test_openai_chat_aggregates_stream(): + completion = ChatCompletion.model_validate( + { + "id": "chatcmpl_1", + "created": 0, + "model": "vllm-test", + "object": "chat.completion", + "prompt_token_ids": [1, 2], + "choices": [ + { + "index": 0, + "finish_reason": "tool_calls", + "token_ids": [3, 4], + "message": { + "role": "assistant", + "reasoning_content": "Think more", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "weather", + "arguments": '{"city":"Berlin"}', + }, + } + ], + }, + "logprobs": { + "content": [ + { + "token": "a", + "bytes": [97], + "logprob": -0.1, + "top_logprobs": [], + }, + { + "token": "b", + "bytes": [98], + "logprob": -0.2, + "top_logprobs": [], + }, + ] + }, + } + ], + "usage": { + "prompt_tokens": 2, + "completion_tokens": 2, + "total_tokens": 4, + }, + } + ) + openai = Mock() + openai.chat.completions.stream = Mock(return_value=AsyncItems(final=completion)) + openai.chat.completions.create = AsyncMock() + client = OpenAIChatCompletionsClient(openai) + + response = await client.get_response( + [vf.UserMessage(content="Weather?")], + "vllm-test", + vf.SamplingConfig.model_validate( + { + "stream": True, + "logprobs": True, + "extra_body": {"return_token_ids": True}, + } + ), + ) + + request = openai.chat.completions.stream.call_args.kwargs + assert "stream" not in request + assert request["stream_options"]["include_usage"] is True + assert response.message.reasoning_content == "Think more" + assert response.message.tool_calls == [ + vf.ToolCall( + id="call_1", + name="weather", + arguments='{"city":"Berlin"}', + ) + ] + assert response.finish_reason == "tool_calls" + assert response.usage == vf.Usage(prompt_tokens=2, completion_tokens=2) + assert response.tokens is not None + assert response.tokens.prompt_ids == [1, 2] + assert response.tokens.completion_ids == [3, 4] + assert response.tokens.completion_logprobs == [-0.1, -0.2] + openai.chat.completions.create.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_openai_responses_aggregates_stream(): + final = OpenAIResponse.model_validate( + { + "id": "resp_1", + "created_at": 0, + "model": "gpt-test", + "object": "response", + "status": "incomplete", + "incomplete_details": {"reason": "max_output_tokens"}, + "output": [ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "incomplete", + "content": [ + {"type": "output_text", "text": "Hi", "annotations": []} + ], + } + ], + "parallel_tool_calls": False, + "tool_choice": "auto", + "tools": [], + "usage": { + "input_tokens": 1, + "output_tokens": 1, + "total_tokens": 2, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0}, + }, + } + ) + event = Mock(type="response.incomplete", response=final) + openai = Mock() + openai.responses.create = AsyncMock(return_value=AsyncItems([event])) + client = OpenAIResponsesClient(openai) + + response = await client.get_response( + [vf.UserMessage(content="Hello")], + "gpt-test", + vf.SamplingConfig.model_validate({"stream": True}), + ) + + assert openai.responses.create.await_args.kwargs["stream"] is True + assert response.message.content == "Hi" + assert response.finish_reason == "length" + assert response.usage == vf.Usage(prompt_tokens=1, completion_tokens=1) + + +def test_anthropic_preserves_thinking_blocks(): + thinking = { + "type": "thinking", + "thinking": "Need weather.", + "signature": "signed-thinking", + } + response = anthropic_response( + AnthropicMessage.model_validate( + { + "id": "msg_1", + "model": "claude-test", + "role": "assistant", + "type": "message", + "stop_reason": "tool_use", + "content": [ + thinking, + { + "type": "tool_use", + "id": "call_1", + "name": "weather", + "input": {"city": "Berlin"}, + }, + ], + "usage": {"input_tokens": 10, "output_tokens": 5}, + } + ) + ) + + assistant = round_trip(response, [vf.UserMessage(content="Weather?")]) + prompt = anthropic_messages( + [ + assistant, + vf.ToolMessage(tool_call_id="call_1", content='{"temp":20}'), + ] + ) + + assert list(prompt[0]["content"])[0] == thinking + assert list(prompt[1]["content"])[0]["tool_use_id"] == "call_1" + + +@pytest.mark.parametrize( + ("block", "expected_reasoning"), + [ + ( + { + "type": "thinking", + "thinking": "Still working.", + "signature": "signed-thinking", + }, + "Still working.", + ), + ({"type": "redacted_thinking", "data": "opaque"}, None), + ], +) +def test_anthropic_accepts_reasoning_only(block, expected_reasoning): + response = anthropic_response( + AnthropicMessage.model_validate( + { + "id": "msg_1", + "model": "claude-test", + "role": "assistant", + "type": "message", + "stop_reason": "max_tokens", + "content": [block], + "usage": {"input_tokens": 10, "output_tokens": 5}, + } + ) + ) + + assert response.message.content is None + assert response.message.reasoning_content == expected_reasoning + assert response.message.tool_calls is None + assert response.message.provider_state == [block] + assert response.finish_reason == "length" + + +@pytest.mark.asyncio +async def test_anthropic_requires_max_tokens(): + anthropic = Mock() + anthropic.messages.create = AsyncMock() + client = AnthropicMessagesClient(anthropic) + + with pytest.raises(ValueError, match="requires max_tokens"): + await client.get_response( + [vf.UserMessage(content="Hello")], + "claude-test", + vf.SamplingConfig(), + ) + + +@pytest.mark.asyncio +async def test_anthropic_passes_native_options(): + anthropic = Mock() + anthropic.messages.create = AsyncMock( + return_value=AnthropicMessage.model_validate( + { + "id": "msg_1", + "model": "claude-test", + "role": "assistant", + "type": "message", + "stop_reason": "end_turn", + "content": [{"type": "text", "text": "Hi"}], + "usage": {"input_tokens": 1, "output_tokens": 1}, + } + ) + ) + client = AnthropicMessagesClient(anthropic) + + await client.get_response( + [ + vf.SystemMessage(content=[vf.TextContentPart(text="System")]), + vf.UserMessage(content="Hello"), + ], + "claude-test", + vf.SamplingConfig.model_validate( + { + "max_tokens": 100, + "thinking": {"type": "adaptive"}, + "output_config": {"effort": "high"}, + } + ), + ) + + await_args = anthropic.messages.create.await_args + assert await_args is not None + request = await_args.kwargs + assert request["system"] == "System" + assert request["max_tokens"] == 100 + assert request["thinking"] == {"type": "adaptive"} + assert request["output_config"] == {"effort": "high"} + + with pytest.raises(ValueError, match="system messages do not support images"): + await client.get_response( + [ + vf.SystemMessage(content=[image("data:image/png;base64,aW1hZ2U=")]), + vf.UserMessage(content="Hello"), + ], + "claude-test", + vf.SamplingConfig(max_tokens=100), + ) + + +@pytest.mark.asyncio +async def test_anthropic_aggregates_stream(): + final = AnthropicMessage.model_validate( + { + "id": "msg_1", + "model": "claude-test", + "role": "assistant", + "type": "message", + "stop_reason": "end_turn", + "content": [{"type": "text", "text": "Hi"}], + "usage": {"input_tokens": 1, "output_tokens": 1}, + } + ) + anthropic = Mock() + anthropic.messages.stream = Mock(return_value=AsyncItems(final=final)) + anthropic.messages.create = AsyncMock() + client = AnthropicMessagesClient(anthropic) + + response = await client.get_response( + [vf.UserMessage(content="Hello")], + "claude-test", + vf.SamplingConfig.model_validate({"max_tokens": 100, "stream": True}), + ) + + request = anthropic.messages.stream.call_args.kwargs + assert "stream" not in request + assert response.message.content == "Hi" + assert response.finish_reason == "stop" + anthropic.messages.create.assert_not_awaited() + + +def test_google_preserves_thought_signatures(): + part = { + "functionCall": { + "id": "call_1", + "name": "weather", + "args": {"city": "Berlin"}, + }, + "thoughtSignature": "c2lnbmVkLXRob3VnaHQ=", + } + response = google_response( + google_types.GenerateContentResponse.model_validate( + { + "responseId": "response_1", + "modelVersion": "gemini-test", + "candidates": [ + { + "finishReason": "STOP", + "content": {"role": "model", "parts": [part]}, + } + ], + "usageMetadata": {"promptTokenCount": 10, "totalTokenCount": 15}, + } + ), + "gemini-test", + ) + + assistant = round_trip(response, [vf.UserMessage(content="Weather?")]) + prompt = google_messages( + [ + assistant, + vf.ToolMessage(tool_call_id="call_1", content='{"temp":20}'), + ] + ) + + assert prompt[0].parts is not None + assert ( + prompt[0].parts[0].model_dump(mode="json", by_alias=True, exclude_none=True) + == part + ) + assert prompt[1].parts is not None + assert prompt[1].parts[0].function_response is not None + assert prompt[1].parts[0].function_response.name == "weather" + + +@pytest.mark.parametrize( + ("part", "expected_reasoning"), + [ + ({"text": "Still working.", "thought": True}, "Still working."), + ({"thoughtSignature": "c2lnbmVkLXRob3VnaHQ="}, None), + ], +) +def test_google_accepts_reasoning_only(part, expected_reasoning): + response = google_response( + google_types.GenerateContentResponse.model_validate( + { + "responseId": "response_1", + "modelVersion": "gemini-test", + "candidates": [ + { + "finishReason": "MAX_TOKENS", + "content": {"role": "model", "parts": [part]}, + } + ], + } + ), + "gemini-test", + ) + + assert response.message.content is None + assert response.message.reasoning_content == expected_reasoning + assert response.message.tool_calls is None + assert response.message.provider_state == [part] + assert response.finish_reason == "length" + + +@pytest.mark.asyncio +async def test_google_uses_native_config(): + response = google_types.GenerateContentResponse.model_validate( + { + "responseId": "response_1", + "modelVersion": "gemini-test", + "candidates": [ + { + "finishReason": "STOP", + "content": { + "role": "model", + "parts": [{"text": "Hi"}], + }, + } + ], + } + ) + google = Mock() + google.aio.models.generate_content = AsyncMock(return_value=response) + client = GoogleResponsesClient(google) + + await client.get_response( + [ + vf.SystemMessage(content="System"), + vf.UserMessage(content="Hello"), + ], + "google/gemini-test", + vf.SamplingConfig.model_validate( + { + "max_tokens": 100, + "stop": "done", + "top_k": 20, + "logprobs": True, + "top_logprobs": 5, + "thinking_config": {"thinking_budget": 1000}, + } + ), + tools=[ + vf.Tool( + name="weather", + description="Get weather", + parameters={"type": "object"}, + ) + ], + ) + + await_args = google.aio.models.generate_content.await_args + assert await_args is not None + request = await_args.kwargs + assert request["model"] == "gemini-test" + assert request["config"].max_output_tokens == 100 + assert request["config"].stop_sequences == ["done"] + assert request["config"].top_k == 20 + assert request["config"].response_logprobs is True + assert request["config"].logprobs == 5 + assert request["config"].system_instruction == "System" + assert request["config"].thinking_config.thinking_budget == 1000 + assert request["config"].tools[0].function_declarations[0].name == "weather" + + await client.get_response( + [ + vf.SystemMessage(content=[image("data:image/png;base64,aW1hZ2U=")]), + vf.UserMessage(content="Hello"), + ], + "google/gemini-test", + vf.SamplingConfig(), + ) + + system_instruction = google.aio.models.generate_content.await_args.kwargs[ + "config" + ].system_instruction + assert isinstance(system_instruction, list) + assert system_instruction[0].inline_data is not None + assert system_instruction[0].inline_data.data == b"image" + + +@pytest.mark.asyncio +async def test_google_aggregates_stream(): + chunks = [ + google_types.GenerateContentResponse.model_validate( + { + "responseId": "response_1", + "modelVersion": "gemini-test", + "candidates": [ + { + "index": 0, + "content": { + "role": "model", + "parts": [ + { + "text": "Think", + "thought": True, + "thoughtSignature": "c2lnbmVkLXRob3VnaHQ=", + } + ], + }, + } + ], + } + ), + google_types.GenerateContentResponse.model_validate( + { + "responseId": "response_1", + "modelVersion": "gemini-test", + "candidates": [ + { + "index": 0, + "finishReason": "STOP", + "content": { + "role": "model", + "parts": [{"text": "Hi"}], + }, + } + ], + "usageMetadata": { + "promptTokenCount": 1, + "candidatesTokenCount": 1, + "thoughtsTokenCount": 2, + "toolUsePromptTokenCount": 3, + "totalTokenCount": 7, + }, + } + ), + ] + google = Mock() + google.aio.models.generate_content_stream = AsyncMock( + return_value=AsyncItems(chunks) + ) + google.aio.models.generate_content = AsyncMock() + client = GoogleResponsesClient(google) + + response = await client.get_response( + [vf.UserMessage(content="Hello")], + "google/gemini-test", + vf.SamplingConfig.model_validate({"stream": True}), + ) + + request = google.aio.models.generate_content_stream.await_args.kwargs + assert "stream" not in request["config"].model_dump(exclude_none=True) + assert response.message.reasoning_content == "Think" + assert response.message.content == "Hi" + assert response.message.provider_state is not None + assert response.message.provider_state[0]["thoughtSignature"] == ( + "c2lnbmVkLXRob3VnaHQ=" + ) + assert response.finish_reason == "stop" + assert response.usage == vf.Usage(prompt_tokens=1, completion_tokens=1) + google.aio.models.generate_content.assert_not_awaited() + + +def test_client_config_protocols(): + adapter = TypeAdapter(vf.ClientConfig) + for protocol in ( + "openai", + "openai_responses", + "anthropic_messages", + "google_responses", + "renderers", + ): + assert adapter.validate_python({"type": protocol}).type == protocol diff --git a/tests/v1/test_graph.py b/tests/v1/test_graph.py index f9d6ddc6d..b3debb32f 100644 --- a/tests/v1/test_graph.py +++ b/tests/v1/test_graph.py @@ -1,3 +1,4 @@ +import pytest import verifiers.v1 as vf from verifiers.v1 import graph @@ -88,3 +89,32 @@ def test_prompt_supplied_assistant_messages_are_not_sampled_turns(): assert [n.sampled for n in trace.nodes] == [False, False, False, True] assert trace.num_turns == 1 assert trace.assistant_messages == [response] + + +@pytest.mark.parametrize( + "message", + [ + vf.AssistantMessage(content="answer"), + vf.AssistantMessage(reasoning_content="thinking"), + vf.AssistantMessage(provider_state=[{"type": "reasoning"}]), + vf.AssistantMessage( + tool_calls=[vf.ToolCall(id="call_0", name="lookup", arguments="{}")] + ), + ], +) +def test_trace_has_response_for_any_model_output(message): + trace = vf.Trace(task=vf.Task(idx=0, instruction="question")) + graph.add_turn(trace, [vf.UserMessage(content="question")], _response(message)) + + assert trace.has_response + + +def test_trace_has_no_response_for_empty_assistant_message(): + trace = vf.Trace(task=vf.Task(idx=0, instruction="question")) + graph.add_turn( + trace, + [vf.UserMessage(content="question")], + _response(vf.AssistantMessage()), + ) + + assert not trace.has_response diff --git a/verifiers/v1/README.md b/verifiers/v1/README.md index c4862a06f..e0b46ab50 100644 --- a/verifiers/v1/README.md +++ b/verifiers/v1/README.md @@ -160,10 +160,13 @@ uv run eval wiki-search-v1 -n 1 --harness.id compact # fresh prompt each turn ### Clients The client sits *behind* the interception server, so the harness only ever speaks plain -chat-completions: +chat-completions. The model endpoint can use any built-in provider protocol: ```bash -uv run eval gsm8k-v1 -n 1 # openai (default): text in / text out +uv run eval gsm8k-v1 -n 1 # OpenAI chat completions (default) +uv run eval gsm8k-v1 -n 1 --client.type openai_responses +uv run eval gsm8k-v1 -n 1 --client.type anthropic_messages --sampling.max-tokens 16384 +uv run eval gsm8k-v1 -n 1 --client.type google_responses uv run eval gsm8k-v1 -n 1 --client.type renderers \ # renderers: client-side tokenization → --client.base-url http://localhost:8000/v1 # token-in/out traces (needs a vLLM engine) ``` diff --git a/verifiers/v1/clients/__init__.py b/verifiers/v1/clients/__init__.py index d21301ad7..045cee36e 100644 --- a/verifiers/v1/clients/__init__.py +++ b/verifiers/v1/clients/__init__.py @@ -1,14 +1,20 @@ -"""The client abstraction and its OpenAI-compatible + renderer implementations.""" +"""The client abstraction and built-in provider implementations.""" +from verifiers.v1.clients.anthropic import AnthropicMessagesClient from verifiers.v1.clients.client import Client, RetryingClient, RolloutContext from verifiers.v1.clients.config import ( + AnthropicMessagesClientConfig, BaseClientConfig, ClientConfig, + GoogleResponsesClientConfig, OpenAIClientConfig, + OpenAIResponsesClientConfig, RendererClientConfig, resolve_client, ) +from verifiers.v1.clients.google import GoogleResponsesClient from verifiers.v1.clients.openai import OpenAIChatCompletionsClient +from verifiers.v1.clients.openai_responses import OpenAIResponsesClient from verifiers.v1.clients.renderer import RendererClient __all__ = [ @@ -18,8 +24,14 @@ "BaseClientConfig", "ClientConfig", "OpenAIClientConfig", + "OpenAIResponsesClientConfig", + "AnthropicMessagesClientConfig", + "GoogleResponsesClientConfig", "RendererClientConfig", "resolve_client", "OpenAIChatCompletionsClient", + "OpenAIResponsesClient", + "AnthropicMessagesClient", + "GoogleResponsesClient", "RendererClient", ] diff --git a/verifiers/v1/clients/anthropic.py b/verifiers/v1/clients/anthropic.py new file mode 100644 index 000000000..419691410 --- /dev/null +++ b/verifiers/v1/clients/anthropic.py @@ -0,0 +1,237 @@ +"""Anthropic Messages API client.""" + +import json +import time +from typing import Any, cast + +from anthropic import AnthropicError, AsyncAnthropic +from anthropic.types import ( + Base64ImageSourceParam, + ContentBlockParam, + ImageBlockParam, + Message as AnthropicMessage, + MessageParam, + RedactedThinkingBlock, + TextBlock, + TextBlockParam, + ThinkingBlock, + ToolParam as AnthropicTool, + ToolResultBlockParam, + ToolUseBlock, + ToolUseBlockParam, + URLImageSourceParam, +) + +from verifiers.v1.clients.client import Client +from verifiers.v1.errors import ModelError +from verifiers.v1.types import ( + AssistantMessage, + FinishReason, + Message, + Messages, + Response, + SamplingConfig, + SystemMessage, + TextContentPart, + Tool, + ToolCall, + ToolMessage, + Usage, + UserMessage, +) + + +FINISH_REASONS: dict[str, FinishReason] = { + "end_turn": "stop", + "stop_sequence": "stop", + "max_tokens": "length", + "tool_use": "tool_calls", + "refusal": "stop", +} + + +def content_to_wire(content) -> str | list[ContentBlockParam]: + if isinstance(content, str): + return content + parts: list[ContentBlockParam] = [] + for part in content: + if isinstance(part, TextContentPart): + parts.append(TextBlockParam(type="text", text=part.text)) + continue + url = part.image_url.url + if url.startswith("data:"): + metadata, data = url.removeprefix("data:").split(",", 1) + media_type, *parameters = metadata.split(";") + if not any(parameter.lower() == "base64" for parameter in parameters): + raise ValueError("Anthropic image data URIs must be base64 encoded") + source: Base64ImageSourceParam | URLImageSourceParam = ( + Base64ImageSourceParam( + type="base64", media_type=cast(Any, media_type.lower()), data=data + ) + ) + else: + source = URLImageSourceParam(type="url", url=url) + parts.append(ImageBlockParam(type="image", source=source)) + return parts + + +def system_to_wire(messages: Messages) -> str: + """Join system messages into Anthropic's top-level `system` string.""" + texts: list[str] = [] + for message in messages: + if not isinstance(message, SystemMessage): + continue + if isinstance(message.content, str): + texts.append(message.content) + elif all(isinstance(part, TextContentPart) for part in message.content): + texts.append( + "".join( + part.text + for part in message.content + if isinstance(part, TextContentPart) + ) + ) + else: + raise ValueError("Anthropic system messages do not support images") + return "\n\n".join(texts) + + +def message_to_wire(message: Message) -> MessageParam | None: + if isinstance(message, SystemMessage): + return None + if isinstance(message, UserMessage): + return MessageParam(role="user", content=content_to_wire(message.content)) + if isinstance(message, ToolMessage): + return MessageParam( + role="user", + content=[ + ToolResultBlockParam( + type="tool_result", + tool_use_id=message.tool_call_id, + content=message.content, + ) + ], + ) + assert isinstance(message, AssistantMessage) + content = [cast(ContentBlockParam, block) for block in message.provider_state or []] + if message.content: + content.append(TextBlockParam(type="text", text=message.content)) + for call in message.tool_calls or []: + content.append( + ToolUseBlockParam( + type="tool_use", + id=call.id, + name=call.name, + input=json.loads(call.arguments), + ) + ) + return MessageParam(role="assistant", content=content) + + +def messages_to_wire(messages: Messages) -> list[MessageParam]: + """Convert the prompt, folding consecutive tool results into one user message + (Anthropic requires tool results as blocks of the following user turn).""" + prompt: list[MessageParam] = [] + for message in messages: + wire = message_to_wire(message) + if wire is None: # system messages go in the top-level `system` field + continue + last_content = prompt[-1]["content"] if prompt else None + if ( + isinstance(message, ToolMessage) + and isinstance(last_content, list) + and prompt[-1]["role"] == "user" + ): + last_content.extend(cast(list[ContentBlockParam], wire["content"])) + else: + prompt.append(wire) + return prompt + + +def response_from_wire(response: AnthropicMessage) -> Response: + content = "" + reasoning = "" + thinking: list[dict] = [] + tool_calls: list[ToolCall] = [] + for block in response.content: + if isinstance(block, TextBlock): + content += block.text + elif isinstance(block, ThinkingBlock): + thinking.append(block.model_dump(mode="json")) + reasoning += block.thinking + elif isinstance(block, RedactedThinkingBlock): + thinking.append(block.model_dump(mode="json")) + elif isinstance(block, ToolUseBlock): + tool_calls.append( + ToolCall( + id=block.id, + name=block.name, + arguments=json.dumps(block.input), + ) + ) + if not content and not thinking and not tool_calls: + raise ModelError("Anthropic Messages returned no output") + return Response( + id=response.id, + created=int(time.time()), + model=response.model, + message=AssistantMessage( + content=content or None, + reasoning_content=reasoning or None, + tool_calls=tool_calls or None, + provider_state=thinking or None, + ), + finish_reason=FINISH_REASONS.get(response.stop_reason or ""), + usage=Usage( + prompt_tokens=response.usage.input_tokens, + completion_tokens=response.usage.output_tokens, + ), + ) + + +class AnthropicMessagesClient(Client): + def __init__(self, anthropic: AsyncAnthropic) -> None: + self.anthropic = anthropic + + async def get_response( + self, + prompt: Messages, + model: str, + sampling_args: SamplingConfig, + tools: list[Tool] | None = None, + ) -> Response: + sampling: dict[str, Any] = sampling_args.model_dump(exclude_none=True) + streaming = bool(sampling.pop("stream", False)) + sampling.pop("n", None) # Anthropic has no n parameter + if "max_tokens" not in sampling: + raise ValueError("Anthropic Messages requires max_tokens") + if stop := sampling.pop("stop", None): + sampling["stop_sequences"] = [stop] if isinstance(stop, str) else stop + body: dict[str, Any] = { + "model": model, + "messages": messages_to_wire(prompt), + **sampling, + } + if system := system_to_wire(prompt): + body["system"] = system + if tools: + body["tools"] = [ + AnthropicTool( + name=tool.name, + description=tool.description, + input_schema=tool.parameters, + ) + for tool in tools + ] + try: + if streaming: + async with self.anthropic.messages.stream(**body) as stream: + response = await stream.get_final_message() + else: + response = await self.anthropic.messages.create(**body) + except AnthropicError as e: + raise ModelError(str(e)) from e + return response_from_wire(response) + + async def close(self) -> None: + await self.anthropic.close() diff --git a/verifiers/v1/clients/config.py b/verifiers/v1/clients/config.py index 2608bee32..ca7f1823d 100644 --- a/verifiers/v1/clients/config.py +++ b/verifiers/v1/clients/config.py @@ -1,23 +1,27 @@ -"""Client configs: describe an OpenAI-compatible endpoint and resolve it to a Client. - -A `BaseClientConfig` is an OpenAI-compatible endpoint (base_url + API-key env var -+ extra headers) that `resolve_client` turns into a `Client`. Prime team-billing -is baked in via a validator, so it's handled in one place. Both the eval entrypoint -(its model client) and in-env LLM calls (e.g. a judge reward) build clients from -these — inherit `BaseClientConfig` to get the endpoint/header handling for free. -`ClientConfig` is the CLI-selectable discriminated union (openai | renderer). +"""Client configs: describe a model endpoint and resolve it to a Client. + +A `BaseClientConfig` is an endpoint (base_url + API-key env var + extra headers) +that `resolve_client` turns into a `Client`. Prime team-billing is baked in via a +validator, so it's handled in one place. Both the eval entrypoint (its model client) +and in-env LLM calls (e.g. a judge reward) build clients from these. """ import os from typing import Annotated, Literal +from anthropic import AsyncAnthropic +from google import genai +from google.genai import types as google_types from openai import AsyncOpenAI from pydantic import Field, model_validator from pydantic_config import BaseConfig from renderers import RendererConfig +from verifiers.v1.clients.anthropic import AnthropicMessagesClient from verifiers.v1.clients.client import Client +from verifiers.v1.clients.google import GoogleResponsesClient from verifiers.v1.clients.openai import OpenAIChatCompletionsClient +from verifiers.v1.clients.openai_responses import OpenAIResponsesClient from verifiers.v1.clients.renderer import RendererClient PRIME_INFERENCE_HOST = "pinference.ai" @@ -25,7 +29,7 @@ class BaseClientConfig(BaseConfig): - """An OpenAI-compatible endpoint. The API key is read from an env var.""" + """A model endpoint. The API key is read from an env var.""" base_url: str = "https://api.pinference.ai/api/v1" api_key_var: str = "PRIME_API_KEY" @@ -48,6 +52,28 @@ class OpenAIClientConfig(BaseClientConfig): type: Literal["openai"] = "openai" +class OpenAIResponsesClientConfig(BaseClientConfig): + """An OpenAI-compatible Responses API endpoint.""" + + type: Literal["openai_responses"] = "openai_responses" + + +class AnthropicMessagesClientConfig(BaseClientConfig): + """The Anthropic Messages API.""" + + type: Literal["anthropic_messages"] = "anthropic_messages" + base_url: str = "https://api.anthropic.com" + api_key_var: str = "ANTHROPIC_API_KEY" + + +class GoogleResponsesClientConfig(BaseClientConfig): + """The Google Gemini generateContent API.""" + + type: Literal["google_responses"] = "google_responses" + base_url: str = "https://generativelanguage.googleapis.com/" + api_key_var: str = "GEMINI_API_KEY" + + class RendererClientConfig(BaseClientConfig): """A vLLM `/inference/v1/generate` endpoint with client-side tokenization, so responses carry token ids + logprobs. Needs a running vLLM engine.""" @@ -68,7 +94,12 @@ class RendererClientConfig(BaseClientConfig): # Discriminated union for a CLI-selectable client (`--client.type renderers`). ClientConfig = Annotated[ - OpenAIClientConfig | RendererClientConfig, Field(discriminator="type") + OpenAIClientConfig + | OpenAIResponsesClientConfig + | AnthropicMessagesClientConfig + | GoogleResponsesClientConfig + | RendererClientConfig, + Field(discriminator="type"), ] @@ -87,4 +118,25 @@ def make_openai_client(config: BaseClientConfig) -> AsyncOpenAI: config=config.renderer, renderer_model_name=config.renderer_model_name, ) + if isinstance(config, OpenAIResponsesClientConfig): + return OpenAIResponsesClient(make_openai_client(config)) + if isinstance(config, AnthropicMessagesClientConfig): + return AnthropicMessagesClient( + AsyncAnthropic( + base_url=config.base_url, + api_key=os.environ.get(config.api_key_var, "EMPTY"), + default_headers=config.headers or None, + ) + ) + if isinstance(config, GoogleResponsesClientConfig): + return GoogleResponsesClient( + genai.Client( + api_key=os.environ.get(config.api_key_var, "EMPTY"), + http_options=google_types.HttpOptions( + base_url=config.base_url, + api_version="v1beta", + headers=config.headers or None, + ), + ) + ) return OpenAIChatCompletionsClient(make_openai_client(config)) diff --git a/verifiers/v1/clients/google.py b/verifiers/v1/clients/google.py new file mode 100644 index 000000000..2de692a70 --- /dev/null +++ b/verifiers/v1/clients/google.py @@ -0,0 +1,302 @@ +"""Google Gemini generateContent client.""" + +import base64 +import json +import time +from typing import Any, cast + +from google import genai +from google.genai import errors, types + +from verifiers.v1.clients.client import Client +from verifiers.v1.errors import ModelError +from verifiers.v1.types import ( + AssistantMessage, + FinishReason, + Messages, + Response, + SamplingConfig, + SystemMessage, + TextContentPart, + Tool, + ToolCall, + ToolMessage, + Usage, + UserMessage, +) + + +MEDIA_RESOLUTIONS = { + "low": types.PartMediaResolutionLevel.MEDIA_RESOLUTION_LOW, + "high": types.PartMediaResolutionLevel.MEDIA_RESOLUTION_HIGH, + "original": types.PartMediaResolutionLevel.MEDIA_RESOLUTION_ULTRA_HIGH, +} + + +def content_to_wire(content) -> list[types.Part]: + if isinstance(content, str): + return [types.Part.from_text(text=content)] + parts: list[types.Part] = [] + for part in content: + if isinstance(part, TextContentPart): + parts.append(types.Part.from_text(text=part.text)) + continue + url = part.image_url.url + if not url.startswith("data:"): + raise ValueError("Google images must use data URIs") + metadata, data = url.removeprefix("data:").split(",", 1) + mime_type, *parameters = metadata.split(";") + if not any(parameter.lower() == "base64" for parameter in parameters): + raise ValueError("Google image data URIs must be base64 encoded") + parts.append( + types.Part.from_bytes( + data=base64.b64decode(data), + mime_type=mime_type.lower(), + media_resolution=MEDIA_RESOLUTIONS.get(part.image_url.detail), + ) + ) + return parts + + +def system_to_wire(messages: Messages) -> str | list[types.Part] | None: + """Google takes system instructions in the request config: joined text when all + system messages are plain strings, content parts otherwise (e.g. images).""" + systems = [m for m in messages if isinstance(m, SystemMessage)] + if not systems: + return None + if all(isinstance(m.content, str) for m in systems): + return "\n\n".join(cast(str, m.content) for m in systems) + return [part for m in systems for part in content_to_wire(m.content)] + + +def assistant_to_wire(message: AssistantMessage) -> types.Content: + # Native parts carry thought signatures required on the next turn, so replay + # them unchanged when available; otherwise rebuild from the portable fields. + if message.provider_state: + parts = [types.Part.model_validate(part) for part in message.provider_state] + return types.Content(role="model", parts=parts) + parts: list[types.Part] = [] + if message.reasoning_content: + parts.append(types.Part(text=message.reasoning_content, thought=True)) + if message.content: + parts.append(types.Part.from_text(text=message.content)) + for call in message.tool_calls or []: + parts.append( + types.Part( + function_call=types.FunctionCall( + id=call.id, + name=call.name, + args=json.loads(call.arguments), + ) + ) + ) + return types.Content(role="model", parts=parts) + + +def tool_result_to_wire(message: ToolMessage, name: str) -> types.Part: + # Google wants a JSON object; wrap non-object results. + try: + result = json.loads(message.content) + except json.JSONDecodeError: + result = message.content + if not isinstance(result, dict): + result = {"result": result} + return types.Part( + function_response=types.FunctionResponse( + id=message.tool_call_id, + name=name, + response=result, + ) + ) + + +def messages_to_wire(messages: Messages) -> list[types.Content]: + """Convert the prompt, folding consecutive tool results into one user turn + (Google expects parallel function responses grouped in a single message).""" + prompt: list[types.Content] = [] + call_names: dict[str, str] = {} # function responses must repeat the call's name + for message in messages: + if isinstance(message, SystemMessage): # sent as system_instruction + continue + if isinstance(message, UserMessage): + prompt.append( + types.Content(role="user", parts=content_to_wire(message.content)) + ) + elif isinstance(message, AssistantMessage): + call_names.update({call.id: call.name for call in message.tool_calls or []}) + prompt.append(assistant_to_wire(message)) + else: + part = tool_result_to_wire(message, call_names[message.tool_call_id]) + last_parts = prompt[-1].parts if prompt else None + if ( + last_parts + and prompt[-1].role == "user" + and all(item.function_response for item in last_parts) + ): + last_parts.append(part) + else: + prompt.append(types.Content(role="user", parts=[part])) + return prompt + + +def merge_stream_chunks( + chunks: list[types.GenerateContentResponse], +) -> types.GenerateContentResponse: + """Combine streamed chunks into one response (Google streams response deltas + with no final-aggregate helper). The last chunk carries the metadata (usage, + finish reason); candidate zero's content parts concatenate across chunks.""" + if not chunks: + raise ModelError("Google stream returned no chunks") + response = chunks[-1].model_copy(deep=True) + candidates = [ + candidate + for chunk in chunks + for candidate in chunk.candidates or [] + if candidate.index in (None, 0) + ] + if not candidates: + return response + candidate = candidates[-1].model_copy(deep=True) + candidate.content = types.Content( + role=candidate.content.role if candidate.content else None, + parts=[ + part + for item in candidates + if item.content + for part in item.content.parts or [] + ], + ) + response.candidates = [candidate] + return response + + +def response_from_wire(response: types.GenerateContentResponse, model: str) -> Response: + if not response.candidates: + raise ModelError("Google returned no candidates") + candidate = response.candidates[0] + parts = candidate.content.parts if candidate.content else None + if not parts: + raise ModelError("Google returned no content") + + content = "" + reasoning = "" + has_reasoning = False + tool_calls: list[ToolCall] = [] + for part in parts: + has_reasoning = has_reasoning or bool(part.thought or part.thought_signature) + if part.text: + if part.thought: + reasoning += part.text + else: + content += part.text + if part.function_call and part.function_call.name: + tool_calls.append( + ToolCall( + id=part.function_call.id or f"call_{len(tool_calls)}", + name=part.function_call.name, + arguments=json.dumps(part.function_call.args or {}), + ) + ) + if not content and not has_reasoning and not tool_calls: + raise ModelError("Google returned no output") + + finish_reason: FinishReason = None + if tool_calls: + finish_reason = "tool_calls" + elif candidate.finish_reason == types.FinishReason.STOP: + finish_reason = "stop" + elif candidate.finish_reason == types.FinishReason.MAX_TOKENS: + finish_reason = "length" + + usage = response.usage_metadata + prompt_tokens = usage.prompt_token_count if usage else None + completion_tokens = usage.candidates_token_count if usage else None + return Response( + id=response.response_id or "", + created=( + int(response.create_time.timestamp()) + if response.create_time + else int(time.time()) + ), + model=response.model_version or model, + message=AssistantMessage( + content=content or None, + reasoning_content=reasoning or None, + tool_calls=tool_calls or None, + provider_state=[ + part.model_dump(mode="json", by_alias=True, exclude_none=True) + for part in parts + ], + ), + finish_reason=finish_reason, + usage=( + Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + if prompt_tokens is not None and completion_tokens is not None + else None + ), + ) + + +class GoogleResponsesClient(Client): + def __init__(self, google: genai.Client) -> None: + self.google = google + + async def get_response( + self, + prompt: Messages, + model: str, + sampling_args: SamplingConfig, + tools: list[Tool] | None = None, + ) -> Response: + sampling: dict[str, Any] = sampling_args.model_dump(exclude_none=True) + streaming = bool(sampling.pop("stream", False)) + if max_tokens := sampling.pop("max_tokens", None): + sampling["max_output_tokens"] = max_tokens + if n := sampling.pop("n", None): + sampling["candidate_count"] = n + if stop := sampling.pop("stop", None): + sampling["stop_sequences"] = [stop] if isinstance(stop, str) else stop + # OpenAI's `logprobs` flag is Google's `response_logprobs`; OpenAI's + # `top_logprobs` count is what Google calls `logprobs`. + logprobs = sampling.pop("logprobs", None) + if logprobs is not None: + sampling["response_logprobs"] = bool(logprobs) + if top_logprobs := sampling.pop("top_logprobs", None): + sampling["logprobs"] = top_logprobs + if (system := system_to_wire(prompt)) is not None: + sampling["system_instruction"] = system + if tools: + sampling["tools"] = [ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name=tool.name, + description=tool.description, + parameters_json_schema=tool.parameters, + ) + for tool in tools + ] + ) + ] + request = { + # eval model ids may carry a provider prefix (e.g. "google/gemini-...") + "model": model.rsplit("/", 1)[-1], + "contents": cast(types.ContentListUnion, messages_to_wire(prompt)), + "config": types.GenerateContentConfig.model_validate(sampling), + } + try: + if streaming: + stream = await self.google.aio.models.generate_content_stream(**request) + response = merge_stream_chunks([chunk async for chunk in stream]) + else: + response = await self.google.aio.models.generate_content(**request) + except errors.APIError as e: + raise ModelError(str(e)) from e + return response_from_wire(response, model) + + async def close(self) -> None: + await self.google.aio.aclose() + self.google.close() diff --git a/verifiers/v1/clients/openai.py b/verifiers/v1/clients/openai.py index 9529ef547..1b79988c2 100644 --- a/verifiers/v1/clients/openai.py +++ b/verifiers/v1/clients/openai.py @@ -1,14 +1,24 @@ -"""OpenAI-compatible chat-completions client. +"""OpenAI-compatible chat-completions client.""" -Distilled from v1's 545-line client: message<->wire translation, tool schemas, -best-effort reasoning_content. Sampling args pass straight through; when the -response carries vLLM's token ids + sampling logprobs (the caller asked for -`logprobs` and `return_token_ids`), we parse them into the response's `tokens` -so MITO training needs no renderer. Routed-experts/audio handling stays dropped. -This is the one place raw provider dicts cross into our typed `Response`. -""" +from typing import Any, cast from openai import AsyncOpenAI, OpenAIError +from openai.types.chat import ( + ChatCompletion, + ChatCompletionAssistantMessageParam, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartParam, + ChatCompletionContentPartTextParam, + ChatCompletionFunctionToolParam, + ChatCompletionMessageFunctionToolCall, + ChatCompletionMessageFunctionToolCallParam, + ChatCompletionMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionUserMessageParam, +) +from openai.types.chat.chat_completion import Choice +from openai.types.shared_params import FunctionDefinition from verifiers.v1.clients.client import Client from verifiers.v1.errors import ModelError, OverlongPromptError @@ -19,14 +29,16 @@ Messages, Response, SamplingConfig, + SystemMessage, + TextContentPart, Tool, ToolCall, + ToolMessage, TurnTokens, Usage, + UserMessage, ) -FINISH_REASONS = frozenset({"stop", "length", "tool_calls"}) - _CONTEXT_LENGTH_PHRASES = ( "this model's maximum context length is", "is longer than the model's context length", @@ -50,48 +62,80 @@ def model_error(e: OpenAIError) -> ModelError: return ModelError(str(e)) -def _content_to_wire(content): - """Plain text passes through; a content-part list becomes OpenAI wire dicts (so the - provider / renderer sees the native `image_url` shape).""" +def content_to_wire(content) -> str | list[ChatCompletionContentPartParam]: if isinstance(content, str): return content - return [part.model_dump() for part in content] - - -def message_to_wire(message: Message) -> dict: - if message.role == "assistant": - wire: dict = {"role": "assistant", "content": message.content} + parts: list[ChatCompletionContentPartParam] = [] + for part in content: + if isinstance(part, TextContentPart): + parts.append( + ChatCompletionContentPartTextParam(type="text", text=part.text) + ) + elif part.image_url.detail == "original": + raise ValueError("OpenAI Chat does not support image detail='original'") + else: + parts.append( + ChatCompletionContentPartImageParam( + type="image_url", + image_url=cast(Any, part.image_url.model_dump(exclude_none=True)), + ) + ) + return parts + + +def message_to_wire(message: Message) -> ChatCompletionMessageParam: + if isinstance(message, AssistantMessage): + wire = ChatCompletionAssistantMessageParam( + role="assistant", content=message.content + ) if message.tool_calls: wire["tool_calls"] = [ - { - "id": call.id, - "type": "function", - "function": {"name": call.name, "arguments": call.arguments}, - } + ChatCompletionMessageFunctionToolCallParam( + id=call.id, + type="function", + function={"name": call.name, "arguments": call.arguments}, + ) for call in message.tool_calls ] + if message.provider_state: + cast(Any, wire)["reasoning_details"] = message.provider_state return wire - if message.role == "tool": - return { - "role": "tool", - "tool_call_id": message.tool_call_id, - "content": message.content, - } - return {"role": message.role, "content": _content_to_wire(message.content)} + if isinstance(message, ToolMessage): + return ChatCompletionToolMessageParam( + role="tool", + tool_call_id=message.tool_call_id, + content=message.content, + ) + if isinstance(message, SystemMessage): + content = message.content + if not isinstance(content, str): + if any(not isinstance(part, TextContentPart) for part in content): + raise ValueError("OpenAI Chat system messages do not support images") + content = [ + ChatCompletionContentPartTextParam(type="text", text=part.text) + for part in content + if isinstance(part, TextContentPart) + ] + return ChatCompletionSystemMessageParam(role="system", content=content) + assert isinstance(message, UserMessage) + return ChatCompletionUserMessageParam( + role="user", + content=content_to_wire(message.content), + ) -def tool_to_wire(tool: Tool) -> dict: - function: dict = { - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters, - } +def tool_to_wire(tool: Tool) -> ChatCompletionFunctionToolParam: + function = FunctionDefinition( + name=tool.name, + description=tool.description, + parameters=tool.parameters, + ) if tool.strict is not None: function["strict"] = tool.strict - return {"type": "function", "function": function} + return ChatCompletionFunctionToolParam(type="function", function=function) -def tokens_from_wire(completion, choice) -> TurnTokens | None: +def tokens_from_wire(completion: ChatCompletion, choice: Choice) -> TurnTokens | None: """Parse vLLM's token ids + sampling logprobs into `TurnTokens`, for training. vLLM surfaces the completion ids on the choice (`return_token_ids`), the prompt @@ -99,26 +143,33 @@ def tokens_from_wire(completion, choice) -> TurnTokens | None: per generated token (`logprobs=True`). All are absent on providers that don't return them, so this is best-effort: no completion ids means no `tokens`. """ - completion_ids = getattr(choice, "token_ids", None) + completion_ids = cast(list[int] | None, (choice.model_extra or {}).get("token_ids")) if not completion_ids: return None content = choice.logprobs.content if choice.logprobs else None return TurnTokens( - prompt_ids=list(getattr(completion, "prompt_token_ids", None) or []), - completion_ids=list(completion_ids), + prompt_ids=cast( + list[int], (completion.model_extra or {}).get("prompt_token_ids") or [] + ), + completion_ids=completion_ids, completion_logprobs=[lp.logprob for lp in content] if content else [], ) -def response_from_wire(completion) -> Response: +def response_from_wire(completion: ChatCompletion) -> Response: choice = completion.choices[0] message = choice.message + extra = message.model_extra or {} tool_calls = [ ToolCall(id=tc.id, name=tc.function.name, arguments=tc.function.arguments) for tc in (message.tool_calls or []) + if isinstance(tc, ChatCompletionMessageFunctionToolCall) ] or None - finish: FinishReason = ( - choice.finish_reason if choice.finish_reason in FINISH_REASONS else None + finish = cast( + FinishReason, + choice.finish_reason + if choice.finish_reason in ("stop", "length", "tool_calls") + else None, ) usage = ( Usage( @@ -134,8 +185,14 @@ def response_from_wire(completion) -> Response: model=completion.model, message=AssistantMessage( content=message.content, - reasoning_content=getattr(message, "reasoning_content", None), + reasoning_content=cast( + str | None, + extra.get("reasoning_content") or extra.get("reasoning"), + ), tool_calls=tool_calls, + provider_state=cast( + list[dict[str, Any]] | None, extra.get("reasoning_details") + ), ), finish_reason=finish, usage=usage, @@ -154,15 +211,27 @@ async def get_response( sampling_args: SamplingConfig, tools: list[Tool] | None = None, ) -> Response: - body: dict = { + sampling: dict[str, Any] = sampling_args.model_dump(exclude_none=True) + streaming = bool(sampling.pop("stream", False)) + if streaming: + # Usage only arrives on the final chunk if asked for. + sampling["stream_options"] = { + "include_usage": True, + **(sampling.get("stream_options") or {}), + } + body: dict[str, Any] = { "model": model, "messages": [message_to_wire(m) for m in prompt], - **sampling_args.model_dump(exclude_none=True), + **sampling, } if tools: body["tools"] = [tool_to_wire(t) for t in tools] try: - completion = await self.openai.chat.completions.create(**body) + if streaming: + async with self.openai.chat.completions.stream(**body) as stream: + completion = await stream.get_final_completion() + else: + completion = await self.openai.chat.completions.create(**body) except OpenAIError as e: raise model_error(e) from e return response_from_wire(completion) diff --git a/verifiers/v1/clients/openai_responses.py b/verifiers/v1/clients/openai_responses.py new file mode 100644 index 000000000..0fa674c16 --- /dev/null +++ b/verifiers/v1/clients/openai_responses.py @@ -0,0 +1,211 @@ +"""OpenAI Responses API client.""" + +from typing import Any, cast + +from openai import AsyncOpenAI, OpenAIError +from openai.types.responses import ( + EasyInputMessageParam, + FunctionToolParam, + Response as OpenAIResponse, + ResponseFunctionToolCall, + ResponseFunctionToolCallParam, + ResponseInputContentParam, + ResponseInputImageParam, + ResponseInputItemParam, + ResponseInputTextParam, + ResponseOutputMessage, + ResponseOutputRefusal, + ResponseReasoningItem, +) +from openai.types.responses.response_input_param import FunctionCallOutput + +from verifiers.v1.clients.client import Client +from verifiers.v1.clients.openai import model_error +from verifiers.v1.errors import ModelError +from verifiers.v1.types import ( + AssistantMessage, + FinishReason, + Message, + Messages, + Response, + SamplingConfig, + SystemMessage, + Tool, + ToolCall, + ToolMessage, + Usage, + UserMessage, +) + +FINISH_REASONS: dict[str, FinishReason] = {"completed": "stop", "incomplete": "length"} +FINAL_EVENTS = ("response.completed", "response.incomplete", "response.failed") + + +def content_to_wire(content) -> str | list[ResponseInputContentParam]: + if isinstance(content, str): + return content + return [ + ResponseInputTextParam(type="input_text", text=part.text) + if part.type == "text" + else ResponseInputImageParam( + type="input_image", + image_url=part.image_url.url, + detail=part.image_url.detail or "auto", + ) + for part in content + ] + + +def message_to_wire(message: Message) -> list[ResponseInputItemParam]: + if isinstance(message, ToolMessage): + return [ + FunctionCallOutput( + type="function_call_output", + call_id=message.tool_call_id, + output=message.content, + ) + ] + if isinstance(message, AssistantMessage): + # Native output items carry reasoning and tool-call state required by the + # Responses API on the next turn, so replay them unchanged when available. + if message.provider_state: + return cast(list[ResponseInputItemParam], message.provider_state) + items: list[ResponseInputItemParam] = [] + if message.content: + items.append( + EasyInputMessageParam(role="assistant", content=message.content) + ) + items.extend( + ResponseFunctionToolCallParam( + type="function_call", + call_id=call.id, + name=call.name, + arguments=call.arguments, + ) + for call in message.tool_calls or [] + ) + return items + assert isinstance(message, (SystemMessage, UserMessage)) + return [ + EasyInputMessageParam( + role=message.role, + content=content_to_wire(message.content), + ) + ] + + +def response_from_wire(response: OpenAIResponse) -> Response: + content = response.output_text + reasoning: list[str] = [] + has_reasoning = False + tool_calls: list[ToolCall] = [] + # `output_text` is only the visible text. Inspect output items for refusals, + # reasoning summaries, tool calls, and the provider state needed for continuation. + for item in response.output: + if isinstance(item, ResponseOutputMessage): + content += "".join( + part.refusal + for part in item.content + if isinstance(part, ResponseOutputRefusal) + ) + elif isinstance(item, ResponseReasoningItem): + has_reasoning = True + reasoning += [part.text for part in item.summary] + reasoning += [part.text for part in item.content or []] + elif isinstance(item, ResponseFunctionToolCall): + tool_calls.append( + ToolCall( + id=item.call_id, + name=item.name, + arguments=item.arguments, + ) + ) + if not content and not has_reasoning and not tool_calls: + raise ModelError("OpenAI Responses returned no output") + + return Response( + id=response.id, + created=int(response.created_at), + model=response.model, + message=AssistantMessage( + content=content or None, + reasoning_content="\n".join(reasoning) or None, + tool_calls=tool_calls or None, + provider_state=[ + item.model_dump(mode="json", exclude_none=True) + for item in response.output + ], + ), + finish_reason=( + "tool_calls" if tool_calls else FINISH_REASONS.get(response.status or "") + ), + usage=( + Usage( + prompt_tokens=response.usage.input_tokens, + completion_tokens=response.usage.output_tokens, + ) + if response.usage + else None + ), + ) + + +class OpenAIResponsesClient(Client): + def __init__(self, openai: AsyncOpenAI) -> None: + self.openai = openai + + async def get_response( + self, + prompt: Messages, + model: str, + sampling_args: SamplingConfig, + tools: list[Tool] | None = None, + ) -> Response: + sampling: dict[str, Any] = sampling_args.model_dump(exclude_none=True) + streaming = bool(sampling.pop("stream", False)) + if max_tokens := sampling.pop("max_tokens", None): + sampling["max_output_tokens"] = max_tokens + if sampling.pop("stop", None): + raise ValueError("OpenAI Responses does not support stop sequences") + if sampling.pop("n", 1) != 1: + raise ValueError("OpenAI Responses only supports n=1") + body: dict[str, Any] = { + "model": model, + "input": [item for message in prompt for item in message_to_wire(message)], + **sampling, + } + if tools: + body["tools"] = [ + FunctionToolParam( + type="function", + name=tool.name, + description=tool.description, + parameters=tool.parameters, + strict=tool.strict, + ) + for tool in tools + ] + try: + if streaming: + # The SDK's final-response helper rejects valid incomplete + # responses, so pick the terminal event out of the stream ourselves. + response = None + async with await self.openai.responses.create( + **body, stream=True + ) as events: + async for event in events: + if event.type in FINAL_EVENTS: + response = event.response + break + if response is None: + raise ModelError( + "OpenAI Responses stream ended without a final response" + ) + else: + response = await self.openai.responses.create(**body) + except OpenAIError as e: + raise model_error(e) from e + return response_from_wire(response) + + async def close(self) -> None: + await self.openai.close() diff --git a/verifiers/v1/clients/renderer.py b/verifiers/v1/clients/renderer.py index 2bd9c8a17..3f515f13f 100644 --- a/verifiers/v1/clients/renderer.py +++ b/verifiers/v1/clients/renderer.py @@ -15,7 +15,7 @@ from renderers import RendererConfig from verifiers.v1.clients.client import Client -from verifiers.v1.clients.openai import FINISH_REASONS, model_error +from verifiers.v1.clients.openai import model_error from verifiers.v1.clients.openai import message_to_wire as chat_message_to_wire from verifiers.v1.clients.openai import tool_to_wire from verifiers.v1.errors import OverlongPromptError @@ -47,7 +47,7 @@ def response_from_generate(result: dict, model: str) -> Response: mirroring the chat client's `response_from_wire` (plus the token encoding).""" finish: FinishReason = ( result["finish_reason"] - if result.get("finish_reason") in FINISH_REASONS + if result.get("finish_reason") in ("stop", "length", "tool_calls") else None ) tool_calls = [ diff --git a/verifiers/v1/graph.py b/verifiers/v1/graph.py index 311a68d43..27e20232a 100644 --- a/verifiers/v1/graph.py +++ b/verifiers/v1/graph.py @@ -103,6 +103,13 @@ def message_hash(message: Message) -> str: if isinstance(message, AssistantMessage): if message.reasoning_content is not None: parts += ["reasoning_content", message.reasoning_content] + if message.provider_state is not None: + parts += [ + "provider_state", + json.dumps( + message.provider_state, sort_keys=True, separators=(",", ":") + ), + ] for tc in message.tool_calls or []: parts += [tc.id, tc.name, _canonical_tool_arguments(tc.arguments)] elif isinstance(message, ToolMessage): diff --git a/verifiers/v1/interception/server.py b/verifiers/v1/interception/server.py index 0a01b98dc..9c8a0669a 100644 --- a/verifiers/v1/interception/server.py +++ b/verifiers/v1/interception/server.py @@ -76,7 +76,10 @@ def parse_message(raw: dict) -> Message: for c in (raw.get("tool_calls") or []) ] or None return AssistantMessage( - content=_content_text(content) or None, tool_calls=calls + content=_content_text(content) or None, + reasoning_content=raw.get("reasoning_content") or raw.get("reasoning"), + tool_calls=calls, + provider_state=raw.get("reasoning_details") or raw.get("provider_state"), ) return UserMessage(content=content_to_parts(content)) @@ -99,6 +102,10 @@ def parse_tools(raw: list[dict] | None) -> list[Tool] | None: def serialize_completion(response: Response, model: str) -> dict: """A `Response` -> an OpenAI chat.completion dict the program's SDK expects.""" message: dict = {"role": "assistant", "content": response.message.content} + if response.message.reasoning_content is not None: + message["reasoning_content"] = response.message.reasoning_content + if response.message.provider_state is not None: + message["reasoning_details"] = response.message.provider_state if response.message.tool_calls: message["tool_calls"] = [ { diff --git a/verifiers/v1/trace.py b/verifiers/v1/trace.py index f3a97c9e2..7c5d61d2f 100644 --- a/verifiers/v1/trace.py +++ b/verifiers/v1/trace.py @@ -234,9 +234,17 @@ def total_tokens(self) -> int: @property def has_response(self) -> bool: - """Whether the most recent assistant message produced non-empty content.""" + """Whether the most recent assistant message produced any output.""" last = self._last_assistant() - return bool(last and last.message.content) + return bool( + last + and ( + last.message.content + or last.message.reasoning_content + or last.message.provider_state + or last.message.tool_calls + ) + ) @property def branches(self) -> list[Branch]: diff --git a/verifiers/v1/types.py b/verifiers/v1/types.py index b5bb2c4b5..cc0979c80 100644 --- a/verifiers/v1/types.py +++ b/verifiers/v1/types.py @@ -32,6 +32,7 @@ class ImageUrlSource(StrictBaseModel): """An image reference — a URL or a `data:` URI.""" url: str + detail: Literal["auto", "low", "high", "original"] | None = None class ImageUrlContentPart(StrictBaseModel): @@ -61,8 +62,15 @@ def content_to_parts(content) -> MessageContent: if p.get("type") == "text": parts.append(TextContentPart(text=p.get("text", ""))) elif p.get("type") == "image_url": - url = (p.get("image_url") or {}).get("url", "") - parts.append(ImageUrlContentPart(image_url=ImageUrlSource(url=url))) + image_url = p.get("image_url") or {} + parts.append( + ImageUrlContentPart( + image_url=ImageUrlSource( + url=image_url.get("url", ""), + detail=image_url.get("detail"), + ) + ) + ) return parts @@ -96,6 +104,8 @@ class AssistantMessage(StrictBaseModel): content: str | None = None reasoning_content: str | None = None tool_calls: list[ToolCall] | None = None + provider_state: list[dict[str, Any]] | None = None + """JSON provider data required to continue a native multi-turn exchange.""" class ToolMessage(StrictBaseModel):