forked from Bananymous/banan-os
				
			resolver: use select for client communication
This commit is contained in:
		
							parent
							
								
									2ab3eb4109
								
							
						
					
					
						commit
						420a7b60ca
					
				| 
						 | 
					@ -11,6 +11,7 @@
 | 
				
			||||||
#include <netinet/in.h>
 | 
					#include <netinet/in.h>
 | 
				
			||||||
#include <stdio.h>
 | 
					#include <stdio.h>
 | 
				
			||||||
#include <stdlib.h>
 | 
					#include <stdlib.h>
 | 
				
			||||||
 | 
					#include <sys/select.h>
 | 
				
			||||||
#include <sys/socket.h>
 | 
					#include <sys/socket.h>
 | 
				
			||||||
#include <sys/stat.h>
 | 
					#include <sys/stat.h>
 | 
				
			||||||
#include <sys/un.h>
 | 
					#include <sys/un.h>
 | 
				
			||||||
| 
						 | 
					@ -53,6 +54,12 @@ struct DNSEntry
 | 
				
			||||||
	BAN::IPv4Address	address 	{ 0 };
 | 
						BAN::IPv4Address	address 	{ 0 };
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct DNSResponse
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
						uint16_t id;
 | 
				
			||||||
 | 
						DNSEntry entry;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
bool send_dns_query(int socket, BAN::StringView domain, uint16_t id)
 | 
					bool send_dns_query(int socket, BAN::StringView domain, uint16_t id)
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
	static uint8_t buffer[4096];
 | 
						static uint8_t buffer[4096];
 | 
				
			||||||
| 
						 | 
					@ -91,7 +98,7 @@ bool send_dns_query(int socket, BAN::StringView domain, uint16_t id)
 | 
				
			||||||
	return true;
 | 
						return true;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
BAN::Optional<DNSEntry> read_dns_response(int socket, uint16_t id)
 | 
					BAN::Optional<DNSResponse> read_dns_response(int socket)
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
	static uint8_t buffer[4096];
 | 
						static uint8_t buffer[4096];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -103,11 +110,6 @@ BAN::Optional<DNSEntry> read_dns_response(int socket, uint16_t id)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	DNSPacket& reply = *reinterpret_cast<DNSPacket*>(buffer);
 | 
						DNSPacket& reply = *reinterpret_cast<DNSPacket*>(buffer);
 | 
				
			||||||
	if (reply.identification != id)
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		dprintln("Reply to invalid packet");
 | 
					 | 
				
			||||||
		return {};
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if (reply.flags & 0x0F)
 | 
						if (reply.flags & 0x0F)
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		dprintln("DNS error (rcode {})", (unsigned)(reply.flags & 0xF));
 | 
							dprintln("DNS error (rcode {})", (unsigned)(reply.flags & 0xF));
 | 
				
			||||||
