Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ jobs:
strategy:
fail-fast: false
matrix:
info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_agent_trajectory.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}]
info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_agent_trajectory.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "test_sglang_control_plane_auth.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}]
defaults:
run:
working-directory: ${{ github.workspace }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pr-test.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
{'test_file': 'test_rollout_validation.py', 'num_gpus': 0},
{'test_file': 'test_placement_group.py', 'num_gpus': 0},
{'test_file': 'test_external_sglang_engines.py', 'num_gpus': 0},
{'test_file': 'test_sglang_control_plane_auth.py', 'num_gpus': 0},
{'test_file': 'utils/test_hf_checkpoint_saver.py', 'num_gpus': 0},
{'test_file': 'plugin_contracts/test_plugin_rollout_contracts.py', 'num_gpus': 0},
{'test_file': 'plugin_contracts/test_plugin_runtime_hook_contracts.py', 'num_gpus': 0},
Expand Down
14 changes: 9 additions & 5 deletions slime/backends/sglang_utils/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import requests

from slime.utils.http_utils import bearer_auth_headers

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -55,11 +57,11 @@ def external_engine_init_kwargs(info: ExternalEngineInfo) -> dict:
return init_kwargs


def get_server_info(url: str, timeout: float = 30.0) -> dict:
def get_server_info(url: str, timeout: float = 30.0, api_key: str | None = None) -> dict:
errors = []
for endpoint in ("/server_info", "/get_server_info"):
try:
response = requests.get(f"{url}{endpoint}", timeout=timeout)
response = requests.get(f"{url}{endpoint}", timeout=timeout, headers=bearer_auth_headers(api_key))
response.raise_for_status()
return response.json()
except Exception as exc:
Expand All @@ -76,13 +78,15 @@ def _infer_worker_type(server_info: dict) -> str:
return "regular"


def discover_external_engines(addrs: list[str], timeout: float = 30.0) -> list[ExternalEngineInfo]:
def discover_external_engines(
addrs: list[str], timeout: float = 30.0, api_key: str | None = None
) -> list[ExternalEngineInfo]:
infos = []
for addr in addrs:
url = normalize_external_engine_addr(addr)
parsed = urlparse(url)
assert parsed.hostname is not None and parsed.port is not None
server_info = get_server_info(url, timeout=timeout)
server_info = get_server_info(url, timeout=timeout, api_key=api_key)

pp_size = int(server_info.get("pp_size") or server_info.get("pipeline_parallel_size") or 1)
tp_size = int(server_info.get("tp_size") or server_info.get("tensor_parallel_size") or 1)
Expand Down Expand Up @@ -110,7 +114,7 @@ def apply_external_engine_info_to_args(args, logger=None) -> None:
if not addrs:
raise ValueError("apply_external_engine_info_to_args requires --rollout-external-engine-addrs.")

infos = discover_external_engines(addrs)
infos = discover_external_engines(addrs, api_key=getattr(args, "sglang_api_key", None))
if not infos:
raise ValueError("--rollout-external-engine-addrs did not contain any engines.")

Expand Down
22 changes: 12 additions & 10 deletions slime/backends/sglang_utils/server_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from typing import Any

from slime.utils.http_utils import get, post
from slime.utils.http_utils import bearer_auth_headers, get, post

logger = logging.getLogger(__name__)

Expand All @@ -29,25 +29,27 @@ def num_requests_from_load(load: Any) -> int:
return (running if isinstance(running, int) else 0) + (waiting if isinstance(waiting, int) else 0)


async def _abort_server_once(url: str) -> None:
async def _abort_server_once(url: str, api_key: str | None = None) -> None:
try:
await post(f"{url}/abort_request", {"abort_all": True})
await post(f"{url}/abort_request", {"abort_all": True}, headers=bearer_auth_headers(api_key))
except Exception as e:
logger.warning(f"Failed to abort SGLang server at {url}: {e}")


async def _get_server_num_requests(url: str) -> int:
return num_requests_from_load(await get(f"{url}/v1/loads?include=core"))
async def _get_server_num_requests(url: str, api_key: str | None = None) -> int:
return num_requests_from_load(await get(f"{url}/v1/loads?include=core", headers=bearer_auth_headers(api_key)))


async def abort_server_until_idle(url: str, retry_interval: int = ABORT_RETRY_INTERVAL_SECONDS) -> None:
async def abort_server_until_idle(
url: str, retry_interval: int = ABORT_RETRY_INTERVAL_SECONDS, api_key: str | None = None
) -> None:
attempt = 1
while True:
logger.info(f"Abort request for SGLang server {url}")
await _abort_server_once(url)
await _abort_server_once(url, api_key=api_key)

try:
num_requests = await _get_server_num_requests(url)
num_requests = await _get_server_num_requests(url, api_key=api_key)
except Exception as e:
logger.warning(f"Failed to get SGLang server load from {url}: {e}")
return
Expand All @@ -63,5 +65,5 @@ async def abort_server_until_idle(url: str, retry_interval: int = ABORT_RETRY_IN
attempt += 1


async def abort_servers_until_idle(urls: list[str]) -> None:
await asyncio.gather(*(abort_server_until_idle(url) for url in urls))
async def abort_servers_until_idle(urls: list[str], api_key: str | None = None) -> None:
await asyncio.gather(*(abort_server_until_idle(url, api_key=api_key) for url in urls))
81 changes: 57 additions & 24 deletions slime/backends/sglang_utils/sglang_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from slime.backends.sglang_utils.external import get_server_info
from slime.ray.ray_actor import RayActor
from slime.utils.http_utils import get_host_info
from slime.utils.http_utils import bearer_auth_headers, get_host_info

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -80,8 +80,8 @@ def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process:
def _wait_server_healthy(base_url, api_key, is_process_alive):
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {api_key}",
}
headers.update(bearer_auth_headers(api_key) or {})

with requests.Session() as session:
while True:
Expand Down Expand Up @@ -161,6 +161,7 @@ def _format_v6_uri(addr):
self.node_rank = server_args_dict["node_rank"]
self.server_host = server_args_dict["host"] # with [] if ipv6
self.server_port = server_args_dict["port"]
self.server_api_key = server_args_dict.get("api_key")

if self.args.rollout_external:
self._init_external(server_args_dict, external_engine_need_check_fields=external_engine_need_check_fields)
Expand All @@ -178,7 +179,9 @@ def _sanity_check_server_args(actual_server_args, expect_server_args):
actual_value == expect_value
), f"{name=} {expect_value=} {actual_value=} {expect_server_args=} {actual_server_args=}"

