Add sendmmsg syscall

This commit is contained in:
zongmin.gu 2021-04-29 16:46:57 +08:00 committed by Tate, Hongliang Tian
parent 437b6245d3
commit 070bdf6f39
7 changed files with 144 additions and 18 deletions

@ -7,7 +7,7 @@ pub use self::io_multiplexing::{
PollEventFlags, PollFd, THREAD_NOTIFIERS, PollEventFlags, PollFd, THREAD_NOTIFIERS,
}; };
pub use self::socket::{ pub use self::socket::{
msghdr, msghdr_mut, socketpair, unix_socket, AddressFamily, AsUnixSocket, FileFlags, mmsghdr, msghdr, msghdr_mut, socketpair, unix_socket, AddressFamily, AsUnixSocket, FileFlags,
HostSocket, HostSocketType, HowToShut, Iovs, IovsMut, MsgHdr, MsgHdrFlags, MsgHdrMut, HostSocket, HostSocketType, HowToShut, Iovs, IovsMut, MsgHdr, MsgHdrFlags, MsgHdrMut,
RecvFlags, SendFlags, SliceAsLibcIovec, SockAddr, SocketType, UnixAddr, RecvFlags, SendFlags, SliceAsLibcIovec, SockAddr, SocketType, UnixAddr,
}; };

@ -14,7 +14,7 @@ pub use self::address_family::AddressFamily;
pub use self::flags::{FileFlags, MsgHdrFlags, RecvFlags, SendFlags}; pub use self::flags::{FileFlags, MsgHdrFlags, RecvFlags, SendFlags};
pub use self::host::{HostSocket, HostSocketType}; pub use self::host::{HostSocket, HostSocketType};
pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec}; pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec};
pub use self::msg::{msghdr, msghdr_mut, MsgHdr, MsgHdrMut}; pub use self::msg::{mmsghdr, msghdr, msghdr_mut, MsgHdr, MsgHdrMut};
pub use self::shutdown::HowToShut; pub use self::shutdown::HowToShut;
pub use self::socket_address::SockAddr; pub use self::socket_address::SockAddr;
pub use self::socket_type::SocketType; pub use self::socket_type::SocketType;

@ -14,6 +14,13 @@ pub struct msghdr {
pub msg_flags: c_int, pub msg_flags: c_int,
} }
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct mmsghdr {
pub msg_hdr: msghdr,
pub msg_len: c_uint,
}
/// C struct for a socket message with mutable pointers /// C struct for a socket message with mutable pointers
#[repr(C)] #[repr(C)]
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]

