Add sendmsg/recvmsg support for unix domain socket

This commit is contained in:
Hui, Chunyang 2022-08-17 08:15:17 +00:00 committed by volcano
parent 71c4937b45
commit 6cb9ca7e44
3 changed files with 153 additions and 24 deletions

@ -5,6 +5,7 @@ use events::{Event, EventFilter, Notifier, Observer};
use fs::channel::Channel; use fs::channel::Channel;
use fs::IoEvents; use fs::IoEvents;
use fs::{CreationFlags, FileMode}; use fs::{CreationFlags, FileMode};
use net::socket::{Iovs, MsgHdr, MsgHdrMut};
use std::fmt; use std::fmt;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
@ -223,6 +224,32 @@ impl Stream {
Ok((data_len, addr)) Ok((data_len, addr))
} }
pub fn sendmsg(&self, msg_hdr: &MsgHdr, flags: SendFlags) -> Result<usize> {
if !flags.is_empty() {
warn!("unsupported flags: {:?}", flags);
}
if msg_hdr.get_control().is_some() {
warn!("sendmsg with msg_control is not supported");
}
let bufs = msg_hdr.get_iovs().as_slices();
self.writev(bufs)
}
pub fn recvmsg(&self, msg_hdr: &mut MsgHdrMut, flags: RecvFlags) -> Result<usize> {
if !flags.is_empty() {
warn!("unsupported flags: {:?}", flags);
}
let bufs = msg_hdr.get_iovs_mut().as_slices_mut();
let data_len = self.readv(bufs)?;
// For stream socket, the msg_name is ignored. And other fields are not supported.
msg_hdr.set_name_len(0);
Ok(data_len)
}
/// perform shutdown on the socket. /// perform shutdown on the socket.
pub fn shutdown(&self, how: HowToShut) -> Result<()> { pub fn shutdown(&self, how: HowToShut) -> Result<()> {
if let Status::Connected(ref end) = &*self.inner() { if let Status::Connected(ref end) = &*self.inner() {

@ -489,23 +489,27 @@ pub fn do_sendmsg(fd: c_int, msg_ptr: *const msghdr, flags_c: c_int) -> Result<i
fd, msg_ptr, flags_c fd, msg_ptr, flags_c
); );
let file_ref = current!().file(fd as FileDesc)?; let msg_hdr = {
if let Ok(socket) = file_ref.as_host_socket() { let msg_hdr_c = {
let msg_c = {
from_user::check_ptr(msg_ptr)?; from_user::check_ptr(msg_ptr)?;
let msg_c = unsafe { &*msg_ptr }; let msg_hdr_c = unsafe { &*msg_ptr };
msg_c.check_member_ptrs()?; msg_hdr_c.check_member_ptrs()?;
msg_c msg_hdr_c
};
unsafe { MsgHdr::from_c(&msg_hdr_c)? }
}; };
let msg = unsafe { MsgHdr::from_c(&msg_c)? };
let flags = SendFlags::from_bits_truncate(flags_c); let flags = SendFlags::from_bits_truncate(flags_c);
let file_ref = current!().file(fd as FileDesc)?;
if let Ok(socket) = file_ref.as_host_socket() {
socket socket
.sendmsg(&msg, flags) .sendmsg(&msg_hdr, flags)
.map(|bytes_sent| bytes_sent as isize) .map(|bytes_sent| bytes_sent as isize)
} else if let Ok(socket) = file_ref.as_unix_socket() { } else if let Ok(socket) = file_ref.as_unix_socket() {
return_errno!(EOPNOTSUPP, "does not support unix socket") socket
.sendmsg(&msg_hdr, flags)
.map(|bytes_sent| bytes_sent as isize)
} else { } else {
return_errno!(ENOTSOCK, "not a socket") return_errno!(ENOTSOCK, "not a socket")
} }
@ -517,23 +521,27 @@ pub fn do_recvmsg(fd: c_int, msg_mut_ptr: *mut msghdr_mut, flags_c: c_int) -> Re
fd, msg_mut_ptr, flags_c fd, msg_mut_ptr, flags_c
); );
let file_ref = current!().file(fd as FileDesc)?; let mut msg_hdr_mut = {
if let Ok(socket) = file_ref.as_host_socket() { let msg_hdr_mut_c = {
let msg_mut_c = {
from_user::check_mut_ptr(msg_mut_ptr)?; from_user::check_mut_ptr(msg_mut_ptr)?;
let msg_mut_c = unsafe { &mut *msg_mut_ptr }; let msg_hdr_mut_c = unsafe { &mut *msg_mut_ptr };
msg_mut_c.check_member_ptrs()?; msg_hdr_mut_c.check_member_ptrs()?;
msg_mut_c msg_hdr_mut_c
};
unsafe { MsgHdrMut::from_c(msg_hdr_mut_c) }?
}; };
let mut msg_mut = unsafe { MsgHdrMut::from_c(msg_mut_c)? };
let flags = RecvFlags::from_bits_truncate(flags_c); let flags = RecvFlags::from_bits_truncate(flags_c);
let file_ref = current!().file(fd as FileDesc)?;
if let Ok(socket) = file_ref.as_host_socket() {
socket socket
.recvmsg(&mut msg_mut, flags) .recvmsg(&mut msg_hdr_mut, flags)
.map(|bytes_recvd| bytes_recvd as isize) .map(|bytes_recvd| bytes_recvd as isize)
} else if let Ok(socket) = file_ref.as_unix_socket() { } else if let Ok(socket) = file_ref.as_unix_socket() {
return_errno!(EOPNOTSUPP, "does not support unix socket") socket
.recvmsg(&mut msg_hdr_mut, flags)
.map(|bytes_recvd| bytes_recvd as isize)
} else { } else {
return_errno!(ENOTSOCK, "not a socket") return_errno!(ENOTSOCK, "not a socket")
} }

