banan-os/userspace/programs/Shell/Input.cpp

742 lines
17 KiB
C++
Raw Permalink Normal View History

#include "Alias.h"
#include "Builtin.h"
#include "Input.h"
#include <BAN/ScopeGuard.h>
#include <BAN/Sort.h>
#include <ctype.h>
#include <dirent.h>
#include <pwd.h>
#include <sys/stat.h>
#include <unistd.h>
static struct termios s_original_termios;
static struct termios s_raw_termios;
static bool s_termios_initialized { false };
static BAN::String s_history_path;
static BAN::Vector<BAN::String> s_history;
static constexpr size_t s_history_max_size = 1000;
static BAN::Vector<BAN::String> list_matching_entries(BAN::StringView path, BAN::StringView start, bool require_executable)
{
ASSERT(path.size() < PATH_MAX);
char path_cstr[PATH_MAX];
memcpy(path_cstr, path.data(), path.size());
path_cstr[path.size()] = '\0';
DIR* dirp = opendir(path_cstr);
if (dirp == nullptr)
return {};
BAN::Vector<BAN::String> result;
dirent* entry;
while ((entry = readdir(dirp)))
{
if (entry->d_name[0] == '.' && !start.starts_with("."_sv))
continue;
if (strncmp(entry->d_name, start.data(), start.size()))
continue;
struct stat st;
if (fstatat(dirfd(dirp), entry->d_name, &st, 0))
continue;
if (require_executable)
{
if (S_ISDIR(st.st_mode))
continue;
if (!(st.st_mode & (S_IXUSR | S_IXGRP | S_IXUSR)))
continue;
}
MUST(result.emplace_back(entry->d_name + start.size()));
if (S_ISDIR(st.st_mode))
MUST(result.back().push_back('/'));
}
closedir(dirp);
return BAN::move(result);
}
struct TabCompletion
{
bool should_escape_spaces { false };
BAN::StringView prefix;
BAN::Vector<BAN::String> completions;
};
static TabCompletion list_tab_completion_entries(BAN::StringView current_input)
{
enum class CompletionType
{
Command,
File,
};
BAN::StringView prefix = current_input;
BAN::String last_argument;
CompletionType completion_type = CompletionType::Command;
bool should_escape_spaces = true;
for (size_t i = 0; i < current_input.size(); i++)
{
if (current_input[i] == '\\')
{
i++;
if (i < current_input.size())
MUST(last_argument.push_back(current_input[i]));
}
else if (isspace(current_input[i]) || current_input[i] == ';' || current_input[i] == '|' || current_input.substring(i).starts_with("&&"_sv))
{
if (!isspace(current_input[i]))
completion_type = CompletionType::Command;
else if (!last_argument.empty())
completion_type = CompletionType::File;
if (auto rest = current_input.substring(i); rest.starts_with("||"_sv) || rest.starts_with("&&"_sv))
i++;
prefix = current_input.substring(i + 1);
last_argument.clear();
should_escape_spaces = true;
}
else if (current_input[i] == '\'' || current_input[i] == '"')
{
const char quote_type = current_input[i++];
while (i < current_input.size() && current_input[i] != quote_type)
MUST(last_argument.push_back(current_input[i++]));
should_escape_spaces = false;
}
else
{
MUST(last_argument.push_back(current_input[i]));
}
}
if (last_argument.sv().contains('/'))
completion_type = CompletionType::File;
BAN::Vector<BAN::String> result;
switch (completion_type)
{
case CompletionType::Command:
{
const char* path_env = getenv("PATH");
if (path_env)
{
auto splitted_path_env = MUST(BAN::StringView(path_env).split(':'));
for (auto path : splitted_path_env)
{
auto matching_entries = list_matching_entries(path, last_argument, true);
MUST(result.reserve(result.size() + matching_entries.size()));
for (auto&& entry : matching_entries)
MUST(result.push_back(BAN::move(entry)));
}
}
Builtin::get().for_each_builtin(
[&](BAN::StringView name, const Builtin::BuiltinCommand&) -> BAN::Iteration
{
if (name.starts_with(last_argument))
MUST(result.emplace_back(name.substring(last_argument.size())));
return BAN::Iteration::Continue;
}
);
Alias::get().for_each_alias(
[&](BAN::StringView name, BAN::StringView) -> BAN::Iteration
{
if (name.starts_with(last_argument))
MUST(result.emplace_back(name.substring(last_argument.size())));
return BAN::Iteration::Continue;
}
);
break;
}
case CompletionType::File:
{
BAN::String dir_path;
if (last_argument.sv().starts_with("/"_sv))
MUST(dir_path.push_back('/'));
else
{
char cwd_buffer[PATH_MAX];
if (getcwd(cwd_buffer, sizeof(cwd_buffer)) == nullptr)
return {};
MUST(dir_path.reserve(strlen(cwd_buffer) + 1));
MUST(dir_path.append(cwd_buffer));
MUST(dir_path.push_back('/'));
}
auto match_against = last_argument.sv();
if (auto idx = match_against.rfind('/'); idx.has_value())
{
MUST(dir_path.append(match_against.substring(0, idx.value())));
match_against = match_against.substring(idx.value() + 1);
}
result = list_matching_entries(dir_path, match_against, false);
break;
}
}
if (auto idx = prefix.rfind('/'); idx.has_value())
prefix = prefix.substring(idx.value() + 1);
return { should_escape_spaces, prefix, BAN::move(result) };
}
static int character_length(BAN::StringView prompt)
{
int length { 0 };
bool in_escape { false };
for (char c : prompt)
{
if (in_escape)
{
if (isalpha(c))
in_escape = false;
}
else
{
if (c == '\e')
in_escape = true;
else if (((uint8_t)c & 0xC0) != 0x80)
length++;
}
}
return length;
}
BAN::String Input::parse_ps1_prompt()
{
const char* raw_prompt = getenv("PS1");
if (raw_prompt == nullptr)
return "$ "_sv;
BAN::String prompt;
for (int i = 0; raw_prompt[i]; i++)
{
char ch = raw_prompt[i];
if (ch == '\\')
{
switch (raw_prompt[++i])
{
case 'e':
MUST(prompt.push_back('\e'));
break;
case 'n':
MUST(prompt.push_back('\n'));
break;
case '\\':
MUST(prompt.push_back('\\'));
break;
case '~':
{
char buffer[256];
if (getcwd(buffer, sizeof(buffer)) == nullptr)
strcpy(buffer, strerrorname_np(errno));
const char* home = getenv("HOME");
size_t home_len = home ? strlen(home) : 0;
if (home && strncmp(buffer, home, home_len) == 0)
{
MUST(prompt.push_back('~'));
MUST(prompt.append(buffer + home_len));
}
else
{
MUST(prompt.append(buffer));
}
break;
}
case 'u':
{
static char* username = nullptr;
if (username == nullptr)
{
auto* passwd = getpwuid(geteuid());
if (passwd == nullptr)
break;
username = new char[strlen(passwd->pw_name) + 1];
strcpy(username, passwd->pw_name);
endpwent();
}
MUST(prompt.append(username));
break;
}
case 'h':
{
MUST(prompt.append(m_hostname));
break;
}
case '\0':
MUST(prompt.push_back('\\'));
break;
default:
MUST(prompt.push_back('\\'));
MUST(prompt.push_back(*raw_prompt));
break;
}
}
else
{
MUST(prompt.push_back(ch));
}
}
return prompt;
}
BAN::Optional<BAN::String> Input::get_input(BAN::Optional<BAN::StringView> custom_prompt)
{
tcsetattr(0, TCSANOW, &s_raw_termios);
BAN::ScopeGuard _([] { tcsetattr(0, TCSANOW, &s_original_termios); });
BAN::String ps1_prompt;
if (!custom_prompt.has_value())
ps1_prompt = parse_ps1_prompt();
const auto print_prompt =
[&]()
{
if (custom_prompt.has_value())
printf("%.*s", (int)custom_prompt->size(), custom_prompt->data());
else
printf("%.*s", (int)ps1_prompt.size(), ps1_prompt.data());
};
const auto prompt_length =
[&]() -> int
{
if (custom_prompt.has_value())
return custom_prompt->size();
return character_length(ps1_prompt);
};
print_prompt();
fflush(stdout);
while (true)
{
int chi = getchar();
if (chi == EOF)
{
if (errno != EINTR)
{
perror("getchar");
exit(1);
}
clearerr(stdin);
m_buffers = s_history;
MUST(m_buffers.emplace_back(""_sv));
m_buffer_index = m_buffers.size() - 1;
m_buffer_col = 0;
putchar('\n');
print_prompt();
fflush(stdout);
continue;
}
uint8_t ch = chi;
if (ch != '\t')
{
m_tab_completions.clear();
m_tab_index.clear();
}
if (m_waiting_utf8 > 0)
{
m_waiting_utf8--;
ASSERT((ch & 0xC0) == 0x80);
putchar(ch);
MUST(m_buffers[m_buffer_index].insert(ch, m_buffer_col++));
if (m_waiting_utf8 == 0)
{
printf("\e[s%s\e[u", m_buffers[m_buffer_index].data() + m_buffer_col);
fflush(stdout);
}
continue;
}
else if (ch & 0x80)
{
if ((ch & 0xE0) == 0xC0)
m_waiting_utf8 = 1;
else if ((ch & 0xF0) == 0xE0)
m_waiting_utf8 = 2;
else if ((ch & 0xF8) == 0xF0)
m_waiting_utf8 = 3;
else
ASSERT_NOT_REACHED();
putchar(ch);
MUST(m_buffers[m_buffer_index].insert(ch, m_buffer_col++));
continue;
}
switch (ch)
{
case '\e':
{
ch = getchar();
if (ch != '[')
break;
ch = getchar();
int value = 0;
while (isdigit(ch))
{
value = (value * 10) + (ch - '0');
ch = getchar();
}
switch (ch)
{
case 'A':
if (m_buffer_index > 0)
{
m_buffer_index--;
m_buffer_col = m_buffers[m_buffer_index].size();
printf("\e[%dG%s\e[K", prompt_length() + 1, m_buffers[m_buffer_index].data());
fflush(stdout);
}
break;
case 'B':
if (m_buffer_index < m_buffers.size() - 1)
{
m_buffer_index++;
m_buffer_col = m_buffers[m_buffer_index].size();
printf("\e[%dG%s\e[K", prompt_length() + 1, m_buffers[m_buffer_index].data());
fflush(stdout);
}
break;
case 'C':
if (m_buffer_col < m_buffers[m_buffer_index].size())
{
m_buffer_col++;
while ((m_buffers[m_buffer_index][m_buffer_col - 1] & 0xC0) == 0x80)
m_buffer_col++;
printf("\e[C");
fflush(stdout);
}
break;
case 'D':
if (m_buffer_col > 0)
{
while ((m_buffers[m_buffer_index][m_buffer_col - 1] & 0xC0) == 0x80)
m_buffer_col--;
m_buffer_col--;
printf("\e[D");
fflush(stdout);
}
break;
case '~':
switch (value)
{
case 3: // delete
if (m_buffer_col >= m_buffers[m_buffer_index].size())
break;
m_buffers[m_buffer_index].remove(m_buffer_col);
while (m_buffer_col < m_buffers[m_buffer_index].size() && (m_buffers[m_buffer_index][m_buffer_col] & 0xC0) == 0x80)
m_buffers[m_buffer_index].remove(m_buffer_col);
printf("\e[s%s \e[u", m_buffers[m_buffer_index].data() + m_buffer_col);
fflush(stdout);
break;
}
break;
}
break;
}
case '\x0C': // ^L
{
int x = prompt_length() + character_length(m_buffers[m_buffer_index].sv().substring(m_buffer_col)) + 1;
printf("\e[H\e[J");
print_prompt();
printf("%s\e[u\e[1;%dH", m_buffers[m_buffer_index].data(), x);
fflush(stdout);
break;
}
case '\b':
if (m_buffer_col <= 0)
break;
while ((m_buffers[m_buffer_index][m_buffer_col - 1] & 0xC0) == 0x80)
m_buffers[m_buffer_index].remove(--m_buffer_col);
m_buffers[m_buffer_index].remove(--m_buffer_col);
printf("\b\e[s%s \e[u", m_buffers[m_buffer_index].data() + m_buffer_col);
fflush(stdout);
break;
case '\x01': // ^A
m_buffer_col = 0;
printf("\e[%dG", prompt_length() + 1);
fflush(stdout);
break;
case '\x03': // ^C
putchar('\n');
print_prompt();
fflush(stdout);
m_buffers[m_buffer_index].clear();
m_buffer_col = 0;
break;
case '\x04': // ^D
if (!m_buffers[m_buffer_index].empty())
break;
putchar('\n');
return {};
case '\x7F':
if (m_buffer_col >= m_buffers[m_buffer_index].size())
break;
m_buffers[m_buffer_index].remove(m_buffer_col);
while (m_buffer_col < m_buffers[m_buffer_index].size() && (m_buffers[m_buffer_index][m_buffer_col] & 0xC0) == 0x80)
m_buffers[m_buffer_index].remove(m_buffer_col);
printf("\e[s%s \e[u", m_buffers[m_buffer_index].data() + m_buffer_col);
fflush(stdout);
break;
case '\n':
{
BAN::String input;
MUST(input.append(m_buffers[m_buffer_index]));
if (!m_buffers[m_buffer_index].empty())
{
MUST(s_history.push_back(m_buffers[m_buffer_index]));
if (s_history.size() > s_history_max_size)
s_history.remove(0);
m_buffers = s_history;
MUST(m_buffers.emplace_back(""_sv));
}
m_buffer_index = m_buffers.size() - 1;
m_buffer_col = 0;
putchar('\n');
return input;
}
case '\t':
{
// FIXME: tab completion is really hacked together currently.
// this should ask token parser about the current parse state
// and do completions based on that, not raw strings
if (m_buffer_col != m_buffers[m_buffer_index].size())
continue;
if (m_tab_completions.has_value())
{
ASSERT(m_tab_completions->size() >= 2);
if (!m_tab_index.has_value())
m_tab_index = 0;
else
{
MUST(m_buffers[m_buffer_index].resize(m_tab_completion_keep));
m_buffer_col = m_tab_completion_keep;
*m_tab_index = (*m_tab_index + 1) % m_tab_completions->size();
}
MUST(m_buffers[m_buffer_index].append(m_tab_completions.value()[*m_tab_index]));
m_buffer_col += m_tab_completions.value()[*m_tab_index].size();
printf("\e[%dG%s\e[K", prompt_length() + 1, m_buffers[m_buffer_index].data());
fflush(stdout);
break;
}
m_tab_completion_keep = m_buffer_col;
auto [should_escape_spaces, prefix, completions] = list_tab_completion_entries(m_buffers[m_buffer_index].sv().substring(0, m_tab_completion_keep));
BAN::sort::sort(completions.begin(), completions.end(),
[](const BAN::String& a, const BAN::String& b) {
if (auto cmp = strcmp(a.data(), b.data()))
return cmp < 0;
return a.size() < b.size();
}
);
for (size_t i = 1; i < completions.size();)
{
if (completions[i - 1] == completions[i])
completions.remove(i);
else
i++;
}
if (completions.empty())
break;
size_t all_match_len = 0;
for (;;)
{
if (completions.front().size() <= all_match_len)
break;
const char target = completions.front()[all_match_len];
bool all_matched = true;
for (const auto& completion : completions)
{
if (completion.size() > all_match_len && completion[all_match_len] == target)
continue;
all_matched = false;
break;
}
if (!all_matched)
break;
all_match_len++;
}
if (all_match_len)
{
auto completion = completions.front().sv().substring(0, all_match_len);
BAN::String temp_escaped;
if (should_escape_spaces)
{
MUST(temp_escaped.append(completion));
for (size_t i = 0; i < temp_escaped.size(); i++)
{
if (!isspace(temp_escaped[i]))
continue;
MUST(temp_escaped.insert('\\', i));
i++;
}
completion = temp_escaped.sv();
if (!m_buffers[m_buffer_index].empty() && m_buffers[m_buffer_index].back() == '\\' && completion.front() == '\\')
completion = completion.substring(1);
}
m_buffer_col += completion.size();
MUST(m_buffers[m_buffer_index].append(completion));
printf("%.*s", (int)completion.size(), completion.data());
fflush(stdout);
break;
}
if (completions.size() == 1)
{
ASSERT(all_match_len == completions.front().size());
break;
}
printf("\n");
for (size_t i = 0; i < completions.size(); i++)
{
if (i != 0)
printf(" ");
const char* format = completions[i].sv().contains(' ') ? "'%.*s%s'" : "%.*s%s";
printf(format, (int)prefix.size(), prefix.data(), completions[i].data());
}
printf("\n");
print_prompt();
printf("%s", m_buffers[m_buffer_index].data());
fflush(stdout);
if (should_escape_spaces)
{
for (auto& completion : completions)
{
for (size_t i = 0; i < completion.size(); i++)
{
if (!isspace(completion[i]))
continue;
MUST(completion.insert('\\', i));
i++;
}
}
}
m_tab_completion_keep = m_buffer_col;
m_tab_completions = BAN::move(completions);
break;
}
default:
MUST(m_buffers[m_buffer_index].insert(ch, m_buffer_col++));
if (m_buffer_col == m_buffers[m_buffer_index].size())
putchar(ch);
else
printf("%c\e[s%s\e[u", ch, m_buffers[m_buffer_index].data() + m_buffer_col);
fflush(stdout);
break;
}
}
}
Input::Input()
{
if (!s_termios_initialized)
{
tcgetattr(0, &s_original_termios);
s_raw_termios = s_original_termios;
s_raw_termios.c_lflag &= ~(ECHO | ICANON);
atexit([] { tcsetattr(0, TCSANOW, &s_original_termios); });
s_termios_initialized = true;
}
char hostname_buffer[HOST_NAME_MAX];
if (gethostname(hostname_buffer, sizeof(hostname_buffer)) == 0) {
MUST(m_hostname.append(hostname_buffer));
}
if (const char* home_env = getenv("HOME"))
{
MUST(s_history_path.append(home_env));
MUST(s_history_path.append("/.shell_history"));
atexit([] {
FILE* history_fp = fopen(s_history_path.data(), "w");
if (history_fp == nullptr)
return;
for (const auto& line : s_history)
{
fwrite(line.data(), line.size(), 1, history_fp);
fputc('\n', history_fp);
}
fclose(history_fp);
});
FILE* history_fp = fopen(s_history_path.data(), "r");
if (history_fp == nullptr)
goto history_initialized;
char buffer[128];
BAN::String current_line;
while (fgets(buffer, sizeof(buffer), history_fp))
{
MUST(current_line.append(buffer));
if (current_line.back() != '\n')
continue;
current_line.pop_back();
MUST(s_history.push_back(BAN::move(current_line)));
if (s_history.size() > s_history_max_size)
s_history.remove(0);
current_line.clear();
}
fclose(history_fp);
if (!current_line.empty())
{
MUST(s_history.push_back(BAN::move(current_line)));
if (s_history.size() > s_history_max_size)
s_history.remove(0);
}
}
history_initialized:
m_buffers = s_history;
MUST(m_buffers.emplace_back(""_sv));
m_buffer_index = m_buffers.size() - 1;
}