Refactor Unix socket

1. Implement type-safe functions;
2. Improve the correctness of nearly all the functions;
3. Improve the readability by introducing Listener and Endpoint for StreamUnix;
4. Substitue RingBuf with Channel in Unix socket.
This commit is contained in:
He Sun 2020-10-29 15:59:49 +08:00 committed by Tate, Hongliang Tian
parent a09c01819b
commit 3b915db774
25 changed files with 1138 additions and 965 deletions

10
src/libos/Cargo.lock generated

@ -11,6 +11,7 @@ dependencies = [
"derive_builder", "derive_builder",
"lazy_static", "lazy_static",
"log", "log",
"memoffset",
"rcore-fs", "rcore-fs",
"rcore-fs-mountfs", "rcore-fs-mountfs",
"rcore-fs-ramfs", "rcore-fs-ramfs",
@ -242,6 +243,15 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "memoffset"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "157b4208e3059a8f9e78d559edc658e13df41410cb3ae03979c83130067fdd87"
dependencies = [
"autocfg 1.0.1",
]
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.19" version = "1.0.19"

@ -23,6 +23,7 @@ rcore-fs-mountfs = { path = "../../deps/sefs/rcore-fs-mountfs" }
rcore-fs-unionfs = { path = "../../deps/sefs/rcore-fs-unionfs" } rcore-fs-unionfs = { path = "../../deps/sefs/rcore-fs-unionfs" }
serde = { path = "../../deps/serde-sgx/serde", features = ["derive"] } serde = { path = "../../deps/serde-sgx/serde", features = ["derive"] }
serde_json = { path = "../../deps/serde-json-sgx" } serde_json = { path = "../../deps/serde-json-sgx" }
memoffset = "0.6.1"
[patch.'https://github.com/apache/teaclave-sgx-sdk.git'] [patch.'https://github.com/apache/teaclave-sgx-sdk.git']
sgx_tstd = { path = "../../deps/rust-sgx-sdk/sgx_tstd" } sgx_tstd = { path = "../../deps/rust-sgx-sdk/sgx_tstd" }

@ -76,6 +76,20 @@ impl<I> Channel<I> {
let Channel { producer, consumer } = self; let Channel { producer, consumer } = self;
(producer, consumer) (producer, consumer)
} }
pub fn items_to_consume(&self) -> usize {
self.consumer.items_to_consume()
}
pub fn set_nonblocking(&self, nonblocking: bool) {
self.consumer.set_nonblocking(nonblocking);
self.producer.set_nonblocking(nonblocking);
}
pub fn shutdown(&self) {
self.consumer.shutdown();
self.producer.shutdown();
}
} }
impl<I: Copy> Channel<I> { impl<I: Copy> Channel<I> {
@ -264,7 +278,12 @@ impl<I> Producer<I> {
impl<I: Copy> Producer<I> { impl<I: Copy> Producer<I> {
pub fn push_slice(&self, items: &[I]) -> Result<usize> { pub fn push_slice(&self, items: &[I]) -> Result<usize> {
if items.len() == 0 { self.push_slices(&[items])
}
pub fn push_slices(&self, item_slices: &[&[I]]) -> Result<usize> {
let len: usize = item_slices.iter().map(|slice| slice.len()).sum();
if len == 0 {
return Ok(0); return Ok(0);
} }
@ -275,11 +294,21 @@ impl<I: Copy> Producer<I> {
return_errno!(EPIPE, "one or both endpoints have been shutdown"); return_errno!(EPIPE, "one or both endpoints have been shutdown");
} }
let mut total_count = 0;
for items in item_slices {
let count = rb_producer.push_slice(items); let count = rb_producer.push_slice(items);
if count > 0 { total_count += count;
if count < items.len() {
break;
} else {
continue;
}
}
if total_count > 0 {
drop(rb_producer); drop(rb_producer);
self.trigger_peer_events(&IoEvents::IN); self.trigger_peer_events(&IoEvents::IN);
return Ok(count); return Ok(total_count);
} }
if self.is_nonblocking() { if self.is_nonblocking() {
@ -374,11 +403,20 @@ impl<I> Consumer<I> {
pub fn is_peer_shutdown(&self) -> bool { pub fn is_peer_shutdown(&self) -> bool {
self.state.is_producer_shutdown() self.state.is_producer_shutdown()
} }
pub fn items_to_consume(&self) -> usize {
self.inner.lock().unwrap().len()
}
} }
impl<I: Copy> Consumer<I> { impl<I: Copy> Consumer<I> {
pub fn pop_slice(&self, items: &mut [I]) -> Result<usize> { pub fn pop_slice(&self, items: &mut [I]) -> Result<usize> {
if items.len() == 0 { self.pop_slices(&mut [items])
}
pub fn pop_slices(&self, item_slices: &mut [&mut [I]]) -> Result<usize> {
let len: usize = item_slices.iter().map(|slice| slice.len()).sum();
if len == 0 {
return Ok(0); return Ok(0);
} }
@ -389,11 +427,21 @@ impl<I: Copy> Consumer<I> {
return_errno!(EPIPE, "this endpoint has been shutdown"); return_errno!(EPIPE, "this endpoint has been shutdown");
} }
let mut total_count = 0;
for items in item_slices.iter_mut() {
let count = rb_consumer.pop_slice(items); let count = rb_consumer.pop_slice(items);
if count > 0 { total_count += count;
if count < items.len() {
break;
} else {
continue;
}
}
if total_count > 0 {
drop(rb_consumer); drop(rb_consumer);
self.trigger_peer_events(&IoEvents::OUT); self.trigger_peer_events(&IoEvents::OUT);
return Ok(count); return Ok(total_count);
}; };
if self.is_peer_shutdown() { if self.is_peer_shutdown() {

@ -41,27 +41,7 @@ impl File for PipeReader {
} }
fn readv(&self, bufs: &mut [&mut [u8]]) -> Result<usize> { fn readv(&self, bufs: &mut [&mut [u8]]) -> Result<usize> {
let mut total_count = 0; self.consumer.pop_slices(bufs)
for buf in bufs {
match self.consumer.pop_slice(buf) {
Ok(count) => {
total_count += count;
if count < buf.len() {
break;
} else {
continue;
}
}
Err(e) => {
if total_count > 0 {
break;
} else {
return Err(e);
}
}
}
}
Ok(total_count)
} }
fn get_access_mode(&self) -> Result<AccessMode> { fn get_access_mode(&self) -> Result<AccessMode> {
@ -120,27 +100,7 @@ impl File for PipeWriter {
} }
fn writev(&self, bufs: &[&[u8]]) -> Result<usize> { fn writev(&self, bufs: &[&[u8]]) -> Result<usize> {
let mut total_count = 0; self.producer.push_slices(bufs)
for buf in bufs {
match self.producer.push_slice(buf) {
Ok(count) => {
total_count += count;
if count < buf.len() {
break;
} else {
continue;
}
}
Err(e) => {
if total_count > 0 {
break;
} else {
return Err(e);
}
}
}
}
Ok(total_count)
} }
fn seek(&self, pos: SeekFrom) -> Result<off_t> { fn seek(&self, pos: SeekFrom) -> Result<off_t> {

@ -17,6 +17,7 @@
// for UntrustedSliceAlloc in slice_alloc // for UntrustedSliceAlloc in slice_alloc
#![feature(slice_ptr_get)] #![feature(slice_ptr_get)]
#![feature(maybe_uninit_extra)] #![feature(maybe_uninit_extra)]
#![feature(get_mut_unchecked)]
#[macro_use] #[macro_use]
extern crate alloc; extern crate alloc;
@ -46,6 +47,8 @@ extern crate derive_builder;
extern crate ringbuf; extern crate ringbuf;
extern crate serde; extern crate serde;
extern crate serde_json; extern crate serde_json;
#[macro_use]
extern crate memoffset;
use sgx_trts::libc; use sgx_trts::libc;
use sgx_types::*; use sgx_types::*;

@ -7,9 +7,9 @@ pub use self::io_multiplexing::{
PollEventFlags, PollFd, THREAD_NOTIFIERS, PollEventFlags, PollFd, THREAD_NOTIFIERS,
}; };
pub use self::socket::{ pub use self::socket::{
msghdr, msghdr_mut, AddressFamily, AsUnixSocket, FileFlags, HostSocket, HostSocketType, Iovs, msghdr, msghdr_mut, socketpair, unix_socket, AddressFamily, AsUnixSocket, FileFlags,
IovsMut, MsgHdr, MsgHdrFlags, MsgHdrMut, RecvFlags, SendFlags, SliceAsLibcIovec, SockAddr, HostSocket, HostSocketType, HowToShut, Iovs, IovsMut, MsgHdr, MsgHdrFlags, MsgHdrMut,
SocketType, UnixSocketFile, RecvFlags, SendFlags, SliceAsLibcIovec, SockAddr, SocketType, UnixAddr,
}; };
pub use self::syscalls::*; pub use self::syscalls::*;

@ -132,6 +132,11 @@ impl HostSocket {
pub fn raw_host_fd(&self) -> FileDesc { pub fn raw_host_fd(&self) -> FileDesc {
self.host_fd.to_raw() self.host_fd.to_raw()
} }
pub fn shutdown(&self, how: HowToShut) -> Result<()> {
try_libc!(libc::ocall::shutdown(self.raw_host_fd() as i32, how.bits()));
Ok(())
}
} }
pub trait HostSocketType { pub trait HostSocketType {

@ -2,18 +2,20 @@ use super::*;
mod address_family; mod address_family;
mod flags; mod flags;
mod host_socket; mod host;
mod iovs; mod iovs;
mod msg; mod msg;
mod shutdown;
mod socket_address; mod socket_address;
mod socket_type; mod socket_type;
mod unix_socket; mod unix;
pub use self::address_family::AddressFamily; pub use self::address_family::AddressFamily;
pub use self::flags::{FileFlags, MsgHdrFlags, RecvFlags, SendFlags}; pub use self::flags::{FileFlags, MsgHdrFlags, RecvFlags, SendFlags};
pub use self::host_socket::{HostSocket, HostSocketType}; pub use self::host::{HostSocket, HostSocketType};
pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec}; pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec};
pub use self::msg::{msghdr, msghdr_mut, MsgHdr, MsgHdrMut}; pub use self::msg::{msghdr, msghdr_mut, MsgHdr, MsgHdrMut};
pub use self::shutdown::HowToShut;
pub use self::socket_address::SockAddr; pub use self::socket_address::SockAddr;
pub use self::socket_type::SocketType; pub use self::socket_type::SocketType;
pub use self::unix_socket::{AsUnixSocket, UnixSocketFile}; pub use self::unix::{socketpair, unix_socket, AsUnixSocket, UnixAddr};

@ -0,0 +1,28 @@
use super::*;
bitflags! {
pub struct HowToShut: c_int {
const READ = 0;
const WRITE = 1;
const BOTH = 2;
}
}
impl HowToShut {
pub fn try_from_raw(how: c_int) -> Result<Self> {
match how {
0 => Ok(Self::READ),
1 => Ok(Self::WRITE),
2 => Ok(Self::BOTH),
_ => return_errno!(EINVAL, "invalid how"),
}
}
pub fn to_shut_read(&self) -> bool {
*self == Self::READ || *self == Self::BOTH
}
pub fn to_shut_write(&self) -> bool {
*self == Self::WRITE || *self == Self::BOTH
}
}

@ -0,0 +1,151 @@
use super::*;
use std::path::{Path, PathBuf};
use std::{cmp, mem, slice, str};
const MAX_PATH_LEN: usize = 108;
const SUN_FAMILY_LEN: usize = mem::size_of::<libc::sa_family_t>();
lazy_static! {
static ref SUN_PATH_OFFSET: usize = memoffset::offset_of!(libc::sockaddr_un, sun_path);
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Addr {
File(UnixPath),
Abstract(String),
}
impl Addr {
/// Caller should guarentee the sockaddr and addr_len are valid.
/// The pathname should end with a '\0' within the passed length.
/// The abstract name should both start and end with a '\0' within the passed length.
pub unsafe fn try_from_raw(
sockaddr: *const libc::sockaddr,
addr_len: libc::socklen_t,
) -> Result<Self> {
let addr_len = addr_len as usize;
// TODO: support autobind to validate when addr_len == SUN_FAMILY_LEN
if addr_len <= SUN_FAMILY_LEN {
return_errno!(EINVAL, "the address is too short.");
}
if addr_len > MAX_PATH_LEN + *SUN_PATH_OFFSET {
return_errno!(EINVAL, "the address is too long.");
}
if AddressFamily::try_from((*sockaddr).sa_family)? != AddressFamily::LOCAL {
return_errno!(EINVAL, "not a valid address for unix socket");
}
let sockaddr = sockaddr as *const libc::sockaddr_un;
let sun_path = (*sockaddr).sun_path;
if sun_path[0] == 0 {
let path_ptr = sun_path[1..(addr_len - *SUN_PATH_OFFSET)].as_ptr();
let path_slice =
slice::from_raw_parts(path_ptr as *const u8, addr_len - *SUN_PATH_OFFSET - 1);
Ok(Self::Abstract(
str::from_utf8(&path_slice).unwrap().to_string(),
))
} else {
let path_cstr = CStr::from_ptr(sun_path.as_ptr());
if path_cstr.to_bytes_with_nul().len() > MAX_PATH_LEN {
return_errno!(EINVAL, "no null in the address");
}
Ok(Self::File(UnixPath::new(path_cstr.to_str().unwrap())))
}
}
pub fn copy_to_slice(&self, dst: &mut [u8]) -> usize {
let (raw_addr, addr_len) = self.to_raw();
let src =
unsafe { std::slice::from_raw_parts(&raw_addr as *const _ as *const u8, addr_len) };
let copied = std::cmp::min(dst.len(), addr_len);
dst[..copied].copy_from_slice(&src[..copied]);
copied
}
pub fn raw_len(&self) -> usize {
/// The '/0' at the end of Self::File counts
self.path_str().len()
+ 1
+ *SUN_PATH_OFFSET
}
pub fn path_str(&self) -> &str {
match self {
Self::File(unix_path) => &unix_path.path_str(),
Self::Abstract(path) => &path,
}
}
fn to_raw(&self) -> (libc::sockaddr_un, usize) {
let mut addr: libc::sockaddr_un = unsafe { mem::zeroed() };
addr.sun_family = AddressFamily::LOCAL as libc::sa_family_t;
let addr_len = match self {
Self::File(unix_path) => {
let path_str = unix_path.path_str();
let buf_len = path_str.len();
/// addr is initialized to all zeros and try_from_raw guarentees
/// unix_path length is shorter than sun_path, so sun_path here
/// will always have a null terminator
addr.sun_path[..buf_len]
.copy_from_slice(unsafe { &*(path_str.as_bytes() as *const _ as *const [i8]) });
buf_len + *SUN_PATH_OFFSET + 1
}
Self::Abstract(path_str) => {
addr.sun_path[0] = 0;
let buf_len = path_str.len() + 1;
addr.sun_path[1..buf_len]
.copy_from_slice(unsafe { &*(path_str.as_bytes() as *const _ as *const [i8]) });
buf_len + *SUN_PATH_OFFSET
}
};
(addr, addr_len)
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct UnixPath {
inner: PathBuf,
/// Holds the cwd when a relative path is created
cwd: Option<String>,
}
impl UnixPath {
pub fn new(path: &str) -> Self {
let inner = PathBuf::from(path);
let is_absolute = inner.is_absolute();
Self {
inner: inner,
cwd: if is_absolute {
None
} else {
let thread = current!();
let fs = thread.fs().lock().unwrap();
let cwd = fs.cwd().to_owned();
Some(cwd)
},
}
}
pub fn absolute(&self) -> String {
let path_str = self.path_str();
if self.inner.is_absolute() {
path_str.to_string()
} else {
let mut prefix = path_str.to_owned();
prefix.push_str(self.cwd.as_ref().unwrap());
prefix
}
}
pub fn path_str(&self) -> &str {
self.inner.to_str().unwrap()
}
}

@ -0,0 +1,49 @@
use self::addr::Addr;
use super::*;
mod addr;
mod stream;
pub use self::addr::Addr as UnixAddr;
pub use self::stream::Stream;
//TODO: rewrite this file when a new kind of uds is added
pub fn unix_socket(socket_type: SocketType, flags: FileFlags, protocol: i32) -> Result<Stream> {
if protocol != 0 && protocol != AddressFamily::LOCAL as i32 {
return_errno!(EPROTONOSUPPORT, "protocol is not supported");
}
if socket_type == SocketType::STREAM {
Ok(Stream::new(flags))
} else {
return_errno!(ESOCKTNOSUPPORT, "only stream type is supported");
}
}
pub fn socketpair(
socket_type: SocketType,
flags: FileFlags,
protocol: i32,
) -> Result<(Stream, Stream)> {
if protocol != 0 && protocol != AddressFamily::LOCAL as i32 {
return_errno!(EPROTONOSUPPORT, "protocol is not supported");
}
if socket_type == SocketType::STREAM {
Stream::socketpair(flags)
} else {
return_errno!(ESOCKTNOSUPPORT, "only stream type is supported");
}
}
pub trait AsUnixSocket {
fn as_unix_socket(&self) -> Result<&Stream>;
}
impl AsUnixSocket for FileRef {
fn as_unix_socket(&self) -> Result<&Stream> {
self.as_any()
.downcast_ref::<Stream>()
.ok_or_else(|| errno!(EBADF, "not a unix socket"))
}
}

@ -0,0 +1,93 @@
use super::endpoint::Endpoint;
use super::stream::Listener;
use super::*;
use std::collections::btree_map::BTreeMap;
lazy_static! {
pub(super) static ref ADDRESS_SPACE: AddressSpace = AddressSpace::new();
}
pub struct AddressSpace {
file: SgxMutex<BTreeMap<String, Option<Arc<Listener>>>>,
abstr: SgxMutex<BTreeMap<String, Option<Arc<Listener>>>>,
}
impl AddressSpace {
pub fn new() -> Self {
Self {
file: SgxMutex::new(BTreeMap::new()),
abstr: SgxMutex::new(BTreeMap::new()),
}
}
pub fn add_binder(&self, addr: &Addr) -> Result<()> {
let key = Self::get_key(addr);
let mut space = self.get_space(addr);
if space.contains_key(&key) {
return_errno!(EADDRINUSE, "the addr is already bound");
} else {
space.insert(key, None);
Ok(())
}
}
pub fn add_listener(&self, addr: &Addr, capacity: usize) -> Result<()> {
let key = Self::get_key(addr);
let mut space = self.get_space(addr);
if let Some(option) = space.get(&key) {
if let Some(listener) = option {
let new_listener = Listener::new(capacity)?;
for i in 0..std::cmp::min(listener.remaining(), capacity) {
new_listener.push_incoming(listener.pop_incoming().unwrap());
}
space.insert(key, Some(Arc::new(new_listener)));
} else {
space.insert(key, Some(Arc::new(Listener::new(capacity)?)));
}
Ok(())
} else {
return_errno!(EINVAL, "the socket is not bound");
}
}
pub fn push_incoming(&self, addr: &Addr, sock: Endpoint) -> Result<()> {
self.get_listener_ref(addr)
.ok_or_else(|| errno!(ECONNREFUSED, "no one's listening on the remote address"))?
.push_incoming(sock);
Ok(())
}
pub fn pop_incoming(&self, addr: &Addr) -> Result<Endpoint> {
self.get_listener_ref(addr)
.ok_or_else(|| errno!(EINVAL, "the socket is not listening"))?
.pop_incoming()
.ok_or_else(|| errno!(EAGAIN, "No connection is incoming"))
}
pub fn get_listener_ref(&self, addr: &Addr) -> Option<Arc<Listener>> {
let key = Self::get_key(addr);
let space = self.get_space(addr);
space.get(&key).map(|x| x.clone()).flatten()
}
pub fn remove_addr(&self, addr: &Addr) {
let key = Self::get_key(addr);
let mut space = self.get_space(addr);
space.remove(&key);
}
fn get_space(&self, addr: &Addr) -> SgxMutexGuard<'_, BTreeMap<String, Option<Arc<Listener>>>> {
match addr {
Addr::File(unix_path) => self.file.lock().unwrap(),
Addr::Abstract(path) => self.abstr.lock().unwrap(),
}
}
fn get_key(addr: &Addr) -> String {
match addr {
Addr::File(unix_path) => unix_path.absolute(),
Addr::Abstract(path) => addr.path_str().to_string(),
}
}
}

@ -0,0 +1,110 @@
use super::*;
use alloc::sync::{Arc, Weak};
use fs::channel::{Channel, Consumer, Producer};
pub type Endpoint = Arc<Inner>;
/// Constructor of two connected Endpoints
pub fn end_pair(nonblocking: bool) -> Result<(Endpoint, Endpoint)> {
let (pro_a, con_a) = Channel::new(DEFAULT_BUF_SIZE)?.split();
let (pro_b, con_b) = Channel::new(DEFAULT_BUF_SIZE)?.split();
let mut end_a = Arc::new(Inner {
addr: RwLock::new(None),
reader: con_a,
writer: pro_b,
peer: Weak::default(),
});
let end_b = Arc::new(Inner {
addr: RwLock::new(None),
reader: con_b,
writer: pro_a,
peer: Arc::downgrade(&end_a),
});
unsafe {
Arc::get_mut_unchecked(&mut end_a).peer = Arc::downgrade(&end_b);
}
end_a.set_nonblocking(nonblocking);
end_b.set_nonblocking(nonblocking);
Ok((end_a, end_b))
}
/// One end of the connected unix socket
pub struct Inner {
addr: RwLock<Option<Addr>>,
reader: Consumer<u8>,
writer: Producer<u8>,
peer: Weak<Self>,
}
impl Inner {
pub fn addr(&self) -> Option<Addr> {
self.addr.read().unwrap().clone()
}
pub fn set_addr(&self, addr: &Addr) {
*self.addr.write().unwrap() = Some(addr.clone());
}
pub fn peer_addr(&self) -> Option<Addr> {
self.peer.upgrade().map(|end| end.addr().clone()).flatten()
}
pub fn set_nonblocking(&self, nonblocking: bool) {
self.reader.set_nonblocking(nonblocking);
self.writer.set_nonblocking(nonblocking);
}
pub fn nonblocking(&self) -> bool {
let cons_nonblocking = self.reader.is_nonblocking();
let prod_nonblocking = self.writer.is_nonblocking();
assert_eq!(cons_nonblocking, prod_nonblocking);
cons_nonblocking
}
pub fn read(&self, buf: &mut [u8]) -> Result<usize> {
self.reader.pop_slice(buf)
}
pub fn write(&self, buf: &[u8]) -> Result<usize> {
self.writer.push_slice(buf)
}
pub fn readv(&self, bufs: &mut [&mut [u8]]) -> Result<usize> {
self.reader.pop_slices(bufs)
}
pub fn writev(&self, bufs: &[&[u8]]) -> Result<usize> {
self.writer.push_slices(bufs)
}
pub fn bytes_to_read(&self) -> usize {
self.reader.items_to_consume()
}
pub fn shutdown(&self, how: HowToShut) -> Result<()> {
if !self.is_connected() {
return_errno!(ENOTCONN, "The socket is not connected.");
}
if how.to_shut_read() {
self.reader.shutdown()
}
if how.to_shut_write() {
self.writer.shutdown()
}
Ok(())
}
fn is_connected(&self) -> bool {
self.peer.upgrade().is_some()
}
}
// TODO: Add SO_SNDBUF and SO_RCVBUF to set/getsockopt to dynamcally change the size.
// This value is got from /proc/sys/net/core/rmem_max and wmem_max that are same on linux.
pub const DEFAULT_BUF_SIZE: usize = 208 * 1024;

@ -0,0 +1,95 @@
use super::stream::Status;
use super::*;
use fs::{AccessMode, File, FileRef, IoctlCmd, StatusFlags};
use std::any::Any;
impl File for Stream {
fn read(&self, buf: &mut [u8]) -> Result<usize> {
match &*self.inner() {
Status::Connected(endpoint) => endpoint.read(buf),
_ => return_errno!(ENOTCONN, "unconnected socket"),
}
}
fn write(&self, buf: &[u8]) -> Result<usize> {
match &*self.inner() {
Status::Connected(endpoint) => endpoint.write(buf),
_ => return_errno!(ENOTCONN, "unconnected socket"),
}
}
fn read_at(&self, offset: usize, buf: &mut [u8]) -> Result<usize> {
if offset != 0 {
return_errno!(ESPIPE, "a nonzero position is not supported");
}
self.read(buf)
}
fn write_at(&self, offset: usize, buf: &[u8]) -> Result<usize> {
if offset != 0 {
return_errno!(ESPIPE, "a nonzero position is not supported");
}
self.write(buf)
}
fn readv(&self, bufs: &mut [&mut [u8]]) -> Result<usize> {
match &*self.inner() {
Status::Connected(endpoint) => endpoint.readv(bufs),
_ => return_errno!(ENOTCONN, "unconnected socket"),
}
}
fn writev(&self, bufs: &[&[u8]]) -> Result<usize> {
match &*self.inner() {
Status::Connected(endpoint) => endpoint.writev(bufs),
_ => return_errno!(ENOTCONN, "unconnected socket"),
}
}
fn ioctl(&self, cmd: &mut IoctlCmd) -> Result<i32> {
match cmd {
IoctlCmd::FIONREAD(arg) => match &*self.inner() {
Status::Connected(endpoint) => {
let bytes_to_read = endpoint.bytes_to_read().min(std::i32::MAX as usize) as i32;
**arg = bytes_to_read;
Ok(0)
}
_ => return_errno!(ENOTCONN, "unconnected socket"),
},
_ => return_errno!(EINVAL, "unknown ioctl cmd for unix socket"),
}
}
fn get_access_mode(&self) -> Result<AccessMode> {
Ok(AccessMode::O_RDWR)
}
fn get_status_flags(&self) -> Result<StatusFlags> {
if self.nonblocking() {
Ok(StatusFlags::O_NONBLOCK)
} else {
Ok(StatusFlags::empty())
}
}
fn set_status_flags(&self, new_status_flags: StatusFlags) -> Result<()> {
// Only O_NONBLOCK, O_ASYNC and O_DIRECT can be set
let status_flags = new_status_flags
& (StatusFlags::O_NONBLOCK | StatusFlags::O_ASYNC | StatusFlags::O_DIRECT);
// Only O_NONBLOCK is supported
let nonblocking = new_status_flags.contains(StatusFlags::O_NONBLOCK);
self.set_nonblocking(nonblocking);
Ok(())
}
fn poll(&self) -> Result<PollEventFlags> {
warn!("poll is not supported for unix_socket");
let events = PollEventFlags::empty();
Ok(events)
}
fn as_any(&self) -> &dyn Any {
self
}
}

@ -0,0 +1,8 @@
use super::*;
mod address_space;
mod endpoint;
mod file;
mod stream;
pub use stream::Stream;

@ -0,0 +1,325 @@
use super::address_space::ADDRESS_SPACE;
use super::endpoint::{end_pair, Endpoint};
use super::*;
use alloc::sync::Arc;
use fs::channel::Channel;
use std::fmt;
use std::sync::atomic::{AtomicBool, Ordering};
/// SOCK_STREAM Unix socket. It has three statuses: unconnected, listening and connected. When a
/// socket is created, it is in unconnected status. It will transfer to listening after listen is
/// called and connected after connect is called. A socket in connected status can be obtained
/// through a listening socket calling accept. Listening and connected are ultimate statuses. They
/// will not transfer to other statuses.
pub struct Stream {
inner: SgxMutex<Status>,
}
impl Stream {
pub fn new(flags: FileFlags) -> Self {
Self {
inner: SgxMutex::new(Status::Unconnected(Info::new(
flags.contains(FileFlags::SOCK_NONBLOCK),
))),
}
}
pub fn socketpair(flags: FileFlags) -> Result<(Self, Self)> {
let nonblocking = flags.contains(FileFlags::SOCK_NONBLOCK);
let (end_a, end_b) = end_pair(nonblocking)?;
let socket_a = Self {
inner: SgxMutex::new(Status::Connected(end_a)),
};
let socket_b = Self {
inner: SgxMutex::new(Status::Connected(end_b)),
};
Ok((socket_a, socket_b))
}
pub fn addr(&self) -> Option<Addr> {
match &*self.inner() {
Status::Unconnected(info) => info.addr().clone(),
Status::Connected(endpoint) => endpoint.addr(),
Status::Listening(addr) => Some(addr).cloned(),
}
}
pub fn peer_addr(&self) -> Result<Addr> {
if let Status::Connected(endpoint) = &*self.inner() {
if let Some(addr) = endpoint.peer_addr() {
return Ok(addr);
}
}
return_errno!(ENOTCONN, "the socket is not connected");
}
// TODO: create the corresponding file in the fs
pub fn bind(&self, addr: &Addr) -> Result<()> {
match &mut *self.inner() {
Status::Unconnected(ref mut info) => {
if info.addr().is_some() {
return_errno!(EINVAL, "the socket is already bound");
}
// check the global address space to see if the address is avaiable before bind
ADDRESS_SPACE.add_binder(addr)?;
info.set_addr(addr);
}
Status::Connected(endpoint) => {
if endpoint.addr().is_some() {
return_errno!(EINVAL, "the socket is already bound");
}
ADDRESS_SPACE.add_binder(addr)?;
endpoint.set_addr(addr);
}
Status::Listening(_) => return_errno!(EINVAL, "the socket is already bound"),
}
Ok(())
}
pub fn listen(&self, backlog: i32) -> Result<()> {
//TODO: restrict backlog accroding to /proc/sys/net/core/somaxconn
if backlog < 0 {
return_errno!(EINVAL, "negative backlog is not supported");
}
let capacity = backlog as usize;
let mut inner = self.inner();
match &*inner {
Status::Unconnected(info) => {
if let Some(addr) = info.addr() {
ADDRESS_SPACE.add_listener(addr, capacity)?;
*inner = Status::Listening(addr.clone());
} else {
return_errno!(EINVAL, "the socket is not bound");
}
}
Status::Connected(_) => return_errno!(EINVAL, "the socket is already connected"),
/// Modify the capacity of the channel holding incoming sockets
Status::Listening(addr) => ADDRESS_SPACE.add_listener(&addr, capacity)?,
}
Ok(())
}
pub fn connect(&self, addr: &Addr) -> Result<()> {
debug!("connect to {:?}", addr);
let mut inner = self.inner();
match &*inner {
Status::Unconnected(info) => {
let self_addr_opt = info.addr();
if let Some(self_addr) = self_addr_opt {
if self_addr == addr {
return_errno!(EINVAL, "self connect is not supported");
}
}
let (end_self, end_incoming) = end_pair(info.nonblocking())?;
end_incoming.set_addr(addr);
if let Some(self_addr) = self_addr_opt {
end_self.set_addr(self_addr);
}
ADDRESS_SPACE.push_incoming(addr, end_incoming)?;
*inner = Status::Connected(end_self);
Ok(())
}
Status::Connected(endpoint) => return_errno!(EISCONN, "already connected"),
Status::Listening(addr) => return_errno!(EINVAL, "invalid socket for connect"),
}
}
pub fn accept(&self, flags: FileFlags) -> Result<(Self, Option<Addr>)> {
match &*self.inner() {
Status::Listening(addr) => {
let endpoint = ADDRESS_SPACE.pop_incoming(&addr)?;
endpoint.set_nonblocking(flags.contains(FileFlags::SOCK_NONBLOCK));
let peer_addr = endpoint.peer_addr();
debug!("accept socket from {:?}", peer_addr);
Ok((
Self {
inner: SgxMutex::new(Status::Connected(endpoint)),
},
peer_addr,
))
}
_ => return_errno!(EINVAL, "the socket is not listening"),
}
}
// TODO: handle flags
pub fn sendto(&self, buf: &[u8], flags: SendFlags, addr: &Option<Addr>) -> Result<usize> {
self.write(buf)
}
// TODO: handle flags
pub fn recvfrom(&self, buf: &mut [u8], flags: RecvFlags) -> Result<(usize, Option<Addr>)> {
let data_len = self.read(buf)?;
let addr = self.peer_addr().ok();
debug!("recvfrom {:?}", addr);
Ok((data_len, addr))
}
/// perform shutdown on the socket.
pub fn shutdown(&self, how: HowToShut) -> Result<()> {
if let Status::Connected(ref end) = &*self.inner() {
end.shutdown(how)
} else {
return_errno!(ENOTCONN, "The socket is not connected.");
}
}
pub(super) fn nonblocking(&self) -> bool {
match &*self.inner() {
Status::Unconnected(info) => info.nonblocking(),
Status::Connected(endpoint) => endpoint.nonblocking(),
Status::Listening(addr) => ADDRESS_SPACE.get_listener_ref(&addr).unwrap().nonblocking(),
}
}
pub(super) fn set_nonblocking(&self, nonblocking: bool) {
match &mut *self.inner() {
Status::Unconnected(ref mut info) => info.set_nonblocking(nonblocking),
Status::Connected(ref mut endpoint) => endpoint.set_nonblocking(nonblocking),
Status::Listening(addr) => ADDRESS_SPACE
.get_listener_ref(&addr)
.unwrap()
.set_nonblocking(nonblocking),
}
}
pub(super) fn inner(&self) -> SgxMutexGuard<'_, Status> {
self.inner.lock().unwrap()
}
}
impl Debug for Stream {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Stream")
.field("addr", &self.addr())
.field("nonblocking", &self.nonblocking())
.finish()
}
}
impl Drop for Stream {
fn drop(&mut self) {
match &*self.inner() {
Status::Unconnected(info) => {
if let Some(addr) = info.addr() {
ADDRESS_SPACE.remove_addr(&addr);
}
}
Status::Listening(addr) => {
let listener = ADDRESS_SPACE.get_listener_ref(&addr).unwrap();
ADDRESS_SPACE.remove_addr(&addr);
/// handle the blocking of other sockets holding the reference to the listener,
/// e.g., pushing to a listener full of incoming sockets
listener.shutdown();
}
_ => {}
}
}
}
pub enum Status {
Unconnected(Info),
/// The listeners are stored in a global data structure indexed by the address.
/// The consitency of Status with that data structure should be carefully maintained.
Listening(Addr),
Connected(Endpoint),
}
#[derive(Debug, Clone)]
pub struct Info {
addr: Option<Addr>,
nonblocking: bool,
}
impl Info {
pub fn new(nonblocking: bool) -> Self {
Self {
addr: None,
nonblocking: nonblocking,
}
}
pub fn addr(&self) -> &Option<Addr> {
&self.addr
}
pub fn set_addr(&mut self, addr: &Addr) {
self.addr = Some(addr.clone());
}
pub fn nonblocking(&self) -> bool {
self.nonblocking
}
pub fn set_nonblocking(&mut self, nonblocking: bool) {
self.nonblocking = nonblocking;
}
}
pub struct Listener {
channel: Channel<Endpoint>,
nonblocking: AtomicBool,
}
impl Listener {
pub fn new(capacity: usize) -> Result<Self> {
let channel = Channel::new(capacity)?;
// It may incur blocking inside a blocking if the channel is blocking. Set the channel to
// nonblocking permanently to avoid the nested blocking. This also results in nonblocking
// accept and connect. Future work is needed to resolve this blocking issue to support
// blocking accept and connect.
channel.set_nonblocking(true);
/// The listener is blocking by default
let nonblocking = AtomicBool::new(true);
Ok(Self {
channel,
nonblocking,
})
}
pub fn push_incoming(&self, stream_socket: Endpoint) {
self.channel.push(stream_socket);
}
pub fn pop_incoming(&self) -> Option<Endpoint> {
self.channel.pop().ok().flatten()
}
pub fn remaining(&self) -> usize {
self.channel.items_to_consume()
}
pub fn nonblocking(&self) -> bool {
warn!("the channel works in a nonblocking way regardless of the nonblocking status");
self.nonblocking.load(Ordering::Acquire)
}
pub fn set_nonblocking(&self, nonblocking: bool) {
warn!("the channel works in a nonblocking way regardless of the nonblocking status");
self.nonblocking.store(nonblocking, Ordering::Release);
}
pub fn shutdown(&self) {
self.channel.shutdown();
}
}

@ -1,390 +0,0 @@
use super::*;
use fs::{File, FileRef, IoctlCmd};
use rcore_fs::vfs::{FileType, Metadata, Timespec};
use std::any::Any;
use std::collections::btree_map::BTreeMap;
use std::fmt;
use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering};
use std::sync::SgxMutex as Mutex;
use util::ring_buf::{ring_buffer, RingBufReader, RingBufWriter};
pub struct UnixSocketFile {
inner: Mutex<UnixSocket>,
}
// TODO: add enqueue_event and dequeue_event
impl File for UnixSocketFile {
fn read(&self, buf: &mut [u8]) -> Result<usize> {
let mut inner = self.inner.lock().unwrap();
inner.read(buf)
}
fn write(&self, buf: &[u8]) -> Result<usize> {
let mut inner = self.inner.lock().unwrap();
inner.write(buf)
}
fn read_at(&self, _offset: usize, buf: &mut [u8]) -> Result<usize> {
self.read(buf)
}
fn write_at(&self, _offset: usize, buf: &[u8]) -> Result<usize> {
self.write(buf)
}
fn readv(&self, bufs: &mut [&mut [u8]]) -> Result<usize> {
let mut inner = self.inner.lock().unwrap();
inner.readv(bufs)
}
fn writev(&self, bufs: &[&[u8]]) -> Result<usize> {
let mut inner = self.inner.lock().unwrap();
inner.writev(bufs)
}
fn metadata(&self) -> Result<Metadata> {
Ok(Metadata {
dev: 0,
inode: 0,
size: 0,
blk_size: 0,
blocks: 0,
atime: Timespec { sec: 0, nsec: 0 },
mtime: Timespec { sec: 0, nsec: 0 },
ctime: Timespec { sec: 0, nsec: 0 },
type_: FileType::Socket,
mode: 0,
nlinks: 0,
uid: 0,
gid: 0,
rdev: 0,
})
}
fn ioctl(&self, cmd: &mut IoctlCmd) -> Result<i32> {
let mut inner = self.inner.lock().unwrap();
inner.ioctl(cmd)
}
fn poll(&self) -> Result<PollEventFlags> {
let mut inner = self.inner.lock().unwrap();
inner.poll()
}
fn as_any(&self) -> &dyn Any {
self
}
}
static SOCKETPAIR_NUM: AtomicUsize = AtomicUsize::new(0);
const SOCK_PATH_PREFIX: &str = "socketpair_";
impl UnixSocketFile {
pub fn new(socket_type: c_int, protocol: c_int) -> Result<Self> {
let inner = UnixSocket::new(socket_type, protocol)?;
Ok(UnixSocketFile {
inner: Mutex::new(inner),
})
}
pub fn bind(&self, path: impl AsRef<str>) -> Result<()> {
let mut inner = self.inner.lock().unwrap();
inner.bind(path)
}
pub fn listen(&self) -> Result<()> {
let mut inner = self.inner.lock().unwrap();
inner.listen()
}
pub fn accept(&self) -> Result<UnixSocketFile> {
let mut inner = self.inner.lock().unwrap();
let new_socket = inner.accept()?;
Ok(UnixSocketFile {
inner: Mutex::new(new_socket),
})
}
pub fn connect(&self, path: impl AsRef<str>) -> Result<()> {
let mut inner = self.inner.lock().unwrap();
inner.connect(path)
}
pub fn socketpair(socket_type: i32, protocol: i32) -> Result<(Self, Self)> {
let listen_socket = Self::new(socket_type, protocol)?;
let bound_path = listen_socket.bind_until_success();
listen_socket.listen()?;
let client_socket = Self::new(socket_type, protocol)?;
client_socket.connect(&bound_path)?;
let accepted_socket = listen_socket.accept()?;
Ok((client_socket, accepted_socket))
}
fn bind_until_success(&self) -> String {
loop {
let sock_path_suffix = SOCKETPAIR_NUM.fetch_add(1, Ordering::SeqCst);
let sock_path = format!("{}{}", SOCK_PATH_PREFIX, sock_path_suffix);
if self.bind(&sock_path).is_ok() {
return sock_path;
}
}
}
pub fn is_connected(&self) -> bool {
if let Status::Connected(_) = self.inner.lock().unwrap().status {
true
} else {
false
}
}
}
impl Debug for UnixSocketFile {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "UnixSocketFile {{ ... }}")
}
}
pub trait AsUnixSocket {
fn as_unix_socket(&self) -> Result<&UnixSocketFile>;
}
impl AsUnixSocket for FileRef {
fn as_unix_socket(&self) -> Result<&UnixSocketFile> {
self.as_any()
.downcast_ref::<UnixSocketFile>()
.ok_or_else(|| errno!(EBADF, "not a unix socket"))
}
}
pub struct UnixSocket {
obj: Option<Arc<UnixSocketObject>>,
status: Status,
}
enum Status {
None,
Listening,
Connected(Channel),
}
impl UnixSocket {
/// C/S 1: Create a new unix socket
pub fn new(socket_type: c_int, protocol: c_int) -> Result<Self> {
if socket_type == libc::SOCK_STREAM && (protocol == 0 || protocol == libc::PF_UNIX) {
Ok(UnixSocket {
obj: None,
status: Status::None,
})
} else {
// Return different error numbers according to input
return_errno!(ENOSYS, "unimplemented unix socket type")
}
}
/// Server 2: Bind the socket to a file system path
pub fn bind(&mut self, path: impl AsRef<str>) -> Result<()> {
// TODO: check permission
if self.obj.is_some() {
return_errno!(EINVAL, "The socket is already bound to an address.");
}
self.obj = Some(UnixSocketObject::create(path)?);
Ok(())
}
/// Server 3: Listen to a socket
pub fn listen(&mut self) -> Result<()> {
self.status = Status::Listening;
Ok(())
}
/// Server 4: Accept a connection on listening.
pub fn accept(&mut self) -> Result<UnixSocket> {
match self.status {
Status::Listening => {}
_ => return_errno!(EINVAL, "unix socket is not listening"),
};
// FIXME: Block. Now spin loop.
let socket = loop {
if let Some(socket) = self.obj.as_mut().unwrap().pop() {
break socket;
}
spin_loop_hint();
};
Ok(socket)
}
/// Client 2: Connect to a path
pub fn connect(&mut self, path: impl AsRef<str>) -> Result<()> {
if let Status::Listening = self.status {
return_errno!(EINVAL, "unix socket is listening?");
}
let obj = UnixSocketObject::get(path)
.ok_or_else(|| errno!(EINVAL, "unix socket path not found"))?;
// TODO: Mov the buffer allocation to function new to comply with the bahavior of unix
let (channel1, channel2) = Channel::new_pair()?;
self.status = Status::Connected(channel1);
obj.push(UnixSocket {
obj: Some(obj.clone()),
status: Status::Connected(channel2),
});
Ok(())
}
pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
self.channel_mut()?.reader.read_from_buffer(buf)
}
pub fn readv(&mut self, bufs: &mut [&mut [u8]]) -> Result<usize> {
self.channel_mut()?.reader.read_from_vector(bufs)
}
pub fn write(&mut self, buf: &[u8]) -> Result<usize> {
self.channel_mut()?.writer.write_to_buffer(buf)
}
pub fn writev(&mut self, bufs: &[&[u8]]) -> Result<usize> {
self.channel_mut()?.writer.write_to_vector(bufs)
}
fn poll(&self) -> Result<PollEventFlags> {
let channel_result = self.channel();
if let Ok(channel) = channel_result {
let readable = channel.reader.can_read() && !channel.reader.is_peer_closed();
let writable = channel.writer.can_write() && !channel.writer.is_peer_closed();
let events = if readable ^ writable {
if channel.reader.can_read() {
PollEventFlags::POLLRDHUP | PollEventFlags::POLLIN | PollEventFlags::POLLRDNORM
} else {
PollEventFlags::POLLRDHUP
}
// both readable and writable
} else if readable {
PollEventFlags::POLLIN
| PollEventFlags::POLLOUT
| PollEventFlags::POLLRDNORM
| PollEventFlags::POLLWRNORM
} else {
PollEventFlags::POLLHUP
};
Ok(events)
} else {
// For the unconnected socket
// TODO: add write support for unconnected sockets like linux does
Ok(PollEventFlags::POLLHUP)
}
}
pub fn ioctl(&self, cmd: &mut IoctlCmd) -> Result<i32> {
match cmd {
IoctlCmd::FIONREAD(arg) => {
let bytes_to_read = self
.channel()?
.reader
.bytes_to_read()
.min(std::i32::MAX as usize) as i32;
**arg = bytes_to_read;
}
_ => return_errno!(EINVAL, "unknown ioctl cmd for unix socket"),
}
Ok(0)
}
fn channel_mut(&mut self) -> Result<&mut Channel> {
if let Status::Connected(ref mut channel) = &mut self.status {
Ok(channel)
} else {
return_errno!(EBADF, "UnixSocket is not connected")
}
}
fn channel(&self) -> Result<&Channel> {
if let Status::Connected(channel) = &self.status {
Ok(channel)
} else {
return_errno!(EBADF, "UnixSocket is not connected")
}
}
}
impl Drop for UnixSocket {
fn drop(&mut self) {
if let Status::Listening = self.status {
// Only remove the object when there is one
if let Some(obj) = self.obj.as_ref() {
UnixSocketObject::remove(&obj.path);
}
}
}
}
pub struct UnixSocketObject {
path: String,
accepted_sockets: Mutex<VecDeque<UnixSocket>>,
}
impl UnixSocketObject {
fn push(&self, unix_socket: UnixSocket) {
let mut queue = self.accepted_sockets.lock().unwrap();
queue.push_back(unix_socket);
}
fn pop(&self) -> Option<UnixSocket> {
let mut queue = self.accepted_sockets.lock().unwrap();
queue.pop_front()
}
fn get(path: impl AsRef<str>) -> Option<Arc<Self>> {
let mut paths = UNIX_SOCKET_OBJS.lock().unwrap();
paths.get(path.as_ref()).map(|obj| obj.clone())
}
fn create(path: impl AsRef<str>) -> Result<Arc<Self>> {
let mut paths = UNIX_SOCKET_OBJS.lock().unwrap();
if paths.contains_key(path.as_ref()) {
return_errno!(EADDRINUSE, "unix socket path already exists");
}
let obj = Arc::new(UnixSocketObject {
path: path.as_ref().to_string(),
accepted_sockets: Mutex::new(VecDeque::new()),
});
paths.insert(path.as_ref().to_string(), obj.clone());
Ok(obj)
}
fn remove(path: impl AsRef<str>) {
let mut paths = UNIX_SOCKET_OBJS.lock().unwrap();
paths.remove(path.as_ref());
}
}
struct Channel {
reader: RingBufReader,
writer: RingBufWriter,
}
unsafe impl Send for Channel {}
unsafe impl Sync for Channel {}
impl Channel {
fn new_pair() -> Result<(Channel, Channel)> {
let (reader1, writer1) = ring_buffer(DEFAULT_BUF_SIZE)?;
let (reader2, writer2) = ring_buffer(DEFAULT_BUF_SIZE)?;
let channel1 = Channel {
reader: reader1,
writer: writer2,
};
let channel2 = Channel {
reader: reader2,
writer: writer1,
};
Ok((channel1, channel2))
}
}
// TODO: Add SO_SNDBUF and SO_RCVBUF to set/getsockopt to dynamcally change the size.
// This value is got from /proc/sys/net/core/rmem_max and wmem_max that are same on linux.
pub const DEFAULT_BUF_SIZE: usize = 208 * 1024;
lazy_static! {
static ref UNIX_SOCKET_OBJS: Mutex<BTreeMap<String, Arc<UnixSocketObject>>> =
Mutex::new(BTreeMap::new());
}

@ -18,7 +18,7 @@ pub fn do_socket(domain: c_int, socket_type: c_int, protocol: c_int) -> Result<i
let file_ref: Arc<dyn File> = match sock_domain { let file_ref: Arc<dyn File> = match sock_domain {
AddressFamily::LOCAL => { AddressFamily::LOCAL => {
let unix_socket = UnixSocketFile::new(socket_type, protocol)?; let unix_socket = unix_socket(sock_type, file_flags, protocol)?;
Arc::new(unix_socket) Arc::new(unix_socket)
} }
_ => { _ => {
@ -38,19 +38,15 @@ pub fn do_bind(fd: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t
} }
from_user::check_array(addr as *const u8, addr_len as usize)?; from_user::check_array(addr as *const u8, addr_len as usize)?;
let sock_addr = unsafe { SockAddr::try_from_raw(addr, addr_len)? };
trace!("bind to addr: {:?}", sock_addr);
let file_ref = current!().file(fd as FileDesc)?; let file_ref = current!().file(fd as FileDesc)?;
if let Ok(socket) = file_ref.as_host_socket() { if let Ok(socket) = file_ref.as_host_socket() {
let sock_addr = unsafe { SockAddr::try_from_raw(addr, addr_len)? };
trace!("bind to addr: {:?}", sock_addr);
socket.bind(&sock_addr)?; socket.bind(&sock_addr)?;
} else if let Ok(unix_socket) = file_ref.as_unix_socket() { } else if let Ok(unix_socket) = file_ref.as_unix_socket() {
let addr = addr as *const libc::sockaddr_un; let unix_addr = unsafe { UnixAddr::try_from_raw(addr, addr_len)? };
from_user::check_ptr(addr)?; trace!("bind to addr: {:?}", unix_addr);
let path = from_user::clone_cstring_safely(unsafe { (&*addr).sun_path.as_ptr() })? unix_socket.bind(&unix_addr)?;
.to_string_lossy()
.into_owned();
unix_socket.bind(path)?;
} else { } else {
return_errno!(EBADF, "not a socket"); return_errno!(EBADF, "not a socket");
} }
@ -63,7 +59,7 @@ pub fn do_listen(fd: c_int, backlog: c_int) -> Result<isize> {
if let Ok(socket) = file_ref.as_host_socket() { if let Ok(socket) = file_ref.as_host_socket() {
socket.listen(backlog)?; socket.listen(backlog)?;
} else if let Ok(unix_socket) = file_ref.as_unix_socket() { } else if let Ok(unix_socket) = file_ref.as_unix_socket() {
unix_socket.listen()?; unix_socket.listen(backlog)?;
} else { } else {
return_errno!(EBADF, "not a socket"); return_errno!(EBADF, "not a socket");
} }
@ -84,24 +80,26 @@ pub fn do_connect(
from_user::check_array(addr as *const u8, addr_len as usize)?; from_user::check_array(addr as *const u8, addr_len as usize)?;
} }
let file_ref = current!().file(fd as FileDesc)?;
if let Ok(socket) = file_ref.as_host_socket() {
let addr_option = if addr_set { let addr_option = if addr_set {
Some(unsafe { SockAddr::try_from_raw(addr, addr_len)? }) Some(unsafe { SockAddr::try_from_raw(addr, addr_len)? })
} else { } else {
None None
}; };
let file_ref = current!().file(fd as FileDesc)?;
if let Ok(socket) = file_ref.as_host_socket() {
socket.connect(&addr_option)?; socket.connect(&addr_option)?;
} else if let Ok(unix_socket) = file_ref.as_unix_socket() { } else if let Ok(unix_socket) = file_ref.as_unix_socket() {
let addr = addr as *const libc::sockaddr_un; // TODO: support AF_UNSPEC address for datagram socket use
from_user::check_ptr(addr)?; let addr = if addr_set {
let path = from_user::clone_cstring_safely(unsafe { (&*addr).sun_path.as_ptr() })? unsafe { UnixAddr::try_from_raw(addr, addr_len)? }
.to_string_lossy()
.into_owned();
unix_socket.connect(path)?;
} else { } else {
return_errno!(EBADF, "not a socket") return_errno!(EINVAL, "invalid address");
};
unix_socket.connect(&addr)?;
} else {
return_errno!(EBADF, "not a socket");
} }
Ok(0) Ok(0)
@ -131,47 +129,65 @@ pub fn do_accept4(
let close_on_spawn = file_flags.contains(FileFlags::SOCK_CLOEXEC); let close_on_spawn = file_flags.contains(FileFlags::SOCK_CLOEXEC);
let file_ref = current!().file(fd as FileDesc)?; let file_ref = current!().file(fd as FileDesc)?;
let new_fd = if let Ok(socket) = file_ref.as_host_socket() { if let Ok(socket) = file_ref.as_host_socket() {
let (new_socket_file, sock_addr_option) = socket.accept(file_flags)?; let (new_socket_file, sock_addr_option) = socket.accept(file_flags)?;
let new_file_ref: Arc<dyn File> = Arc::new(new_socket_file); let new_file_ref: Arc<dyn File> = Arc::new(new_socket_file);
let new_fd = current!().add_file(new_file_ref, close_on_spawn); let new_fd = current!().add_file(new_file_ref, close_on_spawn);
if addr_set && sock_addr_option.is_some() { if addr_set {
let sock_addr = sock_addr_option.unwrap(); if let Some(sock_addr) = sock_addr_option {
let mut buf = let mut buf =
unsafe { std::slice::from_raw_parts_mut(addr as *mut u8, *addr_len as usize) }; unsafe { std::slice::from_raw_parts_mut(addr as *mut u8, *addr_len as usize) };
sock_addr.copy_to_slice(&mut buf); sock_addr.copy_to_slice(&mut buf);
unsafe { unsafe {
*addr_len = sock_addr.len() as u32; *addr_len = sock_addr.len() as u32;
} }
} else {
unsafe {
*addr_len = 0;
} }
new_fd }
}
Ok(new_fd as isize)
} else if let Ok(unix_socket) = file_ref.as_unix_socket() { } else if let Ok(unix_socket) = file_ref.as_unix_socket() {
let addr = addr as *mut libc::sockaddr_un; let (new_socket_file, sock_addr_option) = unix_socket.accept(file_flags)?;
let new_file_ref: Arc<dyn File> = Arc::new(new_socket_file);
let new_fd = current!().add_file(new_file_ref, close_on_spawn);
if addr_set { if addr_set {
from_user::check_mut_ptr(addr)?; if let Some(sock_addr) = sock_addr_option {
let mut buf =
unsafe { std::slice::from_raw_parts_mut(addr as *mut u8, *addr_len as usize) };
sock_addr.copy_to_slice(&mut buf);
unsafe {
*addr_len = sock_addr.raw_len() as u32;
} }
// TODO: handle addr } else {
let new_socket = unix_socket.accept()?; unsafe {
let new_file_ref: Arc<dyn File> = Arc::new(new_socket); *addr_len = 0;
current!().add_file(new_file_ref, false) }
}
}
Ok(new_fd as isize)
} else { } else {
return_errno!(EBADF, "not a socket"); return_errno!(EBADF, "not a socket");
}; }
Ok(new_fd as isize)
} }
pub fn do_shutdown(fd: c_int, how: c_int) -> Result<isize> { pub fn do_shutdown(fd: c_int, how: c_int) -> Result<isize> {
debug!("shutdown: fd: {}, how: {}", fd, how); debug!("shutdown: fd: {}, how: {}", fd, how);
let how = HowToShut::try_from_raw(how)?;
let file_ref = current!().file(fd as FileDesc)?; let file_ref = current!().file(fd as FileDesc)?;
if let Ok(socket) = file_ref.as_host_socket() { if let Ok(socket) = file_ref.as_host_socket() {
let ret = try_libc!(libc::ocall::shutdown(socket.raw_host_fd() as i32, how)); socket.shutdown(how)?;
Ok(ret as isize) } else if let Ok(unix_socket) = file_ref.as_unix_socket() {
unix_socket.shutdown(how)?;
} else { } else {
// TODO: support unix socket return_errno!(EBADF, "not a host socket")
return_errno!(EBADF, "not a socket")
} }
Ok(0)
} }
pub fn do_setsockopt( pub fn do_setsockopt(
@ -232,10 +248,14 @@ pub fn do_getpeername(
addr: *mut libc::sockaddr, addr: *mut libc::sockaddr,
addr_len: *mut libc::socklen_t, addr_len: *mut libc::socklen_t,
) -> Result<isize> { ) -> Result<isize> {
debug!( let addr_set: bool = !addr.is_null();
"getpeername: fd: {}, addr: {:?}, addr_len: {:?}", if addr_set {
fd, addr, addr_len from_user::check_ptr(addr_len)?;
); from_user::check_mut_array(addr as *mut u8, unsafe { *addr_len } as usize)?;
} else {
return Ok(0);
}
let file_ref = current!().file(fd as FileDesc)?; let file_ref = current!().file(fd as FileDesc)?;
if let Ok(socket) = file_ref.as_host_socket() { if let Ok(socket) = file_ref.as_host_socket() {
let ret = try_libc!(libc::ocall::getpeername( let ret = try_libc!(libc::ocall::getpeername(
@ -245,11 +265,15 @@ pub fn do_getpeername(
)); ));
Ok(ret as isize) Ok(ret as isize)
} else if let Ok(unix_socket) = file_ref.as_unix_socket() { } else if let Ok(unix_socket) = file_ref.as_unix_socket() {
warn!("getpeername for unix socket is unimplemented"); let name = unix_socket.peer_addr()?;
return_errno!( let mut dst = unsafe {
ENOTCONN, std::slice::from_raw_parts_mut(addr as *mut _ as *mut u8, *addr_len as usize)
"hack for php: Transport endpoint is not connected" };
) name.copy_to_slice(dst);
unsafe {
*addr_len = name.raw_len() as u32;
}
Ok(0)
} else { } else {
return_errno!(EBADF, "not a socket") return_errno!(EBADF, "not a socket")
} }
@ -260,10 +284,18 @@ pub fn do_getsockname(
addr: *mut libc::sockaddr, addr: *mut libc::sockaddr,
addr_len: *mut libc::socklen_t, addr_len: *mut libc::socklen_t,
) -> Result<isize> { ) -> Result<isize> {
debug!( let addr_set: bool = !addr.is_null();
"getsockname: fd: {}, addr: {:?}, addr_len: {:?}", if addr_set {
fd, addr, addr_len from_user::check_ptr(addr_len)?;
); from_user::check_mut_array(addr as *mut u8, unsafe { *addr_len } as usize)?;
} else {
return Ok(0);
}
if unsafe { *addr_len } < std::mem::size_of::<libc::sa_family_t>() as u32 {
return_errno!(EINVAL, "input length is too short");
}
let file_ref = current!().file(fd as FileDesc)?; let file_ref = current!().file(fd as FileDesc)?;
if let Ok(socket) = file_ref.as_host_socket() { if let Ok(socket) = file_ref.as_host_socket() {
let ret = try_libc!(libc::ocall::getsockname( let ret = try_libc!(libc::ocall::getsockname(
@ -273,10 +305,24 @@ pub fn do_getsockname(
)); ));
Ok(ret as isize) Ok(ret as isize)
} else if let Ok(unix_socket) = file_ref.as_unix_socket() { } else if let Ok(unix_socket) = file_ref.as_unix_socket() {
warn!("getsockname for unix socket is unimplemented"); let name_opt = unix_socket.addr();
if let Some(name) = name_opt {
let mut dst = unsafe {
std::slice::from_raw_parts_mut(addr as *mut _ as *mut u8, *addr_len as usize)
};
name.copy_to_slice(dst);
unsafe {
*addr_len = name.raw_len() as u32;
}
} else {
unsafe {
(*addr).sa_family = AddressFamily::LOCAL as u16;
*addr_len = 2;
}
}
Ok(0) Ok(0)
} else { } else {
return_errno!(EBADF, "not a socket") return_errno!(EBADF, "not a socket");
} }
} }
@ -306,24 +352,27 @@ pub fn do_sendto(
let send_flags = SendFlags::from_bits(flags).unwrap(); let send_flags = SendFlags::from_bits(flags).unwrap();
let file_ref = current!().file(fd as FileDesc)?;
if let Ok(socket) = file_ref.as_host_socket() {
let addr_option = if addr_set { let addr_option = if addr_set {
Some(unsafe { SockAddr::try_from_raw(addr, addr_len)? }) Some(unsafe { SockAddr::try_from_raw(addr, addr_len)? })
} else { } else {
None None
}; };
let file_ref = current!().file(fd as FileDesc)?;
if let Ok(socket) = file_ref.as_host_socket() {
socket socket
.sendto(buf, send_flags, &addr_option) .sendto(buf, send_flags, &addr_option)
.map(|u| u as isize) .map(|u| u as isize)
} else if let Ok(unix) = file_ref.as_unix_socket() { } else if let Ok(unix_socket) = file_ref.as_unix_socket() {
if !unix.is_connected() { let addr_option = if addr_set {
return_errno!(ENOTCONN, "the socket has not been connected yet"); Some(unsafe { UnixAddr::try_from_raw(addr, addr_len)? })
} } else {
None
};
let data = unsafe { std::slice::from_raw_parts(base as *const u8, len) }; unix_socket
unix.write(data).map(|u| u as isize) .sendto(buf, send_flags, &addr_option)
.map(|u| u as isize)
} else { } else {
return_errno!(EBADF, "unsupported file type"); return_errno!(EBADF, "unsupported file type");
} }
@ -356,23 +405,43 @@ pub fn do_recvfrom(
} }
let file_ref = current!().file(fd as FileDesc)?; let file_ref = current!().file(fd as FileDesc)?;
let (data_len, sock_addr_option) = if let Ok(socket) = file_ref.as_host_socket() { if let Ok(socket) = file_ref.as_host_socket() {
socket.recvfrom(buf, recv_flags)? let (data_len, sock_addr_option) = socket.recvfrom(buf, recv_flags)?;
} else { if addr_set {
return_errno!(EBADF, "not a socket"); if let Some(sock_addr) = sock_addr_option {
};
if addr_set && sock_addr_option.is_some() {
let sock_addr = sock_addr_option.unwrap();
let mut buf = let mut buf =
unsafe { std::slice::from_raw_parts_mut(addr as *mut u8, *addr_len as usize) }; unsafe { std::slice::from_raw_parts_mut(addr as *mut u8, *addr_len as usize) };
sock_addr.copy_to_slice(&mut buf); sock_addr.copy_to_slice(&mut buf);
unsafe { unsafe {
*addr_len = sock_addr.len() as u32; *addr_len = sock_addr.len() as u32;
} }
} else {
unsafe {
*addr_len = 0;
}
}
} }
Ok(data_len as isize) Ok(data_len as isize)
} else if let Ok(unix_socket) = file_ref.as_unix_socket() {
let (data_len, sock_addr_option) = unix_socket.recvfrom(buf, recv_flags)?;
if addr_set {
if let Some(sock_addr) = sock_addr_option {
let mut buf =
unsafe { std::slice::from_raw_parts_mut(addr as *mut u8, *addr_len as usize) };
sock_addr.copy_to_slice(&mut buf);
unsafe {
*addr_len = sock_addr.raw_len() as u32;
}
} else {
unsafe {
*addr_len = 0;
}
}
}
Ok(data_len as isize)
} else {
return_errno!(EBADF, "not a socket");
}
} }
pub fn do_socketpair( pub fn do_socketpair(
@ -388,11 +457,11 @@ pub fn do_socketpair(
let file_flags = FileFlags::from_bits_truncate(socket_type); let file_flags = FileFlags::from_bits_truncate(socket_type);
let close_on_spawn = file_flags.contains(FileFlags::SOCK_CLOEXEC); let close_on_spawn = file_flags.contains(FileFlags::SOCK_CLOEXEC);
let sock_type = SocketType::try_from(socket_type & (!file_flags.bits()))?;
let domain = AddressFamily::try_from(domain as u16)?; let domain = AddressFamily::try_from(domain as u16)?;
if (domain == AddressFamily::LOCAL) { if (domain == AddressFamily::LOCAL) {
let (client_socket, server_socket) = let (client_socket, server_socket) = socketpair(sock_type, file_flags, protocol as i32)?;
UnixSocketFile::socketpair(socket_type as i32, protocol as i32)?;
let current = current!(); let current = current!();
let mut files = current.files().lock().unwrap(); let mut files = current.files().lock().unwrap();

@ -4,6 +4,5 @@ pub mod dirty;
pub mod log; pub mod log;
pub mod mem_util; pub mod mem_util;
pub mod mpx_util; pub mod mpx_util;
pub mod ring_buf;
pub mod sgx; pub mod sgx;
pub mod sync; pub mod sync;

@ -1,428 +0,0 @@
use alloc::alloc::{alloc, dealloc, Layout};
use crate::net::{
clear_notifier_status, notify_thread, wait_for_notification, IoEvent, PollEventFlags,
};
use std::cmp::{max, min};
use std::ptr;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use super::*;
use ringbuf::{Consumer, Producer, RingBuffer};
pub fn ring_buffer(capacity: usize) -> Result<(RingBufReader, RingBufWriter)> {
let meta = RingBufMeta::new();
let buffer = RingBuffer::<u8>::new(capacity);
let (producer, consumer) = buffer.split();
let meta_ref = Arc::new(meta);
let reader = RingBufReader {
inner: consumer,
buffer: meta_ref.clone(),
};
let writer = RingBufWriter {
inner: producer,
buffer: meta_ref,
};
Ok((reader, writer))
}
struct RingBufMeta {
lock: Arc<SgxMutex<bool>>, // lock for the synchronization of reader and writer
reader_closed: AtomicBool, // if reader has been dropped
writer_closed: AtomicBool, // if writer has been dropped
reader_wait_queue: SgxMutex<HashMap<pid_t, IoEvent>>,
writer_wait_queue: SgxMutex<HashMap<pid_t, IoEvent>>,
// TODO: support O_ASYNC and O_DIRECT in ringbuffer
blocking_read: AtomicBool, // if the read is blocking
blocking_write: AtomicBool, // if the write is blocking
}
impl RingBufMeta {
pub fn new() -> RingBufMeta {
Self {
lock: Arc::new(SgxMutex::new(true)),
reader_closed: AtomicBool::new(false),
writer_closed: AtomicBool::new(false),
reader_wait_queue: SgxMutex::new(HashMap::new()),
writer_wait_queue: SgxMutex::new(HashMap::new()),
blocking_read: AtomicBool::new(true),
blocking_write: AtomicBool::new(true),
}
}
pub fn is_reader_closed(&self) -> bool {
self.reader_closed.load(Ordering::SeqCst)
}
pub fn close_reader(&self) {
self.reader_closed.store(true, Ordering::SeqCst);
}
pub fn is_writer_closed(&self) -> bool {
self.writer_closed.load(Ordering::SeqCst)
}
pub fn close_writer(&self) {
self.writer_closed.store(true, Ordering::SeqCst);
}
pub fn reader_wait_queue(&self) -> &SgxMutex<HashMap<pid_t, IoEvent>> {
&self.reader_wait_queue
}
pub fn writer_wait_queue(&self) -> &SgxMutex<HashMap<pid_t, IoEvent>> {
&self.writer_wait_queue
}
pub fn enqueue_reader_event(&self, event: IoEvent) -> Result<()> {
self.reader_wait_queue
.lock()
.unwrap()
.insert(current!().tid(), event);
Ok(())
}
pub fn dequeue_reader_event(&self) -> Result<()> {
self.reader_wait_queue
.lock()
.unwrap()
.remove(&current!().tid())
.unwrap();
Ok(())
}
pub fn enqueue_writer_event(&self, event: IoEvent) -> Result<()> {
self.writer_wait_queue
.lock()
.unwrap()
.insert(current!().tid(), event);
Ok(())
}
pub fn dequeue_writer_event(&self) -> Result<()> {
self.writer_wait_queue
.lock()
.unwrap()
.remove(&current!().tid())
.unwrap();
Ok(())
}
pub fn blocking_read(&self) -> bool {
self.blocking_read.load(Ordering::SeqCst)
}
pub fn set_non_blocking_read(&self) {
self.blocking_read.store(false, Ordering::SeqCst);
}
pub fn set_blocking_read(&self) {
self.blocking_read.store(true, Ordering::SeqCst);
}
pub fn blocking_write(&self) -> bool {
self.blocking_write.load(Ordering::SeqCst)
}
pub fn set_non_blocking_write(&self) {
self.blocking_write.store(false, Ordering::SeqCst);
}
pub fn set_blocking_write(&self) {
self.blocking_write.store(true, Ordering::SeqCst);
}
}
pub struct RingBufReader {
inner: Consumer<u8>,
buffer: Arc<RingBufMeta>,
}
impl RingBufReader {
pub fn can_read(&self) -> bool {
self.bytes_to_read() != 0
}
pub fn read_from_buffer(&mut self, buffer: &mut [u8]) -> Result<usize> {
self.read(Some(buffer), None)
}
pub fn read_from_vector(&mut self, buffers: &mut [&mut [u8]]) -> Result<usize> {
self.read(None, Some(buffers))
}
fn read(
&mut self,
buffer: Option<&mut [u8]>,
buffers: Option<&mut [&mut [u8]]>,
) -> Result<usize> {
assert!(buffer.is_some() ^ buffers.is_some());
// In case of write after can_read is false
let lock_ref = self.buffer.lock.clone();
let lock_holder = lock_ref.lock();
if self.can_read() {
let count = if buffer.is_some() {
self.inner.pop_slice(buffer.unwrap())
} else {
self.pop_slices(buffers.unwrap())
};
assert!(count > 0);
self.read_end();
Ok(count)
} else {
if self.is_peer_closed() {
return Ok(0);
}
if !self.buffer.blocking_read() {
return_errno!(EAGAIN, "No data to read");
} else {
// Clear the status of notifier before enqueue
clear_notifier_status(current!().tid())?;
self.enqueue_event(IoEvent::BlockingRead)?;
drop(lock_holder);
drop(lock_ref);
let ret = wait_for_notification();
self.dequeue_event()?;
ret?;
let lock_ref = self.buffer.lock.clone();
let lock_holder = lock_ref.lock();
let count = if buffer.is_some() {
self.inner.pop_slice(buffer.unwrap())
} else {
self.pop_slices(buffers.unwrap())
};
if count > 0 {
self.read_end()?;
} else {
assert!(self.is_peer_closed());
}
Ok(count)
}
}
}
fn pop_slices(&mut self, buffers: &mut [&mut [u8]]) -> usize {
let mut total = 0;
for buf in buffers {
let count = self.inner.pop_slice(buf);
total += count;
if count < buf.len() {
break;
}
}
total
}
pub fn bytes_to_read(&self) -> usize {
self.inner.len()
}
fn read_end(&self) -> Result<()> {
for (tid, event) in &*self.buffer.writer_wait_queue().lock().unwrap() {
match event {
IoEvent::Poll(poll_events) => {
if !(poll_events.events()
& (PollEventFlags::POLLOUT | PollEventFlags::POLLWRNORM))
.is_empty()
{
notify_thread(*tid)?;
}
}
IoEvent::Epoll(epoll_file) => unimplemented!(),
IoEvent::BlockingRead => unreachable!(),
IoEvent::BlockingWrite => notify_thread(*tid)?,
}
}
Ok(())
}
pub fn is_peer_closed(&self) -> bool {
self.buffer.is_writer_closed()
}
pub fn enqueue_event(&self, event: IoEvent) -> Result<()> {
self.buffer.enqueue_reader_event(event)
}
pub fn dequeue_event(&self) -> Result<()> {
self.buffer.dequeue_reader_event()
}
pub fn set_non_blocking(&self) {
self.buffer.set_non_blocking_read()
}
pub fn set_blocking(&self) {
self.buffer.set_blocking_read()
}
fn before_drop(&self) {
for (tid, event) in &*self.buffer.writer_wait_queue().lock().unwrap() {
match event {
IoEvent::Poll(_) | IoEvent::BlockingWrite => notify_thread(*tid).unwrap(),
IoEvent::Epoll(epoll_file) => unimplemented!(),
IoEvent::BlockingRead => unreachable!(),
}
}
}
}
impl Drop for RingBufReader {
fn drop(&mut self) {
debug!("reader drop");
self.buffer.close_reader();
if self.buffer.blocking_write() {
self.before_drop();
}
}
}
pub struct RingBufWriter {
inner: Producer<u8>,
buffer: Arc<RingBufMeta>,
}
impl RingBufWriter {
pub fn write_to_buffer(&mut self, buffer: &[u8]) -> Result<usize> {
self.write(Some(buffer), None)
}
pub fn write_to_vector(&mut self, buffers: &[&[u8]]) -> Result<usize> {
self.write(None, Some(buffers))
}
fn write(&mut self, buffer: Option<&[u8]>, buffers: Option<&[&[u8]]>) -> Result<usize> {
assert!(buffer.is_some() ^ buffers.is_some());
// TODO: send SIGPIPE to the caller
if self.is_peer_closed() {
return_errno!(EPIPE, "reader side is closed");
}
// In case of read after can_write is false
let lock_ref = self.buffer.lock.clone();
let lock_holder = lock_ref.lock();
if self.can_write() {
let count = if buffer.is_some() {
self.inner.push_slice(buffer.unwrap())
} else {
self.push_slices(buffers.unwrap())
};
assert!(count > 0);
self.write_end();
Ok(count)
} else {
if !self.buffer.blocking_write() {
return_errno!(EAGAIN, "No space to write");
}
// Clear the status of notifier before enqueue
clear_notifier_status(current!().tid());
self.enqueue_event(IoEvent::BlockingWrite)?;
drop(lock_holder);
drop(lock_ref);
let ret = wait_for_notification();
self.dequeue_event()?;
ret?;
let lock_ref = self.buffer.lock.clone();
let lock_holder = lock_ref.lock();
let count = if buffer.is_some() {
self.inner.push_slice(buffer.unwrap())
} else {
self.push_slices(buffers.unwrap())
};
if count > 0 {
self.write_end();
Ok(count)
} else {
return_errno!(EPIPE, "reader side is closed");
}
}
}
fn write_end(&self) -> Result<()> {
for (tid, event) in &*self.buffer.reader_wait_queue().lock().unwrap() {
match event {
IoEvent::Poll(poll_events) => {
if !(poll_events.events()
& (PollEventFlags::POLLIN | PollEventFlags::POLLRDNORM))
.is_empty()
{
notify_thread(*tid)?;
}
}
IoEvent::Epoll(epoll_file) => unimplemented!(),
IoEvent::BlockingRead => notify_thread(*tid)?,
IoEvent::BlockingWrite => unreachable!(),
}
}
Ok(())
}
fn push_slices(&mut self, buffers: &[&[u8]]) -> usize {
let mut total = 0;
for buf in buffers {
let count = self.inner.push_slice(buf);
total += count;
if count < buf.len() {
break;
}
}
total
}
pub fn can_write(&self) -> bool {
!self.inner.is_full()
}
pub fn is_peer_closed(&self) -> bool {
self.buffer.is_reader_closed()
}
pub fn enqueue_event(&self, event: IoEvent) -> Result<()> {
self.buffer.enqueue_writer_event(event)
}
pub fn dequeue_event(&self) -> Result<()> {
self.buffer.dequeue_writer_event()
}
pub fn set_non_blocking(&self) {
self.buffer.set_non_blocking_write()
}
pub fn set_blocking(&self) {
self.buffer.set_blocking_write()
}
fn before_drop(&self) {
for (tid, event) in &*self.buffer.reader_wait_queue().lock().unwrap() {
match event {
IoEvent::Poll(_) | IoEvent::BlockingRead => {
notify_thread(*tid).unwrap();
}
IoEvent::Epoll(epoll_file) => unimplemented!(),
IoEvent::BlockingWrite => unreachable!(),
}
}
}
}
impl Drop for RingBufWriter {
fn drop(&mut self) {
debug!("writer drop");
self.buffer.close_writer();
if self.buffer.blocking_read() {
self.before_drop();
}
}
}

@ -203,6 +203,40 @@ int test_poll() {
return 0; return 0;
} }
int test_getname() {
char name[] = "unix_socket_path";
int sock = socket(AF_UNIX, SOCK_STREAM, 0);
if (sock == -1) {
THROW_ERROR("failed to create a unix socket");
}
struct sockaddr_un addr = {0};
memset(&addr, 0, sizeof(struct sockaddr_un)); //Clear structure
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, name);
socklen_t addr_len = strlen(addr.sun_path) + sizeof(addr.sun_family) + 1;
if (bind(sock, (struct sockaddr *)&addr, addr_len) == -1) {
close(sock);
THROW_ERROR("failed to bind");
}
struct sockaddr_un ret_addr = {0};
socklen_t ret_addr_len = sizeof(ret_addr);
if (getsockname(sock, (struct sockaddr *)&ret_addr, &ret_addr_len) < 0) {
close(sock);
THROW_ERROR("failed to getsockname");
}
if (ret_addr_len != addr_len || strcmp(ret_addr.sun_path, name) != 0) {
close(sock);
THROW_ERROR("got name mismatched");
}
close(sock);
return 0;
}
static test_case_t test_cases[] = { static test_case_t test_cases[] = {
TEST_CASE(test_unix_socket_inter_process), TEST_CASE(test_unix_socket_inter_process),
TEST_CASE(test_socketpair_inter_process), TEST_CASE(test_socketpair_inter_process),
@ -210,6 +244,7 @@ static test_case_t test_cases[] = {
// TODO: recover the test after the unix sockets are rewritten by using // TODO: recover the test after the unix sockets are rewritten by using
// the new event subsystem // the new event subsystem
//TEST_CASE(test_poll), //TEST_CASE(test_poll),
TEST_CASE(test_getname),
}; };
int main(int argc, const char *argv[]) { int main(int argc, const char *argv[]) {