[libos] Implement async network framework based on IO_Uring

This commit is contained in:
ClawSeven 2024-04-30 14:54:23 +08:00 committed by volcano
parent 9d4dcc2b21
commit f8be7e7454
33 changed files with 4958 additions and 27 deletions

3
.gitmodules vendored

@ -24,3 +24,6 @@
[submodule "deps/resolv-conf"] [submodule "deps/resolv-conf"]
path = deps/resolv-conf path = deps/resolv-conf
url = https://github.com/tailhook/resolv-conf.git url = https://github.com/tailhook/resolv-conf.git
[submodule "deps/io-uring"]
path = deps/io-uring
url = https://github.com/occlum/io-uring.git

@ -42,6 +42,7 @@ submodule: githooks init-submodule
@cp deps/sefs/sefs-cli/lib/libsefs-cli_sim.so build/lib @cp deps/sefs/sefs-cli/lib/libsefs-cli_sim.so build/lib
@cp deps/sefs/sefs-cli/lib/libsefs-cli.signed.so build/lib @cp deps/sefs/sefs-cli/lib/libsefs-cli.signed.so build/lib
@cp deps/sefs/sefs-cli/enclave/Enclave.config.xml build/sefs-cli.Enclave.xml @cp deps/sefs/sefs-cli/enclave/Enclave.config.xml build/sefs-cli.Enclave.xml
@cd deps/io-uring/ocalls && cargo clean && cargo build --release
else else
submodule: githooks init-submodule submodule: githooks init-submodule
@rm -rf build @rm -rf build
@ -60,6 +61,7 @@ submodule: githooks init-submodule
@cp deps/sefs/sefs-cli/lib/libsefs-cli_sim.so build/lib @cp deps/sefs/sefs-cli/lib/libsefs-cli_sim.so build/lib
@cp deps/sefs/sefs-cli/lib/libsefs-cli.signed.so build/lib @cp deps/sefs/sefs-cli/lib/libsefs-cli.signed.so build/lib
@cp deps/sefs/sefs-cli/enclave/Enclave.config.xml build/sefs-cli.Enclave.xml @cp deps/sefs/sefs-cli/enclave/Enclave.config.xml build/sefs-cli.Enclave.xml
@cd deps/io-uring/ocalls && cargo clean && cargo build --release
endif endif
init-submodule: init-submodule:

1
deps/io-uring vendored Submodule

@ -0,0 +1 @@
Subproject commit c654c4925bb0b013d3eec736015f8ac4888722be

@ -7,6 +7,8 @@ enclave {
from "sgx_net.edl" import *; from "sgx_net.edl" import *;
from "sgx_occlum_utils.edl" import *; from "sgx_occlum_utils.edl" import *;
from "sgx_vdso_time_ocalls.edl" import *; from "sgx_vdso_time_ocalls.edl" import *;
from "sgx_thread.edl" import *;
from "sgx_io_uring_ocalls.edl" import *;
include "sgx_quote.h" include "sgx_quote.h"
include "occlum_edl_types.h" include "occlum_edl_types.h"

186
src/libos/Cargo.lock generated

@ -10,16 +10,21 @@ dependencies = [
"atomic", "atomic",
"bitflags", "bitflags",
"bitvec 1.0.1", "bitvec 1.0.1",
"byteorder",
"ctor", "ctor",
"derive_builder", "derive_builder",
"downcast-rs",
"errno", "errno",
"goblin", "goblin",
"intrusive-collections", "intrusive-collections",
"io-uring-callback",
"itertools", "itertools",
"keyable-arc",
"lazy_static", "lazy_static",
"log", "log",
"memoffset 0.6.5", "memoffset 0.6.5",
"modular-bitfield", "modular-bitfield",
"num_enum",
"rcore-fs", "rcore-fs",
"rcore-fs-devfs", "rcore-fs-devfs",
"rcore-fs-mountfs", "rcore-fs-mountfs",
@ -32,6 +37,7 @@ dependencies = [
"scroll", "scroll",
"serde", "serde",
"serde_json", "serde_json",
"sgx-untrusted-alloc",
"sgx_cov", "sgx_cov",
"sgx_tcrypto", "sgx_tcrypto",
"sgx_trts", "sgx_trts",
@ -112,6 +118,12 @@ dependencies = [
"wyz", "wyz",
] ]
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.0.73" version = "1.0.73"
@ -203,6 +215,12 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "downcast-rs"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ea835d29036a4087793836fa931b08837ad5e957da9e23886b29586fb9b6650"
[[package]] [[package]]
name = "either" name = "either"
version = "1.8.0" version = "1.8.0"
@ -237,6 +255,67 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
[[package]]
name = "futures"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
name = "futures-core"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c"
[[package]]
name = "futures-io"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964"
[[package]]
name = "futures-sink"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e"
[[package]]
name = "futures-task"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65"
[[package]]
name = "futures-util"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533"
dependencies = [
"futures-core",
"futures-sink",
"futures-task",
"pin-project-lite",
"pin-utils",
]
[[package]] [[package]]
name = "goblin" name = "goblin"
version = "0.5.4" version = "0.5.4"
@ -267,6 +346,36 @@ dependencies = [
"memoffset 0.5.6", "memoffset 0.5.6",
] ]
[[package]]
name = "io-uring"
version = "0.5.9"
dependencies = [
"bitflags",
"libc",
"sgx_libc",
"sgx_trts",
"sgx_tstd",
"sgx_types",
]
[[package]]
name = "io-uring-callback"
version = "0.1.0"
dependencies = [
"atomic",
"cfg-if",
"futures",
"io-uring",
"lazy_static",
"libc",
"lock_api",
"log",
"sgx_libc",
"sgx_tstd",
"slab",
"spin 0.7.1",
]
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.10.3" version = "0.10.3"
@ -283,6 +392,10 @@ dependencies = [
"sgx_tstd", "sgx_tstd",
] ]
[[package]]
name = "keyable-arc"
version = "0.1.0"
[[package]] [[package]]
name = "lazy_static" name = "lazy_static"
version = "1.4.0" version = "1.4.0"
@ -298,6 +411,15 @@ version = "0.2.132"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8371e4e5341c3a96db127eb2465ac681ced4c433e01dd0e938adbef26ba93ba5" checksum = "8371e4e5341c3a96db127eb2465ac681ced4c433e01dd0e938adbef26ba93ba5"
[[package]]
name = "lock_api"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd96ffd135b2fd7b973ac026d28085defbe8983df057ced3eb4f2130b0831312"
dependencies = [
"scopeguard",
]
[[package]] [[package]]
name = "log" name = "log"
version = "0.4.17" version = "0.4.17"
@ -346,6 +468,38 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "num_enum"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f646caf906c20226733ed5b1374287eb97e3c2a5c227ce668c1f2ce20ae57c9"
dependencies = [
"num_enum_derive",
]
[[package]]
name = "num_enum_derive"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcbff9bc912032c62bf65ef1d5aea88983b420f4f839db1e9b0c281a25c9c799"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "pin-project-lite"
version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
[[package]]
name = "pin-utils"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]] [[package]]
name = "plain" name = "plain"
version = "0.2.3" version = "0.2.3"
@ -601,6 +755,12 @@ version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09" checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09"
[[package]]
name = "scopeguard"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]] [[package]]
name = "scroll" name = "scroll"
version = "0.11.0" version = "0.11.0"
@ -648,6 +808,23 @@ dependencies = [
"sgx_tstd", "sgx_tstd",
] ]
[[package]]
name = "sgx-untrusted-alloc"
version = "0.1.0"
dependencies = [
"cfg-if",
"errno",
"intrusive-collections",
"lazy_static",
"libc",
"log",
"sgx_libc",
"sgx_trts",
"sgx_tstd",
"sgx_types",
"spin 0.7.1",
]
[[package]] [[package]]
name = "sgx_alloc" name = "sgx_alloc"
version = "1.1.6" version = "1.1.6"
@ -753,6 +930,15 @@ dependencies = [
"sgx_build_helper", "sgx_build_helper",
] ]
[[package]]
name = "slab"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67"
dependencies = [
"autocfg 1.1.0",
]
[[package]] [[package]]
name = "spin" name = "spin"
version = "0.5.2" version = "0.5.2"

@ -7,6 +7,7 @@ use super::*;
use crate::exception::*; use crate::exception::*;
use crate::fs::HostStdioFds; use crate::fs::HostStdioFds;
use crate::interrupt; use crate::interrupt;
use crate::io_uring::ENABLE_URING;
use crate::process::idle_reap_zombie_children; use crate::process::idle_reap_zombie_children;
use crate::process::{ProcessFilter, SpawnAttr}; use crate::process::{ProcessFilter, SpawnAttr};
use crate::signal::SigNum; use crate::signal::SigNum;
@ -101,11 +102,14 @@ pub extern "C" fn occlum_ecall_init(
vm::init_user_space(); vm::init_user_space();
if ENABLE_URING.load(Ordering::Relaxed) {
crate::io_uring::MULTITON.poll_completions();
}
// Register exception handlers (support cpuid & rdtsc for now) // Register exception handlers (support cpuid & rdtsc for now)
register_exception_handlers(); register_exception_handlers();
HAS_INIT.store(true, Ordering::Release); HAS_INIT.store(true, Ordering::Release);
// Enable global backtrace // Enable global backtrace
unsafe { backtrace::enable_backtrace(&ENCLAVE_PATH, PrintFormat::Short) }; unsafe { backtrace::enable_backtrace(&ENCLAVE_PATH, PrintFormat::Short) };

169
src/libos/src/io_uring.rs Normal file

@ -0,0 +1,169 @@
use core::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize};
use std::{collections::HashMap, thread::current};
use crate::util::sync::Mutex;
use alloc::{sync::Arc, vec::Vec};
use atomic::Ordering;
use io_uring_callback::{Builder, IoUring};
use keyable_arc::KeyableArc;
use crate::config::LIBOS_CONFIG;
// The number of sockets to reach the network bandwidth threshold of one io_uring instance
const SOCKET_THRESHOLD_PER_URING: u32 = 1;
lazy_static::lazy_static! {
pub static ref MULTITON: UringSet = {
let uring_set = UringSet::new();
uring_set
};
pub static ref ENABLE_URING: AtomicBool = AtomicBool::new(LIBOS_CONFIG.feature.io_uring > 0);
// Four uring instances are sufficient to reach the network bandwidth threshold of host kernel.
pub static ref URING_LIMIT: AtomicUsize = {
let uring_limit = LIBOS_CONFIG.feature.io_uring;
assert!(uring_limit <= 16, "io_uring limit must not exceed 16");
AtomicUsize::new(uring_limit as usize)
};
}
#[derive(Clone, Copy, Default)]
struct UringState {
registered_num: u32,
is_enable_poll: bool, // CQE polling thread
}
impl UringState {
fn register_one_socket(&mut self) {
self.registered_num += 1;
}
fn unregister_one_socket(&mut self) {
self.registered_num -= 1;
}
fn enable_poll(&mut self, uring: Arc<IoUring>) {
if !self.is_enable_poll {
self.is_enable_poll = true;
std::thread::spawn(move || loop {
let min_complete = 1;
let polling_retries = 10000;
uring.poll_completions(min_complete, polling_retries);
});
}
}
}
pub struct UringSet {
urings: Mutex<HashMap<KeyableArc<IoUring>, UringState>>,
running_uring_num: AtomicU32,
}
impl UringSet {
pub fn new() -> Self {
let urings = Mutex::new(HashMap::new());
let running_uring_num = AtomicU32::new(0);
Self {
urings,
running_uring_num,
}
}
pub fn poll_completions(&self) {
let mut guard = self.urings.lock();
let uring_limit = URING_LIMIT.load(Ordering::Relaxed) as u32;
for _ in 0..uring_limit {
let uring: KeyableArc<IoUring> = Arc::new(
Builder::new()
.setup_sqpoll(500 /* ms */)
.build(256)
.unwrap(),
)
.into();
let mut state = UringState::default();
state.enable_poll(uring.clone().into());
guard.insert(uring.clone(), state);
self.running_uring_num.fetch_add(1, Ordering::Relaxed);
}
}
pub fn get_uring(&self) -> Arc<IoUring> {
let mut map = self.urings.lock();
let running_uring_num = self.running_uring_num.load(Ordering::Relaxed);
let uring_limit = URING_LIMIT.load(Ordering::Relaxed) as u32;
assert!(running_uring_num <= uring_limit);
let init_stage = running_uring_num < uring_limit;
// Construct an io_uring instance and initiate a polling thread
if init_stage {
let should_build_uring = {
// Sum registered socket
let total_socket_num = map
.values()
.fold(0, |acc, state| acc + state.registered_num)
+ 1;
// Determine the number of available io_uring
let uring_num = (total_socket_num / SOCKET_THRESHOLD_PER_URING) + 1;
let existed_uring_num = self.running_uring_num.load(Ordering::Relaxed);
assert!(existed_uring_num <= uring_num);
existed_uring_num < uring_num
};
if should_build_uring {
let uring: KeyableArc<IoUring> = Arc::new(
Builder::new()
.setup_sqpoll(500 /* ms */)
.build(256)
.unwrap(),
)
.into();
let mut state = UringState::default();
state.register_one_socket();
state.enable_poll(uring.clone().into());
map.insert(uring.clone(), state);
self.running_uring_num.fetch_add(1, Ordering::Relaxed);
return uring.into();
}
}
// Link the file to the io_uring instance with the least load.
let (mut uring, mut state) = map
.iter_mut()
.min_by_key(|(_, &mut state)| state.registered_num)
.unwrap();
// Re-select io_uring instance with least task load
if !init_stage {
let min_registered_num = state.registered_num;
(uring, state) = map
.iter_mut()
.filter(|(_, state)| state.registered_num == min_registered_num)
.min_by_key(|(uring, _)| uring.task_load())
.unwrap();
} else {
// At the initial stage, without constructing additional io_uring instances,
// there exists a singular io_uring which has the minimum number of registered sockets.
}
// Update io_uring instance states
state.register_one_socket();
assert!(state.is_enable_poll);
uring.clone().into()
}
pub fn disattach_uring(&self, fd: usize, uring: Arc<IoUring>) {
let uring: KeyableArc<IoUring> = uring.into();
let mut map = self.urings.lock();
let mut state = map.get_mut(&uring).unwrap();
state.unregister_one_socket();
drop(map);
uring.disattach_fd(fd);
}
}

@ -28,6 +28,8 @@
#![feature(is_some_and)] #![feature(is_some_and)]
// for edmm_api macro // for edmm_api macro
#![feature(linkage)] #![feature(linkage)]
#![feature(new_uninit)]
#![feature(raw_ref_op)]
#[macro_use] #[macro_use]
extern crate alloc; extern crate alloc;
@ -66,7 +68,6 @@ extern crate intrusive_collections;
extern crate itertools; extern crate itertools;
extern crate modular_bitfield; extern crate modular_bitfield;
extern crate resolv_conf; extern crate resolv_conf;
extern crate vdso_time;
use sgx_trts::libc; use sgx_trts::libc;
use sgx_types::*; use sgx_types::*;
@ -82,15 +83,18 @@ mod prelude;
#[macro_use] #[macro_use]
mod error; mod error;
#[macro_use]
mod net;
mod config; mod config;
mod entry; mod entry;
mod events; mod events;
mod exception; mod exception;
mod fs; mod fs;
mod interrupt; mod interrupt;
mod io_uring;
mod ipc; mod ipc;
mod misc; mod misc;
mod net;
mod process; mod process;
mod sched; mod sched;
mod signal; mod signal;

@ -7,12 +7,13 @@ pub use self::io_multiplexing::{
PollEventFlags, PollFd, THREAD_NOTIFIERS, PollEventFlags, PollFd, THREAD_NOTIFIERS,
}; };
pub use self::socket::{ pub use self::socket::{
mmsghdr, msghdr, msghdr_mut, socketpair, unix_socket, AddressFamily, AsUnixSocket, FileFlags, socketpair, unix_socket, AsUnixSocket, Domain, HostSocket, HostSocketType, Iovs, IovsMut,
HostSocket, HostSocketType, HowToShut, Iovs, IovsMut, MsgHdr, MsgHdrFlags, MsgHdrMut, RawAddr, SliceAsLibcIovec, UnixAddr,
RecvFlags, SendFlags, SliceAsLibcIovec, SockAddr, SocketType, UnixAddr,
}; };
pub use self::syscalls::*; pub use self::syscalls::*;
mod io_multiplexing; mod io_multiplexing;
mod socket; pub(crate) mod socket;
mod syscalls; mod syscalls;
pub use self::syscalls::*;

@ -1,21 +1,15 @@
use super::*; use super::*;
mod address_family;
mod flags;
mod host; mod host;
mod iovs; pub(crate) mod sockopt;
mod msg;
mod shutdown;
mod socket_address;
mod socket_type;
mod unix; mod unix;
pub(crate) mod uring;
pub(crate) mod util;
pub use self::address_family::AddressFamily;
pub use self::flags::{FileFlags, MsgHdrFlags, RecvFlags, SendFlags};
pub use self::host::{HostSocket, HostSocketType}; pub use self::host::{HostSocket, HostSocketType};
pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec}; pub use self::unix::{socketpair, unix_socket, AsUnixSocket};
pub use self::msg::{mmsghdr, msghdr, msghdr_mut, CMessages, CmsgData, MsgHdr, MsgHdrMut}; pub use self::util::{
pub use self::shutdown::HowToShut; Addr, AnyAddr, CMessages, CSockAddr, CmsgData, Domain, Iovs, IovsMut, Ipv4Addr, Ipv4SocketAddr,
pub use self::socket_address::SockAddr; Ipv6SocketAddr, MsgFlags, RawAddr, RecvFlags, SendFlags, Shutdown, SliceAsLibcIovec,
pub use self::socket_type::SocketType; SocketProtocol, Type, UnixAddr,
pub use self::unix::{socketpair, unix_socket, AsUnixSocket, UnixAddr}; };

