From c54d9b3f6089dbe3d24d2c321c98f60dd17c839e Mon Sep 17 00:00:00 2001 From: Bananymous Date: Sun, 6 Oct 2024 18:18:56 +0300 Subject: [PATCH] Shell: Implement simple tab completion for commands and files --- userspace/programs/Shell/main.cpp | 282 +++++++++++++++++++++++++++++- 1 file changed, 281 insertions(+), 1 deletion(-) diff --git a/userspace/programs/Shell/main.cpp b/userspace/programs/Shell/main.cpp index 4a6ba00990..88bb5c692f 100644 --- a/userspace/programs/Shell/main.cpp +++ b/userspace/programs/Shell/main.cpp @@ -1,9 +1,11 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -834,6 +836,169 @@ static int source_shellrc() return 0; } +static BAN::Vector list_matching_entries(BAN::StringView path, BAN::StringView start, bool require_executable) +{ + ASSERT(path.size() < PATH_MAX); + + char path_cstr[PATH_MAX]; + memcpy(path_cstr, path.data(), path.size()); + path_cstr[path.size()] = '\0'; + + DIR* dirp = opendir(path_cstr); + if (dirp == nullptr) + return {}; + + BAN::Vector result; + + dirent* entry; + while ((entry = readdir(dirp))) + { + if (entry->d_name[0] == '.' && !start.starts_with("."_sv)) + continue; + if (strncmp(entry->d_name, start.data(), start.size())) + continue; + + struct stat st; + if (fstatat(dirfd(dirp), entry->d_name, &st, 0)) + continue; + + if (require_executable) + { + if (S_ISDIR(st.st_mode)) + continue; + if (!(st.st_mode & (S_IXUSR | S_IXGRP | S_IXUSR))) + continue; + } + + MUST(result.emplace_back(entry->d_name + start.size())); + if (S_ISDIR(st.st_mode)) + MUST(result.back().push_back('/')); + } + + closedir(dirp); + + return BAN::move(result); +} + +struct TabCompletion +{ + bool should_escape_spaces; + BAN::StringView prefix; + BAN::Vector completions; +}; + +static TabCompletion list_tab_completion_entries(BAN::StringView command) +{ + enum class CompletionType + { + Command, + File, + }; + + BAN::StringView prefix = command; + BAN::String last_argument; + CompletionType completion_type = CompletionType::Command; + + bool should_escape_spaces = true; + for (size_t i = 0; i < command.size(); i++) + { + if (command[i] == '\\') + { + i++; + if (i < command.size()) + MUST(last_argument.push_back(command[i])); + } + else if (isspace(command[i]) || command[i] == ';' || command[i] == '|' || command.substring(i).starts_with("&&"_sv)) + { + if (!isspace(command[i])) + completion_type = CompletionType::Command; + else if (!last_argument.empty()) + completion_type = CompletionType::File; + if (auto rest = command.substring(i); rest.starts_with("||"_sv) || rest.starts_with("&&"_sv)) + i++; + prefix = command.substring(i + 1); + last_argument.clear(); + should_escape_spaces = true; + } + else if (command[i] == '\'' || command[i] == '"') + { + const char quote_type = command[i++]; + while (i < command.size() && command[i] != quote_type) + MUST(last_argument.push_back(command[i++])); + should_escape_spaces = false; + } + else + { + MUST(last_argument.push_back(command[i])); + } + } + + if (last_argument.sv().contains('/')) + completion_type = CompletionType::File; + + BAN::Vector result; + switch (completion_type) + { + case CompletionType::Command: + { + const char* path_env = getenv("PATH"); + if (path_env) + { + auto splitted_path_env = MUST(BAN::StringView(path_env).split(':')); + for (auto path : splitted_path_env) + { + auto matching_entries = list_matching_entries(path, last_argument, true); + MUST(result.reserve(result.size() + matching_entries.size())); + for (auto&& entry : matching_entries) + MUST(result.push_back(BAN::move(entry))); + } + } + + for (const auto& [builtin_name, _] : s_builtin_commands) + { + if (!builtin_name.sv().starts_with(last_argument)) + continue; + MUST(result.emplace_back(builtin_name.sv().substring(last_argument.size()))); + } + + // TODO: match aliases when added + + break; + } + case CompletionType::File: + { + BAN::String dir_path; + if (last_argument.sv().starts_with("/"_sv)) + MUST(dir_path.push_back('/')); + else + { + char cwd_buffer[PATH_MAX]; + if (getcwd(cwd_buffer, sizeof(cwd_buffer)) == nullptr) + return {}; + MUST(dir_path.reserve(strlen(cwd_buffer) + 1)); + MUST(dir_path.append(cwd_buffer)); + MUST(dir_path.push_back('/')); + } + + auto match_against = last_argument.sv(); + if (auto idx = match_against.rfind('/'); idx.has_value()) + { + MUST(dir_path.append(match_against.substring(0, idx.value()))); + match_against = match_against.substring(idx.value() + 1); + } + + result = list_matching_entries(dir_path, match_against, false); + + break; + } + } + + if (auto idx = prefix.rfind('/'); idx.has_value()) + prefix = prefix.substring(idx.value() + 1); + + return { should_escape_spaces, prefix, BAN::move(result) }; +} + static int character_length(BAN::StringView prompt) { int length { 0 }; @@ -1018,6 +1183,10 @@ int main(int argc, char** argv) size_t index = 0; size_t col = 0; + BAN::Optional tab_index; + BAN::Optional> tab_completions; + size_t tab_completion_keep = 0; + int waiting_utf8 = 0; print_prompt(); @@ -1043,6 +1212,11 @@ int main(int argc, char** argv) } uint8_t ch = chi; + if (ch != '\t') + { + tab_completions.clear(); + tab_index.clear(); + } if (waiting_utf8 > 0) { @@ -1168,8 +1342,114 @@ int main(int argc, char** argv) col = 0; break; case '\t': - // FIXME: Implement tab completion or something + { + if (col != buffers[index].size()) + continue; + + if (tab_completions.has_value()) + { + ASSERT(tab_completions->size() >= 2); + + if (!tab_index.has_value()) + tab_index = 0; + else + { + MUST(buffers[index].resize(tab_completion_keep)); + col = tab_completion_keep; + *tab_index = (*tab_index + 1) % tab_completions->size(); + } + + MUST(buffers[index].append(tab_completions.value()[*tab_index])); + col += tab_completions.value()[*tab_index].size(); + + printf("\e[%dG%s\e[K", prompt_length() + 1, buffers[index].data()); + fflush(stdout); + + break; + } + + tab_completion_keep = col; + auto [should_escape_spaces, prefix, completions] = list_tab_completion_entries(buffers[index].sv().substring(0, tab_completion_keep)); + + if (completions.empty()) + break; + + size_t all_match_len = 0; + for (;;) + { + if (completions.front().size() <= all_match_len) + break; + const char target = completions.front()[all_match_len]; + + bool all_matched = true; + for (const auto& completion : completions) + { + if (completion.size() > all_match_len && completion[all_match_len] == target) + continue; + all_matched = false; + break; + } + + if (!all_matched) + break; + all_match_len++; + } + + if (all_match_len) + { + col += all_match_len; + MUST(buffers[index].append(completions.front().sv().substring(0, all_match_len))); + printf("%.*s", (int)all_match_len, completions.front().data()); + fflush(stdout); + break; + } + + if (completions.size() == 1) + { + ASSERT(all_match_len == completions.front().size()); + break; + } + + BAN::sort::sort(completions.begin(), completions.end(), + [](const BAN::String& a, const BAN::String& b) { + if (auto cmp = strcmp(a.data(), b.data())) + return cmp < 0; + return a.size() < b.size(); + } + ); + + printf("\n"); + for (size_t i = 0; i < completions.size(); i++) + { + if (i != 0) + printf(" "); + const char* format = completions[i].sv().contains(' ') ? "'%.*s%s'" : "%.*s%s"; + printf(format, (int)prefix.size(), prefix.data(), completions[i].data()); + } + printf("\n"); + print_prompt(); + printf("%s", buffers[index].data()); + fflush(stdout); + + if (should_escape_spaces) + { + for (auto& completion : completions) + { + for (size_t i = 0; i < completion.size(); i++) + { + if (!isspace(completion[i])) + continue; + MUST(completion.insert('\\', i)); + i++; + } + } + } + + tab_completion_keep = col; + tab_completions = BAN::move(completions); + break; + } default: MUST(buffers[index].insert(ch, col++)); if (col == buffers[index].size())