From 3b915db7747d8f30bc0a7020c93cbe6cae2282f0 Mon Sep 17 00:00:00 2001 From: He Sun Date: Thu, 29 Oct 2020 15:59:49 +0800 Subject: [PATCH] Refactor Unix socket 1. Implement type-safe functions; 2. Improve the correctness of nearly all the functions; 3. Improve the readability by introducing Listener and Endpoint for StreamUnix; 4. Substitue RingBuf with Channel in Unix socket. --- src/libos/Cargo.lock | 10 + src/libos/Cargo.toml | 1 + src/libos/src/fs/channel.rs | 64 ++- src/libos/src/fs/pipe.rs | 44 +- src/libos/src/lib.rs | 3 + src/libos/src/net/mod.rs | 6 +- .../{host_socket => host}/ioctl_impl.rs | 0 .../net/socket/{host_socket => host}/mod.rs | 5 + .../net/socket/{host_socket => host}/recv.rs | 0 .../net/socket/{host_socket => host}/send.rs | 0 .../{host_socket => host}/socket_file.rs | 0 src/libos/src/net/socket/mod.rs | 10 +- src/libos/src/net/socket/shutdown.rs | 28 ++ src/libos/src/net/socket/unix/addr.rs | 151 ++++++ src/libos/src/net/socket/unix/mod.rs | 49 ++ .../net/socket/unix/stream/address_space.rs | 93 ++++ .../src/net/socket/unix/stream/endpoint.rs | 110 +++++ src/libos/src/net/socket/unix/stream/file.rs | 95 ++++ src/libos/src/net/socket/unix/stream/mod.rs | 8 + .../src/net/socket/unix/stream/stream.rs | 325 +++++++++++++ src/libos/src/net/socket/unix_socket.rs | 390 ---------------- src/libos/src/net/syscalls.rs | 247 ++++++---- src/libos/src/util/mod.rs | 1 - src/libos/src/util/ring_buf.rs | 428 ------------------ test/unix_socket/main.c | 35 ++ 25 files changed, 1138 insertions(+), 965 deletions(-) rename src/libos/src/net/socket/{host_socket => host}/ioctl_impl.rs (100%) rename src/libos/src/net/socket/{host_socket => host}/mod.rs (96%) rename src/libos/src/net/socket/{host_socket => host}/recv.rs (100%) rename src/libos/src/net/socket/{host_socket => host}/send.rs (100%) rename src/libos/src/net/socket/{host_socket => host}/socket_file.rs (100%) create mode 100644 src/libos/src/net/socket/shutdown.rs create mode 100644 src/libos/src/net/socket/unix/addr.rs create mode 100644 src/libos/src/net/socket/unix/mod.rs create mode 100644 src/libos/src/net/socket/unix/stream/address_space.rs create mode 100644 src/libos/src/net/socket/unix/stream/endpoint.rs create mode 100644 src/libos/src/net/socket/unix/stream/file.rs create mode 100644 src/libos/src/net/socket/unix/stream/mod.rs create mode 100644 src/libos/src/net/socket/unix/stream/stream.rs delete mode 100644 src/libos/src/net/socket/unix_socket.rs delete mode 100644 src/libos/src/util/ring_buf.rs diff --git a/src/libos/Cargo.lock b/src/libos/Cargo.lock index 07900c3b..934440c3 100644 --- a/src/libos/Cargo.lock +++ b/src/libos/Cargo.lock @@ -11,6 +11,7 @@ dependencies = [ "derive_builder", "lazy_static", "log", + "memoffset", "rcore-fs", "rcore-fs-mountfs", "rcore-fs-ramfs", @@ -242,6 +243,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "memoffset" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "157b4208e3059a8f9e78d559edc658e13df41410cb3ae03979c83130067fdd87" +dependencies = [ + "autocfg 1.0.1", +] + [[package]] name = "proc-macro2" version = "1.0.19" diff --git a/src/libos/Cargo.toml b/src/libos/Cargo.toml index 9c5c62a1..4675fe27 100644 --- a/src/libos/Cargo.toml +++ b/src/libos/Cargo.toml @@ -23,6 +23,7 @@ rcore-fs-mountfs = { path = "../../deps/sefs/rcore-fs-mountfs" } rcore-fs-unionfs = { path = "../../deps/sefs/rcore-fs-unionfs" } serde = { path = "../../deps/serde-sgx/serde", features = ["derive"] } serde_json = { path = "../../deps/serde-json-sgx" } +memoffset = "0.6.1" [patch.'https://github.com/apache/teaclave-sgx-sdk.git'] sgx_tstd = { path = "../../deps/rust-sgx-sdk/sgx_tstd" } diff --git a/src/libos/src/fs/channel.rs b/src/libos/src/fs/channel.rs index 2ab9b737..f84bee85 100644 --- a/src/libos/src/fs/channel.rs +++ b/src/libos/src/fs/channel.rs @@ -76,6 +76,20 @@ impl Channel { let Channel { producer, consumer } = self; (producer, consumer) } + + pub fn items_to_consume(&self) -> usize { + self.consumer.items_to_consume() + } + + pub fn set_nonblocking(&self, nonblocking: bool) { + self.consumer.set_nonblocking(nonblocking); + self.producer.set_nonblocking(nonblocking); + } + + pub fn shutdown(&self) { + self.consumer.shutdown(); + self.producer.shutdown(); + } } impl Channel { @@ -264,7 +278,12 @@ impl Producer { impl Producer { pub fn push_slice(&self, items: &[I]) -> Result { - if items.len() == 0 { + self.push_slices(&[items]) + } + + pub fn push_slices(&self, item_slices: &[&[I]]) -> Result { + let len: usize = item_slices.iter().map(|slice| slice.len()).sum(); + if len == 0 { return Ok(0); } @@ -275,11 +294,21 @@ impl Producer { return_errno!(EPIPE, "one or both endpoints have been shutdown"); } - let count = rb_producer.push_slice(items); - if count > 0 { + let mut total_count = 0; + for items in item_slices { + let count = rb_producer.push_slice(items); + total_count += count; + if count < items.len() { + break; + } else { + continue; + } + } + + if total_count > 0 { drop(rb_producer); self.trigger_peer_events(&IoEvents::IN); - return Ok(count); + return Ok(total_count); } if self.is_nonblocking() { @@ -374,11 +403,20 @@ impl Consumer { pub fn is_peer_shutdown(&self) -> bool { self.state.is_producer_shutdown() } + + pub fn items_to_consume(&self) -> usize { + self.inner.lock().unwrap().len() + } } impl Consumer { pub fn pop_slice(&self, items: &mut [I]) -> Result { - if items.len() == 0 { + self.pop_slices(&mut [items]) + } + + pub fn pop_slices(&self, item_slices: &mut [&mut [I]]) -> Result { + let len: usize = item_slices.iter().map(|slice| slice.len()).sum(); + if len == 0 { return Ok(0); } @@ -389,11 +427,21 @@ impl Consumer { return_errno!(EPIPE, "this endpoint has been shutdown"); } - let count = rb_consumer.pop_slice(items); - if count > 0 { + let mut total_count = 0; + for items in item_slices.iter_mut() { + let count = rb_consumer.pop_slice(items); + total_count += count; + if count < items.len() { + break; + } else { + continue; + } + } + + if total_count > 0 { drop(rb_consumer); self.trigger_peer_events(&IoEvents::OUT); - return Ok(count); + return Ok(total_count); }; if self.is_peer_shutdown() { diff --git a/src/libos/src/fs/pipe.rs b/src/libos/src/fs/pipe.rs index 769b81dc..4a451201 100644 --- a/src/libos/src/fs/pipe.rs +++ b/src/libos/src/fs/pipe.rs @@ -41,27 +41,7 @@ impl File for PipeReader { } fn readv(&self, bufs: &mut [&mut [u8]]) -> Result { - let mut total_count = 0; - for buf in bufs { - match self.consumer.pop_slice(buf) { - Ok(count) => { - total_count += count; - if count < buf.len() { - break; - } else { - continue; - } - } - Err(e) => { - if total_count > 0 { - break; - } else { - return Err(e); - } - } - } - } - Ok(total_count) + self.consumer.pop_slices(bufs) } fn get_access_mode(&self) -> Result { @@ -120,27 +100,7 @@ impl File for PipeWriter { } fn writev(&self, bufs: &[&[u8]]) -> Result { - let mut total_count = 0; - for buf in bufs { - match self.producer.push_slice(buf) { - Ok(count) => { - total_count += count; - if count < buf.len() { - break; - } else { - continue; - } - } - Err(e) => { - if total_count > 0 { - break; - } else { - return Err(e); - } - } - } - } - Ok(total_count) + self.producer.push_slices(bufs) } fn seek(&self, pos: SeekFrom) -> Result { diff --git a/src/libos/src/lib.rs b/src/libos/src/lib.rs index d6fa9813..5a2b6a03 100644 --- a/src/libos/src/lib.rs +++ b/src/libos/src/lib.rs @@ -17,6 +17,7 @@ // for UntrustedSliceAlloc in slice_alloc #![feature(slice_ptr_get)] #![feature(maybe_uninit_extra)] +#![feature(get_mut_unchecked)] #[macro_use] extern crate alloc; @@ -46,6 +47,8 @@ extern crate derive_builder; extern crate ringbuf; extern crate serde; extern crate serde_json; +#[macro_use] +extern crate memoffset; use sgx_trts::libc; use sgx_types::*; diff --git a/src/libos/src/net/mod.rs b/src/libos/src/net/mod.rs index 257d8814..553ef7d2 100644 --- a/src/libos/src/net/mod.rs +++ b/src/libos/src/net/mod.rs @@ -7,9 +7,9 @@ pub use self::io_multiplexing::{ PollEventFlags, PollFd, THREAD_NOTIFIERS, }; pub use self::socket::{ - msghdr, msghdr_mut, AddressFamily, AsUnixSocket, FileFlags, HostSocket, HostSocketType, Iovs, - IovsMut, MsgHdr, MsgHdrFlags, MsgHdrMut, RecvFlags, SendFlags, SliceAsLibcIovec, SockAddr, - SocketType, UnixSocketFile, + msghdr, msghdr_mut, socketpair, unix_socket, AddressFamily, AsUnixSocket, FileFlags, + HostSocket, HostSocketType, HowToShut, Iovs, IovsMut, MsgHdr, MsgHdrFlags, MsgHdrMut, + RecvFlags, SendFlags, SliceAsLibcIovec, SockAddr, SocketType, UnixAddr, }; pub use self::syscalls::*; diff --git a/src/libos/src/net/socket/host_socket/ioctl_impl.rs b/src/libos/src/net/socket/host/ioctl_impl.rs similarity index 100% rename from src/libos/src/net/socket/host_socket/ioctl_impl.rs rename to src/libos/src/net/socket/host/ioctl_impl.rs diff --git a/src/libos/src/net/socket/host_socket/mod.rs b/src/libos/src/net/socket/host/mod.rs similarity index 96% rename from src/libos/src/net/socket/host_socket/mod.rs rename to src/libos/src/net/socket/host/mod.rs index e2dd9e91..fe85e2d8 100644 --- a/src/libos/src/net/socket/host_socket/mod.rs +++ b/src/libos/src/net/socket/host/mod.rs @@ -132,6 +132,11 @@ impl HostSocket { pub fn raw_host_fd(&self) -> FileDesc { self.host_fd.to_raw() } + + pub fn shutdown(&self, how: HowToShut) -> Result<()> { + try_libc!(libc::ocall::shutdown(self.raw_host_fd() as i32, how.bits())); + Ok(()) + } } pub trait HostSocketType { diff --git a/src/libos/src/net/socket/host_socket/recv.rs b/src/libos/src/net/socket/host/recv.rs similarity index 100% rename from src/libos/src/net/socket/host_socket/recv.rs rename to src/libos/src/net/socket/host/recv.rs diff --git a/src/libos/src/net/socket/host_socket/send.rs b/src/libos/src/net/socket/host/send.rs similarity index 100% rename from src/libos/src/net/socket/host_socket/send.rs rename to src/libos/src/net/socket/host/send.rs diff --git a/src/libos/src/net/socket/host_socket/socket_file.rs b/src/libos/src/net/socket/host/socket_file.rs similarity index 100% rename from src/libos/src/net/socket/host_socket/socket_file.rs rename to src/libos/src/net/socket/host/socket_file.rs diff --git a/src/libos/src/net/socket/mod.rs b/src/libos/src/net/socket/mod.rs index dab297ef..7fb632e4 100644 --- a/src/libos/src/net/socket/mod.rs +++ b/src/libos/src/net/socket/mod.rs @@ -2,18 +2,20 @@ use super::*; mod address_family; mod flags; -mod host_socket; +mod host; mod iovs; mod msg; +mod shutdown; mod socket_address; mod socket_type; -mod unix_socket; +mod unix; pub use self::address_family::AddressFamily; pub use self::flags::{FileFlags, MsgHdrFlags, RecvFlags, SendFlags}; -pub use self::host_socket::{HostSocket, HostSocketType}; +pub use self::host::{HostSocket, HostSocketType}; pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec}; pub use self::msg::{msghdr, msghdr_mut, MsgHdr, MsgHdrMut}; +pub use self::shutdown::HowToShut; pub use self::socket_address::SockAddr; pub use self::socket_type::SocketType; -pub use self::unix_socket::{AsUnixSocket, UnixSocketFile}; +pub use self::unix::{socketpair, unix_socket, AsUnixSocket, UnixAddr}; diff --git a/src/libos/src/net/socket/shutdown.rs b/src/libos/src/net/socket/shutdown.rs new file mode 100644 index 00000000..8b07abb8 --- /dev/null +++ b/src/libos/src/net/socket/shutdown.rs @@ -0,0 +1,28 @@ +use super::*; + +bitflags! { + pub struct HowToShut: c_int { + const READ = 0; + const WRITE = 1; + const BOTH = 2; + } +} + +impl HowToShut { + pub fn try_from_raw(how: c_int) -> Result { + match how { + 0 => Ok(Self::READ), + 1 => Ok(Self::WRITE), + 2 => Ok(Self::BOTH), + _ => return_errno!(EINVAL, "invalid how"), + } + } + + pub fn to_shut_read(&self) -> bool { + *self == Self::READ || *self == Self::BOTH + } + + pub fn to_shut_write(&self) -> bool { + *self == Self::WRITE || *self == Self::BOTH + } +} diff --git a/src/libos/src/net/socket/unix/addr.rs b/src/libos/src/net/socket/unix/addr.rs new file mode 100644 index 00000000..6f082e08 --- /dev/null +++ b/src/libos/src/net/socket/unix/addr.rs @@ -0,0 +1,151 @@ +use super::*; +use std::path::{Path, PathBuf}; +use std::{cmp, mem, slice, str}; + +const MAX_PATH_LEN: usize = 108; +const SUN_FAMILY_LEN: usize = mem::size_of::(); +lazy_static! { + static ref SUN_PATH_OFFSET: usize = memoffset::offset_of!(libc::sockaddr_un, sun_path); +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum Addr { + File(UnixPath), + Abstract(String), +} + +impl Addr { + /// Caller should guarentee the sockaddr and addr_len are valid. + /// The pathname should end with a '\0' within the passed length. + /// The abstract name should both start and end with a '\0' within the passed length. + pub unsafe fn try_from_raw( + sockaddr: *const libc::sockaddr, + addr_len: libc::socklen_t, + ) -> Result { + let addr_len = addr_len as usize; + + // TODO: support autobind to validate when addr_len == SUN_FAMILY_LEN + if addr_len <= SUN_FAMILY_LEN { + return_errno!(EINVAL, "the address is too short."); + } + + if addr_len > MAX_PATH_LEN + *SUN_PATH_OFFSET { + return_errno!(EINVAL, "the address is too long."); + } + + if AddressFamily::try_from((*sockaddr).sa_family)? != AddressFamily::LOCAL { + return_errno!(EINVAL, "not a valid address for unix socket"); + } + + let sockaddr = sockaddr as *const libc::sockaddr_un; + let sun_path = (*sockaddr).sun_path; + + if sun_path[0] == 0 { + let path_ptr = sun_path[1..(addr_len - *SUN_PATH_OFFSET)].as_ptr(); + let path_slice = + slice::from_raw_parts(path_ptr as *const u8, addr_len - *SUN_PATH_OFFSET - 1); + + Ok(Self::Abstract( + str::from_utf8(&path_slice).unwrap().to_string(), + )) + } else { + let path_cstr = CStr::from_ptr(sun_path.as_ptr()); + if path_cstr.to_bytes_with_nul().len() > MAX_PATH_LEN { + return_errno!(EINVAL, "no null in the address"); + } + + Ok(Self::File(UnixPath::new(path_cstr.to_str().unwrap()))) + } + } + + pub fn copy_to_slice(&self, dst: &mut [u8]) -> usize { + let (raw_addr, addr_len) = self.to_raw(); + let src = + unsafe { std::slice::from_raw_parts(&raw_addr as *const _ as *const u8, addr_len) }; + let copied = std::cmp::min(dst.len(), addr_len); + dst[..copied].copy_from_slice(&src[..copied]); + copied + } + + pub fn raw_len(&self) -> usize { + /// The '/0' at the end of Self::File counts + self.path_str().len() + + 1 + + *SUN_PATH_OFFSET + } + + pub fn path_str(&self) -> &str { + match self { + Self::File(unix_path) => &unix_path.path_str(), + Self::Abstract(path) => &path, + } + } + + fn to_raw(&self) -> (libc::sockaddr_un, usize) { + let mut addr: libc::sockaddr_un = unsafe { mem::zeroed() }; + addr.sun_family = AddressFamily::LOCAL as libc::sa_family_t; + + let addr_len = match self { + Self::File(unix_path) => { + let path_str = unix_path.path_str(); + let buf_len = path_str.len(); + /// addr is initialized to all zeros and try_from_raw guarentees + /// unix_path length is shorter than sun_path, so sun_path here + /// will always have a null terminator + addr.sun_path[..buf_len] + .copy_from_slice(unsafe { &*(path_str.as_bytes() as *const _ as *const [i8]) }); + buf_len + *SUN_PATH_OFFSET + 1 + } + Self::Abstract(path_str) => { + addr.sun_path[0] = 0; + let buf_len = path_str.len() + 1; + addr.sun_path[1..buf_len] + .copy_from_slice(unsafe { &*(path_str.as_bytes() as *const _ as *const [i8]) }); + buf_len + *SUN_PATH_OFFSET + } + }; + + (addr, addr_len) + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct UnixPath { + inner: PathBuf, + /// Holds the cwd when a relative path is created + cwd: Option, +} + +impl UnixPath { + pub fn new(path: &str) -> Self { + let inner = PathBuf::from(path); + let is_absolute = inner.is_absolute(); + Self { + inner: inner, + cwd: if is_absolute { + None + } else { + let thread = current!(); + let fs = thread.fs().lock().unwrap(); + let cwd = fs.cwd().to_owned(); + + Some(cwd) + }, + } + } + + pub fn absolute(&self) -> String { + let path_str = self.path_str(); + if self.inner.is_absolute() { + path_str.to_string() + } else { + let mut prefix = path_str.to_owned(); + prefix.push_str(self.cwd.as_ref().unwrap()); + prefix + } + } + + pub fn path_str(&self) -> &str { + self.inner.to_str().unwrap() + } +} diff --git a/src/libos/src/net/socket/unix/mod.rs b/src/libos/src/net/socket/unix/mod.rs new file mode 100644 index 00000000..525d9f39 --- /dev/null +++ b/src/libos/src/net/socket/unix/mod.rs @@ -0,0 +1,49 @@ +use self::addr::Addr; +use super::*; + +mod addr; +mod stream; + +pub use self::addr::Addr as UnixAddr; +pub use self::stream::Stream; + +//TODO: rewrite this file when a new kind of uds is added +pub fn unix_socket(socket_type: SocketType, flags: FileFlags, protocol: i32) -> Result { + if protocol != 0 && protocol != AddressFamily::LOCAL as i32 { + return_errno!(EPROTONOSUPPORT, "protocol is not supported"); + } + + if socket_type == SocketType::STREAM { + Ok(Stream::new(flags)) + } else { + return_errno!(ESOCKTNOSUPPORT, "only stream type is supported"); + } +} + +pub fn socketpair( + socket_type: SocketType, + flags: FileFlags, + protocol: i32, +) -> Result<(Stream, Stream)> { + if protocol != 0 && protocol != AddressFamily::LOCAL as i32 { + return_errno!(EPROTONOSUPPORT, "protocol is not supported"); + } + + if socket_type == SocketType::STREAM { + Stream::socketpair(flags) + } else { + return_errno!(ESOCKTNOSUPPORT, "only stream type is supported"); + } +} + +pub trait AsUnixSocket { + fn as_unix_socket(&self) -> Result<&Stream>; +} + +impl AsUnixSocket for FileRef { + fn as_unix_socket(&self) -> Result<&Stream> { + self.as_any() + .downcast_ref::() + .ok_or_else(|| errno!(EBADF, "not a unix socket")) + } +} diff --git a/src/libos/src/net/socket/unix/stream/address_space.rs b/src/libos/src/net/socket/unix/stream/address_space.rs new file mode 100644 index 00000000..c9840e07 --- /dev/null +++ b/src/libos/src/net/socket/unix/stream/address_space.rs @@ -0,0 +1,93 @@ +use super::endpoint::Endpoint; +use super::stream::Listener; +use super::*; +use std::collections::btree_map::BTreeMap; + +lazy_static! { + pub(super) static ref ADDRESS_SPACE: AddressSpace = AddressSpace::new(); +} + +pub struct AddressSpace { + file: SgxMutex>>>, + abstr: SgxMutex>>>, +} + +impl AddressSpace { + pub fn new() -> Self { + Self { + file: SgxMutex::new(BTreeMap::new()), + abstr: SgxMutex::new(BTreeMap::new()), + } + } + + pub fn add_binder(&self, addr: &Addr) -> Result<()> { + let key = Self::get_key(addr); + let mut space = self.get_space(addr); + if space.contains_key(&key) { + return_errno!(EADDRINUSE, "the addr is already bound"); + } else { + space.insert(key, None); + Ok(()) + } + } + + pub fn add_listener(&self, addr: &Addr, capacity: usize) -> Result<()> { + let key = Self::get_key(addr); + let mut space = self.get_space(addr); + + if let Some(option) = space.get(&key) { + if let Some(listener) = option { + let new_listener = Listener::new(capacity)?; + for i in 0..std::cmp::min(listener.remaining(), capacity) { + new_listener.push_incoming(listener.pop_incoming().unwrap()); + } + space.insert(key, Some(Arc::new(new_listener))); + } else { + space.insert(key, Some(Arc::new(Listener::new(capacity)?))); + } + Ok(()) + } else { + return_errno!(EINVAL, "the socket is not bound"); + } + } + + pub fn push_incoming(&self, addr: &Addr, sock: Endpoint) -> Result<()> { + self.get_listener_ref(addr) + .ok_or_else(|| errno!(ECONNREFUSED, "no one's listening on the remote address"))? + .push_incoming(sock); + Ok(()) + } + + pub fn pop_incoming(&self, addr: &Addr) -> Result { + self.get_listener_ref(addr) + .ok_or_else(|| errno!(EINVAL, "the socket is not listening"))? + .pop_incoming() + .ok_or_else(|| errno!(EAGAIN, "No connection is incoming")) + } + + pub fn get_listener_ref(&self, addr: &Addr) -> Option> { + let key = Self::get_key(addr); + let space = self.get_space(addr); + space.get(&key).map(|x| x.clone()).flatten() + } + + pub fn remove_addr(&self, addr: &Addr) { + let key = Self::get_key(addr); + let mut space = self.get_space(addr); + space.remove(&key); + } + + fn get_space(&self, addr: &Addr) -> SgxMutexGuard<'_, BTreeMap>>> { + match addr { + Addr::File(unix_path) => self.file.lock().unwrap(), + Addr::Abstract(path) => self.abstr.lock().unwrap(), + } + } + + fn get_key(addr: &Addr) -> String { + match addr { + Addr::File(unix_path) => unix_path.absolute(), + Addr::Abstract(path) => addr.path_str().to_string(), + } + } +} diff --git a/src/libos/src/net/socket/unix/stream/endpoint.rs b/src/libos/src/net/socket/unix/stream/endpoint.rs new file mode 100644 index 00000000..fa3ab370 --- /dev/null +++ b/src/libos/src/net/socket/unix/stream/endpoint.rs @@ -0,0 +1,110 @@ +use super::*; +use alloc::sync::{Arc, Weak}; +use fs::channel::{Channel, Consumer, Producer}; + +pub type Endpoint = Arc; + +/// Constructor of two connected Endpoints +pub fn end_pair(nonblocking: bool) -> Result<(Endpoint, Endpoint)> { + let (pro_a, con_a) = Channel::new(DEFAULT_BUF_SIZE)?.split(); + let (pro_b, con_b) = Channel::new(DEFAULT_BUF_SIZE)?.split(); + + let mut end_a = Arc::new(Inner { + addr: RwLock::new(None), + reader: con_a, + writer: pro_b, + peer: Weak::default(), + }); + let end_b = Arc::new(Inner { + addr: RwLock::new(None), + reader: con_b, + writer: pro_a, + peer: Arc::downgrade(&end_a), + }); + + unsafe { + Arc::get_mut_unchecked(&mut end_a).peer = Arc::downgrade(&end_b); + } + + end_a.set_nonblocking(nonblocking); + end_b.set_nonblocking(nonblocking); + + Ok((end_a, end_b)) +} + +/// One end of the connected unix socket +pub struct Inner { + addr: RwLock>, + reader: Consumer, + writer: Producer, + peer: Weak, +} + +impl Inner { + pub fn addr(&self) -> Option { + self.addr.read().unwrap().clone() + } + + pub fn set_addr(&self, addr: &Addr) { + *self.addr.write().unwrap() = Some(addr.clone()); + } + + pub fn peer_addr(&self) -> Option { + self.peer.upgrade().map(|end| end.addr().clone()).flatten() + } + + pub fn set_nonblocking(&self, nonblocking: bool) { + self.reader.set_nonblocking(nonblocking); + self.writer.set_nonblocking(nonblocking); + } + + pub fn nonblocking(&self) -> bool { + let cons_nonblocking = self.reader.is_nonblocking(); + let prod_nonblocking = self.writer.is_nonblocking(); + assert_eq!(cons_nonblocking, prod_nonblocking); + cons_nonblocking + } + pub fn read(&self, buf: &mut [u8]) -> Result { + self.reader.pop_slice(buf) + } + + pub fn write(&self, buf: &[u8]) -> Result { + self.writer.push_slice(buf) + } + + pub fn readv(&self, bufs: &mut [&mut [u8]]) -> Result { + self.reader.pop_slices(bufs) + } + + pub fn writev(&self, bufs: &[&[u8]]) -> Result { + self.writer.push_slices(bufs) + } + + pub fn bytes_to_read(&self) -> usize { + self.reader.items_to_consume() + } + + pub fn shutdown(&self, how: HowToShut) -> Result<()> { + if !self.is_connected() { + return_errno!(ENOTCONN, "The socket is not connected."); + } + + if how.to_shut_read() { + self.reader.shutdown() + } + + if how.to_shut_write() { + self.writer.shutdown() + } + + Ok(()) + } + + fn is_connected(&self) -> bool { + self.peer.upgrade().is_some() + } +} + +// TODO: Add SO_SNDBUF and SO_RCVBUF to set/getsockopt to dynamcally change the size. +// This value is got from /proc/sys/net/core/rmem_max and wmem_max that are same on linux. +pub const DEFAULT_BUF_SIZE: usize = 208 * 1024; diff --git a/src/libos/src/net/socket/unix/stream/file.rs b/src/libos/src/net/socket/unix/stream/file.rs new file mode 100644 index 00000000..99c974d5 --- /dev/null +++ b/src/libos/src/net/socket/unix/stream/file.rs @@ -0,0 +1,95 @@ +use super::stream::Status; +use super::*; +use fs::{AccessMode, File, FileRef, IoctlCmd, StatusFlags}; +use std::any::Any; + +impl File for Stream { + fn read(&self, buf: &mut [u8]) -> Result { + match &*self.inner() { + Status::Connected(endpoint) => endpoint.read(buf), + _ => return_errno!(ENOTCONN, "unconnected socket"), + } + } + + fn write(&self, buf: &[u8]) -> Result { + match &*self.inner() { + Status::Connected(endpoint) => endpoint.write(buf), + _ => return_errno!(ENOTCONN, "unconnected socket"), + } + } + + fn read_at(&self, offset: usize, buf: &mut [u8]) -> Result { + if offset != 0 { + return_errno!(ESPIPE, "a nonzero position is not supported"); + } + self.read(buf) + } + + fn write_at(&self, offset: usize, buf: &[u8]) -> Result { + if offset != 0 { + return_errno!(ESPIPE, "a nonzero position is not supported"); + } + self.write(buf) + } + + fn readv(&self, bufs: &mut [&mut [u8]]) -> Result { + match &*self.inner() { + Status::Connected(endpoint) => endpoint.readv(bufs), + _ => return_errno!(ENOTCONN, "unconnected socket"), + } + } + + fn writev(&self, bufs: &[&[u8]]) -> Result { + match &*self.inner() { + Status::Connected(endpoint) => endpoint.writev(bufs), + _ => return_errno!(ENOTCONN, "unconnected socket"), + } + } + + fn ioctl(&self, cmd: &mut IoctlCmd) -> Result { + match cmd { + IoctlCmd::FIONREAD(arg) => match &*self.inner() { + Status::Connected(endpoint) => { + let bytes_to_read = endpoint.bytes_to_read().min(std::i32::MAX as usize) as i32; + **arg = bytes_to_read; + Ok(0) + } + _ => return_errno!(ENOTCONN, "unconnected socket"), + }, + _ => return_errno!(EINVAL, "unknown ioctl cmd for unix socket"), + } + } + + fn get_access_mode(&self) -> Result { + Ok(AccessMode::O_RDWR) + } + + fn get_status_flags(&self) -> Result { + if self.nonblocking() { + Ok(StatusFlags::O_NONBLOCK) + } else { + Ok(StatusFlags::empty()) + } + } + + fn set_status_flags(&self, new_status_flags: StatusFlags) -> Result<()> { + // Only O_NONBLOCK, O_ASYNC and O_DIRECT can be set + let status_flags = new_status_flags + & (StatusFlags::O_NONBLOCK | StatusFlags::O_ASYNC | StatusFlags::O_DIRECT); + + // Only O_NONBLOCK is supported + let nonblocking = new_status_flags.contains(StatusFlags::O_NONBLOCK); + self.set_nonblocking(nonblocking); + Ok(()) + } + + fn poll(&self) -> Result { + warn!("poll is not supported for unix_socket"); + let events = PollEventFlags::empty(); + Ok(events) + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/src/libos/src/net/socket/unix/stream/mod.rs b/src/libos/src/net/socket/unix/stream/mod.rs new file mode 100644 index 00000000..4d53789f --- /dev/null +++ b/src/libos/src/net/socket/unix/stream/mod.rs @@ -0,0 +1,8 @@ +use super::*; + +mod address_space; +mod endpoint; +mod file; +mod stream; + +pub use stream::Stream; diff --git a/src/libos/src/net/socket/unix/stream/stream.rs b/src/libos/src/net/socket/unix/stream/stream.rs new file mode 100644 index 00000000..7de0ff42 --- /dev/null +++ b/src/libos/src/net/socket/unix/stream/stream.rs @@ -0,0 +1,325 @@ +use super::address_space::ADDRESS_SPACE; +use super::endpoint::{end_pair, Endpoint}; +use super::*; +use alloc::sync::Arc; +use fs::channel::Channel; +use std::fmt; +use std::sync::atomic::{AtomicBool, Ordering}; + +/// SOCK_STREAM Unix socket. It has three statuses: unconnected, listening and connected. When a +/// socket is created, it is in unconnected status. It will transfer to listening after listen is +/// called and connected after connect is called. A socket in connected status can be obtained +/// through a listening socket calling accept. Listening and connected are ultimate statuses. They +/// will not transfer to other statuses. +pub struct Stream { + inner: SgxMutex, +} + +impl Stream { + pub fn new(flags: FileFlags) -> Self { + Self { + inner: SgxMutex::new(Status::Unconnected(Info::new( + flags.contains(FileFlags::SOCK_NONBLOCK), + ))), + } + } + + pub fn socketpair(flags: FileFlags) -> Result<(Self, Self)> { + let nonblocking = flags.contains(FileFlags::SOCK_NONBLOCK); + let (end_a, end_b) = end_pair(nonblocking)?; + + let socket_a = Self { + inner: SgxMutex::new(Status::Connected(end_a)), + }; + + let socket_b = Self { + inner: SgxMutex::new(Status::Connected(end_b)), + }; + + Ok((socket_a, socket_b)) + } + + pub fn addr(&self) -> Option { + match &*self.inner() { + Status::Unconnected(info) => info.addr().clone(), + Status::Connected(endpoint) => endpoint.addr(), + Status::Listening(addr) => Some(addr).cloned(), + } + } + + pub fn peer_addr(&self) -> Result { + if let Status::Connected(endpoint) = &*self.inner() { + if let Some(addr) = endpoint.peer_addr() { + return Ok(addr); + } + } + return_errno!(ENOTCONN, "the socket is not connected"); + } + + // TODO: create the corresponding file in the fs + pub fn bind(&self, addr: &Addr) -> Result<()> { + match &mut *self.inner() { + Status::Unconnected(ref mut info) => { + if info.addr().is_some() { + return_errno!(EINVAL, "the socket is already bound"); + } + + // check the global address space to see if the address is avaiable before bind + ADDRESS_SPACE.add_binder(addr)?; + info.set_addr(addr); + } + Status::Connected(endpoint) => { + if endpoint.addr().is_some() { + return_errno!(EINVAL, "the socket is already bound"); + } + + ADDRESS_SPACE.add_binder(addr)?; + endpoint.set_addr(addr); + } + Status::Listening(_) => return_errno!(EINVAL, "the socket is already bound"), + } + + Ok(()) + } + + pub fn listen(&self, backlog: i32) -> Result<()> { + //TODO: restrict backlog accroding to /proc/sys/net/core/somaxconn + if backlog < 0 { + return_errno!(EINVAL, "negative backlog is not supported"); + } + let capacity = backlog as usize; + + let mut inner = self.inner(); + match &*inner { + Status::Unconnected(info) => { + if let Some(addr) = info.addr() { + ADDRESS_SPACE.add_listener(addr, capacity)?; + *inner = Status::Listening(addr.clone()); + } else { + return_errno!(EINVAL, "the socket is not bound"); + } + } + Status::Connected(_) => return_errno!(EINVAL, "the socket is already connected"), + /// Modify the capacity of the channel holding incoming sockets + Status::Listening(addr) => ADDRESS_SPACE.add_listener(&addr, capacity)?, + } + + Ok(()) + } + + pub fn connect(&self, addr: &Addr) -> Result<()> { + debug!("connect to {:?}", addr); + + let mut inner = self.inner(); + match &*inner { + Status::Unconnected(info) => { + let self_addr_opt = info.addr(); + if let Some(self_addr) = self_addr_opt { + if self_addr == addr { + return_errno!(EINVAL, "self connect is not supported"); + } + } + + let (end_self, end_incoming) = end_pair(info.nonblocking())?; + end_incoming.set_addr(addr); + if let Some(self_addr) = self_addr_opt { + end_self.set_addr(self_addr); + } + + ADDRESS_SPACE.push_incoming(addr, end_incoming)?; + + *inner = Status::Connected(end_self); + Ok(()) + } + Status::Connected(endpoint) => return_errno!(EISCONN, "already connected"), + Status::Listening(addr) => return_errno!(EINVAL, "invalid socket for connect"), + } + } + + pub fn accept(&self, flags: FileFlags) -> Result<(Self, Option)> { + match &*self.inner() { + Status::Listening(addr) => { + let endpoint = ADDRESS_SPACE.pop_incoming(&addr)?; + endpoint.set_nonblocking(flags.contains(FileFlags::SOCK_NONBLOCK)); + + let peer_addr = endpoint.peer_addr(); + + debug!("accept socket from {:?}", peer_addr); + + Ok(( + Self { + inner: SgxMutex::new(Status::Connected(endpoint)), + }, + peer_addr, + )) + } + _ => return_errno!(EINVAL, "the socket is not listening"), + } + } + + // TODO: handle flags + pub fn sendto(&self, buf: &[u8], flags: SendFlags, addr: &Option) -> Result { + self.write(buf) + } + + // TODO: handle flags + pub fn recvfrom(&self, buf: &mut [u8], flags: RecvFlags) -> Result<(usize, Option)> { + let data_len = self.read(buf)?; + let addr = self.peer_addr().ok(); + + debug!("recvfrom {:?}", addr); + + Ok((data_len, addr)) + } + + /// perform shutdown on the socket. + pub fn shutdown(&self, how: HowToShut) -> Result<()> { + if let Status::Connected(ref end) = &*self.inner() { + end.shutdown(how) + } else { + return_errno!(ENOTCONN, "The socket is not connected."); + } + } + + pub(super) fn nonblocking(&self) -> bool { + match &*self.inner() { + Status::Unconnected(info) => info.nonblocking(), + Status::Connected(endpoint) => endpoint.nonblocking(), + Status::Listening(addr) => ADDRESS_SPACE.get_listener_ref(&addr).unwrap().nonblocking(), + } + } + + pub(super) fn set_nonblocking(&self, nonblocking: bool) { + match &mut *self.inner() { + Status::Unconnected(ref mut info) => info.set_nonblocking(nonblocking), + Status::Connected(ref mut endpoint) => endpoint.set_nonblocking(nonblocking), + Status::Listening(addr) => ADDRESS_SPACE + .get_listener_ref(&addr) + .unwrap() + .set_nonblocking(nonblocking), + } + } + + pub(super) fn inner(&self) -> SgxMutexGuard<'_, Status> { + self.inner.lock().unwrap() + } +} + +impl Debug for Stream { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Stream") + .field("addr", &self.addr()) + .field("nonblocking", &self.nonblocking()) + .finish() + } +} + +impl Drop for Stream { + fn drop(&mut self) { + match &*self.inner() { + Status::Unconnected(info) => { + if let Some(addr) = info.addr() { + ADDRESS_SPACE.remove_addr(&addr); + } + } + Status::Listening(addr) => { + let listener = ADDRESS_SPACE.get_listener_ref(&addr).unwrap(); + ADDRESS_SPACE.remove_addr(&addr); + /// handle the blocking of other sockets holding the reference to the listener, + /// e.g., pushing to a listener full of incoming sockets + listener.shutdown(); + } + _ => {} + } + } +} + +pub enum Status { + Unconnected(Info), + /// The listeners are stored in a global data structure indexed by the address. + /// The consitency of Status with that data structure should be carefully maintained. + Listening(Addr), + Connected(Endpoint), +} + +#[derive(Debug, Clone)] +pub struct Info { + addr: Option, + nonblocking: bool, +} + +impl Info { + pub fn new(nonblocking: bool) -> Self { + Self { + addr: None, + nonblocking: nonblocking, + } + } + + pub fn addr(&self) -> &Option { + &self.addr + } + + pub fn set_addr(&mut self, addr: &Addr) { + self.addr = Some(addr.clone()); + } + + pub fn nonblocking(&self) -> bool { + self.nonblocking + } + + pub fn set_nonblocking(&mut self, nonblocking: bool) { + self.nonblocking = nonblocking; + } +} + +pub struct Listener { + channel: Channel, + nonblocking: AtomicBool, +} + +impl Listener { + pub fn new(capacity: usize) -> Result { + let channel = Channel::new(capacity)?; + // It may incur blocking inside a blocking if the channel is blocking. Set the channel to + // nonblocking permanently to avoid the nested blocking. This also results in nonblocking + // accept and connect. Future work is needed to resolve this blocking issue to support + // blocking accept and connect. + channel.set_nonblocking(true); + /// The listener is blocking by default + let nonblocking = AtomicBool::new(true); + + Ok(Self { + channel, + nonblocking, + }) + } + + pub fn push_incoming(&self, stream_socket: Endpoint) { + self.channel.push(stream_socket); + } + + pub fn pop_incoming(&self) -> Option { + self.channel.pop().ok().flatten() + } + + pub fn remaining(&self) -> usize { + self.channel.items_to_consume() + } + + pub fn nonblocking(&self) -> bool { + warn!("the channel works in a nonblocking way regardless of the nonblocking status"); + + self.nonblocking.load(Ordering::Acquire) + } + + pub fn set_nonblocking(&self, nonblocking: bool) { + warn!("the channel works in a nonblocking way regardless of the nonblocking status"); + + self.nonblocking.store(nonblocking, Ordering::Release); + } + + pub fn shutdown(&self) { + self.channel.shutdown(); + } +} diff --git a/src/libos/src/net/socket/unix_socket.rs b/src/libos/src/net/socket/unix_socket.rs deleted file mode 100644 index 6838dddb..00000000 --- a/src/libos/src/net/socket/unix_socket.rs +++ /dev/null @@ -1,390 +0,0 @@ -use super::*; -use fs::{File, FileRef, IoctlCmd}; -use rcore_fs::vfs::{FileType, Metadata, Timespec}; -use std::any::Any; -use std::collections::btree_map::BTreeMap; -use std::fmt; -use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering}; -use std::sync::SgxMutex as Mutex; -use util::ring_buf::{ring_buffer, RingBufReader, RingBufWriter}; - -pub struct UnixSocketFile { - inner: Mutex, -} - -// TODO: add enqueue_event and dequeue_event -impl File for UnixSocketFile { - fn read(&self, buf: &mut [u8]) -> Result { - let mut inner = self.inner.lock().unwrap(); - inner.read(buf) - } - - fn write(&self, buf: &[u8]) -> Result { - let mut inner = self.inner.lock().unwrap(); - inner.write(buf) - } - - fn read_at(&self, _offset: usize, buf: &mut [u8]) -> Result { - self.read(buf) - } - - fn write_at(&self, _offset: usize, buf: &[u8]) -> Result { - self.write(buf) - } - - fn readv(&self, bufs: &mut [&mut [u8]]) -> Result { - let mut inner = self.inner.lock().unwrap(); - inner.readv(bufs) - } - - fn writev(&self, bufs: &[&[u8]]) -> Result { - let mut inner = self.inner.lock().unwrap(); - inner.writev(bufs) - } - - fn metadata(&self) -> Result { - Ok(Metadata { - dev: 0, - inode: 0, - size: 0, - blk_size: 0, - blocks: 0, - atime: Timespec { sec: 0, nsec: 0 }, - mtime: Timespec { sec: 0, nsec: 0 }, - ctime: Timespec { sec: 0, nsec: 0 }, - type_: FileType::Socket, - mode: 0, - nlinks: 0, - uid: 0, - gid: 0, - rdev: 0, - }) - } - - fn ioctl(&self, cmd: &mut IoctlCmd) -> Result { - let mut inner = self.inner.lock().unwrap(); - inner.ioctl(cmd) - } - - fn poll(&self) -> Result { - let mut inner = self.inner.lock().unwrap(); - inner.poll() - } - - fn as_any(&self) -> &dyn Any { - self - } -} - -static SOCKETPAIR_NUM: AtomicUsize = AtomicUsize::new(0); -const SOCK_PATH_PREFIX: &str = "socketpair_"; - -impl UnixSocketFile { - pub fn new(socket_type: c_int, protocol: c_int) -> Result { - let inner = UnixSocket::new(socket_type, protocol)?; - Ok(UnixSocketFile { - inner: Mutex::new(inner), - }) - } - - pub fn bind(&self, path: impl AsRef) -> Result<()> { - let mut inner = self.inner.lock().unwrap(); - inner.bind(path) - } - - pub fn listen(&self) -> Result<()> { - let mut inner = self.inner.lock().unwrap(); - inner.listen() - } - - pub fn accept(&self) -> Result { - let mut inner = self.inner.lock().unwrap(); - let new_socket = inner.accept()?; - Ok(UnixSocketFile { - inner: Mutex::new(new_socket), - }) - } - - pub fn connect(&self, path: impl AsRef) -> Result<()> { - let mut inner = self.inner.lock().unwrap(); - inner.connect(path) - } - - pub fn socketpair(socket_type: i32, protocol: i32) -> Result<(Self, Self)> { - let listen_socket = Self::new(socket_type, protocol)?; - let bound_path = listen_socket.bind_until_success(); - listen_socket.listen()?; - - let client_socket = Self::new(socket_type, protocol)?; - client_socket.connect(&bound_path)?; - - let accepted_socket = listen_socket.accept()?; - Ok((client_socket, accepted_socket)) - } - - fn bind_until_success(&self) -> String { - loop { - let sock_path_suffix = SOCKETPAIR_NUM.fetch_add(1, Ordering::SeqCst); - let sock_path = format!("{}{}", SOCK_PATH_PREFIX, sock_path_suffix); - if self.bind(&sock_path).is_ok() { - return sock_path; - } - } - } - - pub fn is_connected(&self) -> bool { - if let Status::Connected(_) = self.inner.lock().unwrap().status { - true - } else { - false - } - } -} - -impl Debug for UnixSocketFile { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "UnixSocketFile {{ ... }}") - } -} - -pub trait AsUnixSocket { - fn as_unix_socket(&self) -> Result<&UnixSocketFile>; -} - -impl AsUnixSocket for FileRef { - fn as_unix_socket(&self) -> Result<&UnixSocketFile> { - self.as_any() - .downcast_ref::() - .ok_or_else(|| errno!(EBADF, "not a unix socket")) - } -} - -pub struct UnixSocket { - obj: Option>, - status: Status, -} - -enum Status { - None, - Listening, - Connected(Channel), -} - -impl UnixSocket { - /// C/S 1: Create a new unix socket - pub fn new(socket_type: c_int, protocol: c_int) -> Result { - if socket_type == libc::SOCK_STREAM && (protocol == 0 || protocol == libc::PF_UNIX) { - Ok(UnixSocket { - obj: None, - status: Status::None, - }) - } else { - // Return different error numbers according to input - return_errno!(ENOSYS, "unimplemented unix socket type") - } - } - - /// Server 2: Bind the socket to a file system path - pub fn bind(&mut self, path: impl AsRef) -> Result<()> { - // TODO: check permission - if self.obj.is_some() { - return_errno!(EINVAL, "The socket is already bound to an address."); - } - self.obj = Some(UnixSocketObject::create(path)?); - Ok(()) - } - - /// Server 3: Listen to a socket - pub fn listen(&mut self) -> Result<()> { - self.status = Status::Listening; - Ok(()) - } - - /// Server 4: Accept a connection on listening. - pub fn accept(&mut self) -> Result { - match self.status { - Status::Listening => {} - _ => return_errno!(EINVAL, "unix socket is not listening"), - }; - // FIXME: Block. Now spin loop. - let socket = loop { - if let Some(socket) = self.obj.as_mut().unwrap().pop() { - break socket; - } - spin_loop_hint(); - }; - Ok(socket) - } - - /// Client 2: Connect to a path - pub fn connect(&mut self, path: impl AsRef) -> Result<()> { - if let Status::Listening = self.status { - return_errno!(EINVAL, "unix socket is listening?"); - } - let obj = UnixSocketObject::get(path) - .ok_or_else(|| errno!(EINVAL, "unix socket path not found"))?; - // TODO: Mov the buffer allocation to function new to comply with the bahavior of unix - let (channel1, channel2) = Channel::new_pair()?; - self.status = Status::Connected(channel1); - obj.push(UnixSocket { - obj: Some(obj.clone()), - status: Status::Connected(channel2), - }); - Ok(()) - } - - pub fn read(&mut self, buf: &mut [u8]) -> Result { - self.channel_mut()?.reader.read_from_buffer(buf) - } - - pub fn readv(&mut self, bufs: &mut [&mut [u8]]) -> Result { - self.channel_mut()?.reader.read_from_vector(bufs) - } - - pub fn write(&mut self, buf: &[u8]) -> Result { - self.channel_mut()?.writer.write_to_buffer(buf) - } - - pub fn writev(&mut self, bufs: &[&[u8]]) -> Result { - self.channel_mut()?.writer.write_to_vector(bufs) - } - - fn poll(&self) -> Result { - let channel_result = self.channel(); - if let Ok(channel) = channel_result { - let readable = channel.reader.can_read() && !channel.reader.is_peer_closed(); - let writable = channel.writer.can_write() && !channel.writer.is_peer_closed(); - let events = if readable ^ writable { - if channel.reader.can_read() { - PollEventFlags::POLLRDHUP | PollEventFlags::POLLIN | PollEventFlags::POLLRDNORM - } else { - PollEventFlags::POLLRDHUP - } - // both readable and writable - } else if readable { - PollEventFlags::POLLIN - | PollEventFlags::POLLOUT - | PollEventFlags::POLLRDNORM - | PollEventFlags::POLLWRNORM - } else { - PollEventFlags::POLLHUP - }; - Ok(events) - } else { - // For the unconnected socket - // TODO: add write support for unconnected sockets like linux does - Ok(PollEventFlags::POLLHUP) - } - } - - pub fn ioctl(&self, cmd: &mut IoctlCmd) -> Result { - match cmd { - IoctlCmd::FIONREAD(arg) => { - let bytes_to_read = self - .channel()? - .reader - .bytes_to_read() - .min(std::i32::MAX as usize) as i32; - **arg = bytes_to_read; - } - _ => return_errno!(EINVAL, "unknown ioctl cmd for unix socket"), - } - Ok(0) - } - - fn channel_mut(&mut self) -> Result<&mut Channel> { - if let Status::Connected(ref mut channel) = &mut self.status { - Ok(channel) - } else { - return_errno!(EBADF, "UnixSocket is not connected") - } - } - - fn channel(&self) -> Result<&Channel> { - if let Status::Connected(channel) = &self.status { - Ok(channel) - } else { - return_errno!(EBADF, "UnixSocket is not connected") - } - } -} - -impl Drop for UnixSocket { - fn drop(&mut self) { - if let Status::Listening = self.status { - // Only remove the object when there is one - if let Some(obj) = self.obj.as_ref() { - UnixSocketObject::remove(&obj.path); - } - } - } -} - -pub struct UnixSocketObject { - path: String, - accepted_sockets: Mutex>, -} - -impl UnixSocketObject { - fn push(&self, unix_socket: UnixSocket) { - let mut queue = self.accepted_sockets.lock().unwrap(); - queue.push_back(unix_socket); - } - fn pop(&self) -> Option { - let mut queue = self.accepted_sockets.lock().unwrap(); - queue.pop_front() - } - fn get(path: impl AsRef) -> Option> { - let mut paths = UNIX_SOCKET_OBJS.lock().unwrap(); - paths.get(path.as_ref()).map(|obj| obj.clone()) - } - fn create(path: impl AsRef) -> Result> { - let mut paths = UNIX_SOCKET_OBJS.lock().unwrap(); - if paths.contains_key(path.as_ref()) { - return_errno!(EADDRINUSE, "unix socket path already exists"); - } - let obj = Arc::new(UnixSocketObject { - path: path.as_ref().to_string(), - accepted_sockets: Mutex::new(VecDeque::new()), - }); - paths.insert(path.as_ref().to_string(), obj.clone()); - Ok(obj) - } - fn remove(path: impl AsRef) { - let mut paths = UNIX_SOCKET_OBJS.lock().unwrap(); - paths.remove(path.as_ref()); - } -} - -struct Channel { - reader: RingBufReader, - writer: RingBufWriter, -} - -unsafe impl Send for Channel {} -unsafe impl Sync for Channel {} - -impl Channel { - fn new_pair() -> Result<(Channel, Channel)> { - let (reader1, writer1) = ring_buffer(DEFAULT_BUF_SIZE)?; - let (reader2, writer2) = ring_buffer(DEFAULT_BUF_SIZE)?; - let channel1 = Channel { - reader: reader1, - writer: writer2, - }; - let channel2 = Channel { - reader: reader2, - writer: writer1, - }; - Ok((channel1, channel2)) - } -} - -// TODO: Add SO_SNDBUF and SO_RCVBUF to set/getsockopt to dynamcally change the size. -// This value is got from /proc/sys/net/core/rmem_max and wmem_max that are same on linux. -pub const DEFAULT_BUF_SIZE: usize = 208 * 1024; - -lazy_static! { - static ref UNIX_SOCKET_OBJS: Mutex>> = - Mutex::new(BTreeMap::new()); -} diff --git a/src/libos/src/net/syscalls.rs b/src/libos/src/net/syscalls.rs index 76936660..f1a12100 100644 --- a/src/libos/src/net/syscalls.rs +++ b/src/libos/src/net/syscalls.rs @@ -18,7 +18,7 @@ pub fn do_socket(domain: c_int, socket_type: c_int, protocol: c_int) -> Result = match sock_domain { AddressFamily::LOCAL => { - let unix_socket = UnixSocketFile::new(socket_type, protocol)?; + let unix_socket = unix_socket(sock_type, file_flags, protocol)?; Arc::new(unix_socket) } _ => { @@ -38,19 +38,15 @@ pub fn do_bind(fd: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t } from_user::check_array(addr as *const u8, addr_len as usize)?; - let sock_addr = unsafe { SockAddr::try_from_raw(addr, addr_len)? }; - trace!("bind to addr: {:?}", sock_addr); - let file_ref = current!().file(fd as FileDesc)?; if let Ok(socket) = file_ref.as_host_socket() { + let sock_addr = unsafe { SockAddr::try_from_raw(addr, addr_len)? }; + trace!("bind to addr: {:?}", sock_addr); socket.bind(&sock_addr)?; } else if let Ok(unix_socket) = file_ref.as_unix_socket() { - let addr = addr as *const libc::sockaddr_un; - from_user::check_ptr(addr)?; - let path = from_user::clone_cstring_safely(unsafe { (&*addr).sun_path.as_ptr() })? - .to_string_lossy() - .into_owned(); - unix_socket.bind(path)?; + let unix_addr = unsafe { UnixAddr::try_from_raw(addr, addr_len)? }; + trace!("bind to addr: {:?}", unix_addr); + unix_socket.bind(&unix_addr)?; } else { return_errno!(EBADF, "not a socket"); } @@ -63,7 +59,7 @@ pub fn do_listen(fd: c_int, backlog: c_int) -> Result { if let Ok(socket) = file_ref.as_host_socket() { socket.listen(backlog)?; } else if let Ok(unix_socket) = file_ref.as_unix_socket() { - unix_socket.listen()?; + unix_socket.listen(backlog)?; } else { return_errno!(EBADF, "not a socket"); } @@ -84,24 +80,26 @@ pub fn do_connect( from_user::check_array(addr as *const u8, addr_len as usize)?; } - let addr_option = if addr_set { - Some(unsafe { SockAddr::try_from_raw(addr, addr_len)? }) - } else { - None - }; - let file_ref = current!().file(fd as FileDesc)?; if let Ok(socket) = file_ref.as_host_socket() { + let addr_option = if addr_set { + Some(unsafe { SockAddr::try_from_raw(addr, addr_len)? }) + } else { + None + }; + socket.connect(&addr_option)?; } else if let Ok(unix_socket) = file_ref.as_unix_socket() { - let addr = addr as *const libc::sockaddr_un; - from_user::check_ptr(addr)?; - let path = from_user::clone_cstring_safely(unsafe { (&*addr).sun_path.as_ptr() })? - .to_string_lossy() - .into_owned(); - unix_socket.connect(path)?; + // TODO: support AF_UNSPEC address for datagram socket use + let addr = if addr_set { + unsafe { UnixAddr::try_from_raw(addr, addr_len)? } + } else { + return_errno!(EINVAL, "invalid address"); + }; + + unix_socket.connect(&addr)?; } else { - return_errno!(EBADF, "not a socket") + return_errno!(EBADF, "not a socket"); } Ok(0) @@ -131,47 +129,65 @@ pub fn do_accept4( let close_on_spawn = file_flags.contains(FileFlags::SOCK_CLOEXEC); let file_ref = current!().file(fd as FileDesc)?; - let new_fd = if let Ok(socket) = file_ref.as_host_socket() { + if let Ok(socket) = file_ref.as_host_socket() { let (new_socket_file, sock_addr_option) = socket.accept(file_flags)?; let new_file_ref: Arc = Arc::new(new_socket_file); let new_fd = current!().add_file(new_file_ref, close_on_spawn); - if addr_set && sock_addr_option.is_some() { - let sock_addr = sock_addr_option.unwrap(); - let mut buf = - unsafe { std::slice::from_raw_parts_mut(addr as *mut u8, *addr_len as usize) }; - sock_addr.copy_to_slice(&mut buf); - unsafe { - *addr_len = sock_addr.len() as u32; + if addr_set { + if let Some(sock_addr) = sock_addr_option { + let mut buf = + unsafe { std::slice::from_raw_parts_mut(addr as *mut u8, *addr_len as usize) }; + sock_addr.copy_to_slice(&mut buf); + unsafe { + *addr_len = sock_addr.len() as u32; + } + } else { + unsafe { + *addr_len = 0; + } } } - new_fd + Ok(new_fd as isize) } else if let Ok(unix_socket) = file_ref.as_unix_socket() { - let addr = addr as *mut libc::sockaddr_un; + let (new_socket_file, sock_addr_option) = unix_socket.accept(file_flags)?; + let new_file_ref: Arc = Arc::new(new_socket_file); + let new_fd = current!().add_file(new_file_ref, close_on_spawn); + if addr_set { - from_user::check_mut_ptr(addr)?; + if let Some(sock_addr) = sock_addr_option { + let mut buf = + unsafe { std::slice::from_raw_parts_mut(addr as *mut u8, *addr_len as usize) }; + sock_addr.copy_to_slice(&mut buf); + unsafe { + *addr_len = sock_addr.raw_len() as u32; + } + } else { + unsafe { + *addr_len = 0; + } + } } - // TODO: handle addr - let new_socket = unix_socket.accept()?; - let new_file_ref: Arc = Arc::new(new_socket); - current!().add_file(new_file_ref, false) + Ok(new_fd as isize) } else { return_errno!(EBADF, "not a socket"); - }; - - Ok(new_fd as isize) + } } pub fn do_shutdown(fd: c_int, how: c_int) -> Result { debug!("shutdown: fd: {}, how: {}", fd, how); + let how = HowToShut::try_from_raw(how)?; + let file_ref = current!().file(fd as FileDesc)?; if let Ok(socket) = file_ref.as_host_socket() { - let ret = try_libc!(libc::ocall::shutdown(socket.raw_host_fd() as i32, how)); - Ok(ret as isize) + socket.shutdown(how)?; + } else if let Ok(unix_socket) = file_ref.as_unix_socket() { + unix_socket.shutdown(how)?; } else { - // TODO: support unix socket - return_errno!(EBADF, "not a socket") + return_errno!(EBADF, "not a host socket") } + + Ok(0) } pub fn do_setsockopt( @@ -232,10 +248,14 @@ pub fn do_getpeername( addr: *mut libc::sockaddr, addr_len: *mut libc::socklen_t, ) -> Result { - debug!( - "getpeername: fd: {}, addr: {:?}, addr_len: {:?}", - fd, addr, addr_len - ); + let addr_set: bool = !addr.is_null(); + if addr_set { + from_user::check_ptr(addr_len)?; + from_user::check_mut_array(addr as *mut u8, unsafe { *addr_len } as usize)?; + } else { + return Ok(0); + } + let file_ref = current!().file(fd as FileDesc)?; if let Ok(socket) = file_ref.as_host_socket() { let ret = try_libc!(libc::ocall::getpeername( @@ -245,11 +265,15 @@ pub fn do_getpeername( )); Ok(ret as isize) } else if let Ok(unix_socket) = file_ref.as_unix_socket() { - warn!("getpeername for unix socket is unimplemented"); - return_errno!( - ENOTCONN, - "hack for php: Transport endpoint is not connected" - ) + let name = unix_socket.peer_addr()?; + let mut dst = unsafe { + std::slice::from_raw_parts_mut(addr as *mut _ as *mut u8, *addr_len as usize) + }; + name.copy_to_slice(dst); + unsafe { + *addr_len = name.raw_len() as u32; + } + Ok(0) } else { return_errno!(EBADF, "not a socket") } @@ -260,10 +284,18 @@ pub fn do_getsockname( addr: *mut libc::sockaddr, addr_len: *mut libc::socklen_t, ) -> Result { - debug!( - "getsockname: fd: {}, addr: {:?}, addr_len: {:?}", - fd, addr, addr_len - ); + let addr_set: bool = !addr.is_null(); + if addr_set { + from_user::check_ptr(addr_len)?; + from_user::check_mut_array(addr as *mut u8, unsafe { *addr_len } as usize)?; + } else { + return Ok(0); + } + + if unsafe { *addr_len } < std::mem::size_of::() as u32 { + return_errno!(EINVAL, "input length is too short"); + } + let file_ref = current!().file(fd as FileDesc)?; if let Ok(socket) = file_ref.as_host_socket() { let ret = try_libc!(libc::ocall::getsockname( @@ -273,10 +305,24 @@ pub fn do_getsockname( )); Ok(ret as isize) } else if let Ok(unix_socket) = file_ref.as_unix_socket() { - warn!("getsockname for unix socket is unimplemented"); + let name_opt = unix_socket.addr(); + if let Some(name) = name_opt { + let mut dst = unsafe { + std::slice::from_raw_parts_mut(addr as *mut _ as *mut u8, *addr_len as usize) + }; + name.copy_to_slice(dst); + unsafe { + *addr_len = name.raw_len() as u32; + } + } else { + unsafe { + (*addr).sa_family = AddressFamily::LOCAL as u16; + *addr_len = 2; + } + } Ok(0) } else { - return_errno!(EBADF, "not a socket") + return_errno!(EBADF, "not a socket"); } } @@ -306,24 +352,27 @@ pub fn do_sendto( let send_flags = SendFlags::from_bits(flags).unwrap(); - let addr_option = if addr_set { - Some(unsafe { SockAddr::try_from_raw(addr, addr_len)? }) - } else { - None - }; - let file_ref = current!().file(fd as FileDesc)?; if let Ok(socket) = file_ref.as_host_socket() { + let addr_option = if addr_set { + Some(unsafe { SockAddr::try_from_raw(addr, addr_len)? }) + } else { + None + }; + socket .sendto(buf, send_flags, &addr_option) .map(|u| u as isize) - } else if let Ok(unix) = file_ref.as_unix_socket() { - if !unix.is_connected() { - return_errno!(ENOTCONN, "the socket has not been connected yet"); - } + } else if let Ok(unix_socket) = file_ref.as_unix_socket() { + let addr_option = if addr_set { + Some(unsafe { UnixAddr::try_from_raw(addr, addr_len)? }) + } else { + None + }; - let data = unsafe { std::slice::from_raw_parts(base as *const u8, len) }; - unix.write(data).map(|u| u as isize) + unix_socket + .sendto(buf, send_flags, &addr_option) + .map(|u| u as isize) } else { return_errno!(EBADF, "unsupported file type"); } @@ -356,23 +405,43 @@ pub fn do_recvfrom( } let file_ref = current!().file(fd as FileDesc)?; - let (data_len, sock_addr_option) = if let Ok(socket) = file_ref.as_host_socket() { - socket.recvfrom(buf, recv_flags)? + if let Ok(socket) = file_ref.as_host_socket() { + let (data_len, sock_addr_option) = socket.recvfrom(buf, recv_flags)?; + if addr_set { + if let Some(sock_addr) = sock_addr_option { + let mut buf = + unsafe { std::slice::from_raw_parts_mut(addr as *mut u8, *addr_len as usize) }; + sock_addr.copy_to_slice(&mut buf); + unsafe { + *addr_len = sock_addr.len() as u32; + } + } else { + unsafe { + *addr_len = 0; + } + } + } + Ok(data_len as isize) + } else if let Ok(unix_socket) = file_ref.as_unix_socket() { + let (data_len, sock_addr_option) = unix_socket.recvfrom(buf, recv_flags)?; + if addr_set { + if let Some(sock_addr) = sock_addr_option { + let mut buf = + unsafe { std::slice::from_raw_parts_mut(addr as *mut u8, *addr_len as usize) }; + sock_addr.copy_to_slice(&mut buf); + unsafe { + *addr_len = sock_addr.raw_len() as u32; + } + } else { + unsafe { + *addr_len = 0; + } + } + } + Ok(data_len as isize) } else { return_errno!(EBADF, "not a socket"); - }; - - if addr_set && sock_addr_option.is_some() { - let sock_addr = sock_addr_option.unwrap(); - let mut buf = - unsafe { std::slice::from_raw_parts_mut(addr as *mut u8, *addr_len as usize) }; - sock_addr.copy_to_slice(&mut buf); - unsafe { - *addr_len = sock_addr.len() as u32; - } } - - Ok(data_len as isize) } pub fn do_socketpair( @@ -388,11 +457,11 @@ pub fn do_socketpair( let file_flags = FileFlags::from_bits_truncate(socket_type); let close_on_spawn = file_flags.contains(FileFlags::SOCK_CLOEXEC); + let sock_type = SocketType::try_from(socket_type & (!file_flags.bits()))?; let domain = AddressFamily::try_from(domain as u16)?; if (domain == AddressFamily::LOCAL) { - let (client_socket, server_socket) = - UnixSocketFile::socketpair(socket_type as i32, protocol as i32)?; + let (client_socket, server_socket) = socketpair(sock_type, file_flags, protocol as i32)?; let current = current!(); let mut files = current.files().lock().unwrap(); diff --git a/src/libos/src/util/mod.rs b/src/libos/src/util/mod.rs index 0dbf1cbc..92543d44 100644 --- a/src/libos/src/util/mod.rs +++ b/src/libos/src/util/mod.rs @@ -4,6 +4,5 @@ pub mod dirty; pub mod log; pub mod mem_util; pub mod mpx_util; -pub mod ring_buf; pub mod sgx; pub mod sync; diff --git a/src/libos/src/util/ring_buf.rs b/src/libos/src/util/ring_buf.rs deleted file mode 100644 index 2e4bbec2..00000000 --- a/src/libos/src/util/ring_buf.rs +++ /dev/null @@ -1,428 +0,0 @@ -use alloc::alloc::{alloc, dealloc, Layout}; - -use crate::net::{ - clear_notifier_status, notify_thread, wait_for_notification, IoEvent, PollEventFlags, -}; -use std::cmp::{max, min}; -use std::ptr; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::sync::Arc; - -use super::*; -use ringbuf::{Consumer, Producer, RingBuffer}; - -pub fn ring_buffer(capacity: usize) -> Result<(RingBufReader, RingBufWriter)> { - let meta = RingBufMeta::new(); - let buffer = RingBuffer::::new(capacity); - let (producer, consumer) = buffer.split(); - let meta_ref = Arc::new(meta); - - let reader = RingBufReader { - inner: consumer, - buffer: meta_ref.clone(), - }; - let writer = RingBufWriter { - inner: producer, - buffer: meta_ref, - }; - Ok((reader, writer)) -} - -struct RingBufMeta { - lock: Arc>, // lock for the synchronization of reader and writer - reader_closed: AtomicBool, // if reader has been dropped - writer_closed: AtomicBool, // if writer has been dropped - reader_wait_queue: SgxMutex>, - writer_wait_queue: SgxMutex>, - // TODO: support O_ASYNC and O_DIRECT in ringbuffer - blocking_read: AtomicBool, // if the read is blocking - blocking_write: AtomicBool, // if the write is blocking -} - -impl RingBufMeta { - pub fn new() -> RingBufMeta { - Self { - lock: Arc::new(SgxMutex::new(true)), - reader_closed: AtomicBool::new(false), - writer_closed: AtomicBool::new(false), - reader_wait_queue: SgxMutex::new(HashMap::new()), - writer_wait_queue: SgxMutex::new(HashMap::new()), - blocking_read: AtomicBool::new(true), - blocking_write: AtomicBool::new(true), - } - } - - pub fn is_reader_closed(&self) -> bool { - self.reader_closed.load(Ordering::SeqCst) - } - - pub fn close_reader(&self) { - self.reader_closed.store(true, Ordering::SeqCst); - } - - pub fn is_writer_closed(&self) -> bool { - self.writer_closed.load(Ordering::SeqCst) - } - - pub fn close_writer(&self) { - self.writer_closed.store(true, Ordering::SeqCst); - } - - pub fn reader_wait_queue(&self) -> &SgxMutex> { - &self.reader_wait_queue - } - - pub fn writer_wait_queue(&self) -> &SgxMutex> { - &self.writer_wait_queue - } - - pub fn enqueue_reader_event(&self, event: IoEvent) -> Result<()> { - self.reader_wait_queue - .lock() - .unwrap() - .insert(current!().tid(), event); - Ok(()) - } - - pub fn dequeue_reader_event(&self) -> Result<()> { - self.reader_wait_queue - .lock() - .unwrap() - .remove(¤t!().tid()) - .unwrap(); - Ok(()) - } - - pub fn enqueue_writer_event(&self, event: IoEvent) -> Result<()> { - self.writer_wait_queue - .lock() - .unwrap() - .insert(current!().tid(), event); - Ok(()) - } - - pub fn dequeue_writer_event(&self) -> Result<()> { - self.writer_wait_queue - .lock() - .unwrap() - .remove(¤t!().tid()) - .unwrap(); - Ok(()) - } - - pub fn blocking_read(&self) -> bool { - self.blocking_read.load(Ordering::SeqCst) - } - - pub fn set_non_blocking_read(&self) { - self.blocking_read.store(false, Ordering::SeqCst); - } - - pub fn set_blocking_read(&self) { - self.blocking_read.store(true, Ordering::SeqCst); - } - - pub fn blocking_write(&self) -> bool { - self.blocking_write.load(Ordering::SeqCst) - } - - pub fn set_non_blocking_write(&self) { - self.blocking_write.store(false, Ordering::SeqCst); - } - - pub fn set_blocking_write(&self) { - self.blocking_write.store(true, Ordering::SeqCst); - } -} - -pub struct RingBufReader { - inner: Consumer, - buffer: Arc, -} - -impl RingBufReader { - pub fn can_read(&self) -> bool { - self.bytes_to_read() != 0 - } - - pub fn read_from_buffer(&mut self, buffer: &mut [u8]) -> Result { - self.read(Some(buffer), None) - } - - pub fn read_from_vector(&mut self, buffers: &mut [&mut [u8]]) -> Result { - self.read(None, Some(buffers)) - } - - fn read( - &mut self, - buffer: Option<&mut [u8]>, - buffers: Option<&mut [&mut [u8]]>, - ) -> Result { - assert!(buffer.is_some() ^ buffers.is_some()); - // In case of write after can_read is false - let lock_ref = self.buffer.lock.clone(); - let lock_holder = lock_ref.lock(); - - if self.can_read() { - let count = if buffer.is_some() { - self.inner.pop_slice(buffer.unwrap()) - } else { - self.pop_slices(buffers.unwrap()) - }; - assert!(count > 0); - self.read_end(); - Ok(count) - } else { - if self.is_peer_closed() { - return Ok(0); - } - - if !self.buffer.blocking_read() { - return_errno!(EAGAIN, "No data to read"); - } else { - // Clear the status of notifier before enqueue - clear_notifier_status(current!().tid())?; - self.enqueue_event(IoEvent::BlockingRead)?; - drop(lock_holder); - drop(lock_ref); - let ret = wait_for_notification(); - self.dequeue_event()?; - ret?; - - let lock_ref = self.buffer.lock.clone(); - let lock_holder = lock_ref.lock(); - let count = if buffer.is_some() { - self.inner.pop_slice(buffer.unwrap()) - } else { - self.pop_slices(buffers.unwrap()) - }; - - if count > 0 { - self.read_end()?; - } else { - assert!(self.is_peer_closed()); - } - Ok(count) - } - } - } - - fn pop_slices(&mut self, buffers: &mut [&mut [u8]]) -> usize { - let mut total = 0; - for buf in buffers { - let count = self.inner.pop_slice(buf); - total += count; - if count < buf.len() { - break; - } - } - total - } - - pub fn bytes_to_read(&self) -> usize { - self.inner.len() - } - - fn read_end(&self) -> Result<()> { - for (tid, event) in &*self.buffer.writer_wait_queue().lock().unwrap() { - match event { - IoEvent::Poll(poll_events) => { - if !(poll_events.events() - & (PollEventFlags::POLLOUT | PollEventFlags::POLLWRNORM)) - .is_empty() - { - notify_thread(*tid)?; - } - } - IoEvent::Epoll(epoll_file) => unimplemented!(), - IoEvent::BlockingRead => unreachable!(), - IoEvent::BlockingWrite => notify_thread(*tid)?, - } - } - Ok(()) - } - - pub fn is_peer_closed(&self) -> bool { - self.buffer.is_writer_closed() - } - - pub fn enqueue_event(&self, event: IoEvent) -> Result<()> { - self.buffer.enqueue_reader_event(event) - } - - pub fn dequeue_event(&self) -> Result<()> { - self.buffer.dequeue_reader_event() - } - - pub fn set_non_blocking(&self) { - self.buffer.set_non_blocking_read() - } - - pub fn set_blocking(&self) { - self.buffer.set_blocking_read() - } - - fn before_drop(&self) { - for (tid, event) in &*self.buffer.writer_wait_queue().lock().unwrap() { - match event { - IoEvent::Poll(_) | IoEvent::BlockingWrite => notify_thread(*tid).unwrap(), - IoEvent::Epoll(epoll_file) => unimplemented!(), - IoEvent::BlockingRead => unreachable!(), - } - } - } -} - -impl Drop for RingBufReader { - fn drop(&mut self) { - debug!("reader drop"); - self.buffer.close_reader(); - if self.buffer.blocking_write() { - self.before_drop(); - } - } -} - -pub struct RingBufWriter { - inner: Producer, - buffer: Arc, -} - -impl RingBufWriter { - pub fn write_to_buffer(&mut self, buffer: &[u8]) -> Result { - self.write(Some(buffer), None) - } - - pub fn write_to_vector(&mut self, buffers: &[&[u8]]) -> Result { - self.write(None, Some(buffers)) - } - - fn write(&mut self, buffer: Option<&[u8]>, buffers: Option<&[&[u8]]>) -> Result { - assert!(buffer.is_some() ^ buffers.is_some()); - - // TODO: send SIGPIPE to the caller - if self.is_peer_closed() { - return_errno!(EPIPE, "reader side is closed"); - } - - // In case of read after can_write is false - let lock_ref = self.buffer.lock.clone(); - let lock_holder = lock_ref.lock(); - - if self.can_write() { - let count = if buffer.is_some() { - self.inner.push_slice(buffer.unwrap()) - } else { - self.push_slices(buffers.unwrap()) - }; - assert!(count > 0); - self.write_end(); - Ok(count) - } else { - if !self.buffer.blocking_write() { - return_errno!(EAGAIN, "No space to write"); - } - - // Clear the status of notifier before enqueue - clear_notifier_status(current!().tid()); - self.enqueue_event(IoEvent::BlockingWrite)?; - drop(lock_holder); - drop(lock_ref); - let ret = wait_for_notification(); - self.dequeue_event()?; - ret?; - - let lock_ref = self.buffer.lock.clone(); - let lock_holder = lock_ref.lock(); - let count = if buffer.is_some() { - self.inner.push_slice(buffer.unwrap()) - } else { - self.push_slices(buffers.unwrap()) - }; - - if count > 0 { - self.write_end(); - Ok(count) - } else { - return_errno!(EPIPE, "reader side is closed"); - } - } - } - - fn write_end(&self) -> Result<()> { - for (tid, event) in &*self.buffer.reader_wait_queue().lock().unwrap() { - match event { - IoEvent::Poll(poll_events) => { - if !(poll_events.events() - & (PollEventFlags::POLLIN | PollEventFlags::POLLRDNORM)) - .is_empty() - { - notify_thread(*tid)?; - } - } - IoEvent::Epoll(epoll_file) => unimplemented!(), - IoEvent::BlockingRead => notify_thread(*tid)?, - IoEvent::BlockingWrite => unreachable!(), - } - } - Ok(()) - } - - fn push_slices(&mut self, buffers: &[&[u8]]) -> usize { - let mut total = 0; - for buf in buffers { - let count = self.inner.push_slice(buf); - total += count; - if count < buf.len() { - break; - } - } - total - } - - pub fn can_write(&self) -> bool { - !self.inner.is_full() - } - - pub fn is_peer_closed(&self) -> bool { - self.buffer.is_reader_closed() - } - - pub fn enqueue_event(&self, event: IoEvent) -> Result<()> { - self.buffer.enqueue_writer_event(event) - } - - pub fn dequeue_event(&self) -> Result<()> { - self.buffer.dequeue_writer_event() - } - - pub fn set_non_blocking(&self) { - self.buffer.set_non_blocking_write() - } - - pub fn set_blocking(&self) { - self.buffer.set_blocking_write() - } - - fn before_drop(&self) { - for (tid, event) in &*self.buffer.reader_wait_queue().lock().unwrap() { - match event { - IoEvent::Poll(_) | IoEvent::BlockingRead => { - notify_thread(*tid).unwrap(); - } - IoEvent::Epoll(epoll_file) => unimplemented!(), - IoEvent::BlockingWrite => unreachable!(), - } - } - } -} - -impl Drop for RingBufWriter { - fn drop(&mut self) { - debug!("writer drop"); - self.buffer.close_writer(); - if self.buffer.blocking_read() { - self.before_drop(); - } - } -} diff --git a/test/unix_socket/main.c b/test/unix_socket/main.c index 13044706..82692fa8 100644 --- a/test/unix_socket/main.c +++ b/test/unix_socket/main.c @@ -203,6 +203,40 @@ int test_poll() { return 0; } +int test_getname() { + char name[] = "unix_socket_path"; + int sock = socket(AF_UNIX, SOCK_STREAM, 0); + if (sock == -1) { + THROW_ERROR("failed to create a unix socket"); + } + + struct sockaddr_un addr = {0}; + memset(&addr, 0, sizeof(struct sockaddr_un)); //Clear structure + addr.sun_family = AF_UNIX; + strcpy(addr.sun_path, name); + socklen_t addr_len = strlen(addr.sun_path) + sizeof(addr.sun_family) + 1; + if (bind(sock, (struct sockaddr *)&addr, addr_len) == -1) { + close(sock); + THROW_ERROR("failed to bind"); + } + + struct sockaddr_un ret_addr = {0}; + socklen_t ret_addr_len = sizeof(ret_addr); + + if (getsockname(sock, (struct sockaddr *)&ret_addr, &ret_addr_len) < 0) { + close(sock); + THROW_ERROR("failed to getsockname"); + } + + if (ret_addr_len != addr_len || strcmp(ret_addr.sun_path, name) != 0) { + close(sock); + THROW_ERROR("got name mismatched"); + } + + close(sock); + return 0; +} + static test_case_t test_cases[] = { TEST_CASE(test_unix_socket_inter_process), TEST_CASE(test_socketpair_inter_process), @@ -210,6 +244,7 @@ static test_case_t test_cases[] = { // TODO: recover the test after the unix sockets are rewritten by using // the new event subsystem //TEST_CASE(test_poll), + TEST_CASE(test_getname), }; int main(int argc, const char *argv[]) {