@ -465,6 +465,99 @@ int test_epoll_wait() {
return 0; return 0;
} }
int client_sendmsg(int server_fd, char *buf) {
int ret = 0;
struct msghdr msg;
struct iovec iov[1];
msg.msg_name = NULL;
msg.msg_namelen = 0;
iov[0].iov_base = buf;
iov[0].iov_len = strlen(buf);
msg.msg_iov = iov;
msg.msg_iovlen = 1;
msg.msg_control = 0;
msg.msg_controllen = 0;
msg.msg_flags = 0;
ret = sendmsg(server_fd, &msg, 0);
if (ret <= 0) {
THROW_ERROR("sendmsg failed");
}
msg.msg_iov = NULL;
msg.msg_iovlen = 0;
ret = sendmsg(server_fd, &msg, 0);
if (ret != 0) {
THROW_ERROR("empty sendmsg failed");
}
return ret;
}
int server_recvmsg(int client_fd) {
int ret = 0;
const int buf_size = 10;
char buf[3][buf_size];
struct msghdr msg;
struct iovec iov[3];
char result_buf[] = {ECHO_MSG}; // 30 bytes
memset(&msg, 0, sizeof(struct msghdr));
msg.msg_name = NULL;
msg.msg_namelen = 0;
iov[0].iov_base = buf[0];
iov[0].iov_len = buf_size;
iov[1].iov_base = buf[1];
iov[1].iov_len = buf_size;
iov[2].iov_base = buf[2];
iov[2].iov_len = buf_size;
msg.msg_iov = iov;
msg.msg_iovlen = 3;
msg.msg_control = 0;
msg.msg_controllen = 0;
msg.msg_flags = 0;
ret = recvmsg(client_fd, &msg, 0);
if (ret <= 0) {
THROW_ERROR("recvmsg failed");
} else {
if (strncmp(buf[0], result_buf, buf_size) != 0 &&
strncmp(buf[1], result_buf + buf_size, buf_size) != 0 &&
strncmp(buf[0], result_buf + buf_size * 2, buf_size) != 0) {
printf("recvmsg : %d, msg: %s, %s, %s\n", ret, buf[0], buf[1], buf[2]);
THROW_ERROR("msg recvmsg mismatch");
}
}
return ret;
}
int test_sendmsg_recvmsg() {
int ret = 0;
char test_buf[] = {ECHO_MSG};
int socks[2];
ret = socketpair(AF_UNIX, SOCK_STREAM, 0, socks);
if (ret < 0) {
THROW_ERROR("socket pair create failed");
}
int server_fd = socks[0];
int client_fd = socks[1];
ret = client_sendmsg(server_fd, test_buf);
if (ret < 0) {
THROW_ERROR("client_sendmsg failed");
}
ret = server_recvmsg(client_fd);
if (ret < 0) {
THROW_ERROR("server_recvmsg failed");
}
return ret;
}
static test_case_t test_cases[] = { static test_case_t test_cases[] = {
TEST_CASE(test_unix_socket_inter_process), TEST_CASE(test_unix_socket_inter_process),
TEST_CASE(test_socketpair_inter_process), TEST_CASE(test_socketpair_inter_process),
@ -474,6 +567,7 @@ static test_case_t test_cases[] = {
TEST_CASE(test_ioctl_fionread), TEST_CASE(test_ioctl_fionread),
TEST_CASE(test_unix_socket_rename), TEST_CASE(test_unix_socket_rename),
TEST_CASE(test_epoll_wait), TEST_CASE(test_epoll_wait),
TEST_CASE(test_sendmsg_recvmsg),
}; };
int main(int argc, const char *argv[]) { int main(int argc, const char *argv[]) {