resolver: use select for client communication

This commit is contained in:
Bananymous 2024-02-12 23:47:39 +02:00
parent 2ab3eb4109
commit 420a7b60ca
1 changed files with 117 additions and 44 deletions

View File

@ -11,6 +11,7 @@
#include <netinet/in.h> #include <netinet/in.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <sys/select.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/stat.h> #include <sys/stat.h>
#include <sys/un.h> #include <sys/un.h>
@ -53,6 +54,12 @@ struct DNSEntry
BAN::IPv4Address address { 0 }; BAN::IPv4Address address { 0 };
}; };
struct DNSResponse
{
uint16_t id;
DNSEntry entry;
};
bool send_dns_query(int socket, BAN::StringView domain, uint16_t id) bool send_dns_query(int socket, BAN::StringView domain, uint16_t id)
{ {
static uint8_t buffer[4096]; static uint8_t buffer[4096];
@ -91,7 +98,7 @@ bool send_dns_query(int socket, BAN::StringView domain, uint16_t id)
return true; return true;
} }
BAN::Optional<DNSEntry> read_dns_response(int socket, uint16_t id) BAN::Optional<DNSResponse> read_dns_response(int socket)
{ {
static uint8_t buffer[4096]; static uint8_t buffer[4096];
@ -103,11 +110,6 @@ BAN::Optional<DNSEntry> read_dns_response(int socket, uint16_t id)
} }
DNSPacket& reply = *reinterpret_cast<DNSPacket*>(buffer); DNSPacket& reply = *reinterpret_cast<DNSPacket*>(buffer);
if (reply.identification != id)
{
dprintln("Reply to invalid packet");
return {};
}
if (reply.flags & 0x0F) if (reply.flags & 0x0F)
{ {
dprintln("DNS error (rcode {})", (unsigned)(reply.flags & 0xF)); dprintln("DNS error (rcode {})", (unsigned)(reply.flags & 0xF));
@ -134,9 +136,10 @@ BAN::Optional<DNSEntry> read_dns_response(int socket, uint16_t id)
return {}; return {};
} }
DNSEntry result; DNSResponse result;
result.valid_until = time(nullptr) + answer.ttl(); result.id = reply.identification;
result.address = BAN::IPv4Address(*reinterpret_cast<uint32_t*>(answer.data)); result.entry.valid_until = time(nullptr) + answer.ttl();
result.entry.address = BAN::IPv4Address(*reinterpret_cast<uint32_t*>(answer.data));
return result; return result;
} }
@ -207,55 +210,125 @@ int main(int, char**)
BAN::HashMap<BAN::String, DNSEntry> dns_cache; BAN::HashMap<BAN::String, DNSEntry> dns_cache;
struct Client
{
const int socket;
bool close { false };
uint16_t query_id { 0 };
BAN::String query;
};
BAN::LinkedList<Client> clients;
for (;;) for (;;)
{ {
int client = accept(service_socket, nullptr, nullptr); int max_sock = BAN::Math::max(service_socket, dns_socket);
if (client == -1)
fd_set fds;
FD_ZERO(&fds);
FD_SET(service_socket, &fds);
FD_SET(dns_socket, &fds);
for (auto& client : clients)
{ {
dprintln("accept: {}", strerror(errno)); FD_SET(client.socket, &fds);
max_sock = BAN::Math::max(max_sock, client.socket);
}
int nselect = select(max_sock + 1, &fds, nullptr, nullptr, nullptr);
if (nselect == -1)
{
perror("select");
continue; continue;
} }
auto query = read_service_query(client); if (FD_ISSET(service_socket, &fds))
if (!query.has_value())
{ {
close(client); int client = accept(service_socket, nullptr, nullptr);
continue; if (client == -1)
}
BAN::Optional<DNSEntry> result;
if (dns_cache.contains(*query))
{
auto& cached = dns_cache[*query];
if (time(nullptr) <= cached.valid_until)
result = cached;
else
dns_cache.remove(*query);
}
if (!result.has_value())
{
uint16_t id = rand() % 0xFFFF;
if (send_dns_query(dns_socket, *query, id))
{ {
result = read_dns_response(dns_socket, id); perror("accept");
if (result.has_value()) continue;
(void)dns_cache.insert(*query, *result); }
MUST(clients.emplace_back(client));
}
if (FD_ISSET(dns_socket, &fds))
{
auto result = read_dns_response(dns_socket);
if (!result.has_value())
continue;
for (auto& client : clients)
{
if (client.query_id != result->id)
continue;
(void)dns_cache.insert(client.query, result->entry);
sockaddr_storage storage;
storage.ss_family = AF_INET;
memcpy(storage.ss_storage, &result->entry.address.raw, sizeof(result->entry.address.raw));
if (send(client.socket, &storage, sizeof(storage), 0) == -1)
dprintln("send: {}", strerror(errno));
client.close = true;
break;
} }
} }
if (!result.has_value()) for (auto& client : clients)
result = DNSEntry { .valid_until = 0, .address = BAN::IPv4Address(INADDR_ANY) }; {
if (!FD_ISSET(client.socket, &fds))
continue;
sockaddr_storage storage; if (!client.query.empty())
storage.ss_family = AF_INET; {
memcpy(storage.ss_storage, &result->address.raw, sizeof(result->address.raw)); dprintln("Client already has a query");
continue;
}
if (send(client, &storage, sizeof(storage), 0) == -1) auto query = read_service_query(client.socket);
dprintln("send: {}", strerror(errno)); if (!query.has_value())
continue;
close(client); BAN::Optional<DNSEntry> result;
if (dns_cache.contains(*query))
{
auto& cached = dns_cache[*query];
if (time(nullptr) <= cached.valid_until)
result = cached;
else
dns_cache.remove(*query);
}
if (result.has_value())
{
sockaddr_storage storage;
storage.ss_family = AF_INET;
memcpy(storage.ss_storage, &result->address.raw, sizeof(result->address.raw));
if (send(client.socket, &storage, sizeof(storage), 0) == -1)
dprintln("send: {}", strerror(errno));
client.close = true;
continue;
}
client.query = query.release_value();
client.query_id = rand() % 0xFFFF;
send_dns_query(dns_socket, client.query, client.query_id);
}
for (auto it = clients.begin(); it != clients.end();)
{
if (!it->close)
{
it++;
continue;
}
close(it->socket);
it = clients.remove(it);
}
} }
return 0; return 0;