[net] Support send/receive control message in unix socket

This commit is contained in:
Shaowei Song 2023-06-25 17:14:15 +08:00 committed by volcano
parent b0de80bd50
commit 56add87c76
4 changed files with 200 additions and 8 deletions

@ -14,7 +14,7 @@ pub use self::address_family::AddressFamily;
pub use self::flags::{FileFlags, MsgHdrFlags, RecvFlags, SendFlags};
pub use self::host::{HostSocket, HostSocketType};
pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec};
pub use self::msg::{mmsghdr, msghdr, msghdr_mut, MsgHdr, MsgHdrMut};
pub use self::msg::{mmsghdr, msghdr, msghdr_mut, CMessages, CmsgData, MsgHdr, MsgHdrMut};
pub use self::shutdown::HowToShut;
pub use self::socket_address::SockAddr;
pub use self::socket_type::SocketType;

@ -219,6 +219,119 @@ impl<'a> MsgHdrMut<'a> {
}
}
/// This struct is used to iterate through the control messages.
///
/// `cmsghdr` is a C struct for ancillary data object information of a unix socket.
pub struct CMessages<'a> {
buffer: &'a [u8],
current: Option<&'a libc::cmsghdr>,
}
impl<'a> Iterator for CMessages<'a> {
type Item = CmsgData<'a>;
fn next(&mut self) -> Option<Self::Item> {
let cmsg = unsafe {
let mut msg: libc::msghdr = core::mem::zeroed();
msg.msg_control = self.buffer.as_ptr() as *mut _;
msg.msg_controllen = self.buffer.len() as _;
let cmsg = if let Some(current) = self.current {
libc::CMSG_NXTHDR(&msg, current)
} else {
libc::CMSG_FIRSTHDR(&msg)
};
cmsg.as_ref()?
};
self.current = Some(cmsg);
CmsgData::try_from_cmsghdr(cmsg)
}
}
impl<'a> CMessages<'a> {
pub fn from_bytes(msg_control: &'a mut [u8]) -> Self {
Self {
buffer: msg_control,
current: None,
}
}
}
/// Control message data of variable type. The data resides next to `cmsghdr`.
pub enum CmsgData<'a> {
ScmRights(ScmRights<'a>),
ScmCredentials,
}
impl<'a> CmsgData<'a> {
/// Create an `CmsgData::ScmRights` variant.
///
/// # Safety
///
/// `data` must contain a valid control message and the control message must be type of
/// `SOL_SOCKET` and level of `SCM_RIGHTS`.
unsafe fn as_rights(data: &'a mut [u8]) -> Self {
let scm_rights = ScmRights { data };
CmsgData::ScmRights(scm_rights)
}
/// Create an `CmsgData::ScmCredentials` variant.
///
/// # Safety
///
/// `data` must contain a valid control message and the control message must be type of
/// `SOL_SOCKET` and level of `SCM_CREDENTIALS`.
unsafe fn as_credentials(_data: &'a [u8]) -> Self {
CmsgData::ScmCredentials
}
fn try_from_cmsghdr(cmsg: &'a libc::cmsghdr) -> Option<Self> {
unsafe {
let cmsg_len_zero = libc::CMSG_LEN(0) as usize;
let data_len = (*cmsg).cmsg_len as usize - cmsg_len_zero;
let data = libc::CMSG_DATA(cmsg);
let data = core::slice::from_raw_parts_mut(data, data_len);
match (*cmsg).cmsg_level {
libc::SOL_SOCKET => match (*cmsg).cmsg_type {
libc::SCM_RIGHTS => Some(CmsgData::as_rights(data)),
libc::SCM_CREDENTIALS => Some(CmsgData::as_credentials(data)),
_ => None,
},
_ => None,
}
}
}
}
/// The data unit of this control message is file descriptor(s).
///
/// The level is equal to `SOL_SOCKET` and the type is equal to `SCM_RIGHTS`.
pub struct ScmRights<'a> {
data: &'a mut [u8],
}
impl<'a> ScmRights<'a> {
/// Iterate and reassign each fd in data buffer, given a reassignment function.
pub fn iter_and_reassign_fds<F>(&mut self, reassign_fd_fn: F)
where
F: Fn(FileDesc) -> FileDesc,
{
for fd_bytes in self.data.chunks_exact_mut(core::mem::size_of::<FileDesc>()) {
let old_fd = FileDesc::from_ne_bytes(fd_bytes.try_into().unwrap());
let reassigned_fd = reassign_fd_fn(old_fd);
fd_bytes.copy_from_slice(&reassigned_fd.to_ne_bytes());
}
}
pub fn iter_fds(&self) -> impl Iterator<Item = FileDesc> + '_ {
self.data
.chunks_exact(core::mem::size_of::<FileDesc>())
.map(|fd_bytes| FileDesc::from_ne_bytes(fd_bytes.try_into().unwrap()))
}
}
unsafe fn new_optional_slice<'a, T>(slice_ptr: *const T, slice_size: usize) -> Option<&'a [T]> {
if !slice_ptr.is_null() {
let slice = core::slice::from_raw_parts::<T>(slice_ptr, slice_size);

@ -17,12 +17,14 @@ pub fn end_pair(nonblocking: bool) -> Result<(Endpoint, Endpoint)> {
reader: con_a,
writer: pro_b,
peer: Weak::default(),
ancillary: RwLock::new(None),
});
let end_b = Arc::new(Inner {
addr: RwLock::new(None),
reader: con_b,
writer: pro_a,
peer: Arc::downgrade(&end_a),
ancillary: RwLock::new(None),
});
unsafe {
@ -41,6 +43,7 @@ pub struct Inner {
reader: Consumer<u8>,
writer: Producer<u8>,
peer: Weak<Self>,
ancillary: RwLock<Option<Ancillary>>,
}
impl Inner {
@ -119,6 +122,18 @@ impl Inner {
events
}
pub fn ancillary(&self) -> Option<Ancillary> {
self.ancillary.read().unwrap().clone()
}
pub fn set_ancillary(&self, ancillary: Ancillary) {
self.ancillary.write().unwrap().insert(ancillary);
}
pub fn peer_ancillary(&self) -> Option<Ancillary> {
self.peer.upgrade().map(|end| end.ancillary()).flatten()
}
pub(self) fn register_relay_notifier(&self, observer: &Arc<RelayNotifier>) {
self.reader.notifier().register(
Arc::downgrade(observer) as Weak<dyn Observer<_>>,
@ -138,6 +153,18 @@ impl Inner {
}
}
/// Ancillary data of connected unix socket's sent/received control message.
#[derive(Clone, Debug)]
pub struct Ancillary {
pub(super) tid: pid_t, // currently store tid to locate file table
}
impl Ancillary {
pub fn tid(&self) -> pid_t {
self.tid
}
}
// TODO: Add SO_SNDBUF and SO_RCVBUF to set/getsockopt to dynamcally change the size.
// This value is got from /proc/sys/net/core/rmem_max and wmem_max that are same on linux.
pub const DEFAULT_BUF_SIZE: usize = 208 * 1024;

@ -1,11 +1,11 @@
use super::address_space::ADDRESS_SPACE;
use super::endpoint::{end_pair, Endpoint, RelayNotifier};
use super::endpoint::{end_pair, Ancillary, Endpoint, RelayNotifier};
use super::*;
use events::{Event, EventFilter, Notifier, Observer};
use fs::channel::Channel;
use fs::IoEvents;
use fs::{CreationFlags, FileMode};
use net::socket::{Iovs, MsgHdr, MsgHdrMut};
use net::socket::{CMessages, CmsgData, Iovs, MsgHdr, MsgHdrMut};
use std::fmt;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
@ -161,6 +161,9 @@ impl Stream {
if let Some(self_addr) = self_addr_opt {
end_self.set_addr(self_addr);
}
end_self.set_ancillary(Ancillary {
tid: current!().tid(),
});
ADDRESS_SPACE
.push_incoming(addr, end_incoming)
@ -190,6 +193,9 @@ impl Stream {
Status::Listening(addr) => {
let endpoint = ADDRESS_SPACE.pop_incoming(&addr)?;
endpoint.set_nonblocking(flags.contains(FileFlags::SOCK_NONBLOCK));
endpoint.set_ancillary(Ancillary {
tid: current!().tid(),
});
let notifier = Arc::new(RelayNotifier::new());
notifier.observe_endpoint(&endpoint);
@ -228,12 +234,14 @@ impl Stream {
if !flags.is_empty() {
warn!("unsupported flags: {:?}", flags);
}
if msg_hdr.get_control().is_some() {
warn!("sendmsg with msg_control is not supported");
}
let bufs = msg_hdr.get_iovs().as_slices();
self.writev(bufs)
let mut data_len = self.writev(bufs)?;
if let Some(msg_control) = msg_hdr.get_control() {
data_len += self.write(msg_control)?;
}
Ok(data_len)
}
pub fn recvmsg(&self, msg_hdr: &mut MsgHdrMut, flags: RecvFlags) -> Result<usize> {
@ -242,11 +250,33 @@ impl Stream {
}
let bufs = msg_hdr.get_iovs_mut().as_slices_mut();
let data_len = self.readv(bufs)?;
let mut data_len = self.readv(bufs)?;
// For stream socket, the msg_name is ignored. And other fields are not supported.
msg_hdr.set_name_len(0);
if let Some(msg_control) = msg_hdr.get_control_mut() {
data_len += self.read(msg_control)?;
// For each control message that contains file descriptors (SOL_SOCKET and SCM_RIGHTS),
// reassign each fd in the message in receive end.
for cmsg in CMessages::from_bytes(msg_control) {
if let CmsgData::ScmRights(mut scm_rights) = cmsg {
let send_tid = self.peer_ancillary().unwrap().tid();
scm_rights.iter_and_reassign_fds(|send_fd| {
let ipc_file = process::table::get_thread(send_tid)
.unwrap()
.files()
.lock()
.unwrap()
.get(send_fd)
.unwrap();
current!().add_file(ipc_file.clone(), false)
})
}
// Unix credentials need not to be handled here
}
}
Ok(data_len)
}
@ -281,6 +311,28 @@ impl Stream {
pub(super) fn inner(&self) -> SgxMutexGuard<'_, Status> {
self.inner.lock().unwrap()
}
fn ancillary(&self) -> Option<Ancillary> {
match &*self.inner() {
Status::Idle(_) => None,
Status::Listening(_) => None,
Status::Connected(endpoint) => endpoint.ancillary(),
}
}
fn peer_ancillary(&self) -> Option<Ancillary> {
if let Status::Connected(endpoint) = &*self.inner() {
endpoint.peer_ancillary()
} else {
None
}
}
fn set_ancillary(&self, ancillary: Ancillary) {
if let Status::Connected(endpoint) = &*self.inner() {
endpoint.set_ancillary(ancillary)
}
}
}
impl Debug for Stream {