Add sendmmsg syscall

This commit is contained in:
zongmin.gu 2021-04-29 16:46:57 +08:00 committed by Tate, Hongliang Tian
parent 437b6245d3
commit 070bdf6f39
7 changed files with 144 additions and 18 deletions

@ -7,7 +7,7 @@ pub use self::io_multiplexing::{
PollEventFlags, PollFd, THREAD_NOTIFIERS,
};
pub use self::socket::{
msghdr, msghdr_mut, socketpair, unix_socket, AddressFamily, AsUnixSocket, FileFlags,
mmsghdr, msghdr, msghdr_mut, socketpair, unix_socket, AddressFamily, AsUnixSocket, FileFlags,
HostSocket, HostSocketType, HowToShut, Iovs, IovsMut, MsgHdr, MsgHdrFlags, MsgHdrMut,
RecvFlags, SendFlags, SliceAsLibcIovec, SockAddr, SocketType, UnixAddr,
};

@ -14,7 +14,7 @@ pub use self::address_family::AddressFamily;
pub use self::flags::{FileFlags, MsgHdrFlags, RecvFlags, SendFlags};
pub use self::host::{HostSocket, HostSocketType};
pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec};
pub use self::msg::{msghdr, msghdr_mut, MsgHdr, MsgHdrMut};
pub use self::msg::{mmsghdr, msghdr, msghdr_mut, MsgHdr, MsgHdrMut};
pub use self::shutdown::HowToShut;
pub use self::socket_address::SockAddr;
pub use self::socket_type::SocketType;

@ -14,6 +14,13 @@ pub struct msghdr {
pub msg_flags: c_int,
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct mmsghdr {
pub msg_hdr: msghdr,
pub msg_len: c_uint,
}
/// C struct for a socket message with mutable pointers
#[repr(C)]
#[derive(Debug, Copy, Clone)]

@ -48,7 +48,7 @@ pub fn do_bind(fd: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t
trace!("bind to addr: {:?}", unix_addr);
unix_socket.bind(&unix_addr)?;
} else {
return_errno!(EBADF, "not a socket");
return_errno!(ENOTSOCK, "not a socket");
}
Ok(0)
@ -61,7 +61,7 @@ pub fn do_listen(fd: c_int, backlog: c_int) -> Result<isize> {
} else if let Ok(unix_socket) = file_ref.as_unix_socket() {
unix_socket.listen(backlog)?;
} else {
return_errno!(EBADF, "not a socket");
return_errno!(ENOTSOCK, "not a socket");
}
Ok(0)
@ -99,7 +99,7 @@ pub fn do_connect(
unix_socket.connect(&addr)?;
} else {
return_errno!(EBADF, "not a socket");
return_errno!(ENOTSOCK, "not a socket");
}
Ok(0)
@ -170,7 +170,7 @@ pub fn do_accept4(
}
Ok(new_fd as isize)
} else {
return_errno!(EBADF, "not a socket");
return_errno!(ENOTSOCK, "not a socket");
}
}
@ -215,7 +215,7 @@ pub fn do_setsockopt(
warn!("setsockopt for unix socket is unimplemented");
Ok(0)
} else {
return_errno!(EBADF, "not a socket")
return_errno!(ENOTSOCK, "not a socket")
}
}
@ -275,7 +275,7 @@ pub fn do_getpeername(
}
Ok(0)
} else {
return_errno!(EBADF, "not a socket")
return_errno!(ENOTSOCK, "not a socket")
}
}
@ -322,7 +322,7 @@ pub fn do_getsockname(
}
Ok(0)
} else {
return_errno!(EBADF, "not a socket");
return_errno!(ENOTSOCK, "not a socket");
}
}
@ -440,7 +440,7 @@ pub fn do_recvfrom(
}
Ok(data_len as isize)
} else {
return_errno!(EBADF, "not a socket");
return_errno!(ENOTSOCK, "not a socket");
}
}
@ -497,9 +497,9 @@ pub fn do_sendmsg(fd: c_int, msg_ptr: *const msghdr, flags_c: c_int) -> Result<i
.sendmsg(&msg, flags)
.map(|bytes_sent| bytes_sent as isize)
} else if let Ok(socket) = file_ref.as_unix_socket() {
return_errno!(EBADF, "does not support unix socket")
return_errno!(EOPNOTSUPP, "does not support unix socket")
} else {
return_errno!(EBADF, "not a socket")
return_errno!(ENOTSOCK, "not a socket")
}
}
@ -525,9 +525,63 @@ pub fn do_recvmsg(fd: c_int, msg_mut_ptr: *mut msghdr_mut, flags_c: c_int) -> Re
.recvmsg(&mut msg_mut, flags)
.map(|bytes_recvd| bytes_recvd as isize)
} else if let Ok(socket) = file_ref.as_unix_socket() {
return_errno!(EBADF, "does not support unix socket")
return_errno!(EOPNOTSUPP, "does not support unix socket")
} else {
return_errno!(EBADF, "not a socket")
return_errno!(ENOTSOCK, "not a socket")
}
}
pub fn do_sendmmsg(
fd: c_int,
msgvec_ptr: *mut mmsghdr,
vlen: c_uint,
flags_c: c_int,
) -> Result<isize> {
debug!(
"sendmmsg: fd: {}, msg: {:?}, flags: 0x{:x}",
fd, msgvec_ptr, flags_c
);
from_user::check_ptr(msgvec_ptr)?;
let mut msgvec = unsafe { std::slice::from_raw_parts_mut(msgvec_ptr, vlen as usize) };
let flags = SendFlags::from_bits_truncate(flags_c);
let file_ref = current!().file(fd as FileDesc)?;
if let Ok(socket) = file_ref.as_host_socket() {
let mut send_count = 0;
for mmsg in (msgvec) {
if !mmsg.msg_hdr.check_member_ptrs().is_ok() {
break;
}
let msg = unsafe {
if let Ok(msg) = MsgHdr::from_c({ &mmsg.msg_hdr }) {
msg
} else {
break;
}
};
if socket
.sendmsg(&msg, flags)
.map(|bytes_sent| {
mmsg.msg_len = bytes_sent as u32;
mmsg.msg_len
})
.is_ok()
{
send_count += 1;
} else {
break;
}
}
Ok(send_count as isize)
} else if let Ok(socket) = file_ref.as_unix_socket() {
return_errno!(EOPNOTSUPP, "does not support unix socket")
} else {
return_errno!(ENOTSOCK, "not a socket")
}
}

