#include #include #include #include #include #include namespace LibDEFLATE { constexpr size_t s_max_length = 258; constexpr size_t s_max_distance = 32768; constexpr size_t s_max_symbols = 288; constexpr uint8_t s_max_bits = 15; struct Leaf { uint16_t code; uint8_t length; }; static BAN::ErrorOr create_huffman_tree(BAN::Span freq, BAN::Span output) { ASSERT(freq.size() <= s_max_symbols); ASSERT(freq.size() == output.size()); struct node_t { size_t symbol; size_t freq; node_t* left; node_t* right; }; #if LIBDEFLATE_AVOID_STACK BAN::Vector nodes; TRY(nodes.resize(s_max_symbols)); #else BAN::Array nodes; #endif size_t node_count = 0; for (size_t sym = 0; sym < freq.size(); sym++) { if (freq[sym] == 0) continue; nodes[node_count] = static_cast(BAN::allocator(sizeof(node_t))); if (nodes[node_count] == nullptr) { for (size_t j = 0; j < node_count; j++) BAN::deallocator(nodes[j]); return BAN::Error::from_errno(ENOMEM); } *nodes[node_count++] = { .symbol = sym, .freq = freq[sym], .left = nullptr, .right = nullptr, }; } for (auto& symbol : output) symbol = { .code = 0, .length = 0 }; if (node_count == 0) { output[0] = { .code = 0, .length = 1 }; return {}; } static void (*free_tree)(node_t*) = [](node_t* root) -> void { if (root == nullptr) return; free_tree(root->left); free_tree(root->right); BAN::deallocator(root); }; const auto comp = [](const node_t* a, const node_t* b) -> bool { if (a->freq != b->freq) return a->freq > b->freq; return a->symbol > b->symbol; }; auto end_it = nodes.begin() + node_count; BAN::make_heap(nodes.begin(), end_it, comp); while (nodes.begin() + 1 != end_it) { node_t* parent = static_cast(BAN::allocator(sizeof(node_t))); if (parent == nullptr) { for (auto it = nodes.begin(); it != end_it; it++) free_tree(*it); return BAN::Error::from_errno(ENOMEM); } node_t* node1 = nodes.front(); BAN::pop_heap(nodes.begin(), end_it--, comp); node_t* node2 = nodes.front(); BAN::pop_heap(nodes.begin(), end_it--, comp); *parent = { .symbol = 0, .freq = node1->freq + node2->freq, .left = node1, .right = node2, }; *end_it++ = parent; BAN::push_heap(nodes.begin(), end_it, comp); } static uint16_t (*gather_lengths)(const node_t*, BAN::Span, uint16_t) = [](const node_t* node, BAN::Span symbols, uint16_t depth) -> uint16_t { if (node == nullptr) return 0; uint16_t count = (depth > s_max_bits); if (node->left == nullptr && node->right == nullptr) symbols[node->symbol].length = BAN::Math::min(depth, s_max_bits); else { count += gather_lengths(node->left, symbols, depth + 1); count += gather_lengths(node->right, symbols, depth + 1); } return count; }; const auto too_long_count = gather_lengths(nodes[0], output, 0); free_tree(nodes[0]); uint16_t bl_count[s_max_bits + 1] {}; for (size_t sym = 0; sym < freq.size(); sym++) if (const uint8_t len = output[sym].length) bl_count[len]++; if (too_long_count > 0) { for (size_t i = 0; i < too_long_count / 2; i++) { uint16_t bits = s_max_bits - 1; while (bl_count[bits] == 0) bits--; bl_count[bits + 0]--; bl_count[bits + 1] += 2; bl_count[s_max_bits]--; } struct SymFreq { size_t symbol; size_t freq; }; BAN::Vector sym_freq; for (size_t sym = 0; sym < output.size(); sym++) if (freq[sym] != 0) TRY(sym_freq.push_back({ .symbol = sym, .freq = freq[sym] })); BAN::sort::sort(sym_freq.begin(), sym_freq.end(), [](auto a, auto b) { return a.freq < b.freq; } ); size_t index = 0; for (uint16_t bits = s_max_bits; bits > 0; bits--) for (size_t i = 0; i < bl_count[bits]; i++) output[sym_freq[index++].symbol].length = bits; ASSERT(index == sym_freq.size()); } uint16_t next_code[s_max_bits + 1] {}; uint16_t code = 0; for (uint8_t bits = 1; bits <= s_max_bits; bits++) { code = (code + bl_count[bits - 1]) << 1; next_code[bits] = code; } for (size_t sym = 0; sym < freq.size(); sym++) if (const uint16_t len = output[sym].length) output[sym].code = next_code[len]++; return {}; } struct Encoding { uint16_t symbol; uint16_t extra_data { 0 }; uint8_t extra_len { 0 }; }; static constexpr Encoding get_len_encoding(uint16_t length) { ASSERT(3 <= length && length <= s_max_length); constexpr uint16_t base[] { 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258 }; constexpr uint8_t extra_bits[] { 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 }; constexpr size_t count = sizeof(base) / sizeof(*base); for (size_t i = 0;; i++) { if (i + 1 < count && length >= base[i + 1]) continue; return { .symbol = static_cast(257 + i), .extra_data = static_cast(length - base[i]), .extra_len = extra_bits[i], }; } } static constexpr Encoding get_dist_encoding(uint16_t distance) { ASSERT(1 <= distance && distance <= s_max_distance); constexpr uint16_t base[] { 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577 }; constexpr uint8_t extra_bits[] { 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13 }; constexpr size_t count = sizeof(base) / sizeof(*base); for (size_t i = 0;; i++) { if (i + 1 < count && distance >= base[i + 1]) continue; return { .symbol = static_cast(i), .extra_data = static_cast(distance - base[i]), .extra_len = extra_bits[i], }; } } static void get_frequencies(BAN::Span entries, BAN::Span lit_len_freq, BAN::Span dist_freq) { ASSERT(lit_len_freq.size() == 286); ASSERT(dist_freq.size() == 30); for (auto entry : entries) { switch (entry.type) { case Compressor::LZ77Entry::Type::Literal: lit_len_freq[entry.as.literal]++; break; case Compressor::LZ77Entry::Type::DistLength: lit_len_freq[get_len_encoding(entry.as.dist_length.length).symbol]++; dist_freq[get_dist_encoding(entry.as.dist_length.distance).symbol]++; break; } } lit_len_freq[256]++; } struct CodeLengthInfo { uint16_t hlit; uint8_t hdist; uint8_t hclen; BAN::Vector encoding; BAN::Array code_length; BAN::Array code_length_tree; }; static BAN::ErrorOr build_code_length_info(BAN::Span lit_len_tree, BAN::Span dist_tree) { CodeLengthInfo result; const auto append_tree = [&result](BAN::Span& tree) -> BAN::ErrorOr { while (!tree.empty() && tree[tree.size() - 1].length == 0) tree = tree.slice(0, tree.size() - 1); for (size_t i = 0; i < tree.size();) { size_t count = 1; while (i + count < tree.size() && tree[i].length == tree[i + count].length) count++; if (tree[i].length == 0) { if (count > 138) count = 138; if (count < 3) { for (size_t j = 0; j < count; j++) TRY(result.encoding.push_back({ .symbol = 0 })); } else if (count < 11) { TRY(result.encoding.push_back({ .symbol = 17, .extra_data = static_cast(count - 3), .extra_len = 3, })); } else { TRY(result.encoding.push_back({ .symbol = 18, .extra_data = static_cast(count - 11), .extra_len = 7, })); } } else { if (count >= 3 && !result.encoding.empty() && result.encoding.back().symbol == tree[i].length) { if (count > 6) count = 6; TRY(result.encoding.push_back({ .symbol = 16, .extra_data = static_cast(count - 3), .extra_len = 2, })); } else { count = 1; TRY(result.encoding.push_back({ .symbol = tree[i].length })); } } i += count; } return {}; }; TRY(append_tree(lit_len_tree)); result.hlit = lit_len_tree.size(); TRY(append_tree(dist_tree)); result.hdist = dist_tree.size(); BAN::Array code_len_freq(0); for (auto entry : result.encoding) code_len_freq[entry.symbol]++; TRY(create_huffman_tree(code_len_freq.span(), result.code_length_tree.span())); constexpr uint8_t code_length_order[] { 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 }; for (size_t i = 0; i < result.code_length_tree.size(); i++) result.code_length[i] = result.code_length_tree[code_length_order[i]].length; result.hclen = 19; while (result.hclen > 4 && result.code_length[result.hclen - 1] == 0) result.hclen--; return BAN::move(result); } uint32_t Compressor::get_hash_key(BAN::ConstByteSpan needle) const { ASSERT(needle.size() >= 3); return (needle[2] << 16) | (needle[1] << 8) | needle[0]; } BAN::ErrorOr Compressor::update_hash_chain(size_t count) { if (m_hash_chain.size() >= s_max_distance * 2) { const uint8_t* current = m_data.data() + m_hash_chain_index + count; for (auto& [_, chain] : m_hash_chain) { for (auto it = chain.begin(); it != chain.end(); it++) { const size_t distance = current - it->data(); if (distance < s_max_distance) continue; while (it != chain.end()) it = chain.remove(it); break; } } } for (size_t i = 0; i < count; i++) { auto slice = m_data.slice(m_hash_chain_index + i); if (slice.size() < 3) break; const uint32_t key = get_hash_key(slice); auto it = m_hash_chain.find(key); if (it != m_hash_chain.end()) TRY(it->value.insert(it->value.begin(), slice)); else { HashChain new_chain; TRY(new_chain.push_back(slice)); TRY(m_hash_chain.insert(key, BAN::move(new_chain))); } } m_hash_chain_index += count; return {}; } BAN::ErrorOr Compressor::find_longest_match(BAN::ConstByteSpan needle) const { LZ77Entry result = { .type = LZ77Entry::Type::Literal, .as = { .literal = needle[0] } }; if (needle.size() < 3) return result; const uint32_t key = get_hash_key(needle); auto it = m_hash_chain.find(key); if (it == m_hash_chain.end()) return result; auto& chain = it->value; for (const auto node : chain) { const size_t distance = needle.data() - node.data(); if (distance > s_max_distance) break; size_t length = 3; const size_t max_length = BAN::Math::min(needle.size(), s_max_length); while (length < max_length && needle[length] == node[length]) length++; if (result.type != LZ77Entry::Type::DistLength || length > result.as.dist_length.length) { result = LZ77Entry { .type = LZ77Entry::Type::DistLength, .as = { .dist_length = { .length = static_cast(length), .distance = static_cast(distance), } } }; } } return result; } BAN::ErrorOr> Compressor::lz77_compress(BAN::ConstByteSpan data) { BAN::Vector result; size_t advance = 0; for (size_t i = 0; i < data.size(); i += advance) { TRY(update_hash_chain(advance)); auto match = TRY(find_longest_match(data.slice(i))); if (match.type == LZ77Entry::Type::Literal) { TRY(result.push_back(match)); advance = 1; continue; } ASSERT(match.type == LZ77Entry::Type::DistLength); auto lazy_match = TRY(find_longest_match(data.slice(i + 1))); if (lazy_match.type == LZ77Entry::Type::DistLength && lazy_match.as.dist_length.length > match.as.dist_length.length) { TRY(result.push_back({ .type = LZ77Entry::Type::Literal, .as = { .literal = data[i] }})); TRY(result.push_back(lazy_match)); advance = 1 + lazy_match.as.dist_length.length; } else { TRY(result.push_back(match)); advance = match.as.dist_length.length; } } return result; } BAN::ErrorOr Compressor::compress_block(BAN::ConstByteSpan data, bool final) { // FIXME: use fixed trees or uncompressed blocks auto lz77_entries = TRY(lz77_compress(data)); #if LIBDEFLATE_AVOID_STACK BAN::Vector lit_len_freq, dist_freq; TRY(lit_len_freq.resize(286, 0)); TRY(dist_freq.resize(30, 0)); #else BAN::Array lit_len_freq(0); BAN::Array dist_freq(0); #endif get_frequencies(lz77_entries.span(), lit_len_freq.span(), dist_freq.span()); #if LIBDEFLATE_AVOID_STACK BAN::Vector lit_len_tree, dist_tree; TRY(lit_len_tree.resize(286)); TRY(dist_tree.resize(30)); #else BAN::Array lit_len_tree; BAN::Array dist_tree; #endif TRY(create_huffman_tree(lit_len_freq.span(), lit_len_tree.span())); TRY(create_huffman_tree(dist_freq.span(), dist_tree.span())); auto info = TRY(build_code_length_info(lit_len_tree.span(), dist_tree.span())); TRY(m_stream.write_bits(final, 1)); TRY(m_stream.write_bits(2, 2)); TRY(m_stream.write_bits(info.hlit - 257, 5)); TRY(m_stream.write_bits(info.hdist - 1, 5)); TRY(m_stream.write_bits(info.hclen - 4, 4)); for (size_t i = 0; i < info.hclen; i++) TRY(m_stream.write_bits(info.code_length[i], 3)); for (const auto entry : info.encoding) { const auto symbol = info.code_length_tree[entry.symbol]; TRY(m_stream.write_bits(reverse_bits(symbol.code, symbol.length), symbol.length)); TRY(m_stream.write_bits(entry.extra_data, entry.extra_len)); } for (const auto entry : lz77_entries) { switch (entry.type) { case LZ77Entry::Type::Literal: { const auto symbol = lit_len_tree[entry.as.literal]; TRY(m_stream.write_bits(reverse_bits(symbol.code, symbol.length), symbol.length)); break; } case LZ77Entry::Type::DistLength: { const auto len_encoding = get_len_encoding(entry.as.dist_length.length); const auto len_code = lit_len_tree[len_encoding.symbol]; TRY(m_stream.write_bits(reverse_bits(len_code.code, len_code.length), len_code.length)); TRY(m_stream.write_bits(len_encoding.extra_data, len_encoding.extra_len)); const auto dist_encoding = get_dist_encoding(entry.as.dist_length.distance); const auto dist_code = dist_tree[dist_encoding.symbol]; TRY(m_stream.write_bits(reverse_bits(dist_code.code, dist_code.length), dist_code.length)); TRY(m_stream.write_bits(dist_encoding.extra_data, dist_encoding.extra_len)); break; } } } const auto end_code = lit_len_tree[256]; TRY(m_stream.write_bits(reverse_bits(end_code.code, end_code.length), end_code.length)); return {}; } BAN::ErrorOr> Compressor::compress() { uint32_t checksum = 0; switch (m_type) { case StreamType::Raw: break; case StreamType::Zlib: TRY(m_stream.write_bits(0x78, 8)); // deflate with 32k window TRY(m_stream.write_bits(0x9C, 8)); // default compression checksum = calculate_adler32(m_data); break; } constexpr size_t max_block_size = 16 * 1024; while (!m_data.empty()) { const size_t block_size = BAN::Math::min(m_data.size(), max_block_size); TRY(compress_block(m_data.slice(0, block_size), block_size == m_data.size())); m_data = m_data.slice(block_size); } TRY(m_stream.pad_to_byte_boundary()); switch (m_type) { case StreamType::Raw: break; case StreamType::Zlib: TRY(m_stream.write_bits(checksum >> 24, 8)); TRY(m_stream.write_bits(checksum >> 16, 8)); TRY(m_stream.write_bits(checksum >> 8, 8)); TRY(m_stream.write_bits(checksum >> 0, 8)); break; } return m_stream.take_buffer(); } }