@ -48,7 +48,7 @@ pub fn do_bind(fd: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t
trace!("bind to addr: {:?}", unix_addr); trace!("bind to addr: {:?}", unix_addr);
unix_socket.bind(&unix_addr)?; unix_socket.bind(&unix_addr)?;
} else { } else {
return_errno!(EBADF, "not a socket"); return_errno!(ENOTSOCK, "not a socket");
} }
Ok(0) Ok(0)
@ -61,7 +61,7 @@ pub fn do_listen(fd: c_int, backlog: c_int) -> Result<isize> {
} else if let Ok(unix_socket) = file_ref.as_unix_socket() { } else if let Ok(unix_socket) = file_ref.as_unix_socket() {
unix_socket.listen(backlog)?; unix_socket.listen(backlog)?;
} else { } else {
return_errno!(EBADF, "not a socket"); return_errno!(ENOTSOCK, "not a socket");
} }
Ok(0) Ok(0)
@ -99,7 +99,7 @@ pub fn do_connect(
unix_socket.connect(&addr)?; unix_socket.connect(&addr)?;
} else { } else {
return_errno!(EBADF, "not a socket"); return_errno!(ENOTSOCK, "not a socket");
} }
Ok(0) Ok(0)
@ -170,7 +170,7 @@ pub fn do_accept4(
} }
Ok(new_fd as isize) Ok(new_fd as isize)
} else { } else {
return_errno!(EBADF, "not a socket"); return_errno!(ENOTSOCK, "not a socket");
} }
} }
@ -215,7 +215,7 @@ pub fn do_setsockopt(
warn!("setsockopt for unix socket is unimplemented"); warn!("setsockopt for unix socket is unimplemented");
Ok(0) Ok(0)
} else { } else {
return_errno!(EBADF, "not a socket") return_errno!(ENOTSOCK, "not a socket")
} }
} }
@ -275,7 +275,7 @@ pub fn do_getpeername(
} }
Ok(0) Ok(0)
} else { } else {
return_errno!(EBADF, "not a socket") return_errno!(ENOTSOCK, "not a socket")
} }
} }
@ -322,7 +322,7 @@ pub fn do_getsockname(
} }
Ok(0) Ok(0)
} else { } else {
return_errno!(EBADF, "not a socket"); return_errno!(ENOTSOCK, "not a socket");
} }
} }
@ -440,7 +440,7 @@ pub fn do_recvfrom(
} }
Ok(data_len as isize) Ok(data_len as isize)
} else { } else {
return_errno!(EBADF, "not a socket"); return_errno!(ENOTSOCK, "not a socket");
} }
} }
@ -497,9 +497,9 @@ pub fn do_sendmsg(fd: c_int, msg_ptr: *const msghdr, flags_c: c_int) -> Result<i
.sendmsg(&msg, flags) .sendmsg(&msg, flags)
.map(|bytes_sent| bytes_sent as isize) .map(|bytes_sent| bytes_sent as isize)
} else if let Ok(socket) = file_ref.as_unix_socket() { } else if let Ok(socket) = file_ref.as_unix_socket() {
return_errno!(EBADF, "does not support unix socket") return_errno!(EOPNOTSUPP, "does not support unix socket")
} else { } else {
return_errno!(EBADF, "not a socket") return_errno!(ENOTSOCK, "not a socket")
} }
} }
@ -525,9 +525,63 @@ pub fn do_recvmsg(fd: c_int, msg_mut_ptr: *mut msghdr_mut, flags_c: c_int) -> Re
.recvmsg(&mut msg_mut, flags) .recvmsg(&mut msg_mut, flags)
.map(|bytes_recvd| bytes_recvd as isize) .map(|bytes_recvd| bytes_recvd as isize)
} else if let Ok(socket) = file_ref.as_unix_socket() { } else if let Ok(socket) = file_ref.as_unix_socket() {
return_errno!(EBADF, "does not support unix socket") return_errno!(EOPNOTSUPP, "does not support unix socket")
} else { } else {
return_errno!(EBADF, "not a socket") return_errno!(ENOTSOCK, "not a socket")
}
}
pub fn do_sendmmsg(
fd: c_int,
msgvec_ptr: *mut mmsghdr,
vlen: c_uint,
flags_c: c_int,
) -> Result<isize> {
debug!(
"sendmmsg: fd: {}, msg: {:?}, flags: 0x{:x}",
fd, msgvec_ptr, flags_c
);
from_user::check_ptr(msgvec_ptr)?;
let mut msgvec = unsafe { std::slice::from_raw_parts_mut(msgvec_ptr, vlen as usize) };
let flags = SendFlags::from_bits_truncate(flags_c);
let file_ref = current!().file(fd as FileDesc)?;
if let Ok(socket) = file_ref.as_host_socket() {
let mut send_count = 0;
for mmsg in (msgvec) {
if !mmsg.msg_hdr.check_member_ptrs().is_ok() {
break;
}
let msg = unsafe {
if let Ok(msg) = MsgHdr::from_c({ &mmsg.msg_hdr }) {
msg
} else {
break;
}
};
if socket
.sendmsg(&msg, flags)
.map(|bytes_sent| {
mmsg.msg_len = bytes_sent as u32;
mmsg.msg_len
})
.is_ok()
{
send_count += 1;
} else {
break;
}
}
Ok(send_count as isize)
} else if let Ok(socket) = file_ref.as_unix_socket() {
return_errno!(EOPNOTSUPP, "does not support unix socket")
} else {
return_errno!(ENOTSOCK, "not a socket")
} }
} }

