Shell: rewrite the whole shell to use tokens instead of raw strings

tab completion is still running with raw strings and that has to be
fixed in the future.
This commit is contained in:
2024-10-13 21:56:59 +03:00
parent dab6e5a60f
commit 8adc97980a
18 changed files with 2721 additions and 1732 deletions

View File

@@ -0,0 +1,682 @@
#include "Alias.h"
#include "Builtin.h"
#include "Input.h"
#include <BAN/ScopeGuard.h>
#include <BAN/Sort.h>
#include <ctype.h>
#include <dirent.h>
#include <pwd.h>
#include <sys/stat.h>
#include <unistd.h>
static struct termios s_original_termios;
static struct termios s_raw_termios;
static bool s_termios_initialized { false };
static BAN::Vector<BAN::String> list_matching_entries(BAN::StringView path, BAN::StringView start, bool require_executable)
{
ASSERT(path.size() < PATH_MAX);
char path_cstr[PATH_MAX];
memcpy(path_cstr, path.data(), path.size());
path_cstr[path.size()] = '\0';
DIR* dirp = opendir(path_cstr);
if (dirp == nullptr)
return {};
BAN::Vector<BAN::String> result;
dirent* entry;
while ((entry = readdir(dirp)))
{
if (entry->d_name[0] == '.' && !start.starts_with("."_sv))
continue;
if (strncmp(entry->d_name, start.data(), start.size()))
continue;
struct stat st;
if (fstatat(dirfd(dirp), entry->d_name, &st, 0))
continue;
if (require_executable)
{
if (S_ISDIR(st.st_mode))
continue;
if (!(st.st_mode & (S_IXUSR | S_IXGRP | S_IXUSR)))
continue;
}
MUST(result.emplace_back(entry->d_name + start.size()));
if (S_ISDIR(st.st_mode))
MUST(result.back().push_back('/'));
}
closedir(dirp);
return BAN::move(result);
}
struct TabCompletion
{
bool should_escape_spaces { false };
BAN::StringView prefix;
BAN::Vector<BAN::String> completions;
};
static TabCompletion list_tab_completion_entries(BAN::StringView current_input)
{
enum class CompletionType
{
Command,
File,
};
BAN::StringView prefix = current_input;
BAN::String last_argument;
CompletionType completion_type = CompletionType::Command;
bool should_escape_spaces = true;
for (size_t i = 0; i < current_input.size(); i++)
{
if (current_input[i] == '\\')
{
i++;
if (i < current_input.size())
MUST(last_argument.push_back(current_input[i]));
}
else if (isspace(current_input[i]) || current_input[i] == ';' || current_input[i] == '|' || current_input.substring(i).starts_with("&&"_sv))
{
if (!isspace(current_input[i]))
completion_type = CompletionType::Command;
else if (!last_argument.empty())
completion_type = CompletionType::File;
if (auto rest = current_input.substring(i); rest.starts_with("||"_sv) || rest.starts_with("&&"_sv))
i++;
prefix = current_input.substring(i + 1);
last_argument.clear();
should_escape_spaces = true;
}
else if (current_input[i] == '\'' || current_input[i] == '"')
{
const char quote_type = current_input[i++];
while (i < current_input.size() && current_input[i] != quote_type)
MUST(last_argument.push_back(current_input[i++]));
should_escape_spaces = false;
}
else
{
MUST(last_argument.push_back(current_input[i]));
}
}
if (last_argument.sv().contains('/'))
completion_type = CompletionType::File;
BAN::Vector<BAN::String> result;
switch (completion_type)
{
case CompletionType::Command:
{
const char* path_env = getenv("PATH");
if (path_env)
{
auto splitted_path_env = MUST(BAN::StringView(path_env).split(':'));
for (auto path : splitted_path_env)
{
auto matching_entries = list_matching_entries(path, last_argument, true);
MUST(result.reserve(result.size() + matching_entries.size()));
for (auto&& entry : matching_entries)
MUST(result.push_back(BAN::move(entry)));
}
}
Builtin::get().for_each_builtin(
[&](BAN::StringView name, const Builtin::BuiltinCommand&) -> BAN::Iteration
{
if (name.starts_with(last_argument))
MUST(result.emplace_back(name.substring(last_argument.size())));
return BAN::Iteration::Continue;
}
);
Alias::get().for_each_alias(
[&](BAN::StringView name, BAN::StringView) -> BAN::Iteration
{
if (name.starts_with(last_argument))
MUST(result.emplace_back(name.substring(last_argument.size())));
return BAN::Iteration::Continue;
}
);
break;
}
case CompletionType::File:
{
BAN::String dir_path;
if (last_argument.sv().starts_with("/"_sv))
MUST(dir_path.push_back('/'));
else
{
char cwd_buffer[PATH_MAX];
if (getcwd(cwd_buffer, sizeof(cwd_buffer)) == nullptr)
return {};
MUST(dir_path.reserve(strlen(cwd_buffer) + 1));
MUST(dir_path.append(cwd_buffer));
MUST(dir_path.push_back('/'));
}
auto match_against = last_argument.sv();
if (auto idx = match_against.rfind('/'); idx.has_value())
{
MUST(dir_path.append(match_against.substring(0, idx.value())));
match_against = match_against.substring(idx.value() + 1);
}
result = list_matching_entries(dir_path, match_against, false);
break;
}
}
if (auto idx = prefix.rfind('/'); idx.has_value())
prefix = prefix.substring(idx.value() + 1);
return { should_escape_spaces, prefix, BAN::move(result) };
}
static int character_length(BAN::StringView prompt)
{
int length { 0 };
bool in_escape { false };
for (char c : prompt)
{
if (in_escape)
{
if (isalpha(c))
in_escape = false;
}
else
{
if (c == '\e')
in_escape = true;
else if (((uint8_t)c & 0xC0) != 0x80)
length++;
}
}
return length;
}
BAN::String Input::parse_ps1_prompt()
{
const char* raw_prompt = getenv("PS1");
if (raw_prompt == nullptr)
return "$ "_sv;
BAN::String prompt;
for (int i = 0; raw_prompt[i]; i++)
{
char ch = raw_prompt[i];
if (ch == '\\')
{
switch (raw_prompt[++i])
{
case 'e':
MUST(prompt.push_back('\e'));
break;
case 'n':
MUST(prompt.push_back('\n'));
break;
case '\\':
MUST(prompt.push_back('\\'));
break;
case '~':
{
char buffer[256];
if (getcwd(buffer, sizeof(buffer)) == nullptr)
strcpy(buffer, strerrorname_np(errno));
const char* home = getenv("HOME");
size_t home_len = home ? strlen(home) : 0;
if (home && strncmp(buffer, home, home_len) == 0)
{
MUST(prompt.push_back('~'));
MUST(prompt.append(buffer + home_len));
}
else
{
MUST(prompt.append(buffer));
}
break;
}
case 'u':
{
static char* username = nullptr;
if (username == nullptr)
{
auto* passwd = getpwuid(geteuid());
if (passwd == nullptr)
break;
username = new char[strlen(passwd->pw_name) + 1];
strcpy(username, passwd->pw_name);
endpwent();
}
MUST(prompt.append(username));
break;
}
case 'h':
{
MUST(prompt.append(m_hostname));
break;
}
case '\0':
MUST(prompt.push_back('\\'));
break;
default:
MUST(prompt.push_back('\\'));
MUST(prompt.push_back(*raw_prompt));
break;
}
}
else
{
MUST(prompt.push_back(ch));
}
}
return prompt;
}
BAN::Optional<BAN::String> Input::get_input(BAN::Optional<BAN::StringView> custom_prompt)
{
tcsetattr(0, TCSANOW, &s_raw_termios);
BAN::ScopeGuard _([] { tcsetattr(0, TCSANOW, &s_original_termios); });
BAN::String ps1_prompt;
if (!custom_prompt.has_value())
ps1_prompt = parse_ps1_prompt();
const auto print_prompt =
[&]()
{
if (custom_prompt.has_value())
printf("%.*s", (int)custom_prompt->size(), custom_prompt->data());
else
printf("%.*s", (int)ps1_prompt.size(), ps1_prompt.data());
};
const auto prompt_length =
[&]() -> int
{
if (custom_prompt.has_value())
return custom_prompt->size();
return character_length(ps1_prompt);
};
print_prompt();
fflush(stdout);
while (true)
{
int chi = getchar();
if (chi == EOF)
{
if (errno != EINTR)
{
perror("getchar");
exit(1);
}
clearerr(stdin);
m_buffers = m_history;
MUST(m_buffers.emplace_back(""_sv));
m_buffer_index = m_buffers.size() - 1;
m_buffer_col = 0;
putchar('\n');
print_prompt();
fflush(stdout);
continue;
}
uint8_t ch = chi;
if (ch != '\t')
{
m_tab_completions.clear();
m_tab_index.clear();
}
if (m_waiting_utf8 > 0)
{
m_waiting_utf8--;
ASSERT((ch & 0xC0) == 0x80);
putchar(ch);
MUST(m_buffers[m_buffer_index].insert(ch, m_buffer_col++));
if (m_waiting_utf8 == 0)
{
printf("\e[s%s\e[u", m_buffers[m_buffer_index].data() + m_buffer_col);
fflush(stdout);
}
continue;
}
else if (ch & 0x80)
{
if ((ch & 0xE0) == 0xC0)
m_waiting_utf8 = 1;
else if ((ch & 0xF0) == 0xE0)
m_waiting_utf8 = 2;
else if ((ch & 0xF8) == 0xF0)
m_waiting_utf8 = 3;
else
ASSERT_NOT_REACHED();
putchar(ch);
MUST(m_buffers[m_buffer_index].insert(ch, m_buffer_col++));
continue;
}
switch (ch)
{
case '\e':
{
ch = getchar();
if (ch != '[')
break;
ch = getchar();
int value = 0;
while (isdigit(ch))
{
value = (value * 10) + (ch - '0');
ch = getchar();
}
switch (ch)
{
case 'A':
if (m_buffer_index > 0)
{
m_buffer_index--;
m_buffer_col = m_buffers[m_buffer_index].size();
printf("\e[%dG%s\e[K", prompt_length() + 1, m_buffers[m_buffer_index].data());
fflush(stdout);
}
break;
case 'B':
if (m_buffer_index < m_buffers.size() - 1)
{
m_buffer_index++;
m_buffer_col = m_buffers[m_buffer_index].size();
printf("\e[%dG%s\e[K", prompt_length() + 1, m_buffers[m_buffer_index].data());
fflush(stdout);
}
break;
case 'C':
if (m_buffer_col < m_buffers[m_buffer_index].size())
{
m_buffer_col++;
while ((m_buffers[m_buffer_index][m_buffer_col - 1] & 0xC0) == 0x80)
m_buffer_col++;
printf("\e[C");
fflush(stdout);
}
break;
case 'D':
if (m_buffer_col > 0)
{
while ((m_buffers[m_buffer_index][m_buffer_col - 1] & 0xC0) == 0x80)
m_buffer_col--;
m_buffer_col--;
printf("\e[D");
fflush(stdout);
}
break;
case '~':
switch (value)
{
case 3: // delete
if (m_buffer_col >= m_buffers[m_buffer_index].size())
break;
m_buffers[m_buffer_index].remove(m_buffer_col);
while (m_buffer_col < m_buffers[m_buffer_index].size() && (m_buffers[m_buffer_index][m_buffer_col] & 0xC0) == 0x80)
m_buffers[m_buffer_index].remove(m_buffer_col);
printf("\e[s%s \e[u", m_buffers[m_buffer_index].data() + m_buffer_col);
fflush(stdout);
break;
}
break;
}
break;
}
case '\x0C': // ^L
{
int x = prompt_length() + character_length(m_buffers[m_buffer_index].sv().substring(m_buffer_col)) + 1;
printf("\e[H\e[J");
print_prompt();
printf("%s\e[u\e[1;%dH", m_buffers[m_buffer_index].data(), x);
fflush(stdout);
break;
}
case '\b':
if (m_buffer_col <= 0)
break;
while ((m_buffers[m_buffer_index][m_buffer_col - 1] & 0xC0) == 0x80)
m_buffer_col--;
m_buffer_col--;
printf("\e[D");
fflush(stdout);
break;
case '\x01': // ^A
m_buffer_col = 0;
printf("\e[%dG", prompt_length() + 1);
fflush(stdout);
break;
case '\x03': // ^C
putchar('\n');
print_prompt();
fflush(stdout);
m_buffers[m_buffer_index].clear();
m_buffer_col = 0;
break;
case '\x04': // ^D
if (!m_buffers[m_buffer_index].empty())
break;
putchar('\n');
return {};
case '\x7F': // backspace
if (m_buffer_col <= 0)
break;
while ((m_buffers[m_buffer_index][m_buffer_col - 1] & 0xC0) == 0x80)
m_buffers[m_buffer_index].remove(--m_buffer_col);
m_buffers[m_buffer_index].remove(--m_buffer_col);
printf("\b\e[s%s \e[u", m_buffers[m_buffer_index].data() + m_buffer_col);
fflush(stdout);
break;
case '\n':
{
BAN::String input;
MUST(input.append(m_buffers[m_buffer_index]));
if (!m_buffers[m_buffer_index].empty())
{
MUST(m_history.push_back(m_buffers[m_buffer_index]));
m_buffers = m_history;
MUST(m_buffers.emplace_back(""_sv));
}
m_buffer_index = m_buffers.size() - 1;
m_buffer_col = 0;
putchar('\n');
return input;
}
case '\t':
{
// FIXME: tab completion is really hacked together currently.
// this should ask token parser about the current parse state
// and do completions based on that, not raw strings
if (m_buffer_col != m_buffers[m_buffer_index].size())
continue;
if (m_tab_completions.has_value())
{
ASSERT(m_tab_completions->size() >= 2);
if (!m_tab_index.has_value())
m_tab_index = 0;
else
{
MUST(m_buffers[m_buffer_index].resize(m_tab_completion_keep));
m_buffer_col = m_tab_completion_keep;
*m_tab_index = (*m_tab_index + 1) % m_tab_completions->size();
}
MUST(m_buffers[m_buffer_index].append(m_tab_completions.value()[*m_tab_index]));
m_buffer_col += m_tab_completions.value()[*m_tab_index].size();
printf("\e[%dG%s\e[K", prompt_length() + 1, m_buffers[m_buffer_index].data());
fflush(stdout);
break;
}
m_tab_completion_keep = m_buffer_col;
auto [should_escape_spaces, prefix, completions] = list_tab_completion_entries(m_buffers[m_buffer_index].sv().substring(0, m_tab_completion_keep));
BAN::sort::sort(completions.begin(), completions.end(),
[](const BAN::String& a, const BAN::String& b) {
if (auto cmp = strcmp(a.data(), b.data()))
return cmp < 0;
return a.size() < b.size();
}
);
for (size_t i = 1; i < completions.size();)
{
if (completions[i - 1] == completions[i])
completions.remove(i);
else
i++;
}
if (completions.empty())
break;
size_t all_match_len = 0;
for (;;)
{
if (completions.front().size() <= all_match_len)
break;
const char target = completions.front()[all_match_len];
bool all_matched = true;
for (const auto& completion : completions)
{
if (completion.size() > all_match_len && completion[all_match_len] == target)
continue;
all_matched = false;
break;
}
if (!all_matched)
break;
all_match_len++;
}
if (all_match_len)
{
auto completion = completions.front().sv().substring(0, all_match_len);
BAN::String temp_escaped;
if (should_escape_spaces)
{
MUST(temp_escaped.append(completion));
for (size_t i = 0; i < temp_escaped.size(); i++)
{
if (!isspace(temp_escaped[i]))
continue;
MUST(temp_escaped.insert('\\', i));
i++;
}
completion = temp_escaped.sv();
if (!m_buffers[m_buffer_index].empty() && m_buffers[m_buffer_index].back() == '\\' && completion.front() == '\\')
completion = completion.substring(1);
}
m_buffer_col += completion.size();
MUST(m_buffers[m_buffer_index].append(completion));
printf("%.*s", (int)completion.size(), completion.data());
fflush(stdout);
break;
}
if (completions.size() == 1)
{
ASSERT(all_match_len == completions.front().size());
break;
}
printf("\n");
for (size_t i = 0; i < completions.size(); i++)
{
if (i != 0)
printf(" ");
const char* format = completions[i].sv().contains(' ') ? "'%.*s%s'" : "%.*s%s";
printf(format, (int)prefix.size(), prefix.data(), completions[i].data());
}
printf("\n");
print_prompt();
printf("%s", m_buffers[m_buffer_index].data());
fflush(stdout);
if (should_escape_spaces)
{
for (auto& completion : completions)
{
for (size_t i = 0; i < completion.size(); i++)
{
if (!isspace(completion[i]))
continue;
MUST(completion.insert('\\', i));
i++;
}
}
}
m_tab_completion_keep = m_buffer_col;
m_tab_completions = BAN::move(completions);
break;
}
default:
MUST(m_buffers[m_buffer_index].insert(ch, m_buffer_col++));
if (m_buffer_col == m_buffers[m_buffer_index].size())
putchar(ch);
else
printf("%c\e[s%s\e[u", ch, m_buffers[m_buffer_index].data() + m_buffer_col);
fflush(stdout);
break;
}
}
}
Input::Input()
{
if (!s_termios_initialized)
{
tcgetattr(0, &s_original_termios);
s_raw_termios = s_original_termios;
s_raw_termios.c_lflag &= ~(ECHO | ICANON);
atexit([] { tcsetattr(0, TCSANOW, &s_original_termios); });
s_termios_initialized = true;
}
char hostname_buffer[HOST_NAME_MAX];
if (gethostname(hostname_buffer, sizeof(hostname_buffer)) == 0) {
MUST(m_hostname.append(hostname_buffer));
}
}