| 
						 | 
					@ -134,9 +136,10 @@ BAN::Optional<DNSEntry> read_dns_response(int socket, uint16_t id)
 | 
				
			||||||
		return {};
 | 
							return {};
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	DNSEntry result;
 | 
						DNSResponse result;
 | 
				
			||||||
	result.valid_until	= time(nullptr) + answer.ttl();
 | 
						result.id = reply.identification;
 | 
				
			||||||
	result.address		= BAN::IPv4Address(*reinterpret_cast<uint32_t*>(answer.data));
 | 
						result.entry.valid_until = time(nullptr) + answer.ttl();
 | 
				
			||||||
 | 
						result.entry.address = BAN::IPv4Address(*reinterpret_cast<uint32_t*>(answer.data));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return result;
 | 
						return result;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -207,55 +210,125 @@ int main(int, char**)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	BAN::HashMap<BAN::String, DNSEntry> dns_cache;
 | 
						BAN::HashMap<BAN::String, DNSEntry> dns_cache;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						struct Client
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							const int	socket;
 | 
				
			||||||
 | 
							bool		close { false };
 | 
				
			||||||
 | 
							uint16_t	query_id { 0 };
 | 
				
			||||||
 | 
							BAN::String	query;
 | 
				
			||||||
 | 
						};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						BAN::LinkedList<Client> clients;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for (;;)
 | 
						for (;;)
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		int client = accept(service_socket, nullptr, nullptr);
 | 
							int max_sock = BAN::Math::max(service_socket, dns_socket);
 | 
				
			||||||
		if (client == -1)
 | 
					
 | 
				
			||||||
 | 
							fd_set fds;
 | 
				
			||||||
 | 
							FD_ZERO(&fds);
 | 
				
			||||||
 | 
							FD_SET(service_socket, &fds);
 | 
				
			||||||
 | 
							FD_SET(dns_socket, &fds);
 | 
				
			||||||
 | 
							for (auto& client : clients)
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			dprintln("accept: {}", strerror(errno));
 | 
								FD_SET(client.socket, &fds);
 | 
				
			||||||
 | 
								max_sock = BAN::Math::max(max_sock, client.socket);
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							int nselect = select(max_sock + 1, &fds, nullptr, nullptr, nullptr);
 | 
				
			||||||
 | 
							if (nselect == -1)
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								perror("select");
 | 
				
			||||||
			continue;
 | 
								continue;
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		auto query = read_service_query(client);
 | 
							if (FD_ISSET(service_socket, &fds))
 | 
				
			||||||
		if (!query.has_value())
 | 
					 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			close(client);
 | 
								int client = accept(service_socket, nullptr, nullptr);
 | 
				
			||||||
			continue;
 | 
								if (client == -1)
 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		BAN::Optional<DNSEntry> result;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if (dns_cache.contains(*query))
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			auto& cached = dns_cache[*query];
 | 
					 | 
				
			||||||
			if (time(nullptr) <= cached.valid_until)
 | 
					 | 
				
			||||||
				result = cached;
 | 
					 | 
				
			||||||
			else
 | 
					 | 
				
			||||||
				dns_cache.remove(*query);
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if (!result.has_value())
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			uint16_t id = rand() % 0xFFFF;
 | 
					 | 
				
			||||||
			if (send_dns_query(dns_socket, *query, id))
 | 
					 | 
				
			||||||
			{
 | 
								{
 | 
				
			||||||
				result = read_dns_response(dns_socket, id);
 | 
									perror("accept");
 | 
				
			||||||
				if (result.has_value())
 | 
									continue;
 | 
				
			||||||
					(void)dns_cache.insert(*query, *result);
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								MUST(clients.emplace_back(client));
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if (FD_ISSET(dns_socket, &fds))
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								auto result = read_dns_response(dns_socket);
 | 
				
			||||||
 | 
								if (!result.has_value())
 | 
				
			||||||
 | 
									continue;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								for (auto& client : clients)
 | 
				
			||||||
 | 
								{
 | 
				
			||||||
 | 
									if (client.query_id != result->id)
 | 
				
			||||||
 | 
										continue;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									(void)dns_cache.insert(client.query, result->entry);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									sockaddr_storage storage;
 | 
				
			||||||
 | 
									storage.ss_family = AF_INET;
 | 
				
			||||||
 | 
									memcpy(storage.ss_storage, &result->entry.address.raw, sizeof(result->entry.address.raw));
 | 
				
			||||||
 | 
									if (send(client.socket, &storage, sizeof(storage), 0) == -1)
 | 
				
			||||||
 | 
										dprintln("send: {}", strerror(errno));
 | 
				
			||||||
 | 
									client.close = true;
 | 
				
			||||||
 | 
									break;
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if (!result.has_value())
 | 
							for (auto& client : clients)
 | 
				
			||||||
			result = DNSEntry { .valid_until = 0, .address = BAN::IPv4Address(INADDR_ANY) };
 | 
							{
 | 
				
			||||||
 | 
								if (!FD_ISSET(client.socket, &fds))
 | 
				
			||||||
 | 
									continue;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		sockaddr_storage storage;
 | 
								if (!client.query.empty())
 | 
				
			||||||
		storage.ss_family = AF_INET;
 | 
								{
 | 
				
			||||||
		memcpy(storage.ss_storage, &result->address.raw, sizeof(result->address.raw));
 | 
									dprintln("Client already has a query");
 | 
				
			||||||
 | 
									continue;
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if (send(client, &storage, sizeof(storage), 0) == -1)
 | 
								auto query = read_service_query(client.socket);
 | 
				
			||||||
			dprintln("send: {}", strerror(errno));
 | 
								if (!query.has_value())
 | 
				
			||||||
 | 
									continue;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		close(client);
 | 
								BAN::Optional<DNSEntry> result;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if (dns_cache.contains(*query))
 | 
				
			||||||
 | 
								{
 | 
				
			||||||
 | 
									auto& cached = dns_cache[*query];
 | 
				
			||||||
 | 
									if (time(nullptr) <= cached.valid_until)
 | 
				
			||||||
 | 
										result = cached;
 | 
				
			||||||
 | 
									else
 | 
				
			||||||
 | 
										dns_cache.remove(*query);
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if (result.has_value())
 | 
				
			||||||
 | 
								{
 | 
				
			||||||
 | 
									sockaddr_storage storage;
 | 
				
			||||||
 | 
									storage.ss_family = AF_INET;
 | 
				
			||||||
 | 
									memcpy(storage.ss_storage, &result->address.raw, sizeof(result->address.raw));
 | 
				
			||||||
 | 
									if (send(client.socket, &storage, sizeof(storage), 0) == -1)
 | 
				
			||||||
 | 
										dprintln("send: {}", strerror(errno));
 | 
				
			||||||
 | 
									client.close = true;
 | 
				
			||||||
 | 
									continue;
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								client.query = query.release_value();
 | 
				
			||||||
 | 
								client.query_id = rand() % 0xFFFF;
 | 
				
			||||||
 | 
								send_dns_query(dns_socket, client.query, client.query_id);
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for (auto it = clients.begin(); it != clients.end();)
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								if (!it->close)
 | 
				
			||||||
 | 
								{
 | 
				
			||||||
 | 
									it++;
 | 
				
			||||||
 | 
									continue;
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								close(it->socket);
 | 
				
			||||||
 | 
								it = clients.remove(it);
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return 0;
 | 
						return 0;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue