forked from Bananymous/banan-os
				
			Kernel/LibC: Implement `socketpair` for UNIX sockets
This commit is contained in:
		
							parent
							
								
									12b93567f7
								
							
						
					
					
						commit
						89c9bfd052
					
				|  | @ -17,6 +17,7 @@ namespace Kernel | ||||||
| 
 | 
 | ||||||
| 	public: | 	public: | ||||||
| 		static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(Socket::Type, const Socket::Info&); | 		static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(Socket::Type, const Socket::Info&); | ||||||
|  | 		BAN::ErrorOr<void> make_socket_pair(UnixDomainSocket&); | ||||||
| 
 | 
 | ||||||
| 	protected: | 	protected: | ||||||
| 		virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*, int) override; | 		virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*, int) override; | ||||||
|  |  | ||||||
|  | @ -26,6 +26,7 @@ namespace Kernel | ||||||
| 		BAN::ErrorOr<int> open(BAN::StringView absolute_path, int flags); | 		BAN::ErrorOr<int> open(BAN::StringView absolute_path, int flags); | ||||||
| 
 | 
 | ||||||
| 		BAN::ErrorOr<int> socket(int domain, int type, int protocol); | 		BAN::ErrorOr<int> socket(int domain, int type, int protocol); | ||||||
|  | 		BAN::ErrorOr<void> socketpair(int domain, int type, int protocol, int socket_vector[2]); | ||||||
| 
 | 
 | ||||||
| 		BAN::ErrorOr<void> pipe(int fds[2]); | 		BAN::ErrorOr<void> pipe(int fds[2]); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -118,6 +118,7 @@ namespace Kernel | ||||||
| 		BAN::ErrorOr<long> sys_fchownat(int fd, const char* path, uid_t uid, gid_t gid, int flag); | 		BAN::ErrorOr<long> sys_fchownat(int fd, const char* path, uid_t uid, gid_t gid, int flag); | ||||||
| 
 | 
 | ||||||
