banan-os/userspace/libraries/LibDEFLATE/Compressor.cpp

621 lines
16 KiB
C++

#include <LibDEFLATE/Compressor.h>
#include <LibDEFLATE/Utils.h>
#include <BAN/Array.h>
#include <BAN/Heap.h>
#include <BAN/Optional.h>
#include <BAN/Sort.h>
namespace LibDEFLATE
{
constexpr size_t s_max_length = 258;
constexpr size_t s_max_distance = 32768;
constexpr size_t s_max_symbols = 288;
constexpr uint8_t s_max_bits = 15;
struct Leaf
{
uint16_t code;
uint8_t length;
};
static BAN::ErrorOr<void> create_huffman_tree(BAN::Span<const size_t> freq, BAN::Span<Leaf> output)
{
ASSERT(freq.size() <= s_max_symbols);
ASSERT(freq.size() == output.size());
struct node_t
{
size_t symbol;
size_t freq;
node_t* left;
node_t* right;
};
#if LIBDEFLATE_AVOID_STACK
BAN::Vector<node_t*> nodes;
TRY(nodes.resize(s_max_symbols));
#else
BAN::Array<node_t*, s_max_symbols> nodes;
#endif
size_t node_count = 0;
for (size_t sym = 0; sym < freq.size(); sym++)
{
if (freq[sym] == 0)
continue;
nodes[node_count] = static_cast<node_t*>(BAN::allocator(sizeof(node_t)));
if (nodes[node_count] == nullptr)
{
for (size_t j = 0; j < node_count; j++)
BAN::deallocator(nodes[j]);
return BAN::Error::from_errno(ENOMEM);
}
*nodes[node_count++] = {
.symbol = sym,
.freq = freq[sym],
.left = nullptr,
.right = nullptr,
};
}
for (auto& symbol : output)
symbol = { .code = 0, .length = 0 };
if (node_count == 0)
{
output[0] = { .code = 0, .length = 1 };
return {};
}
static void (*free_tree)(node_t*) =
[](node_t* root) -> void {
if (root == nullptr)
return;
free_tree(root->left);
free_tree(root->right);
BAN::deallocator(root);
};
const auto comp =
[](const node_t* a, const node_t* b) -> bool {
if (a->freq != b->freq)
return a->freq > b->freq;
return a->symbol > b->symbol;
};
auto end_it = nodes.begin() + node_count;
BAN::make_heap(nodes.begin(), end_it, comp);
while (nodes.begin() + 1 != end_it)
{
node_t* parent = static_cast<node_t*>(BAN::allocator(sizeof(node_t)));
if (parent == nullptr)
{
for (auto it = nodes.begin(); it != end_it; it++)
free_tree(*it);
return BAN::Error::from_errno(ENOMEM);
}
node_t* node1 = nodes.front();
BAN::pop_heap(nodes.begin(), end_it--, comp);
node_t* node2 = nodes.front();
BAN::pop_heap(nodes.begin(), end_it--, comp);
*parent = {
.symbol = 0,
.freq = node1->freq + node2->freq,
.left = node1,
.right = node2,
};
*end_it++ = parent;
BAN::push_heap(nodes.begin(), end_it, comp);
}
static uint16_t (*gather_lengths)(const node_t*, BAN::Span<Leaf>, uint16_t) =
[](const node_t* node, BAN::Span<Leaf> symbols, uint16_t depth) -> uint16_t {
if (node == nullptr)
return 0;
uint16_t count = (depth > s_max_bits);
if (node->left == nullptr && node->right == nullptr)
symbols[node->symbol].length = BAN::Math::min<uint16_t>(depth, s_max_bits);
else
{
count += gather_lengths(node->left, symbols, depth + 1);
count += gather_lengths(node->right, symbols, depth + 1);
}
return count;
};
const auto too_long_count = gather_lengths(nodes[0], output, 0);
free_tree(nodes[0]);
uint16_t bl_count[s_max_bits + 1] {};
for (size_t sym = 0; sym < freq.size(); sym++)
if (const uint8_t len = output[sym].length)
bl_count[len]++;
if (too_long_count > 0)
{
for (size_t i = 0; i < too_long_count / 2; i++)
{
uint16_t bits = s_max_bits - 1;
while (bl_count[bits] == 0)
bits--;
bl_count[bits + 0]--;
bl_count[bits + 1] += 2;
bl_count[s_max_bits]--;
}
struct SymFreq
{
size_t symbol;
size_t freq;
};
BAN::Vector<SymFreq> sym_freq;
for (size_t sym = 0; sym < output.size(); sym++)
if (freq[sym] != 0)
TRY(sym_freq.push_back({ .symbol = sym, .freq = freq[sym] }));
BAN::sort::sort(sym_freq.begin(), sym_freq.end(),
[](auto a, auto b) { return a.freq < b.freq; }
);
size_t index = 0;
for (uint16_t bits = s_max_bits; bits > 0; bits--)
for (size_t i = 0; i < bl_count[bits]; i++)
output[sym_freq[index++].symbol].length = bits;
ASSERT(index == sym_freq.size());
}
uint16_t next_code[s_max_bits + 1] {};
uint16_t code = 0;
for (uint8_t bits = 1; bits <= s_max_bits; bits++)
{
code = (code + bl_count[bits - 1]) << 1;
next_code[bits] = code;
}
for (size_t sym = 0; sym < freq.size(); sym++)
if (const uint16_t len = output[sym].length)
output[sym].code = next_code[len]++;
return {};
}
struct Encoding
{
uint16_t symbol;
uint16_t extra_data { 0 };
uint8_t extra_len { 0 };
};
static constexpr Encoding get_len_encoding(uint16_t length)
{
ASSERT(3 <= length && length <= s_max_length);
constexpr uint16_t base[] {
3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258
};
constexpr uint8_t extra_bits[] {
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0
};
constexpr size_t count = sizeof(base) / sizeof(*base);
for (size_t i = 0;; i++)
{
if (i + 1 < count && length >= base[i + 1])
continue;
return {
.symbol = static_cast<uint16_t>(257 + i),
.extra_data = static_cast<uint16_t>(length - base[i]),
.extra_len = extra_bits[i],
};
}
}
static constexpr Encoding get_dist_encoding(uint16_t distance)
{
ASSERT(1 <= distance && distance <= s_max_distance);
constexpr uint16_t base[] {
1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577
};
constexpr uint8_t extra_bits[] {
0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13
};
constexpr size_t count = sizeof(base) / sizeof(*base);
for (size_t i = 0;; i++)
{
if (i + 1 < count && distance >= base[i + 1])
continue;
return {
.symbol = static_cast<uint16_t>(i),
.extra_data = static_cast<uint16_t>(distance - base[i]),
.extra_len = extra_bits[i],
};
}
}
static void get_frequencies(BAN::Span<const Compressor::LZ77Entry> entries, BAN::Span<size_t> lit_len_freq, BAN::Span<size_t> dist_freq)
{
ASSERT(lit_len_freq.size() == 286);
ASSERT(dist_freq.size() == 30);
for (auto entry : entries)
{
switch (entry.type)
{
case Compressor::LZ77Entry::Type::Literal:
lit_len_freq[entry.as.literal]++;
break;
case Compressor::LZ77Entry::Type::DistLength:
lit_len_freq[get_len_encoding(entry.as.dist_length.length).symbol]++;
dist_freq[get_dist_encoding(entry.as.dist_length.distance).symbol]++;
break;
}
}
lit_len_freq[256]++;
}
struct CodeLengthInfo
{
uint16_t hlit;
uint8_t hdist;
uint8_t hclen;
BAN::Vector<Encoding> encoding;
BAN::Array<uint8_t, 19> code_length;
BAN::Array<Leaf, 19> code_length_tree;
};
static BAN::ErrorOr<CodeLengthInfo> build_code_length_info(BAN::Span<const Leaf> lit_len_tree, BAN::Span<const Leaf> dist_tree)
{
CodeLengthInfo result;
const auto append_tree =
[&result](BAN::Span<const Leaf>& tree) -> BAN::ErrorOr<void>
{
while (!tree.empty() && tree[tree.size() - 1].length == 0)
tree = tree.slice(0, tree.size() - 1);
for (size_t i = 0; i < tree.size();)
{
size_t count = 1;
while (i + count < tree.size() && tree[i].length == tree[i + count].length)
count++;
if (tree[i].length == 0)
{
if (count > 138)
count = 138;
if (count < 3)
{
for (size_t j = 0; j < count; j++)
TRY(result.encoding.push_back({ .symbol = 0 }));
}
else if (count < 11)
{
TRY(result.encoding.push_back({
.symbol = 17,
.extra_data = static_cast<uint8_t>(count - 3),
.extra_len = 3,
}));
}
else
{
TRY(result.encoding.push_back({
.symbol = 18,
.extra_data = static_cast<uint8_t>(count - 11),
.extra_len = 7,
}));
}
}
else
{
if (count >= 3 && !result.encoding.empty() && result.encoding.back().symbol == tree[i].length)
{
if (count > 6)
count = 6;
TRY(result.encoding.push_back({
.symbol = 16,
.extra_data = static_cast<uint8_t>(count - 3),
.extra_len = 2,
}));
}
else
{
count = 1;
TRY(result.encoding.push_back({ .symbol = tree[i].length }));
}
}
i += count;
}
return {};
};
TRY(append_tree(lit_len_tree));
result.hlit = lit_len_tree.size();
TRY(append_tree(dist_tree));
result.hdist = dist_tree.size();
BAN::Array<size_t, 19> code_len_freq(0);
for (auto entry : result.encoding)
code_len_freq[entry.symbol]++;
TRY(create_huffman_tree(code_len_freq.span(), result.code_length_tree.span()));
constexpr uint8_t code_length_order[] {
16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15
};
for (size_t i = 0; i < result.code_length_tree.size(); i++)
result.code_length[i] = result.code_length_tree[code_length_order[i]].length;
result.hclen = 19;
while (result.hclen > 4 && result.code_length[result.hclen - 1] == 0)
result.hclen--;
return BAN::move(result);
}
uint32_t Compressor::get_hash_key(BAN::ConstByteSpan needle) const
{
ASSERT(needle.size() >= 3);
return (needle[2] << 16) | (needle[1] << 8) | needle[0];
}
BAN::ErrorOr<void> Compressor::update_hash_chain(size_t count)
{
if (m_hash_chain.size() >= s_max_distance * 2)
{
const uint8_t* current = m_data.data() + m_hash_chain_index + count;
for (auto& [_, chain] : m_hash_chain)
{
for (auto it = chain.begin(); it != chain.end(); it++)
{
const size_t distance = current - it->data();
if (distance < s_max_distance)
continue;
while (it != chain.end())
it = chain.remove(it);
break;
}
}
}
for (size_t i = 0; i < count; i++)
{
auto slice = m_data.slice(m_hash_chain_index + i);
if (slice.size() < 3)
break;
const uint32_t key = get_hash_key(slice);
auto it = m_hash_chain.find(key);
if (it != m_hash_chain.end())
TRY(it->value.insert(it->value.begin(), slice));
else
{
HashChain new_chain;
TRY(new_chain.push_back(slice));
TRY(m_hash_chain.insert(key, BAN::move(new_chain)));
}
}
m_hash_chain_index += count;
return {};
}
BAN::ErrorOr<Compressor::LZ77Entry> Compressor::find_longest_match(BAN::ConstByteSpan needle) const
{
LZ77Entry result = {
.type = LZ77Entry::Type::Literal,
.as = { .literal = needle[0] }
};
if (needle.size() < 3)
return result;
const uint32_t key = get_hash_key(needle);
auto it = m_hash_chain.find(key);
if (it == m_hash_chain.end())
return result;
auto& chain = it->value;
for (const auto node : chain)
{
const size_t distance = needle.data() - node.data();
if (distance > s_max_distance)
break;
size_t length = 3;
const size_t max_length = BAN::Math::min(needle.size(), s_max_length);
while (length < max_length && needle[length] == node[length])
length++;
if (result.type != LZ77Entry::Type::DistLength || length > result.as.dist_length.length)
{
result = LZ77Entry {
.type = LZ77Entry::Type::DistLength,
.as = {
.dist_length = {
.length = static_cast<uint16_t>(length),
.distance = static_cast<uint16_t>(distance),
}
}
};
}
}
return result;
}
BAN::ErrorOr<BAN::Vector<Compressor::LZ77Entry>> Compressor::lz77_compress(BAN::ConstByteSpan data)
{
BAN::Vector<LZ77Entry> result;
size_t advance = 0;
for (size_t i = 0; i < data.size(); i += advance)
{
TRY(update_hash_chain(advance));
auto match = TRY(find_longest_match(data.slice(i)));
if (match.type == LZ77Entry::Type::Literal)
{
TRY(result.push_back(match));
advance = 1;
continue;
}
ASSERT(match.type == LZ77Entry::Type::DistLength);
auto lazy_match = TRY(find_longest_match(data.slice(i + 1)));
if (lazy_match.type == LZ77Entry::Type::DistLength && lazy_match.as.dist_length.length > match.as.dist_length.length)
{
TRY(result.push_back({ .type = LZ77Entry::Type::Literal, .as = { .literal = data[i] }}));
TRY(result.push_back(lazy_match));
advance = 1 + lazy_match.as.dist_length.length;
}
else
{
TRY(result.push_back(match));
advance = match.as.dist_length.length;
}
}
return result;
}
BAN::ErrorOr<void> Compressor::compress_block(BAN::ConstByteSpan data, bool final)
{
// FIXME: use fixed trees or uncompressed blocks
auto lz77_entries = TRY(lz77_compress(data));
#if LIBDEFLATE_AVOID_STACK
BAN::Vector<size_t> lit_len_freq, dist_freq;
TRY(lit_len_freq.resize(286, 0));
TRY(dist_freq.resize(30, 0));
#else
BAN::Array<size_t, 286> lit_len_freq(0);
BAN::Array<size_t, 30> dist_freq(0);
#endif
get_frequencies(lz77_entries.span(), lit_len_freq.span(), dist_freq.span());
#if LIBDEFLATE_AVOID_STACK
BAN::Vector<Leaf> lit_len_tree, dist_tree;
TRY(lit_len_tree.resize(286));
TRY(dist_tree.resize(30));
#else
BAN::Array<Leaf, 286> lit_len_tree;
BAN::Array<Leaf, 30> dist_tree;
#endif
TRY(create_huffman_tree(lit_len_freq.span(), lit_len_tree.span()));
TRY(create_huffman_tree(dist_freq.span(), dist_tree.span()));
auto info = TRY(build_code_length_info(lit_len_tree.span(), dist_tree.span()));
TRY(m_stream.write_bits(final, 1));
TRY(m_stream.write_bits(2, 2));
TRY(m_stream.write_bits(info.hlit - 257, 5));
TRY(m_stream.write_bits(info.hdist - 1, 5));
TRY(m_stream.write_bits(info.hclen - 4, 4));
for (size_t i = 0; i < info.hclen; i++)
TRY(m_stream.write_bits(info.code_length[i], 3));
for (const auto entry : info.encoding)
{
const auto symbol = info.code_length_tree[entry.symbol];
TRY(m_stream.write_bits(reverse_bits(symbol.code, symbol.length), symbol.length));
TRY(m_stream.write_bits(entry.extra_data, entry.extra_len));
}
for (const auto entry : lz77_entries)
{
switch (entry.type)
{
case LZ77Entry::Type::Literal:
{
const auto symbol = lit_len_tree[entry.as.literal];
TRY(m_stream.write_bits(reverse_bits(symbol.code, symbol.length), symbol.length));
break;
}
case LZ77Entry::Type::DistLength:
{
const auto len_encoding = get_len_encoding(entry.as.dist_length.length);
const auto len_code = lit_len_tree[len_encoding.symbol];
TRY(m_stream.write_bits(reverse_bits(len_code.code, len_code.length), len_code.length));
TRY(m_stream.write_bits(len_encoding.extra_data, len_encoding.extra_len));
const auto dist_encoding = get_dist_encoding(entry.as.dist_length.distance);
const auto dist_code = dist_tree[dist_encoding.symbol];
TRY(m_stream.write_bits(reverse_bits(dist_code.code, dist_code.length), dist_code.length));
TRY(m_stream.write_bits(dist_encoding.extra_data, dist_encoding.extra_len));
break;
}
}
}
const auto end_code = lit_len_tree[256];
TRY(m_stream.write_bits(reverse_bits(end_code.code, end_code.length), end_code.length));
return {};
}
BAN::ErrorOr<BAN::Vector<uint8_t>> Compressor::compress()
{
uint32_t checksum = 0;
switch (m_type)
{
case StreamType::Raw:
break;
case StreamType::Zlib:
TRY(m_stream.write_bits(0x78, 8)); // deflate with 32k window
TRY(m_stream.write_bits(0x9C, 8)); // default compression
checksum = calculate_adler32(m_data);
break;
}
constexpr size_t max_block_size = 16 * 1024;
while (!m_data.empty())
{
const size_t block_size = BAN::Math::min<size_t>(m_data.size(), max_block_size);
TRY(compress_block(m_data.slice(0, block_size), block_size == m_data.size()));
m_data = m_data.slice(block_size);
}
TRY(m_stream.pad_to_byte_boundary());
switch (m_type)
{
case StreamType::Raw:
break;
case StreamType::Zlib:
TRY(m_stream.write_bits(checksum >> 24, 8));
TRY(m_stream.write_bits(checksum >> 16, 8));
TRY(m_stream.write_bits(checksum >> 8, 8));
TRY(m_stream.write_bits(checksum >> 0, 8));
break;
}
return m_stream.take_buffer();
}
}