From fecda6a03483feefd78b155cfb31695406007211 Mon Sep 17 00:00:00 2001 From: Bananymous Date: Sun, 26 Oct 2025 22:25:11 +0200 Subject: [PATCH] userspace: Add LibDEFLATE This can be used to compress and decompress DEFLATE data either in raw or zlib format --- userspace/libraries/CMakeLists.txt | 1 + userspace/libraries/LibDEFLATE/CMakeLists.txt | 12 + userspace/libraries/LibDEFLATE/Compressor.cpp | 620 ++++++++++++++++++ .../libraries/LibDEFLATE/Decompressor.cpp | 277 ++++++++ .../libraries/LibDEFLATE/HuffmanTree.cpp | 141 ++++ .../LibDEFLATE/include/LibDEFLATE/BitStream.h | 118 ++++ .../include/LibDEFLATE/Compressor.h | 67 ++ .../include/LibDEFLATE/Decompressor.h | 46 ++ .../include/LibDEFLATE/HuffmanTree.h | 61 ++ .../include/LibDEFLATE/StreamType.h | 12 + .../LibDEFLATE/include/LibDEFLATE/Utils.h | 30 + 11 files changed, 1385 insertions(+) create mode 100644 userspace/libraries/LibDEFLATE/CMakeLists.txt create mode 100644 userspace/libraries/LibDEFLATE/Compressor.cpp create mode 100644 userspace/libraries/LibDEFLATE/Decompressor.cpp create mode 100644 userspace/libraries/LibDEFLATE/HuffmanTree.cpp create mode 100644 userspace/libraries/LibDEFLATE/include/LibDEFLATE/BitStream.h create mode 100644 userspace/libraries/LibDEFLATE/include/LibDEFLATE/Compressor.h create mode 100644 userspace/libraries/LibDEFLATE/include/LibDEFLATE/Decompressor.h create mode 100644 userspace/libraries/LibDEFLATE/include/LibDEFLATE/HuffmanTree.h create mode 100644 userspace/libraries/LibDEFLATE/include/LibDEFLATE/StreamType.h create mode 100644 userspace/libraries/LibDEFLATE/include/LibDEFLATE/Utils.h diff --git a/userspace/libraries/CMakeLists.txt b/userspace/libraries/CMakeLists.txt index fc205f26..e9753a61 100644 --- a/userspace/libraries/CMakeLists.txt +++ b/userspace/libraries/CMakeLists.txt @@ -1,6 +1,7 @@ set(USERSPACE_LIBRARIES LibAudio LibC + LibDEFLATE LibDL LibELF LibFont diff --git a/userspace/libraries/LibDEFLATE/CMakeLists.txt b/userspace/libraries/LibDEFLATE/CMakeLists.txt new file mode 100644 index 00000000..8e5b4a1a --- /dev/null +++ b/userspace/libraries/LibDEFLATE/CMakeLists.txt @@ -0,0 +1,12 @@ +set(LIBDEFLATE_SOURCES + Compressor.cpp + Decompressor.cpp + HuffmanTree.cpp +) + +add_library(libdeflate ${LIBDEFLATE_SOURCES}) +banan_link_library(libdeflate ban) +banan_link_library(libdeflate libc) + +banan_install_headers(libdeflate) +install(TARGETS libdeflate OPTIONAL) diff --git a/userspace/libraries/LibDEFLATE/Compressor.cpp b/userspace/libraries/LibDEFLATE/Compressor.cpp new file mode 100644 index 00000000..4664d041 --- /dev/null +++ b/userspace/libraries/LibDEFLATE/Compressor.cpp @@ -0,0 +1,620 @@ +#include +#include + +#include +#include +#include +#include + +namespace LibDEFLATE +{ + + constexpr size_t s_max_length = 258; + constexpr size_t s_max_distance = 32768; + + constexpr size_t s_max_symbols = 288; + constexpr uint8_t s_max_bits = 15; + + struct Leaf + { + uint16_t code; + uint8_t length; + }; + + static BAN::ErrorOr create_huffman_tree(BAN::Span freq, BAN::Span output) + { + ASSERT(freq.size() <= s_max_symbols); + ASSERT(freq.size() == output.size()); + + struct node_t + { + size_t symbol; + size_t freq; + node_t* left; + node_t* right; + }; + +#if LIBDEFLATE_AVOID_STACK + BAN::Vector nodes; + TRY(nodes.resize(s_max_symbols)); +#else + BAN::Array nodes; +#endif + + size_t node_count = 0; + for (size_t sym = 0; sym < freq.size(); sym++) + { + if (freq[sym] == 0) + continue; + nodes[node_count] = static_cast(BAN::allocator(sizeof(node_t))); + if (nodes[node_count] == nullptr) + { + for (size_t j = 0; j < node_count; j++) + BAN::deallocator(nodes[j]); + return BAN::Error::from_errno(ENOMEM); + } + *nodes[node_count++] = { + .symbol = sym, + .freq = freq[sym], + .left = nullptr, + .right = nullptr, + }; + } + + for (auto& symbol : output) + symbol = { .code = 0, .length = 0 }; + + if (node_count == 0) + { + output[0] = { .code = 0, .length = 1 }; + return {}; + } + + static void (*free_tree)(node_t*) = + [](node_t* root) -> void { + if (root == nullptr) + return; + free_tree(root->left); + free_tree(root->right); + BAN::deallocator(root); + }; + + const auto comp = + [](const node_t* a, const node_t* b) -> bool { + if (a->freq != b->freq) + return a->freq > b->freq; + return a->symbol > b->symbol; + }; + + auto end_it = nodes.begin() + node_count; + BAN::make_heap(nodes.begin(), end_it, comp); + + while (nodes.begin() + 1 != end_it) + { + node_t* parent = static_cast(BAN::allocator(sizeof(node_t))); + if (parent == nullptr) + { + for (auto it = nodes.begin(); it != end_it; it++) + free_tree(*it); + return BAN::Error::from_errno(ENOMEM); + } + + node_t* node1 = nodes.front(); + BAN::pop_heap(nodes.begin(), end_it--, comp); + + node_t* node2 = nodes.front(); + BAN::pop_heap(nodes.begin(), end_it--, comp); + + *parent = { + .symbol = 0, + .freq = node1->freq + node2->freq, + .left = node1, + .right = node2, + }; + + *end_it++ = parent; + BAN::push_heap(nodes.begin(), end_it, comp); + } + + static uint16_t (*gather_lengths)(const node_t*, BAN::Span, uint16_t) = + [](const node_t* node, BAN::Span symbols, uint16_t depth) -> uint16_t { + if (node == nullptr) + return 0; + uint16_t count = (depth > s_max_bits); + if (node->left == nullptr && node->right == nullptr) + symbols[node->symbol].length = BAN::Math::min(depth, s_max_bits); + else + { + count += gather_lengths(node->left, symbols, depth + 1); + count += gather_lengths(node->right, symbols, depth + 1); + } + return count; + }; + + const auto too_long_count = gather_lengths(nodes[0], output, 0); + free_tree(nodes[0]); + + uint16_t bl_count[s_max_bits + 1] {}; + for (size_t sym = 0; sym < freq.size(); sym++) + if (const uint8_t len = output[sym].length) + bl_count[len]++; + + if (too_long_count > 0) + { + for (size_t i = 0; i < too_long_count / 2; i++) + { + uint16_t bits = s_max_bits - 1; + while (bl_count[bits] == 0) + bits--; + bl_count[bits + 0]--; + bl_count[bits + 1] += 2; + bl_count[s_max_bits]--; + } + + struct SymFreq + { + size_t symbol; + size_t freq; + }; + + BAN::Vector sym_freq; + for (size_t sym = 0; sym < output.size(); sym++) + if (freq[sym] != 0) + TRY(sym_freq.push_back({ .symbol = sym, .freq = freq[sym] })); + + BAN::sort::sort(sym_freq.begin(), sym_freq.end(), + [](auto a, auto b) { return a.freq < b.freq; } + ); + + size_t index = 0; + for (uint16_t bits = s_max_bits; bits > 0; bits--) + for (size_t i = 0; i < bl_count[bits]; i++) + output[sym_freq[index++].symbol].length = bits; + ASSERT(index == sym_freq.size()); + } + + uint16_t next_code[s_max_bits + 1] {}; + uint16_t code = 0; + for (uint8_t bits = 1; bits <= s_max_bits; bits++) + { + code = (code + bl_count[bits - 1]) << 1; + next_code[bits] = code; + } + + for (size_t sym = 0; sym < freq.size(); sym++) + if (const uint16_t len = output[sym].length) + output[sym].code = next_code[len]++; + + return {}; + } + + struct Encoding + { + uint16_t symbol; + uint16_t extra_data { 0 }; + uint8_t extra_len { 0 }; + }; + + static constexpr Encoding get_len_encoding(uint16_t length) + { + ASSERT(3 <= length && length <= s_max_length); + + constexpr uint16_t base[] { + 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258 + }; + constexpr uint8_t extra_bits[] { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 + }; + constexpr size_t count = sizeof(base) / sizeof(*base); + + for (size_t i = 0;; i++) + { + if (i + 1 < count && length >= base[i + 1]) + continue; + return { + .symbol = static_cast(257 + i), + .extra_data = static_cast(length - base[i]), + .extra_len = extra_bits[i], + }; + } + } + + static constexpr Encoding get_dist_encoding(uint16_t distance) + { + ASSERT(1 <= distance && distance <= s_max_distance); + + constexpr uint16_t base[] { + 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577 + }; + constexpr uint8_t extra_bits[] { + 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13 + }; + constexpr size_t count = sizeof(base) / sizeof(*base); + + for (size_t i = 0;; i++) + { + if (i + 1 < count && distance >= base[i + 1]) + continue; + return { + .symbol = static_cast(i), + .extra_data = static_cast(distance - base[i]), + .extra_len = extra_bits[i], + }; + } + } + + static void get_frequencies(BAN::Span entries, BAN::Span lit_len_freq, BAN::Span dist_freq) + { + ASSERT(lit_len_freq.size() == 286); + ASSERT(dist_freq.size() == 30); + + for (auto entry : entries) + { + switch (entry.type) + { + case Compressor::LZ77Entry::Type::Literal: + lit_len_freq[entry.as.literal]++; + break; + case Compressor::LZ77Entry::Type::DistLength: + lit_len_freq[get_len_encoding(entry.as.dist_length.length).symbol]++; + dist_freq[get_dist_encoding(entry.as.dist_length.distance).symbol]++; + break; + } + } + + lit_len_freq[256]++; + } + + struct CodeLengthInfo + { + uint16_t hlit; + uint8_t hdist; + uint8_t hclen; + BAN::Vector encoding; + BAN::Array code_length; + BAN::Array code_length_tree; + }; + + static BAN::ErrorOr build_code_length_info(BAN::Span lit_len_tree, BAN::Span dist_tree) + { + CodeLengthInfo result; + + const auto append_tree = + [&result](BAN::Span& tree) -> BAN::ErrorOr + { + while (!tree.empty() && tree[tree.size() - 1].length == 0) + tree = tree.slice(0, tree.size() - 1); + + for (size_t i = 0; i < tree.size();) + { + size_t count = 1; + while (i + count < tree.size() && tree[i].length == tree[i + count].length) + count++; + + if (tree[i].length == 0) + { + if (count > 138) + count = 138; + + if (count < 3) + { + for (size_t j = 0; j < count; j++) + TRY(result.encoding.push_back({ .symbol = 0 })); + } + else if (count < 11) + { + TRY(result.encoding.push_back({ + .symbol = 17, + .extra_data = static_cast(count - 3), + .extra_len = 3, + })); + } + else + { + TRY(result.encoding.push_back({ + .symbol = 18, + .extra_data = static_cast(count - 11), + .extra_len = 7, + })); + } + } + else + { + if (count >= 3 && !result.encoding.empty() && result.encoding.back().symbol == tree[i].length) + { + if (count > 6) + count = 6; + TRY(result.encoding.push_back({ + .symbol = 16, + .extra_data = static_cast(count - 3), + .extra_len = 2, + })); + } + else + { + count = 1; + TRY(result.encoding.push_back({ .symbol = tree[i].length })); + } + } + + i += count; + } + + return {}; + }; + + TRY(append_tree(lit_len_tree)); + result.hlit = lit_len_tree.size(); + + TRY(append_tree(dist_tree)); + result.hdist = dist_tree.size(); + + BAN::Array code_len_freq(0); + for (auto entry : result.encoding) + code_len_freq[entry.symbol]++; + TRY(create_huffman_tree(code_len_freq.span(), result.code_length_tree.span())); + + constexpr uint8_t code_length_order[] { + 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 + }; + for (size_t i = 0; i < result.code_length_tree.size(); i++) + result.code_length[i] = result.code_length_tree[code_length_order[i]].length; + result.hclen = 19; + while (result.hclen > 4 && result.code_length[result.hclen - 1] == 0) + result.hclen--; + + return BAN::move(result); + } + + uint32_t Compressor::get_hash_key(BAN::ConstByteSpan needle) const + { + ASSERT(needle.size() >= 3); + return (needle[2] << 16) | (needle[1] << 8) | needle[0]; + } + + BAN::ErrorOr Compressor::update_hash_chain(size_t count) + { + if (m_hash_chain.size() >= s_max_distance * 2) + { + const uint8_t* current = m_data.data() + m_hash_chain_index + count; + for (auto& [_, chain] : m_hash_chain) + { + for (auto it = chain.begin(); it != chain.end(); it++) + { + const size_t distance = current - it->data(); + if (distance < s_max_distance) + continue; + + while (it != chain.end()) + it = chain.remove(it); + break; + } + } + } + + for (size_t i = 0; i < count; i++) + { + auto slice = m_data.slice(m_hash_chain_index + i); + if (slice.size() < 3) + break; + + const uint32_t key = get_hash_key(slice); + + auto it = m_hash_chain.find(key); + if (it != m_hash_chain.end()) + TRY(it->value.insert(it->value.begin(), slice)); + else + { + HashChain new_chain; + TRY(new_chain.push_back(slice)); + TRY(m_hash_chain.insert(key, BAN::move(new_chain))); + } + } + + m_hash_chain_index += count; + + return {}; + } + + BAN::ErrorOr Compressor::find_longest_match(BAN::ConstByteSpan needle) const + { + LZ77Entry result = { + .type = LZ77Entry::Type::Literal, + .as = { .literal = needle[0] } + }; + + if (needle.size() < 3) + return result; + + const uint32_t key = get_hash_key(needle); + + auto it = m_hash_chain.find(key); + if (it == m_hash_chain.end()) + return result; + + auto& chain = it->value; + for (const auto node : chain) + { + const size_t distance = needle.data() - node.data(); + if (distance > s_max_distance) + break; + + size_t length = 3; + const size_t max_length = BAN::Math::min(needle.size(), s_max_length); + while (length < max_length && needle[length] == node[length]) + length++; + + if (result.type != LZ77Entry::Type::DistLength || length > result.as.dist_length.length) + { + result = LZ77Entry { + .type = LZ77Entry::Type::DistLength, + .as = { + .dist_length = { + .length = static_cast(length), + .distance = static_cast(distance), + } + } + }; + } + } + + return result; + } + + BAN::ErrorOr> Compressor::lz77_compress(BAN::ConstByteSpan data) + { + BAN::Vector result; + + size_t advance = 0; + for (size_t i = 0; i < data.size(); i += advance) + { + TRY(update_hash_chain(advance)); + + auto match = TRY(find_longest_match(data.slice(i))); + if (match.type == LZ77Entry::Type::Literal) + { + TRY(result.push_back(match)); + advance = 1; + continue; + } + + ASSERT(match.type == LZ77Entry::Type::DistLength); + + auto lazy_match = TRY(find_longest_match(data.slice(i + 1))); + if (lazy_match.type == LZ77Entry::Type::DistLength && lazy_match.as.dist_length.length > match.as.dist_length.length) + { + TRY(result.push_back({ .type = LZ77Entry::Type::Literal, .as = { .literal = data[i] }})); + TRY(result.push_back(lazy_match)); + advance = 1 + lazy_match.as.dist_length.length; + } + else + { + TRY(result.push_back(match)); + advance = match.as.dist_length.length; + } + } + + return result; + } + + BAN::ErrorOr Compressor::compress_block(BAN::ConstByteSpan data, bool final) + { + // FIXME: use fixed trees or uncompressed blocks + + auto lz77_entries = TRY(lz77_compress(data)); + +#if LIBDEFLATE_AVOID_STACK + BAN::Vector lit_len_freq, dist_freq; + TRY(lit_len_freq.resize(286, 0)); + TRY(dist_freq.resize(30, 0)); +#else + BAN::Array lit_len_freq(0); + BAN::Array dist_freq(0); +#endif + + get_frequencies(lz77_entries.span(), lit_len_freq.span(), dist_freq.span()); + +#if LIBDEFLATE_AVOID_STACK + BAN::Vector lit_len_tree, dist_tree; + TRY(lit_len_tree.resize(286)); + TRY(dist_tree.resize(30)); +#else + BAN::Array lit_len_tree; + BAN::Array dist_tree; +#endif + + TRY(create_huffman_tree(lit_len_freq.span(), lit_len_tree.span())); + TRY(create_huffman_tree(dist_freq.span(), dist_tree.span())); + + auto info = TRY(build_code_length_info(lit_len_tree.span(), dist_tree.span())); + + TRY(m_stream.write_bits(final, 1)); + TRY(m_stream.write_bits(2, 2)); + + TRY(m_stream.write_bits(info.hlit - 257, 5)); + TRY(m_stream.write_bits(info.hdist - 1, 5)); + TRY(m_stream.write_bits(info.hclen - 4, 4)); + + for (size_t i = 0; i < info.hclen; i++) + TRY(m_stream.write_bits(info.code_length[i], 3)); + + for (const auto entry : info.encoding) + { + const auto symbol = info.code_length_tree[entry.symbol]; + TRY(m_stream.write_bits(reverse_bits(symbol.code, symbol.length), symbol.length)); + TRY(m_stream.write_bits(entry.extra_data, entry.extra_len)); + } + + for (const auto entry : lz77_entries) + { + switch (entry.type) + { + case LZ77Entry::Type::Literal: + { + const auto symbol = lit_len_tree[entry.as.literal]; + TRY(m_stream.write_bits(reverse_bits(symbol.code, symbol.length), symbol.length)); + break; + } + case LZ77Entry::Type::DistLength: + { + const auto len_encoding = get_len_encoding(entry.as.dist_length.length); + const auto len_code = lit_len_tree[len_encoding.symbol]; + TRY(m_stream.write_bits(reverse_bits(len_code.code, len_code.length), len_code.length)); + TRY(m_stream.write_bits(len_encoding.extra_data, len_encoding.extra_len)); + + const auto dist_encoding = get_dist_encoding(entry.as.dist_length.distance); + const auto dist_code = dist_tree[dist_encoding.symbol]; + TRY(m_stream.write_bits(reverse_bits(dist_code.code, dist_code.length), dist_code.length)); + TRY(m_stream.write_bits(dist_encoding.extra_data, dist_encoding.extra_len)); + + break; + } + } + } + + const auto end_code = lit_len_tree[256]; + TRY(m_stream.write_bits(reverse_bits(end_code.code, end_code.length), end_code.length)); + + return {}; + } + + BAN::ErrorOr> Compressor::compress() + { + uint32_t checksum = 0; + switch (m_type) + { + case StreamType::Raw: + break; + case StreamType::Zlib: + TRY(m_stream.write_bits(0x78, 8)); // deflate with 32k window + TRY(m_stream.write_bits(0x9C, 8)); // default compression + checksum = calculate_adler32(m_data); + break; + } + + constexpr size_t max_block_size = 16 * 1024; + while (!m_data.empty()) + { + const size_t block_size = BAN::Math::min(m_data.size(), max_block_size); + TRY(compress_block(m_data.slice(0, block_size), block_size == m_data.size())); + m_data = m_data.slice(block_size); + } + + TRY(m_stream.pad_to_byte_boundary()); + + switch (m_type) + { + case StreamType::Raw: + break; + case StreamType::Zlib: + TRY(m_stream.write_bits(checksum >> 24, 8)); + TRY(m_stream.write_bits(checksum >> 16, 8)); + TRY(m_stream.write_bits(checksum >> 8, 8)); + TRY(m_stream.write_bits(checksum >> 0, 8)); + break; + } + + return m_stream.take_buffer(); + } + +} diff --git a/userspace/libraries/LibDEFLATE/Decompressor.cpp b/userspace/libraries/LibDEFLATE/Decompressor.cpp new file mode 100644 index 00000000..c5d757bc --- /dev/null +++ b/userspace/libraries/LibDEFLATE/Decompressor.cpp @@ -0,0 +1,277 @@ +#include +#include + +namespace LibDEFLATE +{ + + union ZLibHeader + { + struct + { + uint8_t cm : 4; + uint8_t cinfo : 4; + uint8_t fcheck : 5; + uint8_t fdict : 1; + uint8_t flevel : 2; + }; + struct + { + uint8_t raw1; + uint8_t raw2; + }; + }; + + BAN::ErrorOr Decompressor::read_symbol(const HuffmanTree& tree) + { + const uint8_t instant_bits = tree.instant_bits(); + + uint16_t code = reverse_bits(TRY(m_stream.peek_bits(instant_bits)), instant_bits); + if (auto symbol = tree.get_symbol_instant(code); symbol.has_value()) + { + MUST(m_stream.take_bits(symbol->len)); + return symbol->symbol; + } + + MUST(m_stream.take_bits(instant_bits)); + + uint8_t len = instant_bits; + while (len < tree.max_bits()) + { + code = (code << 1) | TRY(m_stream.take_bits(1)); + len++; + if (auto symbol = tree.get_symbol(code, len); symbol.has_value()) + return symbol.value(); + } + + return BAN::Error::from_errno(EINVAL); + } + + BAN::ErrorOr Decompressor::inflate_block(const HuffmanTree& length_tree, const HuffmanTree& distance_tree) + { + uint16_t symbol; + while ((symbol = TRY(read_symbol(length_tree))) != 256) + { + if (symbol < 256) + { + TRY(m_output.push_back(symbol)); + continue; + } + + constexpr uint16_t length_base[] { + 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258 + }; + constexpr uint8_t length_extra_bits[] { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 + }; + + constexpr uint16_t distance_base[] { + 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577 + }; + constexpr uint8_t distance_extra_bits[] { + 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13 + }; + + if (symbol > 285) + return BAN::Error::from_errno(EINVAL); + symbol -= 257; + + const uint16_t length = length_base[symbol] + TRY(m_stream.take_bits(length_extra_bits[symbol])); + + uint16_t distance_code; + if (distance_tree.empty()) + distance_code = reverse_bits(TRY(m_stream.take_bits(5)), 5); + else + distance_code = TRY(read_symbol(distance_tree)); + if (distance_code > 29) + return BAN::Error::from_errno(EINVAL); + + const uint16_t distance = distance_base[distance_code] + TRY(m_stream.take_bits(distance_extra_bits[distance_code])); + + const size_t orig_size = m_output.size(); + const size_t offset = orig_size - distance; + TRY(m_output.resize(orig_size + length)); + for (size_t i = 0; i < length; i++) + m_output[orig_size + i] = m_output[offset + i]; + } + + return {}; + } + + BAN::ErrorOr Decompressor::handle_header() + { + switch (m_type) + { + case StreamType::Raw: + return {}; + case StreamType::Zlib: + { + ZLibHeader header; + header.raw1 = TRY(m_stream.take_bits(8)); + header.raw2 = TRY(m_stream.take_bits(8)); + + if (((header.raw1 << 8) | header.raw2) % 31) + { + dwarnln("zlib header checksum failed"); + return BAN::Error::from_errno(EINVAL); + } + + if (header.cm != 8) + { + dwarnln("zlib does not use DEFLATE"); + return BAN::Error::from_errno(EINVAL); + } + + if (header.fdict) + { + TRY(m_stream.take_bits(16)); + TRY(m_stream.take_bits(16)); + } + + return {}; + } + } + + ASSERT_NOT_REACHED(); + } + + BAN::ErrorOr Decompressor::handle_footer() + { + switch (m_type) + { + case StreamType::Raw: + return {}; + case StreamType::Zlib: + { + m_stream.skip_to_byte_boundary(); + + uint32_t adler32 = 0; + for (size_t i = 0; i < 4; i++) + adler32 = (adler32 << 8) | TRY(m_stream.take_bits(8)); + + if (adler32 != calculate_adler32(m_output.span())) + { + dwarnln("zlib final adler32 checksum failed"); + return BAN::Error::from_errno(EINVAL); + } + + return {}; + } + } + + ASSERT_NOT_REACHED(); + } + + BAN::ErrorOr Decompressor::decompress_type0() + { + m_stream.skip_to_byte_boundary(); + const uint16_t len = TRY(m_stream.take_bits(16)); + const uint16_t nlen = TRY(m_stream.take_bits(16)); + if (len != 0xFFFF - nlen) + return BAN::Error::from_errno(EINVAL); + + const size_t orig_size = m_output.size(); + TRY(m_output.resize(orig_size + len)); + TRY(m_stream.take_byte_aligned(&m_output[orig_size], len)); + + return {}; + } + + BAN::ErrorOr Decompressor::decompress_type1() + { + if (!m_fixed_tree.has_value()) + m_fixed_tree = TRY(HuffmanTree::fixed_tree()); + TRY(inflate_block(m_fixed_tree.value(), {})); + return {}; + } + + BAN::ErrorOr Decompressor::decompress_type2() + { + constexpr uint8_t code_length_order[] { + 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 + }; + + const uint16_t hlit = TRY(m_stream.take_bits(5)) + 257; + const uint8_t hdist = TRY(m_stream.take_bits(5)) + 1; + const uint8_t hclen = TRY(m_stream.take_bits(4)) + 4; + + uint8_t code_lengths[19] {}; + for (size_t i = 0; i < hclen; i++) + code_lengths[code_length_order[i]] = TRY(m_stream.take_bits(3)); + const auto code_length_tree = TRY(HuffmanTree::create({ code_lengths, 19 })); + + uint8_t bit_lengths[286 + 32] {}; + size_t bit_lengths_len = 0; + + uint16_t last_symbol = 0; + while (bit_lengths_len < hlit + hdist) + { + uint16_t symbol = TRY(read_symbol(code_length_tree)); + if (symbol > 18) + return BAN::Error::from_errno(EINVAL); + + uint8_t count; + if (symbol <= 15) + { + count = 1; + } + else if (symbol == 16) + { + symbol = last_symbol; + count = TRY(m_stream.take_bits(2)) + 3; + } + else if (symbol == 17) + { + symbol = 0; + count = TRY(m_stream.take_bits(3)) + 3; + } + else + { + symbol = 0; + count = TRY(m_stream.take_bits(7)) + 11; + } + + ASSERT(bit_lengths_len + count <= hlit + hdist); + + for (uint8_t i = 0; i < count; i++) + bit_lengths[bit_lengths_len++] = symbol; + last_symbol = symbol; + } + + TRY(inflate_block( + TRY(HuffmanTree::create({ bit_lengths, hlit })), + TRY(HuffmanTree::create({ bit_lengths + hlit, hdist })) + )); + + return {}; + } + + BAN::ErrorOr> Decompressor::decompress() + { + TRY(handle_header()); + + bool bfinal = false; + while (!bfinal) + { + bfinal = TRY(m_stream.take_bits(1)); + switch (TRY(m_stream.take_bits(2))) + { + case 0b00: + TRY(decompress_type0()); + break; + case 0b01: + TRY(decompress_type1()); + break; + case 0b10: + TRY(decompress_type2()); + break; + default: + return BAN::Error::from_errno(EINVAL); + } + } + + TRY(handle_footer()); + + return BAN::move(m_output); + } + +} diff --git a/userspace/libraries/LibDEFLATE/HuffmanTree.cpp b/userspace/libraries/LibDEFLATE/HuffmanTree.cpp new file mode 100644 index 00000000..17bcc252 --- /dev/null +++ b/userspace/libraries/LibDEFLATE/HuffmanTree.cpp @@ -0,0 +1,141 @@ +#include + +namespace LibDEFLATE +{ + + HuffmanTree& HuffmanTree::operator=(HuffmanTree&& other) + { + m_instant_bits = other.m_instant_bits; + m_min_bits = other.m_min_bits; + m_max_bits = other.m_max_bits; + + m_instant = BAN::move(other.m_instant); + m_min_code = BAN::move(other.m_min_code); + m_slow_table = BAN::move(other.m_slow_table); + + return *this; + } + + BAN::ErrorOr HuffmanTree::create(BAN::Span bit_lengths) + { + HuffmanTree result; + TRY(result.initialize(bit_lengths)); + return result; + } + + BAN::ErrorOr HuffmanTree::initialize(BAN::Span bit_lengths) + { + m_max_bits = 0; + m_min_bits = MAX_BITS; + + uint16_t max_sym = 0; + uint16_t bl_count[MAX_BITS + 1] {}; + for (size_t sym = 0; sym < bit_lengths.size(); sym++) + { + if (bit_lengths[sym] == 0) + continue; + m_max_bits = BAN::Math::max(bit_lengths[sym], m_max_bits); + m_min_bits = BAN::Math::min(bit_lengths[sym], m_min_bits); + bl_count[bit_lengths[sym]]++; + max_sym = sym; + } + + uint16_t next_code[MAX_BITS + 1] {}; + + uint16_t code = 0; + for (uint8_t bits = 1; bits <= MAX_BITS; bits++) + { + code = (code + bl_count[bits - 1]) << 1; + next_code[bits] = code; + m_min_code[bits] = code; + } + + BAN::Vector tree; + TRY(tree.resize(max_sym + 1, { .code = 0, .len = 0 })); + for (uint16_t sym = 0; sym <= max_sym; sym++) + { + tree[sym].len = bit_lengths[sym]; + if (const uint8_t len = tree[sym].len) + tree[sym].code = next_code[len]++; + } + + TRY(build_instant_table(tree.span())); + TRY(build_slow_table(tree.span())); + + return {}; + } + + BAN::ErrorOr HuffmanTree::build_instant_table(BAN::Span tree) + { + m_instant_bits = BAN::Math::min(9, m_max_bits); + TRY(m_instant.resize(1 << m_instant_bits, {})); + + for (uint16_t sym = 0; sym < tree.size(); sym++) + { + if (tree[sym].len == 0 || tree[sym].len > m_instant_bits) + continue; + const uint16_t code = tree[sym].code; + const uint16_t shift = m_instant_bits - tree[sym].len; + for (uint16_t j = code << shift; j < (code + 1) << shift; j++) + m_instant[j] = { sym, tree[sym].len }; + } + + return {}; + } + + BAN::ErrorOr HuffmanTree::build_slow_table(BAN::Span tree) + { + TRY(m_slow_table.resize(MAX_BITS + 1)); + for (uint16_t sym = 0; sym < tree.size(); sym++) + { + const auto leaf = tree[sym]; + if (leaf.len == 0) + continue; + const size_t offset = leaf.code - m_min_code[leaf.len]; + if (offset >= m_slow_table[leaf.len].size()) + TRY(m_slow_table[leaf.len].resize(offset + 1)); + m_slow_table[leaf.len][offset] = sym; + } + + return {}; + } + + + BAN::ErrorOr HuffmanTree::fixed_tree() + { + struct BitLengths + { + consteval BitLengths() + { + size_t i = 0; + for (; i <= 143; i++) values[i] = 8; + for (; i <= 255; i++) values[i] = 9; + for (; i <= 279; i++) values[i] = 7; + for (; i <= 287; i++) values[i] = 8; + } + + BAN::Array values; + }; + static constexpr BitLengths bit_lengths; + return TRY(HuffmanTree::create(bit_lengths.values.span())); + } + + BAN::Optional HuffmanTree::get_symbol_instant(uint16_t code) const + { + ASSERT(code < m_instant.size()); + if (const auto entry = m_instant[code]; entry.len) + return entry; + return {}; + } + + BAN::Optional HuffmanTree::get_symbol(uint16_t code, uint8_t len) const + { + ASSERT(len <= m_max_bits); + const auto& symbols = m_slow_table[len]; + const size_t offset = code - m_min_code[len]; + if (symbols.size() <= offset) + return {}; + return symbols[offset]; + } + +} diff --git a/userspace/libraries/LibDEFLATE/include/LibDEFLATE/BitStream.h b/userspace/libraries/LibDEFLATE/include/LibDEFLATE/BitStream.h new file mode 100644 index 00000000..dbde3cad --- /dev/null +++ b/userspace/libraries/LibDEFLATE/include/LibDEFLATE/BitStream.h @@ -0,0 +1,118 @@ +#pragma once + +#include +#include + +namespace LibDEFLATE +{ + + class BitInputStream + { + public: + BitInputStream(BAN::ConstByteSpan data) + : m_data(data) + { } + + BAN::ErrorOr peek_bits(size_t count) + { + ASSERT(count <= 16); + + while (m_bit_buffer_len < count) + { + if (m_data.empty()) + return BAN::Error::from_errno(ENOBUFS); + m_bit_buffer |= m_data[0] << m_bit_buffer_len; + m_bit_buffer_len += 8; + m_data = m_data.slice(1); + } + + return m_bit_buffer & ((1 << count) - 1); + } + + BAN::ErrorOr take_bits(size_t count) + { + const uint16_t result = TRY(peek_bits(count)); + m_bit_buffer >>= count; + m_bit_buffer_len -= count; + return result; + } + + BAN::ErrorOr take_byte_aligned(uint8_t* output, size_t bytes) + { + ASSERT(m_bit_buffer % 8 == 0); + + while (m_bit_buffer_len && bytes) + { + *output++ = m_bit_buffer; + m_bit_buffer >>= 8; + m_bit_buffer_len -= 8; + bytes--; + } + + if (bytes > m_data.size()) + return BAN::Error::from_errno(EINVAL); + memcpy(output, m_data.data(), bytes); + m_data = m_data.slice(bytes); + + return {}; + } + + void skip_to_byte_boundary() + { + const size_t bits_to_remove = m_bit_buffer_len % 8; + m_bit_buffer >>= bits_to_remove; + m_bit_buffer_len -= bits_to_remove; + } + + private: + BAN::ConstByteSpan m_data; + uint32_t m_bit_buffer { 0 }; + uint8_t m_bit_buffer_len { 0 }; + }; + + class BitOutputStream + { + public: + BAN::ErrorOr write_bits(uint16_t value, size_t count) + { + ASSERT(m_bit_buffer_len < 8); + ASSERT(count <= 16); + + const uint16_t mask = (1 << count) - 1; + m_bit_buffer |= (value & mask) << m_bit_buffer_len; + m_bit_buffer_len += count; + + while (m_bit_buffer_len >= 8) + { + TRY(m_data.push_back(m_bit_buffer)); + m_bit_buffer >>= 8; + m_bit_buffer_len -= 8; + } + + return {}; + } + + BAN::ErrorOr pad_to_byte_boundary() + { + ASSERT(m_bit_buffer_len < 8); + if (m_bit_buffer_len == 0) + return {}; + TRY(m_data.push_back(m_bit_buffer)); + m_bit_buffer = 0; + m_bit_buffer_len = 0; + return {}; + } + + BAN::Vector take_buffer() + { + ASSERT(m_bit_buffer_len == 0); + return BAN::move(m_data); + } + + private: + BAN::Vector m_data; + uint32_t m_bit_buffer { 0 }; + uint8_t m_bit_buffer_len { 0 }; + }; + +} diff --git a/userspace/libraries/LibDEFLATE/include/LibDEFLATE/Compressor.h b/userspace/libraries/LibDEFLATE/include/LibDEFLATE/Compressor.h new file mode 100644 index 00000000..fccc524d --- /dev/null +++ b/userspace/libraries/LibDEFLATE/include/LibDEFLATE/Compressor.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace LibDEFLATE +{ + + class Compressor + { + BAN_NON_COPYABLE(Compressor); + BAN_NON_MOVABLE(Compressor); + + public: + using HashChain = BAN::LinkedList; + + struct LZ77Entry + { + enum class Type + { + Literal, + DistLength, + } type; + union + { + uint8_t literal; + struct + { + uint16_t length; + uint16_t distance; + } dist_length; + } as; + }; + + public: + Compressor(BAN::ConstByteSpan data, StreamType type) + : m_type(type) + , m_data(data) + { } + + BAN::ErrorOr> compress(); + + private: + BAN::ErrorOr compress_block(BAN::ConstByteSpan, bool final); + + uint32_t get_hash_key(BAN::ConstByteSpan needle) const; + BAN::ErrorOr update_hash_chain(size_t count); + + BAN::ErrorOr find_longest_match(BAN::ConstByteSpan needle) const; + BAN::ErrorOr> lz77_compress(BAN::ConstByteSpan data); + + private: + const StreamType m_type; + BAN::ConstByteSpan m_data; + BitOutputStream m_stream; + + size_t m_hash_chain_index { 0 }; + BAN::HashMap m_hash_chain; + }; + +} diff --git a/userspace/libraries/LibDEFLATE/include/LibDEFLATE/Decompressor.h b/userspace/libraries/LibDEFLATE/include/LibDEFLATE/Decompressor.h new file mode 100644 index 00000000..f293a475 --- /dev/null +++ b/userspace/libraries/LibDEFLATE/include/LibDEFLATE/Decompressor.h @@ -0,0 +1,46 @@ +#pragma once + + +#include +#include +#include + +#include +#include +#include + +namespace LibDEFLATE +{ + + class Decompressor + { + BAN_NON_COPYABLE(Decompressor); + BAN_NON_MOVABLE(Decompressor); + + public: + Decompressor(BAN::ConstByteSpan data, StreamType type) + : m_type(type) + , m_stream(data) + { } + + BAN::ErrorOr> decompress(); + + private: + BAN::ErrorOr read_symbol(const HuffmanTree& tree); + BAN::ErrorOr inflate_block(const HuffmanTree& length_tree, const HuffmanTree& distance_tree); + + BAN::ErrorOr decompress_type0(); + BAN::ErrorOr decompress_type1(); + BAN::ErrorOr decompress_type2(); + + BAN::ErrorOr handle_header(); + BAN::ErrorOr handle_footer(); + + private: + const StreamType m_type; + BitInputStream m_stream; + BAN::Vector m_output; + BAN::Optional m_fixed_tree; + }; + +} diff --git a/userspace/libraries/LibDEFLATE/include/LibDEFLATE/HuffmanTree.h b/userspace/libraries/LibDEFLATE/include/LibDEFLATE/HuffmanTree.h new file mode 100644 index 00000000..89fdb7e1 --- /dev/null +++ b/userspace/libraries/LibDEFLATE/include/LibDEFLATE/HuffmanTree.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include +#include +#include + +namespace LibDEFLATE +{ + + class HuffmanTree + { + BAN_NON_COPYABLE(HuffmanTree); + + public: + static constexpr uint8_t MAX_BITS = 15; + + struct Leaf + { + uint16_t code; + uint8_t len; + }; + + struct Instant + { + uint16_t symbol; + uint8_t len; + }; + + HuffmanTree() {} + HuffmanTree(HuffmanTree&& other) { *this = BAN::move(other); } + HuffmanTree& operator=(HuffmanTree&& other); + + static BAN::ErrorOr create(BAN::Span bit_lengths); + + static BAN::ErrorOr fixed_tree(); + BAN::Optional get_symbol_instant(uint16_t code) const; + + BAN::Optional get_symbol(uint16_t code, uint8_t len) const; + + uint8_t instant_bits() const { return m_instant_bits; } + uint8_t min_bits() const { return m_min_bits; } + uint8_t max_bits() const { return m_max_bits; } + bool empty() const { return m_min_bits == 0; } + + private: + BAN::ErrorOr initialize(BAN::Span bit_lengths); + BAN::ErrorOr build_instant_table(BAN::Span tree); + BAN::ErrorOr build_slow_table(BAN::Span tree); + + private: + uint8_t m_instant_bits { 0 }; + uint8_t m_min_bits { 0 }; + uint8_t m_max_bits { 0 }; + + BAN::Vector m_instant; + BAN::Array m_min_code; + BAN::Vector> m_slow_table; + }; + +} diff --git a/userspace/libraries/LibDEFLATE/include/LibDEFLATE/StreamType.h b/userspace/libraries/LibDEFLATE/include/LibDEFLATE/StreamType.h new file mode 100644 index 00000000..8ec35f2e --- /dev/null +++ b/userspace/libraries/LibDEFLATE/include/LibDEFLATE/StreamType.h @@ -0,0 +1,12 @@ +#pragma once + +namespace LibDEFLATE +{ + + enum class StreamType + { + Raw, + Zlib, + }; + +} diff --git a/userspace/libraries/LibDEFLATE/include/LibDEFLATE/Utils.h b/userspace/libraries/LibDEFLATE/include/LibDEFLATE/Utils.h new file mode 100644 index 00000000..699550b1 --- /dev/null +++ b/userspace/libraries/LibDEFLATE/include/LibDEFLATE/Utils.h @@ -0,0 +1,30 @@ +#pragma once + +#include + +namespace LibDEFLATE +{ + + inline uint32_t calculate_adler32(BAN::ConstByteSpan data) + { + uint32_t s1 = 1; + uint32_t s2 = 0; + + for (size_t i = 0; i < data.size(); i++) + { + s1 = (s1 + data[i]) % 65521; + s2 = (s2 + s1) % 65521; + } + + return (s2 << 16) | s1; + } + + inline constexpr uint16_t reverse_bits(uint16_t value, size_t count) + { + uint16_t reverse = 0; + for (uint8_t bit = 0; bit < count; bit++) + reverse |= ((value >> bit) & 1) << (count - bit - 1); + return reverse; + } + +}