Compare commits

...

2 Commits

Author SHA1 Message Date
Bananymous b853d29992 Kernel: Fix unix domain socket close detection 2025-04-22 08:36:44 +03:00
Bananymous 33a0f562d3 resolver: Add support for CNAME
Also rework resolver's send format and convert test-tcp and nslookup to
use getaddrinfo
2025-04-22 08:36:44 +03:00
6 changed files with 277 additions and 134 deletions

View File

@ -60,7 +60,10 @@ namespace Kernel
{ {
auto& connection_info = m_info.get<ConnectionInfo>(); auto& connection_info = m_info.get<ConnectionInfo>();
if (auto connection = connection_info.connection.lock(); connection && connection->m_info.has<ConnectionInfo>()) if (auto connection = connection_info.connection.lock(); connection && connection->m_info.has<ConnectionInfo>())
{
connection->m_info.get<ConnectionInfo>().target_closed = true; connection->m_info.get<ConnectionInfo>().target_closed = true;
connection->m_packet_thread_blocker.unblock();
}
} }
} }
@ -350,6 +353,9 @@ namespace Kernel
} }
BAN::ErrorOr<size_t> UnixDomainSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr*, socklen_t*) BAN::ErrorOr<size_t> UnixDomainSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr*, socklen_t*)
{
auto state = m_packet_lock.lock();
while (m_packet_size_total == 0)
{ {
if (m_info.has<ConnectionInfo>()) if (m_info.has<ConnectionInfo>())
{ {
@ -361,9 +367,6 @@ namespace Kernel
return BAN::Error::from_errno(ENOTCONN); return BAN::Error::from_errno(ENOTCONN);
} }
auto state = m_packet_lock.lock();
while (m_packet_size_total == 0)
{
m_packet_lock.unlock(state); m_packet_lock.unlock(state);
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker)); TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker));
state = m_packet_lock.lock(); state = m_packet_lock.lock();

View File

@ -107,7 +107,7 @@ int getaddrinfo(const char* __restrict nodename, const char* __restrict servname
goto error_close_socket; goto error_close_socket;
sockaddr_storage storage; sockaddr_storage storage;
if (recv(resolver_sock, &storage, sizeof(storage), 0) == -1) if (recv(resolver_sock, &storage, sizeof(storage), 0) < static_cast<ssize_t>(sizeof(sockaddr_in)))
goto error_close_socket; goto error_close_socket;
close(resolver_sock); close(resolver_sock);
@ -115,12 +115,12 @@ int getaddrinfo(const char* __restrict nodename, const char* __restrict servname
if (storage.ss_family != AF_INET) if (storage.ss_family != AF_INET)
return EAI_FAIL; return EAI_FAIL;
ipv4_addr = *reinterpret_cast<in_addr_t*>(storage.ss_storage); ipv4_addr = reinterpret_cast<sockaddr_in&>(storage).sin_addr.s_addr;
} }
{ {
addrinfo* ai = (addrinfo*)malloc(sizeof(addrinfo) + sizeof(sockaddr_in)); addrinfo* ai = (addrinfo*)malloc(sizeof(addrinfo) + sizeof(sockaddr_in));
if (*res == nullptr) if (ai == nullptr)
return EAI_MEMORY; return EAI_MEMORY;
sockaddr_in* sa_in = reinterpret_cast<sockaddr_in*>(reinterpret_cast<uintptr_t>(ai) + sizeof(addrinfo)); sockaddr_in* sa_in = reinterpret_cast<sockaddr_in*>(reinterpret_cast<uintptr_t>(ai) + sizeof(addrinfo));
@ -193,7 +193,7 @@ struct hostent* gethostbyname(const char* name)
goto error_close_socket; goto error_close_socket;
sockaddr_storage storage; sockaddr_storage storage;
if (recv(socket, &storage, sizeof(storage), 0) == -1) if (recv(socket, &storage, sizeof(storage), 0) < static_cast<ssize_t>(sizeof(sockaddr_in)))
goto error_close_socket; goto error_close_socket;
close(socket); close(socket);
@ -201,7 +201,7 @@ struct hostent* gethostbyname(const char* name)
if (storage.ss_family != AF_INET) if (storage.ss_family != AF_INET)
return nullptr; return nullptr;
addr_buffer = *reinterpret_cast<in_addr_t*>(storage.ss_storage); addr_buffer = reinterpret_cast<sockaddr_in&>(storage).sin_addr.s_addr;
} }
return &hostent; return &hostent;

View File