@ -36,8 +36,8 @@ use crate::misc::{resource_t, rlimit_t, sysinfo_t, utsname_t};
use crate::net::{ use crate::net::{
do_accept, do_accept4, do_bind, do_connect, do_epoll_create, do_epoll_create1, do_epoll_ctl, do_accept, do_accept4, do_bind, do_connect, do_epoll_create, do_epoll_create1, do_epoll_ctl,
do_epoll_pwait, do_epoll_wait, do_getpeername, do_getsockname, do_getsockopt, do_listen, do_epoll_pwait, do_epoll_wait, do_getpeername, do_getsockname, do_getsockopt, do_listen,
do_poll, do_recvfrom, do_recvmsg, do_select, do_sendmsg, do_sendto, do_setsockopt, do_shutdown, do_poll, do_recvfrom, do_recvmsg, do_select, do_sendmmsg, do_sendmsg, do_sendto, do_setsockopt,
do_socket, do_socketpair, msghdr, msghdr_mut, do_shutdown, do_socket, do_socketpair, mmsghdr, msghdr, msghdr_mut,
}; };
use crate::process::{ use crate::process::{
do_arch_prctl, do_clone, do_exit, do_exit_group, do_futex, do_getegid, do_geteuid, do_getgid, do_arch_prctl, do_clone, do_exit, do_exit_group, do_futex, do_getegid, do_geteuid, do_getgid,
@ -392,7 +392,7 @@ macro_rules! process_syscall_table_with_callback {
(OpenByHandleAt = 304) => handle_unsupported(), (OpenByHandleAt = 304) => handle_unsupported(),
(ClockAdjtime = 305) => handle_unsupported(), (ClockAdjtime = 305) => handle_unsupported(),
(Syncfs = 306) => handle_unsupported(), (Syncfs = 306) => handle_unsupported(),
(Sendmmsg = 307) => handle_unsupported(), (Sendmmsg = 307) => do_sendmmsg(fd: c_int, msg_ptr: *mut mmsghdr, vlen: c_uint, flags_c: c_int),
(Setns = 308) => handle_unsupported(), (Setns = 308) => handle_unsupported(),
(Getcpu = 309) => do_getcpu(cpu_ptr: *mut u32, node_ptr: *mut u32), (Getcpu = 309) => do_getcpu(cpu_ptr: *mut u32, node_ptr: *mut u32),
(ProcessVmReadv = 310) => handle_unsupported(), (ProcessVmReadv = 310) => handle_unsupported(),

@ -92,6 +92,43 @@ int client_sendmsg(int server_fd, char *buf) {
return ret; return ret;
} }
#ifdef __GLIBC__
struct mmsghdr {
struct msghdr msg;
unsigned int len;
};
int client_sendmmsg(int server_fd, char *buf) {
int ret = 0;
struct mmsghdr msg_v[2] = {};
struct iovec iov[1];
struct msghdr *msg_ptr = &msg_v[0].msg;
// Set msg0
msg_ptr->msg_name = NULL;
msg_ptr->msg_namelen = 0;
iov[0].iov_base = buf;
iov[0].iov_len = strlen(buf);
msg_ptr->msg_iov = iov;
msg_ptr->msg_iovlen = 1;
msg_ptr->msg_control = 0;
msg_ptr->msg_controllen = 0;
msg_ptr->msg_flags = 0;
// Set msg1
msg_v[1] = msg_v[0];
msg_ptr = &msg_v[1].msg;
msg_ptr->msg_iov = NULL;
msg_ptr->msg_iovlen = 0;
ret = sendmmsg(server_fd, msg_v, 2, 0);
if (ret != 2 || msg_v[0].len <= 0 || msg_v[1].len != 0) {
THROW_ERROR("sendmsg failed");
}
}
#endif
int client_connectionless_sendmsg(char *buf) { int client_connectionless_sendmsg(char *buf) {
int ret = 0; int ret = 0;
struct msghdr msg; struct msghdr msg;
@ -148,7 +185,12 @@ int main(int argc, const char *argv[]) {
neogotiate_msg(server_fd, buf, buf_size); neogotiate_msg(server_fd, buf, buf_size);
ret = client_sendmsg(server_fd, buf); ret = client_sendmsg(server_fd, buf);
break; break;
#ifdef __GLIBC__
case 8803: case 8803:
neogotiate_msg(server_fd, buf, buf_size);
ret = client_sendmmsg(server_fd, buf);
#endif
case 8804:
ret = client_connectionless_sendmsg(DEFAULT_MSG); ret = client_connectionless_sendmsg(DEFAULT_MSG);
break; break;
default: default:

@ -268,13 +268,33 @@ int test_sendmsg_recvmsg() {
return ret; return ret;
} }
int test_sendmmsg_recvmsg() {
int ret = 0;
int child_pid = 0;
int client_fd = connect_with_child(8803, &child_pid);
if (client_fd < 0) {
THROW_ERROR("connect failed");
}
if (neogotiate_msg(client_fd) < 0) {
THROW_ERROR("neogotiate failed");
}
ret = server_recvmsg(client_fd);
if (ret < 0) { return -1; }
ret = wait_for_child_exit(child_pid);
return ret;
}
int test_sendmsg_recvmsg_connectionless() { int test_sendmsg_recvmsg_connectionless() {
int ret = 0; int ret = 0;
int child_pid = 0; int child_pid = 0;
signal(SIGCHLD, proc_exit); signal(SIGCHLD, proc_exit);
char *client_argv[] = {"client", "NULL", "8803", NULL}; char *client_argv[] = {"client", "NULL", "8804", NULL};
ret = posix_spawn(&child_pid, "/bin/client", NULL, NULL, client_argv, NULL); ret = posix_spawn(&child_pid, "/bin/client", NULL, NULL, client_argv, NULL);
if (ret < 0) { if (ret < 0) {
THROW_ERROR("spawn client process error"); THROW_ERROR("spawn client process error");
@ -383,6 +403,9 @@ static test_case_t test_cases[] = {
TEST_CASE(test_read_write), TEST_CASE(test_read_write),
TEST_CASE(test_send_recv), TEST_CASE(test_send_recv),
TEST_CASE(test_sendmsg_recvmsg), TEST_CASE(test_sendmsg_recvmsg),
#ifdef __GLIBC__
TEST_CASE(test_sendmmsg_recvmsg),
#endif
TEST_CASE(test_sendmsg_recvmsg_connectionless), TEST_CASE(test_sendmsg_recvmsg_connectionless),
TEST_CASE(test_fcntl_setfl_and_getfl), TEST_CASE(test_fcntl_setfl_and_getfl),
TEST_CASE(test_poll), TEST_CASE(test_poll),