diff --git a/userspace/resolver/main.cpp b/userspace/resolver/main.cpp index a98fcb99c2..ac4bfb7290 100644 --- a/userspace/resolver/main.cpp +++ b/userspace/resolver/main.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -53,6 +54,12 @@ struct DNSEntry BAN::IPv4Address address { 0 }; }; +struct DNSResponse +{ + uint16_t id; + DNSEntry entry; +}; + bool send_dns_query(int socket, BAN::StringView domain, uint16_t id) { static uint8_t buffer[4096]; @@ -91,7 +98,7 @@ bool send_dns_query(int socket, BAN::StringView domain, uint16_t id) return true; } -BAN::Optional read_dns_response(int socket, uint16_t id) +BAN::Optional read_dns_response(int socket) { static uint8_t buffer[4096]; @@ -103,11 +110,6 @@ BAN::Optional read_dns_response(int socket, uint16_t id) } DNSPacket& reply = *reinterpret_cast(buffer); - if (reply.identification != id) - { - dprintln("Reply to invalid packet"); - return {}; - } if (reply.flags & 0x0F) { dprintln("DNS error (rcode {})", (unsigned)(reply.flags & 0xF)); @@ -134,9 +136,10 @@ BAN::Optional read_dns_response(int socket, uint16_t id) return {}; } - DNSEntry result; - result.valid_until = time(nullptr) + answer.ttl(); - result.address = BAN::IPv4Address(*reinterpret_cast(answer.data)); + DNSResponse result; + result.id = reply.identification; + result.entry.valid_until = time(nullptr) + answer.ttl(); + result.entry.address = BAN::IPv4Address(*reinterpret_cast(answer.data)); return result; } @@ -207,55 +210,125 @@ int main(int, char**) BAN::HashMap dns_cache; + struct Client + { + const int socket; + bool close { false }; + uint16_t query_id { 0 }; + BAN::String query; + }; + + BAN::LinkedList clients; + for (;;) { - int client = accept(service_socket, nullptr, nullptr); - if (client == -1) + int max_sock = BAN::Math::max(service_socket, dns_socket); + + fd_set fds; + FD_ZERO(&fds); + FD_SET(service_socket, &fds); + FD_SET(dns_socket, &fds); + for (auto& client : clients) { - dprintln("accept: {}", strerror(errno)); + FD_SET(client.socket, &fds); + max_sock = BAN::Math::max(max_sock, client.socket); + } + + int nselect = select(max_sock + 1, &fds, nullptr, nullptr, nullptr); + if (nselect == -1) + { + perror("select"); continue; } - auto query = read_service_query(client); - if (!query.has_value()) + if (FD_ISSET(service_socket, &fds)) { - close(client); - continue; - } - - BAN::Optional result; - - if (dns_cache.contains(*query)) - { - auto& cached = dns_cache[*query]; - if (time(nullptr) <= cached.valid_until) - result = cached; - else - dns_cache.remove(*query); - } - - if (!result.has_value()) - { - uint16_t id = rand() % 0xFFFF; - if (send_dns_query(dns_socket, *query, id)) + int client = accept(service_socket, nullptr, nullptr); + if (client == -1) { - result = read_dns_response(dns_socket, id); - if (result.has_value()) - (void)dns_cache.insert(*query, *result); + perror("accept"); + continue; + } + + MUST(clients.emplace_back(client)); + } + + if (FD_ISSET(dns_socket, &fds)) + { + auto result = read_dns_response(dns_socket); + if (!result.has_value()) + continue; + + for (auto& client : clients) + { + if (client.query_id != result->id) + continue; + + (void)dns_cache.insert(client.query, result->entry); + + sockaddr_storage storage; + storage.ss_family = AF_INET; + memcpy(storage.ss_storage, &result->entry.address.raw, sizeof(result->entry.address.raw)); + if (send(client.socket, &storage, sizeof(storage), 0) == -1) + dprintln("send: {}", strerror(errno)); + client.close = true; + break; } } - if (!result.has_value()) - result = DNSEntry { .valid_until = 0, .address = BAN::IPv4Address(INADDR_ANY) }; + for (auto& client : clients) + { + if (!FD_ISSET(client.socket, &fds)) + continue; - sockaddr_storage storage; - storage.ss_family = AF_INET; - memcpy(storage.ss_storage, &result->address.raw, sizeof(result->address.raw)); + if (!client.query.empty()) + { + dprintln("Client already has a query"); + continue; + } - if (send(client, &storage, sizeof(storage), 0) == -1) - dprintln("send: {}", strerror(errno)); + auto query = read_service_query(client.socket); + if (!query.has_value()) + continue; - close(client); + BAN::Optional result; + + if (dns_cache.contains(*query)) + { + auto& cached = dns_cache[*query]; + if (time(nullptr) <= cached.valid_until) + result = cached; + else + dns_cache.remove(*query); + } + + if (result.has_value()) + { + sockaddr_storage storage; + storage.ss_family = AF_INET; + memcpy(storage.ss_storage, &result->address.raw, sizeof(result->address.raw)); + if (send(client.socket, &storage, sizeof(storage), 0) == -1) + dprintln("send: {}", strerror(errno)); + client.close = true; + continue; + } + + client.query = query.release_value(); + client.query_id = rand() % 0xFFFF; + send_dns_query(dns_socket, client.query, client.query_id); + } + + for (auto it = clients.begin(); it != clients.end();) + { + if (!it->close) + { + it++; + continue; + } + + close(it->socket); + it = clients.remove(it); + } } return 0;