Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 123 additions & 44 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,60 +15,129 @@ mod py;
pub type Rank = u32;

fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
// This is a vector of (start, rank).
// The rank is of the pair starting at position start.
let mut parts = Vec::with_capacity(piece.len() + 1);

// Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
// the way we currently do, this is equivalent. An easy way to break this would be to decouple
// merge priority from token index or to prevent specific token merges.
let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
for i in 0..piece.len() - 1 {
let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
if rank < min_rank.0 {
min_rank = (rank, i);
use std::cmp::Ordering;
use std::collections::BinaryHeap;

#[derive(Clone, Copy)]
struct Node {
prev: Option<usize>,
next: Option<usize>,
alive: bool,
}

#[derive(Eq, Clone, Copy)]
struct Cand {
rank: Rank,
left: usize,
ver: u32,
}

impl PartialEq for Cand {
fn eq(&self, other: &Self) -> bool {
self.rank == other.rank && self.left == other.left && self.ver == other.ver
}
parts.push((i, rank));
}
parts.push((piece.len() - 1, Rank::MAX));
parts.push((piece.len(), Rank::MAX));

let get_rank = {
#[inline(always)]
|parts: &Vec<(usize, Rank)>, i: usize| {
if (i + 3) < parts.len() {
// Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
// parts[i + 1], see comment in the main loop.
*ranks
.get(&piece[parts[i].0..parts[i + 3].0])
.unwrap_or(&Rank::MAX)
} else {
Rank::MAX
impl PartialOrd for Cand {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Cand {
fn cmp(&self, other: &Self) -> Ordering {
other.rank.cmp(&self.rank)
.then_with(|| other.left.cmp(&self.left))
.then_with(|| other.ver.cmp(&self.ver))
}
}

#[inline(always)]
fn compute_rank_at(
ranks: &HashMap<Vec<u8>, Rank>,
piece: &[u8],
nodes: &Vec<Node>,
i: usize,
) -> Rank {
if let Some(j) = nodes[i].next {
if let Some(k) = nodes[j].next {
return *ranks.get(&piece[i..k]).unwrap_or(&Rank::MAX);
}
}
};
Rank::MAX
}

let n_bytes = piece.len();
if n_bytes == 0 {
return vec![(0, Rank::MAX)];
}
if n_bytes == 1 {
return vec![(0, Rank::MAX), (1, Rank::MAX)];
}

let num_nodes = n_bytes + 1;
let mut nodes: Vec<Node> = (0..num_nodes)
.map(|i| Node {
prev: if i > 0 { Some(i - 1) } else { None },
next: if i + 1 < num_nodes { Some(i + 1) } else { None },
alive: true,
})
.collect();

let mut ver: Vec<u32> = vec![0; num_nodes];

// If you have n parts and m merges, this does O(mn) work.
// We could do something with a heap and do O(m log n) work.
// n is often very small so considerations like cache-locality outweigh the algorithmic
// complexity downsides of the `parts` vector.
while min_rank.0 != Rank::MAX {
let i = min_rank.1;
// Update parts[i] and parts[i - 1] before removing parts[i + 1], since
// `parts.remove(i + 1)` will thrash the cache.
if i > 0 {
parts[i - 1].1 = get_rank(&parts, i - 1);
let mut heap = BinaryHeap::new();
for i in 0..(num_nodes - 2) {
let rank = compute_rank_at(ranks, piece, &nodes, i);
if rank != Rank::MAX {
heap.push(Cand { rank, left: i, ver: ver[i] });
}
parts[i].1 = get_rank(&parts, i);
parts.remove(i + 1);
}

while let Some(c) = heap.pop() {
if !nodes[c.left].alive { continue; }
if ver[c.left] != c.ver { continue; }

let j = match nodes[c.left].next { Some(j) => j, None => continue };
if !nodes[j].alive { continue; }
let k = match nodes[j].next { Some(k) => k, None => continue };
if !nodes[k].alive { continue; }

nodes[c.left].next = Some(k);
nodes[k].prev = Some(c.left);
nodes[j].alive = false;

ver[c.left] = ver[c.left].wrapping_add(1);

min_rank = (Rank::MAX, usize::MAX);
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
if rank < min_rank.0 {
min_rank = (rank, i);
if let Some(p) = nodes[c.left].prev {
ver[p] = ver[p].wrapping_add(1);
let prank = compute_rank_at(ranks, piece, &nodes, p);
if prank != Rank::MAX {
heap.push(Cand { rank: prank, left: p, ver: ver[p] });
}
}

let crank = compute_rank_at(ranks, piece, &nodes, c.left);
if crank != Rank::MAX {
heap.push(Cand { rank: crank, left: c.left, ver: ver[c.left] });
}
}

let mut parts: Vec<(usize, Rank)> = Vec::new();
let mut cur = 0usize;
loop {
if nodes[cur].alive {
let r = compute_rank_at(ranks, piece, &nodes, cur);
parts.push((cur, r));
}
match nodes[cur].next {
Some(n) => cur = n,
None => break,
}
}

if parts.is_empty() || parts.last().unwrap().0 != n_bytes {
parts.push((n_bytes, Rank::MAX));
}

parts
}

Expand Down Expand Up @@ -571,4 +640,14 @@ mod tests {
let res = byte_pair_split(b"abab", &ranks);
assert_eq!(res, vec![b"ab", b"ab"]);
}

#[test]
fn test__byte_pair_merge_boundaries() {
let ranks = setup_ranks();
let piece = b"abcd";
let parts = super::_byte_pair_merge(&ranks, piece);
let positions: Vec<usize> = parts.iter().map(|(i, _)| *i).collect();
assert_eq!(positions, vec![0, 2, 4]);
assert_eq!(parts.last().unwrap().1, Rank::MAX);
}
}