actual_server_args = get_server_info(f"http://{self.server_host}:{self.server_port}")
actual_server_args = get_server_info(
f"http://{self.server_host}:{self.server_port}", api_key=self.server_api_key
)
_sanity_check_server_args(actual_server_args, expect_server_args)
self._register_to_router(expect_server_args)

Expand All @@ -193,16 +196,21 @@ def _register_to_router(self, server_args_dict):

if self.node_rank == 0 and self.router_ip and self.router_port:
worker_url = f"http://{self.server_host}:{self.server_port}"
headers = self._router_auth_headers()
if parse(sglang_router.__version__) <= parse("0.2.1"):
assert self.worker_type == "regular", "pd disaggregation is not supported in old router."
response = requests.post(
f"http://{self.router_ip}:{self.router_port}/add_worker?url={worker_url}",
headers=headers,
)
else:
payload = {
"url": worker_url,
"worker_type": self.worker_type,
}
# The router uses this key to authenticate the traffic it forwards to the worker.
if worker_api_key := server_args_dict.get("api_key"):
payload["api_key"] = worker_api_key
if self.worker_type == "prefill":
bootstrap_port = server_args_dict.get("disaggregation_bootstrap_port")
if bootstrap_port is None:
Expand All @@ -214,6 +222,7 @@ def _register_to_router(self, server_args_dict):
response = requests.post(
f"http://{self.router_ip}:{self.router_port}/workers",
json=payload,
headers=headers,
)
response.raise_for_status()

Expand All @@ -231,14 +240,29 @@ def _make_request(self, endpoint: str, payload: dict | None = None):
return

url = f"http://{self.server_host}:{self.server_port}/{endpoint}"
response = requests.post(url, json=payload or {})
response = requests.post(url, json=payload or {}, headers=self._server_auth_headers())
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
e.add_note(f"{response.text=}")
raise
return response.json()

def _server_auth_headers(self):
return bearer_auth_headers(self.server_api_key)

def _router_auth_headers(self):
return bearer_auth_headers(getattr(self.args, "router_api_key", None))

def _post_server_control(self, endpoint: str, payload: dict | None = None):
response = requests.post(
f"http://{self.server_host}:{self.server_port}/{endpoint}",
json=payload or {},
headers=self._server_auth_headers(),
)
response.raise_for_status()
return response

