[Libos] Refactor io_uring and related module implementions

This commit is contained in:
ClawSeven 2024-05-24 17:18:50 +08:00 committed by volcano
parent b80450ef96
commit 070a024c0d
42 changed files with 478 additions and 518 deletions

@ -26,7 +26,11 @@ impl Synchronizer for EdgeSync {
return Ok(()); return Ok(());
} }
loop { loop {
self.host_eventfd.poll(timeout)?; if let Err(error) = self.host_eventfd.poll(timeout) {
self.state.store(INIT, Ordering::Relaxed);
return Err(error);
}
if self if self
.state .state
.compare_exchange(NOTIFIED, INIT, Ordering::Acquire, Ordering::Acquire) .compare_exchange(NOTIFIED, INIT, Ordering::Acquire, Ordering::Acquire)
@ -47,8 +51,7 @@ impl Synchronizer for EdgeSync {
// Need to change timeout from `Option<&mut Duration>` to `&mut Option<Duration>` // Need to change timeout from `Option<&mut Duration>` to `&mut Option<Duration>`
// so that the Rust compiler is happy about using the variable in a loop. // so that the Rust compiler is happy about using the variable in a loop.
let ret = self.host_eventfd.poll_mut(remain.as_mut()); let ret = self.host_eventfd.poll_mut(remain.as_mut());
// Wait for something to happen, assuming it's still set to PARKED. // Wait for something to happen, assuming it's still set to NOTIFIED.
// futex_wait(&self.state, PARKED, Some(timeout));
// This is not just a store, because we need to establish a // This is not just a store, because we need to establish a
// release-acquire ordering with unpark(). // release-acquire ordering with unpark().
if self.state.swap(INIT, Ordering::Acquire) == NOTIFIED { if self.state.swap(INIT, Ordering::Acquire) == NOTIFIED {

@ -27,7 +27,7 @@ use crate::prelude::*;
/// Although it is safer to use AcqRelhere using `Release` would be enough. /// Although it is safer to use AcqRelhere using `Release` would be enough.
pub struct WaiterQueue<Sync: Synchronizer = LevelSync> { pub struct WaiterQueue<Sync: Synchronizer = LevelSync> {
count: AtomicUsize, count: AtomicUsize,
wakers: SgxMutex<VecDeque<Waker<Sync>>>, wakers: Mutex<VecDeque<Waker<Sync>>>,
} }
impl<Sync: Synchronizer> WaiterQueue<Sync> { impl<Sync: Synchronizer> WaiterQueue<Sync> {
@ -35,7 +35,7 @@ impl<Sync: Synchronizer> WaiterQueue<Sync> {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
count: AtomicUsize::new(0), count: AtomicUsize::new(0),
wakers: SgxMutex::new(VecDeque::new()), wakers: Mutex::new(VecDeque::new()),
} }
} }
@ -54,7 +54,7 @@ impl<Sync: Synchronizer> WaiterQueue<Sync> {
pub fn reset_and_enqueue(&self, waiter: &Waiter<Sync>) { pub fn reset_and_enqueue(&self, waiter: &Waiter<Sync>) {
waiter.reset(); waiter.reset();
let mut wakers = self.wakers.lock().unwrap(); let mut wakers = self.wakers.lock();
self.count.fetch_add(1, Ordering::Release); self.count.fetch_add(1, Ordering::Release);
wakers.push_back(waiter.waker()); wakers.push_back(waiter.waker());
} }
@ -78,7 +78,7 @@ impl<Sync: Synchronizer> WaiterQueue<Sync> {
// Dequeue wakers // Dequeue wakers
let to_wake = { let to_wake = {
let mut wakers = self.wakers.lock().unwrap(); let mut wakers = self.wakers.lock();
let max_count = max_count.min(wakers.len()); let max_count = max_count.min(wakers.len());
let to_wake: Vec<Waker<Sync>> = wakers.drain(..max_count).collect(); let to_wake: Vec<Waker<Sync>> = wakers.drain(..max_count).collect();
self.count.fetch_sub(to_wake.len(), Ordering::Release); self.count.fetch_sub(to_wake.len(), Ordering::Release);

@ -36,7 +36,8 @@ impl GetIfConf {
} }
let mut if_conf = self.to_raw_ifconf(); let mut if_conf = self.to_raw_ifconf();
get_ifconf_by_host(fd, &mut if_conf)?; GetIfConf::get_ifconf_by_host(fd, &mut if_conf)?;
self.set_len(if_conf.ifc_len as usize); self.set_len(if_conf.ifc_len as usize);
Ok(()) Ok(())
} }
@ -78,13 +79,10 @@ impl GetIfConf {
}, },
} }
} }
}
impl IoctlCmd for GetIfConf {} fn get_ifconf_by_host(fd: FileDesc, if_conf: &mut IfConf) -> Result<()> {
const SIOCGIFCONF: u32 = 0x8912;
const SIOCGIFCONF: u32 = 0x8912;
fn get_ifconf_by_host(fd: FileDesc, if_conf: &mut IfConf) -> Result<()> {
extern "C" { extern "C" {
// Used to ioctl arguments with pointer members. // Used to ioctl arguments with pointer members.
// //
@ -133,4 +131,7 @@ fn get_ifconf_by_host(fd: FileDesc, if_conf: &mut IfConf) -> Result<()> {
}); });
Ok(()) Ok(())
}
} }
impl IoctlCmd for GetIfConf {}

@ -34,13 +34,12 @@ impl GetIfReqWithRawCmd {
pub fn execute(&mut self, fd: FileDesc) -> Result<()> { pub fn execute(&mut self, fd: FileDesc) -> Result<()> {
let input_if_req = self.inner.input(); let input_if_req = self.inner.input();
let output_if_req = get_ifreq_by_host(fd, self.raw_cmd, input_if_req)?; let output_if_req = GetIfReqWithRawCmd::get_ifreq_by_host(fd, self.raw_cmd, input_if_req)?;
self.inner.set_output(output_if_req); self.inner.set_output(output_if_req);
Ok(()) Ok(())
} }
}
fn get_ifreq_by_host(fd: FileDesc, cmd: u32, req: &IfReq) -> Result<IfReq> { fn get_ifreq_by_host(fd: FileDesc, cmd: u32, req: &IfReq) -> Result<IfReq> {
let mut if_req: IfReq = req.clone(); let mut if_req: IfReq = req.clone();
try_libc!({ try_libc!({
let mut retval: i32 = 0; let mut retval: i32 = 0;
@ -68,6 +67,7 @@ fn get_ifreq_by_host(fd: FileDesc, cmd: u32, req: &IfReq) -> Result<IfReq> {
retval retval
}); });
Ok(if_req) Ok(if_req)
}
} }
impl IoctlCmd for GetIfReqWithRawCmd {} impl IoctlCmd for GetIfReqWithRawCmd {}

@ -214,7 +214,6 @@ pub use self::set_close_on_exec::*;
pub use self::set_nonblocking::SetNonBlocking; pub use self::set_nonblocking::SetNonBlocking;
pub use self::termios::*; pub use self::termios::*;
pub use self::winsize::*; pub use self::winsize::*;
pub use net::socket::sockopt::SetSockOptRawCmd;
mod get_ifconf; mod get_ifconf;
mod get_ifreq; mod get_ifreq;

