Skip to content

Commit 814ef8d

Browse files
Fix example and add test reproducing it
1 parent d9b2bee commit 814ef8d

File tree

3 files changed

+77
-13
lines changed

3 files changed

+77
-13
lines changed

crates/bpe/README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,35 +94,35 @@ Given a valid encoding sequence `e_0..e_i` and a valid encoding tuple `e_i e_j`,
9494
## Novel Algorithm
9595

9696
At a first glance, it seems impossible to achieve `O(n)` complexity while preserving the encoding output of the original BPE algorithm, since the original BPE algorithm needs to first scan the full input before it can make any encoding decision.
97-
For instance, the sequence `abab` would be encoded as `ab ab` when the dictionary contains the tokens `a b ab ba bc abc babc ababc` ordered by frequency. But appending a single character `ababc` would result in a pretty different tokenization: `ab a bc`. So without looking ahead it seems impossible to properly tokenize the text.
97+
For instance, the sequence `abac` would be encoded as `ab ac` when the dictionary contains the tokens `a b c ab cb ac` ordered by frequency. But appending a single character `abacb` would result in a pretty different tokenization: `ab a cb`. So without looking ahead it seems impossible to properly tokenize the text.
9898

99-
The solution is to track the encodings of ALL text prefixes. For our example `ababc` we would get:
99+
The solution is to track the encodings of ALL text prefixes. For our example `abacb` we would get:
100100

101101
- `a` ------> `a`
102102
- `ab` -----> `ab`
103103
- `aba` ----> `ab a`
104-
- `abab` ---> `ab ab`
105-
- `ababc` --> `ab a bc`
104+
- `abab` ---> `ab ac`
105+
- `ababc` --> `ab a cb`
106106

107107
This can be done much more efficiently thanks to Corollary IIa, since now only the last token of every prefix has to be remembered:
108108

109109
- `a` ------> `a`
110110
- `ab` -----> `ab`
111111
- `aba` ----> `a`
112-
- `abab` ---> `ab`
113-
- `ababc` --> `bc`
112+
- `abac` ---> `ac`
113+
- `abacb` --> `bc`
114114

115115
In order to reconstruct the full encoding for a specific prefix, one simply starts with the last token of that prefix, shortens the prefix by the extracted token and looks up the token associated with the shortened prefix and so on until the beginning of the text is reached.
116116

117-
For our example prefix `ababc`, this procedure executes the following steps and determines the correct encoding in reverse order:
117+
For our example prefix `abacb`, this procedure executes the following steps and determines the correct encoding in reverse order:
118118

119-
- `ababc` -> `bc`
119+
- `abacb` -> `cb`
120120
- `aba` ---> `a`
121121
- `ab` ----> `ab`
122122
- `<empty>`
123123