@ -0,0 +1,241 @@
use core::time::Duration;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, Ordering};
use super::Timeout;
use io_uring_callback::IoUring;
use libc::ocall::getsockname as do_getsockname;
use libc::ocall::shutdown as do_shutdown;
use libc::ocall::socket as do_socket;
use libc::ocall::socketpair as do_socketpair;
use crate::events::Pollee;
use crate::fs::{IoEvents, IoNotifier};
use crate::net::socket::uring::runtime::Runtime;
use crate::prelude::*;
/// The common parts of all stream sockets.
pub struct Common<A: Addr + 'static, R: Runtime> {
host_fd: FileDesc,
type_: Type,
nonblocking: AtomicBool,
is_closed: AtomicBool,
pollee: Pollee,
inner: Mutex<Inner<A>>,
timeout: Mutex<Timeout>,
errno: Mutex<Option<Errno>>,
io_uring: Arc<IoUring>,
phantom_data: PhantomData<(A, R)>,
}
impl<A: Addr + 'static, R: Runtime> Common<A, R> {
pub fn new(type_: Type, nonblocking: bool, protocol: Option<i32>) -> Result<Self> {
let domain_c = A::domain() as libc::c_int;
let type_c = type_ as libc::c_int;
let protocol = protocol.unwrap_or(0) as libc::c_int;
let host_fd = try_libc!(do_socket(domain_c, type_c, protocol)) as FileDesc;
let nonblocking = AtomicBool::new(nonblocking);
let is_closed = AtomicBool::new(false);
let pollee = Pollee::new(IoEvents::empty());
let inner = Mutex::new(Inner::new());
let timeout = Mutex::new(Timeout::new());
let io_uring = R::io_uring();
let errno = Mutex::new(None);
Ok(Self {
host_fd,
type_,
nonblocking,
is_closed,
pollee,
inner,
timeout,
errno,
io_uring,
phantom_data: PhantomData,
})
}
pub fn new_pair(sock_type: Type, nonblocking: bool) -> Result<(Self, Self)> {
return_errno!(EINVAL, "Unix is unsupported");
}
pub fn with_host_fd(host_fd: FileDesc, type_: Type, nonblocking: bool) -> Self {
let nonblocking = AtomicBool::new(nonblocking);
let is_closed = AtomicBool::new(false);
let pollee = Pollee::new(IoEvents::empty());
let inner = Mutex::new(Inner::new());
let timeout = Mutex::new(Timeout::new());
let io_uring = R::io_uring();
let errno = Mutex::new(None);
Self {
host_fd,
type_,
nonblocking,
is_closed,
pollee,
inner,
timeout,
errno,
io_uring,
phantom_data: PhantomData,
}
}
pub fn io_uring(&self) -> Arc<IoUring> {
self.io_uring.clone()
}
pub fn host_fd(&self) -> FileDesc {
self.host_fd
}
pub fn type_(&self) -> Type {
self.type_
}
pub fn nonblocking(&self) -> bool {
self.nonblocking.load(Ordering::Relaxed)
}
pub fn set_nonblocking(&self, is_nonblocking: bool) {
self.nonblocking.store(is_nonblocking, Ordering::Relaxed)
}
pub fn notifier(&self) -> &IoNotifier {
self.pollee.notifier()
}
pub fn send_timeout(&self) -> Option<Duration> {
self.timeout.lock().sender_timeout()
}
pub fn recv_timeout(&self) -> Option<Duration> {
self.timeout.lock().receiver_timeout()
}
pub fn set_send_timeout(&self, timeout: Duration) {
self.timeout.lock().set_sender(timeout)
}
pub fn set_recv_timeout(&self, timeout: Duration) {
self.timeout.lock().set_receiver(timeout)
}
pub fn is_closed(&self) -> bool {
self.is_closed.load(Ordering::Relaxed)
}
pub fn set_closed(&self) {
self.is_closed.store(true, Ordering::Relaxed)
}
pub fn reset_closed(&self) {
self.is_closed.store(false, Ordering::Relaxed)
}
pub fn pollee(&self) -> &Pollee {
&self.pollee
}
#[allow(unused)]
pub fn addr(&self) -> Option<A> {
let inner = self.inner.lock();
inner.addr.clone()
}
pub fn set_addr(&self, addr: &A) {
let mut inner = self.inner.lock();
inner.addr = Some(addr.clone())
}
pub fn get_addr_from_host(&self) -> Result<A> {
let mut c_addr: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
let mut c_addr_len = std::mem::size_of::<libc::sockaddr_storage>() as u32;
try_libc!(do_getsockname(
self.host_fd as _,
&mut c_addr as *mut libc::sockaddr_storage as *mut _,
&mut c_addr_len as *mut _,
));
A::from_c_storage(&c_addr, c_addr_len as _)
}
pub fn peer_addr(&self) -> Option<A> {
let inner = self.inner.lock();
inner.peer_addr.clone()
}
pub fn set_peer_addr(&self, peer_addr: &A) {
let mut inner = self.inner.lock();
inner.peer_addr = Some(peer_addr.clone());
}
pub fn reset_peer_addr(&self) {
let mut inner = self.inner.lock();
inner.peer_addr = None;
}
// For getsockopt SO_ERROR command
pub fn errno(&self) -> Option<Errno> {
let mut errno_option = self.errno.lock();
errno_option.take()
}
pub fn set_errno(&self, errno: Errno) {
let mut errno_option = self.errno.lock();
*errno_option = Some(errno);
}
pub fn host_shutdown(&self, how: Shutdown) -> Result<()> {
trace!("host shutdown: {:?}", how);
match how {
Shutdown::Write => {
try_libc!(do_shutdown(self.host_fd as _, libc::SHUT_WR));
}
Shutdown::Read => {
try_libc!(do_shutdown(self.host_fd as _, libc::SHUT_RD));
}
Shutdown::Both => {
try_libc!(do_shutdown(self.host_fd as _, libc::SHUT_RDWR));
}
}
Ok(())
}
}
impl<A: Addr + 'static, R: Runtime> std::fmt::Debug for Common<A, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Common")
.field("host_fd", &self.host_fd)
.field("type", &self.type_)
.field("nonblocking", &self.nonblocking)
.field("pollee", &self.pollee)
.field("inner", &self.inner.lock())
.finish()
}
}
impl<A: Addr + 'static, R: Runtime> Drop for Common<A, R> {
fn drop(&mut self) {
if let Err(e) = super::do_close(self.host_fd) {
log::error!("do_close failed, host_fd: {}, err: {:?}", self.host_fd, e);
}
R::disattach_io_uring(self.host_fd as usize, self.io_uring())
}
}
#[derive(Debug)]
struct Inner<A: Addr + 'static> {
addr: Option<A>,
peer_addr: Option<A>,
}
impl<A: Addr + 'static> Inner<A> {
pub fn new() -> Self {
Self {
addr: None,
peer_addr: None,
}
}
}

@ -0,0 +1,7 @@
mod common;
mod operation;
mod timeout;
pub use self::common::Common;
pub use self::operation::{do_bind, do_close, do_connect, do_unlink};
pub use self::timeout::Timeout;

@ -0,0 +1,44 @@
use std::ffi::CString;
use std::mem::{self, MaybeUninit};
use crate::prelude::*;
pub fn do_bind<A: Addr>(host_fd: FileDesc, addr: &A) -> Result<()> {
let fd = host_fd as i32;
let (c_addr_storage, c_addr_len) = addr.to_c_storage();
let c_addr_ptr = &c_addr_storage as *const _ as _;
let c_addr_len = c_addr_len as u32;
try_libc!(libc::ocall::bind(fd, c_addr_ptr, c_addr_len));
Ok(())
}
pub fn do_close(host_fd: FileDesc) -> Result<()> {
let fd = host_fd as i32;
try_libc!(libc::ocall::close(fd));
Ok(())
}
pub fn do_unlink(path: &String) -> Result<()> {
let c_string =
CString::new(path.as_bytes()).map_err(|_| errno!(EINVAL, "cstring new failure"))?;
let c_path = c_string.as_c_str().as_ptr();
try_libc!(libc::ocall::unlink(c_path));
Ok(())
}
pub fn do_connect<A: Addr>(host_fd: FileDesc, addr: Option<&A>) -> Result<()> {
let fd = host_fd as i32;
let (c_addr_storage, c_addr_len) = match addr {
Some(addr_inner) => addr_inner.to_c_storage(),
None => {
let mut sockaddr_storage =
unsafe { MaybeUninit::<libc::sockaddr_storage>::uninit().assume_init() };
sockaddr_storage.ss_family = libc::AF_UNSPEC as _;
(sockaddr_storage, mem::size_of::<libc::sa_family_t>())
}
};
let c_addr_ptr = &c_addr_storage as *const _ as _;
let c_addr_len = c_addr_len as u32;
try_libc!(libc::ocall::connect(fd, c_addr_ptr, c_addr_len));
Ok(())
}

@ -0,0 +1,32 @@
use std::time::Duration;
#[derive(Clone, Debug)]
pub struct Timeout {
sender: Option<Duration>,
receiver: Option<Duration>,
}
impl Timeout {
pub fn new() -> Self {
Self {
sender: None,
receiver: None,
}
}
pub fn sender_timeout(&self) -> Option<Duration> {
self.sender
}
pub fn receiver_timeout(&self) -> Option<Duration> {
self.receiver
}
pub fn set_sender(&mut self, timeout: Duration) {
self.sender = Some(timeout);
}
pub fn set_receiver(&mut self, timeout: Duration) {
self.receiver = Some(timeout);
}
}

@ -0,0 +1,494 @@
use core::time::Duration;
use crate::{
events::{Observer, Poller},
fs::{IoNotifier, StatusFlags},
match_ioctl_cmd_mut,
net::socket::MsgFlags,
};
use super::*;
use crate::fs::IoEvents as Events;
use crate::fs::{GetIfConf, GetIfReqWithRawCmd, GetReadBufLen, IoctlCmd};
pub struct DatagramSocket<A: Addr + 'static, R: Runtime> {
common: Arc<Common<A, R>>,
state: RwLock<State>,
sender: Arc<Sender<A, R>>,
receiver: Arc<Receiver<A, R>>,
}
impl<A: Addr, R: Runtime> DatagramSocket<A, R> {
pub fn new(nonblocking: bool) -> Result<Self> {
let common = Arc::new(Common::new(Type::DGRAM, nonblocking, None)?);
let state = RwLock::new(State::new());
let sender = Sender::new(common.clone());
let receiver = Receiver::new(common.clone());
Ok(Self {
common,
state,
sender,
receiver,
})
}
pub fn new_pair(nonblocking: bool) -> Result<(Self, Self)> {
let (common1, common2) = Common::new_pair(Type::DGRAM, nonblocking)?;
let socket1 = Self::new_connected(common1);
let socket2 = Self::new_connected(common2);
Ok((socket1, socket2))
}
fn new_connected(common: Common<A, R>) -> Self {
let common = Arc::new(common);
let state = RwLock::new(State::new_connected());
let sender = Sender::new(common.clone());
let receiver = Receiver::new(common.clone());
receiver.initiate_async_recv();
Self {
common,
state,
sender,
receiver,
}
}
pub fn domain(&self) -> Domain {
A::domain()
}
pub fn host_fd(&self) -> FileDesc {
self.common.host_fd()
}
pub fn status_flags(&self) -> StatusFlags {
// Only support O_NONBLOCK
if self.common.nonblocking() {
StatusFlags::O_NONBLOCK
} else {
StatusFlags::empty()
}
}
pub fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> {
// Only support O_NONBLOCK
let nonblocking = new_flags.is_nonblocking();
self.common.set_nonblocking(nonblocking);
Ok(())
}
/// When creating a datagram socket, you can use `bind` to bind the socket
/// to a address, hence another socket can send data to this address.
///
/// Binding is divided into explicit and implicit. Invoking `bind` is
/// explicit binding, while invoking `sendto` / `sendmsg` / `connect`
/// will trigger implicit binding.
///
/// Datagram sockets can only bind once. You should use explicit binding or
/// just implicit binding. The explicit binding will failed if it happens after
/// a implicit binding.
pub fn bind(&self, addr: &A) -> Result<()> {
let mut state = self.state.write().unwrap();
if state.is_bound() {
return_errno!(EINVAL, "The socket is already bound to an address");
}
do_bind(self.host_fd(), addr)?;
self.common.set_addr(addr);
state.mark_explicit_bind();
// Start async recv after explicit binding or implicit binding
self.receiver.initiate_async_recv();
Ok(())
}
/// Datagram sockets provide only connectionless interactions, But datagram sockets
/// can also use connect to associate a socket with a specific address.
/// After connection, any data sent on the socket is automatically addressed to the
/// connected peer, and only data received from that peer is delivered to the user.
///
/// Unlike stream sockets, datagram sockets can connect multiple times. But the socket
/// can only connect to one peer in the same time; a second connect will change the
/// peer address, and a connect to a address with family AF_UNSPEC will dissolve the
/// association ("disconnect" or "unconnect").
///
/// Before connection you can only use `sendto` / `sendmsg` / `recvfrom` / `recvmsg`.
/// Only after connection, you can use `read` / `recv` / `write` / `send`.
/// And you can ignore the address in `sendto` / `sendmsg` if you just want to
/// send data to the connected peer.
///
/// Ref 1: http://osr507doc.xinuos.com/en/netguide/disockD.connecting_datagrams.html
/// Ref 2: https://www.masterraghu.com/subjects/np/introduction/unix_network_programming_v1.3/ch08lev1sec11.html
pub fn connect(&self, peer_addr: Option<&A>) -> Result<()> {
let mut state = self.state.write().unwrap();
// if previous peer.is_default() and peer_addr.is_none()
// is unspec, so the situation exists that both
// !state.is_connected() and peer_addr.is_none() are true.
if let Some(peer) = peer_addr {
do_connect(self.host_fd(), Some(peer))?;
self.receiver.reset_shutdown();
self.sender.reset_shutdown();
self.common.set_peer_addr(peer);
if peer.is_default() {
state.mark_disconnected();
} else {
state.mark_connected();
}
if !state.is_bound() {
state.mark_implicit_bind();
// Start async recv after explicit binding or implicit binding
self.receiver.initiate_async_recv();
}
// TODO: update binding address in some cases
// For a ipv4 socket bound to 0.0.0.0 (INADDR_ANY), if you do connection
// to 127.0.0.1 (Local IP address), the IP address of the socket will
// change to 127.0.0.1 too. And if connect to non-local IP address, linux
// will assign a address to the socket.
// In both cases, we should update the binding address that we stored.
} else {
do_connect::<A>(self.host_fd(), None)?;
self.common.reset_peer_addr();
state.mark_disconnected();
// TODO: clear binding in some cases.
// Disconnect will effect the binding address. In Linux, for socket that
// explicit bound to local IP address, disconnect will clear the binding address,
// but leave the port intact. For socket with implicit bound, disconnect will
// clear both the address and port.
}
Ok(())
}
// Close the datagram socket, cancel pending iouring requests
pub fn close(&self) -> Result<()> {
self.sender.shutdown();
self.receiver.shutdown();
self.common.set_closed();
self.cancel_requests();
Ok(())
}
/// Shutdown the udp socket. This syscall is very TCP-oriented, but it is also useful for udp socket.
/// Not like tcp, shutdown does nothing on the wire, it only changes shutdown states.
/// The shutdown states block the io-uring request of receiving or sending message.
pub fn shutdown(&self, how: Shutdown) -> Result<()> {
let state = self.state.read().unwrap();
if !state.is_connected() {
return_errno!(ENOTCONN, "The udp socket is not connected");
}
drop(state);
match how {
Shutdown::Read => {
self.common.host_shutdown(how)?;
self.receiver.shutdown();
self.common.pollee().add_events(Events::IN);
}
Shutdown::Write => {
if self.sender.is_empty() {
self.common.host_shutdown(how)?;
}
self.sender.shutdown();
self.common.pollee().add_events(Events::OUT);
}
Shutdown::Both => {
self.common.host_shutdown(Shutdown::Read)?;
if self.sender.is_empty() {
self.common.host_shutdown(Shutdown::Write)?;
}
self.receiver.shutdown();
self.sender.shutdown();
self.common
.pollee()
.add_events(Events::IN | Events::OUT | Events::HUP);
}
}
Ok(())
}
pub fn read(&self, buf: &mut [u8]) -> Result<usize> {
self.readv(&mut [buf])
}
pub fn readv(&self, bufs: &mut [&mut [u8]]) -> Result<usize> {
let state = self.state.read().unwrap();
drop(state);
self.recvmsg(bufs, RecvFlags::empty(), None)
.map(|(ret, ..)| ret)
}
/// You can not invoke `recvfrom` directly after creating a datagram socket.
/// That is because `recvfrom` doesn't privide a implicit binding. If you
/// don't do a explicit or implicit binding, the sender doesn't know where
/// to send the data.
pub fn recvmsg(
&self,
bufs: &mut [&mut [u8]],
flags: RecvFlags,
control: Option<&mut [u8]>,
) -> Result<(usize, Option<A>, MsgFlags, usize)> {
self.receiver.recvmsg(bufs, flags, control)
}
pub fn write(&self, buf: &[u8]) -> Result<usize> {
self.writev(&[buf])
}
pub fn writev(&self, bufs: &[&[u8]]) -> Result<usize> {
self.sendmsg(bufs, None, SendFlags::empty(), None)
}
pub fn sendmsg(
&self,
bufs: &[&[u8]],
addr: Option<&A>,
flags: SendFlags,
control: Option<&[u8]>,
) -> Result<usize> {
let state = self.state.read().unwrap();
if addr.is_none() && !state.is_connected() {
return_errno!(EDESTADDRREQ, "Destination address required");
}
drop(state);
let res = if let Some(addr) = addr {
self.sender.sendmsg(bufs, addr, flags, control)
} else {
let peer = self.common.peer_addr();
if let Some(peer) = peer.as_ref() {
self.sender.sendmsg(bufs, peer, flags, control)
} else {
return_errno!(EDESTADDRREQ, "Destination address required");
}
};
let mut state = self.state.write().unwrap();
if !state.is_bound() {
state.mark_implicit_bind();
// Start async recv after explicit binding or implicit binding
self.receiver.initiate_async_recv();
}
res
}
pub fn poll(&self, mask: Events, poller: Option<&Poller>) -> Events {
let pollee = self.common.pollee();
pollee.poll(mask, poller)
}
pub fn addr(&self) -> Result<A> {
let common = &self.common;
// Always get addr from host.
// Because for IP socket, users can specify "0" as port and the kernel should select a usable port for him.
// Thus, when calling getsockname, this should be updated.
let addr = common.get_addr_from_host()?;
common.set_addr(&addr);
Ok(addr)
}
pub fn notifier(&self) -> &IoNotifier {
let notifier = self.common.notifier();
notifier
}
pub fn peer_addr(&self) -> Result<A> {
let state = self.state.read().unwrap();
if !state.is_connected() {
return_errno!(ENOTCONN, "the socket is not connected");
}
Ok(self.common.peer_addr().unwrap())
}
pub fn errno(&self) -> Option<Errno> {
self.common.errno()
}
pub fn ioctl(&self, cmd: &mut dyn IoctlCmd) -> Result<()> {
match_ioctl_cmd_mut!(&mut *cmd, {
cmd: GetSockOptRawCmd => {
cmd.execute(self.host_fd())?;
},
cmd: SetSockOptRawCmd => {
cmd.execute(self.host_fd())?;
},
cmd: SetRecvTimeoutCmd => {
self.set_recv_timeout(*cmd.timeout());
},
cmd: SetSendTimeoutCmd => {
self.set_send_timeout(*cmd.timeout());
},
cmd: GetRecvTimeoutCmd => {
let timeval = timeout_to_timeval(self.recv_timeout());
cmd.set_output(timeval);
},
cmd: GetSendTimeoutCmd => {
let timeval = timeout_to_timeval(self.send_timeout());
cmd.set_output(timeval);
},
cmd: GetAcceptConnCmd => {
// Datagram doesn't support listen
cmd.set_output(0);
},
cmd: GetDomainCmd => {
cmd.set_output(self.domain() as _);
},
cmd: GetErrorCmd => {
let error: i32 = self.errno().map(|err| err as i32).unwrap_or(0);
cmd.set_output(error);
},
cmd: GetPeerNameCmd => {
let peer = self.peer_addr()?;
cmd.set_output(AddrStorage(peer.to_c_storage()));
},
cmd: GetTypeCmd => {
cmd.set_output(self.common.type_() as _);
},
cmd: GetIfReqWithRawCmd => {
cmd.execute(self.host_fd())?;
},
cmd: GetIfConf => {
cmd.execute(self.host_fd())?;
},
cmd: GetReadBufLen => {
let read_buf_len = self.receiver.ready_len();
cmd.set_output(read_buf_len as _);
},
_ => {
return_errno!(EINVAL, "Not supported yet");
}
});
Ok(())
}
fn send_timeout(&self) -> Option<Duration> {
self.common.send_timeout()
}
fn recv_timeout(&self) -> Option<Duration> {
self.common.recv_timeout()
}
fn set_send_timeout(&self, timeout: Duration) {
self.common.set_send_timeout(timeout);
}
fn set_recv_timeout(&self, timeout: Duration) {
self.common.set_recv_timeout(timeout);
}
fn cancel_requests(&self) {
self.receiver.cancel_recv_requests();
self.sender.try_clear_msg_queue_when_close();
}
}
impl<A: Addr + 'static, R: Runtime> Drop for DatagramSocket<A, R> {
fn drop(&mut self) {
self.common.set_closed();
}
}
impl<A: Addr + 'static, R: Runtime> std::fmt::Debug for DatagramSocket<A, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DatagramSocket")
.field("common", &self.common)
.finish()
}
}
#[derive(Debug)]
struct State {
bind_state: BindState,
is_connected: bool,
}
impl State {
pub fn new() -> Self {
Self {
bind_state: BindState::Unbound,
is_connected: false,
}
}
pub fn new_connected() -> Self {
Self {
bind_state: BindState::Unbound,
is_connected: true,
}
}
pub fn is_bound(&self) -> bool {
self.bind_state.is_bound()
}
#[allow(dead_code)]
pub fn is_explicit_bound(&self) -> bool {
self.bind_state.is_explicit_bound()
}
#[allow(dead_code)]
pub fn is_implicit_bound(&self) -> bool {
self.bind_state.is_implicit_bound()
}
pub fn is_connected(&self) -> bool {
self.is_connected
}
pub fn mark_explicit_bind(&mut self) {
self.bind_state = BindState::ExplicitBound;
}
pub fn mark_implicit_bind(&mut self) {
self.bind_state = BindState::ImplicitBound;
}
pub fn mark_connected(&mut self) {
self.is_connected = true;
}
pub fn mark_disconnected(&mut self) {
self.is_connected = false;
}
}
#[derive(Debug)]
enum BindState {
Unbound,
ExplicitBound,
ImplicitBound,
}
impl BindState {
pub fn is_bound(&self) -> bool {
match self {
Self::Unbound => false,
_ => true,
}
}
#[allow(dead_code)]
pub fn is_explicit_bound(&self) -> bool {
match self {
Self::ExplicitBound => true,
_ => false,
}
}
#[allow(dead_code)]
pub fn is_implicit_bound(&self) -> bool {
match self {
Self::ImplicitBound => true,
_ => false,
}
}
}

