//! Implementations of radix keys and sorting functions. use core::mem; use crate::{double_buffer::DoubleBuffer, Key}; /// Unsigned integers used as sorting keys for radix sort. /// /// These keys can be sorted bitwise. For conversion from scalar types, see /// [`Scalar::to_radix_key()`]. /// /// [`Scalar::to_radix_key()`]: ../scalar/trait.Scalar.html#tymethod.to_radix_key pub trait RadixKey: Key { /// Sorts the slice using provided key extraction function. /// Runs one of the other functions, based on the length of the slice. #[inline] fn radix_sort(slice: &mut [T], mut key_fn: F, unopt: bool) where F: FnMut(&T) -> Self, { // Sorting has no meaningful behavior on zero-sized types. if mem::size_of::() == 0 { return; } let len = slice.len(); if len < 2 { return; } #[cfg(not(any(target_pointer_width = "16", target_pointer_width = "32")))] { if len <= u32::MAX as usize { Self::radix_sort_u32(slice, |t| key_fn(t), unopt); return; } } Self::radix_sort_usize(slice, |t| key_fn(t), unopt); } /// Sorting for slices with up to `u32::MAX` elements, which is a majority /// of cases. Uses `u32` indices for histograms and offsets to save cache /// space. #[cfg(not(any(target_pointer_width = "16", target_pointer_width = "32")))] fn radix_sort_u32(slice: &mut [T], key_fn: F, unopt: bool) where F: FnMut(&T) -> Self; /// Sorting function for slices with up to `usize::MAX` elements. fn radix_sort_usize(slice: &mut [T], key_fn: F, unopt: bool) where F: FnMut(&T) -> Self; } macro_rules! sort_impl { ($name:ident, $radix_key_type:ty, $offset_type:ty) => { #[inline(never)] // Don't inline, the offset array needs a lot of stack fn $name(input: &mut [T], mut key_fn: F, unopt: bool) where F: FnMut(&T) -> $radix_key_type, { // This implementation is radix 256, so the size of a digit is 8 bits / one byte. // You can experiment with different digit sizes by changing this constant, but // according to my benchmarks, the overhead from arbitrary shifting and masking // will be higher than what you save by having less digits. const DIGIT_BITS: usize = 8; const RADIX_KEY_BITS: usize = mem::size_of::<$radix_key_type>() * 8; // Have one bucket for each possible value of the digit const BUCKET_COUNT: usize = 1 << DIGIT_BITS; const DIGIT_COUNT: usize = (RADIX_KEY_BITS + DIGIT_BITS - 1) / DIGIT_BITS; let digit_skip_enabled: bool = !unopt; /// Extracts the digit from the key, starting with the least significant digit. /// The digit is used as a bucket index. #[inline(always)] fn extract_digit(key: $radix_key_type, digit: usize) -> usize { const DIGIT_MASK: $radix_key_type = ((1 << DIGIT_BITS) - 1) as $radix_key_type; ((key >> (digit * DIGIT_BITS)) & DIGIT_MASK) as usize } // In the worst case (`u128` key, `input.len() >= u32::MAX`) uses 32 KiB on the stack. let mut offsets = [[0 as $offset_type; BUCKET_COUNT]; DIGIT_COUNT]; let mut skip_digit = [false; DIGIT_COUNT]; { // Calculate bucket offsets for each digit. // Calculate histograms/bucket sizes and store in `offsets`. for t in input.iter() { let key = key_fn(t); for digit in 0..DIGIT_COUNT { offsets[digit][extract_digit(key, digit)] += 1; } } if digit_skip_enabled { // For each digit, check if all the elements are in the same bucket. // If so, we can skip the whole digit. Instead of checking all the buckets, // we pick a key and check whether the bucket contains all the elements. let last_key = key_fn(input.last().unwrap()); for digit in 0..DIGIT_COUNT { let last_bucket = extract_digit(last_key, digit); let skip = offsets[digit][last_bucket] == input.len() as $offset_type; skip_digit[digit] = skip; } } // Turn the histogram/bucket sizes into bucket offsets by calculating a prefix sum. // Sizes: |---b1---|-b2-|---b3---|----b4----| // Offsets: 0 b1 b1+b2 b1+b2+b3 for digit in 0..DIGIT_COUNT { if !(digit_skip_enabled && skip_digit[digit]) { let mut offset_acc = 0; for count in offsets[digit].iter_mut() { let offset = offset_acc; offset_acc += *count; *count = offset; } } } // The `offsets` array now contains bucket offsets for each digit. } let len = input.len(); // Drop impl of DoubleBuffer ensures that `input` is consistent, // e.g. in case of panic in the key function. let mut buffer = DoubleBuffer::new(input); // This is the main sorting loop. We sort the elements by each digit of the key, // starting from the least-significant. After sorting by the last, most significant // digit, our elements are sorted. for digit in 0..DIGIT_COUNT { if !(digit_skip_enabled && skip_digit[digit]) { // Initial offset of each bucket. let init_offsets = &offsets[digit]; // Offset of the first empty index in each bucket. let mut working_offsets = *init_offsets; buffer.scatter(|t| { let key = key_fn(t); let bucket = extract_digit(key, digit); let offset = &mut working_offsets[bucket]; let index = *offset as usize; // Increment the offset of the bucket. Use wrapping add in case the // key function is unreliable and the bucket overflowed. *offset = offset.wrapping_add(1); index }); // Check that each bucket had the same number of insertions as we expected. // If this is not true, then the key function is unreliable and some elements // in the write buffer were not written to. // // If the key function is unreliable, but the sizes of buckets ended up being // the same, it would not get detected. This is sound, the only consequence is // that the elements won't be sorted right. { // The `working_offsets` array now contains the end offset of each bucket. // If the bucket is full, the working offset is now equal to the original // offset of the next bucket. The working offset of the last bucket should // be equal to the number of elements. let bucket_sizes_match = working_offsets[0..BUCKET_COUNT - 1] == offsets[digit][1..BUCKET_COUNT] && working_offsets[BUCKET_COUNT - 1] == len as $offset_type; if !bucket_sizes_match { // The bucket sizes do not match expected sizes, the key function is // unreliable (programming mistake). // // The Drop impl will copy the last completed buffer into the slice. drop(buffer); panic!( "The key function is not reliable: when called repeatedly, \ it returned different keys for the same element." ) } } unsafe { // SAFETY: we just ensured that every index was written to. buffer.swap(); } } } // The Drop impl will copy the last completed buffer into the slice. drop(buffer); } }; } macro_rules! radix_key_impl { ($($key_type:ty)*) => ($( impl RadixKey for $key_type { #[cfg(not(any(target_pointer_width = "16", target_pointer_width = "32")))] sort_impl!(radix_sort_u32, $key_type, u32); sort_impl!(radix_sort_usize, $key_type, usize); } )*) } radix_key_impl! { u8 u16 u32 u64 u128 }