Files
banan-os/userspace/libraries/LibDEFLATE/Decompressor.cpp
Bananymous 3ebadc5c74 LibDEFLATE: Optimize decompression
Instead of calculating bit-by-bit crc32, we now calculate a lookup table
during compile time. The old crc32 calculation was taking almost 50% of
the decompression time.

Also handle multiple symbols at once without outputting to user. It is
much more efficient to output many bytes instead of the up to 258 that a
single symbol can decode to :^)
2026-04-14 01:50:30 +03:00

691 lines
17 KiB
C++

#include <LibDEFLATE/Decompressor.h>
#include <LibDEFLATE/Utils.h>
#include <BAN/ScopeGuard.h>
namespace LibDEFLATE
{
union ZLibHeader
{
struct
{
uint8_t cm : 4;
uint8_t cinfo : 4;
uint8_t fcheck : 5;
uint8_t fdict : 1;
uint8_t flevel : 2;
};
struct
{
uint8_t raw1;
uint8_t raw2;
};
};
struct crc32_table_t
{
consteval crc32_table_t()
{
for (uint32_t i = 0; i < 256; i++)
{
uint32_t crc32 = i;
for (size_t j = 0; j < 8; j++) {
if (crc32 & 1)
crc32 = (crc32 >> 1) ^ 0xEDB88320;
else
crc32 >>= 1;
}
table[i] = crc32;
}
}
uint32_t table[256];
};
static constexpr crc32_table_t s_crc32_table;
BAN::ErrorOr<uint16_t> Decompressor::read_symbol(const HuffmanTree& tree)
{
const uint8_t instant_bits = tree.instant_bits();
uint16_t code = reverse_bits(TRY(m_stream.peek_bits(instant_bits)), instant_bits);
if (auto symbol = tree.get_symbol_instant(code); symbol.has_value())
{
MUST(m_stream.take_bits(symbol->len));
return symbol->symbol;
}
MUST(m_stream.take_bits(instant_bits));
uint8_t len = instant_bits;
while (len < tree.max_bits())
{
code = (code << 1) | TRY(m_stream.take_bits(1));
len++;
if (auto symbol = tree.get_symbol(code, len); symbol.has_value())
return symbol.value();
}
return BAN::Error::from_errno(EINVAL);
}
BAN::ErrorOr<void> Decompressor::handle_header()
{
switch (m_type)
{
case StreamType::Raw:
return {};
case StreamType::Zlib:
{
ZLibHeader header;
header.raw1 = TRY(m_stream.take_bits(8));
header.raw2 = TRY(m_stream.take_bits(8));
if (((header.raw1 << 8) | header.raw2) % 31)
{
dwarnln("zlib header checksum failed");
return BAN::Error::from_errno(EINVAL);
}
if (header.cm != 8)
{
dwarnln("zlib does not use DEFLATE");
return BAN::Error::from_errno(EINVAL);
}
if (header.fdict)
{
TRY(m_stream.take_bits(16));
TRY(m_stream.take_bits(16));
}
m_stream_info.zlib = {
.s1 = 1,
.s2 = 0,
.adler32 = 0,
};
return {};
}
case StreamType::GZip:
{
const uint8_t id1 = TRY(m_stream.take_bits(8));
const uint8_t id2 = TRY(m_stream.take_bits(8));
if (id1 != 0x1F || id2 != 0x8B)
{
dwarnln("gzip header invalid identification");
return BAN::Error::from_errno(EINVAL);
}
const uint8_t cm = TRY(m_stream.take_bits(8));
if (cm != 8)
{
dwarnln("gzip does not use DEFLATE");
return BAN::Error::from_errno(EINVAL);
}
const uint8_t flg = TRY(m_stream.take_bits(8));
TRY(m_stream.take_bits(16)); // mtime
TRY(m_stream.take_bits(16));
TRY(m_stream.take_bits(8)); // xfl
TRY(m_stream.take_bits(8)); // os
// extra fields
if (flg & (1 << 2))
{
const uint16_t xlen = TRY(m_stream.take_bits(16));
for (size_t i = 0; i < xlen; i++)
TRY(m_stream.take_bits(8));
}
// file name
if (flg & (1 << 3))
while (TRY(m_stream.take_bits(8)) != '\0')
continue;
// file comment
if (flg & (1 << 4))
while (TRY(m_stream.take_bits(8)) != '\0')
continue;
// crc16
// TODO: validate
if (flg & (1 << 1))
TRY(m_stream.take_bits(16));
m_stream_info.gzip = {
.crc32 = 0xFFFFFFFF,
.isize = 0,
};
return {};
}
}
ASSERT_NOT_REACHED();
}
BAN::ErrorOr<void> Decompressor::handle_footer()
{
switch (m_type)
{
case StreamType::Raw:
return {};
case StreamType::Zlib:
{
m_stream.skip_to_byte_boundary();
uint32_t adler32 = 0;
for (size_t i = 0; i < 4; i++)
adler32 = (adler32 << 8) | TRY(m_stream.take_bits(8));
auto& zlib = m_stream_info.zlib;
zlib.adler32 = (zlib.s2 << 16) | zlib.s1;
if (adler32 != zlib.adler32)
{
dwarnln("zlib final adler32 checksum failed {8h} vs {8h}", adler32, zlib.adler32);
return BAN::Error::from_errno(EINVAL);
}
return {};
}
case StreamType::GZip:
{
m_stream.skip_to_byte_boundary();
auto& gzip = m_stream_info.gzip;
gzip.crc32 ^= 0xFFFFFFFF;
const uint32_t crc32 =
static_cast<uint32_t>(TRY(m_stream.take_bits(16))) |
static_cast<uint32_t>(TRY(m_stream.take_bits(16))) << 16;
if (crc32 != gzip.crc32)
{
dwarnln("gzip final crc32 checksum failed {8h} vs {8h}", crc32, gzip.crc32);
return BAN::Error::from_errno(EINVAL);
}
const uint32_t isize =
static_cast<uint32_t>(TRY(m_stream.take_bits(16))) |
static_cast<uint32_t>(TRY(m_stream.take_bits(16))) << 16;
if (isize != gzip.isize)
{
dwarnln("gzip final isize does not match {} vs {}", isize, gzip.isize);
return BAN::Error::from_errno(EINVAL);
}
return {};
}
}
ASSERT_NOT_REACHED();
}
BAN::ErrorOr<void> Decompressor::handle_dynamic_header()
{
constexpr uint8_t code_length_order[] {
16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15
};
const uint16_t hlit = TRY(m_stream.take_bits(5)) + 257;
const uint8_t hdist = TRY(m_stream.take_bits(5)) + 1;
const uint8_t hclen = TRY(m_stream.take_bits(4)) + 4;
uint8_t code_lengths[19] {};
for (size_t i = 0; i < hclen; i++)
code_lengths[code_length_order[i]] = TRY(m_stream.take_bits(3));
const auto code_length_tree = TRY(HuffmanTree::create({ code_lengths, 19 }));
uint8_t bit_lengths[286 + 32] {};
size_t bit_lengths_len = 0;
uint16_t last_symbol = 0;
while (bit_lengths_len < hlit + hdist)
{
uint16_t symbol = TRY(read_symbol(code_length_tree));
if (symbol > 18)
return BAN::Error::from_errno(EINVAL);
uint8_t count;
if (symbol <= 15)
{
count = 1;
}
else if (symbol == 16)
{
symbol = last_symbol;
count = TRY(m_stream.take_bits(2)) + 3;
}
else if (symbol == 17)
{
symbol = 0;
count = TRY(m_stream.take_bits(3)) + 3;
}
else
{
symbol = 0;
count = TRY(m_stream.take_bits(7)) + 11;
}
ASSERT(bit_lengths_len + count <= hlit + hdist);
for (uint8_t i = 0; i < count; i++)
bit_lengths[bit_lengths_len++] = symbol;
last_symbol = symbol;
}
m_length_tree = TRY(HuffmanTree::create({ bit_lengths, hlit }));
m_distance_tree = TRY(HuffmanTree::create({ bit_lengths + hlit, hdist }));
return {};
}
BAN::ErrorOr<void> Decompressor::handle_symbol()
{
uint16_t symbol = TRY(read_symbol(m_length_tree));
if (symbol == 256)
{
m_state = State::BlockHeader;
return {};
}
if (symbol < 256)
{
m_window[(m_window_tail + m_window_size) % total_window_size] = symbol;
m_produced_bytes++;
if (m_window_size < total_window_size)
m_window_size++;
else
m_window_tail = (m_window_tail + 1) % total_window_size;
return {};
}
constexpr uint16_t length_base[] {
3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258
};
constexpr uint8_t length_extra_bits[] {
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0
};
constexpr uint16_t distance_base[] {
1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577
};
constexpr uint8_t distance_extra_bits[] {
0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13
};
if (symbol > 285)
return BAN::Error::from_errno(EINVAL);
symbol -= 257;
const uint16_t length = length_base[symbol] + TRY(m_stream.take_bits(length_extra_bits[symbol]));
uint16_t distance_code;
if (m_distance_tree.empty())
distance_code = reverse_bits(TRY(m_stream.take_bits(5)), 5);
else
distance_code = TRY(read_symbol(m_distance_tree));
if (distance_code > 29)
return BAN::Error::from_errno(EINVAL);
const uint16_t distance = distance_base[distance_code] + TRY(m_stream.take_bits(distance_extra_bits[distance_code]));
if (distance > m_window_size)
return BAN::Error::from_errno(EINVAL);
const size_t offset = m_window_size - distance;
for (size_t i = 0; i < length; i++)
m_window[(m_window_tail + m_window_size + i) % total_window_size] = m_window[(m_window_tail + offset + i) % total_window_size];
m_window_size += length;
m_produced_bytes += length;
if (m_window_size > total_window_size)
{
const size_t extra = m_window_size - total_window_size;
m_window_tail = (m_window_tail + extra) % total_window_size;
m_window_size = total_window_size;
}
return {};
}
void Decompressor::write_data_to_output(BAN::ByteSpan& output)
{
if (m_produced_bytes == 0)
return;
ASSERT(m_produced_bytes <= m_window_size);
const size_t unwritten_tail = (m_window_tail + m_window_size - m_produced_bytes) % total_window_size;
const size_t to_write = BAN::Math::min(output.size(), m_produced_bytes);
const size_t before_wrap = BAN::Math::min(total_window_size - unwritten_tail, to_write);
memcpy(output.data(), m_window.data() + unwritten_tail, before_wrap);
if (const size_t after_wrap = to_write - before_wrap)
memcpy(output.data() + before_wrap, m_window.data(), after_wrap);
switch (m_type)
{
case StreamType::Raw:
break;
case StreamType::Zlib:
{
auto& zlib = m_stream_info.zlib;
for (size_t i = 0; i < to_write; i++)
{
zlib.s1 = (zlib.s1 + output[i]) % 65521;
zlib.s2 = (zlib.s2 + zlib.s1) % 65521;
}
break;
}
case StreamType::GZip:
{
auto& gzip = m_stream_info.gzip;
gzip.isize += to_write;
for (size_t i = 0; i < to_write; i++)
gzip.crc32 = (gzip.crc32 >> 8) ^ s_crc32_table.table[(gzip.crc32 ^ output[i]) & 0xFF];
break;
}
}
m_produced_bytes -= to_write;
output = output.slice(to_write);
}
BAN::ErrorOr<BAN::Vector<uint8_t>> Decompressor::decompress(BAN::ConstByteSpan input)
{
BAN::Vector<uint8_t> full_output;
TRY(full_output.resize(2 * input.size()));
size_t total_output_size { 0 };
for (;;)
{
size_t input_consumed, output_produced;
const auto status = TRY(decompress(input, input_consumed, full_output.span().slice(total_output_size), output_produced));
input = input.slice(input_consumed);
total_output_size += output_produced;
switch (status)
{
case Status::Done:
TRY(full_output.resize(total_output_size));
(void)full_output.shrink_to_fit();
return full_output;
case Status::NeedMoreOutput:
TRY(full_output.resize(full_output.size() * 2));
break;
case Status::NeedMoreInput:
return BAN::Error::from_errno(EINVAL);
}
}
}
BAN::ErrorOr<BAN::Vector<uint8_t>> Decompressor::decompress(BAN::Span<const BAN::ConstByteSpan> input)
{
size_t total_input_size = 0;
for (const auto& buffer : input)
total_input_size += buffer.size();
BAN::Vector<uint8_t> full_output;
TRY(full_output.resize(2 * total_input_size));
BAN::Vector<uint8_t> input_buffer;
TRY(input_buffer.resize(BAN::Math::min<size_t>(32 * 1024, total_input_size)));
size_t input_buffer_index = 0;
size_t input_buffer_size = 0;
const auto append_input_data =
[&]() -> bool
{
bool did_append = false;
while (!input.empty() && input_buffer_size < input_buffer.size())
{
if (input_buffer_index >= input[0].size())
{
input_buffer_index = 0;
input = input.slice(1);
continue;
}
const size_t to_copy = BAN::Math::min(input[0].size() - input_buffer_index, input_buffer.size() - input_buffer_size);
memcpy(input_buffer.data() + input_buffer_size, input[0].data() + input_buffer_index, to_copy);
input_buffer_size += to_copy;
input_buffer_index += to_copy;
did_append = true;
}
return did_append;
};
append_input_data();
size_t total_output_size = 0;
for (;;)
{
size_t input_consumed, output_produced;
const auto status = TRY(decompress(
input_buffer.span().slice(0, input_buffer_size),
input_consumed,
full_output.span().slice(total_output_size),
output_produced
));
if (input_consumed)
{
memmove(input_buffer.data(), input_buffer.data() + input_consumed, input_buffer_size - input_consumed);
input_buffer_size -= input_consumed;
}
total_output_size += output_produced;
switch (status)
{
case Status::Done:
TRY(full_output.resize(total_output_size));
(void)full_output.shrink_to_fit();
return full_output;
case Status::NeedMoreOutput:
TRY(full_output.resize(full_output.size() * 2));
break;
case Status::NeedMoreInput:
if (!append_input_data())
return BAN::Error::from_errno(EINVAL);
break;
}
}
}
BAN::ErrorOr<Decompressor::Status> Decompressor::decompress(BAN::ConstByteSpan input, size_t& input_consumed, BAN::ByteSpan output, size_t& output_produced)
{
const size_t original_input_size = input.size();
const size_t original_output_size = output.size();
BAN::ScopeGuard _([&] {
input_consumed = original_input_size - m_stream.unprocessed_bytes();
output_produced = original_output_size - output.size();
m_stream.drop_unprocessed_data();
});
m_stream.set_data(input);
if (m_window.empty())
TRY(m_window.resize(total_window_size));
write_data_to_output(output);
if (m_produced_bytes > 0)
return Status::NeedMoreOutput;
while (m_state != State::Done)
{
bool need_more_input = false;
bool restore_saved_stream = false;
auto saved_stream = m_stream;
switch (m_state)
{
case State::Done:
ASSERT_NOT_REACHED();
case State::StreamHeader:
{
if (auto ret = handle_header(); !ret.is_error())
m_state = State::BlockHeader;
else
{
if (ret.error().get_error_code() != ENOBUFS)
return ret.release_error();
need_more_input = true;
restore_saved_stream = true;
}
break;
}
case State::StreamFooter:
{
if (auto ret = handle_footer(); !ret.is_error())
m_state = State::Done;
else
{
if (ret.error().get_error_code() != ENOBUFS)
return ret.release_error();
need_more_input = true;
restore_saved_stream = true;
}
break;
}
case State::BlockHeader:
{
if (m_bfinal)
{
m_state = State::StreamFooter;
break;
}
if (m_stream.available_bits() < 3)
{
need_more_input = true;
break;
}
m_bfinal = MUST(m_stream.take_bits(1));
switch (MUST(m_stream.take_bits(2)))
{
case 0b00:
m_state = State::LiteralHeader;
break;
case 0b01:
m_length_tree = TRY(HuffmanTree::fixed_tree());
m_distance_tree = {};
m_state = State::Symbol;
break;
case 0b10:
m_state = State::DynamicHeader;
break;
default:
return BAN::Error::from_errno(EINVAL);
}
break;
}
case State::LiteralHeader:
{
if (m_stream.available_bytes() < 4)
{
need_more_input = true;
break;
}
m_stream.skip_to_byte_boundary();
const uint16_t len = MUST(m_stream.take_bits(16));
const uint16_t nlen = MUST(m_stream.take_bits(16));
if (len != 0xFFFF - nlen)
return BAN::Error::from_errno(EINVAL);
m_raw_bytes_left = len;
m_state = State::ReadRaw;
break;
}
case State::DynamicHeader:
{
if (auto ret = handle_dynamic_header(); !ret.is_error())
m_state = State::Symbol;
else
{
if (ret.error().get_error_code() != ENOBUFS)
return ret.release_error();
need_more_input = true;
restore_saved_stream = true;
}
break;
}
case State::ReadRaw:
{
const size_t window_head = (m_window_tail + m_window_size) % total_window_size;
// FIXME: m_raw_bytes_left can be up to 64KB
const size_t max_bytes_to_read = BAN::Math::min<size_t>(m_raw_bytes_left, total_window_size);
const size_t can_read = BAN::Math::min(max_bytes_to_read, m_stream.available_bytes());
const size_t before_wrap = BAN::Math::min(total_window_size - window_head, can_read);
MUST(m_stream.take_byte_aligned(BAN::ByteSpan(m_window.span()).slice(window_head, before_wrap)));
if (const size_t after_wrap = can_read - before_wrap)
MUST(m_stream.take_byte_aligned(BAN::ByteSpan(m_window.span()).slice(0, after_wrap)));
m_window_size += can_read;
m_produced_bytes += can_read;
if (m_window_size > total_window_size)
{
const size_t extra = m_window_size - total_window_size;
m_window_tail = (m_window_tail + extra) % total_window_size;
m_window_size = total_window_size;
}
m_raw_bytes_left -= can_read;
if (m_raw_bytes_left == 0)
m_state = State::BlockHeader;
else if (m_stream.available_bytes() == 0)
need_more_input = true;
break;
}
case State::Symbol:
{
while (m_produced_bytes + 258 < total_window_size && m_state == State::Symbol)
{
saved_stream = m_stream;
if (auto ret = handle_symbol(); ret.is_error())
{
if (ret.error().get_error_code() != ENOBUFS)
return ret.release_error();
need_more_input = true;
restore_saved_stream = true;
break;
}
}
break;
}
}
if (need_more_input)
{
if (restore_saved_stream)
m_stream = saved_stream;
return Status::NeedMoreInput;
}
write_data_to_output(output);
if (m_produced_bytes > 0)
return Status::NeedMoreOutput;
}
return Status::Done;
}
}