From 1f9b296ae73951d5081e7cc0605425e1b56e4ac8 Mon Sep 17 00:00:00 2001 From: Bananymous Date: Sun, 23 Nov 2025 02:25:05 +0200 Subject: [PATCH] cp: Add -r/--recursive flag --- userspace/programs/cp/CMakeLists.txt | 1 - userspace/programs/cp/main.cpp | 306 ++++++++++++++++++--------- 2 files changed, 204 insertions(+), 103 deletions(-) diff --git a/userspace/programs/cp/CMakeLists.txt b/userspace/programs/cp/CMakeLists.txt index 25f24ca3..8bdf6587 100644 --- a/userspace/programs/cp/CMakeLists.txt +++ b/userspace/programs/cp/CMakeLists.txt @@ -3,7 +3,6 @@ set(SOURCES ) add_executable(cp ${SOURCES}) -banan_include_headers(cp ban) banan_link_library(cp libc) install(TARGETS cp OPTIONAL) diff --git a/userspace/programs/cp/main.cpp b/userspace/programs/cp/main.cpp index e3d3acdc..16298e09 100644 --- a/userspace/programs/cp/main.cpp +++ b/userspace/programs/cp/main.cpp @@ -1,143 +1,245 @@ -#include -#include +#include +#include #include +#include +#include #include #include #include #include -#define STR_STARTS_WITH(str, arg) (strncmp(str, arg, sizeof(arg) - 1) == 0) -#define STR_EQUAL(str, arg) (strcmp(str, arg) == 0) - -bool copy_file(const BAN::String& source, BAN::String destination) +static int copy_file(const char* src, const char* dst) { - struct stat st; - if (stat(source.data(), &st) == -1) + struct stat src_st; + if (stat(src, &src_st) == -1) + return errno; + + struct stat dst_st; + if (stat(dst, &dst_st) == 0) { - fprintf(stderr, "%s: ", source.data()); - perror("stat"); - return false; - } - if (S_ISDIR(st.st_mode)) - { - fprintf(stderr, "%s: is a directory\n", source.data()); - return false; + if (S_ISDIR(dst_st.st_mode)) + return EINVAL; + if (unlinkat(AT_FDCWD, dst, 0) == -1) + return errno; } - if (stat(destination.data(), &st) != -1 && S_ISDIR(st.st_mode)) + if (S_ISREG(src_st.st_mode)) { - MUST(destination.push_back('/')); - MUST(destination.append(MUST(source.sv().split('/')).back())); - } + const int src_fd = open(src, O_RDONLY); + const int dst_fd = open(dst, O_RDWR | O_CREAT | O_EXCL, src_st.st_mode); - int src_fd = open(source.data(), O_RDONLY); - if (src_fd == -1) - { - fprintf(stderr, "%s: ", source.data()); - perror("open"); - return false; - } - - int dest_fd = open(destination.data(), O_CREAT | O_TRUNC | O_WRONLY, 0644); - if (dest_fd == -1) - { - fprintf(stderr, "%s: ", destination.data()); - perror("open"); - close(src_fd); - return false; - } - - bool ret = true; - char buffer[1024]; - while (ssize_t nread = read(src_fd, buffer, sizeof(buffer))) - { - if (nread < 0) + if (src_fd == -1 || dst_fd == -1) { - fprintf(stderr, "%s: ", source.data()); - perror("read"); - ret = false; - break; + if (src_fd != -1) + close(src_fd); + if (dst_fd != -1) + close(dst_fd); + return errno; } - ssize_t written = 0; - while (written < nread) + int result = 0; + + char buffer[512]; + for (;;) { - ssize_t nwrite = write(dest_fd, buffer, nread - written); - if (nwrite < 0) + const ssize_t nread = read(src_fd, buffer, 512); + if (nread <= 0) { - fprintf(stderr, "%s: ", destination.data()); - perror("write"); - ret = false; - } - if (nwrite <= 0) + if (nread == -1) + result = errno; break; - written += nwrite; + } + + ssize_t total_written = 0; + while (total_written < nread) + { + const ssize_t nwrite = write(dst_fd, buffer + total_written, nread - total_written); + if (nwrite < 0) + { + result = errno; + break; + } + total_written += nwrite; + } } - if (written < nread) - break; + close(src_fd); + close(dst_fd); + return result; } - close(src_fd); - close(dest_fd); - return ret; + if (S_ISLNK(src_st.st_mode)) + { + char* buffer = static_cast(malloc(512)); + if (buffer == nullptr) + return errno; + ssize_t buffer_size = 512; + + ssize_t link_len; + while ((link_len = readlink(src, buffer, buffer_size)) == buffer_size) + { + buffer_size *= 2; + void* new_buffer = realloc(buffer, buffer_size); + if (new_buffer == nullptr) + { + free(buffer); + return errno; + } + buffer = static_cast(new_buffer); + } + + int result = 0; + if (link_len == -1) + result = errno; + if (result == 0 && symlink(buffer, dst) == -1) + result = errno; + free(buffer); + return result; + } + + fprintf(stddbg, "move file with mode %07o to another filesystem\n", src_st.st_mode); + return ENOTSUP; } -bool copy_file_to_directory(const BAN::String& source, const BAN::String& destination) +static int copy_directory(const char* src, const char* dst) { - auto temp = destination; - MUST(temp.append(MUST(source.sv().split('/')).back())); - return copy_file(source, destination); -} + struct stat src_st; + if (stat(src, &src_st) == -1) + return errno; -void usage(const char* argv0, int ret) -{ - FILE* out = (ret == 0) ? stdout : stderr; - fprintf(out, "usage: %s [OPTIONS]... SOURCE... DEST\n", argv0); - fprintf(out, "Copies files SOURCE... to DEST\n"); - fprintf(out, "OPTIONS:\n"); - fprintf(out, " -h, --help\n"); - fprintf(out, " Show this message and exit\n"); - exit(ret); + struct stat dst_st; + if (stat(dst, &dst_st) == 0) + { + if (!S_ISDIR(dst_st.st_mode)) + return ENOTDIR; + if (rmdir(dst) == -1) + return errno; + } + + if (mkdir(dst, src_st.st_mode) == -1) + return errno; + + DIR* dirp = opendir(src); + if (dirp == nullptr) + return errno; + + int result = 0; + + dirent* dent; + while ((dent = readdir(dirp))) + { + if (strcmp(dent->d_name, ".") == 0 || strcmp(dent->d_name, "..") == 0) + continue; + + bool name_too_long = false; + + char src_buffer[PATH_MAX]; + if (snprintf(src_buffer, PATH_MAX, "%s/%s", src, dent->d_name) > PATH_MAX) + name_too_long = true; + + char dst_buffer[PATH_MAX]; + if (snprintf(dst_buffer, PATH_MAX, "%s/%s", dst, dent->d_name) > PATH_MAX) + name_too_long = true; + + if (name_too_long) + result = ENAMETOOLONG; + else + { + auto* copy_func = (dent->d_type == DT_DIR) ? copy_directory : copy_file; + if (int ret = copy_func(src_buffer, dst_buffer)) + result = ret; + } + } + + closedir(dirp); + + return result; } int main(int argc, char** argv) { - BAN::Vector src; - BAN::StringView dest; + bool recursive = false; - int i = 1; - for (; i < argc; i++) + for (;;) { - if (STR_EQUAL(argv[i], "-h") || STR_EQUAL(argv[i], "--help")) - { - usage(argv[0], 0); - } - else if (argv[i][0] == '-') - { - fprintf(stderr, "Unknown argument %s\n", argv[i]); - usage(argv[0], 1); - } - else - { + static option long_options[] { + { "help", no_argument, nullptr, 'h' }, + { "recursive", no_argument, nullptr, 'r' }, + }; + + int ch = getopt_long(argc, argv, "hr", long_options, nullptr); + if (ch == -1) break; + + switch (ch) + { + case 'h': + printf("usage: %s [OPTIONS]... SOURCE... DEST\n", argv[0]); + printf("Copies files SOURCE... to DEST\n"); + printf("OPTIONS:\n"); + printf(" -h, --help Show this message and exit\n"); + return 0; + case 'r': + recursive = true; + break; + case '?': + fprintf(stderr, "invalid option %c\n", optopt); + fprintf(stderr, "see '%s --help' for usage\n", argv[0]); + return 1; } } - for (; i < argc - 1; i++) - MUST(src.push_back(argv[i])); - dest = argv[argc - 1]; - - if (src.empty()) + const int src_count = argc - optind - 1; + if (src_count < 1) { - fprintf(stderr, "Missing destination operand\n"); - usage(argv[0], 1); + fprintf(stderr, "missing destination operand\n"); + return 1; } - int ret = 0; - for (auto file_path : src) - if (!copy_file(file_path, dest)) - ret = 1; + const char* dest = argv[argc - 1]; - return ret; + if (src_count >= 2) + { + struct stat st; + if (stat(dest, &st) == -1 || !S_ISDIR(st.st_mode)) + { + fprintf(stderr, "destination is not a directory\n"); + return 1; + } + } + + int result = 0; + for (int i = optind; i < argc - 1; i++) + { + struct stat src_st; + if (stat(argv[i], &src_st) == -1) + { + fprintf(stderr, "%s: %s\n", argv[0], strerror(errno)); + continue; + } + + if (!recursive && S_ISDIR(src_st.st_mode)) + { + fprintf(stderr, "%s: %s\n", argv[0], strerror(EISDIR)); + continue; + } + + struct stat dst_st; + if (stat(dest, &dst_st) == 0 && S_ISDIR(dst_st.st_mode)) + { + static char buffer[PATH_MAX]; + if (snprintf(buffer, PATH_MAX, "%s/%s", dest, basename(argv[i])) > PATH_MAX) + return ENAMETOOLONG; + dest = buffer; + } + + auto* copy_func = S_ISDIR(src_st.st_mode) ? copy_directory : copy_file; + if (int ret = copy_func(argv[i], dest); ret != 0) + { + fprintf(stderr, "%s: %s\n", argv[0], strerror(ret)); + result = 1; + } + } + + return result; }