#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include struct DNSPacket { BAN::NetworkEndian identification { 0 }; BAN::NetworkEndian flags { 0 }; BAN::NetworkEndian question_count { 0 }; BAN::NetworkEndian answer_count { 0 }; BAN::NetworkEndian authority_RR_count { 0 }; BAN::NetworkEndian additional_RR_count { 0 }; uint8_t data[]; }; static_assert(sizeof(DNSPacket) == 12); struct DNSAnswer { uint8_t __storage[12]; BAN::NetworkEndian& name() { return *reinterpret_cast*>(__storage + 0x00); }; BAN::NetworkEndian& type() { return *reinterpret_cast*>(__storage + 0x02); }; BAN::NetworkEndian& class_() { return *reinterpret_cast*>(__storage + 0x04); }; BAN::NetworkEndian& ttl() { return *reinterpret_cast*>(__storage + 0x06); }; BAN::NetworkEndian& data_len() { return *reinterpret_cast*>(__storage + 0x0A); }; uint8_t data[]; }; static_assert(sizeof(DNSAnswer) == 12); enum QTYPE : uint16_t { A = 0x0001, CNAME = 0x0005, AAAA = 0x001C, }; struct DNSEntry { time_t valid_until { 0 }; BAN::IPv4Address address { 0 }; }; bool send_dns_query(int socket, BAN::StringView domain, uint16_t id) { static uint8_t buffer[4096]; memset(buffer, 0, sizeof(buffer)); DNSPacket& request = *reinterpret_cast(buffer); request.identification = id; request.flags = 0x0100; request.question_count = 1; size_t idx = 0; auto labels = MUST(BAN::StringView(domain).split('.')); for (auto label : labels) { ASSERT(label.size() <= 0xFF); request.data[idx++] = label.size(); for (char c : label) request.data[idx++] = c; } request.data[idx++] = 0x00; *(uint16_t*)&request.data[idx] = htons(QTYPE::A); idx += 2; *(uint16_t*)&request.data[idx] = htons(0x0001); idx += 2; sockaddr_in nameserver; nameserver.sin_family = AF_INET; nameserver.sin_port = htons(53); nameserver.sin_addr.s_addr = inet_addr("8.8.8.8"); if (sendto(socket, &request, sizeof(DNSPacket) + idx, 0, (sockaddr*)&nameserver, sizeof(nameserver)) == -1) { dprintln("sendto: {}", strerror(errno)); return false; } return true; } BAN::Optional read_dns_response(int socket, uint16_t id) { static uint8_t buffer[4096]; ssize_t nrecv = recvfrom(socket, buffer, sizeof(buffer), 0, nullptr, nullptr); if (nrecv == -1) { dprintln("recvfrom: {}", strerror(errno)); return {}; } DNSPacket& reply = *reinterpret_cast(buffer); if (reply.identification != id) { dprintln("Reply to invalid packet"); return {}; } if (reply.flags & 0x0F) { dprintln("DNS error (rcode {})", (unsigned)(reply.flags & 0xF)); return {}; } size_t idx = 0; for (size_t i = 0; i < reply.question_count; i++) { while (reply.data[idx]) idx += reply.data[idx] + 1; idx += 5; } DNSAnswer& answer = *reinterpret_cast(&reply.data[idx]); if (answer.type() != QTYPE::A) { dprintln("Not A record"); return {}; } if (answer.data_len() != 4) { dprintln("corrupted package"); return {}; } DNSEntry result; result.valid_until = time(nullptr) + answer.ttl(); result.address = BAN::IPv4Address(*reinterpret_cast(answer.data)); return result; } int create_service_socket() { int socket = ::socket(AF_UNIX, SOCK_SEQPACKET, 0); if (socket == -1) { dprintln("socket: {}", strerror(errno)); return -1; } sockaddr_un addr; addr.sun_family = AF_UNIX; strcpy(addr.sun_path, "/tmp/resolver.sock"); if (bind(socket, (sockaddr*)&addr, sizeof(addr)) == -1) { dprintln("bind: {}", strerror(errno)); close(socket); return -1; } if (chmod("/tmp/resolver.sock", 0777) == -1) { dprintln("chmod: {}", strerror(errno)); close(socket); return -1; } if (listen(socket, 10) == -1) { dprintln("listen: {}", strerror(errno)); close(socket); return -1; } return socket; } BAN::Optional read_service_query(int socket) { static char buffer[4096]; ssize_t nrecv = recv(socket, buffer, sizeof(buffer), 0); if (nrecv == -1) { dprintln("recv: {}", strerror(errno)); return {}; } buffer[nrecv] = '\0'; return BAN::String(buffer); } int main(int, char**) { srand(time(nullptr)); int service_socket = create_service_socket(); if (service_socket == -1) return 1; int dns_socket = socket(AF_INET, SOCK_DGRAM, 0); if (dns_socket == -1) { dprintln("socket: {}", strerror(errno)); return 1; } BAN::HashMap dns_cache; for (;;) { int client = accept(service_socket, nullptr, nullptr); if (client == -1) { dprintln("accept: {}", strerror(errno)); continue; } auto query = read_service_query(client); if (!query.has_value()) { close(client); continue; } BAN::Optional result; if (dns_cache.contains(*query)) { auto& cached = dns_cache[*query]; if (time(nullptr) <= cached.valid_until) result = cached; else dns_cache.remove(*query); } if (!result.has_value()) { uint16_t id = rand() % 0xFFFF; if (send_dns_query(dns_socket, *query, id)) { result = read_dns_response(dns_socket, id); if (result.has_value()) (void)dns_cache.insert(*query, *result); } } if (!result.has_value()) result = DNSEntry { .valid_until = 0, .address = BAN::IPv4Address(INADDR_ANY) }; sockaddr_storage storage; storage.ss_family = AF_INET; memcpy(storage.ss_storage, &result->address.raw, sizeof(result->address.raw)); if (send(client, &storage, sizeof(storage), 0) == -1) dprintln("send: {}", strerror(errno)); close(client); } return 0; }