diff --git a/userspace/libraries/LibC/sys/socket.cpp b/userspace/libraries/LibC/sys/socket.cpp index d6a5cd59..251fb57b 100644 --- a/userspace/libraries/LibC/sys/socket.cpp +++ b/userspace/libraries/LibC/sys/socket.cpp @@ -1,3 +1,6 @@ +#include + +#include #include #include #include @@ -33,7 +36,7 @@ int listen(int socket, int backlog) ssize_t recv(int socket, void* __restrict buffer, size_t length, int flags) { - pthread_testcancel(); + // cancellation point in recvfrom return recvfrom(socket, buffer, length, flags, nullptr, nullptr); } @@ -53,7 +56,7 @@ ssize_t recvfrom(int socket, void* __restrict buffer, size_t length, int flags, ssize_t send(int socket, const void* message, size_t length, int flags) { - pthread_testcancel(); + // cancellation point in sendto return sendto(socket, message, length, flags, nullptr, 0); } @@ -71,6 +74,74 @@ ssize_t sendto(int socket, const void* message, size_t length, int flags, const return syscall(SYS_SENDTO, &arguments); } +ssize_t recvmsg(int socket, struct msghdr* message, int flags) +{ + if (CMSG_FIRSTHDR(message)) + { + dwarnln("TODO: recvmsg ancillary data"); + errno = ENOTSUP; + return -1; + } + + size_t total_recv = 0; + + for (int i = 0; i < message->msg_iovlen; i++) + { + const ssize_t nrecv = recvfrom( + socket, + message->msg_iov[i].iov_base, + message->msg_iov[i].iov_len, + flags, + static_cast(message->msg_name), + &message->msg_namelen + ); + + if (nrecv < 0) + return -1; + + total_recv += nrecv; + + if (static_cast(nrecv) < message->msg_iov[i].iov_len) + break; + } + + return total_recv; +} + +ssize_t sendmsg(int socket, const struct msghdr* message, int flags) +{ + if (CMSG_FIRSTHDR(message)) + { + dwarnln("TODO: sendmsg ancillary data"); + errno = ENOTSUP; + return -1; + } + + size_t total_sent = 0; + + for (int i = 0; i < message->msg_iovlen; i++) + { + const ssize_t nsend = sendto( + socket, + message->msg_iov[i].iov_base, + message->msg_iov[i].iov_len, + flags, + static_cast(message->msg_name), + message->msg_namelen + ); + + if (nsend < 0) + return -1; + + total_sent += nsend; + + if (static_cast(nsend) < message->msg_iov[i].iov_len) + break; + } + + return total_sent; +} + int socket(int domain, int type, int protocol) { return syscall(SYS_SOCKET, domain, type, protocol);