@ -112,7 +112,7 @@ impl File for PipeReader {
fn ioctl(&self, cmd: &mut dyn IoctlCmd) -> Result<()> { fn ioctl(&self, cmd: &mut dyn IoctlCmd) -> Result<()> {
match_ioctl_cmd_auto_error!(cmd, { match_ioctl_cmd_auto_error!(cmd, {
cmd : GetReadBufLen => { cmd : GetReadBufLen => {
let read_buf_len = self.consumer.ready_len(); let read_buf_len = self.get_ready_len().min(std::i32::MAX as usize) as i32;
cmd.set_output(read_buf_len as _); cmd.set_output(read_buf_len as _);
}, },
}); });

@ -104,13 +104,11 @@ impl UringSet {
// Sum registered socket // Sum registered socket
let total_socket_num = map let total_socket_num = map
.values() .values()
.fold(0, |acc, state| acc + state.registered_num) .fold(0, |acc, state| acc + state.registered_num);
+ 1;
// Determine the number of available io_uring // Determine the number of available io_uring
let uring_num = (total_socket_num / SOCKET_THRESHOLD_PER_URING) + 1; let uring_num = (total_socket_num / SOCKET_THRESHOLD_PER_URING) + 1;
let existed_uring_num = self.running_uring_num.load(Ordering::Relaxed);
assert!(existed_uring_num <= uring_num); running_uring_num < uring_num
existed_uring_num < uring_num
}; };
if should_build_uring { if should_build_uring {
@ -134,7 +132,7 @@ impl UringSet {
// Link the file to the io_uring instance with the least load. // Link the file to the io_uring instance with the least load.
let (mut uring, mut state) = map let (mut uring, mut state) = map
.iter_mut() .iter_mut()
.min_by_key(|(_, &mut state)| state.registered_num) .min_by_key(|(_, state)| state.registered_num)
.unwrap(); .unwrap();
// Re-select io_uring instance with least task load // Re-select io_uring instance with least task load

@ -7,13 +7,17 @@ pub use self::io_multiplexing::{
PollEventFlags, PollFd, THREAD_NOTIFIERS, PollEventFlags, PollFd, THREAD_NOTIFIERS,
}; };
pub use self::socket::{ pub use self::socket::{
socketpair, unix_socket, AsUnixSocket, Domain, HostSocket, HostSocketType, Iovs, IovsMut, mmsghdr, socketpair, unix_socket, Addr, AnyAddr, AsUnixSocket, Domain, GetAcceptConnCmd,
RawAddr, SliceAsLibcIovec, UnixAddr, GetDomainCmd, GetErrorCmd, GetOutputAsBytes, GetPeerNameCmd, GetRecvBufSizeCmd,
GetRecvTimeoutCmd, GetSendBufSizeCmd, GetSendTimeoutCmd, GetSockOptRawCmd, GetTypeCmd,
HostSocket, HostSocketType, Iovs, IovsMut, RecvFlags, SendFlags, SetRecvBufSizeCmd,
SetRecvTimeoutCmd, SetSendBufSizeCmd, SetSendTimeoutCmd, SetSockOptRawCmd, Shutdown,
SliceAsLibcIovec, SockAddr, SockOptName, SocketFile, SocketType, UnixAddr, UringSocketType,
}; };
pub use self::syscalls::*; pub use self::syscalls::*;
mod io_multiplexing; mod io_multiplexing;
pub(crate) mod socket; mod socket;
mod syscalls; mod syscalls;
pub use self::syscalls::*; pub use self::syscalls::*;

@ -26,7 +26,7 @@ pub struct HostSocket {
impl HostSocket { impl HostSocket {
pub fn new( pub fn new(
domain: Domain, domain: Domain,
socket_type: Type, socket_type: SocketType,
socket_flags: SocketFlags, socket_flags: SocketFlags,
protocol: i32, protocol: i32,
) -> Result<Self> { ) -> Result<Self> {
@ -49,7 +49,7 @@ impl HostSocket {
}) })
} }
pub fn bind(&self, addr: &RawAddr) -> Result<()> { pub fn bind(&self, addr: &SockAddr) -> Result<()> {
let (addr_ptr, addr_len) = addr.as_ptr_and_len(); let (addr_ptr, addr_len) = addr.as_ptr_and_len();
let ret = try_libc!(libc::ocall::bind( let ret = try_libc!(libc::ocall::bind(
@ -65,8 +65,8 @@ impl HostSocket {
Ok(()) Ok(())
} }
pub fn accept(&self, flags: SocketFlags) -> Result<(Self, Option<RawAddr>)> { pub fn accept(&self, flags: SocketFlags) -> Result<(Self, Option<SockAddr>)> {
let mut sockaddr = RawAddr::default(); let mut sockaddr = SockAddr::default();
let mut addr_len = sockaddr.len(); let mut addr_len = sockaddr.len();
let raw_host_fd = try_libc!(libc::ocall::accept4( let raw_host_fd = try_libc!(libc::ocall::accept4(
@ -86,8 +86,8 @@ impl HostSocket {
Ok((HostSocket::from_host_fd(host_fd)?, addr_option)) Ok((HostSocket::from_host_fd(host_fd)?, addr_option))
} }
pub fn addr(&self) -> Result<RawAddr> { pub fn addr(&self) -> Result<SockAddr> {
let mut sockaddr = RawAddr::default(); let mut sockaddr = SockAddr::default();
let mut addr_len = sockaddr.len(); let mut addr_len = sockaddr.len();
try_libc!(libc::ocall::getsockname( try_libc!(libc::ocall::getsockname(
self.raw_host_fd() as i32, self.raw_host_fd() as i32,
@ -99,8 +99,8 @@ impl HostSocket {
Ok(sockaddr) Ok(sockaddr)
} }
pub fn peer_addr(&self) -> Result<RawAddr> { pub fn peer_addr(&self) -> Result<SockAddr> {
let mut sockaddr = RawAddr::default(); let mut sockaddr = SockAddr::default();
let mut addr_len = sockaddr.len(); let mut addr_len = sockaddr.len();
try_libc!(libc::ocall::getpeername( try_libc!(libc::ocall::getpeername(
self.raw_host_fd() as i32, self.raw_host_fd() as i32,
@ -112,7 +112,7 @@ impl HostSocket {
Ok(sockaddr) Ok(sockaddr)
} }
pub fn connect(&self, addr: &Option<RawAddr>) -> Result<()> { pub fn connect(&self, addr: Option<&SockAddr>) -> Result<()> {
debug!("connect: host_fd: {}, addr {:?}", self.raw_host_fd(), addr); debug!("connect: host_fd: {}, addr {:?}", self.raw_host_fd(), addr);
let (addr_ptr, addr_len) = if let Some(sock_addr) = addr { let (addr_ptr, addr_len) = if let Some(sock_addr) = addr {
@ -133,14 +133,13 @@ impl HostSocket {
&self, &self,
buf: &[u8], buf: &[u8],
flags: SendFlags, flags: SendFlags,
addr_option: &Option<RawAddr>, addr_option: Option<AnyAddr>,
) -> Result<usize> { ) -> Result<usize> {
let bufs = vec![buf]; let bufs = vec![buf];
self.sendmsg(&bufs, flags, addr_option, None) self.sendmsg(&bufs, flags, addr_option, None)
} }
pub fn recvfrom(&self, buf: &mut [u8], flags: RecvFlags) -> Result<(usize, Option<RawAddr>)> { pub fn recvfrom(&self, buf: &mut [u8], flags: RecvFlags) -> Result<(usize, Option<AnyAddr>)> {
let mut sockaddr = RawAddr::default();
let mut bufs = vec![buf]; let mut bufs = vec![buf];
let (bytes_recv, recv_addr, _, _) = self.recvmsg(&mut bufs, flags, None)?; let (bytes_recv, recv_addr, _, _) = self.recvmsg(&mut bufs, flags, None)?;

@ -12,7 +12,7 @@ impl HostSocket {
data: &mut [&mut [u8]], data: &mut [&mut [u8]],
flags: RecvFlags, flags: RecvFlags,
control: Option<&mut [u8]>, control: Option<&mut [u8]>,
) -> Result<(usize, Option<RawAddr>, MsgFlags, usize)> { ) -> Result<(usize, Option<AnyAddr>, MsgFlags, usize)> {
let current = current!(); let current = current!();
let data_length = data.iter().map(|s| s.len()).sum(); let data_length = data.iter().map(|s| s.len()).sum();
let mut ocall_alloc; let mut ocall_alloc;
@ -54,10 +54,10 @@ impl HostSocket {
data: &mut [UntrustedSlice], data: &mut [UntrustedSlice],
flags: RecvFlags, flags: RecvFlags,
mut control: Option<&mut [u8]>, mut control: Option<&mut [u8]>,
) -> Result<(usize, Option<RawAddr>, MsgFlags, usize)> { ) -> Result<(usize, Option<AnyAddr>, MsgFlags, usize)> {
// Prepare the arguments for OCall // Prepare the arguments for OCall
let host_fd = self.raw_host_fd() as i32; let host_fd = self.raw_host_fd() as i32;
let mut addr = RawAddr::default(); let mut addr = SockAddr::default();
let mut msg_name = addr.as_mut_ptr(); let mut msg_name = addr.as_mut_ptr();
let mut msg_namelen = addr.len(); let mut msg_namelen = addr.len();
let mut msg_namelen_recvd = 0_u32; let mut msg_namelen_recvd = 0_u32;
@ -122,16 +122,16 @@ impl HostSocket {
}; };
let msg_namelen_recvd = msg_namelen_recvd as usize; let msg_namelen_recvd = msg_namelen_recvd as usize;
let raw_addr = if msg_namelen_recvd == 0 { let raw_addr = (msg_namelen_recvd != 0).then(|| {
None addr.set_len(msg_namelen_recvd);
} else { addr
addr.set_len(msg_namelen_recvd)?; });
Some(addr)
}; let addr = raw_addr.map(|addr| AnyAddr::Raw(addr));
assert!(msg_namelen_recvd <= msg_namelen); assert!(msg_namelen_recvd <= msg_namelen);
assert!(msg_controllen_recvd <= msg_controllen); assert!(msg_controllen_recvd <= msg_controllen);
Ok((bytes_recvd, raw_addr, flags_recvd, msg_controllen_recvd)) Ok((bytes_recvd, addr, flags_recvd, msg_controllen_recvd))
} }
} }

@ -2,14 +2,14 @@ use super::*;
impl HostSocket { impl HostSocket {
pub fn send(&self, buf: &[u8], flags: SendFlags) -> Result<usize> { pub fn send(&self, buf: &[u8], flags: SendFlags) -> Result<usize> {
self.sendto(buf, flags, &None) self.sendto(buf, flags, None)
} }
pub fn sendmsg( pub fn sendmsg(
&self, &self,
data: &[&[u8]], data: &[&[u8]],
flags: SendFlags, flags: SendFlags,
addr: &Option<RawAddr>, addr: Option<AnyAddr>,
control: Option<&[u8]>, control: Option<&[u8]>,
) -> Result<usize> { ) -> Result<usize> {
let current = current!(); let current = current!();
@ -34,8 +34,14 @@ impl HostSocket {
bufs bufs
}; };
let name = addr.as_ref().map(|raw_addr| raw_addr.as_slice()); let raw_addr = addr.map(|addr| addr.to_raw());
self.do_sendmsg_untrusted_data(&u_data, flags, name, control)
self.do_sendmsg_untrusted_data(
&u_data,
flags,
raw_addr.as_ref().map(|addr| addr.as_slice()),
control,
)
} }
fn do_sendmsg_untrusted_data( fn do_sendmsg_untrusted_data(

@ -41,7 +41,7 @@ impl File for HostSocket {
} }
fn writev(&self, bufs: &[&[u8]]) -> Result<usize> { fn writev(&self, bufs: &[&[u8]]) -> Result<usize> {
self.sendmsg(bufs, SendFlags::empty(), &None, None) self.sendmsg(bufs, SendFlags::empty(), None, None)
} }
fn seek(&self, pos: SeekFrom) -> Result<off_t> { fn seek(&self, pos: SeekFrom) -> Result<off_t> {

@ -1,15 +1,22 @@
use super::*; use super::*;
mod host; mod host;
pub(crate) mod sockopt; mod sockopt;
mod unix; mod unix;
pub(crate) mod uring; mod uring;
pub(crate) mod util; mod util;
pub use self::host::{HostSocket, HostSocketType}; pub use self::host::{HostSocket, HostSocketType};
pub use self::unix::{socketpair, unix_socket, AsUnixSocket}; pub use self::unix::{socketpair, unix_socket, AsUnixSocket};
pub use self::util::{ pub use self::util::{
Addr, AnyAddr, CMessages, CSockAddr, CmsgData, Domain, Iovs, IovsMut, Ipv4Addr, Ipv4SocketAddr, mmsghdr, Addr, AnyAddr, CMessages, CSockAddr, CmsgData, Domain, Iovs, IovsMut, Ipv4Addr,
Ipv6SocketAddr, MsgFlags, RawAddr, RecvFlags, SendFlags, Shutdown, SliceAsLibcIovec, Ipv4SocketAddr, Ipv6SocketAddr, MsgFlags, RecvFlags, SendFlags, Shutdown, SliceAsLibcIovec,
SocketProtocol, Type, UnixAddr, SockAddr, SocketFlags, SocketProtocol, SocketType, UnixAddr,
}; };
pub use sockopt::{
GetAcceptConnCmd, GetDomainCmd, GetErrorCmd, GetOutputAsBytes, GetPeerNameCmd,
GetRecvBufSizeCmd, GetRecvTimeoutCmd, GetSendBufSizeCmd, GetSendTimeoutCmd, GetSockOptRawCmd,
GetTypeCmd, SetRecvBufSizeCmd, SetRecvTimeoutCmd, SetSendBufSizeCmd, SetSendTimeoutCmd,
SetSockOptRawCmd, SockOptName,
};
pub use uring::{socket_file::SocketFile, UringSocketType};

@ -1,12 +1,11 @@
use super::{GetRecvTimeoutCmd, GetSendTimeoutCmd}; use super::{GetRecvTimeoutCmd, GetSendTimeoutCmd};
use super::{ use super::{
GetAcceptConnCmd, GetDomainCmd, GetErrorCmd, GetPeerNameCmd, GetRcvBufSizeCmd, GetAcceptConnCmd, GetDomainCmd, GetErrorCmd, GetPeerNameCmd, GetRecvBufSizeCmd,
GetSndBufSizeCmd, GetSockOptRawCmd, GetTypeCmd, GetSendBufSizeCmd, GetSockOptRawCmd, GetTypeCmd,
}; };
use libc::timeval; use libc::timeval;
use std::time::Duration;
use crate::prelude::*; use crate::prelude::*;
@ -60,7 +59,7 @@ impl GetOutputAsBytes for GetTypeCmd {
} }
} }
impl GetOutputAsBytes for GetSndBufSizeCmd { impl GetOutputAsBytes for GetSendBufSizeCmd {
fn get_output_as_bytes(&self) -> Option<&[u8]> { fn get_output_as_bytes(&self) -> Option<&[u8]> {
self.output().map(|val_ref| unsafe { self.output().map(|val_ref| unsafe {
std::slice::from_raw_parts( std::slice::from_raw_parts(
@ -71,7 +70,7 @@ impl GetOutputAsBytes for GetSndBufSizeCmd {
} }
} }
impl GetOutputAsBytes for GetRcvBufSizeCmd { impl GetOutputAsBytes for GetRecvBufSizeCmd {
fn get_output_as_bytes(&self) -> Option<&[u8]> { fn get_output_as_bytes(&self) -> Option<&[u8]> {
self.output().map(|val_ref| unsafe { self.output().map(|val_ref| unsafe {
std::slice::from_raw_parts( std::slice::from_raw_parts(

@ -1,7 +1,7 @@
crate::impl_ioctl_cmd! { crate::impl_ioctl_cmd! {
pub struct GetSndBufSizeCmd<Input=(), Output=usize> {} pub struct GetSendBufSizeCmd<Input=(), Output=usize> {}
} }
crate::impl_ioctl_cmd! { crate::impl_ioctl_cmd! {
pub struct GetRcvBufSizeCmd<Input=(), Output=usize> {} pub struct GetRecvBufSizeCmd<Input=(), Output=usize> {}
} }

@ -16,10 +16,10 @@ pub use get_domain::GetDomainCmd;
pub use get_error::GetErrorCmd; pub use get_error::GetErrorCmd;
pub use get_output::*; pub use get_output::*;
pub use get_peername::{AddrStorage, GetPeerNameCmd}; pub use get_peername::{AddrStorage, GetPeerNameCmd};
pub use get_sockbuf::{GetRcvBufSizeCmd, GetSndBufSizeCmd}; pub use get_sockbuf::{GetRecvBufSizeCmd, GetSendBufSizeCmd};
pub use get_type::GetTypeCmd; pub use get_type::GetTypeCmd;
pub use set::{setsockopt_by_host, SetSockOptRawCmd}; pub use set::{setsockopt_by_host, SetSockOptRawCmd};
pub use set_sockbuf::{SetRcvBufSizeCmd, SetSndBufSizeCmd}; pub use set_sockbuf::{SetRecvBufSizeCmd, SetSendBufSizeCmd};
pub use timeout::{ pub use timeout::{
timeout_to_timeval, GetRecvTimeoutCmd, GetSendTimeoutCmd, SetRecvTimeoutCmd, SetSendTimeoutCmd, timeout_to_timeval, GetRecvTimeoutCmd, GetSendTimeoutCmd, SetRecvTimeoutCmd, SetSendTimeoutCmd,
}; };

@ -2,14 +2,16 @@ use crate::{fs::IoctlCmd, prelude::*};
use libc::ocall::setsockopt as do_setsockopt; use libc::ocall::setsockopt as do_setsockopt;
#[derive(Debug)] #[derive(Debug)]
pub struct SetSockOptRawCmd { pub struct SetSockOptRawCmd<'a> {
level: i32, level: i32,
optname: i32, optname: i32,
optval: &'static [u8], optval: &'a [u8],
} }
impl SetSockOptRawCmd { impl IoctlCmd for SetSockOptRawCmd<'static> {}
pub fn new(level: i32, optname: i32, optval: &'static [u8]) -> Self {
impl<'a> SetSockOptRawCmd<'a> {
pub fn new(level: i32, optname: i32, optval: &'a [u8]) -> SetSockOptRawCmd<'a> {
Self { Self {
level, level,
optname, optname,
@ -23,8 +25,6 @@ impl SetSockOptRawCmd {
} }
} }
impl IoctlCmd for SetSockOptRawCmd {}
pub fn setsockopt_by_host(fd: FileDesc, level: i32, optname: i32, optval: &[u8]) -> Result<()> { pub fn setsockopt_by_host(fd: FileDesc, level: i32, optname: i32, optval: &[u8]) -> Result<()> {
try_libc!(do_setsockopt( try_libc!(do_setsockopt(
fd as _, fd as _,

@ -1,23 +1,18 @@
use super::set::setsockopt_by_host; use super::set::setsockopt_by_host;
use crate::{fs::IoctlCmd, prelude::*}; use crate::{fs::IoctlCmd, prelude::*};
#[derive(Debug)] crate::impl_ioctl_cmd! {
pub struct SetSndBufSizeCmd { pub struct SetSendBufSizeCmd<Input=usize, Output=()> {}
buf_size: usize,
} }
impl SetSndBufSizeCmd { crate::impl_ioctl_cmd! {
pub fn new(buf_size: usize) -> Self { pub struct SetRecvBufSizeCmd<Input=usize, Output=()> {}
Self { buf_size } }
}
pub fn buf_size(&self) -> usize {
self.buf_size
}
impl SetSendBufSizeCmd {
pub fn update_host(&self, fd: FileDesc) -> Result<()> { pub fn update_host(&self, fd: FileDesc) -> Result<()> {
// The buf size for host call should be divided by 2 because the value will be doubled by host kernel. // The buf size for host call should be divided by 2 because the value will be doubled by host kernel.
let host_call_buf_size = (self.buf_size / 2).to_ne_bytes(); let host_call_buf_size = (self.input / 2).to_ne_bytes();
// Setting SO_SNDBUF for host socket needs to respect /proc/sys/net/core/wmem_max. Thus, the value might be different on host, but it is fine. // Setting SO_SNDBUF for host socket needs to respect /proc/sys/net/core/wmem_max. Thus, the value might be different on host, but it is fine.
setsockopt_by_host( setsockopt_by_host(
@ -29,25 +24,10 @@ impl SetSndBufSizeCmd {
} }
} }
impl IoctlCmd for SetSndBufSizeCmd {} impl SetRecvBufSizeCmd {
#[derive(Debug)]
pub struct SetRcvBufSizeCmd {
buf_size: usize,
}
impl SetRcvBufSizeCmd {
pub fn new(buf_size: usize) -> Self {
Self { buf_size }
}
pub fn buf_size(&self) -> usize {
self.buf_size
}
pub fn update_host(&self, fd: FileDesc) -> Result<()> { pub fn update_host(&self, fd: FileDesc) -> Result<()> {
// The buf size for host call should be divided by 2 because the value will be doubled by host kernel. // The buf size for host call should be divided by 2 because the value will be doubled by host kernel.
let host_call_buf_size = (self.buf_size / 2).to_ne_bytes(); let host_call_buf_size = (self.input / 2).to_ne_bytes();
// Setting SO_RCVBUF for host socket needs to respect /proc/sys/net/core/rmem_max. Thus, the value might be different on host, but it is fine. // Setting SO_RCVBUF for host socket needs to respect /proc/sys/net/core/rmem_max. Thus, the value might be different on host, but it is fine.
setsockopt_by_host( setsockopt_by_host(
@ -58,5 +38,3 @@ impl SetRcvBufSizeCmd {
) )
} }
} }
impl IoctlCmd for SetRcvBufSizeCmd {}

@ -3,34 +3,12 @@ use crate::prelude::*;
use libc::{suseconds_t, time_t}; use libc::{suseconds_t, time_t};
use std::time::Duration; use std::time::Duration;
#[derive(Debug)] crate::impl_ioctl_cmd! {
pub struct SetSendTimeoutCmd(Duration); pub struct SetSendTimeoutCmd<Input=Duration, Output=()> {}
impl IoctlCmd for SetSendTimeoutCmd {}
impl SetSendTimeoutCmd {
pub fn new(timeout: Duration) -> Self {
Self(timeout)
}
pub fn timeout(&self) -> &Duration {
&self.0
}
} }
#[derive(Debug)] crate::impl_ioctl_cmd! {
pub struct SetRecvTimeoutCmd(Duration); pub struct SetRecvTimeoutCmd<Input=Duration, Output=()> {}
impl IoctlCmd for SetRecvTimeoutCmd {}
impl SetRecvTimeoutCmd {
pub fn new(timeout: Duration) -> Self {
Self(timeout)
}
pub fn timeout(&self) -> &Duration {
&self.0
}
} }
crate::impl_ioctl_cmd! { crate::impl_ioctl_cmd! {

@ -5,12 +5,12 @@ mod stream;
pub use self::stream::Stream; pub use self::stream::Stream;
//TODO: rewrite this file when a new kind of uds is added //TODO: rewrite this file when a new kind of uds is added
pub fn unix_socket(socket_type: Type, flags: SocketFlags, protocol: i32) -> Result<Stream> { pub fn unix_socket(socket_type: SocketType, flags: SocketFlags, protocol: i32) -> Result<Stream> {
if protocol != 0 && protocol != Domain::LOCAL as i32 { if protocol != 0 && protocol != Domain::LOCAL as i32 {
return_errno!(EPROTONOSUPPORT, "protocol is not supported"); return_errno!(EPROTONOSUPPORT, "protocol is not supported");
} }
if socket_type == Type::STREAM { if socket_type == SocketType::STREAM {
Ok(Stream::new(flags)) Ok(Stream::new(flags))
} else { } else {
return_errno!(ESOCKTNOSUPPORT, "only stream type is supported"); return_errno!(ESOCKTNOSUPPORT, "only stream type is supported");
@ -18,7 +18,7 @@ pub fn unix_socket(socket_type: Type, flags: SocketFlags, protocol: i32) -> Resu
} }
pub fn socketpair( pub fn socketpair(
socket_type: Type, socket_type: SocketType,
flags: SocketFlags, flags: SocketFlags,
protocol: i32, protocol: i32,
) -> Result<(Stream, Stream)> { ) -> Result<(Stream, Stream)> {
@ -26,7 +26,7 @@ pub fn socketpair(
return_errno!(EPROTONOSUPPORT, "protocol is not supported"); return_errno!(EPROTONOSUPPORT, "protocol is not supported");
} }
if socket_type == Type::STREAM { if socket_type == SocketType::STREAM {
Stream::socketpair(flags) Stream::socketpair(flags)
} else { } else {
return_errno!(ESOCKTNOSUPPORT, "only stream type is supported"); return_errno!(ESOCKTNOSUPPORT, "only stream type is supported");

@ -72,7 +72,10 @@ impl Stream {
return_errno!(ENOTCONN, "the socket is not connected"); return_errno!(ENOTCONN, "the socket is not connected");
} }
pub fn bind(&self, addr: &mut UnixAddr) -> Result<()> { pub fn bind(&self, addr: &UnixAddr) -> Result<()> {
let mut unix_addr = addr.clone();
let addr = &mut unix_addr;
if let UnixAddr::File(inode_num, path) = addr { if let UnixAddr::File(inode_num, path) = addr {
// create the corresponding file in the fs and fill Addr with its inode // create the corresponding file in the fs and fill Addr with its inode
let corresponding_inode_num = { let corresponding_inode_num = {
@ -218,7 +221,7 @@ impl Stream {
} }
// TODO: handle flags // TODO: handle flags
pub fn sendto(&self, buf: &[u8], flags: SendFlags, addr: &Option<UnixAddr>) -> Result<usize> { pub fn sendto(&self, buf: &[u8], flags: SendFlags, addr: Option<&UnixAddr>) -> Result<usize> {
self.write(buf) self.write(buf)
} }
@ -255,7 +258,7 @@ impl Stream {
bufs: &mut [&mut [u8]], bufs: &mut [&mut [u8]],
flags: RecvFlags, flags: RecvFlags,
control: Option<&mut [u8]>, control: Option<&mut [u8]>,
) -> Result<(usize, usize)> { ) -> Result<(usize, Option<AnyAddr>, MsgFlags, usize)> {
if !flags.is_empty() { if !flags.is_empty() {
warn!("unsupported flags: {:?}", flags); warn!("unsupported flags: {:?}", flags);
} }
@ -288,7 +291,7 @@ impl Stream {
0 0
}; };
Ok((data_len, control_len)) Ok((data_len, None, MsgFlags::empty(), control_len))
} }
/// perform shutdown on the socket. /// perform shutdown on the socket.

@ -18,7 +18,7 @@ use crate::prelude::*;
/// The common parts of all stream sockets. /// The common parts of all stream sockets.
pub struct Common<A: Addr + 'static, R: Runtime> { pub struct Common<A: Addr + 'static, R: Runtime> {
host_fd: FileDesc, host_fd: FileDesc,
type_: Type, type_: SocketType,
nonblocking: AtomicBool, nonblocking: AtomicBool,
is_closed: AtomicBool, is_closed: AtomicBool,
pollee: Pollee, pollee: Pollee,
@ -30,7 +30,7 @@ pub struct Common<A: Addr + 'static, R: Runtime> {
} }
impl<A: Addr + 'static, R: Runtime> Common<A, R> { impl<A: Addr + 'static, R: Runtime> Common<A, R> {
pub fn new(type_: Type, nonblocking: bool, protocol: Option<i32>) -> Result<Self> { pub fn new(type_: SocketType, nonblocking: bool, protocol: Option<i32>) -> Result<Self> {
let domain_c = A::domain() as libc::c_int; let domain_c = A::domain() as libc::c_int;
let type_c = type_ as libc::c_int; let type_c = type_ as libc::c_int;
let protocol = protocol.unwrap_or(0) as libc::c_int; let protocol = protocol.unwrap_or(0) as libc::c_int;
@ -56,11 +56,11 @@ impl<A: Addr + 'static, R: Runtime> Common<A, R> {
}) })
} }
pub fn new_pair(sock_type: Type, nonblocking: bool) -> Result<(Self, Self)> { pub fn new_pair(sock_type: SocketType, nonblocking: bool) -> Result<(Self, Self)> {
return_errno!(EINVAL, "Unix is unsupported"); return_errno!(EINVAL, "Unix is unsupported");
} }
pub fn with_host_fd(host_fd: FileDesc, type_: Type, nonblocking: bool) -> Self { pub fn with_host_fd(host_fd: FileDesc, type_: SocketType, nonblocking: bool) -> Self {
let nonblocking = AtomicBool::new(nonblocking); let nonblocking = AtomicBool::new(nonblocking);
let is_closed = AtomicBool::new(false); let is_closed = AtomicBool::new(false);
let pollee = Pollee::new(IoEvents::empty()); let pollee = Pollee::new(IoEvents::empty());
@ -90,7 +90,7 @@ impl<A: Addr + 'static, R: Runtime> Common<A, R> {
self.host_fd self.host_fd
} }
pub fn type_(&self) -> Type { pub fn type_(&self) -> SocketType {
self.type_ self.type_
} }

@ -20,7 +20,7 @@ pub struct DatagramSocket<A: Addr + 'static, R: Runtime> {
impl<A: Addr, R: Runtime> DatagramSocket<A, R> { impl<A: Addr, R: Runtime> DatagramSocket<A, R> {
pub fn new(nonblocking: bool) -> Result<Self> { pub fn new(nonblocking: bool) -> Result<Self> {
let common = Arc::new(Common::new(Type::DGRAM, nonblocking, None)?); let common = Arc::new(Common::new(SocketType::DGRAM, nonblocking, None)?);
let state = RwLock::new(State::new()); let state = RwLock::new(State::new());
let sender = Sender::new(common.clone()); let sender = Sender::new(common.clone());
let receiver = Receiver::new(common.clone()); let receiver = Receiver::new(common.clone());
@ -33,7 +33,7 @@ impl<A: Addr, R: Runtime> DatagramSocket<A, R> {
} }
pub fn new_pair(nonblocking: bool) -> Result<(Self, Self)> { pub fn new_pair(nonblocking: bool) -> Result<(Self, Self)> {
let (common1, common2) = Common::new_pair(Type::DGRAM, nonblocking)?; let (common1, common2) = Common::new_pair(SocketType::DGRAM, nonblocking)?;
let socket1 = Self::new_connected(common1); let socket1 = Self::new_connected(common1);
let socket2 = Self::new_connected(common2); let socket2 = Self::new_connected(common2);
Ok((socket1, socket2)) Ok((socket1, socket2))
@ -321,10 +321,10 @@ impl<A: Addr, R: Runtime> DatagramSocket<A, R> {
cmd.execute(self.host_fd())?; cmd.execute(self.host_fd())?;
}, },
cmd: SetRecvTimeoutCmd => { cmd: SetRecvTimeoutCmd => {
self.set_recv_timeout(*cmd.timeout()); self.set_recv_timeout(*cmd.input());
}, },
cmd: SetSendTimeoutCmd => { cmd: SetSendTimeoutCmd => {
self.set_send_timeout(*cmd.timeout()); self.set_send_timeout(*cmd.input());
}, },
cmd: GetRecvTimeoutCmd => { cmd: GetRecvTimeoutCmd => {
let timeval = timeout_to_timeval(self.recv_timeout()); let timeval = timeout_to_timeval(self.recv_timeout());

@ -102,10 +102,10 @@ impl SocketFile {
apply_fn_on_any_socket!(&self.socket, |socket| { socket.ioctl(cmd) }) apply_fn_on_any_socket!(&self.socket, |socket| { socket.ioctl(cmd) })
} }
pub fn get_type(&self) -> Type { pub fn get_type(&self) -> SocketType {
match self.socket { match self.socket {
AnySocket::Ipv4Stream(_) | AnySocket::Ipv6Stream(_) => Type::STREAM, AnySocket::Ipv4Stream(_) | AnySocket::Ipv6Stream(_) => SocketType::STREAM,
AnySocket::Ipv4Datagram(_) | AnySocket::Ipv6Datagram(_) => Type::DGRAM, AnySocket::Ipv4Datagram(_) | AnySocket::Ipv6Datagram(_) => SocketType::DGRAM,
} }
} }
} }
@ -115,11 +115,11 @@ impl SocketFile {
pub fn new( pub fn new(
domain: Domain, domain: Domain,
protocol: SocketProtocol, protocol: SocketProtocol,
socket_type: Type, socket_type: SocketType,
nonblocking: bool, nonblocking: bool,
) -> Result<Self> { ) -> Result<Self> {
match socket_type { match socket_type {
Type::STREAM => { SocketType::STREAM => {
if protocol != SocketProtocol::IPPROTO_IP && protocol != SocketProtocol::IPPROTO_TCP if protocol != SocketProtocol::IPPROTO_IP && protocol != SocketProtocol::IPPROTO_TCP
{ {
return_errno!(EPROTONOSUPPORT, "Protocol not supported"); return_errno!(EPROTONOSUPPORT, "Protocol not supported");
@ -140,7 +140,7 @@ impl SocketFile {
let new_self = Self { socket: any_socket }; let new_self = Self { socket: any_socket };
Ok(new_self) Ok(new_self)
} }
Type::DGRAM => { SocketType::DGRAM => {
if protocol != SocketProtocol::IPPROTO_IP && protocol != SocketProtocol::IPPROTO_UDP if protocol != SocketProtocol::IPPROTO_IP && protocol != SocketProtocol::IPPROTO_UDP
{ {
return_errno!(EPROTONOSUPPORT, "Protocol not supported"); return_errno!(EPROTONOSUPPORT, "Protocol not supported");
@ -161,7 +161,7 @@ impl SocketFile {
let new_self = Self { socket: any_socket }; let new_self = Self { socket: any_socket };
Ok(new_self) Ok(new_self)
} }
Type::RAW => { SocketType::RAW => {
return_errno!(EINVAL, "RAW socket not supported"); return_errno!(EINVAL, "RAW socket not supported");
} }
_ => { _ => {
@ -210,7 +210,7 @@ impl SocketFile {
} }
} }
pub fn bind(&self, addr: &mut AnyAddr) -> Result<()> { pub fn bind(&self, addr: &AnyAddr) -> Result<()> {
match &self.socket { match &self.socket {
AnySocket::Ipv4Stream(ipv4_stream) => { AnySocket::Ipv4Stream(ipv4_stream) => {
let ip_addr = addr.to_ipv4()?; let ip_addr = addr.to_ipv4()?;

@ -57,7 +57,7 @@ impl<A: Addr, R: Runtime> StreamSocket<A, R> {
} }
pub fn new_pair(nonblocking: bool) -> Result<(Self, Self)> { pub fn new_pair(nonblocking: bool) -> Result<(Self, Self)> {
let (common1, common2) = Common::new_pair(Type::STREAM, nonblocking)?; let (common1, common2) = Common::new_pair(SocketType::STREAM, nonblocking)?;
let connected1 = ConnectedStream::new(Arc::new(common1)); let connected1 = ConnectedStream::new(Arc::new(common1));
let connected2 = ConnectedStream::new(Arc::new(common2)); let connected2 = ConnectedStream::new(Arc::new(common2));
let socket1 = Self::new_connected(connected1); let socket1 = Self::new_connected(connected1);
@ -373,10 +373,10 @@ impl<A: Addr, R: Runtime> StreamSocket<A, R> {
cmd.execute(self.host_fd())?; cmd.execute(self.host_fd())?;
}, },
cmd: SetRecvTimeoutCmd => { cmd: SetRecvTimeoutCmd => {
self.set_recv_timeout(*cmd.timeout()); self.set_recv_timeout(*cmd.input());
}, },
cmd: SetSendTimeoutCmd => { cmd: SetSendTimeoutCmd => {
self.set_send_timeout(*cmd.timeout()); self.set_send_timeout(*cmd.input());
}, },
cmd: GetRecvTimeoutCmd => { cmd: GetRecvTimeoutCmd => {
let timeval = timeout_to_timeval(self.recv_timeout()); let timeval = timeout_to_timeval(self.recv_timeout());
@ -386,21 +386,21 @@ impl<A: Addr, R: Runtime> StreamSocket<A, R> {
let timeval = timeout_to_timeval(self.send_timeout()); let timeval = timeout_to_timeval(self.send_timeout());
cmd.set_output(timeval); cmd.set_output(timeval);
}, },
cmd: SetSndBufSizeCmd => { cmd: SetSendBufSizeCmd => {
cmd.update_host(self.host_fd())?; cmd.update_host(self.host_fd())?;
let buf_size = cmd.buf_size(); let buf_size = *cmd.input();
self.set_kernel_send_buf_size(buf_size); self.set_kernel_send_buf_size(buf_size);
}, },
cmd: SetRcvBufSizeCmd => { cmd: SetRecvBufSizeCmd => {
cmd.update_host(self.host_fd())?; cmd.update_host(self.host_fd())?;
let buf_size = cmd.buf_size(); let buf_size = *cmd.input();
self.set_kernel_recv_buf_size(buf_size); self.set_kernel_recv_buf_size(buf_size);
}, },
cmd: GetSndBufSizeCmd => { cmd: GetSendBufSizeCmd => {
let buf_size = SEND_BUF_SIZE.load(Ordering::Relaxed); let buf_size = SEND_BUF_SIZE.load(Ordering::Relaxed);
cmd.set_output(buf_size); cmd.set_output(buf_size);
}, },
cmd: GetRcvBufSizeCmd => { cmd: GetRecvBufSizeCmd => {
let buf_size = RECV_BUF_SIZE.load(Ordering::Relaxed); let buf_size = RECV_BUF_SIZE.load(Ordering::Relaxed);
cmd.set_output(buf_size); cmd.set_output(buf_size);
}, },
@ -454,6 +454,8 @@ impl<A: Addr, R: Runtime> StreamSocket<A, R> {
} }
fn set_kernel_send_buf_size(&self, buf_size: usize) { fn set_kernel_send_buf_size(&self, buf_size: usize) {
// Setting the minimal buf_size to 128 Kbytes
let buf_size = (128 * 1024 + 1).max(buf_size);
let state = self.state.read().unwrap(); let state = self.state.read().unwrap();
match &*state { match &*state {
State::Init(_) | State::Listen(_) | State::Connect(_) => { State::Init(_) | State::Listen(_) | State::Connect(_) => {
@ -467,6 +469,8 @@ impl<A: Addr, R: Runtime> StreamSocket<A, R> {
} }
fn set_kernel_recv_buf_size(&self, buf_size: usize) { fn set_kernel_recv_buf_size(&self, buf_size: usize) {
// Setting the minimal buf_size to 128 Kbytes
let buf_size = (128 * 1024 + 1).max(buf_size);
let state = self.state.read().unwrap(); let state = self.state.read().unwrap();
match &*state { match &*state {
State::Init(_) | State::Listen(_) | State::Connect(_) => { State::Init(_) | State::Listen(_) | State::Connect(_) => {

@ -191,7 +191,6 @@ impl<A: Addr + 'static, R: Runtime> ConnectedStream<A, R> {
// Init the callback invoked upon the completion of the async recv // Init the callback invoked upon the completion of the async recv
let stream = self.clone(); let stream = self.clone();
let complete_fn = move |retval: i32| { let complete_fn = move |retval: i32| {
// let mut inner = stream.receiver.inner.lock().unwrap();
let mut inner = stream.receiver.inner.lock(); let mut inner = stream.receiver.inner.lock();
trace!("recv request complete with retval: {:?}", retval); trace!("recv request complete with retval: {:?}", retval);
@ -232,7 +231,9 @@ impl<A: Addr + 'static, R: Runtime> ConnectedStream<A, R> {
// ready to read. // ready to read.
stream.common.pollee().add_events(Events::IN); stream.common.pollee().add_events(Events::IN);
if !stream.receiver.need_update() {
stream.do_recv(&mut inner); stream.do_recv(&mut inner);
}
}; };
// Generate the async recv request // Generate the async recv request

@ -183,7 +183,14 @@ impl<A: Addr + 'static, R: Runtime> ConnectedStream<A, R> {
inner.fatal = Some(errno); inner.fatal = Some(errno);
stream.common.set_errno(errno); stream.common.set_errno(errno);
stream.common.pollee().add_events(Events::ERR);
let events = if errno == ENOTCONN || errno == ECONNRESET || errno == ECONNREFUSED {
Events::HUP | Events::OUT | Events::ERR
} else {
Events::ERR
};
stream.common.pollee().add_events(events);
return; return;
} }
assert!(retval != 0); assert!(retval != 0);

@ -15,7 +15,7 @@ struct Inner {
impl<A: Addr + 'static, R: Runtime> InitStream<A, R> { impl<A: Addr + 'static, R: Runtime> InitStream<A, R> {
pub fn new(nonblocking: bool) -> Result<Arc<Self>> { pub fn new(nonblocking: bool) -> Result<Arc<Self>> {
let common = Arc::new(Common::new(Type::STREAM, nonblocking, None)?); let common = Arc::new(Common::new(SocketType::STREAM, nonblocking, None)?);
common.pollee().add_events(IoEvents::HUP | IoEvents::OUT); common.pollee().add_events(IoEvents::HUP | IoEvents::OUT);
let inner = Mutex::new(Inner::new()); let inner = Mutex::new(Inner::new());
let new_self = Self { common, inner }; let new_self = Self { common, inner };

@ -121,7 +121,11 @@ impl<A: Addr + 'static, R: Runtime> ListenerStream<A, R> {
self.initiate_async_accepts(inner); self.initiate_async_accepts(inner);
let common = { let common = {
let common = Arc::new(Common::with_host_fd(accepted_fd, Type::STREAM, nonblocking)); let common = Arc::new(Common::with_host_fd(
accepted_fd,
SocketType::STREAM,
nonblocking,
));
common.set_peer_addr(&accepted_addr); common.set_peer_addr(&accepted_addr);
common common
}; };

@ -1,7 +1,8 @@
use crate::net::{Addr, Domain};
use std::any::Any; use std::any::Any;
use std::fmt::{self, Debug}; use std::fmt::{self, Debug};
use super::{Addr, CSockAddr, Domain, RawAddr}; use super::{CSockAddr, SockAddr};
use crate::prelude::*; use crate::prelude::*;
/// An IPv4 socket address, consisting of an IPv4 address and a port. /// An IPv4 socket address, consisting of an IPv4 address and a port.
@ -70,9 +71,9 @@ impl Ipv4SocketAddr {
} }
} }
pub fn to_raw(&self) -> RawAddr { pub fn to_raw(&self) -> SockAddr {
let (storage, len) = self.to_c_storage(); let (storage, len) = self.to_c_storage();
RawAddr::from_c_storage(&storage, len) SockAddr::from_c_storage(&storage, len)
} }
pub fn ip(&self) -> &Ipv4Addr { pub fn ip(&self) -> &Ipv4Addr {

@ -1,7 +1,7 @@
use std::any::Any; use std::any::Any;
use std::fmt::Debug; use std::fmt::Debug;
use super::RawAddr; use super::SockAddr;
use super::{Addr, CSockAddr, Domain}; use super::{Addr, CSockAddr, Domain};
use crate::prelude::*; use crate::prelude::*;
use libc::in6_addr; use libc::in6_addr;
@ -85,9 +85,9 @@ impl Ipv6SocketAddr {
} }
} }
pub fn to_raw(&self) -> RawAddr { pub fn to_raw(&self) -> SockAddr {
let (storage, len) = self.to_c_storage(); let (storage, len) = self.to_c_storage();
RawAddr::from_c_storage(&storage, len) SockAddr::from_c_storage(&storage, len)
} }
pub fn ip(&self) -> &Ipv6Addr { pub fn ip(&self) -> &Ipv6Addr {

@ -38,7 +38,7 @@ pub trait Addr: Clone + Debug + Default + PartialEq + Send + Sync {
pub use self::c_sock_addr::CSockAddr; pub use self::c_sock_addr::CSockAddr;
pub use self::ipv4::{Ipv4Addr, Ipv4SocketAddr}; pub use self::ipv4::{Ipv4Addr, Ipv4SocketAddr};
pub use self::ipv6::{Ipv6Addr, Ipv6SocketAddr}; pub use self::ipv6::{Ipv6Addr, Ipv6SocketAddr};
pub use self::raw_addr::RawAddr; pub use self::raw_addr::SockAddr;
pub use self::unix_addr::UnixAddr; pub use self::unix_addr::UnixAddr;
#[cfg(test)] #[cfg(test)]

@ -2,22 +2,22 @@ use super::*;
use std::*; use std::*;
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct RawAddr { pub struct SockAddr {
storage: libc::sockaddr_storage, storage: libc::sockaddr_storage,
len: usize, len: usize,
} }
// TODO: add more fields // TODO: add more fields
impl fmt::Debug for RawAddr { impl fmt::Debug for SockAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RawAddr") f.debug_struct("SockAddr")
.field("family", &Domain::try_from(self.storage.ss_family).unwrap()) .field("family", &Domain::try_from(self.storage.ss_family).unwrap())
.field("len", &self.len) .field("len", &self.len)
.finish() .finish()
} }
} }
impl RawAddr { impl SockAddr {
pub fn from_c_storage(c_addr: &libc::sockaddr_storage, c_addr_len: usize) -> Self { pub fn from_c_storage(c_addr: &libc::sockaddr_storage, c_addr_len: usize) -> Self {
Self { Self {
storage: *c_addr, storage: *c_addr,
@ -118,7 +118,7 @@ impl RawAddr {
} }
} }
impl Default for RawAddr { impl Default for SockAddr {
fn default() -> Self { fn default() -> Self {
let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() }; let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
Self { Self {

@ -106,7 +106,7 @@ impl UnixAddr {
/// The '/0' at the end of Self::File counts /// The '/0' at the end of Self::File counts
match self.path_str() { match self.path_str() {
Ok(str) => str.len() + 1 + *SUN_PATH_OFFSET, Ok(str) => str.len() + 1 + *SUN_PATH_OFFSET,
Err(_) => 0, Err(_) => std::mem::size_of::<libc::sa_family_t>(),
} }
} }
@ -123,9 +123,9 @@ impl UnixAddr {
c_un_addr.to_c_storage() c_un_addr.to_c_storage()
} }
pub fn to_raw(&self) -> RawAddr { pub fn to_raw(&self) -> SockAddr {
let (storage, addr_len) = self.to_c_storage(); let (storage, addr_len) = self.to_c_storage();
RawAddr::from_c_storage(&storage, addr_len) SockAddr::from_c_storage(&storage, addr_len)
} }
fn to_c(&self) -> (libc::sockaddr_un, usize) { fn to_c(&self) -> (libc::sockaddr_un, usize) {

@ -3,7 +3,7 @@ use std::mem::{self, MaybeUninit};
use crate::net::socket::Domain; use crate::net::socket::Domain;
use crate::prelude::*; use crate::prelude::*;
use super::{Addr, CSockAddr, Ipv4Addr, Ipv4SocketAddr, Ipv6SocketAddr, RawAddr, UnixAddr}; use super::{Addr, CSockAddr, Ipv4Addr, Ipv4SocketAddr, Ipv6SocketAddr, SockAddr, UnixAddr};
use num_enum::IntoPrimitive; use num_enum::IntoPrimitive;
use std::path::Path; use std::path::Path;
@ -12,7 +12,7 @@ pub enum AnyAddr {
Ipv4(Ipv4SocketAddr), Ipv4(Ipv4SocketAddr),
Ipv6(Ipv6SocketAddr), Ipv6(Ipv6SocketAddr),
Unix(UnixAddr), Unix(UnixAddr),
Raw(RawAddr), Raw(SockAddr),
Unspec, Unspec,
} }
@ -33,7 +33,7 @@ impl AnyAddr {
Self::Unix(unix_addr) Self::Unix(unix_addr)
} }
_ => { _ => {
let raw_addr = RawAddr::from_c_storage(c_addr, c_addr_len); let raw_addr = SockAddr::from_c_storage(c_addr, c_addr_len);
Self::Raw(raw_addr) Self::Raw(raw_addr)
} }
}; };
@ -55,7 +55,7 @@ impl AnyAddr {
} }
} }
pub fn to_raw(&self) -> RawAddr { pub fn to_raw(&self) -> SockAddr {
match self { match self {
Self::Ipv4(ipv4_addr) => ipv4_addr.to_raw(), Self::Ipv4(ipv4_addr) => ipv4_addr.to_raw(),
Self::Ipv6(ipv6_addr) => ipv6_addr.to_raw(), Self::Ipv6(ipv6_addr) => ipv6_addr.to_raw(),
@ -65,7 +65,7 @@ impl AnyAddr {
let mut sockaddr_storage = let mut sockaddr_storage =
unsafe { MaybeUninit::<libc::sockaddr_storage>::uninit().assume_init() }; unsafe { MaybeUninit::<libc::sockaddr_storage>::uninit().assume_init() };
sockaddr_storage.ss_family = libc::AF_UNSPEC as _; sockaddr_storage.ss_family = libc::AF_UNSPEC as _;
RawAddr::from_c_storage(&sockaddr_storage, mem::size_of::<libc::sa_family_t>()) SockAddr::from_c_storage(&sockaddr_storage, mem::size_of::<libc::sa_family_t>())
} }
} }
} }
@ -77,13 +77,6 @@ impl AnyAddr {
} }
} }
pub fn as_ipv4(&self) -> Option<&Ipv4SocketAddr> {
match self {
Self::Ipv4(ipv4_addr) => Some(ipv4_addr),
_ => None,
}
}
pub fn to_ipv4(&self) -> Result<&Ipv4SocketAddr> { pub fn to_ipv4(&self) -> Result<&Ipv4SocketAddr> {
match self { match self {
Self::Ipv4(ipv4_addr) => Ok(ipv4_addr), Self::Ipv4(ipv4_addr) => Ok(ipv4_addr),

@ -1,4 +1,6 @@
use bitflags::bitflags; use bitflags::bitflags;
use sgx_trts::libc;
use std::ffi::c_uint;
// Flags to use when sending data through a socket // Flags to use when sending data through a socket
bitflags! { bitflags! {
@ -37,3 +39,18 @@ bitflags! {
const MSG_NOTIFICATION = 0x8000; // Only applicable to SCTP socket const MSG_NOTIFICATION = 0x8000; // Only applicable to SCTP socket
} }
} }
// Flags to use when creating a new socket
bitflags! {
pub struct SocketFlags: i32 {
const SOCK_NONBLOCK = 0x800;
const SOCK_CLOEXEC = 0x80000;
}
}
#[repr(C)]
#[derive(Copy, Clone)]
pub struct mmsghdr {
pub msg_hdr: libc::msghdr,
pub msg_len: c_uint,
}

@ -15,13 +15,13 @@ mod shutdown;
mod r#type; mod r#type;
pub use self::addr::{ pub use self::addr::{
Addr, CSockAddr, Ipv4Addr, Ipv4SocketAddr, Ipv6SocketAddr, RawAddr, UnixAddr, Addr, CSockAddr, Ipv4Addr, Ipv4SocketAddr, Ipv6SocketAddr, SockAddr, UnixAddr,
}; };
pub use self::any_addr::AnyAddr; pub use self::any_addr::AnyAddr;
pub use self::domain::Domain; pub use self::domain::Domain;
pub use self::flags::{MsgFlags, RecvFlags, SendFlags}; pub use self::flags::{mmsghdr, MsgFlags, RecvFlags, SendFlags, SocketFlags};
pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec}; pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec};
pub use self::msg::{CMessages, CmsgData}; pub use self::msg::{CMessages, CmsgData};
pub use self::protocol::SocketProtocol; pub use self::protocol::SocketProtocol;
pub use self::r#type::Type; pub use self::r#type::SocketType;
pub use self::shutdown::Shutdown; pub use self::shutdown::Shutdown;

@ -4,7 +4,7 @@ use num_enum::{IntoPrimitive, TryFromPrimitive};
/// A network type. /// A network type.
#[derive(Clone, Copy, Debug, Eq, PartialEq, IntoPrimitive, TryFromPrimitive)] #[derive(Clone, Copy, Debug, Eq, PartialEq, IntoPrimitive, TryFromPrimitive)]
#[repr(i32)] #[repr(i32)]
pub enum Type { pub enum SocketType {
STREAM = libc::SOCK_STREAM, STREAM = libc::SOCK_STREAM,
DGRAM = libc::SOCK_DGRAM, DGRAM = libc::SOCK_DGRAM,
RAW = libc::SOCK_RAW, RAW = libc::SOCK_RAW,

@ -1,5 +1,4 @@
use super::socket::{MsgFlags, SocketProtocol}; use super::socket::{mmsghdr, MsgFlags, SocketFlags, SocketProtocol};
use super::{socket::uring::socket_file::SocketFile, *};
use atomic::Ordering; use atomic::Ordering;
use core::f32::consts::E; use core::f32::consts::E;
@ -17,15 +16,6 @@ use std::convert::TryFrom;
use time::{timespec_t, timeval_t}; use time::{timespec_t, timeval_t};
use util::mem_util::from_user; use util::mem_util::from_user;
use super::socket::sockopt::{
GetAcceptConnCmd, GetDomainCmd, GetErrorCmd, GetOutputAsBytes, GetPeerNameCmd,
GetRcvBufSizeCmd, GetRecvTimeoutCmd, GetSendTimeoutCmd, GetSndBufSizeCmd, GetSockOptRawCmd,
GetTypeCmd, SetRcvBufSizeCmd, SetRecvTimeoutCmd, SetSendTimeoutCmd, SetSndBufSizeCmd,
SetSockOptRawCmd, SockOptName,
};
use super::socket::uring::UringSocketType;
use super::socket::AnyAddr;
use super::*; use super::*;
use crate::fs::StatusFlags; use crate::fs::StatusFlags;
@ -36,20 +26,13 @@ use crate::prelude::*;
const SOMAXCONN: u32 = 4096; const SOMAXCONN: u32 = 4096;
const SOCONN_DEFAULT: u32 = 16; const SOCONN_DEFAULT: u32 = 16;
#[repr(C)]
#[derive(Copy, Clone)]
pub struct mmsghdr {
pub msg_hdr: libc::msghdr,
pub msg_len: c_uint,
}
pub fn do_socket(domain: c_int, socket_type: c_int, protocol: c_int) -> Result<isize> { pub fn do_socket(domain: c_int, socket_type: c_int, protocol: c_int) -> Result<isize> {
let domain = Domain::try_from(domain as u16)?; let domain = Domain::try_from(domain as u16)?;
let flags = SocketFlags::from_bits_truncate(socket_type); let flags = SocketFlags::from_bits_truncate(socket_type);
let type_bits = socket_type & !flags.bits(); let type_bits = socket_type & !flags.bits();
let socket_type = let socket_type =
Type::try_from(type_bits).map_err(|_| errno!(EINVAL, "invalid socket type"))?; SocketType::try_from(type_bits).map_err(|_| errno!(EINVAL, "invalid socket type"))?;
debug!( debug!(
"socket domain: {:?}, type: {:?}, protocol: {:?}", "socket domain: {:?}, type: {:?}, protocol: {:?}",
@ -64,7 +47,8 @@ pub fn do_socket(domain: c_int, socket_type: c_int, protocol: c_int) -> Result<i
// Determine if type and domain match uring supported socket // Determine if type and domain match uring supported socket
let match_uring = move || { let match_uring = move || {
let is_uring_type = (socket_type == Type::DGRAM || socket_type == Type::STREAM); let is_uring_type =
(socket_type == SocketType::DGRAM || socket_type == SocketType::STREAM);
let is_uring_protocol = (protocol == SocketProtocol::IPPROTO_IP let is_uring_protocol = (protocol == SocketProtocol::IPPROTO_IP
|| protocol == SocketProtocol::IPPROTO_TCP || protocol == SocketProtocol::IPPROTO_TCP
|| protocol == SocketProtocol::IPPROTO_UDP); || protocol == SocketProtocol::IPPROTO_UDP);
@ -100,20 +84,24 @@ pub fn do_socket(domain: c_int, socket_type: c_int, protocol: c_int) -> Result<i
} }
pub fn do_bind(fd: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t) -> Result<isize> { pub fn do_bind(fd: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t) -> Result<isize> {
let addr = {
let addr_len = addr_len as usize; let addr_len = addr_len as usize;
let sockaddr_storage = copy_sock_addr_from_user(addr, addr_len)?; let sockaddr_storage = copy_sock_addr_from_user(addr, addr_len)?;
let mut addr = AnyAddr::from_c_storage(&sockaddr_storage, addr_len)?; let addr = AnyAddr::from_c_storage(&sockaddr_storage, addr_len)?;
addr
};
trace!("bind to addr: {:?}", addr); trace!("bind to addr: {:?}", addr);
let file_ref = current!().file(fd as FileDesc)?; let file_ref = current!().file(fd as FileDesc)?;
if let Ok(socket) = file_ref.as_host_socket() { if let Ok(socket) = file_ref.as_host_socket() {
let mut raw_addr = addr.to_raw(); let raw_addr = addr.to_raw();
socket.bind(&mut raw_addr)?; socket.bind(&raw_addr)?;
} else if let Ok(unix_socket) = file_ref.as_unix_socket() { } else if let Ok(unix_socket) = file_ref.as_unix_socket() {
let mut unix_addr = (addr.to_unix()?).clone(); let unix_addr = addr.to_unix()?;
unix_socket.bind(&mut unix_addr)?; unix_socket.bind(unix_addr)?;
} else if let Ok(uring_socket) = file_ref.as_uring_socket() { } else if let Ok(uring_socket) = file_ref.as_uring_socket() {
uring_socket.bind(&mut addr)?; uring_socket.bind(&addr)?;
} else { } else {
return_errno!(ENOTSOCK, "not a socket"); return_errno!(ENOTSOCK, "not a socket");
} }
@ -159,18 +147,21 @@ pub fn do_connect(
let file_ref = current!().file(fd as FileDesc)?; let file_ref = current!().file(fd as FileDesc)?;
if let Ok(socket) = file_ref.as_host_socket() { if let Ok(socket) = file_ref.as_host_socket() {
let addr_option = if addr_set { let addr_option = if addr_set {
Some(unsafe { RawAddr::try_from_raw(addr, addr_len as u32)? }) Some(unsafe { SockAddr::try_from_raw(addr, addr_len as u32)? })
} else { } else {
None None
}; };
socket.connect(&addr_option)?; socket.connect(addr_option.as_ref())?;
return Ok(0); return Ok(0);
}; };
let addr = {
let addr_len = addr_len as usize; let addr_len = addr_len as usize;
let sockaddr_storage = copy_sock_addr_from_user(addr, addr_len)?; let sockaddr_storage = copy_sock_addr_from_user(addr, addr_len)?;
let addr = AnyAddr::from_c_storage(&sockaddr_storage, addr_len)?; let addr = AnyAddr::from_c_storage(&sockaddr_storage, addr_len)?;
addr
};
if let Ok(unix_socket) = file_ref.as_unix_socket() { if let Ok(unix_socket) = file_ref.as_unix_socket() {
// TODO: support AF_UNSPEC address for datagram socket use // TODO: support AF_UNSPEC address for datagram socket use
@ -415,19 +406,17 @@ pub fn do_sendto(
let file_ref = current!().file(fd as FileDesc)?; let file_ref = current!().file(fd as FileDesc)?;
if let Ok(host_socket) = file_ref.as_host_socket() { if let Ok(host_socket) = file_ref.as_host_socket() {
let addr = addr.map(|any_addr| any_addr.to_raw());
host_socket host_socket
.sendto(buf, send_flags, &addr) .sendto(buf, send_flags, addr)
.map(|u| u as isize) .map(|u| u as isize)
} else if let Ok(unix_socket) = file_ref.as_unix_socket() { } else if let Ok(unix_socket) = file_ref.as_unix_socket() {
let addr = match addr { let addr = match addr {
Some(any_addr) => Some(any_addr.to_unix()?.clone()), Some(ref any_addr) => Some(any_addr.to_unix()?),
None => None, None => None,
}; };
unix_socket unix_socket
.sendto(buf, send_flags, &addr) .sendto(buf, send_flags, addr)
.map(|u| u as isize) .map(|u| u as isize)
} else if let Ok(uring_socket) = file_ref.as_uring_socket() { } else if let Ok(uring_socket) = file_ref.as_uring_socket() {
uring_socket uring_socket
@ -458,9 +447,7 @@ pub fn do_recvfrom(
let file_ref = current!().file(fd as FileDesc)?; let file_ref = current!().file(fd as FileDesc)?;
let (data_len, addr_recv) = if let Ok(socket) = file_ref.as_host_socket() { let (data_len, addr_recv) = if let Ok(socket) = file_ref.as_host_socket() {
socket socket.recvfrom(buf, recv_flags)?
.recvfrom(buf, recv_flags)
.map(|(len, addr_recv)| (len, addr_recv.map(|raw_addr| AnyAddr::Raw(raw_addr))))?
} else if let Ok(unix_socket) = file_ref.as_unix_socket() { } else if let Ok(unix_socket) = file_ref.as_unix_socket() {
unix_socket unix_socket
.recvfrom(buf, recv_flags) .recvfrom(buf, recv_flags)
@ -496,7 +483,7 @@ pub fn do_socketpair(
let file_flags = SocketFlags::from_bits_truncate(socket_type); let file_flags = SocketFlags::from_bits_truncate(socket_type);
let close_on_spawn = file_flags.contains(SocketFlags::SOCK_CLOEXEC); let close_on_spawn = file_flags.contains(SocketFlags::SOCK_CLOEXEC);
let sock_type = Type::try_from(socket_type & (!file_flags.bits())) let sock_type = SocketType::try_from(socket_type & (!file_flags.bits()))
.map_err(|_| errno!(EINVAL, "invalid socket type"))?; .map_err(|_| errno!(EINVAL, "invalid socket type"))?;
let domain = Domain::try_from(domain as u16)?; let domain = Domain::try_from(domain as u16)?;
@ -526,9 +513,8 @@ pub fn do_sendmsg(fd: c_int, msg_ptr: *const libc::msghdr, flags_c: c_int) -> Re
let file_ref = current!().file(fd as FileDesc)?; let file_ref = current!().file(fd as FileDesc)?;
if let Ok(host_socket) = file_ref.as_host_socket() { if let Ok(host_socket) = file_ref.as_host_socket() {
let raw_addr = addr.map(|addr| addr.to_raw());
host_socket host_socket
.sendmsg(&bufs[..], flags, &raw_addr, control) .sendmsg(&bufs[..], flags, addr, control)
.map(|bytes_send| bytes_send as isize) .map(|bytes_send| bytes_send as isize)
} else if let Ok(socket) = file_ref.as_unix_socket() { } else if let Ok(socket) = file_ref.as_unix_socket() {
socket socket
@ -554,20 +540,9 @@ pub fn do_recvmsg(fd: c_int, msg_mut_ptr: *mut libc::msghdr, flags_c: c_int) ->
let file_ref = current!().file(fd as FileDesc)?; let file_ref = current!().file(fd as FileDesc)?;
let (bytes_recv, recv_addr, msg_flags, msg_controllen) = let (bytes_recv, recv_addr, msg_flags, msg_controllen) =
if let Ok(host_socket) = file_ref.as_host_socket() { if let Ok(host_socket) = file_ref.as_host_socket() {
host_socket.recvmsg(&mut bufs[..], flags, control).map( host_socket.recvmsg(&mut bufs[..], flags, control)?
|(bytes, addr, msg, controllen)| {
(
bytes,
addr.map(|raw_addr| AnyAddr::Raw(raw_addr)),
msg,
controllen,
)
},
)?
} else if let Ok(unix_socket) = file_ref.as_unix_socket() { } else if let Ok(unix_socket) = file_ref.as_unix_socket() {
unix_socket.recvmsg(&mut bufs[..], flags, control).map( unix_socket.recvmsg(&mut bufs[..], flags, control)?
|(bytes_recvd, control_len)| (bytes_recvd, None, MsgFlags::empty(), control_len),
)?
} else if let Ok(uring_socket) = file_ref.as_uring_socket() { } else if let Ok(uring_socket) = file_ref.as_uring_socket() {
uring_socket.recvmsg(&mut bufs[..], flags, control)? uring_socket.recvmsg(&mut bufs[..], flags, control)?
} else { } else {
@ -610,13 +585,12 @@ pub fn do_sendmmsg(
if let Ok(host_socket) = file_ref.as_host_socket() { if let Ok(host_socket) = file_ref.as_host_socket() {
for mmsg in (msgvec) { for mmsg in (msgvec) {
let (any_addr, bufs, control) = extract_msghdr_from_user(&mmsg.msg_hdr)?; let (addr, bufs, control) = extract_msghdr_from_user(&mmsg.msg_hdr)?;
let raw_addr = any_addr.map(|any_addr| any_addr.to_raw());
if host_socket if host_socket
.sendmsg(&bufs[..], flags, &raw_addr, control) .sendmsg(&bufs[..], flags, addr, control)
.map(|bytes_send| { .map(|bytes_send| {
mmsg.msg_len += bytes_send as c_uint; mmsg.msg_len = bytes_send as c_uint;
bytes_send as isize bytes_send as isize
}) })
.is_ok() .is_ok()
@ -635,7 +609,7 @@ pub fn do_sendmmsg(
if uring_socket if uring_socket
.sendmsg(&bufs[..], addr, flags, control) .sendmsg(&bufs[..], addr, flags, control)
.map(|bytes_send| { .map(|bytes_send| {
mmsg.msg_len += bytes_send as c_uint; mmsg.msg_len = bytes_send as c_uint;
bytes_send as isize bytes_send as isize
}) })
.is_ok() .is_ok()
@ -1021,18 +995,17 @@ fn copy_sock_addr_from_user(
let sockaddr_src_buf = from_user::make_slice(addr as *const u8, addr_len)?; let sockaddr_src_buf = from_user::make_slice(addr as *const u8, addr_len)?;
let sockaddr_storage = { let sockaddr_storage = {
// Safety. The content will be initialized before function returns. let mut sockaddr_storage = MaybeUninit::<libc::sockaddr_storage>::uninit();
let mut sockaddr_storage =
unsafe { MaybeUninit::<libc::sockaddr_storage>::uninit().assume_init() };
// Safety. The dst slice is the only mutable reference to the sockaddr_storage // Safety. The dst slice is the only mutable reference to the sockaddr_storage
let sockaddr_dst_buf = unsafe { let sockaddr_dst_buf = unsafe {
let ptr = &mut sockaddr_storage as *mut _ as *mut u8; let ptr = sockaddr_storage.as_mut_ptr() as *mut u8;
let len = addr_len; let len = addr_len;
std::slice::from_raw_parts_mut(ptr, len) std::slice::from_raw_parts_mut(ptr, len)
}; };
sockaddr_dst_buf.copy_from_slice(sockaddr_src_buf); sockaddr_dst_buf.copy_from_slice(sockaddr_src_buf);
sockaddr_storage unsafe { sockaddr_storage.assume_init() }
}; };
Ok(sockaddr_storage) Ok(sockaddr_storage)
} }
@ -1088,7 +1061,7 @@ fn new_uring_getsockopt_cmd(
level: i32, level: i32,
optname: i32, optname: i32,
optlen: u32, optlen: u32,
socket_type: Type, socket_type: SocketType,
) -> Result<Box<dyn IoctlCmd>> { ) -> Result<Box<dyn IoctlCmd>> {
if level != libc::SOL_SOCKET { if level != libc::SOL_SOCKET {
return Ok(Box::new(GetSockOptRawCmd::new(level, optname, optlen))); return Ok(Box::new(GetSockOptRawCmd::new(level, optname, optlen)));
@ -1106,17 +1079,17 @@ fn new_uring_getsockopt_cmd(
SockOptName::SO_RCVTIMEO_OLD => Box::new(GetRecvTimeoutCmd::new(())), SockOptName::SO_RCVTIMEO_OLD => Box::new(GetRecvTimeoutCmd::new(())),
SockOptName::SO_SNDTIMEO_OLD => Box::new(GetSendTimeoutCmd::new(())), SockOptName::SO_SNDTIMEO_OLD => Box::new(GetSendTimeoutCmd::new(())),
SockOptName::SO_SNDBUF => { SockOptName::SO_SNDBUF => {
if socket_type == Type::STREAM { if socket_type == SocketType::STREAM {
// Implement dynamic buf size for stream socket only. // Implement dynamic buf size for stream socket only.
Box::new(GetSndBufSizeCmd::new(())) Box::new(GetSendBufSizeCmd::new(()))
} else { } else {
Box::new(GetSockOptRawCmd::new(level, optname, optlen)) Box::new(GetSockOptRawCmd::new(level, optname, optlen))
} }
} }
SockOptName::SO_RCVBUF => { SockOptName::SO_RCVBUF => {
if socket_type == Type::STREAM { if socket_type == SocketType::STREAM {
// Implement dynamic buf size for stream socket only. // Implement dynamic buf size for stream socket only.
Box::new(GetRcvBufSizeCmd::new(())) Box::new(GetRecvBufSizeCmd::new(()))
} else { } else {
Box::new(GetSockOptRawCmd::new(level, optname, optlen)) Box::new(GetSockOptRawCmd::new(level, optname, optlen))
} }
@ -1163,34 +1136,19 @@ fn new_uring_setsockopt_cmd(
level: i32, level: i32,
optname: i32, optname: i32,
optval: &'static [u8], optval: &'static [u8],
socket_type: Type, socket_type: SocketType,
) -> Result<Box<dyn IoctlCmd>> { ) -> Result<Box<dyn IoctlCmd>> {
if level != libc::SOL_SOCKET { if level != libc::SOL_SOCKET {
return Ok(Box::new(SetSockOptRawCmd::new(level, optname, optval))); return Ok(Box::new(SetSockOptRawCmd::new(level, optname, optval)));
} }
if optval.len() == 0 {
return_errno!(EINVAL, "Not a valid optval length");
}
let opt = let opt =
SockOptName::try_from(optname).map_err(|_| errno!(ENOPROTOOPT, "Not a valid optname"))?; SockOptName::try_from(optname).map_err(|_| errno!(ENOPROTOOPT, "Not a valid optname"))?;
let enable_uring = ENABLE_URING.load(Ordering::Relaxed);
if !enable_uring {
Ok(match opt {
SockOptName::SO_ACCEPTCONN
| SockOptName::SO_DOMAIN
| SockOptName::SO_PEERNAME
| SockOptName::SO_TYPE
| SockOptName::SO_ERROR
| SockOptName::SO_PEERCRED
| SockOptName::SO_SNDLOWAT
| SockOptName::SO_PEERSEC
| SockOptName::SO_PROTOCOL
| SockOptName::SO_MEMINFO
| SockOptName::SO_INCOMING_NAPI_ID
| SockOptName::SO_COOKIE
| SockOptName::SO_PEERGROUPS => return_errno!(ENOPROTOOPT, "it's a read-only option"),
_ => Box::new(SetSockOptRawCmd::new(level, optname, optval)),
})
} else {
Ok(match opt { Ok(match opt {
SockOptName::SO_ACCEPTCONN SockOptName::SO_ACCEPTCONN
| SockOptName::SO_DOMAIN | SockOptName::SO_DOMAIN
@ -1253,16 +1211,20 @@ fn new_uring_setsockopt_cmd(
} }
SockOptName::SO_SNDBUF => { SockOptName::SO_SNDBUF => {
// Implement dynamic buf size for stream socket only. // Implement dynamic buf size for stream socket only.
if socket_type != Type::STREAM { if socket_type != SocketType::STREAM {
Box::new(SetSockOptRawCmd::new(level, optname, optval)) Box::new(SetSockOptRawCmd::new(level, optname, optval))
} else { } else {
// Based on the man page: The minimum (doubled) value for this option is 2048. // Based on the man page: The minimum (doubled) value for this option is 2048.
// However, the test on Linux shows the minimum value (doubled) is 4608. Here, we just // However, the test on Linux shows the minimum value (doubled) is 4608. Here, we just
// use the same value as Linux. // use the same value as Linux.
let min_size = 1152; let min_size = 128 * 1024;
// For the max value, we choose 4MB (doubled) to assure the libos kernel buf won't be the bottleneck. // For the max value, we choose 4MB (doubled) to assure the libos kernel buf won't be the bottleneck.
let max_size = 2 * 1024 * 1024; let max_size = 2 * 1024 * 1024;
if optval.len() > 8 {
return_errno!(EINVAL, "optval size is invalid");
}
let mut send_buf_size = { let mut send_buf_size = {
let mut size = [0 as u8; std::mem::size_of::<usize>()]; let mut size = [0 as u8; std::mem::size_of::<usize>()];
let start_offset = size.len() - optval.len(); let start_offset = size.len() - optval.len();
@ -1279,11 +1241,11 @@ fn new_uring_setsockopt_cmd(
// Based on man page: The kernel doubles this value (to allow space for bookkeeping overhead) // Based on man page: The kernel doubles this value (to allow space for bookkeeping overhead)
// when it is set using setsockopt(2), and this doubled value is returned by getsockopt(2). // when it is set using setsockopt(2), and this doubled value is returned by getsockopt(2).
send_buf_size *= 2; send_buf_size *= 2;
Box::new(SetSndBufSizeCmd::new(send_buf_size)) Box::new(SetSendBufSizeCmd::new(send_buf_size))
} }
} }
SockOptName::SO_RCVBUF => { SockOptName::SO_RCVBUF => {
if socket_type != Type::STREAM { if socket_type != SocketType::STREAM {
Box::new(SetSockOptRawCmd::new(level, optname, optval)) Box::new(SetSockOptRawCmd::new(level, optname, optval))
} else { } else {
// Implement dynamic buf size for stream socket only. // Implement dynamic buf size for stream socket only.
@ -1291,10 +1253,14 @@ fn new_uring_setsockopt_cmd(
// Based on the man page: The minimum (doubled) value for this option is 256. // Based on the man page: The minimum (doubled) value for this option is 256.
// However, the test on Linux shows the minimum value (doubled) is 2304. Here, we just // However, the test on Linux shows the minimum value (doubled) is 2304. Here, we just
// use the same value as Linux. // use the same value as Linux.
let min_size = 1152; let min_size = 128 * 1024;
// For the max value, we choose 4MB (doubled) to assure the libos kernel buf won't be the bottleneck. // For the max value, we choose 4MB (doubled) to assure the libos kernel buf won't be the bottleneck.
let max_size = 2 * 1024 * 1024; let max_size = 2 * 1024 * 1024;
if optval.len() > 8 {
return_errno!(EINVAL, "optval size is invalid");
}
let mut recv_buf_size = { let mut recv_buf_size = {
let mut size = [0 as u8; std::mem::size_of::<usize>()]; let mut size = [0 as u8; std::mem::size_of::<usize>()];
let start_offset = size.len() - optval.len(); let start_offset = size.len() - optval.len();
@ -1312,12 +1278,11 @@ fn new_uring_setsockopt_cmd(
// Based on man page: The kernel doubles this value (to allow space for bookkeeping overhead) // Based on man page: The kernel doubles this value (to allow space for bookkeeping overhead)
// when it is set using setsockopt(2), and this doubled value is returned by getsockopt(2). // when it is set using setsockopt(2), and this doubled value is returned by getsockopt(2).
recv_buf_size *= 2; recv_buf_size *= 2;
Box::new(SetRcvBufSizeCmd::new(recv_buf_size)) Box::new(SetRecvBufSizeCmd::new(recv_buf_size))
} }
} }
_ => Box::new(SetSockOptRawCmd::new(level, optname, optval)), _ => Box::new(SetSockOptRawCmd::new(level, optname, optval)),
}) })
}
} }
fn get_optval(cmd: &dyn IoctlCmd) -> Result<&[u8]> { fn get_optval(cmd: &dyn IoctlCmd) -> Result<&[u8]> {
@ -1346,10 +1311,10 @@ fn get_optval(cmd: &dyn IoctlCmd) -> Result<&[u8]> {
cmd : GetSendTimeoutCmd => { cmd : GetSendTimeoutCmd => {
cmd.get_output_as_bytes() cmd.get_output_as_bytes()
}, },
cmd : GetSndBufSizeCmd => { cmd : GetSendBufSizeCmd => {
cmd.get_output_as_bytes() cmd.get_output_as_bytes()
}, },
cmd : GetRcvBufSizeCmd => { cmd : GetRecvBufSizeCmd => {
cmd.get_output_as_bytes() cmd.get_output_as_bytes()
}, },
_ => { _ => {
@ -1491,11 +1456,3 @@ fn extract_msghdr_mut_from_user<'a>(
Ok((msg_mut, name, control, bufs)) Ok((msg_mut, name, control, bufs))
} }
// Flags to use when creating a new socket
bitflags! {
pub struct SocketFlags: i32 {
const SOCK_NONBLOCK = 0x800;
const SOCK_CLOEXEC = 0x80000;
}
}

@ -17,8 +17,7 @@ pub use std::sync::{
pub use crate::error::Result; pub use crate::error::Result;
pub use crate::error::*; pub use crate::error::*;
pub use crate::fs::{File, FileDesc, FileRef}; pub use crate::fs::{File, FileDesc, FileRef};
pub use crate::net::socket::util::Addr; pub use crate::net::{Addr, Domain, RecvFlags, SendFlags, Shutdown, SocketType};
pub use crate::net::socket::{Domain, RecvFlags, SendFlags, Shutdown, Type};
pub use crate::process::{pid_t, uid_t}; pub use crate::process::{pid_t, uid_t};
pub use crate::util::sync::RwLock; pub use crate::util::sync::RwLock;
pub use crate::util::sync::{Mutex, MutexGuard}; pub use crate::util::sync::{Mutex, MutexGuard};

@ -14,7 +14,6 @@ use crate::misc::ResourceLimits;
use crate::prelude::*; use crate::prelude::*;
use crate::sched::{NiceValue, SchedAgent}; use crate::sched::{NiceValue, SchedAgent};
use crate::signal::{SigDispositions, SigQueues}; use crate::signal::{SigDispositions, SigQueues};
use crate::util::sync::Mutex;
use crate::vm::ProcessVM; use crate::vm::ProcessVM;
use self::pgrp::ProcessGrp; use self::pgrp::ProcessGrp;

@ -11,20 +11,20 @@ use atomic::Ordering;
use crate::process::{futex_wait, futex_wake}; use crate::process::{futex_wait, futex_wake};
#[derive(Default)] #[derive(Default)]
pub struct Mutex<T> { pub struct Mutex<T: ?Sized> {
value: UnsafeCell<T>,
inner: Box<MutexInner>, inner: Box<MutexInner>,
value: UnsafeCell<T>,
} }
unsafe impl<T: Send> Sync for Mutex<T> {} unsafe impl<T: Send + ?Sized> Sync for Mutex<T> {}
unsafe impl<T: Send> Send for Mutex<T> {} unsafe impl<T: Send + ?Sized> Send for Mutex<T> {}
pub struct MutexGuard<'a, T: 'a> { pub struct MutexGuard<'a, T: ?Sized + 'a> {
inner: &'a Mutex<T>, inner: &'a Mutex<T>,
} }
impl<T> !Send for MutexGuard<'_, T> {} impl<T: ?Sized> !Send for MutexGuard<'_, T> {}
unsafe impl<T: Sync> Sync for MutexGuard<'_, T> {} unsafe impl<T: Sync + ?Sized> Sync for MutexGuard<'_, T> {}
impl<T> Mutex<T> { impl<T> Mutex<T> {
#[inline] #[inline]
@ -34,14 +34,9 @@ impl<T> Mutex<T> {
inner: Box::new(MutexInner::new()), inner: Box::new(MutexInner::new()),
} }
} }
#[inline]
pub fn into_inner(self) -> T {
self.value.into_inner()
}
} }
impl<T> Mutex<T> { impl<T: ?Sized> Mutex<T> {
#[inline] #[inline]
pub fn lock(&self) -> MutexGuard<'_, T> { pub fn lock(&self) -> MutexGuard<'_, T> {
self.inner.lock(); self.inner.lock();
@ -50,7 +45,7 @@ impl<T> Mutex<T> {
#[inline] #[inline]
pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> { pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
self.inner.try_lock().map(|_| MutexGuard { inner: self }) self.inner.try_lock().then(|| MutexGuard { inner: self })
} }
#[inline] #[inline]
@ -67,9 +62,17 @@ impl<T> Mutex<T> {
pub fn get_mut(&mut self) -> &mut T { pub fn get_mut(&mut self) -> &mut T {
self.value.get_mut() self.value.get_mut()
} }
#[inline]
pub fn into_inner(self) -> T
where
T: Sized,
{
self.value.into_inner()
}
} }
impl<T> Deref for MutexGuard<'_, T> { impl<T: ?Sized> Deref for MutexGuard<'_, T> {
type Target = T; type Target = T;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
@ -77,13 +80,13 @@ impl<T> Deref for MutexGuard<'_, T> {
} }
} }
impl<T> DerefMut for MutexGuard<'_, T> { impl<T: ?Sized> DerefMut for MutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target { fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.inner.value.get() } unsafe { &mut *self.inner.value.get() }
} }
} }
impl<T> Drop for MutexGuard<'_, T> { impl<T: ?Sized> Drop for MutexGuard<'_, T> {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { unsafe {
self.inner.force_unlock(); self.inner.force_unlock();
@ -91,13 +94,13 @@ impl<T> Drop for MutexGuard<'_, T> {
} }
} }
impl<T: fmt::Debug> fmt::Debug for MutexGuard<'_, T> { impl<T: fmt::Debug + ?Sized> fmt::Debug for MutexGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&**self, f) fmt::Debug::fmt(&**self, f)
} }
} }
impl<T: fmt::Debug> fmt::Debug for Mutex<T> { impl<T: fmt::Debug + ?Sized> fmt::Debug for Mutex<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.try_lock() { match self.try_lock() {
Some(guard) => write!(f, "Mutex {{ value: ") Some(guard) => write!(f, "Mutex {{ value: ")
@ -138,10 +141,10 @@ impl MutexInner {
} }
#[inline] #[inline]
pub fn try_lock(&self) -> Option<u32> { pub fn try_lock(&self) -> bool {
self.lock self.lock
.compare_exchange(0, 1, Ordering::Acquire, Ordering::Relaxed) .compare_exchange(0, 1, Ordering::Acquire, Ordering::Relaxed)
.ok() .is_ok()
} }
#[inline] #[inline]