diff --git a/userspace/nslookup/main.cpp b/userspace/nslookup/main.cpp index 6943ca48a0..c053f3be8b 100644 --- a/userspace/nslookup/main.cpp +++ b/userspace/nslookup/main.cpp @@ -1,7 +1,12 @@ +#include +#include #include #include #include #include +#include + +#define MAX(a, b) ((a) < (b) ? (b) : (a)) int main(int argc, char** argv) { @@ -33,15 +38,17 @@ int main(int argc, char** argv) return 1; } - char buffer[128]; - ssize_t nrecv = recv(socket, buffer, sizeof(buffer), 0); - if (nrecv == -1) + sockaddr_storage storage; + if (recv(socket, &storage, sizeof(storage), 0) == -1) { perror("recv"); return 1; } - buffer[nrecv] = '\0'; - printf("%s\n", buffer); + close(socket); + + char buffer[MAX(INET_ADDRSTRLEN, INET6_ADDRSTRLEN)]; + printf("%s\n", inet_ntop(storage.ss_family, storage.ss_storage, buffer, sizeof(buffer))); + return 0; } diff --git a/userspace/resolver/main.cpp b/userspace/resolver/main.cpp index 7f2f80e974..7d557d68b6 100644 --- a/userspace/resolver/main.cpp +++ b/userspace/resolver/main.cpp @@ -39,6 +39,19 @@ struct DNSAnswer }; static_assert(sizeof(DNSAnswer) == 12); +enum QTYPE : uint16_t +{ + A = 0x0001, + CNAME = 0x0005, + AAAA = 0x001C, +}; + +struct DNSEntry +{ + time_t valid_until { 0 }; + BAN::IPv4Address address { 0 }; +}; + bool send_dns_query(int socket, BAN::StringView domain, uint16_t id) { static uint8_t buffer[4096]; @@ -61,8 +74,8 @@ bool send_dns_query(int socket, BAN::StringView domain, uint16_t id) } request.data[idx++] = 0x00; - *(uint16_t*)&request.data[idx] = htons(0x01); idx += 2; - *(uint16_t*)&request.data[idx] = htons(0x01); idx += 2; + *(uint16_t*)&request.data[idx] = htons(QTYPE::A); idx += 2; + *(uint16_t*)&request.data[idx] = htons(0x0001); idx += 2; sockaddr_in nameserver; nameserver.sin_family = AF_INET; @@ -77,7 +90,7 @@ bool send_dns_query(int socket, BAN::StringView domain, uint16_t id) return true; } -BAN::Optional read_dns_response(int socket, uint16_t id) +BAN::Optional read_dns_response(int socket, uint16_t id) { static uint8_t buffer[4096]; @@ -109,13 +122,22 @@ BAN::Optional read_dns_response(int socket, uint16_t id) } DNSAnswer& answer = *reinterpret_cast(&reply.data[idx]); + if (answer.type() != QTYPE::A) + { + fprintf(stderr, "Not A record\n"); + return {}; + } if (answer.data_len() != 4) { - fprintf(stderr, "Not IPv4\n"); + fprintf(stderr, "corrupted package\n"); return {}; } - return inet_ntoa({ .s_addr = *reinterpret_cast(answer.data) }); + DNSEntry result; + result.valid_until = time(nullptr) + answer.ttl(); + result.address = BAN::IPv4Address(*reinterpret_cast(answer.data)); + + return result; } int create_service_socket() @@ -182,6 +204,8 @@ int main(int, char**) return 1; } + BAN::HashMap dns_cache; + for (;;) { int client = accept(service_socket, nullptr, nullptr); @@ -193,24 +217,43 @@ int main(int, char**) auto query = read_service_query(client); if (!query.has_value()) - continue; - - uint16_t id = rand() % 0xFFFF; - - if (send_dns_query(dns_socket, *query, id)) { - auto response = read_dns_response(dns_socket, id); - if (response.has_value()) + close(client); + continue; + } + + BAN::Optional result; + + if (dns_cache.contains(*query)) + { + auto& cached = dns_cache[*query]; + if (time(nullptr) <= cached.valid_until) + result = cached; + else + dns_cache.remove(*query); + } + + if (!result.has_value()) + { + uint16_t id = rand() % 0xFFFF; + if (send_dns_query(dns_socket, *query, id)) { - if (send(client, response->data(), response->size() + 1, 0) == -1) - perror("send"); - close(client); - continue; + result = read_dns_response(dns_socket, id); + if (result.has_value()) + (void)dns_cache.insert(*query, *result); } } - char message[] = "unavailable"; - send(client, message, sizeof(message), 0); + if (!result.has_value()) + result = DNSEntry { .valid_until = 0, .address = BAN::IPv4Address(INADDR_ANY) }; + + sockaddr_storage storage; + storage.ss_family = AF_INET; + memcpy(storage.ss_storage, &result->address.raw, sizeof(result->address.raw)); + + if (send(client, &storage, sizeof(storage), 0) == -1) + perror("send"); + close(client); }