Kernel: Rework socket binding to an address

Sockets are no longer bound to an interface, but an ipv4 address. This
allows servers at 0.0.0.0 talk to multiple different interfaces
This commit is contained in:
2025-12-30 16:11:06 +02:00
parent efdbd1576f
commit f06e5d33e7
7 changed files with 263 additions and 94 deletions

View File

@@ -75,35 +75,58 @@ namespace Kernel
m_bound_sockets.remove(it);
}
BAN::ErrorOr<void> IPv4Layer::bind_socket_to_unused(BAN::RefPtr<NetworkSocket> socket, const sockaddr* address, socklen_t address_len)
BAN::ErrorOr<in_port_t> IPv4Layer::find_free_port()
{
if (!address || address_len < (socklen_t)sizeof(sockaddr_in))
return BAN::Error::from_errno(EINVAL);
if (address->sa_family != AF_INET)
return BAN::Error::from_errno(EAFNOSUPPORT);
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(address);
SpinLockGuard _(m_bound_socket_lock);
uint16_t port = NetworkSocket::PORT_NONE;
for (uint32_t i = 0; i < 100 && port == NetworkSocket::PORT_NONE; i++)
if (uint32_t temp = 0xC000 | (Random::get_u32() & 0x3FFF); !m_bound_sockets.contains(temp))
port = temp;
for (uint32_t temp = 0xC000; temp < 0xFFFF && port == NetworkSocket::PORT_NONE; temp++)
if (!m_bound_sockets.contains(temp))
port = temp;
if (port == NetworkSocket::PORT_NONE)
{
dwarnln("No ports available");
return BAN::Error::from_errno(EAGAIN);
}
dprintln_if(DEBUG_IPV4, "using port {}", port);
for (uint32_t i = 0; i < 100; i++)
if (uint32_t port = 0xC000 | (Random::get_u32() & 0x3FFF); !m_bound_sockets.contains(port))
return port;
struct sockaddr_in target;
target.sin_family = AF_INET;
target.sin_port = BAN::host_to_network_endian(port);
target.sin_addr.s_addr = sockaddr_in.sin_addr.s_addr;
return bind_socket_to_address(socket, (sockaddr*)&target, sizeof(sockaddr_in));
for (uint32_t port = 0xC000; port < 0xFFFF; port++)
if (!m_bound_sockets.contains(port))
return port;
dwarnln("No ports available");
return BAN::Error::from_errno(EAGAIN);
}
BAN::ErrorOr<void> IPv4Layer::bind_socket_with_target(BAN::RefPtr<NetworkSocket> socket, const sockaddr* target, socklen_t target_len)
{
if (!target || target_len < (socklen_t)sizeof(sockaddr_in))
return BAN::Error::from_errno(EINVAL);
if (target->sa_family != AF_INET)
return BAN::Error::from_errno(EAFNOSUPPORT);
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(target);
auto interface =
TRY([&sockaddr_in]() -> BAN::ErrorOr<BAN::RefPtr<NetworkInterface>> {
const auto ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr };
// try to find an interface in the same subnet
const auto& all_interfaces = NetworkManager::get().interfaces();
for (const auto& interface : all_interfaces)
{
const auto netmask = interface->get_netmask();
if (ipv4.mask(netmask) == interface->get_ipv4_address().mask(netmask))
return interface;
}
// fallback to non-loopback interface
// FIXME: make sure target is reachable
for (const auto& interface : all_interfaces)
if (interface->type() != NetworkInterface::Type::Loopback)
return interface;
return BAN::Error::from_errno(EHOSTUNREACH);
}());
// FIXME: race condition with port allocation/binding
struct sockaddr_in bind_address;
bind_address.sin_family = AF_INET;
bind_address.sin_port = BAN::host_to_network_endian(TRY(find_free_port()));
bind_address.sin_addr.s_addr = interface->get_ipv4_address().raw;
return bind_socket_to_address(socket, (sockaddr*)&bind_address, sizeof(bind_address));
}
BAN::ErrorOr<void> IPv4Layer::bind_socket_to_address(BAN::RefPtr<NetworkSocket> socket, const sockaddr* address, socklen_t address_len)
@@ -114,33 +137,47 @@ namespace Kernel
return BAN::Error::from_errno(EAFNOSUPPORT);
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(address);
const uint16_t port = BAN::host_to_network_endian(sockaddr_in.sin_port);
if (port == NetworkSocket::PORT_NONE)
return bind_socket_to_unused(socket, address, address_len);
const auto ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr };
BAN::RefPtr<NetworkInterface> bind_interface;
for (auto interface : NetworkManager::get().interfaces())
{
if (interface->type() != NetworkInterface::Type::Loopback)
bind_interface = interface;
const auto netmask = interface->get_netmask();
if (ipv4.mask(netmask) != interface->get_ipv4_address().mask(netmask))
continue;
bind_interface = interface;
break;
}
TRY([&sockaddr_in]() -> BAN::ErrorOr<void> {
const auto ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr };
if (ipv4 == 0)
return {};
const auto& all_interfaces = NetworkManager::get().interfaces();
for (const auto& interface : all_interfaces)
{
switch (interface->type())
{
case NetworkInterface::Type::Ethernet:
if (ipv4 == interface->get_ipv4_address())
return {};
break;
case NetworkInterface::Type::Loopback:
const auto netmask = interface->get_netmask();
if (ipv4.mask(netmask) == interface->get_ipv4_address().mask(netmask))
return {};
break;
}
}
if (!bind_interface)
return BAN::Error::from_errno(EADDRNOTAVAIL);
}());
struct sockaddr_in bind_address;
memcpy(&bind_address, address, sizeof(sockaddr_in));
SpinLockGuard _(m_bound_socket_lock);
if (bind_address.sin_port == 0)
bind_address.sin_port = TRY(find_free_port());
const uint16_t port = BAN::host_to_network_endian(bind_address.sin_port);
if (m_bound_sockets.contains(port))
return BAN::Error::from_errno(EADDRINUSE);
TRY(m_bound_sockets.insert(port, TRY(socket->get_weak_ptr())));
socket->bind_interface_and_port(bind_interface.ptr(), port);
socket->bind_address_and_port(reinterpret_cast<struct sockaddr*>(&bind_address), sizeof(bind_address));
return {};
}
@@ -173,18 +210,36 @@ namespace Kernel
return BAN::Error::from_errno(EINVAL);
if (address == nullptr || address_len != sizeof(sockaddr_in))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(address);
auto interface = TRY(socket.interface(address, address_len));
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(address);
auto dst_port = BAN::host_to_network_endian(sockaddr_in.sin_port);
auto dst_ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr };
auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(socket.interface(), dst_ipv4));
auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(*interface, dst_ipv4));
if (interface->type() == NetworkInterface::Type::Loopback)
{
BAN::RefPtr<NetworkSocket> receiver;
{
SpinLockGuard _(m_bound_socket_lock);
auto receiver_it = m_bound_sockets.find(dst_port);
if (receiver_it != m_bound_sockets.end())
receiver = receiver_it->value.lock();
}
if (!receiver)
return BAN::Error::from_errno(EADDRNOTAVAIL);
TRY(socket.interface(receiver->address(), receiver->address_len()));
}
BAN::Vector<uint8_t> packet_buffer;
TRY(packet_buffer.resize(buffer.size() + sizeof(IPv4Header) + socket.protocol_header_size()));
auto packet = BAN::ByteSpan { packet_buffer.span() };
auto pseudo_header = PseudoHeader {
.src_ipv4 = socket.interface().get_ipv4_address(),
.src_ipv4 = interface->get_ipv4_address(),
.dst_ipv4 = dst_ipv4,
.protocol = socket.protocol()
};
@@ -201,12 +256,12 @@ namespace Kernel
);
add_ipv4_header(
packet,
socket.interface().get_ipv4_address(),
interface->get_ipv4_address(),
dst_ipv4,
socket.protocol()
);
TRY(socket.interface().send_bytes(dst_mac, EtherType::IPv4, packet));
TRY(interface->send_bytes(dst_mac, EtherType::IPv4, packet));
return buffer.size();
}

