use alloc::vec::Vec; use core::cmp::Ordering; use crate::{ bit_io::BitWriter, fse::fse_encoder::{self, FSEEncoder}, }; pub(crate) struct HuffmanEncoder<'output, 'table, V: AsMut>> { table: &'table HuffmanTable, writer: &'output mut BitWriter, } impl>> HuffmanEncoder<'_, '_, V> { pub fn new<'o, 't>( table: &'t HuffmanTable, writer: &'o mut BitWriter, ) -> HuffmanEncoder<'o, 't, V> { HuffmanEncoder { table, writer } } /// Encodes the data using the provided table /// Writes /// * Table description /// * Encoded data /// * Padding bits to fill up last byte pub fn encode(&mut self, data: &[u8], with_table: bool) { if with_table { self.write_table(); } Self::encode_stream(self.table, self.writer, data); } /// Encodes the data using the provided table in 4 concatenated streams /// Writes /// * Table description /// * Jumptable /// * Encoded data in 4 streams, each padded to fill the last byte pub fn encode4x(&mut self, data: &[u8], with_table: bool) { assert!(data.len() >= 4); // Split data in 4 equally sized parts (the last one might be a bit smaller than the rest) let split_size = data.len().div_ceil(4); let src1 = &data[..split_size]; let src2 = &data[split_size..split_size * 2]; let src3 = &data[split_size * 2..split_size * 3]; let src4 = &data[split_size * 3..]; // Write table description if with_table { self.write_table(); } // Reserve space for the jump table, will be changed later let size_idx = self.writer.index(); self.writer.write_bits(0u16, 16); self.writer.write_bits(0u16, 16); self.writer.write_bits(0u16, 16); // Write the 4 streams, noting the sizes of the encoded streams let index_before = self.writer.index(); Self::encode_stream(self.table, self.writer, src1); let size1 = (self.writer.index() - index_before) / 8; let index_before = self.writer.index(); Self::encode_stream(self.table, self.writer, src2); let size2 = (self.writer.index() - index_before) / 8; let index_before = self.writer.index(); Self::encode_stream(self.table, self.writer, src3); let size3 = (self.writer.index() - index_before) / 8; Self::encode_stream(self.table, self.writer, src4); // Sanity check, if this doesn't hold we produce a broken stream assert!(size1 <= u16::MAX as usize); assert!(size2 <= u16::MAX as usize); assert!(size3 <= u16::MAX as usize); // Update the jumptable with the real sizes self.writer.change_bits(size_idx, size1 as u16, 16); self.writer.change_bits(size_idx + 16, size2 as u16, 16); self.writer.change_bits(size_idx + 32, size3 as u16, 16); } /// Encode one stream and pad it to fill the last byte fn encode_stream>>( table: &HuffmanTable, writer: &mut BitWriter, data: &[u8], ) { for symbol in data.iter().rev() { let (code, num_bits) = table.codes[*symbol as usize]; debug_assert!(num_bits > 0); writer.write_bits(code, num_bits as usize); } let bits_to_fill = writer.misaligned(); if bits_to_fill == 0 { writer.write_bits(1u32, 8); } else { writer.write_bits(1u32, bits_to_fill); } } pub(super) fn weights(&self) -> Vec { let max = self.table.codes.iter().map(|(_, nb)| nb).max().unwrap(); let weights = self .table .codes .iter() .copied() .map(|(_, nb)| if nb == 0 { 0 } else { max - nb + 1 }) .collect::>(); weights } fn write_table(&mut self) { // TODO strategy for determining this? let weights = self.weights(); let weights = &weights[..weights.len() - 1]; // dont encode last weight if weights.len() > 16 { let size_idx = self.writer.index(); self.writer.write_bits(0u8, 8); let idx_before = self.writer.index(); let mut encoder = FSEEncoder::new( fse_encoder::build_table_from_data(weights.iter().copied(), 6, true), self.writer, ); encoder.encode_interleaved(weights); let encoded_len = (self.writer.index() - idx_before) / 8; assert!(encoded_len < 128); self.writer.change_bits(size_idx, encoded_len as u8, 8); } else { self.writer.write_bits(weights.len() as u8 + 127, 8); let pairs = weights.chunks_exact(2); let remainder = pairs.remainder(); for pair in pairs.into_iter() { let weight1 = pair[0]; let weight2 = pair[1]; assert!(weight1 < 16); assert!(weight2 < 16); self.writer.write_bits(weight2, 4); self.writer.write_bits(weight1, 4); } if !remainder.is_empty() { let weight = remainder[0]; assert!(weight < 16); self.writer.write_bits(weight << 4, 8); } } } } pub struct HuffmanTable { /// Index is the symbol, values are the bitstring in the lower bits of the u32 and the amount of bits in the u8 codes: Vec<(u32, u8)>, } impl HuffmanTable { pub fn build_from_data(data: &[u8]) -> Self { let mut counts = [0; 256]; let mut max = 0; for x in data { counts[*x as usize] += 1; max = max.max(*x); } Self::build_from_counts(&counts[..=max as usize]) } pub fn build_from_counts(counts: &[usize]) -> Self { assert!(counts.len() <= 256); let zeros = counts.iter().filter(|x| **x == 0).count(); let mut weights = distribute_weights(counts.len() - zeros); let limit = weights.len().ilog2() as usize + 2; redistribute_weights(&mut weights, limit); weights.reverse(); let mut counts_sorted = counts.iter().enumerate().collect::>(); counts_sorted.sort_by(|(_, c1), (_, c2)| c1.cmp(c2)); let mut weights_distributed = alloc::vec![0; counts.len()]; for (idx, count) in counts_sorted { if *count == 0 { weights_distributed[idx] = 0; } else { weights_distributed[idx] = weights.pop().unwrap(); } } Self::build_from_weights(&weights_distributed) } pub fn build_from_weights(weights: &[usize]) -> Self { let mut sorted = Vec::with_capacity(weights.len()); struct SortEntry { symbol: u8, weight: usize, } // TODO this doesn't need to be a temporary Vec, it could be done in a [_; 264] // only non-zero weights are interesting here for (symbol, weight) in weights.iter().copied().enumerate() { if weight > 0 { sorted.push(SortEntry { symbol: symbol as u8, weight, }); } } // We process symbols ordered by weight and then ordered by symbol sorted.sort_by(|left, right| match left.weight.cmp(&right.weight) { Ordering::Equal => left.symbol.cmp(&right.symbol), other => other, }); // Prepare huffman table with placeholders let mut table = HuffmanTable { codes: Vec::with_capacity(weights.len()), }; for _ in 0..weights.len() { table.codes.push((0, 0)); } // Determine the number of bits needed for codes with the lowest weight let weight_sum = sorted.iter().map(|e| 1 << (e.weight - 1)).sum::(); if !weight_sum.is_power_of_two() { panic!("This is an internal error"); } let max_num_bits = highest_bit_set(weight_sum) - 1; // this is a log_2 of a clean power of two // Starting at the symbols with the lowest weight we update the placeholders in the table let mut current_code = 0; let mut current_weight = 0; let mut current_num_bits = 0; for entry in sorted.iter() { // If the entry isn't the same weight as the last one we need to change a few things if current_weight != entry.weight { // The code shifts by the difference of the weights to allow for enough unique values current_code >>= entry.weight - current_weight; // Encoding a symbol of this weight will take less bits than the previous weight current_num_bits = max_num_bits - entry.weight + 1; // Run the next update when the weight changes again current_weight = entry.weight; } table.codes[entry.symbol as usize] = (current_code as u32, current_num_bits as u8); current_code += 1; } table } pub fn can_encode(&self, other: &Self) -> Option { if other.codes.len() > self.codes.len() { return None; } let mut sum = 0; for ((_, other_num_bits), (_, self_num_bits)) in other.codes.iter().zip(self.codes.iter()) { if *other_num_bits != 0 && *self_num_bits == 0 { return None; } sum += other_num_bits.abs_diff(*self_num_bits) as usize; } Some(sum) } } /// Assert that the provided value is greater than zero, and returns index of the first set bit fn highest_bit_set(x: usize) -> usize { assert!(x > 0); usize::BITS as usize - x.leading_zeros() as usize } #[test] fn huffman() { let table = HuffmanTable::build_from_weights(&[2, 2, 2, 1, 1]); assert_eq!(table.codes[0], (1, 2)); assert_eq!(table.codes[1], (2, 2)); assert_eq!(table.codes[2], (3, 2)); assert_eq!(table.codes[3], (0, 3)); assert_eq!(table.codes[4], (1, 3)); let table = HuffmanTable::build_from_weights(&[4, 3, 2, 0, 1, 1]); assert_eq!(table.codes[0], (1, 1)); assert_eq!(table.codes[1], (1, 2)); assert_eq!(table.codes[2], (1, 3)); assert_eq!(table.codes[3], (0, 0)); assert_eq!(table.codes[4], (0, 4)); assert_eq!(table.codes[5], (1, 4)); } /// Distributes weights that add up to a clean power of two fn distribute_weights(amount: usize) -> Vec { assert!(amount >= 2); assert!(amount <= 256); let mut weights = Vec::new(); // This is the trivial power of two we always need weights.push(1); weights.push(1); // This is the weight we are adding right now let mut target_weight = 1; // Counts how many times we have added weights let mut weight_counter = 2; // We always add a power of 2 new weights so that the weights that we add equal // the weights are already in the vec if raised to the power of two. // This means we double the weights in the vec -> results in a new power of two // // Example: [1, 1] -> [1,1,2] (2^1 + 2^1 == 2^2) // // Example: [1, 1] -> [1,1,1,1] (2^1 + 2^1 == 2^1 + 2^1) // [1,1,1,1] -> [1,1,1,1,3] (2^1 + 2^1 + 2^1 + 2^1 == 2^3) while weights.len() < amount { let mut add_new = 1 << (weight_counter - target_weight); let available_space = amount - weights.len(); // If the amount of new weights needed to get to the next power of two would exceed amount // We instead add 1 of a bigger weight and start the cycle again if add_new > available_space { // TODO we could maybe instead do this until add_new <= available_space? // target_weight += 1 // add_new /= 2 target_weight = weight_counter; add_new = 1; } for _ in 0..add_new { weights.push(target_weight); } weight_counter += 1; } assert_eq!(amount, weights.len()); weights } /// Sometimes distribute_weights generates weights that require too many bits to encode /// This redistributes the weights to have less variance by raising the lower weights while still maintaining the /// required attributes of the weight distribution fn redistribute_weights(weights: &mut [usize], max_num_bits: usize) { let weight_sum_log = weights .iter() .copied() .map(|x| 1 << x) .sum::() .ilog2() as usize; // Nothing needs to be done, this is already fine if weight_sum_log < max_num_bits { return; } // We need to decrease the weight difference by the difference between weight_sum_log and max_num_bits let decrease_weights_by = weight_sum_log - max_num_bits + 1; // To do that we raise the lower weights up by that difference, recording how much weight we added in the process let mut added_weights = 0; for weight in weights.iter_mut() { if *weight < decrease_weights_by { for add in *weight..decrease_weights_by { added_weights += 1 << add; } *weight = decrease_weights_by; } } // Then we reduce weights until the added weights are equaled out while added_weights > 0 { // Find the highest weight that is still lower or equal to the added weight let mut current_idx = 0; let mut current_weight = 0; for (idx, weight) in weights.iter().copied().enumerate() { if 1 << (weight - 1) > added_weights { break; } if weight > current_weight { current_weight = weight; current_idx = idx; } } // Reduce that weight by 1 added_weights -= 1 << (current_weight - 1); weights[current_idx] -= 1; } // At the end we normalize the weights so that they start at 1 again if weights[0] > 1 { let offset = weights[0] - 1; for weight in weights.iter_mut() { *weight -= offset; } } } #[test] fn weights() { // assert_eq!(distribute_weights(5).as_slice(), &[1, 1, 2, 3, 4]); for amount in 2..=256 { let mut weights = distribute_weights(amount); assert_eq!(weights.len(), amount); let sum = weights .iter() .copied() .map(|weight| 1 << weight) .sum::(); assert!(sum.is_power_of_two()); for num_bit_limit in (amount.ilog2() as usize + 1)..=11 { redistribute_weights(&mut weights, num_bit_limit); let sum = weights .iter() .copied() .map(|weight| 1 << weight) .sum::(); assert!(sum.is_power_of_two()); assert!( sum.ilog2() <= 11, "Max bits too big: sum: {} {weights:?}", sum ); let codes = HuffmanTable::build_from_weights(&weights).codes; for (code, num_bits) in codes.iter().copied() { for (code2, num_bits2) in codes.iter().copied() { if num_bits == 0 || num_bits2 == 0 || (code, num_bits) == (code2, num_bits2) { continue; } if num_bits <= num_bits2 { let code2_shifted = code2 >> (num_bits2 - num_bits); assert_ne!( code, code2_shifted, "{:b},{num_bits:} is prefix of {:b},{num_bits2:}", code, code2 ); } } } } } } #[test] fn counts() { let counts = &[3, 0, 4, 1, 5]; let table = HuffmanTable::build_from_counts(counts).codes; assert_eq!(table[1].1, 0); assert!(table[3].1 >= table[0].1); assert!(table[0].1 >= table[2].1); assert!(table[2].1 >= table[4].1); let counts = &[3, 0, 4, 0, 7, 2, 2, 2, 0, 2, 2, 1, 5]; let table = HuffmanTable::build_from_counts(counts).codes; assert_eq!(table[1].1, 0); assert_eq!(table[3].1, 0); assert_eq!(table[8].1, 0); assert!(table[11].1 >= table[5].1); assert!(table[5].1 >= table[6].1); assert!(table[6].1 >= table[7].1); assert!(table[7].1 >= table[9].1); assert!(table[9].1 >= table[10].1); assert!(table[10].1 >= table[0].1); assert!(table[0].1 >= table[2].1); assert!(table[2].1 >= table[12].1); assert!(table[12].1 >= table[4].1); } #[test] fn from_data() { let counts = &[3, 0, 4, 1, 5]; let table = HuffmanTable::build_from_counts(counts).codes; let data = &[0, 2, 4, 4, 0, 3, 2, 2, 0, 2]; let table2 = HuffmanTable::build_from_data(data).codes; assert_eq!(table, table2); }