Add socketpair syscall

1. Support creating socket pairs of the UNIX domain
2. Add test for socketpair in test/unix_socket
3. Refactor unix_socket test
This commit is contained in:
He Sun 2019-11-22 08:25:29 +00:00 committed by Tate, Hongliang Tian
parent 9c4391b32d
commit dc14f27a29
4 changed files with 234 additions and 80 deletions

@ -2,7 +2,7 @@ use super::*;
use alloc::prelude::ToString;
use std::collections::btree_map::BTreeMap;
use std::fmt;
use std::sync::atomic::spin_loop_hint;
use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering};
use std::sync::SgxMutex as Mutex;
use util::ring_buf::{RingBuf, RingBufReader, RingBufWriter};
@ -88,6 +88,9 @@ impl File for UnixSocketFile {
}
}
static SOCKETPAIR_NUM: AtomicUsize = AtomicUsize::new(0);
const SOCK_PATH_PREFIX: &str = "socketpair_";
impl UnixSocketFile {
pub fn new(socket_type: c_int, protocol: c_int) -> Result<Self> {
let inner = UnixSocket::new(socket_type, protocol)?;
@ -123,6 +126,34 @@ impl UnixSocketFile {
let mut inner = self.inner.lock().unwrap();
inner.poll()
}
pub fn socketpair(socket_type: i32, protocol: i32) -> Result<(Self, Self)> {
let listen_socket = Self::new(socket_type, protocol)?;
let bound_path = listen_socket.bind_until_success();
listen_socket.listen()?;
let client_socket = Self::new(socket_type, protocol)?;
client_socket.connect(&bound_path)?;
let accepted_socket = listen_socket.accept()?;
Ok((client_socket, accepted_socket))
}
fn bind_until_success(&self) -> String {
let mut path = SOCK_PATH_PREFIX.to_string();
let mut index = SOCKETPAIR_NUM.load(Ordering::SeqCst);
path.push_str(&index.to_string());
while self.bind(&path).is_err() {
if index == std::usize::MAX {
SOCKETPAIR_NUM.store(0, Ordering::SeqCst); //flip SOCKETPAIR_NUM
}
index += 1;
path = SOCK_PATH_PREFIX.to_string();
path.push_str(&index.to_string());
}
SOCKETPAIR_NUM.fetch_max(index + 1, Ordering::SeqCst);
path
}
}
impl Debug for UnixSocketFile {

@ -8,6 +8,7 @@
#![feature(range_contains)]
#![feature(core_intrinsics)]
#![feature(stmt_expr_attributes)]
#![feature(atomic_min_max)]
#[macro_use]
extern crate alloc;

@ -295,6 +295,12 @@ pub extern "C" fn dispatch_syscall(
arg4 as *mut libc::sockaddr,
arg5 as *mut libc::socklen_t,
),
SYS_SOCKETPAIR => do_socketpair(
arg0 as c_int,
arg1 as c_int,
arg2 as c_int,
arg3 as *mut c_int,
),
_ => do_unknown(num, arg0, arg1, arg2, arg3, arg4, arg5),
};
@ -1026,7 +1032,7 @@ fn do_sched_setaffinity(pid: pid_t, cpusize: size_t, buf: *const c_uchar) -> Res
fn do_socket(domain: c_int, socket_type: c_int, protocol: c_int) -> Result<isize> {
info!(
"socket: domain: {}, socket_type: {}, protocol: {}",
"socket: domain: {}, socket_type: 0x{:x}, protocol: {}",
domain, socket_type, protocol
);
@ -1320,6 +1326,46 @@ fn do_recvfrom(
Ok(ret as isize)
}
fn do_socketpair(
domain: c_int,
socket_type: c_int,
protocol: c_int,
sv: *mut c_int,
) -> Result<isize> {
info!(
"socketpair: domain: {}, type:0x{:x}, protocol: {}",
domain, socket_type, protocol
);
let mut sock_pair = unsafe {
check_mut_array(sv, 2)?;
std::slice::from_raw_parts_mut(sv as *mut u32, 2)
};
if (domain == libc::AF_UNIX) {
let (client_socket, server_socket) =
UnixSocketFile::socketpair(socket_type as i32, protocol as i32)?;
let current_ref = process::get_current();
let mut proc = current_ref.lock().unwrap();
sock_pair[0] = proc
.get_files()
.lock()
.unwrap()
.put(Arc::new(Box::new(client_socket)), false);
sock_pair[1] = proc
.get_files()
.lock()
.unwrap()
.put(Arc::new(Box::new(server_socket)), false);
info!("socketpair: ({}, {})", sock_pair[0], sock_pair[1]);
Ok(0)
} else if (domain == libc::AF_TIPC) {
return_errno!(EAFNOSUPPORT, "cluster domain sockets not supported")
} else {
return_errno!(EAFNOSUPPORT, "domain not supported")
}
}
fn do_select(
nfds: c_int,
readfds: *mut libc::fd_set,

@ -9,106 +9,182 @@
#include <string.h>
#include <spawn.h>
const char SOCK_PATH[] = "echo_socket";
#include "test.h"
int create_server_socket() {
int fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (fd == -1) {
printf("ERROR: failed to create a unix socket\n");
return -1;
}
#define ECHO_MSG "echo msg for unix_socket test"
#define THROW_ERROR(msg) do { \
printf("\t\tERROR: %s in func %s at line %d of file %s\n", \
(msg), __func__, __LINE__, __FILE__); \
return -1; \
} while(0)
struct sockaddr_un local;
local.sun_family = AF_UNIX;
strcpy(local.sun_path, SOCK_PATH);
socklen_t len = strlen(local.sun_path) + sizeof(local.sun_family);
int create_connected_sockets(int *sockets, char *sock_path) {
int listen_fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (listen_fd == -1) {
THROW_ERROR("failed to create a unix socket");
}
if (bind(fd, (struct sockaddr *)&local, len) == -1) {
printf("ERROR: failed to bind\n");
return -1;
}
struct sockaddr_un addr;
memset(&addr, 0, sizeof(struct sockaddr_un)); //Clear structure
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, sock_path);
socklen_t addr_len = strlen(addr.sun_path) + sizeof(addr.sun_family);
if (bind(listen_fd, (struct sockaddr *)&addr, addr_len) == -1) {
close(listen_fd);
THROW_ERROR("failed to bind");
}
if (listen(fd, 5) == -1) {
printf("ERROR: failed to listen\n");
return -1;
}
return fd;
if (listen(listen_fd, 5) == -1) {
close(listen_fd);
THROW_ERROR("failed to listen");
}
int client_fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (client_fd == -1) {
close(listen_fd);
THROW_ERROR("failed to create a unix socket");
}
if (connect(client_fd, (struct sockaddr *)&addr, addr_len) == -1) {
close(listen_fd);
close(client_fd);
THROW_ERROR("failed to connect");
}
int accepted_fd = accept(listen_fd, (struct sockaddr *)&addr, &addr_len);
if (accepted_fd == -1) {
close(listen_fd);
close(client_fd);
THROW_ERROR("failed to accept socket");
}
sockets[0] = client_fd;
sockets[1] = accepted_fd;
close(listen_fd);
return 0;
}
int create_client_socket() {
int fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (fd == -1) {
printf("ERROR: failed to create a unix socket\n");
return -1;
}
struct sockaddr_un remote;
remote.sun_family = AF_UNIX;
strcpy(remote.sun_path, SOCK_PATH);
socklen_t len = strlen(remote.sun_path) + sizeof(remote.sun_family);
if (connect(fd, (struct sockaddr *)&remote, len) == -1) {
printf("ERROR: failed to connect\n");
return -1;
}
return fd;
int create_connceted_sockets_default(int *sockets) {
return create_connected_sockets(sockets, "unix_socket_default_path");
}
int main(int argc, const char* argv[]) {
int listen_fd = create_server_socket();
if (listen_fd == -1) {
printf("ERROR: failed to create server socket\n");
return -1;
}
int socket_rd_fd = create_client_socket();
if (socket_rd_fd == -1) {
printf("ERROR: failed to create client socket\n");
return -1;
}
struct sockaddr_un remote;
socklen_t len = sizeof(remote);
int socket_wr_fd = accept(listen_fd, (struct sockaddr *)&remote, &len);
if (socket_wr_fd == -1) {
printf("ERROR: failed to accept socket\n");
return -1;
}
// The following is same as 'pipe'
posix_spawn_file_actions_t file_actions;
posix_spawn_file_actions_init(&file_actions);
posix_spawn_file_actions_adddup2(&file_actions, socket_wr_fd, STDOUT_FILENO);
posix_spawn_file_actions_addclose(&file_actions, socket_rd_fd);
const char* msg = "Echo!\n";
int verify_child_echo(int *connected_sockets) {
const char* child_prog = "/bin/hello_world";
const char* child_argv[3] = { child_prog, msg, NULL };
const char* child_argv[3] = { child_prog, ECHO_MSG, NULL };
int child_pid;
posix_spawn_file_actions_t file_actions;
posix_spawn_file_actions_init(&file_actions);
posix_spawn_file_actions_adddup2(&file_actions, connected_sockets[0], STDOUT_FILENO);
posix_spawn_file_actions_addclose(&file_actions, connected_sockets[1]);
if (posix_spawn(&child_pid, child_prog, &file_actions,
NULL, (char*const*)child_argv, NULL) < 0) {
printf("ERROR: failed to spawn a child process\n");
return -1;
THROW_ERROR("failed to spawn a child process");
}
close(socket_wr_fd);
const char* expected_str = msg;
size_t expected_len = strlen(expected_str);
char actual_str[32] = {0};
ssize_t actual_len;
//TODO: implement blocking read
do {
actual_len = read(socket_rd_fd, actual_str, sizeof(actual_str) - 1);
actual_len = read(connected_sockets[1], actual_str, 32);
} while (actual_len == 0);
if (strncmp(expected_str, actual_str, expected_len) != 0) {
printf("ERROR: received string is not as expected\n");
return -1;
if (strncmp(actual_str, ECHO_MSG, sizeof(ECHO_MSG) - 1) != 0) {
printf("data read is :%s\n", actual_str);
THROW_ERROR("received string is not as expected");
}
int status = 0;
if (wait4(child_pid, &status, 0, NULL) < 0) {
printf("ERROR: failed to wait4 the child process\n");
return -1;
THROW_ERROR("failed to wait4 the child process");
}
return 0;
}
int verify_connection(int src_sock, int dest_sock) {
char buf[1024];
int i;
for (i = 0; i < 100; i++) {
if (write(src_sock, ECHO_MSG, sizeof(ECHO_MSG)) < 0) {
THROW_ERROR("writing server message");
}
if (read(dest_sock, buf, 1024) < 0) {
THROW_ERROR("reading server message");
}
if (strncmp(buf, ECHO_MSG, sizeof(ECHO_MSG)) != 0) {
THROW_ERROR("msg received mismatch");
}
}
return 0;
}
//this value should not be too large as one pair consumes 2MB memory
#define PAIR_NUM 15
int test_multiple_socketpairs() {
int sockets[PAIR_NUM][2];
int i;
int ret = 0;
for(i = 0; i < PAIR_NUM; i++) {
if (socketpair(AF_UNIX, SOCK_STREAM, 0, sockets[i]) < 0) {
THROW_ERROR("opening stream socket pair");
}
if(verify_connection(sockets[i][0], sockets[i][1]) < 0) {
ret = -1;
goto cleanup;
}
if(verify_connection(sockets[i][1], sockets[i][0]) < 0) {
ret = -1;
goto cleanup;
}
}
i--;
cleanup:
for(; i >= 0; i--){
close(sockets[i][0]);
close(sockets[i][1]);
}
return ret;
}
int socketpair_default(int *sockets) {
return socketpair(AF_UNIX, SOCK_STREAM, 0, sockets);
}
typedef int(*create_connection_func_t)(int *);
int test_connected_sockets_inter_process(create_connection_func_t fn) {
int ret = 0;
int sockets[2];
if (fn(sockets) < 0)
return -1;
ret = verify_child_echo(sockets);
close(sockets[0]);
close(sockets[1]);
return ret;
}
int test_unix_socket_inter_process() {
return test_connected_sockets_inter_process(socketpair_default);
}
int test_socketpair_inter_process() {
return test_connected_sockets_inter_process(create_connceted_sockets_default);
}
static test_case_t test_cases[] = {
TEST_CASE(test_unix_socket_inter_process),
TEST_CASE(test_socketpair_inter_process),
TEST_CASE(test_multiple_socketpairs),
};
int main(int argc, const char* argv[]) {
return test_suite_run(test_cases, ARRAY_SIZE(test_cases));
}