diff --git a/src/libos/src/net/syscalls.rs b/src/libos/src/net/syscalls.rs index c3a5381e..11786739 100644 --- a/src/libos/src/net/syscalls.rs +++ b/src/libos/src/net/syscalls.rs @@ -278,21 +278,37 @@ pub fn do_sendto( addr_len: libc::socklen_t, ) -> Result { debug!( - "sendto: fd: {}, base: {:?}, len: {}, addr: {:?}, addr_len: {}", - fd, base, len, addr, addr_len + "sendto: fd: {}, base: {:?}, len: {}, flags: {} addr: {:?}, addr_len: {}", + fd, base, len, flags, addr, addr_len ); - let file_ref = current!().file(fd as FileDesc)?; - let socket = file_ref.as_socket()?; + from_user::check_array(base as *const u8, len)?; - let ret = try_libc!(libc::ocall::sendto( - socket.fd(), - base, - len, - flags, - addr, - addr_len - )); - Ok(ret as isize) + let file_ref = current!().file(fd as FileDesc)?; + if let Ok(socket) = file_ref.as_socket() { + // TODO: check addr and addr_len according to connection mode + let ret = try_libc!(libc::ocall::sendto( + socket.fd(), + base, + len, + flags, + addr, + addr_len + )); + Ok(ret as isize) + } else if let Ok(unix) = file_ref.as_unix_socket() { + if !addr.is_null() || addr_len != 0 { + return_errno!(EISCONN, "Only connection-mode socket is supported"); + } + + if !unix.is_connected() { + return_errno!(ENOTCONN, "the socket has not been connected yet"); + } + + let data = unsafe { std::slice::from_raw_parts(base as *const u8, len) }; + unix.write(data).map(|u| u as isize) + } else { + return_errno!(EBADF, "unsupported file type"); + } } pub fn do_recvfrom( diff --git a/src/libos/src/net/unix_socket.rs b/src/libos/src/net/unix_socket.rs index ba5fb837..8d3e8a8d 100644 --- a/src/libos/src/net/unix_socket.rs +++ b/src/libos/src/net/unix_socket.rs @@ -150,6 +150,14 @@ impl UnixSocketFile { } } } + + pub fn is_connected(&self) -> bool { + if let Status::Connected(_) = self.inner.lock().unwrap().status { + true + } else { + false + } + } } impl Debug for UnixSocketFile { diff --git a/test/unix_socket/main.c b/test/unix_socket/main.c index a98d93ba..8c830d92 100644 --- a/test/unix_socket/main.c +++ b/test/unix_socket/main.c @@ -101,8 +101,14 @@ int verify_connection(int src_sock, int dest_sock) { char buf[1024]; int i; for (i = 0; i < 100; i++) { - if (write(src_sock, ECHO_MSG, sizeof(ECHO_MSG)) < 0) { - THROW_ERROR("writing server message"); + if (i % 2 == 0) { + if (write(src_sock, ECHO_MSG, sizeof(ECHO_MSG)) < 0) { + THROW_ERROR("writing server message"); + } + } else { + if (sendto(src_sock, ECHO_MSG, sizeof(ECHO_MSG), 0, NULL, 0) < 0) { + THROW_ERROR("sendto server message"); + } } if (read(dest_sock, buf, 1024) < 0) {