diff --git a/userspace/Shell/main.cpp b/userspace/Shell/main.cpp index b8b31f68..5791b87b 100644 --- a/userspace/Shell/main.cpp +++ b/userspace/Shell/main.cpp @@ -21,7 +21,7 @@ extern char** environ; static const char* argv0 = nullptr; static int last_return = 0; -BAN::Vector parse_command(BAN::StringView); +BAN::Vector> parse_command(BAN::StringView); BAN::Optional parse_dollar(BAN::StringView command, size_t& i) { @@ -151,7 +151,7 @@ BAN::Optional parse_dollar(BAN::StringView command, size_t& i) return "$"sv; } -BAN::Vector parse_command(BAN::StringView command) +BAN::Vector> parse_command(BAN::StringView command_view) { enum class State { @@ -160,13 +160,14 @@ BAN::Vector parse_command(BAN::StringView command) DoubleQuote, }; - BAN::Vector result; + BAN::Vector> result; + BAN::Vector command_args; State state = State::Normal; - BAN::String current; - for (size_t i = 0; i < command.size(); i++) + BAN::String current_arg; + for (size_t i = 0; i < command_view.size(); i++) { - char c = command[i]; + char c = command_view[i]; switch (state) { @@ -177,22 +178,31 @@ BAN::Vector parse_command(BAN::StringView command) state = State::DoubleQuote; else if (c == '$') { - auto expansion = parse_dollar(command, i); + auto expansion = parse_dollar(command_view, i); if (!expansion.has_value()) { fprintf(stderr, "bad substitution\n"); return {}; } - MUST(current.append(expansion.value())); + MUST(current_arg.append(expansion.value())); + } + else if (c == '|') + { + if (!current_arg.empty()) + MUST(command_args.push_back(current_arg)); + current_arg.clear(); + + MUST(result.push_back(command_args)); + command_args.clear(); } else if (!isspace(c)) - MUST(current.push_back(c)); + MUST(current_arg.push_back(c)); else { - if (!current.empty()) + if (!current_arg.empty()) { - MUST(result.push_back(current)); - current.clear(); + MUST(command_args.push_back(current_arg)); + current_arg.clear(); } } break; @@ -200,43 +210,59 @@ BAN::Vector parse_command(BAN::StringView command) if (c == '\'') state = State::Normal; else - MUST(current.push_back(c)); + MUST(current_arg.push_back(c)); break; case State::DoubleQuote: if (c == '"') state = State::Normal; else if (c != '$') - MUST(current.push_back(c)); + MUST(current_arg.push_back(c)); else { - auto expansion = parse_dollar(command, i); + auto expansion = parse_dollar(command_view, i); if (!expansion.has_value()) { fprintf(stderr, "bad substitution\n"); return {}; } - MUST(current.append(expansion.value())); + MUST(current_arg.append(expansion.value())); } break; } } // FIXME: handle state != State::Normal - MUST(result.push_back(BAN::move(current))); + MUST(command_args.push_back(BAN::move(current_arg))); + MUST(result.push_back(BAN::move(command_args))); return result; } -int execute_command(BAN::Vector& args) +int execute_command(BAN::Vector& args, int fd_in, int fd_out); + +BAN::Optional execute_builtin(BAN::Vector& args, int fd_in, int fd_out) { if (args.empty()) return 0; - + + FILE* fout = stdout; + bool should_close = false; + if (fd_out != STDOUT_FILENO) + { + int fd_dup = dup(fd_out); + if (fd_dup == -1) + ERROR_RETURN("dup", 1); + fout = fdopen(fd_dup, "w"); + if (fout == nullptr) + ERROR_RETURN("fdopen", 1); + should_close = true; + } + BAN::ScopeGuard _([fout, should_close] { if (should_close) fclose(fout); }); + if (args.front() == "clear"sv) { - fprintf(stdout, "\e[H\e[J"); - fflush(stdout); - return 0; + fprintf(fout, "\e[H\e[J"); + fflush(fout); } else if (args.front() == "exit"sv) { @@ -265,7 +291,7 @@ int execute_command(BAN::Vector& args) { char** current = environ; while (*current) - printf("%s\n", *current++); + fprintf(fout, "%s\n", *current++); } else if (args.front() == "page-fault-test"sv) { @@ -277,7 +303,7 @@ int execute_command(BAN::Vector& args) pid_t pid = fork(); if (pid == 0) { - printf("child\n"); + fprintf(fout, "child\n"); for (;;); } if (pid == -1) @@ -298,6 +324,7 @@ int execute_command(BAN::Vector& args) pid_t pid = fork(); if (pid == 0) { + dup2(fileno(fout), STDOUT_FILENO); if (signal(SIGSEGV, [](int) { printf("SIGSEGV\n"); }) == SIG_ERR) { perror("signal"); @@ -328,18 +355,18 @@ int execute_command(BAN::Vector& args) } else if (args.front() == "printf-test"sv) { - printf(" 0.0: %f\n", 0.0f); - printf(" 123.0: %f\n", 123.0f); - printf(" 0.123: %f\n", 0.123f); - printf(" NAN: %f\n", NAN); - printf("+INF: %f\n", INFINITY); - printf("-INF: %f\n", -INFINITY); + fprintf(fout, " 0.0: %f\n", 0.0f); + fprintf(fout, " 123.0: %f\n", 123.0f); + fprintf(fout, " 0.123: %f\n", 0.123f); + fprintf(fout, " NAN: %f\n", NAN); + fprintf(fout, "+INF: %f\n", INFINITY); + fprintf(fout, "-INF: %f\n", -INFINITY); } else if (args.front() == "cd"sv) { if (args.size() > 2) { - printf("cd: too many arguments\n"); + fprintf(fout, "cd: too many arguments\n"); return 1; } @@ -367,7 +394,7 @@ int execute_command(BAN::Vector& args) if (clock_gettime(CLOCK_MONOTONIC, &start) == -1) ERROR_RETURN("clock_gettime", 1); - int ret = execute_command(args); + int ret = execute_command(args, fd_in, fd_out); if (clock_gettime(CLOCK_MONOTONIC, &end) == -1) ERROR_RETURN("clock_gettime", 1); @@ -379,45 +406,154 @@ int execute_command(BAN::Vector& args) int secs = total_ns / 1'000'000'000; int msecs = (total_ns % 1'000'000'000) / 1'000'000; - printf("took %d.%03d s\n", secs, msecs); + fprintf(fout, "took %d.%03d s\n", secs, msecs); return ret; } else { - BAN::Vector cmd_args; - MUST(cmd_args.reserve(args.size() + 1)); - for (const auto& arg : args) - MUST(cmd_args.push_back((char*)arg.data())); - MUST(cmd_args.push_back(nullptr)); + return {}; + } - pid_t pid = fork(); - if (pid == 0) + return 0; +} + +pid_t execute_command_no_wait(BAN::Vector& args, int fd_in, int fd_out) +{ + if (args.empty()) + return 0; + + BAN::Vector cmd_args; + MUST(cmd_args.reserve(args.size() + 1)); + for (const auto& arg : args) + MUST(cmd_args.push_back((char*)arg.data())); + MUST(cmd_args.push_back(nullptr)); + + pid_t pid = fork(); + if (pid == 0) + { + if (fd_in != STDIN_FILENO) { - execv(cmd_args.front(), cmd_args.data()); - perror("execv"); - exit(1); + if (dup2(fd_in, STDIN_FILENO) == -1) + { + perror("dup2"); + exit(1); + } + close(fd_in); + } + if (fd_out != STDOUT_FILENO) + { + if (dup2(fd_out, STDOUT_FILENO) == -1) + { + perror("dup2"); + exit(1); + } + close(fd_out); } - if (pid == -1) - ERROR_RETURN("fork", 1); - if (tcsetpgrp(0, pid) == -1) - ERROR_RETURN("tcsetpgrp", 1); + execv(cmd_args.front(), cmd_args.data()); + perror("execv"); + exit(1); + } + + return pid; +} + +int execute_command(BAN::Vector& args, int fd_in, int fd_out) +{ + pid_t pid = execute_command_no_wait(args, fd_in, fd_out); + if (pid == -1) + ERROR_RETURN("fork", 1); + + if (tcsetpgrp(0, pid) == -1) + ERROR_RETURN("tcsetpgrp", 1); + + int status; + if (waitpid(pid, &status, 0) == -1) + ERROR_RETURN("waitpid", 1); + + if (tcsetpgrp(0, getpid()) == -1) + ERROR_RETURN("tcsetpgrp", 1); + + if (WIFSIGNALED(status)) + fprintf(stderr, "Terminated by signal %d\n", WTERMSIG(status)); + + return WEXITSTATUS(status); +} + +int execute_piped_commands(BAN::Vector>& commands) +{ + if (commands.empty()) + return 0; + + if (commands.size() == 1) + { + auto& command = commands.front(); + if (auto ret = execute_builtin(command, STDIN_FILENO, STDOUT_FILENO); ret.has_value()) + return ret.value(); + return execute_command(command, STDIN_FILENO, STDOUT_FILENO); + } + + BAN::Vector exit_codes(commands.size(), 0); + BAN::Vector processes(commands.size(), -1); + + int next_stdin = STDIN_FILENO; + for (size_t i = 0; i < commands.size(); i++) + { + bool first = (i == 0); + bool last = (i == commands.size() - 1); + + int pipefd[2] { -1, STDOUT_FILENO }; + if (!last && pipe(pipefd) == -1) + { + if (i > 0) + close(next_stdin); + perror("pipe"); + break; + } + + auto builtin_ret = execute_builtin(commands[i], next_stdin, pipefd[1]); + if (builtin_ret.has_value()) + exit_codes[i] = builtin_ret.value(); + else + { + pid_t pid = execute_command_no_wait(commands[i], next_stdin, pipefd[1]); + processes[i] = pid; + if (first && tcsetpgrp(0, pid) == -1) + ERROR_RETURN("tcsetpgrp", 1); + } + + if (next_stdin != STDIN_FILENO) + close(next_stdin); + if (pipefd[1] != STDOUT_FILENO) + close(pipefd[1]); + next_stdin = pipefd[0]; + } + + for (size_t i = 0; i < commands.size(); i++) + { + if (processes[i] == -1) + continue; int status; - if (waitpid(pid, &status, 0) == -1) - ERROR_RETURN("waitpid", 1); - - if (tcsetpgrp(0, getpid()) == -1) - ERROR_RETURN("tcsetpgrp", 1); + if (waitpid(processes[i], &status, 0) == -1) + { + perror("waitpid"); + exit_codes[i] = 69420; + continue; + } if (WIFSIGNALED(status)) fprintf(stderr, "Terminated by signal %d\n", WTERMSIG(status)); - return WEXITSTATUS(status); + if (WEXITSTATUS(status)) + exit_codes[i] = WEXITSTATUS(status); } - return 0; + if (tcsetpgrp(0, getpid()) == -1) + ERROR_RETURN("tcsetpgrp", 1); + + return exit_codes.back(); } int character_length(BAN::StringView prompt) @@ -544,8 +680,8 @@ int main(int argc, char** argv) BAN::String command; MUST(command.append(argv[2])); - auto arguments = parse_command(command); - return execute_command(arguments); + auto commands = parse_command(command); + return execute_piped_commands(commands); } printf("unknown argument '%s'\n", argv[1]); @@ -657,8 +793,8 @@ int main(int argc, char** argv) if (!buffers[index].empty()) { tcsetattr(0, TCSANOW, &old_termios); - auto parsed_arguments = parse_command(buffers[index]); - last_return = execute_command(parsed_arguments); + auto commands = parse_command(buffers[index]); + last_return = execute_piped_commands(commands); tcsetattr(0, TCSANOW, &new_termios); MUST(history.push_back(buffers[index])); buffers = history;