Change the flags used in sendmsg/recvmsg from raw int to memory-safe type
This commit is contained in:
		
							parent
							
								
									eff91daac9
								
							
						
					
					
						commit
						e2edaa49c0
					
				| @ -11,7 +11,7 @@ mod unix_socket; | |||||||
| 
 | 
 | ||||||
| pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec}; | pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec}; | ||||||
| pub use self::msg::{msghdr, msghdr_mut, MsgHdr, MsgHdrMut}; | pub use self::msg::{msghdr, msghdr_mut, MsgHdr, MsgHdrMut}; | ||||||
| pub use self::msg_flags::MsgFlags; | pub use self::msg_flags::{MsgHdrFlags, RecvFlags, SendFlags}; | ||||||
| pub use self::socket_file::{AsSocket, SocketFile}; | pub use self::socket_file::{AsSocket, SocketFile}; | ||||||
| pub use self::syscalls::*; | pub use self::syscalls::*; | ||||||
| pub use self::unix_socket::{AsUnixSocket, UnixSocketFile}; | pub use self::unix_socket::{AsUnixSocket, UnixSocketFile}; | ||||||
|  | |||||||
| @ -32,7 +32,7 @@ pub struct MsgHdr<'a> { | |||||||
|     name: Option<&'a [u8]>, |     name: Option<&'a [u8]>, | ||||||
|     iovs: Iovs<'a>, |     iovs: Iovs<'a>, | ||||||
|     control: Option<&'a [u8]>, |     control: Option<&'a [u8]>, | ||||||
|     flags: MsgFlags, |     flags: MsgHdrFlags, | ||||||
|     c_self: &'a msghdr, |     c_self: &'a msghdr, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -51,7 +51,7 @@ impl<'a> MsgHdr<'a> { | |||||||
|             c_msg.msg_controllen as usize, |             c_msg.msg_controllen as usize, | ||||||
|         ); |         ); | ||||||
| 
 | 
 | ||||||
|         let flags = MsgFlags::from_u32(c_msg.msg_flags as u32)?; |         let flags = MsgHdrFlags::from_bits_truncate(c_msg.msg_flags); | ||||||
| 
 | 
 | ||||||
