1
- use fancy_regex:: Regex ;
1
+ use fancy_regex:: Regex as FancyRegex ;
2
+ use regex:: Regex ;
2
3
use rustc_hash:: FxHashMap as HashMap ;
3
4
use rustc_hash:: FxHashSet as HashSet ;
4
5
use thread_local:: ThreadLocal ;
@@ -133,9 +134,9 @@ pub struct CoreBPE {
133
134
decoder : HashMap < Rank , & ' static [ u8 ] > ,
134
135
special_tokens_decoder : HashMap < Rank , Vec < u8 > > ,
135
136
regex : Regex ,
136
- special_regex : Regex ,
137
+ special_regex : FancyRegex ,
137
138
regex_tls : ThreadLocal < Regex > ,
138
- special_regex_tls : ThreadLocal < Regex > ,
139
+ special_regex_tls : ThreadLocal < FancyRegex > ,
139
140
sorted_token_bytes : Vec < & ' static [ u8 ] > ,
140
141
}
141
142
@@ -144,7 +145,7 @@ impl CoreBPE {
144
145
self . regex_tls . get_or ( || self . regex . clone ( ) )
145
146
}
146
147
147
- fn _get_tl_special_regex ( & self ) -> & Regex {
148
+ fn _get_tl_special_regex ( & self ) -> & FancyRegex {
148
149
self . special_regex_tls . get_or ( || self . special_regex . clone ( ) )
149
150
}
150
151
@@ -161,24 +162,85 @@ impl CoreBPE {
161
162
ret
162
163
}
163
164
164
- fn _encode_ordinary_native ( & self , text : & str ) -> Vec < Rank > {
165
+ fn _encode_ordinary_native_impl ( & self , text : & str , ret : & mut Vec < Rank > ) -> usize {
165
166
// This is the core of the encoding logic; the other functions in here
166
167
// just make things complicated :-)
167
168
let regex = self . _get_tl_regex ( ) ;
168
- let mut ret = vec ! [ ] ;
169
+ let mut last_end = 0 ;
170
+ let mut last_piece_token_len = 0 ;
171
+ let mut piece: & [ u8 ] = & [ ] ;
169
172
for mat in regex. find_iter ( text) {
170
- let piece = mat. unwrap ( ) . as_str ( ) . as_bytes ( ) ;
173
+ piece = mat. as_str ( ) . as_bytes ( ) ;
174
+ let start = mat. start ( ) ;
175
+ let end = mat. end ( ) ;
176
+
177
+ // If there is a whitespace gap between peice and the previous piece, add its tokens
178
+ if last_end < start {
179
+ // If current piece starts with a whitespace, the whole gap is one new piece
180
+ if mat
181
+ . as_str ( )
182
+ . chars ( )
183
+ . next ( )
184
+ . map_or ( false , |c| c. is_whitespace ( ) )
185
+ {
186
+ let wpiece = text[ last_end..start] . as_bytes ( ) ;
187
+ match self . encoder . get ( wpiece) {
188
+ Some ( token) => ret. push ( * token) ,
189
+ None => ret. extend ( & byte_pair_encode ( wpiece, & self . encoder ) ) ,
190
+ }
191
+ // otherwise the last char of gap makes a piece, and the rest (if any) makes another piece
192
+ } else {
193
+ let last_char_size = & text[ last_end..start]
194
+ . chars ( )
195
+ . next_back ( )
196
+ . unwrap ( )
197
+ . len_utf8 ( ) ;
198
+ // Example for gpt4-o: for text "= 6", "=" and "6" are matches, " " is the gap,
199
+ // so the gap makes just one piece
200
+ if last_char_size < & ( start - last_end) {
201
+ let wpiece1 = text[ last_end..start - last_char_size] . as_bytes ( ) ;
202
+ match self . encoder . get ( wpiece1) {
203
+ Some ( token) => ret. push ( * token) ,
204
+ None => ret. extend ( & byte_pair_encode ( wpiece1, & self . encoder ) ) ,
205
+ }
206
+ }
207
+ let wpiece2 = text[ start - last_char_size..start] . as_bytes ( ) ;
208
+ match self . encoder . get ( wpiece2) {
209
+ Some ( token) => ret. push ( * token) ,
210
+ None => ret. extend ( & byte_pair_encode ( wpiece2, & self . encoder ) ) ,
211
+ }
212
+ }
213
+ }
214
+ last_end = end;
215
+
216
+ // Now add piece tokens
171
217
match self . encoder . get ( piece) {
172
218
Some ( token) => ret. push ( * token) ,
173
219
None => ret. extend ( & byte_pair_encode ( piece, & self . encoder ) ) ,
174
220
}
175
221
}
176
- ret
222
+ // Gap of whitespaces at the end of text
223
+ if last_end < text. len ( ) {
224
+ piece = text[ last_end..text. len ( ) ] . as_bytes ( ) ;
225
+ match self . encoder . get ( piece) {
226
+ Some ( token) => ret. push ( * token) ,
227
+ None => ret. extend ( & byte_pair_encode ( piece, & self . encoder ) ) ,
228
+ }
229
+ }
230
+
231
+ if !piece. is_empty ( ) {
232
+ last_piece_token_len = match self . encoder . get ( piece) {
233
+ Some ( token) => 1 ,
234
+ None => byte_pair_encode ( piece, & self . encoder ) . len ( ) ,
235
+ } ;
236
+ } ;
237
+
238
+ last_piece_token_len
177
239
}
178
240
179
241
fn _encode_native ( & self , text : & str , allowed_special : & HashSet < & str > ) -> ( Vec < Rank > , usize ) {
180
242
let special_regex = self . _get_tl_special_regex ( ) ;
181
- let regex = self . _get_tl_regex ( ) ;
243
+
182
244
let mut ret = vec ! [ ] ;
183
245
184
246
let mut start = 0 ;
@@ -201,17 +263,10 @@ impl CoreBPE {
201
263
}
202
264
let end = next_special. map_or ( text. len ( ) , |m| m. start ( ) ) ;
203
265
204
- // Okay, here we go, compare this logic to _encode_ordinary_native
205
- for mat in regex. find_iter ( & text[ start..end] ) {
206
- let piece = mat. unwrap ( ) . as_str ( ) . as_bytes ( ) ;
207
- if let Some ( token) = self . encoder . get ( piece) {
208
- last_piece_token_len = 1 ;
209
- ret. push ( * token) ;
210
- continue ;
211
- }
212
- let tokens = byte_pair_encode ( piece, & self . encoder ) ;
213
- last_piece_token_len = tokens. len ( ) ;
214
- ret. extend ( & tokens) ;
266
+ if end > start {
267
+ // regex is not created and passed here, but it seems harmless.
268
+ last_piece_token_len =
269
+ self . _encode_ordinary_native_impl ( & text[ start..end] , & mut ret) ;
215
270
}
216
271
217
272
match next_special {
@@ -271,6 +326,13 @@ impl CoreBPE {
271
326
( tokens, last_piece_token_len)
272
327
}
273
328
329
+ fn _encode_ordinary_native ( & self , text : & str ) -> Vec < Rank > {
330
+ // This wrapper function is needed for those callers that do not pass ret.
331
+ let mut ret = vec ! [ ] ;
332
+ self . _encode_ordinary_native_impl ( text, & mut ret) ;
333
+ ret
334
+ }
335
+
274
336
fn _encode_unstable_native (
275
337
& self ,
276
338
text : & str ,
@@ -302,7 +364,7 @@ impl CoreBPE {
302
364
// Separating this from the loop below helps with performance in a common case.
303
365
let mut point = self
304
366
. sorted_token_bytes
305
- . partition_point ( |x| * x < unstable_bytes. as_slice ( ) ) ;
367
+ . partition_point ( |x| & x [ .. ] < unstable_bytes. as_slice ( ) ) ;
306
368
while point < self . sorted_token_bytes . len ( )
307
369
&& self . sorted_token_bytes [ point] . starts_with ( & unstable_bytes)
308
370
{
@@ -318,9 +380,7 @@ impl CoreBPE {
318
380
for i in 1 ..unstable_bytes. len ( ) {
319
381
let prefix = & unstable_bytes[ ..i] ;
320
382
let suffix = & unstable_bytes[ i..] ;
321
- let mut point = self
322
- . sorted_token_bytes
323
- . partition_point ( |x| * x < suffix) ;
383
+ let mut point = self . sorted_token_bytes . partition_point ( |x| & x[ ..] < suffix) ;
324
384
// TODO: Perf optimisation if suffix starts with " "?
325
385
while point < self . sorted_token_bytes . len ( )
326
386
&& self . sorted_token_bytes [ point] . starts_with ( suffix)
@@ -393,15 +453,15 @@ impl CoreBPE {
393
453
encoder : HashMap < Vec < u8 > , Rank > ,
394
454
special_tokens_encoder : HashMap < String , Rank > ,
395
455
pattern : & str ,
396
- ) -> Result < Self , fancy_regex :: Error > {
456
+ ) -> Result < Self , regex :: Error > {
397
457
let regex = Regex :: new ( pattern) ?;
398
458
399
459
let special_regex = {
400
460
let parts = special_tokens_encoder
401
461
. keys ( )
402
462
. map ( |s| fancy_regex:: escape ( s) )
403
463
. collect :: < Vec < _ > > ( ) ;
404
- Regex :: new ( & parts. join ( "|" ) ) ?
464
+ FancyRegex :: new ( & parts. join ( "|" ) ) . unwrap ( )
405
465
} ;
406
466
407
467
// Use unsafe to extend the lifetime of references to the encoder's keys
0 commit comments