diff --git a/userspace/Shell/main.cpp b/userspace/Shell/main.cpp index 5987888c70..98e14f2491 100644 --- a/userspace/Shell/main.cpp +++ b/userspace/Shell/main.cpp @@ -1,6 +1,8 @@ +#include #include #include #include + #include #include #include @@ -10,6 +12,59 @@ struct termios old_termios, new_termios; +BAN::Optional parse_dollar(BAN::StringView command, size_t& i) +{ + ASSERT(command[i] == '$'); + + if (++i >= command.size()) + return "$"sv; + + if (isalnum(command[i])) + { + size_t len = 1; + for (; i + len < command.size(); len++) + if (!isalnum(command[i + len])) + break; + BAN::String name = command.substring(i, len); + i += len - 1; + + if (const char* value = getenv(name.data())) + return BAN::StringView(value); + return ""sv; + } + else if (command[i] == '{') + { + size_t len = 1; + for (; i + len < command.size(); len++) + { + if (command[i + len] == '}') + break; + if (!isalnum(command[i + len])) + return {}; + } + + if (i + len >= command.size()) + return {}; + + BAN::String name = command.substring(i + 1, len - 1); + i += len; + + if (const char* value = getenv(name.data())) + return BAN::StringView(value); + return ""sv; + } + else if (command[i] == '[') + { + return {}; + } + else if (command[i] == '(') + { + return {}; + } + + return "$"sv; +} + BAN::Vector parse_command(BAN::StringView command) { enum class State @@ -23,8 +78,10 @@ BAN::Vector parse_command(BAN::StringView command) State state = State::Normal; BAN::String current; - for (char c : command) + for (size_t i = 0; i < command.size(); i++) { + char c = command[i]; + switch (state) { case State::Normal: @@ -32,6 +89,16 @@ BAN::Vector parse_command(BAN::StringView command) state = State::SingleQuote; else if (c == '"') state = State::DoubleQuote; + else if (c == '$') + { + auto expansion = parse_dollar(command, i); + if (!expansion.has_value()) + { + fprintf(stderr, "bad substitution\n"); + return {}; + } + MUST(current.append(expansion.value())); + } else if (!isspace(c)) MUST(current.push_back(c)); else @@ -52,8 +119,18 @@ BAN::Vector parse_command(BAN::StringView command) case State::DoubleQuote: if (c == '"') state = State::Normal; - else + else if (c != '$') MUST(current.push_back(c)); + else + { + auto expansion = parse_dollar(command, i); + if (!expansion.has_value()) + { + fprintf(stderr, "bad substitution\n"); + return {}; + } + MUST(current.append(expansion.value())); + } break; } } @@ -181,7 +258,7 @@ BAN::String get_prompt() { const char* raw_prompt = getenv("PS1"); if (raw_prompt == nullptr) - raw_prompt = "\e[32muser@host\e[m:\e[34m\\~\e[m$ "; + return ""sv; BAN::String prompt; for (int i = 0; raw_prompt[i]; i++) @@ -254,6 +331,7 @@ int main(int argc, char** argv) { if (argc >= 1) setenv("SHELL", argv[0], true); + setenv("PS1", "\e[32muser@host\e[m:\e[34m\\~\e[m$ ", false); tcgetattr(0, &old_termios);