|         let iovs = { |         let iovs = { | ||||||
|             let iovs_vec = match iovs_opt_slice { |             let iovs_vec = match iovs_opt_slice { | ||||||
| @ -85,7 +85,7 @@ impl<'a> MsgHdr<'a> { | |||||||
|         self.control |         self.control | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn get_flags(&self) -> MsgFlags { |     pub fn get_flags(&self) -> MsgHdrFlags { | ||||||
|         self.flags |         self.flags | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @ -95,7 +95,7 @@ pub struct MsgHdrMut<'a> { | |||||||
|     name: Option<&'a mut [u8]>, |     name: Option<&'a mut [u8]>, | ||||||
|     iovs: IovsMut<'a>, |     iovs: IovsMut<'a>, | ||||||
|     control: Option<&'a mut [u8]>, |     control: Option<&'a mut [u8]>, | ||||||
|     flags: MsgFlags, |     flags: MsgHdrFlags, | ||||||
|     c_self: &'a mut msghdr_mut, |     c_self: &'a mut msghdr_mut, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -111,7 +111,7 @@ impl<'a> MsgHdrMut<'a> { | |||||||
|         let control_opt_slice = |         let control_opt_slice = | ||||||
|             new_optional_slice_mut(c_msg.msg_control as *mut u8, c_msg.msg_controllen as usize); |             new_optional_slice_mut(c_msg.msg_control as *mut u8, c_msg.msg_controllen as usize); | ||||||
| 
 | 
 | ||||||
|         let flags = MsgFlags::from_u32(c_msg.msg_flags as u32)?; |         let flags = MsgHdrFlags::from_bits_truncate(c_msg.msg_flags); | ||||||
| 
 | 
 | ||||||
|         let iovs = { |         let iovs = { | ||||||
|             let iovs_vec = match iovs_opt_slice { |             let iovs_vec = match iovs_opt_slice { | ||||||
| @ -149,7 +149,7 @@ impl<'a> MsgHdrMut<'a> { | |||||||
|         self.control.as_ref().map(|control| &control[..]) |         self.control.as_ref().map(|control| &control[..]) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn get_flags(&self) -> MsgFlags { |     pub fn get_flags(&self) -> MsgHdrFlags { | ||||||
|         self.flags |         self.flags | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -203,9 +203,9 @@ impl<'a> MsgHdrMut<'a> { | |||||||
|         ) |         ) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn set_flags(&mut self, flags: MsgFlags) { |     pub fn set_flags(&mut self, flags: MsgHdrFlags) { | ||||||
|         self.flags = flags; |         self.flags = flags; | ||||||
|         self.c_self.msg_flags = flags.to_u32() as i32; |         self.c_self.msg_flags = flags.bits(); | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -1,17 +1,36 @@ | |||||||
| use super::*; | use super::*; | ||||||
| 
 | 
 | ||||||
| // TODO: use bitflag! to make this memory safe
 | bitflags! { | ||||||
| #[derive(Debug, Copy, Clone, Default)] |     pub struct SendFlags: i32 { | ||||||
| pub struct MsgFlags { |         const MSG_OOB          = 0x01; | ||||||
|     bits: u32, |         const MSG_DONTROUTE    = 0x04; | ||||||
| } |         const MSG_DONTWAIT     = 0x40;       // Nonblocking io
 | ||||||
| 
 |         const MSG_EOR          = 0x80;       // End of record
 | ||||||
| impl MsgFlags { |         const MSG_CONFIRM      = 0x0800;     // Confirm path validity
 | ||||||
|     pub fn from_u32(c_flags: u32) -> Result<MsgFlags> { |         const MSG_NOSIGNAL     = 0x4000;     // Do not generate SIGPIPE
 | ||||||
|         Ok(MsgFlags { bits: 0 }) |         const MSG_MORE         = 0x8000;     // Sender will send more
 | ||||||
|     } |     } | ||||||
| 
 | } | ||||||
|     pub fn to_u32(&self) -> u32 { | 
 | ||||||
|         self.bits | bitflags! { | ||||||
|  |     pub struct RecvFlags: i32 { | ||||||
|  |         const MSG_OOB          = 0x01; | ||||||
|  |         const MSG_PEEK         = 0x02; | ||||||
|  |         const MSG_TRUNC        = 0x20; | ||||||
|  |         const MSG_DONTWAIT     = 0x40;       // Nonblocking io
 | ||||||
|  |         const MSG_WAITALL      = 0x0100;     // Wait for a full request
 | ||||||
|  |         const MSG_ERRQUEUE     = 0x2000;     // Fetch message from error queue
 | ||||||
|  |         const MSG_CMSG_CLOEXEC = 0x40000000; // Set close_on_exec for file descriptor received through M_RIGHTS
 | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | bitflags! { | ||||||
|  |     pub struct MsgHdrFlags: i32 { | ||||||
|  |         const MSG_OOB          = 0x01; | ||||||
|  |         const MSG_CTRUNC       = 0x08; | ||||||
|  |         const MSG_TRUNC        = 0x20; | ||||||
|  |         const MSG_EOR          = 0x80;       // End of record
 | ||||||
|  |         const MSG_ERRQUEUE     = 0x2000;     // Fetch message from error queue
 | ||||||
|  |         const MSG_NOTIFICATION = 0x8000;     // Only applicable to SCTP socket
 | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -4,12 +4,12 @@ use crate::untrusted::{SliceAsMutPtrAndLen, SliceAsPtrAndLen, UntrustedSliceAllo | |||||||
| impl SocketFile { | impl SocketFile { | ||||||
|     // TODO: need sockaddr type to implement send/sento
 |     // TODO: need sockaddr type to implement send/sento
 | ||||||
|     /* |     /* | ||||||
|     pub fn recv(&self, buf: &mut [u8], flags: MsgFlags) -> Result<usize> { |     pub fn recv(&self, buf: &mut [u8], flags: MsgHdrFlags) -> Result<usize> { | ||||||
|         let (bytes_recvd, _) = self.recvfrom(buf, flags, None)?; |         let (bytes_recvd, _) = self.recvfrom(buf, flags, None)?; | ||||||
|         Ok(bytes_recvd) |         Ok(bytes_recvd) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn recvfrom(&self, buf: &mut [u8], flags: MsgFlags, src_addr: Option<&mut [u8]>) -> Result<(usize, usize)> { |     pub fn recvfrom(&self, buf: &mut [u8], flags: MsgHdrFlags, src_addr: Option<&mut [u8]>) -> Result<(usize, usize)> { | ||||||
|         let (bytes_recvd, src_addr_len, _, _) = self.do_recvmsg( |         let (bytes_recvd, src_addr_len, _, _) = self.do_recvmsg( | ||||||
|             &mut buf[..], |             &mut buf[..], | ||||||
|             flags, |             flags, | ||||||
| @ -19,7 +19,7 @@ impl SocketFile { | |||||||
|         Ok((bytes_recvd, src_addr_len)) |         Ok((bytes_recvd, src_addr_len)) | ||||||
|     }*/ |     }*/ | ||||||
| 
 | 
 | ||||||
|     pub fn recvmsg<'a, 'b>(&self, msg: &'b mut MsgHdrMut<'a>, flags: MsgFlags) -> Result<usize> { |     pub fn recvmsg<'a, 'b>(&self, msg: &'b mut MsgHdrMut<'a>, flags: RecvFlags) -> Result<usize> { | ||||||
|         // Alloc untrusted iovecs to receive data via OCall
 |         // Alloc untrusted iovecs to receive data via OCall
 | ||||||
|         let msg_iov = msg.get_iovs(); |         let msg_iov = msg.get_iovs(); | ||||||
|         let u_slice_alloc = UntrustedSliceAlloc::new(msg_iov.total_bytes())?; |         let u_slice_alloc = UntrustedSliceAlloc::new(msg_iov.total_bytes())?; | ||||||
| @ -62,10 +62,10 @@ impl SocketFile { | |||||||
|     fn do_recvmsg( |     fn do_recvmsg( | ||||||
|         &self, |         &self, | ||||||
|         data: &mut [&mut [u8]], |         data: &mut [&mut [u8]], | ||||||
|         flags: MsgFlags, |         flags: RecvFlags, | ||||||
|         mut name: Option<&mut [u8]>, |         mut name: Option<&mut [u8]>, | ||||||
|         mut control: Option<&mut [u8]>, |         mut control: Option<&mut [u8]>, | ||||||
|     ) -> Result<(usize, usize, usize, MsgFlags)> { |     ) -> Result<(usize, usize, usize, MsgHdrFlags)> { | ||||||
|         // Prepare the arguments for OCall
 |         // Prepare the arguments for OCall
 | ||||||
|         // Host socket fd
 |         // Host socket fd
 | ||||||
|         let host_fd = self.host_fd; |         let host_fd = self.host_fd; | ||||||
| @ -82,7 +82,7 @@ impl SocketFile { | |||||||
|         let msg_control = msg_control as *mut c_void; |         let msg_control = msg_control as *mut c_void; | ||||||
|         let mut msg_controllen_recvd = 0; |         let mut msg_controllen_recvd = 0; | ||||||
|         // Flags
 |         // Flags
 | ||||||
|         let flags = flags.to_u32() as i32; |         let flags = flags.bits(); | ||||||
|         let mut msg_flags_recvd = 0; |         let mut msg_flags_recvd = 0; | ||||||
| 
 | 
 | ||||||
|         // Do OCall
 |         // Do OCall
 | ||||||
| @ -123,7 +123,7 @@ impl SocketFile { | |||||||
|         let msg_namelen_recvd = msg_namelen_recvd as usize; |         let msg_namelen_recvd = msg_namelen_recvd as usize; | ||||||
|         assert!(msg_namelen_recvd <= msg_namelen); |         assert!(msg_namelen_recvd <= msg_namelen); | ||||||
|         assert!(msg_controllen_recvd <= msg_controllen); |         assert!(msg_controllen_recvd <= msg_controllen); | ||||||
|         let flags_recvd = MsgFlags::from_u32(msg_flags_recvd as u32)?; |         let flags_recvd = MsgHdrFlags::from_bits(msg_flags_recvd).unwrap(); | ||||||
| 
 | 
 | ||||||
|         Ok(( |         Ok(( | ||||||
|             bytes_recvd, |             bytes_recvd, | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ impl SocketFile { | |||||||
|     } |     } | ||||||
|     */ |     */ | ||||||
| 
 | 
 | ||||||
|     pub fn sendmsg<'a, 'b>(&self, msg: &'b MsgHdr<'a>, flags: MsgFlags) -> Result<usize> { |     pub fn sendmsg<'a, 'b>(&self, msg: &'b MsgHdr<'a>, flags: SendFlags) -> Result<usize> { | ||||||
|         // Copy message's iovecs into untrusted iovecs
 |         // Copy message's iovecs into untrusted iovecs
 | ||||||
|         let msg_iov = msg.get_iovs(); |         let msg_iov = msg.get_iovs(); | ||||||
|         let u_slice_alloc = UntrustedSliceAlloc::new(msg_iov.total_bytes())?; |         let u_slice_alloc = UntrustedSliceAlloc::new(msg_iov.total_bytes())?; | ||||||
| @ -39,7 +39,7 @@ impl SocketFile { | |||||||
|     fn do_sendmsg( |     fn do_sendmsg( | ||||||
|         &self, |         &self, | ||||||
|         data: &[&[u8]], |         data: &[&[u8]], | ||||||
|         flags: MsgFlags, |         flags: SendFlags, | ||||||
|         name: Option<&[u8]>, |         name: Option<&[u8]>, | ||||||
|         control: Option<&[u8]>, |         control: Option<&[u8]>, | ||||||
|     ) -> Result<usize> { |     ) -> Result<usize> { | ||||||
| @ -57,7 +57,7 @@ impl SocketFile { | |||||||
|         let (msg_control, msg_controllen) = control.as_ptr_and_len(); |         let (msg_control, msg_controllen) = control.as_ptr_and_len(); | ||||||
|         let msg_control = msg_control as *const c_void; |         let msg_control = msg_control as *const c_void; | ||||||
|         // Flags
 |         // Flags
 | ||||||
|         let flags = flags.to_u32() as i32; |         let flags = flags.bits(); | ||||||
| 
 | 
 | ||||||
|         let bytes_sent = try_libc!({ |         let bytes_sent = try_libc!({ | ||||||
|             // Do OCall
 |             // Do OCall
 | ||||||
|  | |||||||
| @ -23,7 +23,7 @@ pub fn do_sendmsg(fd: c_int, msg_ptr: *const msghdr, flags_c: c_int) -> Result<i | |||||||
|         }; |         }; | ||||||
|         let msg = unsafe { MsgHdr::from_c(&msg_c)? }; |         let msg = unsafe { MsgHdr::from_c(&msg_c)? }; | ||||||
| 
 | 
 | ||||||
|         let flags = MsgFlags::from_u32(flags_c as u32)?; |         let flags = SendFlags::from_bits_truncate(flags_c); | ||||||
| 
 | 
 | ||||||
|         socket |         socket | ||||||
|             .sendmsg(&msg, flags) |             .sendmsg(&msg, flags) | ||||||
| @ -53,7 +53,7 @@ pub fn do_recvmsg(fd: c_int, msg_mut_ptr: *mut msghdr_mut, flags_c: c_int) -> Re | |||||||
|         }; |         }; | ||||||
|         let mut msg_mut = unsafe { MsgHdrMut::from_c(msg_mut_c)? }; |         let mut msg_mut = unsafe { MsgHdrMut::from_c(msg_mut_c)? }; | ||||||
| 
 | 
 | ||||||
|         let flags = MsgFlags::from_u32(flags_c as u32)?; |         let flags = RecvFlags::from_bits_truncate(flags_c); | ||||||
| 
 | 
 | ||||||
|         socket |         socket | ||||||
|             .recvmsg(&mut msg_mut, flags) |             .recvmsg(&mut msg_mut, flags) | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user