124124
The actual challenge is to determine for every prefix this last token efficiently.
125-
The prefix `abab` could for instance end with either the token `b` or `ab`, but only `ab` leads to a valid encoding sequence.
125+
The prefix `abac` could for instance end with either the token `c` or `ac`, but only `ac` leads to a valid encoding sequence.
126126
But, Corollary IIa tells us that **one and only one** last token can be the correct one and Corollary IIIa shows us how to find it:
127127
We only have to check whether a possible next token is "compatible" with its previous token, i.e. whether the two tokens form a valid encoding sequence.
128128

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,12 @@ pub fn find_hash_factor_for_tiktoken(bpe: &tiktoken_rs::CoreBPE, len: usize) ->
176176
/// Find a suitable hash factor for a set of given tokens that prevents collisions when
177177
/// constructing a [`BytePairEncoding`] from those tokens.
178178
#[cfg(feature = "rand")]
179-
pub fn find_hash_factor_for_dictionary(iter: impl Iterator<Item = Vec<u8>>) -> u64 {
179+
pub fn find_hash_factor_for_dictionary(tokens: impl IntoIterator<Item = Vec<u8>>) -> u64 {
180180
use std::collections::HashSet;
181181

182182
use rand::Rng;
183183

184-
let all_tokens = iter.collect_vec();
184+
let all_tokens = tokens.into_iter().collect_vec();
185185
let mut rnd = rand::thread_rng();
186186
loop {
187187
let factor: u64 = rnd.gen();
@@ -244,15 +244,18 @@ impl BytePairEncoding {
244244
///
245245
/// The recommended approach is to store the serialized value and reuse that,
246246
/// to prevent repeating the cost of computing the hash factor and encoding.
247-
pub fn from_dictionary(iter: impl Iterator<Item = Vec<u8>>, hash_factor: Option<u64>) -> Self {
247+
pub fn from_dictionary(
248+
tokens: impl IntoIterator<Item = Vec<u8>>,
249+
hash_factor: Option<u64>,
250+
) -> Self {
248251
let hash_factor = hash_factor
249252
.inspect(|f| assert_ne!(*f, 0, "hash factor must be larger than zero"))
250253
.unwrap_or(1);
251254
let mut all_tokens = Vec::new();
252255
let mut all_tokens_rev = Vec::new();
253256
let mut token_starts = vec![0];
254257
let mut bytes_hash_to_token = FnvHashMap::default();
255-
for (i, token) in iter.enumerate() {
258+
for (i, token) in tokens.into_iter().enumerate() {
256259
bytes_hash_to_token.insert(hash_bytes(&token, hash_factor), i as u32);
257260
all_tokens_rev.extend(token.iter().copied().rev());
258261
all_tokens.extend(token);

crates/bpe/src/lib.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,64 @@ mod bitfield;
44
pub mod byte_pair_encoding;
55
pub mod interval_encoding;
66
pub mod prependable_encoder;
7+
8+
#[cfg(test)]
9+
mod tests {
10+
use itertools::Itertools;
11+
12+
use crate::byte_pair_encoding::BytePairEncoding;
13+
14+
/// This test produces the output for the encoding example in the README.
15+
#[test]
16+
fn readme_example() {
17+
let tokens = ["a", "b", "c", "ab", "cb", "ac"].map(|t| t.as_bytes().to_vec());
18+
let bpe = BytePairEncoding::from_dictionary(tokens, None);
19+
let text = "abacb";
20+
let prefixes = (1..=text.len()).map(|end| &text[..end]).collect_vec();
21+
let all_prefix_tokens = prefixes
22+
.iter()
23+
.map(|prefix| {
24+
bpe.encode_via_backtracking(prefix.as_bytes())
25+
.into_iter()
26+
.map(|t| unsafe { String::from_utf8_unchecked(bpe.decode_tokens(&[t])) })
27+
.collect_vec()
28+
})
29+
.collect_vec();
30+
let last_prefix_tokens = all_prefix_tokens
31+
.iter()
32+
.map(|tokens| tokens.last().unwrap())
33+
.collect_vec();
34+
35+
println!("All tokens for each prefix of `{text}`:\n");
36+
for (prefix, tokens) in prefixes.iter().zip(&all_prefix_tokens) {
37+
println!(
38+
"- `{prefix}` {}> `{}`",
39+
"-".repeat(text.len() + 2 - prefix.len()),
40+
tokens.join(" ")
41+
);
42+
}
43+
println!();
44+
45+
println!("Last token for each prefix of `{text}`:\n");
46+
for (prefix, token) in prefixes.iter().zip(&last_prefix_tokens) {
47+
println!(
48+
"- `{prefix}` {}> `{token}`",
49+
"-".repeat(text.len() + 2 - prefix.len()),
50+
);
51+
}
52+
println!();
53+
54+
println!("Tokenization of `{text}`:\n");
55+
let mut remaining = text.len();
56+
while remaining > 0 {
57+
let prefix = &text[..remaining];
58+
let token = last_prefix_tokens[remaining - 1];
59+
println!(
60+
"- `{prefix}` {}> `{token}`",
61+
"-".repeat(text.len() + 2 - prefix.len()),
62+
);
63+
remaining -= token.len();
64+
}
65+
println!("- `<empty>`");
66+
}
67+
}

0 commit comments

Comments
 (0)