Change the flags used in sendmsg/recvmsg from raw int to memory-safe type

This commit is contained in:
He Sun 2020-03-11 22:33:48 +08:00 committed by tate.thl
parent eff91daac9
commit e2edaa49c0
6 changed files with 53 additions and 34 deletions

@ -11,7 +11,7 @@ mod unix_socket;
pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec};
pub use self::msg::{msghdr, msghdr_mut, MsgHdr, MsgHdrMut};
pub use self::msg_flags::MsgFlags;
pub use self::msg_flags::{MsgHdrFlags, RecvFlags, SendFlags};
pub use self::socket_file::{AsSocket, SocketFile};
pub use self::syscalls::*;
pub use self::unix_socket::{AsUnixSocket, UnixSocketFile};

@ -32,7 +32,7 @@ pub struct MsgHdr<'a> {
name: Option<&'a [u8]>,
iovs: Iovs<'a>,
control: Option<&'a [u8]>,
flags: MsgFlags,
flags: MsgHdrFlags,
c_self: &'a msghdr,
}
@ -51,7 +51,7 @@ impl<'a> MsgHdr<'a> {
c_msg.msg_controllen as usize,
);
let flags = MsgFlags::from_u32(c_msg.msg_flags as u32)?;
let flags = MsgHdrFlags::from_bits_truncate(c_msg.msg_flags);
let iovs = {
let iovs_vec = match iovs_opt_slice {
@ -85,7 +85,7 @@ impl<'a> MsgHdr<'a> {
self.control
}
pub fn get_flags(&self) -> MsgFlags {
pub fn get_flags(&self) -> MsgHdrFlags {
self.flags
}
}
@ -95,7 +95,7 @@ pub struct MsgHdrMut<'a> {
name: Option<&'a mut [u8]>,
iovs: IovsMut<'a>,
control: Option<&'a mut [u8]>,
flags: MsgFlags,
flags: MsgHdrFlags,
c_self: &'a mut msghdr_mut,
}
@ -111,7 +111,7 @@ impl<'a> MsgHdrMut<'a> {
let control_opt_slice =
new_optional_slice_mut(c_msg.msg_control as *mut u8, c_msg.msg_controllen as usize);
let flags = MsgFlags::from_u32(c_msg.msg_flags as u32)?;
let flags = MsgHdrFlags::from_bits_truncate(c_msg.msg_flags);
let iovs = {
let iovs_vec = match iovs_opt_slice {
@ -149,7 +149,7 @@ impl<'a> MsgHdrMut<'a> {
self.control.as_ref().map(|control| &control[..])
}
pub fn get_flags(&self) -> MsgFlags {
pub fn get_flags(&self) -> MsgHdrFlags {
self.flags
}
@ -203,9 +203,9 @@ impl<'a> MsgHdrMut<'a> {
)
}
pub fn set_flags(&mut self, flags: MsgFlags) {
pub fn set_flags(&mut self, flags: MsgHdrFlags) {
self.flags = flags;
self.c_self.msg_flags = flags.to_u32() as i32;
self.c_self.msg_flags = flags.bits();
}
}

@ -1,17 +1,36 @@
use super::*;
// TODO: use bitflag! to make this memory safe
#[derive(Debug, Copy, Clone, Default)]
pub struct MsgFlags {
bits: u32,
bitflags! {
pub struct SendFlags: i32 {
const MSG_OOB = 0x01;
const MSG_DONTROUTE = 0x04;
const MSG_DONTWAIT = 0x40; // Nonblocking io
const MSG_EOR = 0x80; // End of record
const MSG_CONFIRM = 0x0800; // Confirm path validity
const MSG_NOSIGNAL = 0x4000; // Do not generate SIGPIPE
const MSG_MORE = 0x8000; // Sender will send more
}
}
impl MsgFlags {
pub fn from_u32(c_flags: u32) -> Result<MsgFlags> {
Ok(MsgFlags { bits: 0 })
bitflags! {
pub struct RecvFlags: i32 {
const MSG_OOB = 0x01;
const MSG_PEEK = 0x02;
const MSG_TRUNC = 0x20;
const MSG_DONTWAIT = 0x40; // Nonblocking io
const MSG_WAITALL = 0x0100; // Wait for a full request
const MSG_ERRQUEUE = 0x2000; // Fetch message from error queue
const MSG_CMSG_CLOEXEC = 0x40000000; // Set close_on_exec for file descriptor received through M_RIGHTS
}
}
pub fn to_u32(&self) -> u32 {
self.bits
bitflags! {
pub struct MsgHdrFlags: i32 {
const MSG_OOB = 0x01;
const MSG_CTRUNC = 0x08;
const MSG_TRUNC = 0x20;
const MSG_EOR = 0x80; // End of record
const MSG_ERRQUEUE = 0x2000; // Fetch message from error queue
const MSG_NOTIFICATION = 0x8000; // Only applicable to SCTP socket
}
}

@ -4,12 +4,12 @@ use crate::untrusted::{SliceAsMutPtrAndLen, SliceAsPtrAndLen, UntrustedSliceAllo
impl SocketFile {
// TODO: need sockaddr type to implement send/sento
/*
pub fn recv(&self, buf: &mut [u8], flags: MsgFlags) -> Result<usize> {
pub fn recv(&self, buf: &mut [u8], flags: MsgHdrFlags) -> Result<usize> {
let (bytes_recvd, _) = self.recvfrom(buf, flags, None)?;
Ok(bytes_recvd)
}
pub fn recvfrom(&self, buf: &mut [u8], flags: MsgFlags, src_addr: Option<&mut [u8]>) -> Result<(usize, usize)> {
pub fn recvfrom(&self, buf: &mut [u8], flags: MsgHdrFlags, src_addr: Option<&mut [u8]>) -> Result<(usize, usize)> {
let (bytes_recvd, src_addr_len, _, _) = self.do_recvmsg(
&mut buf[..],
flags,
@ -19,7 +19,7 @@ impl SocketFile {
Ok((bytes_recvd, src_addr_len))
}*/
pub fn recvmsg<'a, 'b>(&self, msg: &'b mut MsgHdrMut<'a>, flags: MsgFlags) -> Result<usize> {
pub fn recvmsg<'a, 'b>(&self, msg: &'b mut MsgHdrMut<'a>, flags: RecvFlags) -> Result<usize> {
// Alloc untrusted iovecs to receive data via OCall
let msg_iov = msg.get_iovs();
let u_slice_alloc = UntrustedSliceAlloc::new(msg_iov.total_bytes())?;
@ -62,10 +62,10 @@ impl SocketFile {
fn do_recvmsg(
&self,
data: &mut [&mut [u8]],
flags: MsgFlags,
flags: RecvFlags,
mut name: Option<&mut [u8]>,
mut control: Option<&mut [u8]>,
) -> Result<(usize, usize, usize, MsgFlags)> {
) -> Result<(usize, usize, usize, MsgHdrFlags)> {
// Prepare the arguments for OCall
// Host socket fd
let host_fd = self.host_fd;
@ -82,7 +82,7 @@ impl SocketFile {
let msg_control = msg_control as *mut c_void;
let mut msg_controllen_recvd = 0;
// Flags
let flags = flags.to_u32() as i32;
let flags = flags.bits();
let mut msg_flags_recvd = 0;
// Do OCall
@ -123,7 +123,7 @@ impl SocketFile {
let msg_namelen_recvd = msg_namelen_recvd as usize;
assert!(msg_namelen_recvd <= msg_namelen);
assert!(msg_controllen_recvd <= msg_controllen);
let flags_recvd = MsgFlags::from_u32(msg_flags_recvd as u32)?;
let flags_recvd = MsgHdrFlags::from_bits(msg_flags_recvd).unwrap();
Ok((
bytes_recvd,

@ -18,7 +18,7 @@ impl SocketFile {
}
*/
pub fn sendmsg<'a, 'b>(&self, msg: &'b MsgHdr<'a>, flags: MsgFlags) -> Result<usize> {
pub fn sendmsg<'a, 'b>(&self, msg: &'b MsgHdr<'a>, flags: SendFlags) -> Result<usize> {
// Copy message's iovecs into untrusted iovecs
let msg_iov = msg.get_iovs();
let u_slice_alloc = UntrustedSliceAlloc::new(msg_iov.total_bytes())?;
@ -39,7 +39,7 @@ impl SocketFile {
fn do_sendmsg(
&self,
data: &[&[u8]],
flags: MsgFlags,
flags: SendFlags,
name: Option<&[u8]>,
control: Option<&[u8]>,
) -> Result<usize> {
@ -57,7 +57,7 @@ impl SocketFile {
let (msg_control, msg_controllen) = control.as_ptr_and_len();
let msg_control = msg_control as *const c_void;
// Flags
let flags = flags.to_u32() as i32;
let flags = flags.bits();
let bytes_sent = try_libc!({
// Do OCall

@ -23,7 +23,7 @@ pub fn do_sendmsg(fd: c_int, msg_ptr: *const msghdr, flags_c: c_int) -> Result<i
};
let msg = unsafe { MsgHdr::from_c(&msg_c)? };
let flags = MsgFlags::from_u32(flags_c as u32)?;
let flags = SendFlags::from_bits_truncate(flags_c);
socket
.sendmsg(&msg, flags)
@ -53,7 +53,7 @@ pub fn do_recvmsg(fd: c_int, msg_mut_ptr: *mut msghdr_mut, flags_c: c_int) -> Re
};
let mut msg_mut = unsafe { MsgHdrMut::from_c(msg_mut_c)? };
let flags = MsgFlags::from_u32(flags_c as u32)?;
let flags = RecvFlags::from_bits_truncate(flags_c);
socket
.recvmsg(&mut msg_mut, flags)