def health_generate(self, timeout: float = 5.0) -> bool:
"""Run /health_generate on the underlying SGLang HTTP server.

Expand All @@ -256,6 +280,7 @@ def health_generate(self, timeout: float = 5.0) -> bool:

response = requests.get(
f"http://{self.server_host}:{self.server_port}/health_generate",
headers=self._server_auth_headers(),
timeout=timeout,
)
response.raise_for_status()
Expand Down Expand Up @@ -293,13 +318,22 @@ def flush_cache(self):
# flush cache will not return status_code 200 when there are pending requests
for _ in range(60):
try:
response = requests.get(f"http://{self.server_host}:{self.server_port}/flush_cache")
response = requests.get(
f"http://{self.server_host}:{self.server_port}/flush_cache",
headers=self._server_auth_headers(),
)
if response.status_code == 200:
break
# An auth misconfiguration will never recover by retrying, so surface it
# immediately instead of spinning for the full timeout.
if response.status_code in (401, 403):
response.raise_for_status()
logger.info(f"Error flushing cache: HTTP {response.status_code} {response.text!r}")
time.sleep(1)
except NewConnectionError as e:
raise e
except requests.exceptions.HTTPError:
raise
except Exception as e:
logger.info(f"Error flushing cache: {e}")
time.sleep(1)
Expand All @@ -319,22 +353,29 @@ def shutdown(self):
logger.info(f"Shutdown engine {self.server_host}:{self.server_port}...")
if self.worker_type != "encoder" and self.node_rank == 0:
worker_url = f"http://{self.server_host}:{self.server_port}"
headers = self._router_auth_headers()
response = None
if parse(sglang_router.__version__) <= parse("0.2.1"):
response = requests.post(
f"http://{self.router_ip}:{self.router_port}/remove_worker?url=http://{self.server_host}:{self.server_port}"
f"http://{self.router_ip}:{self.router_port}/remove_worker?url=http://{self.server_host}:{self.server_port}",
headers=headers,
)
elif parse(sglang_router.__version__) < parse("0.3.0"):
worker_url = quote(worker_url, safe="")
response = requests.delete(f"http://{self.router_ip}:{self.router_port}/workers/{worker_url}")
response = requests.delete(
f"http://{self.router_ip}:{self.router_port}/workers/{worker_url}", headers=headers
)
else:
try:
all_workers = requests.get(f"http://{self.router_ip}:{self.router_port}/workers").json()["workers"]
all_workers = requests.get(
f"http://{self.router_ip}:{self.router_port}/workers", headers=headers
).json()["workers"]
for worker in all_workers:
if worker["url"] == worker_url:
worker_id = worker["id"]
response = requests.delete(
f"http://{self.router_ip}:{self.router_port}/workers/{worker_id}"
f"http://{self.router_ip}:{self.router_port}/workers/{worker_id}",
headers=headers,
)
break
else:
Expand All @@ -350,7 +391,7 @@ def get_weight_version(self):
if self.node_rank != 0:
return
url = f"http://{self.server_host}:{self.server_port}/get_weight_version"
response = requests.get(url)
response = requests.get(url, headers=self._server_auth_headers())
response.raise_for_status()
return response.json()["weight_version"]

Expand Down Expand Up @@ -469,14 +510,10 @@ def update_weights_from_distributed(
)

def pause_generation(self):
response = requests.post(f"http://{self.server_host}:{self.server_port}/pause_generation", json={})
response.raise_for_status()
return response
return self._post_server_control("pause_generation")

def continue_generation(self):
response = requests.post(f"http://{self.server_host}:{self.server_port}/continue_generation", json={})
response.raise_for_status()
return response
return self._post_server_control("continue_generation")

def post_process_weights(
self,
Expand Down Expand Up @@ -511,9 +548,9 @@ def start_profile(
with_stack: bool | None = None,
record_shapes: bool | None = None,
):
response = requests.post(
f"http://{self.server_host}:{self.server_port}/start_profile",
json={
return self._post_server_control(
"start_profile",
{
"output_dir": output_dir,
"start_step": start_step,
"num_steps": num_steps,
Expand All @@ -523,13 +560,9 @@ def start_profile(
"record_shapes": record_shapes,
},
)
response.raise_for_status()
return response

def stop_profile(self):
response = requests.post(f"http://{self.server_host}:{self.server_port}/stop_profile", json={})
response.raise_for_status()
return response
return self._post_server_control("stop_profile")

def simulate_crash(self):
if self.args.rollout_external or not getattr(self, "process", None):
Expand Down
13 changes: 9 additions & 4 deletions slime/rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from slime.utils.async_utils import run
from slime.utils.data import Dataset
from slime.utils.eval_config import EvalDatasetConfig
from slime.utils.http_utils import get, get_rollout_num_engines, post
from slime.utils.http_utils import bearer_auth_headers, get, get_rollout_num_engines, post
from slime.utils.misc import SingletonMeta, load_function
from slime.utils.processing_utils import (
build_processor_kwargs,
Expand Down Expand Up @@ -355,14 +355,19 @@ async def abort(args: Namespace, rollout_id: int) -> list[list[Sample]]:
assert not state.aborted
state.aborted = True

router_headers = bearer_auth_headers(getattr(args, "router_api_key", None))
if parse(sglang_router.__version__) <= parse("0.2.1"):
response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers")
response = await get(
f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers", headers=router_headers
)
urls = response["urls"]
else:
response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers")
response = await get(
f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers", headers=router_headers
)
urls = [worker["url"] for worker in response["workers"]]

await abort_servers_until_idle(urls)
await abort_servers_until_idle(urls, api_key=getattr(args, "sglang_api_key", None))

# make sure all the pending tasks are finished
count = 0
Expand Down
Loading
Loading