forked from Bananymous/banan-os
				
			Userspace: Add DNS cache to resolver
Also the format of resolver reply is now just sockaddr_storage with family set and address in the storage field.
This commit is contained in:
		
							parent
							
								
									6fb69a1dc2
								
							
						
					
					
						commit
						065ee9004c
					
				| 
						 | 
					@ -1,7 +1,12 @@
 | 
				
			||||||
 | 
					#include <arpa/inet.h>
 | 
				
			||||||
 | 
					#include <netinet/in.h>
 | 
				
			||||||
#include <stdio.h>
 | 
					#include <stdio.h>
 | 
				
			||||||
#include <string.h>
 | 
					#include <string.h>
 | 
				
			||||||
#include <sys/socket.h>
 | 
					#include <sys/socket.h>
 | 
				
			||||||
#include <sys/un.h>
 | 
					#include <sys/un.h>
 | 
				
			||||||
 | 
					#include <unistd.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define MAX(a, b) ((a) < (b) ? (b) : (a))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
int main(int argc, char** argv)
 | 
					int main(int argc, char** argv)
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
| 
						 | 
					@ -33,15 +38,17 @@ int main(int argc, char** argv)
 | 
				
			||||||
		return 1;
 | 
							return 1;
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	char buffer[128];
 | 
						sockaddr_storage storage;
 | 
				
			||||||
	ssize_t nrecv = recv(socket, buffer, sizeof(buffer), 0);
 | 
						if (recv(socket, &storage, sizeof(storage), 0) == -1)
 | 
				
			||||||
	if (nrecv == -1)
 | 
					 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		perror("recv");
 | 
							perror("recv");
 | 
				
			||||||
		return 1;
 | 
							return 1;
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	buffer[nrecv] = '\0';
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	printf("%s\n", buffer);
 | 
						close(socket);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						char buffer[MAX(INET_ADDRSTRLEN, INET6_ADDRSTRLEN)];
 | 
				
			||||||
 | 
						printf("%s\n", inet_ntop(storage.ss_family, storage.ss_storage, buffer, sizeof(buffer)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return 0;
 | 
						return 0;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -39,6 +39,19 @@ struct DNSAnswer
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
static_assert(sizeof(DNSAnswer) == 12);
 | 
					static_assert(sizeof(DNSAnswer) == 12);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					enum QTYPE : uint16_t
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
						A		= 0x0001,
 | 
				
			||||||
 | 
						CNAME	= 0x0005,
 | 
				
			||||||
 | 
						AAAA	= 0x001C,
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct DNSEntry
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
						time_t				valid_until	{ 0 };
 | 
				
			||||||
 | 
						BAN::IPv4Address	address 	{ 0 };
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
bool send_dns_query(int socket, BAN::StringView domain, uint16_t id)
 | 
					bool send_dns_query(int socket, BAN::StringView domain, uint16_t id)
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
	static uint8_t buffer[4096];
 | 
						static uint8_t buffer[4096];
 | 
				
			||||||
| 
						 | 
					@ -61,8 +74,8 @@ bool send_dns_query(int socket, BAN::StringView domain, uint16_t id)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	request.data[idx++] = 0x00;
 | 
						request.data[idx++] = 0x00;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	*(uint16_t*)&request.data[idx] = htons(0x01); idx += 2;
 | 
						*(uint16_t*)&request.data[idx] = htons(QTYPE::A); idx += 2;
 | 
				
			||||||
	*(uint16_t*)&request.data[idx] = htons(0x01); idx += 2;
 | 
						*(uint16_t*)&request.data[idx] = htons(0x0001); idx += 2;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	sockaddr_in nameserver;
 | 
						sockaddr_in nameserver;
 | 
				
			||||||
	nameserver.sin_family = AF_INET;
 | 
						nameserver.sin_family = AF_INET;
 | 
				
			||||||
| 
						 | 
					@ -77,7 +90,7 @@ bool send_dns_query(int socket, BAN::StringView domain, uint16_t id)
 | 
				
			||||||
	return true;
 | 
						return true;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
BAN::Optional<BAN::String> read_dns_response(int socket, uint16_t id)
 | 
					BAN::Optional<DNSEntry> read_dns_response(int socket, uint16_t id)
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
	static uint8_t buffer[4096];
 | 
						static uint8_t buffer[4096];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -109,13 +122,22 @@ BAN::Optional<BAN::String> read_dns_response(int socket, uint16_t id)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	DNSAnswer& answer = *reinterpret_cast<DNSAnswer*>(&reply.data[idx]);
 | 
						DNSAnswer& answer = *reinterpret_cast<DNSAnswer*>(&reply.data[idx]);
 | 
				
			||||||
 | 
						if (answer.type() != QTYPE::A)
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							fprintf(stderr, "Not A record\n");
 | 
				
			||||||
 | 
							return {};
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	if (answer.data_len() != 4)
 | 
						if (answer.data_len() != 4)
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		fprintf(stderr, "Not IPv4\n");
 | 
							fprintf(stderr, "corrupted package\n");
 | 
				
			||||||
		return {};
 | 
							return {};
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return inet_ntoa({ .s_addr = *reinterpret_cast<uint32_t*>(answer.data) });
 | 
						DNSEntry result;
 | 
				
			||||||
 | 
						result.valid_until	= time(nullptr) + answer.ttl();
 | 
				
			||||||
 | 
						result.address		= BAN::IPv4Address(*reinterpret_cast<uint32_t*>(answer.data));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return result;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
int create_service_socket()
 | 
					int create_service_socket()
 | 
				
			||||||
| 
						 | 
					@ -182,6 +204,8 @@ int main(int, char**)
 | 
				
			||||||
		return 1;
 | 
							return 1;
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						BAN::HashMap<BAN::String, DNSEntry> dns_cache;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for (;;)
 | 
						for (;;)
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		int client = accept(service_socket, nullptr, nullptr);
 | 
							int client = accept(service_socket, nullptr, nullptr);
 | 
				
			||||||
| 
						 | 
					@ -193,24 +217,43 @@ int main(int, char**)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		auto query = read_service_query(client);
 | 
							auto query = read_service_query(client);
 | 
				
			||||||
		if (!query.has_value())
 | 
							if (!query.has_value())
 | 
				
			||||||
			continue;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		uint16_t id = rand() % 0xFFFF;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if (send_dns_query(dns_socket, *query, id))
 | 
					 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			auto response = read_dns_response(dns_socket, id);
 | 
								close(client);
 | 
				
			||||||
			if (response.has_value())
 | 
								continue;
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							BAN::Optional<DNSEntry> result;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if (dns_cache.contains(*query))
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								auto& cached = dns_cache[*query];
 | 
				
			||||||
 | 
								if (time(nullptr) <= cached.valid_until)
 | 
				
			||||||
 | 
									result = cached;
 | 
				
			||||||
 | 
								else
 | 
				
			||||||
 | 
									dns_cache.remove(*query);
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if (!result.has_value())
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								uint16_t id = rand() % 0xFFFF;
 | 
				
			||||||
 | 
								if (send_dns_query(dns_socket, *query, id))
 | 
				
			||||||
			{
 | 
								{
 | 
				
			||||||
				if (send(client, response->data(), response->size() + 1, 0) == -1)
 | 
									result = read_dns_response(dns_socket, id);
 | 
				
			||||||
					perror("send");
 | 
									if (result.has_value())
 | 
				
			||||||
				close(client);
 | 
										(void)dns_cache.insert(*query, *result);
 | 
				
			||||||
				continue;
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		char message[] = "unavailable";
 | 
							if (!result.has_value())
 | 
				
			||||||
		send(client, message, sizeof(message), 0);
 | 
								result = DNSEntry { .valid_until = 0, .address = BAN::IPv4Address(INADDR_ANY) };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							sockaddr_storage storage;
 | 
				
			||||||
 | 
							storage.ss_family = AF_INET;
 | 
				
			||||||
 | 
							memcpy(storage.ss_storage, &result->address.raw, sizeof(result->address.raw));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if (send(client, &storage, sizeof(storage), 0) == -1)
 | 
				
			||||||
 | 
								perror("send");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		close(client);
 | 
							close(client);
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue