diff --git a/scripts/regenerate_train_data.py b/scripts/regenerate_train_data.py index d38392b69..d41edc8d2 100644 --- a/scripts/regenerate_train_data.py +++ b/scripts/regenerate_train_data.py @@ -131,9 +131,22 @@ def parse_arguments(): nargs="+", help="Server address and port for sglang model server", ) + server_group.add_argument( + "--api-key", + type=str, + default="None", + help="API key for the endpoint", + ) return parser.parse_args() +def normalize_server_address(address: str) -> str: + """Add http:// prefix if no protocol is specified.""" + if not address.startswith(("http://", "https://")): + return f"http://{address}" + return address + + def get_random_reasoning_effort() -> str: """Get a random reasoning effort level for the model with weighted probabilities.""" # usage example: https://huggingface.co/openai/gpt-oss-20b/discussions/28 @@ -198,7 +211,7 @@ def call_sglang( max_tokens=None, ) -> str: """Send a batch of prompts to sglang /v1/completions.""" - client = OpenAI(base_url=f"http://{server_address}/v1", api_key="None") + client = OpenAI(base_url=f"{server_address}/v1", api_key=args.api_key) messages = data["conversations"] regenerated_messages = [] @@ -248,6 +261,11 @@ def main(): # Parse command line arguments args = parse_arguments() + # Normalize server addresses to ensure they have protocol prefix + args.server_address = [ + normalize_server_address(addr) for addr in args.server_address + ] + # Validate parameters if not (0.0 <= args.temperature <= 1.0): raise ValueError("Temperature must be between 0.0 and 1.0") @@ -261,6 +279,7 @@ def main(): print(f" Concurrency: {args.concurrency}") print(f" Temperature: {args.temperature}") print(f" API URL: {args.server_address}") + print(f" API key set: {args.api_key != "None"}") print(f" Input file: {args.input_file_path}") print(f" Output file: {args.output_file_path}") print(f" Resume mode: {args.resume}")