@ -0,0 +1,20 @@
//! Datagram sockets.
mod generic;
mod receiver;
mod sender;
use self::receiver::Receiver;
use self::sender::Sender;
use crate::net::socket::sockopt::*;
use crate::net::socket::uring::common::{do_bind, do_connect, Common};
use crate::net::socket::uring::runtime::Runtime;
use crate::prelude::*;
pub use generic::DatagramSocket;
use crate::net::socket::sockopt::{
timeout_to_timeval, GetRecvTimeoutCmd, GetSendTimeoutCmd, SetRecvTimeoutCmd, SetSendTimeoutCmd,
};
const MAX_BUF_SIZE: usize = 64 * 1024;
const OPTMEM_MAX: usize = 64 * 1024;

@ -0,0 +1,382 @@
use core::time::Duration;
use std::mem::MaybeUninit;
use crate::events::Poller;
use crate::net::socket::MsgFlags;
use io_uring_callback::{Fd, IoHandle};
use sgx_untrusted_alloc::{MaybeUntrusted, UntrustedBox};
use crate::fs::IoEvents as Events;
use crate::net::socket::uring::common::Common;
use crate::net::socket::uring::runtime::Runtime;
use crate::prelude::*;
pub struct Receiver<A: Addr + 'static, R: Runtime> {
common: Arc<Common<A, R>>,
inner: Mutex<Inner>,
}
impl<A: Addr, R: Runtime> Receiver<A, R> {
pub fn new(common: Arc<Common<A, R>>) -> Arc<Self> {
let inner = Mutex::new(Inner::new());
Arc::new(Self { common, inner })
}
pub fn recvmsg(
self: &Arc<Self>,
bufs: &mut [&mut [u8]],
flags: RecvFlags,
mut control: Option<&mut [u8]>,
) -> Result<(usize, Option<A>, MsgFlags, usize)> {
let mask = Events::IN;
// Initialize the poller only when needed
let mut poller = None;
let mut timeout = self.common.recv_timeout();
loop {
// Attempt to recv
let res = self.try_recvmsg(bufs, flags, &mut control);
if !res.has_errno(EAGAIN) {
return res;
}
// Need more handles for flags not MSG_DONTWAIT
// recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT
if self.common.nonblocking()
|| flags.contains(RecvFlags::MSG_DONTWAIT)
|| flags.contains(RecvFlags::MSG_ERRQUEUE)
{
return_errno!(EAGAIN, "no data are present to be received");
}
// Wait for interesting events by polling
if poller.is_none() {
let new_poller = Poller::new();
self.common.pollee().connect_poller(mask, &new_poller);
poller = Some(new_poller);
}
let events = self.common.pollee().poll(mask, None);
if events.is_empty() {
let ret = poller.as_ref().unwrap().wait_timeout(timeout.as_mut());
if let Err(e) = ret {
warn!("recv wait errno = {:?}", e.errno());
match e.errno() {
ETIMEDOUT => {
return_errno!(EAGAIN, "timeout reached")
}
_ => {
return_errno!(e.errno(), "wait error")
}
}
}
}
}
}
fn try_recvmsg(
self: &Arc<Self>,
bufs: &mut [&mut [u8]],
flags: RecvFlags,
control: &mut Option<&mut [u8]>,
) -> Result<(usize, Option<A>, MsgFlags, usize)> {
let mut inner = self.inner.lock();
if !flags.is_empty() && flags.contains(RecvFlags::MSG_OOB | RecvFlags::MSG_CMSG_CLOEXEC) {
// todo!("Support other flags");
return_errno!(EINVAL, "the socket flags is not supported");
}
// Mark the socket as non-readable since Datagram uses single packet
self.common.pollee().del_events(Events::IN);
let mut recv_bytes = 0;
let mut msg_flags = MsgFlags::empty();
let recv_addr = inner.get_packet_addr();
let msg_controllen = inner.control_len.unwrap_or(0);
let user_controllen = control.as_ref().map_or(0, |buf| buf.len());
// Copy ancillary data from control buffer
if user_controllen > super::OPTMEM_MAX {
return_errno!(EINVAL, "invalid msg control length");
}
if user_controllen < msg_controllen {
msg_flags = msg_flags | MsgFlags::MSG_CTRUNC
}
if msg_controllen > 0 {
let copied_bytes = msg_controllen.min(user_controllen);
control
.as_mut()
.map(|buf| buf[..copied_bytes].copy_from_slice(&inner.msg_control[..copied_bytes]));
}
// Copy data from the recv buffer to the bufs
let copied_bytes = inner.try_copy_buf(bufs);
if let Some(copied_bytes) = copied_bytes {
let bufs_len: usize = bufs.iter().map(|buf| buf.len()).sum();
// If user provided buffer length is smaller than kernel received datagram length,
// discard the datagram and set MsgFlags::MSG_TRUNC in returned msg_flags.
if bufs_len < inner.recv_len().unwrap() {
// update msg.msg_flags to MSG_TRUNC
msg_flags = msg_flags | MsgFlags::MSG_TRUNC
};
// If user provided flags contain MSG_TRUNC, the return received length should be
// kernel receiver buffer length, vice versa should return truly copied bytes length.
recv_bytes = if flags.contains(RecvFlags::MSG_TRUNC) {
inner.recv_len().unwrap()
} else {
copied_bytes
};
// When flags contain MSG_PEEK and there is data in socket recv buffer,
// it is unnecessary to send blocking recv request (do_recv) to fetch data
// from iouring buffer, which may flush the data in recv buffer.
// When flags don't contain MSG_PEEK or there is no available data,
// it is time to send blocking request to iouring for notifying events.
if !flags.contains(RecvFlags::MSG_PEEK) {
self.do_recv(&mut inner);
}
return Ok((recv_bytes, recv_addr, msg_flags, msg_controllen));
}
// In some situantions of MSG_ERRQUEUE, users only require control buffer but setting iovec length to zero.
if msg_controllen > 0 {
return Ok((recv_bytes, recv_addr, msg_flags, msg_controllen));
}
// Handle iouring message error
if let Some(errno) = inner.error {
// Reset error
inner.error = None;
self.common.pollee().del_events(Events::ERR);
return_errno!(errno, "recv failed");
}
if inner.is_shutdown {
if self.common.nonblocking()
|| flags.contains(RecvFlags::MSG_DONTWAIT)
|| flags.contains(RecvFlags::MSG_ERRQUEUE)
{
return_errno!(Errno::EWOULDBLOCK, "the socket recv has been shutdown");
} else {
return Ok((0, None, msg_flags, 0));
}
}
self.do_recv(&mut inner);
return_errno!(EAGAIN, "try recv again");
}
fn do_recv(self: &Arc<Self>, inner: &mut MutexGuard<Inner>) {
if inner.io_handle.is_some() || self.common.is_closed() {
return;
}
// Clear recv_len and error
inner.recv_len.take();
inner.control_len.take();
inner.error.take();
if inner.is_shutdown {
info!("do_recv early return, the socket recv has been shutdown");
return;
}
let receiver = self.clone();
// Init the callback invoked upon the completion of the async recv
let complete_fn = move |retval: i32| {
let mut inner = receiver.inner.lock();
// Release the handle to the async recv
inner.io_handle.take();
// Handle error
if retval < 0 {
// TODO: Should we filter the error case? Do we have the ability to filter?
// We only filter the normal case now. According to the man page of recvmsg,
// these errors should not happen, since our fd and arguments should always
// be valid unless being attacked.
// TODO: guard against Iago attack through errno
let errno = Errno::from(-retval as u32);
inner.error = Some(errno);
receiver.common.set_errno(errno);
// TODO: add PRI event if set SO_SELECT_ERR_QUEUE
receiver.common.pollee().add_events(Events::ERR);
return;
}
// Handle the normal case of a successful read
inner.recv_len = Some(retval as usize);
let control_len = inner.req.msg.msg_controllen;
inner.control_len = Some(control_len);
receiver.common.pollee().add_events(Events::IN);
// We don't do_recv() here, since do_recv() will clear the recv message.
};
// Generate the async recv request
let msghdr_ptr = inner.new_recv_req();
// Submit the async recv to io_uring
let io_uring = self.common.io_uring();
let host_fd = Fd(self.common.host_fd() as _);
let handle = unsafe { io_uring.recvmsg(host_fd, msghdr_ptr, 0, complete_fn) };
inner.io_handle.replace(handle);
}
pub fn initiate_async_recv(self: &Arc<Self>) {
let mut inner = self.inner.lock();
self.do_recv(&mut inner);
}
pub fn cancel_recv_requests(&self) {
{
let inner = self.inner.lock();
if let Some(io_handle) = &inner.io_handle {
let io_uring = self.common.io_uring();
unsafe { io_uring.cancel(io_handle) };
} else {
return;
}
}
// wait for the cancel to complete
let poller = Poller::new();
let mask = Events::ERR | Events::IN;
self.common.pollee().connect_poller(mask, &poller);
loop {
let pending_request_exist = {
let inner = self.inner.lock();
inner.io_handle.is_some()
};
if pending_request_exist {
let mut timeout = Some(Duration::from_secs(10));
let ret = poller.wait_timeout(timeout.as_mut());
if let Err(e) = ret {
warn!("wait cancel recv request error = {:?}", e.errno());
continue;
}
} else {
break;
}
}
}
/// Shutdown udp receiver.
pub fn shutdown(&self) {
let mut inner = self.inner.lock();
inner.is_shutdown = true;
}
/// Reset udp receiver shutdown state.
pub fn reset_shutdown(&self) {
let mut inner = self.inner.lock();
inner.is_shutdown = false;
}
pub fn ready_len(&self) -> usize {
let inner = self.inner.lock();
inner.recv_len().unwrap_or(0)
}
}
struct Inner {
recv_buf: UntrustedBox<[u8]>,
// Datagram sockets in various domains permit zero-length datagrams.
// Hence the recv_len might be 0.
recv_len: Option<usize>,
// When the recv_buf content length is greater than user buffer,
// store the offset for the recv_buf for read loop
recv_buf_offset: usize,
msg_control: UntrustedBox<[u8]>,
control_len: Option<usize>,
req: UntrustedBox<RecvReq>,
io_handle: Option<IoHandle>,
error: Option<Errno>,
is_shutdown: bool,
}
unsafe impl Send for Inner {}
impl Inner {
pub fn new() -> Self {
Self {
recv_buf: UntrustedBox::new_uninit_slice(super::MAX_BUF_SIZE),
recv_len: None,
recv_buf_offset: 0,
msg_control: UntrustedBox::new_uninit_slice(super::OPTMEM_MAX),
control_len: None,
req: UntrustedBox::new_uninit(),
io_handle: None,
error: None,
is_shutdown: false,
}
}
pub fn new_recv_req(&mut self) -> *mut libc::msghdr {
let iovec = libc::iovec {
iov_base: self.recv_buf.as_mut_ptr() as _,
iov_len: self.recv_buf.len(),
};
let msghdr_ptr = &raw mut self.req.msg;
let mut msg: libc::msghdr = unsafe { MaybeUninit::zeroed().assume_init() };
msg.msg_iov = &raw mut self.req.iovec as _;
msg.msg_iovlen = 1;
msg.msg_name = &raw mut self.req.addr as _;
msg.msg_namelen = std::mem::size_of::<libc::sockaddr_storage>() as _;
msg.msg_control = self.msg_control.as_mut_ptr() as _;
msg.msg_controllen = self.msg_control.len() as _;
self.req.msg = msg;
self.req.iovec = iovec;
msghdr_ptr
}
pub fn try_copy_buf(&self, bufs: &mut [&mut [u8]]) -> Option<usize> {
self.recv_len.map(|recv_len| {
let mut copy_len = 0;
for buf in bufs {
let recv_buf = &self.recv_buf[copy_len..recv_len];
if buf.len() <= recv_buf.len() {
buf.copy_from_slice(&recv_buf[..buf.len()]);
copy_len += buf.len();
} else {
buf[..recv_buf.len()].copy_from_slice(&recv_buf[..]);
copy_len += recv_buf.len();
break;
}
}
copy_len
})
}
pub fn recv_len(&self) -> Option<usize> {
self.recv_len
}
/// Return the addr of the received packet if udp socket is not connected.
/// Return None if udp socket is connected.
pub fn get_packet_addr<A: Addr>(&self) -> Option<A> {
let recv_addr_len = self.req.msg.msg_namelen as usize;
A::from_c_storage(&self.req.addr, recv_addr_len).ok()
}
}
#[repr(C)]
struct RecvReq {
msg: libc::msghdr,
iovec: libc::iovec,
addr: libc::sockaddr_storage,
}
unsafe impl MaybeUntrusted for RecvReq {}

@ -0,0 +1,406 @@
use core::time::Duration;
use std::ptr::{self};
use io_uring_callback::{Fd, IoHandle};
use libc::c_void;
use sgx_untrusted_alloc::{MaybeUntrusted, UntrustedBox};
use std::collections::VecDeque;
use crate::events::Poller;
use crate::fs::IoEvents as Events;
use crate::net::socket::uring::common::Common;
use crate::net::socket::uring::runtime::Runtime;
use crate::prelude::*;
use crate::util::sync::MutexGuard;
const SENDMSG_QUEUE_LEN: usize = 16;
pub struct Sender<A: Addr + 'static, R: Runtime> {
common: Arc<Common<A, R>>,
inner: Mutex<Inner>,
}
impl<A: Addr, R: Runtime> Sender<A, R> {
pub fn new(common: Arc<Common<A, R>>) -> Arc<Self> {
common.pollee().add_events(Events::OUT);
let inner = Mutex::new(Inner::new());
Arc::new(Self { common, inner })
}
/// Shutdown udp sender.
pub fn shutdown(&self) {
let mut inner = self.inner.lock();
inner.is_shutdown = ShutdownStatus::PreShutdown;
}
/// Reset udp sender shutdown state.
pub fn reset_shutdown(&self) {
let mut inner = self.inner.lock();
inner.is_shutdown = ShutdownStatus::Running;
}
/// Whether no buffer in sender.
pub fn is_empty(&self) -> bool {
let inner = self.inner.lock();
inner.msg_queue.is_empty()
}
// Normally, We will always try to send as long as the kernel send buf is not empty.
// However, if the user calls close, we will wait LINGER time
// and then cancel on-going or new-issued send requests.
pub fn try_clear_msg_queue_when_close(&self) {
let inner = self.inner.lock();
debug_assert!(inner.is_shutdown());
if inner.msg_queue.is_empty() {
return;
}
// Wait for linger time to empty the kernel buffer or cancel subsequent requests.
drop(inner);
const DEFUALT_LINGER_TIME: usize = 10;
let poller = Poller::new();
let mask = Events::ERR | Events::OUT;
self.common.pollee().connect_poller(mask, &poller);
loop {
let pending_request_exist = {
let inner = self.inner.lock();
inner.io_handle.is_some()
};
if pending_request_exist {
let mut timeout = Some(Duration::from_secs(DEFUALT_LINGER_TIME as u64));
let ret = poller.wait_timeout(timeout.as_mut());
trace!("wait empty send buffer ret = {:?}", ret);
if let Err(_) = ret {
// No complete request to wake. Just cancel the send requests.
let io_uring = self.common.io_uring();
let inner = self.inner.lock();
if let Some(io_handle) = &inner.io_handle {
unsafe { io_uring.cancel(io_handle) };
// Loop again to wait the cancel request to complete
continue;
} else {
// No pending request, just break
break;
}
}
} else {
// There is no pending requests
break;
}
}
}
pub fn sendmsg(
self: &Arc<Self>,
bufs: &[&[u8]],
addr: &A,
flags: SendFlags,
control: Option<&[u8]>,
) -> Result<usize> {
if !flags.is_empty()
&& flags.intersects(!(SendFlags::MSG_DONTWAIT | SendFlags::MSG_NOSIGNAL))
{
error!("Not supported flags: {:?}", flags);
return_errno!(EINVAL, "not supported flags");
}
let mask = Events::OUT;
// Initialize the poller only when needed
let mut poller = None;
let mut timeout = self.common.send_timeout();
loop {
// Attempt to write
let res = self.try_sendmsg(bufs, addr, control);
if !res.has_errno(EAGAIN) {
return res;
}
// Still some buffer contents pending
if self.common.nonblocking() || flags.contains(SendFlags::MSG_DONTWAIT) {
return_errno!(EAGAIN, "try write again");
}
// Wait for interesting events by polling
if poller.is_none() {
let new_poller = Poller::new();
self.common.pollee().connect_poller(mask, &new_poller);
poller = Some(new_poller);
}
let events = self.common.pollee().poll(mask, None);
if events.is_empty() {
let ret = poller.as_ref().unwrap().wait_timeout(timeout.as_mut());
if let Err(e) = ret {
warn!("send wait errno = {:?}", e.errno());
match e.errno() {
ETIMEDOUT => {
return_errno!(EAGAIN, "timeout reached")
}
_ => {
return_errno!(e.errno(), "wait error")
}
}
}
}
}
}
fn try_sendmsg(
self: &Arc<Self>,
bufs: &[&[u8]],
addr: &A,
control: Option<&[u8]>,
) -> Result<usize> {
let mut inner = self.inner.lock();
if inner.is_shutdown() {
return_errno!(EPIPE, "the write has been shutdown")
}
if let Some(errno) = inner.error {
// Reset error
inner.error = None;
self.common.pollee().del_events(Events::ERR);
return_errno!(errno, "write failed");
}
let buf_len: usize = bufs.iter().map(|buf| buf.len()).sum();
let mut msg = DataMsg::new(buf_len);
let total_copied = msg.copy_buf(bufs)?;
msg.copy_control(control)?;
let msghdr_ptr = new_send_req(&mut msg, addr);
if !inner.msg_queue.push_msg(msg) {
// Msg queue can not push this msg, mark the socket as non-writable
self.common.pollee().del_events(Events::OUT);
return_errno!(EAGAIN, "try write again");
}
// Since the send buffer is not empty, try to flush the buffer
if inner.io_handle.is_none() {
self.do_send(&mut inner, msghdr_ptr);
}
Ok(total_copied)
}
fn do_send(self: &Arc<Self>, inner: &mut MutexGuard<Inner>, msghdr_ptr: *const libc::msghdr) {
debug_assert!(!inner.msg_queue.is_empty());
debug_assert!(inner.io_handle.is_none());
let sender = self.clone();
// Submit the async send to io_uring
let complete_fn = move |retval: i32| {
let mut inner = sender.inner.lock();
trace!("send request complete with retval: {}", retval);
// Release the handle to the async recv
inner.io_handle.take();
if retval < 0 {
// TODO: add PRI event if set SO_SELECT_ERR_QUEUE
let errno = Errno::from(-retval as u32);
inner.error = Some(errno);
sender.common.set_errno(errno);
sender.common.pollee().add_events(Events::ERR);
return;
}
// Need to handle normal case
inner.msg_queue.pop_msg();
sender.common.pollee().add_events(Events::OUT);
if !inner.msg_queue.is_empty() {
let msghdr_ptr = inner.msg_queue.first_msg_ptr();
debug_assert!(msghdr_ptr.is_some());
sender.do_send(&mut inner, msghdr_ptr.unwrap());
} else if inner.is_shutdown == ShutdownStatus::PreShutdown {
// The buffer is empty and the write side is shutdown by the user.
// We can safely shutdown host file here.
let _ = sender.common.host_shutdown(Shutdown::Write);
inner.is_shutdown = ShutdownStatus::PostShutdown
}
};
// Generate the async recv request
let io_uring = self.common.io_uring();
let host_fd = Fd(self.common.host_fd() as _);
let handle = unsafe { io_uring.sendmsg(host_fd, msghdr_ptr, 0, complete_fn) };
inner.io_handle.replace(handle);
}
}
fn new_send_req<A: Addr>(dmsg: &mut DataMsg, addr: &A) -> *const libc::msghdr {
let iovec = libc::iovec {
iov_base: dmsg.send_buf.as_ptr() as _,
iov_len: dmsg.send_buf.len(),
};
let (control, controllen) = match &dmsg.control {
Some(control) => (control.as_mut_ptr() as *mut c_void, control.len()),
None => (ptr::null_mut(), 0),
};
dmsg.req.iovec = iovec;
dmsg.req.msg.msg_iov = &raw mut dmsg.req.iovec as _;
dmsg.req.msg.msg_iovlen = 1;
let (c_addr_storage, c_addr_len) = addr.to_c_storage();
dmsg.req.addr = c_addr_storage;
dmsg.req.msg.msg_name = &raw mut dmsg.req.addr as _;
dmsg.req.msg.msg_namelen = c_addr_len as _;
dmsg.req.msg.msg_control = control;
dmsg.req.msg.msg_controllen = controllen;
&mut dmsg.req.msg
}
pub struct Inner {
io_handle: Option<IoHandle>,
error: Option<Errno>,
is_shutdown: ShutdownStatus,
msg_queue: MsgQueue,
}
unsafe impl Send for Inner {}
impl Inner {
pub fn new() -> Self {
Self {
io_handle: None,
error: None,
is_shutdown: ShutdownStatus::Running,
msg_queue: MsgQueue::new(),
}
}
/// Obtain udp sender shutdown state.
#[inline(always)]
pub fn is_shutdown(&self) -> bool {
self.is_shutdown == ShutdownStatus::PreShutdown
|| self.is_shutdown == ShutdownStatus::PostShutdown
}
}
#[repr(C)]
struct SendReq {
msg: libc::msghdr,
iovec: libc::iovec,
addr: libc::sockaddr_storage,
}
unsafe impl MaybeUntrusted for SendReq {}
struct MsgQueue {
queue: VecDeque<DataMsg>,
curr_size: usize,
}
impl MsgQueue {
#[inline(always)]
fn new() -> Self {
Self {
queue: VecDeque::with_capacity(SENDMSG_QUEUE_LEN),
curr_size: 0,
}
}
#[inline(always)]
fn size(&self) -> usize {
self.curr_size
}
#[inline(always)]
fn is_empty(&self) -> bool {
self.queue.is_empty()
}
// Push datagram msg, return true if succeed,
// return false if buffer is full.
#[inline(always)]
fn push_msg(&mut self, msg: DataMsg) -> bool {
let total_len = msg.len() + self.size();
if total_len <= super::MAX_BUF_SIZE {
self.curr_size = total_len;
self.queue.push_back(msg);
return true;
}
false
}
#[inline(always)]
fn pop_msg(&mut self) {
if let Some(msg) = self.queue.pop_front() {
self.curr_size = self.size() - msg.len();
}
}
#[inline(always)]
fn first_msg_ptr(&self) -> Option<*const libc::msghdr> {
self.queue
.front()
.map(|data_msg| &data_msg.req.msg as *const libc::msghdr)
}
}
// Datagram msg contents in untrusted region
struct DataMsg {
req: UntrustedBox<SendReq>,
send_buf: UntrustedBox<[u8]>,
control: Option<UntrustedBox<[u8]>>,
}
impl DataMsg {
#[inline(always)]
fn new(buf_len: usize) -> Self {
Self {
req: UntrustedBox::<SendReq>::new_uninit(),
send_buf: UntrustedBox::new_uninit_slice(buf_len),
control: None,
}
}
#[inline(always)]
fn copy_buf(&mut self, bufs: &[&[u8]]) -> Result<usize> {
let total_len = self.send_buf.len();
if total_len > super::MAX_BUF_SIZE {
return_errno!(EMSGSIZE, "the message is too large")
}
// Copy data from the bufs to the send buffer
let mut total_copied = 0;
for buf in bufs {
self.send_buf[total_copied..(total_copied + buf.len())].copy_from_slice(buf);
total_copied += buf.len();
}
Ok(total_copied)
}
#[inline(always)]
fn copy_control(&mut self, control: Option<&[u8]>) -> Result<usize> {
if let Some(msg_control) = control {
let send_controllen = msg_control.len();
if send_controllen > super::OPTMEM_MAX {
return_errno!(EINVAL, "invalid msg control length");
}
let mut send_control_buf = UntrustedBox::new_uninit_slice(send_controllen);
send_control_buf.copy_from_slice(&msg_control[..send_controllen]);
self.control = Some(send_control_buf);
return Ok(send_controllen);
};
Ok(0)
}
#[inline(always)]
fn len(&self) -> usize {
self.send_buf.len()
}
}
#[derive(Debug, PartialEq)]
enum ShutdownStatus {
Running, // not shutdown
PreShutdown, // start the shutdown process, set by calling shutdown syscall
PostShutdown, // shutdown process is done, set when the buffer is empty
}

