diff --git a/userspace/programs/Shell/main.cpp b/userspace/programs/Shell/main.cpp index 034c1ba4..2e11d854 100644 --- a/userspace/programs/Shell/main.cpp +++ b/userspace/programs/Shell/main.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -1243,6 +1244,101 @@ static void print_prompt() fflush(stdout); } +static bool detect_cursor_position_support() +{ + constexpr auto getchar_nonblock = + []() -> char + { + fd_set fds; + FD_ZERO(&fds); + FD_SET(STDIN_FILENO, &fds); + + timeval timeout; + timeout.tv_sec = 0; + timeout.tv_usec = 100'000; + + int nselect = select(STDIN_FILENO + 1, &fds, nullptr, nullptr, &timeout); + if (nselect != 1) + return '\0'; + + char ch; + if (read(STDIN_FILENO, &ch, 1) != 1) + return '\0'; + return ch; + }; + + if (write(STDOUT_FILENO, "\e[6n", 4) != 4) + return false; + + char ch = getchar_nonblock(); + if (ch != '\e') + { + if (ch != '\0') + ungetc(ch, stdin); + return false; + } + if (getchar_nonblock() != '[') + return false; + + int cur; + while (isdigit(cur = getchar_nonblock())) + ; + if (cur != ';') + return false; + while (isdigit(cur = getchar_nonblock())) + ; + if (cur != 'R') + return false; + + return true; +} + +struct CursorPosition +{ + int x; + int y; +}; + +static BAN::Optional try_read_cursor_position() +{ +#if __banan_os__ + return {}; +#endif + + static BAN::Optional s_supports_cursor_position; + if (!s_supports_cursor_position.has_value()) + s_supports_cursor_position = detect_cursor_position_support(); + + if (!s_supports_cursor_position.value()) + return {}; + + if (write(STDOUT_FILENO, "\e[6n", 4) != 4) + return {}; + + char ch = getchar(); + if (ch != '\e') + { + ungetc(ch, stdin); + return {}; + } + if (getchar() != '[') + return {}; + + int cur, x = 0, y = 0; + while (isdigit(cur = getchar())) + y = (y * 10) + (cur - '0'); + if (cur != ';') + return {}; + while (isdigit(cur = getchar())) + x = (x * 10) + (cur - '0'); + if (cur != 'R') + return {}; + + if (x > 0) x--; + if (y > 0) y--; + return CursorPosition { x, y }; +} + int main(int argc, char** argv) { realpath(argv[0], s_shell_path); @@ -1466,6 +1562,10 @@ int main(int argc, char** argv) MUST(history.push_back(buffers[index])); buffers = history; MUST(buffers.emplace_back(""_sv)); + + auto cursor_pos = try_read_cursor_position(); + if (cursor_pos.has_value() && cursor_pos->x > 0) + printf("\e[7m%%\e[m\n"); } print_prompt(); index = buffers.size() - 1;