diff --git a/scripts/benchmark_batch_decoding.py b/scripts/benchmark_batch_decoding.py new file mode 100644 index 00000000..83183f09 --- /dev/null +++ b/scripts/benchmark_batch_decoding.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json +import os +import platform +import statistics +import time +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import tiktoken + + +def _tool_json(i: int) -> str: + return json.dumps( + { + "tool": "lookup_order", + "arguments": { + "order_id": f"ord_{i:06d}", + "urgent": i % 13 == 0, + }, + }, + separators=(",", ":"), + ) + + +def make_workloads() -> dict[str, list[str]]: + return { + "tiny_10k": [f"hello world {i}" for i in range(10_000)], + "chat_messages_10k": [ + f"user: customer {i} asked whether order ord_{i:06d} can be rerouted before " + "the warehouse batch closes." + for i in range(10_000) + ], + "tool_json_5k": [_tool_json(i) for i in range(5_000)], + "medium_1k": [ + ("The quick brown fox jumps over the lazy dog. " * 20) + str(i) + for i in range(1_000) + ], + "long_100": [ + ("The quick brown fox jumps over the lazy dog. " * 2_000) + str(i) + for i in range(100) + ], + "mixed_2k": [ + ("short " + str(i)) + if i % 2 + else ("The quick brown fox jumps over the lazy dog. " * 200) + str(i) + for i in range(2_000) + ], + } + + +def measure( + name: str, + fn: Callable[[list[list[int]]], list[str] | list[bytes]], + batch: list[list[int]], + reps: int, + warmups: int, +) -> dict[str, Any]: + for _ in range(warmups): + fn(batch[: min(len(batch), 128)]) + + times = [] + for _ in range(reps): + start = time.perf_counter() + out = fn(batch) + elapsed = time.perf_counter() - start + if len(out) != len(batch): + raise RuntimeError(f"{name}: expected {len(batch)} outputs, got {len(out)}") + times.append(elapsed) + + best = min(times) + median = statistics.median(times) + p95 = sorted(times)[max(0, int(len(times) * 0.95) - 1)] + return { + "best_s": best, + "median_s": median, + "p95_s": p95, + } + + +def main() -> None: + parser = argparse.ArgumentParser(description="Benchmark tiktoken batch decoding workloads.") + parser.add_argument("--encoding", default="cl100k_base") + parser.add_argument("--num-threads", type=int, default=8) + parser.add_argument("--reps", type=int, default=10) + parser.add_argument("--warmups", type=int, default=2) + parser.add_argument("--json-output", type=Path) + args = parser.parse_args() + + enc = tiktoken.get_encoding(args.encoding) + workloads = make_workloads() + benchmarks: dict[str, Callable[[list[list[int]]], list[str] | list[bytes]]] = { + "decode_batch": lambda batch: enc.decode_batch(batch, num_threads=args.num_threads), + "decode_bytes_batch": lambda batch: enc.decode_bytes_batch( + batch, num_threads=args.num_threads + ), + } + + result: dict[str, Any] = { + "environment": { + "python": platform.python_version(), + "platform": platform.platform(), + "machine": platform.machine(), + "cpu_count": os.cpu_count(), + "tiktoken_file": tiktoken.__file__, + "encoding": args.encoding, + "num_threads": args.num_threads, + "reps": args.reps, + "warmups": args.warmups, + }, + "results": {}, + } + + for workload_name, docs in workloads.items(): + batch = enc.encode_ordinary_batch(docs, num_threads=args.num_threads) + num_bytes = sum(len(doc.encode("utf-8")) for doc in docs) + num_tokens = sum(map(len, batch)) + workload_result: dict[str, Any] = { + "documents": len(batch), + "bytes": num_bytes, + "tokens": num_tokens, + "avg_tokens": num_tokens / len(batch), + "benchmarks": {}, + } + print( + f"{workload_name}: docs={len(batch)} bytes={num_bytes} " + f"tokens={num_tokens} avg_tokens={workload_result['avg_tokens']:.1f}" + ) + for bench_name, bench_fn in benchmarks.items(): + metrics = measure(bench_name, bench_fn, batch, args.reps, args.warmups) + metrics["docs_per_s"] = len(batch) / metrics["best_s"] + metrics["mb_per_s"] = num_bytes / 1_000_000 / metrics["best_s"] + workload_result["benchmarks"][bench_name] = metrics + print( + f" {bench_name}: best={metrics['best_s'] * 1000:.3f}ms " + f"median={metrics['median_s'] * 1000:.3f}ms " + f"docs/s={metrics['docs_per_s']:.0f} " + f"MB/s={metrics['mb_per_s']:.2f}" + ) + result["results"][workload_name] = workload_result + + if args.json_output is not None: + args.json_output.write_text(json.dumps(result, indent=2) + "\n") + + +if __name__ == "__main__": + main() diff --git a/scripts/benchmark_batch_encoding.py b/scripts/benchmark_batch_encoding.py new file mode 100644 index 00000000..bb9cf304 --- /dev/null +++ b/scripts/benchmark_batch_encoding.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json +import os +import platform +import statistics +import time +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import tiktoken + + +def _tool_json(i: int) -> str: + return json.dumps( + { + "tool": "lookup_order", + "arguments": { + "order_id": f"ord_{i:06d}", + "include": ["status", "refunds", "shipments"], + "urgent": i % 13 == 0, + }, + }, + separators=(",", ":"), + ) + + +def _chat_message(i: int) -> str: + roles = ["system", "user", "assistant", "tool"] + return ( + f"{roles[i % len(roles)]}: customer {i} asked whether order ord_{i:06d} " + "can be rerouted before the warehouse batch closes." + ) + + +def _rag_snippet(i: int) -> str: + return ( + f"doc={i} title=Returns policy. Customers can request a return within 30 days " + "when the item is unused, in original packaging, and accompanied by a receipt. " + f"Region shard {i % 17}." + ) + + +def make_workloads() -> dict[str, list[str]]: + return { + "tiny_10k": [f"hello world {i}" for i in range(10_000)], + "chat_messages_10k": [_chat_message(i) for i in range(10_000)], + "tool_json_5k": [_tool_json(i) for i in range(5_000)], + "rag_snippets_5k": [_rag_snippet(i) for i in range(5_000)], + "medium_1k": [("The quick brown fox jumps over the lazy dog. " * 20) + str(i) for i in range(1_000)], + } + + +def measure( + name: str, + fn: Callable[[list[str]], list[list[int]]], + docs: list[str], + reps: int, + warmups: int, +) -> dict[str, Any]: + for _ in range(warmups): + fn(docs[: min(len(docs), 128)]) + + times = [] + token_count = None + for _ in range(reps): + start = time.perf_counter() + out = fn(docs) + elapsed = time.perf_counter() - start + if len(out) != len(docs): + raise RuntimeError(f"{name}: expected {len(docs)} outputs, got {len(out)}") + current_token_count = sum(map(len, out)) + if token_count is None: + token_count = current_token_count + elif current_token_count != token_count: + raise RuntimeError(f"{name}: token count changed between runs") + times.append(elapsed) + + best = min(times) + median = statistics.median(times) + p95 = sorted(times)[max(0, int(len(times) * 0.95) - 1)] + return { + "best_s": best, + "median_s": median, + "p95_s": p95, + "tokens": token_count, + } + + +def main() -> None: + parser = argparse.ArgumentParser(description="Benchmark tiktoken batch encoding workloads.") + parser.add_argument("--encoding", default="cl100k_base") + parser.add_argument("--num-threads", type=int, default=8) + parser.add_argument("--reps", type=int, default=10) + parser.add_argument("--warmups", type=int, default=2) + parser.add_argument("--json-output", type=Path) + args = parser.parse_args() + + enc = tiktoken.get_encoding(args.encoding) + workloads = make_workloads() + benchmarks: dict[str, Callable[[list[str]], list[list[int]]]] = { + "encode_batch": lambda docs: enc.encode_batch(docs, num_threads=args.num_threads), + "encode_ordinary_batch": lambda docs: enc.encode_ordinary_batch( + docs, num_threads=args.num_threads + ), + } + + result: dict[str, Any] = { + "environment": { + "python": platform.python_version(), + "platform": platform.platform(), + "machine": platform.machine(), + "cpu_count": os.cpu_count(), + "tiktoken_file": tiktoken.__file__, + "encoding": args.encoding, + "num_threads": args.num_threads, + "reps": args.reps, + "warmups": args.warmups, + }, + "results": {}, + } + + for workload_name, docs in workloads.items(): + num_bytes = sum(len(doc.encode("utf-8")) for doc in docs) + workload_result: dict[str, Any] = { + "documents": len(docs), + "bytes": num_bytes, + "avg_chars": sum(map(len, docs)) / len(docs), + "benchmarks": {}, + } + print( + f"{workload_name}: docs={len(docs)} bytes={num_bytes} " + f"avg_chars={workload_result['avg_chars']:.1f}" + ) + for bench_name, bench_fn in benchmarks.items(): + metrics = measure(bench_name, bench_fn, docs, args.reps, args.warmups) + metrics["docs_per_s"] = len(docs) / metrics["best_s"] + metrics["mb_per_s"] = num_bytes / 1_000_000 / metrics["best_s"] + workload_result["benchmarks"][bench_name] = metrics + print( + f" {bench_name}: best={metrics['best_s'] * 1000:.3f}ms " + f"median={metrics['median_s'] * 1000:.3f}ms " + f"docs/s={metrics['docs_per_s']:.0f} " + f"MB/s={metrics['mb_per_s']:.2f}" + ) + result["results"][workload_name] = workload_result + + if args.json_output is not None: + args.json_output.write_text(json.dumps(result, indent=2) + "\n") + + +if __name__ == "__main__": + main() diff --git a/scripts/benchmark_special_encoding.py b/scripts/benchmark_special_encoding.py new file mode 100644 index 00000000..b947502d --- /dev/null +++ b/scripts/benchmark_special_encoding.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json +import os +import platform +import statistics +import time +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import tiktoken + + +def make_single_workloads() -> dict[str, str]: + return { + "tiny": "hello world", + "chat": "user: customer asked whether order can be rerouted before close.", + "medium": "The quick brown fox jumps over the lazy dog. " * 100, + "long": "The quick brown fox jumps over the lazy dog. " * 5_000, + } + + +def make_batch_workloads() -> dict[str, list[str]]: + return { + "tiny_10k": [f"hello world {i}" for i in range(10_000)], + "chat_10k": [ + f"user: customer {i} asked whether order ord_{i:06d} can be rerouted before close." + for i in range(10_000) + ], + "tool_json_5k": [ + json.dumps( + { + "tool": "lookup_order", + "arguments": { + "order_id": f"ord_{i:06d}", + "urgent": False, + }, + }, + separators=(",", ":"), + ) + for i in range(5_000) + ], + } + + +def measure_single( + fn: Callable[[str], list[int]], text: str, reps: int, warmups: int +) -> dict[str, Any]: + for _ in range(warmups): + fn(text) + + times = [] + out = [] + for _ in range(reps): + start = time.perf_counter() + out = fn(text) + times.append(time.perf_counter() - start) + + return { + "best_s": min(times), + "median_s": statistics.median(times), + "tokens": len(out), + } + + +def measure_batch( + fn: Callable[[list[str]], list[list[int]]], docs: list[str], reps: int, warmups: int +) -> dict[str, Any]: + for _ in range(warmups): + fn(docs[: min(len(docs), 128)]) + + times = [] + out: list[list[int]] = [] + for _ in range(reps): + start = time.perf_counter() + out = fn(docs) + times.append(time.perf_counter() - start) + + return { + "best_s": min(times), + "median_s": statistics.median(times), + "tokens": sum(map(len, out)), + } + + +def main() -> None: + parser = argparse.ArgumentParser(description="Benchmark tiktoken special-token encode paths.") + parser.add_argument("--encoding", default="o200k_harmony") + parser.add_argument("--single-reps", type=int, default=2_000) + parser.add_argument("--batch-reps", type=int, default=10) + parser.add_argument("--warmups", type=int, default=5) + parser.add_argument("--json-output", type=Path) + args = parser.parse_args() + + enc = tiktoken.get_encoding(args.encoding) + result: dict[str, Any] = { + "environment": { + "python": platform.python_version(), + "platform": platform.platform(), + "machine": platform.machine(), + "cpu_count": os.cpu_count(), + "tiktoken_file": tiktoken.__file__, + "encoding": args.encoding, + "special_tokens": len(enc.special_tokens_set), + }, + "single": {}, + "batch": {}, + } + + single_benchmarks: dict[str, Callable[[str], list[int]]] = { + "encode": enc.encode, + "encode_disallowed_special_empty": lambda text: enc.encode(text, disallowed_special=()), + "encode_ordinary": enc.encode_ordinary, + } + for workload_name, text in make_single_workloads().items(): + print(f"{workload_name}: bytes={len(text.encode('utf-8'))}") + workload_result = {} + reps = args.single_reps if len(text) < 1_000 else max(100, args.single_reps // 20) + for bench_name, bench_fn in single_benchmarks.items(): + metrics = measure_single(bench_fn, text, reps, args.warmups) + workload_result[bench_name] = metrics + print( + f" {bench_name}: best={metrics['best_s'] * 1_000_000:.3f}us " + f"median={metrics['median_s'] * 1_000_000:.3f}us " + f"tokens={metrics['tokens']}" + ) + result["single"][workload_name] = workload_result + + batch_benchmarks: dict[str, Callable[[list[str]], list[list[int]]]] = { + "encode_batch": enc.encode_batch, + "encode_batch_disallowed_special_empty": lambda docs: enc.encode_batch( + docs, disallowed_special=() + ), + "encode_ordinary_batch": enc.encode_ordinary_batch, + } + for workload_name, docs in make_batch_workloads().items(): + num_bytes = sum(len(doc.encode("utf-8")) for doc in docs) + print(f"{workload_name}: docs={len(docs)} bytes={num_bytes}") + workload_result = {} + for bench_name, bench_fn in batch_benchmarks.items(): + metrics = measure_batch(bench_fn, docs, args.batch_reps, args.warmups) + metrics["docs_per_s"] = len(docs) / metrics["best_s"] + metrics["mb_per_s"] = num_bytes / 1_000_000 / metrics["best_s"] + workload_result[bench_name] = metrics + print( + f" {bench_name}: best={metrics['best_s'] * 1000:.3f}ms " + f"median={metrics['median_s'] * 1000:.3f}ms " + f"docs/s={metrics['docs_per_s']:.0f} " + f"MB/s={metrics['mb_per_s']:.2f}" + ) + result["batch"][workload_name] = workload_result + + if args.json_output is not None: + args.json_output.write_text(json.dumps(result, indent=2) + "\n") + + +if __name__ == "__main__": + main() diff --git a/scripts/benchmark_token_decoding.py b/scripts/benchmark_token_decoding.py new file mode 100644 index 00000000..b3ea20e9 --- /dev/null +++ b/scripts/benchmark_token_decoding.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json +import os +import platform +import statistics +import time +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import tiktoken + + +def _tool_json(i: int) -> str: + return json.dumps( + { + "tool": "lookup_order", + "arguments": { + "order_id": f"ord_{i:06d}", + "urgent": i % 13 == 0, + "destination": f"warehouse-{i % 17}", + }, + }, + separators=(",", ":"), + ) + + +def make_workloads() -> dict[str, str]: + chat = "\n".join( + ( + f"user: customer {i} asked whether order ord_{i:06d} can be rerouted before " + "the warehouse batch closes.\n" + f"assistant: I checked the carrier window and found option {i % 5}." + ) + for i in range(5_000) + ) + unicode_notes = "\n".join( + f"{i}: 我非常渴望与人工智能一起工作. நடிகர் சூர்யா. Ġ除." + for i in range(5_000) + ) + return { + "tiny_lines_20k": "\n".join(f"hello world {i}" for i in range(20_000)), + "chat_transcript_5k": chat, + "tool_json_10k": "\n".join(_tool_json(i) for i in range(10_000)), + "unicode_notes_5k": unicode_notes, + "long_doc": ("The quick brown fox jumps over the lazy dog. " * 40_000), + } + + +def _validate_decode_tokens_bytes(out: Any, tokens: list[int]) -> None: + if len(out) != len(tokens): + raise RuntimeError(f"expected {len(tokens)} token byte chunks, got {len(out)}") + + +def _validate_decode_with_offsets(out: Any, tokens: list[int]) -> None: + text, offsets = out + if not isinstance(text, str): + raise RuntimeError("decode_with_offsets returned non-string text") + if len(offsets) != len(tokens): + raise RuntimeError(f"expected {len(tokens)} offsets, got {len(offsets)}") + + +def measure( + name: str, + fn: Callable[[list[int]], Any], + tokens: list[int], + reps: int, + warmups: int, +) -> dict[str, Any]: + for _ in range(warmups): + fn(tokens) + + times = [] + for _ in range(reps): + start = time.perf_counter() + out = fn(tokens) + elapsed = time.perf_counter() - start + if name == "decode_tokens_bytes": + _validate_decode_tokens_bytes(out, tokens) + elif name == "decode_with_offsets": + _validate_decode_with_offsets(out, tokens) + times.append(elapsed) + + best = min(times) + median = statistics.median(times) + p95 = sorted(times)[max(0, int(len(times) * 0.95) - 1)] + return { + "best_s": best, + "median_s": median, + "p95_s": p95, + } + + +def main() -> None: + parser = argparse.ArgumentParser(description="Benchmark tiktoken token decoding workloads.") + parser.add_argument("--encoding", default="cl100k_base") + parser.add_argument("--reps", type=int, default=10) + parser.add_argument("--warmups", type=int, default=2) + parser.add_argument("--json-output", type=Path) + args = parser.parse_args() + + enc = tiktoken.get_encoding(args.encoding) + workloads = make_workloads() + benchmarks: dict[str, Callable[[list[int]], Any]] = { + "decode_tokens_bytes": enc.decode_tokens_bytes, + "decode_with_offsets": enc.decode_with_offsets, + } + + result: dict[str, Any] = { + "environment": { + "python": platform.python_version(), + "platform": platform.platform(), + "machine": platform.machine(), + "cpu_count": os.cpu_count(), + "tiktoken_file": tiktoken.__file__, + "encoding": args.encoding, + "reps": args.reps, + "warmups": args.warmups, + }, + "results": {}, + } + + for workload_name, text in workloads.items(): + tokens = enc.encode_ordinary(text) + num_bytes = len(text.encode("utf-8")) + workload_result: dict[str, Any] = { + "bytes": num_bytes, + "tokens": len(tokens), + "benchmarks": {}, + } + print(f"{workload_name}: bytes={num_bytes} tokens={len(tokens)}") + for bench_name, bench_fn in benchmarks.items(): + metrics = measure(bench_name, bench_fn, tokens, args.reps, args.warmups) + metrics["tokens_per_s"] = len(tokens) / metrics["best_s"] + metrics["mb_per_s"] = num_bytes / 1_000_000 / metrics["best_s"] + workload_result["benchmarks"][bench_name] = metrics + print( + f" {bench_name}: best={metrics['best_s'] * 1000:.3f}ms " + f"median={metrics['median_s'] * 1000:.3f}ms " + f"tokens/s={metrics['tokens_per_s']:.0f} " + f"MB/s={metrics['mb_per_s']:.2f}" + ) + result["results"][workload_name] = workload_result + + if args.json_output is not None: + args.json_output.write_text(json.dumps(result, indent=2) + "\n") + + +if __name__ == "__main__": + main() diff --git a/src/lib.rs b/src/lib.rs index ea54eac8..cf711f76 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ use std::collections::HashSet; use std::num::NonZeroU64; +use std::sync::OnceLock; use std::thread; use fancy_regex::Regex; @@ -322,12 +323,39 @@ pub struct CoreBPE { special_tokens_encoder: HashMap, decoder: HashMap>, special_tokens_decoder: HashMap>, + special_regex_pattern: String, regex_tls: Vec, - special_regex_tls: Vec, - sorted_token_bytes: Vec>, + special_regex_tls: OnceLock>, + #[allow(dead_code)] + sorted_token_bytes: OnceLock>>, + token_bytes_by_first_byte: OnceLock>>>, } impl CoreBPE { + #[allow(dead_code)] + fn sorted_token_bytes(&self) -> &[Vec] { + self.sorted_token_bytes.get_or_init(|| { + let mut sorted_token_bytes: Vec> = self.encoder.keys().cloned().collect(); + sorted_token_bytes.sort(); + sorted_token_bytes + }) + } + + fn token_bytes_by_first_byte(&self) -> &[Vec>] { + self.token_bytes_by_first_byte.get_or_init(|| { + let mut groups = vec![Vec::new(); 256]; + for token_bytes in self.encoder.keys() { + if let Some(&first_byte) = token_bytes.first() { + groups[first_byte as usize].push(token_bytes.clone()); + } + } + for group in &mut groups { + group.sort(); + } + groups + }) + } + fn _get_tl_regex(&self) -> &Regex { // See performance notes above for what this is about // It's also a little janky, please make a better version of it! @@ -336,7 +364,14 @@ impl CoreBPE { } fn _get_tl_special_regex(&self) -> &Regex { - &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS] + let special_regex_tls = self.special_regex_tls.get_or_init(|| { + let special_regex = Regex::new(&self.special_regex_pattern) + .expect("escaped special token regex should compile"); + (0..MAX_NUM_THREADS) + .map(|_| special_regex.clone()) + .collect() + }); + &special_regex_tls[hash_current_thread() % MAX_NUM_THREADS] } /// Decodes tokens into a list of bytes. @@ -511,16 +546,15 @@ impl CoreBPE { // This is the easy bit. Just find all single tokens that start with unstable_bytes // (including tokens that exactly match unstable_bytes) // Separating this from the loop below helps with performance in a common case. - let mut point = self - .sorted_token_bytes - .partition_point(|x| x.as_slice() < unstable_bytes.as_slice()); - while point < self.sorted_token_bytes.len() - && self.sorted_token_bytes[point].starts_with(&unstable_bytes) - { - completions.insert(vec![ - self.encoder[self.sorted_token_bytes[point].as_slice()], - ]); - point += 1; + let token_bytes_by_first_byte = self.token_bytes_by_first_byte(); + if let Some(&first_byte) = unstable_bytes.first() { + let token_bytes = &token_bytes_by_first_byte[first_byte as usize]; + let mut point = + token_bytes.partition_point(|x| x.as_slice() < unstable_bytes.as_slice()); + while point < token_bytes.len() && token_bytes[point].starts_with(&unstable_bytes) { + completions.insert(vec![self.encoder[token_bytes[point].as_slice()]]); + point += 1; + } } // Now apply even more brute force. At every (other) possible position for the straddling @@ -529,44 +563,43 @@ impl CoreBPE { for i in 1..unstable_bytes.len() { let prefix = &unstable_bytes[..i]; let suffix = &unstable_bytes[i..]; - let mut point = self - .sorted_token_bytes - .partition_point(|x| x.as_slice() < suffix); // TODO: Perf optimisation if suffix starts with " "? - while point < self.sorted_token_bytes.len() - && self.sorted_token_bytes[point].starts_with(suffix) - { - let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat(); - let encoded = match std::str::from_utf8(&possibility) { - // Morally, this is byte_pair_encode(&possibility, &self.encoder) - // But we might have introduced a regex split which would prevent merges. - // (particularly possible in the presence of unstable regex splits) - // So convert to UTF-8 and do regex splitting. - // E.g. with cl100k_base " !" gets split to " " + " !", - // but byte_pair_encode(" !") != byte_pair_encode(" ") - Ok(s) => self.encode_ordinary(s), - - // Technically, whether or not this arm is correct depends on whether there - // would be a regex split before the UTF-8 truncation point. - // Probably niche enough that no one will ever notice (after all, people didn't - // notice all the big holes in the previous unstable token implementation) - Err(_) => byte_pair_encode(&possibility, &self.encoder), - // Something like the following is intriguing but incorrect: - // Err(e) => self.encode_ordinary(unsafe { - // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()]) - // }), - }; - let mut seq = Vec::new(); - let mut seq_len = 0; - for token in encoded { - seq.push(token); - seq_len += self.decoder[&token].len(); - if seq_len >= unstable_bytes.len() { - break; + if let Some(&first_byte) = suffix.first() { + let token_bytes = &token_bytes_by_first_byte[first_byte as usize]; + let mut point = token_bytes.partition_point(|x| x.as_slice() < suffix); + while point < token_bytes.len() && token_bytes[point].starts_with(suffix) { + let possibility = [prefix, token_bytes[point].as_slice()].concat(); + let encoded = match std::str::from_utf8(&possibility) { + // Morally, this is byte_pair_encode(&possibility, &self.encoder) + // But we might have introduced a regex split which would prevent merges. + // (particularly possible in the presence of unstable regex splits) + // So convert to UTF-8 and do regex splitting. + // E.g. with cl100k_base " !" gets split to " " + " !", + // but byte_pair_encode(" !") != byte_pair_encode(" ") + Ok(s) => self.encode_ordinary(s), + + // Technically, whether or not this arm is correct depends on whether there + // would be a regex split before the UTF-8 truncation point. + // Probably niche enough that no one will ever notice (after all, people didn't + // notice all the big holes in the previous unstable token implementation) + Err(_) => byte_pair_encode(&possibility, &self.encoder), + // Something like the following is intriguing but incorrect: + // Err(e) => self.encode_ordinary(unsafe { + // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()]) + // }), + }; + let mut seq = Vec::new(); + let mut seq_len = 0; + for token in encoded { + seq.push(token); + seq_len += self.decoder[&token].len(); + if seq_len >= unstable_bytes.len() { + break; + } } + completions.insert(seq); + point += 1; } - completions.insert(seq); - point += 1; } } @@ -622,13 +655,11 @@ impl CoreBPE { ) -> Result> { let regex = Regex::new(pattern)?; - let special_regex = { - let parts = special_tokens_encoder - .keys() - .map(|s| fancy_regex::escape(s)) - .collect::>(); - Regex::new(&parts.join("|"))? - }; + let special_regex_pattern = special_tokens_encoder + .keys() + .map(|s| fancy_regex::escape(s)) + .collect::>() + .join("|"); let decoder: HashMap> = encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); @@ -646,19 +677,16 @@ impl CoreBPE { .collect(); // Clone because I don't know how to tell Rust I'm not going to change the map - let mut sorted_token_bytes: Vec> = encoder.keys().cloned().collect(); - sorted_token_bytes.sort(); - Ok(Self { encoder, special_tokens_encoder, decoder, special_tokens_decoder, + special_regex_pattern, regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(), - special_regex_tls: (0..MAX_NUM_THREADS) - .map(|_| special_regex.clone()) - .collect(), - sorted_token_bytes, + special_regex_tls: OnceLock::new(), + sorted_token_bytes: OnceLock::new(), + token_bytes_by_first_byte: OnceLock::new(), }) } @@ -677,7 +705,6 @@ impl CoreBPE { #[cfg(test)] mod tests { - use fancy_regex::Regex; use rustc_hash::FxHashMap as HashMap; use crate::{Rank, byte_pair_split}; diff --git a/src/py.rs b/src/py.rs index 60b62452..64bffc51 100644 --- a/src/py.rs +++ b/src/py.rs @@ -4,12 +4,210 @@ use pyo3::{ IntoPyObjectExt, PyResult, exceptions, prelude::*, pybacked::PyBackedStr, - types::{PyBytes, PyList}, + types::{PyBytes, PyDict, PyList}, }; use rustc_hash::FxHashMap as HashMap; use crate::{CoreBPE, Rank, byte_pair_encode}; +fn is_ascii_whitespace(byte: u8) -> bool { + matches!(byte, b' ' | b'\t' | b'\n' | b'\r' | 0x0b | 0x0c) +} + +fn bytes_repr(bytes: &[u8]) -> String { + let mut out = String::from("b'"); + for &byte in bytes { + match byte { + b'\'' => out.push_str("\\'"), + b'\\' => out.push_str("\\\\"), + b'\n' => out.push_str("\\n"), + b'\r' => out.push_str("\\r"), + b'\t' => out.push_str("\\t"), + 0x20..=0x7e => out.push(byte as char), + _ => out.push_str(&format!("\\x{byte:02x}")), + } + } + out.push('\''); + out +} + +fn parse_bpe_error(line: &[u8], source: &str) -> PyErr { + exceptions::PyValueError::new_err(format!( + "Error parsing line {} in {source}", + bytes_repr(line) + )) +} + +fn base64_value(byte: u8) -> Option { + match byte { + b'A'..=b'Z' => Some(byte - b'A'), + b'a'..=b'z' => Some(byte - b'a' + 26), + b'0'..=b'9' => Some(byte - b'0' + 52), + b'+' => Some(62), + b'/' => Some(63), + _ => None, + } +} + +fn decode_base64(input: &[u8]) -> Option> { + if input.is_empty() || input.len() % 4 != 0 { + return None; + } + + let padding = input.iter().rev().take_while(|&&byte| byte == b'=').count(); + if padding > 2 { + return None; + } + + let output_len = input.len() / 4 * 3 - padding; + let mut output = Vec::with_capacity(output_len); + + for (chunk_index, chunk) in input.chunks_exact(4).enumerate() { + let is_last = chunk_index == input.len() / 4 - 1; + if !is_last && chunk.contains(&b'=') { + return None; + } + + let a = base64_value(chunk[0])?; + let b = base64_value(chunk[1])?; + output.push((a << 2) | (b >> 4)); + + match (chunk[2], chunk[3]) { + (b'=', b'=') if is_last => {} + (b'=', _) => return None, + (c, b'=') if is_last => { + let c = base64_value(c)?; + output.push((b << 4) | (c >> 2)); + } + (c, d) => { + let c = base64_value(c)?; + let d = base64_value(d)?; + output.push((b << 4) | (c >> 2)); + output.push((c << 6) | d); + } + } + } + + Some(output) +} + +fn split_bpe_line(line: &[u8]) -> Option<(&[u8], &[u8])> { + let mut fields = line + .split(|&byte| is_ascii_whitespace(byte)) + .filter(|field| !field.is_empty()); + let token = fields.next()?; + let rank = fields.next()?; + if fields.next().is_some() { + return None; + } + Some((token, rank)) +} + +fn parse_rank(bytes: &[u8]) -> Option { + if bytes.is_empty() { + return None; + } + + let mut rank: u64 = 0; + for &byte in bytes { + if !byte.is_ascii_digit() { + return None; + } + rank = rank.checked_mul(10)?.checked_add(u64::from(byte - b'0'))?; + if rank > u64::from(Rank::MAX) { + return None; + } + } + Some(rank as Rank) +} + +fn for_each_bpe_entry( + contents: &[u8], + source: &str, + mut f: impl FnMut(Vec, Rank) -> PyResult<()>, +) -> PyResult<()> { + for mut line in contents.split(|&byte| byte == b'\n') { + if line.ends_with(b"\r") { + line = &line[..line.len() - 1]; + } + if line.is_empty() { + continue; + } + + let (token, rank) = split_bpe_line(line).ok_or_else(|| parse_bpe_error(line, source))?; + let token = decode_base64(token).ok_or_else(|| parse_bpe_error(line, source))?; + let rank = parse_rank(rank).ok_or_else(|| parse_bpe_error(line, source))?; + f(token, rank)?; + } + + Ok(()) +} + +#[pyfunction] +fn load_tiktoken_bpe(py: Python, contents: &[u8], source: &str) -> PyResult> { + let ret = PyDict::new(py); + + for_each_bpe_entry(contents, source, |token, rank| { + ret.set_item(PyBytes::new(py, &token), rank)?; + Ok(()) + })?; + + Ok(ret.into()) +} + +#[pyfunction] +fn load_tiktoken_bpe_core( + contents: &[u8], + source: &str, + special_tokens_encoder: HashMap, + pattern: &str, +) -> PyResult<(CoreBPE, usize, Rank)> { + let mut encoder = HashMap::default(); + let mut max_rank = 0; + + for_each_bpe_entry(contents, source, |token, rank| { + max_rank = max_rank.max(rank); + encoder.insert(token, rank); + Ok(()) + })?; + + let n_mergeable_ranks = encoder.len(); + let core_bpe = CoreBPE::new_internal(encoder, special_tokens_encoder, pattern) + .map_err(|e| exceptions::PyValueError::new_err(e.to_string()))?; + Ok((core_bpe, n_mergeable_ranks, max_rank)) +} + +fn decode_token_bytes(core_bpe: &CoreBPE, token: Rank) -> Result<&[u8], Rank> { + if let Some(bytes) = core_bpe.decoder.get(&token) { + return Ok(bytes); + } + if let Some(bytes) = core_bpe.special_tokens_decoder.get(&token) { + return Ok(bytes); + } + Err(token) +} + +fn decode_with_offsets(core_bpe: &CoreBPE, tokens: &[Rank]) -> Result<(Vec, Vec), Rank> { + let mut text = Vec::with_capacity(tokens.len() * 4); + let mut offsets = Vec::with_capacity(tokens.len()); + let mut text_len = 0usize; + + for &token in tokens { + let token_bytes = decode_token_bytes(core_bpe, token)?; + let starts_with_continuation = token_bytes + .first() + .is_some_and(|&byte| (0x80..0xC0).contains(&byte)); + offsets.push(text_len.saturating_sub(starts_with_continuation as usize)); + text_len += token_bytes + .iter() + .filter(|&&byte| !(0x80..0xC0).contains(&byte)) + .count(); + text.extend_from_slice(token_bytes); + } + + Ok((text, offsets)) +} + #[pymethods] impl CoreBPE { #[new] @@ -31,6 +229,15 @@ impl CoreBPE { py.detach(|| self.encode_ordinary(text)) } + #[pyo3(name = "encode_ordinary_batch")] + fn py_encode_ordinary_batch(&self, py: Python, text: Vec) -> Vec> { + py.detach(|| { + text.iter() + .map(|text| self.encode_ordinary(text.as_ref())) + .collect() + }) + } + #[pyo3(name = "encode")] fn py_encode( &self, @@ -48,6 +255,25 @@ impl CoreBPE { }) } + #[pyo3(name = "encode_batch")] + fn py_encode_batch( + &self, + py: Python, + text: Vec, + allowed_special: HashSet, + ) -> PyResult>> { + py.detach(|| { + let allowed_special: HashSet<&str> = + allowed_special.iter().map(|s| s.as_ref()).collect(); + text.iter() + .map(|text| match self.encode(text.as_ref(), &allowed_special) { + Ok((tokens, _)) => Ok(tokens), + Err(e) => Err(PyErr::new::(e.message)), + }) + .collect() + }) + } + fn encode_to_tiktoken_buffer( &self, py: Python, @@ -161,14 +387,58 @@ impl CoreBPE { } } - fn decode_single_token_bytes(&self, py: Python, token: Rank) -> PyResult> { - if let Some(bytes) = self.decoder.get(&token) { - return Ok(PyBytes::new(py, bytes).into()); + #[pyo3(name = "decode_bytes_batch")] + fn py_decode_bytes_batch( + &self, + py: Python, + batch: Vec>, + ) -> Result>, PyErr> { + match py.detach(|| { + batch + .iter() + .map(|tokens| self.decode_bytes(tokens)) + .collect::, _>>() + }) { + Ok(bytes_batch) => Ok(bytes_batch + .iter() + .map(|bytes| PyBytes::new(py, bytes).into()) + .collect()), + Err(e) => Err(pyo3::exceptions::PyKeyError::new_err(format!("{}", e))), + } + } + + #[pyo3(name = "decode_tokens_bytes")] + fn py_decode_tokens_bytes( + &self, + py: Python, + tokens: Vec, + ) -> Result>, PyErr> { + tokens + .iter() + .map(|&token| match decode_token_bytes(self, token) { + Ok(bytes) => Ok(PyBytes::new(py, bytes).into()), + Err(token) => Err(pyo3::exceptions::PyKeyError::new_err(token.to_string())), + }) + .collect() + } + + #[pyo3(name = "decode_with_offsets")] + fn py_decode_with_offsets( + &self, + py: Python, + tokens: Vec, + ) -> Result<(Py, Vec), PyErr> { + match py.detach(|| decode_with_offsets(self, &tokens)) { + Ok((text, offsets)) => Ok((PyBytes::new(py, &text).into(), offsets)), + Err(token) => Err(pyo3::exceptions::PyKeyError::new_err(token.to_string())), } - if let Some(bytes) = self.special_tokens_decoder.get(&token) { - return Ok(PyBytes::new(py, bytes).into()); + } + + fn decode_single_token_bytes(&self, py: Python, token: Rank) -> PyResult> { + match decode_token_bytes(self, token) { + Ok(bytes) => Ok(PyBytes::new(py, bytes).into()), + Err(token) => Err(PyErr::new::(token.to_string())), } - Err(PyErr::new::(token.to_string())) } // ==================== @@ -176,11 +446,19 @@ impl CoreBPE { // ==================== fn token_byte_values(&self, py: Python) -> Vec> { - self.sorted_token_bytes + self.sorted_token_bytes() .iter() .map(|x| PyBytes::new(py, x).into()) .collect() } + + fn mergeable_ranks(&self, py: Python) -> PyResult> { + let ret = PyDict::new(py); + for (token, rank) in &self.encoder { + ret.set_item(PyBytes::new(py, token), *rank)?; + } + Ok(ret.into()) + } } #[pyclass(frozen)] @@ -251,5 +529,7 @@ impl TiktokenBuffer { #[pymodule(gil_used = false)] fn _tiktoken(_py: Python, m: &Bound) -> PyResult<()> { m.add_class::()?; + m.add_function(wrap_pyfunction!(load_tiktoken_bpe, m)?)?; + m.add_function(wrap_pyfunction!(load_tiktoken_bpe_core, m)?)?; Ok(()) } diff --git a/tests/test_encoding.py b/tests/test_encoding.py index b77ca135..0a436ee2 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -152,7 +152,29 @@ def test_basic_roundtrip(make_enc): def test_hyp_roundtrip(make_enc: Callable[[], tiktoken.Encoding], text): enc = make_enc() - assert text == enc.decode(enc.encode(text)) + assert text == enc.decode(enc.encode(text, disallowed_special=())) + + +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) +def test_encode_with_unstable_invariants(make_enc: Callable[[], tiktoken.Encoding]): + enc = make_enc() + text = "hello fanta" + + stable_tokens, completions = enc.encode_with_unstable(text) + + assert text.encode().startswith(enc.decode_bytes(stable_tokens)) + assert completions + assert all( + enc.decode_bytes(stable_tokens + completion).startswith(text.encode()) + for completion in completions + ) + + +def test_encode_with_unstable_disallowed_special(): + enc = tiktoken.get_encoding("o200k_harmony") + + with pytest.raises(ValueError, match="<\\|endoftext\\|>"): + enc.encode_with_unstable("hello <|endoftext|>") @pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) @@ -222,6 +244,13 @@ def test_special_token(): assert fip not in tokens assert fim in tokens + enc = tiktoken.get_encoding("o200k_harmony") + assert enc.encode("hello world") == enc.encode_ordinary("hello world") + assert enc.encode("<|message|>", disallowed_special=()) == enc.encode_ordinary("<|message|>") + assert enc.encode_batch(["hello world"]) == [enc.encode_ordinary("hello world")] + with pytest.raises(ValueError): + enc.encode("<|message|>") + @pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) @hypothesis.given(text=st.text()) @@ -244,12 +273,109 @@ def test_batch_encode(make_enc: Callable[[], tiktoken.Encoding]): assert enc.encode_batch([text1]) == [enc.encode(text1)] assert enc.encode_batch([text1, text2]) == [enc.encode(text1), enc.encode(text2)] + assert enc.encode_batch([text1, text2], num_threads=1) == [ + enc.encode(text1), + enc.encode(text2), + ] assert enc.encode_ordinary_batch([text1]) == [enc.encode_ordinary(text1)] assert enc.encode_ordinary_batch([text1, text2]) == [ enc.encode_ordinary(text1), enc.encode_ordinary(text2), ] + assert enc.encode_ordinary_batch([text1, text2], num_threads=1) == [ + enc.encode_ordinary(text1), + enc.encode_ordinary(text2), + ] + + +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) +def test_encode_ordinary_batch_edge_cases(make_enc: Callable[[], tiktoken.Encoding]): + enc = make_enc() + + assert enc.encode_ordinary_batch([]) == [] + assert enc.encode_ordinary_batch(["hello", "\ud800"]) == [ + enc.encode_ordinary("hello"), + enc.encode_ordinary("\ud800"), + ] + assert enc.encode_batch(["hello", "\ud800"], disallowed_special=()) == [ + enc.encode("hello", disallowed_special=()), + enc.encode("\ud800", disallowed_special=()), + ] + + with pytest.raises(ValueError, match="max_workers must be greater than 0"): + enc.encode_ordinary_batch(["hello"], num_threads=0) + + with pytest.raises(ValueError, match="max_workers must be greater than 0"): + enc.encode_batch(["hello"], num_threads=0) + + +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) +def test_encode_batch_special_tokens(make_enc: Callable[[], tiktoken.Encoding]): + enc = make_enc() + special = next(iter(enc.special_tokens_set)) + + assert enc.encode_batch([special], allowed_special={special}) == [ + enc.encode(special, allowed_special={special}) + ] + + with pytest.raises(ValueError, match="disallowed special token"): + enc.encode_batch([special]) + + +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) +def test_decode_batch_edge_cases(make_enc: Callable[[], tiktoken.Encoding]): + enc = make_enc() + tokens = [enc.encode_ordinary("hello"), enc.encode_ordinary("world")] + + assert enc.decode_batch(tokens) == [enc.decode(t) for t in tokens] + assert enc.decode_batch(tokens, num_threads=1) == [enc.decode(t) for t in tokens] + assert enc.decode_bytes_batch(tokens) == [enc.decode_bytes(t) for t in tokens] + assert enc.decode_bytes_batch(tokens, num_threads=1) == [enc.decode_bytes(t) for t in tokens] + + assert enc.decode_batch([]) == [] + assert enc.decode_bytes_batch([]) == [] + + with pytest.raises(ValueError, match="max_workers must be greater than 0"): + enc.decode_batch(tokens, num_threads=0) + + with pytest.raises(ValueError, match="max_workers must be greater than 0"): + enc.decode_bytes_batch(tokens, num_threads=0) + + with pytest.raises(KeyError): + enc.decode_batch([[enc.max_token_value + 1]]) + + with pytest.raises(KeyError): + enc.decode_bytes_batch([[enc.max_token_value + 1]]) + + +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) +def test_decode_batch_errors(make_enc: Callable[[], tiktoken.Encoding]): + enc = make_enc() + tokens = enc._encode_bytes(b"\xff") + + assert enc.decode_batch([tokens], errors="replace") == [enc.decode(tokens, errors="replace")] + assert enc.decode_batch([tokens], errors="ignore") == [enc.decode(tokens, errors="ignore")] + + with pytest.raises(UnicodeDecodeError): + enc.decode_batch([tokens], errors="strict") + + +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) +def test_decode_tokens_bytes_edge_cases(make_enc: Callable[[], tiktoken.Encoding]): + enc = make_enc() + tokens = enc.encode("hello world", allowed_special="all") + + assert enc.decode_tokens_bytes(tokens) == [ + enc.decode_single_token_bytes(token) for token in tokens + ] + assert enc.decode_tokens_bytes(token for token in tokens) == [ + enc.decode_single_token_bytes(token) for token in tokens + ] + assert enc.decode_tokens_bytes([]) == [] + + with pytest.raises(KeyError): + enc.decode_tokens_bytes([enc.max_token_value + 1]) @pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) diff --git a/tests/test_load.py b/tests/test_load.py new file mode 100644 index 00000000..46c266b8 --- /dev/null +++ b/tests/test_load.py @@ -0,0 +1,74 @@ +import pytest + +import tiktoken +import tiktoken_ext.openai_public +from tiktoken.load import ( + _load_tiktoken_bpe_python, + _load_tiktoken_bpe_core, + load_tiktoken_bpe, +) + + +def test_load_tiktoken_bpe_rust_matches_python(tmp_path): + bpe_file = tmp_path / "tiny.tiktoken" + bpe_file.write_bytes(b"IQ== 0\nIg== 1\n4pyT 2\n") + + assert load_tiktoken_bpe(str(bpe_file)) == _load_tiktoken_bpe_python(str(bpe_file)) + + +def test_load_tiktoken_bpe_core(tmp_path): + bpe_file = tmp_path / "tiny.tiktoken" + bpe_file.write_bytes(b"IQ== 0\nIg== 1\n4pyT 2\n") + + core_bpe, mergeable_ranks_len, mergeable_ranks_max_token_value = _load_tiktoken_bpe_core( + str(bpe_file), + special_tokens={"<|special|>": 3}, + pat_str=r""".+""", + ) + + assert mergeable_ranks_len == 3 + assert mergeable_ranks_max_token_value == 2 + assert core_bpe.encode_single_token(b"!") == 0 + assert core_bpe.encode_single_token("<|special|>".encode()) == 3 + + +def test_load_tiktoken_bpe_parse_error_includes_source(tmp_path): + bpe_file = tmp_path / "bad.tiktoken" + bpe_file.write_bytes(b"IQ== 0 extra\n") + + with pytest.raises(ValueError, match="bad.tiktoken"): + load_tiktoken_bpe(str(bpe_file)) + + +def test_public_encoding_mergeable_ranks_materialize_lazily(): + enc = tiktoken.get_encoding("cl100k_base") + + assert enc.__dict__["_mergeable_ranks"] is None + assert enc._mergeable_ranks[b"!"] == 0 + assert isinstance(enc.__dict__["_mergeable_ranks"], dict) + + +def test_extending_public_encoding_after_lazy_construction(): + base = tiktoken.get_encoding("cl100k_base") + enc = tiktoken.Encoding( + name="cl100k_test", + pat_str=base._pat_str, + mergeable_ranks=base._mergeable_ranks, + special_tokens={ + **base._special_tokens, + "<|test|>": 100264, + }, + ) + + assert enc.encode("hello <|test|>", allowed_special="all") == [15339, 220, 100264] + + +def test_openai_public_constructor_private_core_path_still_constructs(): + constructor_args = tiktoken_ext.openai_public.cl100k_base() + assert constructor_args["mergeable_ranks"][b"!"] == 0 + assert dict(constructor_args["mergeable_ranks"])[b"!"] == 0 + + enc = tiktoken.Encoding(**tiktoken_ext.openai_public.cl100k_base()) + + assert enc.encode("hello world") == [15339, 1917] + assert enc.__dict__["_mergeable_ranks"] is None diff --git a/tests/test_offsets.py b/tests/test_offsets.py index 31b7f8d4..746aa402 100644 --- a/tests/test_offsets.py +++ b/tests/test_offsets.py @@ -49,10 +49,15 @@ def test_hyp_offsets(make_enc: Callable[[], tiktoken.Encoding], data): def test_basic_offsets(): enc = tiktoken.get_encoding("cl100k_base") + assert enc.decode_with_offsets([]) == ("", []) + prompt = "hello world" p, o = enc.decode_with_offsets(enc.encode(prompt)) assert p == prompt assert o == [0, 5] + p, o = enc.decode_with_offsets(token for token in enc.encode(prompt)) + assert p == prompt + assert o == [0, 5] prompt = "hello world<|endoftext|> green cow" p, o = enc.decode_with_offsets(enc.encode(prompt, allowed_special="all")) @@ -77,3 +82,13 @@ def test_basic_offsets(): p, o = enc.decode_with_offsets(enc.encode(prompt)) assert p == prompt assert o == [0, 1] + + +def test_offsets_errors(): + enc = tiktoken.get_encoding("cl100k_base") + + with pytest.raises(KeyError): + enc.decode_with_offsets([enc.max_token_value + 1]) + + with pytest.raises(UnicodeDecodeError): + enc.decode_with_offsets(enc._encode_bytes(b"\x80")) diff --git a/tiktoken/core.py b/tiktoken/core.py index 530f8f59..2296d190 100644 --- a/tiktoken/core.py +++ b/tiktoken/core.py @@ -1,8 +1,9 @@ from __future__ import annotations import functools +from collections.abc import Mapping from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, AbstractSet, Collection, Literal, NoReturn, Sequence +from typing import TYPE_CHECKING, AbstractSet, Collection, Iterator, Literal, NoReturn, Sequence from tiktoken import _tiktoken @@ -13,15 +14,62 @@ import numpy.typing as npt +class _LazyMergeableRanks(Mapping[bytes, int]): + def __init__( + self, + core_bpe: _tiktoken.CoreBPE, + n_mergeable_ranks: int, + max_token_value: int, + ): + self._core_bpe = core_bpe + self._n_mergeable_ranks = n_mergeable_ranks + self._max_token_value = max_token_value + self._mergeable_ranks: dict[bytes, int] | None = None + + @property + def n_mergeable_ranks(self) -> int: + return self._n_mergeable_ranks + + @property + def max_token_value(self) -> int: + return self._max_token_value + + @property + def core_bpe(self) -> _tiktoken.CoreBPE: + return self._core_bpe + + def _materialized(self) -> dict[bytes, int]: + mergeable_ranks = self._mergeable_ranks + if mergeable_ranks is None: + mergeable_ranks = self._core_bpe.mergeable_ranks() + self._mergeable_ranks = mergeable_ranks + return mergeable_ranks + + def __getitem__(self, key: bytes) -> int: + return self._materialized()[key] + + def __iter__(self) -> Iterator[bytes]: + return iter(self._materialized()) + + def __len__(self) -> int: + return self._n_mergeable_ranks + + def copy(self) -> dict[bytes, int]: + return self._materialized().copy() + + class Encoding: def __init__( self, name: str, *, pat_str: str, - mergeable_ranks: dict[bytes, int], + mergeable_ranks: dict[bytes, int] | _LazyMergeableRanks | None, special_tokens: dict[str, int], explicit_n_vocab: int | None = None, + _core_bpe: _tiktoken.CoreBPE | None = None, + _mergeable_ranks_len: int | None = None, + _mergeable_ranks_max_token_value: int | None = None, ): """Creates an Encoding object. @@ -41,24 +89,61 @@ def __init__( self.name = name self._pat_str = pat_str + if isinstance(mergeable_ranks, _LazyMergeableRanks): + if _core_bpe is None: + _core_bpe = mergeable_ranks.core_bpe + if _mergeable_ranks_len is None: + _mergeable_ranks_len = mergeable_ranks.n_mergeable_ranks + if _mergeable_ranks_max_token_value is None: + _mergeable_ranks_max_token_value = mergeable_ranks.max_token_value + mergeable_ranks = None + self._mergeable_ranks = mergeable_ranks self._special_tokens = special_tokens + mergeable_ranks_len = ( + len(mergeable_ranks) if mergeable_ranks is not None else _mergeable_ranks_len + ) + mergeable_ranks_max_token_value = ( + max(mergeable_ranks.values()) + if mergeable_ranks is not None + else _mergeable_ranks_max_token_value + ) + assert mergeable_ranks_len is not None + assert mergeable_ranks_max_token_value is not None + self.max_token_value = max( - max(mergeable_ranks.values()), max(special_tokens.values(), default=0) + mergeable_ranks_max_token_value, max(special_tokens.values(), default=0) ) if explicit_n_vocab: - assert len(mergeable_ranks) + len(special_tokens) == explicit_n_vocab + assert mergeable_ranks_len + len(special_tokens) == explicit_n_vocab assert self.max_token_value == explicit_n_vocab - 1 # Contains on set is significantly faster than on dict_values self._special_token_values = set(self._special_tokens.values()) + self._special_tokens_set_frozen = frozenset(self._special_tokens) - self._core_bpe = _tiktoken.CoreBPE(mergeable_ranks, special_tokens, pat_str) + if _core_bpe is not None: + self._core_bpe = _core_bpe + else: + assert mergeable_ranks is not None + self._core_bpe = _tiktoken.CoreBPE(mergeable_ranks, special_tokens, pat_str) def __repr__(self) -> str: return f"" + @property + def _mergeable_ranks(self) -> dict[bytes, int]: + mergeable_ranks = self.__dict__.get("_mergeable_ranks") + if mergeable_ranks is None: + mergeable_ranks = self._core_bpe.mergeable_ranks() + self.__dict__["_mergeable_ranks"] = mergeable_ranks + return mergeable_ranks + + @_mergeable_ranks.setter + def _mergeable_ranks(self, mergeable_ranks: dict[bytes, int] | None) -> None: + self.__dict__["_mergeable_ranks"] = mergeable_ranks + # ==================== # Encoding # ==================== @@ -116,14 +201,22 @@ def encode( if allowed_special == "all": allowed_special = self.special_tokens_set if disallowed_special == "all": - disallowed_special = self.special_tokens_set - allowed_special + disallowed_special = ( + self._special_tokens_set_frozen + if not allowed_special + else self._special_tokens_set_frozen - allowed_special + ) if disallowed_special: if not isinstance(disallowed_special, frozenset): disallowed_special = frozenset(disallowed_special) - if match := _special_token_regex(disallowed_special).search(text): - raise_disallowed_special_token(match.group()) + common_prefix = _special_token_common_prefix(disallowed_special) + if not common_prefix or common_prefix in text: + if match := _special_token_regex(disallowed_special).search(text): + raise_disallowed_special_token(match.group()) try: + if not allowed_special: + return self._core_bpe.encode_ordinary(text) return self._core_bpe.encode(text, allowed_special) except UnicodeEncodeError: # BPE operates on bytes, but the regex operates on unicode. If we pass a str that is @@ -133,6 +226,8 @@ def encode( # string, but given that this is input we want to support, maybe that's okay. # Also we use errors="replace" to handle weird things like lone surrogates. text = text.encode("utf-16", "surrogatepass").decode("utf-16", "replace") + if not allowed_special: + return self._core_bpe.encode_ordinary(text) return self._core_bpe.encode(text, allowed_special) def encode_to_numpy( @@ -149,12 +244,18 @@ def encode_to_numpy( if allowed_special == "all": allowed_special = self.special_tokens_set if disallowed_special == "all": - disallowed_special = self.special_tokens_set - allowed_special + disallowed_special = ( + self._special_tokens_set_frozen + if not allowed_special + else self._special_tokens_set_frozen - allowed_special + ) if disallowed_special: if not isinstance(disallowed_special, frozenset): disallowed_special = frozenset(disallowed_special) - if match := _special_token_regex(disallowed_special).search(text): - raise_disallowed_special_token(match.group()) + common_prefix = _special_token_common_prefix(disallowed_special) + if not common_prefix or common_prefix in text: + if match := _special_token_regex(disallowed_special).search(text): + raise_disallowed_special_token(match.group()) import numpy as np @@ -171,6 +272,25 @@ def encode_ordinary_batch(self, text: list[str], *, num_threads: int = 8) -> lis [[31373, 995], [11274, 16390, 995]] ``` """ + if num_threads <= 0: + raise ValueError("max_workers must be greater than 0") + + try: + batch_len = len(text) + except TypeError: + batch_len = None + + if batch_len == 0: + return [] + + if _use_native_batch(text, batch_len, num_threads): + try: + return self._core_bpe.encode_ordinary_batch(text) + except (TypeError, UnicodeEncodeError): + # Match encode_ordinary's surrogate fixup behavior by falling back to the + # per-string path when any string cannot be passed to Rust as UTF-8. + pass + encoder = functools.partial(self.encode_ordinary) with ThreadPoolExecutor(num_threads) as e: return list(e.map(encoder, text)) @@ -195,10 +315,43 @@ def encode_batch( if allowed_special == "all": allowed_special = self.special_tokens_set if disallowed_special == "all": - disallowed_special = self.special_tokens_set - allowed_special + disallowed_special = ( + self._special_tokens_set_frozen + if not allowed_special + else self._special_tokens_set_frozen - allowed_special + ) if not isinstance(disallowed_special, frozenset): disallowed_special = frozenset(disallowed_special) + if num_threads <= 0: + raise ValueError("max_workers must be greater than 0") + + try: + batch_len = len(text) + except TypeError: + batch_len = None + + if batch_len == 0: + return [] + + if _use_native_batch(text, batch_len, num_threads): + try: + if disallowed_special: + common_prefix = _special_token_common_prefix(disallowed_special) + special_regex = None + for piece in text: + if common_prefix and common_prefix not in piece: + continue + if special_regex is None: + special_regex = _special_token_regex(disallowed_special) + if match := special_regex.search(piece): + raise_disallowed_special_token(match.group()) + if not allowed_special: + return self._core_bpe.encode_ordinary_batch(text) + return self._core_bpe.encode_batch(text, allowed_special) + except (TypeError, UnicodeEncodeError): + pass + encoder = functools.partial( self.encode, allowed_special=allowed_special, disallowed_special=disallowed_special ) @@ -233,12 +386,18 @@ def encode_with_unstable( if allowed_special == "all": allowed_special = self.special_tokens_set if disallowed_special == "all": - disallowed_special = self.special_tokens_set - allowed_special + disallowed_special = ( + self._special_tokens_set_frozen + if not allowed_special + else self._special_tokens_set_frozen - allowed_special + ) if disallowed_special: if not isinstance(disallowed_special, frozenset): disallowed_special = frozenset(disallowed_special) - if match := _special_token_regex(disallowed_special).search(text): - raise_disallowed_special_token(match.group()) + common_prefix = _special_token_common_prefix(disallowed_special) + if not common_prefix or common_prefix in text: + if match := _special_token_regex(disallowed_special).search(text): + raise_disallowed_special_token(match.group()) return self._core_bpe.encode_with_unstable(text, allowed_special) @@ -307,7 +466,10 @@ def decode_tokens_bytes(self, tokens: Sequence[int]) -> list[bytes]: >>> enc.decode_tokens_bytes([31373, 995]) [b'hello', b' world'] """ - return [self.decode_single_token_bytes(token) for token in tokens] + try: + return self._core_bpe.decode_tokens_bytes(tokens) + except TypeError: + return [self.decode_single_token_bytes(token) for token in tokens] def decode_with_offsets(self, tokens: Sequence[int]) -> tuple[str, list[int]]: """Decodes a list of tokens into a string and a list of offsets. @@ -322,22 +484,46 @@ def decode_with_offsets(self, tokens: Sequence[int]) -> tuple[str, list[int]]: >>> enc.decode_with_offsets([31373, 995]) ('hello world', [0, 5]) """ - token_bytes = self.decode_tokens_bytes(tokens) + try: + text_bytes, offsets = self._core_bpe.decode_with_offsets(tokens) + text = text_bytes.decode("utf-8", errors="strict") + return text, offsets + except TypeError: + token_bytes = self.decode_tokens_bytes(tokens) - text_len = 0 - offsets = [] - for token in token_bytes: - offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0))) - text_len += sum(1 for c in token if not 0x80 <= c < 0xC0) + text_len = 0 + offsets = [] + for token in token_bytes: + offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0))) + text_len += sum(1 for c in token if not 0x80 <= c < 0xC0) - # TODO: assess correctness for errors="ignore" and errors="replace" - text = b"".join(token_bytes).decode("utf-8", errors="strict") - return text, offsets + text = b"".join(token_bytes).decode("utf-8", errors="strict") + return text, offsets def decode_batch( self, batch: Sequence[Sequence[int]], *, errors: str = "replace", num_threads: int = 8 ) -> list[str]: """Decodes a batch (list of lists of tokens) into a list of strings.""" + if num_threads <= 0: + raise ValueError("max_workers must be greater than 0") + + try: + batch_len = len(batch) + except TypeError: + batch_len = None + + if batch_len == 0: + return [] + + if _use_native_decode_batch(batch, batch_len, num_threads): + try: + return [ + text.decode("utf-8", errors=errors) + for text in self._core_bpe.decode_bytes_batch(batch) + ] + except TypeError: + pass + decoder = functools.partial(self.decode, errors=errors) with ThreadPoolExecutor(num_threads) as e: return list(e.map(decoder, batch)) @@ -346,6 +532,23 @@ def decode_bytes_batch( self, batch: Sequence[Sequence[int]], *, num_threads: int = 8 ) -> list[bytes]: """Decodes a batch (list of lists of tokens) into a list of bytes.""" + if num_threads <= 0: + raise ValueError("max_workers must be greater than 0") + + try: + batch_len = len(batch) + except TypeError: + batch_len = None + + if batch_len == 0: + return [] + + if _use_native_decode_batch(batch, batch_len, num_threads): + try: + return self._core_bpe.decode_bytes_batch(batch) + except TypeError: + pass + with ThreadPoolExecutor(num_threads) as e: return list(e.map(self.decode_bytes, batch)) @@ -438,6 +641,67 @@ def _special_token_regex(tokens: frozenset[str]) -> re.Pattern[str]: return re.compile(f"({inner})") +@functools.lru_cache(maxsize=128) +def _special_token_common_prefix(tokens: frozenset[str]) -> str: + if not tokens: + return "" + first = min(tokens) + last = max(tokens) + for i, char in enumerate(first): + if i == len(last) or last[i] != char: + return first[:i] + return first + + +def _use_native_batch(text: list[str], batch_len: int | None, num_threads: int) -> bool: + if batch_len is None: + return False + if num_threads == 1: + return True + + try: + head = min(batch_len, 32) + sample_chars = 0 + sample_count = 0 + for i in range(head): + sample_chars += len(text[i]) + sample_count += 1 + for i in range(max(head, batch_len - 32), batch_len): + sample_chars += len(text[i]) + sample_count += 1 + except (IndexError, TypeError): + return False + + return sample_chars <= sample_count * 256 + + +def _use_native_decode_batch( + batch: Sequence[Sequence[int]], batch_len: int | None, num_threads: int +) -> bool: + if batch_len is None: + return False + if num_threads == 1: + return True + + try: + head = min(batch_len, 32) + sample_tokens = 0 + sample_count = 0 + for i in range(head): + sample_tokens += len(batch[i]) + sample_count += 1 + for i in range(max(head, batch_len - 32), batch_len): + sample_tokens += len(batch[i]) + sample_count += 1 + except (IndexError, TypeError): + return False + + if sample_tokens <= sample_count * 256: + return True + + return batch_len >= 1000 and sample_tokens <= sample_count * 2048 + + def raise_disallowed_special_token(token: str) -> NoReturn: raise ValueError( f"Encountered text corresponding to disallowed special token {token!r}.\n" diff --git a/tiktoken/load.py b/tiktoken/load.py index 3c76bcb3..fe6d5f83 100644 --- a/tiktoken/load.py +++ b/tiktoken/load.py @@ -1,6 +1,7 @@ from __future__ import annotations import base64 +import binascii import hashlib import os @@ -158,6 +159,29 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: str | None = None) -> dict[bytes, int]: # NB: do not add caching to this function + contents = read_file_cached(tiktoken_bpe_file, expected_hash) + from tiktoken import _tiktoken + + return _tiktoken.load_tiktoken_bpe(contents, tiktoken_bpe_file) + + +def _load_tiktoken_bpe_core( + tiktoken_bpe_file: str, + *, + special_tokens: dict[str, int], + pat_str: str, + expected_hash: str | None = None, +) -> tuple[object, int, int]: + # NB: do not add caching to this function + contents = read_file_cached(tiktoken_bpe_file, expected_hash) + from tiktoken import _tiktoken + + return _tiktoken.load_tiktoken_bpe_core(contents, tiktoken_bpe_file, special_tokens, pat_str) + + +def _load_tiktoken_bpe_python( + tiktoken_bpe_file: str, expected_hash: str | None = None +) -> dict[bytes, int]: contents = read_file_cached(tiktoken_bpe_file, expected_hash) ret = {} for line in contents.splitlines(): @@ -165,7 +189,7 @@ def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: str | None = None) continue try: token, rank = line.split() - ret[base64.b64decode(token)] = int(rank) + ret[binascii.a2b_base64(token)] = int(rank) except Exception as e: raise ValueError(f"Error parsing line {line!r} in {tiktoken_bpe_file}") from e return ret diff --git a/tiktoken_ext/openai_public.py b/tiktoken_ext/openai_public.py index 02c9ee20..6b3b8d01 100644 --- a/tiktoken_ext/openai_public.py +++ b/tiktoken_ext/openai_public.py @@ -1,4 +1,5 @@ -from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe +from tiktoken.core import _LazyMergeableRanks +from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe, _load_tiktoken_bpe_core ENDOFTEXT = "<|endoftext|>" FIM_PREFIX = "<|fim_prefix|>" @@ -13,6 +14,21 @@ r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s""" ) +# This regex could be made more efficient. If I was the one working on this encoding, I would +# have done a few other things differently too, e.g. I think you can allocate tokens more +# efficiently across languages. +o200k_pat_str = "|".join( + [ + r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""", + r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""", + r"""\p{N}{1,3}""", + r""" ?[^\s\p{L}\p{N}]+[\r\n/]*""", + r"""\s*[\r\n]+""", + r"""\s+(?!\S)""", + r"""\s+""", + ] +) + def gpt2(): mergeable_ranks = data_gym_to_mergeable_bpe_ranks( @@ -30,53 +46,74 @@ def gpt2(): } +def _load_tiktoken_bpe_args(tiktoken_bpe_file, *, special_tokens, pat_str, expected_hash): + core_bpe, mergeable_ranks_len, mergeable_ranks_max_token_value = _load_tiktoken_bpe_core( + tiktoken_bpe_file, + special_tokens=special_tokens, + pat_str=pat_str, + expected_hash=expected_hash, + ) + return { + "mergeable_ranks": _LazyMergeableRanks( + core_bpe, mergeable_ranks_len, mergeable_ranks_max_token_value + ), + "_core_bpe": core_bpe, + "_mergeable_ranks_len": mergeable_ranks_len, + "_mergeable_ranks_max_token_value": mergeable_ranks_max_token_value, + } + + def r50k_base(): - mergeable_ranks = load_tiktoken_bpe( + special_tokens = {ENDOFTEXT: 50256} + bpe_args = _load_tiktoken_bpe_args( "https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken", + special_tokens=special_tokens, + pat_str=r50k_pat_str, expected_hash="306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930", ) return { "name": "r50k_base", "explicit_n_vocab": 50257, "pat_str": r50k_pat_str, - "mergeable_ranks": mergeable_ranks, - "special_tokens": {ENDOFTEXT: 50256}, + "special_tokens": special_tokens, + **bpe_args, } def p50k_base(): - mergeable_ranks = load_tiktoken_bpe( + special_tokens = {ENDOFTEXT: 50256} + bpe_args = _load_tiktoken_bpe_args( "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken", + special_tokens=special_tokens, + pat_str=r50k_pat_str, expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069", ) return { "name": "p50k_base", "explicit_n_vocab": 50281, "pat_str": r50k_pat_str, - "mergeable_ranks": mergeable_ranks, - "special_tokens": {ENDOFTEXT: 50256}, + "special_tokens": special_tokens, + **bpe_args, } def p50k_edit(): - mergeable_ranks = load_tiktoken_bpe( + special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283} + bpe_args = _load_tiktoken_bpe_args( "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken", + special_tokens=special_tokens, + pat_str=r50k_pat_str, expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069", ) - special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283} return { "name": "p50k_edit", "pat_str": r50k_pat_str, - "mergeable_ranks": mergeable_ranks, "special_tokens": special_tokens, + **bpe_args, } def cl100k_base(): - mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken", - expected_hash="223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7", - ) special_tokens = { ENDOFTEXT: 100257, FIM_PREFIX: 100258, @@ -84,51 +121,43 @@ def cl100k_base(): FIM_SUFFIX: 100260, ENDOFPROMPT: 100276, } + pat_str = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s""" + bpe_args = _load_tiktoken_bpe_args( + "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken", + special_tokens=special_tokens, + pat_str=pat_str, + expected_hash="223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7", + ) return { "name": "cl100k_base", - "pat_str": r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s""", - "mergeable_ranks": mergeable_ranks, + "pat_str": pat_str, "special_tokens": special_tokens, + **bpe_args, } def o200k_base(): - mergeable_ranks = load_tiktoken_bpe( + special_tokens = {ENDOFTEXT: 199999, ENDOFPROMPT: 200018} + bpe_args = _load_tiktoken_bpe_args( "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken", + special_tokens=special_tokens, + pat_str=o200k_pat_str, expected_hash="446a9538cb6c348e3516120d7c08b09f57c36495e2acfffe59a5bf8b0cfb1a2d", ) - special_tokens = {ENDOFTEXT: 199999, ENDOFPROMPT: 200018} - # This regex could be made more efficient. If I was the one working on this encoding, I would - # have done a few other things differently too, e.g. I think you can allocate tokens more - # efficiently across languages. - pat_str = "|".join( - [ - r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""", - r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""", - r"""\p{N}{1,3}""", - r""" ?[^\s\p{L}\p{N}]+[\r\n/]*""", - r"""\s*[\r\n]+""", - r"""\s+(?!\S)""", - r"""\s+""", - ] - ) return { "name": "o200k_base", - "pat_str": pat_str, - "mergeable_ranks": mergeable_ranks, + "pat_str": o200k_pat_str, "special_tokens": special_tokens, + **bpe_args, } def o200k_harmony(): - base_enc = o200k_base() name = "o200k_harmony" - pat_str = base_enc["pat_str"] - mergeable_ranks = base_enc["mergeable_ranks"] special_tokens = { - **base_enc["special_tokens"], + ENDOFTEXT: 199999, + ENDOFPROMPT: 200018, "<|startoftext|>": 199998, - "<|endoftext|>": 199999, "<|reserved_200000|>": 200000, "<|reserved_200001|>": 200001, "<|return|>": 200002, @@ -143,11 +172,17 @@ def o200k_harmony(): "<|reserved_200011|>": 200011, "<|call|>": 200012, } | {f"<|reserved_{i}|>": i for i in range(200013, 201088)} + bpe_args = _load_tiktoken_bpe_args( + "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken", + special_tokens=special_tokens, + pat_str=o200k_pat_str, + expected_hash="446a9538cb6c348e3516120d7c08b09f57c36495e2acfffe59a5bf8b0cfb1a2d", + ) return { "name": name, - "pat_str": pat_str, - "mergeable_ranks": mergeable_ranks, + "pat_str": o200k_pat_str, "special_tokens": special_tokens, + **bpe_args, }