Skip to content

Commit a076df4

Browse files
Majdoddintmm1
authored andcommitted
switch to regex::Regex (#22)
Based on openai#331 Uses Regex in _encode_ordinary_native instead of fancy-regex, to get a 6x speedup. To make the regex patterns compatible with Regex, drops part of thpatterns for whitespaces, and handles the whitespaces with scripting instead of regex. Still with exact same output. _encode_native calls _encode_ordinary_native_impl directly (_encode_ordinary_native is a wrapper of _encode_ordinary_native now).
1 parent a52c83f commit a076df4

File tree

3 files changed

+89
-44
lines changed

3 files changed

+89
-44
lines changed

src/corebpe.rs

Lines changed: 86 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
use fancy_regex::Regex;
1+
use fancy_regex::Regex as FancyRegex;
2+
use regex::Regex;
23
use rustc_hash::FxHashMap as HashMap;
34
use rustc_hash::FxHashSet as HashSet;
45
use thread_local::ThreadLocal;
@@ -133,9 +134,9 @@ pub struct CoreBPE {
133134
decoder: HashMap<Rank, &'static [u8]>,
134135
special_tokens_decoder: HashMap<Rank, Vec<u8>>,
135136
regex: Regex,
136-
special_regex: Regex,
137+
special_regex: FancyRegex,
137138
regex_tls: ThreadLocal<Regex>,
138-
special_regex_tls: ThreadLocal<Regex>,
139+
special_regex_tls: ThreadLocal<FancyRegex>,
139140
sorted_token_bytes: Vec<&'static [u8]>,
140141
}
141142

@@ -144,7 +145,7 @@ impl CoreBPE {
144145
self.regex_tls.get_or(|| self.regex.clone())
145146
}
146147

147-
fn _get_tl_special_regex(&self) -> &Regex {
148+
fn _get_tl_special_regex(&self) -> &FancyRegex {
148149
self.special_regex_tls.get_or(|| self.special_regex.clone())
149150
}
150151

@@ -161,24 +162,85 @@ impl CoreBPE {
161162
ret
162163
}
163164

164-
fn _encode_ordinary_native(&self, text: &str) -> Vec<Rank> {
165+
fn _encode_ordinary_native_impl(&self, text: &str, ret: &mut Vec<Rank>) -> usize {
165166
// This is the core of the encoding logic; the other functions in here
166167
// just make things complicated :-)
167168
let regex = self._get_tl_regex();
168-
let mut ret = vec![];
169+
let mut last_end = 0;
170+
let mut last_piece_token_len = 0;
171+
let mut piece: &[u8] = &[];
169172
for mat in regex.find_iter(text) {
170-
let piece = mat.unwrap().as_str().as_bytes();
173+
piece = mat.as_str().as_bytes();
174+
let start = mat.start();
175+
let end = mat.end();
176+
177+
// If there is a whitespace gap between peice and the previous piece, add its tokens
178+
if last_end < start {
179+
// If current piece starts with a whitespace, the whole gap is one new piece
180+
if mat
181+
.as_str()
182+
.chars()
183+
.next()
184+
.map_or(false, |c| c.is_whitespace())
185+
{
186+
let wpiece = text[last_end..start].as_bytes();
187+
match self.encoder.get(wpiece) {
188+
Some(token) => ret.push(*token),
189+
None => ret.extend(&byte_pair_encode(wpiece, &self.encoder)),
190+
}
191+
// otherwise the last char of gap makes a piece, and the rest (if any) makes another piece
192+
} else {
193+
let last_char_size = &text[last_end..start]
194+
.chars()
195+
.next_back()
196+
.unwrap()
197+
.len_utf8();
198+
// Example for gpt4-o: for text "= 6", "=" and "6" are matches, " " is the gap,
199+
// so the gap makes just one piece
200+
if last_char_size < &(start - last_end) {
201+
let wpiece1 = text[last_end..start - last_char_size].as_bytes();
202+
match self.encoder.get(wpiece1) {
203+
Some(token) => ret.push(*token),
204+
None => ret.extend(&byte_pair_encode(wpiece1, &self.encoder)),
205+
}
206+
}
207+
let wpiece2 = text[start - last_char_size..start].as_bytes();
208+
match self.encoder.get(wpiece2) {
209+
Some(token) => ret.push(*token),
210+
None => ret.extend(&byte_pair_encode(wpiece2, &self.encoder)),
211+
}
212+
}
213+
}
214+
last_end = end;
215+
216+
// Now add piece tokens
171217
match self.encoder.get(piece) {
172218
Some(token) => ret.push(*token),
173219
None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
174220
}
175221
}
176-
ret
222+
// Gap of whitespaces at the end of text
223+
if last_end < text.len() {
224+
piece = text[last_end..text.len()].as_bytes();
225+
match self.encoder.get(piece) {
226+
Some(token) => ret.push(*token),
227+
None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
228+
}
229+
}
230+
231+
if !piece.is_empty() {
232+
last_piece_token_len = match self.encoder.get(piece) {
233+
Some(token) => 1,
234+
None => byte_pair_encode(piece, &self.encoder).len(),
235+
};
236+
};
237+
238+
last_piece_token_len
177239
}
178240

179241
fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<Rank>, usize) {
180242
let special_regex = self._get_tl_special_regex();
181-
let regex = self._get_tl_regex();
243+
182244
let mut ret = vec![];
183245

184246
let mut start = 0;
@@ -201,17 +263,10 @@ impl CoreBPE {
201263
}
202264
let end = next_special.map_or(text.len(), |m| m.start());
203265

204-
// Okay, here we go, compare this logic to _encode_ordinary_native
205-
for mat in regex.find_iter(&text[start..end]) {
206-
let piece = mat.unwrap().as_str().as_bytes();
207-
if let Some(token) = self.encoder.get(piece) {
208-
last_piece_token_len = 1;
209-
ret.push(*token);
210-
continue;
211-
}
212-
let tokens = byte_pair_encode(piece, &self.encoder);
213-
last_piece_token_len = tokens.len();
214-
ret.extend(&tokens);
266+
if end > start {
267+
// regex is not created and passed here, but it seems harmless.
268+
last_piece_token_len =
269+
self._encode_ordinary_native_impl(&text[start..end], &mut ret);
215270
}
216271

217272
match next_special {
@@ -271,6 +326,13 @@ impl CoreBPE {
271326
(tokens, last_piece_token_len)
272327
}
273328

329+
fn _encode_ordinary_native(&self, text: &str) -> Vec<Rank> {
330+
// This wrapper function is needed for those callers that do not pass ret.
331+
let mut ret = vec![];
332+
self._encode_ordinary_native_impl(text, &mut ret);
333+
ret
334+
}
335+
274336
fn _encode_unstable_native(
275337
&self,
276338
text: &str,
@@ -302,7 +364,7 @@ impl CoreBPE {
302364
// Separating this from the loop below helps with performance in a common case.
303365
let mut point = self
304366
.sorted_token_bytes
305-
.partition_point(|x| *x < unstable_bytes.as_slice());
367+
.partition_point(|x| &x[..] < unstable_bytes.as_slice());
306368
while point < self.sorted_token_bytes.len()
307369
&& self.sorted_token_bytes[point].starts_with(&unstable_bytes)
308370
{
@@ -318,9 +380,7 @@ impl CoreBPE {
318380
for i in 1..unstable_bytes.len() {
319381
let prefix = &unstable_bytes[..i];
320382
let suffix = &unstable_bytes[i..];
321-
let mut point = self
322-
.sorted_token_bytes
323-
.partition_point(|x| *x < suffix);
383+
let mut point = self.sorted_token_bytes.partition_point(|x| &x[..] < suffix);
324384
// TODO: Perf optimisation if suffix starts with " "?
325385
while point < self.sorted_token_bytes.len()
326386
&& self.sorted_token_bytes[point].starts_with(suffix)
@@ -393,15 +453,15 @@ impl CoreBPE {
393453
encoder: HashMap<Vec<u8>, Rank>,
394454
special_tokens_encoder: HashMap<String, Rank>,
395455
pattern: &str,
396-
) -> Result<Self, fancy_regex::Error> {
456+
) -> Result<Self, regex::Error> {
397457
let regex = Regex::new(pattern)?;
398458

399459
let special_regex = {
400460
let parts = special_tokens_encoder
401461
.keys()
402462
.map(|s| fancy_regex::escape(s))
403463
.collect::<Vec<_>>();
404-
Regex::new(&parts.join("|"))?
464+
FancyRegex::new(&parts.join("|")).unwrap()
405465
};
406466

407467
// Use unsafe to extend the lifetime of references to the encoder's keys

src/encoding.rs

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ include!(concat!(env!("OUT_DIR"), "/odht_gen.rs"));
1616
pub struct Encoding {
1717
/// The name of the encoding.
1818
pub name: String,
19-
/// The regular expression pattern used to split text into pieces.
20-
pat_str: String,
2119
/// The maximum length of the keys in `mergeable_ranks`.
2220
mergeable_ranks_max_key_len: usize,
2321
/// All prefixes of the mergeable ranks. May or may not be tokens themselves!
@@ -117,7 +115,6 @@ impl Encoding {
117115

118116
Ok(Self {
119117
name: name.to_string(),
120-
pat_str: pat_str.to_string(),
121118
mergeable_ranks_max_key_len,
122119
prefixes_of_mergeable_ranks,
123120
special_tokens,
@@ -468,16 +465,6 @@ impl Encoding {
468465
self.core_bpe.encode_single_piece(text_or_bytes)
469466
}
470467

471-
/// Encodes a string into tokens, but do regex splitting in Rust.
472-
fn _encode_only_native_bpe(&self, text: &str) -> Vec<Rank> {
473-
let re = Regex::new(&self.pat_str).unwrap();
474-
let mut ret = Vec::new();
475-
for piece in re.find_iter(text) {
476-
ret.extend(self.core_bpe.encode_single_piece(piece.as_str().as_bytes()));
477-
}
478-
ret
479-
}
480-
481468
/// Encodes bytes into tokens.
482469
fn _encode_bytes(&self, text: &[u8]) -> Vec<Rank> {
483470
self.core_bpe._encode_bytes(text)

src/openai_public.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ impl EncodingFactory {
2929
// The pattern in the original GPT-2 release is:
3030
// r"'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
3131
// This is equivalent, but executes faster:
32-
const LEGACY_SPLITTER_REGEX: &str = r"'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s";
32+
const LEGACY_SPLITTER_REGEX: &str = r"'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+";
3333

3434
pub fn gpt2() -> Result<Encoding, EncodingFactoryError> {
3535
// todo!
@@ -114,7 +114,7 @@ impl EncodingFactory {
114114
special_tokens.shrink_to_fit();
115115
// use faster version from tiktoken upstream https://github.com/openai/tiktoken/pull/258/files#r1487668172
116116
// const PATTERN: &str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
117-
const PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s";
117+
const PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]";
118118
Encoding::new(
119119
"cl100k_base",
120120
PATTERN,
@@ -142,8 +142,6 @@ impl EncodingFactory {
142142
r"\p{N}{1,3}",
143143
r" ?[^\s\p{L}\p{N}]+[\r\n/]*",
144144
r"\s*[\r\n]+",
145-
r"\s+(?!\S)",
146-
r"\s+",
147145
].join("|");
148146

149147
Encoding::new("o200k_base", pat_str, mergeable_ranks, special_tokens, None)
@@ -204,7 +202,7 @@ impl EncodingFactory {
204202
special_tokens.into_iter().enumerate().map(|(i, token)| (token, (num_base_tokens + i) as Rank)).collect();
205203
special_tokens_map.shrink_to_fit();
206204

207-
let pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
205+
let pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+";
208206

209207
let vocab_size = num_base_tokens + special_tokens_map.len();
210208
Encoding::new(

0 commit comments

Comments
 (0)