Add sendmsg/recvmsg support for unix domain socket
This commit is contained in:
		
							parent
							
								
									71c4937b45
								
							
						
					
					
						commit
						6cb9ca7e44
					
				| @ -5,6 +5,7 @@ use events::{Event, EventFilter, Notifier, Observer}; | |||||||
| use fs::channel::Channel; | use fs::channel::Channel; | ||||||
| use fs::IoEvents; | use fs::IoEvents; | ||||||
| use fs::{CreationFlags, FileMode}; | use fs::{CreationFlags, FileMode}; | ||||||
|  | use net::socket::{Iovs, MsgHdr, MsgHdrMut}; | ||||||
| use std::fmt; | use std::fmt; | ||||||
| use std::sync::atomic::{AtomicBool, Ordering}; | use std::sync::atomic::{AtomicBool, Ordering}; | ||||||
| use std::sync::Arc; | use std::sync::Arc; | ||||||
| @ -223,6 +224,32 @@ impl Stream { | |||||||
|         Ok((data_len, addr)) |         Ok((data_len, addr)) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     pub fn sendmsg(&self, msg_hdr: &MsgHdr, flags: SendFlags) -> Result<usize> { | ||||||
|  |         if !flags.is_empty() { | ||||||
|  |             warn!("unsupported flags: {:?}", flags); | ||||||
|  |         } | ||||||
|  |         if msg_hdr.get_control().is_some() { | ||||||
|  |             warn!("sendmsg with msg_control is not supported"); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         let bufs = msg_hdr.get_iovs().as_slices(); | ||||||
|  |         self.writev(bufs) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn recvmsg(&self, msg_hdr: &mut MsgHdrMut, flags: RecvFlags) -> Result<usize> { | ||||||
|  |         if !flags.is_empty() { | ||||||
|  |             warn!("unsupported flags: {:?}", flags); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         let bufs = msg_hdr.get_iovs_mut().as_slices_mut(); | ||||||
|  |         let data_len = self.readv(bufs)?; | ||||||
|  | 
 | ||||||
|  |         // For stream socket, the msg_name is ignored. And other fields are not supported.
 | ||||||
|  |         msg_hdr.set_name_len(0); | ||||||
|  | 
 | ||||||
|  |         Ok(data_len) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     /// perform shutdown on the socket.
 |     /// perform shutdown on the socket.
 | ||||||
|     pub fn shutdown(&self, how: HowToShut) -> Result<()> { |     pub fn shutdown(&self, how: HowToShut) -> Result<()> { | ||||||
|         if let Status::Connected(ref end) = &*self.inner() { |         if let Status::Connected(ref end) = &*self.inner() { | ||||||
|  | |||||||
| @ -489,23 +489,27 @@ pub fn do_sendmsg(fd: c_int, msg_ptr: *const msghdr, flags_c: c_int) -> Result<i | |||||||
|         fd, msg_ptr, flags_c |         fd, msg_ptr, flags_c | ||||||
|     ); |     ); | ||||||
| 
 | 
 | ||||||
|  |     let msg_hdr = { | ||||||
|  |         let msg_hdr_c = { | ||||||
|  |             from_user::check_ptr(msg_ptr)?; | ||||||
|  |             let msg_hdr_c = unsafe { &*msg_ptr }; | ||||||
|  |             msg_hdr_c.check_member_ptrs()?; | ||||||
|  |             msg_hdr_c | ||||||
|  |         }; | ||||||
|  |         unsafe { MsgHdr::from_c(&msg_hdr_c)? } | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     let flags = SendFlags::from_bits_truncate(flags_c); | ||||||
|  | 
 | ||||||
|     let file_ref = current!().file(fd as FileDesc)?; |     let file_ref = current!().file(fd as FileDesc)?; | ||||||
|     if let Ok(socket) = file_ref.as_host_socket() { |     if let Ok(socket) = file_ref.as_host_socket() { | ||||||
|         let msg_c = { |  | ||||||
|             from_user::check_ptr(msg_ptr)?; |  | ||||||
|             let msg_c = unsafe { &*msg_ptr }; |  | ||||||
|             msg_c.check_member_ptrs()?; |  | ||||||
|             msg_c |  | ||||||
|         }; |  | ||||||
|         let msg = unsafe { MsgHdr::from_c(&msg_c)? }; |  | ||||||
| 
 |  | ||||||
|         let flags = SendFlags::from_bits_truncate(flags_c); |  | ||||||
| 
 |  | ||||||
|         socket |         socket | ||||||
|             .sendmsg(&msg, flags) |             .sendmsg(&msg_hdr, flags) | ||||||
|             .map(|bytes_sent| bytes_sent as isize) |             .map(|bytes_sent| bytes_sent as isize) | ||||||
|     } else if let Ok(socket) = file_ref.as_unix_socket() { |     } else if let Ok(socket) = file_ref.as_unix_socket() { | ||||||
|         return_errno!(EOPNOTSUPP, "does not support unix socket") |         socket | ||||||
|  |             .sendmsg(&msg_hdr, flags) | ||||||
|  |             .map(|bytes_sent| bytes_sent as isize) | ||||||
|     } else { |     } else { | ||||||
|         return_errno!(ENOTSOCK, "not a socket") |         return_errno!(ENOTSOCK, "not a socket") | ||||||
|     } |     } | ||||||
| @ -517,23 +521,27 @@ pub fn do_recvmsg(fd: c_int, msg_mut_ptr: *mut msghdr_mut, flags_c: c_int) -> Re | |||||||
|         fd, msg_mut_ptr, flags_c |         fd, msg_mut_ptr, flags_c | ||||||
|     ); |     ); | ||||||
| 
 | 
 | ||||||
|  |     let mut msg_hdr_mut = { | ||||||
|  |         let msg_hdr_mut_c = { | ||||||
|  |             from_user::check_mut_ptr(msg_mut_ptr)?; | ||||||
|  |             let msg_hdr_mut_c = unsafe { &mut *msg_mut_ptr }; | ||||||
|  |             msg_hdr_mut_c.check_member_ptrs()?; | ||||||
|  |             msg_hdr_mut_c | ||||||
|  |         }; | ||||||
|  |         unsafe { MsgHdrMut::from_c(msg_hdr_mut_c) }? | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     let flags = RecvFlags::from_bits_truncate(flags_c); | ||||||
|  | 
 | ||||||
|     let file_ref = current!().file(fd as FileDesc)?; |     let file_ref = current!().file(fd as FileDesc)?; | ||||||
|     if let Ok(socket) = file_ref.as_host_socket() { |     if let Ok(socket) = file_ref.as_host_socket() { | ||||||
|         let msg_mut_c = { |  | ||||||
|             from_user::check_mut_ptr(msg_mut_ptr)?; |  | ||||||
|             let msg_mut_c = unsafe { &mut *msg_mut_ptr }; |  | ||||||
|             msg_mut_c.check_member_ptrs()?; |  | ||||||
|             msg_mut_c |  | ||||||
|         }; |  | ||||||
|         let mut msg_mut = unsafe { MsgHdrMut::from_c(msg_mut_c)? }; |  | ||||||
| 
 |  | ||||||
|         let flags = RecvFlags::from_bits_truncate(flags_c); |  | ||||||
| 
 |  | ||||||
|         socket |         socket | ||||||
|             .recvmsg(&mut msg_mut, flags) |             .recvmsg(&mut msg_hdr_mut, flags) | ||||||
|             .map(|bytes_recvd| bytes_recvd as isize) |             .map(|bytes_recvd| bytes_recvd as isize) | ||||||
|     } else if let Ok(socket) = file_ref.as_unix_socket() { |     } else if let Ok(socket) = file_ref.as_unix_socket() { | ||||||
|         return_errno!(EOPNOTSUPP, "does not support unix socket") |         socket | ||||||
|  |             .recvmsg(&mut msg_hdr_mut, flags) | ||||||
|  |             .map(|bytes_recvd| bytes_recvd as isize) | ||||||
|     } else { |     } else { | ||||||
|         return_errno!(ENOTSOCK, "not a socket") |         return_errno!(ENOTSOCK, "not a socket") | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -465,6 +465,99 @@ int test_epoll_wait() { | |||||||
|     return 0; |     return 0; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | int client_sendmsg(int server_fd, char *buf) { | ||||||
|  |     int ret = 0; | ||||||
|  |     struct msghdr msg; | ||||||
|  |     struct iovec iov[1]; | ||||||
|  |     msg.msg_name = NULL; | ||||||
|  |     msg.msg_namelen = 0; | ||||||
|  |     iov[0].iov_base = buf; | ||||||
|  |     iov[0].iov_len = strlen(buf); | ||||||
|  |     msg.msg_iov = iov; | ||||||
|  |     msg.msg_iovlen = 1; | ||||||
|  |     msg.msg_control = 0; | ||||||
|  |     msg.msg_controllen = 0; | ||||||
|  |     msg.msg_flags = 0; | ||||||
|  | 
 | ||||||
|  |     ret = sendmsg(server_fd, &msg, 0); | ||||||
|  |     if (ret <= 0) { | ||||||
|  |         THROW_ERROR("sendmsg failed"); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     msg.msg_iov = NULL; | ||||||
|  |     msg.msg_iovlen = 0; | ||||||
|  | 
 | ||||||
|  |     ret = sendmsg(server_fd, &msg, 0); | ||||||
|  |     if (ret != 0) { | ||||||
|  |         THROW_ERROR("empty sendmsg failed"); | ||||||
|  |     } | ||||||
|  |     return ret; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | int server_recvmsg(int client_fd) { | ||||||
|  |     int ret = 0; | ||||||
|  |     const int buf_size = 10; | ||||||
|  |     char buf[3][buf_size]; | ||||||
|  |     struct msghdr msg; | ||||||
|  |     struct iovec iov[3]; | ||||||
|  |     char result_buf[] = {ECHO_MSG}; // 30 bytes
 | ||||||
|  | 
 | ||||||
|  |     memset(&msg, 0, sizeof(struct msghdr)); | ||||||
|  |     msg.msg_name  = NULL; | ||||||
|  |     msg.msg_namelen  = 0; | ||||||
|  |     iov[0].iov_base = buf[0]; | ||||||
|  |     iov[0].iov_len = buf_size; | ||||||
|  |     iov[1].iov_base = buf[1]; | ||||||
|  |     iov[1].iov_len = buf_size; | ||||||
|  |     iov[2].iov_base = buf[2]; | ||||||
|  |     iov[2].iov_len = buf_size; | ||||||
|  |     msg.msg_iov = iov; | ||||||
|  |     msg.msg_iovlen = 3; | ||||||
|  |     msg.msg_control = 0; | ||||||
|  |     msg.msg_controllen = 0; | ||||||
|  |     msg.msg_flags = 0; | ||||||
|  | 
 | ||||||
|  |     ret = recvmsg(client_fd, &msg, 0); | ||||||
|  |     if (ret <= 0) { | ||||||
|  |         THROW_ERROR("recvmsg failed"); | ||||||
|  |     } else { | ||||||
|  |         if (strncmp(buf[0], result_buf, buf_size) != 0 && | ||||||
|  |                 strncmp(buf[1], result_buf + buf_size, buf_size) != 0 && | ||||||
|  |                 strncmp(buf[0], result_buf + buf_size * 2, buf_size) != 0) { | ||||||
|  |             printf("recvmsg : %d, msg: %s,  %s, %s\n", ret, buf[0], buf[1], buf[2]); | ||||||
|  |             THROW_ERROR("msg recvmsg mismatch"); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     return ret; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | int test_sendmsg_recvmsg() { | ||||||
|  |     int ret = 0; | ||||||
|  |     char test_buf[] = {ECHO_MSG}; | ||||||
|  |     int socks[2]; | ||||||
|  | 
 | ||||||
|  |     ret = socketpair(AF_UNIX, SOCK_STREAM, 0, socks); | ||||||
|  |     if (ret < 0) { | ||||||
|  |         THROW_ERROR("socket pair create failed"); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     int server_fd = socks[0]; | ||||||
|  |     int client_fd = socks[1]; | ||||||
|  | 
 | ||||||
|  |     ret = client_sendmsg(server_fd, test_buf); | ||||||
|  |     if (ret < 0) { | ||||||
|  |         THROW_ERROR("client_sendmsg failed"); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     ret = server_recvmsg(client_fd); | ||||||
|  |     if (ret < 0) { | ||||||
|  |         THROW_ERROR("server_recvmsg failed"); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     return ret; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| static test_case_t test_cases[] = { | static test_case_t test_cases[] = { | ||||||
|     TEST_CASE(test_unix_socket_inter_process), |     TEST_CASE(test_unix_socket_inter_process), | ||||||
|     TEST_CASE(test_socketpair_inter_process), |     TEST_CASE(test_socketpair_inter_process), | ||||||
| @ -474,6 +567,7 @@ static test_case_t test_cases[] = { | |||||||
|     TEST_CASE(test_ioctl_fionread), |     TEST_CASE(test_ioctl_fionread), | ||||||
|     TEST_CASE(test_unix_socket_rename), |     TEST_CASE(test_unix_socket_rename), | ||||||
|     TEST_CASE(test_epoll_wait), |     TEST_CASE(test_epoll_wait), | ||||||
|  |     TEST_CASE(test_sendmsg_recvmsg), | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| int main(int argc, const char *argv[]) { | int main(int argc, const char *argv[]) { | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user