diff --git a/userspace/libraries/LibGUI/Window.cpp b/userspace/libraries/LibGUI/Window.cpp index b77c10ed..f633bb9a 100644 --- a/userspace/libraries/LibGUI/Window.cpp +++ b/userspace/libraries/LibGUI/Window.cpp @@ -16,47 +16,6 @@ namespace LibGUI { - struct ReceivePacket - { - PacketType type; - BAN::Vector data_with_type; - }; - - static BAN::ErrorOr recv_packet(int socket) - { - uint32_t packet_size; - - { - const ssize_t nrecv = recv(socket, &packet_size, sizeof(uint32_t), 0); - if (nrecv < 0) - return BAN::Error::from_errno(errno); - if (nrecv == 0) - return BAN::Error::from_errno(ECONNRESET); - } - - if (packet_size < sizeof(uint32_t)) - return BAN::Error::from_literal("invalid packet, does not fit packet id"); - - BAN::Vector packet_data; - TRY(packet_data.resize(packet_size)); - - size_t total_recv = 0; - while (total_recv < packet_size) - { - const ssize_t nrecv = recv(socket, packet_data.data() + total_recv, packet_size - total_recv, 0); - if (nrecv < 0) - return BAN::Error::from_errno(errno); - if (nrecv == 0) - return BAN::Error::from_errno(ECONNRESET); - total_recv += nrecv; - } - - return ReceivePacket { - *reinterpret_cast(packet_data.data()), - packet_data - }; - } - Window::~Window() { cleanup(); @@ -64,7 +23,7 @@ namespace LibGUI BAN::ErrorOr> Window::create(uint32_t width, uint32_t height, BAN::StringView title, Attributes attributes) { - int server_fd = socket(AF_UNIX, SOCK_STREAM, 0); + int server_fd = socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); if (server_fd == -1) return BAN::Error::from_errno(errno); BAN::ScopeGuard server_closer([server_fd] { close(server_fd); }); @@ -74,16 +33,10 @@ namespace LibGUI return BAN::Error::from_errno(errno); BAN::ScopeGuard epoll_closer([epoll_fd] { close(epoll_fd); }); - epoll_event epoll_event { - .events = EPOLLIN, - .data = { .fd = server_fd }, - }; + epoll_event epoll_event { .events = EPOLLIN, .data = { .fd = server_fd } }; if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, server_fd, &epoll_event) == -1) return BAN::Error::from_errno(errno); - if (fcntl(server_fd, F_SETFD, fcntl(server_fd, F_GETFD) | FD_CLOEXEC) == -1) - return BAN::Error::from_errno(errno); - timespec start_time; clock_gettime(CLOCK_MONOTONIC, &start_time); @@ -107,19 +60,23 @@ namespace LibGUI nanosleep(&sleep_time, nullptr); } + auto window = TRY(BAN::UniqPtr::create(server_fd, epoll_fd, attributes)); + WindowPacket::WindowCreate create_packet; create_packet.width = width; create_packet.height = height; create_packet.attributes = attributes; TRY(create_packet.title.append(title)); - TRY(create_packet.send_serialized(server_fd)); - - auto window = TRY(BAN::UniqPtr::create(server_fd, epoll_fd, attributes)); + window->send_packet(create_packet, __FUNCTION__); bool resized = false; window->set_resize_window_event_callback([&]() { resized = true; }); while (!resized) + { + // FIXME: timeout? + window->wait_events(); window->poll_events(); + } window->set_resize_window_event_callback({}); server_closer.disable(); @@ -128,6 +85,30 @@ namespace LibGUI return window; } + template + void Window::send_packet(const T& packet, BAN::StringView function) + { + const size_t serialized_size = packet.serialized_size(); + if (serialized_size > m_out_buffer.size()) + { + dwarnln("cannot to send {} byte packet", serialized_size); + return on_socket_error(function); + } + + packet.serialize(m_in_buffer.span()); + + size_t total_sent = 0; + while (total_sent < serialized_size) + { + const ssize_t nsend = send(m_server_fd, m_in_buffer.data() + total_sent, serialized_size - total_sent, 0); + if (nsend < 0) + dwarnln("send: {}", strerror(errno)); + if (nsend <= 0) + return on_socket_error(function); + total_sent += nsend; + } + } + BAN::ErrorOr Window::set_root_widget(BAN::RefPtr widget) { TRY(widget->set_fixed_geometry({ 0, 0, m_width, m_height })); @@ -186,36 +167,28 @@ namespace LibGUI packet.y = y; packet.width = width; packet.height = height; - - if (auto ret = packet.send_serialized(m_server_fd); ret.is_error()) - return on_socket_error(__FUNCTION__); + send_packet(packet, __FUNCTION__); } void Window::set_mouse_relative(bool enabled) { WindowPacket::WindowSetMouseRelative packet; packet.enabled = enabled; - - if (auto ret = packet.send_serialized(m_server_fd); ret.is_error()) - return on_socket_error(__FUNCTION__); + send_packet(packet, __FUNCTION__); } void Window::set_fullscreen(bool fullscreen) { WindowPacket::WindowSetFullscreen packet; packet.fullscreen = fullscreen; - - if (auto ret = packet.send_serialized(m_server_fd); ret.is_error()) - return on_socket_error(__FUNCTION__); + send_packet(packet, __FUNCTION__); } void Window::set_title(BAN::StringView title) { WindowPacket::WindowSetTitle packet; MUST(packet.title.append(title)); - - if (auto ret = packet.send_serialized(m_server_fd); ret.is_error()) - return on_socket_error(__FUNCTION__); + send_packet(packet, __FUNCTION__); } void Window::set_position(int32_t x, int32_t y) @@ -223,9 +196,7 @@ namespace LibGUI WindowPacket::WindowSetPosition packet; packet.x = x; packet.y = y; - - if (auto ret = packet.send_serialized(m_server_fd); ret.is_error()) - return on_socket_error(__FUNCTION__); + send_packet(packet, __FUNCTION__); } void Window::set_cursor_visible(bool visible) @@ -247,9 +218,7 @@ namespace LibGUI MUST(packet.pixels.resize(pixels.size())); for (size_t i = 0; i < packet.pixels.size(); i++) packet.pixels[i] = pixels[i]; - - if (auto ret = packet.send_serialized(m_server_fd); ret.is_error()) - return on_socket_error(__FUNCTION__); + send_packet(packet, __FUNCTION__); } void Window::set_min_size(uint32_t width, uint32_t height) @@ -257,9 +226,7 @@ namespace LibGUI WindowPacket::WindowSetMinSize packet; packet.width = width; packet.height = height; - - if (auto ret = packet.send_serialized(m_server_fd); ret.is_error()) - return on_socket_error(__FUNCTION__); + send_packet(packet, __FUNCTION__); } void Window::set_max_size(uint32_t width, uint32_t height) @@ -267,19 +234,14 @@ namespace LibGUI WindowPacket::WindowSetMaxSize packet; packet.width = width; packet.height = height; - - if (auto ret = packet.send_serialized(m_server_fd); ret.is_error()) - return on_socket_error(__FUNCTION__); + send_packet(packet, __FUNCTION__); } void Window::set_attributes(Attributes attributes) { WindowPacket::WindowSetAttributes packet; packet.attributes = attributes; - - if (auto ret = packet.send_serialized(m_server_fd); ret.is_error()) - return on_socket_error(__FUNCTION__); - + send_packet(packet, __FUNCTION__); m_attributes = attributes; } @@ -288,9 +250,7 @@ namespace LibGUI WindowPacket::WindowSetSize packet; packet.width = width; packet.height = height; - - if (auto ret = packet.send_serialized(m_server_fd); ret.is_error()) - return on_socket_error(__FUNCTION__); + send_packet(packet, __FUNCTION__); } void Window::on_socket_error(BAN::StringView function) @@ -347,82 +307,121 @@ namespace LibGUI void Window::poll_events() { -#define TRY_OR_BREAK(...) ({ auto&& e = (__VA_ARGS__); if (e.is_error()) break; e.release_value(); }) for (;;) { epoll_event event; if (epoll_wait(m_epoll_fd, &event, 1, 0) == 0) break; - auto packet_or_error = recv_packet(m_server_fd); - if (packet_or_error.is_error()) + if (event.events & (EPOLLHUP | EPOLLERR)) return on_socket_error(__FUNCTION__); - const auto [packet_type, packet_data] = packet_or_error.release_value(); - switch (packet_type) + ASSERT(event.events & EPOLLIN); + { - case PacketType::DestroyWindowEvent: - exit(1); - case PacketType::CloseWindowEvent: - if (m_close_window_event_callback) - m_close_window_event_callback(); - else - exit(0); + const ssize_t nrecv = recv(m_server_fd, m_in_buffer.data() + m_in_buffer_size, m_in_buffer.size() - m_in_buffer_size, 0); + if (nrecv <= 0) + return on_socket_error(__FUNCTION__); + if (nrecv > 0) + m_in_buffer_size += nrecv; + } + + size_t bytes_handled = 0; + while (m_in_buffer_size - bytes_handled >= sizeof(PacketHeader)) + { + BAN::ConstByteSpan packet_span = m_in_buffer.span().slice(bytes_handled); + const auto header = packet_span.as(); + if (packet_span.size() < header.size || header.size < sizeof(LibGUI::PacketHeader)) break; - case PacketType::ResizeWindowEvent: + packet_span = packet_span.slice(0, header.size); + + switch (header.type) { - MUST(handle_resize_event(TRY_OR_BREAK(EventPacket::ResizeWindowEvent::deserialize(packet_data.span())))); - if (m_resize_window_event_callback) - m_resize_window_event_callback(); - break; - } - case PacketType::WindowShownEvent: - if (m_window_shown_event_callback) - m_window_shown_event_callback(TRY_OR_BREAK(EventPacket::WindowShownEvent::deserialize(packet_data.span())).event); - break; - case PacketType::WindowFocusEvent: - if (m_window_focus_event_callback) - m_window_focus_event_callback(TRY_OR_BREAK(EventPacket::WindowFocusEvent::deserialize(packet_data.span())).event); - break; - case PacketType::WindowFullscreenEvent: - if (m_window_fullscreen_event_callback) - m_window_fullscreen_event_callback(TRY_OR_BREAK(EventPacket::WindowFullscreenEvent::deserialize(packet_data.span())).event); - break; - case PacketType::KeyEvent: - if (m_key_event_callback) - m_key_event_callback(TRY_OR_BREAK(EventPacket::KeyEvent::deserialize(packet_data.span())).event); - break; - case PacketType::MouseButtonEvent: - { - auto event = TRY_OR_BREAK(EventPacket::MouseButtonEvent::deserialize(packet_data.span())).event; - if (m_mouse_button_event_callback) - m_mouse_button_event_callback(event); - if (m_root_widget) - m_root_widget->on_mouse_button(event); - break; - } - case PacketType::MouseMoveEvent: - { - auto event = TRY_OR_BREAK(EventPacket::MouseMoveEvent::deserialize(packet_data.span())).event; - if (m_mouse_move_event_callback) - m_mouse_move_event_callback(event); - if (m_root_widget) +#define TRY_OR_BREAK(...) ({ auto&& e = (__VA_ARGS__); if (e.is_error()) break; e.release_value(); }) + case PacketType::DestroyWindowEvent: + exit(1); + case PacketType::CloseWindowEvent: + if (m_close_window_event_callback) + m_close_window_event_callback(); + else + exit(0); + break; + case PacketType::ResizeWindowEvent: { - m_root_widget->before_mouse_move(); - m_root_widget->on_mouse_move(event); - m_root_widget->after_mouse_move(); + MUST(handle_resize_event(TRY_OR_BREAK(EventPacket::ResizeWindowEvent::deserialize(packet_span)))); + if (m_resize_window_event_callback) + m_resize_window_event_callback(); + break; } - break; + case PacketType::WindowShownEvent: + if (m_window_shown_event_callback) + m_window_shown_event_callback(TRY_OR_BREAK(EventPacket::WindowShownEvent::deserialize(packet_span)).event); + break; + case PacketType::WindowFocusEvent: + if (m_window_focus_event_callback) + m_window_focus_event_callback(TRY_OR_BREAK(EventPacket::WindowFocusEvent::deserialize(packet_span)).event); + break; + case PacketType::WindowFullscreenEvent: + if (m_window_fullscreen_event_callback) + m_window_fullscreen_event_callback(TRY_OR_BREAK(EventPacket::WindowFullscreenEvent::deserialize(packet_span)).event); + break; + case PacketType::KeyEvent: + if (m_key_event_callback) + m_key_event_callback(TRY_OR_BREAK(EventPacket::KeyEvent::deserialize(packet_span)).event); + break; + case PacketType::MouseButtonEvent: + { + auto event = TRY_OR_BREAK(EventPacket::MouseButtonEvent::deserialize(packet_span)).event; + if (m_mouse_button_event_callback) + m_mouse_button_event_callback(event); + if (m_root_widget) + m_root_widget->on_mouse_button(event); + break; + } + case PacketType::MouseMoveEvent: + { + auto event = TRY_OR_BREAK(EventPacket::MouseMoveEvent::deserialize(packet_span)).event; + if (m_mouse_move_event_callback) + m_mouse_move_event_callback(event); + if (m_root_widget) + { + m_root_widget->before_mouse_move(); + m_root_widget->on_mouse_move(event); + m_root_widget->after_mouse_move(); + } + break; + } + case PacketType::MouseScrollEvent: + if (m_mouse_scroll_event_callback) + m_mouse_scroll_event_callback(TRY_OR_BREAK(EventPacket::MouseScrollEvent::deserialize(packet_span)).event); + break; +#undef TRY_OR_BREAK + default: + dprintln("unhandled packet type: {}", static_cast(header.type)); + break; + } + + bytes_handled += header.size; + } + + // NOTE: this will only move a single partial packet, so this is fine + m_in_buffer_size -= bytes_handled; + memmove( + m_in_buffer.data(), + m_in_buffer.data() + bytes_handled, + m_in_buffer_size + ); + + if (m_in_buffer_size >= sizeof(LibGUI::PacketHeader)) + { + const auto header = BAN::ConstByteSpan(m_in_buffer.span()).as(); + if (header.size < sizeof(LibGUI::PacketHeader) || header.size > m_in_buffer.size()) + { + dwarnln("server tried to send a {} byte packet", header.size); + return on_socket_error(__FUNCTION__); } - case PacketType::MouseScrollEvent: - if (m_mouse_scroll_event_callback) - m_mouse_scroll_event_callback(TRY_OR_BREAK(EventPacket::MouseScrollEvent::deserialize(packet_data.span())).event); - break; - default: - break; } } -#undef TRY_OR_BREAK if (m_root_widget) { diff --git a/userspace/libraries/LibGUI/include/LibGUI/Packet.h b/userspace/libraries/LibGUI/include/LibGUI/Packet.h index 31a6e6d8..4c3ac003 100644 --- a/userspace/libraries/LibGUI/include/LibGUI/Packet.h +++ b/userspace/libraries/LibGUI/include/LibGUI/Packet.h @@ -1,15 +1,13 @@ #pragma once +#include #include #include -#include +#include #include #include -#include -#include - #define FOR_EACH_0(macro) #define FOR_EACH_2(macro, type, name) macro(type, name) #define FOR_EACH_4(macro, type, name, ...) macro(type, name) FOR_EACH_2(macro, __VA_ARGS__) @@ -31,7 +29,7 @@ #define FIELD_DECL(type, name) type name; #define ADD_SERIALIZED_SIZE(type, name) serialized_size += Serialize::serialized_size_impl(this->name); -#define SEND_SERIALIZED(type, name) TRY(Serialize::send_serialized_impl(socket, this->name)); +#define SERIALIZE(type, name) Serialize::serialize_impl(buffer, this->name); #define DESERIALIZE(type, name) value.name = TRY(Serialize::deserialize_impl(buffer)); #define DEFINE_PACKET_EXTRA(name, extra, ...) \ @@ -44,26 +42,28 @@ \ FOR_EACH(FIELD_DECL, __VA_ARGS__) \ \ - size_t serialized_size() \ + size_t serialized_size() const \ { \ - size_t serialized_size = Serialize::serialized_size_impl(type_u32); \ + size_t serialized_size = 0; \ + serialized_size += Serialize::serialized_size_impl(0); \ + serialized_size += Serialize::serialized_size_impl(type_u32); \ FOR_EACH(ADD_SERIALIZED_SIZE, __VA_ARGS__) \ return serialized_size; \ } \ \ - BAN::ErrorOr send_serialized(int socket) \ + void serialize(BAN::ByteSpan buffer) const \ { \ const uint32_t serialized_size = this->serialized_size(); \ - TRY(Serialize::send_serialized_impl(socket, serialized_size)); \ - TRY(Serialize::send_serialized_impl(socket, type_u32)); \ - FOR_EACH(SEND_SERIALIZED, __VA_ARGS__) \ - return {}; \ + Serialize::serialize_impl(buffer, serialized_size); \ + Serialize::serialize_impl(buffer, type_u32); \ + FOR_EACH(SERIALIZE, __VA_ARGS__); \ } \ \ static BAN::ErrorOr deserialize(BAN::ConstByteSpan buffer) \ { \ + const uint32_t size_u32 = TRY(Serialize::deserialize_impl(buffer)); \ const uint32_t type_u32 = TRY(Serialize::deserialize_impl(buffer)); \ - if (type_u32 != name::type_u32) \ + if (type_u32 != name::type_u32 || size_u32 != buffer.size() + 2 * sizeof(uint32_t)) \ return BAN::Error::from_errno(EINVAL); \ name value; \ FOR_EACH(DESERIALIZE, __VA_ARGS__) \ @@ -90,19 +90,11 @@ namespace LibGUI namespace Serialize { - inline BAN::ErrorOr send_raw_data(int socket, BAN::ConstByteSpan data) + inline void append_raw_data(BAN::ByteSpan& buffer, BAN::ConstByteSpan data) { - size_t send_done = 0; - while (send_done < data.size()) - { - const ssize_t nsend = ::send(socket, data.data() + send_done, data.size() - send_done, 0); - if (nsend < 0) - return BAN::Error::from_errno(errno); - if (nsend == 0) - return BAN::Error::from_errno(ECONNRESET); - send_done += nsend; - } - return {}; + ASSERT(buffer.size() >= data.size()); + memcpy(buffer.data(), data.data(), data.size()); + buffer = buffer.slice(data.size()); } template requires BAN::is_pod_v @@ -112,10 +104,9 @@ namespace LibGUI } template requires BAN::is_pod_v - inline BAN::ErrorOr send_serialized_impl(int socket, const T& value) + inline void serialize_impl(BAN::ByteSpan& buffer, const T& value) { - TRY(send_raw_data(socket, BAN::ConstByteSpan::from(value))); - return {}; + append_raw_data(buffer, BAN::ConstByteSpan::from(value)); } template requires BAN::is_pod_v @@ -135,13 +126,12 @@ namespace LibGUI } template requires BAN::is_same_v - inline BAN::ErrorOr send_serialized_impl(int socket, const T& value) + inline void serialize_impl(BAN::ByteSpan& buffer, const T& value) { const uint32_t value_size = value.size(); - TRY(send_raw_data(socket, BAN::ConstByteSpan::from(value_size))); + append_raw_data(buffer, BAN::ConstByteSpan::from(value_size)); auto* u8_data = reinterpret_cast(value.data()); - TRY(send_raw_data(socket, BAN::ConstByteSpan(u8_data, value.size()))); - return {}; + append_raw_data(buffer, BAN::ConstByteSpan(u8_data, value.size())); } template requires BAN::is_same_v @@ -173,13 +163,12 @@ namespace LibGUI } template - inline BAN::ErrorOr send_serialized_impl(int socket, const T& vector) + inline void serialize_impl(BAN::ByteSpan& buffer, const T& vector) { const uint32_t value_size = vector.size(); - TRY(send_raw_data(socket, BAN::ConstByteSpan::from(value_size))); + append_raw_data(buffer, BAN::ConstByteSpan::from(value_size)); for (const auto& element : vector) - TRY(send_serialized_impl(socket, element)); - return {}; + serialize_impl(buffer, element); } template @@ -226,6 +215,12 @@ namespace LibGUI MouseScrollEvent, }; + struct PacketHeader + { + uint32_t size; + PacketType type; + }; + namespace WindowPacket { diff --git a/userspace/libraries/LibGUI/include/LibGUI/Window.h b/userspace/libraries/LibGUI/include/LibGUI/Window.h index a3ea642c..e3252ff4 100644 --- a/userspace/libraries/LibGUI/include/LibGUI/Window.h +++ b/userspace/libraries/LibGUI/include/LibGUI/Window.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -93,6 +94,9 @@ namespace LibGUI BAN::ErrorOr handle_resize_event(const EventPacket::ResizeWindowEvent&); + template + void send_packet(const T& packet, BAN::StringView function); + private: const int m_server_fd; const int m_epoll_fd; @@ -119,6 +123,11 @@ namespace LibGUI BAN::Function m_mouse_move_event_callback; BAN::Function m_mouse_scroll_event_callback; + size_t m_in_buffer_size { 0 }; + BAN::Array m_in_buffer; + + BAN::Array m_out_buffer; + friend class BAN::UniqPtr; }; diff --git a/userspace/programs/WindowServer/Window.cpp b/userspace/programs/WindowServer/Window.cpp index 78f20293..149a25ec 100644 --- a/userspace/programs/WindowServer/Window.cpp +++ b/userspace/programs/WindowServer/Window.cpp @@ -16,7 +16,22 @@ Window::~Window() smo_delete(m_smo_key); LibGUI::EventPacket::DestroyWindowEvent packet; - (void)packet.send_serialized(m_client_fd); + + BAN::Vector buffer; + if (!buffer.resize(packet.serialized_size()).is_error()) + { + packet.serialize(buffer.span()); + + size_t total_sent = 0; + while (total_sent < buffer.size()) + { + const ssize_t nsend = send(m_client_fd, buffer.data() + total_sent, buffer.size() - total_sent, 0); + if (nsend <= 0) + break; + total_sent += nsend; + } + } + close(m_client_fd); } diff --git a/userspace/programs/WindowServer/WindowServer.cpp b/userspace/programs/WindowServer/WindowServer.cpp index fef32bdd..b1d7eca3 100644 --- a/userspace/programs/WindowServer/WindowServer.cpp +++ b/userspace/programs/WindowServer/WindowServer.cpp @@ -86,7 +86,7 @@ void WindowServer::on_window_create(int fd, const LibGUI::WindowPacket::WindowCr response.width = window->client_width(); response.height = window->client_height(); response.smo_key = window->smo_key(); - if (auto ret = response.send_serialized(fd); ret.is_error()) + if (auto ret = append_serialized_packet(response, fd); ret.is_error()) { dwarnln("could not respond to window create request: {}", ret.error()); return; @@ -205,7 +205,7 @@ void WindowServer::on_window_set_attributes(int fd, const LibGUI::WindowPacket:: .shown = target_window->get_attributes().shown, }, }; - if (auto ret = event_packet.send_serialized(target_window->client_fd()); ret.is_error()) + if (auto ret = append_serialized_packet(event_packet, target_window->client_fd()); ret.is_error()) dwarnln("could not send window shown event: {}", ret.error()); if (packet.attributes.focusable && packet.attributes.shown && m_state == State::Normal) @@ -312,7 +312,7 @@ void WindowServer::on_window_set_fullscreen(int fd, const LibGUI::WindowPacket:: auto event_packet = LibGUI::EventPacket::WindowFullscreenEvent { .event = { .fullscreen = false } }; - if (auto ret = event_packet.send_serialized(m_focused_window->client_fd()); ret.is_error()) + if (auto ret = append_serialized_packet(event_packet, m_focused_window->client_fd()); ret.is_error()) dwarnln("could not send window fullscreen event: {}", ret.error()); m_state = State::Normal; @@ -347,7 +347,7 @@ void WindowServer::on_window_set_fullscreen(int fd, const LibGUI::WindowPacket:: auto event_packet = LibGUI::EventPacket::WindowFullscreenEvent { .event = { .fullscreen = true } }; - if (auto ret = event_packet.send_serialized(target_window->client_fd()); ret.is_error()) + if (auto ret = append_serialized_packet(event_packet, target_window->client_fd()); ret.is_error()) dwarnln("could not send window fullscreen event: {}", ret.error()); m_state = State::Fullscreen; @@ -488,7 +488,7 @@ void WindowServer::on_key_event(LibInput::KeyEvent event) if (m_is_mod_key_held && event.pressed() && event.key == LibInput::Key::Q) { LibGUI::EventPacket::CloseWindowEvent packet; - if (auto ret = packet.send_serialized(m_focused_window->client_fd()); ret.is_error()) + if (auto ret = append_serialized_packet(packet, m_focused_window->client_fd()); ret.is_error()) dwarnln("could not send window close event: {}", ret.error()); return; } @@ -521,7 +521,7 @@ void WindowServer::on_key_event(LibInput::KeyEvent event) auto event_packet = LibGUI::EventPacket::WindowFullscreenEvent { .event = { .fullscreen = (m_state == State::Fullscreen) } }; - if (auto ret = event_packet.send_serialized(m_focused_window->client_fd()); ret.is_error()) + if (auto ret = append_serialized_packet(event_packet, m_focused_window->client_fd()); ret.is_error()) dwarnln("could not send window fullscreen event: {}", ret.error()); invalidate(m_framebuffer.area()); @@ -530,7 +530,7 @@ void WindowServer::on_key_event(LibInput::KeyEvent event) LibGUI::EventPacket::KeyEvent packet; packet.event = event; - if (auto ret = packet.send_serialized(m_focused_window->client_fd()); ret.is_error()) + if (auto ret = append_serialized_packet(packet, m_focused_window->client_fd()); ret.is_error()) dwarnln("could not send key event: {}", ret.error()); } @@ -545,7 +545,7 @@ void WindowServer::on_mouse_button(LibInput::MouseButtonEvent event) packet.event.pressed = event.pressed; packet.event.x = 0; packet.event.y = 0; - if (auto ret = packet.send_serialized(m_focused_window->client_fd()); ret.is_error()) + if (auto ret = append_serialized_packet(packet, m_focused_window->client_fd()); ret.is_error()) dwarnln("could not send mouse button event: {}", ret.error()); return; } @@ -604,7 +604,7 @@ void WindowServer::on_mouse_button(LibInput::MouseButtonEvent event) if (event.button == LibInput::MouseButton::Left && !event.pressed && target_window->close_button_area().contains(m_cursor)) { LibGUI::EventPacket::CloseWindowEvent packet; - if (auto ret = packet.send_serialized(target_window->client_fd()); ret.is_error()) + if (auto ret = append_serialized_packet(packet, target_window->client_fd()); ret.is_error()) dwarnln("could not send close window event: {}", ret.error()); break; } @@ -618,7 +618,7 @@ void WindowServer::on_mouse_button(LibInput::MouseButtonEvent event) packet.event.pressed = event.pressed; packet.event.x = m_cursor.x - target_window->client_x(); packet.event.y = m_cursor.y - target_window->client_y(); - if (auto ret = packet.send_serialized(target_window->client_fd()); ret.is_error()) + if (auto ret = append_serialized_packet(packet, target_window->client_fd()); ret.is_error()) { dwarnln("could not send mouse button event: {}", ret.error()); return; @@ -649,7 +649,7 @@ void WindowServer::on_mouse_button(LibInput::MouseButtonEvent event) event.width = m_focused_window->client_width(); event.height = m_focused_window->client_height(); event.smo_key = m_focused_window->smo_key(); - if (auto ret = event.send_serialized(m_focused_window->client_fd()); ret.is_error()) + if (auto ret = append_serialized_packet(event, m_focused_window->client_fd()); ret.is_error()) { dwarnln("could not respond to window resize request: {}", ret.error()); return; @@ -698,7 +698,7 @@ void WindowServer::on_mouse_move_impl(int32_t new_x, int32_t new_y) LibGUI::EventPacket::MouseMoveEvent packet; packet.event.x = m_cursor.x - m_focused_window->client_x(); packet.event.y = m_cursor.y - m_focused_window->client_y(); - if (auto ret = packet.send_serialized(m_focused_window->client_fd()); ret.is_error()) + if (auto ret = append_serialized_packet(packet, m_focused_window->client_fd()); ret.is_error()) { dwarnln("could not send mouse move event: {}", ret.error()); return; @@ -736,7 +736,7 @@ void WindowServer::on_mouse_move(LibInput::MouseMoveEvent event) LibGUI::EventPacket::MouseMoveEvent packet; packet.event.x = event.rel_x; packet.event.y = -event.rel_y; - if (auto ret = packet.send_serialized(m_focused_window->client_fd()); ret.is_error()) + if (auto ret = append_serialized_packet(packet, m_focused_window->client_fd()); ret.is_error()) dwarnln("could not send mouse move event: {}", ret.error()); return; } @@ -807,7 +807,7 @@ void WindowServer::on_mouse_scroll(LibInput::MouseScrollEvent event) { LibGUI::EventPacket::MouseScrollEvent packet; packet.event.scroll = event.scroll; - if (auto ret = packet.send_serialized(m_focused_window->client_fd()); ret.is_error()) + if (auto ret = append_serialized_packet(packet, m_focused_window->client_fd()); ret.is_error()) { dwarnln("could not send mouse scroll event: {}", ret.error()); return; @@ -831,7 +831,7 @@ void WindowServer::set_focused_window(BAN::RefPtr window) { LibGUI::EventPacket::WindowFocusEvent packet; packet.event.focused = false; - if (auto ret = packet.send_serialized(m_focused_window->client_fd()); ret.is_error()) + if (auto ret = append_serialized_packet(packet, m_focused_window->client_fd()); ret.is_error()) dwarnln("could not send window focus event: {}", ret.error()); } @@ -851,7 +851,7 @@ void WindowServer::set_focused_window(BAN::RefPtr window) { LibGUI::EventPacket::WindowFocusEvent packet; packet.event.focused = true; - if (auto ret = packet.send_serialized(m_focused_window->client_fd()); ret.is_error()) + if (auto ret = append_serialized_packet(packet, m_focused_window->client_fd()); ret.is_error()) dwarnln("could not send window focus event: {}", ret.error()); } } @@ -1526,7 +1526,7 @@ BAN::RefPtr WindowServer::find_hovered_window() const return {}; } -bool WindowServer::resize_window(BAN::RefPtr window, uint32_t width, uint32_t height) const +bool WindowServer::resize_window(BAN::RefPtr window, uint32_t width, uint32_t height) { if (auto ret = window->resize(width, height); ret.is_error()) { @@ -1538,7 +1538,7 @@ bool WindowServer::resize_window(BAN::RefPtr window, uint32_t width, uin response.width = window->client_width(); response.height = window->client_height(); response.smo_key = window->smo_key(); - if (auto ret = response.send_serialized(window->client_fd()); ret.is_error()) + if (auto ret = append_serialized_packet(response, window->client_fd()); ret.is_error()) { dwarnln("could not respond to window resize request: {}", ret.error()); return false; @@ -1547,13 +1547,10 @@ bool WindowServer::resize_window(BAN::RefPtr window, uint32_t width, uin return true; } -void WindowServer::add_client_fd(int fd) +BAN::ErrorOr WindowServer::add_client_fd(int fd) { - if (auto ret = m_client_data.emplace(fd); ret.is_error()) - { - dwarnln("could not add client: {}", ret.error()); - return; - } + TRY(m_client_data.emplace(fd)); + return {}; } void WindowServer::remove_client_fd(int fd) @@ -1612,3 +1609,31 @@ WindowServer::ClientData& WindowServer::get_client_data(int fd) ASSERT_NOT_REACHED(); } + +// TODO: this epoll stuff is very hacky + +#include + +extern int g_epoll_fd; + +template +BAN::ErrorOr WindowServer::append_serialized_packet(const T& packet, int fd) +{ + const size_t serialized_size = packet.serialized_size(); + + auto& client_data = m_client_data[fd]; + if (client_data.out_buffer_size + serialized_size > client_data.out_buffer.size()) + return BAN::Error::from_errno(ENOBUFS); + + if (client_data.out_buffer_size == 0) + { + epoll_event event { .events = EPOLLIN | EPOLLOUT, .data = { .fd = fd } }; + if (epoll_ctl(g_epoll_fd, EPOLL_CTL_MOD, fd, &event) == -1) + dwarnln("epoll_ctl add EPOLLOUT: {}", strerror(errno)); + } + + packet.serialize(client_data.out_buffer.span().slice(client_data.out_buffer_size, serialized_size)); + client_data.out_buffer_size += serialized_size; + + return {}; +} diff --git a/userspace/programs/WindowServer/WindowServer.h b/userspace/programs/WindowServer/WindowServer.h index b357f5e7..0f92bf57 100644 --- a/userspace/programs/WindowServer/WindowServer.h +++ b/userspace/programs/WindowServer/WindowServer.h @@ -20,8 +20,10 @@ class WindowServer public: struct ClientData { - size_t packet_buffer_nread = 0; - BAN::Vector packet_buffer; + size_t in_buffer_size { 0 }; + BAN::Array in_buffer; + size_t out_buffer_size { 0 }; + BAN::Array out_buffer; }; public: @@ -54,7 +56,7 @@ public: Rectangle cursor_area() const; Rectangle resize_area(Position cursor) const; - void add_client_fd(int fd); + BAN::ErrorOr add_client_fd(int fd); void remove_client_fd(int fd); ClientData& get_client_data(int fd); @@ -65,11 +67,14 @@ private: void mark_pending_sync(Rectangle area); - bool resize_window(BAN::RefPtr window, uint32_t width, uint32_t height) const; + bool resize_window(BAN::RefPtr window, uint32_t width, uint32_t height); BAN::RefPtr find_window_with_fd(int fd) const; BAN::RefPtr find_hovered_window() const; + template + BAN::ErrorOr append_serialized_packet(const T& packet, int fd); + private: struct RangeList { diff --git a/userspace/programs/WindowServer/main.cpp b/userspace/programs/WindowServer/main.cpp index 9d894378..3ad136d7 100644 --- a/userspace/programs/WindowServer/main.cpp +++ b/userspace/programs/WindowServer/main.cpp @@ -145,6 +145,8 @@ int open_server_fd() return server_fd; } +int g_epoll_fd = -1; + int main() { srand(time(nullptr)); @@ -157,8 +159,8 @@ int main() return 1; } - int epoll_fd = epoll_create1(EPOLL_CLOEXEC); - if (epoll_fd == -1) + g_epoll_fd = epoll_create1(EPOLL_CLOEXEC); + if (g_epoll_fd == -1) { dwarnln("epoll_create1: {}", strerror(errno)); return 1; @@ -169,7 +171,7 @@ int main() .events = EPOLLIN, .data = { .fd = server_fd }, }; - if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, server_fd, &event) == -1) + if (epoll_ctl(g_epoll_fd, EPOLL_CTL_ADD, server_fd, &event) == -1) { dwarnln("epoll_ctl server: {}", strerror(errno)); return 1; @@ -214,7 +216,7 @@ int main() .events = EPOLLIN, .data = { .fd = keyboard_fd }, }; - if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, keyboard_fd, &event) == -1) + if (epoll_ctl(g_epoll_fd, EPOLL_CTL_ADD, keyboard_fd, &event) == -1) { dwarnln("epoll_ctl keyboard: {}", strerror(errno)); close(keyboard_fd); @@ -231,7 +233,7 @@ int main() .events = EPOLLIN, .data = { .fd = mouse_fd }, }; - if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, mouse_fd, &event) == -1) + if (epoll_ctl(g_epoll_fd, EPOLL_CTL_ADD, mouse_fd, &event) == -1) { dwarnln("epoll_ctl mouse: {}", strerror(errno)); close(mouse_fd); @@ -283,7 +285,7 @@ int main() timeout.tv_nsec = (sync_interval_us - (current_us - last_sync_us)) * 1000; epoll_event events[16]; - int epoll_events = epoll_pwait2(epoll_fd, events, 16, &timeout, nullptr); + int epoll_events = epoll_pwait2(g_epoll_fd, events, 16, &timeout, nullptr); if (epoll_events == -1 && errno != EINTR) { dwarnln("epoll_pwait2: {}", strerror(errno)); @@ -296,25 +298,28 @@ int main() { ASSERT(events[i].events & EPOLLIN); - int window_fd = accept4(server_fd, nullptr, nullptr, SOCK_NONBLOCK | SOCK_CLOEXEC); - if (window_fd == -1) + int client_fd = accept4(server_fd, nullptr, nullptr, SOCK_NONBLOCK | SOCK_CLOEXEC); + if (client_fd == -1) { dwarnln("accept: {}", strerror(errno)); continue; } - epoll_event event { - .events = EPOLLIN, - .data = { .fd = window_fd }, - }; - if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, window_fd, &event) == -1) + epoll_event event { .events = EPOLLIN, .data = { .fd = client_fd } }; + if (epoll_ctl(g_epoll_fd, EPOLL_CTL_ADD, client_fd, &event) == -1) { dwarnln("epoll_ctl: {}", strerror(errno)); - close(window_fd); + close(client_fd); + continue; + } + + if (auto ret = window_server.add_client_fd(client_fd); ret.is_error()) + { + dwarnln("add_client: {}", ret.error()); + close(client_fd); continue; } - window_server.add_client_fd(window_fd); continue; } @@ -361,99 +366,127 @@ int main() } const int client_fd = events[i].data.fd; - if (events[i].events & EPOLLHUP) + if (events[i].events & (EPOLLHUP | EPOLLERR)) { - epoll_ctl(epoll_fd, EPOLL_CTL_DEL, client_fd, nullptr); + epoll_ctl(g_epoll_fd, EPOLL_CTL_DEL, client_fd, nullptr); window_server.remove_client_fd(client_fd); continue; } - ASSERT(events[i].events & EPOLLIN); - auto& client_data = window_server.get_client_data(client_fd); - if (client_data.packet_buffer.empty()) + if (events[i].events & EPOLLOUT) { - uint32_t packet_size; - const ssize_t nrecv = recv(client_fd, &packet_size, sizeof(uint32_t), 0); - if (nrecv < 0) - dwarnln("recv 1: {}", strerror(errno)); - if (nrecv > 0 && nrecv != sizeof(uint32_t)) - dwarnln("could not read packet size with a single recv call, closing connection..."); - if (nrecv != sizeof(uint32_t)) + ASSERT(client_data.out_buffer_size > 0); + + const ssize_t nsend = send(client_fd, client_data.out_buffer.data(), client_data.out_buffer_size, 0); + if (nsend < 0 && !(errno == EWOULDBLOCK || errno == EAGAIN)) { - epoll_ctl(epoll_fd, EPOLL_CTL_DEL, client_fd, nullptr); + dwarnln("send: {}", strerror(errno)); + epoll_ctl(g_epoll_fd, EPOLL_CTL_DEL, client_fd, nullptr); window_server.remove_client_fd(client_fd); break; } - if (packet_size < 4) + if (nsend > 0) { - dwarnln("client sent invalid packet, closing connection..."); - epoll_ctl(epoll_fd, EPOLL_CTL_DEL, client_fd, nullptr); - window_server.remove_client_fd(client_fd); - break; + client_data.out_buffer_size -= nsend; + if (client_data.out_buffer_size == 0) + { + epoll_event event { .events = EPOLLIN, .data = { .fd = client_fd } }; + if (epoll_ctl(g_epoll_fd, EPOLL_CTL_MOD, client_fd, &event) == -1) + dwarnln("epoll_ctl remove EPOLLOUT: {}", strerror(errno)); + } + else + { + // TODO: maybe use a ring buffer so we don't have to memmove everything not sent + memmove( + client_data.out_buffer.data(), + client_data.out_buffer.data() + nsend, + client_data.out_buffer_size + ); + } } - - // this is a bit harsh, but i don't want to work on skipping streaming packets - if (client_data.packet_buffer.resize(packet_size).is_error()) - { - dwarnln("could not allocate memory for client packet, closing connection..."); - epoll_ctl(epoll_fd, EPOLL_CTL_DEL, client_fd, nullptr); - window_server.remove_client_fd(client_fd); - break; - } - - client_data.packet_buffer_nread = 0; - continue; } - const ssize_t nrecv = recv( - client_fd, - client_data.packet_buffer.data() + client_data.packet_buffer_nread, - client_data.packet_buffer.size() - client_data.packet_buffer_nread, - 0 - ); - if (nrecv < 0) - dwarnln("recv 2: {}", strerror(errno)); - if (nrecv <= 0) - { - epoll_ctl(epoll_fd, EPOLL_CTL_DEL, client_fd, nullptr); - window_server.remove_client_fd(client_fd); - break; - } - - client_data.packet_buffer_nread += nrecv; - if (client_data.packet_buffer_nread < client_data.packet_buffer.size()) + if (!(events[i].events & EPOLLIN)) continue; - ASSERT(client_data.packet_buffer.size() >= sizeof(uint32_t)); - - switch (*reinterpret_cast(client_data.packet_buffer.data())) { + const ssize_t nrecv = recv( + client_fd, + client_data.in_buffer.data() + client_data.in_buffer_size, + client_data.in_buffer.size() - client_data.in_buffer_size, + 0 + ); + if (nrecv < 0 && !(errno == EWOULDBLOCK || errno == EAGAIN)) + { + dwarnln("recv: {}", strerror(errno)); + epoll_ctl(g_epoll_fd, EPOLL_CTL_DEL, client_fd, nullptr); + window_server.remove_client_fd(client_fd); + break; + } + if (nrecv > 0) + client_data.in_buffer_size += nrecv; + } + + size_t bytes_handled = 0; + while (client_data.in_buffer_size - bytes_handled >= sizeof(LibGUI::PacketHeader)) + { + BAN::ConstByteSpan packet_span = client_data.in_buffer.span().slice(bytes_handled, client_data.in_buffer_size - bytes_handled); + const auto header = packet_span.as(); + if (packet_span.size() < header.size || header.size < sizeof(LibGUI::PacketHeader)) + break; + packet_span = packet_span.slice(0, header.size); + + switch (header.type) + { #define WINDOW_PACKET_CASE(enum, function) \ - case LibGUI::PacketType::enum: \ - if (auto ret = LibGUI::WindowPacket::enum::deserialize(client_data.packet_buffer.span()); !ret.is_error()) \ - window_server.function(client_fd, ret.release_value()); \ - break - WINDOW_PACKET_CASE(WindowCreate, on_window_create); - WINDOW_PACKET_CASE(WindowInvalidate, on_window_invalidate); - WINDOW_PACKET_CASE(WindowSetPosition, on_window_set_position); - WINDOW_PACKET_CASE(WindowSetAttributes, on_window_set_attributes); - WINDOW_PACKET_CASE(WindowSetMouseRelative, on_window_set_mouse_relative); - WINDOW_PACKET_CASE(WindowSetSize, on_window_set_size); - WINDOW_PACKET_CASE(WindowSetMinSize, on_window_set_min_size); - WINDOW_PACKET_CASE(WindowSetMaxSize, on_window_set_max_size); - WINDOW_PACKET_CASE(WindowSetFullscreen, on_window_set_fullscreen); - WINDOW_PACKET_CASE(WindowSetTitle, on_window_set_title); - WINDOW_PACKET_CASE(WindowSetCursor, on_window_set_cursor); + case LibGUI::PacketType::enum: \ + if (auto ret = LibGUI::WindowPacket::enum::deserialize(packet_span); !ret.is_error()) \ + window_server.function(client_fd, ret.release_value()); \ + else \ + derrorln("invalid packet: {}", ret.error()); \ + break + WINDOW_PACKET_CASE(WindowCreate, on_window_create); + WINDOW_PACKET_CASE(WindowInvalidate, on_window_invalidate); + WINDOW_PACKET_CASE(WindowSetPosition, on_window_set_position); + WINDOW_PACKET_CASE(WindowSetAttributes, on_window_set_attributes); + WINDOW_PACKET_CASE(WindowSetMouseRelative, on_window_set_mouse_relative); + WINDOW_PACKET_CASE(WindowSetSize, on_window_set_size); + WINDOW_PACKET_CASE(WindowSetMinSize, on_window_set_min_size); + WINDOW_PACKET_CASE(WindowSetMaxSize, on_window_set_max_size); + WINDOW_PACKET_CASE(WindowSetFullscreen, on_window_set_fullscreen); + WINDOW_PACKET_CASE(WindowSetTitle, on_window_set_title); + WINDOW_PACKET_CASE(WindowSetCursor, on_window_set_cursor); #undef WINDOW_PACKET_CASE - default: - dprintln("unhandled packet type: {}", *reinterpret_cast(client_data.packet_buffer.data())); + default: + dprintln("unhandled packet type: {}", static_cast(header.type)); + break; + } + + bytes_handled += header.size; } - client_data.packet_buffer.clear(); - client_data.packet_buffer_nread = 0; + // NOTE: this will only move a single partial packet, so this is fine + client_data.in_buffer_size -= bytes_handled; + memmove( + client_data.in_buffer.data(), + client_data.in_buffer.data() + bytes_handled, + client_data.in_buffer_size + ); + + if (client_data.in_buffer_size >= sizeof(LibGUI::PacketHeader)) + { + const auto header = BAN::ConstByteSpan(client_data.in_buffer.span()).as(); + if (header.size < sizeof(LibGUI::PacketHeader) || header.size > client_data.in_buffer.size()) + { + dwarnln("client tried to send a {} byte packet", header.size); + epoll_ctl(g_epoll_fd, EPOLL_CTL_DEL, client_fd, nullptr); + window_server.remove_client_fd(client_fd); + break; + } + } } } }