diff --git a/src/libos/src/net/socket/unix/addr.rs b/src/libos/src/net/socket/unix/addr.rs index 397b2851..c79a6f25 100644 --- a/src/libos/src/net/socket/unix/addr.rs +++ b/src/libos/src/net/socket/unix/addr.rs @@ -10,7 +10,7 @@ lazy_static! { #[derive(Clone, Debug, Eq, PartialEq)] pub enum Addr { - File(UnixPath), + File(Option, UnixPath), // An optional inode number and path. Use inode if there is one. Abstract(String), } @@ -54,7 +54,7 @@ impl Addr { return_errno!(EINVAL, "no null in the address"); } - Ok(Self::File(UnixPath::new(path_cstr.to_str().unwrap()))) + Ok(Self::File(None, UnixPath::new(path_cstr.to_str().unwrap()))) } } @@ -76,7 +76,7 @@ impl Addr { pub fn path_str(&self) -> &str { match self { - Self::File(unix_path) => &unix_path.path_str(), + Self::File(_, unix_path) => &unix_path.path_str(), Self::Abstract(path) => &path, } } @@ -86,7 +86,7 @@ impl Addr { addr.sun_family = AddressFamily::LOCAL as libc::sa_family_t; let addr_len = match self { - Self::File(unix_path) => { + Self::File(_, unix_path) => { let path_str = unix_path.path_str(); let buf_len = path_str.len(); /// addr is initialized to all zeros and try_from_raw guarentees diff --git a/src/libos/src/net/socket/unix/stream/address_space.rs b/src/libos/src/net/socket/unix/stream/address_space.rs index b9055145..7cbfc14e 100644 --- a/src/libos/src/net/socket/unix/stream/address_space.rs +++ b/src/libos/src/net/socket/unix/stream/address_space.rs @@ -7,9 +7,27 @@ lazy_static! { pub(super) static ref ADDRESS_SPACE: AddressSpace = AddressSpace::new(); } +#[derive(PartialEq, Eq, PartialOrd, Ord)] +pub enum AddressSpaceKey { + FileKey(usize), + AbstrKey(String), +} + +impl AddressSpaceKey { + pub fn from_inode(inode: usize) -> Self { + AddressSpaceKey::FileKey(inode) + } + + pub fn from_path(path: String) -> Self { + AddressSpaceKey::AbstrKey(path) + } +} + pub struct AddressSpace { - file: SgxMutex>>>, - abstr: SgxMutex>>>, + // For "file", use inode number as "key" instead of path string so that listeners can still + // be reached even if the socket file is moved or renamed. + file: SgxMutex>>>, + abstr: SgxMutex>>>, } impl AddressSpace { @@ -21,7 +39,7 @@ impl AddressSpace { } pub fn add_binder(&self, addr: &Addr) -> Result<()> { - let key = Self::get_key(addr); + let key = Self::get_key(addr).ok_or_else(|| errno!(EINVAL, "can't find socket file"))?; let mut space = self.get_space(addr); if space.contains_key(&key) { return_errno!(EADDRINUSE, "the addr is already bound"); @@ -32,7 +50,7 @@ impl AddressSpace { } pub fn add_listener(&self, addr: &Addr, capacity: usize, nonblocking: bool) -> Result<()> { - let key = Self::get_key(addr); + let key = Self::get_key(addr).ok_or_else(|| errno!(EINVAL, "the socket is not bound"))?; let mut space = self.get_space(addr); if let Some(option) = space.get(&key) { @@ -48,7 +66,7 @@ impl AddressSpace { } pub fn resize_listener(&self, addr: &Addr, capacity: usize) -> Result<()> { - let key = Self::get_key(addr); + let key = Self::get_key(addr).ok_or_else(|| errno!(EINVAL, "the socket is not bound"))?; let mut space = self.get_space(addr); if let Some(option) = space.get(&key) { @@ -78,27 +96,54 @@ impl AddressSpace { pub fn get_listener_ref(&self, addr: &Addr) -> Option> { let key = Self::get_key(addr); - let space = self.get_space(addr); - space.get(&key).map(|x| x.clone()).flatten() + if let Some(key) = key { + let space = self.get_space(addr); + space.get(&key).map(|x| x.clone()).flatten() + } else { + None + } } pub fn remove_addr(&self, addr: &Addr) { let key = Self::get_key(addr); - let mut space = self.get_space(addr); - space.remove(&key); - } - - fn get_space(&self, addr: &Addr) -> SgxMutexGuard<'_, BTreeMap>>> { - match addr { - Addr::File(unix_path) => self.file.lock().unwrap(), - Addr::Abstract(path) => self.abstr.lock().unwrap(), + if let Some(key) = key { + let mut space = self.get_space(addr); + space.remove(&key); + } else { + warn!("address space key not exit: {:?}", addr); } } - fn get_key(addr: &Addr) -> String { + fn get_space( + &self, + addr: &Addr, + ) -> SgxMutexGuard<'_, BTreeMap>>> { match addr { - Addr::File(unix_path) => unix_path.absolute(), - Addr::Abstract(path) => addr.path_str().to_string(), + Addr::File(_, _) => self.file.lock().unwrap(), + Addr::Abstract(_) => self.abstr.lock().unwrap(), + } + } + + fn get_key(addr: &Addr) -> Option { + trace!("addr = {:?}", addr); + match addr { + Addr::File(inode_num, unix_path) if inode_num.is_some() => { + Some(AddressSpaceKey::from_inode(inode_num.unwrap())) + } + Addr::File(_, unix_path) => { + let inode = { + let file_path = unix_path.absolute(); + let current = current!(); + let fs = current.fs().read().unwrap(); + fs.lookup_inode(&file_path) + }; + if let Ok(inode) = inode { + Some(AddressSpaceKey::from_inode(inode.metadata().unwrap().inode)) + } else { + None + } + } + Addr::Abstract(path) => Some(AddressSpaceKey::from_path(addr.path_str().to_string())), } } } diff --git a/src/libos/src/net/socket/unix/stream/stream.rs b/src/libos/src/net/socket/unix/stream/stream.rs index b5923cc9..8c037c6b 100644 --- a/src/libos/src/net/socket/unix/stream/stream.rs +++ b/src/libos/src/net/socket/unix/stream/stream.rs @@ -3,6 +3,7 @@ use super::endpoint::{end_pair, Endpoint, RelayNotifier}; use super::*; use events::{Event, EventFilter, Notifier, Observer}; use fs::channel::Channel; +use fs::CreationFlags; use fs::IoEvents; use std::fmt; use std::sync::atomic::{AtomicBool, Ordering}; @@ -68,8 +69,19 @@ impl Stream { return_errno!(ENOTCONN, "the socket is not connected"); } - // TODO: create the corresponding file in the fs - pub fn bind(&self, addr: &Addr) -> Result<()> { + pub fn bind(&self, addr: &mut Addr) -> Result<()> { + if let Addr::File(inode_num, path) = addr { + // create the corresponding file in the fs and fill Addr with its inode + let corresponding_inode_num = { + let current = current!(); + let fs = current.fs().read().unwrap(); + let file_ref = + fs.open_file(path.path_str(), CreationFlags::O_CREAT.bits(), 0o777)?; + file_ref.metadata()?.inode + }; + *inode_num = Some(corresponding_inode_num); + } + match &mut *self.inner() { Status::Idle(ref mut info) => { if info.addr().is_some() { @@ -105,6 +117,7 @@ impl Stream { match &*inner { Status::Idle(info) => { if let Some(addr) = info.addr() { + warn!("addr = {:?}", addr); ADDRESS_SPACE.add_listener(addr, capacity, info.nonblocking())?; *inner = Status::Listening(addr.clone()); } else { diff --git a/src/libos/src/net/syscalls.rs b/src/libos/src/net/syscalls.rs index 2e114955..042e73c9 100644 --- a/src/libos/src/net/syscalls.rs +++ b/src/libos/src/net/syscalls.rs @@ -44,9 +44,9 @@ pub fn do_bind(fd: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t trace!("bind to addr: {:?}", sock_addr); socket.bind(&sock_addr)?; } else if let Ok(unix_socket) = file_ref.as_unix_socket() { - let unix_addr = unsafe { UnixAddr::try_from_raw(addr, addr_len)? }; + let mut unix_addr = unsafe { UnixAddr::try_from_raw(addr, addr_len)? }; trace!("bind to addr: {:?}", unix_addr); - unix_socket.bind(&unix_addr)?; + unix_socket.bind(&mut unix_addr)?; } else { return_errno!(ENOTSOCK, "not a socket"); } @@ -231,16 +231,19 @@ pub fn do_getsockopt( fd, level, optname, optval, optlen ); let file_ref = current!().file(fd as FileDesc)?; - let socket = file_ref.as_host_socket()?; - - let ret = try_libc!(libc::ocall::getsockopt( - socket.raw_host_fd() as i32, - level, - optname, - optval, - optlen - )); - Ok(ret as isize) + if let Ok(socket) = file_ref.as_host_socket() { + let ret = try_libc!(libc::ocall::getsockopt( + socket.raw_host_fd() as i32, + level, + optname, + optval, + optlen + )); + Ok(ret as isize) + } else { + warn!("getsockeopt is not implemented for non-host socket."); + Ok(0 as isize) + } } pub fn do_getpeername( diff --git a/test/unix_socket/main.c b/test/unix_socket/main.c index 3ca99f9a..bab423ed 100644 --- a/test/unix_socket/main.c +++ b/test/unix_socket/main.c @@ -25,7 +25,7 @@ int create_connected_sockets(int *sockets, char *sock_path) { memset(&addr, 0, sizeof(struct sockaddr_un)); //Clear structure addr.sun_family = AF_UNIX; strcpy(addr.sun_path, sock_path); - socklen_t addr_len = strlen(addr.sun_path) + sizeof(addr.sun_family); + socklen_t addr_len = strlen(addr.sun_path) + sizeof(addr.sun_family) + 1; if (bind(listen_fd, (struct sockaddr *)&addr, addr_len) == -1) { close(listen_fd); THROW_ERROR("failed to bind"); @@ -36,7 +36,7 @@ int create_connected_sockets(int *sockets, char *sock_path) { THROW_ERROR("failed to listen"); } - int client_fd = socket(AF_UNIX, SOCK_STREAM, PF_UNIX); + int client_fd = socket(AF_UNIX, SOCK_STREAM, 0); if (client_fd == -1) { close(listen_fd); THROW_ERROR("failed to create a unix socket"); @@ -65,6 +65,75 @@ int create_connceted_sockets_default(int *sockets) { return create_connected_sockets(sockets, "unix_socket_default_path"); } +int create_connected_sockets_then_rename(int *sockets) { + char *socket_original_path = "/tmp/socket_tmp"; + char *socket_ready_path = "/tmp/.socket_tmp"; + int listen_fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (listen_fd == -1) { + THROW_ERROR("failed to create a unix socket"); + } + + struct sockaddr_un addr; + memset(&addr, 0, sizeof(struct sockaddr_un)); //Clear structure + addr.sun_family = AF_UNIX; + strcpy(addr.sun_path, socket_original_path); + + // About addr_len (from man page): + // a UNIX domain socket can be bound to a null-terminated + // filesystem pathname using bind(2). When the address of + // a pathname socket is returned (by one of the system + // calls noted above), its length is: + // offsetof(struct sockaddr_un, sun_path) + strlen(sun_path) + 1 + socklen_t addr_len = strlen(addr.sun_path) + sizeof(addr.sun_family) + 1; + if (bind(listen_fd, (struct sockaddr *)&addr, addr_len) == -1) { + close(listen_fd); + THROW_ERROR("failed to bind"); + } + + if (listen(listen_fd, 5) == -1) { + close(listen_fd); + THROW_ERROR("failed to listen"); + } + + // rename to new path + unlink(socket_ready_path); + if (rename(socket_original_path, socket_ready_path) < 0) { + THROW_ERROR("failed to rename"); + } + + int client_fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (client_fd == -1) { + close(listen_fd); + THROW_ERROR("failed to create a unix socket"); + } + + struct sockaddr_un addr_client; + memset(&addr_client, 0, sizeof(struct sockaddr_un)); //Clear structure + addr_client.sun_family = AF_UNIX; + strcpy(addr_client.sun_path, "/proc/self/root"); + strcat(addr_client.sun_path, socket_ready_path); + + socklen_t client_addr_len = strlen(addr_client.sun_path) + sizeof( + addr_client.sun_family) + 1; + if (connect(client_fd, (struct sockaddr *)&addr_client, client_addr_len) == -1) { + close(listen_fd); + close(client_fd); + THROW_ERROR("failed to connect"); + } + + int accepted_fd = accept(listen_fd, (struct sockaddr *)&addr_client, &client_addr_len); + if (accepted_fd == -1) { + close(listen_fd); + close(client_fd); + THROW_ERROR("failed to accept socket"); + } + + sockets[0] = client_fd; + sockets[1] = accepted_fd; + close(listen_fd); + return 0; +} + int verify_child_echo(int *connected_sockets) { const char *child_prog = "/bin/hello_world"; const char *child_argv[3] = { child_prog, ECHO_MSG, NULL }; @@ -190,6 +259,11 @@ int test_socketpair_inter_process() { return test_connected_sockets_inter_process(create_connceted_sockets_default); } +// To emulate JVM bahaviour on UDS +int test_unix_socket_rename() { + return test_connected_sockets_inter_process(create_connected_sockets_then_rename); +} + int test_poll() { int socks[2]; if (socketpair(AF_UNIX, SOCK_STREAM, 0, socks) < 0) { @@ -301,6 +375,7 @@ static test_case_t test_cases[] = { TEST_CASE(test_poll), TEST_CASE(test_getname), TEST_CASE(test_ioctl_fionread), + TEST_CASE(test_unix_socket_rename), }; int main(int argc, const char *argv[]) {