Musa-Cpp-Lib-V2/lib/Base/RadixSort.cpp

128 lines
3.4 KiB
C++

struct RadixSort {
ArrayView<u32> ranks;
ArrayView<u32> ranks2;
Allocator allocator;
bool valid_ranks;
};
void radix_sort_init (RadixSort* r, u32 items_to_allocate) {
if (r->allocator.proc == nullptr) {
r->allocator = context_allocator();
}
push_allocator(r->allocator);
r->ranks = ArrayView<u32>(items_to_allocate);
r->ranks2 = ArrayView<u32>(items_to_allocate);
r->valid_ranks = false;
}
void radix_sort_free (RadixSort* r) {
Assert(r->allocator.proc != nullptr);
push_allocator(r->allocator);
array_free(r->ranks);
array_free(r->ranks2);
}
// RadixSort provides an array of indices in sorted order.
u32 rank (RadixSort* r, s64 i) {
Assert(r != nullptr);
#if ARRAY_ENABLE_BOUNDS_CHECKING
if (i < 0 || i >= r->ranks.count) { debug_break(); /*INDEX OOB*/ }
#endif
return r->ranks[i];
}
template <typename T> void create_histograms (RadixSort* r, T* buffer, u32 count, u32* histogram) {
constexpr u32 bucket_count = sizeof(T);
// Init bucket pointers:
u32* h[bucket_count] = {};
for (u32 i = 0; i < bucket_count; i += 1) {
h[i] = histogram + (256 * i);
}
// Build histogram:
u8* p = (u8*)buffer;
u8* pe = (p + count * sizeof(T));
while (p != pe) {
h[0][*p] += 1; p += 1;
if (bucket_count > 1) { // how to make compile time if?
h[1][*p] += 1; p += 1;
if (bucket_count > 2) {
h[2][*p] += 1; p += 1;
h[3][*p] += 1; p += 1;
if (bucket_count == 8) {
h[4][*p] += 1; p += 1;
h[5][*p] += 1; p += 1;
h[6][*p] += 1; p += 1;
h[7][*p] += 1; p += 1;
}
}
}
}
}
template <typename T> void radix_sort (RadixSort* r, T* input, u32 count) {
constexpr u32 T_SIZE = sizeof(T);
// Allocate histograms & offsets on the stack:
u32 histogram [256 * T_SIZE] = {};
u32* link [256];
create_histograms(r, input, count, histogram);
// Radix sort, j is the pass number, (0 = LSB, P = MSB)
for (u32 j = 0; j < T_SIZE; j += 1) {
u32* h = &histogram[j * 256];
u8* input_bytes = (u8*)input;
input_bytes += j; // Assumes little endian!
if (h[input_bytes[0]] == count) {
continue;
}
// Create offsets
link[0] = r->ranks2.data;
for (u32 i = 1; i < 256; i += 1) { // 1..255
link[i] = link[i-1] + h[i-1];
}
// Perform Radix Sort
if (!r->valid_ranks) {
for (u32 i = 0; i < count; i += 1) {
*link[input_bytes[i*T_SIZE]] = i;
link[input_bytes[i*T_SIZE]] += 1;
}
r->valid_ranks = true;
} else {
for (u32 i = 0; i < count; i += 1) {
u32 idx = r->ranks[i];
*link[input_bytes[idx*T_SIZE]] = idx;
link[input_bytes[idx*T_SIZE]] += 1;
}
}
// Swap pointers for next pass. Valid indices - the most recent ones - are in ranks after the swap.
ArrayView<u32> ranks2_temp = r->ranks2;
r->ranks2 = r->ranks;
r->ranks = ranks2_temp;
}
// All values were equal; generate linear ranks
if (!r->valid_ranks) {
for (u32 i = 0; i < count; i += 1) {
r->ranks[i] = i;
r->valid_ranks = true;
}
}
}
// NOTE: For a small number of elements it's more efficient to use insertion sort
void radix_sort_u64 (RadixSort* r, u64* input, u32 count) {
if (input == nullptr || count == 0) return;
if (r->ranks.count == 0) {
radix_sort_init(r, count);
}
radix_sort(r, input, count);
}