|
| 1 | +use io::statistics::Instances; |
| 2 | + |
| 3 | +use std::cmp::Ordering; |
| 4 | +use std::collections::{BinaryHeap, HashMap}; |
| 5 | +use std::hash::Hash; |
| 6 | + |
| 7 | +/// A newtype for `u8` used to count the length of a key in bits. |
| 8 | +#[derive( |
| 9 | + Debug, |
| 10 | + Default, |
| 11 | + Display, |
| 12 | + Serialize, |
| 13 | + Deserialize, |
| 14 | + From, |
| 15 | + Into, |
| 16 | + Add, |
| 17 | + AddAssign, |
| 18 | + Sub, |
| 19 | + SubAssign, |
| 20 | + Clone, |
| 21 | + Copy, |
| 22 | + PartialOrd, |
| 23 | + Ord, |
| 24 | + PartialEq, |
| 25 | + Eq, |
| 26 | +)] |
| 27 | +pub struct BitLen(u8); |
| 28 | + |
| 29 | +/// Convenience implementation of operator `<<` in |
| 30 | +/// `bits << bit_len` |
| 31 | +impl std::ops::Shl<BitLen> for u32 { |
| 32 | + type Output = u32; |
| 33 | + fn shl(self, rhs: BitLen) -> u32 { |
| 34 | + self << Into::<u8>::into(rhs) |
| 35 | + } |
| 36 | +} |
| 37 | + |
| 38 | +/// The largerst acceptable length for a key. |
| 39 | +/// |
| 40 | +/// Hardcoded in the format. |
| 41 | +const MAX_CODE_BIT_LENGTH: u8 = 20; |
| 42 | + |
| 43 | +/// A Huffman key |
| 44 | +#[derive(Debug)] |
| 45 | +struct Key { |
| 46 | + /// The bits in the key. |
| 47 | + /// |
| 48 | + /// Note that we only use the `bit_len` lowest-weight bits. |
| 49 | + /// Any other bit is ignored. |
| 50 | + bits: u32, |
| 51 | + |
| 52 | + /// The number of bits of `bits` to use. |
| 53 | + bit_len: BitLen, |
| 54 | +} |
| 55 | + |
| 56 | +/// A node in the Huffman tree. |
| 57 | +struct Node<T> { |
| 58 | + /// The total number of instances of all `NodeContent::Leaf(T)` in this subtree. |
| 59 | + instances: Instances, |
| 60 | + |
| 61 | + /// The content of the node. |
| 62 | + content: NodeContent<T>, |
| 63 | +} |
| 64 | + |
| 65 | +/// Contents of a node in the Huffman tree. |
| 66 | +enum NodeContent<T> { |
| 67 | + /// A value from the stream of values. |
| 68 | + Leaf(T), |
| 69 | + |
| 70 | + /// An internal node obtained by joining two subtrees. |
| 71 | + Internal { |
| 72 | + left: Box<NodeContent<T>>, |
| 73 | + right: Box<NodeContent<T>>, |
| 74 | + }, |
| 75 | +} |
| 76 | + |
| 77 | +/// Custom ordering of `NodeContent`. |
| 78 | +/// |
| 79 | +/// We compare *only* by number of instances. |
| 80 | +impl<T> PartialOrd for Node<T> { |
| 81 | + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { |
| 82 | + self.instances.partial_cmp(&other.instances) |
| 83 | + } |
| 84 | +} |
| 85 | +impl<T> Ord for Node<T> { |
| 86 | + fn cmp(&self, other: &Self) -> Ordering { |
| 87 | + self.instances.cmp(&other.instances) |
| 88 | + } |
| 89 | +} |
| 90 | +impl<T> PartialEq for Node<T> { |
| 91 | + fn eq(&self, other: &Self) -> bool { |
| 92 | + self.instances.eq(&other.instances) |
| 93 | + } |
| 94 | +} |
| 95 | +impl<T> Eq for Node<T> {} |
| 96 | + |
| 97 | +/// Keys associated to a sequence of values. |
| 98 | +#[derive(Debug)] |
| 99 | +pub struct Keys<T> |
| 100 | +where |
| 101 | + T: Ord + Clone, |
| 102 | +{ |
| 103 | + /// The sequence of keys. |
| 104 | + /// |
| 105 | + /// Order is meaningful. |
| 106 | + keys: Vec<(T, Key)>, |
| 107 | +} |
| 108 | + |
| 109 | +impl<T> Keys<T> |
| 110 | +where |
| 111 | + T: Ord + Clone, |
| 112 | +{ |
| 113 | + /// Compute a `Keys` from a sequence of values. |
| 114 | + /// |
| 115 | + /// Optionally, `max_bit_len` may specify a largest acceptable bit length. |
| 116 | + /// If `Keys` may not be computed without exceeding this bit length, |
| 117 | + /// fail with `Err(problemantic_bit_length)`. |
| 118 | + /// |
| 119 | + /// The current implementation only attempts to produce the best compression |
| 120 | + /// level. This may cause us to exceed `max_bit_length` even though an |
| 121 | + /// alternative table, with a lower compression level, would let us |
| 122 | + /// proceed without exceeding `max_bit_length`. |
| 123 | + /// |
| 124 | + /// # Performance |
| 125 | + /// |
| 126 | + /// Values (type `T`) will be cloned regularly, so you should make |
| 127 | + /// sure that their cloning is reasonably cheap. |
| 128 | + pub fn from_sequence<S>(source: S, max_bit_len: u8) -> Result<Self, u8> |
| 129 | + where |
| 130 | + S: IntoIterator<Item = T>, |
| 131 | + T: PartialEq + Hash, |
| 132 | + { |
| 133 | + // Count the values. |
| 134 | + let mut map = HashMap::new(); |
| 135 | + for item in source { |
| 136 | + let counter = map.entry(item).or_insert(0.into()); |
| 137 | + *counter += 1.into(); |
| 138 | + } |
| 139 | + // Then compute the `Keys`. |
| 140 | + Self::from_instances(map, max_bit_len) |
| 141 | + } |
| 142 | + |
| 143 | + /// Compute a `Keys` from a sequence of values |
| 144 | + /// with a number of instances already attached. |
| 145 | + /// |
| 146 | + /// The current implementation only attempts to produce the best compression |
| 147 | + /// level. This may cause us to exceed `max_bit_length` even though an |
| 148 | + /// alternative table, with a lower compression level, would let us |
| 149 | + /// proceed without exceeding `max_bit_length`. |
| 150 | + /// |
| 151 | + /// # Requirement |
| 152 | + /// |
| 153 | + /// Values of `T` in the source MUST be distinct. |
| 154 | + pub fn from_instances<S>(source: S, max_bit_len: u8) -> Result<Self, u8> |
| 155 | + where |
| 156 | + S: IntoIterator<Item = (T, Instances)>, |
| 157 | + { |
| 158 | + let mut bit_lengths = Self::compute_bit_lengths(source, max_bit_len)?; |
| 159 | + |
| 160 | + // Canonicalize order: (BitLen, T) |
| 161 | + // As values of `T` are |
| 162 | + bit_lengths.sort_unstable_by_key(|&(ref value, ref bit_len)| (*bit_len, value.clone())); |
| 163 | + |
| 164 | + // The bits associated to the next value. |
| 165 | + let mut bits = 0; |
| 166 | + let mut keys = Vec::with_capacity(bit_lengths.len()); |
| 167 | + |
| 168 | + for i in 0..bit_lengths.len() - 1 { |
| 169 | + let (bit_len, symbol, next_bit_len) = ( |
| 170 | + bit_lengths[i].1, |
| 171 | + bit_lengths[i].0.clone(), |
| 172 | + bit_lengths[i + 1].1, |
| 173 | + ); |
| 174 | + keys.push((symbol.clone(), Key { bits, bit_len })); |
| 175 | + bits = (bits + 1) << (next_bit_len - bit_len); |
| 176 | + } |
| 177 | + // Handle the last element. |
| 178 | + let (ref symbol, bit_len) = bit_lengths[bit_lengths.len() - 1]; |
| 179 | + keys.push((symbol.clone(), Key { bits, bit_len })); |
| 180 | + |
| 181 | + return Ok(Self { keys }); |
| 182 | + } |
| 183 | + |
| 184 | + /// Convert a sequence of values labelled by their number of instances |
| 185 | + /// into a sequence of values labelled by the length for their path |
| 186 | + /// in the Huffman tree, aka the bitlength of their Huffman key. |
| 187 | + /// |
| 188 | + /// Values that have 0 instances are skipped. |
| 189 | + pub fn compute_bit_lengths<S>(source: S, max_bit_len: u8) -> Result<Vec<(T, BitLen)>, u8> |
| 190 | + where |
| 191 | + S: IntoIterator<Item = (T, Instances)>, |
| 192 | + { |
| 193 | + // Build a min-heap sorted by number of instances. |
| 194 | + use std::cmp::Reverse; |
| 195 | + let mut heap = BinaryHeap::new(); |
| 196 | + |
| 197 | + // Skip values that have 0 instances. |
| 198 | + for (value, instances) in source { |
| 199 | + if !instances.is_zero() { |
| 200 | + heap.push(Reverse(Node { |
| 201 | + instances, |
| 202 | + content: NodeContent::Leaf(value), |
| 203 | + })); |
| 204 | + } |
| 205 | + } |
| 206 | + |
| 207 | + let len = heap.len(); |
| 208 | + if len == 0 { |
| 209 | + // Special case: no tree to build. |
| 210 | + return Ok(vec![]); |
| 211 | + } |
| 212 | + |
| 213 | + // Take the two rarest nodes, merge them behind a prefix, |
| 214 | + // turn them into a single node with combined number of |
| 215 | + // instances. Repeat. |
| 216 | + while heap.len() > 1 { |
| 217 | + let left = heap.pop().unwrap(); |
| 218 | + let right = heap.pop().unwrap(); |
| 219 | + heap.push(Reverse(Node { |
| 220 | + instances: left.0.instances + right.0.instances, |
| 221 | + content: NodeContent::Internal { |
| 222 | + left: Box::new(left.0.content), |
| 223 | + right: Box::new(right.0.content), |
| 224 | + }, |
| 225 | + })); |
| 226 | + } |
| 227 | + |
| 228 | + // Convert tree into bit lengths |
| 229 | + let root = heap.pop().unwrap(); // We have checked above that there is at least one value. |
| 230 | + let mut bit_lengths = Vec::with_capacity(len); |
| 231 | + fn aux<T>( |
| 232 | + bit_lengths: &mut Vec<(T, BitLen)>, |
| 233 | + max_bit_len: u8, |
| 234 | + depth: u8, |
| 235 | + node: &NodeContent<T>, |
| 236 | + ) -> Result<(), u8> |
| 237 | + where |
| 238 | + T: Clone, |
| 239 | + { |
| 240 | + match *node { |
| 241 | + NodeContent::Leaf(ref value) => { |
| 242 | + if depth > max_bit_len { |
| 243 | + return Err(depth); |
| 244 | + } |
| 245 | + bit_lengths.push((value.clone(), BitLen(depth))); |
| 246 | + Ok(()) |
| 247 | + } |
| 248 | + NodeContent::Internal { |
| 249 | + ref left, |
| 250 | + ref right, |
| 251 | + } => { |
| 252 | + aux(bit_lengths, max_bit_len, depth + 1, left)?; |
| 253 | + aux(bit_lengths, max_bit_len, depth + 1, right)?; |
| 254 | + Ok(()) |
| 255 | + } |
| 256 | + } |
| 257 | + } |
| 258 | + aux(&mut bit_lengths, max_bit_len, 0, &root.0.content)?; |
| 259 | + |
| 260 | + Ok(bit_lengths) |
| 261 | + } |
| 262 | +} |
| 263 | + |
| 264 | +#[test] |
| 265 | +fn test_coded_from_sequence() { |
| 266 | + let sample = "appl"; |
| 267 | + let coded = Keys::from_sequence(sample.chars(), std::u8::MAX).unwrap(); |
| 268 | + |
| 269 | + // Symbol 'p' appears twice, we should see 3 codes. |
| 270 | + assert_eq!(coded.keys.len(), 3); |
| 271 | + |
| 272 | + // Check order of symbols. |
| 273 | + assert_eq!(coded.keys[0].0, 'p'); |
| 274 | + assert_eq!(coded.keys[1].0, 'a'); |
| 275 | + assert_eq!(coded.keys[2].0, 'l'); |
| 276 | + |
| 277 | + // Check bit length of symbols. |
| 278 | + assert_eq!(coded.keys[0].1.bit_len, 1.into()); |
| 279 | + assert_eq!(coded.keys[1].1.bit_len, 2.into()); |
| 280 | + assert_eq!(coded.keys[2].1.bit_len, 2.into()); |
| 281 | + |
| 282 | + // Check code of symbols. |
| 283 | + assert_eq!(coded.keys[0].1.bits, 0b00); |
| 284 | + assert_eq!(coded.keys[1].1.bits, 0b10); |
| 285 | + assert_eq!(coded.keys[2].1.bits, 0b11); |
| 286 | + |
| 287 | + // Let's try again with a limit to 1 bit paths. |
| 288 | + assert_eq!(Keys::from_sequence(sample.chars(), 1).unwrap_err(), 2); |
| 289 | +} |
0 commit comments