Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion scripts/regenerate_train_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,22 @@ def parse_arguments():
nargs="+",
help="Server address and port for sglang model server",
)
server_group.add_argument(
"--api-key",
type=str,
default="None",
help="API key for the endpoint",
)
return parser.parse_args()


def normalize_server_address(address: str) -> str:
"""Add http:// prefix if no protocol is specified."""
if not address.startswith(("http://", "https://")):
return f"http://{address}"
return address


def get_random_reasoning_effort() -> str:
"""Get a random reasoning effort level for the model with weighted probabilities."""
# usage example: https://huggingface.co/openai/gpt-oss-20b/discussions/28
Expand Down Expand Up @@ -198,7 +211,7 @@ def call_sglang(
max_tokens=None,
) -> str:
"""Send a batch of prompts to sglang /v1/completions."""
client = OpenAI(base_url=f"http://{server_address}/v1", api_key="None")
client = OpenAI(base_url=f"{server_address}/v1", api_key=args.api_key)

messages = data["conversations"]
regenerated_messages = []
Expand Down Expand Up @@ -248,6 +261,11 @@ def main():
# Parse command line arguments
args = parse_arguments()

# Normalize server addresses to ensure they have protocol prefix
args.server_address = [
normalize_server_address(addr) for addr in args.server_address
]

# Validate parameters
if not (0.0 <= args.temperature <= 1.0):
raise ValueError("Temperature must be between 0.0 and 1.0")
Expand All @@ -261,6 +279,7 @@ def main():
print(f" Concurrency: {args.concurrency}")
print(f" Temperature: {args.temperature}")
print(f" API URL: {args.server_address}")
print(f" API key set: {args.api_key != "None"}")
print(f" Input file: {args.input_file_path}")
print(f" Output file: {args.output_file_path}")
print(f" Resume mode: {args.resume}")
Expand Down