@ -1,10 +1,7 @@
#include <arpa/inet.h> #include <arpa/inet.h>
#include <netdb.h>
#include <netinet/in.h> #include <netinet/in.h>
#include <stdio.h> #include <stdio.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>
#define MAX(a, b) ((a) < (b) ? (b) : (a)) #define MAX(a, b) ((a) < (b) ? (b) : (a))
@ -16,39 +13,37 @@ int main(int argc, char** argv)
return 1; return 1;
} }
int socket = ::socket(AF_UNIX, SOCK_SEQPACKET, 0); const addrinfo hints {
if (socket == -1) .ai_flags = 0,
.ai_family = AF_INET,
.ai_socktype = SOCK_STREAM,
.ai_protocol = 0,
.ai_addrlen = 0,
.ai_addr = nullptr,
.ai_canonname = nullptr,
.ai_next = nullptr,
};
addrinfo* result;
if (int ret = getaddrinfo(argv[1], nullptr, &hints, &result); ret != 0)
{ {
perror("socket"); fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(ret));
return 1; return 1;
} }
sockaddr_un addr; for (addrinfo* ai = result; ai; ai = ai->ai_next)
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, "/tmp/resolver.sock");
if (connect(socket, (sockaddr*)&addr, sizeof(addr)) == -1)
{ {
perror("connect"); if (ai->ai_family != AF_INET)
return 1; continue;
}
if (send(socket, argv[1], strlen(argv[1]), 0) == -1) char buffer[NI_MAXHOST];
{ if (inet_ntop(ai->ai_family, &reinterpret_cast<sockaddr_in*>(ai->ai_addr)->sin_addr, buffer, sizeof(buffer)) == nullptr)
perror("send"); continue;
return 1;
}
sockaddr_storage storage;
if (recv(socket, &storage, sizeof(storage), 0) == -1)
{
perror("recv");
return 1;
}
close(socket);
char buffer[MAX(INET_ADDRSTRLEN, INET6_ADDRSTRLEN)];
printf("%s\n", inet_ntop(storage.ss_family, storage.ss_storage, buffer, sizeof(buffer)));
printf("%s\n", buffer);
return 0;
}
fprintf(stderr, "no address information available\n");
return 0; return 0;
} }

View File

@ -6,4 +6,6 @@ add_executable(resolver ${SOURCES})
banan_link_library(resolver ban) banan_link_library(resolver ban)
banan_link_library(resolver libc) banan_link_library(resolver libc)
target_compile_options(resolver PRIVATE -Wno-maybe-uninitialized)
install(TARGETS resolver OPTIONAL) install(TARGETS resolver OPTIONAL)

View File

