BAN: Fix HashSet

This commit is contained in:
2026-05-02 13:12:22 +03:00
parent 1602b195c5
commit 21a2e7fd51

View File

@@ -60,7 +60,7 @@ namespace BAN
void skip_to_valid_bucket() void skip_to_valid_bucket()
{ {
while (!m_bucket->used && !m_bucket->end) while (m_bucket->state != Bucket::USED && !m_bucket->end)
m_bucket++; m_bucket++;
if (m_bucket->end) if (m_bucket->end)
m_bucket = nullptr; m_bucket = nullptr;
@@ -83,10 +83,13 @@ namespace BAN
private: private:
struct Bucket struct Bucket
{ {
static constexpr uint8_t UNUSED = 0;
static constexpr uint8_t USED = 1;
static constexpr uint8_t REMOVED = 2;
alignas(T) uint8_t storage[sizeof(T)]; alignas(T) uint8_t storage[sizeof(T)];
hash_t hash; hash_t hash;
uint8_t used : 1; uint8_t state : 2;
uint8_t removed : 1;
uint8_t chain_start : 1; uint8_t chain_start : 1;
uint8_t end : 1; uint8_t end : 1;
@@ -148,38 +151,7 @@ namespace BAN
{ {
if (should_rehash_with_size(m_size + 1)) if (should_rehash_with_size(m_size + 1))
TRY(rehash(m_size * 2)); TRY(rehash(m_size * 2));
return insert_impl(BAN::move(value), HASH()(value));
bool first = true;
const hash_t orig_hash = HASH()(value);
for (auto hash = orig_hash;; hash = get_next_hash_in_chain(hash, orig_hash), first = false)
{
auto& bucket = m_buckets[hash & (m_capacity - 1)];
if (!first)
bucket.chain_start = false;
if (bucket.used)
{
if (!COMP()(*bucket.element(), value))
continue;
bucket.element()->~T();
new (bucket.element()) T(BAN::move(value));
}
else
{
m_removed -= bucket.removed;
bucket.used = true;
bucket.removed = false;
new (bucket.element()) T(BAN::move(value));
m_size++;
}
if (first)
bucket.chain_start = true;
bucket.hash = orig_hash;
return iterator(&bucket);
}
} }
template<detail::HashSetFindable<T, HASH, COMP> U> template<detail::HashSetFindable<T, HASH, COMP> U>
@@ -193,8 +165,7 @@ namespace BAN
{ {
auto& bucket = *it.m_bucket; auto& bucket = *it.m_bucket;
bucket.element()->~T(); bucket.element()->~T();
bucket.used = false; bucket.state = Bucket::REMOVED;
bucket.removed = true;
m_size--; m_size--;
m_removed++; m_removed++;
return iterator(&bucket); return iterator(&bucket);
@@ -218,11 +189,8 @@ namespace BAN
return; return;
for (size_type i = 0; i < m_capacity; i++) for (size_type i = 0; i < m_capacity; i++)
{ if (m_buckets[i].state == Bucket::USED)
auto& bucket = m_buckets[i]; m_buckets[i].element()->~T();
if (bucket.used)
bucket.element()->~T();
}
BAN::deallocator(m_buckets); BAN::deallocator(m_buckets);
m_buckets = nullptr; m_buckets = nullptr;
@@ -281,9 +249,9 @@ namespace BAN
for (size_type i = 0; i < old_capacity; i++) for (size_type i = 0; i < old_capacity; i++)
{ {
auto& old_bucket = old_buckets[i]; auto& old_bucket = old_buckets[i];
if (!old_bucket.used) if (old_bucket.state != Bucket::USED)
continue; continue;
MUST(insert(BAN::move(*old_bucket.element()))); insert_impl(BAN::move(*old_bucket.element()), old_bucket.hash);
old_bucket.element()->~T(); old_bucket.element()->~T();
} }
@@ -305,15 +273,66 @@ namespace BAN
for (auto hash = orig_hash;; hash = get_next_hash_in_chain(hash, orig_hash), first = false) for (auto hash = orig_hash;; hash = get_next_hash_in_chain(hash, orig_hash), first = false)
{ {
auto& bucket = m_buckets[hash & (m_capacity - 1)]; auto& bucket = m_buckets[hash & (m_capacity - 1)];
if (bucket.used && bucket.hash == orig_hash && COMP()(*bucket.element(), value)) if (bucket.state == Bucket::USED && bucket.hash == orig_hash && COMP()(*bucket.element(), value))
return const_iterator(&bucket); return const_iterator(&bucket);
if (!bucket.used && !bucket.removed) if (bucket.state == Bucket::UNUSED)
return end(); return end();
if (!first && bucket.chain_start) if (!first && bucket.chain_start)
return end(); return end();
} }
} }
iterator insert_impl(T&& value, hash_t orig_hash)
{
ASSERT(!should_rehash_with_size(m_size + 1));
Bucket* target = nullptr;
bool first = true;
for (auto hash = orig_hash;; hash = get_next_hash_in_chain(hash, orig_hash), first = false)
{
auto& bucket = m_buckets[hash & (m_capacity - 1)];
if (!first)
bucket.chain_start = false;
if (bucket.state == Bucket::USED)
{
if (bucket.hash != orig_hash || !COMP()(*bucket.element(), value))
continue;
target = &bucket;
break;
}
if (target == nullptr)
target = &bucket;
if (bucket.state == Bucket::UNUSED)
break;
}
switch (target->state)
{
case Bucket::USED:
target->element()->~T();
break;
case Bucket::REMOVED:
m_removed--;
[[fallthrough]];
case Bucket::UNUSED:
m_size++;
break;
}
target->chain_start = first && target->state == Bucket::UNUSED;
target->hash = orig_hash;
target->state = Bucket::USED;
new (target->element()) T(BAN::move(value));
return iterator(target);
}
bool should_rehash_with_size(size_type size) const bool should_rehash_with_size(size_type size) const
{ {
if (m_capacity < 16) if (m_capacity < 16)