diff --git a/src/libos/src/net/socket/unix/stream/stream.rs b/src/libos/src/net/socket/unix/stream/stream.rs index d7e1d16f..e3d16a62 100644 --- a/src/libos/src/net/socket/unix/stream/stream.rs +++ b/src/libos/src/net/socket/unix/stream/stream.rs @@ -5,6 +5,7 @@ use events::{Event, EventFilter, Notifier, Observer}; use fs::channel::Channel; use fs::IoEvents; use fs::{CreationFlags, FileMode}; +use net::socket::{Iovs, MsgHdr, MsgHdrMut}; use std::fmt; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -223,6 +224,32 @@ impl Stream { Ok((data_len, addr)) } + pub fn sendmsg(&self, msg_hdr: &MsgHdr, flags: SendFlags) -> Result { + if !flags.is_empty() { + warn!("unsupported flags: {:?}", flags); + } + if msg_hdr.get_control().is_some() { + warn!("sendmsg with msg_control is not supported"); + } + + let bufs = msg_hdr.get_iovs().as_slices(); + self.writev(bufs) + } + + pub fn recvmsg(&self, msg_hdr: &mut MsgHdrMut, flags: RecvFlags) -> Result { + if !flags.is_empty() { + warn!("unsupported flags: {:?}", flags); + } + + let bufs = msg_hdr.get_iovs_mut().as_slices_mut(); + let data_len = self.readv(bufs)?; + + // For stream socket, the msg_name is ignored. And other fields are not supported. + msg_hdr.set_name_len(0); + + Ok(data_len) + } + /// perform shutdown on the socket. pub fn shutdown(&self, how: HowToShut) -> Result<()> { if let Status::Connected(ref end) = &*self.inner() { diff --git a/src/libos/src/net/syscalls.rs b/src/libos/src/net/syscalls.rs index fe02a6d4..dc962433 100644 --- a/src/libos/src/net/syscalls.rs +++ b/src/libos/src/net/syscalls.rs @@ -489,23 +489,27 @@ pub fn do_sendmsg(fd: c_int, msg_ptr: *const msghdr, flags_c: c_int) -> Result Re fd, msg_mut_ptr, flags_c ); + let mut msg_hdr_mut = { + let msg_hdr_mut_c = { + from_user::check_mut_ptr(msg_mut_ptr)?; + let msg_hdr_mut_c = unsafe { &mut *msg_mut_ptr }; + msg_hdr_mut_c.check_member_ptrs()?; + msg_hdr_mut_c + }; + unsafe { MsgHdrMut::from_c(msg_hdr_mut_c) }? + }; + + let flags = RecvFlags::from_bits_truncate(flags_c); + let file_ref = current!().file(fd as FileDesc)?; if let Ok(socket) = file_ref.as_host_socket() { - let msg_mut_c = { - from_user::check_mut_ptr(msg_mut_ptr)?; - let msg_mut_c = unsafe { &mut *msg_mut_ptr }; - msg_mut_c.check_member_ptrs()?; - msg_mut_c - }; - let mut msg_mut = unsafe { MsgHdrMut::from_c(msg_mut_c)? }; - - let flags = RecvFlags::from_bits_truncate(flags_c); - socket - .recvmsg(&mut msg_mut, flags) + .recvmsg(&mut msg_hdr_mut, flags) .map(|bytes_recvd| bytes_recvd as isize) } else if let Ok(socket) = file_ref.as_unix_socket() { - return_errno!(EOPNOTSUPP, "does not support unix socket") + socket + .recvmsg(&mut msg_hdr_mut, flags) + .map(|bytes_recvd| bytes_recvd as isize) } else { return_errno!(ENOTSOCK, "not a socket") } diff --git a/test/unix_socket/main.c b/test/unix_socket/main.c index 156127c5..c328ea0a 100644 --- a/test/unix_socket/main.c +++ b/test/unix_socket/main.c @@ -465,6 +465,99 @@ int test_epoll_wait() { return 0; } +int client_sendmsg(int server_fd, char *buf) { + int ret = 0; + struct msghdr msg; + struct iovec iov[1]; + msg.msg_name = NULL; + msg.msg_namelen = 0; + iov[0].iov_base = buf; + iov[0].iov_len = strlen(buf); + msg.msg_iov = iov; + msg.msg_iovlen = 1; + msg.msg_control = 0; + msg.msg_controllen = 0; + msg.msg_flags = 0; + + ret = sendmsg(server_fd, &msg, 0); + if (ret <= 0) { + THROW_ERROR("sendmsg failed"); + } + + msg.msg_iov = NULL; + msg.msg_iovlen = 0; + + ret = sendmsg(server_fd, &msg, 0); + if (ret != 0) { + THROW_ERROR("empty sendmsg failed"); + } + return ret; +} + + +int server_recvmsg(int client_fd) { + int ret = 0; + const int buf_size = 10; + char buf[3][buf_size]; + struct msghdr msg; + struct iovec iov[3]; + char result_buf[] = {ECHO_MSG}; // 30 bytes + + memset(&msg, 0, sizeof(struct msghdr)); + msg.msg_name = NULL; + msg.msg_namelen = 0; + iov[0].iov_base = buf[0]; + iov[0].iov_len = buf_size; + iov[1].iov_base = buf[1]; + iov[1].iov_len = buf_size; + iov[2].iov_base = buf[2]; + iov[2].iov_len = buf_size; + msg.msg_iov = iov; + msg.msg_iovlen = 3; + msg.msg_control = 0; + msg.msg_controllen = 0; + msg.msg_flags = 0; + + ret = recvmsg(client_fd, &msg, 0); + if (ret <= 0) { + THROW_ERROR("recvmsg failed"); + } else { + if (strncmp(buf[0], result_buf, buf_size) != 0 && + strncmp(buf[1], result_buf + buf_size, buf_size) != 0 && + strncmp(buf[0], result_buf + buf_size * 2, buf_size) != 0) { + printf("recvmsg : %d, msg: %s, %s, %s\n", ret, buf[0], buf[1], buf[2]); + THROW_ERROR("msg recvmsg mismatch"); + } + } + return ret; +} + +int test_sendmsg_recvmsg() { + int ret = 0; + char test_buf[] = {ECHO_MSG}; + int socks[2]; + + ret = socketpair(AF_UNIX, SOCK_STREAM, 0, socks); + if (ret < 0) { + THROW_ERROR("socket pair create failed"); + } + + int server_fd = socks[0]; + int client_fd = socks[1]; + + ret = client_sendmsg(server_fd, test_buf); + if (ret < 0) { + THROW_ERROR("client_sendmsg failed"); + } + + ret = server_recvmsg(client_fd); + if (ret < 0) { + THROW_ERROR("server_recvmsg failed"); + } + + return ret; +} + static test_case_t test_cases[] = { TEST_CASE(test_unix_socket_inter_process), TEST_CASE(test_socketpair_inter_process), @@ -474,6 +567,7 @@ static test_case_t test_cases[] = { TEST_CASE(test_ioctl_fionread), TEST_CASE(test_unix_socket_rename), TEST_CASE(test_epoll_wait), + TEST_CASE(test_sendmsg_recvmsg), }; int main(int argc, const char *argv[]) {