From fe613e427439ba7890235907c703d341ef776e95 Mon Sep 17 00:00:00 2001 From: Bananymous Date: Tue, 21 Apr 2026 00:06:46 +0300 Subject: [PATCH] BAN: Rewrite HashSet Instead of representing the map as vector or linked lists which required an allocation for every insertion and deallocation for removal, we now store a single big contiguous block of memory and use hash chains to handle collisions. This intuitively feels much better although I did not run any benchmarks. --- BAN/include/BAN/HashSet.h | 484 ++++++++++++++++++++++++-------------- 1 file changed, 313 insertions(+), 171 deletions(-) diff --git a/BAN/include/BAN/HashSet.h b/BAN/include/BAN/HashSet.h index d8a5305e..3883eb38 100644 --- a/BAN/include/BAN/HashSet.h +++ b/BAN/include/BAN/HashSet.h @@ -2,198 +2,340 @@ #include #include -#include -#include #include #include -#include +#include namespace BAN { - template> - class HashSet + template + class HashSetIterator { public: - using value_type = T; - using size_type = size_t; - using iterator = IteratorDouble; - using const_iterator = ConstIteratorDouble; + HashSetIterator() = default; + + T& operator*() + { + ASSERT(m_bucket); + return *m_bucket->element(); + } + const T& operator*() const + { + ASSERT(m_bucket); + return *m_bucket->element(); + } + + T* operator->() + { + ASSERT(m_bucket); + return m_bucket->element(); + } + const T* operator->() const + { + ASSERT(m_bucket); + return m_bucket->element(); + } + + HashSetIterator& operator++() + { + ASSERT(m_bucket); + m_bucket++; + skip_to_valid_bucket(); + return *this; + } + HashSetIterator operator++(int) + { + auto temp = *this; + ++(*this); + return temp; + } + + bool operator==(HashSetIterator other) const + { + return m_bucket == other.m_bucket; + } + bool operator!=(HashSetIterator other) const + { + return m_bucket != other.m_bucket; + } + + private: + explicit HashSetIterator(Bucket* bucket) + : m_bucket(bucket) + { + if (m_bucket != nullptr) + skip_to_valid_bucket(); + } + + void skip_to_valid_bucket() + { + while (!m_bucket->used && !m_bucket->end) + m_bucket++; + if (m_bucket->end) + m_bucket = nullptr; + } + + private: + Bucket* m_bucket { nullptr }; + friend HashSet; + }; + + template, typename COMP = BAN::equal> + class HashSet + { + private: + struct Bucket + { + alignas(T) uint8_t storage[sizeof(T)]; + hash_t hash; + uint8_t used : 1; + uint8_t removed : 1; + uint8_t chain_start : 1; + uint8_t end : 1; + + T* element() { return reinterpret_cast(storage); } + const T* element() const { return reinterpret_cast(storage); } + }; + + public: + using value_type = T; + using size_type = size_t; + using iterator = HashSetIterator; + using const_iterator = HashSetIterator; public: HashSet() = default; - HashSet(const HashSet&); - HashSet(HashSet&&); + ~HashSet() { clear(); } - HashSet& operator=(const HashSet&); - HashSet& operator=(HashSet&&); + HashSet(const HashSet& other) { *this = other; } + HashSet& operator=(const HashSet& other) + { + clear(); - ErrorOr insert(const T&); - ErrorOr insert(T&&); - void remove(const T&); - void clear(); + MUST(reserve(other.size())); + for (auto& bucket : other) + MUST(insert(bucket)); - ErrorOr reserve(size_type); + return *this; + } - iterator begin() { return iterator(m_buckets.end(), m_buckets.begin()); } - iterator end() { return iterator(m_buckets.end(), m_buckets.end()); } - const_iterator begin() const { return const_iterator(m_buckets.end(), m_buckets.begin()); } - const_iterator end() const { return const_iterator(m_buckets.end(), m_buckets.end()); } + HashSet(HashSet&& other) { *this = BAN::move(other); } + HashSet& operator=(HashSet&& other) + { + clear(); - bool contains(const T&) const; + m_buckets = other.m_buckets; + m_capacity = other.m_capacity; + m_size = other.m_size; + m_removed = other.m_removed; - size_type size() const; - bool empty() const; + other.m_buckets = nullptr; + other.m_capacity = 0; + other.m_size = 0; + other.m_removed = 0; + + return *this; + } + + iterator begin() { return iterator(m_buckets); } + iterator end() { return iterator(nullptr); } + const_iterator begin() const { return const_iterator(m_buckets); } + const_iterator end() const { return const_iterator(nullptr); } + + ErrorOr insert(const T& value) + { + return insert(T(value)); + } + + ErrorOr insert(T&& value) + { + if (should_rehash_with_size(m_size + 1)) + TRY(rehash(m_size * 2)); + + bool first = true; + const hash_t orig_hash = HASH()(value); + for (auto hash = orig_hash;; hash = get_next_hash_in_chain(hash, orig_hash), first = false) + { + auto& bucket = m_buckets[hash & (m_capacity - 1)]; + + if (!first) + bucket.chain_start = false; + + if (bucket.used) + { + if (!COMP()(*bucket.element(), value)) + continue; + *bucket.element() = BAN::move(value); + } + else + { + m_removed -= bucket.removed; + bucket.used = true; + bucket.removed = false; + new (bucket.element()) T(BAN::move(value)); + m_size++; + } + + if (first) + bucket.chain_start = true; + bucket.hash = orig_hash; + + return iterator(&bucket); + } + } + + void remove(const T& value) + { + if (auto it = find(value); it != end()) + remove(it); + } + + iterator remove(iterator it) + { + auto& bucket = *it.m_bucket; + bucket.element()->~T(); + bucket.used = false; + bucket.removed = true; + m_size--; + m_removed++; + return iterator(&bucket); + } + + template + iterator find(const U& value) + { + return iterator(const_cast(find_impl(value).m_bucket)); + } + + template + const_iterator find(const U& value) const + { + return find_impl(value); + } + + void clear() + { + if (m_buckets == nullptr) + return; + + for (size_type i = 0; i < m_capacity; i++) + { + auto& bucket = m_buckets[i]; + if (bucket.used) + bucket.element()->~T(); + } + + BAN::deallocator(m_buckets); + m_buckets = nullptr; + m_capacity = 0; + m_size = 0; + m_removed = 0; + } + + ErrorOr reserve(size_type size) + { + if (should_rehash_with_size(size)) + TRY(rehash(size * 2)); + return {}; + } + + bool contains(const T& value) const + { + return find(value) != end(); + } + + size_type capacity() const + { + return m_capacity; + } + + size_type size() const + { + return m_size; + } + + bool empty() const + { + return m_size == 0; + } private: - ErrorOr rebucket(size_type); - LinkedList& get_bucket(const T&); - const LinkedList& get_bucket(const T&) const; + ErrorOr rehash(size_type new_capacity) + { + new_capacity = BAN::Math::max(16, BAN::Math::max(new_capacity, m_size + 1)); + new_capacity = BAN::Math::round_up_to_power_of_two(new_capacity); + + void* new_buckets = BAN::allocator((new_capacity + 1) * sizeof(Bucket)); + if (new_buckets == nullptr) + return BAN::Error::from_errno(ENOMEM); + memset(new_buckets, 0, (new_capacity + 1) * sizeof(Bucket)); + + Bucket* old_buckets = m_buckets; + const size_type old_capacity = m_capacity; + + m_buckets = static_cast(new_buckets); + m_capacity = new_capacity; + m_size = 0; + m_removed = 0; + + for (size_type i = 0; i < old_capacity; i++) + { + auto& old_bucket = old_buckets[i]; + if (!old_bucket.used) + continue; + MUST(insert(BAN::move(*old_bucket.element()))); + old_bucket.element()->~T(); + } + + m_buckets[m_capacity].end = true; + + BAN::deallocator(old_buckets); + + return {}; + } + + template requires requires(const T& a, const U& b) { COMP()(a, b); HASH()(b); } + const_iterator find_impl(const U& value) const + { + if (m_capacity == 0) + return end(); + + bool first = true; + const hash_t orig_hash = HASH()(value); + for (auto hash = orig_hash;; hash = get_next_hash_in_chain(hash, orig_hash), first = false) + { + auto& bucket = m_buckets[hash & (m_capacity - 1)]; + if (bucket.used && bucket.hash == orig_hash && COMP()(*bucket.element(), value)) + return const_iterator(&bucket); + if (!bucket.used && !bucket.removed) + return end(); + if (!first && bucket.chain_start) + return end(); + } + } + + bool should_rehash_with_size(size_type size) const + { + if (m_capacity < 16) + return true; + if (size + m_removed > m_capacity / 4 * 3) + return true; + return false; + } + + hash_t get_next_hash_in_chain(hash_t prev_hash, hash_t orig_hash) const + { + // TODO: does this even provide better performance than `return prev_hash + 1` + // when using "good" hash functions + return prev_hash * 1103515245 + (orig_hash | 1); + } private: - Vector> m_buckets; - size_type m_size = 0; + Bucket* m_buckets { nullptr }; + size_type m_capacity { 0 }; + size_type m_size { 0 }; + size_type m_removed { 0 }; }; - template - HashSet::HashSet(const HashSet& other) - : m_buckets(other.m_buckets) - , m_size(other.m_size) - { - } - - template - HashSet::HashSet(HashSet&& other) - : m_buckets(move(other.m_buckets)) - , m_size(other.m_size) - { - other.clear(); - } - - template - HashSet& HashSet::operator=(const HashSet& other) - { - clear(); - m_buckets = other.m_buckets; - m_size = other.m_size; - return *this; - } - - template - HashSet& HashSet::operator=(HashSet&& other) - { - clear(); - m_buckets = move(other.m_buckets); - m_size = other.m_size; - other.clear(); - return *this; - } - - template - ErrorOr HashSet::insert(const T& key) - { - return insert(move(T(key))); - } - - template - ErrorOr HashSet::insert(T&& key) - { - if (!empty() && get_bucket(key).contains(key)) - return {}; - - TRY(rebucket(m_size + 1)); - TRY(get_bucket(key).push_back(move(key))); - m_size++; - return {}; - } - - template - void HashSet::remove(const T& key) - { - if (empty()) return; - auto& bucket = get_bucket(key); - for (auto it = bucket.begin(); it != bucket.end(); it++) - { - if (*it == key) - { - bucket.remove(it); - m_size--; - break; - } - } - } - - template - void HashSet::clear() - { - m_buckets.clear(); - m_size = 0; - } - - template - ErrorOr HashSet::reserve(size_type size) - { - TRY(rebucket(size)); - return {}; - } - - template - bool HashSet::contains(const T& key) const - { - if (empty()) return false; - return get_bucket(key).contains(key); - } - - template - typename HashSet::size_type HashSet::size() const - { - return m_size; - } - - template - bool HashSet::empty() const - { - return m_size == 0; - } - - template - ErrorOr HashSet::rebucket(size_type bucket_count) - { - if (m_buckets.size() >= bucket_count) - return {}; - - size_type new_bucket_count = Math::max(bucket_count, m_buckets.size() * 2); - Vector> new_buckets; - if (new_buckets.resize(new_bucket_count).is_error()) - return Error::from_errno(ENOMEM); - - for (auto& bucket : m_buckets) - { - for (auto it = bucket.begin(); it != bucket.end();) - { - size_type new_bucket_index = HASH()(*it) % new_buckets.size(); - it = bucket.move_element_to_other_linked_list(new_buckets[new_bucket_index], new_buckets[new_bucket_index].end(), it); - } - } - - m_buckets = move(new_buckets); - return {}; - } - - template - LinkedList& HashSet::get_bucket(const T& key) - { - ASSERT(!m_buckets.empty()); - size_type index = HASH()(key) % m_buckets.size(); - return m_buckets[index]; - } - - template - const LinkedList& HashSet::get_bucket(const T& key) const - { - ASSERT(!m_buckets.empty()); - size_type index = HASH()(key) % m_buckets.size(); - return m_buckets[index]; - } - }