From d7e5c56e94b01fd2302adf924f992911db60974c Mon Sep 17 00:00:00 2001 From: Bananymous Date: Thu, 17 Oct 2024 01:36:59 +0300 Subject: [PATCH] userspace: Use SOCK_STREAM instead of SOCK_SEQPACKET for WindowServer This makes more sense if we have longer packages --- userspace/libraries/LibGUI/Window.cpp | 117 +++++--- .../libraries/LibGUI/include/LibGUI/Packet.h | 241 ++++++++++++++++ .../libraries/LibGUI/include/LibGUI/Window.h | 109 +------ userspace/programs/Terminal/Terminal.cpp | 4 +- userspace/programs/Terminal/Terminal.h | 2 +- userspace/programs/WindowServer/Window.cpp | 5 +- .../programs/WindowServer/WindowServer.cpp | 273 +++++++++--------- .../programs/WindowServer/WindowServer.h | 15 +- userspace/programs/WindowServer/main.cpp | 80 ++++- userspace/tests/test-window/main.cpp | 14 +- 10 files changed, 568 insertions(+), 292 deletions(-) create mode 100644 userspace/libraries/LibGUI/include/LibGUI/Packet.h diff --git a/userspace/libraries/LibGUI/Window.cpp b/userspace/libraries/LibGUI/Window.cpp index e8951fd5..5983e145 100644 --- a/userspace/libraries/LibGUI/Window.cpp +++ b/userspace/libraries/LibGUI/Window.cpp @@ -16,6 +16,47 @@ namespace LibGUI { + struct ReceivePacket + { + PacketType type; + BAN::Vector data_with_type; + }; + + static BAN::ErrorOr recv_packet(int socket) + { + uint32_t packet_size; + + { + const ssize_t nrecv = recv(socket, &packet_size, sizeof(uint32_t), 0); + if (nrecv < 0) + return BAN::Error::from_errno(errno); + if (nrecv == 0) + return BAN::Error::from_errno(ECONNRESET); + } + + if (packet_size < sizeof(uint32_t)) + return BAN::Error::from_literal("invalid packet, does not fit packet id"); + + BAN::Vector packet_data; + TRY(packet_data.resize(packet_size)); + + size_t total_recv = 0; + while (total_recv < packet_size) + { + const ssize_t nrecv = recv(socket, packet_data.data() + total_recv, packet_size - total_recv, 0); + if (nrecv < 0) + return BAN::Error::from_errno(errno); + if (nrecv == 0) + return BAN::Error::from_errno(ECONNRESET); + total_recv += nrecv; + } + + return ReceivePacket { + *reinterpret_cast(packet_data.data()), + packet_data + }; + } + Window::~Window() { munmap(m_framebuffer_smo, m_width * m_height * 4); @@ -24,13 +65,10 @@ namespace LibGUI BAN::ErrorOr> Window::create(uint32_t width, uint32_t height, BAN::StringView title) { - if (title.size() >= sizeof(WindowCreatePacket::title)) - return BAN::Error::from_errno(EINVAL); - BAN::Vector framebuffer; TRY(framebuffer.resize(width * height, 0xFF000000)); - int server_fd = socket(AF_UNIX, SOCK_SEQPACKET, 0); + int server_fd = socket(AF_UNIX, SOCK_STREAM, 0); if (server_fd == -1) return BAN::Error::from_errno(errno); BAN::ScopeGuard server_closer([server_fd] { close(server_fd); }); @@ -61,31 +99,32 @@ namespace LibGUI nanosleep(&sleep_time, nullptr); } - WindowCreatePacket packet; - packet.width = width; - packet.height = height; - strncpy(packet.title, title.data(), title.size()); - packet.title[title.size()] = '\0'; - if (send(server_fd, &packet, sizeof(packet), 0) != sizeof(packet)) - return BAN::Error::from_errno(errno); + WindowPacket::WindowCreate create_packet; + create_packet.width = width; + create_packet.height = height; + TRY(create_packet.title.append(title)); + TRY(create_packet.send_serialized(server_fd)); - WindowCreateResponse response; - if (recv(server_fd, &response, sizeof(response), 0) != sizeof(response)) - return BAN::Error::from_errno(errno); + const auto [response_type, response_data ] = TRY(recv_packet(server_fd)); + if (response_type != PacketType::WindowCreateResponse) + return BAN::Error::from_literal("Server responded with invalid packet"); - void* framebuffer_addr = smo_map(response.framebuffer_smo_key); + const auto create_response = TRY(WindowPacket::WindowCreateResponse::deserialize(response_data.span())); + void* framebuffer_addr = smo_map(create_response.smo_key); if (framebuffer_addr == nullptr) return BAN::Error::from_errno(errno); - server_closer.disable(); - - return TRY(BAN::UniqPtr::create( + auto window = TRY(BAN::UniqPtr::create( server_fd, static_cast(framebuffer_addr), BAN::move(framebuffer), width, height )); + + server_closer.disable(); + + return window; } void Window::fill_rect(int32_t x, int32_t y, uint32_t width, uint32_t height, uint32_t color) @@ -211,14 +250,23 @@ namespace LibGUI for (uint32_t i = 0; i < height; i++) memcpy(&m_framebuffer_smo[(y + i) * m_width + x], &m_framebuffer[(y + i) * m_width + x], width * sizeof(uint32_t)); - WindowInvalidatePacket packet; + WindowPacket::WindowInvalidate packet; packet.x = x; packet.y = y; packet.width = width; packet.height = height; - return send(m_server_fd, &packet, sizeof(packet), 0) == sizeof(packet); + + if (auto ret = packet.send_serialized(m_server_fd); ret.is_error()) + { + dprintln("Failed to send packet: {}", ret.error().get_message()); + return false; + } + + return true; } +#define TRY_OR_BREAK(...) ({ auto&& e = (__VA_ARGS__); if (e.is_error()) break; e.release_value(); }) + void Window::poll_events() { for (;;) @@ -232,35 +280,38 @@ namespace LibGUI if (!FD_ISSET(m_server_fd, &fds)) break; - EventPacket packet; - if (recv(m_server_fd, &packet, sizeof(packet), 0) <= 0) + auto packet_or_error = recv_packet(m_server_fd); + if (packet_or_error.is_error()) break; - switch (packet.type) + const auto [packet_type, packet_data] = packet_or_error.release_value(); + switch (packet_type) { - case EventPacket::Type::DestroyWindow: + case PacketType::DestroyWindowEvent: exit(1); - case EventPacket::Type::CloseWindow: + case PacketType::CloseWindowEvent: if (m_close_window_event_callback) m_close_window_event_callback(); else exit(0); break; - case EventPacket::Type::KeyEvent: + case PacketType::KeyEvent: if (m_key_event_callback) - m_key_event_callback(packet.key_event); + m_key_event_callback(TRY_OR_BREAK(EventPacket::KeyEvent::deserialize(packet_data.span())).event); break; - case EventPacket::Type::MouseButtonEvent: + case PacketType::MouseButtonEvent: if (m_mouse_button_event_callback) - m_mouse_button_event_callback(packet.mouse_button_event); + m_mouse_button_event_callback(TRY_OR_BREAK(EventPacket::MouseButtonEvent::deserialize(packet_data.span())).event); break; - case EventPacket::Type::MouseMoveEvent: + case PacketType::MouseMoveEvent: if (m_mouse_move_event_callback) - m_mouse_move_event_callback(packet.mouse_move_event); + m_mouse_move_event_callback(TRY_OR_BREAK(EventPacket::MouseMoveEvent::deserialize(packet_data.span())).event); break; - case EventPacket::Type::MouseScrollEvent: + case PacketType::MouseScrollEvent: if (m_mouse_scroll_event_callback) - m_mouse_scroll_event_callback(packet.mouse_scroll_event); + m_mouse_scroll_event_callback(TRY_OR_BREAK(EventPacket::MouseScrollEvent::deserialize(packet_data.span())).event); + break; + default: break; } } diff --git a/userspace/libraries/LibGUI/include/LibGUI/Packet.h b/userspace/libraries/LibGUI/include/LibGUI/Packet.h new file mode 100644 index 00000000..728703a5 --- /dev/null +++ b/userspace/libraries/LibGUI/include/LibGUI/Packet.h @@ -0,0 +1,241 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include +#include + +#define FOR_EACH_0(macro) +#define FOR_EACH_2(macro, type, name) macro(type, name) +#define FOR_EACH_4(macro, type, name, ...) macro(type, name) FOR_EACH_2(macro, __VA_ARGS__) +#define FOR_EACH_6(macro, type, name, ...) macro(type, name) FOR_EACH_4(macro, __VA_ARGS__) +#define FOR_EACH_8(macro, type, name, ...) macro(type, name) FOR_EACH_6(macro, __VA_ARGS__) + +#define CONCATENATE_2(arg1, arg2) arg1 ## arg2 +#define CONCATENATE_1(arg1, arg2) CONCATENATE_2(arg1, arg2) +#define CONCATENATE(arg1, arg2) CONCATENATE_1(arg1, arg2) + +#define FOR_EACH_NARG(...) FOR_EACH_NARG_(__VA_ARGS__ __VA_OPT__(,) FOR_EACH_RSEQ_N()) +#define FOR_EACH_NARG_(...) FOR_EACH_ARG_N(__VA_ARGS__) +#define FOR_EACH_ARG_N(_1, _2, _3, _4, _5, _6, _7, _8, N, ...) N +#define FOR_EACH_RSEQ_N() 8, 7, 6, 5, 4, 3, 2, 1, 0 + +#define FOR_EACH_(N, what, ...) CONCATENATE(FOR_EACH_, N)(what __VA_OPT__(,) __VA_ARGS__) +#define FOR_EACH(what, ...) FOR_EACH_(FOR_EACH_NARG(__VA_ARGS__), what __VA_OPT__(,) __VA_ARGS__) + +#define FIELD_DECL(type, name) type name; +#define ADD_SERIALIZED_SIZE(type, name) serialized_size += Serialize::serialized_size_impl(this->name); +#define SEND_SERIALIZED(type, name) TRY(Serialize::send_serialized_impl(socket, this->name)); +#define DESERIALIZE(type, name) value.name = TRY(Serialize::deserialize_impl(buffer)); + +#define DEFINE_PACKET_EXTRA(name, extra, ...) \ + struct name \ + { \ + static constexpr PacketType type = PacketType::name; \ + static constexpr uint32_t type_u32 = static_cast(type); \ + \ + extra; \ + \ + FOR_EACH(FIELD_DECL, __VA_ARGS__) \ + \ + size_t serialized_size() \ + { \ + size_t serialized_size = Serialize::serialized_size_impl(type_u32); \ + FOR_EACH(ADD_SERIALIZED_SIZE, __VA_ARGS__) \ + return serialized_size; \ + } \ + \ + BAN::ErrorOr send_serialized(int socket) \ + { \ + const uint32_t serialized_size = this->serialized_size(); \ + TRY(Serialize::send_serialized_impl(socket, serialized_size)); \ + TRY(Serialize::send_serialized_impl(socket, type_u32)); \ + FOR_EACH(SEND_SERIALIZED, __VA_ARGS__) \ + return {}; \ + } \ + \ + static BAN::ErrorOr deserialize(BAN::ConstByteSpan buffer) \ + { \ + const uint32_t type_u32 = TRY(Serialize::deserialize_impl(buffer)); \ + if (type_u32 != name::type_u32) \ + return BAN::Error::from_errno(EINVAL); \ + name value; \ + FOR_EACH(DESERIALIZE, __VA_ARGS__) \ + return value; \ + } \ + } + +#define DEFINE_PACKET(name, ...) DEFINE_PACKET_EXTRA(name, , __VA_ARGS__) + +namespace LibGUI +{ + + static constexpr BAN::StringView s_window_server_socket = "/tmp/window-server.socket"_sv; + + namespace Serialize + { + + inline BAN::ErrorOr send_raw_data(int socket, BAN::ConstByteSpan data) + { + size_t send_done = 0; + while (send_done < data.size()) + { + const ssize_t nsend = ::send(socket, data.data() + send_done, data.size() - send_done, 0); + if (nsend < 0) + return BAN::Error::from_errno(errno); + if (nsend == 0) + return BAN::Error::from_errno(ECONNRESET); + send_done += nsend; + } + return {}; + } + + template requires BAN::is_pod_v + inline size_t serialized_size_impl(const T&) + { + return sizeof(T); + } + + template requires BAN::is_pod_v + inline BAN::ErrorOr send_serialized_impl(int socket, const T& value) + { + TRY(send_raw_data(socket, BAN::ConstByteSpan::from(value))); + return {}; + } + + template requires BAN::is_pod_v + inline BAN::ErrorOr deserialize_impl(BAN::ConstByteSpan& buffer) + { + if (buffer.size() < sizeof(T)) + return BAN::Error::from_errno(ENOBUFS); + const T value = buffer.as(); + buffer = buffer.slice(sizeof(T)); + return value; + } + + template requires BAN::is_same_v + inline size_t serialized_size_impl(const T& value) + { + return sizeof(uint32_t) + value.size(); + } + + template requires BAN::is_same_v + inline BAN::ErrorOr send_serialized_impl(int socket, const T& value) + { + const uint32_t value_size = value.size(); + TRY(send_raw_data(socket, BAN::ConstByteSpan::from(value_size))); + auto* u8_data = reinterpret_cast(value.data()); + TRY(send_raw_data(socket, BAN::ConstByteSpan(u8_data, value.size()))); + return {}; + } + + template requires BAN::is_same_v + inline BAN::ErrorOr deserialize_impl(BAN::ConstByteSpan& buffer) + { + if (buffer.size() < sizeof(uint32_t)) + return BAN::Error::from_errno(ENOBUFS); + const uint32_t string_len = buffer.as(); + buffer = buffer.slice(sizeof(uint32_t)); + + if (buffer.size() < string_len) + return BAN::Error::from_errno(ENOBUFS); + + BAN::String string; + TRY(string.resize(string_len)); + memcpy(string.data(), buffer.data(), string_len); + buffer = buffer.slice(string_len); + + return string; + } + + } + + enum class PacketType : uint32_t + { + WindowCreate, + WindowCreateResponse, + WindowInvalidate, + + DestroyWindowEvent, + CloseWindowEvent, + KeyEvent, + MouseButtonEvent, + MouseMoveEvent, + MouseScrollEvent, + }; + + namespace WindowPacket + { + + DEFINE_PACKET(WindowCreate, + uint32_t, width, + uint32_t, height, + BAN::String, title + ); + + DEFINE_PACKET(WindowCreateResponse, + long, smo_key + ); + + DEFINE_PACKET(WindowInvalidate, + uint32_t, x, + uint32_t, y, + uint32_t, width, + uint32_t, height + ); + + } + + namespace EventPacket + { + + DEFINE_PACKET( + DestroyWindowEvent + ); + + DEFINE_PACKET( + CloseWindowEvent + ); + + DEFINE_PACKET_EXTRA( + KeyEvent, + using event_t = LibInput::KeyEvent, + event_t, event + ); + + DEFINE_PACKET_EXTRA( + MouseButtonEvent, + struct event_t { + LibInput::MouseButton button; + bool pressed; + int32_t x; + int32_t y; + }, + event_t, event + ); + + DEFINE_PACKET_EXTRA( + MouseMoveEvent, + struct event_t { + int32_t x; + int32_t y; + }, + event_t, event + ); + + DEFINE_PACKET_EXTRA( + MouseScrollEvent, + struct event_t { + int32_t scroll; + }, + event_t, event + ); + + } + +} diff --git a/userspace/libraries/LibGUI/include/LibGUI/Window.h b/userspace/libraries/LibGUI/include/LibGUI/Window.h index 849503c3..8a1c19d8 100644 --- a/userspace/libraries/LibGUI/include/LibGUI/Window.h +++ b/userspace/libraries/LibGUI/include/LibGUI/Window.h @@ -4,100 +4,13 @@ #include #include -#include -#include - -#include -#include +#include namespace LibFont { class Font; } namespace LibGUI { - static constexpr BAN::StringView s_window_server_socket = "/tmp/window-server.socket"_sv; - - enum WindowPacketType : uint8_t - { - INVALID, - CreateWindow, - Invalidate, - COUNT - }; - - struct WindowCreatePacket - { - WindowPacketType type = WindowPacketType::CreateWindow; - uint32_t width; - uint32_t height; - char title[52]; - }; - - struct WindowInvalidatePacket - { - WindowPacketType type = WindowPacketType::Invalidate; - uint32_t x; - uint32_t y; - uint32_t width; - uint32_t height; - }; - - struct WindowCreateResponse - { - long framebuffer_smo_key; - }; - - struct WindowPacket - { - WindowPacket() - : type(WindowPacketType::INVALID) - { } - - union - { - WindowPacketType type; - WindowCreatePacket create; - WindowInvalidatePacket invalidate; - }; - }; - - struct EventPacket - { - enum class Type : uint8_t - { - DestroyWindow, - CloseWindow, - KeyEvent, - MouseButtonEvent, - MouseMoveEvent, - MouseScrollEvent, - }; - using KeyEvent = LibInput::KeyEvent; - using MouseButton = LibInput::MouseButton; - struct MouseButtonEvent - { - MouseButton button; - bool pressed; - int32_t x; - int32_t y; - }; - struct MouseMoveEvent - { - int32_t x; - int32_t y; - }; - using MouseScrollEvent = LibInput::MouseScrollEvent; - - Type type; - union - { - KeyEvent key_event; - MouseButtonEvent mouse_button_event; - MouseMoveEvent mouse_move_event; - MouseScrollEvent mouse_scroll_event; - }; - }; - class Window { public: @@ -140,11 +53,11 @@ namespace LibGUI uint32_t height() const { return m_height; } void poll_events(); - void set_close_window_event_callback(BAN::Function callback) { m_close_window_event_callback = callback; } - void set_key_event_callback(BAN::Function callback) { m_key_event_callback = callback; } - void set_mouse_button_event_callback(BAN::Function callback) { m_mouse_button_event_callback = callback; } - void set_mouse_move_event_callback(BAN::Function callback) { m_mouse_move_event_callback = callback; } - void set_mouse_scroll_event_callback(BAN::Function callback) { m_mouse_scroll_event_callback = callback; } + void set_close_window_event_callback(BAN::Function callback) { m_close_window_event_callback = callback; } + void set_key_event_callback(BAN::Function callback) { m_key_event_callback = callback; } + void set_mouse_button_event_callback(BAN::Function callback) { m_mouse_button_event_callback = callback; } + void set_mouse_move_event_callback(BAN::Function callback) { m_mouse_move_event_callback = callback; } + void set_mouse_scroll_event_callback(BAN::Function callback) { m_mouse_scroll_event_callback = callback; } int server_fd() const { return m_server_fd; } @@ -167,11 +80,11 @@ namespace LibGUI uint32_t m_width; uint32_t m_height; - BAN::Function m_close_window_event_callback; - BAN::Function m_key_event_callback; - BAN::Function m_mouse_button_event_callback; - BAN::Function m_mouse_move_event_callback; - BAN::Function m_mouse_scroll_event_callback; + BAN::Function m_close_window_event_callback; + BAN::Function m_key_event_callback; + BAN::Function m_mouse_button_event_callback; + BAN::Function m_mouse_move_event_callback; + BAN::Function m_mouse_scroll_event_callback; friend class BAN::UniqPtr; }; diff --git a/userspace/programs/Terminal/Terminal.cpp b/userspace/programs/Terminal/Terminal.cpp index a1846eb5..d2766327 100644 --- a/userspace/programs/Terminal/Terminal.cpp +++ b/userspace/programs/Terminal/Terminal.cpp @@ -126,7 +126,7 @@ void Terminal::run() MUST(m_cursor_buffer.resize(m_font.width() * m_font.height(), m_bg_color)); show_cursor(); - m_window->set_key_event_callback([&](LibGUI::EventPacket::KeyEvent event) { on_key_event(event); }); + m_window->set_key_event_callback([&](LibGUI::EventPacket::KeyEvent::event_t event) { on_key_event(event); }); const int max_fd = BAN::Math::max(m_shell_info.pts_master, m_window->server_fd()); while (!s_shell_exited) @@ -576,7 +576,7 @@ Rectangle Terminal::putchar(uint8_t ch) return should_invalidate; } -void Terminal::on_key_event(LibGUI::EventPacket::KeyEvent event) +void Terminal::on_key_event(LibGUI::EventPacket::KeyEvent::event_t event) { if (event.released()) return; diff --git a/userspace/programs/Terminal/Terminal.h b/userspace/programs/Terminal/Terminal.h index a3e6f598..238a365f 100644 --- a/userspace/programs/Terminal/Terminal.h +++ b/userspace/programs/Terminal/Terminal.h @@ -43,7 +43,7 @@ private: void hide_cursor(); void show_cursor(); - void on_key_event(LibGUI::EventPacket::KeyEvent); + void on_key_event(LibGUI::EventPacket::KeyEvent::event_t); void start_shell(); diff --git a/userspace/programs/WindowServer/Window.cpp b/userspace/programs/WindowServer/Window.cpp index b2fcc354..ff73ba64 100644 --- a/userspace/programs/WindowServer/Window.cpp +++ b/userspace/programs/WindowServer/Window.cpp @@ -27,9 +27,8 @@ Window::~Window() munmap(m_fb_addr, client_width() * client_height() * 4); smo_delete(m_smo_key); - LibGUI::EventPacket event; - event.type = LibGUI::EventPacket::Type::DestroyWindow; - send(m_client_fd, &event, sizeof(event), 0); + LibGUI::EventPacket::DestroyWindowEvent packet; + (void)packet.send_serialized(m_client_fd); close(m_client_fd); } diff --git a/userspace/programs/WindowServer/WindowServer.cpp b/userspace/programs/WindowServer/WindowServer.cpp index 0bd6c03e..15558cfb 100644 --- a/userspace/programs/WindowServer/WindowServer.cpp +++ b/userspace/programs/WindowServer/WindowServer.cpp @@ -2,6 +2,7 @@ #include "WindowServer.h" #include +#include #include #include @@ -26,107 +27,109 @@ WindowServer::WindowServer(Framebuffer& framebuffer, int32_t corner_radius) BAN::ErrorOr WindowServer::set_background_image(BAN::UniqPtr image) { if (image->width() != (uint64_t)m_framebuffer.width || image->height() != (uint64_t)m_framebuffer.height) - image = TRY(image->resize(m_framebuffer.width, m_framebuffer.height)); + image = TRY(image->resize(m_framebuffer.width, m_framebuffer.height, LibImage::Image::ResizeAlgorithm::Linear)); m_background_image = BAN::move(image); invalidate(m_framebuffer.area()); return {}; } -void WindowServer::on_window_packet(int fd, LibGUI::WindowPacket packet) +void WindowServer::on_window_create(int fd, const LibGUI::WindowPacket::WindowCreate& packet) { - switch (packet.type) + for (auto& window : m_client_windows) { - case LibGUI::WindowPacketType::CreateWindow: - { - // FIXME: This should be probably allowed - for (auto& window : m_client_windows) - { - if (window->client_fd() == fd) - { - dwarnln("client {} tried to create window while already owning a window", fd); - return; - } - } - - const size_t window_fb_bytes = packet.create.width * packet.create.height * 4; - - long smo_key = smo_create(window_fb_bytes, PROT_READ | PROT_WRITE); - if (smo_key == -1) - { - dwarnln("smo_create: {}", strerror(errno)); - break; - } - - Rectangle window_area { - static_cast((m_framebuffer.width - packet.create.width) / 2), - static_cast((m_framebuffer.height - packet.create.height) / 2), - static_cast(packet.create.width), - static_cast(packet.create.height) - }; - - packet.create.title[sizeof(packet.create.title) - 1] = '\0'; - - // Window::Window(int fd, Rectangle area, long smo_key, BAN::StringView title, const LibFont::Font& font) - auto window = MUST(BAN::RefPtr::create( - fd, - window_area, - smo_key, - packet.create.title, - m_font - )); - MUST(m_client_windows.push_back(window)); - set_focused_window(window); - - LibGUI::WindowCreateResponse response; - response.framebuffer_smo_key = smo_key; - if (send(window->client_fd(), &response, sizeof(response), 0) != sizeof(response)) - { - dwarnln("send: {}", strerror(errno)); - break; - } - - break; - } - case LibGUI::WindowPacketType::Invalidate: - { - if (packet.invalidate.width == 0 || packet.invalidate.height == 0) - break; - - BAN::RefPtr target_window; - for (auto& window : m_client_windows) - { - if (window->client_fd() == fd) - { - target_window = window; - break; - } - } - if (!target_window) - { - dwarnln("client {} tried to invalidate window while not owning a window", fd); - break; - } - - const int32_t br_x = packet.invalidate.x + packet.invalidate.width - 1; - const int32_t br_y = packet.invalidate.y + packet.invalidate.height - 1; - if (!target_window->client_size().contains({ br_x, br_y })) - { - dwarnln("Invalid Invalidate packet parameters"); - break; - } - - invalidate({ - target_window->client_x() + static_cast(packet.invalidate.x), - target_window->client_y() + static_cast(packet.invalidate.y), - static_cast(packet.invalidate.width), - static_cast(packet.invalidate.height), - }); - - break; - } - default: - ASSERT_NOT_REACHED(); + if (window->client_fd() != fd) + continue; + dwarnln("client with window tried to create another one"); + return; } + + const size_t window_fb_bytes = packet.width * packet.height * 4; + + long smo_key = smo_create(window_fb_bytes, PROT_READ | PROT_WRITE); + if (smo_key == -1) + { + dwarnln("smo_create: {}", strerror(errno)); + return; + } + BAN::ScopeGuard smo_deleter([smo_key] { smo_delete(smo_key); }); + + Rectangle window_area { + static_cast((m_framebuffer.width - packet.width) / 2), + static_cast((m_framebuffer.height - packet.height) / 2), + static_cast(packet.width), + static_cast(packet.height) + }; + + // Window::Window(int fd, Rectangle area, long smo_key, BAN::StringView title, const LibFont::Font& font) + auto window_or_error = (BAN::RefPtr::create( + fd, + window_area, + smo_key, + packet.title, + m_font + )); + if (window_or_error.is_error()) + { + dwarnln("could not create window for client: {}", window_or_error.error()); + return; + } + auto window = window_or_error.release_value(); + + if (auto ret = m_client_windows.push_back(window); ret.is_error()) + { + dwarnln("could not create window for client: {}", ret.error()); + return; + } + BAN::ScopeGuard window_popper([&] { m_client_windows.pop_back(); }); + + LibGUI::WindowPacket::WindowCreateResponse response; + response.smo_key = smo_key; + if (auto ret = response.send_serialized(fd); ret.is_error()) + { + dwarnln("could not respond to window create request: {}", ret.error()); + return; + } + + smo_deleter.disable(); + window_popper.disable(); + + set_focused_window(window); +} + +void WindowServer::on_window_invalidate(int fd, const LibGUI::WindowPacket::WindowInvalidate& packet) +{ + if (packet.width == 0 || packet.height == 0) + return; + + BAN::RefPtr target_window; + for (auto& window : m_client_windows) + { + if (window->client_fd() != fd) + continue; + target_window = window; + break; + } + + if (!target_window) + { + dwarnln("client tried to invalidate window while not owning a window"); + return; + } + + const int32_t br_x = packet.x + packet.width - 1; + const int32_t br_y = packet.y + packet.height - 1; + if (!target_window->client_size().contains({ br_x, br_y })) + { + dwarnln("invalid Invalidate packet parameters"); + return; + } + + invalidate({ + target_window->client_x() + static_cast(packet.x), + target_window->client_y() + static_cast(packet.y), + static_cast(packet.width), + static_cast(packet.height), + }); } void WindowServer::on_key_event(LibInput::KeyEvent event) @@ -173,10 +176,10 @@ void WindowServer::on_key_event(LibInput::KeyEvent event) if (m_focused_window) { - LibGUI::EventPacket packet; - packet.type = LibGUI::EventPacket::Type::KeyEvent; - packet.key_event = event; - send(m_focused_window->client_fd(), &packet, sizeof(packet), 0); + LibGUI::EventPacket::KeyEvent packet; + packet.event = event; + if (auto ret = packet.send_serialized(m_focused_window->client_fd()); ret.is_error()) + dwarnln("could not send key event: {}", ret.error()); } } @@ -207,20 +210,26 @@ void WindowServer::on_mouse_button(LibInput::MouseButtonEvent event) else if (!event.pressed && event.button == LibInput::MouseButton::Left && target_window->close_button_area().contains(m_cursor)) { // NOTE: we always have target window if code reaches here - LibGUI::EventPacket packet; - packet.type = LibGUI::EventPacket::Type::CloseWindow; - send(m_focused_window->client_fd(), &packet, sizeof(packet), 0); + LibGUI::EventPacket::CloseWindowEvent packet; + if (auto ret = packet.send_serialized(m_focused_window->client_fd()); ret.is_error()) + { + dwarnln("could not send close window event: {}", ret.error()); + return; + } } else if (target_window->client_area().contains(m_cursor)) { // NOTE: we always have target window if code reaches here - LibGUI::EventPacket packet; - packet.type = LibGUI::EventPacket::Type::MouseButtonEvent; - packet.mouse_button_event.button = event.button; - packet.mouse_button_event.pressed = event.pressed; - packet.mouse_button_event.x = m_cursor.x - m_focused_window->client_x(); - packet.mouse_button_event.y = m_cursor.y - m_focused_window->client_y(); - send(m_focused_window->client_fd(), &packet, sizeof(packet), 0); + LibGUI::EventPacket::MouseButtonEvent packet; + packet.event.button = event.button; + packet.event.pressed = event.pressed; + packet.event.x = m_cursor.x - m_focused_window->client_x(); + packet.event.y = m_cursor.y - m_focused_window->client_y(); + if (auto ret = packet.send_serialized(m_focused_window->client_fd()); ret.is_error()) + { + dwarnln("could not send mouse button event event: {}", ret.error()); + return; + } } } @@ -265,11 +274,14 @@ void WindowServer::on_mouse_move(LibInput::MouseMoveEvent event) if (m_focused_window) { - LibGUI::EventPacket packet; - packet.type = LibGUI::EventPacket::Type::MouseMoveEvent; - packet.mouse_move_event.x = m_cursor.x - m_focused_window->client_x(); - packet.mouse_move_event.y = m_cursor.y - m_focused_window->client_y(); - send(m_focused_window->client_fd(), &packet, sizeof(packet), 0); + LibGUI::EventPacket::MouseMoveEvent packet; + packet.event.x = m_cursor.x - m_focused_window->client_x(); + packet.event.y = m_cursor.y - m_focused_window->client_y(); + if (auto ret = packet.send_serialized(m_focused_window->client_fd()); ret.is_error()) + { + dwarnln("could not send mouse move event event: {}", ret.error()); + return; + } } } @@ -277,10 +289,13 @@ void WindowServer::on_mouse_scroll(LibInput::MouseScrollEvent event) { if (m_focused_window) { - LibGUI::EventPacket packet; - packet.type = LibGUI::EventPacket::Type::MouseScrollEvent; - packet.mouse_scroll_event = event; - send(m_focused_window->client_fd(), &packet, sizeof(packet), 0); + LibGUI::EventPacket::MouseScrollEvent packet; + packet.event.scroll = event.scroll; + if (auto ret = packet.send_serialized(m_focused_window->client_fd()); ret.is_error()) + { + dwarnln("could not send mouse scroll event event: {}", ret.error()); + return; + } } } @@ -595,19 +610,19 @@ Rectangle WindowServer::cursor_area() const void WindowServer::add_client_fd(int fd) { - MUST(m_client_fds.push_back(fd)); + if (auto ret = m_client_data.emplace(fd); ret.is_error()) + { + dwarnln("could not add client: {}", ret.error()); + return; + } } void WindowServer::remove_client_fd(int fd) { - for (size_t i = 0; i < m_client_fds.size(); i++) - { - if (m_client_fds[i] == fd) - { - m_client_fds.remove(i); - break; - } - } + auto it = m_client_data.find(fd); + if (it == m_client_data.end()) + return; + m_client_data.remove(it); for (size_t i = 0; i < m_client_windows.size(); i++) { @@ -635,7 +650,7 @@ void WindowServer::remove_client_fd(int fd) int WindowServer::get_client_fds(fd_set& fds) const { int max_fd = 0; - for (int fd : m_client_fds) + for (const auto& [fd, _] : m_client_data) { FD_SET(fd, &fds); max_fd = BAN::Math::max(max_fd, fd); @@ -643,13 +658,13 @@ int WindowServer::get_client_fds(fd_set& fds) const return max_fd; } -void WindowServer::for_each_client_fd(const BAN::Function& callback) +void WindowServer::for_each_client_fd(const BAN::Function& callback) { m_deleted_window = false; - for (int fd : m_client_fds) + for (auto& [fd, cliend_data] : m_client_data) { if (m_deleted_window) break; - callback(fd); + callback(fd, cliend_data); } } diff --git a/userspace/programs/WindowServer/WindowServer.h b/userspace/programs/WindowServer/WindowServer.h index 0752bb8c..1b8c0dcd 100644 --- a/userspace/programs/WindowServer/WindowServer.h +++ b/userspace/programs/WindowServer/WindowServer.h @@ -18,12 +18,20 @@ class WindowServer { +public: + struct ClientData + { + size_t packet_buffer_nread = 0; + BAN::Vector packet_buffer; + }; + public: WindowServer(Framebuffer& framebuffer, int32_t corner_radius); BAN::ErrorOr set_background_image(BAN::UniqPtr); - void on_window_packet(int fd, LibGUI::WindowPacket); + void on_window_create(int fd, const LibGUI::WindowPacket::WindowCreate&); + void on_window_invalidate(int fd, const LibGUI::WindowPacket::WindowInvalidate&); void on_key_event(LibInput::KeyEvent event); void on_mouse_button(LibInput::MouseButtonEvent event); @@ -39,14 +47,15 @@ public: void add_client_fd(int fd); void remove_client_fd(int fd); int get_client_fds(fd_set& fds) const; - void for_each_client_fd(const BAN::Function& callback); + void for_each_client_fd(const BAN::Function& callback); bool is_stopped() const { return m_is_stopped; } private: Framebuffer& m_framebuffer; BAN::Vector> m_client_windows; - BAN::Vector m_client_fds; + + BAN::HashMap m_client_data; const int32_t m_corner_radius; diff --git a/userspace/programs/WindowServer/main.cpp b/userspace/programs/WindowServer/main.cpp index 05503e6c..35dde3e4 100644 --- a/userspace/programs/WindowServer/main.cpp +++ b/userspace/programs/WindowServer/main.cpp @@ -114,7 +114,7 @@ int open_server_fd() if (stat(LibGUI::s_window_server_socket.data(), &st) != -1) unlink(LibGUI::s_window_server_socket.data()); - int server_fd = socket(AF_UNIX, SOCK_SEQPACKET | SOCK_CLOEXEC, 0); + int server_fd = socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); if (server_fd == -1) { perror("socket"); @@ -178,12 +178,6 @@ int main() dprintln("Window server started"); - size_t window_packet_sizes[LibGUI::WindowPacketType::COUNT] {}; - window_packet_sizes[LibGUI::WindowPacketType::INVALID] = 0; - window_packet_sizes[LibGUI::WindowPacketType::CreateWindow] = sizeof(LibGUI::WindowCreatePacket); - window_packet_sizes[LibGUI::WindowPacketType::Invalidate] = sizeof(LibGUI::WindowInvalidatePacket); - static_assert(LibGUI::WindowPacketType::COUNT == 3); - auto config = parse_config(); WindowServer window_server(framebuffer, config.corner_radius); @@ -281,13 +275,49 @@ int main() } window_server.for_each_client_fd( - [&](int fd) -> BAN::Iteration + [&](int fd, WindowServer::ClientData& client_data) -> BAN::Iteration { if (!FD_ISSET(fd, &fds)) return BAN::Iteration::Continue; - LibGUI::WindowPacket packet; - ssize_t nrecv = recv(fd, &packet, sizeof(packet), 0); + if (client_data.packet_buffer.empty()) + { + uint32_t packet_size; + const ssize_t nrecv = recv(fd, &packet_size, sizeof(uint32_t), 0); + if (nrecv < 0) + dwarnln("recv: {}", strerror(errno)); + if (nrecv > 0 && nrecv != sizeof(uint32_t)) + dwarnln("could not read packet size with a single recv call, closing connection..."); + if (nrecv != sizeof(uint32_t)) + { + window_server.remove_client_fd(fd); + return BAN::Iteration::Continue; + } + + if (packet_size < 4) + { + dwarnln("client sent invalid packet, closing connection..."); + return BAN::Iteration::Continue; + } + + // this is a bit harsh, but i don't want to work on skipping streaming packets + if (client_data.packet_buffer.resize(packet_size).is_error()) + { + dwarnln("could not allocate memory for client packet, closing connection..."); + window_server.remove_client_fd(fd); + return BAN::Iteration::Continue; + } + + client_data.packet_buffer_nread = 0; + return BAN::Iteration::Continue; + } + + const ssize_t nrecv = recv( + fd, + client_data.packet_buffer.data() + client_data.packet_buffer_nread, + client_data.packet_buffer.size() - client_data.packet_buffer_nread, + 0 + ); if (nrecv < 0) dwarnln("recv: {}", strerror(errno)); if (nrecv <= 0) @@ -296,12 +326,30 @@ int main() return BAN::Iteration::Continue; } - if (packet.type == LibGUI::WindowPacketType::INVALID || packet.type >= LibGUI::WindowPacketType::COUNT) - dwarnln("Invalid WindowPacket (type {})", (int)packet.type); - if (static_cast(nrecv) != window_packet_sizes[packet.type]) - dwarnln("Invalid WindowPacket size (type {}, size {})", (int)packet.type, nrecv); - else - window_server.on_window_packet(fd, packet); + client_data.packet_buffer_nread += nrecv; + if (client_data.packet_buffer_nread < client_data.packet_buffer.size()) + return BAN::Iteration::Continue; + + ASSERT(client_data.packet_buffer.size() >= sizeof(uint32_t)); + + switch (*reinterpret_cast(client_data.packet_buffer.data())) + { + case LibGUI::PacketType::WindowCreate: + { + if (auto ret = LibGUI::WindowPacket::WindowCreate::deserialize(client_data.packet_buffer.span()); !ret.is_error()) + window_server.on_window_create(fd, ret.release_value()); + break; + } + case LibGUI::PacketType::WindowInvalidate: + if (auto ret = LibGUI::WindowPacket::WindowInvalidate::deserialize(client_data.packet_buffer.span()); !ret.is_error()) + window_server.on_window_invalidate(fd, ret.release_value()); + break; + default: + dprintln("unhandled packet type: {}", *reinterpret_cast(client_data.packet_buffer.data())); + } + + client_data.packet_buffer.clear(); + client_data.packet_buffer_nread = 0; return BAN::Iteration::Continue; } ); diff --git a/userspace/tests/test-window/main.cpp b/userspace/tests/test-window/main.cpp index ab62f722..98ba3fc6 100644 --- a/userspace/tests/test-window/main.cpp +++ b/userspace/tests/test-window/main.cpp @@ -32,19 +32,19 @@ int main() auto window = window_or_error.release_value(); window->set_close_window_event_callback([&] { running = false; }); window->set_mouse_button_event_callback( - [&](LibGUI::EventPacket::MouseButtonEvent event) + [&](LibGUI::EventPacket::MouseButtonEvent::event_t event) { - if (event.pressed && event.button == LibGUI::EventPacket::MouseButton::Left) + if (event.pressed && event.button == LibInput::MouseButton::Left) randomize_color(window); const char* button; switch (event.button) { - case LibGUI::EventPacket::MouseButton::Left: button = "left"; break; - case LibGUI::EventPacket::MouseButton::Right: button = "right"; break; - case LibGUI::EventPacket::MouseButton::Middle: button = "middle"; break; - case LibGUI::EventPacket::MouseButton::Extra1: button = "extra1"; break; - case LibGUI::EventPacket::MouseButton::Extra2: button = "extra2"; break; + case LibInput::MouseButton::Left: button = "left"; break; + case LibInput::MouseButton::Right: button = "right"; break; + case LibInput::MouseButton::Middle: button = "middle"; break; + case LibInput::MouseButton::Extra1: button = "extra1"; break; + case LibInput::MouseButton::Extra2: button = "extra2"; break; } dprintln("mouse button '{}' {} at {}, {}", button, event.pressed ? "pressed" : "released", event.x, event.y); }