@ -43,6 +43,7 @@ static_assert(sizeof(DNSAnswer) == 12);
enum QTYPE : uint16_t enum QTYPE : uint16_t
{ {
INVALID = 0x0000,
A = 0x0001, A = 0x0001,
CNAME = 0x0005, CNAME = 0x0005,
AAAA = 0x001C, AAAA = 0x001C,
@ -50,16 +51,85 @@ enum QTYPE : uint16_t
struct DNSEntry struct DNSEntry
{ {
time_t valid_until { 0 }; DNSEntry(BAN::IPv4Address&& address, time_t valid_until)
BAN::IPv4Address address { 0 }; : type(QTYPE::A)
, valid_until(valid_until)
, address(BAN::move(address))
{}
DNSEntry(BAN::String&& cname, time_t valid_until)
: type(QTYPE::CNAME)
, valid_until(valid_until)
, cname(BAN::move(cname))
{}
DNSEntry(DNSEntry&& other)
{
*this = BAN::move(other);
}
~DNSEntry() { clear(); }
DNSEntry& operator=(DNSEntry&& other)
{
clear();
valid_until = other.valid_until;
switch (type = other.type)
{
case QTYPE::A:
new (&address) BAN::IPv4Address(BAN::move(other.address));
break;
case QTYPE::CNAME:
new (&cname) BAN::String(BAN::move(other.cname));
break;
case QTYPE::INVALID:
case QTYPE::AAAA:
ASSERT_NOT_REACHED();
}
other.clear();
return *this;
}
void clear()
{
switch (type)
{
case QTYPE::A:
using BAN::IPv4Address;
address.~IPv4Address();
break;
case QTYPE::CNAME:
using BAN::String;
cname.~String();
break;
case QTYPE::AAAA:
ASSERT_NOT_REACHED();
case QTYPE::INVALID:
break;
}
type = QTYPE::INVALID;
}
QTYPE type;
time_t valid_until;
union {
BAN::IPv4Address address;
BAN::String cname;
};
}; };
struct DNSResponse struct DNSResponse
{ {
uint16_t id; struct NameEntryPair
{
BAN::String name;
DNSEntry entry; DNSEntry entry;
}; };
uint16_t id;
BAN::Vector<NameEntryPair> entries;
};
bool send_dns_query(int socket, BAN::StringView domain, uint16_t id) bool send_dns_query(int socket, BAN::StringView domain, uint16_t id)
{ {
static uint8_t buffer[4096]; static uint8_t buffer[4096];
@ -110,36 +180,83 @@ BAN::Optional<DNSResponse> read_dns_response(int socket)
} }
DNSPacket& reply = *reinterpret_cast<DNSPacket*>(buffer); DNSPacket& reply = *reinterpret_cast<DNSPacket*>(buffer);
if (reply.flags & 0x0F)
{
dprintln("DNS error (rcode {})", (unsigned)(reply.flags & 0xF));
return {};
}
size_t idx = 0;
for (size_t i = 0; i < reply.question_count; i++)
{
while (reply.data[idx])
idx += reply.data[idx] + 1;
idx += 5;
}
DNSAnswer& answer = *reinterpret_cast<DNSAnswer*>(&reply.data[idx]);
if (answer.type() != QTYPE::A)
{
dprintln("Not A record, but {}", static_cast<uint16_t>(answer.type()));
return {};
}
if (answer.data_len() != 4)
{
dprintln("corrupted package");
return {};
}
DNSResponse result; DNSResponse result;
result.id = reply.identification; result.id = reply.identification;
result.entry.valid_until = time(nullptr) + answer.ttl();
result.entry.address = BAN::IPv4Address(*reinterpret_cast<uint32_t*>(answer.data)); if (reply.flags & 0x0F)
{
dprintln("DNS error (rcode {})", (unsigned)(reply.flags & 0xF));
return result;
}
size_t idx = reply.data - buffer;
for (size_t i = 0; i < reply.question_count; i++)
{
while (buffer[idx])
idx += buffer[idx] + 1;
idx += 5;
}
const auto read_name =
[](size_t idx) -> BAN::String
{
BAN::String result;
while (buffer[idx])
{
if ((buffer[idx] & 0xC0) == 0xC0)
{
idx = ((buffer[idx] & 0x3F) << 8) | buffer[idx + 1];
continue;
}
MUST(result.append(BAN::StringView(reinterpret_cast<const char*>(&buffer[idx + 1]), buffer[idx])));
MUST(result.push_back('.'));
idx += buffer[idx] + 1;
}
if (!result.empty())
result.pop_back();
return result;
};
for (size_t i = 0; i < reply.answer_count; i++)
{
auto& answer = *reinterpret_cast<DNSAnswer*>(&buffer[idx]);
auto name = read_name(answer.__storage - buffer);
if (answer.type() == QTYPE::A)
{
if (answer.data_len() != 4)
{
dprintln("Invalid A record size {}", (uint16_t)answer.data_len());
return result;
}
MUST(result.entries.push_back({
.name = BAN::move(name),
.entry = {
BAN::IPv4Address(*reinterpret_cast<uint32_t*>(answer.data)),
time(nullptr) + answer.ttl(),
},
}));
}
else if (answer.type() == QTYPE::CNAME)
{
auto target = read_name(answer.data - buffer);
MUST(result.entries.push_back({
.name = BAN::move(name),
.entry = {
BAN::move(target),
time(nullptr) + answer.ttl()
},
}));
}
idx += sizeof(DNSAnswer) + answer.data_len();
}
return result; return result;
} }
@ -193,6 +310,32 @@ BAN::Optional<BAN::String> read_service_query(int socket)
return BAN::String(buffer); return BAN::String(buffer);
} }
BAN::Optional<BAN::IPv4Address> resolve_from_dns_cache(BAN::HashMap<BAN::String, DNSEntry>& dns_cache, const BAN::String& domain)
{
for (auto it = dns_cache.find(domain); it != dns_cache.end();)
{
if (time(nullptr) > it->value.valid_until)
{
dns_cache.remove(it);
return {};
}
switch (it->value.type)
{
case QTYPE::A:
return it->value.address;
case QTYPE::CNAME:
it = dns_cache.find(it->value.cname);
break;
case QTYPE::AAAA:
case QTYPE::INVALID:
ASSERT_NOT_REACHED();
}
}
return {};
}
int main(int, char**) int main(int, char**)
{ {
srand(time(nullptr)); srand(time(nullptr));
@ -266,17 +409,42 @@ int main(int, char**)
if (!result.has_value()) if (!result.has_value())
continue; continue;
for (auto&& [name, entry] : result->entries)
MUST(dns_cache.insert_or_assign(BAN::move(name), BAN::move(entry)));
for (auto& client : clients) for (auto& client : clients)
{ {
if (client.query_id != result->id) if (client.query_id != result->id)
continue; continue;
(void)dns_cache.insert(client.query, result->entry); auto resolved = resolve_from_dns_cache(dns_cache, client.query);
if (!resolved.has_value())
{
auto it = dns_cache.find(client.query);
if (it == dns_cache.end())
{
client.close = true;
break;
}
for (;;)
{
ASSERT(it->value.type == QTYPE::CNAME);
auto next = dns_cache.find(it->value.cname);
if (next == dns_cache.end())
break;
it = next;
}
send_dns_query(service_socket, it->value.cname, client.query_id);
break;
}
sockaddr_storage storage; const sockaddr_in addr {
storage.ss_family = AF_INET; .sin_family = AF_INET,
memcpy(storage.ss_storage, &result->entry.address.raw, sizeof(result->entry.address.raw)); .sin_port = 0,
if (send(client.socket, &storage, sizeof(storage), 0) == -1) .sin_addr = { .s_addr = resolved->raw },
};
if (send(client.socket, &addr, sizeof(addr), 0) == -1)
dprintln("send: {}", strerror(errno)); dprintln("send: {}", strerror(errno));
client.close = true; client.close = true;
break; break;
@ -308,30 +476,22 @@ int main(int, char**)
continue; continue;
} }
BAN::Optional<DNSEntry> result; BAN::Optional<BAN::IPv4Address> result;
if (*hostname && strcmp(query->data(), hostname) == 0) if (*hostname && strcmp(query->data(), hostname) == 0)
{ result = BAN::IPv4Address(ntohl(INADDR_LOOPBACK));
result = DNSEntry { else if (auto resolved = resolve_from_dns_cache(dns_cache, query.value()); resolved.has_value())
.valid_until = time(nullptr), result = resolved.release_value();
.address = ntohl(INADDR_LOOPBACK),
};
}
else if (dns_cache.contains(*query))
{
auto& cached = dns_cache[*query];
if (time(nullptr) <= cached.valid_until)
result = cached;
else
dns_cache.remove(*query);
}
if (result.has_value()) if (result.has_value())
{ {
sockaddr_storage storage; const sockaddr_in addr {
storage.ss_family = AF_INET; .sin_family = AF_INET,
memcpy(storage.ss_storage, &result->address.raw, sizeof(result->address.raw)); .sin_port = 0,
if (send(client.socket, &storage, sizeof(storage), 0) == -1) .sin_addr = { .s_addr = result->raw },
};
if (send(client.socket, &addr, sizeof(addr), 0) == -1)
dprintln("send: {}", strerror(errno)); dprintln("send: {}", strerror(errno));
client.close = true; client.close = true;
continue; continue;

View File

@ -1,4 +1,5 @@
#include <arpa/inet.h> #include <arpa/inet.h>
#include <netdb.h>
#include <netinet/in.h> #include <netinet/in.h>
#include <stdio.h> #include <stdio.h>
#include <string.h> #include <string.h>
@ -8,46 +9,28 @@
in_addr_t get_ipv4_address(const char* query) in_addr_t get_ipv4_address(const char* query)
{ {
if (in_addr_t ipv4 = inet_addr(query); ipv4 != (in_addr_t)(-1)) const addrinfo hints {
return ipv4; .ai_flags = 0,
.ai_family = AF_INET,
.ai_socktype = SOCK_STREAM,
.ai_protocol = 0,
.ai_addrlen = 0,
.ai_addr = nullptr,
.ai_canonname = nullptr,
.ai_next = nullptr,
};
int socket = ::socket(AF_UNIX, SOCK_SEQPACKET, 0); addrinfo* result;
if (socket == -1) if (getaddrinfo(query, nullptr, &hints, &result) != 0)
{
perror("socket");
return -1; return -1;
}
sockaddr_un addr; for (addrinfo* ai = result; ai; ai = ai->ai_next)
addr.sun_family = AF_UNIX; if (ai->ai_family != AF_INET)
strcpy(addr.sun_path, "/tmp/resolver.sock"); return reinterpret_cast<sockaddr_in*>(ai->ai_addr)->sin_addr.s_addr;
if (connect(socket, (sockaddr*)&addr, sizeof(addr)) == -1)
{
perror("connect");
close(socket);
return -1; return -1;
} }
if (send(socket, query, strlen(query), 0) == -1)
{
perror("send");
close(socket);
return -1;
}
sockaddr_storage storage;
if (recv(socket, &storage, sizeof(storage), 0) == -1)
{
perror("recv");
close(socket);
return -1;
}
close(socket);
return *reinterpret_cast<in_addr_t*>(storage.ss_storage);
}
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
if (argc != 2) if (argc != 2)