@ -0,0 +1,79 @@
use super::socket_file::SocketFile;
use crate::fs::{
AccessMode, FileDesc, HostFd, IoEvents, IoNotifier, IoctlCmd, IoctlRawCmd, StatusFlags,
};
use crate::prelude::*;
use std::{io::SeekFrom, os::unix::raw::off_t};
impl File for SocketFile {
fn read(&self, buf: &mut [u8]) -> Result<usize> {
self.read(buf)
}
fn readv(&self, bufs: &mut [&mut [u8]]) -> Result<usize> {
self.readv(bufs)
}
fn write(&self, buf: &[u8]) -> Result<usize> {
self.write(buf)
}
fn writev(&self, bufs: &[&[u8]]) -> Result<usize> {
self.writev(bufs)
}
fn read_at(&self, offset: usize, buf: &mut [u8]) -> Result<usize> {
if offset != 0 {
return_errno!(ESPIPE, "a nonzero position is not supported");
}
self.read(buf)
}
fn write_at(&self, offset: usize, buf: &[u8]) -> Result<usize> {
if offset != 0 {
return_errno!(ESPIPE, "a nonzero position is not supported");
}
self.write(buf)
}
fn seek(&self, pos: SeekFrom) -> Result<off_t> {
return_errno!(ESPIPE, "Socket does not support seek")
}
fn ioctl(&self, cmd: &mut dyn IoctlCmd) -> Result<()> {
self.ioctl(cmd)
}
fn notifier(&self) -> Option<&IoNotifier> {
Some(self.notifier())
}
fn access_mode(&self) -> Result<AccessMode> {
Ok(AccessMode::O_RDWR)
}
fn status_flags(&self) -> Result<StatusFlags> {
Ok(self.status_flags())
}
fn set_status_flags(&self, new_status_flags: StatusFlags) -> Result<()> {
self.set_status_flags(new_status_flags)
}
fn poll_new(&self) -> IoEvents {
let mask = IoEvents::all();
self.poll(mask, None)
}
fn host_fd(&self) -> Option<&HostFd> {
None
}
fn update_host_events(&self, ready: &IoEvents, mask: &IoEvents, trigger_notifier: bool) {
unreachable!()
}
fn as_any(&self) -> &dyn core::any::Any {
self
}
}

@ -0,0 +1,12 @@
#![feature(stmt_expr_attributes)]
#![feature(new_uninit)]
#![feature(raw_ref_op)]
pub mod common;
pub mod datagram;
pub mod file_impl;
pub mod runtime;
pub mod socket_file;
pub mod stream;
pub use self::socket_file::UringSocketType;

@ -0,0 +1,12 @@
use alloc::sync::Arc;
use io_uring_callback::IoUring;
/// The runtime support for HostSocket.
///
/// This trait provides a common interface for user-implemented runtimes
/// that support HostSocket. Currently, the only dependency is a singleton
/// of IoUring instance.
pub trait Runtime: Send + Sync + 'static {
fn io_uring() -> Arc<IoUring>;
fn disattach_io_uring(fd: usize, uring: Arc<IoUring>);
}

@ -0,0 +1,446 @@
use self::impls::{Ipv4Datagram, Ipv6Datagram};
use crate::events::{Observer, Poller};
use crate::net::socket::{MsgFlags, SocketProtocol};
use self::impls::{Ipv4Stream, Ipv6Stream};
use crate::fs::{AccessMode, IoEvents, IoNotifier, IoctlCmd, StatusFlags};
use crate::net::socket::{AnyAddr, Ipv4SocketAddr, Ipv6SocketAddr};
use crate::prelude::*;
#[derive(Debug)]
pub struct SocketFile {
socket: AnySocket,
}
// Apply a function to all variants of AnySocket enum.
macro_rules! apply_fn_on_any_socket {
($any_socket:expr, |$socket:ident| { $($fn_body:tt)* }) => {{
let any_socket: &AnySocket = $any_socket;
match any_socket {
AnySocket::Ipv4Stream($socket) => {
$($fn_body)*
}
AnySocket::Ipv6Stream($socket) => {
$($fn_body)*
}
AnySocket::Ipv4Datagram($socket) => {
$($fn_body)*
}
AnySocket::Ipv6Datagram($socket) => {
$($fn_body)*
}
}
}}
}
pub trait UringSocketType {
fn as_uring_socket(&self) -> Result<&SocketFile>;
}
impl UringSocketType for FileRef {
fn as_uring_socket(&self) -> Result<&SocketFile> {
self.as_any()
.downcast_ref::<SocketFile>()
.ok_or_else(|| errno!(ENOTSOCK, "not a uring socket"))
}
}
#[derive(Debug)]
enum AnySocket {
Ipv4Stream(Ipv4Stream),
Ipv6Stream(Ipv6Stream),
Ipv4Datagram(Ipv4Datagram),
Ipv6Datagram(Ipv6Datagram),
}
// Implement the common methods required by FileHandle
impl SocketFile {
pub fn read(&self, buf: &mut [u8]) -> Result<usize> {
apply_fn_on_any_socket!(&self.socket, |socket| { socket.read(buf) })
}
pub fn readv(&self, bufs: &mut [&mut [u8]]) -> Result<usize> {
apply_fn_on_any_socket!(&self.socket, |socket| { socket.readv(bufs) })
}
pub fn write(&self, buf: &[u8]) -> Result<usize> {
apply_fn_on_any_socket!(&self.socket, |socket| { socket.write(buf) })
}
pub fn writev(&self, bufs: &[&[u8]]) -> Result<usize> {
apply_fn_on_any_socket!(&self.socket, |socket| { socket.writev(bufs) })
}
pub fn access_mode(&self) -> AccessMode {
// We consider all sockets both readable and writable
AccessMode::O_RDWR
}
pub fn status_flags(&self) -> StatusFlags {
apply_fn_on_any_socket!(&self.socket, |socket| { socket.status_flags() })
}
pub fn host_fd_inner(&self) -> FileDesc {
apply_fn_on_any_socket!(&self.socket, |socket| { socket.host_fd() })
}
pub fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> {
apply_fn_on_any_socket!(&self.socket, |socket| {
socket.set_status_flags(new_flags)
})
}
pub fn notifier(&self) -> &IoNotifier {
apply_fn_on_any_socket!(&self.socket, |socket| { socket.notifier() })
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
apply_fn_on_any_socket!(&self.socket, |socket| { socket.poll(mask, poller) })
}
pub fn ioctl(&self, cmd: &mut dyn IoctlCmd) -> Result<()> {
apply_fn_on_any_socket!(&self.socket, |socket| { socket.ioctl(cmd) })
}
pub fn get_type(&self) -> Type {
match self.socket {
AnySocket::Ipv4Stream(_) | AnySocket::Ipv6Stream(_) => Type::STREAM,
AnySocket::Ipv4Datagram(_) | AnySocket::Ipv6Datagram(_) => Type::DGRAM,
}
}
}
// Implement socket-specific methods
impl SocketFile {
pub fn new(
domain: Domain,
protocol: SocketProtocol,
socket_type: Type,
nonblocking: bool,
) -> Result<Self> {
match socket_type {
Type::STREAM => {
if protocol != SocketProtocol::IPPROTO_IP && protocol != SocketProtocol::IPPROTO_TCP
{
return_errno!(EPROTONOSUPPORT, "Protocol not supported");
}
let any_socket = match domain {
Domain::INET => {
let ipv4_stream = Ipv4Stream::new(nonblocking)?;
AnySocket::Ipv4Stream(ipv4_stream)
}
Domain::INET6 => {
let ipv6_stream = Ipv6Stream::new(nonblocking)?;
AnySocket::Ipv6Stream(ipv6_stream)
}
_ => {
panic!()
}
};
let new_self = Self { socket: any_socket };
Ok(new_self)
}
Type::DGRAM => {
if protocol != SocketProtocol::IPPROTO_IP && protocol != SocketProtocol::IPPROTO_UDP
{
return_errno!(EPROTONOSUPPORT, "Protocol not supported");
}
let any_socket = match domain {
Domain::INET => {
let ipv4_datagram = Ipv4Datagram::new(nonblocking)?;
AnySocket::Ipv4Datagram(ipv4_datagram)
}
Domain::INET6 => {
let ipv6_datagram = Ipv6Datagram::new(nonblocking)?;
AnySocket::Ipv6Datagram(ipv6_datagram)
}
_ => {
return_errno!(EINVAL, "not support yet");
}
};
let new_self = Self { socket: any_socket };
Ok(new_self)
}
Type::RAW => {
return_errno!(EINVAL, "RAW socket not supported");
}
_ => {
return_errno!(ESOCKTNOSUPPORT, "socket type not supported");
}
}
}
pub fn domain(&self) -> Domain {
apply_fn_on_any_socket!(&self.socket, |socket| { socket.domain() })
}
pub fn is_stream(&self) -> bool {
matches!(&self.socket, AnySocket::Ipv4Stream(_))
}
pub fn connect(&self, addr: &AnyAddr) -> Result<()> {
match &self.socket {
AnySocket::Ipv4Stream(ipv4_stream) => {
let ip_addr = addr.to_ipv4()?;
ipv4_stream.connect(ip_addr)
}
AnySocket::Ipv6Stream(ipv6_stream) => {
let ip_addr = addr.to_ipv6()?;
ipv6_stream.connect(ip_addr)
}
AnySocket::Ipv4Datagram(ipv4_datagram) => {
let mut ip_addr = None;
if !addr.is_unspec() {
let ipv4_addr = addr.to_ipv4()?;
ip_addr = Some(ipv4_addr);
}
ipv4_datagram.connect(ip_addr)
}
AnySocket::Ipv6Datagram(ipv6_datagram) => {
let mut ip_addr = None;
if !addr.is_unspec() {
let ipv6_addr = addr.to_ipv6()?;
ip_addr = Some(ipv6_addr);
}
ipv6_datagram.connect(ip_addr)
}
_ => {
return_errno!(EINVAL, "connect is not supported");
}
}
}
pub fn bind(&self, addr: &mut AnyAddr) -> Result<()> {
match &self.socket {
AnySocket::Ipv4Stream(ipv4_stream) => {
let ip_addr = addr.to_ipv4()?;
ipv4_stream.bind(ip_addr)
}
AnySocket::Ipv6Stream(ipv6_stream) => {
let ip_addr = addr.to_ipv6()?;
ipv6_stream.bind(ip_addr)
}
AnySocket::Ipv4Datagram(ipv4_datagram) => {
let ip_addr = addr.to_ipv4()?;
ipv4_datagram.bind(ip_addr)
}
AnySocket::Ipv6Datagram(ipv6_datagram) => {
let ip_addr = addr.to_ipv6()?;
ipv6_datagram.bind(ip_addr)
}
_ => {
return_errno!(EINVAL, "bind is not supported");
}
}
}
pub fn listen(&self, backlog: u32) -> Result<()> {
match &self.socket {
AnySocket::Ipv4Stream(ip_stream) => ip_stream.listen(backlog),
AnySocket::Ipv6Stream(ip_stream) => ip_stream.listen(backlog),
_ => {
return_errno!(EOPNOTSUPP, "The socket is not of a listen supported type");
}
}
}
pub fn accept(&self, nonblocking: bool) -> Result<Self> {
let accepted_any_socket = match &self.socket {
AnySocket::Ipv4Stream(ipv4_stream) => {
let accepted_ipv4_stream = ipv4_stream.accept(nonblocking)?;
AnySocket::Ipv4Stream(accepted_ipv4_stream)
}
AnySocket::Ipv6Stream(ipv6_stream) => {
let accepted_ipv6_stream = ipv6_stream.accept(nonblocking)?;
AnySocket::Ipv6Stream(accepted_ipv6_stream)
}
_ => {
return_errno!(EOPNOTSUPP, "The socket is not of a accept supported type");
}
};
let accepted_socket_file = SocketFile {
socket: accepted_any_socket,
};
Ok(accepted_socket_file)
}
pub fn recvfrom(&self, buf: &mut [u8], flags: RecvFlags) -> Result<(usize, Option<AnyAddr>)> {
let (bytes_recv, addr_recv, _, _) = self.recvmsg(&mut [buf], flags, None)?;
Ok((bytes_recv, addr_recv))
}
pub fn recvmsg(
&self,
bufs: &mut [&mut [u8]],
flags: RecvFlags,
control: Option<&mut [u8]>,
) -> Result<(usize, Option<AnyAddr>, MsgFlags, usize)> {
// TODO: support msg_flags and msg_control
Ok(match &self.socket {
AnySocket::Ipv4Stream(ipv4_stream) => {
let (bytes_recv, addr_recv, msg_flags) = ipv4_stream.recvmsg(bufs, flags)?;
(
bytes_recv,
addr_recv.map(|addr| AnyAddr::Ipv4(addr)),
msg_flags,
0,
)
}
AnySocket::Ipv6Stream(ipv6_stream) => {
let (bytes_recv, addr_recv, msg_flags) = ipv6_stream.recvmsg(bufs, flags)?;
(
bytes_recv,
addr_recv.map(|addr| AnyAddr::Ipv6(addr)),
msg_flags,
0,
)
}
AnySocket::Ipv4Datagram(ipv4_datagram) => {
let (bytes_recv, addr_recv, msg_flags, msg_controllen) =
ipv4_datagram.recvmsg(bufs, flags, control)?;
(
bytes_recv,
addr_recv.map(|addr| AnyAddr::Ipv4(addr)),
msg_flags,
msg_controllen,
)
}
AnySocket::Ipv6Datagram(ipv6_datagram) => {
let (bytes_recv, addr_recv, msg_flags, msg_controllen) =
ipv6_datagram.recvmsg(bufs, flags, control)?;
(
bytes_recv,
addr_recv.map(|addr| AnyAddr::Ipv6(addr)),
msg_flags,
msg_controllen,
)
}
_ => {
return_errno!(EINVAL, "recvfrom is not supported");
}
})
}
pub fn sendto(&self, buf: &[u8], addr: Option<AnyAddr>, flags: SendFlags) -> Result<usize> {
self.sendmsg(&[buf], addr, flags, None)
}
pub fn sendmsg(
&self,
bufs: &[&[u8]],
addr: Option<AnyAddr>,
flags: SendFlags,
control: Option<&[u8]>,
) -> Result<usize> {
let res = match &self.socket {
AnySocket::Ipv4Stream(ipv4_stream) => ipv4_stream.sendmsg(bufs, flags),
AnySocket::Ipv6Stream(ipv6_stream) => ipv6_stream.sendmsg(bufs, flags),
AnySocket::Ipv4Datagram(ipv4_datagram) => {
let ip_addr = if let Some(addr) = addr.as_ref() {
Some(addr.to_ipv4()?)
} else {
None
};
ipv4_datagram.sendmsg(bufs, ip_addr, flags, control)
}
AnySocket::Ipv6Datagram(ipv6_datagram) => {
let ip_addr = if let Some(addr) = addr.as_ref() {
Some(addr.to_ipv6()?)
} else {
None
};
ipv6_datagram.sendmsg(bufs, ip_addr, flags, control)
}
_ => {
return_errno!(EINVAL, "sendmsg is not supported");
}
};
if res.has_errno(EPIPE) && !flags.contains(SendFlags::MSG_NOSIGNAL) {
crate::signal::do_tkill(current!().tid(), crate::signal::SIGPIPE.as_u8() as i32);
}
res
}
pub fn addr(&self) -> Result<AnyAddr> {
Ok(match &self.socket {
AnySocket::Ipv4Stream(ipv4_stream) => AnyAddr::Ipv4(ipv4_stream.addr()?),
AnySocket::Ipv6Stream(ipv6_stream) => AnyAddr::Ipv6(ipv6_stream.addr()?),
AnySocket::Ipv4Datagram(ipv4_datagram) => AnyAddr::Ipv4(ipv4_datagram.addr()?),
AnySocket::Ipv6Datagram(ipv6_datagram) => AnyAddr::Ipv6(ipv6_datagram.addr()?),
_ => {
return_errno!(EINVAL, "addr is not supported");
}
})
}
pub fn peer_addr(&self) -> Result<AnyAddr> {
Ok(match &self.socket {
AnySocket::Ipv4Stream(ipv4_stream) => AnyAddr::Ipv4(ipv4_stream.peer_addr()?),
AnySocket::Ipv6Stream(ipv6_stream) => AnyAddr::Ipv6(ipv6_stream.peer_addr()?),
AnySocket::Ipv4Datagram(ipv4_datagram) => AnyAddr::Ipv4(ipv4_datagram.peer_addr()?),
AnySocket::Ipv6Datagram(ipv6_datagram) => AnyAddr::Ipv6(ipv6_datagram.peer_addr()?),
_ => {
return_errno!(EINVAL, "peer_addr is not supported");
}
})
}
pub fn shutdown(&self, how: Shutdown) -> Result<()> {
match &self.socket {
AnySocket::Ipv4Stream(ipv4_stream) => ipv4_stream.shutdown(how),
AnySocket::Ipv6Stream(ipv6_stream) => ipv6_stream.shutdown(how),
AnySocket::Ipv4Datagram(ipv4_datagram) => ipv4_datagram.shutdown(how),
AnySocket::Ipv6Datagram(ipv6_datagram) => ipv6_datagram.shutdown(how),
_ => {
return_errno!(EINVAL, "shutdown is not supported");
}
}
}
pub fn close(&self) -> Result<()> {
match &self.socket {
AnySocket::Ipv4Stream(ipv4_stream) => ipv4_stream.close(),
AnySocket::Ipv6Stream(ipv6_stream) => ipv6_stream.close(),
AnySocket::Ipv4Datagram(ipv4_datagram) => ipv4_datagram.close(),
AnySocket::Ipv6Datagram(ipv6_datagram) => ipv6_datagram.close(),
_ => Ok(()),
}
}
}
impl Drop for SocketFile {
fn drop(&mut self) {
self.close();
}
}
mod impls {
use super::*;
use io_uring_callback::IoUring;
pub type Ipv4Stream =
crate::net::socket::uring::stream::StreamSocket<Ipv4SocketAddr, SocketRuntime>;
pub type Ipv6Stream =
crate::net::socket::uring::stream::StreamSocket<Ipv6SocketAddr, SocketRuntime>;
pub type Ipv4Datagram =
crate::net::socket::uring::datagram::DatagramSocket<Ipv4SocketAddr, SocketRuntime>;
pub type Ipv6Datagram =
crate::net::socket::uring::datagram::DatagramSocket<Ipv6SocketAddr, SocketRuntime>;
pub struct SocketRuntime;
impl crate::net::socket::uring::runtime::Runtime for SocketRuntime {
// Assign an IO-Uring instance for newly created socket
fn io_uring() -> Arc<IoUring> {
crate::io_uring::MULTITON.get_uring()
}
// Disattach IO-Uring instance with closed socket
fn disattach_io_uring(fd: usize, uring: Arc<IoUring>) {
crate::io_uring::MULTITON.disattach_uring(fd, uring);
}
}
}

