diff --git a/src/libos/src/net/syscalls.rs b/src/libos/src/net/syscalls.rs index dad9f2cc..c8c72ee7 100644 --- a/src/libos/src/net/syscalls.rs +++ b/src/libos/src/net/syscalls.rs @@ -1,10 +1,15 @@ -use super::*; +use super::socket::{MsgFlags, SocketProtocol}; +use super::{socket::uring::socket_file::SocketFile, *}; +use atomic::Ordering; +use core::f32::consts::E; +use num_enum::TryFromPrimitive; use std::mem::MaybeUninit; +use std::ptr; use std::time::Duration; use super::io_multiplexing::{AsEpollFile, EpollCtl, EpollFile, EpollFlags, FdSetExt, PollFd}; -use fs::{CreationFlags, File, FileDesc, FileRef}; +use fs::{CreationFlags, File, FileDesc, FileRef, IoctlCmd}; use misc::resource_t; use process::Process; use signal::{sigset_t, MaskOp, SigSet, SIGKILL, SIGSTOP}; @@ -12,42 +17,103 @@ use std::convert::TryFrom; use time::{timespec_t, timeval_t}; use util::mem_util::from_user; -pub fn do_socket(domain: c_int, socket_type: c_int, protocol: c_int) -> Result { - let sock_domain = AddressFamily::try_from(domain as u16)?; - let file_flags = FileFlags::from_bits_truncate(socket_type); - let sock_type = SocketType::try_from(socket_type & (!file_flags.bits()))?; +use super::socket::sockopt::{ + GetAcceptConnCmd, GetDomainCmd, GetErrorCmd, GetOutputAsBytes, GetPeerNameCmd, + GetRcvBufSizeCmd, GetRecvTimeoutCmd, GetSendTimeoutCmd, GetSndBufSizeCmd, GetSockOptRawCmd, + GetTypeCmd, SetRcvBufSizeCmd, SetRecvTimeoutCmd, SetSendTimeoutCmd, SetSndBufSizeCmd, + SetSockOptRawCmd, SockOptName, +}; +use super::socket::uring::UringSocketType; +use super::socket::AnyAddr; - let file_ref: Arc = match sock_domain { - AddressFamily::LOCAL => { - let unix_socket = unix_socket(sock_type, file_flags, protocol)?; - Arc::new(unix_socket) - } - _ => { - let socket = HostSocket::new(sock_domain, sock_type, file_flags, protocol)?; - Arc::new(socket) +use super::*; + +use crate::fs::StatusFlags; +use crate::io_uring::ENABLE_URING; +use crate::prelude::*; + +// 4096 is default max socket connection value in Ubuntu 20.04 +const SOMAXCONN: u32 = 4096; +const SOCONN_DEFAULT: u32 = 16; + +#[repr(C)] +#[derive(Copy, Clone)] +pub struct mmsghdr { + pub msg_hdr: libc::msghdr, + pub msg_len: c_uint, +} + +pub fn do_socket(domain: c_int, socket_type: c_int, protocol: c_int) -> Result { + let domain = Domain::try_from(domain as u16)?; + let flags = SocketFlags::from_bits_truncate(socket_type); + + let type_bits = socket_type & !flags.bits(); + let socket_type = + Type::try_from(type_bits).map_err(|_| errno!(EINVAL, "invalid socket type"))?; + + debug!( + "socket domain: {:?}, type: {:?}, protocol: {:?}", + domain, socket_type, protocol + ); + + let mut file_ref: Option> = None; + + if ENABLE_URING.load(Ordering::Relaxed) { + let protocol = SocketProtocol::try_from(protocol) + .map_err(|_| errno!(EINVAL, "Invalid or unsupported network protocol"))?; + + // Determine if type and domain match uring supported socket + let match_uring = move || { + let is_uring_type = (socket_type == Type::DGRAM || socket_type == Type::STREAM); + let is_uring_protocol = (protocol == SocketProtocol::IPPROTO_IP + || protocol == SocketProtocol::IPPROTO_TCP + || protocol == SocketProtocol::IPPROTO_UDP); + let is_uring_domain = (domain == Domain::INET || domain == Domain::INET6); + + is_uring_type && is_uring_protocol && is_uring_domain + }; + + if match_uring() { + let nonblocking = flags.contains(SocketFlags::SOCK_NONBLOCK); + let socket_file = SocketFile::new(domain, protocol, socket_type, nonblocking)?; + file_ref = Some(Arc::new(socket_file)); } }; - let close_on_spawn = file_flags.contains(FileFlags::SOCK_CLOEXEC); - let fd = current!().add_file(file_ref, close_on_spawn); + // Dispatch unsupported uring domain and flags to ocall + if file_ref.is_none() { + match domain { + Domain::LOCAL => { + let unix_socket = unix_socket(socket_type, flags, protocol)?; + file_ref = Some(Arc::new(unix_socket)); + } + _ => { + let socket = HostSocket::new(domain, socket_type, flags, protocol)?; + file_ref = Some(Arc::new(socket)); + } + } + }; + + let close_on_spawn = flags.contains(SocketFlags::SOCK_CLOEXEC); + let fd = current!().add_file(file_ref.unwrap(), close_on_spawn); Ok(fd as isize) } pub fn do_bind(fd: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t) -> Result { - if addr.is_null() || addr_len == 0 { - return_errno!(EINVAL, "no address is specified"); - } - from_user::check_array(addr as *const u8, addr_len as usize)?; + let addr_len = addr_len as usize; + let sockaddr_storage = copy_sock_addr_from_user(addr, addr_len)?; + let mut addr = AnyAddr::from_c_storage(&sockaddr_storage, addr_len)?; + trace!("bind to addr: {:?}", addr); let file_ref = current!().file(fd as FileDesc)?; if let Ok(socket) = file_ref.as_host_socket() { - let sock_addr = unsafe { SockAddr::try_from_raw(addr, addr_len)? }; - trace!("bind to addr: {:?}", sock_addr); - socket.bind(&sock_addr)?; + let mut raw_addr = addr.to_raw(); + socket.bind(&mut raw_addr)?; } else if let Ok(unix_socket) = file_ref.as_unix_socket() { - let mut unix_addr = unsafe { UnixAddr::try_from_raw(addr, addr_len)? }; - trace!("bind to addr: {:?}", unix_addr); + let mut unix_addr = (addr.to_unix()?).clone(); unix_socket.bind(&mut unix_addr)?; + } else if let Ok(uring_socket) = file_ref.as_uring_socket() { + uring_socket.bind(&mut addr)?; } else { return_errno!(ENOTSOCK, "not a socket"); } @@ -61,6 +127,15 @@ pub fn do_listen(fd: c_int, backlog: c_int) -> Result { socket.listen(backlog)?; } else if let Ok(unix_socket) = file_ref.as_unix_socket() { unix_socket.listen(backlog)?; + } else if let Ok(uring_socket) = file_ref.as_uring_socket() { + let backlog: u32 = if backlog as u32 > SOMAXCONN { + SOMAXCONN + } else if backlog == 0 { + SOCONN_DEFAULT + } else { + backlog as u32 + }; + uring_socket.listen(backlog)?; } else { return_errno!(ENOTSOCK, "not a socket"); } @@ -84,21 +159,24 @@ pub fn do_connect( let file_ref = current!().file(fd as FileDesc)?; if let Ok(socket) = file_ref.as_host_socket() { let addr_option = if addr_set { - Some(unsafe { SockAddr::try_from_raw(addr, addr_len)? }) + Some(unsafe { RawAddr::try_from_raw(addr, addr_len as u32)? }) } else { None }; socket.connect(&addr_option)?; - } else if let Ok(unix_socket) = file_ref.as_unix_socket() { - // TODO: support AF_UNSPEC address for datagram socket use - let addr = if addr_set { - unsafe { UnixAddr::try_from_raw(addr, addr_len)? } - } else { - return_errno!(EINVAL, "invalid address"); - }; + return Ok(0); + }; - unix_socket.connect(&addr)?; + let addr_len = addr_len as usize; + let sockaddr_storage = copy_sock_addr_from_user(addr, addr_len)?; + let addr = AnyAddr::from_c_storage(&sockaddr_storage, addr_len)?; + + if let Ok(unix_socket) = file_ref.as_unix_socket() { + // TODO: support AF_UNSPEC address for datagram socket use + unix_socket.connect(addr.to_unix()?)?; + } else if let Ok(uring_socket) = file_ref.as_uring_socket() { + uring_socket.connect(&addr)?; } else { return_errno!(ENOTSOCK, "not a socket"); } @@ -120,70 +198,62 @@ pub fn do_accept4( addr_len: *mut libc::socklen_t, flags: c_int, ) -> Result { - let addr_set: bool = !addr.is_null(); - if addr_set { - from_user::check_ptr(addr_len)?; - from_user::check_mut_array(addr as *mut u8, unsafe { *addr_len } as usize)?; - } - - let file_flags = FileFlags::from_bits(flags).ok_or_else(|| errno!(EINVAL, "invalid flags"))?; - let close_on_spawn = file_flags.contains(FileFlags::SOCK_CLOEXEC); + let addr_and_addr_len = get_slice_from_sock_addr_ptr_mut(addr, addr_len)?; + let sock_flags = + SocketFlags::from_bits(flags).ok_or_else(|| errno!(EINVAL, "invalid flags"))?; + let close_on_spawn = sock_flags.contains(SocketFlags::SOCK_CLOEXEC); let file_ref = current!().file(fd as FileDesc)?; - if let Ok(socket) = file_ref.as_host_socket() { - let (new_socket_file, sock_addr_option) = socket.accept(file_flags)?; - let new_file_ref: Arc = Arc::new(new_socket_file); - let new_fd = current!().add_file(new_file_ref, close_on_spawn); - if addr_set { - if let Some(sock_addr) = sock_addr_option { - let mut buf = - unsafe { std::slice::from_raw_parts_mut(addr as *mut u8, *addr_len as usize) }; - sock_addr.copy_to_slice(&mut buf); - unsafe { - *addr_len = sock_addr.len() as u32; - } - } else { - unsafe { - *addr_len = 0; - } - } - } - Ok(new_fd as isize) - } else if let Ok(unix_socket) = file_ref.as_unix_socket() { - let (new_socket_file, sock_addr_option) = unix_socket.accept(file_flags)?; - let new_file_ref: Arc = Arc::new(new_socket_file); - let new_fd = current!().add_file(new_file_ref, close_on_spawn); + // Accept the socket + let (new_file_ref, sock_addr_option): (Arc, Option) = + if let Ok(socket) = file_ref.as_host_socket() { + let (new_socket_file, sock_addr_option) = socket.accept(sock_flags)?; + ( + Arc::new(new_socket_file), + sock_addr_option.map(|raw_addr| AnyAddr::Raw(raw_addr)), + ) + } else if let Ok(unix_socket) = file_ref.as_unix_socket() { + let (new_socket_file, sock_addr_option) = unix_socket.accept(sock_flags)?; + ( + Arc::new(new_socket_file), + sock_addr_option.map(|unix_addr| AnyAddr::Unix(unix_addr)), + ) + } else if let Ok(uring_socket) = file_ref.as_uring_socket() { + let nonblocking = sock_flags.contains(SocketFlags::SOCK_NONBLOCK); + let accepted_socket = uring_socket.accept(nonblocking)?; + let sock_addr = accepted_socket.peer_addr()?; + (Arc::new(accepted_socket), Some(sock_addr)) + } else { + return_errno!(ENOTSOCK, "not a socket"); + }; - if addr_set { - if let Some(sock_addr) = sock_addr_option { - let mut buf = - unsafe { std::slice::from_raw_parts_mut(addr as *mut u8, *addr_len as usize) }; - sock_addr.copy_to_slice(&mut buf); - unsafe { - *addr_len = sock_addr.raw_len() as u32; - } - } else { - unsafe { - *addr_len = 0; - } - } + let new_fd = current!().add_file(new_file_ref, close_on_spawn); + + // Output the address + if let Some((addr_mut, addr_len_mut)) = addr_and_addr_len { + if let Some(sock_addr) = sock_addr_option { + let (src_addr, src_addr_len) = sock_addr.to_c_storage(); + copy_sock_addr_to_user(src_addr, src_addr_len, addr_mut, addr_len_mut); + } else { + *addr_len_mut = 0; } - Ok(new_fd as isize) - } else { - return_errno!(ENOTSOCK, "not a socket"); } + + Ok(new_fd as isize) } pub fn do_shutdown(fd: c_int, how: c_int) -> Result { debug!("shutdown: fd: {}, how: {}", fd, how); - let how = HowToShut::try_from_raw(how)?; + let how = Shutdown::from_c(how as _)?; let file_ref = current!().file(fd as FileDesc)?; if let Ok(socket) = file_ref.as_host_socket() { socket.shutdown(how)?; } else if let Ok(unix_socket) = file_ref.as_unix_socket() { unix_socket.shutdown(how)?; + } else if let Ok(uring_socket) = file_ref.as_uring_socket() { + uring_socket.shutdown(how)?; } else { return_errno!(EBADF, "not a host socket") } @@ -203,21 +273,25 @@ pub fn do_setsockopt( fd, level, optname, optval, optlen ); let file_ref = current!().file(fd as FileDesc)?; - if let Ok(socket) = file_ref.as_host_socket() { - let ret = try_libc!(libc::ocall::setsockopt( - socket.raw_host_fd() as i32, - level, - optname, - optval, - optlen - )); - Ok(ret as isize) + + if optval as usize != 0 && optlen == 0 && ENABLE_URING.load(Ordering::Relaxed) { + return_errno!(EINVAL, "the optlen size is 0"); + } + + let optval = from_user::make_slice(optval as *const u8, optlen as usize)?; + + if let Ok(host_socket) = file_ref.as_host_socket() { + let mut cmd = new_host_setsockopt_cmd(level, optname, optval)?; + host_socket.ioctl(cmd.as_mut())?; } else if let Ok(unix_socket) = file_ref.as_unix_socket() { warn!("setsockopt for unix socket is unimplemented"); - Ok(0) + } else if let Ok(uring_socket) = file_ref.as_uring_socket() { + let mut cmd = new_uring_setsockopt_cmd(level, optname, optval, uring_socket.get_type())?; + uring_socket.ioctl(cmd.as_mut())?; } else { return_errno!(ENOTSOCK, "not a socket") } + Ok(0) } pub fn do_getsockopt( @@ -231,24 +305,35 @@ pub fn do_getsockopt( "getsockopt: fd: {}, level: {}, optname: {}, optval: {:?}, optlen: {:?}", fd, level, optname, optval, optlen ); - let file_ref = current!().file(fd as FileDesc)?; - let socket = file_ref.as_host_socket(); + let optlen_mut = from_user::make_mut_ref(optlen)?; + let optlen = *optlen_mut; + let optval_mut = from_user::make_mut_slice(optval as *mut u8, optlen as usize)?; - if let Ok(socket) = socket { - let ret = try_libc!(libc::ocall::getsockopt( - socket.raw_host_fd() as i32, - level, - optname, - optval, - optlen - )); - Ok(ret as isize) + // Man getsockopt: + // If the size of the option value is greater than option_len, the value stored in the object pointed to by the option_value argument will be silently truncated. + // Thus if the optlen is 0, nothing is returned to optval. We can just return here. + if optlen == 0 { + return Ok(0); + } + + let file_ref = current!().file(fd as FileDesc)?; + + if let Ok(host_socket) = file_ref.as_host_socket() { + let mut cmd = new_host_getsockopt_cmd(level, optname, optlen)?; + host_socket.ioctl(cmd.as_mut())?; + let src_optval = get_optval(cmd.as_ref())?; + copy_bytes_to_user(src_optval, optval_mut, optlen_mut); } else if let Ok(unix_socket) = file_ref.as_unix_socket() { warn!("getsockopt for unix socket is unimplemented"); - Ok(0) + } else if let Ok(uring_socket) = file_ref.as_uring_socket() { + let mut cmd = new_uring_getsockopt_cmd(level, optname, optlen, uring_socket.get_type())?; + uring_socket.ioctl(cmd.as_mut())?; + let src_optval = get_optval(cmd.as_ref())?; + copy_bytes_to_user(src_optval, optval_mut, optlen_mut); } else { return_errno!(ENOTSOCK, "not a socket") } + Ok(0) } pub fn do_getpeername( @@ -256,35 +341,27 @@ pub fn do_getpeername( addr: *mut libc::sockaddr, addr_len: *mut libc::socklen_t, ) -> Result { - let addr_set: bool = !addr.is_null(); - if addr_set { - from_user::check_ptr(addr_len)?; - from_user::check_mut_array(addr as *mut u8, unsafe { *addr_len } as usize)?; - } else { + let addr_and_addr_len = get_slice_from_sock_addr_ptr_mut(addr, addr_len)?; + if addr_and_addr_len.is_none() { return Ok(0); } let file_ref = current!().file(fd as FileDesc)?; - if let Ok(socket) = file_ref.as_host_socket() { - let ret = try_libc!(libc::ocall::getpeername( - socket.raw_host_fd() as i32, - addr, - addr_len - )); - Ok(ret as isize) + let (src_addr, src_addr_len) = if let Ok(host_socket) = file_ref.as_host_socket() { + host_socket.peer_addr()?.to_c_storage() } else if let Ok(unix_socket) = file_ref.as_unix_socket() { - let name = unix_socket.peer_addr()?; - let mut dst = unsafe { - std::slice::from_raw_parts_mut(addr as *mut _ as *mut u8, *addr_len as usize) - }; - name.copy_to_slice(dst); - unsafe { - *addr_len = name.raw_len() as u32; - } - Ok(0) + unix_socket.peer_addr()?.to_c_storage() + } else if let Ok(uring_socket) = file_ref.as_uring_socket() { + uring_socket.peer_addr()?.to_c_storage() } else { return_errno!(ENOTSOCK, "not a socket") + }; + + if let Some((addr_mut, addr_len_mut)) = addr_and_addr_len { + copy_sock_addr_to_user(src_addr, src_addr_len, addr_mut, addr_len_mut); } + + Ok(0) } pub fn do_getsockname( @@ -292,46 +369,27 @@ pub fn do_getsockname( addr: *mut libc::sockaddr, addr_len: *mut libc::socklen_t, ) -> Result { - let addr_set: bool = !addr.is_null(); - if addr_set { - from_user::check_ptr(addr_len)?; - from_user::check_mut_array(addr as *mut u8, unsafe { *addr_len } as usize)?; - } else { + let addr_and_addr_len = get_slice_from_sock_addr_ptr_mut(addr, addr_len)?; + if addr_and_addr_len.is_none() { return Ok(0); } - if unsafe { *addr_len } < std::mem::size_of::() as u32 { - return_errno!(EINVAL, "input length is too short"); - } - let file_ref = current!().file(fd as FileDesc)?; - if let Ok(socket) = file_ref.as_host_socket() { - let ret = try_libc!(libc::ocall::getsockname( - socket.raw_host_fd() as i32, - addr, - addr_len - )); - Ok(ret as isize) + let (src_addr, src_addr_len) = if let Ok(host_socket) = file_ref.as_host_socket() { + host_socket.addr()?.to_c_storage() } else if let Ok(unix_socket) = file_ref.as_unix_socket() { - let name_opt = unix_socket.addr(); - if let Some(name) = name_opt { - let mut dst = unsafe { - std::slice::from_raw_parts_mut(addr as *mut _ as *mut u8, *addr_len as usize) - }; - name.copy_to_slice(dst); - unsafe { - *addr_len = name.raw_len() as u32; - } - } else { - unsafe { - (*addr).sa_family = AddressFamily::LOCAL as u16; - *addr_len = 2; - } - } - Ok(0) + unix_socket.addr().to_c_storage() + } else if let Ok(uring_socket) = file_ref.as_uring_socket() { + uring_socket.addr()?.to_c_storage() } else { return_errno!(ENOTSOCK, "not a socket"); + }; + + if let Some((addr_mut, addr_len_mut)) = addr_and_addr_len { + copy_sock_addr_to_user(src_addr, src_addr_len, addr_mut, addr_len_mut); } + + Ok(0) } pub fn do_sendto( @@ -342,45 +400,43 @@ pub fn do_sendto( addr: *const libc::sockaddr, addr_len: libc::socklen_t, ) -> Result { - if len == 0 { - return Ok(0); - } - if addr.is_null() ^ (addr_len == 0) { - return_errno!(EINVAL, "addr and ddr_len should be both null"); + return_errno!(EINVAL, "addr and addr_len should be both null and 0 or not"); } + let addr = { + if addr.is_null() { + None + } else { + let addr_storage = copy_sock_addr_from_user(addr, addr_len as _)?; + Some(AnyAddr::from_c_storage(&addr_storage, addr_len as _)?) + } + }; from_user::check_array(base as *const u8, len)?; let buf = unsafe { std::slice::from_raw_parts(base as *const u8, len as usize) }; - let addr_set: bool = !addr.is_null(); - if addr_set { - from_user::check_mut_array(addr as *mut u8, addr_len as usize)?; - } - - let send_flags = SendFlags::from_bits(flags).unwrap(); + let send_flags = SendFlags::from_bits_truncate(flags); let file_ref = current!().file(fd as FileDesc)?; - if let Ok(socket) = file_ref.as_host_socket() { - let addr_option = if addr_set { - Some(unsafe { SockAddr::try_from_raw(addr, addr_len)? }) - } else { - None - }; + if let Ok(host_socket) = file_ref.as_host_socket() { + let addr = addr.map(|any_addr| any_addr.to_raw()); - socket - .sendto(buf, send_flags, &addr_option) + host_socket + .sendto(buf, send_flags, &addr) .map(|u| u as isize) } else if let Ok(unix_socket) = file_ref.as_unix_socket() { - let addr_option = if addr_set { - Some(unsafe { UnixAddr::try_from_raw(addr, addr_len)? }) - } else { - None + let addr = match addr { + Some(any_addr) => Some(any_addr.to_unix()?.clone()), + None => None, }; unix_socket - .sendto(buf, send_flags, &addr_option) + .sendto(buf, send_flags, &addr) .map(|u| u as isize) + } else if let Ok(uring_socket) = file_ref.as_uring_socket() { + uring_socket + .sendto(&buf, addr, send_flags) + .map(|bytes_send| bytes_send as isize) } else { return_errno!(EBADF, "unsupported file type"); } @@ -394,62 +450,41 @@ pub fn do_recvfrom( addr: *mut libc::sockaddr, addr_len: *mut libc::socklen_t, ) -> Result { - if addr.is_null() ^ addr_len.is_null() { - return_errno!(EINVAL, "addr and ddr_len should be both null"); - } + let addr_and_addr_len = get_slice_from_sock_addr_ptr_mut(addr, addr_len)?; from_user::check_array(base as *mut u8, len)?; let mut buf = unsafe { std::slice::from_raw_parts_mut(base as *mut u8, len as usize) }; // MSG_CTRUNC is a return flag but linux allows it to be set on input flags. // We just ignore it. - let recv_flags = RecvFlags::from_bits(flags & !(MsgHdrFlags::MSG_CTRUNC.bits())) + let recv_flags = RecvFlags::from_bits(flags & !(MsgFlags::MSG_CTRUNC.bits())) .ok_or_else(|| errno!(EINVAL, "invalid flags"))?; - let addr_set: bool = !addr.is_null(); - if addr_set { - from_user::check_ptr(addr_len)?; - from_user::check_mut_array(addr as *mut u8, unsafe { *addr_len } as usize)?; - } - let file_ref = current!().file(fd as FileDesc)?; - if let Ok(socket) = file_ref.as_host_socket() { - let (data_len, sock_addr_option) = socket.recvfrom(buf, recv_flags)?; - if addr_set { - if let Some(sock_addr) = sock_addr_option { - let mut buf = - unsafe { std::slice::from_raw_parts_mut(addr as *mut u8, *addr_len as usize) }; - sock_addr.copy_to_slice(&mut buf); - unsafe { - *addr_len = sock_addr.len() as u32; - } - } else { - unsafe { - *addr_len = 0; - } - } - } - Ok(data_len as isize) + let (data_len, addr_recv) = if let Ok(socket) = file_ref.as_host_socket() { + socket + .recvfrom(buf, recv_flags) + .map(|(len, addr_recv)| (len, addr_recv.map(|raw_addr| AnyAddr::Raw(raw_addr))))? } else if let Ok(unix_socket) = file_ref.as_unix_socket() { - let (data_len, sock_addr_option) = unix_socket.recvfrom(buf, recv_flags)?; - if addr_set { - if let Some(sock_addr) = sock_addr_option { - let mut buf = - unsafe { std::slice::from_raw_parts_mut(addr as *mut u8, *addr_len as usize) }; - sock_addr.copy_to_slice(&mut buf); - unsafe { - *addr_len = sock_addr.raw_len() as u32; - } - } else { - unsafe { - *addr_len = 0; - } - } - } - Ok(data_len as isize) + unix_socket + .recvfrom(buf, recv_flags) + .map(|(len, addr_recv)| (len, addr_recv.map(|unix_addr| AnyAddr::Unix(unix_addr))))? + } else if let Ok(uring_socket) = file_ref.as_uring_socket() { + uring_socket.recvfrom(&mut buf, recv_flags)? } else { return_errno!(ENOTSOCK, "not a socket"); + }; + + if let Some((addr_mut, addr_len_mut)) = addr_and_addr_len { + if let Some(addr_recv) = addr_recv { + let (c_addr_storage, c_addr_len) = addr_recv.to_c_storage(); + copy_sock_addr_to_user(c_addr_storage, c_addr_len, addr_mut, addr_len_mut); + } else { + // If addr_recv is not filled, set addr_len to 0 + *addr_len_mut = 0; + } } + Ok(data_len as isize) } pub fn do_socketpair( @@ -463,16 +498,17 @@ pub fn do_socketpair( std::slice::from_raw_parts_mut(sv as *mut u32, 2) }; - let file_flags = FileFlags::from_bits_truncate(socket_type); - let close_on_spawn = file_flags.contains(FileFlags::SOCK_CLOEXEC); - let sock_type = SocketType::try_from(socket_type & (!file_flags.bits()))?; + let file_flags = SocketFlags::from_bits_truncate(socket_type); + let close_on_spawn = file_flags.contains(SocketFlags::SOCK_CLOEXEC); + let sock_type = Type::try_from(socket_type & (!file_flags.bits())) + .map_err(|_| errno!(EINVAL, "invalid socket type"))?; - let domain = AddressFamily::try_from(domain as u16)?; - if (domain == AddressFamily::LOCAL) { - let (client_socket, server_socket) = socketpair(sock_type, file_flags, protocol as i32)?; + let domain = Domain::try_from(domain as u16)?; + if (domain == Domain::LOCAL) { + let (client_socket, server_socket) = socketpair(sock_type, file_flags, protocol)?; let current = current!(); - let mut files = current.files().lock().unwrap(); + let mut files = current.files().lock(); sock_pair[0] = files.put(Arc::new(client_socket), close_on_spawn); sock_pair[1] = files.put(Arc::new(server_socket), close_on_spawn); @@ -483,68 +519,79 @@ pub fn do_socketpair( } } -pub fn do_sendmsg(fd: c_int, msg_ptr: *const msghdr, flags_c: c_int) -> Result { +pub fn do_sendmsg(fd: c_int, msg_ptr: *const libc::msghdr, flags_c: c_int) -> Result { debug!( "sendmsg: fd: {}, msg: {:?}, flags: 0x{:x}", fd, msg_ptr, flags_c ); - let msg_hdr = { - let msg_hdr_c = { - from_user::check_ptr(msg_ptr)?; - let msg_hdr_c = unsafe { &*msg_ptr }; - msg_hdr_c.check_member_ptrs()?; - msg_hdr_c - }; - unsafe { MsgHdr::from_c(&msg_hdr_c)? } - }; - + let (addr, bufs, control) = extract_msghdr_from_user(msg_ptr)?; let flags = SendFlags::from_bits_truncate(flags_c); let file_ref = current!().file(fd as FileDesc)?; - if let Ok(socket) = file_ref.as_host_socket() { - socket - .sendmsg(&msg_hdr, flags) - .map(|bytes_sent| bytes_sent as isize) + if let Ok(host_socket) = file_ref.as_host_socket() { + let raw_addr = addr.map(|addr| addr.to_raw()); + host_socket + .sendmsg(&bufs[..], flags, &raw_addr, control) + .map(|bytes_send| bytes_send as isize) } else if let Ok(socket) = file_ref.as_unix_socket() { socket - .sendmsg(&msg_hdr, flags) + .sendmsg(&bufs[..], flags, control) .map(|bytes_sent| bytes_sent as isize) + } else if let Ok(uring_socket) = file_ref.as_uring_socket() { + uring_socket + .sendmsg(&bufs[..], addr, flags, control) + .map(|bytes_send| bytes_send as isize) } else { return_errno!(ENOTSOCK, "not a socket") } } -pub fn do_recvmsg(fd: c_int, msg_mut_ptr: *mut msghdr_mut, flags_c: c_int) -> Result { +pub fn do_recvmsg(fd: c_int, msg_mut_ptr: *mut libc::msghdr, flags_c: c_int) -> Result { debug!( "recvmsg: fd: {}, msg: {:?}, flags: 0x{:x}", fd, msg_mut_ptr, flags_c ); - - let mut msg_hdr_mut = { - let msg_hdr_mut_c = { - from_user::check_mut_ptr(msg_mut_ptr)?; - let msg_hdr_mut_c = unsafe { &mut *msg_mut_ptr }; - msg_hdr_mut_c.check_member_ptrs()?; - msg_hdr_mut_c - }; - unsafe { MsgHdrMut::from_c(msg_hdr_mut_c) }? - }; - + let (mut msg, mut addr, mut control, mut bufs) = extract_msghdr_mut_from_user(msg_mut_ptr)?; let flags = RecvFlags::from_bits_truncate(flags_c); let file_ref = current!().file(fd as FileDesc)?; - if let Ok(socket) = file_ref.as_host_socket() { - socket - .recvmsg(&mut msg_hdr_mut, flags) - .map(|bytes_recvd| bytes_recvd as isize) - } else if let Ok(socket) = file_ref.as_unix_socket() { - socket - .recvmsg(&mut msg_hdr_mut, flags) - .map(|bytes_recvd| bytes_recvd as isize) - } else { - return_errno!(ENOTSOCK, "not a socket") + let (bytes_recv, recv_addr, msg_flags, msg_controllen) = + if let Ok(host_socket) = file_ref.as_host_socket() { + host_socket.recvmsg(&mut bufs[..], flags, control).map( + |(bytes, addr, msg, controllen)| { + ( + bytes, + addr.map(|raw_addr| AnyAddr::Raw(raw_addr)), + msg, + controllen, + ) + }, + )? + } else if let Ok(unix_socket) = file_ref.as_unix_socket() { + unix_socket.recvmsg(&mut bufs[..], flags, control).map( + |(bytes_recvd, control_len)| (bytes_recvd, None, MsgFlags::empty(), control_len), + )? + } else if let Ok(uring_socket) = file_ref.as_uring_socket() { + uring_socket.recvmsg(&mut bufs[..], flags, control)? + } else { + return_errno!(ENOTSOCK, "not a socket") + }; + + if let Some(addr) = addr { + if let Some(recv_addr) = recv_addr { + let (c_addr_storage, c_addr_len) = recv_addr.to_c_storage(); + copy_sock_addr_to_user(c_addr_storage, c_addr_len, addr, &mut msg.msg_namelen); + } } + + msg.msg_flags = msg_flags.bits(); + msg.msg_controllen = msg_controllen; + if msg_controllen == 0 { + msg.msg_control = ptr::null_mut(); + } + + Ok(bytes_recv as isize) } pub fn do_sendmmsg( @@ -559,31 +606,22 @@ pub fn do_sendmmsg( ); from_user::check_ptr(msgvec_ptr)?; - let mut msgvec = unsafe { std::slice::from_raw_parts_mut(msgvec_ptr, vlen as usize) }; + let flags = SendFlags::from_bits_truncate(flags_c); let file_ref = current!().file(fd as FileDesc)?; + let mut send_count = 0; - if let Ok(socket) = file_ref.as_host_socket() { - let mut send_count = 0; + if let Ok(host_socket) = file_ref.as_host_socket() { for mmsg in (msgvec) { - if !mmsg.msg_hdr.check_member_ptrs().is_ok() { - break; - } + let (any_addr, bufs, control) = extract_msghdr_from_user(&mmsg.msg_hdr)?; + let raw_addr = any_addr.map(|any_addr| any_addr.to_raw()); - let msg = unsafe { - if let Ok(msg) = MsgHdr::from_c({ &mmsg.msg_hdr }) { - msg - } else { - break; - } - }; - - if socket - .sendmsg(&msg, flags) - .map(|bytes_sent| { - mmsg.msg_len = bytes_sent as u32; - mmsg.msg_len + if host_socket + .sendmsg(&bufs[..], flags, &raw_addr, control) + .map(|bytes_send| { + mmsg.msg_len += bytes_send as c_uint; + bytes_send as isize }) .is_ok() { @@ -592,74 +630,30 @@ pub fn do_sendmmsg( break; } } - - Ok(send_count as isize) } else if let Ok(socket) = file_ref.as_unix_socket() { return_errno!(EOPNOTSUPP, "does not support unix socket") + } else if let Ok(uring_socket) = file_ref.as_uring_socket() { + for mmsg in (msgvec) { + let (addr, bufs, control) = extract_msghdr_from_user(&mmsg.msg_hdr)?; + + if uring_socket + .sendmsg(&bufs[..], addr, flags, control) + .map(|bytes_send| { + mmsg.msg_len += bytes_send as c_uint; + bytes_send as isize + }) + .is_ok() + { + send_count += 1; + } else { + break; + } + } } else { return_errno!(ENOTSOCK, "not a socket") } -} -#[allow(non_camel_case_types)] -trait c_msghdr_ext { - fn check_member_ptrs(&self) -> Result<()>; -} - -impl c_msghdr_ext for msghdr { - // TODO: implement this! - fn check_member_ptrs(&self) -> Result<()> { - Ok(()) - } - /* - ///user space check - pub unsafe fn check_from_user(user_hdr: *const msghdr) -> Result<()> { - Self::check_pointer(user_hdr, from_user::check_ptr) - } - - ///Check msghdr ptr - pub unsafe fn check_pointer( - user_hdr: *const msghdr, - check_ptr: fn(*const u8) -> Result<()>, - ) -> Result<()> { - check_ptr(user_hdr as *const u8)?; - - if (*user_hdr).msg_name.is_null() ^ ((*user_hdr).msg_namelen == 0) { - return_errno!(EINVAL, "name length is invalid"); - } - - if (*user_hdr).msg_iov.is_null() ^ ((*user_hdr).msg_iovlen == 0) { - return_errno!(EINVAL, "iov length is invalid"); - } - - if (*user_hdr).msg_control.is_null() ^ ((*user_hdr).msg_controllen == 0) { - return_errno!(EINVAL, "control length is invalid"); - } - - if !(*user_hdr).msg_name.is_null() { - check_ptr((*user_hdr).msg_name as *const u8)?; - } - - if !(*user_hdr).msg_iov.is_null() { - check_ptr((*user_hdr).msg_iov as *const u8)?; - let iov_slice = slice::from_raw_parts((*user_hdr).msg_iov, (*user_hdr).msg_iovlen); - for iov in iov_slice { - check_ptr(iov.iov_base as *const u8)?; - } - } - - if !(*user_hdr).msg_control.is_null() { - check_ptr((*user_hdr).msg_control as *const u8)?; - } - Ok(()) - } - */ -} - -impl c_msghdr_ext for msghdr_mut { - fn check_member_ptrs(&self) -> Result<()> { - Ok(()) - } + Ok(send_count as isize) } pub fn do_select( @@ -1013,3 +1007,495 @@ pub fn do_epoll_pwait( } do_epoll_wait(epfd, events, maxevents, timeout) } + +fn copy_sock_addr_from_user( + addr: *const libc::sockaddr, + addr_len: usize, +) -> Result { + // Check the address pointer and length + if addr.is_null() || addr_len == 0 { + return_errno!(EINVAL, "no address is specified"); + } + if addr_len > std::mem::size_of::() { + return_errno!( + EINVAL, + "addr len cannot be greater than sockaddr_storage's size" + ); + } + let sockaddr_src_buf = from_user::make_slice(addr as *const u8, addr_len)?; + + let sockaddr_storage = { + // Safety. The content will be initialized before function returns. + let mut sockaddr_storage = + unsafe { MaybeUninit::::uninit().assume_init() }; + // Safety. The dst slice is the only mutable reference to the sockaddr_storage + let sockaddr_dst_buf = unsafe { + let ptr = &mut sockaddr_storage as *mut _ as *mut u8; + let len = addr_len; + std::slice::from_raw_parts_mut(ptr, len) + }; + sockaddr_dst_buf.copy_from_slice(sockaddr_src_buf); + sockaddr_storage + }; + Ok(sockaddr_storage) +} + +fn get_slice_from_sock_addr_ptr_mut<'a>( + addr_ptr: *mut libc::sockaddr, + addr_len_ptr: *mut libc::socklen_t, +) -> Result> { + if addr_ptr.is_null() ^ addr_len_ptr.is_null() { + return_errno!(EINVAL, "addr and addr_len should be both null or not null"); + } + if addr_ptr.is_null() { + return Ok(None); + } + + let addr_len_mut = from_user::make_mut_ref(addr_len_ptr)?; + let addr_len = *addr_len_mut; + let addr_mut = from_user::make_mut_slice(addr_ptr as *mut u8, addr_len as usize)?; + Ok(Some((addr_mut, addr_len_mut))) +} + +fn copy_sock_addr_to_user( + src_addr: libc::sockaddr_storage, + src_addr_len: usize, + dst_addr: &mut [u8], + dst_addr_len: &mut u32, +) { + let len = std::cmp::min(src_addr_len, *dst_addr_len as usize); + let sockaddr_src_buf = unsafe { + let ptr = &src_addr as *const _ as *const u8; + std::slice::from_raw_parts(ptr, len) + }; + dst_addr[..len].copy_from_slice(sockaddr_src_buf); + *dst_addr_len = src_addr_len as u32; +} + +/// Create a new ioctl command for host socket getsockopt syscall +fn new_host_getsockopt_cmd(level: i32, optname: i32, optlen: u32) -> Result> { + if level != libc::SOL_SOCKET { + return Ok(Box::new(GetSockOptRawCmd::new(level, optname, optlen))); + } + + let opt = + SockOptName::try_from(optname).map_err(|_| errno!(ENOPROTOOPT, "Not a valid optname"))?; + + Ok(match opt { + SockOptName::SO_CNX_ADVICE => return_errno!(ENOPROTOOPT, "it's a write-only option"), + _ => Box::new(GetSockOptRawCmd::new(level, optname, optlen)), + }) +} + +/// Create a new ioctl command for uring socket getsockopt syscall +fn new_uring_getsockopt_cmd( + level: i32, + optname: i32, + optlen: u32, + socket_type: Type, +) -> Result> { + if level != libc::SOL_SOCKET { + return Ok(Box::new(GetSockOptRawCmd::new(level, optname, optlen))); + } + + let opt = + SockOptName::try_from(optname).map_err(|_| errno!(ENOPROTOOPT, "Not a valid optname"))?; + + Ok(match opt { + SockOptName::SO_ACCEPTCONN => Box::new(GetAcceptConnCmd::new(())), + SockOptName::SO_DOMAIN => Box::new(GetDomainCmd::new(())), + SockOptName::SO_ERROR => Box::new(GetErrorCmd::new(())), + SockOptName::SO_PEERNAME => Box::new(GetPeerNameCmd::new(())), + SockOptName::SO_TYPE => Box::new(GetTypeCmd::new(())), + SockOptName::SO_RCVTIMEO_OLD => Box::new(GetRecvTimeoutCmd::new(())), + SockOptName::SO_SNDTIMEO_OLD => Box::new(GetSendTimeoutCmd::new(())), + SockOptName::SO_SNDBUF => { + if socket_type == Type::STREAM { + // Implement dynamic buf size for stream socket only. + Box::new(GetSndBufSizeCmd::new(())) + } else { + Box::new(GetSockOptRawCmd::new(level, optname, optlen)) + } + } + SockOptName::SO_RCVBUF => { + if socket_type == Type::STREAM { + // Implement dynamic buf size for stream socket only. + Box::new(GetRcvBufSizeCmd::new(())) + } else { + Box::new(GetSockOptRawCmd::new(level, optname, optlen)) + } + } + + SockOptName::SO_CNX_ADVICE => return_errno!(ENOPROTOOPT, "it's a write-only option"), + _ => Box::new(GetSockOptRawCmd::new(level, optname, optlen)), + }) +} + +/// Create a new ioctl command for host socket setsockopt syscall +fn new_host_setsockopt_cmd(level: i32, optname: i32, optval: &[u8]) -> Result> { + if level != libc::SOL_SOCKET { + return Ok(Box::new(SetSockOptRawCmd::new(level, optname, optval))); + } + + let opt = + SockOptName::try_from(optname).map_err(|_| errno!(ENOPROTOOPT, "Not a valid optname"))?; + + Ok(match opt { + SockOptName::SO_ACCEPTCONN + | SockOptName::SO_DOMAIN + | SockOptName::SO_PEERNAME + | SockOptName::SO_TYPE + | SockOptName::SO_ERROR + | SockOptName::SO_PEERCRED + | SockOptName::SO_SNDLOWAT + | SockOptName::SO_PEERSEC + | SockOptName::SO_PROTOCOL + | SockOptName::SO_MEMINFO + | SockOptName::SO_INCOMING_NAPI_ID + | SockOptName::SO_COOKIE + | SockOptName::SO_PEERGROUPS => return_errno!(ENOPROTOOPT, "it's a read-only option"), + _ => Box::new(SetSockOptRawCmd::new(level, optname, optval)), + }) +} + +/// Create a new ioctl command for uring socket setsockopt syscall +fn new_uring_setsockopt_cmd( + level: i32, + optname: i32, + optval: &[u8], + socket_type: Type, +) -> Result> { + if level != libc::SOL_SOCKET { + return Ok(Box::new(SetSockOptRawCmd::new(level, optname, optval))); + } + + let opt = + SockOptName::try_from(optname).map_err(|_| errno!(ENOPROTOOPT, "Not a valid optname"))?; + + let enable_uring = ENABLE_URING.load(Ordering::Relaxed); + if !enable_uring { + Ok(match opt { + SockOptName::SO_ACCEPTCONN + | SockOptName::SO_DOMAIN + | SockOptName::SO_PEERNAME + | SockOptName::SO_TYPE + | SockOptName::SO_ERROR + | SockOptName::SO_PEERCRED + | SockOptName::SO_SNDLOWAT + | SockOptName::SO_PEERSEC + | SockOptName::SO_PROTOCOL + | SockOptName::SO_MEMINFO + | SockOptName::SO_INCOMING_NAPI_ID + | SockOptName::SO_COOKIE + | SockOptName::SO_PEERGROUPS => return_errno!(ENOPROTOOPT, "it's a read-only option"), + _ => Box::new(SetSockOptRawCmd::new(level, optname, optval)), + }) + } else { + Ok(match opt { + SockOptName::SO_ACCEPTCONN + | SockOptName::SO_DOMAIN + | SockOptName::SO_PEERNAME + | SockOptName::SO_TYPE + | SockOptName::SO_ERROR + | SockOptName::SO_PEERCRED + | SockOptName::SO_SNDLOWAT + | SockOptName::SO_PEERSEC + | SockOptName::SO_PROTOCOL + | SockOptName::SO_MEMINFO + | SockOptName::SO_INCOMING_NAPI_ID + | SockOptName::SO_COOKIE + | SockOptName::SO_PEERGROUPS => return_errno!(ENOPROTOOPT, "it's a read-only option"), + SockOptName::SO_RCVTIMEO_OLD => { + let mut timeout: *const libc::timeval = std::ptr::null(); + if optval.len() >= std::mem::size_of::() { + timeout = optval as *const _ as *const libc::timeval; + } else { + return_errno!(EINVAL, "invalid timeout option"); + } + let timeout = unsafe { + let secs = if (*timeout).tv_sec < 0 { + 0 + } else { + (*timeout).tv_sec + }; + + let usec = (*timeout).tv_usec; + if usec < 0 || usec > 1000000 || (usec as u32).checked_mul(1000).is_none() { + return_errno!(EDOM, "time struct value is invalid"); + } + Duration::new(secs as u64, (*timeout).tv_usec as u32 * 1000) + }; + trace!("recv timeout = {:?}", timeout); + Box::new(SetRecvTimeoutCmd::new(timeout)) + } + SockOptName::SO_SNDTIMEO_OLD => { + let mut timeout: *const libc::timeval = std::ptr::null(); + if optval.len() >= std::mem::size_of::() { + timeout = optval as *const _ as *const libc::timeval; + } else { + return_errno!(EINVAL, "invalid timeout option"); + } + let timeout = unsafe { + let secs = if (*timeout).tv_sec < 0 { + 0 + } else { + (*timeout).tv_sec + }; + + let usec = (*timeout).tv_usec; + if usec < 0 || usec > 1000000 || (usec as u32).checked_mul(1000).is_none() { + return_errno!(EDOM, "time struct value is invalid"); + } + Duration::new(secs as u64, usec as u32 * 1000) + }; + trace!("send timeout = {:?}", timeout); + Box::new(SetSendTimeoutCmd::new(timeout)) + } + SockOptName::SO_SNDBUF => { + // Implement dynamic buf size for stream socket only. + if socket_type != Type::STREAM { + Box::new(SetSockOptRawCmd::new(level, optname, optval)) + } else { + // Based on the man page: The minimum (doubled) value for this option is 2048. + // However, the test on Linux shows the minimum value (doubled) is 4608. Here, we just + // use the same value as Linux. + let min_size = 1152; + // For the max value, we choose 4MB (doubled) to assure the libos kernel buf won't be the bottleneck. + let max_size = 2 * 1024 * 1024; + + let mut send_buf_size = { + let mut size = [0 as u8; std::mem::size_of::()]; + let start_offset = size.len() - optval.len(); + size[start_offset..].copy_from_slice(optval); + usize::from_ne_bytes(size) + }; + trace!("set SO_SNDBUF size = {:?}", send_buf_size); + if send_buf_size < min_size { + send_buf_size = min_size; + } + if send_buf_size > max_size { + send_buf_size = max_size; + } + // Based on man page: The kernel doubles this value (to allow space for bookkeeping overhead) + // when it is set using setsockopt(2), and this doubled value is returned by getsockopt(2). + send_buf_size *= 2; + Box::new(SetSndBufSizeCmd::new(send_buf_size)) + } + } + SockOptName::SO_RCVBUF => { + if socket_type != Type::STREAM { + Box::new(SetSockOptRawCmd::new(level, optname, optval)) + } else { + // Implement dynamic buf size for stream socket only. + info!("optval = {:?}", optval); + // Based on the man page: The minimum (doubled) value for this option is 256. + // However, the test on Linux shows the minimum value (doubled) is 2304. Here, we just + // use the same value as Linux. + let min_size = 1152; + // For the max value, we choose 4MB (doubled) to assure the libos kernel buf won't be the bottleneck. + let max_size = 2 * 1024 * 1024; + + let mut recv_buf_size = { + let mut size = [0 as u8; std::mem::size_of::()]; + let start_offset = size.len() - optval.len(); + size[start_offset..].copy_from_slice(optval); + usize::from_ne_bytes(size) + }; + trace!("set SO_RCVBUF size = {:?}", recv_buf_size); + if recv_buf_size < min_size { + recv_buf_size = min_size; + } + + if recv_buf_size > max_size { + recv_buf_size = max_size + } + // Based on man page: The kernel doubles this value (to allow space for bookkeeping overhead) + // when it is set using setsockopt(2), and this doubled value is returned by getsockopt(2). + recv_buf_size *= 2; + Box::new(SetRcvBufSizeCmd::new(recv_buf_size)) + } + } + _ => Box::new(SetSockOptRawCmd::new(level, optname, optval)), + }) + } +} + +fn get_optval(cmd: &dyn IoctlCmd) -> Result<&[u8]> { + crate::match_ioctl_cmd_ref!(cmd, { + cmd : GetAcceptConnCmd => { + cmd.get_output_as_bytes() + }, + cmd : GetDomainCmd => { + cmd.get_output_as_bytes() + }, + cmd : GetPeerNameCmd => { + cmd.get_output_as_bytes() + }, + cmd : GetTypeCmd => { + cmd.get_output_as_bytes() + }, + cmd : GetSockOptRawCmd => { + cmd.get_output_as_bytes() + }, + cmd : GetErrorCmd => { + cmd.get_output_as_bytes() + }, + cmd : GetRecvTimeoutCmd => { + cmd.get_output_as_bytes() + }, + cmd : GetSendTimeoutCmd => { + cmd.get_output_as_bytes() + }, + cmd : GetSndBufSizeCmd => { + cmd.get_output_as_bytes() + }, + cmd : GetRcvBufSizeCmd => { + cmd.get_output_as_bytes() + }, + _ => { + return_errno!(EINVAL, "invalid sockopt command"); + } + }) + .ok_or_else(|| errno!(EINVAL, "no available output")) +} + +fn copy_bytes_to_user(src_buf: &[u8], dst_buf: &mut [u8], dst_len: &mut u32) { + let copy_len = dst_buf.len().min(src_buf.len()); + dst_buf[..copy_len].copy_from_slice(&src_buf[..copy_len]); + *dst_len = copy_len as _; +} + +fn extract_msghdr_from_user<'a>( + msg_ptr: *const libc::msghdr, +) -> Result<(Option, Vec<&'a [u8]>, Option<&'a [u8]>)> { + let msg = from_user::make_ref(msg_ptr)?; + + let msg_name = msg.msg_name; + let msg_namelen = msg.msg_namelen; + if msg_name.is_null() ^ (msg_namelen == 0) { + return_errno!(EINVAL, "name and namelen should be both null and 0 or not"); + } + let name = if msg_name.is_null() { + None + } else { + let sockaddr_storage = copy_sock_addr_from_user(msg_name as *const _, msg_namelen as _)?; + Some(AnyAddr::from_c_storage( + &sockaddr_storage, + msg_namelen as _, + )?) + }; + + let msg_control = msg.msg_control; + let msg_controllen = msg.msg_controllen; + + if msg_control.is_null() ^ (msg_controllen == 0) { + return_errno!( + EINVAL, + "message control and controllen should be both null and 0 or not" + ); + } + + let control = if msg_control.is_null() { + None + } else { + Some(from_user::make_slice( + msg_control as *const u8, + msg_controllen as _, + )?) + }; + + let msg_iov = msg.msg_iov; + let msg_iovlen = msg.msg_iovlen; + if msg_iov.is_null() ^ (msg_iovlen == 0) { + return_errno!(EINVAL, "iov and iovlen should be both null and 0 or not"); + } + let bufs = if msg_iov.is_null() { + Vec::new() + } else { + let iovs = from_user::make_slice(msg_iov, msg_iovlen)?; + let mut bufs = Vec::with_capacity(msg_iovlen); + for iov in iovs { + let buf = from_user::make_slice(iov.iov_base as *const u8, iov.iov_len)?; + bufs.push(buf); + } + bufs + }; + + Ok((name, bufs, control)) +} + +fn extract_msghdr_mut_from_user<'a>( + msg_mut_ptr: *mut libc::msghdr, +) -> Result<( + &'a mut libc::msghdr, + Option<&'a mut [u8]>, + Option<&'a mut [u8]>, + Vec<&'a mut [u8]>, +)> { + let msg_mut = from_user::make_mut_ref(msg_mut_ptr)?; + + let msg_name = msg_mut.msg_name; + let msg_namelen = msg_mut.msg_namelen; + if msg_name.is_null() ^ (msg_namelen == 0) { + return_errno!(EINVAL, "name and namelen should be both null and 0 or not"); + } + let name = if msg_name.is_null() { + None + } else { + Some(from_user::make_mut_slice( + msg_name as *mut u8, + msg_namelen as usize, + )?) + }; + + let msg_control = msg_mut.msg_control; + let msg_controllen = msg_mut.msg_controllen; + + if msg_control.is_null() ^ (msg_controllen == 0) { + return_errno!( + EINVAL, + "message control and controllen should be both null and 0 or not" + ); + } + + let control = if msg_control.is_null() { + None + } else { + Some(from_user::make_mut_slice( + msg_control as *mut u8, + msg_controllen as usize, + )?) + }; + + let msg_iov = msg_mut.msg_iov; + let msg_iovlen = msg_mut.msg_iovlen; + if msg_iov.is_null() ^ (msg_iovlen == 0) { + return_errno!(EINVAL, "iov and iovlen should be both null and 0 or not"); + } + let bufs = if msg_iov.is_null() { + Vec::new() + } else { + let iovs = from_user::make_mut_slice(msg_iov, msg_iovlen)?; + let mut bufs = Vec::with_capacity(msg_iovlen); + for iov in iovs { + // In some situation using MSG_ERRQUEUE, users just require control buffers, + // they may left iovec buffer all zero. It works in Linux. + if iov.iov_base.is_null() { + break; + } + let buf = from_user::make_mut_slice(iov.iov_base as *mut u8, iov.iov_len)?; + bufs.push(buf); + } + bufs + }; + + Ok((msg_mut, name, control, bufs)) +} + +// Flags to use when creating a new socket +bitflags! { + pub struct SocketFlags: i32 { + const SOCK_NONBLOCK = 0x800; + const SOCK_CLOEXEC = 0x80000; + } +}