@@ -63,6 +63,8 @@ pub struct BytePairEncoding {
6363 /// But we don't have efficient access to it and therefore store it here again.
6464 /// If there is none, then the value is set to u32::MAX.
6565 next_prefix_match : Vec < u32 > ,
66+ /// Hash factor used to prevent hash collisions.
67+ hash_factor : u64 ,
6668}
6769
6870fn serialize_daac < S : Serializer > (
@@ -156,25 +158,51 @@ fn token_bytes<'a>(all_tokens: &'a [u8], token_starts: &[u32], token_id: u32) ->
156158 & all_tokens[ token_range ( token_starts, token_id) ]
157159}
158160
159- fn hash_bytes ( bytes : & [ u8 ] ) -> u32 {
160- hash_bytes_with_factor ( bytes, 17846336922010275747 )
161- }
162-
163- fn hash_bytes_with_factor ( bytes : & [ u8 ] , factor : u64 ) -> u32 {
161+ fn hash_bytes ( bytes : & [ u8 ] , factor : u64 ) -> u32 {
164162 let mut hasher = FnvHasher :: default ( ) ;
165163 bytes. hash ( & mut hasher) ;
166164 // Note: we save 1/3 of space for the hashmap by only using the most significant bits of the hash.
167165 // To make them unique for the given tokens, we have to add unfortunately another multiplication.
168166 ( ( hasher. finish ( ) . wrapping_mul ( factor) ) >> 32 ) as u32
169167}
170168
169+ /// Find a suitable hash factor for the given tiktoken dictionary that prevents collisions
170+ /// when constructing a [`BytePairEncoding`] from those tokens.
171+ #[ cfg( all( feature = "tiktoken-rs" , feature = "rand" ) ) ]
172+ pub fn find_hash_factor_for_tiktoken ( bpe : & tiktoken_rs:: CoreBPE , len : usize ) -> u64 {
173+ find_hash_factor_for_dictionary ( ( 0 ..len) . map ( |i| bpe. _decode_native ( & [ i] ) ) )
174+ }
175+
176+ /// Find a suitable hash factor for a set of given tokens that prevents collisions when
177+ /// constructing a [`BytePairEncoding`] from those tokens.
178+ #[ cfg( feature = "rand" ) ]
179+ pub fn find_hash_factor_for_dictionary ( iter : impl Iterator < Item = Vec < u8 > > ) -> u64 {
180+ use std:: collections:: HashSet ;
181+
182+ use rand:: Rng ;
183+
184+ let all_tokens = iter. collect_vec ( ) ;
185+ let mut rnd = rand:: thread_rng ( ) ;
186+ loop {
187+ let factor: u64 = rnd. gen ( ) ;
188+ let mut seen = HashSet :: new ( ) ;
189+ if all_tokens
190+ . iter ( )
191+ . all ( |token| seen. insert ( hash_bytes ( token, factor) ) )
192+ {
193+ return factor;
194+ }
195+ }
196+ }
197+
171198fn find_token_by_bytes (
172199 all_tokens : & [ u8 ] ,
173200 token_starts : & [ u32 ] ,
174201 bytes_hash_to_token : & FnvHashMap < u32 , u32 > ,
175202 bytes : & [ u8 ] ,
203+ hash_factor : u64 ,
176204) -> Option < u32 > {
177- let hash = hash_bytes ( bytes) ;
205+ let hash = hash_bytes ( bytes, hash_factor ) ;
178206 let token = * bytes_hash_to_token. get ( & hash) ?;
179207 if token_bytes ( all_tokens, token_starts, token) == bytes {
180208 Some ( token)
@@ -192,20 +220,40 @@ impl BytePairEncoding {
192220 & BPE_O200K
193221 }
194222
195- /// Construct a BytePairEncoding instance frmo a tiktoken dictionary.
223+ /// Construct a BytePairEncoding instance from a tiktoken dictionary.
224+ /// A suitable hash factor may be necessary to prevent hash collisions,
225+ /// which can by found using [`find_hash_factor_for_tiktoken`].
226+ ///
227+ /// The recommended approach is to store the serialized value and reuse that,
228+ /// to prevent repeating the cost of computing the hash factor and encoding.
196229 #[ cfg( feature = "tiktoken-rs" ) ]
197- pub fn from_tiktoken ( tiktoken_bpe : & tiktoken_rs:: CoreBPE , num_tokens : usize ) -> Self {
198- Self :: from_dictionary ( ( 0 ..num_tokens) . map ( |i| tiktoken_bpe. _decode_native ( & [ i] ) ) )
230+ pub fn from_tiktoken (
231+ tiktoken_bpe : & tiktoken_rs:: CoreBPE ,
232+ num_tokens : usize ,
233+ hash_factor : Option < u64 > ,
234+ ) -> Self {
235+ Self :: from_dictionary (
236+ ( 0 ..num_tokens) . map ( |i| tiktoken_bpe. _decode_native ( & [ i] ) ) ,
237+ hash_factor,
238+ )
199239 }
200240
201- /// Construct a BytePairEncoding instance from an iterator which enumerates all tokens.
202- pub fn from_dictionary ( iter : impl Iterator < Item = Vec < u8 > > ) -> Self {
241+ /// Construct a BytePairEncoding instance from an iterator that enumerates all tokens.
242+ /// A suitable hash factor may be necessary to prevent hash collisions, which can be
243+ /// found using [`find_hash_factor_for_dictionary`].
244+ ///
245+ /// The recommended approach is to store the serialized value and reuse that,
246+ /// to prevent repeating the cost of computing the hash factor and encoding.
247+ pub fn from_dictionary ( iter : impl Iterator < Item = Vec < u8 > > , hash_factor : Option < u64 > ) -> Self {
248+ let hash_factor = hash_factor
249+ . inspect ( |f| assert_ne ! ( * f, 0 , "hash factor must be larger than zero" ) )
250+ . unwrap_or ( 1 ) ;
203251 let mut all_tokens = Vec :: new ( ) ;
204252 let mut all_tokens_rev = Vec :: new ( ) ;
205253 let mut token_starts = vec ! [ 0 ] ;
206254 let mut bytes_hash_to_token = FnvHashMap :: default ( ) ;
207255 for ( i, token) in iter. enumerate ( ) {
208- bytes_hash_to_token. insert ( hash_bytes ( & token) , i as u32 ) ;
256+ bytes_hash_to_token. insert ( hash_bytes ( & token, hash_factor ) , i as u32 ) ;
209257 all_tokens_rev. extend ( token. iter ( ) . copied ( ) . rev ( ) ) ;
210258 all_tokens. extend ( token) ;
211259 token_starts. push ( all_tokens. len ( ) as u32 ) ;
@@ -236,9 +284,13 @@ impl BytePairEncoding {
236284 let mut token1 = next_prefix_match[ id] ;
237285 while token1 != u32:: MAX {
238286 let rest = & token[ token_range ( & token_starts, token1) . len ( ) ..] ;
239- if let Some ( token2) =
240- find_token_by_bytes ( & all_tokens, & token_starts, & bytes_hash_to_token, rest)
241- {
287+ if let Some ( token2) = find_token_by_bytes (
288+ & all_tokens,
289+ & token_starts,
290+ & bytes_hash_to_token,
291+ rest,
292+ hash_factor,
293+ ) {
242294 if token1 < id as u32
243295 && token2 < id as u32
244296 && is_valid_token_pair ( & pair_lookup, & split_table, token1, token2)
@@ -264,6 +316,7 @@ impl BytePairEncoding {
264316 next_prefix_match,
265317 pair_lookup,
266318 split_table,
319+ hash_factor,
267320 }
268321 }
269322
@@ -308,6 +361,7 @@ impl BytePairEncoding {
308361 & self . token_starts ,
309362 & self . bytes_hash_to_token ,
310363 bytes,
364+ self . hash_factor ,
311365 )
312366 }
313367
@@ -557,68 +611,44 @@ mod tests {
557611
558612#[ cfg( test) ]
559613mod data {
560- use std:: collections:: HashSet ;
561614 use std:: fs:: File ;
562615 use std:: path:: PathBuf ;
563616
564- use rand:: Rng ;
565617 use serde:: Serialize ;
566- use tiktoken_rs:: { cl100k_base, o200k_base} ;
567-
568- use super :: * ;
569618
570- const BPE_CL100K_LEN : usize = 100256 ;
571- const BPE_O200K_LEN : usize = 199998 ;
572-
573- /// Use this to find a hashing factor for [`hash_bytes`] that prevents collisions.
574- /// 1. Ensure all supported tokenizers are in the list.
575- /// 2. Update the hash factor in [`hash_bytes`].
576- /// 3. Run [`update_token_dicts`] tests below to update data files.
577- #[ test]
578- #[ ignore = "run manually to find a suitable hash factor" ]
579- fn find_hash_factor ( ) {
580- let bpes = & mut [
581- ( cl100k_base ( ) . unwrap ( ) , BPE_CL100K_LEN ) ,
582- ( o200k_base ( ) . unwrap ( ) , BPE_O200K_LEN ) ,
583- ] ;
584- let mut rnd = rand:: thread_rng ( ) ;
585- loop {
586- let factor: u64 = rnd. gen ( ) ;
587- if bpes. iter ( ) . all ( |( bpe, len) | {
588- let mut seen = HashSet :: with_capacity ( * len) ;
589- ( 0 ..* len)
590- . all ( |i| seen. insert ( hash_bytes_with_factor ( & bpe. _decode_native ( & [ i] ) , factor) ) )
591- } ) {
592- println ! ( "hash factor: {factor}" ) ;
593- return ;
594- }
595- }
596- }
619+ use crate :: byte_pair_encoding:: BytePairEncoding ;
597620
598621 #[ test]
599622 fn update_token_dicts ( ) {
600623 serialize_tokens (
601- & cl100k_base ( ) . expect ( "tiktoken initialization must not fail!" ) ,
602- BPE_CL100K_LEN ,
603624 "cl100k" ,
625+ & tiktoken_rs:: cl100k_base ( ) . expect ( "tiktoken initialization must not fail!" ) ,
626+ 100256 ,
627+ 17846336922010275747 ,
604628 ) ;
605629 serialize_tokens (
606- & o200k_base ( ) . expect ( "tiktoken initialization must not fail!" ) ,
607- BPE_O200K_LEN ,
608630 "o200k" ,
631+ & tiktoken_rs:: o200k_base ( ) . expect ( "tiktoken initialization must not fail!" ) ,
632+ 199998 ,
633+ 17846336922010275747 ,
609634 ) ;
610635 }
611636
612637 #[ track_caller]
613- fn serialize_tokens ( dict : & tiktoken_rs:: CoreBPE , num_tokens : usize , name : & str ) {
638+ fn serialize_tokens (
639+ name : & str ,
640+ dict : & tiktoken_rs:: CoreBPE ,
641+ num_tokens : usize ,
642+ hash_factor : u64 ,
643+ ) {
614644 let path = PathBuf :: from ( file ! ( ) ) ;
615645 let dir = path. parent ( ) . unwrap ( ) ;
616646 let data_file = dir. join ( format ! ( "data/bpe_{name}.dict" ) ) ;
617647 let current_dir = std:: env:: current_dir ( ) . unwrap ( ) ;
618648 let abs_path = current_dir. parent ( ) . unwrap ( ) . parent ( ) . unwrap ( ) ;
619649 let file = File :: create ( abs_path. join ( data_file) ) . unwrap ( ) ;
620650 let mut serializer = rmp_serde:: Serializer :: new ( file) ;
621- BytePairEncoding :: from_tiktoken ( dict, num_tokens)
651+ BytePairEncoding :: from_tiktoken ( dict, num_tokens, Some ( hash_factor ) )
622652 . serialize ( & mut serializer)
623653 . unwrap ( ) ;
624654 }
0 commit comments