View File

@@ -15,12 +15,97 @@ namespace Kernel
{
}
void NetworkSocket::bind_interface_and_port(NetworkInterface* interface, uint16_t port)
bool NetworkSocket::can_interface_send_to(const NetworkInterface& interface, const sockaddr* target, socklen_t target_len) const
{
ASSERT(!m_interface);
ASSERT(interface);
m_interface = interface;
m_port = port;
ASSERT(target);
ASSERT(target_len >= static_cast<socklen_t>(sizeof(sockaddr_in)));
ASSERT(target->sa_family == AF_INET);
const auto target_ipv4 = BAN::IPv4Address {
reinterpret_cast<const sockaddr_in*>(target)->sin_addr.s_addr
};
switch (interface.type())
{
case NetworkInterface::Type::Ethernet:
// FIXME: this is not really correct :D
return target_ipv4.octets[0] != IN_LOOPBACKNET;
case NetworkInterface::Type::Loopback:
return target_ipv4.octets[0] == IN_LOOPBACKNET;
}
ASSERT_NOT_REACHED();
}
BAN::ErrorOr<BAN::RefPtr<NetworkInterface>> NetworkSocket::interface(const sockaddr* target, socklen_t target_len)
{
ASSERT(m_network_layer.domain() == NetworkSocket::Domain::INET);
ASSERT(is_bound());
if (target != nullptr)
{
ASSERT(target_len >= static_cast<socklen_t>(sizeof(sockaddr_in)));
ASSERT(target->sa_family == AF_INET);
}
const auto& all_interfaces = NetworkManager::get().interfaces();
const auto bound_ipv4 = BAN::IPv4Address {
reinterpret_cast<const sockaddr_in*>(&m_address)->sin_addr.s_addr
};
// find the bound interface
if (bound_ipv4 != 0)
{
for (const auto& interface : all_interfaces)
{
const auto netmask = interface->get_netmask();
if (bound_ipv4.mask(netmask) != interface->get_ipv4_address().mask(netmask))
continue;
if (target && !can_interface_send_to(*interface, target, target_len))
continue;
return interface;
}
return BAN::Error::from_errno(EADDRNOTAVAIL);
}
// try to find an interface in the same subnet as target
if (target != nullptr)
{
const auto target_ipv4 = BAN::IPv4Address {
reinterpret_cast<const sockaddr_in*>(target)->sin_addr.s_addr
};
for (const auto& interface : all_interfaces)
{
const auto netmask = interface->get_netmask();
if (target_ipv4.mask(netmask) == interface->get_ipv4_address().mask(netmask))
return interface;
}
}
// return any interface (prefer non-loopback)
for (const auto& interface : all_interfaces)
if (interface->type() != NetworkInterface::Type::Loopback)
if (!target || can_interface_send_to(*interface, target, target_len))
return interface;
for (const auto& interface : all_interfaces)
if (interface->type() == NetworkInterface::Type::Loopback)
if (!target || can_interface_send_to(*interface, target, target_len))
return interface;
return BAN::Error::from_errno(EHOSTUNREACH);
}
void NetworkSocket::bind_address_and_port(const sockaddr* addr, socklen_t addr_len)
{
ASSERT(!is_bound());
ASSERT(addr->sa_family != AF_UNSPEC);
ASSERT(addr_len <= static_cast<socklen_t>(sizeof(sockaddr_storage)));
memcpy(&m_address, addr, addr_len);
m_address_len = addr_len;
}
BAN::ErrorOr<long> NetworkSocket::ioctl_impl(int request, void* arg)
@@ -30,12 +115,8 @@ namespace Kernel
dprintln("No argument provided");
return BAN::Error::from_errno(EINVAL);
}
if (m_interface == nullptr)
{
dprintln("No interface bound");
return BAN::Error::from_errno(EADDRNOTAVAIL);
}
auto interface = TRY(this->interface(nullptr, 0));
auto* ifreq = reinterpret_cast<struct ifreq*>(arg);
switch (request)
@@ -44,7 +125,7 @@ namespace Kernel
{
auto& ifru_addr = *reinterpret_cast<sockaddr_in*>(&ifreq->ifr_ifru.ifru_addr);
ifru_addr.sin_family = AF_INET;
ifru_addr.sin_addr.s_addr = m_interface->get_ipv4_address().raw;
ifru_addr.sin_addr.s_addr = interface->get_ipv4_address().raw;
return 0;
}
case SIOCSIFADDR:
@@ -52,15 +133,15 @@ namespace Kernel
auto& ifru_addr = *reinterpret_cast<const sockaddr_in*>(&ifreq->ifr_ifru.ifru_addr);
if (ifru_addr.sin_family != AF_INET)
return BAN::Error::from_errno(EADDRNOTAVAIL);
m_interface->set_ipv4_address(BAN::IPv4Address { ifru_addr.sin_addr.s_addr });
dprintln("IPv4 address set to {}", m_interface->get_ipv4_address());
interface->set_ipv4_address(BAN::IPv4Address { ifru_addr.sin_addr.s_addr });
dprintln("IPv4 address set to {}", interface->get_ipv4_address());
return 0;
}
case SIOCGIFNETMASK:
{
auto& ifru_netmask = *reinterpret_cast<sockaddr_in*>(&ifreq->ifr_ifru.ifru_netmask);
ifru_netmask.sin_family = AF_INET;
ifru_netmask.sin_addr.s_addr = m_interface->get_netmask().raw;
ifru_netmask.sin_addr.s_addr = interface->get_netmask().raw;
return 0;
}
case SIOCSIFNETMASK:
@@ -68,15 +149,15 @@ namespace Kernel
auto& ifru_netmask = *reinterpret_cast<const sockaddr_in*>(&ifreq->ifr_ifru.ifru_netmask);
if (ifru_netmask.sin_family != AF_INET)
return BAN::Error::from_errno(EADDRNOTAVAIL);
m_interface->set_netmask(BAN::IPv4Address { ifru_netmask.sin_addr.s_addr });
dprintln("Netmask set to {}", m_interface->get_netmask());
interface->set_netmask(BAN::IPv4Address { ifru_netmask.sin_addr.s_addr });
dprintln("Netmask set to {}", interface->get_netmask());
return 0;
}
case SIOCGIFGWADDR:
{
auto& ifru_gwaddr = *reinterpret_cast<sockaddr_in*>(&ifreq->ifr_ifru.ifru_gwaddr);
ifru_gwaddr.sin_family = AF_INET;
ifru_gwaddr.sin_addr.s_addr = m_interface->get_gateway().raw;
ifru_gwaddr.sin_addr.s_addr = interface->get_gateway().raw;
return 0;
}
case SIOCSIFGWADDR:
@@ -84,13 +165,13 @@ namespace Kernel
auto& ifru_gwaddr = *reinterpret_cast<const sockaddr_in*>(&ifreq->ifr_ifru.ifru_gwaddr);
if (ifru_gwaddr.sin_family != AF_INET)
return BAN::Error::from_errno(EADDRNOTAVAIL);
m_interface->set_gateway(BAN::IPv4Address { ifru_gwaddr.sin_addr.s_addr });
dprintln("Gateway set to {}", m_interface->get_gateway());
interface->set_gateway(BAN::IPv4Address { ifru_gwaddr.sin_addr.s_addr });
dprintln("Gateway set to {}", interface->get_gateway());
return 0;
}
case SIOCGIFHWADDR:
{
auto mac_address = m_interface->get_mac_address();
auto mac_address = interface->get_mac_address();
ifreq->ifr_ifru.ifru_hwaddr.sa_family = AF_INET;
memcpy(ifreq->ifr_ifru.ifru_hwaddr.sa_data, &mac_address, sizeof(mac_address));
return 0;
@@ -98,9 +179,9 @@ namespace Kernel
case SIOCGIFNAME:
{
auto& ifrn_name = ifreq->ifr_ifrn.ifrn_name;
ASSERT(m_interface->name().size() < sizeof(ifrn_name));
memcpy(ifrn_name, m_interface->name().data(), m_interface->name().size());
ifrn_name[m_interface->name().size()] = '\0';
ASSERT(interface->name().size() < sizeof(ifrn_name));
memcpy(ifrn_name, interface->name().data(), interface->name().size());
ifrn_name[interface->name().size()] = '\0';
return 0;
}
default:

View File

@@ -95,8 +95,8 @@ namespace Kernel
}
return_inode->m_mutex.lock();
return_inode->m_port = m_port;
return_inode->m_interface = m_interface;
memcpy(&return_inode->m_address, &connection.target.address, connection.target.address_len);
return_inode->m_address_len = connection.target.address_len;
return_inode->m_listen_parent = this;
return_inode->m_connection_info.emplace(connection.target);
return_inode->m_recv_window.start_seq = connection.target_start_seq;
@@ -152,13 +152,18 @@ namespace Kernel
};
if (!is_bound())
TRY(m_network_layer.bind_socket_to_unused(this, address, address_len));
TRY(m_network_layer.bind_socket_with_target(this, address, address_len));
m_connection_info.emplace(sockaddr_storage {}, address_len, true);
memcpy(&m_connection_info->address, address, address_len);
m_next_flags = SYN;
TRY(m_network_layer.sendto(*this, {}, address, address_len));
if (m_network_layer.sendto(*this, {}, address, address_len).is_error())
{
set_connection_as_closed();
return BAN::Error::from_errno(ECONNREFUSED);
}
m_next_flags = 0;
m_state = State::SynSent;
@@ -410,8 +415,8 @@ namespace Kernel
memset(&header, 0, sizeof(TCPHeader));
memset(header.options, TCPOption::End, m_tcp_options_bytes);
header.src_port = bound_port();
header.dst_port = dst_port;
header.src_port = m_port;
header.seq_number = m_send_window.current_seq + m_send_window.has_ghost_byte;
header.ack_number = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte;
header.data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t);
@@ -423,7 +428,15 @@ namespace Kernel
if (m_state == State::Closed || m_state == State::SynReceived)
{
add_tcp_header_option<0, TCPOption::MaximumSeqmentSize>(header, m_interface->payload_mtu() - m_network_layer.header_size());
const sockaddr_in target {
.sin_family = AF_INET,
.sin_port = dst_port,
.sin_addr = { .s_addr = pseudo_header.dst_ipv4.raw },
.sin_zero = {},
};
auto interface = MUST(this->interface(reinterpret_cast<const sockaddr*>(&target), sizeof(target)));
add_tcp_header_option<0, TCPOption::MaximumSeqmentSize>(header, interface->payload_mtu() - m_network_layer.header_size());
if (m_connection_info->has_window_scale)
add_tcp_header_option<4, TCPOption::WindowScale>(header, m_recv_window.scale_shift);
@@ -451,11 +464,16 @@ namespace Kernel
if (sender->sa_family == AF_INET)
{
auto interface_or_error = interface(sender, sender_len);
if (interface_or_error.is_error())
return;
auto interface = interface_or_error.release_value();
auto& addr_in = *reinterpret_cast<const sockaddr_in*>(sender);
checksum = calculate_internet_checksum(buffer,
PseudoHeader {
.src_ipv4 = BAN::IPv4Address(addr_in.sin_addr.s_addr),
.dst_ipv4 = m_interface->get_ipv4_address(),
.dst_ipv4 = interface->get_ipv4_address(),
.protocol = NetworkProtocol::TCP,
.extra = buffer.size()
}
@@ -663,11 +681,11 @@ namespace Kernel
// NOTE: Only listen socket can unbind the socket as
// listen socket is always alive to redirect packets
if (!m_listen_parent)
m_network_layer.unbind_socket(m_port);
m_network_layer.unbind_socket(bound_port());
else
m_listen_parent->remove_listen_child(this);
m_interface = nullptr;
m_port = PORT_NONE;
m_address.ss_family = AF_UNSPEC;
m_address_len = 0;
dprintln_if(DEBUG_TCP, "Socket unbound");
}

View File

@@ -30,15 +30,15 @@ namespace Kernel
UDPSocket::~UDPSocket()
{
if (is_bound())
m_network_layer.unbind_socket(m_port);
m_port = PORT_NONE;
m_interface = nullptr;
m_network_layer.unbind_socket(bound_port());
m_address.ss_family = AF_UNSPEC;
m_address_len = 0;
}
void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader)
{
auto& header = packet.as<UDPHeader>();
header.src_port = m_port;
header.src_port = bound_port();
header.dst_port = dst_port;
header.length = packet.size();
header.checksum = 0;
@@ -115,7 +115,6 @@ namespace Kernel
dprintln("No interface bound");
return BAN::Error::from_errno(EINVAL);
}
ASSERT(m_port != PORT_NONE);
SpinLockGuard guard(m_packet_lock);
@@ -176,7 +175,7 @@ namespace Kernel
dwarnln("ignoring sendmsg control message");
if (!is_bound())
TRY(m_network_layer.bind_socket_to_unused(this, static_cast<sockaddr*>(message.msg_name), message.msg_namelen));
TRY(m_network_layer.bind_socket_with_target(this, static_cast<sockaddr*>(message.msg_name), message.msg_namelen));
const size_t total_send_size =
[&message]() -> size_t {