banan-os/userspace/resolver/main.cpp

262 lines
5.7 KiB
C++

#include <BAN/ByteSpan.h>
#include <BAN/Endianness.h>
#include <BAN/HashMap.h>
#include <BAN/IPv4.h>
#include <BAN/String.h>
#include <BAN/StringView.h>
#include <BAN/Vector.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/un.h>
#include <unistd.h>
struct DNSPacket
{
BAN::NetworkEndian<uint16_t> identification { 0 };
BAN::NetworkEndian<uint16_t> flags { 0 };
BAN::NetworkEndian<uint16_t> question_count { 0 };
BAN::NetworkEndian<uint16_t> answer_count { 0 };
BAN::NetworkEndian<uint16_t> authority_RR_count { 0 };
BAN::NetworkEndian<uint16_t> additional_RR_count { 0 };
uint8_t data[];
};
static_assert(sizeof(DNSPacket) == 12);
struct DNSAnswer
{
uint8_t __storage[12];
BAN::NetworkEndian<uint16_t>& name() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x00); };
BAN::NetworkEndian<uint16_t>& type() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x02); };
BAN::NetworkEndian<uint16_t>& class_() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x04); };
BAN::NetworkEndian<uint32_t>& ttl() { return *reinterpret_cast<BAN::NetworkEndian<uint32_t>*>(__storage + 0x06); };
BAN::NetworkEndian<uint16_t>& data_len() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x0A); };
uint8_t data[];
};
static_assert(sizeof(DNSAnswer) == 12);
enum QTYPE : uint16_t
{
A = 0x0001,
CNAME = 0x0005,
AAAA = 0x001C,
};
struct DNSEntry
{
time_t valid_until { 0 };
BAN::IPv4Address address { 0 };
};
bool send_dns_query(int socket, BAN::StringView domain, uint16_t id)
{
static uint8_t buffer[4096];
memset(buffer, 0, sizeof(buffer));
DNSPacket& request = *reinterpret_cast<DNSPacket*>(buffer);
request.identification = id;
request.flags = 0x0100;
request.question_count = 1;
size_t idx = 0;
auto labels = MUST(BAN::StringView(domain).split('.'));
for (auto label : labels)
{
ASSERT(label.size() <= 0xFF);
request.data[idx++] = label.size();
for (char c : label)
request.data[idx++] = c;
}
request.data[idx++] = 0x00;
*(uint16_t*)&request.data[idx] = htons(QTYPE::A); idx += 2;
*(uint16_t*)&request.data[idx] = htons(0x0001); idx += 2;
sockaddr_in nameserver;
nameserver.sin_family = AF_INET;
nameserver.sin_port = htons(53);
nameserver.sin_addr.s_addr = inet_addr("8.8.8.8");
if (sendto(socket, &request, sizeof(DNSPacket) + idx, 0, (sockaddr*)&nameserver, sizeof(nameserver)) == -1)
{
perror("sendto");
return false;
}
return true;
}
BAN::Optional<DNSEntry> read_dns_response(int socket, uint16_t id)
{
static uint8_t buffer[4096];
ssize_t nrecv = recvfrom(socket, buffer, sizeof(buffer), 0, nullptr, nullptr);
if (nrecv == -1)
{
perror("recvfrom");
return {};
}
DNSPacket& reply = *reinterpret_cast<DNSPacket*>(buffer);
if (reply.identification != id)
{
fprintf(stderr, "Reply to invalid packet\n");
return {};
}
if (reply.flags & 0x0F)
{
fprintf(stderr, "DNS error (rcode %u)\n", (unsigned)(reply.flags & 0xF));
return {};
}
size_t idx = 0;
for (size_t i = 0; i < reply.question_count; i++)
{
while (reply.data[idx])
idx += reply.data[idx] + 1;
idx += 5;
}
DNSAnswer& answer = *reinterpret_cast<DNSAnswer*>(&reply.data[idx]);
if (answer.type() != QTYPE::A)
{
fprintf(stderr, "Not A record\n");
return {};
}
if (answer.data_len() != 4)
{
fprintf(stderr, "corrupted package\n");
return {};
}
DNSEntry result;
result.valid_until = time(nullptr) + answer.ttl();
result.address = BAN::IPv4Address(*reinterpret_cast<uint32_t*>(answer.data));
return result;
}
int create_service_socket()
{
int socket = ::socket(AF_UNIX, SOCK_SEQPACKET, 0);
if (socket == -1)
{
perror("socket");
return -1;
}
sockaddr_un addr;
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, "/tmp/resolver.sock");
if (bind(socket, (sockaddr*)&addr, sizeof(addr)) == -1)
{
perror("bind");
close(socket);
return -1;
}
if (chmod("/tmp/resolver.sock", 0777) == -1)
{
perror("chmod");
close(socket);
return -1;
}
if (listen(socket, 10) == -1)
{
perror("listen");
close(socket);
return -1;
}
return socket;
}
BAN::Optional<BAN::String> read_service_query(int socket)
{
static char buffer[4096];
ssize_t nrecv = recv(socket, buffer, sizeof(buffer), 0);
if (nrecv == -1)
{
perror("recv");
return {};
}
buffer[nrecv] = '\0';
return BAN::String(buffer);
}
int main(int, char**)
{
srand(time(nullptr));
int service_socket = create_service_socket();
if (service_socket == -1)
return 1;
int dns_socket = socket(AF_INET, SOCK_DGRAM, 0);
if (dns_socket == -1)
{
perror("socket");
return 1;
}
BAN::HashMap<BAN::String, DNSEntry> dns_cache;
for (;;)
{
int client = accept(service_socket, nullptr, nullptr);
if (client == -1)
{
perror("accept");
continue;
}
auto query = read_service_query(client);
if (!query.has_value())
{
close(client);
continue;
}
BAN::Optional<DNSEntry> result;
if (dns_cache.contains(*query))
{
auto& cached = dns_cache[*query];
if (time(nullptr) <= cached.valid_until)
result = cached;
else
dns_cache.remove(*query);
}
if (!result.has_value())
{
uint16_t id = rand() % 0xFFFF;
if (send_dns_query(dns_socket, *query, id))
{
result = read_dns_response(dns_socket, id);
if (result.has_value())
(void)dns_cache.insert(*query, *result);
}
}
if (!result.has_value())
result = DNSEntry { .valid_until = 0, .address = BAN::IPv4Address(INADDR_ANY) };
sockaddr_storage storage;
storage.ss_family = AF_INET;
memcpy(storage.ss_storage, &result->address.raw, sizeof(result->address.raw));
if (send(client, &storage, sizeof(storage), 0) == -1)
perror("send");
close(client);
}
return 0;
}