diff --git a/src/libos/src/file.rs b/src/libos/src/file.rs index f653c0c6..367f114f 100644 --- a/src/libos/src/file.rs +++ b/src/libos/src/file.rs @@ -8,8 +8,9 @@ use std::io::{Read, Write, Seek, SeekFrom}; pub trait File : Debug + Sync + Send { fn read(&self, buf: &mut [u8]) -> Result; fn write(&self, buf: &[u8]) -> Result; + fn readv<'a, 'b>(&self, bufs: &'a mut [&'b mut [u8]]) -> Result; + fn writev<'a, 'b>(&self, bufs: &'a [&'b [u8]]) -> Result; //pub seek(&mut self, ) -> Result; - } pub type FileRef = Arc>; @@ -48,6 +49,18 @@ impl File for SgxFile { let inner = inner_guard.borrow_mut(); inner.write(buf) } + + fn readv<'a, 'b>(&self, bufs: &'a mut [&'b mut [u8]]) -> Result { + let mut inner_guard = self.inner.lock().unwrap(); + let inner = inner_guard.borrow_mut(); + inner.readv(bufs) + } + + fn writev<'a, 'b>(&self, bufs: &'a [&'b [u8]]) -> Result { + let mut inner_guard = self.inner.lock().unwrap(); + let inner = inner_guard.borrow_mut(); + inner.writev(bufs) + } } #[derive(Clone)] @@ -67,6 +80,7 @@ impl SgxFileInner { let file = file_guard.borrow_mut(); let seek_pos = SeekFrom::Start(self.pos as u64); + // TODO: recover from error file.seek(seek_pos).map_err( |e| Error::new(Errno::EINVAL, "Failed to seek to a position"))?; @@ -95,6 +109,70 @@ impl SgxFileInner { self.pos += read_len; Ok(read_len) } + + pub fn writev<'a, 'b>(&mut self, bufs: &'a [&'b [u8]]) -> Result { + let mut file_guard = self.file.lock().unwrap(); + let file = file_guard.borrow_mut(); + + let seek_pos = SeekFrom::Start(self.pos as u64); + file.seek(seek_pos).map_err( + |e| Error::new(Errno::EINVAL, "Failed to seek to a position"))?; + + let mut total_bytes = 0; + for buf in bufs { + match file.write(buf) { + Ok(this_bytes) => { + if this_bytes == 0 { break; } + + total_bytes += this_bytes; + if this_bytes < buf.len() { break; } + } + Err(e) => { + match total_bytes { + // a complete failure + 0 => return Err(Error::new(Errno::EINVAL, "Failed to write")), + // a partially failure + _ => break, + } + } + } + } + + self.pos += total_bytes; + Ok(total_bytes) + } + + fn readv<'a, 'b>(&mut self, bufs: &'a mut [&'b mut [u8]]) -> Result { + let mut file_guard = self.file.lock().unwrap(); + let file = file_guard.borrow_mut(); + + let seek_pos = SeekFrom::Start(self.pos as u64); + file.seek(seek_pos).map_err( + |e| Error::new(Errno::EINVAL, "Failed to seek to a position"))?; + + let mut total_bytes = 0; + for buf in bufs { + match file.read(buf) { + Ok(this_bytes) => { + if this_bytes == 0 { break; } + + total_bytes += this_bytes; + if this_bytes < buf.len() { break; } + } + Err(e) => { + match total_bytes { + // a complete failure + 0 => return Err(Error::new(Errno::EINVAL, "Failed to write")), + // a partially failure + _ => break, + } + } + } + } + + self.pos += total_bytes; + Ok(total_bytes) + } } unsafe impl Send for SgxFileInner {} @@ -130,6 +208,33 @@ impl File for StdoutFile { fn read(&self, buf: &mut [u8]) -> Result { Err(Error::new(Errno::EBADF, "Stdout does not support reading")) } + + fn writev<'a, 'b>(&self, bufs: &'a [&'b [u8]]) -> Result { + let mut guard = self.inner.lock(); + let mut total_bytes = 0; + for buf in bufs { + match guard.write(buf) { + Ok(this_len) => { + if this_len == 0 { break; } + total_bytes += this_len; + if this_len < buf.len() { break; } + } + Err(e) => { + match total_bytes { + // a complete failure + 0 => return Err(Error::new(Errno::EINVAL, "Failed to write")), + // a partially failure + _ => break, + } + } + } + } + Ok(total_bytes) + } + + fn readv<'a, 'b>(&self, bufs: &'a mut [&'b mut [u8]]) -> Result { + Err(Error::new(Errno::EBADF, "Stdout does not support reading")) + } } impl Debug for StdoutFile { @@ -165,6 +270,33 @@ impl File for StdinFile { fn write(&self, buf: &[u8]) -> Result { Err(Error::new(Errno::EBADF, "Stdin does not support reading")) } + + fn readv<'a, 'b>(&self, bufs: &'a mut [&'b mut [u8]]) -> Result { + let mut guard = self.inner.lock(); + let mut total_bytes = 0; + for buf in bufs { + match guard.read(buf) { + Ok(this_len) => { + if this_len == 0 { break; } + total_bytes += this_len; + if this_len < buf.len() { break; } + } + Err(e) => { + match total_bytes { + // a complete failure + 0 => return Err(Error::new(Errno::EINVAL, "Failed to write")), + // a partially failure + _ => break, + } + } + } + } + Ok(total_bytes) + } + + fn writev<'a, 'b>(&self, bufs: &'a [&'b [u8]]) -> Result { + Err(Error::new(Errno::EBADF, "Stdin does not support reading")) + } } impl Debug for StdinFile { diff --git a/src/libos/src/fs.rs b/src/libos/src/fs.rs index a40f8b0b..18c60b16 100644 --- a/src/libos/src/fs.rs +++ b/src/libos/src/fs.rs @@ -63,6 +63,22 @@ pub fn do_read(fd: FileDesc, buf: &mut [u8]) -> Result { file_ref.read(buf) } +pub fn do_writev<'a, 'b>(fd: FileDesc, bufs: &'a [&'b [u8]]) -> Result { + let current_ref = process::get_current(); + let current_process = current_ref.lock().unwrap(); + let file_ref = current_process.file_table.get(fd) + .ok_or_else(|| Error::new(Errno::EBADF, "Invalid file descriptor [do_write]"))?; + file_ref.writev(bufs) +} + +pub fn do_readv<'a, 'b>(fd: FileDesc, bufs: &'a mut [&'b mut [u8]]) -> Result { + let current_ref = process::get_current(); + let current_process = current_ref.lock().unwrap(); + let file_ref = current_process.file_table.get(fd) + .ok_or_else(|| Error::new(Errno::EBADF, "Invalid file descriptor [do_read]"))?; + file_ref.readv(bufs) +} + pub fn do_close(fd: FileDesc) -> Result<(), Error> { let current_ref = process::get_current(); let mut current_process = current_ref.lock().unwrap(); diff --git a/src/libos/src/process.rs b/src/libos/src/process.rs index 44a9d983..43666822 100644 --- a/src/libos/src/process.rs +++ b/src/libos/src/process.rs @@ -77,6 +77,7 @@ pub fn do_spawn>(elf_path: &P, argv: &[CString], envp: &[CString] //let stdin = Arc::new(SgxMutex::new(Box::new(StdinFile::new()))); let stdin : Arc> = Arc::new(Box::new(StdinFile::new())); let stdout : Arc> = Arc::new(Box::new(StdoutFile::new())); + // TODO: implement and use a real stderr let stderr = stdout.clone(); file_table.put(stdin); file_table.put(stdout); diff --git a/src/libos/src/syscall.h b/src/libos/src/syscall.h index 494901c0..7a240a1e 100644 --- a/src/libos/src/syscall.h +++ b/src/libos/src/syscall.h @@ -4,6 +4,8 @@ #include #include "syscall_nr.h" +struct iovec; + #ifdef __cplusplus extern "C" { #endif @@ -12,6 +14,8 @@ extern int occlum_open(const char* path, int flags, int mode); extern int occlum_close(int fd); extern ssize_t occlum_read(int fd, void* buf, size_t size); extern ssize_t occlum_write(int fd, const void* buf, size_t size); +extern ssize_t occlum_readv(int fd, struct iovec* iov, int count); +extern ssize_t occlum_writev(int fd, const struct iovec* iov, int count); extern int occlum_spawn(int* child_pid, const char* path, const char** argv, const char** envp); diff --git a/src/libos/src/syscall.rs b/src/libos/src/syscall.rs index 22cd49ad..ea94cead 100644 --- a/src/libos/src/syscall.rs +++ b/src/libos/src/syscall.rs @@ -5,18 +5,33 @@ use std::ffi::{CStr, CString}; //use std::libc_fs as fs; //use std::libc_io as io; -fn check_ptr_from_user(user_ptr: *const T) -> Result<*const T, Error> { - Ok(user_ptr) +#[allow(non_camel_case_types)] +pub struct iovec_t { + base: *const c_void, + len: size_t, } -fn check_mut_ptr_from_user(user_ptr: *mut T) -> Result<*mut T, Error> { - Ok(user_ptr) + +fn check_ptr_from_user(user_ptr: *const T) -> Result<(), Error> { + Ok(()) +} + +fn check_mut_ptr_from_user(user_ptr: *mut T) -> Result<(), Error> { + Ok(()) +} + +fn check_array_from_user(user_buf: *const T, count: usize) -> Result<(), Error> { + Ok(()) +} + +fn check_mut_array_from_user(user_buf: *mut T, count: usize) -> Result<(), Error> { + Ok(()) } fn clone_string_from_user_safely(user_ptr: *const c_char) -> Result { - let user_ptr = check_ptr_from_user(user_ptr)?; + check_ptr_from_user(user_ptr)?; let string = unsafe { CStr::from_ptr(user_ptr).to_string_lossy().into_owned() }; @@ -31,6 +46,93 @@ fn clone_cstrings_from_user_safely(user_ptr: *const *const c_char) } +fn do_read(fd: c_int, buf: *mut c_void, size: size_t) + -> Result +{ + let fd = fd as file_table::FileDesc; + let safe_buf = { + let buf = buf as *mut u8; + let size = size as usize; + check_mut_array_from_user(buf, size)?; + unsafe { std::slice::from_raw_parts_mut(buf, size) } + }; + fs::do_read(fd, safe_buf) +} + +fn do_write(fd: c_int, buf: *const c_void, size: size_t) + -> Result +{ + let fd = fd as file_table::FileDesc; + let safe_buf = { + let buf = buf as *mut u8; + let size = size as usize; + check_array_from_user(buf, size)?; + unsafe { std::slice::from_raw_parts(buf, size) } + }; + fs::do_write(fd, safe_buf) +} + +fn do_writev(fd: c_int, iov: *const iovec_t, count: c_int) + -> Result +{ + let fd = fd as file_table::FileDesc; + + let count = { + if count < 0 { + return Err(Error::new(Errno::EINVAL, "Invalid count of iovec")); + } + count as usize + }; + + check_array_from_user(iov, count); + let bufs_vec = { + let mut bufs_vec = Vec::with_capacity(count); + for iov_i in 0..count { + let iov_ptr = unsafe { iov.offset(iov_i as isize) }; + let iov = unsafe { &*iov_ptr }; + let buf = unsafe { + std::slice::from_raw_parts(iov.base as * const u8, iov.len) + }; + bufs_vec[iov_i] = buf; + } + bufs_vec + }; + let bufs = &bufs_vec[..]; + + fs::do_writev(fd, bufs) +} + +fn do_readv(fd: c_int, iov: *mut iovec_t, count: c_int) + -> Result +{ + let fd = fd as file_table::FileDesc; + + let count = { + if count < 0 { + return Err(Error::new(Errno::EINVAL, "Invalid count of iovec")); + } + count as usize + }; + + check_array_from_user(iov, count); + let mut bufs_vec = { + let mut bufs_vec = Vec::with_capacity(count); + for iov_i in 0..count { + let iov_ptr = unsafe { iov.offset(iov_i as isize) }; + let iov = unsafe { &*iov_ptr }; + let buf = unsafe { + std::slice::from_raw_parts_mut(iov.base as * mut u8, iov.len) + }; + bufs_vec[iov_i] = buf; + } + bufs_vec + }; + let bufs = &mut bufs_vec[..]; + + fs::do_readv(fd, bufs) +} + + #[no_mangle] pub extern "C" fn occlum_open(path_buf: * const c_char, flags: c_int, mode: c_int) -> c_int { let path = unsafe { @@ -60,10 +162,7 @@ pub extern "C" fn occlum_close(fd: c_int) -> c_int { #[no_mangle] pub extern "C" fn occlum_read(fd: c_int, buf: * mut c_void, size: size_t) -> ssize_t { - let buf = unsafe { - std::slice::from_raw_parts_mut(buf as *mut u8, size as usize) - }; - match fs::do_read(fd as file_table::FileDesc, buf) { + match do_read(fd, buf, size) { Ok(read_len) => { read_len as ssize_t }, @@ -75,16 +174,7 @@ pub extern "C" fn occlum_read(fd: c_int, buf: * mut c_void, size: size_t) -> ssi #[no_mangle] pub extern "C" fn occlum_write(fd: c_int, buf: * const c_void, size: size_t) -> ssize_t { -/* let str_from_c = unsafe { - CStr::from_ptr(buf as * const i8).to_string_lossy().into_owned() - }; - println!("occlum_write: {}", str_from_c); - size as ssize_t -*/ - let buf = unsafe { - std::slice::from_raw_parts(buf as *const u8, size as usize) - }; - match fs::do_write(fd as file_table::FileDesc, buf) { + match do_write(fd, buf, size) { Ok(write_len) => { write_len as ssize_t }, @@ -94,6 +184,31 @@ pub extern "C" fn occlum_write(fd: c_int, buf: * const c_void, size: size_t) -> } } +#[no_mangle] +pub extern "C" fn occlum_readv(fd: c_int, iov: * mut iovec_t, count: c_int) -> ssize_t { + match do_readv(fd, iov, count) { + Ok(read_len) => { + read_len as ssize_t + }, + Err(e) => { + e.errno.as_retval() as ssize_t + } + } +} + +#[no_mangle] +pub extern "C" fn occlum_writev(fd: c_int, iov: * const iovec_t, count: c_int) -> ssize_t { + match do_writev(fd, iov, count) { + Ok(write_len) => { + write_len as ssize_t + }, + Err(e) => { + e.errno.as_retval() as ssize_t + } + } +} + + #[no_mangle] pub extern "C" fn occlum_getpid() -> c_uint { @@ -118,7 +233,7 @@ fn do_spawn(child_pid_ptr: *mut c_uint, envp: *const *const c_char) -> Result<(), Error> { - let child_pid_ptr = check_mut_ptr_from_user(child_pid_ptr)?; + check_mut_ptr_from_user(child_pid_ptr)?; let path = clone_string_from_user_safely(path)?; let argv = clone_cstrings_from_user_safely(argv)?; let envp = clone_cstrings_from_user_safely(envp)?; diff --git a/src/libos/src/syscall_entry.c b/src/libos/src/syscall_entry.c index aef447ef..3efe20ef 100644 --- a/src/libos/src/syscall_entry.c +++ b/src/libos/src/syscall_entry.c @@ -40,6 +40,20 @@ long dispatch_syscall(int num, long arg0, long arg1, long arg2, long arg3, long ret = occlum_read(fd, buf, buf_size); break; } + case SYS_writev: { + DECL_SYSCALL_ARG(int, fd, arg0); + DECL_SYSCALL_ARG(const struct iovec*, iov, arg1); + DECL_SYSCALL_ARG(int, count, arg2); + ret = occlum_writev(fd, iov, count); + break; + } + case SYS_readv: { + DECL_SYSCALL_ARG(int, fd, arg0); + DECL_SYSCALL_ARG(struct iovec*, iov, arg1); + DECL_SYSCALL_ARG(int, count, arg2); + ret = occlum_readv(fd, iov, count); + break; + } case SYS_spawn: { DECL_SYSCALL_ARG(int*, child_pid, arg0); DECL_SYSCALL_ARG(const char*, path, arg1);