[net] Support send/receive control message in unix socket
This commit is contained in:
		
							parent
							
								
									b0de80bd50
								
							
						
					
					
						commit
						56add87c76
					
				| @ -14,7 +14,7 @@ pub use self::address_family::AddressFamily; | ||||
| pub use self::flags::{FileFlags, MsgHdrFlags, RecvFlags, SendFlags}; | ||||
| pub use self::host::{HostSocket, HostSocketType}; | ||||
| pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec}; | ||||
| pub use self::msg::{mmsghdr, msghdr, msghdr_mut, MsgHdr, MsgHdrMut}; | ||||
| pub use self::msg::{mmsghdr, msghdr, msghdr_mut, CMessages, CmsgData, MsgHdr, MsgHdrMut}; | ||||
| pub use self::shutdown::HowToShut; | ||||
| pub use self::socket_address::SockAddr; | ||||
| pub use self::socket_type::SocketType; | ||||
|  | ||||
| @ -219,6 +219,119 @@ impl<'a> MsgHdrMut<'a> { | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// This struct is used to iterate through the control messages.
 | ||||
| ///
 | ||||
| /// `cmsghdr` is a C struct for ancillary data object information of a unix socket.
 | ||||
| pub struct CMessages<'a> { | ||||
|     buffer: &'a [u8], | ||||
|     current: Option<&'a libc::cmsghdr>, | ||||
| } | ||||
| 
 | ||||