@ -0,0 +1,610 @@
mod states;
use core::hint;
use core::sync::atomic::AtomicUsize;
use core::time::Duration;
use atomic::Ordering;
use self::states::{ConnectedStream, ConnectingStream, InitStream, ListenerStream};
use crate::events::Observer;
use crate::fs::{
GetIfConf, GetIfReqWithRawCmd, GetReadBufLen, IoEvents, IoNotifier, IoctlCmd, SetNonBlocking,
StatusFlags,
};
use crate::net::socket::uring::common::Common;
use crate::net::socket::uring::runtime::Runtime;
use crate::prelude::*;
use crate::events::Poller;
use crate::net::socket::{sockopt::*, MsgFlags};
lazy_static! {
pub static ref SEND_BUF_SIZE: AtomicUsize = AtomicUsize::new(2565 * 1024 + 1); // Default Linux send buffer size is 2.5MB.
pub static ref RECV_BUF_SIZE: AtomicUsize = AtomicUsize::new(256 * 1024 + 1);
}
pub struct StreamSocket<A: Addr + 'static, R: Runtime> {
state: RwLock<State<A, R>>,
common: Arc<Common<A, R>>,
}
enum State<A: Addr + 'static, R: Runtime> {
// Start state
Init(Arc<InitStream<A, R>>),
// Intermediate state
Connect(Arc<ConnectingStream<A, R>>),
// Final state 1
Connected(Arc<ConnectedStream<A, R>>),
// Final state 2
Listen(Arc<ListenerStream<A, R>>),
}
impl<A: Addr, R: Runtime> StreamSocket<A, R> {
pub fn new(nonblocking: bool) -> Result<Self> {
let init_stream = InitStream::new(nonblocking)?;
let common = init_stream.common().clone();
let fd = common.host_fd();
debug!("host fd: {}", fd);
let init_state = State::Init(init_stream);
Ok(Self {
state: RwLock::new(init_state),
common,
})
}
pub fn new_pair(nonblocking: bool) -> Result<(Self, Self)> {
let (common1, common2) = Common::new_pair(Type::STREAM, nonblocking)?;
let connected1 = ConnectedStream::new(Arc::new(common1));
let connected2 = ConnectedStream::new(Arc::new(common2));
let socket1 = Self::new_connected(connected1);
let socket2 = Self::new_connected(connected2);
Ok((socket1, socket2))
}
fn new_connected(connected_stream: Arc<ConnectedStream<A, R>>) -> Self {
let common = connected_stream.common().clone();
let state = RwLock::new(State::Connected(connected_stream));
Self { state, common }
}
fn try_switch_to_connected_state(
connecting_stream: &Arc<ConnectingStream<A, R>>,
) -> Option<Arc<ConnectedStream<A, R>>> {
// Previously, I thought connecting state only exists for non-blocking socket. However, some applications can set non-blocking for
// connect syscall and after the connect returns, set the socket to blocking socket. Thus, this function shouldn't assert the connecting
// stream is non-blocking socket.
if connecting_stream.check_connection() {
let common = connecting_stream.common().clone();
common.set_peer_addr(connecting_stream.peer_addr());
Some(ConnectedStream::new(common))
} else {
None
}
}
pub fn domain(&self) -> Domain {
A::domain()
}
pub fn errno(&self) -> Option<Errno> {
self.common.errno()
}
pub fn host_fd(&self) -> FileDesc {
let state = self.state.read().unwrap();
state.common().host_fd()
}
pub fn status_flags(&self) -> StatusFlags {
// Only support O_NONBLOCK
let state = self.state.read().unwrap();
if state.common().nonblocking() {
StatusFlags::O_NONBLOCK
} else {
StatusFlags::empty()
}
}
pub fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> {
// Only support O_NONBLOCK
let state = self.state.read().unwrap();
let nonblocking = new_flags.is_nonblocking();
state.common().set_nonblocking(nonblocking);
Ok(())
}
pub fn bind(&self, addr: &A) -> Result<()> {
let state = self.state.read().unwrap();
match &*state {
State::Init(init_stream) => init_stream.bind(addr),
_ => {
return_errno!(EINVAL, "cannot bind");
}
}
}
pub fn listen(&self, backlog: u32) -> Result<()> {
let mut state = self.state.write().unwrap();
match &*state {
State::Init(init_stream) => {
let common = init_stream.common().clone();
let listener = ListenerStream::new(backlog, common)?;
*state = State::Listen(listener);
Ok(())
}
_ => {
return_errno!(EINVAL, "cannot listen");
}
}
}
pub fn connect(&self, peer_addr: &A) -> Result<()> {
// Create the new intermediate state of connecting and save the
// old state of init in case of failure to connect.
let (init_stream, connecting_stream) = {
let mut state = self.state.write().unwrap();
match &*state {
State::Init(init_stream) => {
let connecting_stream = {
let common = init_stream.common().clone();
ConnectingStream::new(peer_addr, common)?
};
let init_stream = init_stream.clone();
*state = State::Connect(connecting_stream.clone());
(init_stream, connecting_stream)
}
State::Connect(connecting_stream) => {
if let Some(connected_stream) =
Self::try_switch_to_connected_state(connecting_stream)
{
*state = State::Connected(connected_stream);
return_errno!(EISCONN, "the socket is already connected");
} else {
// Not connected, keep the connecting state and try connect
let init_stream =
InitStream::new_with_common(connecting_stream.common().clone())?;
(init_stream, connecting_stream.clone())
}
}
State::Connected(_) => {
return_errno!(EISCONN, "the socket is already connected");
}
State::Listen(_) => {
return_errno!(EINVAL, "the socket is listening");
}
}
};
let res = connecting_stream.connect();
// If success, then the state is switched to connected; otherwise, for blocking socket
// the state is restored to the init state, and for non-blocking socket, the state
// keeps in connecting state.
match &res {
Ok(()) => {
let connected_stream = {
let common = init_stream.common().clone();
common.set_peer_addr(peer_addr);
ConnectedStream::new(common)
};
let mut state = self.state.write().unwrap();
*state = State::Connected(connected_stream);
}
Err(_) => {
if !connecting_stream.common().nonblocking() {
let mut state = self.state.write().unwrap();
*state = State::Init(init_stream);
}
}
}
res
}
pub fn accept(&self, nonblocking: bool) -> Result<Self> {
let listener_stream = {
let state = self.state.read().unwrap();
match &*state {
State::Listen(listener_stream) => listener_stream.clone(),
_ => {
return_errno!(EINVAL, "the socket is not listening");
}
}
};
let connected_stream = listener_stream.accept(nonblocking)?;
let new_self = Self::new_connected(connected_stream);
Ok(new_self)
}
pub fn read(&self, buf: &mut [u8]) -> Result<usize> {
self.readv(&mut [buf])
}
pub fn readv(&self, bufs: &mut [&mut [u8]]) -> Result<usize> {
let ret = self.recvmsg(bufs, RecvFlags::empty())?;
Ok(ret.0)
}
/// Receive messages from connected socket
///
/// Linux behavior:
/// Unlike datagram socket, `recvfrom` / `recvmsg` of stream socket will
/// ignore the address even if user specified it.
pub fn recvmsg(
&self,
buf: &mut [&mut [u8]],
flags: RecvFlags,
) -> Result<(usize, Option<A>, MsgFlags)> {
let connected_stream = {
let mut state = self.state.write().unwrap();
match &*state {
State::Connected(connected_stream) => connected_stream.clone(),
State::Connect(connecting_stream) => {
if let Some(connected_stream) =
Self::try_switch_to_connected_state(connecting_stream)
{
*state = State::Connected(connected_stream.clone());
connected_stream
} else {
return_errno!(ENOTCONN, "the socket is not connected");
}
}
_ => {
return_errno!(ENOTCONN, "the socket is not connected");
}
}
};
let recv_len = connected_stream.recvmsg(buf, flags)?;
Ok((recv_len, None, MsgFlags::empty()))
}
pub fn write(&self, buf: &[u8]) -> Result<usize> {
self.writev(&[buf])
}
pub fn writev(&self, bufs: &[&[u8]]) -> Result<usize> {
self.sendmsg(bufs, SendFlags::empty())
}
pub fn sendmsg(&self, bufs: &[&[u8]], flags: SendFlags) -> Result<usize> {
let connected_stream = {
let mut state = self.state.write().unwrap();
match &*state {
State::Connected(connected_stream) => connected_stream.clone(),
State::Connect(connecting_stream) => {
if let Some(connected_stream) =
Self::try_switch_to_connected_state(connecting_stream)
{
*state = State::Connected(connected_stream.clone());
connected_stream
} else {
return_errno!(ENOTCONN, "the socket is not connected");
}
}
_ => {
return_errno!(EPIPE, "the socket is not connected");
}
}
};
connected_stream.sendmsg(bufs, flags)
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
let state = self.state.read().unwrap();
let pollee = state.common().pollee();
pollee.poll(mask, poller)
}
pub fn addr(&self) -> Result<A> {
let state = self.state.read().unwrap();
let common = state.common();
// Always get addr from host.
// Because for IP socket, users can specify "0" as port and the kernel should select a usable port for him.
// Thus, when calling getsockname, this should be updated.
let addr = common.get_addr_from_host()?;
common.set_addr(&addr);
Ok(addr)
}
pub fn notifier(&self) -> &IoNotifier {
{
let mut state = self.state.write().unwrap();
// Try switch to connected state to receive endpoint status
if let State::Connect(connecting_stream) = &*state {
if let Some(connected_stream) =
Self::try_switch_to_connected_state(connecting_stream)
{
*state = State::Connected(connected_stream.clone());
}
}
// `state` goes out of scope here and the lock is implicitly released.
}
self.common.notifier()
}
pub fn peer_addr(&self) -> Result<A> {
let mut state = self.state.write().unwrap();
match &*state {
State::Connected(connected_stream) => {
Ok(connected_stream.common().peer_addr().unwrap())
}
State::Connect(connecting_stream) => {
if let Some(connected_stream) =
Self::try_switch_to_connected_state(connecting_stream)
{
*state = State::Connected(connected_stream.clone());
Ok(connected_stream.common().peer_addr().unwrap())
} else {
return_errno!(ENOTCONN, "the socket is not connected");
}
}
_ => return_errno!(ENOTCONN, "the socket is not connected"),
}
}
pub fn ioctl(&self, cmd: &mut dyn IoctlCmd) -> Result<()> {
let mut state = self.state.write().unwrap();
match &*state {
State::Connect(connecting_stream) => {
if let Some(connected_stream) =
Self::try_switch_to_connected_state(connecting_stream)
{
*state = State::Connected(connected_stream.clone());
}
}
_ => {}
}
drop(state);
crate::match_ioctl_cmd_mut!(&mut *cmd, {
cmd: GetSockOptRawCmd => {
cmd.execute(self.host_fd())?;
},
cmd: SetSockOptRawCmd => {
cmd.execute(self.host_fd())?;
},
cmd: SetRecvTimeoutCmd => {
self.set_recv_timeout(*cmd.timeout());
},
cmd: SetSendTimeoutCmd => {
self.set_send_timeout(*cmd.timeout());
},
cmd: GetRecvTimeoutCmd => {
let timeval = timeout_to_timeval(self.recv_timeout());
cmd.set_output(timeval);
},
cmd: GetSendTimeoutCmd => {
let timeval = timeout_to_timeval(self.send_timeout());
cmd.set_output(timeval);
},
cmd: SetSndBufSizeCmd => {
cmd.update_host(self.host_fd())?;
let buf_size = cmd.buf_size();
self.set_kernel_send_buf_size(buf_size);
},
cmd: SetRcvBufSizeCmd => {
cmd.update_host(self.host_fd())?;
let buf_size = cmd.buf_size();
self.set_kernel_recv_buf_size(buf_size);
},
cmd: GetSndBufSizeCmd => {
let buf_size = SEND_BUF_SIZE.load(Ordering::Relaxed);
cmd.set_output(buf_size);
},
cmd: GetRcvBufSizeCmd => {
let buf_size = RECV_BUF_SIZE.load(Ordering::Relaxed);
cmd.set_output(buf_size);
},
cmd: GetAcceptConnCmd => {
let mut is_listen = false;
let state = self.state.read().unwrap();
if let State::Listen(_listener_stream) = &*state {
is_listen = true;
}
cmd.set_output(is_listen as _);
},
cmd: GetDomainCmd => {
cmd.set_output(self.domain() as _);
},
cmd: GetPeerNameCmd => {
let peer = self.peer_addr()?;
cmd.set_output(AddrStorage(peer.to_c_storage()));
},
cmd: GetErrorCmd => {
let error: i32 = self.errno().map(|err| err as i32).unwrap_or(0);
cmd.set_output(error);
},
cmd: GetTypeCmd => {
let state = self.state.read().unwrap();
cmd.set_output(state.common().type_() as _);
},
cmd: SetNonBlocking => {
let state = self.state.read().unwrap();
state.common().set_nonblocking(*cmd.input() != 0);
},
cmd: GetReadBufLen => {
let state = self.state.read().unwrap();
if let State::Connected(connected_stream) = &*state {
let read_buf_len = connected_stream.bytes_to_consume();
cmd.set_output(read_buf_len as _);
} else {
return_errno!(ENOTCONN, "unconnected socket");
}
},
cmd: GetIfReqWithRawCmd => {
cmd.execute(self.host_fd())?;
},
cmd: GetIfConf => {
cmd.execute(self.host_fd())?;
},
_ => {
return_errno!(EINVAL, "Not supported yet");
}
});
Ok(())
}
fn set_kernel_send_buf_size(&self, buf_size: usize) {
let state = self.state.read().unwrap();
match &*state {
State::Init(_) | State::Listen(_) | State::Connect(_) => {
// The kernel buffer is only created when the socket is connected. Just update the static variable.
SEND_BUF_SIZE.store(buf_size, Ordering::Relaxed);
}
State::Connected(connected_stream) => {
connected_stream.try_update_send_buf_size(buf_size);
}
}
}
fn set_kernel_recv_buf_size(&self, buf_size: usize) {
let state = self.state.read().unwrap();
match &*state {
State::Init(_) | State::Listen(_) | State::Connect(_) => {
// The kernel buffer is only created when the socket is connected. Just update the static variable.
RECV_BUF_SIZE.store(buf_size, Ordering::Relaxed);
}
State::Connected(connected_stream) => {
connected_stream.try_update_recv_buf_size(buf_size);
}
}
}
pub fn shutdown(&self, shutdown: Shutdown) -> Result<()> {
let mut state = self.state.write().unwrap();
match &*state {
State::Listen(listener_stream) => {
// listening socket can be shutdown and then re-use by calling listen again.
listener_stream.shutdown(shutdown)?;
if shutdown.should_shut_read() {
// Cancel pending accept requests. This is necessary because the socket is reusable.
listener_stream.cancel_accept_requests();
// Set init state
let init_stream =
InitStream::new_with_common(listener_stream.common().clone())?;
let init_state = State::Init(init_stream);
*state = init_state;
Ok(())
} else {
// shutdown the writer of the listener expect to have no effect
Ok(())
}
}
State::Connected(connected_stream) => connected_stream.shutdown(shutdown),
State::Connect(connecting_stream) => {
if let Some(connected_stream) =
Self::try_switch_to_connected_state(connecting_stream)
{
connected_stream.shutdown(shutdown)?;
*state = State::Connected(connected_stream);
Ok(())
} else {
return_errno!(ENOTCONN, "the socket is not connected");
}
}
_ => {
return_errno!(ENOTCONN, "the socket is not connected");
}
}
}
pub fn close(&self) -> Result<()> {
let state = self.state.read().unwrap();
match &*state {
State::Init(_) => {}
State::Listen(listener_stream) => {
listener_stream.common().set_closed();
listener_stream.cancel_accept_requests();
}
State::Connect(connecting_stream) => {
connecting_stream.common().set_closed();
let need_wait = true;
connecting_stream.cancel_connect_request(need_wait);
}
State::Connected(connected_stream) => {
connected_stream.set_closed();
connected_stream.cancel_recv_requests();
connected_stream.try_empty_send_buf_when_close();
}
}
Ok(())
}
fn send_timeout(&self) -> Option<Duration> {
let state = self.state.read().unwrap();
state.common().send_timeout()
}
fn recv_timeout(&self) -> Option<Duration> {
let state = self.state.read().unwrap();
state.common().recv_timeout()
}
fn set_send_timeout(&self, timeout: Duration) {
let state = self.state.read().unwrap();
state.common().set_send_timeout(timeout);
}
fn set_recv_timeout(&self, timeout: Duration) {
let state = self.state.read().unwrap();
state.common().set_recv_timeout(timeout);
}
/*
pub fn poll_by(&self, mask: Events, mut poller: Option<&mut Poller>) -> Events {
let state = self.state.read();
match *state {
Init(init_stream) => init_stream.poll_by(mask, poller),
Connect(connect_stream) => connect_stream.poll_by(mask, poller),
Connected(connected_stream) = connected_stream.poll_by(mask, poller),
Listen(listener_stream) = listener_stream.poll_by(mask, poller),
}
}
*/
}
impl<A: Addr + 'static, R: Runtime> Drop for StreamSocket<A, R> {
fn drop(&mut self) {
let state = self.state.read().unwrap();
state.common().set_closed();
drop(state);
}
}
impl<A: Addr + 'static, R: Runtime> std::fmt::Debug for State<A, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let inner: &dyn std::fmt::Debug = match self {
State::Init(inner) => inner as _,
State::Connect(inner) => inner as _,
State::Connected(inner) => inner as _,
State::Listen(inner) => inner as _,
};
f.debug_tuple("State").field(inner).finish()
}
}
impl<A: Addr + 'static, R: Runtime> std::fmt::Debug for StreamSocket<A, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamSocket").finish()
}
}
impl<A: Addr + 'static, R: Runtime> State<A, R> {
fn common(&self) -> &Common<A, R> {
match self {
Self::Init(stream) => stream.common(),
Self::Connect(stream) => stream.common(),
Self::Connected(stream) => stream.common(),
Self::Listen(stream) => stream.common(),
}
}
}

