diff --git a/userspace/libraries/LibC/netdb.cpp b/userspace/libraries/LibC/netdb.cpp index a7ec5944..14231326 100644 --- a/userspace/libraries/LibC/netdb.cpp +++ b/userspace/libraries/LibC/netdb.cpp @@ -107,7 +107,7 @@ int getaddrinfo(const char* __restrict nodename, const char* __restrict servname goto error_close_socket; sockaddr_storage storage; - if (recv(resolver_sock, &storage, sizeof(storage), 0) == -1) + if (recv(resolver_sock, &storage, sizeof(storage), 0) < static_cast(sizeof(sockaddr_in))) goto error_close_socket; close(resolver_sock); @@ -115,12 +115,12 @@ int getaddrinfo(const char* __restrict nodename, const char* __restrict servname if (storage.ss_family != AF_INET) return EAI_FAIL; - ipv4_addr = *reinterpret_cast(storage.ss_storage); + ipv4_addr = reinterpret_cast(storage).sin_addr.s_addr; } { addrinfo* ai = (addrinfo*)malloc(sizeof(addrinfo) + sizeof(sockaddr_in)); - if (*res == nullptr) + if (ai == nullptr) return EAI_MEMORY; sockaddr_in* sa_in = reinterpret_cast(reinterpret_cast(ai) + sizeof(addrinfo)); @@ -193,7 +193,7 @@ struct hostent* gethostbyname(const char* name) goto error_close_socket; sockaddr_storage storage; - if (recv(socket, &storage, sizeof(storage), 0) == -1) + if (recv(socket, &storage, sizeof(storage), 0) < static_cast(sizeof(sockaddr_in))) goto error_close_socket; close(socket); @@ -201,7 +201,7 @@ struct hostent* gethostbyname(const char* name) if (storage.ss_family != AF_INET) return nullptr; - addr_buffer = *reinterpret_cast(storage.ss_storage); + addr_buffer = reinterpret_cast(storage).sin_addr.s_addr; } return &hostent; diff --git a/userspace/programs/nslookup/main.cpp b/userspace/programs/nslookup/main.cpp index c053f3be..ba095d0a 100644 --- a/userspace/programs/nslookup/main.cpp +++ b/userspace/programs/nslookup/main.cpp @@ -1,10 +1,7 @@ #include +#include #include #include -#include -#include -#include -#include #define MAX(a, b) ((a) < (b) ? (b) : (a)) @@ -16,39 +13,37 @@ int main(int argc, char** argv) return 1; } - int socket = ::socket(AF_UNIX, SOCK_SEQPACKET, 0); - if (socket == -1) + const addrinfo hints { + .ai_flags = 0, + .ai_family = AF_INET, + .ai_socktype = SOCK_STREAM, + .ai_protocol = 0, + .ai_addrlen = 0, + .ai_addr = nullptr, + .ai_canonname = nullptr, + .ai_next = nullptr, + }; + + addrinfo* result; + if (int ret = getaddrinfo(argv[1], nullptr, &hints, &result); ret != 0) { - perror("socket"); + fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(ret)); return 1; } - sockaddr_un addr; - addr.sun_family = AF_UNIX; - strcpy(addr.sun_path, "/tmp/resolver.sock"); - if (connect(socket, (sockaddr*)&addr, sizeof(addr)) == -1) + for (addrinfo* ai = result; ai; ai = ai->ai_next) { - perror("connect"); - return 1; + if (ai->ai_family != AF_INET) + continue; + + char buffer[NI_MAXHOST]; + if (inet_ntop(ai->ai_family, &reinterpret_cast(ai->ai_addr)->sin_addr, buffer, sizeof(buffer)) == nullptr) + continue; + + printf("%s\n", buffer); + return 0; } - if (send(socket, argv[1], strlen(argv[1]), 0) == -1) - { - perror("send"); - return 1; - } - - sockaddr_storage storage; - if (recv(socket, &storage, sizeof(storage), 0) == -1) - { - perror("recv"); - return 1; - } - - close(socket); - - char buffer[MAX(INET_ADDRSTRLEN, INET6_ADDRSTRLEN)]; - printf("%s\n", inet_ntop(storage.ss_family, storage.ss_storage, buffer, sizeof(buffer))); - + fprintf(stderr, "no address information available\n"); return 0; } diff --git a/userspace/programs/resolver/CMakeLists.txt b/userspace/programs/resolver/CMakeLists.txt index ca488ed9..1139ecc6 100644 --- a/userspace/programs/resolver/CMakeLists.txt +++ b/userspace/programs/resolver/CMakeLists.txt @@ -6,4 +6,6 @@ add_executable(resolver ${SOURCES}) banan_link_library(resolver ban) banan_link_library(resolver libc) +target_compile_options(resolver PRIVATE -Wno-maybe-uninitialized) + install(TARGETS resolver OPTIONAL) diff --git a/userspace/programs/resolver/main.cpp b/userspace/programs/resolver/main.cpp index ecc1bbde..3201908d 100644 --- a/userspace/programs/resolver/main.cpp +++ b/userspace/programs/resolver/main.cpp @@ -43,6 +43,7 @@ static_assert(sizeof(DNSAnswer) == 12); enum QTYPE : uint16_t { + INVALID = 0x0000, A = 0x0001, CNAME = 0x0005, AAAA = 0x001C, @@ -50,14 +51,83 @@ enum QTYPE : uint16_t struct DNSEntry { - time_t valid_until { 0 }; - BAN::IPv4Address address { 0 }; + DNSEntry(BAN::IPv4Address&& address, time_t valid_until) + : type(QTYPE::A) + , valid_until(valid_until) + , address(BAN::move(address)) + {} + + DNSEntry(BAN::String&& cname, time_t valid_until) + : type(QTYPE::CNAME) + , valid_until(valid_until) + , cname(BAN::move(cname)) + {} + + DNSEntry(DNSEntry&& other) + { + *this = BAN::move(other); + } + + ~DNSEntry() { clear(); } + + DNSEntry& operator=(DNSEntry&& other) + { + clear(); + valid_until = other.valid_until; + switch (type = other.type) + { + case QTYPE::A: + new (&address) BAN::IPv4Address(BAN::move(other.address)); + break; + case QTYPE::CNAME: + new (&cname) BAN::String(BAN::move(other.cname)); + break; + case QTYPE::INVALID: + case QTYPE::AAAA: + ASSERT_NOT_REACHED(); + } + other.clear(); + return *this; + } + + void clear() + { + switch (type) + { + case QTYPE::A: + using BAN::IPv4Address; + address.~IPv4Address(); + break; + case QTYPE::CNAME: + using BAN::String; + cname.~String(); + break; + case QTYPE::AAAA: + ASSERT_NOT_REACHED(); + case QTYPE::INVALID: + break; + } + type = QTYPE::INVALID; + } + + QTYPE type; + time_t valid_until; + union { + BAN::IPv4Address address; + BAN::String cname; + }; }; struct DNSResponse { + struct NameEntryPair + { + BAN::String name; + DNSEntry entry; + }; + uint16_t id; - DNSEntry entry; + BAN::Vector entries; }; bool send_dns_query(int socket, BAN::StringView domain, uint16_t id) @@ -110,36 +180,83 @@ BAN::Optional read_dns_response(int socket) } DNSPacket& reply = *reinterpret_cast(buffer); - if (reply.flags & 0x0F) - { - dprintln("DNS error (rcode {})", (unsigned)(reply.flags & 0xF)); - return {}; - } - - size_t idx = 0; - for (size_t i = 0; i < reply.question_count; i++) - { - while (reply.data[idx]) - idx += reply.data[idx] + 1; - idx += 5; - } - - DNSAnswer& answer = *reinterpret_cast(&reply.data[idx]); - if (answer.type() != QTYPE::A) - { - dprintln("Not A record, but {}", static_cast(answer.type())); - return {}; - } - if (answer.data_len() != 4) - { - dprintln("corrupted package"); - return {}; - } DNSResponse result; result.id = reply.identification; - result.entry.valid_until = time(nullptr) + answer.ttl(); - result.entry.address = BAN::IPv4Address(*reinterpret_cast(answer.data)); + + if (reply.flags & 0x0F) + { + dprintln("DNS error (rcode {})", (unsigned)(reply.flags & 0xF)); + return result; + } + + size_t idx = reply.data - buffer; + for (size_t i = 0; i < reply.question_count; i++) + { + while (buffer[idx]) + idx += buffer[idx] + 1; + idx += 5; + } + + const auto read_name = + [](size_t idx) -> BAN::String + { + BAN::String result; + while (buffer[idx]) + { + if ((buffer[idx] & 0xC0) == 0xC0) + { + idx = ((buffer[idx] & 0x3F) << 8) | buffer[idx + 1]; + continue; + } + + MUST(result.append(BAN::StringView(reinterpret_cast(&buffer[idx + 1]), buffer[idx]))); + MUST(result.push_back('.')); + idx += buffer[idx] + 1; + } + + if (!result.empty()) + result.pop_back(); + return result; + }; + + for (size_t i = 0; i < reply.answer_count; i++) + { + auto& answer = *reinterpret_cast(&buffer[idx]); + + auto name = read_name(answer.__storage - buffer); + + if (answer.type() == QTYPE::A) + { + if (answer.data_len() != 4) + { + dprintln("Invalid A record size {}", (uint16_t)answer.data_len()); + return result; + } + + MUST(result.entries.push_back({ + .name = BAN::move(name), + .entry = { + BAN::IPv4Address(*reinterpret_cast(answer.data)), + time(nullptr) + answer.ttl(), + }, + })); + } + else if (answer.type() == QTYPE::CNAME) + { + auto target = read_name(answer.data - buffer); + + MUST(result.entries.push_back({ + .name = BAN::move(name), + .entry = { + BAN::move(target), + time(nullptr) + answer.ttl() + }, + })); + } + + idx += sizeof(DNSAnswer) + answer.data_len(); + } return result; } @@ -193,6 +310,33 @@ BAN::Optional read_service_query(int socket) return BAN::String(buffer); } +BAN::Optional resolve_from_dns_cache(BAN::HashMap& dns_cache, const BAN::String& domain) +{ + for (auto it = dns_cache.find(domain); it != dns_cache.end();) + { + if (time(nullptr) > it->value.valid_until) + { + dprintln("{} timedout", it->key); + dns_cache.remove(it); + return {}; + } + + switch (it->value.type) + { + case QTYPE::A: + return it->value.address; + case QTYPE::CNAME: + it = dns_cache.find(it->value.cname); + break; + case QTYPE::AAAA: + case QTYPE::INVALID: + ASSERT_NOT_REACHED(); + } + } + + return {}; +} + int main(int, char**) { srand(time(nullptr)); @@ -266,17 +410,42 @@ int main(int, char**) if (!result.has_value()) continue; + for (auto&& [name, entry] : result->entries) + MUST(dns_cache.insert_or_assign(BAN::move(name), BAN::move(entry))); + for (auto& client : clients) { if (client.query_id != result->id) continue; - (void)dns_cache.insert(client.query, result->entry); + auto resolved = resolve_from_dns_cache(dns_cache, client.query); + if (!resolved.has_value()) + { + auto it = dns_cache.find(client.query); + if (it == dns_cache.end()) + { + client.close = true; + break; + } + for (;;) + { + ASSERT(it->value.type == QTYPE::CNAME); + auto next = dns_cache.find(it->value.cname); + if (next == dns_cache.end()) + break; + it = next; + } + send_dns_query(service_socket, it->value.cname, client.query_id); + break; + } - sockaddr_storage storage; - storage.ss_family = AF_INET; - memcpy(storage.ss_storage, &result->entry.address.raw, sizeof(result->entry.address.raw)); - if (send(client.socket, &storage, sizeof(storage), 0) == -1) + const sockaddr_in addr { + .sin_family = AF_INET, + .sin_port = 0, + .sin_addr = { .s_addr = resolved->raw }, + }; + + if (send(client.socket, &addr, sizeof(addr), 0) == -1) dprintln("send: {}", strerror(errno)); client.close = true; break; @@ -308,30 +477,22 @@ int main(int, char**) continue; } - BAN::Optional result; + BAN::Optional result; if (*hostname && strcmp(query->data(), hostname) == 0) - { - result = DNSEntry { - .valid_until = time(nullptr), - .address = ntohl(INADDR_LOOPBACK), - }; - } - else if (dns_cache.contains(*query)) - { - auto& cached = dns_cache[*query]; - if (time(nullptr) <= cached.valid_until) - result = cached; - else - dns_cache.remove(*query); - } + result = BAN::IPv4Address(ntohl(INADDR_LOOPBACK)); + else if (auto resolved = resolve_from_dns_cache(dns_cache, query.value()); resolved.has_value()) + result = resolved.release_value(); if (result.has_value()) { - sockaddr_storage storage; - storage.ss_family = AF_INET; - memcpy(storage.ss_storage, &result->address.raw, sizeof(result->address.raw)); - if (send(client.socket, &storage, sizeof(storage), 0) == -1) + const sockaddr_in addr { + .sin_family = AF_INET, + .sin_port = 0, + .sin_addr = { .s_addr = result->raw }, + }; + + if (send(client.socket, &addr, sizeof(addr), 0) == -1) dprintln("send: {}", strerror(errno)); client.close = true; continue; diff --git a/userspace/tests/test-tcp/main.cpp b/userspace/tests/test-tcp/main.cpp index b767996a..b2d64788 100644 --- a/userspace/tests/test-tcp/main.cpp +++ b/userspace/tests/test-tcp/main.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -8,44 +9,26 @@ in_addr_t get_ipv4_address(const char* query) { - if (in_addr_t ipv4 = inet_addr(query); ipv4 != (in_addr_t)(-1)) - return ipv4; + const addrinfo hints { + .ai_flags = 0, + .ai_family = AF_INET, + .ai_socktype = SOCK_STREAM, + .ai_protocol = 0, + .ai_addrlen = 0, + .ai_addr = nullptr, + .ai_canonname = nullptr, + .ai_next = nullptr, + }; - int socket = ::socket(AF_UNIX, SOCK_SEQPACKET, 0); - if (socket == -1) - { - perror("socket"); + addrinfo* result; + if (getaddrinfo(query, nullptr, &hints, &result) != 0) return -1; - } - sockaddr_un addr; - addr.sun_family = AF_UNIX; - strcpy(addr.sun_path, "/tmp/resolver.sock"); - if (connect(socket, (sockaddr*)&addr, sizeof(addr)) == -1) - { - perror("connect"); - close(socket); - return -1; - } + for (addrinfo* ai = result; ai; ai = ai->ai_next) + if (ai->ai_family != AF_INET) + return reinterpret_cast(ai->ai_addr)->sin_addr.s_addr; - if (send(socket, query, strlen(query), 0) == -1) - { - perror("send"); - close(socket); - return -1; - } - - sockaddr_storage storage; - if (recv(socket, &storage, sizeof(storage), 0) == -1) - { - perror("recv"); - close(socket); - return -1; - } - - close(socket); - - return *reinterpret_cast(storage.ss_storage); + return -1; } int main(int argc, char** argv)