diff --git a/slime/agent/adapters/common.py b/slime/agent/adapters/common.py index ed5d2e06a4..ad30e89677 100644 --- a/slime/agent/adapters/common.py +++ b/slime/agent/adapters/common.py @@ -22,6 +22,8 @@ SGLANG_URL_KEY = web.AppKey("sglang_url", object) TOOL_PARSER_KEY = web.AppKey("tool_parser", object) REASONING_PARSER_KEY = web.AppKey("reasoning_parser", object) +# Pooled per-app SGLang client, registered by sglang_client_context and reused across turns. +SGLANG_CLIENT_KEY = web.AppKey("sglang_client", aiohttp.ClientSession) @dataclasses.dataclass @@ -51,6 +53,7 @@ def __init__(self, *, tokenizer, sglang_url, tool_parser=None, reasoning_parser= self.app[SGLANG_URL_KEY] = sglang_url.rstrip("/") if isinstance(sglang_url, str) else sglang_url self.app[TOOL_PARSER_KEY] = tool_parser self.app[REASONING_PARSER_KEY] = reasoning_parser + self.app.cleanup_ctx.append(sglang_client_context) def open_session( self, @@ -174,6 +177,42 @@ def _sampling_params(session: Any, body: dict, *, max_token_keys: tuple[str, ... return sp +async def sglang_client_context(app: web.Application): + # limit=0: connection concurrency is governed by the rollout scheduler, not the pool. + connector = aiohttp.TCPConnector(limit=0, ttl_dns_cache=300, keepalive_timeout=60) + timeout = aiohttp.ClientTimeout(total=None, sock_read=900) + app[SGLANG_CLIENT_KEY] = aiohttp.ClientSession(timeout=timeout, connector=connector) + try: + yield + finally: + await app[SGLANG_CLIENT_KEY].close() + + +async def _post_sglang_generate( + client: aiohttp.ClientSession, + sglang_url: str, + *, + rid: str, + prompt_ids: list[int], + sampling_params: dict, + headers: dict[str, str] | None, +) -> dict: + async with client.post( + f"{sglang_url}/generate", + json={ + "rid": rid, + "input_ids": prompt_ids, + "sampling_params": sampling_params, + "return_logprob": True, + }, + headers=headers, + ) as r: + if r.status >= 400: + text = await r.text() + raise RuntimeError(f"sglang upstream {r.status}: {text[:400]}") + return await r.json(content_type=None) + + async def call_sglang_generate( prompt_ids: list[int], session: Any, @@ -201,24 +240,22 @@ async def call_sglang_generate( sp["max_new_tokens"] = min(int(sp.get("max_new_tokens", remaining_context)), remaining_context) sglang_url = app[SGLANG_URL_KEY] + client = app[SGLANG_CLIENT_KEY] rid = uuid.uuid4().hex headers = {"X-SMG-Routing-Key": session_id} if session_id and session_id != "default" else None - timeout = aiohttp.ClientTimeout(total=None, sock_read=900) try: - async with aiohttp.ClientSession(timeout=timeout) as sess, sess.post( - f"{sglang_url}/generate", - json={ - "rid": rid, - "input_ids": prompt_ids, - "sampling_params": sp, - "return_logprob": True, - }, - headers=headers, - ) as r: - if r.status >= 400: - text = await r.text() - raise RuntimeError(f"sglang upstream {r.status}: {text[:400]}") - data = await r.json(content_type=None) + try: + data = await _post_sglang_generate( + client, sglang_url, rid=rid, prompt_ids=prompt_ids, sampling_params=sp, headers=headers + ) + except aiohttp.ClientConnectorError: + # A connector error is raised before any request bytes reach SGLang, so retrying + # with the SAME rid cannot double-generate. Errors after the request may have + # reached the server are NOT retried and fall through to the abort path below. + logger.warning("[%s] retrying SGLang generate after connector failure", log_prefix) + data = await _post_sglang_generate( + client, sglang_url, rid=rid, prompt_ids=prompt_ids, sampling_params=sp, headers=headers + ) meta = data.get("meta_info") or {} output_token_logprobs = meta.get("output_token_logprobs") or [] output_ids = [x[1] for x in output_token_logprobs] @@ -226,8 +263,10 @@ async def call_sglang_generate( finish = (meta.get("finish_reason") or {}).get("type", "stop") or "stop" except (asyncio.CancelledError, aiohttp.ClientError, asyncio.TimeoutError): try: - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=5)) as s2: - await s2.post(f"{sglang_url}/abort_request", json={"rid": rid}) + async with client.post( + f"{sglang_url}/abort_request", json={"rid": rid}, timeout=aiohttp.ClientTimeout(total=5) + ): + pass except Exception: pass raise diff --git a/tests/test_agent_adapters.py b/tests/test_agent_adapters.py index 47906c2e82..38f2c3f877 100644 --- a/tests/test_agent_adapters.py +++ b/tests/test_agent_adapters.py @@ -3,6 +3,7 @@ import sys from pathlib import Path +import aiohttp import pytest from aiohttp import web from aiohttp.test_utils import TestClient, TestServer @@ -12,7 +13,7 @@ sys.path.insert(0, str(REPO_ROOT)) from slime.agent.adapters import anthropic, openai -from slime.agent.adapters.common import SGLANG_URL_KEY +from slime.agent.adapters.common import SGLANG_CLIENT_KEY, SGLANG_URL_KEY from slime.agent.trajectory import TurnRecord @@ -819,12 +820,16 @@ async def handle_generate(request): await server.start_server() try: session = openai.Session(sampling_defaults={"max_new_tokens": 9}) - turn = await openai._generate( - [11, 12], - session, - {"max_tokens": 3, "temperature": 0.25, "stop": [""]}, - {SGLANG_URL_KEY: str(server.make_url("")).rstrip("/")}, - ) + async with aiohttp.ClientSession() as client: + turn = await openai._generate( + [11, 12], + session, + {"max_tokens": 3, "temperature": 0.25, "stop": [""]}, + { + SGLANG_URL_KEY: str(server.make_url("")).rstrip("/"), + SGLANG_CLIENT_KEY: client, + }, + ) finally: await server.close() @@ -841,5 +846,131 @@ async def handle_generate(request): asyncio.run(run_case()) +class FakeConnectionKey: + host = "10.0.1.4" + port = 4049 + ssl = True + + +class FakeSGLangResponse: + status = 200 + + def __init__(self, token_id: int) -> None: + self.token_id = token_id + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def json(self, content_type=None): + return { + "meta_info": { + "output_token_logprobs": [[-0.1, self.token_id]], + "finish_reason": {"type": "stop"}, + } + } + + +class FakeSGLangClient: + # Stands in for the app-owned pooled client: each /generate post consumes one scripted + # outcome (a token id to return or an exception to raise); /abort_request always succeeds. + def __init__(self, generate_outcomes: list[int | Exception]) -> None: + self.generate_outcomes = list(generate_outcomes) + self.posts: list[dict] = [] + + def post(self, url, *, json=None, headers=None, timeout=None): + self.posts.append({"url": url, "json": json}) + if url.endswith("/abort_request"): + return FakeSGLangResponse(0) + outcome = self.generate_outcomes.pop(0) + if isinstance(outcome, Exception): + raise outcome + return FakeSGLangResponse(outcome) + + +@pytest.mark.unit +def test_openai_generate_reuses_app_owned_sglang_client(): + # Both turns post through the same app-owned pooled client instead of a fresh session each. + async def run_case(): + client = FakeSGLangClient([800, 801]) + session = openai.Session(sampling_defaults={"max_new_tokens": 9}) + app = {SGLANG_URL_KEY: "http://router", SGLANG_CLIENT_KEY: client} + + first = await openai._generate([11], session, {"max_tokens": 3}, app) + second = await openai._generate([12], session, {"max_tokens": 4}, app) + + assert first.output_ids == [800] + assert second.output_ids == [801] + assert [p["url"] for p in client.posts] == ["http://router/generate", "http://router/generate"] + + asyncio.run(run_case()) + + +@pytest.mark.unit +def test_openai_generate_retries_connector_error_once_with_same_rid(): + # A pre-connect connector error is retried exactly once with the SAME rid: no request + # bytes reached SGLang, so the retry cannot double-generate. + async def run_case(): + client = FakeSGLangClient([aiohttp.ClientConnectorError(FakeConnectionKey(), OSError("bad fd")), 901]) + session = openai.Session(sampling_defaults={"max_new_tokens": 9}) + app = {SGLANG_URL_KEY: "http://router", SGLANG_CLIENT_KEY: client} + + turn = await openai._generate([11], session, {"max_tokens": 3}, app) + + assert turn.output_ids == [901] + assert [p["url"] for p in client.posts] == ["http://router/generate", "http://router/generate"] + assert client.posts[0]["json"]["rid"] == client.posts[1]["json"]["rid"] + + asyncio.run(run_case()) + + +@pytest.mark.unit +def test_openai_generate_does_not_retry_after_request_reaches_server(): + # Once the request may have reached SGLang, a failure must abort by rid, never re-issue; + # the scripted success after the disconnect makes a wrongly-retried request fail loudly. + async def run_case(): + client = FakeSGLangClient([aiohttp.ServerDisconnectedError(), 999]) + session = openai.Session(sampling_defaults={"max_new_tokens": 9}) + app = {SGLANG_URL_KEY: "http://router", SGLANG_CLIENT_KEY: client} + + with pytest.raises(aiohttp.ServerDisconnectedError): + await openai._generate([11], session, {"max_tokens": 3}, app) + + assert [p["url"] for p in client.posts] == ["http://router/generate", "http://router/abort_request"] + assert client.posts[1]["json"]["rid"] == client.posts[0]["json"]["rid"] + + asyncio.run(run_case()) + + +@pytest.mark.unit +def test_openai_generate_retries_connector_error_only_once(): + # A second consecutive connector error propagates and aborts by rid; a third attempt + # would hit the scripted success, so an unbounded retry fails loudly. + async def run_case(): + client = FakeSGLangClient( + [ + aiohttp.ClientConnectorError(FakeConnectionKey(), OSError("bad fd")), + aiohttp.ClientConnectorError(FakeConnectionKey(), OSError("bad fd")), + 999, + ] + ) + session = openai.Session(sampling_defaults={"max_new_tokens": 9}) + app = {SGLANG_URL_KEY: "http://router", SGLANG_CLIENT_KEY: client} + + with pytest.raises(aiohttp.ClientConnectorError): + await openai._generate([11], session, {"max_tokens": 3}, app) + + assert [p["url"] for p in client.posts] == [ + "http://router/generate", + "http://router/generate", + "http://router/abort_request", + ] + assert client.posts[0]["json"]["rid"] == client.posts[1]["json"]["rid"] + + asyncio.run(run_case()) + + if __name__ == "__main__": raise SystemExit(pytest.main([__file__]))