diff --git a/src/libos/src/net/mod.rs b/src/libos/src/net/mod.rs index 9441ec1e..490f2e91 100644 --- a/src/libos/src/net/mod.rs +++ b/src/libos/src/net/mod.rs @@ -11,7 +11,7 @@ mod unix_socket; pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec}; pub use self::msg::{msghdr, msghdr_mut, MsgHdr, MsgHdrMut}; -pub use self::msg_flags::MsgFlags; +pub use self::msg_flags::{MsgHdrFlags, RecvFlags, SendFlags}; pub use self::socket_file::{AsSocket, SocketFile}; pub use self::syscalls::*; pub use self::unix_socket::{AsUnixSocket, UnixSocketFile}; diff --git a/src/libos/src/net/msg.rs b/src/libos/src/net/msg.rs index 9007f69a..92433d55 100644 --- a/src/libos/src/net/msg.rs +++ b/src/libos/src/net/msg.rs @@ -32,7 +32,7 @@ pub struct MsgHdr<'a> { name: Option<&'a [u8]>, iovs: Iovs<'a>, control: Option<&'a [u8]>, - flags: MsgFlags, + flags: MsgHdrFlags, c_self: &'a msghdr, } @@ -51,7 +51,7 @@ impl<'a> MsgHdr<'a> { c_msg.msg_controllen as usize, ); - let flags = MsgFlags::from_u32(c_msg.msg_flags as u32)?; + let flags = MsgHdrFlags::from_bits_truncate(c_msg.msg_flags); let iovs = { let iovs_vec = match iovs_opt_slice { @@ -85,7 +85,7 @@ impl<'a> MsgHdr<'a> { self.control } - pub fn get_flags(&self) -> MsgFlags { + pub fn get_flags(&self) -> MsgHdrFlags { self.flags } } @@ -95,7 +95,7 @@ pub struct MsgHdrMut<'a> { name: Option<&'a mut [u8]>, iovs: IovsMut<'a>, control: Option<&'a mut [u8]>, - flags: MsgFlags, + flags: MsgHdrFlags, c_self: &'a mut msghdr_mut, } @@ -111,7 +111,7 @@ impl<'a> MsgHdrMut<'a> { let control_opt_slice = new_optional_slice_mut(c_msg.msg_control as *mut u8, c_msg.msg_controllen as usize); - let flags = MsgFlags::from_u32(c_msg.msg_flags as u32)?; + let flags = MsgHdrFlags::from_bits_truncate(c_msg.msg_flags); let iovs = { let iovs_vec = match iovs_opt_slice { @@ -149,7 +149,7 @@ impl<'a> MsgHdrMut<'a> { self.control.as_ref().map(|control| &control[..]) } - pub fn get_flags(&self) -> MsgFlags { + pub fn get_flags(&self) -> MsgHdrFlags { self.flags } @@ -203,9 +203,9 @@ impl<'a> MsgHdrMut<'a> { ) } - pub fn set_flags(&mut self, flags: MsgFlags) { + pub fn set_flags(&mut self, flags: MsgHdrFlags) { self.flags = flags; - self.c_self.msg_flags = flags.to_u32() as i32; + self.c_self.msg_flags = flags.bits(); } } diff --git a/src/libos/src/net/msg_flags.rs b/src/libos/src/net/msg_flags.rs index 096e00ec..de696916 100644 --- a/src/libos/src/net/msg_flags.rs +++ b/src/libos/src/net/msg_flags.rs @@ -1,17 +1,36 @@ use super::*; -// TODO: use bitflag! to make this memory safe -#[derive(Debug, Copy, Clone, Default)] -pub struct MsgFlags { - bits: u32, -} - -impl MsgFlags { - pub fn from_u32(c_flags: u32) -> Result { - Ok(MsgFlags { bits: 0 }) - } - - pub fn to_u32(&self) -> u32 { - self.bits +bitflags! { + pub struct SendFlags: i32 { + const MSG_OOB = 0x01; + const MSG_DONTROUTE = 0x04; + const MSG_DONTWAIT = 0x40; // Nonblocking io + const MSG_EOR = 0x80; // End of record + const MSG_CONFIRM = 0x0800; // Confirm path validity + const MSG_NOSIGNAL = 0x4000; // Do not generate SIGPIPE + const MSG_MORE = 0x8000; // Sender will send more + } +} + +bitflags! { + pub struct RecvFlags: i32 { + const MSG_OOB = 0x01; + const MSG_PEEK = 0x02; + const MSG_TRUNC = 0x20; + const MSG_DONTWAIT = 0x40; // Nonblocking io + const MSG_WAITALL = 0x0100; // Wait for a full request + const MSG_ERRQUEUE = 0x2000; // Fetch message from error queue + const MSG_CMSG_CLOEXEC = 0x40000000; // Set close_on_exec for file descriptor received through M_RIGHTS + } +} + +bitflags! { + pub struct MsgHdrFlags: i32 { + const MSG_OOB = 0x01; + const MSG_CTRUNC = 0x08; + const MSG_TRUNC = 0x20; + const MSG_EOR = 0x80; // End of record + const MSG_ERRQUEUE = 0x2000; // Fetch message from error queue + const MSG_NOTIFICATION = 0x8000; // Only applicable to SCTP socket } } diff --git a/src/libos/src/net/socket_file/recv.rs b/src/libos/src/net/socket_file/recv.rs index 7f688ab1..58ffa869 100644 --- a/src/libos/src/net/socket_file/recv.rs +++ b/src/libos/src/net/socket_file/recv.rs @@ -4,12 +4,12 @@ use crate::untrusted::{SliceAsMutPtrAndLen, SliceAsPtrAndLen, UntrustedSliceAllo impl SocketFile { // TODO: need sockaddr type to implement send/sento /* - pub fn recv(&self, buf: &mut [u8], flags: MsgFlags) -> Result { + pub fn recv(&self, buf: &mut [u8], flags: MsgHdrFlags) -> Result { let (bytes_recvd, _) = self.recvfrom(buf, flags, None)?; Ok(bytes_recvd) } - pub fn recvfrom(&self, buf: &mut [u8], flags: MsgFlags, src_addr: Option<&mut [u8]>) -> Result<(usize, usize)> { + pub fn recvfrom(&self, buf: &mut [u8], flags: MsgHdrFlags, src_addr: Option<&mut [u8]>) -> Result<(usize, usize)> { let (bytes_recvd, src_addr_len, _, _) = self.do_recvmsg( &mut buf[..], flags, @@ -19,7 +19,7 @@ impl SocketFile { Ok((bytes_recvd, src_addr_len)) }*/ - pub fn recvmsg<'a, 'b>(&self, msg: &'b mut MsgHdrMut<'a>, flags: MsgFlags) -> Result { + pub fn recvmsg<'a, 'b>(&self, msg: &'b mut MsgHdrMut<'a>, flags: RecvFlags) -> Result { // Alloc untrusted iovecs to receive data via OCall let msg_iov = msg.get_iovs(); let u_slice_alloc = UntrustedSliceAlloc::new(msg_iov.total_bytes())?; @@ -62,10 +62,10 @@ impl SocketFile { fn do_recvmsg( &self, data: &mut [&mut [u8]], - flags: MsgFlags, + flags: RecvFlags, mut name: Option<&mut [u8]>, mut control: Option<&mut [u8]>, - ) -> Result<(usize, usize, usize, MsgFlags)> { + ) -> Result<(usize, usize, usize, MsgHdrFlags)> { // Prepare the arguments for OCall // Host socket fd let host_fd = self.host_fd; @@ -82,7 +82,7 @@ impl SocketFile { let msg_control = msg_control as *mut c_void; let mut msg_controllen_recvd = 0; // Flags - let flags = flags.to_u32() as i32; + let flags = flags.bits(); let mut msg_flags_recvd = 0; // Do OCall @@ -123,7 +123,7 @@ impl SocketFile { let msg_namelen_recvd = msg_namelen_recvd as usize; assert!(msg_namelen_recvd <= msg_namelen); assert!(msg_controllen_recvd <= msg_controllen); - let flags_recvd = MsgFlags::from_u32(msg_flags_recvd as u32)?; + let flags_recvd = MsgHdrFlags::from_bits(msg_flags_recvd).unwrap(); Ok(( bytes_recvd, diff --git a/src/libos/src/net/socket_file/send.rs b/src/libos/src/net/socket_file/send.rs index 635f8d46..b2581f1b 100644 --- a/src/libos/src/net/socket_file/send.rs +++ b/src/libos/src/net/socket_file/send.rs @@ -18,7 +18,7 @@ impl SocketFile { } */ - pub fn sendmsg<'a, 'b>(&self, msg: &'b MsgHdr<'a>, flags: MsgFlags) -> Result { + pub fn sendmsg<'a, 'b>(&self, msg: &'b MsgHdr<'a>, flags: SendFlags) -> Result { // Copy message's iovecs into untrusted iovecs let msg_iov = msg.get_iovs(); let u_slice_alloc = UntrustedSliceAlloc::new(msg_iov.total_bytes())?; @@ -39,7 +39,7 @@ impl SocketFile { fn do_sendmsg( &self, data: &[&[u8]], - flags: MsgFlags, + flags: SendFlags, name: Option<&[u8]>, control: Option<&[u8]>, ) -> Result { @@ -57,7 +57,7 @@ impl SocketFile { let (msg_control, msg_controllen) = control.as_ptr_and_len(); let msg_control = msg_control as *const c_void; // Flags - let flags = flags.to_u32() as i32; + let flags = flags.bits(); let bytes_sent = try_libc!({ // Do OCall diff --git a/src/libos/src/net/syscalls.rs b/src/libos/src/net/syscalls.rs index 9ae77ae5..39c5bed4 100644 --- a/src/libos/src/net/syscalls.rs +++ b/src/libos/src/net/syscalls.rs @@ -23,7 +23,7 @@ pub fn do_sendmsg(fd: c_int, msg_ptr: *const msghdr, flags_c: c_int) -> Result Re }; let mut msg_mut = unsafe { MsgHdrMut::from_c(msg_mut_c)? }; - let flags = MsgFlags::from_u32(flags_c as u32)?; + let flags = RecvFlags::from_bits_truncate(flags_c); socket .recvmsg(&mut msg_mut, flags)