Add sendmsg/recvmsg syscalls
1. Add a separate net/ directory for the network subsystem; 2. Move some existing socket code to net/; 3. Implement sendmsg/recvmsg with OCalls; 4. Extend client/server test cases.
This commit is contained in:
parent
2052447950
commit
0cef5b1b53
@ -40,5 +40,30 @@ enclave {
|
||||
[out] sgx_report_t* qe_report,
|
||||
[out, size=quote_buf_len] uint8_t* quote_buf,
|
||||
uint32_t quote_buf_len);
|
||||
};
|
||||
|
||||
int64_t ocall_sendmsg(
|
||||
int sockfd,
|
||||
[in, size=msg_namelen] const void* msg_name,
|
||||
socklen_t msg_namelen,
|
||||
[in, size=buf_len] const void* buf,
|
||||
size_t buf_len,
|
||||
[in, size=msg_controllen] const void* msg_control,
|
||||
size_t msg_controllen,
|
||||
int flags
|
||||
) propagate_errno;
|
||||
|
||||
int64_t ocall_recvmsg(
|
||||
int sockfd,
|
||||
[out, size=msg_namelen] void *msg_name,
|
||||
socklen_t msg_namelen,
|
||||
[out] socklen_t* msg_namelen_recv,
|
||||
[out, size=buf_len] void* buf,
|
||||
size_t buf_len,
|
||||
[out, size=msg_controllen] void *msg_control,
|
||||
size_t msg_controllen,
|
||||
[out] size_t* msg_controllen_recv,
|
||||
[out] int* msg_flags_recv,
|
||||
int flags
|
||||
) propagate_errno;
|
||||
};
|
||||
};
|
||||
|
@ -141,6 +141,7 @@ pub enum Errno {
|
||||
EHWPOISON = 133,
|
||||
// Note: always keep the last item in sync with ERRNO_MAX
|
||||
}
|
||||
const ERRNO_MIN: u32 = Errno::EPERM as u32;
|
||||
const ERRNO_MAX: u32 = Errno::EHWPOISON as u32;
|
||||
|
||||
impl Errno {
|
||||
@ -192,10 +193,8 @@ impl Errno {
|
||||
}
|
||||
|
||||
impl From<u32> for Errno {
|
||||
fn from(mut raw_errno: u32) -> Self {
|
||||
if raw_errno > ERRNO_MAX {
|
||||
raw_errno = 0;
|
||||
}
|
||||
fn from(raw_errno: u32) -> Self {
|
||||
assert!(ERRNO_MIN <= raw_errno && raw_errno <= ERRNO_MAX);
|
||||
unsafe { core::mem::transmute(raw_errno as u8) }
|
||||
}
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ macro_rules! return_errno {
|
||||
macro_rules! try_libc {
|
||||
($ret: expr) => {{
|
||||
let ret = unsafe { $ret };
|
||||
if ret == -1 {
|
||||
if ret < 0 {
|
||||
let errno = unsafe { libc::errno() };
|
||||
return_errno!(Errno::from(errno as u32), "libc error");
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
use super::*;
|
||||
|
||||
use net::{AsSocket, SocketFile};
|
||||
use process::Process;
|
||||
use rcore_fs::vfs::{FileType, FsError, INode, Metadata, Timespec};
|
||||
use std::io::{Read, Seek, SeekFrom, Write};
|
||||
@ -19,7 +20,6 @@ pub use self::io_multiplexing::*;
|
||||
pub use self::ioctl::*;
|
||||
pub use self::pipe::Pipe;
|
||||
pub use self::root_inode::ROOT_INODE;
|
||||
pub use self::socket_file::{AsSocket, SocketFile};
|
||||
pub use self::unix_socket::{AsUnixSocket, UnixSocketFile};
|
||||
use sgx_trts::libc::S_IWUSR;
|
||||
use std::any::Any;
|
||||
@ -39,7 +39,6 @@ mod ioctl;
|
||||
mod pipe;
|
||||
mod root_inode;
|
||||
mod sgx_impl;
|
||||
mod socket_file;
|
||||
mod unix_socket;
|
||||
|
||||
pub fn do_open(path: &str, flags: u32, mode: u32) -> Result<FileDesc> {
|
||||
|
@ -57,6 +57,7 @@ mod entry;
|
||||
mod exception;
|
||||
mod fs;
|
||||
mod misc;
|
||||
mod net;
|
||||
mod process;
|
||||
mod syscall;
|
||||
mod time;
|
||||
|
88
src/libos/src/net/iovs.rs
Normal file
88
src/libos/src/net/iovs.rs
Normal file
@ -0,0 +1,88 @@
|
||||
//! I/O vectors
|
||||
|
||||
use super::*;
|
||||
|
||||
/// A memory safe, immutable version of C iovec array
|
||||
pub struct Iovs<'a> {
|
||||
iovs: Vec<&'a [u8]>,
|
||||
}
|
||||
|
||||
impl<'a> Iovs<'a> {
|
||||
pub fn new(slices: Vec<&'a [u8]>) -> Iovs {
|
||||
Self { iovs: slices }
|
||||
}
|
||||
|
||||
pub fn as_slices(&self) -> &[&[u8]] {
|
||||
&self.iovs[..]
|
||||
}
|
||||
|
||||
pub fn total_bytes(&self) -> usize {
|
||||
self.iovs.iter().map(|s| s.len()).sum()
|
||||
}
|
||||
|
||||
pub fn gather_to_vec(&self) -> Vec<u8> {
|
||||
Self::gather_slices_to_vec(&self.iovs[..])
|
||||
}
|
||||
|
||||
fn gather_slices_to_vec(slices: &[&[u8]]) -> Vec<u8> {
|
||||
let vec_len = slices.iter().map(|slice| slice.len()).sum();
|
||||
let mut vec = Vec::with_capacity(vec_len);
|
||||
for slice in slices {
|
||||
vec.extend_from_slice(slice);
|
||||
}
|
||||
vec
|
||||
}
|
||||
}
|
||||
|
||||
/// A memory safe, mutable version of C iovec array
|
||||
pub struct IovsMut<'a> {
|
||||
iovs: Vec<&'a mut [u8]>,
|
||||
}
|
||||
|
||||
impl<'a> IovsMut<'a> {
|
||||
pub fn new(slices: Vec<&'a mut [u8]>) -> Self {
|
||||
Self { iovs: slices }
|
||||
}
|
||||
|
||||
pub fn as_slices<'b>(&'b self) -> &'b [&'a [u8]] {
|
||||
let slices_mut: &'b [&'a mut [u8]] = &self.iovs[..];
|
||||
// We are "downgrading" _mutable_ slices to _immutable_ ones. It should be
|
||||
// safe to do this transmute
|
||||
unsafe { std::mem::transmute(slices_mut) }
|
||||
}
|
||||
|
||||
pub fn as_slices_mut<'b>(&'b mut self) -> &'b mut [&'a mut [u8]] {
|
||||
&mut self.iovs[..]
|
||||
}
|
||||
|
||||
pub fn total_bytes(&self) -> usize {
|
||||
self.iovs.iter().map(|s| s.len()).sum()
|
||||
}
|
||||
|
||||
pub fn gather_to_vec(&self) -> Vec<u8> {
|
||||
Iovs::gather_slices_to_vec(self.as_slices())
|
||||
}
|
||||
|
||||
pub fn scatter_copy_from(&mut self, data: &[u8]) -> usize {
|
||||
let mut total_nbytes = 0;
|
||||
let mut remain_slice = data;
|
||||
for iov in &mut self.iovs {
|
||||
if remain_slice.len() == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
let copy_nbytes = remain_slice.len().min(iov.len());
|
||||
let dst_slice = unsafe {
|
||||
debug_assert!(iov.len() >= copy_nbytes);
|
||||
iov.get_unchecked_mut(..copy_nbytes)
|
||||
};
|
||||
let (src_slice, _remain_slice) = remain_slice.split_at(copy_nbytes);
|
||||
dst_slice.copy_from_slice(src_slice);
|
||||
|
||||
remain_slice = _remain_slice;
|
||||
total_nbytes += copy_nbytes;
|
||||
}
|
||||
debug_assert!(remain_slice.len() == 0);
|
||||
total_nbytes
|
||||
}
|
||||
}
|
13
src/libos/src/net/mod.rs
Normal file
13
src/libos/src/net/mod.rs
Normal file
@ -0,0 +1,13 @@
|
||||
use super::*;
|
||||
|
||||
mod iovs;
|
||||
mod msg;
|
||||
mod msg_flags;
|
||||
mod socket_file;
|
||||
mod syscalls;
|
||||
|
||||
pub use self::iovs::{Iovs, IovsMut};
|
||||
pub use self::msg::{msghdr, msghdr_mut, MsgHdr, MsgHdrMut};
|
||||
pub use self::msg_flags::MsgFlags;
|
||||
pub use self::socket_file::{AsSocket, SocketFile};
|
||||
pub use self::syscalls::{do_recvmsg, do_sendmsg};
|
231
src/libos/src/net/msg.rs
Normal file
231
src/libos/src/net/msg.rs
Normal file
@ -0,0 +1,231 @@
|
||||
/// Socket message and its flags.
|
||||
use super::*;
|
||||
|
||||
/// C struct for a socket message with const pointers
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct msghdr {
|
||||
pub msg_name: *const c_void,
|
||||
pub msg_namelen: libc::socklen_t,
|
||||
pub msg_iov: *const libc::iovec,
|
||||
pub msg_iovlen: size_t,
|
||||
pub msg_control: *const c_void,
|
||||
pub msg_controllen: size_t,
|
||||
pub msg_flags: c_int,
|
||||
}
|
||||
|
||||
/// C struct for a socket message with mutable pointers
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct msghdr_mut {
|
||||
pub msg_name: *mut c_void,
|
||||
pub msg_namelen: libc::socklen_t,
|
||||
pub msg_iov: *mut libc::iovec,
|
||||
pub msg_iovlen: size_t,
|
||||
pub msg_control: *mut c_void,
|
||||
pub msg_controllen: size_t,
|
||||
pub msg_flags: c_int,
|
||||
}
|
||||
|
||||
/// MsgHdr is a memory-safe, immutable wrapper of msghdr
|
||||
pub struct MsgHdr<'a> {
|
||||
name: Option<&'a [u8]>,
|
||||
iovs: Iovs<'a>,
|
||||
control: Option<&'a [u8]>,
|
||||
flags: MsgFlags,
|
||||
c_self: &'a msghdr,
|
||||
}
|
||||
|
||||
impl<'a> MsgHdr<'a> {
|
||||
/// Wrap a unsafe msghdr into a safe MsgHdr
|
||||
pub unsafe fn from_c(c_msg: &'a msghdr) -> Result<MsgHdr> {
|
||||
// Convert c_msg's (*mut T, usize)-pair fields to Option<&mut [T]>
|
||||
let name_opt_slice =
|
||||
new_optional_slice(c_msg.msg_name as *const u8, c_msg.msg_namelen as usize);
|
||||
let iovs_opt_slice = new_optional_slice(
|
||||
c_msg.msg_iov as *const libc::iovec,
|
||||
c_msg.msg_iovlen as usize,
|
||||
);
|
||||
let control_opt_slice = new_optional_slice(
|
||||
c_msg.msg_control as *const u8,
|
||||
c_msg.msg_controllen as usize,
|
||||
);
|
||||
|
||||
let flags = MsgFlags::from_u32(c_msg.msg_flags as u32)?;
|
||||
|
||||
let iovs = {
|
||||
let iovs_vec = match iovs_opt_slice {
|
||||
Some(iovs_slice) => iovs_slice
|
||||
.iter()
|
||||
.flat_map(|iov| new_optional_slice(iov.iov_base as *const u8, iov.iov_len))
|
||||
.collect(),
|
||||
None => Vec::new(),
|
||||
};
|
||||
Iovs::new(iovs_vec)
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
name: name_opt_slice,
|
||||
iovs: iovs,
|
||||
control: control_opt_slice,
|
||||
flags: flags,
|
||||
c_self: c_msg,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_iovs(&self) -> &Iovs {
|
||||
&self.iovs
|
||||
}
|
||||
|
||||
pub fn get_name(&self) -> Option<&[u8]> {
|
||||
self.name
|
||||
}
|
||||
|
||||
pub fn get_control(&self) -> Option<&[u8]> {
|
||||
self.control
|
||||
}
|
||||
|
||||
pub fn get_flags(&self) -> MsgFlags {
|
||||
self.flags
|
||||
}
|
||||
}
|
||||
|
||||
/// MsgHdrMut is a memory-safe, mutable wrapper of msghdr_mut
|
||||
pub struct MsgHdrMut<'a> {
|
||||
name: Option<&'a mut [u8]>,
|
||||
iovs: IovsMut<'a>,
|
||||
control: Option<&'a mut [u8]>,
|
||||
flags: MsgFlags,
|
||||
c_self: &'a mut msghdr_mut,
|
||||
}
|
||||
|
||||
// TODO: use macros to eliminate redundant code between MsgHdr and MsgHdrMut
|
||||
impl<'a> MsgHdrMut<'a> {
|
||||
/// Wrap a unsafe msghdr_mut into a safe MsgHdrMut
|
||||
pub unsafe fn from_c(c_msg: &'a mut msghdr_mut) -> Result<MsgHdrMut> {
|
||||
// Convert c_msg's (*mut T, usize)-pair fields to Option<&mut [T]>
|
||||
let name_opt_slice =
|
||||
new_optional_slice_mut(c_msg.msg_name as *mut u8, c_msg.msg_namelen as usize);
|
||||
let iovs_opt_slice =
|
||||
new_optional_slice_mut(c_msg.msg_iov as *mut libc::iovec, c_msg.msg_iovlen as usize);
|
||||
let control_opt_slice =
|
||||
new_optional_slice_mut(c_msg.msg_control as *mut u8, c_msg.msg_controllen as usize);
|
||||
|
||||
let flags = MsgFlags::from_u32(c_msg.msg_flags as u32)?;
|
||||
|
||||
let iovs = {
|
||||
let iovs_vec = match iovs_opt_slice {
|
||||
Some(iovs_slice) => iovs_slice
|
||||
.iter()
|
||||
.flat_map(|iov| new_optional_slice_mut(iov.iov_base as *mut u8, iov.iov_len))
|
||||
.collect(),
|
||||
None => Vec::new(),
|
||||
};
|
||||
IovsMut::new(iovs_vec)
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
name: name_opt_slice,
|
||||
iovs: iovs,
|
||||
control: control_opt_slice,
|
||||
flags: flags,
|
||||
c_self: c_msg,
|
||||
})
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
// Immutable interfaces (same as MsgHdr)
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
|
||||
pub fn get_iovs(&self) -> &IovsMut {
|
||||
&self.iovs
|
||||
}
|
||||
|
||||
pub fn get_name(&self) -> Option<&[u8]> {
|
||||
self.name.as_ref().map(|name| &name[..])
|
||||
}
|
||||
|
||||
pub fn get_control(&self) -> Option<&[u8]> {
|
||||
self.control.as_ref().map(|control| &control[..])
|
||||
}
|
||||
|
||||
pub fn get_flags(&self) -> MsgFlags {
|
||||
self.flags
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
// Mutable interfaces (unique to MsgHdrMut)
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
|
||||
pub fn get_iovs_mut<'b>(&'b mut self) -> &'b mut IovsMut<'a> {
|
||||
&mut self.iovs
|
||||
}
|
||||
|
||||
pub fn get_name_mut(&mut self) -> Option<&mut [u8]> {
|
||||
self.name.as_mut().map(|name| &mut name[..])
|
||||
}
|
||||
|
||||
pub fn get_name_max_len(&self) -> usize {
|
||||
self.name.as_ref().map(|name| name.len()).unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn set_name_len(&mut self, new_name_len: usize) -> Result<()> {
|
||||
if new_name_len > self.get_name_max_len() {
|
||||
return_errno!(EINVAL, "new_name_len is too big");
|
||||
}
|
||||
self.c_self.msg_namelen = new_name_len as libc::socklen_t;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_control_mut(&mut self) -> Option<&mut [u8]> {
|
||||
self.control.as_mut().map(|control| &mut control[..])
|
||||
}
|
||||
|
||||
pub fn get_control_max_len(&self) -> usize {
|
||||
self.control
|
||||
.as_ref()
|
||||
.map(|control| control.len())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn set_control_len(&mut self, new_control_len: usize) -> Result<()> {
|
||||
if new_control_len > self.get_control_max_len() {
|
||||
return_errno!(EINVAL, "new_control_len is too big");
|
||||
}
|
||||
self.c_self.msg_controllen = new_control_len;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_name_and_control_mut(&mut self) -> (Option<&mut [u8]>, Option<&mut [u8]>) {
|
||||
(
|
||||
self.name.as_mut().map(|name| &mut name[..]),
|
||||
self.control.as_mut().map(|control| &mut control[..]),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn set_flags(&mut self, flags: MsgFlags) {
|
||||
self.flags = flags;
|
||||
self.c_self.msg_flags = flags.to_u32() as i32;
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn new_optional_slice<'a, T>(slice_ptr: *const T, slice_size: usize) -> Option<&'a [T]> {
|
||||
if !slice_ptr.is_null() {
|
||||
let slice = core::slice::from_raw_parts::<T>(slice_ptr, slice_size);
|
||||
Some(slice)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn new_optional_slice_mut<'a, T>(
|
||||
slice_ptr: *mut T,
|
||||
slice_size: usize,
|
||||
) -> Option<&'a mut [T]> {
|
||||
if !slice_ptr.is_null() {
|
||||
let slice = core::slice::from_raw_parts_mut::<T>(slice_ptr, slice_size);
|
||||
Some(slice)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
17
src/libos/src/net/msg_flags.rs
Normal file
17
src/libos/src/net/msg_flags.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use super::*;
|
||||
|
||||
// TODO: use bitflag! to make this memory safe
|
||||
#[derive(Debug, Copy, Clone, Default)]
|
||||
pub struct MsgFlags {
|
||||
bits: u32,
|
||||
}
|
||||
|
||||
impl MsgFlags {
|
||||
pub fn from_u32(c_flags: u32) -> Result<MsgFlags> {
|
||||
Ok(MsgFlags { bits: 0 })
|
||||
}
|
||||
|
||||
pub fn to_u32(&self) -> u32 {
|
||||
self.bits
|
||||
}
|
||||
}
|
@ -1,16 +1,22 @@
|
||||
use super::*;
|
||||
|
||||
mod recv;
|
||||
mod send;
|
||||
|
||||
use fs::{File, FileRef, IoctlCmd};
|
||||
use std::any::Any;
|
||||
use std::io::{Read, Seek, SeekFrom, Write};
|
||||
|
||||
/// Native Linux socket
|
||||
#[derive(Debug)]
|
||||
pub struct SocketFile {
|
||||
fd: c_int,
|
||||
host_fd: c_int,
|
||||
}
|
||||
|
||||
impl SocketFile {
|
||||
pub fn new(domain: c_int, socket_type: c_int, protocol: c_int) -> Result<Self> {
|
||||
let ret = try_libc!(libc::ocall::socket(domain, socket_type, protocol));
|
||||
Ok(SocketFile { fd: ret })
|
||||
Ok(SocketFile { host_fd: ret })
|
||||
}
|
||||
|
||||
pub fn accept(
|
||||
@ -19,32 +25,28 @@ impl SocketFile {
|
||||
addr_len: *mut libc::socklen_t,
|
||||
flags: c_int,
|
||||
) -> Result<Self> {
|
||||
let ret = try_libc!(libc::ocall::accept4(self.fd, addr, addr_len, flags));
|
||||
Ok(SocketFile { fd: ret })
|
||||
let ret = try_libc!(libc::ocall::accept4(self.host_fd, addr, addr_len, flags));
|
||||
Ok(SocketFile { host_fd: ret })
|
||||
}
|
||||
|
||||
pub fn fd(&self) -> c_int {
|
||||
self.fd
|
||||
self.host_fd
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SocketFile {
|
||||
fn drop(&mut self) {
|
||||
let ret = unsafe { libc::ocall::close(self.fd) };
|
||||
if ret < 0 {
|
||||
let errno = unsafe { libc::errno() };
|
||||
warn!(
|
||||
"socket (host fd: {}) close failed: errno = {}",
|
||||
self.fd, errno
|
||||
);
|
||||
}
|
||||
let ret = unsafe { libc::ocall::close(self.host_fd) };
|
||||
assert!(ret == 0);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: rewrite read/write/readv/writev as send/recv
|
||||
// TODO: implement readfrom/sendto
|
||||
impl File for SocketFile {
|
||||
fn read(&self, buf: &mut [u8]) -> Result<usize> {
|
||||
let ret = try_libc!(libc::ocall::read(
|
||||
self.fd,
|
||||
self.host_fd,
|
||||
buf.as_mut_ptr() as *mut c_void,
|
||||
buf.len()
|
||||
));
|
||||
@ -53,7 +55,7 @@ impl File for SocketFile {
|
||||
|
||||
fn write(&self, buf: &[u8]) -> Result<usize> {
|
||||
let ret = try_libc!(libc::ocall::write(
|
||||
self.fd,
|
||||
self.host_fd,
|
||||
buf.as_ptr() as *const c_void,
|
||||
buf.len()
|
||||
));
|
||||
@ -100,25 +102,6 @@ impl File for SocketFile {
|
||||
return_errno!(ESPIPE, "Socket does not support seek")
|
||||
}
|
||||
|
||||
fn metadata(&self) -> Result<Metadata> {
|
||||
Ok(Metadata {
|
||||
dev: 0,
|
||||
inode: 0,
|
||||
size: 0,
|
||||
blk_size: 0,
|
||||
blocks: 0,
|
||||
atime: Timespec { sec: 0, nsec: 0 },
|
||||
mtime: Timespec { sec: 0, nsec: 0 },
|
||||
ctime: Timespec { sec: 0, nsec: 0 },
|
||||
type_: FileType::Socket,
|
||||
mode: 0,
|
||||
nlinks: 0,
|
||||
uid: 0,
|
||||
gid: 0,
|
||||
rdev: 0,
|
||||
})
|
||||
}
|
||||
|
||||
fn ioctl(&self, cmd: &mut IoctlCmd) -> Result<()> {
|
||||
let cmd_num = cmd.cmd_num() as c_int;
|
||||
let cmd_arg_ptr = cmd.arg_ptr() as *const c_int;
|
138
src/libos/src/net/socket_file/recv.rs
Normal file
138
src/libos/src/net/socket_file/recv.rs
Normal file
@ -0,0 +1,138 @@
|
||||
use super::*;
|
||||
|
||||
impl SocketFile {
|
||||
// TODO: need sockaddr type to implement send/sento
|
||||
/*
|
||||
pub fn recv(&self, buf: &mut [u8], flags: MsgFlags) -> Result<usize> {
|
||||
let (bytes_recvd, _) = self.recvfrom(buf, flags, None)?;
|
||||
Ok(bytes_recvd)
|
||||
}
|
||||
|
||||
pub fn recvfrom(&self, buf: &mut [u8], flags: MsgFlags, src_addr: Option<&mut [u8]>) -> Result<(usize, usize)> {
|
||||
let (bytes_recvd, src_addr_len, _, _) = self.do_recvmsg(
|
||||
&mut buf[..],
|
||||
flags,
|
||||
src_addr,
|
||||
None,
|
||||
)?;
|
||||
Ok((bytes_recvd, src_addr_len))
|
||||
}*/
|
||||
|
||||
pub fn recvmsg<'a, 'b>(&self, msg: &'b mut MsgHdrMut<'a>, flags: MsgFlags) -> Result<usize> {
|
||||
// Allocate a single data buffer is big enough for all iovecs of msg.
|
||||
// This is a workaround for the OCall that takes only one data buffer.
|
||||
let mut data_buf = {
|
||||
let data_buf_len = msg.get_iovs().total_bytes();
|
||||
let data_vec = vec![0; data_buf_len];
|
||||
data_vec.into_boxed_slice()
|
||||
};
|
||||
|
||||
let (bytes_recvd, namelen_recvd, controllen_recvd, flags_recvd) = {
|
||||
let data = &mut data_buf[..];
|
||||
// Acquire mutable references to the name and control buffers
|
||||
let (name, control) = msg.get_name_and_control_mut();
|
||||
// Fill the data, the name, and the control buffers
|
||||
self.do_recvmsg(data, flags, name, control)?
|
||||
};
|
||||
|
||||
// Update the lengths and flags
|
||||
msg.set_name_len(namelen_recvd)?;
|
||||
msg.set_control_len(controllen_recvd)?;
|
||||
msg.set_flags(flags_recvd);
|
||||
|
||||
let recv_data = &data_buf[..bytes_recvd];
|
||||
// TODO: avoid this one extra copy due to the intermediate data buffer
|
||||
msg.get_iovs_mut().scatter_copy_from(recv_data);
|
||||
|
||||
Ok(bytes_recvd)
|
||||
}
|
||||
|
||||
fn do_recvmsg(
|
||||
&self,
|
||||
data: &mut [u8],
|
||||
flags: MsgFlags,
|
||||
mut name: Option<&mut [u8]>,
|
||||
mut control: Option<&mut [u8]>,
|
||||
) -> Result<(usize, usize, usize, MsgFlags)> {
|
||||
// Prepare the arguments for OCall
|
||||
// Host socket fd
|
||||
let host_fd = self.host_fd;
|
||||
// Name
|
||||
let (msg_name, msg_namelen) = name.get_mut_ptr_and_len();
|
||||
let msg_name = msg_name as *mut c_void;
|
||||
let mut msg_namelen_recvd = 0_u32;
|
||||
// Data
|
||||
let msg_data = data.as_mut_ptr();
|
||||
let msg_datalen = data.len();
|
||||
// Control
|
||||
let (msg_control, msg_controllen) = control.get_mut_ptr_and_len();
|
||||
let msg_control = msg_control as *mut c_void;
|
||||
let mut msg_controllen_recvd = 0;
|
||||
// Flags
|
||||
let flags = flags.to_u32() as i32;
|
||||
let mut msg_flags_recvd = 0;
|
||||
|
||||
// Do OCall
|
||||
let retval = try_libc!({
|
||||
let mut retval = 0_isize;
|
||||
let status = ocall_recvmsg(
|
||||
&mut retval as *mut isize,
|
||||
host_fd,
|
||||
msg_name,
|
||||
msg_namelen as u32,
|
||||
&mut msg_namelen_recvd as *mut u32,
|
||||
msg_data,
|
||||
msg_datalen,
|
||||
msg_control,
|
||||
msg_controllen,
|
||||
&mut msg_controllen_recvd as *mut usize,
|
||||
&mut msg_flags_recvd as *mut i32,
|
||||
flags,
|
||||
);
|
||||
assert!(status == sgx_status_t::SGX_SUCCESS);
|
||||
|
||||
// TODO: what if retval < 0 but buffers are modified by the
|
||||
// untrusted OCall? We reset the potentially tampered buffers.
|
||||
retval
|
||||
});
|
||||
|
||||
// Check values returned from outside the enclave
|
||||
let bytes_recvd = {
|
||||
// Guarantted by try_libc!
|
||||
debug_assert!(retval >= 0);
|
||||
let retval = retval as usize;
|
||||
|
||||
// Check bytes_recvd returned from outside the enclave
|
||||
assert!(retval <= data.len());
|
||||
retval
|
||||
};
|
||||
let msg_namelen_recvd = msg_namelen_recvd as usize;
|
||||
assert!(msg_namelen_recvd <= msg_namelen);
|
||||
assert!(msg_controllen_recvd <= msg_controllen);
|
||||
let flags_recvd = MsgFlags::from_u32(msg_flags_recvd as u32)?;
|
||||
|
||||
Ok((
|
||||
bytes_recvd,
|
||||
msg_namelen_recvd,
|
||||
msg_controllen_recvd,
|
||||
flags_recvd,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
fn ocall_recvmsg(
|
||||
ret: *mut ssize_t,
|
||||
fd: c_int,
|
||||
msg_name: *mut c_void,
|
||||
msg_namelen: libc::socklen_t,
|
||||
msg_namelen_recv: *mut libc::socklen_t,
|
||||
msg_data: *mut u8,
|
||||
msg_data: size_t,
|
||||
msg_control: *mut c_void,
|
||||
msg_controllen: size_t,
|
||||
msg_controllen_recv: *mut size_t,
|
||||
msg_flags: *mut c_int,
|
||||
flags: c_int,
|
||||
) -> sgx_status_t;
|
||||
}
|
84
src/libos/src/net/socket_file/send.rs
Normal file
84
src/libos/src/net/socket_file/send.rs
Normal file
@ -0,0 +1,84 @@
|
||||
use super::*;
|
||||
|
||||
impl SocketFile {
|
||||
// TODO: need sockaddr type to implement send/sento
|
||||
/*
|
||||
pub fn send(&self, buf: &[u8], flags: MsgFlags) -> Result<usize> {
|
||||
self.sendto(buf, flags, None)
|
||||
}
|
||||
|
||||
pub fn sendto(&self, buf: &[u8], flags: MsgFlags, dest_addr: Option<&[u8]>) -> Result<usize> {
|
||||
Self::do_sendmsg(
|
||||
self.host_fd,
|
||||
&buf[..],
|
||||
flags,
|
||||
dest_addr,
|
||||
None)
|
||||
}
|
||||
*/
|
||||
|
||||
pub fn sendmsg<'a, 'b>(&self, msg: &'b MsgHdr<'a>, flags: MsgFlags) -> Result<usize> {
|
||||
// Copy data in iovs into a single buffer
|
||||
let data_buf = msg.get_iovs().gather_to_vec();
|
||||
|
||||
self.do_sendmsg(&data_buf[..], flags, msg.get_name(), msg.get_control())
|
||||
}
|
||||
|
||||
fn do_sendmsg(
|
||||
&self,
|
||||
data: &[u8],
|
||||
flags: MsgFlags,
|
||||
name: Option<&[u8]>,
|
||||
control: Option<&[u8]>,
|
||||
) -> Result<usize> {
|
||||
let bytes_sent = try_libc!({
|
||||
// Prepare the arguments for OCall
|
||||
let mut retval: isize = 0;
|
||||
// Host socket fd
|
||||
let host_fd = self.host_fd;
|
||||
// Name
|
||||
let (msg_name, msg_namelen) = name.get_ptr_and_len();
|
||||
let msg_name = msg_name as *const c_void;
|
||||
// Data
|
||||
let msg_data = data.as_ptr();
|
||||
let msg_datalen = data.len();
|
||||
// Control
|
||||
let (msg_control, msg_controllen) = control.get_ptr_and_len();
|
||||
let msg_control = msg_control as *const c_void;
|
||||
// Flags
|
||||
let flags = flags.to_u32() as i32;
|
||||
|
||||
// Do OCall
|
||||
let status = ocall_sendmsg(
|
||||
&mut retval as *mut isize,
|
||||
host_fd,
|
||||
msg_name,
|
||||
msg_namelen as u32,
|
||||
msg_data,
|
||||
msg_datalen,
|
||||
msg_control,
|
||||
msg_controllen,
|
||||
flags,
|
||||
);
|
||||
assert!(status == sgx_status_t::SGX_SUCCESS);
|
||||
|
||||
retval
|
||||
});
|
||||
debug_assert!(bytes_sent >= 0);
|
||||
Ok(bytes_sent as usize)
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
fn ocall_sendmsg(
|
||||
ret: *mut ssize_t,
|
||||
fd: c_int,
|
||||
msg_name: *const c_void,
|
||||
msg_namelen: libc::socklen_t,
|
||||
msg_data: *const u8,
|
||||
msg_datalen: size_t,
|
||||
msg_control: *const c_void,
|
||||
msg_controllen: size_t,
|
||||
flags: c_int,
|
||||
) -> sgx_status_t;
|
||||
}
|
126
src/libos/src/net/syscalls.rs
Normal file
126
src/libos/src/net/syscalls.rs
Normal file
@ -0,0 +1,126 @@
|
||||
use super::*;
|
||||
|
||||
use fs::{AsUnixSocket, File, FileDesc, FileRef, UnixSocketFile};
|
||||
use process::Process;
|
||||
use util::mem_util::from_user;
|
||||
|
||||
pub fn do_sendmsg(fd: c_int, msg_ptr: *const msghdr, flags_c: c_int) -> Result<isize> {
|
||||
info!(
|
||||
"sendmsg: fd: {}, msg: {:?}, flags: 0x{:x}",
|
||||
fd, msg_ptr, flags_c
|
||||
);
|
||||
let current_ref = process::get_current();
|
||||
let mut proc = current_ref.lock().unwrap();
|
||||
let file_ref = proc.get_files().lock().unwrap().get(fd as FileDesc)?;
|
||||
|
||||
if let Ok(socket) = file_ref.as_socket() {
|
||||
let msg_c = {
|
||||
from_user::check_ptr(msg_ptr)?;
|
||||
let msg_c = unsafe { &*msg_ptr };
|
||||
msg_c.check_member_ptrs()?;
|
||||
msg_c
|
||||
};
|
||||
let msg = unsafe { MsgHdr::from_c(&msg_c)? };
|
||||
|
||||
let flags = MsgFlags::from_u32(flags_c as u32)?;
|
||||
|
||||
socket
|
||||
.sendmsg(&msg, flags)
|
||||
.map(|bytes_sent| bytes_sent as isize)
|
||||
} else if let Ok(socket) = file_ref.as_unix_socket() {
|
||||
return_errno!(EBADF, "does not support unix socket")
|
||||
} else {
|
||||
return_errno!(EBADF, "not a socket")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn do_recvmsg(fd: c_int, msg_mut_ptr: *mut msghdr_mut, flags_c: c_int) -> Result<isize> {
|
||||
info!(
|
||||
"recvmsg: fd: {}, msg: {:?}, flags: 0x{:x}",
|
||||
fd, msg_mut_ptr, flags_c
|
||||
);
|
||||
let current_ref = process::get_current();
|
||||
let mut proc = current_ref.lock().unwrap();
|
||||
let file_ref = proc.get_files().lock().unwrap().get(fd as FileDesc)?;
|
||||
|
||||
if let Ok(socket) = file_ref.as_socket() {
|
||||
let msg_mut_c = {
|
||||
from_user::check_mut_ptr(msg_mut_ptr)?;
|
||||
let msg_mut_c = unsafe { &mut *msg_mut_ptr };
|
||||
msg_mut_c.check_member_ptrs()?;
|
||||
msg_mut_c
|
||||
};
|
||||
let mut msg_mut = unsafe { MsgHdrMut::from_c(msg_mut_c)? };
|
||||
|
||||
let flags = MsgFlags::from_u32(flags_c as u32)?;
|
||||
|
||||
socket
|
||||
.recvmsg(&mut msg_mut, flags)
|
||||
.map(|bytes_recvd| bytes_recvd as isize)
|
||||
} else if let Ok(socket) = file_ref.as_unix_socket() {
|
||||
return_errno!(EBADF, "does not support unix socket")
|
||||
} else {
|
||||
return_errno!(EBADF, "not a socket")
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
trait c_msghdr_ext {
|
||||
fn check_member_ptrs(&self) -> Result<()>;
|
||||
}
|
||||
|
||||
impl c_msghdr_ext for msghdr {
|
||||
// TODO: implement this!
|
||||
fn check_member_ptrs(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
/*
|
||||
///user space check
|
||||
pub unsafe fn check_from_user(user_hdr: *const msghdr) -> Result<()> {
|
||||
Self::check_pointer(user_hdr, from_user::check_ptr)
|
||||
}
|
||||
|
||||
///Check msghdr ptr
|
||||
pub unsafe fn check_pointer(
|
||||
user_hdr: *const msghdr,
|
||||
check_ptr: fn(*const u8) -> Result<()>,
|
||||
) -> Result<()> {
|
||||
check_ptr(user_hdr as *const u8)?;
|
||||
|
||||
if (*user_hdr).msg_name.is_null() ^ ((*user_hdr).msg_namelen == 0) {
|
||||
return_errno!(EINVAL, "name length is invalid");
|
||||
}
|
||||
|
||||
if (*user_hdr).msg_iov.is_null() ^ ((*user_hdr).msg_iovlen == 0) {
|
||||
return_errno!(EINVAL, "iov length is invalid");
|
||||
}
|
||||
|
||||
if (*user_hdr).msg_control.is_null() ^ ((*user_hdr).msg_controllen == 0) {
|
||||
return_errno!(EINVAL, "control length is invalid");
|
||||
}
|
||||
|
||||
if !(*user_hdr).msg_name.is_null() {
|
||||
check_ptr((*user_hdr).msg_name as *const u8)?;
|
||||
}
|
||||
|
||||
if !(*user_hdr).msg_iov.is_null() {
|
||||
check_ptr((*user_hdr).msg_iov as *const u8)?;
|
||||
let iov_slice = slice::from_raw_parts((*user_hdr).msg_iov, (*user_hdr).msg_iovlen);
|
||||
for iov in iov_slice {
|
||||
check_ptr(iov.iov_base as *const u8)?;
|
||||
}
|
||||
}
|
||||
|
||||
if !(*user_hdr).msg_control.is_null() {
|
||||
check_ptr((*user_hdr).msg_control as *const u8)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
impl c_msghdr_ext for msghdr_mut {
|
||||
fn check_member_ptrs(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -31,3 +31,29 @@ pub fn align_down(addr: usize, align: usize) -> usize {
|
||||
pub fn unbox<T>(value: Box<T>) -> T {
|
||||
*value
|
||||
}
|
||||
|
||||
pub trait SliceOptionExt<T> {
|
||||
fn get_ptr_and_len(&self) -> (*const T, usize);
|
||||
}
|
||||
|
||||
impl<T> SliceOptionExt<T> for Option<&[T]> {
|
||||
fn get_ptr_and_len(&self) -> (*const T, usize) {
|
||||
match self {
|
||||
Some(self_slice) => (self_slice.as_ptr(), self_slice.len()),
|
||||
None => (std::ptr::null(), 0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait MutSliceOptionExt<T> {
|
||||
fn get_mut_ptr_and_len(&mut self) -> (*mut T, usize);
|
||||
}
|
||||
|
||||
impl<T> MutSliceOptionExt<T> for Option<&mut [T]> {
|
||||
fn get_mut_ptr_and_len(&mut self) -> (*mut T, usize) {
|
||||
match self {
|
||||
Some(self_slice) => (self_slice.as_mut_ptr(), self_slice.len()),
|
||||
None => (std::ptr::null_mut(), 0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7,8 +7,12 @@
|
||||
//! 3. Dispatch the syscall to `do_*` (at this file)
|
||||
//! 4. Do some memory checks then call `mod::do_*` (at each module)
|
||||
|
||||
use fs::*;
|
||||
use fs::{
|
||||
AccessFlags, AccessModes, AsUnixSocket, FcntlCmd, File, FileDesc, FileRef, IoctlCmd,
|
||||
UnixSocketFile, AT_FDCWD,
|
||||
};
|
||||
use misc::{resource_t, rlimit_t, utsname_t};
|
||||
use net::{msghdr, msghdr_mut, AsSocket, SocketFile};
|
||||
use process::{pid_t, ChildProcessFilter, CloneFlags, CpuSet, FileAction, FutexFlags, FutexOp};
|
||||
use std::ffi::{CStr, CString};
|
||||
use std::ptr;
|
||||
@ -307,6 +311,7 @@ pub extern "C" fn dispatch_syscall(
|
||||
arg4 as *mut libc::sockaddr,
|
||||
arg5 as *mut libc::socklen_t,
|
||||
),
|
||||
|
||||
SYS_SOCKETPAIR => do_socketpair(
|
||||
arg0 as c_int,
|
||||
arg1 as c_int,
|
||||
@ -314,6 +319,9 @@ pub extern "C" fn dispatch_syscall(
|
||||
arg3 as *mut c_int,
|
||||
),
|
||||
|
||||
SYS_SENDMSG => net::do_sendmsg(arg0 as c_int, arg1 as *const msghdr, arg2 as c_int),
|
||||
SYS_RECVMSG => net::do_recvmsg(arg0 as c_int, arg1 as *mut msghdr_mut, arg2 as c_int),
|
||||
|
||||
_ => do_unknown(num, arg0, arg1, arg2, arg3, arg4, arg5),
|
||||
};
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include <sys/syscall.h>
|
||||
#include <sys/time.h>
|
||||
#include <unistd.h>
|
||||
#include <sys/socket.h>
|
||||
|
||||
#include <sgx_eid.h>
|
||||
#include <sgx_error.h>
|
||||
@ -296,6 +297,65 @@ void ocall_sync(void) {
|
||||
sync();
|
||||
}
|
||||
|
||||
ssize_t ocall_sendmsg(int sockfd,
|
||||
const void *msg_name,
|
||||
socklen_t msg_namelen,
|
||||
const void *buf,
|
||||
size_t buf_len,
|
||||
const void *msg_control,
|
||||
size_t msg_controllen,
|
||||
int flags)
|
||||
{
|
||||
struct iovec msg_iov = { .iov_base = (void*)buf, .iov_len = buf_len };
|
||||
struct iovec* p_msg_iov = buf != NULL ? &msg_iov : NULL;
|
||||
size_t msg_iovlen = buf != NULL ? 1 : 0;
|
||||
|
||||
struct msghdr msg = {
|
||||
(void*) msg_name,
|
||||
msg_namelen,
|
||||
p_msg_iov,
|
||||
msg_iovlen,
|
||||
(void*) msg_control,
|
||||
msg_controllen,
|
||||
0,
|
||||
};
|
||||
return sendmsg(sockfd, &msg, flags);
|
||||
}
|
||||
|
||||
ssize_t ocall_recvmsg(int sockfd,
|
||||
void *msg_name,
|
||||
socklen_t msg_namelen,
|
||||
socklen_t* msg_namelen_recv,
|
||||
void *buf,
|
||||
size_t buf_len,
|
||||
void *msg_control,
|
||||
size_t msg_controllen,
|
||||
size_t* msg_controllen_recv,
|
||||
int* msg_flags_recv,
|
||||
int flags)
|
||||
{
|
||||
struct iovec msg_iov = { .iov_base = buf, .iov_len = buf_len };
|
||||
struct iovec* p_msg_iov = buf != NULL ? &msg_iov : NULL;
|
||||
size_t msg_iovlen = buf != NULL ? 1 : 0;
|
||||
|
||||
struct msghdr msg = {
|
||||
msg_name,
|
||||
msg_namelen,
|
||||
p_msg_iov,
|
||||
msg_iovlen,
|
||||
msg_control,
|
||||
msg_controllen,
|
||||
0,
|
||||
};
|
||||
ssize_t ret = recvmsg(sockfd, &msg, flags);
|
||||
if (ret < 0) return ret;
|
||||
|
||||
*msg_namelen_recv = msg.msg_namelen;
|
||||
*msg_controllen_recv = msg.msg_controllen;
|
||||
*msg_flags_recv = msg.msg_flags;
|
||||
return ret;
|
||||
}
|
||||
|
||||
// ==========================================================================
|
||||
// Main
|
||||
// ==========================================================================
|
||||
|
@ -9,45 +9,138 @@
|
||||
#include <spawn.h>
|
||||
#include <unistd.h>
|
||||
|
||||
int main(int argc, const char *argv[]) {
|
||||
const char* message = "Hello world!";
|
||||
int ret;
|
||||
#include "test.h"
|
||||
|
||||
if (argc != 3) {
|
||||
printf("usage: ./client <ipaddress> <port>\n");
|
||||
return 0;
|
||||
}
|
||||
#define RESPONSE "ACK"
|
||||
#define DEFAULT_MSG "Hello World!\n"
|
||||
|
||||
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sockfd < 0) {
|
||||
printf("create socket error: %s(errno: %d)\n", strerror(errno), errno);
|
||||
return -1;
|
||||
}
|
||||
int connect_with_server(const char *addr_string, const char *port_string) {
|
||||
//"NULL" addr means connectionless, no need to connect to server
|
||||
if (strcmp(addr_string, "NULL") == 0)
|
||||
return 0;
|
||||
|
||||
struct sockaddr_in servaddr;
|
||||
memset(&servaddr, 0, sizeof(servaddr));
|
||||
servaddr.sin_family = AF_INET;
|
||||
servaddr.sin_port = htons((uint16_t)strtol(argv[2], NULL, 10));
|
||||
int ret = 0;
|
||||
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sockfd < 0)
|
||||
THROW_ERROR("create socket error");
|
||||
|
||||
ret = inet_pton(AF_INET, argv[1], &servaddr.sin_addr);
|
||||
if (ret <= 0) {
|
||||
printf("inet_pton error for %s\n", argv[1]);
|
||||
return -1;
|
||||
}
|
||||
struct sockaddr_in servaddr;
|
||||
memset(&servaddr, 0, sizeof(servaddr));
|
||||
servaddr.sin_family = AF_INET;
|
||||
servaddr.sin_port = htons((uint16_t)strtol(port_string, NULL, 10));
|
||||
ret = inet_pton(AF_INET, addr_string, &servaddr.sin_addr);
|
||||
if (ret <= 0) {
|
||||
close(sockfd);
|
||||
THROW_ERROR("inet_pton error");
|
||||
}
|
||||
|
||||
ret = connect(sockfd, (struct sockaddr *) &servaddr, sizeof(servaddr));
|
||||
if (ret < 0) {
|
||||
printf("connect error: %s(errno: %d)\n", strerror(errno), errno);
|
||||
return -1;
|
||||
}
|
||||
ret = connect(sockfd, (struct sockaddr *) &servaddr, sizeof(servaddr));
|
||||
if (ret < 0) {
|
||||
close(sockfd);
|
||||
THROW_ERROR("connect error");
|
||||
}
|
||||
|
||||
printf("send msg to server: %s\n", message);
|
||||
ret = send(sockfd, message, strlen(message), 0);
|
||||
if (ret < 0) {
|
||||
printf("send msg error: %s(errno: %d)\n", strerror(errno), errno);
|
||||
return -1;
|
||||
}
|
||||
|
||||
close(sockfd);
|
||||
return 0;
|
||||
return sockfd;
|
||||
}
|
||||
|
||||
int neogotiate_msg(int server_fd, char *buf, int buf_size) {
|
||||
if (read(server_fd, buf, buf_size) < 0)
|
||||
THROW_ERROR("read failed");
|
||||
|
||||
if (write(server_fd, RESPONSE, sizeof(RESPONSE)) < 0) {
|
||||
THROW_ERROR("write failed");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int client_send(int server_fd, char *buf) {
|
||||
if (send(server_fd, buf, strlen(buf), 0) < 0)
|
||||
THROW_ERROR("send msg error");
|
||||
return 0;
|
||||
}
|
||||
|
||||
int client_sendmsg(int server_fd, char *buf) {
|
||||
int ret = 0;
|
||||
struct msghdr msg;
|
||||
struct iovec iov[1];
|
||||
msg.msg_name = NULL;
|
||||
msg.msg_namelen = 0;
|
||||
iov[0].iov_base = buf;
|
||||
iov[0].iov_len = strlen(buf);
|
||||
msg.msg_iov = iov;
|
||||
msg.msg_iovlen = 1;
|
||||
msg.msg_control = 0;
|
||||
msg.msg_controllen = 0;
|
||||
msg.msg_flags = 0;
|
||||
|
||||
ret = sendmsg(server_fd, &msg, 0);
|
||||
if (ret <= 0)
|
||||
THROW_ERROR("sendmsg failed");
|
||||
return ret;
|
||||
}
|
||||
|
||||
int client_connectionless_sendmsg(char *buf) {
|
||||
int ret = 0;
|
||||
struct msghdr msg;
|
||||
struct iovec iov[1];
|
||||
struct sockaddr_in servaddr;
|
||||
memset(&servaddr, 0, sizeof(servaddr));
|
||||
|
||||
servaddr.sin_family = AF_INET;
|
||||
servaddr.sin_port = htons(9900);
|
||||
servaddr.sin_addr.s_addr= htonl(INADDR_ANY);
|
||||
|
||||
msg.msg_name = &servaddr;
|
||||
msg.msg_namelen = sizeof(servaddr);
|
||||
iov[0].iov_base = buf;
|
||||
iov[0].iov_len = strlen(buf);
|
||||
msg.msg_iov = iov;
|
||||
msg.msg_iovlen = 1;
|
||||
msg.msg_control = 0;
|
||||
msg.msg_controllen = 0;
|
||||
msg.msg_flags = 0;
|
||||
|
||||
int server_fd = socket(AF_INET, SOCK_DGRAM, 0);
|
||||
if (server_fd < 0)
|
||||
THROW_ERROR("create socket error");
|
||||
|
||||
ret = sendmsg(server_fd, &msg, 0);
|
||||
if (ret <= 0)
|
||||
THROW_ERROR("sendmsg failed");
|
||||
return ret;
|
||||
}
|
||||
|
||||
int main(int argc, const char *argv[]) {
|
||||
if (argc != 3) {
|
||||
THROW_ERROR("usage: ./client <ipaddress> <port>\n");
|
||||
}
|
||||
|
||||
int ret = 0;
|
||||
const int buf_size = 100;
|
||||
char buf[buf_size];
|
||||
int port = strtol(argv[2], NULL, 10);
|
||||
int server_fd = connect_with_server(argv[1], argv[2]);
|
||||
|
||||
switch (port)
|
||||
{
|
||||
case 8800:
|
||||
neogotiate_msg(server_fd, buf, buf_size);
|
||||
break;
|
||||
case 8801:
|
||||
neogotiate_msg(server_fd, buf, buf_size);
|
||||
ret = client_send(server_fd, buf);
|
||||
break;
|
||||
case 8802:
|
||||
neogotiate_msg(server_fd, buf, buf_size);
|
||||
ret = client_sendmsg(server_fd, buf);
|
||||
break;
|
||||
case 8803:
|
||||
ret = client_connectionless_sendmsg(DEFAULT_MSG);
|
||||
break;
|
||||
default:
|
||||
ret = client_send(server_fd, DEFAULT_MSG);
|
||||
}
|
||||
|
||||
close(server_fd);
|
||||
return ret;
|
||||
}
|
||||
|
@ -4,62 +4,259 @@
|
||||
#include <errno.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/wait.h>
|
||||
#include <netinet/in.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <spawn.h>
|
||||
#include <unistd.h>
|
||||
|
||||
int main(int argc, const char *argv[]) {
|
||||
const int BUF_SIZE = 0x1000;
|
||||
int ret;
|
||||
#include "test.h"
|
||||
|
||||
int listenfd = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (listenfd < 0) {
|
||||
printf("create socket error: %s(errno: %d)\n", strerror(errno), errno);
|
||||
return -1;
|
||||
}
|
||||
#define ECHO_MSG "msg for client/server test"
|
||||
#define RESPONSE "ACK"
|
||||
#define DEFAULT_MSG "Hello World!\n"
|
||||
|
||||
struct sockaddr_in servaddr;
|
||||
memset(&servaddr, 0, sizeof(servaddr));
|
||||
servaddr.sin_family = AF_INET;
|
||||
servaddr.sin_addr.s_addr = htonl(INADDR_ANY);
|
||||
servaddr.sin_port = htons(6666);
|
||||
int connect_with_child(int port, int *child_pid) {
|
||||
int ret = 0;
|
||||
int listen_fd = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (listen_fd < 0)
|
||||
THROW_ERROR("create socket error");
|
||||
int reuse = 1;
|
||||
if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0)
|
||||
THROW_ERROR("setsockopt port to reuse failed");
|
||||
|
||||
int reuse = 1;
|
||||
if (setsockopt(listenfd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0)
|
||||
perror("setsockopt port to reuse failed");
|
||||
struct sockaddr_in servaddr;
|
||||
memset(&servaddr, 0, sizeof(servaddr));
|
||||
servaddr.sin_family = AF_INET;
|
||||
servaddr.sin_addr.s_addr = htonl(INADDR_ANY);
|
||||
servaddr.sin_port = htons(port);
|
||||
ret = bind(listen_fd, (struct sockaddr *) &servaddr, sizeof(servaddr));
|
||||
if (ret < 0) {
|
||||
close(listen_fd);
|
||||
THROW_ERROR("bind socket failed");
|
||||
}
|
||||
|
||||
ret = bind(listenfd, (struct sockaddr *) &servaddr, sizeof(servaddr));
|
||||
if (ret < 0) {
|
||||
printf("bind socket error: %s(errno: %d)\n", strerror(errno), errno);
|
||||
return -1;
|
||||
}
|
||||
ret = listen(listen_fd, 10);
|
||||
if (ret < 0) {
|
||||
close(listen_fd);
|
||||
THROW_ERROR("listen socket error");
|
||||
}
|
||||
|
||||
ret = listen(listenfd, 10);
|
||||
if (ret < 0) {
|
||||
printf("listen socket error: %s(errno: %d)\n", strerror(errno), errno);
|
||||
return -1;
|
||||
}
|
||||
char port_string[8];
|
||||
sprintf(port_string, "%d", port);
|
||||
char* client_argv[] = {"client", "127.0.0.1", port_string, NULL};
|
||||
ret = posix_spawn(child_pid, "/bin/client", NULL, NULL, client_argv, NULL);
|
||||
if (ret < 0) {
|
||||
close(listen_fd);
|
||||
THROW_ERROR("spawn client process error");
|
||||
}
|
||||
|
||||
int client_pid;
|
||||
char* client_argv[] = {"client", "127.0.0.1", "6666", NULL};
|
||||
ret = posix_spawn(&client_pid, "/bin/client", NULL, NULL, client_argv, NULL);
|
||||
if (ret < 0) {
|
||||
printf("spawn client process error: %s(errno: %d)\n", strerror(errno), errno);
|
||||
return -1;
|
||||
}
|
||||
int connected_fd = accept(listen_fd, (struct sockaddr *) NULL, NULL);
|
||||
if (connected_fd < 0) {
|
||||
close(listen_fd);
|
||||
THROW_ERROR("accept socket error");
|
||||
}
|
||||
|
||||
printf("====== waiting for client's request ======\n");
|
||||
int connect_fd = accept(listenfd, (struct sockaddr *) NULL, NULL);
|
||||
if (connect_fd < 0) {
|
||||
printf("accept socket error: %s(errno: %d)", strerror(errno), errno);
|
||||
return -1;
|
||||
}
|
||||
char buff[BUF_SIZE];
|
||||
int n = recv(connect_fd, buff, BUF_SIZE, 0);
|
||||
buff[n] = '\0';
|
||||
printf("recv msg from client: %s\n", buff);
|
||||
|
||||
close(connect_fd);
|
||||
close(listenfd);
|
||||
return 0;
|
||||
close(listen_fd);
|
||||
return connected_fd;
|
||||
}
|
||||
|
||||
int neogotiate_msg(int client_fd) {
|
||||
char buf[16];
|
||||
if (write(client_fd, ECHO_MSG, sizeof(ECHO_MSG)) < 0)
|
||||
THROW_ERROR("write failed");
|
||||
|
||||
if (read(client_fd, buf, 16) < 0)
|
||||
THROW_ERROR("read failed");
|
||||
|
||||
if (strncmp(buf, RESPONSE, sizeof(RESPONSE)) != 0) {
|
||||
THROW_ERROR("msg recv mismatch");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int server_recv(int client_fd) {
|
||||
const int buf_size = 32;
|
||||
char buf[buf_size];
|
||||
|
||||
if (recv(client_fd, buf, buf_size, 0) <= 0)
|
||||
THROW_ERROR("msg recv failed");
|
||||
|
||||
if (strncmp(buf, ECHO_MSG, sizeof(ECHO_MSG)) != 0) {
|
||||
THROW_ERROR("msg recv mismatch");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int server_recvmsg(int client_fd) {
|
||||
int ret = 0;
|
||||
const int buf_size = 1000;
|
||||
char buf[buf_size];
|
||||
struct msghdr msg;
|
||||
struct iovec iov[1];
|
||||
|
||||
msg.msg_name = NULL;
|
||||
msg.msg_namelen = 0;
|
||||
iov[0].iov_base = buf;
|
||||
iov[0].iov_len = buf_size;
|
||||
msg.msg_iov = iov;
|
||||
msg.msg_iovlen = 1;
|
||||
msg.msg_control = 0;
|
||||
msg.msg_controllen = 0;
|
||||
msg.msg_flags = 0;
|
||||
|
||||
ret = recvmsg(client_fd, &msg, 0);
|
||||
if (ret <= 0) {
|
||||
THROW_ERROR("recvmsg failed");
|
||||
} else {
|
||||
if (strncmp(buf, ECHO_MSG, sizeof(ECHO_MSG)) != 0) {
|
||||
printf("recvmsg : %d, msg: %s\n", ret, buf);
|
||||
THROW_ERROR("msg recvmsg mismatch");
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int server_connectionless_recvmsg() {
|
||||
int ret = 0;
|
||||
const int buf_size = 1000;
|
||||
char buf[buf_size];
|
||||
struct msghdr msg;
|
||||
struct iovec iov[1];
|
||||
|
||||
struct sockaddr_in servaddr;
|
||||
struct sockaddr_in clientaddr;
|
||||
memset(&servaddr, 0, sizeof(servaddr));
|
||||
memset(&clientaddr, 0, sizeof(clientaddr));
|
||||
|
||||
int sock = socket(AF_INET, SOCK_DGRAM, 0);
|
||||
if (sock < 0)
|
||||
THROW_ERROR("create socket error");
|
||||
|
||||
servaddr.sin_family = AF_INET;
|
||||
servaddr.sin_addr.s_addr = htonl(INADDR_ANY);
|
||||
servaddr.sin_port = htons(9900);
|
||||
ret = bind(sock, (struct sockaddr *) &servaddr, sizeof(servaddr));
|
||||
if (ret < 0) {
|
||||
close(sock);
|
||||
THROW_ERROR("bind socket failed");
|
||||
}
|
||||
|
||||
msg.msg_name = &clientaddr;
|
||||
msg.msg_namelen = sizeof(clientaddr);
|
||||
iov[0].iov_base = buf;
|
||||
iov[0].iov_len = buf_size;
|
||||
msg.msg_iov = iov;
|
||||
msg.msg_iovlen = 1;
|
||||
msg.msg_control = 0;
|
||||
msg.msg_controllen = 0;
|
||||
msg.msg_flags = 0;
|
||||
|
||||
ret = recvmsg(sock, &msg, 0);
|
||||
if (ret <= 0) {
|
||||
THROW_ERROR("recvmsg failed");
|
||||
} else {
|
||||
if (strncmp(buf, DEFAULT_MSG, sizeof(DEFAULT_MSG)) != 0) {
|
||||
printf("recvmsg : %d, msg: %s\n", ret, buf);
|
||||
THROW_ERROR("msg recvmsg mismatch");
|
||||
} else {
|
||||
inet_ntop(AF_INET, &clientaddr.sin_addr,
|
||||
buf, sizeof(buf));
|
||||
if(strcmp(buf, "127.0.0.1") !=0) {
|
||||
printf("from port %d and address %s\n", ntohs(clientaddr.sin_port), buf);
|
||||
THROW_ERROR("client addr mismatch");
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int wait_for_child_exit(int child_pid) {
|
||||
int status = 0;
|
||||
if (wait4(child_pid, &status, 0, NULL) < 0) {
|
||||
THROW_ERROR("failed to wait4 the child process");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int test_read_write() {
|
||||
int ret = 0;
|
||||
int child_pid = 0;
|
||||
int client_fd = connect_with_child(8800, &child_pid);
|
||||
if (client_fd < 0)
|
||||
THROW_ERROR("connect failed");
|
||||
else
|
||||
ret = neogotiate_msg(client_fd);
|
||||
|
||||
//wait for the child to exit for next spawn
|
||||
int status = 0;
|
||||
if (wait4(child_pid, &status, 0, NULL) < 0) {
|
||||
THROW_ERROR("failed to wait4 the child process");
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
int test_send_recv() {
|
||||
int ret = 0;
|
||||
int child_pid = 0;
|
||||
int client_fd = connect_with_child(8801, &child_pid);
|
||||
if (client_fd < 0)
|
||||
THROW_ERROR("connect failed");
|
||||
|
||||
if (neogotiate_msg(client_fd) < 0)
|
||||
THROW_ERROR("neogotiate failed");
|
||||
|
||||
ret = server_recv(client_fd);
|
||||
if (ret < 0) return -1;
|
||||
|
||||
ret = wait_for_child_exit(child_pid);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
int test_sendmsg_recvmsg() {
|
||||
int ret = 0;
|
||||
int child_pid = 0;
|
||||
int client_fd = connect_with_child(8802, &child_pid);
|
||||
if (client_fd < 0)
|
||||
THROW_ERROR("connect failed");
|
||||
|
||||
if (neogotiate_msg(client_fd) < 0)
|
||||
THROW_ERROR("neogotiate failed");
|
||||
|
||||
ret = server_recvmsg(client_fd);
|
||||
if (ret < 0) return -1;
|
||||
|
||||
ret = wait_for_child_exit(child_pid);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
int test_sendmsg_recvmsg_connectionless() {
|
||||
int ret = 0;
|
||||
int child_pid = 0;
|
||||
|
||||
char* client_argv[] = {"client", "NULL", "8803", NULL};
|
||||
ret = posix_spawn(&child_pid, "/bin/client", NULL, NULL, client_argv, NULL);
|
||||
if (ret < 0) {
|
||||
THROW_ERROR("spawn client process error");
|
||||
}
|
||||
|
||||
ret = server_connectionless_recvmsg();
|
||||
if (ret < 0) return -1;
|
||||
|
||||
ret = wait_for_child_exit(child_pid);
|
||||
|
||||
return ret;
|
||||
}
|
||||
static test_case_t test_cases[] = {
|
||||
TEST_CASE(test_read_write),
|
||||
TEST_CASE(test_send_recv),
|
||||
TEST_CASE(test_sendmsg_recvmsg),
|
||||
TEST_CASE(test_sendmsg_recvmsg_connectionless),
|
||||
};
|
||||
|
||||
int main(int argc, const char* argv[]) {
|
||||
return test_suite_run(test_cases, ARRAY_SIZE(test_cases));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user