diff --git a/userspace/programs/Shell/main.cpp b/userspace/programs/Shell/main.cpp index 92d21f76..fbe72081 100644 --- a/userspace/programs/Shell/main.cpp +++ b/userspace/programs/Shell/main.cpp @@ -25,7 +25,47 @@ static int last_return = 0; static BAN::String hostname; -static BAN::Vector> parse_command(BAN::StringView); +struct SingleCommand +{ + BAN::Vector arguments; +}; + +struct PipedCommand +{ + BAN::Vector commands; +}; + +struct CommandList +{ + enum class Condition + { + Always, + OnSuccess, + OnFailure, + }; + + struct Command + { + BAN::String expression; + Condition condition; + }; + BAN::Vector commands; +}; + +static BAN::StringView strip_whitespace(BAN::StringView sv) +{ + size_t leading = 0; + while (leading < sv.size() && isspace(sv[leading])) + leading++; + sv = sv.substring(leading); + + size_t trailing = 0; + while (trailing < sv.size() && isspace(sv[sv.size() - trailing - 1])) + trailing++; + sv = sv.substring(0, sv.size() - trailing); + + return sv; +} static BAN::Optional parse_dollar(BAN::StringView command, size_t& i) { @@ -157,22 +197,7 @@ static BAN::Optional parse_dollar(BAN::StringView command, size_t& return temp; } -static BAN::StringView strip_whitespace(BAN::StringView sv) -{ - size_t leading = 0; - while (leading < sv.size() && isspace(sv[leading])) - leading++; - sv = sv.substring(leading); - - size_t trailing = 0; - while (trailing < sv.size() && isspace(sv[sv.size() - trailing - 1])) - trailing++; - sv = sv.substring(0, sv.size() - trailing); - - return sv; -} - -static BAN::Vector> parse_command(BAN::StringView command_view) +static PipedCommand parse_piped_command(BAN::StringView command_view) { enum class State { @@ -183,11 +208,10 @@ static BAN::Vector> parse_command(BAN::StringView comma command_view = strip_whitespace(command_view); - BAN::Vector> result; - BAN::Vector command_args; - State state = State::Normal; - BAN::String current_arg; + SingleCommand current_command; + BAN::String current_argument; + PipedCommand result; for (size_t i = 0; i < command_view.size(); i++) { char c = command_view[i]; @@ -198,7 +222,7 @@ static BAN::Vector> parse_command(BAN::StringView comma if (next == '\'' || next == '"') { if (i + 1 < command_view.size()) - MUST(current_arg.push_back(next)); + MUST(current_argument.push_back(next)); i++; continue; } @@ -219,25 +243,25 @@ static BAN::Vector> parse_command(BAN::StringView comma fprintf(stderr, "bad substitution\n"); return {}; } - MUST(current_arg.append(expansion.value())); + MUST(current_argument.append(expansion.value())); } else if (c == '|') { - if (!current_arg.empty()) - MUST(command_args.push_back(current_arg)); - current_arg.clear(); + if (!current_argument.empty()) + MUST(current_command.arguments.push_back(current_argument)); + current_argument.clear(); - MUST(result.push_back(command_args)); - command_args.clear(); + MUST(result.commands.push_back(current_command)); + current_command.arguments.clear(); } else if (!isspace(c)) - MUST(current_arg.push_back(c)); + MUST(current_argument.push_back(c)); else { - if (!current_arg.empty()) + if (!current_argument.empty()) { - MUST(command_args.push_back(current_arg)); - current_arg.clear(); + MUST(current_command.arguments.push_back(current_argument)); + current_argument.clear(); } } break; @@ -245,13 +269,13 @@ static BAN::Vector> parse_command(BAN::StringView comma if (c == '\'') state = State::Normal; else - MUST(current_arg.push_back(c)); + MUST(current_argument.push_back(c)); break; case State::DoubleQuote: if (c == '"') state = State::Normal; else if (c != '$') - MUST(current_arg.push_back(c)); + MUST(current_argument.push_back(c)); else { auto expansion = parse_dollar(command_view, i); @@ -260,26 +284,88 @@ static BAN::Vector> parse_command(BAN::StringView comma fprintf(stderr, "bad substitution\n"); return {}; } - MUST(current_arg.append(expansion.value())); + MUST(current_argument.append(expansion.value())); } break; } } // FIXME: handle state != State::Normal - MUST(command_args.push_back(BAN::move(current_arg))); - MUST(result.push_back(BAN::move(command_args))); + MUST(current_command.arguments.push_back(BAN::move(current_argument))); + MUST(result.commands.push_back(BAN::move(current_command))); - return result; + return BAN::move(result); } -static int execute_command(BAN::Vector& args, int fd_in, int fd_out); +static CommandList parse_command_list(BAN::StringView command_view) +{ + CommandList result; + CommandList::Condition next_condition = CommandList::Condition::Always; + for (size_t i = 0; i < command_view.size(); i++) + { + const char current = command_view[i]; + switch (current) + { + case '\\': + i++; + break; + case '\'': + case '"': + while (++i < command_view.size()) + { + if (command_view[i] == '\\') + i++; + else if (command_view[i] == current) + break; + } + break; + case ';': + MUST(result.commands.emplace_back( + strip_whitespace(command_view.substring(0, i)), + next_condition + )); + command_view = strip_whitespace(command_view.substring(i + 1)); + next_condition = CommandList::Condition::Always; + i = -1; + break; + case '|': + case '&': + if (i + 1 >= command_view.size() || command_view[i + 1] != current) + break; + MUST(result.commands.emplace_back( + strip_whitespace(command_view.substring(0, i)), + next_condition + )); + command_view = strip_whitespace(command_view.substring(i + 2)); + next_condition = (current == '|') ? CommandList::Condition::OnFailure : CommandList::Condition::OnSuccess; + i = -1; + break; + } + } + + MUST(result.commands.emplace_back( + strip_whitespace(command_view), + next_condition + )); + + for (const auto& [expression, _] : result.commands) + { + if (!expression.empty()) + continue; + fprintf(stderr, "expected an expression\n"); + return {}; + } + + return BAN::move(result); +} + +static int execute_command(const SingleCommand& command, int fd_in, int fd_out); static int source_script(const BAN::String& path); -static BAN::Optional execute_builtin(BAN::Vector& args, int fd_in, int fd_out) +static BAN::Optional execute_builtin(const SingleCommand& command, int fd_in, int fd_out) { - if (args.empty()) + if (command.arguments.empty()) return 0; FILE* fout = stdout; @@ -296,19 +382,19 @@ static BAN::Optional execute_builtin(BAN::Vector& args, int fd } BAN::ScopeGuard _([fout, should_close] { if (should_close) fclose(fout); }); - if (args.front() == "clear"_sv) + if (command.arguments.front() == "clear"_sv) { fprintf(fout, "\e[H\e[2J"); fflush(fout); } - else if (args.front() == "exit"_sv) + else if (command.arguments.front() == "exit"_sv) { exit(0); } - else if (args.front() == "export"_sv) + else if (command.arguments.front() == "export"_sv) { bool first = false; - for (const auto& arg : args) + for (const auto& argument : command.arguments) { if (first) { @@ -316,7 +402,7 @@ static BAN::Optional execute_builtin(BAN::Vector& args, int fd continue; } - auto split = MUST(arg.sv().split('=', true)); + auto split = MUST(argument.sv().split('=', true)); if (split.size() != 2) continue; @@ -324,24 +410,24 @@ static BAN::Optional execute_builtin(BAN::Vector& args, int fd ERROR_RETURN("setenv", 1); } } - else if (args.front() == "source"_sv) + else if (command.arguments.front() == "source"_sv) { - if (args.size() != 2) + if (command.arguments.size() != 2) { fprintf(fout, "usage: source FILE\n"); return 1; } - return source_script(args[1]); + return source_script(command.arguments[1]); } - else if (args.front() == "env"_sv) + else if (command.arguments.front() == "env"_sv) { char** current = environ; while (*current) fprintf(fout, "%s\n", *current++); } - else if (args.front() == "cd"_sv) + else if (command.arguments.front() == "cd"_sv) { - if (args.size() > 2) + if (command.arguments.size() > 2) { fprintf(fout, "cd: too many arguments\n"); return 1; @@ -349,7 +435,7 @@ static BAN::Optional execute_builtin(BAN::Vector& args, int fd BAN::StringView path; - if (args.size() == 1) + if (command.arguments.size() == 1) { if (const char* path_env = getenv("HOME")) path = path_env; @@ -357,21 +443,24 @@ static BAN::Optional execute_builtin(BAN::Vector& args, int fd return 0; } else - path = args[1]; + path = command.arguments[1]; if (chdir(path.data()) == -1) ERROR_RETURN("chdir", 1); } - else if (args.front() == "time"_sv) + else if (command.arguments.front() == "time"_sv) { - args.remove(0); + SingleCommand timed_command; + MUST(timed_command.arguments.reserve(command.arguments.size() - 1)); + for (size_t i = 1; i < command.arguments.size(); i++) + timed_command.arguments[i - 1] = command.arguments[i]; timespec start, end; if (clock_gettime(CLOCK_MONOTONIC, &start) == -1) ERROR_RETURN("clock_gettime", 1); - int ret = execute_command(args, fd_in, fd_out); + int ret = execute_command(timed_command, fd_in, fd_out); if (clock_gettime(CLOCK_MONOTONIC, &end) == -1) ERROR_RETURN("clock_gettime", 1); @@ -387,7 +476,7 @@ static BAN::Optional execute_builtin(BAN::Vector& args, int fd return ret; } - else if (args.front() == "start-gui"_sv) + else if (command.arguments.front() == "start-gui"_sv) { pid_t pid = fork(); if (pid == 0) @@ -404,20 +493,19 @@ static BAN::Optional execute_builtin(BAN::Vector& args, int fd return 0; } -static pid_t execute_command_no_wait(BAN::Vector& args, int fd_in, int fd_out, pid_t pgrp) +static pid_t execute_command_no_wait(const SingleCommand& command, int fd_in, int fd_out, pid_t pgrp) { - if (args.empty()) - return 0; + ASSERT(!command.arguments.empty()); BAN::Vector cmd_args; - MUST(cmd_args.reserve(args.size() + 1)); - for (const auto& arg : args) + MUST(cmd_args.reserve(command.arguments.size() + 1)); + for (const auto& arg : command.arguments) MUST(cmd_args.push_back((char*)arg.data())); MUST(cmd_args.push_back(nullptr)); // do PATH resolution BAN::String executable_file; - if (!args.front().sv().contains('/')) + if (!command.arguments.front().sv().contains('/')) { const char* path_env_cstr = getenv("PATH"); if (path_env_cstr == nullptr) @@ -428,7 +516,7 @@ static pid_t execute_command_no_wait(BAN::Vector& args, int fd_in, { BAN::String test_file = path_env; MUST(test_file.push_back('/')); - MUST(test_file.append(args.front())); + MUST(test_file.append(command.arguments.front())); struct stat st; if (stat(test_file.data(), &st) == 0) @@ -440,7 +528,7 @@ static pid_t execute_command_no_wait(BAN::Vector& args, int fd_in, } else { - executable_file = args.front(); + executable_file = command.arguments.front(); } // Verify that the file exists is executable @@ -448,7 +536,7 @@ static pid_t execute_command_no_wait(BAN::Vector& args, int fd_in, struct stat st; if (executable_file.empty() || stat(executable_file.data(), &st) == -1) { - fprintf(stderr, "command not found: %s\n", args.front().data()); + fprintf(stderr, "command not found: %s\n", command.arguments.front().data()); return -1; } if ((st.st_mode & 0111) == 0) @@ -458,7 +546,7 @@ static pid_t execute_command_no_wait(BAN::Vector& args, int fd_in, } } - pid_t pid = fork(); + const pid_t pid = fork(); if (pid == 0) { if (fd_in != STDIN_FILENO) @@ -503,9 +591,9 @@ static pid_t execute_command_no_wait(BAN::Vector& args, int fd_in, return pid; } -static int execute_command(BAN::Vector& args, int fd_in, int fd_out) +static int execute_command(const SingleCommand& command, int fd_in, int fd_out) { - pid_t pid = execute_command_no_wait(args, fd_in, fd_out, 0); + const pid_t pid = execute_command_no_wait(command, fd_in, fd_out, 0); if (pid == -1) return 1; @@ -522,27 +610,27 @@ static int execute_command(BAN::Vector& args, int fd_in, int fd_out return WEXITSTATUS(status); } -static int execute_piped_commands(BAN::Vector>& commands) +static int execute_piped_commands(const PipedCommand& piped_command) { - if (commands.empty()) + if (piped_command.commands.empty()) return 0; - if (commands.size() == 1) + if (piped_command.commands.size() == 1) { - auto& command = commands.front(); + auto& command = piped_command.commands.front(); if (auto ret = execute_builtin(command, STDIN_FILENO, STDOUT_FILENO); ret.has_value()) return ret.value(); return execute_command(command, STDIN_FILENO, STDOUT_FILENO); } - BAN::Vector exit_codes(commands.size(), 0); - BAN::Vector processes(commands.size(), -1); + BAN::Vector exit_codes(piped_command.commands.size(), 0); + BAN::Vector processes(piped_command.commands.size(), -1); pid_t pgrp = 0; int next_stdin = STDIN_FILENO; - for (size_t i = 0; i < commands.size(); i++) + for (size_t i = 0; i < piped_command.commands.size(); i++) { - bool last = (i == commands.size() - 1); + const bool last = (i == piped_command.commands.size() - 1); int pipefd[2] { -1, STDOUT_FILENO }; if (!last && pipe(pipefd) == -1) @@ -553,12 +641,12 @@ static int execute_piped_commands(BAN::Vector>& command break; } - auto builtin_ret = execute_builtin(commands[i], next_stdin, pipefd[1]); + auto builtin_ret = execute_builtin(piped_command.commands[i], next_stdin, pipefd[1]); if (builtin_ret.has_value()) exit_codes[i] = builtin_ret.value(); else { - pid_t pid = execute_command_no_wait(commands[i], next_stdin, pipefd[1], pgrp); + pid_t pid = execute_command_no_wait(piped_command.commands[i], next_stdin, pipefd[1], pgrp); processes[i] = pid; if (pgrp == 0) pgrp = pid; @@ -571,7 +659,7 @@ static int execute_piped_commands(BAN::Vector>& command next_stdin = pipefd[0]; } - for (size_t i = 0; i < commands.size(); i++) + for (size_t i = 0; i < piped_command.commands.size(); i++) { if (processes[i] == -1) continue; @@ -599,15 +687,42 @@ static int execute_piped_commands(BAN::Vector>& command static int parse_and_execute_command(BAN::StringView command) { + command = strip_whitespace(command); if (command.empty()) return 0; - auto parsed_commands = parse_command(command); - if (parsed_commands.empty()) + + auto command_list = parse_command_list(command); + if (command_list.commands.empty()) return 0; + tcsetattr(0, TCSANOW, &old_termios); - int ret = execute_piped_commands(parsed_commands); + + last_return = 0; + for (const auto& [expression, condition] : command_list.commands) + { + bool should_run = false; + switch (condition) + { + case CommandList::Condition::Always: + should_run = true; + break; + case CommandList::Condition::OnSuccess: + should_run = (last_return == 0); + break; + case CommandList::Condition::OnFailure: + should_run = (last_return != 0); + break; + } + + if (!should_run) + continue; + + last_return = execute_piped_commands(parse_piped_command(expression)); + } + tcsetattr(0, TCSANOW, &new_termios); - return ret; + + return last_return; } static int source_script(const BAN::String& path) @@ -814,9 +929,7 @@ int main(int argc, char** argv) printf("-c requires an argument\n"); return 1; } - - auto commands = parse_command(BAN::String(argv[i + 1])); - return execute_piped_commands(commands); + return parse_and_execute_command(BAN::String(argv[i + 1])); } else if (strcmp(argv[i], "-v") == 0 || strcmp(argv[i], "--version") == 0) { @@ -985,7 +1098,7 @@ int main(int argc, char** argv) putchar('\n'); if (!buffers[index].empty()) { - last_return = parse_and_execute_command(buffers[index]); + parse_and_execute_command(buffers[index]); MUST(history.push_back(buffers[index])); buffers = history; MUST(buffers.emplace_back(""_sv));