Files
banan-os/userspace/programs/resolver/main.cpp
Bananymous 40f3546aca BAN: Rewrite HashMap as a wrapper around HashSet
There is really no need to have two implementation of the same thing.
Only difference now is that HashMap's value type has to be movable but
this wasn't an issue
2026-04-21 00:18:18 +03:00

523 lines
11 KiB
C++

#include <BAN/ByteSpan.h>
#include <BAN/Debug.h>
#include <BAN/Endianness.h>
#include <BAN/HashMap.h>
#include <BAN/IPv4.h>
#include <BAN/LinkedList.h>
#include <BAN/String.h>
#include <BAN/StringView.h>
#include <BAN/Vector.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/select.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/un.h>
#include <unistd.h>
struct DNSPacket
{
BAN::NetworkEndian<uint16_t> identification { 0 };
BAN::NetworkEndian<uint16_t> flags { 0 };
BAN::NetworkEndian<uint16_t> question_count { 0 };
BAN::NetworkEndian<uint16_t> answer_count { 0 };
BAN::NetworkEndian<uint16_t> authority_RR_count { 0 };
BAN::NetworkEndian<uint16_t> additional_RR_count { 0 };
uint8_t data[];
};
static_assert(sizeof(DNSPacket) == 12);
struct DNSAnswer
{
uint8_t __storage[12];
BAN::NetworkEndian<uint16_t>& name() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x00); };
BAN::NetworkEndian<uint16_t>& type() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x02); };
BAN::NetworkEndian<uint16_t>& class_() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x04); };
BAN::NetworkEndian<uint32_t>& ttl() { return *reinterpret_cast<BAN::NetworkEndian<uint32_t>*>(__storage + 0x06); };
BAN::NetworkEndian<uint16_t>& data_len() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x0A); };
uint8_t data[];
};
static_assert(sizeof(DNSAnswer) == 12);
enum QTYPE : uint16_t
{
INVALID = 0x0000,
A = 0x0001,
CNAME = 0x0005,
AAAA = 0x001C,
};
struct DNSEntry
{
DNSEntry(BAN::IPv4Address&& address, time_t valid_until)
: type(QTYPE::A)
, valid_until(valid_until)
, address(BAN::move(address))
{}
DNSEntry(BAN::String&& cname, time_t valid_until)
: type(QTYPE::CNAME)
, valid_until(valid_until)
, cname(BAN::move(cname))
{}
DNSEntry(DNSEntry&& other)
{
*this = BAN::move(other);
}
~DNSEntry() { clear(); }
DNSEntry& operator=(DNSEntry&& other)
{
clear();
valid_until = other.valid_until;
switch (type = other.type)
{
case QTYPE::A:
new (&address) BAN::IPv4Address(BAN::move(other.address));
break;
case QTYPE::CNAME:
new (&cname) BAN::String(BAN::move(other.cname));
break;
case QTYPE::INVALID:
case QTYPE::AAAA:
ASSERT_NOT_REACHED();
}
other.clear();
return *this;
}
void clear()
{
switch (type)
{
case QTYPE::A:
using BAN::IPv4Address;
address.~IPv4Address();
break;
case QTYPE::CNAME:
using BAN::String;
cname.~String();
break;
case QTYPE::AAAA:
ASSERT_NOT_REACHED();
case QTYPE::INVALID:
break;
}
type = QTYPE::INVALID;
}
QTYPE type;
time_t valid_until;
union {
BAN::IPv4Address address;
BAN::String cname;
};
};
struct DNSResponse
{
struct NameEntryPair
{
BAN::String name;
DNSEntry entry;
};
uint16_t id;
BAN::Vector<NameEntryPair> entries;
};
bool send_dns_query(int socket, BAN::StringView domain, uint16_t id)
{
static uint8_t buffer[4096];
memset(buffer, 0, sizeof(buffer));
DNSPacket& request = *reinterpret_cast<DNSPacket*>(buffer);
request.identification = id;
request.flags = 0x0100;
request.question_count = 1;
size_t idx = 0;
auto labels = MUST(BAN::StringView(domain).split('.'));
for (auto label : labels)
{
ASSERT(label.size() <= 0xFF);
request.data[idx++] = label.size();
for (char c : label)
request.data[idx++] = c;
}
request.data[idx++] = 0x00;
*(uint16_t*)&request.data[idx] = htons(QTYPE::A); idx += 2;
*(uint16_t*)&request.data[idx] = htons(0x0001); idx += 2;
sockaddr_in nameserver;
nameserver.sin_family = AF_INET;
nameserver.sin_port = htons(53);
nameserver.sin_addr.s_addr = inet_addr("8.8.8.8");
if (sendto(socket, &request, sizeof(DNSPacket) + idx, 0, (sockaddr*)&nameserver, sizeof(nameserver)) == -1)
{
dprintln("sendto: {}", strerror(errno));
return false;
}
return true;
}
BAN::Optional<DNSResponse> read_dns_response(int socket)
{
static uint8_t buffer[4096];
ssize_t nrecv = recvfrom(socket, buffer, sizeof(buffer), 0, nullptr, nullptr);
if (nrecv == -1)
{
dprintln("recvfrom: {}", strerror(errno));
return {};
}
DNSPacket& reply = *reinterpret_cast<DNSPacket*>(buffer);
DNSResponse result;
result.id = reply.identification;
if (reply.flags & 0x0F)
{
dprintln("DNS error (rcode {})", (unsigned)(reply.flags & 0xF));
return result;
}
size_t idx = reply.data - buffer;
for (size_t i = 0; i < reply.question_count; i++)
{
while (buffer[idx])
idx += buffer[idx] + 1;
idx += 5;
}
const auto read_name =
[](size_t idx) -> BAN::String
{
BAN::String result;
while (buffer[idx])
{
if ((buffer[idx] & 0xC0) == 0xC0)
{
idx = ((buffer[idx] & 0x3F) << 8) | buffer[idx + 1];
continue;
}
MUST(result.append(BAN::StringView(reinterpret_cast<const char*>(&buffer[idx + 1]), buffer[idx])));
MUST(result.push_back('.'));
idx += buffer[idx] + 1;
}
if (!result.empty())
result.pop_back();
return result;
};
for (size_t i = 0; i < reply.answer_count; i++)
{
auto& answer = *reinterpret_cast<DNSAnswer*>(&buffer[idx]);
auto name = read_name(answer.__storage - buffer);
if (answer.type() == QTYPE::A)
{
if (answer.data_len() != 4)
{
dprintln("Invalid A record size {}", (uint16_t)answer.data_len());
return result;
}
MUST(result.entries.push_back({
.name = BAN::move(name),
.entry = {
BAN::IPv4Address(*reinterpret_cast<uint32_t*>(answer.data)),
time(nullptr) + answer.ttl(),
},
}));
}
else if (answer.type() == QTYPE::CNAME)
{
auto target = read_name(answer.data - buffer);
MUST(result.entries.push_back({
.name = BAN::move(name),
.entry = {
BAN::move(target),
time(nullptr) + answer.ttl()
},
}));
}
idx += sizeof(DNSAnswer) + answer.data_len();
}
return result;
}
int create_service_socket()
{
int socket = ::socket(AF_UNIX, SOCK_SEQPACKET, 0);
if (socket == -1)
{
dprintln("socket: {}", strerror(errno));
return -1;
}
sockaddr_un addr;
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, "/tmp/resolver.sock");
if (bind(socket, (sockaddr*)&addr, sizeof(addr)) == -1)
{
dprintln("bind: {}", strerror(errno));
close(socket);
return -1;
}
if (chmod("/tmp/resolver.sock", 0777) == -1)
{
dprintln("chmod: {}", strerror(errno));
close(socket);
return -1;
}
if (listen(socket, 10) == -1)
{
dprintln("listen: {}", strerror(errno));
close(socket);
return -1;
}
return socket;
}
BAN::Optional<BAN::String> read_service_query(int socket)
{
static char buffer[4096];
ssize_t nrecv = recv(socket, buffer, sizeof(buffer), 0);
if (nrecv == -1)
{
dprintln("recv: {}", strerror(errno));
return {};
}
buffer[nrecv] = '\0';
return BAN::String(buffer);
}
BAN::Optional<BAN::IPv4Address> resolve_from_dns_cache(BAN::HashMap<BAN::String, DNSEntry>& dns_cache, const BAN::String& domain)
{
for (auto it = dns_cache.find(domain); it != dns_cache.end();)
{
if (time(nullptr) > it->value.valid_until)
{
dns_cache.remove(it);
return {};
}
switch (it->value.type)
{
case QTYPE::A:
return it->value.address;
case QTYPE::CNAME:
it = dns_cache.find(it->value.cname);
break;
case QTYPE::AAAA:
case QTYPE::INVALID:
ASSERT_NOT_REACHED();
}
}
return {};
}
int main(int, char**)
{
srand(time(nullptr));
char hostname[HOST_NAME_MAX];
if (gethostname(hostname, sizeof(hostname)) == -1)
hostname[0] = '\0';
int service_socket = create_service_socket();
if (service_socket == -1)
return 1;
int dns_socket = socket(AF_INET, SOCK_DGRAM, 0);
if (dns_socket == -1)
{
dprintln("socket: {}", strerror(errno));
return 1;
}
BAN::HashMap<BAN::String, DNSEntry> dns_cache;
struct Client
{
Client(int socket)
: socket(socket)
{ }
const int socket;
bool close { false };
uint16_t query_id { 0 };
BAN::String query;
};
BAN::LinkedList<Client> clients;
for (;;)
{
int max_sock = BAN::Math::max(service_socket, dns_socket);
fd_set fds;
FD_ZERO(&fds);
FD_SET(service_socket, &fds);
FD_SET(dns_socket, &fds);
for (auto& client : clients)
{
FD_SET(client.socket, &fds);
max_sock = BAN::Math::max(max_sock, client.socket);
}
int nselect = select(max_sock + 1, &fds, nullptr, nullptr, nullptr);
if (nselect == -1)
{
perror("select");
continue;
}
if (FD_ISSET(service_socket, &fds))
{
int client = accept(service_socket, nullptr, nullptr);
if (client == -1)
{
perror("accept");
continue;
}
MUST(clients.emplace_back(client));
}
if (FD_ISSET(dns_socket, &fds))
{
auto result = read_dns_response(dns_socket);
if (!result.has_value())
continue;
for (auto&& [name, entry] : result->entries)
MUST(dns_cache.insert_or_assign(BAN::move(name), BAN::move(entry)));
for (auto& client : clients)
{
if (client.query_id != result->id)
continue;
auto resolved = resolve_from_dns_cache(dns_cache, client.query);
if (!resolved.has_value())
{
auto it = dns_cache.find(client.query);
if (it == dns_cache.end())
{
client.close = true;
break;
}
for (;;)
{
ASSERT(it->value.type == QTYPE::CNAME);
auto next = dns_cache.find(it->value.cname);
if (next == dns_cache.end())
break;
it = next;
}
send_dns_query(service_socket, it->value.cname, client.query_id);
break;
}
const sockaddr_in addr {
.sin_family = AF_INET,
.sin_port = 0,
.sin_addr = { .s_addr = resolved->raw },
.sin_zero = {},
};
if (send(client.socket, &addr, sizeof(addr), 0) == -1)
dprintln("send: {}", strerror(errno));
client.close = true;
break;
}
}
for (auto& client : clients)
{
if (!FD_ISSET(client.socket, &fds))
continue;
if (!client.query.empty())
{
static uint8_t buffer[4096];
ssize_t nrecv = recv(client.socket, buffer, sizeof(buffer), 0);
if (nrecv < 0)
dprintln("{}", strerror(errno));
if (nrecv <= 0)
client.close = true;
else
dprintln("Client already has a query");
continue;
}
auto query = read_service_query(client.socket);
if (!query.has_value())
{
client.close = true;
continue;
}
BAN::Optional<BAN::IPv4Address> result;
if (*hostname && strcmp(query->data(), hostname) == 0)
result = BAN::IPv4Address(ntohl(INADDR_LOOPBACK));
else if (auto resolved = resolve_from_dns_cache(dns_cache, query.value()); resolved.has_value())
result = resolved.release_value();
if (result.has_value())
{
const sockaddr_in addr {
.sin_family = AF_INET,
.sin_port = 0,
.sin_addr = { .s_addr = result->raw },
.sin_zero = {},
};
if (send(client.socket, &addr, sizeof(addr), 0) == -1)
dprintln("send: {}", strerror(errno));
client.close = true;
continue;
}
client.query = query.release_value();
client.query_id = rand() % 0xFFFF;
send_dns_query(dns_socket, client.query, client.query_id);
}
for (auto it = clients.begin(); it != clients.end();)
{
if (!it->close)
{
it++;
continue;
}
close(it->socket);
it = clients.remove(it);
}
}
return 0;
}