@ -0,0 +1,227 @@
use core::time::Duration;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, Ordering};
use io_uring_callback::{Fd, IoHandle};
use sgx_untrusted_alloc::UntrustedBox;
use crate::events::Poller;
use crate::fs::IoEvents;
use crate::net::socket::uring::common::Common;
use crate::net::socket::uring::runtime::Runtime;
use crate::prelude::*;
/// A stream socket that is in its connecting state.
pub struct ConnectingStream<A: Addr + 'static, R: Runtime> {
common: Arc<Common<A, R>>,
peer_addr: A,
req: Mutex<ConnectReq<A>>,
connected: AtomicBool, // Mainly use for nonblocking socket to update status asynchronously
}
struct ConnectReq<A: Addr> {
io_handle: Option<IoHandle>,
c_addr: UntrustedBox<libc::sockaddr_storage>,
c_addr_len: usize,
errno: Option<Errno>,
phantom_data: PhantomData<A>,
}
impl<A: Addr + 'static, R: Runtime> ConnectingStream<A, R> {
pub fn new(peer_addr: &A, common: Arc<Common<A, R>>) -> Result<Arc<Self>> {
let req = Mutex::new(ConnectReq::new(peer_addr));
let new_self = Self {
common,
peer_addr: peer_addr.clone(),
req,
connected: AtomicBool::new(false),
};
Ok(Arc::new(new_self))
}
/// Connect to the peer address.
pub fn connect(self: &Arc<Self>) -> Result<()> {
let pollee = self.common.pollee();
pollee.reset_events();
self.initiate_async_connect();
if self.common.nonblocking() {
return_errno!(EINPROGRESS, "non-blocking connect request in progress");
}
// Wait for the async connect to complete
let mask = IoEvents::OUT;
let poller = Poller::new();
pollee.connect_poller(mask, &poller);
let mut timeout = self.common.send_timeout();
loop {
let events = pollee.poll(mask, None);
if !events.is_empty() {
break;
}
let ret = poller.wait_timeout(timeout.as_mut());
if let Err(e) = ret {
let errno = e.errno();
warn!("connect wait errno = {:?}", errno);
match errno {
ETIMEDOUT => {
// Cancel connect request if timeout. No need to wait for cancel to complete.
self.cancel_connect_request(false);
// This error code is same as the connect timeout error code on Linux
return_errno!(EINPROGRESS, "timeout reached")
}
_ => {
return_errno!(e.errno(), "wait error")
}
}
}
}
// Finish the async connect
let req = self.req.lock();
if let Some(e) = req.errno {
return_errno!(e, "connect failed");
}
Ok(())
}
fn initiate_async_connect(self: &Arc<Self>) {
let io_uring = self.common.io_uring();
let mut req = self.req.lock();
// Skip if there is pending request
if req.io_handle.is_some() {
return;
}
let arc_self = self.clone();
let callback = move |retval: i32| {
// Guard against Igao attack
assert!(retval <= 0);
debug!("connect request complete with retval: {}", retval);
let mut req = arc_self.req.lock();
// Release the handle to the async connect
req.io_handle.take();
if retval == 0 {
arc_self.connected.store(true, Ordering::Relaxed);
arc_self.common.pollee().add_events(IoEvents::OUT);
} else {
// Store the errno
let errno = Errno::from(-retval as u32);
req.errno = Some(errno);
drop(req);
arc_self.common.set_errno(errno);
arc_self.connected.store(false, Ordering::Relaxed);
let events = if errno == ENOTCONN || errno == ECONNRESET || errno == ECONNREFUSED {
IoEvents::HUP | IoEvents::IN | IoEvents::ERR
} else {
IoEvents::ERR
};
arc_self.common.pollee().add_events(events);
}
};
let host_fd = self.common.host_fd() as _;
let c_addr_ptr = req.c_addr.as_ptr();
let c_addr_len = req.c_addr_len;
let io_handle = unsafe {
io_uring.connect(
Fd(host_fd),
c_addr_ptr as *const libc::sockaddr,
c_addr_len as u32,
callback,
)
};
req.io_handle = Some(io_handle);
}
pub fn cancel_connect_request(&self, need_wait: bool) {
{
let io_uring = self.common.io_uring();
let req = self.req.lock();
if let Some(io_handle) = &req.io_handle {
unsafe { io_uring.cancel(io_handle) };
} else {
return;
}
}
// Wait for the cancel to complete if needed
if !need_wait {
return;
}
let poller = Poller::new();
let mask = IoEvents::ERR | IoEvents::IN;
self.common.pollee().connect_poller(mask, &poller);
loop {
let pending_request_exist = {
let req = self.req.lock();
req.io_handle.is_some()
};
if pending_request_exist {
let mut timeout = Some(Duration::from_secs(10));
let ret = poller.wait_timeout(timeout.as_mut());
if let Err(e) = ret {
warn!("wait cancel connect request error = {:?}", e.errno());
continue;
}
} else {
break;
}
}
}
#[allow(dead_code)]
pub fn peer_addr(&self) -> &A {
&self.peer_addr
}
pub fn common(&self) -> &Arc<Common<A, R>> {
&self.common
}
// This can be used in connecting state to check non-blocking connect status.
pub fn check_connection(&self) -> bool {
// It is fine whether the load happens before or after the store operation
self.connected.load(Ordering::Relaxed)
}
}
impl<A: Addr> ConnectReq<A> {
pub fn new(peer_addr: &A) -> Self {
let (c_addr_storage, c_addr_len) = peer_addr.to_c_storage();
Self {
io_handle: None,
c_addr: UntrustedBox::new(c_addr_storage),
c_addr_len,
errno: None,
phantom_data: PhantomData,
}
}
}
impl<A: Addr, R: Runtime> std::fmt::Debug for ConnectingStream<A, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectingStream")
.field("common", &self.common)
.field("peer_addr", &self.peer_addr)
.field("req", &*self.req.lock())
.field("connected", &self.connected)
.finish()
}
}
impl<A: Addr> std::fmt::Debug for ConnectReq<A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectReq")
.field("io_handle", &self.io_handle)
.field("errno", &self.errno)
.finish()
}
}

@ -0,0 +1,114 @@
use atomic::Ordering;
use self::recv::Receiver;
use self::send::Sender;
use crate::fs::IoEvents as Events;
use crate::net::socket::sockopt::SockOptName;
use crate::net::socket::uring::common::Common;
use crate::net::socket::uring::runtime::Runtime;
use crate::prelude::*;
mod recv;
mod send;
pub struct ConnectedStream<A: Addr + 'static, R: Runtime> {
common: Arc<Common<A, R>>,
sender: Sender,
receiver: Receiver,
}
impl<A: Addr + 'static, R: Runtime> ConnectedStream<A, R> {
pub fn new(common: Arc<Common<A, R>>) -> Arc<Self> {
common.pollee().reset_events();
common.pollee().add_events(Events::OUT);
let fd = common.host_fd();
let sender = Sender::new();
let receiver = Receiver::new();
let new_self = Arc::new(Self {
common,
sender,
receiver,
});
// Start async recv requests right as early as possible to support poll and
// improve performance. If we don't start recv requests early, the poll()
// might block forever when user just invokes poll(Event::In) without read().
// Once we have recv requests completed, we can have Event::In in the events.
new_self.initiate_async_recv();
new_self
}
pub fn common(&self) -> &Arc<Common<A, R>> {
&self.common
}
pub fn shutdown(&self, how: Shutdown) -> Result<()> {
// Do host shutdown
// For shutdown write, don't call host_shutdown until the content in the pending buffer is sent.
// For shutdown read, ignore the pending buffer.
let (shut_write, send_buf_is_empty, shut_read) = (
how.should_shut_write(),
self.sender.is_empty(),
how.should_shut_read(),
);
match (shut_write, send_buf_is_empty, shut_read) {
// As long as send buf is empty, just shutdown.
(_, true, _) => self.common.host_shutdown(how)?,
// If not shutdown write, just shutdown.
(false, _, _) => self.common.host_shutdown(how)?,
// If shutdown both but the send buf is not empty, only shutdown read.
(true, false, true) => self.common.host_shutdown(Shutdown::Read)?,
// If shutdown write but the send buf is not empty, don't do shutdown.
(true, false, false) => {}
}
// Set internal state and trigger events.
if shut_read {
self.receiver.shutdown();
self.common.pollee().add_events(Events::IN);
}
if shut_write {
self.sender.shutdown();
self.common.pollee().add_events(Events::OUT);
}
if shut_read && shut_write {
self.common.pollee().add_events(Events::HUP);
}
Ok(())
}
pub fn set_closed(&self) {
// Mark the sender and receiver to shutdown to prevent submitting new requests.
self.receiver.shutdown();
self.sender.shutdown();
self.common.set_closed();
}
// Other methods are implemented in the send and receive modules
}
impl<A: Addr + 'static, R: Runtime> std::fmt::Debug for ConnectedStream<A, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectedStream")
.field("common", &self.common)
.field("sender", &self.sender)
.field("receiver", &self.receiver)
.finish()
}
}
fn new_msghdr(iovecs_ptr: *mut libc::iovec, iovecs_len: usize) -> libc::msghdr {
use std::mem::MaybeUninit;
// Safety. Setting all fields to zeros is a valid state for msghdr.
let mut msghdr: libc::msghdr = unsafe { MaybeUninit::zeroed().assume_init() };
msghdr.msg_iov = iovecs_ptr;
msghdr.msg_iovlen = iovecs_len as _;
// We do want to leave all other fields as zeros
msghdr
}

@ -0,0 +1,458 @@
use core::hint;
use core::sync::atomic::AtomicBool;
use core::time::Duration;
use std::mem::MaybeUninit;
use std::ptr::{self};
use atomic::Ordering;
use io_uring_callback::{Fd, IoHandle};
use sgx_untrusted_alloc::{MaybeUntrusted, UntrustedBox};
use super::ConnectedStream;
use crate::net::socket::uring::runtime::Runtime;
use crate::net::socket::uring::stream::RECV_BUF_SIZE;
use crate::prelude::*;
use crate::untrusted::UntrustedCircularBuf;
use crate::util::sync::{Mutex, MutexGuard};
use crate::events::Poller;
use crate::fs::IoEvents as Events;
impl<A: Addr + 'static, R: Runtime> ConnectedStream<A, R> {
pub fn recvmsg(self: &Arc<Self>, bufs: &mut [&mut [u8]], flags: RecvFlags) -> Result<usize> {
let total_len: usize = bufs.iter().map(|buf| buf.len()).sum();
if total_len == 0 {
return Ok(0);
}
let mut total_received = 0;
let mut iov_buffer_index = 0;
let mut iov_buffer_offset = 0;
let mask = Events::IN;
// Initialize the poller only when needed
let mut poller = None;
let mut timeout = self.common.recv_timeout();
loop {
// Attempt to read
let res = self.try_recvmsg(bufs, flags, iov_buffer_index, iov_buffer_offset);
match res {
Ok((received_size, index, offset)) => {
total_received += received_size;
if !flags.contains(RecvFlags::MSG_WAITALL) || total_received == total_len {
return Ok(total_received);
} else {
// save the index and offset for the next round
iov_buffer_index = index;
iov_buffer_offset = offset;
}
}
Err(e) => {
if e.errno() != EAGAIN {
return Err(e);
}
}
};
if self.common.nonblocking() || flags.contains(RecvFlags::MSG_DONTWAIT) {
return_errno!(EAGAIN, "no data are present to be received");
}
// Wait for interesting events by polling
if poller.is_none() {
let new_poller = Poller::new();
self.common.pollee().connect_poller(mask, &new_poller);
poller = Some(new_poller);
}
let events = self.common.pollee().poll(mask, None);
if events.is_empty() {
let ret = poller.as_ref().unwrap().wait_timeout(timeout.as_mut());
if let Err(e) = ret {
warn!("recv wait errno = {:?}", e.errno());
// For recv with MSG_WAITALL, return total received bytes if timeout or interrupt
if flags.contains(RecvFlags::MSG_WAITALL) && total_received > 0 {
return Ok(total_received);
}
match e.errno() {
ETIMEDOUT => {
return_errno!(EAGAIN, "timeout reached")
}
_ => {
return_errno!(e.errno(), "wait error")
}
}
}
}
}
}
fn try_recvmsg(
self: &Arc<Self>,
bufs: &mut [&mut [u8]],
flags: RecvFlags,
iov_buffer_index: usize,
iov_buffer_offset: usize,
) -> Result<(usize, usize, usize)> {
let mut inner = self.receiver.inner.lock();
if !flags.is_empty()
&& flags.intersects(!(RecvFlags::MSG_DONTWAIT | RecvFlags::MSG_WAITALL))
{
warn!("Unsupported flags: {:?}", flags);
return_errno!(EINVAL, "flags not supported");
}
let res = {
let mut total_consumed = 0;
let mut iov_buffer_index = iov_buffer_index;
let mut iov_buffer_offset = iov_buffer_offset;
// save the received data from bufs[iov_buffer_index][iov_buffer_offset..]
for (_, buf) in bufs.iter_mut().skip(iov_buffer_index).enumerate() {
let this_consumed = inner.recv_buf.consume(&mut buf[iov_buffer_offset..]);
if this_consumed == 0 {
break;
}
total_consumed += this_consumed;
// if the buffer is not full, then the try_recvmsg will be used again
// next time, the data will be stored from the offset
if this_consumed < buf[iov_buffer_offset..].len() {
iov_buffer_offset += this_consumed;
break;
} else {
iov_buffer_index += 1;
iov_buffer_offset = 0;
}
}
(total_consumed, iov_buffer_index, iov_buffer_offset)
};
if self.receiver.need_update() {
// Only update the recv buf when it is empty and there is no pending recv request
if inner.recv_buf.is_empty() && inner.io_handle.is_none() {
self.receiver.set_need_update(false);
inner.update_buf_size(RECV_BUF_SIZE.load(Ordering::Relaxed));
}
}
if inner.end_of_file {
return Ok(res);
}
if inner.recv_buf.is_empty() {
// Mark the socket as non-readable
self.common.pollee().del_events(Events::IN);
}
if res.0 > 0 {
self.do_recv(&mut inner);
return Ok(res);
}
// Only when there are no data available in the recv buffer, shall we check
// the following error conditions.
//
// Case 1: If the read side of the connection has been shutdown...
if inner.is_shutdown {
return_errno!(EPIPE, "read side is shutdown");
}
// Case 2: If the connenction has been broken...
if let Some(errno) = inner.fatal {
// Reset error
inner.fatal = None;
self.common.pollee().del_events(Events::ERR);
return_errno!(errno, "read failed");
}
self.do_recv(&mut inner);
return_errno!(EAGAIN, "try read again");
}
fn do_recv(self: &Arc<Self>, inner: &mut MutexGuard<Inner>) {
if inner.recv_buf.is_full()
|| inner.is_shutdown
|| inner.io_handle.is_some()
|| inner.end_of_file
|| self.common.is_closed()
{
// Delete ERR events from sender. If io_handle is some, the recv request must be
// pending and the events can't be for the reciever. Just delete this event.
// This can happen when send request is timeout and canceled.
let events = self.common.pollee().poll(Events::IN, None);
if events.contains(Events::ERR) && inner.io_handle.is_some() {
self.common.pollee().del_events(Events::ERR);
}
return;
}
// Init the callback invoked upon the completion of the async recv
let stream = self.clone();
let complete_fn = move |retval: i32| {
// let mut inner = stream.receiver.inner.lock().unwrap();
let mut inner = stream.receiver.inner.lock();
trace!("recv request complete with retval: {:?}", retval);
// Release the handle to the async recv
inner.io_handle.take();
// Handle error
if retval < 0 {
// TODO: guard against Iago attack through errno
// We should return here, The error may be due to network reasons
// or because the request was cancelled. We don't want to start a
// new request after cancelled a request.
let errno = Errno::from(-retval as u32);
inner.fatal = Some(errno);
stream.common.set_errno(errno);
let events = if errno == ENOTCONN || errno == ECONNRESET || errno == ECONNREFUSED {
Events::HUP | Events::IN | Events::ERR
} else {
Events::ERR
};
stream.common.pollee().add_events(events);
return;
}
// Handle end of file
else if retval == 0 {
inner.end_of_file = true;
stream.common.pollee().add_events(Events::IN);
return;
}
// Handle the normal case of a successful read
let nbytes = retval as usize;
inner.recv_buf.produce_without_copy(nbytes);
// Now that we have produced non-zero bytes, the buf must become
// ready to read.
stream.common.pollee().add_events(Events::IN);
stream.do_recv(&mut inner);
};
// Generate the async recv request
let msghdr_ptr = inner.new_recv_req();
// Submit the async recv to io_uring
let io_uring = self.common.io_uring();
let host_fd = Fd(self.common.host_fd() as _);
let handle = unsafe { io_uring.recvmsg(host_fd, msghdr_ptr, 0, complete_fn) };
inner.io_handle.replace(handle);
}
pub(super) fn initiate_async_recv(self: &Arc<Self>) {
// trace!("initiate async recv");
let mut inner = self.receiver.inner.lock();
self.do_recv(&mut inner);
}
pub fn cancel_recv_requests(&self) {
{
let inner = self.receiver.inner.lock();
if let Some(io_handle) = &inner.io_handle {
let io_uring = self.common.io_uring();
unsafe { io_uring.cancel(io_handle) };
} else {
return;
}
}
// wait for the cancel to complete
let poller = Poller::new();
let mask = Events::ERR | Events::IN;
self.common.pollee().connect_poller(mask, &poller);
loop {
let pending_request_exist = {
let inner = self.receiver.inner.lock();
inner.io_handle.is_some()
};
if pending_request_exist {
let mut timeout = Some(Duration::from_secs(10));
let ret = poller.wait_timeout(timeout.as_mut());
if let Err(e) = ret {
warn!("wait cancel recv request error = {:?}", e.errno());
continue;
}
} else {
break;
}
}
}
pub fn bytes_to_consume(self: &Arc<Self>) -> usize {
let inner = self.receiver.inner.lock();
inner.recv_buf.consumable()
}
// This function will try to update the kernel recv buf size.
// For socket recv, there will always be a pending request in advance. Thus,we can only update the kernel
// buffer when a recv request is done and the kernel buffer is empty. Here, we just set the update flag.
pub fn try_update_recv_buf_size(&self, buf_size: usize) {
let pre_buf_size = RECV_BUF_SIZE.swap(buf_size, Ordering::Relaxed);
if buf_size == pre_buf_size {
return;
}
self.receiver.set_need_update(true);
}
}
pub struct Receiver {
inner: Mutex<Inner>,
need_update: AtomicBool,
}
impl Receiver {
pub fn new() -> Self {
let inner = Mutex::new(Inner::new());
let need_update = AtomicBool::new(false);
Self { inner, need_update }
}
pub fn shutdown(&self) {
let mut inner = self.inner.lock();
inner.is_shutdown = true;
}
pub fn set_need_update(&self, need_update: bool) {
self.need_update.store(need_update, Ordering::Relaxed)
}
pub fn need_update(&self) -> bool {
self.need_update.load(Ordering::Relaxed)
}
}
impl std::fmt::Debug for Receiver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Receiver")
.field("inner", &self.inner.lock())
.finish()
}
}
struct Inner {
recv_buf: UntrustedCircularBuf,
recv_req: UntrustedBox<RecvReq>,
io_handle: Option<IoHandle>,
is_shutdown: bool,
end_of_file: bool,
fatal: Option<Errno>,
}
// Safety. `RecvReq` does not implement `Send`. But since all pointers in `RecvReq`
// refer to `recv_buf`, we can be sure that it is ok for `RecvReq` to move between
// threads. All other fields in `RecvReq` implement `Send` as well. So the entirety
// of `Inner` is `Send`-safe.
unsafe impl Send for Inner {}
impl Inner {
pub fn new() -> Self {
Self {
recv_buf: UntrustedCircularBuf::with_capacity(RECV_BUF_SIZE.load(Ordering::Relaxed)),
recv_req: UntrustedBox::new_uninit(),
io_handle: None,
is_shutdown: false,
end_of_file: false,
fatal: None,
}
}
fn update_buf_size(&mut self, buf_size: usize) {
debug_assert!(self.recv_buf.is_empty() && self.io_handle.is_none());
let new_recv_buf = UntrustedCircularBuf::with_capacity(buf_size);
self.recv_buf = new_recv_buf;
}
/// Constructs a new recv request according to the receiver's internal state.
///
/// The new `RecvReq` will be put into `self.recv_req`, which is a location that is
/// accessible by io_uring. A pointer to the C version of the resulting `RecvReq`,
/// which is `libc::msghdr`, will be returned.
///
/// The buffer used in the new `RecvReq` is part of `self.recv_buf`.
pub fn new_recv_req(&mut self) -> *mut libc::msghdr {
let (iovecs, iovecs_len) = self.gen_iovecs_from_recv_buf();
let msghdr_ptr: *mut libc::msghdr = &mut self.recv_req.msg;
let iovecs_ptr: *mut libc::iovec = &mut self.recv_req.iovecs as *mut _ as _;
let msg = super::new_msghdr(iovecs_ptr, iovecs_len);
self.recv_req.msg = msg;
self.recv_req.iovecs = iovecs;
msghdr_ptr
}
fn gen_iovecs_from_recv_buf(&mut self) -> ([libc::iovec; 2], usize) {
let mut iovecs_len = 0;
let mut iovecs = unsafe { MaybeUninit::<[libc::iovec; 2]>::uninit().assume_init() };
self.recv_buf.with_producer_view(|part0, part1| {
debug_assert!(part0.len() > 0);
iovecs[0] = libc::iovec {
iov_base: part0.as_ptr() as _,
iov_len: part0.len() as _,
};
iovecs[1] = if part1.len() > 0 {
iovecs_len = 2;
libc::iovec {
iov_base: part1.as_ptr() as _,
iov_len: part1.len() as _,
}
} else {
iovecs_len = 1;
libc::iovec {
iov_base: ptr::null_mut(),
iov_len: 0,
}
};
// Only access the producer's buffer; zero bytes produced for now.
0
});
debug_assert!(iovecs_len > 0);
(iovecs, iovecs_len)
}
}
impl std::fmt::Debug for Inner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Inner")
.field("recv_buf", &self.recv_buf)
.field("io_handle", &self.io_handle)
.field("is_shutdown", &self.is_shutdown)
.field("end_of_file", &self.end_of_file)
.field("fatal", &self.fatal)
.finish()
}
}
#[repr(C)]
struct RecvReq {
msg: libc::msghdr,
iovecs: [libc::iovec; 2],
}
// Safety. RecvReq is a C-style struct.
unsafe impl MaybeUntrusted for RecvReq {}
// Acquired by `IoUringCell<T: Copy>`.
impl Copy for RecvReq {}
impl Clone for RecvReq {
fn clone(&self) -> Self {
*self
}
}

