diff --git a/kernel/include/kernel/Networking/IPv4Layer.h b/kernel/include/kernel/Networking/IPv4Layer.h index 9f3f51e4..b0a0bc8e 100644 --- a/kernel/include/kernel/Networking/IPv4Layer.h +++ b/kernel/include/kernel/Networking/IPv4Layer.h @@ -45,7 +45,7 @@ namespace Kernel void add_ipv4_packet(NetworkInterface&, BAN::ConstByteSpan); virtual void unbind_socket(uint16_t port) override; - virtual BAN::ErrorOr bind_socket_to_unused(BAN::RefPtr, const sockaddr* send_address, socklen_t send_address_len) override; + virtual BAN::ErrorOr bind_socket_with_target(BAN::RefPtr, const sockaddr* target_address, socklen_t target_address_len) override; virtual BAN::ErrorOr bind_socket_to_address(BAN::RefPtr, const sockaddr* address, socklen_t address_len) override; virtual BAN::ErrorOr get_socket_address(BAN::RefPtr, sockaddr* address, socklen_t* address_len) override; @@ -59,6 +59,8 @@ namespace Kernel void add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol) const; + BAN::ErrorOr find_free_port(); + void packet_handle_task(); BAN::ErrorOr handle_ipv4_packet(NetworkInterface&, BAN::ByteSpan); diff --git a/kernel/include/kernel/Networking/NetworkLayer.h b/kernel/include/kernel/Networking/NetworkLayer.h index 299c7477..54183603 100644 --- a/kernel/include/kernel/Networking/NetworkLayer.h +++ b/kernel/include/kernel/Networking/NetworkLayer.h @@ -23,7 +23,7 @@ namespace Kernel virtual ~NetworkLayer() {} virtual void unbind_socket(uint16_t port) = 0; - virtual BAN::ErrorOr bind_socket_to_unused(BAN::RefPtr, const sockaddr* send_address, socklen_t send_address_len) = 0; + virtual BAN::ErrorOr bind_socket_with_target(BAN::RefPtr, const sockaddr* target_address, socklen_t target_address_len) = 0; virtual BAN::ErrorOr bind_socket_to_address(BAN::RefPtr, const sockaddr* address, socklen_t address_len) = 0; virtual BAN::ErrorOr get_socket_address(BAN::RefPtr, sockaddr* address, socklen_t* address_len) = 0; diff --git a/kernel/include/kernel/Networking/NetworkSocket.h b/kernel/include/kernel/Networking/NetworkSocket.h index 66175a53..6c33fb0a 100644 --- a/kernel/include/kernel/Networking/NetworkSocket.h +++ b/kernel/include/kernel/Networking/NetworkSocket.h @@ -5,6 +5,8 @@ #include #include +#include + namespace Kernel { @@ -24,10 +26,10 @@ namespace Kernel static constexpr uint16_t PORT_NONE = 0; public: - void bind_interface_and_port(NetworkInterface*, uint16_t port); + void bind_address_and_port(const sockaddr*, socklen_t); ~NetworkSocket(); - NetworkInterface& interface() { ASSERT(m_interface); return *m_interface; } + BAN::ErrorOr> interface(const sockaddr* target, socklen_t target_len); virtual size_t protocol_header_size() const = 0; virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) = 0; @@ -35,7 +37,19 @@ namespace Kernel virtual void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) = 0; - bool is_bound() const { return m_interface != nullptr; } + bool is_bound() const { return m_address_len >= static_cast(sizeof(sa_family_t)) && m_address.ss_family != AF_UNSPEC; } + in_port_t bound_port() const + { + ASSERT(is_bound()); + ASSERT(m_address.ss_family == AF_INET && m_address_len >= static_cast(sizeof(sockaddr_in))); + return BAN::network_endian_to_host(reinterpret_cast(&m_address)->sin_port); + } + + const sockaddr* address() const { return reinterpret_cast(&m_address); } + socklen_t address_len() const { return m_address_len; } + + private: + bool can_interface_send_to(const NetworkInterface&, const sockaddr*, socklen_t) const; protected: NetworkSocket(NetworkLayer&, const Socket::Info&); @@ -45,9 +59,9 @@ namespace Kernel virtual BAN::ErrorOr getpeername_impl(sockaddr*, socklen_t*) override = 0; protected: - NetworkLayer& m_network_layer; - NetworkInterface* m_interface = nullptr; - uint16_t m_port { PORT_NONE }; + NetworkLayer& m_network_layer; + sockaddr_storage m_address { .ss_family = AF_UNSPEC, .ss_storage = {} }; + socklen_t m_address_len { 0 }; }; } diff --git a/kernel/kernel/Networking/IPv4Layer.cpp b/kernel/kernel/Networking/IPv4Layer.cpp index 0304db5c..5165f73d 100644 --- a/kernel/kernel/Networking/IPv4Layer.cpp +++ b/kernel/kernel/Networking/IPv4Layer.cpp @@ -75,35 +75,58 @@ namespace Kernel m_bound_sockets.remove(it); } - BAN::ErrorOr IPv4Layer::bind_socket_to_unused(BAN::RefPtr socket, const sockaddr* address, socklen_t address_len) + BAN::ErrorOr IPv4Layer::find_free_port() { - if (!address || address_len < (socklen_t)sizeof(sockaddr_in)) - return BAN::Error::from_errno(EINVAL); - if (address->sa_family != AF_INET) - return BAN::Error::from_errno(EAFNOSUPPORT); - auto& sockaddr_in = *reinterpret_cast(address); - SpinLockGuard _(m_bound_socket_lock); - uint16_t port = NetworkSocket::PORT_NONE; - for (uint32_t i = 0; i < 100 && port == NetworkSocket::PORT_NONE; i++) - if (uint32_t temp = 0xC000 | (Random::get_u32() & 0x3FFF); !m_bound_sockets.contains(temp)) - port = temp; - for (uint32_t temp = 0xC000; temp < 0xFFFF && port == NetworkSocket::PORT_NONE; temp++) - if (!m_bound_sockets.contains(temp)) - port = temp; - if (port == NetworkSocket::PORT_NONE) - { - dwarnln("No ports available"); - return BAN::Error::from_errno(EAGAIN); - } - dprintln_if(DEBUG_IPV4, "using port {}", port); + for (uint32_t i = 0; i < 100; i++) + if (uint32_t port = 0xC000 | (Random::get_u32() & 0x3FFF); !m_bound_sockets.contains(port)) + return port; - struct sockaddr_in target; - target.sin_family = AF_INET; - target.sin_port = BAN::host_to_network_endian(port); - target.sin_addr.s_addr = sockaddr_in.sin_addr.s_addr; - return bind_socket_to_address(socket, (sockaddr*)&target, sizeof(sockaddr_in)); + for (uint32_t port = 0xC000; port < 0xFFFF; port++) + if (!m_bound_sockets.contains(port)) + return port; + + dwarnln("No ports available"); + return BAN::Error::from_errno(EAGAIN); + } + + BAN::ErrorOr IPv4Layer::bind_socket_with_target(BAN::RefPtr socket, const sockaddr* target, socklen_t target_len) + { + if (!target || target_len < (socklen_t)sizeof(sockaddr_in)) + return BAN::Error::from_errno(EINVAL); + if (target->sa_family != AF_INET) + return BAN::Error::from_errno(EAFNOSUPPORT); + auto& sockaddr_in = *reinterpret_cast(target); + + auto interface = + TRY([&sockaddr_in]() -> BAN::ErrorOr> { + const auto ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr }; + + // try to find an interface in the same subnet + const auto& all_interfaces = NetworkManager::get().interfaces(); + for (const auto& interface : all_interfaces) + { + const auto netmask = interface->get_netmask(); + if (ipv4.mask(netmask) == interface->get_ipv4_address().mask(netmask)) + return interface; + } + + // fallback to non-loopback interface + // FIXME: make sure target is reachable + for (const auto& interface : all_interfaces) + if (interface->type() != NetworkInterface::Type::Loopback) + return interface; + + return BAN::Error::from_errno(EHOSTUNREACH); + }()); + + // FIXME: race condition with port allocation/binding + struct sockaddr_in bind_address; + bind_address.sin_family = AF_INET; + bind_address.sin_port = BAN::host_to_network_endian(TRY(find_free_port())); + bind_address.sin_addr.s_addr = interface->get_ipv4_address().raw; + return bind_socket_to_address(socket, (sockaddr*)&bind_address, sizeof(bind_address)); } BAN::ErrorOr IPv4Layer::bind_socket_to_address(BAN::RefPtr socket, const sockaddr* address, socklen_t address_len) @@ -114,33 +137,47 @@ namespace Kernel return BAN::Error::from_errno(EAFNOSUPPORT); auto& sockaddr_in = *reinterpret_cast(address); - const uint16_t port = BAN::host_to_network_endian(sockaddr_in.sin_port); - if (port == NetworkSocket::PORT_NONE) - return bind_socket_to_unused(socket, address, address_len); - const auto ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr }; - BAN::RefPtr bind_interface; - for (auto interface : NetworkManager::get().interfaces()) - { - if (interface->type() != NetworkInterface::Type::Loopback) - bind_interface = interface; - const auto netmask = interface->get_netmask(); - if (ipv4.mask(netmask) != interface->get_ipv4_address().mask(netmask)) - continue; - bind_interface = interface; - break; - } + TRY([&sockaddr_in]() -> BAN::ErrorOr { + const auto ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr }; + + if (ipv4 == 0) + return {}; + + const auto& all_interfaces = NetworkManager::get().interfaces(); + for (const auto& interface : all_interfaces) + { + switch (interface->type()) + { + case NetworkInterface::Type::Ethernet: + if (ipv4 == interface->get_ipv4_address()) + return {}; + break; + case NetworkInterface::Type::Loopback: + const auto netmask = interface->get_netmask(); + if (ipv4.mask(netmask) == interface->get_ipv4_address().mask(netmask)) + return {}; + break; + } + } - if (!bind_interface) return BAN::Error::from_errno(EADDRNOTAVAIL); + }()); + + struct sockaddr_in bind_address; + memcpy(&bind_address, address, sizeof(sockaddr_in)); SpinLockGuard _(m_bound_socket_lock); + if (bind_address.sin_port == 0) + bind_address.sin_port = TRY(find_free_port()); + const uint16_t port = BAN::host_to_network_endian(bind_address.sin_port); + if (m_bound_sockets.contains(port)) return BAN::Error::from_errno(EADDRINUSE); TRY(m_bound_sockets.insert(port, TRY(socket->get_weak_ptr()))); - socket->bind_interface_and_port(bind_interface.ptr(), port); + socket->bind_address_and_port(reinterpret_cast(&bind_address), sizeof(bind_address)); return {}; } @@ -173,18 +210,36 @@ namespace Kernel return BAN::Error::from_errno(EINVAL); if (address == nullptr || address_len != sizeof(sockaddr_in)) return BAN::Error::from_errno(EINVAL); - auto& sockaddr_in = *reinterpret_cast(address); + auto interface = TRY(socket.interface(address, address_len)); + + auto& sockaddr_in = *reinterpret_cast(address); auto dst_port = BAN::host_to_network_endian(sockaddr_in.sin_port); auto dst_ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr }; - auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(socket.interface(), dst_ipv4)); + auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(*interface, dst_ipv4)); + + if (interface->type() == NetworkInterface::Type::Loopback) + { + BAN::RefPtr receiver; + + { + SpinLockGuard _(m_bound_socket_lock); + auto receiver_it = m_bound_sockets.find(dst_port); + if (receiver_it != m_bound_sockets.end()) + receiver = receiver_it->value.lock(); + } + + if (!receiver) + return BAN::Error::from_errno(EADDRNOTAVAIL); + TRY(socket.interface(receiver->address(), receiver->address_len())); + } BAN::Vector packet_buffer; TRY(packet_buffer.resize(buffer.size() + sizeof(IPv4Header) + socket.protocol_header_size())); auto packet = BAN::ByteSpan { packet_buffer.span() }; auto pseudo_header = PseudoHeader { - .src_ipv4 = socket.interface().get_ipv4_address(), + .src_ipv4 = interface->get_ipv4_address(), .dst_ipv4 = dst_ipv4, .protocol = socket.protocol() }; @@ -201,12 +256,12 @@ namespace Kernel ); add_ipv4_header( packet, - socket.interface().get_ipv4_address(), + interface->get_ipv4_address(), dst_ipv4, socket.protocol() ); - TRY(socket.interface().send_bytes(dst_mac, EtherType::IPv4, packet)); + TRY(interface->send_bytes(dst_mac, EtherType::IPv4, packet)); return buffer.size(); } diff --git a/kernel/kernel/Networking/NetworkSocket.cpp b/kernel/kernel/Networking/NetworkSocket.cpp index b3fb906b..b689eeaa 100644 --- a/kernel/kernel/Networking/NetworkSocket.cpp +++ b/kernel/kernel/Networking/NetworkSocket.cpp @@ -15,12 +15,97 @@ namespace Kernel { } - void NetworkSocket::bind_interface_and_port(NetworkInterface* interface, uint16_t port) + bool NetworkSocket::can_interface_send_to(const NetworkInterface& interface, const sockaddr* target, socklen_t target_len) const { - ASSERT(!m_interface); - ASSERT(interface); - m_interface = interface; - m_port = port; + ASSERT(target); + ASSERT(target_len >= static_cast(sizeof(sockaddr_in))); + ASSERT(target->sa_family == AF_INET); + + const auto target_ipv4 = BAN::IPv4Address { + reinterpret_cast(target)->sin_addr.s_addr + }; + + switch (interface.type()) + { + case NetworkInterface::Type::Ethernet: + // FIXME: this is not really correct :D + return target_ipv4.octets[0] != IN_LOOPBACKNET; + case NetworkInterface::Type::Loopback: + return target_ipv4.octets[0] == IN_LOOPBACKNET; + } + + ASSERT_NOT_REACHED(); + } + + BAN::ErrorOr> NetworkSocket::interface(const sockaddr* target, socklen_t target_len) + { + ASSERT(m_network_layer.domain() == NetworkSocket::Domain::INET); + ASSERT(is_bound()); + + if (target != nullptr) + { + ASSERT(target_len >= static_cast(sizeof(sockaddr_in))); + ASSERT(target->sa_family == AF_INET); + } + + const auto& all_interfaces = NetworkManager::get().interfaces(); + + const auto bound_ipv4 = BAN::IPv4Address { + reinterpret_cast(&m_address)->sin_addr.s_addr + }; + + // find the bound interface + if (bound_ipv4 != 0) + { + for (const auto& interface : all_interfaces) + { + const auto netmask = interface->get_netmask(); + if (bound_ipv4.mask(netmask) != interface->get_ipv4_address().mask(netmask)) + continue; + if (target && !can_interface_send_to(*interface, target, target_len)) + continue; + return interface; + } + + return BAN::Error::from_errno(EADDRNOTAVAIL); + } + + // try to find an interface in the same subnet as target + if (target != nullptr) + { + const auto target_ipv4 = BAN::IPv4Address { + reinterpret_cast(target)->sin_addr.s_addr + }; + + for (const auto& interface : all_interfaces) + { + const auto netmask = interface->get_netmask(); + if (target_ipv4.mask(netmask) == interface->get_ipv4_address().mask(netmask)) + return interface; + } + } + + // return any interface (prefer non-loopback) + for (const auto& interface : all_interfaces) + if (interface->type() != NetworkInterface::Type::Loopback) + if (!target || can_interface_send_to(*interface, target, target_len)) + return interface; + for (const auto& interface : all_interfaces) + if (interface->type() == NetworkInterface::Type::Loopback) + if (!target || can_interface_send_to(*interface, target, target_len)) + return interface; + + return BAN::Error::from_errno(EHOSTUNREACH); + } + + void NetworkSocket::bind_address_and_port(const sockaddr* addr, socklen_t addr_len) + { + ASSERT(!is_bound()); + ASSERT(addr->sa_family != AF_UNSPEC); + ASSERT(addr_len <= static_cast(sizeof(sockaddr_storage))); + + memcpy(&m_address, addr, addr_len); + m_address_len = addr_len; } BAN::ErrorOr NetworkSocket::ioctl_impl(int request, void* arg) @@ -30,12 +115,8 @@ namespace Kernel dprintln("No argument provided"); return BAN::Error::from_errno(EINVAL); } - if (m_interface == nullptr) - { - dprintln("No interface bound"); - return BAN::Error::from_errno(EADDRNOTAVAIL); - } + auto interface = TRY(this->interface(nullptr, 0)); auto* ifreq = reinterpret_cast(arg); switch (request) @@ -44,7 +125,7 @@ namespace Kernel { auto& ifru_addr = *reinterpret_cast(&ifreq->ifr_ifru.ifru_addr); ifru_addr.sin_family = AF_INET; - ifru_addr.sin_addr.s_addr = m_interface->get_ipv4_address().raw; + ifru_addr.sin_addr.s_addr = interface->get_ipv4_address().raw; return 0; } case SIOCSIFADDR: @@ -52,15 +133,15 @@ namespace Kernel auto& ifru_addr = *reinterpret_cast(&ifreq->ifr_ifru.ifru_addr); if (ifru_addr.sin_family != AF_INET) return BAN::Error::from_errno(EADDRNOTAVAIL); - m_interface->set_ipv4_address(BAN::IPv4Address { ifru_addr.sin_addr.s_addr }); - dprintln("IPv4 address set to {}", m_interface->get_ipv4_address()); + interface->set_ipv4_address(BAN::IPv4Address { ifru_addr.sin_addr.s_addr }); + dprintln("IPv4 address set to {}", interface->get_ipv4_address()); return 0; } case SIOCGIFNETMASK: { auto& ifru_netmask = *reinterpret_cast(&ifreq->ifr_ifru.ifru_netmask); ifru_netmask.sin_family = AF_INET; - ifru_netmask.sin_addr.s_addr = m_interface->get_netmask().raw; + ifru_netmask.sin_addr.s_addr = interface->get_netmask().raw; return 0; } case SIOCSIFNETMASK: @@ -68,15 +149,15 @@ namespace Kernel auto& ifru_netmask = *reinterpret_cast(&ifreq->ifr_ifru.ifru_netmask); if (ifru_netmask.sin_family != AF_INET) return BAN::Error::from_errno(EADDRNOTAVAIL); - m_interface->set_netmask(BAN::IPv4Address { ifru_netmask.sin_addr.s_addr }); - dprintln("Netmask set to {}", m_interface->get_netmask()); + interface->set_netmask(BAN::IPv4Address { ifru_netmask.sin_addr.s_addr }); + dprintln("Netmask set to {}", interface->get_netmask()); return 0; } case SIOCGIFGWADDR: { auto& ifru_gwaddr = *reinterpret_cast(&ifreq->ifr_ifru.ifru_gwaddr); ifru_gwaddr.sin_family = AF_INET; - ifru_gwaddr.sin_addr.s_addr = m_interface->get_gateway().raw; + ifru_gwaddr.sin_addr.s_addr = interface->get_gateway().raw; return 0; } case SIOCSIFGWADDR: @@ -84,13 +165,13 @@ namespace Kernel auto& ifru_gwaddr = *reinterpret_cast(&ifreq->ifr_ifru.ifru_gwaddr); if (ifru_gwaddr.sin_family != AF_INET) return BAN::Error::from_errno(EADDRNOTAVAIL); - m_interface->set_gateway(BAN::IPv4Address { ifru_gwaddr.sin_addr.s_addr }); - dprintln("Gateway set to {}", m_interface->get_gateway()); + interface->set_gateway(BAN::IPv4Address { ifru_gwaddr.sin_addr.s_addr }); + dprintln("Gateway set to {}", interface->get_gateway()); return 0; } case SIOCGIFHWADDR: { - auto mac_address = m_interface->get_mac_address(); + auto mac_address = interface->get_mac_address(); ifreq->ifr_ifru.ifru_hwaddr.sa_family = AF_INET; memcpy(ifreq->ifr_ifru.ifru_hwaddr.sa_data, &mac_address, sizeof(mac_address)); return 0; @@ -98,9 +179,9 @@ namespace Kernel case SIOCGIFNAME: { auto& ifrn_name = ifreq->ifr_ifrn.ifrn_name; - ASSERT(m_interface->name().size() < sizeof(ifrn_name)); - memcpy(ifrn_name, m_interface->name().data(), m_interface->name().size()); - ifrn_name[m_interface->name().size()] = '\0'; + ASSERT(interface->name().size() < sizeof(ifrn_name)); + memcpy(ifrn_name, interface->name().data(), interface->name().size()); + ifrn_name[interface->name().size()] = '\0'; return 0; } default: diff --git a/kernel/kernel/Networking/TCPSocket.cpp b/kernel/kernel/Networking/TCPSocket.cpp index 377e8ca6..9d041be1 100644 --- a/kernel/kernel/Networking/TCPSocket.cpp +++ b/kernel/kernel/Networking/TCPSocket.cpp @@ -95,8 +95,8 @@ namespace Kernel } return_inode->m_mutex.lock(); - return_inode->m_port = m_port; - return_inode->m_interface = m_interface; + memcpy(&return_inode->m_address, &connection.target.address, connection.target.address_len); + return_inode->m_address_len = connection.target.address_len; return_inode->m_listen_parent = this; return_inode->m_connection_info.emplace(connection.target); return_inode->m_recv_window.start_seq = connection.target_start_seq; @@ -152,13 +152,18 @@ namespace Kernel }; if (!is_bound()) - TRY(m_network_layer.bind_socket_to_unused(this, address, address_len)); + TRY(m_network_layer.bind_socket_with_target(this, address, address_len)); m_connection_info.emplace(sockaddr_storage {}, address_len, true); memcpy(&m_connection_info->address, address, address_len); m_next_flags = SYN; - TRY(m_network_layer.sendto(*this, {}, address, address_len)); + if (m_network_layer.sendto(*this, {}, address, address_len).is_error()) + { + set_connection_as_closed(); + return BAN::Error::from_errno(ECONNREFUSED); + } + m_next_flags = 0; m_state = State::SynSent; @@ -410,8 +415,8 @@ namespace Kernel memset(&header, 0, sizeof(TCPHeader)); memset(header.options, TCPOption::End, m_tcp_options_bytes); + header.src_port = bound_port(); header.dst_port = dst_port; - header.src_port = m_port; header.seq_number = m_send_window.current_seq + m_send_window.has_ghost_byte; header.ack_number = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte; header.data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t); @@ -423,7 +428,15 @@ namespace Kernel if (m_state == State::Closed || m_state == State::SynReceived) { - add_tcp_header_option<0, TCPOption::MaximumSeqmentSize>(header, m_interface->payload_mtu() - m_network_layer.header_size()); + const sockaddr_in target { + .sin_family = AF_INET, + .sin_port = dst_port, + .sin_addr = { .s_addr = pseudo_header.dst_ipv4.raw }, + .sin_zero = {}, + }; + auto interface = MUST(this->interface(reinterpret_cast(&target), sizeof(target))); + + add_tcp_header_option<0, TCPOption::MaximumSeqmentSize>(header, interface->payload_mtu() - m_network_layer.header_size()); if (m_connection_info->has_window_scale) add_tcp_header_option<4, TCPOption::WindowScale>(header, m_recv_window.scale_shift); @@ -451,11 +464,16 @@ namespace Kernel if (sender->sa_family == AF_INET) { + auto interface_or_error = interface(sender, sender_len); + if (interface_or_error.is_error()) + return; + auto interface = interface_or_error.release_value(); + auto& addr_in = *reinterpret_cast(sender); checksum = calculate_internet_checksum(buffer, PseudoHeader { .src_ipv4 = BAN::IPv4Address(addr_in.sin_addr.s_addr), - .dst_ipv4 = m_interface->get_ipv4_address(), + .dst_ipv4 = interface->get_ipv4_address(), .protocol = NetworkProtocol::TCP, .extra = buffer.size() } @@ -663,11 +681,11 @@ namespace Kernel // NOTE: Only listen socket can unbind the socket as // listen socket is always alive to redirect packets if (!m_listen_parent) - m_network_layer.unbind_socket(m_port); + m_network_layer.unbind_socket(bound_port()); else m_listen_parent->remove_listen_child(this); - m_interface = nullptr; - m_port = PORT_NONE; + m_address.ss_family = AF_UNSPEC; + m_address_len = 0; dprintln_if(DEBUG_TCP, "Socket unbound"); } diff --git a/kernel/kernel/Networking/UDPSocket.cpp b/kernel/kernel/Networking/UDPSocket.cpp index caf94682..c1795502 100644 --- a/kernel/kernel/Networking/UDPSocket.cpp +++ b/kernel/kernel/Networking/UDPSocket.cpp @@ -30,15 +30,15 @@ namespace Kernel UDPSocket::~UDPSocket() { if (is_bound()) - m_network_layer.unbind_socket(m_port); - m_port = PORT_NONE; - m_interface = nullptr; + m_network_layer.unbind_socket(bound_port()); + m_address.ss_family = AF_UNSPEC; + m_address_len = 0; } void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) { auto& header = packet.as(); - header.src_port = m_port; + header.src_port = bound_port(); header.dst_port = dst_port; header.length = packet.size(); header.checksum = 0; @@ -115,7 +115,6 @@ namespace Kernel dprintln("No interface bound"); return BAN::Error::from_errno(EINVAL); } - ASSERT(m_port != PORT_NONE); SpinLockGuard guard(m_packet_lock); @@ -176,7 +175,7 @@ namespace Kernel dwarnln("ignoring sendmsg control message"); if (!is_bound()) - TRY(m_network_layer.bind_socket_to_unused(this, static_cast(message.msg_name), message.msg_namelen)); + TRY(m_network_layer.bind_socket_with_target(this, static_cast(message.msg_name), message.msg_namelen)); const size_t total_send_size = [&message]() -> size_t {