Kernel: Fix most of mutex + block race conditions
All block functions now take an optional mutex parameter that is atomically unlocked instead of having the user unlock it before hand. This prevents a ton of race conditions everywhere in the code!
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
#include <kernel/Lock/SpinLockAsMutex.h>
|
||||
#include <kernel/Networking/ARPTable.h>
|
||||
#include <kernel/Scheduler.h>
|
||||
#include <kernel/Timer/Timer.h>
|
||||
@@ -158,16 +159,15 @@ namespace Kernel
|
||||
for (;;)
|
||||
{
|
||||
PendingArpPacket pending = ({
|
||||
auto state = m_pending_lock.lock();
|
||||
SpinLockGuard guard(m_pending_lock);
|
||||
while (m_pending_packets.empty())
|
||||
{
|
||||
m_pending_lock.unlock(state);
|
||||
m_pending_thread_blocker.block_with_timeout_ms(100);
|
||||
state = m_pending_lock.lock();
|
||||
SpinLockGuardAsMutex smutex(guard);
|
||||
m_pending_thread_blocker.block_indefinite(&smutex);
|
||||
}
|
||||
|
||||
auto packet = m_pending_packets.front();
|
||||
m_pending_packets.pop();
|
||||
m_pending_lock.unlock(state);
|
||||
|
||||
packet;
|
||||
});
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include <kernel/Memory/Heap.h>
|
||||
#include <kernel/Memory/PageTable.h>
|
||||
#include <kernel/Lock/SpinLockAsMutex.h>
|
||||
#include <kernel/Networking/ICMP.h>
|
||||
#include <kernel/Networking/IPv4Layer.h>
|
||||
#include <kernel/Networking/NetworkManager.h>
|
||||
@@ -331,16 +332,15 @@ namespace Kernel
|
||||
for (;;)
|
||||
{
|
||||
PendingIPv4Packet pending = ({
|
||||
auto state = m_pending_lock.lock();
|
||||
SpinLockGuard guard(m_pending_lock);
|
||||
while (m_pending_packets.empty())
|
||||
{
|
||||
m_pending_lock.unlock(state);
|
||||
m_pending_thread_blocker.block_with_timeout_ms(100);
|
||||
state = m_pending_lock.lock();
|
||||
SpinLockGuardAsMutex smutex(guard);
|
||||
m_pending_thread_blocker.block_indefinite(&smutex);
|
||||
}
|
||||
|
||||
auto packet = m_pending_packets.front();
|
||||
m_pending_packets.pop();
|
||||
m_pending_lock.unlock(state);
|
||||
|
||||
packet;
|
||||
});
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include <kernel/Lock/SpinLockAsMutex.h>
|
||||
#include <kernel/Networking/NetworkManager.h>
|
||||
#include <kernel/Networking/RTL8169/Definitions.h>
|
||||
#include <kernel/Networking/RTL8169/RTL8169.h>
|
||||
@@ -205,13 +206,18 @@ namespace Kernel
|
||||
return BAN::Error::from_errno(EADDRNOTAVAIL);
|
||||
|
||||
auto state = m_lock.lock();
|
||||
|
||||
const uint32_t tx_current = m_tx_current;
|
||||
m_tx_current = (m_tx_current + 1) % m_tx_descriptor_count;
|
||||
m_lock.unlock(state);
|
||||
|
||||
auto& descriptor = reinterpret_cast<volatile RTL8169Descriptor*>(m_tx_descriptor_region->vaddr())[tx_current];
|
||||
while (descriptor.command & RTL8169_DESC_CMD_OWN)
|
||||
m_thread_blocker.block_with_timeout_ms(100);
|
||||
{
|
||||
SpinLockAsMutex smutex(m_lock, state);
|
||||
m_thread_blocker.block_indefinite(&smutex);
|
||||
}
|
||||
|
||||
m_lock.unlock(state);
|
||||
|
||||
auto* tx_buffer = reinterpret_cast<uint8_t*>(m_tx_buffer_region->vaddr() + tx_current * buffer_size);
|
||||
|
||||
@@ -246,7 +252,10 @@ namespace Kernel
|
||||
}
|
||||
|
||||
if (interrupt_status & RTL8169_IR_TOK)
|
||||
{
|
||||
SpinLockGuard _(m_lock);
|
||||
m_thread_blocker.unblock();
|
||||
}
|
||||
|
||||
if (interrupt_status & RTL8169_IR_RER)
|
||||
dwarnln("Rx error");
|
||||
|
||||
@@ -73,10 +73,7 @@ namespace Kernel
|
||||
return BAN::Error::from_errno(EINVAL);
|
||||
|
||||
while (m_pending_connections.empty())
|
||||
{
|
||||
LockFreeGuard _(m_mutex);
|
||||
TRY(Thread::current().block_or_eintr_or_timeout_ms(m_thread_blocker, 100, false));
|
||||
}
|
||||
TRY(Thread::current().block_or_eintr_indefinite(m_thread_blocker, &m_mutex));
|
||||
|
||||
auto connection = m_pending_connections.front();
|
||||
m_pending_connections.pop();
|
||||
@@ -111,12 +108,7 @@ namespace Kernel
|
||||
|
||||
const uint64_t wake_time_ms = SystemTimer::get().ms_since_boot() + 5000;
|
||||
while (!return_inode->m_has_connected)
|
||||
{
|
||||
if (SystemTimer::get().ms_since_boot() >= wake_time_ms)
|
||||
return BAN::Error::from_errno(ECONNABORTED);
|
||||
LockFreeGuard free(m_mutex);
|
||||
TRY(Thread::current().block_or_eintr_or_waketime_ms(return_inode->m_thread_blocker, wake_time_ms, true));
|
||||
}
|
||||
TRY(Thread::current().block_or_eintr_or_waketime_ms(return_inode->m_thread_blocker, wake_time_ms, true, &m_mutex));
|
||||
|
||||
if (address)
|
||||
{
|
||||
@@ -168,12 +160,7 @@ namespace Kernel
|
||||
|
||||
const uint64_t wake_time_ms = SystemTimer::get().ms_since_boot() + 5000;
|
||||
while (!m_has_connected)
|
||||
{
|
||||
if (SystemTimer::get().ms_since_boot() >= wake_time_ms)
|
||||
return BAN::Error::from_errno(ECONNREFUSED);
|
||||
LockFreeGuard free(m_mutex);
|
||||
TRY(Thread::current().block_or_eintr_or_waketime_ms(m_thread_blocker, wake_time_ms, true));
|
||||
}
|
||||
TRY(Thread::current().block_or_eintr_or_waketime_ms(m_thread_blocker, wake_time_ms, true, &m_mutex));
|
||||
|
||||
return {};
|
||||
}
|
||||
@@ -208,8 +195,7 @@ namespace Kernel
|
||||
{
|
||||
if (m_state != State::Established)
|
||||
return return_with_maybe_zero();
|
||||
LockFreeGuard free(m_mutex);
|
||||
TRY(Thread::current().block_or_eintr_or_timeout_ms(m_thread_blocker, 100, false));
|
||||
TRY(Thread::current().block_or_eintr_indefinite(m_thread_blocker, &m_mutex));
|
||||
}
|
||||
|
||||
const uint32_t to_recv = BAN::Math::min<uint32_t>(buffer.size(), m_recv_window.data_size);
|
||||
@@ -239,8 +225,7 @@ namespace Kernel
|
||||
{
|
||||
if (m_state != State::Established)
|
||||
return return_with_maybe_zero();
|
||||
LockFreeGuard free(m_mutex);
|
||||
TRY(Thread::current().block_or_eintr_or_timeout_ms(m_thread_blocker, 100, false));
|
||||
TRY(Thread::current().block_or_eintr_indefinite(m_thread_blocker, &m_mutex));
|
||||
}
|
||||
|
||||
const size_t to_send = BAN::Math::min<size_t>(message.size(), m_send_window.buffer->size() - m_send_window.data_size);
|
||||
@@ -519,8 +504,10 @@ namespace Kernel
|
||||
}
|
||||
auto socket = it->value;
|
||||
|
||||
LockFreeGuard _(m_mutex);
|
||||
m_mutex.unlock();
|
||||
socket->receive_packet(buffer, sender, sender_len);
|
||||
m_mutex.lock();
|
||||
|
||||
return;
|
||||
}
|
||||
break;
|
||||
@@ -660,116 +647,114 @@ namespace Kernel
|
||||
BAN::RefPtr<TCPSocket> keep_alive { this };
|
||||
this->unref();
|
||||
|
||||
LockGuard _(m_mutex);
|
||||
|
||||
while (m_process)
|
||||
{
|
||||
const uint64_t current_ms = SystemTimer::get().ms_since_boot();
|
||||
|
||||
if (m_state == State::TimeWait && current_ms >= m_time_wait_start_ms + 30'000)
|
||||
{
|
||||
LockGuard _(m_mutex);
|
||||
set_connection_as_closed();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (m_state == State::TimeWait && current_ms >= m_time_wait_start_ms + 30'000)
|
||||
// This is the last instance
|
||||
if (ref_count() == 1)
|
||||
{
|
||||
if (m_state == State::Listen)
|
||||
{
|
||||
set_connection_as_closed();
|
||||
continue;
|
||||
}
|
||||
|
||||
// This is the last instance
|
||||
if (ref_count() == 1)
|
||||
if (m_state == State::Established)
|
||||
{
|
||||
if (m_state == State::Listen)
|
||||
{
|
||||
set_connection_as_closed();
|
||||
continue;
|
||||
}
|
||||
if (m_state == State::Established)
|
||||
{
|
||||
m_next_flags = FIN | ACK;
|
||||
m_next_state = State::FinWait1;
|
||||
}
|
||||
}
|
||||
|
||||
if (m_next_flags)
|
||||
{
|
||||
ASSERT(m_connection_info.has_value());
|
||||
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
|
||||
auto target_address_len = m_connection_info->address_len;
|
||||
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())
|
||||
dwarnln("{}", ret.error());
|
||||
const bool hungup_before = has_hungup_impl();
|
||||
m_state = m_next_state;
|
||||
if (m_state == State::Established)
|
||||
m_has_connected = true;
|
||||
if (!hungup_before && has_hungup_impl())
|
||||
epoll_notify(EPOLLHUP);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (m_send_window.data_size > 0 && m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq)
|
||||
{
|
||||
uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte;
|
||||
ASSERT(acknowledged_bytes <= m_send_window.data_size);
|
||||
|
||||
m_send_window.data_size -= acknowledged_bytes;
|
||||
m_send_window.start_seq += acknowledged_bytes;
|
||||
|
||||
if (m_send_window.data_size > 0)
|
||||
{
|
||||
auto* send_buffer = reinterpret_cast<uint8_t*>(m_send_window.buffer->vaddr());
|
||||
memmove(send_buffer, send_buffer + acknowledged_bytes, m_send_window.data_size);
|
||||
}
|
||||
|
||||
m_send_window.sent_size -= acknowledged_bytes;
|
||||
|
||||
epoll_notify(EPOLLOUT);
|
||||
|
||||
dprintln_if(DEBUG_TCP, "Target acknowledged {} bytes", acknowledged_bytes);
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
const bool should_retransmit = m_send_window.data_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms;
|
||||
|
||||
if (m_send_window.data_size > m_send_window.sent_size || should_retransmit)
|
||||
{
|
||||
ASSERT(m_connection_info.has_value());
|
||||
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
|
||||
auto target_address_len = m_connection_info->address_len;
|
||||
|
||||
const uint32_t send_base = should_retransmit ? 0 : m_send_window.sent_size;
|
||||
|
||||
const uint32_t total_send = BAN::Math::min<uint32_t>(m_send_window.data_size - send_base, m_send_window.scaled_size());
|
||||
|
||||
m_send_window.current_seq = m_send_window.start_seq;
|
||||
|
||||
auto* send_buffer = reinterpret_cast<const uint8_t*>(m_send_window.buffer->vaddr() + send_base);
|
||||
for (uint32_t i = 0; i < total_send;)
|
||||
{
|
||||
const uint32_t to_send = BAN::Math::min(total_send - i, m_send_window.mss);
|
||||
|
||||
auto message = BAN::ConstByteSpan(send_buffer + i, to_send);
|
||||
|
||||
m_next_flags = ACK;
|
||||
if (auto ret = m_network_layer.sendto(*this, message, target_address, target_address_len); ret.is_error())
|
||||
{
|
||||
dwarnln("{}", ret.error());
|
||||
break;
|
||||
}
|
||||
|
||||
dprintln_if(DEBUG_TCP, "Sent {} bytes", to_send);
|
||||
|
||||
m_send_window.sent_size += to_send;
|
||||
m_send_window.current_seq += to_send;
|
||||
i += to_send;
|
||||
}
|
||||
|
||||
m_send_window.last_send_ms = current_ms;
|
||||
|
||||
continue;
|
||||
m_next_flags = FIN | ACK;
|
||||
m_next_state = State::FinWait1;
|
||||
}
|
||||
}
|
||||
|
||||
if (m_next_flags)
|
||||
{
|
||||
ASSERT(m_connection_info.has_value());
|
||||
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
|
||||
auto target_address_len = m_connection_info->address_len;
|
||||
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())
|
||||
dwarnln("{}", ret.error());
|
||||
const bool hungup_before = has_hungup_impl();
|
||||
m_state = m_next_state;
|
||||
if (m_state == State::Established)
|
||||
m_has_connected = true;
|
||||
if (!hungup_before && has_hungup_impl())
|
||||
epoll_notify(EPOLLHUP);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (m_send_window.data_size > 0 && m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq)
|
||||
{
|
||||
uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte;
|
||||
ASSERT(acknowledged_bytes <= m_send_window.data_size);
|
||||
|
||||
m_send_window.data_size -= acknowledged_bytes;
|
||||
m_send_window.start_seq += acknowledged_bytes;
|
||||
|
||||
if (m_send_window.data_size > 0)
|
||||
{
|
||||
auto* send_buffer = reinterpret_cast<uint8_t*>(m_send_window.buffer->vaddr());
|
||||
memmove(send_buffer, send_buffer + acknowledged_bytes, m_send_window.data_size);
|
||||
}
|
||||
|
||||
m_send_window.sent_size -= acknowledged_bytes;
|
||||
|
||||
epoll_notify(EPOLLOUT);
|
||||
|
||||
dprintln_if(DEBUG_TCP, "Target acknowledged {} bytes", acknowledged_bytes);
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
const bool should_retransmit = m_send_window.data_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms;
|
||||
|
||||
if (m_send_window.data_size > m_send_window.sent_size || should_retransmit)
|
||||
{
|
||||
ASSERT(m_connection_info.has_value());
|
||||
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
|
||||
auto target_address_len = m_connection_info->address_len;
|
||||
|
||||
const uint32_t send_base = should_retransmit ? 0 : m_send_window.sent_size;
|
||||
|
||||
const uint32_t total_send = BAN::Math::min<uint32_t>(m_send_window.data_size - send_base, m_send_window.scaled_size());
|
||||
|
||||
m_send_window.current_seq = m_send_window.start_seq;
|
||||
|
||||
auto* send_buffer = reinterpret_cast<const uint8_t*>(m_send_window.buffer->vaddr() + send_base);
|
||||
for (uint32_t i = 0; i < total_send;)
|
||||
{
|
||||
const uint32_t to_send = BAN::Math::min(total_send - i, m_send_window.mss);
|
||||
|
||||
auto message = BAN::ConstByteSpan(send_buffer + i, to_send);
|
||||
|
||||
m_next_flags = ACK;
|
||||
if (auto ret = m_network_layer.sendto(*this, message, target_address, target_address_len); ret.is_error())
|
||||
{
|
||||
dwarnln("{}", ret.error());
|
||||
break;
|
||||
}
|
||||
|
||||
dprintln_if(DEBUG_TCP, "Sent {} bytes", to_send);
|
||||
|
||||
m_send_window.sent_size += to_send;
|
||||
m_send_window.current_seq += to_send;
|
||||
i += to_send;
|
||||
}
|
||||
|
||||
m_send_window.last_send_ms = current_ms;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
m_thread_blocker.unblock();
|
||||
m_thread_blocker.block_with_wake_time_ms(current_ms + retransmit_timeout_ms);
|
||||
m_thread_blocker.block_with_wake_time_ms(current_ms + retransmit_timeout_ms, &m_mutex);
|
||||
}
|
||||
|
||||
m_thread_blocker.unblock();
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include <kernel/Lock/SpinLockAsMutex.h>
|
||||
#include <kernel/Memory/Heap.h>
|
||||
#include <kernel/Networking/UDPSocket.h>
|
||||
#include <kernel/Thread.h>
|
||||
@@ -93,12 +94,12 @@ namespace Kernel
|
||||
}
|
||||
ASSERT(m_port != PORT_NONE);
|
||||
|
||||
auto state = m_packet_lock.lock();
|
||||
SpinLockGuard guard(m_packet_lock);
|
||||
|
||||
while (m_packets.empty())
|
||||
{
|
||||
m_packet_lock.unlock(state);
|
||||
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker));
|
||||
state = m_packet_lock.lock();
|
||||
SpinLockGuardAsMutex smutex(guard);
|
||||
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &smutex));
|
||||
}
|
||||
|
||||
auto packet_info = m_packets.front();
|
||||
@@ -120,8 +121,6 @@ namespace Kernel
|
||||
|
||||
m_packet_total_size -= packet_info.packet_size;
|
||||
|
||||
m_packet_lock.unlock(state);
|
||||
|
||||
if (address && address_len)
|
||||
{
|
||||
if (*address_len > (socklen_t)sizeof(sockaddr_storage))
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include <BAN/HashMap.h>
|
||||
#include <kernel/FS/VirtualFileSystem.h>
|
||||
#include <kernel/Lock/SpinLockAsMutex.h>
|
||||
#include <kernel/Networking/NetworkManager.h>
|
||||
#include <kernel/Networking/UNIX/Socket.h>
|
||||
#include <kernel/Scheduler.h>
|
||||
@@ -16,6 +17,8 @@ namespace Kernel
|
||||
|
||||
static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE;
|
||||
|
||||
// FIXME: why is this using spinlocks instead of mutexes??
|
||||
|
||||
BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> UnixDomainSocket::create(Socket::Type socket_type, const Socket::Info& info)
|
||||
{
|
||||
auto socket = TRY(BAN::RefPtr<UnixDomainSocket>::create(socket_type, info));
|
||||
@@ -91,13 +94,16 @@ namespace Kernel
|
||||
if (!connection_info.listening)
|
||||
return BAN::Error::from_errno(EINVAL);
|
||||
|
||||
while (connection_info.pending_connections.empty())
|
||||
TRY(Thread::current().block_or_eintr_indefinite(connection_info.pending_thread_blocker));
|
||||
|
||||
BAN::RefPtr<UnixDomainSocket> pending;
|
||||
|
||||
{
|
||||
SpinLockGuard _(connection_info.pending_lock);
|
||||
SpinLockGuard guard(connection_info.pending_lock);
|
||||
|
||||
SpinLockGuardAsMutex smutex(guard);
|
||||
while (connection_info.pending_connections.empty())
|
||||
TRY(Thread::current().block_or_eintr_indefinite(connection_info.pending_thread_blocker, &smutex));
|
||||
|
||||
pending = connection_info.pending_connections.front();
|
||||
connection_info.pending_connections.pop();
|
||||
connection_info.pending_thread_blocker.unblock();
|
||||
@@ -176,16 +182,18 @@ namespace Kernel
|
||||
for (;;)
|
||||
{
|
||||
auto& target_info = target->m_info.get<ConnectionInfo>();
|
||||
|
||||
SpinLockGuard guard(target_info.pending_lock);
|
||||
|
||||
if (target_info.pending_connections.size() < target_info.pending_connections.capacity())
|
||||
{
|
||||
SpinLockGuard _(target_info.pending_lock);
|
||||
if (target_info.pending_connections.size() < target_info.pending_connections.capacity())
|
||||
{
|
||||
MUST(target_info.pending_connections.push(this));
|
||||
target_info.pending_thread_blocker.unblock();
|
||||
break;
|
||||
}
|
||||
MUST(target_info.pending_connections.push(this));
|
||||
target_info.pending_thread_blocker.unblock();
|
||||
break;
|
||||
}
|
||||
TRY(Thread::current().block_or_eintr_indefinite(target_info.pending_thread_blocker));
|
||||
|
||||
SpinLockGuardAsMutex smutex(guard);
|
||||
TRY(Thread::current().block_or_eintr_indefinite(target_info.pending_thread_blocker, &smutex));
|
||||
}
|
||||
|
||||
target->epoll_notify(EPOLLIN);
|
||||
@@ -269,9 +277,8 @@ namespace Kernel
|
||||
auto state = m_packet_lock.lock();
|
||||
while (m_packet_sizes.full() || m_packet_size_total + packet.size() > s_packet_buffer_size)
|
||||
{
|
||||
m_packet_lock.unlock(state);
|
||||
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker));
|
||||
state = m_packet_lock.lock();
|
||||
SpinLockAsMutex smutex(m_packet_lock, state);
|
||||
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &smutex));
|
||||
}
|
||||
|
||||
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total);
|
||||
@@ -405,9 +412,8 @@ namespace Kernel
|
||||
}
|
||||
}
|
||||
|
||||
m_packet_lock.unlock(state);
|
||||
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker));
|
||||
state = m_packet_lock.lock();
|
||||
SpinLockAsMutex smutex(m_packet_lock, state);
|
||||
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &smutex));
|
||||
}
|
||||
|
||||
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
|
||||
|
||||
Reference in New Issue
Block a user