From 6d27595195900bab01c02742009a4fdd9a8e5e11 Mon Sep 17 00:00:00 2001 From: LI Qing Date: Tue, 28 Apr 2020 06:26:07 +0000 Subject: [PATCH] Fix the negative offset check for pread/pwrite --- src/libos/src/fs/file_ops/read.rs | 7 +- src/libos/src/fs/file_ops/write.rs | 7 +- src/libos/src/fs/syscalls.rs | 4 +- src/libos/src/syscall/mod.rs | 4 +- test/file/main.c | 312 +++++++++++++++++++---------- 5 files changed, 221 insertions(+), 113 deletions(-) diff --git a/src/libos/src/fs/file_ops/read.rs b/src/libos/src/fs/file_ops/read.rs index 17b1294a..7af358fe 100644 --- a/src/libos/src/fs/file_ops/read.rs +++ b/src/libos/src/fs/file_ops/read.rs @@ -12,8 +12,11 @@ pub fn do_readv(fd: FileDesc, bufs: &mut [&mut [u8]]) -> Result { file_ref.readv(bufs) } -pub fn do_pread(fd: FileDesc, buf: &mut [u8], offset: usize) -> Result { +pub fn do_pread(fd: FileDesc, buf: &mut [u8], offset: off_t) -> Result { debug!("pread: fd: {}, offset: {}", fd, offset); + if offset < 0 { + return_errno!(EINVAL, "the offset is negative"); + } let file_ref = current!().file(fd)?; - file_ref.read_at(offset, buf) + file_ref.read_at(offset as usize, buf) } diff --git a/src/libos/src/fs/file_ops/write.rs b/src/libos/src/fs/file_ops/write.rs index 2be419c6..34e736a9 100644 --- a/src/libos/src/fs/file_ops/write.rs +++ b/src/libos/src/fs/file_ops/write.rs @@ -12,8 +12,11 @@ pub fn do_writev(fd: FileDesc, bufs: &[&[u8]]) -> Result { file_ref.writev(bufs) } -pub fn do_pwrite(fd: FileDesc, buf: &[u8], offset: usize) -> Result { +pub fn do_pwrite(fd: FileDesc, buf: &[u8], offset: off_t) -> Result { debug!("pwrite: fd: {}, offset: {}", fd, offset); + if offset < 0 { + return_errno!(EINVAL, "the offset is negative"); + } let file_ref = current!().file(fd)?; - file_ref.write_at(offset, buf) + file_ref.write_at(offset as usize, buf) } diff --git a/src/libos/src/fs/syscalls.rs b/src/libos/src/fs/syscalls.rs index 27494127..da1debc2 100644 --- a/src/libos/src/fs/syscalls.rs +++ b/src/libos/src/fs/syscalls.rs @@ -124,7 +124,7 @@ pub fn do_readv(fd: FileDesc, iov: *mut iovec_t, count: i32) -> Result { Ok(len as isize) } -pub fn do_pread(fd: FileDesc, buf: *mut u8, size: usize, offset: usize) -> Result { +pub fn do_pread(fd: FileDesc, buf: *mut u8, size: usize, offset: off_t) -> Result { let safe_buf = { from_user::check_mut_array(buf, size)?; unsafe { std::slice::from_raw_parts_mut(buf, size) } @@ -133,7 +133,7 @@ pub fn do_pread(fd: FileDesc, buf: *mut u8, size: usize, offset: usize) -> Resul Ok(len as isize) } -pub fn do_pwrite(fd: FileDesc, buf: *const u8, size: usize, offset: usize) -> Result { +pub fn do_pwrite(fd: FileDesc, buf: *const u8, size: usize, offset: off_t) -> Result { let safe_buf = { from_user::check_array(buf, size)?; unsafe { std::slice::from_raw_parts(buf, size) } diff --git a/src/libos/src/syscall/mod.rs b/src/libos/src/syscall/mod.rs index a1d9e7d7..d556878e 100644 --- a/src/libos/src/syscall/mod.rs +++ b/src/libos/src/syscall/mod.rs @@ -88,8 +88,8 @@ macro_rules! process_syscall_table_with_callback { (RtSigprocmask = 14) => do_rt_sigprocmask(), (RtSigreturn = 15) => handle_unsupported(), (Ioctl = 16) => do_ioctl(fd: FileDesc, cmd: u32, argp: *mut u8), - (Pread64 = 17) => do_pread(fd: FileDesc, buf: *mut u8, size: usize, offset: usize), - (Pwrite64 = 18) => do_pwrite(fd: FileDesc, buf: *const u8, size: usize, offset: usize), + (Pread64 = 17) => do_pread(fd: FileDesc, buf: *mut u8, size: usize, offset: off_t), + (Pwrite64 = 18) => do_pwrite(fd: FileDesc, buf: *const u8, size: usize, offset: off_t), (Readv = 19) => do_readv(fd: FileDesc, iov: *mut iovec_t, count: i32), (Writev = 20) => do_writev(fd: FileDesc, iov: *const iovec_t, count: i32), (Access = 21) => do_access(path: *const i8, mode: u32), diff --git a/test/file/main.c b/test/file/main.c index dc0ebb77..00a8f3c6 100644 --- a/test/file/main.c +++ b/test/file/main.c @@ -1,116 +1,218 @@ -#include +#include #include +#include #include #include #include #include +#include "test.h" -int main(int argc, const char* argv[]) { - const char* file_name = "/root/test_filesystem_file_read_write.txt"; - int fd, flags, mode, len; - off_t offset; - const char* write_msg = "Hello World\n"; - char read_buf[128] = {0}; +// ============================================================================ +// Helper function +// ============================================================================ - // write - flags = O_WRONLY | O_CREAT| O_TRUNC; - mode = 00666; - if ((fd = open(file_name, flags, mode)) < 0) { - printf("ERROR: failed to open a file for write\n"); - return -1; - } - if ((len = write(fd, write_msg, strlen(write_msg))) <= 0) { - printf("ERROR: failed to write to the file\n"); - return -1; - } - - // lseek - if ((offset = lseek(fd, 0, SEEK_END)) != 12) { - printf("ERROR: failed to lseek the file\n"); - return -1; - } - close(fd); - - // read - flags = O_RDONLY; - if ((fd = open(file_name, flags)) < 0) { - printf("ERROR: failed to open a file for read\n"); - return -1; - } - if ((len = read(fd, read_buf, sizeof(read_buf) - 1)) <= 0) { - printf("ERROR: failed to read from the file\n"); - return -1; +static int create_file(const char *file_path) { + int fd; + int flags = O_RDONLY | O_CREAT| O_TRUNC; + int mode = 00666; + fd = open(file_path, flags, mode); + if (fd < 0) { + THROW_ERROR("failed to create a file"); } close(fd); - - if (strcmp(write_msg, read_buf) != 0) { - printf("ERROR: the message read from the file is not as it was written\n"); - return -1; - } - - // writev - flags = O_WRONLY | O_CREAT| O_TRUNC; - if ((fd = open(file_name, flags)) < 0) { - printf("ERROR: failed to open a file for write\n"); - return -1; - } - - const char* iov_msg[2] = {"hello_", "world!"}; - struct iovec iov[2]; - for(int i=0; i<2; ++i) { - iov[i].iov_base = (void*)iov_msg[i]; - iov[i].iov_len = strlen(iov_msg[i]); - } - if ((len = writev(fd, iov, 2)) != 12) { - printf("ERROR: failed to write vectors to the file\n"); - return -1; - } - - // pwrite - if ((len = pwrite(fd, " ", 1, 5)) != 1) { - printf("ERROR: failed to pwrite to the file\n"); - } - - close(fd); - - flags = O_RDONLY; - if ((fd = open(file_name, flags)) < 0) { - printf("ERROR: failed to open a file for read\n"); - return -1; - } - - // lseek - if ((offset = lseek(fd, 2, SEEK_SET)) != 2) { - printf("ERROR: failed to lseek the file\n"); - return -1; - } - - // readv - iov[0].iov_base = read_buf; - iov[0].iov_len = 3; - iov[1].iov_base = read_buf + 5; - iov[1].iov_len = 20; - if ((len = readv(fd, iov, 2)) != 10) { - printf("ERROR: failed to read vectors from the file\n"); - return -1; - } - - if (memcmp(read_buf, "llo", 3) != 0 - || memcmp(read_buf + 5, " world!", 7) != 0) { - printf("ERROR: the message read from the file is not as it was written\n"); - return -1; - } - - // pread - if ((len = pread(fd, read_buf, sizeof(read_buf) - 1, 4)) != 8) { - printf("ERROR: failed to pread from the file\n"); - } - if (memcmp(read_buf, "o world!", 8) != 0) { - printf("ERROR: the message read from the file is not as it was written\n"); - return -1; - } - close(fd); - - printf("File write and read successfully\n"); return 0; } + +static int remove_file(const char *file_path) { + int ret; + ret = unlink(file_path); + if (ret < 0) { + THROW_ERROR("failed to unlink the created file"); + } + return 0; +} + +// ============================================================================ +// Test cases for file +// ============================================================================ + +static int __test_write_read(const char *file_path) { + char *write_str = "Hello World\n"; + char read_buf[128] = { 0 }; + int fd; + + fd = open(file_path, O_WRONLY); + if (fd < 0) { + THROW_ERROR("failed to open a file to write"); + } + if (write(fd, write_str, strlen(write_str)) <= 0) { + THROW_ERROR("failed to write"); + } + close(fd); + fd = open(file_path, O_RDONLY); + if (fd < 0) { + THROW_ERROR("failed to open a file to read"); + } + if (read(fd, read_buf, sizeof(read_buf)) != strlen(write_str)) { + THROW_ERROR("failed to read"); + } + if (strcmp(write_str, read_buf) != 0) { + THROW_ERROR("the message read from the file is not as it was written"); + } + close(fd); + return 0; +} + +static int __test_pwrite_pread(const char *file_path) { + char *write_str = "Hello World\n"; + char read_buf[128] = { 0 }; + int ret, fd; + + fd = open(file_path, O_WRONLY); + if (fd < 0) { + THROW_ERROR("failed to open a file to pwrite"); + } + if (pwrite(fd, write_str, strlen(write_str), 1) <= 0) { + THROW_ERROR("failed to pwrite"); + } + ret = pwrite(fd, write_str, strlen(write_str), -1); + if (ret >= 0 || errno != EINVAL) { + THROW_ERROR("check pwrite with negative offset fail"); + } + close(fd); + fd = open(file_path, O_RDONLY); + if (fd < 0) { + THROW_ERROR("failed to open a file to pread"); + } + if (pread(fd, read_buf, sizeof(read_buf), 1) != strlen(write_str)) { + THROW_ERROR("failed to pread"); + } + if (strcmp(write_str, read_buf) != 0) { + THROW_ERROR("the message read from the file is not as it was written"); + } + ret = pread(fd, write_str, strlen(write_str), -1); + if (ret >= 0 || errno != EINVAL) { + THROW_ERROR("check pread with negative offset fail"); + } + close(fd); + return 0; +} + +static int __test_writev_readv(const char *file_path) { + const char* iov_msg[2] = {"hello_", "world!"}; + char read_buf[128] = { 0 }; + struct iovec iov[2]; + int fd, len = 0; + + fd = open(file_path, O_WRONLY); + if (fd < 0) { + THROW_ERROR("failed to open a file to writev"); + } + for(int i = 0; i < 2; ++i) { + iov[i].iov_base = (void*)iov_msg[i]; + iov[i].iov_len = strlen(iov_msg[i]); + len += iov[i].iov_len; + } + if (writev(fd, iov, 2) != len) { + THROW_ERROR("failed to write vectors to the file"); + return -1; + } + close(fd); + fd = open(file_path, O_RDONLY); + if (fd < 0) { + THROW_ERROR("failed to open a file to readv"); + } + iov[0].iov_base = read_buf; + iov[0].iov_len = strlen(iov_msg[0]); + iov[1].iov_base = read_buf + strlen(iov_msg[0]); + iov[1].iov_len = strlen(iov_msg[1]); + if (readv(fd, iov, 2) != len) { + THROW_ERROR("failed to read vectors from the file"); + } + if (memcmp(read_buf, iov_msg[0], strlen(iov_msg[0])) != 0 || + memcmp(read_buf + strlen(iov_msg[0]), iov_msg[1], strlen(iov_msg[1])) != 0) { + THROW_ERROR("the message read from the file is not as it was written"); + } + close(fd); + return 0; +} + +static int __test_lseek(const char *file_path) { + char *write_str = "Hello World\n"; + char read_buf[128] = { 0 }; + int fd, offset, ret; + + fd = open(file_path, O_RDWR); + if (fd < 0) { + THROW_ERROR("failed to open a file to read/write"); + } + if (write(fd, write_str, strlen(write_str)) <= 0) { + THROW_ERROR("failed to write"); + } + /* make sure offset is in range (0, strlen(write_str)) */ + offset = 2; + if (lseek(fd, offset, SEEK_SET) != offset) { + THROW_ERROR("failed to lseek the file"); + } + if (read(fd, read_buf, sizeof(read_buf)) >= strlen(write_str)) { + THROW_ERROR("failed to read from offset"); + } + if (strcmp(write_str + offset, read_buf) != 0) { + THROW_ERROR("the message read from the offset is wrong"); + } + offset = -1; + ret = lseek(fd, offset, SEEK_SET); + if (ret >= 0 || errno != EINVAL) { + THROW_ERROR("check lseek with negative offset fail"); + } + if (lseek(fd, 0, SEEK_END) != strlen(write_str)) { + THROW_ERROR("faild to lseek to the end of the file"); + } + close(fd); + return 0; +} + +typedef int(*test_file_func_t)(const char *); + +static int test_file_framework(test_file_func_t fn) { + const char *file_path = "/root/test_filesystem_file_read_write.txt"; + + if (create_file(file_path) < 0) + return -1; + if (fn(file_path) < 0) + return -1; + if (remove_file(file_path) < 0) + return -1; + return 0; +} + +static int test_write_read() { + return test_file_framework(__test_write_read); +} + +static int test_pwrite_pread() { + return test_file_framework(__test_pwrite_pread); +} + +static int test_writev_readv() { + return test_file_framework(__test_writev_readv); +} + +static int test_lseek() { + return test_file_framework(__test_lseek); +} + +// ============================================================================ +// Test suite main +// ============================================================================ + +static test_case_t test_cases[] = { + TEST_CASE(test_write_read), + TEST_CASE(test_pwrite_pread), + TEST_CASE(test_writev_readv), + TEST_CASE(test_lseek), +}; + +int main(int argc, const char *argv[]) { + return test_suite_run(test_cases, ARRAY_SIZE(test_cases)); +}