diff --git a/userspace/libraries/LibImage/CMakeLists.txt b/userspace/libraries/LibImage/CMakeLists.txt index b6df7088..99893f3e 100644 --- a/userspace/libraries/LibImage/CMakeLists.txt +++ b/userspace/libraries/LibImage/CMakeLists.txt @@ -7,6 +7,7 @@ set(LIBIMAGE_SOURCES add_library(libimage ${LIBIMAGE_SOURCES}) banan_link_library(libimage ban) banan_link_library(libimage libc) +banan_link_library(libimage libdeflate) banan_install_headers(libimage) install(TARGETS libimage OPTIONAL) diff --git a/userspace/libraries/LibImage/PNG.cpp b/userspace/libraries/LibImage/PNG.cpp index 750f7fdf..544f759f 100644 --- a/userspace/libraries/LibImage/PNG.cpp +++ b/userspace/libraries/LibImage/PNG.cpp @@ -3,13 +3,13 @@ #include +#include + #include #define DEBUG_PNG 0 -// PNG https://www.w3.org/TR/png-3/ -// ZLIB https://www.rfc-editor.org/rfc/rfc1950 -// DEFLATE https://www.rfc-editor.org/rfc/rfc1951 +// https://www.w3.org/TR/png-3/ namespace LibImage { @@ -66,400 +66,12 @@ namespace LibImage InterlaceMethod interlace_method; } __attribute__((packed)); - struct ZLibStream - { - uint8_t cm : 4; - uint8_t cinfo : 4; - uint8_t fcheck : 5; - uint8_t fdict : 1; - uint8_t flevel : 2; - }; - struct PNGChunk { BAN::StringView name; BAN::ConstByteSpan data; }; - class BitBuffer - { - public: - BitBuffer(BAN::Vector data) - : m_data(data) - {} - - BAN::ErrorOr peek_bits(uint8_t count) - { - ASSERT(count <= 16); - - while (m_bit_buffer_len < count) - { - if (m_data.empty()) - return BAN::Error::from_errno(ENODATA); - m_bit_buffer |= m_data[0][0] << m_bit_buffer_len; - m_bit_buffer_len += 8; - if (m_data[0].size() > 1) - m_data[0] = m_data[0].slice(1); - else - m_data.remove(0); - } - - return m_bit_buffer & ((1 << count) - 1); - } - - void remove_bits(uint8_t count) - { - ASSERT(count <= 16); - ASSERT(m_bit_buffer_len >= count); - m_bit_buffer_len -= count; - m_bit_buffer >>= count; - } - - BAN::ErrorOr get_bits(uint8_t count) - { - uint16_t result = TRY(peek_bits(count)); - remove_bits(count); - return result; - } - - void skip_to_byte_boundary() - { - m_bit_buffer >>= m_bit_buffer_len % 8; - m_bit_buffer_len -= m_bit_buffer_len % 8; - } - - private: - BAN::Vector m_data; - uint32_t m_bit_buffer { 0 }; - uint8_t m_bit_buffer_len { 0 }; - }; - - constexpr uint16_t reverse_bits(uint16_t value, uint8_t count) - { - uint16_t reverse = 0; - for (uint8_t bit = 0; bit < count; bit++) - reverse |= ((value >> bit) & 1) << (count - bit - 1); - return reverse; - } - - class HuffmanTree - { - public: - static constexpr uint8_t MAX_BITS = 15; - - struct Leaf - { - uint16_t code; - uint8_t len; - }; - - public: - HuffmanTree() = default; - HuffmanTree(BAN::Vector&& leaves, uint8_t min_len, uint8_t max_len, uint8_t instant_max_bit) - : m_leaves(BAN::move(leaves)) - , m_min_bits(min_len), m_max_bits(max_len) - , m_instant_max_bit(instant_max_bit) - {} - - uint8_t min_bits() const { return m_min_bits; } - uint8_t max_bits() const { return m_max_bits; } - uint8_t instant_max_bit() const { return m_instant_max_bit; } - Leaf get_leaf(size_t index) const { return m_leaves[index]; } - bool empty() const { return m_leaves.empty(); } - - static BAN::ErrorOr create(const BAN::Vector& bit_lengths) - { - uint16_t bl_count[MAX_BITS] {}; - for (uint8_t bl : bit_lengths) - bl_count[bl]++; - bl_count[0] = 0; - - uint8_t min_bits = MAX_BITS; - uint8_t max_bits = 0; - for (uint8_t bits = 0; bits <= MAX_BITS; bits++) - { - if (bit_lengths[bits] == 0) - continue; - min_bits = BAN::Math::min(min_bits, bits); - max_bits = BAN::Math::max(max_bits, bits); - } - - uint8_t instant_max_bit = BAN::Math::min(10, max_bits); - uint16_t instant_mask = (1 << instant_max_bit) - 1; - - uint16_t code = 0; - uint16_t next_code[MAX_BITS + 1] {}; - for (uint8_t bits = 1; bits <= max_bits; bits++) - { - code = (code + bl_count[bits - 1]) << 1; - next_code[bits] = code; - } - - BAN::Vector leaves; - TRY(leaves.resize(1 << max_bits)); - - for (uint16_t n = 0; n < bit_lengths.size(); n++) - { - uint8_t bits = bit_lengths[n]; - if (bits == 0) - continue; - - uint16_t canonical = next_code[bits]; - next_code[bits]++; - - uint16_t reversed = reverse_bits(canonical, bits); - leaves[reversed] = Leaf { n, bits }; - - if (bits <= instant_max_bit) - { - uint16_t step = 1 << bits; - for (uint16_t spread = reversed + step; spread <= instant_mask; spread += step) - leaves[spread] = Leaf { n, bits }; - } - } - - return HuffmanTree(BAN::move(leaves), min_bits, max_bits, instant_max_bit); - } - - static BAN::ErrorOr fixed_tree() - { - BAN::Vector bit_lengths; - TRY(bit_lengths.resize(288)); - size_t i = 0; - for (; i <= 143; i++) bit_lengths[i] = 8; - for (; i <= 255; i++) bit_lengths[i] = 9; - for (; i <= 279; i++) bit_lengths[i] = 7; - for (; i <= 287; i++) bit_lengths[i] = 8; - return TRY(HuffmanTree::create(bit_lengths)); - } - - private: - BAN::Vector m_leaves; - uint8_t m_min_bits { 0 }; - uint8_t m_max_bits { 0 }; - uint8_t m_instant_max_bit { 0 }; - }; - - class DeflateDecoder - { - public: - DeflateDecoder(BAN::Vector data) - : m_buffer(BitBuffer(BAN::move(data))) - {} - - BAN::ErrorOr decode_stream() - { - while (!TRY(decode_block())) - continue; - - m_buffer.skip_to_byte_boundary(); - - uint32_t checksum = 0; - for (int i = 0; i < 4; i++) - checksum = (checksum << 8) | TRY(m_buffer.get_bits(8)); - - if (decoded_adler32() != checksum) - { - dwarnln_if(DEBUG_PNG, "decode checksum does not match"); - return BAN::Error::from_errno(EINVAL); - } - - return BAN::ByteSpan(m_decoded.span()); - } - - private: - uint32_t decoded_adler32() const - { - uint32_t a = 1; - uint32_t b = 0; - - for (uint8_t byte : m_decoded) - { - a = (a + byte) % 65521; - b = (b + a) % 65521; - } - - return (b << 16) | a; - } - - BAN::ErrorOr decode_block() - { - bool bfinal = TRY(m_buffer.get_bits(1)); - uint8_t btype = TRY(m_buffer.get_bits(2)); - - switch (btype) - { - case 0: TRY(decode_type0()); break; - case 1: TRY(decode_type1()); break; - case 2: TRY(decode_type2()); break; - default: - dwarnln_if(DEBUG_PNG, "Deflate block has invalid method {}", btype); - return BAN::Error::from_errno(EINVAL); - } - - return bfinal; - } - - BAN::ErrorOr decode_type0() - { - m_buffer.skip_to_byte_boundary(); - - uint16_t len = TRY(m_buffer.get_bits(16)); - uint16_t nlen = TRY(m_buffer.get_bits(16)); - if (len != 0xFFFF - nlen) - { - dwarnln_if(DEBUG_PNG, "Deflate block uncompressed data length is invalid"); - return BAN::Error::from_errno(EINVAL); - } - - TRY(m_decoded.reserve(m_decoded.size() + len)); - for (uint16_t i = 0; i < len; i++) - MUST(m_decoded.push_back(TRY(m_buffer.get_bits(8)))); - - return {}; - } - - BAN::ErrorOr decode_type1() - { - TRY(inflate_block(TRY(HuffmanTree::fixed_tree()), HuffmanTree())); - return {}; - } - - BAN::ErrorOr decode_type2() - { - static constexpr uint8_t code_length_order[] { - 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 - }; - - const uint16_t hlit = TRY(m_buffer.get_bits(5)) + 257; - const uint8_t hdist = TRY(m_buffer.get_bits(5)) + 1; - const uint8_t hclen = TRY(m_buffer.get_bits(4)) + 4; - - HuffmanTree code_length_tree; - { - BAN::Vector code_lengths; - TRY(code_lengths.resize(19, 0)); - for (uint8_t i = 0; i < hclen; i++) - code_lengths[code_length_order[i]] = TRY(m_buffer.get_bits(3)); - code_length_tree = TRY(HuffmanTree::create(code_lengths)); - } - - uint16_t last_symbol = 0; - BAN::Vector bit_lengths; - TRY(bit_lengths.reserve(288 + 32)); - while (bit_lengths.size() < hlit + hdist) - { - uint16_t symbol = TRY(read_symbol(code_length_tree)); - uint8_t count = 0; - - if (symbol <= 15) - { - count = 1; - } - else if (symbol == 16) - { - symbol = last_symbol; - count = TRY(m_buffer.get_bits(2)) + 3; - } - else if (symbol == 17) - { - symbol = 0; - count = TRY(m_buffer.get_bits(3)) + 3; - } - else if (symbol == 18) - { - symbol = 0; - count = TRY(m_buffer.get_bits(7)) + 11; - } - - for (uint8_t i = 0; i < count; i++) - TRY(bit_lengths.push_back(symbol)); - last_symbol = symbol; - } - - TRY(bit_lengths.resize(hlit + 32, 0)); - - BAN::Vector distance_lengths; - TRY(distance_lengths.resize(32)); - for (uint8_t i = 0; i < 32; i++) - distance_lengths[i] = bit_lengths[hlit + i]; - - TRY(bit_lengths.resize(hlit)); - TRY(bit_lengths.resize(288, 0)); - - TRY(inflate_block(TRY(HuffmanTree::create(bit_lengths)), TRY(HuffmanTree::create(distance_lengths)))); - return {}; - } - - BAN::ErrorOr inflate_block(const HuffmanTree& length_tree, const HuffmanTree& distance_tree) - { - static constexpr uint16_t length_base[] { - 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258 - }; - static constexpr uint8_t extra_length_bits[] { - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 - }; - - static constexpr uint16_t distance_base[] { - 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577 - }; - static constexpr uint8_t extra_distance_bits[] { - 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13 - }; - - uint16_t symbol; - while ((symbol = TRY(read_symbol(length_tree))) != 256) - { - if (symbol < 256) - { - TRY(m_decoded.push_back(symbol)); - continue; - } - - ASSERT(symbol <= 285); - symbol -= 257; - - const uint16_t length = length_base[symbol] + TRY(m_buffer.get_bits(extra_length_bits[symbol])); - - uint16_t distance_code; - if (distance_tree.empty()) - distance_code = reverse_bits(TRY(m_buffer.get_bits(5)), 5); - else - distance_code = TRY(read_symbol(distance_tree)); - ASSERT(distance_code <= 30); - - const size_t distance = distance_base[distance_code] + TRY(m_buffer.get_bits(extra_distance_bits[distance_code])); - - size_t offset = m_decoded.size() - distance; - for (size_t i = 0; i < length; i++) - TRY(m_decoded.push_back(m_decoded[offset + i])); - } - - return {}; - } - - BAN::ErrorOr read_symbol(const HuffmanTree& tree) - { - uint16_t compare = TRY(m_buffer.peek_bits(tree.max_bits())); - for (uint8_t bits = tree.instant_max_bit(); bits <= tree.max_bits(); bits++) - { - uint16_t mask = (1 << bits) - 1; - auto leaf = tree.get_leaf(compare & mask); - - if (leaf.len <= bits) - { - m_buffer.remove_bits(leaf.len); - return leaf.code; - } - } - return BAN::Error::from_errno(EINVAL); - } - - private: - BAN::Vector m_decoded; - BitBuffer m_buffer; - }; - BAN::ErrorOr read_and_take_chunk(BAN::ConstByteSpan& image_data) { if (image_data.size() < 12) @@ -545,7 +157,7 @@ namespace LibImage const auto extract_channel = [&](auto& bit_buffer) -> uint8_t { - uint16_t tmp = MUST(bit_buffer.get_bits(bits_per_channel)); + uint16_t tmp = MUST(bit_buffer.take_bits(bits_per_channel)); switch (bits_per_channel) { case 1: return tmp * 0xFF; @@ -576,7 +188,7 @@ namespace LibImage color.a = 0xFF; break; case ColourType::IndexedColour: - color = palette[MUST(bit_buffer.get_bits(bits_per_channel))]; + color = palette[MUST(bit_buffer.take_bits(bits_per_channel))]; break; case ColourType::GreyscaleAlpha: color.r = extract_channel(bit_buffer); @@ -620,9 +232,6 @@ namespace LibImage BAN::Vector zero_scanline; TRY(zero_scanline.resize(bytes_per_scanline, 0)); - BAN::Vector encoded_data_wrapper; - TRY(encoded_data_wrapper.push_back({})); - const uint8_t filter_offset = (bits_per_channel < 8) ? 1 : channels * (bits_per_channel / 8); for (uint64_t y = 0; y < image_height; y++) @@ -660,11 +269,9 @@ namespace LibImage return BAN::Error::from_errno(EINVAL); } - encoded_data_wrapper[0] = scanline; - BitBuffer bit_buffer(encoded_data_wrapper); - + LibDEFLATE::BitInputStream bit_stream(scanline); for (uint64_t x = 0; x < image_width; x++) - color_bitmap[y * image_width + x] = extract_color(bit_buffer); + color_bitmap[y * image_width + x] = extract_color(bit_stream); } return pitch * image_height; @@ -813,30 +420,27 @@ namespace LibImage } } - { - if (zlib_stream.empty() || zlib_stream.front().size() < 2) - { - dwarnln_if(DEBUG_PNG, "PNG does not have zlib stream"); - return BAN::Error::from_errno(EINVAL); - } - if (zlib_stream[0].as>() % 31) - { - dwarnln_if(DEBUG_PNG, "PNG zlib stream checksum failed"); - return BAN::Error::from_errno(EINVAL); - } + BAN::Vector zlib_stream_buf; + BAN::ConstByteSpan zlib_stream_span; - auto zlib_header = zlib_stream[0].as(); - if (zlib_header.fdict) + if (zlib_stream.empty()) + { + dwarnln_if(DEBUG_PNG, "PNG does not have zlib stream"); + return BAN::Error::from_errno(EINVAL); + } + + if (zlib_stream.size() == 1) + zlib_stream_span = zlib_stream.front(); + else + { + for (auto stream : zlib_stream) { - dwarnln_if(DEBUG_PNG, "PNG IDAT zlib stream has fdict set"); - return BAN::Error::from_errno(EINVAL); + const size_t old_size = zlib_stream_buf.size(); + TRY(zlib_stream_buf.resize(old_size + stream.size())); + for (size_t i = 0; i < stream.size(); i++) + zlib_stream_buf[old_size + i] = stream[i]; } - if (zlib_header.cm != 8) - { - dwarnln_if(DEBUG_PNG, "PNG IDAT has invalid zlib compression method {}", (uint8_t)zlib_header.cm); - return BAN::Error::from_errno(EINVAL); - } - zlib_stream[0] = zlib_stream[0].slice(2); + zlib_stream_span = zlib_stream_buf.span(); } uint64_t total_size = 0; @@ -844,8 +448,9 @@ namespace LibImage total_size += stream.size(); dprintln_if(DEBUG_PNG, "PNG has {} byte zlib stream", total_size); - DeflateDecoder decoder(BAN::move(zlib_stream)); - auto inflated_data = TRY(decoder.decode_stream()); + LibDEFLATE::Decompressor decompressor(zlib_stream_span, LibDEFLATE::StreamType::Zlib); + auto inflated_buffer = TRY(decompressor.decompress()); + auto inflated_data = inflated_buffer.span(); dprintln_if(DEBUG_PNG, " uncompressed size {}", inflated_data.size()); dprintln_if(DEBUG_PNG, " compression ratio {}", (double)inflated_data.size() / total_size);