@ -0,0 +1,466 @@
use core::hint;
use core::sync::atomic::AtomicBool;
use core::time::Duration;
use std::mem::MaybeUninit;
use std::ptr::{self};
use atomic::Ordering;
use io_uring_callback::{Fd, IoHandle};
use log::error;
use sgx_untrusted_alloc::{MaybeUntrusted, UntrustedBox};
use super::ConnectedStream;
use crate::net::socket::uring::runtime::Runtime;
use crate::net::socket::uring::stream::SEND_BUF_SIZE;
use crate::prelude::*;
use crate::untrusted::UntrustedCircularBuf;
use crate::util::sync::{Mutex, MutexGuard};
use crate::events::Poller;
use crate::fs::IoEvents as Events;
impl<A: Addr + 'static, R: Runtime> ConnectedStream<A, R> {
// We make sure the all the buffer contents are buffered in kernel and then return.
pub fn sendmsg(self: &Arc<Self>, bufs: &[&[u8]], flags: SendFlags) -> Result<usize> {
let total_len: usize = bufs.iter().map(|buf| buf.len()).sum();
if total_len == 0 {
return Ok(0);
}
let mut send_len = 0;
// variables to track the position of async sendmsg.
let mut iov_buf_id = 0; // user buffer id tracker
let mut iov_buf_index = 0; // user buffer index tracker
let mask = Events::OUT;
// Initialize the poller only when needed
let mut poller = None;
let mut timeout = self.common.send_timeout();
loop {
// Attempt to write
let res = self.try_sendmsg(bufs, flags, &mut iov_buf_id, &mut iov_buf_index);
if let Ok(len) = res {
send_len += len;
// Sent all or sent partial but it is nonblocking, return bytes sent
if send_len == total_len
|| self.common.nonblocking()
|| flags.contains(SendFlags::MSG_DONTWAIT)
{
return Ok(send_len);
}
} else if !res.has_errno(EAGAIN) {
return res;
}
// Still some buffer contents pending
if self.common.nonblocking() || flags.contains(SendFlags::MSG_DONTWAIT) {
return_errno!(EAGAIN, "try write again");
}
// Wait for interesting events by polling
if poller.is_none() {
let new_poller = Poller::new();
self.common.pollee().connect_poller(mask, &new_poller);
poller = Some(new_poller);
}
let events = self.common.pollee().poll(mask, None);
if events.is_empty() {
let ret = poller.as_ref().unwrap().wait_timeout(timeout.as_mut());
if let Err(e) = ret {
warn!("send wait errno = {:?}", e.errno());
match e.errno() {
ETIMEDOUT => {
// Just cancel send requests if timeout
self.cancel_send_requests();
return_errno!(EAGAIN, "timeout reached")
}
_ => {
return_errno!(e.errno(), "wait error")
}
}
}
}
}
}
fn try_sendmsg(
self: &Arc<Self>,
bufs: &[&[u8]],
flags: SendFlags,
iov_buf_id: &mut usize,
iov_buf_index: &mut usize,
) -> Result<usize> {
let mut inner = self.sender.inner.lock();
if !flags.is_empty()
&& flags.intersects(
!(SendFlags::MSG_DONTWAIT | SendFlags::MSG_NOSIGNAL | SendFlags::MSG_MORE),
)
{
error!("Not supported flags: {:?}", flags);
return_errno!(EINVAL, "not supported flags");
}
// Check for error condition before write.
//
// Case 1. If the write side of the connection has been shutdown...
if inner.is_shutdown() {
return_errno!(EPIPE, "write side is shutdown");
}
// Case 2. If the connenction has been broken...
if let Some(errno) = inner.fatal {
// Reset error
inner.fatal = None;
self.common.pollee().del_events(Events::ERR);
return_errno!(errno, "write failed");
}
// Copy data from the bufs to the send buffer
// If the send buffer is full, update the user buffer tracker, return error to wait for events
// And once there is free space, continue from the user buffer tracker
let nbytes = {
let mut total_produced = 0;
let last_time_buf_id = iov_buf_id.clone();
let mut last_time_buf_idx = iov_buf_index.clone();
for (_i, buf) in bufs.iter().skip(last_time_buf_id).enumerate() {
let i = _i + last_time_buf_id; // After skipping ,the index still starts from 0
let this_produced = inner.send_buf.produce(&buf[last_time_buf_idx..]);
total_produced += this_produced;
if this_produced < buf[last_time_buf_idx..].len() {
// Send buffer is full.
*iov_buf_id = i;
*iov_buf_index = last_time_buf_idx + this_produced;
break;
} else {
// For next buffer, start from the front
last_time_buf_idx = 0;
}
}
total_produced
};
if inner.send_buf.is_full() {
// Mark the socket as non-writable
self.common.pollee().del_events(Events::OUT);
}
// Since the send buffer is not empty, we can try to flush the buffer
if inner.io_handle.is_none() {
self.do_send(&mut inner);
}
if nbytes > 0 {
Ok(nbytes)
} else {
return_errno!(EAGAIN, "try write again");
}
}
fn do_send(self: &Arc<Self>, inner: &mut MutexGuard<Inner>) {
// This function can also be called even if the socket is set to shutdown by shutdown syscall. This is due to the
// async behaviour that the kernel may return to user before actually issuing the request. We should
// keep sending the request as long as the send buffer is not empty even if the socket is shutdown.
debug_assert!(inner.is_shutdown != ShutdownStatus::PostShutdown);
debug_assert!(!inner.send_buf.is_empty());
debug_assert!(inner.io_handle.is_none());
// Init the callback invoked upon the completion of the async send
let stream = self.clone();
let complete_fn = move |retval: i32| {
let mut inner = stream.sender.inner.lock();
trace!("send request complete with retval: {}", retval);
// Release the handle to the async send
inner.io_handle.take();
// Handle error
if retval < 0 {
// TODO: guard against Iago attack through errno
// TODO: should we ignore EINTR and try again?
let errno = Errno::from(-retval as u32);
inner.fatal = Some(errno);
stream.common.set_errno(errno);
stream.common.pollee().add_events(Events::ERR);
return;
}
assert!(retval != 0);
// Handle the normal case of a successful write
let nbytes = retval as usize;
inner.send_buf.consume_without_copy(nbytes);
// Now that we have consume non-zero bytes, the buf must become
// ready to write.
stream.common.pollee().add_events(Events::OUT);
// Attempt to send again if there are available data in the buf.
if !inner.send_buf.is_empty() {
stream.do_send(&mut inner);
} else if inner.is_shutdown == ShutdownStatus::PreShutdown {
// The buffer is empty and the write side is shutdown by the user. We can safely shutdown host file here.
let _ = stream.common.host_shutdown(Shutdown::Write);
inner.is_shutdown = ShutdownStatus::PostShutdown
} else if stream.sender.need_update() {
// send_buf is empty. We can try to update the send_buf
stream.sender.set_need_update(false);
inner.update_buf_size(SEND_BUF_SIZE.load(Ordering::Relaxed));
}
};
// Generate the async send request
let msghdr_ptr = inner.new_send_req();
trace!("send submit request");
// Submit the async send to io_uring
let io_uring = self.common.io_uring();
let host_fd = Fd(self.common.host_fd() as _);
let handle = unsafe { io_uring.sendmsg(host_fd, msghdr_ptr, 0, complete_fn) };
inner.io_handle.replace(handle);
}
pub fn cancel_send_requests(&self) {
let io_uring = self.common.io_uring();
let inner = self.sender.inner.lock();
if let Some(io_handle) = &inner.io_handle {
unsafe { io_uring.cancel(io_handle) };
}
}
// This function will try to update the kernel buf size.
// If the kernel buf is currently empty, the size will be updated immediately.
// If the kernel buf is not empty, update the flag in Sender and update the kernel buf after send.
pub fn try_update_send_buf_size(&self, buf_size: usize) {
let pre_buf_size = SEND_BUF_SIZE.swap(buf_size, Ordering::Relaxed);
if pre_buf_size == buf_size {
return;
}
// Try to acquire the lock. If success, try directly update here.
// If failure, don't wait because there is pending send request.
if let Some(mut inner) = self.sender.inner.try_lock() {
if inner.send_buf.is_empty() && inner.io_handle.is_none() {
inner.update_buf_size(buf_size);
return;
}
}
// Can't easily aquire lock or the sendbuf is not empty. Update the flag only
self.sender.set_need_update(true);
}
// Normally, We will always try to send as long as the kernel send buf is not empty. However, if the user calls close, we will wait LINGER time
// and then cancel on-going or new-issued send requests.
pub fn try_empty_send_buf_when_close(&self) {
// let inner = self.sender.inner.lock().unwrap();
let inner = self.sender.inner.lock();
debug_assert!(inner.is_shutdown());
if inner.send_buf.is_empty() {
return;
}
// Wait for linger time to empty the kernel buffer or cancel subsequent requests.
drop(inner);
const DEFUALT_LINGER_TIME: usize = 10;
let poller = Poller::new();
let mask = Events::ERR | Events::OUT;
self.common.pollee().connect_poller(mask, &poller);
loop {
let pending_request_exist = {
// let inner = self.sender.inner.lock().unwrap();
let inner = self.sender.inner.lock();
inner.io_handle.is_some()
};
if pending_request_exist {
let mut timeout = Some(Duration::from_secs(DEFUALT_LINGER_TIME as u64));
let ret = poller.wait_timeout(timeout.as_mut());
trace!("wait empty send buffer ret = {:?}", ret);
if let Err(_) = ret {
// No complete request to wake. Just cancel the send requests.
let io_uring = self.common.io_uring();
let inner = self.sender.inner.lock();
if let Some(io_handle) = &inner.io_handle {
unsafe { io_uring.cancel(io_handle) };
// Loop again to wait the cancel request to complete
continue;
} else {
// No pending request, just break
break;
}
}
} else {
// There is no pending requests
break;
}
}
}
}
pub struct Sender {
inner: Mutex<Inner>,
need_update: AtomicBool,
}
impl Sender {
pub fn new() -> Self {
let inner = Mutex::new(Inner::new());
let need_update = AtomicBool::new(false);
Self { inner, need_update }
}
pub fn shutdown(&self) {
let mut inner = self.inner.lock();
inner.is_shutdown = ShutdownStatus::PreShutdown;
}
pub fn is_empty(&self) -> bool {
let inner = self.inner.lock();
inner.send_buf.is_empty()
}
pub fn set_need_update(&self, need_update: bool) {
self.need_update.store(need_update, Ordering::Relaxed)
}
pub fn need_update(&self) -> bool {
self.need_update.load(Ordering::Relaxed)
}
}
impl std::fmt::Debug for Sender {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sender")
.field("inner", &self.inner.lock())
.finish()
}
}
struct Inner {
send_buf: UntrustedCircularBuf,
send_req: UntrustedBox<SendReq>,
io_handle: Option<IoHandle>,
is_shutdown: ShutdownStatus,
fatal: Option<Errno>,
}
// Safety. `SendReq` does not implement `Send`. But since all pointers in `SengReq`
// refer to `send_buf`, we can be sure that it is ok for `SendReq` to move between
// threads. All other fields in `SendReq` implement `Send` as well. So the entirety
// of `Inner` is `Send`-safe.
unsafe impl Send for Inner {}
impl Inner {
pub fn new() -> Self {
Self {
send_buf: UntrustedCircularBuf::with_capacity(SEND_BUF_SIZE.load(Ordering::Relaxed)),
send_req: UntrustedBox::new_uninit(),
io_handle: None,
is_shutdown: ShutdownStatus::Running,
fatal: None,
}
}
fn update_buf_size(&mut self, buf_size: usize) {
debug_assert!(self.send_buf.is_empty() && self.io_handle.is_none());
let new_send_buf = UntrustedCircularBuf::with_capacity(buf_size);
self.send_buf = new_send_buf;
}
pub fn is_shutdown(&self) -> bool {
self.is_shutdown == ShutdownStatus::PreShutdown
|| self.is_shutdown == ShutdownStatus::PostShutdown
}
/// Constructs a new send request according to the sender's internal state.
///
/// The new `SendReq` will be put into `self.send_req`, which is a location that is
/// accessible by io_uring. A pointer to the C version of the resulting `SendReq`,
/// which is `libc::msghdr`, will be returned.
///
/// The buffer used in the new `SendReq` is part of `self.send_buf`.
pub fn new_send_req(&mut self) -> *mut libc::msghdr {
let (iovecs, iovecs_len) = self.gen_iovecs_from_send_buf();
let msghdr_ptr: *mut libc::msghdr = &mut self.send_req.msg;
let iovecs_ptr: *mut libc::iovec = &mut self.send_req.iovecs as *mut _ as _;
let msg = super::new_msghdr(iovecs_ptr, iovecs_len);
self.send_req.msg = msg;
self.send_req.iovecs = iovecs;
msghdr_ptr
}
fn gen_iovecs_from_send_buf(&mut self) -> ([libc::iovec; 2], usize) {
let mut iovecs_len = 0;
let mut iovecs = unsafe { MaybeUninit::<[libc::iovec; 2]>::uninit().assume_init() };
self.send_buf.with_consumer_view(|part0, part1| {
debug_assert!(part0.len() > 0);
iovecs[0] = libc::iovec {
iov_base: part0.as_ptr() as _,
iov_len: part0.len() as _,
};
iovecs[1] = if part1.len() > 0 {
iovecs_len = 2;
libc::iovec {
iov_base: part1.as_ptr() as _,
iov_len: part1.len() as _,
}
} else {
iovecs_len = 1;
libc::iovec {
iov_base: ptr::null_mut(),
iov_len: 0,
}
};
// Only access the consumer's buffer; zero bytes consumed for now.
0
});
debug_assert!(iovecs_len > 0);
(iovecs, iovecs_len)
}
}
impl std::fmt::Debug for Inner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Inner")
.field("send_buf", &self.send_buf)
.field("io_handle", &self.io_handle)
.field("is_shutdown", &self.is_shutdown)
.field("fatal", &self.fatal)
.finish()
}
}
#[repr(C)]
struct SendReq {
msg: libc::msghdr,
iovecs: [libc::iovec; 2],
}
// Safety. SendReq is a C-style struct.
unsafe impl MaybeUntrusted for SendReq {}
// Acquired by `IoUringCell<T: Copy>`.
impl Copy for SendReq {}
impl Clone for SendReq {
fn clone(&self) -> Self {
*self
}
}
#[derive(Debug, PartialEq)]
enum ShutdownStatus {
Running, // not shutdown
PreShutdown, // start the shutdown process, set by calling shutdown syscall
PostShutdown, // shutdown process is done, set when the buffer is empty
}