| 		BAN::ErrorOr<long> sys_socket(int domain, int type, int protocol); | 		BAN::ErrorOr<long> sys_socket(int domain, int type, int protocol); | ||||||
|  | 		BAN::ErrorOr<long> sys_socketpair(int domain, int type, int protocol, int socket_vector[2]); | ||||||
| 		BAN::ErrorOr<long> sys_getsockname(int socket, sockaddr* address, socklen_t* address_len); | 		BAN::ErrorOr<long> sys_getsockname(int socket, sockaddr* address, socklen_t* address_len); | ||||||
| 		BAN::ErrorOr<long> sys_getpeername(int socket, sockaddr* address, socklen_t* address_len); | 		BAN::ErrorOr<long> sys_getpeername(int socket, sockaddr* address, socklen_t* address_len); | ||||||
| 		BAN::ErrorOr<long> sys_getsockopt(int socket, int level, int option_name, void* option_value, socklen_t* option_len); | 		BAN::ErrorOr<long> sys_getsockopt(int socket, int level, int option_name, void* option_value, socklen_t* option_len); | ||||||
|  |  | ||||||
|  | @ -129,6 +129,25 @@ namespace Kernel | ||||||
| 		return socket; | 		return socket; | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	BAN::ErrorOr<void> NetworkManager::connect_sockets(Socket::Domain domain, BAN::RefPtr<Socket> socket1, BAN::RefPtr<Socket> socket2) | ||||||
|  | 	{ | ||||||
|  | 		switch (domain) | ||||||
|  | 		{ | ||||||
|  | 			case Socket::Domain::UNIX: | ||||||
|  | 			{ | ||||||
|  | 				auto* usocket1 = static_cast<UnixDomainSocket*>(socket1.ptr()); | ||||||
|  | 				auto* usocket2 = static_cast<UnixDomainSocket*>(socket2.ptr()); | ||||||
|  | 				TRY(usocket1->make_socket_pair(*usocket2)); | ||||||
|  | 				break; | ||||||
|  | 			} | ||||||
|  | 			default: | ||||||
|  | 				dwarnln("TODO: connect {} domain sockets", static_cast<int>(domain)); | ||||||
|  | 				return BAN::Error::from_errno(ENOTSUP); | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return {}; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	void NetworkManager::on_receive(NetworkInterface& interface, BAN::ConstByteSpan packet) | 	void NetworkManager::on_receive(NetworkInterface& interface, BAN::ConstByteSpan packet) | ||||||
| 	{ | 	{ | ||||||
| 		if (packet.size() < sizeof(EthernetHeader)) | 		if (packet.size() < sizeof(EthernetHeader)) | ||||||
|  |  | ||||||
|  | @ -69,6 +69,20 @@ namespace Kernel | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	BAN::ErrorOr<void> UnixDomainSocket::make_socket_pair(UnixDomainSocket& other) | ||||||
|  | 	{ | ||||||
|  | 		if (!m_info.has<ConnectionInfo>() || !other.m_info.has<ConnectionInfo>()) | ||||||
|  | 			return BAN::Error::from_errno(EINVAL); | ||||||
|  | 
 | ||||||
|  | 		TRY(this->get_weak_ptr()); | ||||||
|  | 		TRY(other.get_weak_ptr()); | ||||||
|  | 
 | ||||||
|  | 		this->m_info.get<ConnectionInfo>().connection = MUST(other.get_weak_ptr()); | ||||||
|  | 		other.m_info.get<ConnectionInfo>().connection = MUST(this->get_weak_ptr()); | ||||||
|  | 
 | ||||||
|  | 		return {}; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	BAN::ErrorOr<long> UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len, int flags) | 	BAN::ErrorOr<long> UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len, int flags) | ||||||
| 	{ | 	{ | ||||||
| 		if (!m_info.has<ConnectionInfo>()) | 		if (!m_info.has<ConnectionInfo>()) | ||||||
|  |  | ||||||
|  | @ -83,50 +83,57 @@ namespace Kernel | ||||||
| 		return open(TRY(VirtualFileSystem::get().file_from_absolute_path(m_credentials, absolute_path, flags)), flags); | 		return open(TRY(VirtualFileSystem::get().file_from_absolute_path(m_credentials, absolute_path, flags)), flags); | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	BAN::ErrorOr<int> OpenFileDescriptorSet::socket(int domain, int type, int protocol) | 	struct SocketInfo | ||||||
| 	{ | 	{ | ||||||
| 		bool valid_protocol = true; | 		Socket::Domain domain; | ||||||
|  | 		Socket::Type type; | ||||||
|  | 		int status_flags; | ||||||
|  | 		int descriptor_flags; | ||||||
|  | 	}; | ||||||
| 
 | 
 | ||||||
| 		Socket::Domain sock_domain; | 	static BAN::ErrorOr<SocketInfo> parse_socket_info(int domain, int type, int protocol) | ||||||
|  | 	{ | ||||||
|  | 		SocketInfo info; | ||||||
|  | 
 | ||||||
|  | 		bool valid_protocol = true; | ||||||
| 		switch (domain) | 		switch (domain) | ||||||
| 		{ | 		{ | ||||||
| 			case AF_INET: | 			case AF_INET: | ||||||
| 				sock_domain = Socket::Domain::INET; | 				info.domain = Socket::Domain::INET; | ||||||
| 				break; | 				break; | ||||||
| 			case AF_INET6: | 			case AF_INET6: | ||||||
| 				sock_domain = Socket::Domain::INET6; | 				info.domain = Socket::Domain::INET6; | ||||||
| 				break; | 				break; | ||||||
| 			case AF_UNIX: | 			case AF_UNIX: | ||||||
| 				sock_domain = Socket::Domain::UNIX; | 				info.domain = Socket::Domain::UNIX; | ||||||
| 				valid_protocol = false; | 				valid_protocol = false; | ||||||
| 				break; | 				break; | ||||||
| 			default: | 			default: | ||||||
| 				return BAN::Error::from_errno(EPROTOTYPE); | 				return BAN::Error::from_errno(EPROTOTYPE); | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		int status_flags = 0; | 		info.status_flags = 0; | ||||||
| 		int descriptor_flags = 0; | 		info.descriptor_flags = 0; | ||||||
| 		if (type & SOCK_NONBLOCK) | 		if (type & SOCK_NONBLOCK) | ||||||
| 			status_flags |= O_NONBLOCK; | 			info.status_flags |= O_NONBLOCK; | ||||||
| 		if (type & SOCK_CLOEXEC) | 		if (type & SOCK_CLOEXEC) | ||||||
| 			descriptor_flags |= O_CLOEXEC; | 			info.descriptor_flags |= O_CLOEXEC; | ||||||
| 		type &= ~(SOCK_NONBLOCK | SOCK_CLOEXEC); | 		type &= ~(SOCK_NONBLOCK | SOCK_CLOEXEC); | ||||||
| 
 | 
 | ||||||
| 		Socket::Type sock_type; |  | ||||||
| 		switch (type) | 		switch (type) | ||||||
| 		{ | 		{ | ||||||
| 			case SOCK_STREAM: | 			case SOCK_STREAM: | ||||||
| 				sock_type = Socket::Type::STREAM; | 				info.type = Socket::Type::STREAM; | ||||||
| 				if (protocol != IPPROTO_TCP) | 				if (protocol != IPPROTO_TCP) | ||||||
| 					valid_protocol = false; | 					valid_protocol = false; | ||||||
| 				break; | 				break; | ||||||
| 			case SOCK_DGRAM: | 			case SOCK_DGRAM: | ||||||
| 				sock_type = Socket::Type::DGRAM; | 				info.type = Socket::Type::DGRAM; | ||||||
| 				if (protocol != IPPROTO_UDP) | 				if (protocol != IPPROTO_UDP) | ||||||
| 					valid_protocol = false; | 					valid_protocol = false; | ||||||
| 				break; | 				break; | ||||||
| 			case SOCK_SEQPACKET: | 			case SOCK_SEQPACKET: | ||||||
| 				sock_type = Socket::Type::SEQPACKET; | 				info.type = Socket::Type::SEQPACKET; | ||||||
| 				valid_protocol = false; | 				valid_protocol = false; | ||||||
| 				break; | 				break; | ||||||
| 			default: | 			default: | ||||||
|  | @ -136,15 +143,39 @@ namespace Kernel | ||||||
| 		if (protocol && !valid_protocol) | 		if (protocol && !valid_protocol) | ||||||
| 			return BAN::Error::from_errno(EPROTONOSUPPORT); | 			return BAN::Error::from_errno(EPROTONOSUPPORT); | ||||||
| 
 | 
 | ||||||
| 		auto socket = TRY(NetworkManager::get().create_socket(sock_domain, sock_type, 0777, m_credentials.euid(), m_credentials.egid())); | 		return info; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	BAN::ErrorOr<int> OpenFileDescriptorSet::socket(int domain, int type, int protocol) | ||||||
|  | 	{ | ||||||
|  | 		auto sock_info = TRY(parse_socket_info(domain, type, protocol)); | ||||||
|  | 		auto socket = TRY(NetworkManager::get().create_socket(sock_info.domain, sock_info.type, 0777, m_credentials.euid(), m_credentials.egid())); | ||||||
| 
 | 
 | ||||||
| 		LockGuard _(m_mutex); | 		LockGuard _(m_mutex); | ||||||
| 		int fd = TRY(get_free_fd()); | 		int fd = TRY(get_free_fd()); | ||||||
| 		m_open_files[fd].description = TRY(BAN::RefPtr<OpenFileDescription>::create(VirtualFileSystem::File(socket, "<socket>"_sv), 0, O_RDWR | status_flags)); | 		m_open_files[fd].description = TRY(BAN::RefPtr<OpenFileDescription>::create(VirtualFileSystem::File(socket, "<socket>"_sv), 0, O_RDWR | sock_info.status_flags)); | ||||||
| 		m_open_files[fd].descriptor_flags = descriptor_flags; | 		m_open_files[fd].descriptor_flags = sock_info.descriptor_flags; | ||||||
| 		return fd; | 		return fd; | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	BAN::ErrorOr<void> OpenFileDescriptorSet::socketpair(int domain, int type, int protocol, int socket_vector[2]) | ||||||
|  | 	{ | ||||||
|  | 		auto sock_info = TRY(parse_socket_info(domain, type, protocol)); | ||||||
|  | 
 | ||||||
|  | 		auto socket1 = TRY(NetworkManager::get().create_socket(sock_info.domain, sock_info.type, 0600, m_credentials.euid(), m_credentials.egid())); | ||||||
|  | 		auto socket2 = TRY(NetworkManager::get().create_socket(sock_info.domain, sock_info.type, 0600, m_credentials.euid(), m_credentials.egid())); | ||||||
|  | 		TRY(NetworkManager::get().connect_sockets(sock_info.domain, socket1, socket2)); | ||||||
|  | 
 | ||||||
|  | 		LockGuard _(m_mutex); | ||||||
|  | 
 | ||||||
|  | 		TRY(get_free_fd_pair(socket_vector)); | ||||||
|  | 		m_open_files[socket_vector[0]].description = TRY(BAN::RefPtr<OpenFileDescription>::create(VirtualFileSystem::File(socket1, "<socketpair>"_sv), 0, O_RDWR | sock_info.status_flags)); | ||||||
|  | 		m_open_files[socket_vector[0]].descriptor_flags = sock_info.descriptor_flags; | ||||||
|  | 		m_open_files[socket_vector[1]].description = TRY(BAN::RefPtr<OpenFileDescription>::create(VirtualFileSystem::File(socket2, "<socketpair>"_sv), 0, O_RDWR | sock_info.status_flags)); | ||||||
|  | 		m_open_files[socket_vector[1]].descriptor_flags = sock_info.descriptor_flags; | ||||||
|  | 		return {}; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	BAN::ErrorOr<void> OpenFileDescriptorSet::pipe(int fds[2]) | 	BAN::ErrorOr<void> OpenFileDescriptorSet::pipe(int fds[2]) | ||||||
| 	{ | 	{ | ||||||
| 		LockGuard _(m_mutex); | 		LockGuard _(m_mutex); | ||||||
|  |  | ||||||
|  | @ -1270,6 +1270,14 @@ namespace Kernel | ||||||
| 		return TRY(m_open_file_descriptors.socket(domain, type, protocol)); | 		return TRY(m_open_file_descriptors.socket(domain, type, protocol)); | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	BAN::ErrorOr<long> Process::sys_socketpair(int domain, int type, int protocol, int socket_vector[2]) | ||||||
|  | 	{ | ||||||
|  | 		LockGuard _(m_process_lock); | ||||||
|  | 		TRY(validate_pointer_access(socket_vector, sizeof(int) * 2, true)); | ||||||
|  | 		TRY(m_open_file_descriptors.socketpair(domain, type, protocol, socket_vector)); | ||||||
|  | 		return 0; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	BAN::ErrorOr<long> Process::sys_getsockname(int socket, sockaddr* address, socklen_t* address_len) | 	BAN::ErrorOr<long> Process::sys_getsockname(int socket, sockaddr* address, socklen_t* address_len) | ||||||
| 	{ | 	{ | ||||||
| 		LockGuard _(m_process_lock); | 		LockGuard _(m_process_lock); | ||||||
|  |  | ||||||
|  | @ -64,6 +64,7 @@ __BEGIN_DECLS | ||||||
| 	O(SYS_FCHOWNAT,			fchownat)		\ | 	O(SYS_FCHOWNAT,			fchownat)		\ | ||||||
| 	O(SYS_LOAD_KEYMAP,		load_keymap)	\ | 	O(SYS_LOAD_KEYMAP,		load_keymap)	\ | ||||||
| 	O(SYS_SOCKET,			socket)			\ | 	O(SYS_SOCKET,			socket)			\ | ||||||
|  | 	O(SYS_SOCKETPAIR,		socketpair)		\ | ||||||
| 	O(SYS_BIND,				bind)			\ | 	O(SYS_BIND,				bind)			\ | ||||||
| 	O(SYS_SENDTO,			sendto)			\ | 	O(SYS_SENDTO,			sendto)			\ | ||||||
| 	O(SYS_RECVFROM,			recvfrom)		\ | 	O(SYS_RECVFROM,			recvfrom)		\ | ||||||
|  |  | ||||||
|  | @ -68,6 +68,11 @@ int socket(int domain, int type, int protocol) | ||||||
| 	return syscall(SYS_SOCKET, domain, type, protocol); | 	return syscall(SYS_SOCKET, domain, type, protocol); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | int socketpair(int domain, int type, int protocol, int socket_vector[2]) | ||||||
|  | { | ||||||
|  | 	return syscall(SYS_SOCKETPAIR, domain, type, protocol, socket_vector); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| int getsockname(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len) | int getsockname(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len) | ||||||
| { | { | ||||||
| 	return syscall(SYS_GETSOCKNAME, socket, address, address_len); | 	return syscall(SYS_GETSOCKNAME, socket, address, address_len); | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue