Skip to content

Commit a26c577

Browse files
Merge pull request #19 from github/expose-hash-factor
Expose hash factor in API
2 parents f0c9def + cbaaa2d commit a26c577

File tree

3 files changed

+83
-53
lines changed

3 files changed

+83
-53
lines changed

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 83 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ pub struct BytePairEncoding {
6363
/// But we don't have efficient access to it and therefore store it here again.
6464
/// If there is none, then the value is set to u32::MAX.
6565
next_prefix_match: Vec<u32>,
66+
/// Hash factor used to prevent hash collisions.
67+
hash_factor: u64,
6668
}
6769

6870
fn serialize_daac<S: Serializer>(
@@ -156,25 +158,51 @@ fn token_bytes<'a>(all_tokens: &'a [u8], token_starts: &[u32], token_id: u32) ->
156158
&all_tokens[token_range(token_starts, token_id)]
157159
}
158160

159-
fn hash_bytes(bytes: &[u8]) -> u32 {
160-
hash_bytes_with_factor(bytes, 17846336922010275747)
161-
}
162-
163-
fn hash_bytes_with_factor(bytes: &[u8], factor: u64) -> u32 {
161+
fn hash_bytes(bytes: &[u8], factor: u64) -> u32 {
164162
let mut hasher = FnvHasher::default();
165163
bytes.hash(&mut hasher);
166164
// Note: we save 1/3 of space for the hashmap by only using the most significant bits of the hash.
167165
// To make them unique for the given tokens, we have to add unfortunately another multiplication.
168166
((hasher.finish().wrapping_mul(factor)) >> 32) as u32
169167
}
170168

169+
/// Find a suitable hash factor for the given tiktoken dictionary that prevents collisions
170+
/// when constructing a [`BytePairEncoding`] from those tokens.
171+
#[cfg(all(feature = "tiktoken-rs", feature = "rand"))]
172+
pub fn find_hash_factor_for_tiktoken(bpe: &tiktoken_rs::CoreBPE, len: usize) -> u64 {
173+
find_hash_factor_for_dictionary((0..len).map(|i| bpe._decode_native(&[i])))
174+
}
175+
176+
/// Find a suitable hash factor for a set of given tokens that prevents collisions when
177+
/// constructing a [`BytePairEncoding`] from those tokens.
178+
#[cfg(feature = "rand")]
179+
pub fn find_hash_factor_for_dictionary(iter: impl Iterator<Item = Vec<u8>>) -> u64 {
180+
use std::collections::HashSet;
181+
182+
use rand::Rng;
183+
184+
let all_tokens = iter.collect_vec();
185+
let mut rnd = rand::thread_rng();
186+
loop {
187+
let factor: u64 = rnd.gen();
188+
let mut seen = HashSet::new();
189+
if all_tokens
190+
.iter()
191+
.all(|token| seen.insert(hash_bytes(token, factor)))
192+
{
193+
return factor;
194+
}
195+
}
196+
}
197+
171198
fn find_token_by_bytes(
172199
all_tokens: &[u8],
173200
token_starts: &[u32],
174201
bytes_hash_to_token: &FnvHashMap<u32, u32>,
175202
bytes: &[u8],
203+
hash_factor: u64,
176204
) -> Option<u32> {
177-
let hash = hash_bytes(bytes);
205+
let hash = hash_bytes(bytes, hash_factor);
178206
let token = *bytes_hash_to_token.get(&hash)?;
179207
if token_bytes(all_tokens, token_starts, token) == bytes {
180208
Some(token)
@@ -192,20 +220,40 @@ impl BytePairEncoding {
192220
&BPE_O200K
193221
}
194222

195-
/// Construct a BytePairEncoding instance frmo a tiktoken dictionary.
223+
/// Construct a BytePairEncoding instance from a tiktoken dictionary.
224+
/// A suitable hash factor may be necessary to prevent hash collisions,
225+
/// which can by found using [`find_hash_factor_for_tiktoken`].
226+
///
227+
/// The recommended approach is to store the serialized value and reuse that,
228+
/// to prevent repeating the cost of computing the hash factor and encoding.
196229
#[cfg(feature = "tiktoken-rs")]
197-
pub fn from_tiktoken(tiktoken_bpe: &tiktoken_rs::CoreBPE, num_tokens: usize) -> Self {
198-
Self::from_dictionary((0..num_tokens).map(|i| tiktoken_bpe._decode_native(&[i])))
230+
pub fn from_tiktoken(
231+
tiktoken_bpe: &tiktoken_rs::CoreBPE,
232+
num_tokens: usize,
233+
hash_factor: Option<u64>,
234+
) -> Self {
235+
Self::from_dictionary(
236+
(0..num_tokens).map(|i| tiktoken_bpe._decode_native(&[i])),
237+
hash_factor,
238+
)
199239
}
200240

201-
/// Construct a BytePairEncoding instance from an iterator which enumerates all tokens.
202-
pub fn from_dictionary(iter: impl Iterator<Item = Vec<u8>>) -> Self {
241+
/// Construct a BytePairEncoding instance from an iterator that enumerates all tokens.
242+
/// A suitable hash factor may be necessary to prevent hash collisions, which can be
243+
/// found using [`find_hash_factor_for_dictionary`].
244+
///
245+
/// The recommended approach is to store the serialized value and reuse that,
246+
/// to prevent repeating the cost of computing the hash factor and encoding.
247+
pub fn from_dictionary(iter: impl Iterator<Item = Vec<u8>>, hash_factor: Option<u64>) -> Self {
248+
let hash_factor = hash_factor
249+
.inspect(|f| assert_ne!(*f, 0, "hash factor must be larger than zero"))
250+
.unwrap_or(1);
203251
let mut all_tokens = Vec::new();
204252
let mut all_tokens_rev = Vec::new();
205253
let mut token_starts = vec![0];
206254
let mut bytes_hash_to_token = FnvHashMap::default();
207255
for (i, token) in iter.enumerate() {
208-
bytes_hash_to_token.insert(hash_bytes(&token), i as u32);
256+
bytes_hash_to_token.insert(hash_bytes(&token, hash_factor), i as u32);
209257
all_tokens_rev.extend(token.iter().copied().rev());
210258
all_tokens.extend(token);
211259
token_starts.push(all_tokens.len() as u32);
@@ -236,9 +284,13 @@ impl BytePairEncoding {
236284
let mut token1 = next_prefix_match[id];
237285
while token1 != u32::MAX {
238286
let rest = &token[token_range(&token_starts, token1).len()..];
239-
if let Some(token2) =
240-
find_token_by_bytes(&all_tokens, &token_starts, &bytes_hash_to_token, rest)
241-
{
287+
if let Some(token2) = find_token_by_bytes(
288+
&all_tokens,
289+
&token_starts,
290+
&bytes_hash_to_token,
291+
rest,
292+
hash_factor,
293+
) {
242294
if token1 < id as u32
243295
&& token2 < id as u32
244296
&& is_valid_token_pair(&pair_lookup, &split_table, token1, token2)
@@ -264,6 +316,7 @@ impl BytePairEncoding {
264316
next_prefix_match,
265317
pair_lookup,
266318
split_table,
319+
hash_factor,
267320
}
268321
}
269322

@@ -308,6 +361,7 @@ impl BytePairEncoding {
308361
&self.token_starts,
309362
&self.bytes_hash_to_token,
310363
bytes,
364+
self.hash_factor,
311365
)
312366
}
313367

@@ -557,68 +611,44 @@ mod tests {
557611

558612
#[cfg(test)]
559613
mod data {
560-
use std::collections::HashSet;
561614
use std::fs::File;
562615
use std::path::PathBuf;
563616

564-
use rand::Rng;
565617
use serde::Serialize;
566-
use tiktoken_rs::{cl100k_base, o200k_base};
567-
568-
use super::*;
569618

570-
const BPE_CL100K_LEN: usize = 100256;
571-
const BPE_O200K_LEN: usize = 199998;
572-
573-
/// Use this to find a hashing factor for [`hash_bytes`] that prevents collisions.
574-
/// 1. Ensure all supported tokenizers are in the list.
575-
/// 2. Update the hash factor in [`hash_bytes`].
576-
/// 3. Run [`update_token_dicts`] tests below to update data files.
577-
#[test]
578-
#[ignore = "run manually to find a suitable hash factor"]
579-
fn find_hash_factor() {
580-
let bpes = &mut [
581-
(cl100k_base().unwrap(), BPE_CL100K_LEN),
582-
(o200k_base().unwrap(), BPE_O200K_LEN),
583-
];
584-
let mut rnd = rand::thread_rng();
585-
loop {
586-
let factor: u64 = rnd.gen();
587-
if bpes.iter().all(|(bpe, len)| {
588-
let mut seen = HashSet::with_capacity(*len);
589-
(0..*len)
590-
.all(|i| seen.insert(hash_bytes_with_factor(&bpe._decode_native(&[i]), factor)))
591-
}) {
592-
println!("hash factor: {factor}");
593-
return;
594-
}
595-
}
596-
}
619+
use crate::byte_pair_encoding::BytePairEncoding;
597620

598621
#[test]
599622
fn update_token_dicts() {
600623
serialize_tokens(
601-
&cl100k_base().expect("tiktoken initialization must not fail!"),
602-
BPE_CL100K_LEN,
603624
"cl100k",
625+
&tiktoken_rs::cl100k_base().expect("tiktoken initialization must not fail!"),
626+
100256,
627+
17846336922010275747,
604628
);
605629
serialize_tokens(
606-
&o200k_base().expect("tiktoken initialization must not fail!"),
607-
BPE_O200K_LEN,
608630
"o200k",
631+
&tiktoken_rs::o200k_base().expect("tiktoken initialization must not fail!"),
632+
199998,
633+
17846336922010275747,
609634
);
610635
}
611636

612637
#[track_caller]
613-
fn serialize_tokens(dict: &tiktoken_rs::CoreBPE, num_tokens: usize, name: &str) {
638+
fn serialize_tokens(
639+
name: &str,
640+
dict: &tiktoken_rs::CoreBPE,
641+
num_tokens: usize,
642+
hash_factor: u64,
643+
) {
614644
let path = PathBuf::from(file!());
615645
let dir = path.parent().unwrap();
616646
let data_file = dir.join(format!("data/bpe_{name}.dict"));
617647
let current_dir = std::env::current_dir().unwrap();
618648
let abs_path = current_dir.parent().unwrap().parent().unwrap();
619649
let file = File::create(abs_path.join(data_file)).unwrap();
620650
let mut serializer = rmp_serde::Serializer::new(file);
621-
BytePairEncoding::from_tiktoken(dict, num_tokens)
651+
BytePairEncoding::from_tiktoken(dict, num_tokens, Some(hash_factor))
622652
.serialize(&mut serializer)
623653
.unwrap();
624654
}
9 Bytes
Binary file not shown.

crates/bpe/src/data/bpe_o200k.dict

9 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)