@ -36,8 +36,8 @@ use crate::misc::{resource_t, rlimit_t, sysinfo_t, utsname_t};
use crate::net::{
do_accept, do_accept4, do_bind, do_connect, do_epoll_create, do_epoll_create1, do_epoll_ctl,
do_epoll_pwait, do_epoll_wait, do_getpeername, do_getsockname, do_getsockopt, do_listen,
do_poll, do_recvfrom, do_recvmsg, do_select, do_sendmsg, do_sendto, do_setsockopt, do_shutdown,
do_socket, do_socketpair, msghdr, msghdr_mut,
do_poll, do_recvfrom, do_recvmsg, do_select, do_sendmmsg, do_sendmsg, do_sendto, do_setsockopt,
do_shutdown, do_socket, do_socketpair, mmsghdr, msghdr, msghdr_mut,
};
use crate::process::{
do_arch_prctl, do_clone, do_exit, do_exit_group, do_futex, do_getegid, do_geteuid, do_getgid,
@ -392,7 +392,7 @@ macro_rules! process_syscall_table_with_callback {
(OpenByHandleAt = 304) => handle_unsupported(),
(ClockAdjtime = 305) => handle_unsupported(),
(Syncfs = 306) => handle_unsupported(),
(Sendmmsg = 307) => handle_unsupported(),
(Sendmmsg = 307) => do_sendmmsg(fd: c_int, msg_ptr: *mut mmsghdr, vlen: c_uint, flags_c: c_int),
(Setns = 308) => handle_unsupported(),
(Getcpu = 309) => do_getcpu(cpu_ptr: *mut u32, node_ptr: *mut u32),
(ProcessVmReadv = 310) => handle_unsupported(),

@ -92,6 +92,43 @@ int client_sendmsg(int server_fd, char *buf) {
return ret;
}
#ifdef __GLIBC__
struct mmsghdr {
struct msghdr msg;
unsigned int len;
};
int client_sendmmsg(int server_fd, char *buf) {
int ret = 0;
struct mmsghdr msg_v[2] = {};
struct iovec iov[1];
struct msghdr *msg_ptr = &msg_v[0].msg;
// Set msg0
msg_ptr->msg_name = NULL;
msg_ptr->msg_namelen = 0;
iov[0].iov_base = buf;
iov[0].iov_len = strlen(buf);
msg_ptr->msg_iov = iov;
msg_ptr->msg_iovlen = 1;
msg_ptr->msg_control = 0;
msg_ptr->msg_controllen = 0;
msg_ptr->msg_flags = 0;
// Set msg1
msg_v[1] = msg_v[0];
msg_ptr = &msg_v[1].msg;
msg_ptr->msg_iov = NULL;
msg_ptr->msg_iovlen = 0;
ret = sendmmsg(server_fd, msg_v, 2, 0);
if (ret != 2 || msg_v[0].len <= 0 || msg_v[1].len != 0) {
THROW_ERROR("sendmsg failed");
}
}
#endif
int client_connectionless_sendmsg(char *buf) {
int ret = 0;
struct msghdr msg;
@ -148,7 +185,12 @@ int main(int argc, const char *argv[]) {
neogotiate_msg(server_fd, buf, buf_size);
ret = client_sendmsg(server_fd, buf);
break;
#ifdef __GLIBC__
case 8803:
neogotiate_msg(server_fd, buf, buf_size);
ret = client_sendmmsg(server_fd, buf);
#endif
case 8804:
ret = client_connectionless_sendmsg(DEFAULT_MSG);
break;
default:

@ -268,13 +268,33 @@ int test_sendmsg_recvmsg() {
return ret;
}
int test_sendmmsg_recvmsg() {
int ret = 0;
int child_pid = 0;
int client_fd = connect_with_child(8803, &child_pid);
if (client_fd < 0) {
THROW_ERROR("connect failed");
}
if (neogotiate_msg(client_fd) < 0) {
THROW_ERROR("neogotiate failed");
}
ret = server_recvmsg(client_fd);
if (ret < 0) { return -1; }
ret = wait_for_child_exit(child_pid);
return ret;
}
int test_sendmsg_recvmsg_connectionless() {
int ret = 0;
int child_pid = 0;
signal(SIGCHLD, proc_exit);
char *client_argv[] = {"client", "NULL", "8803", NULL};
char *client_argv[] = {"client", "NULL", "8804", NULL};
ret = posix_spawn(&child_pid, "/bin/client", NULL, NULL, client_argv, NULL);
if (ret < 0) {
THROW_ERROR("spawn client process error");
@ -383,6 +403,9 @@ static test_case_t test_cases[] = {
TEST_CASE(test_read_write),
TEST_CASE(test_send_recv),
TEST_CASE(test_sendmsg_recvmsg),
#ifdef __GLIBC__
TEST_CASE(test_sendmmsg_recvmsg),
#endif
TEST_CASE(test_sendmsg_recvmsg_connectionless),
TEST_CASE(test_fcntl_setfl_and_getfl),
TEST_CASE(test_poll),