forked from Bananymous/banan-os
Kernel/LibC: Implement `socketpair` for UNIX sockets
This commit is contained in:
parent
12b93567f7
commit
89c9bfd052
|
@ -17,6 +17,7 @@ namespace Kernel
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(Socket::Type, const Socket::Info&);
|
static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(Socket::Type, const Socket::Info&);
|
||||||
|
BAN::ErrorOr<void> make_socket_pair(UnixDomainSocket&);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*, int) override;
|
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*, int) override;
|
||||||
|
|
|
@ -26,6 +26,7 @@ namespace Kernel
|
||||||
BAN::ErrorOr<int> open(BAN::StringView absolute_path, int flags);
|
BAN::ErrorOr<int> open(BAN::StringView absolute_path, int flags);
|
||||||
|
|
||||||
BAN::ErrorOr<int> socket(int domain, int type, int protocol);
|
BAN::ErrorOr<int> socket(int domain, int type, int protocol);
|
||||||
|
BAN::ErrorOr<void> socketpair(int domain, int type, int protocol, int socket_vector[2]);
|
||||||
|
|
||||||
BAN::ErrorOr<void> pipe(int fds[2]);
|
BAN::ErrorOr<void> pipe(int fds[2]);
|
||||||
|
|
||||||
|
|
|
@ -118,6 +118,7 @@ namespace Kernel
|
||||||
BAN::ErrorOr<long> sys_fchownat(int fd, const char* path, uid_t uid, gid_t gid, int flag);
|
BAN::ErrorOr<long> sys_fchownat(int fd, const char* path, uid_t uid, gid_t gid, int flag);
|
||||||
|
|
||||||
BAN::ErrorOr<long> sys_socket(int domain, int type, int protocol);
|
BAN::ErrorOr<long> sys_socket(int domain, int type, int protocol);
|
||||||
|
BAN::ErrorOr<long> sys_socketpair(int domain, int type, int protocol, int socket_vector[2]);
|
||||||
BAN::ErrorOr<long> sys_getsockname(int socket, sockaddr* address, socklen_t* address_len);
|
BAN::ErrorOr<long> sys_getsockname(int socket, sockaddr* address, socklen_t* address_len);
|
||||||
BAN::ErrorOr<long> sys_getpeername(int socket, sockaddr* address, socklen_t* address_len);
|
BAN::ErrorOr<long> sys_getpeername(int socket, sockaddr* address, socklen_t* address_len);
|
||||||
BAN::ErrorOr<long> sys_getsockopt(int socket, int level, int option_name, void* option_value, socklen_t* option_len);
|
BAN::ErrorOr<long> sys_getsockopt(int socket, int level, int option_name, void* option_value, socklen_t* option_len);
|
||||||
|
|
|
@ -129,6 +129,25 @@ namespace Kernel
|
||||||
return socket;
|
return socket;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BAN::ErrorOr<void> NetworkManager::connect_sockets(Socket::Domain domain, BAN::RefPtr<Socket> socket1, BAN::RefPtr<Socket> socket2)
|
||||||
|
{
|
||||||
|
switch (domain)
|
||||||
|
{
|
||||||
|
case Socket::Domain::UNIX:
|
||||||
|
{
|
||||||
|
auto* usocket1 = static_cast<UnixDomainSocket*>(socket1.ptr());
|
||||||
|
auto* usocket2 = static_cast<UnixDomainSocket*>(socket2.ptr());
|
||||||
|
TRY(usocket1->make_socket_pair(*usocket2));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
dwarnln("TODO: connect {} domain sockets", static_cast<int>(domain));
|
||||||
|
return BAN::Error::from_errno(ENOTSUP);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
void NetworkManager::on_receive(NetworkInterface& interface, BAN::ConstByteSpan packet)
|
void NetworkManager::on_receive(NetworkInterface& interface, BAN::ConstByteSpan packet)
|
||||||
{
|
{
|
||||||
if (packet.size() < sizeof(EthernetHeader))
|
if (packet.size() < sizeof(EthernetHeader))
|
||||||
|
|
|
@ -69,6 +69,20 @@ namespace Kernel
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BAN::ErrorOr<void> UnixDomainSocket::make_socket_pair(UnixDomainSocket& other)
|
||||||
|
{
|
||||||
|
if (!m_info.has<ConnectionInfo>() || !other.m_info.has<ConnectionInfo>())
|
||||||
|
return BAN::Error::from_errno(EINVAL);
|
||||||
|
|
||||||
|
TRY(this->get_weak_ptr());
|
||||||
|
TRY(other.get_weak_ptr());
|
||||||
|
|
||||||
|
this->m_info.get<ConnectionInfo>().connection = MUST(other.get_weak_ptr());
|
||||||
|
other.m_info.get<ConnectionInfo>().connection = MUST(this->get_weak_ptr());
|
||||||
|
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
BAN::ErrorOr<long> UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len, int flags)
|
BAN::ErrorOr<long> UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len, int flags)
|
||||||
{
|
{
|
||||||
if (!m_info.has<ConnectionInfo>())
|
if (!m_info.has<ConnectionInfo>())
|
||||||
|
|
|
@ -83,50 +83,57 @@ namespace Kernel
|
||||||
return open(TRY(VirtualFileSystem::get().file_from_absolute_path(m_credentials, absolute_path, flags)), flags);
|
return open(TRY(VirtualFileSystem::get().file_from_absolute_path(m_credentials, absolute_path, flags)), flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
BAN::ErrorOr<int> OpenFileDescriptorSet::socket(int domain, int type, int protocol)
|
struct SocketInfo
|
||||||
{
|
{
|
||||||
bool valid_protocol = true;
|
Socket::Domain domain;
|
||||||
|
Socket::Type type;
|
||||||
|
int status_flags;
|
||||||
|
int descriptor_flags;
|
||||||
|
};
|
||||||
|
|
||||||
Socket::Domain sock_domain;
|
static BAN::ErrorOr<SocketInfo> parse_socket_info(int domain, int type, int protocol)
|
||||||
|
{
|
||||||
|
SocketInfo info;
|
||||||
|
|
||||||
|
bool valid_protocol = true;
|
||||||
switch (domain)
|
switch (domain)
|
||||||
{
|
{
|
||||||
case AF_INET:
|
case AF_INET:
|
||||||
sock_domain = Socket::Domain::INET;
|
info.domain = Socket::Domain::INET;
|
||||||
break;
|
break;
|
||||||
case AF_INET6:
|
case AF_INET6:
|
||||||
sock_domain = Socket::Domain::INET6;
|
info.domain = Socket::Domain::INET6;
|
||||||
break;
|
break;
|
||||||
case AF_UNIX:
|
case AF_UNIX:
|
||||||
sock_domain = Socket::Domain::UNIX;
|
info.domain = Socket::Domain::UNIX;
|
||||||
valid_protocol = false;
|
valid_protocol = false;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return BAN::Error::from_errno(EPROTOTYPE);
|
return BAN::Error::from_errno(EPROTOTYPE);
|
||||||
}
|
}
|
||||||
|
|
||||||
int status_flags = 0;
|
info.status_flags = 0;
|
||||||
int descriptor_flags = 0;
|
info.descriptor_flags = 0;
|
||||||
if (type & SOCK_NONBLOCK)
|
if (type & SOCK_NONBLOCK)
|
||||||
status_flags |= O_NONBLOCK;
|
info.status_flags |= O_NONBLOCK;
|
||||||
if (type & SOCK_CLOEXEC)
|
if (type & SOCK_CLOEXEC)
|
||||||
descriptor_flags |= O_CLOEXEC;
|
info.descriptor_flags |= O_CLOEXEC;
|
||||||
type &= ~(SOCK_NONBLOCK | SOCK_CLOEXEC);
|
type &= ~(SOCK_NONBLOCK | SOCK_CLOEXEC);
|
||||||
|
|
||||||
Socket::Type sock_type;
|
|
||||||
switch (type)
|
switch (type)
|
||||||
{
|
{
|
||||||
case SOCK_STREAM:
|
case SOCK_STREAM:
|
||||||
sock_type = Socket::Type::STREAM;
|
info.type = Socket::Type::STREAM;
|
||||||
if (protocol != IPPROTO_TCP)
|
if (protocol != IPPROTO_TCP)
|
||||||
valid_protocol = false;
|
valid_protocol = false;
|
||||||
break;
|
break;
|
||||||
case SOCK_DGRAM:
|
case SOCK_DGRAM:
|
||||||
sock_type = Socket::Type::DGRAM;
|
info.type = Socket::Type::DGRAM;
|
||||||
if (protocol != IPPROTO_UDP)
|
if (protocol != IPPROTO_UDP)
|
||||||
valid_protocol = false;
|
valid_protocol = false;
|
||||||
break;
|
break;
|
||||||
case SOCK_SEQPACKET:
|
case SOCK_SEQPACKET:
|
||||||
sock_type = Socket::Type::SEQPACKET;
|
info.type = Socket::Type::SEQPACKET;
|
||||||
valid_protocol = false;
|
valid_protocol = false;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -136,15 +143,39 @@ namespace Kernel
|
||||||
if (protocol && !valid_protocol)
|
if (protocol && !valid_protocol)
|
||||||
return BAN::Error::from_errno(EPROTONOSUPPORT);
|
return BAN::Error::from_errno(EPROTONOSUPPORT);
|
||||||
|
|
||||||
auto socket = TRY(NetworkManager::get().create_socket(sock_domain, sock_type, 0777, m_credentials.euid(), m_credentials.egid()));
|
return info;
|
||||||
|
}
|
||||||
|
|
||||||
|
BAN::ErrorOr<int> OpenFileDescriptorSet::socket(int domain, int type, int protocol)
|
||||||
|
{
|
||||||
|
auto sock_info = TRY(parse_socket_info(domain, type, protocol));
|
||||||
|
auto socket = TRY(NetworkManager::get().create_socket(sock_info.domain, sock_info.type, 0777, m_credentials.euid(), m_credentials.egid()));
|
||||||
|
|
||||||
LockGuard _(m_mutex);
|
LockGuard _(m_mutex);
|
||||||
int fd = TRY(get_free_fd());
|
int fd = TRY(get_free_fd());
|
||||||
m_open_files[fd].description = TRY(BAN::RefPtr<OpenFileDescription>::create(VirtualFileSystem::File(socket, "<socket>"_sv), 0, O_RDWR | status_flags));
|
m_open_files[fd].description = TRY(BAN::RefPtr<OpenFileDescription>::create(VirtualFileSystem::File(socket, "<socket>"_sv), 0, O_RDWR | sock_info.status_flags));
|
||||||
m_open_files[fd].descriptor_flags = descriptor_flags;
|
m_open_files[fd].descriptor_flags = sock_info.descriptor_flags;
|
||||||
return fd;
|
return fd;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BAN::ErrorOr<void> OpenFileDescriptorSet::socketpair(int domain, int type, int protocol, int socket_vector[2])
|
||||||
|
{
|
||||||
|
auto sock_info = TRY(parse_socket_info(domain, type, protocol));
|
||||||
|
|
||||||
|
auto socket1 = TRY(NetworkManager::get().create_socket(sock_info.domain, sock_info.type, 0600, m_credentials.euid(), m_credentials.egid()));
|
||||||
|
auto socket2 = TRY(NetworkManager::get().create_socket(sock_info.domain, sock_info.type, 0600, m_credentials.euid(), m_credentials.egid()));
|
||||||
|
TRY(NetworkManager::get().connect_sockets(sock_info.domain, socket1, socket2));
|
||||||
|
|
||||||
|
LockGuard _(m_mutex);
|
||||||
|
|
||||||
|
TRY(get_free_fd_pair(socket_vector));
|
||||||
|
m_open_files[socket_vector[0]].description = TRY(BAN::RefPtr<OpenFileDescription>::create(VirtualFileSystem::File(socket1, "<socketpair>"_sv), 0, O_RDWR | sock_info.status_flags));
|
||||||
|
m_open_files[socket_vector[0]].descriptor_flags = sock_info.descriptor_flags;
|
||||||
|
m_open_files[socket_vector[1]].description = TRY(BAN::RefPtr<OpenFileDescription>::create(VirtualFileSystem::File(socket2, "<socketpair>"_sv), 0, O_RDWR | sock_info.status_flags));
|
||||||
|
m_open_files[socket_vector[1]].descriptor_flags = sock_info.descriptor_flags;
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
BAN::ErrorOr<void> OpenFileDescriptorSet::pipe(int fds[2])
|
BAN::ErrorOr<void> OpenFileDescriptorSet::pipe(int fds[2])
|
||||||
{
|
{
|
||||||
LockGuard _(m_mutex);
|
LockGuard _(m_mutex);
|
||||||
|
|
|
@ -1270,6 +1270,14 @@ namespace Kernel
|
||||||
return TRY(m_open_file_descriptors.socket(domain, type, protocol));
|
return TRY(m_open_file_descriptors.socket(domain, type, protocol));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BAN::ErrorOr<long> Process::sys_socketpair(int domain, int type, int protocol, int socket_vector[2])
|
||||||
|
{
|
||||||
|
LockGuard _(m_process_lock);
|
||||||
|
TRY(validate_pointer_access(socket_vector, sizeof(int) * 2, true));
|
||||||
|
TRY(m_open_file_descriptors.socketpair(domain, type, protocol, socket_vector));
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
BAN::ErrorOr<long> Process::sys_getsockname(int socket, sockaddr* address, socklen_t* address_len)
|
BAN::ErrorOr<long> Process::sys_getsockname(int socket, sockaddr* address, socklen_t* address_len)
|
||||||
{
|
{
|
||||||
LockGuard _(m_process_lock);
|
LockGuard _(m_process_lock);
|
||||||
|
|
|
@ -64,6 +64,7 @@ __BEGIN_DECLS
|
||||||
O(SYS_FCHOWNAT, fchownat) \
|
O(SYS_FCHOWNAT, fchownat) \
|
||||||
O(SYS_LOAD_KEYMAP, load_keymap) \
|
O(SYS_LOAD_KEYMAP, load_keymap) \
|
||||||
O(SYS_SOCKET, socket) \
|
O(SYS_SOCKET, socket) \
|
||||||
|
O(SYS_SOCKETPAIR, socketpair) \
|
||||||
O(SYS_BIND, bind) \
|
O(SYS_BIND, bind) \
|
||||||
O(SYS_SENDTO, sendto) \
|
O(SYS_SENDTO, sendto) \
|
||||||
O(SYS_RECVFROM, recvfrom) \
|
O(SYS_RECVFROM, recvfrom) \
|
||||||
|
|
|
@ -68,6 +68,11 @@ int socket(int domain, int type, int protocol)
|
||||||
return syscall(SYS_SOCKET, domain, type, protocol);
|
return syscall(SYS_SOCKET, domain, type, protocol);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int socketpair(int domain, int type, int protocol, int socket_vector[2])
|
||||||
|
{
|
||||||
|
return syscall(SYS_SOCKETPAIR, domain, type, protocol, socket_vector);
|
||||||
|
}
|
||||||
|
|
||||||
int getsockname(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len)
|
int getsockname(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len)
|
||||||
{
|
{
|
||||||
return syscall(SYS_GETSOCKNAME, socket, address, address_len);
|
return syscall(SYS_GETSOCKNAME, socket, address, address_len);
|
||||||
|
|
Loading…
Reference in New Issue