Add socketpair syscall
1. Support creating socket pairs of the UNIX domain 2. Add test for socketpair in test/unix_socket 3. Refactor unix_socket test
This commit is contained in:
parent
9c4391b32d
commit
dc14f27a29
@ -2,7 +2,7 @@ use super::*;
|
||||
use alloc::prelude::ToString;
|
||||
use std::collections::btree_map::BTreeMap;
|
||||
use std::fmt;
|
||||
use std::sync::atomic::spin_loop_hint;
|
||||
use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering};
|
||||
use std::sync::SgxMutex as Mutex;
|
||||
use util::ring_buf::{RingBuf, RingBufReader, RingBufWriter};
|
||||
|
||||
@ -88,6 +88,9 @@ impl File for UnixSocketFile {
|
||||
}
|
||||
}
|
||||
|
||||
static SOCKETPAIR_NUM: AtomicUsize = AtomicUsize::new(0);
|
||||
const SOCK_PATH_PREFIX: &str = "socketpair_";
|
||||
|
||||
impl UnixSocketFile {
|
||||
pub fn new(socket_type: c_int, protocol: c_int) -> Result<Self> {
|
||||
let inner = UnixSocket::new(socket_type, protocol)?;
|
||||
@ -123,6 +126,34 @@ impl UnixSocketFile {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
inner.poll()
|
||||
}
|
||||
|
||||
pub fn socketpair(socket_type: i32, protocol: i32) -> Result<(Self, Self)> {
|
||||
let listen_socket = Self::new(socket_type, protocol)?;
|
||||
let bound_path = listen_socket.bind_until_success();
|
||||
listen_socket.listen()?;
|
||||
|
||||
let client_socket = Self::new(socket_type, protocol)?;
|
||||
client_socket.connect(&bound_path)?;
|
||||
|
||||
let accepted_socket = listen_socket.accept()?;
|
||||
Ok((client_socket, accepted_socket))
|
||||
}
|
||||
|
||||
fn bind_until_success(&self) -> String {
|
||||
let mut path = SOCK_PATH_PREFIX.to_string();
|
||||
let mut index = SOCKETPAIR_NUM.load(Ordering::SeqCst);
|
||||
path.push_str(&index.to_string());
|
||||
while self.bind(&path).is_err() {
|
||||
if index == std::usize::MAX {
|
||||
SOCKETPAIR_NUM.store(0, Ordering::SeqCst); //flip SOCKETPAIR_NUM
|
||||
}
|
||||
index += 1;
|
||||
path = SOCK_PATH_PREFIX.to_string();
|
||||
path.push_str(&index.to_string());
|
||||
}
|
||||
SOCKETPAIR_NUM.fetch_max(index + 1, Ordering::SeqCst);
|
||||
path
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for UnixSocketFile {
|
||||
|
@ -8,6 +8,7 @@
|
||||
#![feature(range_contains)]
|
||||
#![feature(core_intrinsics)]
|
||||
#![feature(stmt_expr_attributes)]
|
||||
#![feature(atomic_min_max)]
|
||||
|
||||
#[macro_use]
|
||||
extern crate alloc;
|
||||
|
@ -295,6 +295,12 @@ pub extern "C" fn dispatch_syscall(
|
||||
arg4 as *mut libc::sockaddr,
|
||||
arg5 as *mut libc::socklen_t,
|
||||
),
|
||||
SYS_SOCKETPAIR => do_socketpair(
|
||||
arg0 as c_int,
|
||||
arg1 as c_int,
|
||||
arg2 as c_int,
|
||||
arg3 as *mut c_int,
|
||||
),
|
||||
|
||||
_ => do_unknown(num, arg0, arg1, arg2, arg3, arg4, arg5),
|
||||
};
|
||||
@ -1026,7 +1032,7 @@ fn do_sched_setaffinity(pid: pid_t, cpusize: size_t, buf: *const c_uchar) -> Res
|
||||
|
||||
fn do_socket(domain: c_int, socket_type: c_int, protocol: c_int) -> Result<isize> {
|
||||
info!(
|
||||
"socket: domain: {}, socket_type: {}, protocol: {}",
|
||||
"socket: domain: {}, socket_type: 0x{:x}, protocol: {}",
|
||||
domain, socket_type, protocol
|
||||
);
|
||||
|
||||
@ -1320,6 +1326,46 @@ fn do_recvfrom(
|
||||
Ok(ret as isize)
|
||||
}
|
||||
|
||||
fn do_socketpair(
|
||||
domain: c_int,
|
||||
socket_type: c_int,
|
||||
protocol: c_int,
|
||||
sv: *mut c_int,
|
||||
) -> Result<isize> {
|
||||
info!(
|
||||
"socketpair: domain: {}, type:0x{:x}, protocol: {}",
|
||||
domain, socket_type, protocol
|
||||
);
|
||||
let mut sock_pair = unsafe {
|
||||
check_mut_array(sv, 2)?;
|
||||
std::slice::from_raw_parts_mut(sv as *mut u32, 2)
|
||||
};
|
||||
|
||||
if (domain == libc::AF_UNIX) {
|
||||
let (client_socket, server_socket) =
|
||||
UnixSocketFile::socketpair(socket_type as i32, protocol as i32)?;
|
||||
let current_ref = process::get_current();
|
||||
let mut proc = current_ref.lock().unwrap();
|
||||
sock_pair[0] = proc
|
||||
.get_files()
|
||||
.lock()
|
||||
.unwrap()
|
||||
.put(Arc::new(Box::new(client_socket)), false);
|
||||
sock_pair[1] = proc
|
||||
.get_files()
|
||||
.lock()
|
||||
.unwrap()
|
||||
.put(Arc::new(Box::new(server_socket)), false);
|
||||
|
||||
info!("socketpair: ({}, {})", sock_pair[0], sock_pair[1]);
|
||||
Ok(0)
|
||||
} else if (domain == libc::AF_TIPC) {
|
||||
return_errno!(EAFNOSUPPORT, "cluster domain sockets not supported")
|
||||
} else {
|
||||
return_errno!(EAFNOSUPPORT, "domain not supported")
|
||||
}
|
||||
}
|
||||
|
||||
fn do_select(
|
||||
nfds: c_int,
|
||||
readfds: *mut libc::fd_set,
|
||||
|
@ -9,106 +9,182 @@
|
||||
#include <string.h>
|
||||
#include <spawn.h>
|
||||
|
||||
const char SOCK_PATH[] = "echo_socket";
|
||||
#include "test.h"
|
||||
|
||||
int create_server_socket() {
|
||||
int fd = socket(AF_UNIX, SOCK_STREAM, 0);
|
||||
if (fd == -1) {
|
||||
printf("ERROR: failed to create a unix socket\n");
|
||||
return -1;
|
||||
}
|
||||
#define ECHO_MSG "echo msg for unix_socket test"
|
||||
#define THROW_ERROR(msg) do { \
|
||||
printf("\t\tERROR: %s in func %s at line %d of file %s\n", \
|
||||
(msg), __func__, __LINE__, __FILE__); \
|
||||
return -1; \
|
||||
} while(0)
|
||||
|
||||
struct sockaddr_un local;
|
||||
local.sun_family = AF_UNIX;
|
||||
strcpy(local.sun_path, SOCK_PATH);
|
||||
socklen_t len = strlen(local.sun_path) + sizeof(local.sun_family);
|
||||
int create_connected_sockets(int *sockets, char *sock_path) {
|
||||
int listen_fd = socket(AF_UNIX, SOCK_STREAM, 0);
|
||||
if (listen_fd == -1) {
|
||||
THROW_ERROR("failed to create a unix socket");
|
||||
}
|
||||
|
||||
if (bind(fd, (struct sockaddr *)&local, len) == -1) {
|
||||
printf("ERROR: failed to bind\n");
|
||||
return -1;
|
||||
}
|
||||
struct sockaddr_un addr;
|
||||
memset(&addr, 0, sizeof(struct sockaddr_un)); //Clear structure
|
||||
addr.sun_family = AF_UNIX;
|
||||
strcpy(addr.sun_path, sock_path);
|
||||
socklen_t addr_len = strlen(addr.sun_path) + sizeof(addr.sun_family);
|
||||
if (bind(listen_fd, (struct sockaddr *)&addr, addr_len) == -1) {
|
||||
close(listen_fd);
|
||||
THROW_ERROR("failed to bind");
|
||||
}
|
||||
|
||||
if (listen(fd, 5) == -1) {
|
||||
printf("ERROR: failed to listen\n");
|
||||
return -1;
|
||||
}
|
||||
return fd;
|
||||
if (listen(listen_fd, 5) == -1) {
|
||||
close(listen_fd);
|
||||
THROW_ERROR("failed to listen");
|
||||
}
|
||||
|
||||
int client_fd = socket(AF_UNIX, SOCK_STREAM, 0);
|
||||
if (client_fd == -1) {
|
||||
close(listen_fd);
|
||||
THROW_ERROR("failed to create a unix socket");
|
||||
}
|
||||
|
||||
if (connect(client_fd, (struct sockaddr *)&addr, addr_len) == -1) {
|
||||
close(listen_fd);
|
||||
close(client_fd);
|
||||
THROW_ERROR("failed to connect");
|
||||
}
|
||||
|
||||
int accepted_fd = accept(listen_fd, (struct sockaddr *)&addr, &addr_len);
|
||||
if (accepted_fd == -1) {
|
||||
close(listen_fd);
|
||||
close(client_fd);
|
||||
THROW_ERROR("failed to accept socket");
|
||||
}
|
||||
|
||||
sockets[0] = client_fd;
|
||||
sockets[1] = accepted_fd;
|
||||
close(listen_fd);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int create_client_socket() {
|
||||
int fd = socket(AF_UNIX, SOCK_STREAM, 0);
|
||||
if (fd == -1) {
|
||||
printf("ERROR: failed to create a unix socket\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
struct sockaddr_un remote;
|
||||
remote.sun_family = AF_UNIX;
|
||||
strcpy(remote.sun_path, SOCK_PATH);
|
||||
socklen_t len = strlen(remote.sun_path) + sizeof(remote.sun_family);
|
||||
|
||||
if (connect(fd, (struct sockaddr *)&remote, len) == -1) {
|
||||
printf("ERROR: failed to connect\n");
|
||||
return -1;
|
||||
}
|
||||
return fd;
|
||||
int create_connceted_sockets_default(int *sockets) {
|
||||
return create_connected_sockets(sockets, "unix_socket_default_path");
|
||||
}
|
||||
|
||||
int main(int argc, const char* argv[]) {
|
||||
int listen_fd = create_server_socket();
|
||||
if (listen_fd == -1) {
|
||||
printf("ERROR: failed to create server socket\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
int socket_rd_fd = create_client_socket();
|
||||
if (socket_rd_fd == -1) {
|
||||
printf("ERROR: failed to create client socket\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
struct sockaddr_un remote;
|
||||
socklen_t len = sizeof(remote);
|
||||
int socket_wr_fd = accept(listen_fd, (struct sockaddr *)&remote, &len);
|
||||
if (socket_wr_fd == -1) {
|
||||
printf("ERROR: failed to accept socket\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
// The following is same as 'pipe'
|
||||
|
||||
posix_spawn_file_actions_t file_actions;
|
||||
posix_spawn_file_actions_init(&file_actions);
|
||||
posix_spawn_file_actions_adddup2(&file_actions, socket_wr_fd, STDOUT_FILENO);
|
||||
posix_spawn_file_actions_addclose(&file_actions, socket_rd_fd);
|
||||
|
||||
const char* msg = "Echo!\n";
|
||||
int verify_child_echo(int *connected_sockets) {
|
||||
const char* child_prog = "/bin/hello_world";
|
||||
const char* child_argv[3] = { child_prog, msg, NULL };
|
||||
const char* child_argv[3] = { child_prog, ECHO_MSG, NULL };
|
||||
int child_pid;
|
||||
posix_spawn_file_actions_t file_actions;
|
||||
|
||||
posix_spawn_file_actions_init(&file_actions);
|
||||
posix_spawn_file_actions_adddup2(&file_actions, connected_sockets[0], STDOUT_FILENO);
|
||||
posix_spawn_file_actions_addclose(&file_actions, connected_sockets[1]);
|
||||
|
||||
if (posix_spawn(&child_pid, child_prog, &file_actions,
|
||||
NULL, (char*const*)child_argv, NULL) < 0) {
|
||||
printf("ERROR: failed to spawn a child process\n");
|
||||
return -1;
|
||||
THROW_ERROR("failed to spawn a child process");
|
||||
}
|
||||
close(socket_wr_fd);
|
||||
|
||||
const char* expected_str = msg;
|
||||
size_t expected_len = strlen(expected_str);
|
||||
char actual_str[32] = {0};
|
||||
ssize_t actual_len;
|
||||
//TODO: implement blocking read
|
||||
do {
|
||||
actual_len = read(socket_rd_fd, actual_str, sizeof(actual_str) - 1);
|
||||
actual_len = read(connected_sockets[1], actual_str, 32);
|
||||
} while (actual_len == 0);
|
||||
if (strncmp(expected_str, actual_str, expected_len) != 0) {
|
||||
printf("ERROR: received string is not as expected\n");
|
||||
return -1;
|
||||
if (strncmp(actual_str, ECHO_MSG, sizeof(ECHO_MSG) - 1) != 0) {
|
||||
printf("data read is :%s\n", actual_str);
|
||||
THROW_ERROR("received string is not as expected");
|
||||
}
|
||||
|
||||
int status = 0;
|
||||
if (wait4(child_pid, &status, 0, NULL) < 0) {
|
||||
printf("ERROR: failed to wait4 the child process\n");
|
||||
return -1;
|
||||
THROW_ERROR("failed to wait4 the child process");
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int verify_connection(int src_sock, int dest_sock) {
|
||||
char buf[1024];
|
||||
int i;
|
||||
for (i = 0; i < 100; i++) {
|
||||
if (write(src_sock, ECHO_MSG, sizeof(ECHO_MSG)) < 0) {
|
||||
THROW_ERROR("writing server message");
|
||||
}
|
||||
|
||||
if (read(dest_sock, buf, 1024) < 0) {
|
||||
THROW_ERROR("reading server message");
|
||||
}
|
||||
|
||||
if (strncmp(buf, ECHO_MSG, sizeof(ECHO_MSG)) != 0) {
|
||||
THROW_ERROR("msg received mismatch");
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
//this value should not be too large as one pair consumes 2MB memory
|
||||
#define PAIR_NUM 15
|
||||
|
||||
int test_multiple_socketpairs() {
|
||||
int sockets[PAIR_NUM][2];
|
||||
int i;
|
||||
int ret = 0;
|
||||
|
||||
for(i = 0; i < PAIR_NUM; i++) {
|
||||
if (socketpair(AF_UNIX, SOCK_STREAM, 0, sockets[i]) < 0) {
|
||||
THROW_ERROR("opening stream socket pair");
|
||||
}
|
||||
|
||||
if(verify_connection(sockets[i][0], sockets[i][1]) < 0) {
|
||||
ret = -1;
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
if(verify_connection(sockets[i][1], sockets[i][0]) < 0) {
|
||||
ret = -1;
|
||||
goto cleanup;
|
||||
}
|
||||
}
|
||||
i--;
|
||||
cleanup:
|
||||
for(; i >= 0; i--){
|
||||
close(sockets[i][0]);
|
||||
close(sockets[i][1]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int socketpair_default(int *sockets) {
|
||||
return socketpair(AF_UNIX, SOCK_STREAM, 0, sockets);
|
||||
}
|
||||
|
||||
typedef int(*create_connection_func_t)(int *);
|
||||
int test_connected_sockets_inter_process(create_connection_func_t fn) {
|
||||
int ret = 0;
|
||||
int sockets[2];
|
||||
if (fn(sockets) < 0)
|
||||
return -1;
|
||||
|
||||
ret = verify_child_echo(sockets);
|
||||
|
||||
close(sockets[0]);
|
||||
close(sockets[1]);
|
||||
return ret;
|
||||
}
|
||||
|
||||
int test_unix_socket_inter_process() {
|
||||
return test_connected_sockets_inter_process(socketpair_default);
|
||||
}
|
||||
|
||||
int test_socketpair_inter_process() {
|
||||
return test_connected_sockets_inter_process(create_connceted_sockets_default);
|
||||
}
|
||||
|
||||
static test_case_t test_cases[] = {
|
||||
TEST_CASE(test_unix_socket_inter_process),
|
||||
TEST_CASE(test_socketpair_inter_process),
|
||||
TEST_CASE(test_multiple_socketpairs),
|
||||
};
|
||||
|
||||
int main(int argc, const char* argv[]) {
|
||||
return test_suite_run(test_cases, ARRAY_SIZE(test_cases));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user