Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 56 additions & 17 deletions slime/agent/adapters/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
SGLANG_URL_KEY = web.AppKey("sglang_url", object)
TOOL_PARSER_KEY = web.AppKey("tool_parser", object)
REASONING_PARSER_KEY = web.AppKey("reasoning_parser", object)
# Pooled per-app SGLang client, registered by sglang_client_context and reused across turns.
SGLANG_CLIENT_KEY = web.AppKey("sglang_client", aiohttp.ClientSession)


@dataclasses.dataclass
Expand Down Expand Up @@ -51,6 +53,7 @@ def __init__(self, *, tokenizer, sglang_url, tool_parser=None, reasoning_parser=
self.app[SGLANG_URL_KEY] = sglang_url.rstrip("/") if isinstance(sglang_url, str) else sglang_url
self.app[TOOL_PARSER_KEY] = tool_parser
self.app[REASONING_PARSER_KEY] = reasoning_parser
self.app.cleanup_ctx.append(sglang_client_context)

def open_session(
self,
Expand Down Expand Up @@ -174,6 +177,42 @@ def _sampling_params(session: Any, body: dict, *, max_token_keys: tuple[str, ...
return sp


async def sglang_client_context(app: web.Application):
# limit=0: connection concurrency is governed by the rollout scheduler, not the pool.
connector = aiohttp.TCPConnector(limit=0, ttl_dns_cache=300, keepalive_timeout=60)
timeout = aiohttp.ClientTimeout(total=None, sock_read=900)
app[SGLANG_CLIENT_KEY] = aiohttp.ClientSession(timeout=timeout, connector=connector)
try:
yield
finally:
await app[SGLANG_CLIENT_KEY].close()


async def _post_sglang_generate(
client: aiohttp.ClientSession,
sglang_url: str,
*,
rid: str,
prompt_ids: list[int],
sampling_params: dict,
headers: dict[str, str] | None,
) -> dict:
async with client.post(
f"{sglang_url}/generate",
json={
"rid": rid,
"input_ids": prompt_ids,
"sampling_params": sampling_params,
"return_logprob": True,
},
headers=headers,
) as r:
if r.status >= 400:
text = await r.text()
raise RuntimeError(f"sglang upstream {r.status}: {text[:400]}")
return await r.json(content_type=None)


async def call_sglang_generate(
prompt_ids: list[int],
session: Any,
Expand Down Expand Up @@ -201,33 +240,33 @@ async def call_sglang_generate(
sp["max_new_tokens"] = min(int(sp.get("max_new_tokens", remaining_context)), remaining_context)

sglang_url = app[SGLANG_URL_KEY]
client = app[SGLANG_CLIENT_KEY]
rid = uuid.uuid4().hex
headers = {"X-SMG-Routing-Key": session_id} if session_id and session_id != "default" else None
timeout = aiohttp.ClientTimeout(total=None, sock_read=900)
try:
async with aiohttp.ClientSession(timeout=timeout) as sess, sess.post(
f"{sglang_url}/generate",
json={
"rid": rid,
"input_ids": prompt_ids,
"sampling_params": sp,
"return_logprob": True,
},
headers=headers,
) as r:
if r.status >= 400:
text = await r.text()
raise RuntimeError(f"sglang upstream {r.status}: {text[:400]}")
data = await r.json(content_type=None)
try:
data = await _post_sglang_generate(
client, sglang_url, rid=rid, prompt_ids=prompt_ids, sampling_params=sp, headers=headers
)
except aiohttp.ClientConnectorError:
# A connector error is raised before any request bytes reach SGLang, so retrying
# with the SAME rid cannot double-generate. Errors after the request may have
# reached the server are NOT retried and fall through to the abort path below.
logger.warning("[%s] retrying SGLang generate after connector failure", log_prefix)
data = await _post_sglang_generate(
client, sglang_url, rid=rid, prompt_ids=prompt_ids, sampling_params=sp, headers=headers
)
meta = data.get("meta_info") or {}
output_token_logprobs = meta.get("output_token_logprobs") or []
output_ids = [x[1] for x in output_token_logprobs]
output_log_probs = [float(x[0]) for x in output_token_logprobs]
finish = (meta.get("finish_reason") or {}).get("type", "stop") or "stop"
except (asyncio.CancelledError, aiohttp.ClientError, asyncio.TimeoutError):
try:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=5)) as s2:
await s2.post(f"{sglang_url}/abort_request", json={"rid": rid})
async with client.post(
f"{sglang_url}/abort_request", json={"rid": rid}, timeout=aiohttp.ClientTimeout(total=5)
):
pass
except Exception:
pass
raise
Expand Down
145 changes: 138 additions & 7 deletions tests/test_agent_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
from pathlib import Path

import aiohttp
import pytest
from aiohttp import web
from aiohttp.test_utils import TestClient, TestServer
Expand All @@ -12,7 +13,7 @@
sys.path.insert(0, str(REPO_ROOT))

from slime.agent.adapters import anthropic, openai
from slime.agent.adapters.common import SGLANG_URL_KEY
from slime.agent.adapters.common import SGLANG_CLIENT_KEY, SGLANG_URL_KEY
from slime.agent.trajectory import TurnRecord


Expand Down Expand Up @@ -819,12 +820,16 @@ async def handle_generate(request):
await server.start_server()
try:
session = openai.Session(sampling_defaults={"max_new_tokens": 9})
turn = await openai._generate(
[11, 12],
session,
{"max_tokens": 3, "temperature": 0.25, "stop": ["</s>"]},
{SGLANG_URL_KEY: str(server.make_url("")).rstrip("/")},
)
async with aiohttp.ClientSession() as client:
turn = await openai._generate(
[11, 12],
session,
{"max_tokens": 3, "temperature": 0.25, "stop": ["</s>"]},
{
SGLANG_URL_KEY: str(server.make_url("")).rstrip("/"),
SGLANG_CLIENT_KEY: client,
},
)
finally:
await server.close()

Expand All @@ -841,5 +846,131 @@ async def handle_generate(request):
asyncio.run(run_case())


class FakeConnectionKey:
host = "10.0.1.4"
port = 4049
ssl = True


class FakeSGLangResponse:
status = 200

def __init__(self, token_id: int) -> None:
self.token_id = token_id

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
return None

async def json(self, content_type=None):
return {
"meta_info": {
"output_token_logprobs": [[-0.1, self.token_id]],
"finish_reason": {"type": "stop"},
}
}


class FakeSGLangClient:
# Stands in for the app-owned pooled client: each /generate post consumes one scripted
# outcome (a token id to return or an exception to raise); /abort_request always succeeds.
def __init__(self, generate_outcomes: list[int | Exception]) -> None:
self.generate_outcomes = list(generate_outcomes)
self.posts: list[dict] = []

def post(self, url, *, json=None, headers=None, timeout=None):
self.posts.append({"url": url, "json": json})
if url.endswith("/abort_request"):
return FakeSGLangResponse(0)
outcome = self.generate_outcomes.pop(0)
if isinstance(outcome, Exception):
raise outcome
return FakeSGLangResponse(outcome)


@pytest.mark.unit
def test_openai_generate_reuses_app_owned_sglang_client():
# Both turns post through the same app-owned pooled client instead of a fresh session each.
async def run_case():
client = FakeSGLangClient([800, 801])
session = openai.Session(sampling_defaults={"max_new_tokens": 9})
app = {SGLANG_URL_KEY: "http://router", SGLANG_CLIENT_KEY: client}

first = await openai._generate([11], session, {"max_tokens": 3}, app)
second = await openai._generate([12], session, {"max_tokens": 4}, app)

assert first.output_ids == [800]
assert second.output_ids == [801]
assert [p["url"] for p in client.posts] == ["http://router/generate", "http://router/generate"]

asyncio.run(run_case())


@pytest.mark.unit
def test_openai_generate_retries_connector_error_once_with_same_rid():
# A pre-connect connector error is retried exactly once with the SAME rid: no request
# bytes reached SGLang, so the retry cannot double-generate.
async def run_case():
client = FakeSGLangClient([aiohttp.ClientConnectorError(FakeConnectionKey(), OSError("bad fd")), 901])
session = openai.Session(sampling_defaults={"max_new_tokens": 9})
app = {SGLANG_URL_KEY: "http://router", SGLANG_CLIENT_KEY: client}

turn = await openai._generate([11], session, {"max_tokens": 3}, app)

assert turn.output_ids == [901]
assert [p["url"] for p in client.posts] == ["http://router/generate", "http://router/generate"]
assert client.posts[0]["json"]["rid"] == client.posts[1]["json"]["rid"]

asyncio.run(run_case())


@pytest.mark.unit
def test_openai_generate_does_not_retry_after_request_reaches_server():
# Once the request may have reached SGLang, a failure must abort by rid, never re-issue;
# the scripted success after the disconnect makes a wrongly-retried request fail loudly.
async def run_case():
client = FakeSGLangClient([aiohttp.ServerDisconnectedError(), 999])
session = openai.Session(sampling_defaults={"max_new_tokens": 9})
app = {SGLANG_URL_KEY: "http://router", SGLANG_CLIENT_KEY: client}

with pytest.raises(aiohttp.ServerDisconnectedError):
await openai._generate([11], session, {"max_tokens": 3}, app)

assert [p["url"] for p in client.posts] == ["http://router/generate", "http://router/abort_request"]
assert client.posts[1]["json"]["rid"] == client.posts[0]["json"]["rid"]

asyncio.run(run_case())


@pytest.mark.unit
def test_openai_generate_retries_connector_error_only_once():
# A second consecutive connector error propagates and aborts by rid; a third attempt
# would hit the scripted success, so an unbounded retry fails loudly.
async def run_case():
client = FakeSGLangClient(
[
aiohttp.ClientConnectorError(FakeConnectionKey(), OSError("bad fd")),
aiohttp.ClientConnectorError(FakeConnectionKey(), OSError("bad fd")),
999,
]
)
session = openai.Session(sampling_defaults={"max_new_tokens": 9})
app = {SGLANG_URL_KEY: "http://router", SGLANG_CLIENT_KEY: client}

with pytest.raises(aiohttp.ClientConnectorError):
await openai._generate([11], session, {"max_tokens": 3}, app)

assert [p["url"] for p in client.posts] == [
"http://router/generate",
"http://router/generate",
"http://router/abort_request",
]
assert client.posts[0]["json"]["rid"] == client.posts[1]["json"]["rid"]

asyncio.run(run_case())


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
Loading