| impl<'a> Iterator for CMessages<'a> { | ||||
|     type Item = CmsgData<'a>; | ||||
| 
 | ||||
|     fn next(&mut self) -> Option<Self::Item> { | ||||
|         let cmsg = unsafe { | ||||
|             let mut msg: libc::msghdr = core::mem::zeroed(); | ||||
|             msg.msg_control = self.buffer.as_ptr() as *mut _; | ||||
|             msg.msg_controllen = self.buffer.len() as _; | ||||
| 
 | ||||
|             let cmsg = if let Some(current) = self.current { | ||||
|                 libc::CMSG_NXTHDR(&msg, current) | ||||
|             } else { | ||||
|                 libc::CMSG_FIRSTHDR(&msg) | ||||
|             }; | ||||
|             cmsg.as_ref()? | ||||
|         }; | ||||
| 
 | ||||
|         self.current = Some(cmsg); | ||||
|         CmsgData::try_from_cmsghdr(cmsg) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl<'a> CMessages<'a> { | ||||
|     pub fn from_bytes(msg_control: &'a mut [u8]) -> Self { | ||||
|         Self { | ||||
|             buffer: msg_control, | ||||
|             current: None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// Control message data of variable type. The data resides next to `cmsghdr`.
 | ||||
| pub enum CmsgData<'a> { | ||||
|     ScmRights(ScmRights<'a>), | ||||
|     ScmCredentials, | ||||
| } | ||||
| 
 | ||||
| impl<'a> CmsgData<'a> { | ||||
|     /// Create an `CmsgData::ScmRights` variant.
 | ||||
|     ///
 | ||||
|     /// # Safety
 | ||||
|     ///
 | ||||
|     /// `data` must contain a valid control message and the control message must be type of
 | ||||
|     /// `SOL_SOCKET` and level of `SCM_RIGHTS`.
 | ||||
|     unsafe fn as_rights(data: &'a mut [u8]) -> Self { | ||||
|         let scm_rights = ScmRights { data }; | ||||
|         CmsgData::ScmRights(scm_rights) | ||||
|     } | ||||
| 
 | ||||
|     /// Create an `CmsgData::ScmCredentials` variant.
 | ||||
|     ///
 | ||||
|     /// # Safety
 | ||||
|     ///
 | ||||
|     /// `data` must contain a valid control message and the control message must be type of
 | ||||
|     /// `SOL_SOCKET` and level of `SCM_CREDENTIALS`.
 | ||||
|     unsafe fn as_credentials(_data: &'a [u8]) -> Self { | ||||
|         CmsgData::ScmCredentials | ||||
|     } | ||||
| 
 | ||||
|     fn try_from_cmsghdr(cmsg: &'a libc::cmsghdr) -> Option<Self> { | ||||
|         unsafe { | ||||
|             let cmsg_len_zero = libc::CMSG_LEN(0) as usize; | ||||
|             let data_len = (*cmsg).cmsg_len as usize - cmsg_len_zero; | ||||
|             let data = libc::CMSG_DATA(cmsg); | ||||
|             let data = core::slice::from_raw_parts_mut(data, data_len); | ||||
| 
 | ||||
|             match (*cmsg).cmsg_level { | ||||
|                 libc::SOL_SOCKET => match (*cmsg).cmsg_type { | ||||
|                     libc::SCM_RIGHTS => Some(CmsgData::as_rights(data)), | ||||
|                     libc::SCM_CREDENTIALS => Some(CmsgData::as_credentials(data)), | ||||
|                     _ => None, | ||||
|                 }, | ||||
|                 _ => None, | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// The data unit of this control message is file descriptor(s).
 | ||||
| ///
 | ||||
| /// The level is equal to `SOL_SOCKET` and the type is equal to `SCM_RIGHTS`.
 | ||||
| pub struct ScmRights<'a> { | ||||
|     data: &'a mut [u8], | ||||
| } | ||||
| 
 | ||||
| impl<'a> ScmRights<'a> { | ||||
|     /// Iterate and reassign each fd in data buffer, given a reassignment function.
 | ||||
|     pub fn iter_and_reassign_fds<F>(&mut self, reassign_fd_fn: F) | ||||
|     where | ||||
|         F: Fn(FileDesc) -> FileDesc, | ||||
|     { | ||||
|         for fd_bytes in self.data.chunks_exact_mut(core::mem::size_of::<FileDesc>()) { | ||||
|             let old_fd = FileDesc::from_ne_bytes(fd_bytes.try_into().unwrap()); | ||||
|             let reassigned_fd = reassign_fd_fn(old_fd); | ||||
|             fd_bytes.copy_from_slice(&reassigned_fd.to_ne_bytes()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn iter_fds(&self) -> impl Iterator<Item = FileDesc> + '_ { | ||||
|         self.data | ||||
|             .chunks_exact(core::mem::size_of::<FileDesc>()) | ||||
|             .map(|fd_bytes| FileDesc::from_ne_bytes(fd_bytes.try_into().unwrap())) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| unsafe fn new_optional_slice<'a, T>(slice_ptr: *const T, slice_size: usize) -> Option<&'a [T]> { | ||||
|     if !slice_ptr.is_null() { | ||||
|         let slice = core::slice::from_raw_parts::<T>(slice_ptr, slice_size); | ||||
|  | ||||
| @ -17,12 +17,14 @@ pub fn end_pair(nonblocking: bool) -> Result<(Endpoint, Endpoint)> { | ||||
|         reader: con_a, | ||||
|         writer: pro_b, | ||||
|         peer: Weak::default(), | ||||
|         ancillary: RwLock::new(None), | ||||
|     }); | ||||
|     let end_b = Arc::new(Inner { | ||||
|         addr: RwLock::new(None), | ||||
|         reader: con_b, | ||||
|         writer: pro_a, | ||||
|         peer: Arc::downgrade(&end_a), | ||||
|         ancillary: RwLock::new(None), | ||||
|     }); | ||||
| 
 | ||||
|     unsafe { | ||||
| @ -41,6 +43,7 @@ pub struct Inner { | ||||
|     reader: Consumer<u8>, | ||||
|     writer: Producer<u8>, | ||||
|     peer: Weak<Self>, | ||||
|     ancillary: RwLock<Option<Ancillary>>, | ||||
| } | ||||
| 
 | ||||
| impl Inner { | ||||
| @ -119,6 +122,18 @@ impl Inner { | ||||
|         events | ||||
|     } | ||||
| 
 | ||||
|     pub fn ancillary(&self) -> Option<Ancillary> { | ||||
|         self.ancillary.read().unwrap().clone() | ||||
|     } | ||||
| 
 | ||||
|     pub fn set_ancillary(&self, ancillary: Ancillary) { | ||||
|         self.ancillary.write().unwrap().insert(ancillary); | ||||
|     } | ||||
| 
 | ||||
|     pub fn peer_ancillary(&self) -> Option<Ancillary> { | ||||
|         self.peer.upgrade().map(|end| end.ancillary()).flatten() | ||||
|     } | ||||
| 
 | ||||
|     pub(self) fn register_relay_notifier(&self, observer: &Arc<RelayNotifier>) { | ||||
|         self.reader.notifier().register( | ||||
|             Arc::downgrade(observer) as Weak<dyn Observer<_>>, | ||||
| @ -138,6 +153,18 @@ impl Inner { | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// Ancillary data of connected unix socket's sent/received control message.
 | ||||
| #[derive(Clone, Debug)] | ||||
| pub struct Ancillary { | ||||
|     pub(super) tid: pid_t, // currently store tid to locate file table
 | ||||
| } | ||||
| 
 | ||||
| impl Ancillary { | ||||
|     pub fn tid(&self) -> pid_t { | ||||
|         self.tid | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| // TODO: Add SO_SNDBUF and SO_RCVBUF to set/getsockopt to dynamcally change the size.
 | ||||
| // This value is got from /proc/sys/net/core/rmem_max and wmem_max that are same on linux.
 | ||||
| pub const DEFAULT_BUF_SIZE: usize = 208 * 1024; | ||||
|  | ||||
| @ -1,11 +1,11 @@ | ||||
| use super::address_space::ADDRESS_SPACE; | ||||
| use super::endpoint::{end_pair, Endpoint, RelayNotifier}; | ||||
| use super::endpoint::{end_pair, Ancillary, Endpoint, RelayNotifier}; | ||||
| use super::*; | ||||
| use events::{Event, EventFilter, Notifier, Observer}; | ||||
| use fs::channel::Channel; | ||||
| use fs::IoEvents; | ||||
| use fs::{CreationFlags, FileMode}; | ||||
| use net::socket::{Iovs, MsgHdr, MsgHdrMut}; | ||||
| use net::socket::{CMessages, CmsgData, Iovs, MsgHdr, MsgHdrMut}; | ||||
| use std::fmt; | ||||
| use std::sync::atomic::{AtomicBool, Ordering}; | ||||
| use std::sync::Arc; | ||||
| @ -161,6 +161,9 @@ impl Stream { | ||||
|                 if let Some(self_addr) = self_addr_opt { | ||||
|                     end_self.set_addr(self_addr); | ||||
|                 } | ||||
|                 end_self.set_ancillary(Ancillary { | ||||
|                     tid: current!().tid(), | ||||
|                 }); | ||||
| 
 | ||||
|                 ADDRESS_SPACE | ||||
|                     .push_incoming(addr, end_incoming) | ||||
| @ -190,6 +193,9 @@ impl Stream { | ||||
|             Status::Listening(addr) => { | ||||
|                 let endpoint = ADDRESS_SPACE.pop_incoming(&addr)?; | ||||
|                 endpoint.set_nonblocking(flags.contains(FileFlags::SOCK_NONBLOCK)); | ||||
|                 endpoint.set_ancillary(Ancillary { | ||||
|                     tid: current!().tid(), | ||||
|                 }); | ||||
|                 let notifier = Arc::new(RelayNotifier::new()); | ||||
|                 notifier.observe_endpoint(&endpoint); | ||||
| 
 | ||||
| @ -228,12 +234,14 @@ impl Stream { | ||||
|         if !flags.is_empty() { | ||||
|             warn!("unsupported flags: {:?}", flags); | ||||
|         } | ||||
|         if msg_hdr.get_control().is_some() { | ||||
|             warn!("sendmsg with msg_control is not supported"); | ||||
|         } | ||||
| 
 | ||||
|         let bufs = msg_hdr.get_iovs().as_slices(); | ||||
|         self.writev(bufs) | ||||
|         let mut data_len = self.writev(bufs)?; | ||||
| 
 | ||||
|         if let Some(msg_control) = msg_hdr.get_control() { | ||||
|             data_len += self.write(msg_control)?; | ||||
|         } | ||||
|         Ok(data_len) | ||||
|     } | ||||
| 
 | ||||
|     pub fn recvmsg(&self, msg_hdr: &mut MsgHdrMut, flags: RecvFlags) -> Result<usize> { | ||||
| @ -242,11 +250,33 @@ impl Stream { | ||||
|         } | ||||
| 
 | ||||
|         let bufs = msg_hdr.get_iovs_mut().as_slices_mut(); | ||||
|         let data_len = self.readv(bufs)?; | ||||
|         let mut data_len = self.readv(bufs)?; | ||||
| 
 | ||||
|         // For stream socket, the msg_name is ignored. And other fields are not supported.
 | ||||
|         msg_hdr.set_name_len(0); | ||||
| 
 | ||||
|         if let Some(msg_control) = msg_hdr.get_control_mut() { | ||||
|             data_len += self.read(msg_control)?; | ||||
| 
 | ||||
|             // For each control message that contains file descriptors (SOL_SOCKET and SCM_RIGHTS),
 | ||||
|             // reassign each fd in the message in receive end.
 | ||||
|             for cmsg in CMessages::from_bytes(msg_control) { | ||||
|                 if let CmsgData::ScmRights(mut scm_rights) = cmsg { | ||||
|                     let send_tid = self.peer_ancillary().unwrap().tid(); | ||||
|                     scm_rights.iter_and_reassign_fds(|send_fd| { | ||||
|                         let ipc_file = process::table::get_thread(send_tid) | ||||
|                             .unwrap() | ||||
|                             .files() | ||||
|                             .lock() | ||||
|                             .unwrap() | ||||
|                             .get(send_fd) | ||||
|                             .unwrap(); | ||||
|                         current!().add_file(ipc_file.clone(), false) | ||||
|                     }) | ||||
|                 } | ||||
|                 // Unix credentials need not to be handled here
 | ||||
|             } | ||||
|         } | ||||
|         Ok(data_len) | ||||
|     } | ||||
| 
 | ||||
| @ -281,6 +311,28 @@ impl Stream { | ||||
|     pub(super) fn inner(&self) -> SgxMutexGuard<'_, Status> { | ||||
|         self.inner.lock().unwrap() | ||||
|     } | ||||
| 
 | ||||
|     fn ancillary(&self) -> Option<Ancillary> { | ||||
|         match &*self.inner() { | ||||
|             Status::Idle(_) => None, | ||||
|             Status::Listening(_) => None, | ||||
|             Status::Connected(endpoint) => endpoint.ancillary(), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn peer_ancillary(&self) -> Option<Ancillary> { | ||||
|         if let Status::Connected(endpoint) = &*self.inner() { | ||||
|             endpoint.peer_ancillary() | ||||
|         } else { | ||||
|             None | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn set_ancillary(&self, ancillary: Ancillary) { | ||||
|         if let Status::Connected(endpoint) = &*self.inner() { | ||||
|             endpoint.set_ancillary(ancillary) | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Debug for Stream { | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user