diff --git a/src/lib.rs b/src/lib.rs index ea54eac8..589da30a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,9 @@ mod py; pub type Rank = u32; +// Every distinct 2-byte sequence (256 byte values × 256). Indexed by (byte_a << 8) | byte_b. +const PAIR_TABLE_SIZE: usize = 256 * 256; + use std::collections::BinaryHeap; #[derive(Eq, PartialEq, Clone, Copy)] @@ -44,7 +47,11 @@ struct State { cur_rank: Rank, } -fn _byte_pair_merge_large(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec { +fn _byte_pair_merge_large( + ranks: &HashMap, Rank>, + piece: &[u8], + pair_table: Option<&[Rank; PAIR_TABLE_SIZE]>, +) -> Vec { let mut state = Vec::with_capacity(piece.len()); state.push(State { prev: usize::MAX, @@ -56,7 +63,16 @@ fn _byte_pair_merge_large(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec { + let r = table[((piece[i] as u16) << 8 | piece[i + 1] as u16) as usize]; + (r != Rank::MAX).then_some(r) + } + None => ranks.get(&piece[i..i + 2]).copied(), + }; + if let Some(rank) = rank_opt { heap.push(Merge { start: i, rank }); state[i].next_rank = rank; } @@ -137,7 +153,11 @@ fn _byte_pair_merge_large(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { +fn _byte_pair_merge( + ranks: &HashMap, Rank>, + piece: &[u8], + pair_table: Option<&[Rank; PAIR_TABLE_SIZE]>, +) -> Vec<(usize, Rank)> { // This is a vector of (start, rank). // The rank is of the pair starting at position start. let mut parts = Vec::with_capacity(piece.len() + 1); @@ -145,9 +165,16 @@ fn _byte_pair_merge(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE // the way we currently do, this is equivalent. An easy way to break this would be to decouple // merge priority from token index or to prevent specific token merges. + // + // When `pair_table` is Some, the initial pair scan uses a flat `PAIR_TABLE_SIZE`-entry + // array indexed by the 2-byte pair instead of the hashmap. Subsequent merges (3+ byte + // sequences) always go through the hashmap, since 3+ byte keys don't fit in a u16. let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX); for i in 0..piece.len() - 1 { - let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX); + let rank = match pair_table { + Some(table) => table[((piece[i] as u16) << 8 | piece[i + 1] as u16) as usize], + None => *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX), + }; if rank < min_rank.0 { min_rank = (rank, i); } @@ -202,17 +229,39 @@ pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap, Rank>) -> Vec, Rank>, + pair_table: &[Rank; PAIR_TABLE_SIZE], +) -> Vec { + let piece_len = piece.len(); + + if piece_len == 1 { + return vec![ranks[piece]]; + } + if piece_len < 100 { + return _byte_pair_merge(ranks, piece, Some(pair_table)) .windows(2) .map(|part| ranks[&piece[part[0].0..part[1].0]]) .collect(); } - _byte_pair_merge_large(ranks, piece) + _byte_pair_merge_large(ranks, piece, Some(pair_table)) } pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap, Rank>) -> Vec<&'a [u8]> { assert!(piece.len() > 1); - _byte_pair_merge(ranks, piece) + _byte_pair_merge(ranks, piece, None) .windows(2) .map(|part| &piece[part[0].0..part[1].0]) .collect() @@ -325,6 +374,10 @@ pub struct CoreBPE { regex_tls: Vec, special_regex_tls: Vec, sorted_token_bytes: Vec>, + /// Precomputed 2-byte pair to rank lookup table (~256 KB, built once at + /// construction). Used by encoding methods to skip the hashmap lookup for + /// the hot initial adjacent-pair scan inside `_byte_pair_merge`. + pair_table: Box<[Rank; PAIR_TABLE_SIZE]>, } impl CoreBPE { @@ -366,7 +419,11 @@ impl CoreBPE { let piece = mat.unwrap().as_str().as_bytes(); match self.encoder.get(piece) { Some(token) => ret.push(*token), - None => ret.extend(&byte_pair_encode(piece, &self.encoder)), + None => ret.extend(&byte_pair_encode_with_table( + piece, + &self.encoder, + &self.pair_table, + )), } } ret @@ -418,7 +475,7 @@ impl CoreBPE { ret.push(*token); continue; } - let tokens = byte_pair_encode(piece, &self.encoder); + let tokens = byte_pair_encode_with_table(piece, &self.encoder, &self.pair_table); last_piece_token_len = tokens.len(); ret.extend(&tokens); } @@ -550,11 +607,12 @@ impl CoreBPE { // would be a regex split before the UTF-8 truncation point. // Probably niche enough that no one will ever notice (after all, people didn't // notice all the big holes in the previous unstable token implementation) - Err(_) => byte_pair_encode(&possibility, &self.encoder), - // Something like the following is intriguing but incorrect: - // Err(e) => self.encode_ordinary(unsafe { - // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()]) - // }), + Err(_) => { + byte_pair_encode_with_table(&possibility, &self.encoder, &self.pair_table) + } // Something like the following is intriguing but incorrect: + // Err(e) => self.encode_ordinary(unsafe { + // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()]) + // }), }; let mut seq = Vec::new(); let mut seq_len = 0; @@ -583,13 +641,15 @@ impl CoreBPE { if unstable_bytes.len() - last_decoded.1 > 0 && last_decoded.0.is_some_and(|c| c.is_whitespace()) { - let mut reencoded = byte_pair_encode( + let mut reencoded = byte_pair_encode_with_table( &unstable_bytes[..unstable_bytes.len() - last_decoded.1], &self.encoder, + &self.pair_table, ); - reencoded.extend(byte_pair_encode( + reencoded.extend(byte_pair_encode_with_table( &unstable_bytes[unstable_bytes.len() - last_decoded.1..], &self.encoder, + &self.pair_table, )); completions.insert(reencoded); } @@ -649,6 +709,15 @@ impl CoreBPE { let mut sorted_token_bytes: Vec> = encoder.keys().cloned().collect(); sorted_token_bytes.sort(); + // Build the 2-byte pair lookup table (~256 KB, sub-millisecond one-time cost). + let mut pair_table: Box<[Rank; PAIR_TABLE_SIZE]> = Box::new([Rank::MAX; PAIR_TABLE_SIZE]); + for (key, &rank) in &encoder { + if key.len() == 2 { + let idx = ((key[0] as u16) << 8 | key[1] as u16) as usize; + pair_table[idx] = rank; + } + } + Ok(Self { encoder, special_tokens_encoder, @@ -659,6 +728,7 @@ impl CoreBPE { .map(|_| special_regex.clone()) .collect(), sorted_token_bytes, + pair_table, }) } @@ -700,3 +770,114 @@ mod tests { assert_eq!(res, vec![b"ab", b"ab"]); } } + +/// Tests that the precomputed 2-byte pair lookup table produces output +/// byte-identical to the vanilla hashmap-lookup path. Covers both the +/// linear `_byte_pair_merge` (pieces < 100 bytes) and the heap-based +/// `_byte_pair_merge_large` (pieces >= 100 bytes) code paths. +#[cfg(test)] +mod pair_table_equivalence { + use super::*; + + /// Build a small synthetic encoder: all ASCII a-z and 0-9 as single-byte + /// tokens, plus a handful of common 2-byte and 3-byte merges. Enough to + /// exercise real merge dynamics without depending on a downloaded vocab. + fn synthetic_encoder() -> HashMap, Rank> { + let mut encoder = HashMap::default(); + let mut rank: Rank = 0; + for b in b'a'..=b'z' { + encoder.insert(vec![b], rank); + rank += 1; + } + for b in b'0'..=b'9' { + encoder.insert(vec![b], rank); + rank += 1; + } + for pair in ["th", "he", "in", "er", "an", "re", "on", "at", "es", "or"] { + encoder.insert(pair.as_bytes().to_vec(), rank); + rank += 1; + } + for triple in ["the", "and", "ing", "ion", "for"] { + encoder.insert(triple.as_bytes().to_vec(), rank); + rank += 1; + } + encoder + } + + fn build_pair_table(encoder: &HashMap, Rank>) -> Box<[Rank; PAIR_TABLE_SIZE]> { + let mut pair_table: Box<[Rank; PAIR_TABLE_SIZE]> = Box::new([Rank::MAX; PAIR_TABLE_SIZE]); + for (key, &rank) in encoder { + if key.len() == 2 { + let idx = ((key[0] as u16) << 8 | key[1] as u16) as usize; + pair_table[idx] = rank; + } + } + pair_table + } + + fn check_equivalence(piece: &[u8]) { + let encoder = synthetic_encoder(); + let pair_table = build_pair_table(&encoder); + let vanilla = byte_pair_encode(piece, &encoder); + let patched = byte_pair_encode_with_table(piece, &encoder, &pair_table); + assert_eq!( + vanilla, + patched, + "vanilla vs pair-table diverged on piece of length {}", + piece.len(), + ); + } + + /// Generate an alphabetic piece of the requested length by cycling a..z. + fn alpha_piece(n: usize) -> Vec { + (0..n).map(|i| b'a' + ((i % 26) as u8)).collect() + } + + #[test] + fn equivalence_length_1_direct_lookup() { + // Single byte takes the early-return path; pair table never consulted. + check_equivalence(b"a"); + check_equivalence(b"z"); + check_equivalence(b"0"); + } + + #[test] + fn equivalence_short_pieces_linear_path() { + // Pieces < 100 bytes go through `_byte_pair_merge` (linear scan). + check_equivalence(b"th"); + check_equivalence(b"the"); + check_equivalence(b"that"); + check_equivalence(b"information"); + check_equivalence(b"theandingionfor"); + } + + #[test] + fn equivalence_just_under_100b_cutoff() { + check_equivalence(&alpha_piece(50)); + check_equivalence(&alpha_piece(98)); + check_equivalence(&alpha_piece(99)); + } + + #[test] + fn equivalence_at_100b_cutoff() { + // 100 bytes is the boundary: dispatches to `_byte_pair_merge_large`. + check_equivalence(&alpha_piece(100)); + check_equivalence(&alpha_piece(101)); + } + + #[test] + fn equivalence_long_pieces_heap_path() { + // Well into the heap-based path. + check_equivalence(&alpha_piece(200)); + check_equivalence(&alpha_piece(500)); + check_equivalence(&alpha_piece(1000)); + } + + #[test] + fn equivalence_repeated_pairs() { + // Many identical 2-byte pairs to stress the initial-scan path. + check_equivalence(&[b'x'; 5]); + check_equivalence(&[b'x'; 99]); + check_equivalence(&[b'x'; 150]); + } +}