Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,10 @@ if(ipo_supported AND (CMAKE_BUILD_TYPE STREQUAL "Release"))
set_target_properties(_simple_ans PROPERTIES INTERPROCEDURAL_OPTIMIZATION TRUE)
endif()

set_target_properties(_simple_ans PROPERTIES
CXX_VISIBILITY_PRESET hidden
VISIBILITY_INLINES_HIDDEN ON
)

# Install rules
install(TARGETS _simple_ans LIBRARY DESTINATION simple_ans)
67 changes: 20 additions & 47 deletions simple_ans/cpp/simple_ans.hpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
#pragma once

#include <algorithm>
#include <bit>
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <execution>
#include <limits>
#include <numeric>
#include <stdexcept>
#include <vector>

#include <tsl/robin_map.h>

#include "libdivide.h"
#include <libdivide.h>

namespace simple_ans
{
Expand All @@ -20,7 +23,7 @@ struct EncodedData
};

// Helper function to verify if a number is a power of 2
inline bool is_power_of_2(uint32_t x)
constexpr bool is_power_of_2(uint32_t x)
{
return x && !(x & (x - 1));
}
Expand Down Expand Up @@ -113,29 +116,17 @@ EncodedData ans_encode_t(const T* signal,
"Value range of T must fit in int64_t for table lookup");

// Calculate L and verify it's a power of 2
uint32_t index_size = 0;
for (size_t i = 0; i < num_symbols; ++i)
{
index_size += symbol_counts[i];
}
const uint32_t index_size = std::reduce(symbol_counts, symbol_counts + num_symbols, 0u);
if (!is_power_of_2(index_size))
{
throw std::invalid_argument("L must be a power of 2");
}

int PRECISION_BITS = 0;
while ((1U << PRECISION_BITS) < index_size)
{
PRECISION_BITS++;
}
const auto PRECISION_BITS = std::bit_width(index_size - 1);

// Pre-compute cumulative sums
std::vector<uint32_t> C(num_symbols);
C[0] = 0;
for (size_t i = 1; i < num_symbols; ++i)
{
C[i] = C[i - 1] + symbol_counts[i - 1];
}
std::exclusive_scan(symbol_counts, symbol_counts + num_symbols, C.begin(), 0u);

// Precompute libdivide dividers for each symbol count
std::vector<libdivide::divider<uint64_t>> fast_dividers(num_symbols);
Expand All @@ -157,15 +148,12 @@ EncodedData ans_encode_t(const T* signal,
}

// Map lookups can be a bottleneck, so we use a lookup array if the number of symbols is "small"
const bool use_lookup_array = (max_symbol - min_symbol + 1) <= lookup_array_threshold;
std::vector<size_t> symbol_index_lookup_array;
const auto array_size = max_symbol - min_symbol + 1;
const bool use_lookup_array = array_size <= lookup_array_threshold;
std::vector<size_t> symbol_index_lookup_array(0);
if (use_lookup_array)
{
symbol_index_lookup_array.resize(max_symbol - min_symbol + 1);

std::fill(symbol_index_lookup_array.begin(),
symbol_index_lookup_array.end(),
std::numeric_limits<size_t>::max());
symbol_index_lookup_array.resize(array_size);

for (size_t i = 0; i < num_symbols; ++i)
{
Expand Down Expand Up @@ -241,40 +229,25 @@ void ans_decode_t(T* output,
size_t num_symbols)
{
// very important that this is signed, because it becomes -1
int32_t word_idx = num_words - 1;
auto word_idx = static_cast<int32_t>(num_words) - 1;
// Calculate index size and verify it's a power of 2
uint32_t index_size = 0;
for (size_t i = 0; i < num_symbols; ++i)
{
index_size += symbol_counts[i];
}
const uint32_t index_size = std::reduce(symbol_counts, symbol_counts + num_symbols, 0u);
if (!is_power_of_2(index_size))
{
throw std::invalid_argument("L must be a power of 2");
}

int PRECISION_BITS = 0;
while ((1U << PRECISION_BITS) < index_size)
{
PRECISION_BITS++;
}
const auto PRECISION_BITS = std::bit_width(index_size - 1);

// Pre-compute cumulative sums
std::vector<uint32_t> C(num_symbols);
C[0] = 0;
for (size_t i = 1; i < num_symbols; ++i)
{
C[i] = C[i - 1] + symbol_counts[i - 1];
}
std::exclusive_scan(symbol_counts, symbol_counts + num_symbols, C.begin(), 0u);

// Create symbol lookup table
std::vector<uint32_t> symbol_lookup(index_size);
for (size_t s = 0; s < num_symbols; ++s)
for (uint32_t s = 0; s < num_symbols; ++s)
{
for (uint32_t j = 0; j < symbol_counts[s]; ++j)
{
symbol_lookup[C[s] + j] = s;
}
std::fill_n(symbol_lookup.begin() + C[s], symbol_counts[s], s);
}

// Decode symbols in reverse order
Expand Down