@ -0,0 +1,72 @@
use crate::fs::IoEvents;
use crate::net::socket::uring::common::Common;
use crate::net::socket::uring::runtime::Runtime;
use crate::prelude::*;
/// A stream socket that is in its initial state.
pub struct InitStream<A: Addr + 'static, R: Runtime> {
common: Arc<Common<A, R>>,
inner: Mutex<Inner>,
}
struct Inner {
has_bound: bool,
}
impl<A: Addr + 'static, R: Runtime> InitStream<A, R> {
pub fn new(nonblocking: bool) -> Result<Arc<Self>> {
let common = Arc::new(Common::new(Type::STREAM, nonblocking, None)?);
common.pollee().add_events(IoEvents::HUP | IoEvents::OUT);
let inner = Mutex::new(Inner::new());
let new_self = Self { common, inner };
Ok(Arc::new(new_self))
}
pub fn new_with_common(common: Arc<Common<A, R>>) -> Result<Arc<Self>> {
let inner = Mutex::new(Inner {
has_bound: common.addr().is_some(),
});
let new_self = Self { common, inner };
Ok(Arc::new(new_self))
}
pub fn bind(&self, addr: &A) -> Result<()> {
let mut inner = self.inner.lock();
if inner.has_bound {
return_errno!(EINVAL, "the socket is already bound to an address");
}
crate::net::socket::uring::common::do_bind(self.common.host_fd(), addr)?;
inner.has_bound = true;
self.common.set_addr(addr);
Ok(())
}
pub fn common(&self) -> &Arc<Common<A, R>> {
&self.common
}
}
impl<A: Addr + 'static, R: Runtime> std::fmt::Debug for InitStream<A, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InitStream")
.field("common", &self.common)
.field("inner", &*self.inner.lock())
.finish()
}
}
impl Inner {
pub fn new() -> Self {
Self { has_bound: false }
}
}
impl std::fmt::Debug for Inner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Inner")
.field("has_bound", &self.has_bound)
.finish()
}
}

@ -0,0 +1,430 @@
use core::time::Duration;
use std::collections::VecDeque;
use std::marker::PhantomData;
use std::mem::size_of;
use io_uring_callback::{Fd, IoHandle};
use sgx_untrusted_alloc::{MaybeUntrusted, UntrustedBox};
use super::ConnectedStream;
use crate::events::Poller;
use crate::fs::IoEvents;
use crate::net::socket::uring::common::{do_close, Common};
use crate::net::socket::uring::runtime::Runtime;
use crate::prelude::*;
use libc::ocall::shutdown as do_shutdown;
// We issue the async accept request ahead of time. But with big backlog number,
// there will be too many pending requests, which could be harmful to the system.
const PENDING_ASYNC_ACCEPT_NUM_MAX: usize = 128;
/// A listener stream, ready to accept incoming connections.
pub struct ListenerStream<A: Addr + 'static, R: Runtime> {
common: Arc<Common<A, R>>,
inner: Mutex<Inner<A>>,
}
impl<A: Addr + 'static, R: Runtime> ListenerStream<A, R> {
/// Creates a new listener stream.
pub fn new(backlog: u32, common: Arc<Common<A, R>>) -> Result<Arc<Self>> {
// Here we use different variables for the backlog. For the libos, as we will issue async accept request
// ahead of time, and cacel when the socket closes, we set the libos backlog to a reasonable value which
// is no greater than the max value we set to save resources and make it more efficient. For the host,
// we just use the backlog value for maximum connection.
let libos_backlog = std::cmp::min(backlog, PENDING_ASYNC_ACCEPT_NUM_MAX as u32);
let host_backlog = backlog;
let inner = Inner::new(libos_backlog)?;
Self::do_listen(common.host_fd(), host_backlog)?;
common.pollee().reset_events();
let new_self = Arc::new(Self {
common,
inner: Mutex::new(inner),
});
// Start async accept requests right as early as possible to improve performance
{
let inner = new_self.inner.lock();
new_self.initiate_async_accepts(inner);
}
Ok(new_self)
}
fn do_listen(host_fd: FileDesc, backlog: u32) -> Result<()> {
try_libc!(libc::ocall::listen(host_fd as _, backlog as _));
Ok(())
}
pub fn accept(self: &Arc<Self>, nonblocking: bool) -> Result<Arc<ConnectedStream<A, R>>> {
let mask = IoEvents::IN;
// Init the poller only when needed
let mut poller = None;
let mut timeout = self.common.recv_timeout();
loop {
// Attempt to accept
let res = self.try_accept(nonblocking);
if !res.has_errno(EAGAIN) {
return res;
}
if self.common.nonblocking() {
return_errno!(EAGAIN, "no connections are present to be accepted");
}
// Ensure the poller is initialized
if poller.is_none() {
let new_poller = Poller::new();
self.common.pollee().connect_poller(mask, &new_poller);
poller = Some(new_poller);
}
// Wait for interesting events by polling
let events = self.common.pollee().poll(mask, None);
if events.is_empty() {
let ret = poller.as_ref().unwrap().wait_timeout(timeout.as_mut());
if let Err(e) = ret {
warn!("accept wait errno = {:?}", e.errno());
match e.errno() {
ETIMEDOUT => {
return_errno!(EAGAIN, "timeout reached")
}
_ => {
return_errno!(e.errno(), "wait error")
}
}
}
}
}
}
pub fn try_accept(self: &Arc<Self>, nonblocking: bool) -> Result<Arc<ConnectedStream<A, R>>> {
let mut inner = self.inner.lock();
if let Some(errno) = inner.fatal {
// Reset error
inner.fatal = None;
self.common.pollee().del_events(IoEvents::ERR);
return_errno!(errno, "accept failed");
}
let (accepted_fd, accepted_addr) = inner.backlog.pop_completed_req().ok_or_else(|| {
self.common.pollee().del_events(IoEvents::IN);
errno!(EAGAIN, "try accept again")
})?;
if !inner.backlog.has_completed_reqs() {
self.common.pollee().del_events(IoEvents::IN);
}
self.initiate_async_accepts(inner);
let common = {
let common = Arc::new(Common::with_host_fd(accepted_fd, Type::STREAM, nonblocking));
common.set_peer_addr(&accepted_addr);
common
};
let accepted_stream = ConnectedStream::new(common);
Ok(accepted_stream)
}
fn initiate_async_accepts(self: &Arc<Self>, mut inner: MutexGuard<Inner<A>>) {
let backlog = &mut inner.backlog;
while backlog.has_free_entries() {
backlog.start_new_req(self);
}
}
pub fn common(&self) -> &Arc<Common<A, R>> {
&self.common
}
pub fn cancel_accept_requests(&self) {
{
// Set the listener stream as closed to prevent submitting new request in the callback fn
self.common().set_closed();
let io_uring = self.common.io_uring();
let inner = self.inner.lock();
for entry in inner.backlog.entries.iter() {
if let Entry::Pending { io_handle } = entry {
unsafe { io_uring.cancel(&io_handle) };
}
}
}
// wait for all the cancel requests to complete
let poller = Poller::new();
let mask = IoEvents::ERR | IoEvents::IN;
self.common.pollee().connect_poller(mask, &poller);
loop {
let pending_entry_exists = {
let inner = self.inner.lock();
inner
.backlog
.entries
.iter()
.find(|entry| match entry {
Entry::Pending { .. } => true,
_ => false,
})
.is_some()
};
if pending_entry_exists {
let mut timeout = Some(Duration::from_secs(20));
let ret = poller.wait_timeout(timeout.as_mut());
if let Err(e) = ret {
warn!("wait cancel accept request error = {:?}", e.errno());
continue;
}
} else {
// Reset the stream for re-listen
self.common().reset_closed();
return;
}
}
}
pub fn shutdown(&self, how: Shutdown) -> Result<()> {
if how == Shutdown::Both {
self.common.host_shutdown(Shutdown::Both)?;
self.common
.pollee()
.add_events(IoEvents::IN | IoEvents::OUT | IoEvents::HUP);
} else if how.should_shut_read() {
self.common.host_shutdown(Shutdown::Read)?;
self.common.pollee().add_events(IoEvents::IN);
} else if how.should_shut_write() {
self.common.host_shutdown(Shutdown::Write)?;
self.common.pollee().add_events(IoEvents::OUT);
}
Ok(())
}
}
impl<A: Addr + 'static, R: Runtime> std::fmt::Debug for ListenerStream<A, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ListenerStream")
.field("common", &self.common)
.field("inner", &self.inner.lock())
.finish()
}
}
/// The mutable, internal state of a listener stream.
struct Inner<A: Addr> {
backlog: Backlog<A>,
fatal: Option<Errno>,
}
impl<A: Addr> Inner<A> {
pub fn new(backlog: u32) -> Result<Self> {
Ok(Inner {
backlog: Backlog::with_capacity(backlog as usize)?,
fatal: None,
})
}
}
impl<A: Addr + 'static> std::fmt::Debug for Inner<A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Inner")
.field("backlog", &self.backlog)
.field("fatal", &self.fatal)
.finish()
}
}
/// An entry in the backlog.
#[derive(Debug)]
enum Entry {
/// The entry is free to use.
Free,
/// The entry is a pending accept request.
Pending { io_handle: IoHandle },
/// The entry is a completed accept request.
Completed { host_fd: FileDesc },
}
impl Default for Entry {
fn default() -> Self {
Self::Free
}
}
/// An async io_uring accept request.
#[derive(Copy, Clone)]
#[repr(C)]
struct AcceptReq {
c_addr: libc::sockaddr_storage,
c_addr_len: libc::socklen_t,
}
// Safety. AcceptReq is a C-style struct with C-style fields.
unsafe impl MaybeUntrusted for AcceptReq {}
/// A backlog of incoming connections of a listener stream.
///
/// With backlog, we can start async accept requests, keep track of the pending requests,
/// and maintain the ones that have completed.
struct Backlog<A: Addr> {
// The entries in the backlog.
entries: Box<[Entry]>,
// Arguments of the io_uring requests submitted for the entries in the backlog.
reqs: UntrustedBox<[AcceptReq]>,
// The indexes of completed entries.
completed: VecDeque<usize>,
// The number of free entries.
num_free: usize,
phantom_data: PhantomData<A>,
}
impl<A: Addr> Backlog<A> {
pub fn with_capacity(capacity: usize) -> Result<Self> {
if capacity == 0 {
return_errno!(EINVAL, "capacity cannot be zero");
}
let entries = (0..capacity)
.map(|_| Entry::Free)
.collect::<Vec<Entry>>()
.into_boxed_slice();
let reqs = UntrustedBox::new_uninit_slice(capacity);
let completed = VecDeque::new();
let num_free = capacity;
let new_self = Self {
entries,
reqs,
completed,
num_free,
phantom_data: PhantomData,
};
Ok(new_self)
}
pub fn has_free_entries(&self) -> bool {
self.num_free > 0
}
/// Start a new async accept request, turning a free entry into a pending one.
pub fn start_new_req<R: Runtime>(&mut self, stream: &Arc<ListenerStream<A, R>>) {
if stream.common.is_closed() {
return;
}
debug_assert!(self.has_free_entries());
let entry_idx = self
.entries
.iter()
.position(|entry| matches!(entry, Entry::Free))
.unwrap();
let (c_addr_ptr, c_addr_len_ptr) = {
let accept_req = &mut self.reqs[entry_idx];
accept_req.c_addr_len = size_of::<libc::sockaddr_storage>() as _;
let c_addr_ptr = &mut accept_req.c_addr as *mut _ as _;
let c_addr_len_ptr = &mut accept_req.c_addr_len as _;
(c_addr_ptr, c_addr_len_ptr)
};
let callback = {
let stream = stream.clone();
move |retval: i32| {
let mut inner = stream.inner.lock();
trace!("accept request complete with retval: {:?}", retval);
if retval < 0 {
// Since most errors that may result from the accept syscall are _not fatal_,
// we simply ignore the errno code and try again.
//
// According to the man page, Linux may report the network errors on an
// newly-accepted socket through the accept system call. Thus, we should not
// treat the listener socket as "broken" simply because an error is returned
// from the accept syscall.
//
// TODO: throw fatal errors to the upper layer.
let errno = Errno::from(-retval as u32);
log::error!("Accept error: errno = {}", errno);
inner.backlog.entries[entry_idx] = Entry::Free;
inner.backlog.num_free += 1;
// When canceling request, a poller might be waiting for this to return.
inner.fatal = Some(errno);
stream.common.set_errno(errno);
stream.common.pollee().add_events(IoEvents::ERR);
// After getting the error from the accept system call, we should not start
// the async accept requests again, because this may cause a large number of
// io-uring requests to be retried
return;
}
let host_fd = retval as FileDesc;
inner.backlog.entries[entry_idx] = Entry::Completed { host_fd };
inner.backlog.completed.push_back(entry_idx);
stream.common.pollee().add_events(IoEvents::IN);
stream.initiate_async_accepts(inner);
}
};
let io_uring = stream.common.io_uring();
let fd = stream.common.host_fd() as i32;
let flags = 0;
let io_handle =
unsafe { io_uring.accept(Fd(fd), c_addr_ptr, c_addr_len_ptr, flags, callback) };
self.entries[entry_idx] = Entry::Pending { io_handle };
self.num_free -= 1;
}
pub fn has_completed_reqs(&self) -> bool {
self.completed.len() > 0
}
/// Pop a completed async accept request, turing a completed entry into a free one.
pub fn pop_completed_req(&mut self) -> Option<(FileDesc, A)> {
let completed_idx = self.completed.pop_front()?;
let accepted_addr = {
let AcceptReq { c_addr, c_addr_len } = self.reqs[completed_idx].clone();
A::from_c_storage(&c_addr, c_addr_len as _).unwrap()
};
let accepted_fd = {
let entry = &mut self.entries[completed_idx];
let accepted_fd = match entry {
Entry::Completed { host_fd } => *host_fd,
_ => unreachable!("the entry should have been completed"),
};
self.num_free += 1;
*entry = Entry::Free;
accepted_fd
};
Some((accepted_fd, accepted_addr))
}
}
impl<A: Addr + 'static> std::fmt::Debug for Backlog<A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Backlog")
.field("entries", &self.entries)
.field("completed", &self.completed)
.field("num_free", &self.num_free)
.finish()
}
}
impl<A: Addr> Drop for Backlog<A> {
fn drop(&mut self) {
for entry in self.entries.iter() {
if let Entry::Completed { host_fd } = entry {
if let Err(e) = do_close(*host_fd) {
log::error!("close fd failed, host_fd: {}, err: {}", host_fd, e);
}
}
}
}
}

@ -0,0 +1,9 @@
mod connect;
mod connected;
mod init;
mod listen;
pub use self::connect::ConnectingStream;
pub use self::connected::ConnectedStream;
pub use self::init::InitStream;
pub use self::listen::ListenerStream;

@ -17,8 +17,11 @@ pub use std::sync::{
pub use crate::error::Result; pub use crate::error::Result;
pub use crate::error::*; pub use crate::error::*;
pub use crate::fs::{File, FileDesc, FileRef}; pub use crate::fs::{File, FileDesc, FileRef};
pub use crate::net::socket::util::Addr;
pub use crate::net::socket::{Domain, RecvFlags, SendFlags, Shutdown, Type};
pub use crate::process::{pid_t, uid_t}; pub use crate::process::{pid_t, uid_t};
pub use crate::util::sync::RwLock; pub use crate::util::sync::RwLock;
pub use crate::util::sync::{Mutex, MutexGuard};
macro_rules! debug_trace { macro_rules! debug_trace {
() => { () => {

@ -44,8 +44,7 @@ use crate::net::{
do_accept, do_accept4, do_bind, do_connect, do_epoll_create, do_epoll_create1, do_epoll_ctl, do_accept, do_accept4, do_bind, do_connect, do_epoll_create, do_epoll_create1, do_epoll_ctl,
do_epoll_pwait, do_epoll_wait, do_getpeername, do_getsockname, do_getsockopt, do_listen, do_epoll_pwait, do_epoll_wait, do_getpeername, do_getsockname, do_getsockopt, do_listen,
do_poll, do_ppoll, do_pselect6, do_recvfrom, do_recvmsg, do_select, do_sendmmsg, do_sendmsg, do_poll, do_ppoll, do_pselect6, do_recvfrom, do_recvmsg, do_select, do_sendmmsg, do_sendmsg,
do_sendto, do_setsockopt, do_shutdown, do_socket, do_socketpair, mmsghdr, msghdr, msghdr_mut, do_sendto, do_setsockopt, do_shutdown, do_socket, do_socketpair, mmsghdr, sigset_argpack,
sigset_argpack,
}; };
use crate::process::{ use crate::process::{
do_arch_prctl, do_clone, do_execve, do_exit, do_exit_group, do_futex, do_get_robust_list, do_arch_prctl, do_clone, do_execve, do_exit, do_exit_group, do_futex, do_get_robust_list,
@ -143,8 +142,8 @@ macro_rules! process_syscall_table_with_callback {
(Accept = 43) => do_accept(fd: c_int, addr: *mut libc::sockaddr, addr_len: *mut libc::socklen_t), (Accept = 43) => do_accept(fd: c_int, addr: *mut libc::sockaddr, addr_len: *mut libc::socklen_t),
(Sendto = 44) => do_sendto(fd: c_int, base: *const c_void, len: size_t, flags: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t), (Sendto = 44) => do_sendto(fd: c_int, base: *const c_void, len: size_t, flags: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t),
(Recvfrom = 45) => do_recvfrom(fd: c_int, base: *mut c_void, len: size_t, flags: c_int, addr: *mut libc::sockaddr, addr_len: *mut libc::socklen_t), (Recvfrom = 45) => do_recvfrom(fd: c_int, base: *mut c_void, len: size_t, flags: c_int, addr: *mut libc::sockaddr, addr_len: *mut libc::socklen_t),
(Sendmsg = 46) => do_sendmsg(fd: c_int, msg_ptr: *const msghdr, flags_c: c_int), (Sendmsg = 46) => do_sendmsg(fd: c_int, msg_ptr: *const libc::msghdr, flags_c: c_int),
(Recvmsg = 47) => do_recvmsg(fd: c_int, msg_mut_ptr: *mut msghdr_mut, flags_c: c_int), (Recvmsg = 47) => do_recvmsg(fd: c_int, msg_mut_ptr: *mut libc::msghdr, flags_c: c_int),
(Shutdown = 48) => do_shutdown(fd: c_int, how: c_int), (Shutdown = 48) => do_shutdown(fd: c_int, how: c_int),
(Bind = 49) => do_bind(fd: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t), (Bind = 49) => do_bind(fd: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t),
(Listen = 50) => do_listen(fd: c_int, backlog: c_int), (Listen = 50) => do_listen(fd: c_int, backlog: c_int),

@ -54,6 +54,7 @@ LINK_FLAGS += -lsgx_quote_ex_sim
endif endif
endif endif
LINK_FLAGS += -L$(PROJECT_DIR)/deps/io-uring/ocalls/target/release/ -lsgx_io_uring_ocalls
ALL_BUILD_SUBDIRS := $(sort $(patsubst %/,%,$(dir $(LIBOCCLUM_PAL_SO_REAL) $(EDL_C_OBJS) $(C_OBJS) $(CXX_OBJS) $(VDSO_OBJS)))) ALL_BUILD_SUBDIRS := $(sort $(patsubst %/,%,$(dir $(LIBOCCLUM_PAL_SO_REAL) $(EDL_C_OBJS) $(C_OBJS) $(CXX_OBJS) $(VDSO_OBJS))))
.PHONY: all format format-check clean .PHONY: all format format-check clean
@ -79,7 +80,8 @@ $(OBJ_DIR)/pal/$(SRC_OBJ)/Enclave_u.c: $(SGX_EDGER8R) ../Enclave.edl
$(SGX_EDGER8R) $(SGX_EDGER8R_MODE) --untrusted $(CUR_DIR)/../Enclave.edl \ $(SGX_EDGER8R) $(SGX_EDGER8R_MODE) --untrusted $(CUR_DIR)/../Enclave.edl \
--search-path $(SGX_SDK)/include \ --search-path $(SGX_SDK)/include \
--search-path $(RUST_SGX_SDK_DIR)/edl/ \ --search-path $(RUST_SGX_SDK_DIR)/edl/ \
--search-path $(CRATES_DIR)/vdso-time/ocalls --search-path $(CRATES_DIR)/vdso-time/ocalls \
--search-path $(PROJECT_DIR)/deps/io-uring/ocalls
@echo "GEN <= $@" @echo "GEN <= $@"
$(OBJ_DIR)/pal/$(SRC_OBJ)/%.o: src/%.c $(OBJ_DIR)/pal/$(SRC_OBJ)/%.o: src/%.c