userspace: Add LibDEFLATE

This can be used to compress and decompress DEFLATE data either in raw
or zlib format
This commit is contained in:
Bananymous 2025-10-26 22:25:11 +02:00
parent 9f0addbd8b
commit fecda6a034
11 changed files with 1385 additions and 0 deletions

View File

@ -1,6 +1,7 @@
set(USERSPACE_LIBRARIES set(USERSPACE_LIBRARIES
LibAudio LibAudio
LibC LibC
LibDEFLATE
LibDL LibDL
LibELF LibELF
LibFont LibFont

View File

@ -0,0 +1,12 @@
set(LIBDEFLATE_SOURCES
Compressor.cpp
Decompressor.cpp
HuffmanTree.cpp
)
add_library(libdeflate ${LIBDEFLATE_SOURCES})
banan_link_library(libdeflate ban)
banan_link_library(libdeflate libc)
banan_install_headers(libdeflate)
install(TARGETS libdeflate OPTIONAL)

View File

@ -0,0 +1,620 @@
#include <LibDEFLATE/Compressor.h>
#include <LibDEFLATE/Utils.h>
#include <BAN/Array.h>
#include <BAN/Heap.h>
#include <BAN/Optional.h>
#include <BAN/Sort.h>
namespace LibDEFLATE
{
constexpr size_t s_max_length = 258;
constexpr size_t s_max_distance = 32768;
constexpr size_t s_max_symbols = 288;
constexpr uint8_t s_max_bits = 15;
struct Leaf
{
uint16_t code;
uint8_t length;
};
static BAN::ErrorOr<void> create_huffman_tree(BAN::Span<const size_t> freq, BAN::Span<Leaf> output)
{
ASSERT(freq.size() <= s_max_symbols);
ASSERT(freq.size() == output.size());
struct node_t
{
size_t symbol;
size_t freq;
node_t* left;
node_t* right;
};
#if LIBDEFLATE_AVOID_STACK
BAN::Vector<node_t*> nodes;
TRY(nodes.resize(s_max_symbols));
#else
BAN::Array<node_t*, s_max_symbols> nodes;
#endif
size_t node_count = 0;
for (size_t sym = 0; sym < freq.size(); sym++)
{
if (freq[sym] == 0)
continue;
nodes[node_count] = static_cast<node_t*>(BAN::allocator(sizeof(node_t)));
if (nodes[node_count] == nullptr)
{
for (size_t j = 0; j < node_count; j++)
BAN::deallocator(nodes[j]);
return BAN::Error::from_errno(ENOMEM);
}
*nodes[node_count++] = {
.symbol = sym,
.freq = freq[sym],
.left = nullptr,
.right = nullptr,
};
}
for (auto& symbol : output)
symbol = { .code = 0, .length = 0 };
if (node_count == 0)
{
output[0] = { .code = 0, .length = 1 };
return {};
}
static void (*free_tree)(node_t*) =
[](node_t* root) -> void {
if (root == nullptr)
return;
free_tree(root->left);
free_tree(root->right);
BAN::deallocator(root);
};
const auto comp =
[](const node_t* a, const node_t* b) -> bool {
if (a->freq != b->freq)
return a->freq > b->freq;
return a->symbol > b->symbol;
};
auto end_it = nodes.begin() + node_count;
BAN::make_heap(nodes.begin(), end_it, comp);
while (nodes.begin() + 1 != end_it)
{
node_t* parent = static_cast<node_t*>(BAN::allocator(sizeof(node_t)));
if (parent == nullptr)
{
for (auto it = nodes.begin(); it != end_it; it++)
free_tree(*it);
return BAN::Error::from_errno(ENOMEM);
}
node_t* node1 = nodes.front();
BAN::pop_heap(nodes.begin(), end_it--, comp);
node_t* node2 = nodes.front();
BAN::pop_heap(nodes.begin(), end_it--, comp);
*parent = {
.symbol = 0,
.freq = node1->freq + node2->freq,
.left = node1,
.right = node2,
};
*end_it++ = parent;
BAN::push_heap(nodes.begin(), end_it, comp);
}
static uint16_t (*gather_lengths)(const node_t*, BAN::Span<Leaf>, uint16_t) =
[](const node_t* node, BAN::Span<Leaf> symbols, uint16_t depth) -> uint16_t {
if (node == nullptr)
return 0;
uint16_t count = (depth > s_max_bits);
if (node->left == nullptr && node->right == nullptr)
symbols[node->symbol].length = BAN::Math::min<uint16_t>(depth, s_max_bits);
else
{
count += gather_lengths(node->left, symbols, depth + 1);
count += gather_lengths(node->right, symbols, depth + 1);
}
return count;
};
const auto too_long_count = gather_lengths(nodes[0], output, 0);
free_tree(nodes[0]);
uint16_t bl_count[s_max_bits + 1] {};
for (size_t sym = 0; sym < freq.size(); sym++)
if (const uint8_t len = output[sym].length)
bl_count[len]++;
if (too_long_count > 0)
{
for (size_t i = 0; i < too_long_count / 2; i++)
{
uint16_t bits = s_max_bits - 1;
while (bl_count[bits] == 0)
bits--;
bl_count[bits + 0]--;
bl_count[bits + 1] += 2;
bl_count[s_max_bits]--;
}
struct SymFreq
{
size_t symbol;
size_t freq;
};
BAN::Vector<SymFreq> sym_freq;
for (size_t sym = 0; sym < output.size(); sym++)
if (freq[sym] != 0)
TRY(sym_freq.push_back({ .symbol = sym, .freq = freq[sym] }));
BAN::sort::sort(sym_freq.begin(), sym_freq.end(),
[](auto a, auto b) { return a.freq < b.freq; }
);
size_t index = 0;
for (uint16_t bits = s_max_bits; bits > 0; bits--)
for (size_t i = 0; i < bl_count[bits]; i++)
output[sym_freq[index++].symbol].length = bits;
ASSERT(index == sym_freq.size());
}
uint16_t next_code[s_max_bits + 1] {};
uint16_t code = 0;
for (uint8_t bits = 1; bits <= s_max_bits; bits++)
{
code = (code + bl_count[bits - 1]) << 1;
next_code[bits] = code;
}
for (size_t sym = 0; sym < freq.size(); sym++)
if (const uint16_t len = output[sym].length)
output[sym].code = next_code[len]++;
return {};
}
struct Encoding
{
uint16_t symbol;
uint16_t extra_data { 0 };
uint8_t extra_len { 0 };
};
static constexpr Encoding get_len_encoding(uint16_t length)
{
ASSERT(3 <= length && length <= s_max_length);
constexpr uint16_t base[] {
3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258
};
constexpr uint8_t extra_bits[] {
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0
};
constexpr size_t count = sizeof(base) / sizeof(*base);
for (size_t i = 0;; i++)
{
if (i + 1 < count && length >= base[i + 1])
continue;
return {
.symbol = static_cast<uint16_t>(257 + i),
.extra_data = static_cast<uint16_t>(length - base[i]),
.extra_len = extra_bits[i],
};
}
}
static constexpr Encoding get_dist_encoding(uint16_t distance)
{
ASSERT(1 <= distance && distance <= s_max_distance);
constexpr uint16_t base[] {
1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577
};
constexpr uint8_t extra_bits[] {
0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13
};
constexpr size_t count = sizeof(base) / sizeof(*base);
for (size_t i = 0;; i++)
{
if (i + 1 < count && distance >= base[i + 1])
continue;
return {
.symbol = static_cast<uint16_t>(i),
.extra_data = static_cast<uint16_t>(distance - base[i]),
.extra_len = extra_bits[i],
};
}
}
static void get_frequencies(BAN::Span<const Compressor::LZ77Entry> entries, BAN::Span<size_t> lit_len_freq, BAN::Span<size_t> dist_freq)
{
ASSERT(lit_len_freq.size() == 286);
ASSERT(dist_freq.size() == 30);
for (auto entry : entries)
{
switch (entry.type)
{
case Compressor::LZ77Entry::Type::Literal:
lit_len_freq[entry.as.literal]++;
break;
case Compressor::LZ77Entry::Type::DistLength:
lit_len_freq[get_len_encoding(entry.as.dist_length.length).symbol]++;
dist_freq[get_dist_encoding(entry.as.dist_length.distance).symbol]++;
break;
}
}
lit_len_freq[256]++;
}
struct CodeLengthInfo
{
uint16_t hlit;
uint8_t hdist;
uint8_t hclen;
BAN::Vector<Encoding> encoding;
BAN::Array<uint8_t, 19> code_length;
BAN::Array<Leaf, 19> code_length_tree;
};
static BAN::ErrorOr<CodeLengthInfo> build_code_length_info(BAN::Span<const Leaf> lit_len_tree, BAN::Span<const Leaf> dist_tree)
{
CodeLengthInfo result;
const auto append_tree =
[&result](BAN::Span<const Leaf>& tree) -> BAN::ErrorOr<void>
{
while (!tree.empty() && tree[tree.size() - 1].length == 0)
tree = tree.slice(0, tree.size() - 1);
for (size_t i = 0; i < tree.size();)
{
size_t count = 1;
while (i + count < tree.size() && tree[i].length == tree[i + count].length)
count++;
if (tree[i].length == 0)
{
if (count > 138)
count = 138;
if (count < 3)
{
for (size_t j = 0; j < count; j++)
TRY(result.encoding.push_back({ .symbol = 0 }));
}
else if (count < 11)
{
TRY(result.encoding.push_back({
.symbol = 17,
.extra_data = static_cast<uint8_t>(count - 3),
.extra_len = 3,
}));
}
else
{
TRY(result.encoding.push_back({
.symbol = 18,
.extra_data = static_cast<uint8_t>(count - 11),
.extra_len = 7,
}));
}
}
else
{
if (count >= 3 && !result.encoding.empty() && result.encoding.back().symbol == tree[i].length)
{
if (count > 6)
count = 6;
TRY(result.encoding.push_back({
.symbol = 16,
.extra_data = static_cast<uint8_t>(count - 3),
.extra_len = 2,
}));
}
else
{
count = 1;
TRY(result.encoding.push_back({ .symbol = tree[i].length }));
}
}
i += count;
}
return {};
};
TRY(append_tree(lit_len_tree));
result.hlit = lit_len_tree.size();
TRY(append_tree(dist_tree));
result.hdist = dist_tree.size();
BAN::Array<size_t, 19> code_len_freq(0);
for (auto entry : result.encoding)
code_len_freq[entry.symbol]++;
TRY(create_huffman_tree(code_len_freq.span(), result.code_length_tree.span()));
constexpr uint8_t code_length_order[] {
16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15
};
for (size_t i = 0; i < result.code_length_tree.size(); i++)
result.code_length[i] = result.code_length_tree[code_length_order[i]].length;
result.hclen = 19;
while (result.hclen > 4 && result.code_length[result.hclen - 1] == 0)
result.hclen--;
return BAN::move(result);
}
uint32_t Compressor::get_hash_key(BAN::ConstByteSpan needle) const
{
ASSERT(needle.size() >= 3);
return (needle[2] << 16) | (needle[1] << 8) | needle[0];
}
BAN::ErrorOr<void> Compressor::update_hash_chain(size_t count)
{
if (m_hash_chain.size() >= s_max_distance * 2)
{
const uint8_t* current = m_data.data() + m_hash_chain_index + count;
for (auto& [_, chain] : m_hash_chain)
{
for (auto it = chain.begin(); it != chain.end(); it++)
{
const size_t distance = current - it->data();
if (distance < s_max_distance)
continue;
while (it != chain.end())
it = chain.remove(it);
break;
}
}
}
for (size_t i = 0; i < count; i++)
{
auto slice = m_data.slice(m_hash_chain_index + i);
if (slice.size() < 3)
break;
const uint32_t key = get_hash_key(slice);
auto it = m_hash_chain.find(key);
if (it != m_hash_chain.end())
TRY(it->value.insert(it->value.begin(), slice));
else
{
HashChain new_chain;
TRY(new_chain.push_back(slice));
TRY(m_hash_chain.insert(key, BAN::move(new_chain)));
}
}
m_hash_chain_index += count;
return {};
}
BAN::ErrorOr<Compressor::LZ77Entry> Compressor::find_longest_match(BAN::ConstByteSpan needle) const
{
LZ77Entry result = {
.type = LZ77Entry::Type::Literal,
.as = { .literal = needle[0] }
};
if (needle.size() < 3)
return result;
const uint32_t key = get_hash_key(needle);
auto it = m_hash_chain.find(key);
if (it == m_hash_chain.end())
return result;
auto& chain = it->value;
for (const auto node : chain)
{
const size_t distance = needle.data() - node.data();
if (distance > s_max_distance)
break;
size_t length = 3;
const size_t max_length = BAN::Math::min(needle.size(), s_max_length);
while (length < max_length && needle[length] == node[length])
length++;
if (result.type != LZ77Entry::Type::DistLength || length > result.as.dist_length.length)
{
result = LZ77Entry {
.type = LZ77Entry::Type::DistLength,
.as = {
.dist_length = {
.length = static_cast<uint16_t>(length),
.distance = static_cast<uint16_t>(distance),
}
}
};
}
}
return result;
}
BAN::ErrorOr<BAN::Vector<Compressor::LZ77Entry>> Compressor::lz77_compress(BAN::ConstByteSpan data)
{
BAN::Vector<LZ77Entry> result;
size_t advance = 0;
for (size_t i = 0; i < data.size(); i += advance)
{
TRY(update_hash_chain(advance));
auto match = TRY(find_longest_match(data.slice(i)));
if (match.type == LZ77Entry::Type::Literal)
{
TRY(result.push_back(match));
advance = 1;
continue;
}
ASSERT(match.type == LZ77Entry::Type::DistLength);
auto lazy_match = TRY(find_longest_match(data.slice(i + 1)));
if (lazy_match.type == LZ77Entry::Type::DistLength && lazy_match.as.dist_length.length > match.as.dist_length.length)
{
TRY(result.push_back({ .type = LZ77Entry::Type::Literal, .as = { .literal = data[i] }}));
TRY(result.push_back(lazy_match));
advance = 1 + lazy_match.as.dist_length.length;
}
else
{
TRY(result.push_back(match));
advance = match.as.dist_length.length;
}
}
return result;
}
BAN::ErrorOr<void> Compressor::compress_block(BAN::ConstByteSpan data, bool final)
{
// FIXME: use fixed trees or uncompressed blocks
auto lz77_entries = TRY(lz77_compress(data));
#if LIBDEFLATE_AVOID_STACK
BAN::Vector<size_t> lit_len_freq, dist_freq;
TRY(lit_len_freq.resize(286, 0));
TRY(dist_freq.resize(30, 0));
#else
BAN::Array<size_t, 286> lit_len_freq(0);
BAN::Array<size_t, 30> dist_freq(0);
#endif
get_frequencies(lz77_entries.span(), lit_len_freq.span(), dist_freq.span());
#if LIBDEFLATE_AVOID_STACK
BAN::Vector<Leaf> lit_len_tree, dist_tree;
TRY(lit_len_tree.resize(286));
TRY(dist_tree.resize(30));
#else
BAN::Array<Leaf, 286> lit_len_tree;
BAN::Array<Leaf, 30> dist_tree;
#endif
TRY(create_huffman_tree(lit_len_freq.span(), lit_len_tree.span()));
TRY(create_huffman_tree(dist_freq.span(), dist_tree.span()));
auto info = TRY(build_code_length_info(lit_len_tree.span(), dist_tree.span()));
TRY(m_stream.write_bits(final, 1));
TRY(m_stream.write_bits(2, 2));
TRY(m_stream.write_bits(info.hlit - 257, 5));
TRY(m_stream.write_bits(info.hdist - 1, 5));
TRY(m_stream.write_bits(info.hclen - 4, 4));
for (size_t i = 0; i < info.hclen; i++)
TRY(m_stream.write_bits(info.code_length[i], 3));
for (const auto entry : info.encoding)
{
const auto symbol = info.code_length_tree[entry.symbol];
TRY(m_stream.write_bits(reverse_bits(symbol.code, symbol.length), symbol.length));
TRY(m_stream.write_bits(entry.extra_data, entry.extra_len));
}
for (const auto entry : lz77_entries)
{
switch (entry.type)
{
case LZ77Entry::Type::Literal:
{
const auto symbol = lit_len_tree[entry.as.literal];
TRY(m_stream.write_bits(reverse_bits(symbol.code, symbol.length), symbol.length));
break;
}
case LZ77Entry::Type::DistLength:
{
const auto len_encoding = get_len_encoding(entry.as.dist_length.length);
const auto len_code = lit_len_tree[len_encoding.symbol];
TRY(m_stream.write_bits(reverse_bits(len_code.code, len_code.length), len_code.length));
TRY(m_stream.write_bits(len_encoding.extra_data, len_encoding.extra_len));
const auto dist_encoding = get_dist_encoding(entry.as.dist_length.distance);
const auto dist_code = dist_tree[dist_encoding.symbol];
TRY(m_stream.write_bits(reverse_bits(dist_code.code, dist_code.length), dist_code.length));
TRY(m_stream.write_bits(dist_encoding.extra_data, dist_encoding.extra_len));
break;
}
}
}
const auto end_code = lit_len_tree[256];
TRY(m_stream.write_bits(reverse_bits(end_code.code, end_code.length), end_code.length));
return {};
}
BAN::ErrorOr<BAN::Vector<uint8_t>> Compressor::compress()
{
uint32_t checksum = 0;
switch (m_type)
{
case StreamType::Raw:
break;
case StreamType::Zlib:
TRY(m_stream.write_bits(0x78, 8)); // deflate with 32k window
TRY(m_stream.write_bits(0x9C, 8)); // default compression
checksum = calculate_adler32(m_data);
break;
}
constexpr size_t max_block_size = 16 * 1024;
while (!m_data.empty())
{
const size_t block_size = BAN::Math::min<size_t>(m_data.size(), max_block_size);
TRY(compress_block(m_data.slice(0, block_size), block_size == m_data.size()));
m_data = m_data.slice(block_size);
}
TRY(m_stream.pad_to_byte_boundary());
switch (m_type)
{
case StreamType::Raw:
break;
case StreamType::Zlib:
TRY(m_stream.write_bits(checksum >> 24, 8));
TRY(m_stream.write_bits(checksum >> 16, 8));
TRY(m_stream.write_bits(checksum >> 8, 8));
TRY(m_stream.write_bits(checksum >> 0, 8));
break;
}
return m_stream.take_buffer();
}
}

View File

@ -0,0 +1,277 @@
#include <LibDEFLATE/Decompressor.h>
#include <LibDEFLATE/Utils.h>
namespace LibDEFLATE
{
union ZLibHeader
{
struct
{
uint8_t cm : 4;
uint8_t cinfo : 4;
uint8_t fcheck : 5;
uint8_t fdict : 1;
uint8_t flevel : 2;
};
struct
{
uint8_t raw1;
uint8_t raw2;
};
};
BAN::ErrorOr<uint16_t> Decompressor::read_symbol(const HuffmanTree& tree)
{
const uint8_t instant_bits = tree.instant_bits();
uint16_t code = reverse_bits(TRY(m_stream.peek_bits(instant_bits)), instant_bits);
if (auto symbol = tree.get_symbol_instant(code); symbol.has_value())
{
MUST(m_stream.take_bits(symbol->len));
return symbol->symbol;
}
MUST(m_stream.take_bits(instant_bits));
uint8_t len = instant_bits;
while (len < tree.max_bits())
{
code = (code << 1) | TRY(m_stream.take_bits(1));
len++;
if (auto symbol = tree.get_symbol(code, len); symbol.has_value())
return symbol.value();
}
return BAN::Error::from_errno(EINVAL);
}
BAN::ErrorOr<void> Decompressor::inflate_block(const HuffmanTree& length_tree, const HuffmanTree& distance_tree)
{
uint16_t symbol;
while ((symbol = TRY(read_symbol(length_tree))) != 256)
{
if (symbol < 256)
{
TRY(m_output.push_back(symbol));
continue;
}
constexpr uint16_t length_base[] {
3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258
};
constexpr uint8_t length_extra_bits[] {
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0
};
constexpr uint16_t distance_base[] {
1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577
};
constexpr uint8_t distance_extra_bits[] {
0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13
};
if (symbol > 285)
return BAN::Error::from_errno(EINVAL);
symbol -= 257;
const uint16_t length = length_base[symbol] + TRY(m_stream.take_bits(length_extra_bits[symbol]));
uint16_t distance_code;
if (distance_tree.empty())
distance_code = reverse_bits(TRY(m_stream.take_bits(5)), 5);
else
distance_code = TRY(read_symbol(distance_tree));
if (distance_code > 29)
return BAN::Error::from_errno(EINVAL);
const uint16_t distance = distance_base[distance_code] + TRY(m_stream.take_bits(distance_extra_bits[distance_code]));
const size_t orig_size = m_output.size();
const size_t offset = orig_size - distance;
TRY(m_output.resize(orig_size + length));
for (size_t i = 0; i < length; i++)
m_output[orig_size + i] = m_output[offset + i];
}
return {};
}
BAN::ErrorOr<void> Decompressor::handle_header()
{
switch (m_type)
{
case StreamType::Raw:
return {};
case StreamType::Zlib:
{
ZLibHeader header;
header.raw1 = TRY(m_stream.take_bits(8));
header.raw2 = TRY(m_stream.take_bits(8));
if (((header.raw1 << 8) | header.raw2) % 31)
{
dwarnln("zlib header checksum failed");
return BAN::Error::from_errno(EINVAL);
}
if (header.cm != 8)
{
dwarnln("zlib does not use DEFLATE");
return BAN::Error::from_errno(EINVAL);
}
if (header.fdict)
{
TRY(m_stream.take_bits(16));
TRY(m_stream.take_bits(16));
}
return {};
}
}
ASSERT_NOT_REACHED();
}
BAN::ErrorOr<void> Decompressor::handle_footer()
{
switch (m_type)
{
case StreamType::Raw:
return {};
case StreamType::Zlib:
{
m_stream.skip_to_byte_boundary();
uint32_t adler32 = 0;
for (size_t i = 0; i < 4; i++)
adler32 = (adler32 << 8) | TRY(m_stream.take_bits(8));
if (adler32 != calculate_adler32(m_output.span()))
{
dwarnln("zlib final adler32 checksum failed");
return BAN::Error::from_errno(EINVAL);
}
return {};
}
}
ASSERT_NOT_REACHED();
}
BAN::ErrorOr<void> Decompressor::decompress_type0()
{
m_stream.skip_to_byte_boundary();
const uint16_t len = TRY(m_stream.take_bits(16));
const uint16_t nlen = TRY(m_stream.take_bits(16));
if (len != 0xFFFF - nlen)
return BAN::Error::from_errno(EINVAL);
const size_t orig_size = m_output.size();
TRY(m_output.resize(orig_size + len));
TRY(m_stream.take_byte_aligned(&m_output[orig_size], len));
return {};
}
BAN::ErrorOr<void> Decompressor::decompress_type1()
{
if (!m_fixed_tree.has_value())
m_fixed_tree = TRY(HuffmanTree::fixed_tree());
TRY(inflate_block(m_fixed_tree.value(), {}));
return {};
}
BAN::ErrorOr<void> Decompressor::decompress_type2()
{
constexpr uint8_t code_length_order[] {
16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15
};
const uint16_t hlit = TRY(m_stream.take_bits(5)) + 257;
const uint8_t hdist = TRY(m_stream.take_bits(5)) + 1;
const uint8_t hclen = TRY(m_stream.take_bits(4)) + 4;
uint8_t code_lengths[19] {};
for (size_t i = 0; i < hclen; i++)
code_lengths[code_length_order[i]] = TRY(m_stream.take_bits(3));
const auto code_length_tree = TRY(HuffmanTree::create({ code_lengths, 19 }));
uint8_t bit_lengths[286 + 32] {};
size_t bit_lengths_len = 0;
uint16_t last_symbol = 0;
while (bit_lengths_len < hlit + hdist)
{
uint16_t symbol = TRY(read_symbol(code_length_tree));
if (symbol > 18)
return BAN::Error::from_errno(EINVAL);
uint8_t count;
if (symbol <= 15)
{
count = 1;
}
else if (symbol == 16)
{
symbol = last_symbol;
count = TRY(m_stream.take_bits(2)) + 3;
}
else if (symbol == 17)
{
symbol = 0;
count = TRY(m_stream.take_bits(3)) + 3;
}
else
{
symbol = 0;
count = TRY(m_stream.take_bits(7)) + 11;
}
ASSERT(bit_lengths_len + count <= hlit + hdist);
for (uint8_t i = 0; i < count; i++)
bit_lengths[bit_lengths_len++] = symbol;
last_symbol = symbol;
}
TRY(inflate_block(
TRY(HuffmanTree::create({ bit_lengths, hlit })),
TRY(HuffmanTree::create({ bit_lengths + hlit, hdist }))
));
return {};
}
BAN::ErrorOr<BAN::Vector<uint8_t>> Decompressor::decompress()
{
TRY(handle_header());
bool bfinal = false;
while (!bfinal)
{
bfinal = TRY(m_stream.take_bits(1));
switch (TRY(m_stream.take_bits(2)))
{
case 0b00:
TRY(decompress_type0());
break;
case 0b01:
TRY(decompress_type1());
break;
case 0b10:
TRY(decompress_type2());
break;
default:
return BAN::Error::from_errno(EINVAL);
}
}
TRY(handle_footer());
return BAN::move(m_output);
}
}

View File

@ -0,0 +1,141 @@
#include <LibDEFLATE/HuffmanTree.h>
namespace LibDEFLATE
{
HuffmanTree& HuffmanTree::operator=(HuffmanTree&& other)
{
m_instant_bits = other.m_instant_bits;
m_min_bits = other.m_min_bits;
m_max_bits = other.m_max_bits;
m_instant = BAN::move(other.m_instant);
m_min_code = BAN::move(other.m_min_code);
m_slow_table = BAN::move(other.m_slow_table);
return *this;
}
BAN::ErrorOr<HuffmanTree> HuffmanTree::create(BAN::Span<const uint8_t> bit_lengths)
{
HuffmanTree result;
TRY(result.initialize(bit_lengths));
return result;
}
BAN::ErrorOr<void> HuffmanTree::initialize(BAN::Span<const uint8_t> bit_lengths)
{
m_max_bits = 0;
m_min_bits = MAX_BITS;
uint16_t max_sym = 0;
uint16_t bl_count[MAX_BITS + 1] {};
for (size_t sym = 0; sym < bit_lengths.size(); sym++)
{
if (bit_lengths[sym] == 0)
continue;
m_max_bits = BAN::Math::max(bit_lengths[sym], m_max_bits);
m_min_bits = BAN::Math::min(bit_lengths[sym], m_min_bits);
bl_count[bit_lengths[sym]]++;
max_sym = sym;
}
uint16_t next_code[MAX_BITS + 1] {};
uint16_t code = 0;
for (uint8_t bits = 1; bits <= MAX_BITS; bits++)
{
code = (code + bl_count[bits - 1]) << 1;
next_code[bits] = code;
m_min_code[bits] = code;
}
BAN::Vector<Leaf> tree;
TRY(tree.resize(max_sym + 1, { .code = 0, .len = 0 }));
for (uint16_t sym = 0; sym <= max_sym; sym++)
{
tree[sym].len = bit_lengths[sym];
if (const uint8_t len = tree[sym].len)
tree[sym].code = next_code[len]++;
}
TRY(build_instant_table(tree.span()));
TRY(build_slow_table(tree.span()));
return {};
}
BAN::ErrorOr<void> HuffmanTree::build_instant_table(BAN::Span<const Leaf> tree)
{
m_instant_bits = BAN::Math::min<uint8_t>(9, m_max_bits);
TRY(m_instant.resize(1 << m_instant_bits, {}));
for (uint16_t sym = 0; sym < tree.size(); sym++)
{
if (tree[sym].len == 0 || tree[sym].len > m_instant_bits)
continue;
const uint16_t code = tree[sym].code;
const uint16_t shift = m_instant_bits - tree[sym].len;
for (uint16_t j = code << shift; j < (code + 1) << shift; j++)
m_instant[j] = { sym, tree[sym].len };
}
return {};
}
BAN::ErrorOr<void> HuffmanTree::build_slow_table(BAN::Span<const Leaf> tree)
{
TRY(m_slow_table.resize(MAX_BITS + 1));
for (uint16_t sym = 0; sym < tree.size(); sym++)
{
const auto leaf = tree[sym];
if (leaf.len == 0)
continue;
const size_t offset = leaf.code - m_min_code[leaf.len];
if (offset >= m_slow_table[leaf.len].size())
TRY(m_slow_table[leaf.len].resize(offset + 1));
m_slow_table[leaf.len][offset] = sym;
}
return {};
}
BAN::ErrorOr<HuffmanTree> HuffmanTree::fixed_tree()
{
struct BitLengths
{
consteval BitLengths()
{
size_t i = 0;
for (; i <= 143; i++) values[i] = 8;
for (; i <= 255; i++) values[i] = 9;
for (; i <= 279; i++) values[i] = 7;
for (; i <= 287; i++) values[i] = 8;
}
BAN::Array<uint8_t, 288> values;
};
static constexpr BitLengths bit_lengths;
return TRY(HuffmanTree::create(bit_lengths.values.span()));
}
BAN::Optional<HuffmanTree::Instant> HuffmanTree::get_symbol_instant(uint16_t code) const
{
ASSERT(code < m_instant.size());
if (const auto entry = m_instant[code]; entry.len)
return entry;
return {};
}
BAN::Optional<uint16_t> HuffmanTree::get_symbol(uint16_t code, uint8_t len) const
{
ASSERT(len <= m_max_bits);
const auto& symbols = m_slow_table[len];
const size_t offset = code - m_min_code[len];
if (symbols.size() <= offset)
return {};
return symbols[offset];
}
}

View File

@ -0,0 +1,118 @@
#pragma once
#include <BAN/Vector.h>
#include <BAN/ByteSpan.h>
namespace LibDEFLATE
{
class BitInputStream
{
public:
BitInputStream(BAN::ConstByteSpan data)
: m_data(data)
{ }
BAN::ErrorOr<uint16_t> peek_bits(size_t count)
{
ASSERT(count <= 16);
while (m_bit_buffer_len < count)
{
if (m_data.empty())
return BAN::Error::from_errno(ENOBUFS);
m_bit_buffer |= m_data[0] << m_bit_buffer_len;
m_bit_buffer_len += 8;
m_data = m_data.slice(1);
}
return m_bit_buffer & ((1 << count) - 1);
}
BAN::ErrorOr<uint16_t> take_bits(size_t count)
{
const uint16_t result = TRY(peek_bits(count));
m_bit_buffer >>= count;
m_bit_buffer_len -= count;
return result;
}
BAN::ErrorOr<void> take_byte_aligned(uint8_t* output, size_t bytes)
{
ASSERT(m_bit_buffer % 8 == 0);
while (m_bit_buffer_len && bytes)
{
*output++ = m_bit_buffer;
m_bit_buffer >>= 8;
m_bit_buffer_len -= 8;
bytes--;
}
if (bytes > m_data.size())
return BAN::Error::from_errno(EINVAL);
memcpy(output, m_data.data(), bytes);
m_data = m_data.slice(bytes);
return {};
}
void skip_to_byte_boundary()
{
const size_t bits_to_remove = m_bit_buffer_len % 8;
m_bit_buffer >>= bits_to_remove;
m_bit_buffer_len -= bits_to_remove;
}
private:
BAN::ConstByteSpan m_data;
uint32_t m_bit_buffer { 0 };
uint8_t m_bit_buffer_len { 0 };
};
class BitOutputStream
{
public:
BAN::ErrorOr<void> write_bits(uint16_t value, size_t count)
{
ASSERT(m_bit_buffer_len < 8);
ASSERT(count <= 16);
const uint16_t mask = (1 << count) - 1;
m_bit_buffer |= (value & mask) << m_bit_buffer_len;
m_bit_buffer_len += count;
while (m_bit_buffer_len >= 8)
{
TRY(m_data.push_back(m_bit_buffer));
m_bit_buffer >>= 8;
m_bit_buffer_len -= 8;
}
return {};
}
BAN::ErrorOr<void> pad_to_byte_boundary()
{
ASSERT(m_bit_buffer_len < 8);
if (m_bit_buffer_len == 0)
return {};
TRY(m_data.push_back(m_bit_buffer));
m_bit_buffer = 0;
m_bit_buffer_len = 0;
return {};
}
BAN::Vector<uint8_t> take_buffer()
{
ASSERT(m_bit_buffer_len == 0);
return BAN::move(m_data);
}
private:
BAN::Vector<uint8_t> m_data;
uint32_t m_bit_buffer { 0 };
uint8_t m_bit_buffer_len { 0 };
};
}

View File

@ -0,0 +1,67 @@
#pragma once
#include <BAN/ByteSpan.h>
#include <BAN/HashMap.h>
#include <BAN/LinkedList.h>
#include <BAN/NoCopyMove.h>
#include <BAN/Vector.h>
#include <LibDEFLATE/BitStream.h>
#include <LibDEFLATE/StreamType.h>
namespace LibDEFLATE
{
class Compressor
{
BAN_NON_COPYABLE(Compressor);
BAN_NON_MOVABLE(Compressor);
public:
using HashChain = BAN::LinkedList<BAN::ConstByteSpan>;
struct LZ77Entry
{
enum class Type
{
Literal,
DistLength,
} type;
union
{
uint8_t literal;
struct
{
uint16_t length;
uint16_t distance;
} dist_length;
} as;
};
public:
Compressor(BAN::ConstByteSpan data, StreamType type)
: m_type(type)
, m_data(data)
{ }
BAN::ErrorOr<BAN::Vector<uint8_t>> compress();
private:
BAN::ErrorOr<void> compress_block(BAN::ConstByteSpan, bool final);
uint32_t get_hash_key(BAN::ConstByteSpan needle) const;
BAN::ErrorOr<void> update_hash_chain(size_t count);
BAN::ErrorOr<LZ77Entry> find_longest_match(BAN::ConstByteSpan needle) const;
BAN::ErrorOr<BAN::Vector<LZ77Entry>> lz77_compress(BAN::ConstByteSpan data);
private:
const StreamType m_type;
BAN::ConstByteSpan m_data;
BitOutputStream m_stream;
size_t m_hash_chain_index { 0 };
BAN::HashMap<uint32_t, HashChain> m_hash_chain;
};
}

View File

@ -0,0 +1,46 @@
#pragma once
#include <BAN/ByteSpan.h>
#include <BAN/NoCopyMove.h>
#include <BAN/Vector.h>
#include <LibDEFLATE/BitStream.h>
#include <LibDEFLATE/HuffmanTree.h>
#include <LibDEFLATE/StreamType.h>
namespace LibDEFLATE
{
class Decompressor
{
BAN_NON_COPYABLE(Decompressor);
BAN_NON_MOVABLE(Decompressor);
public:
Decompressor(BAN::ConstByteSpan data, StreamType type)
: m_type(type)
, m_stream(data)
{ }
BAN::ErrorOr<BAN::Vector<uint8_t>> decompress();
private:
BAN::ErrorOr<uint16_t> read_symbol(const HuffmanTree& tree);
BAN::ErrorOr<void> inflate_block(const HuffmanTree& length_tree, const HuffmanTree& distance_tree);
BAN::ErrorOr<void> decompress_type0();
BAN::ErrorOr<void> decompress_type1();
BAN::ErrorOr<void> decompress_type2();
BAN::ErrorOr<void> handle_header();
BAN::ErrorOr<void> handle_footer();
private:
const StreamType m_type;
BitInputStream m_stream;
BAN::Vector<uint8_t> m_output;
BAN::Optional<HuffmanTree> m_fixed_tree;
};
}

View File

@ -0,0 +1,61 @@
#pragma once
#include <BAN/Array.h>
#include <BAN/NoCopyMove.h>
#include <BAN/Optional.h>
#include <BAN/Vector.h>
namespace LibDEFLATE
{
class HuffmanTree
{
BAN_NON_COPYABLE(HuffmanTree);
public:
static constexpr uint8_t MAX_BITS = 15;
struct Leaf
{
uint16_t code;
uint8_t len;
};
struct Instant
{
uint16_t symbol;
uint8_t len;
};
HuffmanTree() {}
HuffmanTree(HuffmanTree&& other) { *this = BAN::move(other); }
HuffmanTree& operator=(HuffmanTree&& other);
static BAN::ErrorOr<HuffmanTree> create(BAN::Span<const uint8_t> bit_lengths);
static BAN::ErrorOr<HuffmanTree> fixed_tree();
BAN::Optional<Instant> get_symbol_instant(uint16_t code) const;
BAN::Optional<uint16_t> get_symbol(uint16_t code, uint8_t len) const;
uint8_t instant_bits() const { return m_instant_bits; }
uint8_t min_bits() const { return m_min_bits; }
uint8_t max_bits() const { return m_max_bits; }
bool empty() const { return m_min_bits == 0; }
private:
BAN::ErrorOr<void> initialize(BAN::Span<const uint8_t> bit_lengths);
BAN::ErrorOr<void> build_instant_table(BAN::Span<const Leaf> tree);
BAN::ErrorOr<void> build_slow_table(BAN::Span<const Leaf> tree);
private:
uint8_t m_instant_bits { 0 };
uint8_t m_min_bits { 0 };
uint8_t m_max_bits { 0 };
BAN::Vector<Instant> m_instant;
BAN::Array<uint16_t, MAX_BITS + 1> m_min_code;
BAN::Vector<BAN::Vector<uint16_t>> m_slow_table;
};
}

View File

@ -0,0 +1,12 @@
#pragma once
namespace LibDEFLATE
{
enum class StreamType
{
Raw,
Zlib,
};
}

View File

@ -0,0 +1,30 @@
#pragma once
#include <BAN/ByteSpan.h>
namespace LibDEFLATE
{
inline uint32_t calculate_adler32(BAN::ConstByteSpan data)
{
uint32_t s1 = 1;
uint32_t s2 = 0;
for (size_t i = 0; i < data.size(); i++)
{
s1 = (s1 + data[i]) % 65521;
s2 = (s2 + s1) % 65521;
}
return (s2 << 16) | s1;
}
inline constexpr uint16_t reverse_bits(uint16_t value, size_t count)
{
uint16_t reverse = 0;
for (uint8_t bit = 0; bit < count; bit++)
reverse |= ((value >> bit) & 1) << (count - bit - 1);
return reverse;
}
}