diff --git a/src/libos/src/net/socket/mod.rs b/src/libos/src/net/socket/mod.rs index 38203bcb..f57d2fb0 100644 --- a/src/libos/src/net/socket/mod.rs +++ b/src/libos/src/net/socket/mod.rs @@ -14,7 +14,7 @@ pub use self::address_family::AddressFamily; pub use self::flags::{FileFlags, MsgHdrFlags, RecvFlags, SendFlags}; pub use self::host::{HostSocket, HostSocketType}; pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec}; -pub use self::msg::{mmsghdr, msghdr, msghdr_mut, MsgHdr, MsgHdrMut}; +pub use self::msg::{mmsghdr, msghdr, msghdr_mut, CMessages, CmsgData, MsgHdr, MsgHdrMut}; pub use self::shutdown::HowToShut; pub use self::socket_address::SockAddr; pub use self::socket_type::SocketType; diff --git a/src/libos/src/net/socket/msg.rs b/src/libos/src/net/socket/msg.rs index 522b9102..47cb7a6f 100644 --- a/src/libos/src/net/socket/msg.rs +++ b/src/libos/src/net/socket/msg.rs @@ -219,6 +219,119 @@ impl<'a> MsgHdrMut<'a> { } } +/// This struct is used to iterate through the control messages. +/// +/// `cmsghdr` is a C struct for ancillary data object information of a unix socket. +pub struct CMessages<'a> { + buffer: &'a [u8], + current: Option<&'a libc::cmsghdr>, +} + +impl<'a> Iterator for CMessages<'a> { + type Item = CmsgData<'a>; + + fn next(&mut self) -> Option { + let cmsg = unsafe { + let mut msg: libc::msghdr = core::mem::zeroed(); + msg.msg_control = self.buffer.as_ptr() as *mut _; + msg.msg_controllen = self.buffer.len() as _; + + let cmsg = if let Some(current) = self.current { + libc::CMSG_NXTHDR(&msg, current) + } else { + libc::CMSG_FIRSTHDR(&msg) + }; + cmsg.as_ref()? + }; + + self.current = Some(cmsg); + CmsgData::try_from_cmsghdr(cmsg) + } +} + +impl<'a> CMessages<'a> { + pub fn from_bytes(msg_control: &'a mut [u8]) -> Self { + Self { + buffer: msg_control, + current: None, + } + } +} + +/// Control message data of variable type. The data resides next to `cmsghdr`. +pub enum CmsgData<'a> { + ScmRights(ScmRights<'a>), + ScmCredentials, +} + +impl<'a> CmsgData<'a> { + /// Create an `CmsgData::ScmRights` variant. + /// + /// # Safety + /// + /// `data` must contain a valid control message and the control message must be type of + /// `SOL_SOCKET` and level of `SCM_RIGHTS`. + unsafe fn as_rights(data: &'a mut [u8]) -> Self { + let scm_rights = ScmRights { data }; + CmsgData::ScmRights(scm_rights) + } + + /// Create an `CmsgData::ScmCredentials` variant. + /// + /// # Safety + /// + /// `data` must contain a valid control message and the control message must be type of + /// `SOL_SOCKET` and level of `SCM_CREDENTIALS`. + unsafe fn as_credentials(_data: &'a [u8]) -> Self { + CmsgData::ScmCredentials + } + + fn try_from_cmsghdr(cmsg: &'a libc::cmsghdr) -> Option { + unsafe { + let cmsg_len_zero = libc::CMSG_LEN(0) as usize; + let data_len = (*cmsg).cmsg_len as usize - cmsg_len_zero; + let data = libc::CMSG_DATA(cmsg); + let data = core::slice::from_raw_parts_mut(data, data_len); + + match (*cmsg).cmsg_level { + libc::SOL_SOCKET => match (*cmsg).cmsg_type { + libc::SCM_RIGHTS => Some(CmsgData::as_rights(data)), + libc::SCM_CREDENTIALS => Some(CmsgData::as_credentials(data)), + _ => None, + }, + _ => None, + } + } + } +} + +/// The data unit of this control message is file descriptor(s). +/// +/// The level is equal to `SOL_SOCKET` and the type is equal to `SCM_RIGHTS`. +pub struct ScmRights<'a> { + data: &'a mut [u8], +} + +impl<'a> ScmRights<'a> { + /// Iterate and reassign each fd in data buffer, given a reassignment function. + pub fn iter_and_reassign_fds(&mut self, reassign_fd_fn: F) + where + F: Fn(FileDesc) -> FileDesc, + { + for fd_bytes in self.data.chunks_exact_mut(core::mem::size_of::()) { + let old_fd = FileDesc::from_ne_bytes(fd_bytes.try_into().unwrap()); + let reassigned_fd = reassign_fd_fn(old_fd); + fd_bytes.copy_from_slice(&reassigned_fd.to_ne_bytes()); + } + } + + pub fn iter_fds(&self) -> impl Iterator + '_ { + self.data + .chunks_exact(core::mem::size_of::()) + .map(|fd_bytes| FileDesc::from_ne_bytes(fd_bytes.try_into().unwrap())) + } +} + unsafe fn new_optional_slice<'a, T>(slice_ptr: *const T, slice_size: usize) -> Option<&'a [T]> { if !slice_ptr.is_null() { let slice = core::slice::from_raw_parts::(slice_ptr, slice_size); diff --git a/src/libos/src/net/socket/unix/stream/endpoint.rs b/src/libos/src/net/socket/unix/stream/endpoint.rs index 1e427600..9ddbb7f0 100644 --- a/src/libos/src/net/socket/unix/stream/endpoint.rs +++ b/src/libos/src/net/socket/unix/stream/endpoint.rs @@ -17,12 +17,14 @@ pub fn end_pair(nonblocking: bool) -> Result<(Endpoint, Endpoint)> { reader: con_a, writer: pro_b, peer: Weak::default(), + ancillary: RwLock::new(None), }); let end_b = Arc::new(Inner { addr: RwLock::new(None), reader: con_b, writer: pro_a, peer: Arc::downgrade(&end_a), + ancillary: RwLock::new(None), }); unsafe { @@ -41,6 +43,7 @@ pub struct Inner { reader: Consumer, writer: Producer, peer: Weak, + ancillary: RwLock>, } impl Inner { @@ -119,6 +122,18 @@ impl Inner { events } + pub fn ancillary(&self) -> Option { + self.ancillary.read().unwrap().clone() + } + + pub fn set_ancillary(&self, ancillary: Ancillary) { + self.ancillary.write().unwrap().insert(ancillary); + } + + pub fn peer_ancillary(&self) -> Option { + self.peer.upgrade().map(|end| end.ancillary()).flatten() + } + pub(self) fn register_relay_notifier(&self, observer: &Arc) { self.reader.notifier().register( Arc::downgrade(observer) as Weak>, @@ -138,6 +153,18 @@ impl Inner { } } +/// Ancillary data of connected unix socket's sent/received control message. +#[derive(Clone, Debug)] +pub struct Ancillary { + pub(super) tid: pid_t, // currently store tid to locate file table +} + +impl Ancillary { + pub fn tid(&self) -> pid_t { + self.tid + } +} + // TODO: Add SO_SNDBUF and SO_RCVBUF to set/getsockopt to dynamcally change the size. // This value is got from /proc/sys/net/core/rmem_max and wmem_max that are same on linux. pub const DEFAULT_BUF_SIZE: usize = 208 * 1024; diff --git a/src/libos/src/net/socket/unix/stream/stream.rs b/src/libos/src/net/socket/unix/stream/stream.rs index e3d16a62..4bea1a07 100644 --- a/src/libos/src/net/socket/unix/stream/stream.rs +++ b/src/libos/src/net/socket/unix/stream/stream.rs @@ -1,11 +1,11 @@ use super::address_space::ADDRESS_SPACE; -use super::endpoint::{end_pair, Endpoint, RelayNotifier}; +use super::endpoint::{end_pair, Ancillary, Endpoint, RelayNotifier}; use super::*; use events::{Event, EventFilter, Notifier, Observer}; use fs::channel::Channel; use fs::IoEvents; use fs::{CreationFlags, FileMode}; -use net::socket::{Iovs, MsgHdr, MsgHdrMut}; +use net::socket::{CMessages, CmsgData, Iovs, MsgHdr, MsgHdrMut}; use std::fmt; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -161,6 +161,9 @@ impl Stream { if let Some(self_addr) = self_addr_opt { end_self.set_addr(self_addr); } + end_self.set_ancillary(Ancillary { + tid: current!().tid(), + }); ADDRESS_SPACE .push_incoming(addr, end_incoming) @@ -190,6 +193,9 @@ impl Stream { Status::Listening(addr) => { let endpoint = ADDRESS_SPACE.pop_incoming(&addr)?; endpoint.set_nonblocking(flags.contains(FileFlags::SOCK_NONBLOCK)); + endpoint.set_ancillary(Ancillary { + tid: current!().tid(), + }); let notifier = Arc::new(RelayNotifier::new()); notifier.observe_endpoint(&endpoint); @@ -228,12 +234,14 @@ impl Stream { if !flags.is_empty() { warn!("unsupported flags: {:?}", flags); } - if msg_hdr.get_control().is_some() { - warn!("sendmsg with msg_control is not supported"); - } let bufs = msg_hdr.get_iovs().as_slices(); - self.writev(bufs) + let mut data_len = self.writev(bufs)?; + + if let Some(msg_control) = msg_hdr.get_control() { + data_len += self.write(msg_control)?; + } + Ok(data_len) } pub fn recvmsg(&self, msg_hdr: &mut MsgHdrMut, flags: RecvFlags) -> Result { @@ -242,11 +250,33 @@ impl Stream { } let bufs = msg_hdr.get_iovs_mut().as_slices_mut(); - let data_len = self.readv(bufs)?; + let mut data_len = self.readv(bufs)?; // For stream socket, the msg_name is ignored. And other fields are not supported. msg_hdr.set_name_len(0); + if let Some(msg_control) = msg_hdr.get_control_mut() { + data_len += self.read(msg_control)?; + + // For each control message that contains file descriptors (SOL_SOCKET and SCM_RIGHTS), + // reassign each fd in the message in receive end. + for cmsg in CMessages::from_bytes(msg_control) { + if let CmsgData::ScmRights(mut scm_rights) = cmsg { + let send_tid = self.peer_ancillary().unwrap().tid(); + scm_rights.iter_and_reassign_fds(|send_fd| { + let ipc_file = process::table::get_thread(send_tid) + .unwrap() + .files() + .lock() + .unwrap() + .get(send_fd) + .unwrap(); + current!().add_file(ipc_file.clone(), false) + }) + } + // Unix credentials need not to be handled here + } + } Ok(data_len) } @@ -281,6 +311,28 @@ impl Stream { pub(super) fn inner(&self) -> SgxMutexGuard<'_, Status> { self.inner.lock().unwrap() } + + fn ancillary(&self) -> Option { + match &*self.inner() { + Status::Idle(_) => None, + Status::Listening(_) => None, + Status::Connected(endpoint) => endpoint.ancillary(), + } + } + + fn peer_ancillary(&self) -> Option { + if let Status::Connected(endpoint) = &*self.inner() { + endpoint.peer_ancillary() + } else { + None + } + } + + fn set_ancillary(&self, ancillary: Ancillary) { + if let Status::Connected(endpoint) = &*self.inner() { + endpoint.set_ancillary(ancillary) + } + } } impl Debug for Stream {