Add support for indexing unix domain socket file with inode

This commit is contained in:
Hui, Chunyang 2021-06-01 08:27:31 +00:00 committed by Zongmin.Gu
parent 2cedafeacb
commit 0dc85f8229
5 changed files with 174 additions and 38 deletions

@ -10,7 +10,7 @@ lazy_static! {
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Addr {
File(UnixPath),
File(Option<usize>, UnixPath), // An optional inode number and path. Use inode if there is one.
Abstract(String),
}
@ -54,7 +54,7 @@ impl Addr {
return_errno!(EINVAL, "no null in the address");
}
Ok(Self::File(UnixPath::new(path_cstr.to_str().unwrap())))
Ok(Self::File(None, UnixPath::new(path_cstr.to_str().unwrap())))
}
}
@ -76,7 +76,7 @@ impl Addr {
pub fn path_str(&self) -> &str {
match self {
Self::File(unix_path) => &unix_path.path_str(),
Self::File(_, unix_path) => &unix_path.path_str(),
Self::Abstract(path) => &path,
}
}
@ -86,7 +86,7 @@ impl Addr {
addr.sun_family = AddressFamily::LOCAL as libc::sa_family_t;
let addr_len = match self {
Self::File(unix_path) => {
Self::File(_, unix_path) => {
let path_str = unix_path.path_str();
let buf_len = path_str.len();
/// addr is initialized to all zeros and try_from_raw guarentees

@ -7,9 +7,27 @@ lazy_static! {
pub(super) static ref ADDRESS_SPACE: AddressSpace = AddressSpace::new();
}
#[derive(PartialEq, Eq, PartialOrd, Ord)]
pub enum AddressSpaceKey {
FileKey(usize),
AbstrKey(String),
}
impl AddressSpaceKey {
pub fn from_inode(inode: usize) -> Self {
AddressSpaceKey::FileKey(inode)
}
pub fn from_path(path: String) -> Self {
AddressSpaceKey::AbstrKey(path)
}
}
pub struct AddressSpace {
file: SgxMutex<BTreeMap<String, Option<Arc<Listener>>>>,
abstr: SgxMutex<BTreeMap<String, Option<Arc<Listener>>>>,
// For "file", use inode number as "key" instead of path string so that listeners can still
// be reached even if the socket file is moved or renamed.
file: SgxMutex<BTreeMap<AddressSpaceKey, Option<Arc<Listener>>>>,
abstr: SgxMutex<BTreeMap<AddressSpaceKey, Option<Arc<Listener>>>>,
}
impl AddressSpace {
@ -21,7 +39,7 @@ impl AddressSpace {
}
pub fn add_binder(&self, addr: &Addr) -> Result<()> {
let key = Self::get_key(addr);
let key = Self::get_key(addr).ok_or_else(|| errno!(EINVAL, "can't find socket file"))?;
let mut space = self.get_space(addr);
if space.contains_key(&key) {
return_errno!(EADDRINUSE, "the addr is already bound");
@ -32,7 +50,7 @@ impl AddressSpace {
}
pub fn add_listener(&self, addr: &Addr, capacity: usize, nonblocking: bool) -> Result<()> {
let key = Self::get_key(addr);
let key = Self::get_key(addr).ok_or_else(|| errno!(EINVAL, "the socket is not bound"))?;
let mut space = self.get_space(addr);
if let Some(option) = space.get(&key) {
@ -48,7 +66,7 @@ impl AddressSpace {
}
pub fn resize_listener(&self, addr: &Addr, capacity: usize) -> Result<()> {
let key = Self::get_key(addr);
let key = Self::get_key(addr).ok_or_else(|| errno!(EINVAL, "the socket is not bound"))?;
let mut space = self.get_space(addr);
if let Some(option) = space.get(&key) {
@ -78,27 +96,54 @@ impl AddressSpace {
pub fn get_listener_ref(&self, addr: &Addr) -> Option<Arc<Listener>> {
let key = Self::get_key(addr);
let space = self.get_space(addr);
space.get(&key).map(|x| x.clone()).flatten()
if let Some(key) = key {
let space = self.get_space(addr);
space.get(&key).map(|x| x.clone()).flatten()
} else {
None
}
}
pub fn remove_addr(&self, addr: &Addr) {
let key = Self::get_key(addr);
let mut space = self.get_space(addr);
space.remove(&key);
}
fn get_space(&self, addr: &Addr) -> SgxMutexGuard<'_, BTreeMap<String, Option<Arc<Listener>>>> {
match addr {
Addr::File(unix_path) => self.file.lock().unwrap(),
Addr::Abstract(path) => self.abstr.lock().unwrap(),
if let Some(key) = key {
let mut space = self.get_space(addr);
space.remove(&key);
} else {
warn!("address space key not exit: {:?}", addr);
}
}
fn get_key(addr: &Addr) -> String {
fn get_space(
&self,
addr: &Addr,
) -> SgxMutexGuard<'_, BTreeMap<AddressSpaceKey, Option<Arc<Listener>>>> {
match addr {
Addr::File(unix_path) => unix_path.absolute(),
Addr::Abstract(path) => addr.path_str().to_string(),
Addr::File(_, _) => self.file.lock().unwrap(),
Addr::Abstract(_) => self.abstr.lock().unwrap(),
}
}
fn get_key(addr: &Addr) -> Option<AddressSpaceKey> {
trace!("addr = {:?}", addr);
match addr {
Addr::File(inode_num, unix_path) if inode_num.is_some() => {
Some(AddressSpaceKey::from_inode(inode_num.unwrap()))
}
Addr::File(_, unix_path) => {
let inode = {
let file_path = unix_path.absolute();
let current = current!();
let fs = current.fs().read().unwrap();
fs.lookup_inode(&file_path)
};
if let Ok(inode) = inode {
Some(AddressSpaceKey::from_inode(inode.metadata().unwrap().inode))
} else {
None
}
}
Addr::Abstract(path) => Some(AddressSpaceKey::from_path(addr.path_str().to_string())),
}
}
}

@ -3,6 +3,7 @@ use super::endpoint::{end_pair, Endpoint, RelayNotifier};
use super::*;
use events::{Event, EventFilter, Notifier, Observer};
use fs::channel::Channel;
use fs::CreationFlags;
use fs::IoEvents;
use std::fmt;
use std::sync::atomic::{AtomicBool, Ordering};
@ -68,8 +69,19 @@ impl Stream {
return_errno!(ENOTCONN, "the socket is not connected");
}
// TODO: create the corresponding file in the fs
pub fn bind(&self, addr: &Addr) -> Result<()> {
pub fn bind(&self, addr: &mut Addr) -> Result<()> {
if let Addr::File(inode_num, path) = addr {
// create the corresponding file in the fs and fill Addr with its inode
let corresponding_inode_num = {
let current = current!();
let fs = current.fs().read().unwrap();
let file_ref =
fs.open_file(path.path_str(), CreationFlags::O_CREAT.bits(), 0o777)?;
file_ref.metadata()?.inode
};
*inode_num = Some(corresponding_inode_num);
}
match &mut *self.inner() {
Status::Idle(ref mut info) => {
if info.addr().is_some() {
@ -105,6 +117,7 @@ impl Stream {
match &*inner {
Status::Idle(info) => {
if let Some(addr) = info.addr() {
warn!("addr = {:?}", addr);
ADDRESS_SPACE.add_listener(addr, capacity, info.nonblocking())?;
*inner = Status::Listening(addr.clone());
} else {

@ -44,9 +44,9 @@ pub fn do_bind(fd: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t
trace!("bind to addr: {:?}", sock_addr);
socket.bind(&sock_addr)?;
} else if let Ok(unix_socket) = file_ref.as_unix_socket() {
let unix_addr = unsafe { UnixAddr::try_from_raw(addr, addr_len)? };
let mut unix_addr = unsafe { UnixAddr::try_from_raw(addr, addr_len)? };
trace!("bind to addr: {:?}", unix_addr);
unix_socket.bind(&unix_addr)?;
unix_socket.bind(&mut unix_addr)?;
} else {
return_errno!(ENOTSOCK, "not a socket");
}
@ -231,16 +231,19 @@ pub fn do_getsockopt(
fd, level, optname, optval, optlen
);
let file_ref = current!().file(fd as FileDesc)?;
let socket = file_ref.as_host_socket()?;
let ret = try_libc!(libc::ocall::getsockopt(
socket.raw_host_fd() as i32,
level,
optname,
optval,
optlen
));
Ok(ret as isize)
if let Ok(socket) = file_ref.as_host_socket() {
let ret = try_libc!(libc::ocall::getsockopt(
socket.raw_host_fd() as i32,
level,
optname,
optval,
optlen
));
Ok(ret as isize)
} else {
warn!("getsockeopt is not implemented for non-host socket.");
Ok(0 as isize)
}
}
pub fn do_getpeername(

@ -25,7 +25,7 @@ int create_connected_sockets(int *sockets, char *sock_path) {
memset(&addr, 0, sizeof(struct sockaddr_un)); //Clear structure
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, sock_path);
socklen_t addr_len = strlen(addr.sun_path) + sizeof(addr.sun_family);
socklen_t addr_len = strlen(addr.sun_path) + sizeof(addr.sun_family) + 1;
if (bind(listen_fd, (struct sockaddr *)&addr, addr_len) == -1) {
close(listen_fd);
THROW_ERROR("failed to bind");
@ -36,7 +36,7 @@ int create_connected_sockets(int *sockets, char *sock_path) {
THROW_ERROR("failed to listen");
}
int client_fd = socket(AF_UNIX, SOCK_STREAM, PF_UNIX);
int client_fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (client_fd == -1) {
close(listen_fd);
THROW_ERROR("failed to create a unix socket");
@ -65,6 +65,75 @@ int create_connceted_sockets_default(int *sockets) {
return create_connected_sockets(sockets, "unix_socket_default_path");
}
int create_connected_sockets_then_rename(int *sockets) {
char *socket_original_path = "/tmp/socket_tmp";
char *socket_ready_path = "/tmp/.socket_tmp";
int listen_fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (listen_fd == -1) {
THROW_ERROR("failed to create a unix socket");
}
struct sockaddr_un addr;
memset(&addr, 0, sizeof(struct sockaddr_un)); //Clear structure
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, socket_original_path);
// About addr_len (from man page):
// a UNIX domain socket can be bound to a null-terminated
// filesystem pathname using bind(2). When the address of
// a pathname socket is returned (by one of the system
// calls noted above), its length is:
// offsetof(struct sockaddr_un, sun_path) + strlen(sun_path) + 1
socklen_t addr_len = strlen(addr.sun_path) + sizeof(addr.sun_family) + 1;
if (bind(listen_fd, (struct sockaddr *)&addr, addr_len) == -1) {
close(listen_fd);
THROW_ERROR("failed to bind");
}
if (listen(listen_fd, 5) == -1) {
close(listen_fd);
THROW_ERROR("failed to listen");
}
// rename to new path
unlink(socket_ready_path);
if (rename(socket_original_path, socket_ready_path) < 0) {
THROW_ERROR("failed to rename");
}
int client_fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (client_fd == -1) {
close(listen_fd);
THROW_ERROR("failed to create a unix socket");
}
struct sockaddr_un addr_client;
memset(&addr_client, 0, sizeof(struct sockaddr_un)); //Clear structure
addr_client.sun_family = AF_UNIX;
strcpy(addr_client.sun_path, "/proc/self/root");
strcat(addr_client.sun_path, socket_ready_path);
socklen_t client_addr_len = strlen(addr_client.sun_path) + sizeof(
addr_client.sun_family) + 1;
if (connect(client_fd, (struct sockaddr *)&addr_client, client_addr_len) == -1) {
close(listen_fd);
close(client_fd);
THROW_ERROR("failed to connect");
}
int accepted_fd = accept(listen_fd, (struct sockaddr *)&addr_client, &client_addr_len);
if (accepted_fd == -1) {
close(listen_fd);
close(client_fd);
THROW_ERROR("failed to accept socket");
}
sockets[0] = client_fd;
sockets[1] = accepted_fd;
close(listen_fd);
return 0;
}
int verify_child_echo(int *connected_sockets) {
const char *child_prog = "/bin/hello_world";
const char *child_argv[3] = { child_prog, ECHO_MSG, NULL };
@ -190,6 +259,11 @@ int test_socketpair_inter_process() {
return test_connected_sockets_inter_process(create_connceted_sockets_default);
}
// To emulate JVM bahaviour on UDS
int test_unix_socket_rename() {
return test_connected_sockets_inter_process(create_connected_sockets_then_rename);
}
int test_poll() {
int socks[2];
if (socketpair(AF_UNIX, SOCK_STREAM, 0, socks) < 0) {
@ -301,6 +375,7 @@ static test_case_t test_cases[] = {
TEST_CASE(test_poll),
TEST_CASE(test_getname),
TEST_CASE(test_ioctl_fionread),
TEST_CASE(test_unix_socket_rename),
};
int main(int argc, const char *argv[]) {