From f8be7e74544c323f89f2808acaff04018e222cd2 Mon Sep 17 00:00:00 2001 From: ClawSeven Date: Tue, 30 Apr 2024 14:54:23 +0800 Subject: [PATCH] [libos] Implement async network framework based on IO_Uring --- .gitmodules | 3 + Makefile | 2 + deps/io-uring | 1 + src/Enclave.edl | 2 + src/libos/Cargo.lock | 186 ++++++ src/libos/src/entry.rs | 6 +- src/libos/src/io_uring.rs | 169 +++++ src/libos/src/lib.rs | 8 +- src/libos/src/net/mod.rs | 9 +- src/libos/src/net/socket/mod.rs | 24 +- .../src/net/socket/uring/common/common.rs | 241 +++++++ src/libos/src/net/socket/uring/common/mod.rs | 7 + .../src/net/socket/uring/common/operation.rs | 44 ++ .../src/net/socket/uring/common/timeout.rs | 32 + .../src/net/socket/uring/datagram/generic.rs | 494 ++++++++++++++ .../src/net/socket/uring/datagram/mod.rs | 20 + .../src/net/socket/uring/datagram/receiver.rs | 382 +++++++++++ .../src/net/socket/uring/datagram/sender.rs | 406 ++++++++++++ src/libos/src/net/socket/uring/file_impl.rs | 79 +++ src/libos/src/net/socket/uring/mod.rs | 12 + src/libos/src/net/socket/uring/runtime.rs | 12 + src/libos/src/net/socket/uring/socket_file.rs | 446 +++++++++++++ src/libos/src/net/socket/uring/stream/mod.rs | 610 ++++++++++++++++++ .../net/socket/uring/stream/states/connect.rs | 227 +++++++ .../uring/stream/states/connected/mod.rs | 114 ++++ .../uring/stream/states/connected/recv.rs | 458 +++++++++++++ .../uring/stream/states/connected/send.rs | 466 +++++++++++++ .../net/socket/uring/stream/states/init.rs | 72 +++ .../net/socket/uring/stream/states/listen.rs | 430 ++++++++++++ .../src/net/socket/uring/stream/states/mod.rs | 9 + src/libos/src/prelude.rs | 3 + src/libos/src/syscall/mod.rs | 7 +- src/pal/Makefile | 4 +- 33 files changed, 4958 insertions(+), 27 deletions(-) create mode 160000 deps/io-uring create mode 100644 src/libos/src/io_uring.rs create mode 100644 src/libos/src/net/socket/uring/common/common.rs create mode 100644 src/libos/src/net/socket/uring/common/mod.rs create mode 100644 src/libos/src/net/socket/uring/common/operation.rs create mode 100644 src/libos/src/net/socket/uring/common/timeout.rs create mode 100644 src/libos/src/net/socket/uring/datagram/generic.rs create mode 100644 src/libos/src/net/socket/uring/datagram/mod.rs create mode 100644 src/libos/src/net/socket/uring/datagram/receiver.rs create mode 100644 src/libos/src/net/socket/uring/datagram/sender.rs create mode 100644 src/libos/src/net/socket/uring/file_impl.rs create mode 100644 src/libos/src/net/socket/uring/mod.rs create mode 100644 src/libos/src/net/socket/uring/runtime.rs create mode 100644 src/libos/src/net/socket/uring/socket_file.rs create mode 100644 src/libos/src/net/socket/uring/stream/mod.rs create mode 100644 src/libos/src/net/socket/uring/stream/states/connect.rs create mode 100644 src/libos/src/net/socket/uring/stream/states/connected/mod.rs create mode 100644 src/libos/src/net/socket/uring/stream/states/connected/recv.rs create mode 100644 src/libos/src/net/socket/uring/stream/states/connected/send.rs create mode 100644 src/libos/src/net/socket/uring/stream/states/init.rs create mode 100644 src/libos/src/net/socket/uring/stream/states/listen.rs create mode 100644 src/libos/src/net/socket/uring/stream/states/mod.rs diff --git a/.gitmodules b/.gitmodules index 7c840d97..647774e5 100644 --- a/.gitmodules +++ b/.gitmodules @@ -24,3 +24,6 @@ [submodule "deps/resolv-conf"] path = deps/resolv-conf url = https://github.com/tailhook/resolv-conf.git +[submodule "deps/io-uring"] + path = deps/io-uring + url = https://github.com/occlum/io-uring.git diff --git a/Makefile b/Makefile index 6c9178e7..2ea5c417 100644 --- a/Makefile +++ b/Makefile @@ -42,6 +42,7 @@ submodule: githooks init-submodule @cp deps/sefs/sefs-cli/lib/libsefs-cli_sim.so build/lib @cp deps/sefs/sefs-cli/lib/libsefs-cli.signed.so build/lib @cp deps/sefs/sefs-cli/enclave/Enclave.config.xml build/sefs-cli.Enclave.xml + @cd deps/io-uring/ocalls && cargo clean && cargo build --release else submodule: githooks init-submodule @rm -rf build @@ -60,6 +61,7 @@ submodule: githooks init-submodule @cp deps/sefs/sefs-cli/lib/libsefs-cli_sim.so build/lib @cp deps/sefs/sefs-cli/lib/libsefs-cli.signed.so build/lib @cp deps/sefs/sefs-cli/enclave/Enclave.config.xml build/sefs-cli.Enclave.xml + @cd deps/io-uring/ocalls && cargo clean && cargo build --release endif init-submodule: diff --git a/deps/io-uring b/deps/io-uring new file mode 160000 index 00000000..c654c492 --- /dev/null +++ b/deps/io-uring @@ -0,0 +1 @@ +Subproject commit c654c4925bb0b013d3eec736015f8ac4888722be diff --git a/src/Enclave.edl b/src/Enclave.edl index 366de52c..f8c77bf4 100644 --- a/src/Enclave.edl +++ b/src/Enclave.edl @@ -7,6 +7,8 @@ enclave { from "sgx_net.edl" import *; from "sgx_occlum_utils.edl" import *; from "sgx_vdso_time_ocalls.edl" import *; + from "sgx_thread.edl" import *; + from "sgx_io_uring_ocalls.edl" import *; include "sgx_quote.h" include "occlum_edl_types.h" diff --git a/src/libos/Cargo.lock b/src/libos/Cargo.lock index c4a241ce..8533a22e 100644 --- a/src/libos/Cargo.lock +++ b/src/libos/Cargo.lock @@ -10,16 +10,21 @@ dependencies = [ "atomic", "bitflags", "bitvec 1.0.1", + "byteorder", "ctor", "derive_builder", + "downcast-rs", "errno", "goblin", "intrusive-collections", + "io-uring-callback", "itertools", + "keyable-arc", "lazy_static", "log", "memoffset 0.6.5", "modular-bitfield", + "num_enum", "rcore-fs", "rcore-fs-devfs", "rcore-fs-mountfs", @@ -32,6 +37,7 @@ dependencies = [ "scroll", "serde", "serde_json", + "sgx-untrusted-alloc", "sgx_cov", "sgx_tcrypto", "sgx_trts", @@ -112,6 +118,12 @@ dependencies = [ "wyz", ] +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "cc" version = "1.0.73" @@ -203,6 +215,12 @@ dependencies = [ "syn", ] +[[package]] +name = "downcast-rs" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ea835d29036a4087793836fa931b08837ad5e957da9e23886b29586fb9b6650" + [[package]] name = "either" version = "1.8.0" @@ -237,6 +255,67 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" +[[package]] +name = "futures" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" + +[[package]] +name = "futures-io" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" + +[[package]] +name = "futures-sink" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" + +[[package]] +name = "futures-task" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" + +[[package]] +name = "futures-util" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +dependencies = [ + "futures-core", + "futures-sink", + "futures-task", + "pin-project-lite", + "pin-utils", +] + [[package]] name = "goblin" version = "0.5.4" @@ -267,6 +346,36 @@ dependencies = [ "memoffset 0.5.6", ] +[[package]] +name = "io-uring" +version = "0.5.9" +dependencies = [ + "bitflags", + "libc", + "sgx_libc", + "sgx_trts", + "sgx_tstd", + "sgx_types", +] + +[[package]] +name = "io-uring-callback" +version = "0.1.0" +dependencies = [ + "atomic", + "cfg-if", + "futures", + "io-uring", + "lazy_static", + "libc", + "lock_api", + "log", + "sgx_libc", + "sgx_tstd", + "slab", + "spin 0.7.1", +] + [[package]] name = "itertools" version = "0.10.3" @@ -283,6 +392,10 @@ dependencies = [ "sgx_tstd", ] +[[package]] +name = "keyable-arc" +version = "0.1.0" + [[package]] name = "lazy_static" version = "1.4.0" @@ -298,6 +411,15 @@ version = "0.2.132" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8371e4e5341c3a96db127eb2465ac681ced4c433e01dd0e938adbef26ba93ba5" +[[package]] +name = "lock_api" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd96ffd135b2fd7b973ac026d28085defbe8983df057ced3eb4f2130b0831312" +dependencies = [ + "scopeguard", +] + [[package]] name = "log" version = "0.4.17" @@ -346,6 +468,38 @@ dependencies = [ "syn", ] +[[package]] +name = "num_enum" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f646caf906c20226733ed5b1374287eb97e3c2a5c227ce668c1f2ce20ae57c9" +dependencies = [ + "num_enum_derive", +] + +[[package]] +name = "num_enum_derive" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcbff9bc912032c62bf65ef1d5aea88983b420f4f839db1e9b0c281a25c9c799" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "plain" version = "0.2.3" @@ -601,6 +755,12 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09" +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "scroll" version = "0.11.0" @@ -648,6 +808,23 @@ dependencies = [ "sgx_tstd", ] +[[package]] +name = "sgx-untrusted-alloc" +version = "0.1.0" +dependencies = [ + "cfg-if", + "errno", + "intrusive-collections", + "lazy_static", + "libc", + "log", + "sgx_libc", + "sgx_trts", + "sgx_tstd", + "sgx_types", + "spin 0.7.1", +] + [[package]] name = "sgx_alloc" version = "1.1.6" @@ -753,6 +930,15 @@ dependencies = [ "sgx_build_helper", ] +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg 1.1.0", +] + [[package]] name = "spin" version = "0.5.2" diff --git a/src/libos/src/entry.rs b/src/libos/src/entry.rs index 6a36a7b2..d8262ec1 100644 --- a/src/libos/src/entry.rs +++ b/src/libos/src/entry.rs @@ -7,6 +7,7 @@ use super::*; use crate::exception::*; use crate::fs::HostStdioFds; use crate::interrupt; +use crate::io_uring::ENABLE_URING; use crate::process::idle_reap_zombie_children; use crate::process::{ProcessFilter, SpawnAttr}; use crate::signal::SigNum; @@ -101,11 +102,14 @@ pub extern "C" fn occlum_ecall_init( vm::init_user_space(); + if ENABLE_URING.load(Ordering::Relaxed) { + crate::io_uring::MULTITON.poll_completions(); + } + // Register exception handlers (support cpuid & rdtsc for now) register_exception_handlers(); HAS_INIT.store(true, Ordering::Release); - // Enable global backtrace unsafe { backtrace::enable_backtrace(&ENCLAVE_PATH, PrintFormat::Short) }; diff --git a/src/libos/src/io_uring.rs b/src/libos/src/io_uring.rs new file mode 100644 index 00000000..e629ad32 --- /dev/null +++ b/src/libos/src/io_uring.rs @@ -0,0 +1,169 @@ +use core::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize}; +use std::{collections::HashMap, thread::current}; + +use crate::util::sync::Mutex; +use alloc::{sync::Arc, vec::Vec}; +use atomic::Ordering; +use io_uring_callback::{Builder, IoUring}; +use keyable_arc::KeyableArc; + +use crate::config::LIBOS_CONFIG; + +// The number of sockets to reach the network bandwidth threshold of one io_uring instance +const SOCKET_THRESHOLD_PER_URING: u32 = 1; + +lazy_static::lazy_static! { + pub static ref MULTITON: UringSet = { + let uring_set = UringSet::new(); + uring_set + }; + + pub static ref ENABLE_URING: AtomicBool = AtomicBool::new(LIBOS_CONFIG.feature.io_uring > 0); + + // Four uring instances are sufficient to reach the network bandwidth threshold of host kernel. + pub static ref URING_LIMIT: AtomicUsize = { + let uring_limit = LIBOS_CONFIG.feature.io_uring; + assert!(uring_limit <= 16, "io_uring limit must not exceed 16"); + AtomicUsize::new(uring_limit as usize) + }; +} + +#[derive(Clone, Copy, Default)] +struct UringState { + registered_num: u32, + is_enable_poll: bool, // CQE polling thread +} + +impl UringState { + fn register_one_socket(&mut self) { + self.registered_num += 1; + } + + fn unregister_one_socket(&mut self) { + self.registered_num -= 1; + } + + fn enable_poll(&mut self, uring: Arc) { + if !self.is_enable_poll { + self.is_enable_poll = true; + std::thread::spawn(move || loop { + let min_complete = 1; + let polling_retries = 10000; + uring.poll_completions(min_complete, polling_retries); + }); + } + } +} + +pub struct UringSet { + urings: Mutex, UringState>>, + running_uring_num: AtomicU32, +} + +impl UringSet { + pub fn new() -> Self { + let urings = Mutex::new(HashMap::new()); + let running_uring_num = AtomicU32::new(0); + Self { + urings, + running_uring_num, + } + } + + pub fn poll_completions(&self) { + let mut guard = self.urings.lock(); + let uring_limit = URING_LIMIT.load(Ordering::Relaxed) as u32; + + for _ in 0..uring_limit { + let uring: KeyableArc = Arc::new( + Builder::new() + .setup_sqpoll(500 /* ms */) + .build(256) + .unwrap(), + ) + .into(); + let mut state = UringState::default(); + state.enable_poll(uring.clone().into()); + + guard.insert(uring.clone(), state); + self.running_uring_num.fetch_add(1, Ordering::Relaxed); + } + } + + pub fn get_uring(&self) -> Arc { + let mut map = self.urings.lock(); + let running_uring_num = self.running_uring_num.load(Ordering::Relaxed); + let uring_limit = URING_LIMIT.load(Ordering::Relaxed) as u32; + assert!(running_uring_num <= uring_limit); + + let init_stage = running_uring_num < uring_limit; + + // Construct an io_uring instance and initiate a polling thread + if init_stage { + let should_build_uring = { + // Sum registered socket + let total_socket_num = map + .values() + .fold(0, |acc, state| acc + state.registered_num) + + 1; + // Determine the number of available io_uring + let uring_num = (total_socket_num / SOCKET_THRESHOLD_PER_URING) + 1; + let existed_uring_num = self.running_uring_num.load(Ordering::Relaxed); + assert!(existed_uring_num <= uring_num); + existed_uring_num < uring_num + }; + + if should_build_uring { + let uring: KeyableArc = Arc::new( + Builder::new() + .setup_sqpoll(500 /* ms */) + .build(256) + .unwrap(), + ) + .into(); + let mut state = UringState::default(); + state.register_one_socket(); + state.enable_poll(uring.clone().into()); + + map.insert(uring.clone(), state); + self.running_uring_num.fetch_add(1, Ordering::Relaxed); + return uring.into(); + } + } + + // Link the file to the io_uring instance with the least load. + let (mut uring, mut state) = map + .iter_mut() + .min_by_key(|(_, &mut state)| state.registered_num) + .unwrap(); + + // Re-select io_uring instance with least task load + if !init_stage { + let min_registered_num = state.registered_num; + (uring, state) = map + .iter_mut() + .filter(|(_, state)| state.registered_num == min_registered_num) + .min_by_key(|(uring, _)| uring.task_load()) + .unwrap(); + } else { + // At the initial stage, without constructing additional io_uring instances, + // there exists a singular io_uring which has the minimum number of registered sockets. + } + + // Update io_uring instance states + state.register_one_socket(); + assert!(state.is_enable_poll); + + uring.clone().into() + } + + pub fn disattach_uring(&self, fd: usize, uring: Arc) { + let uring: KeyableArc = uring.into(); + let mut map = self.urings.lock(); + let mut state = map.get_mut(&uring).unwrap(); + state.unregister_one_socket(); + drop(map); + + uring.disattach_fd(fd); + } +} diff --git a/src/libos/src/lib.rs b/src/libos/src/lib.rs index 54239950..83e2240f 100644 --- a/src/libos/src/lib.rs +++ b/src/libos/src/lib.rs @@ -28,6 +28,8 @@ #![feature(is_some_and)] // for edmm_api macro #![feature(linkage)] +#![feature(new_uninit)] +#![feature(raw_ref_op)] #[macro_use] extern crate alloc; @@ -66,7 +68,6 @@ extern crate intrusive_collections; extern crate itertools; extern crate modular_bitfield; extern crate resolv_conf; -extern crate vdso_time; use sgx_trts::libc; use sgx_types::*; @@ -82,15 +83,18 @@ mod prelude; #[macro_use] mod error; +#[macro_use] +mod net; + mod config; mod entry; mod events; mod exception; mod fs; mod interrupt; +mod io_uring; mod ipc; mod misc; -mod net; mod process; mod sched; mod signal; diff --git a/src/libos/src/net/mod.rs b/src/libos/src/net/mod.rs index b8518a15..04bb2172 100644 --- a/src/libos/src/net/mod.rs +++ b/src/libos/src/net/mod.rs @@ -7,12 +7,13 @@ pub use self::io_multiplexing::{ PollEventFlags, PollFd, THREAD_NOTIFIERS, }; pub use self::socket::{ - mmsghdr, msghdr, msghdr_mut, socketpair, unix_socket, AddressFamily, AsUnixSocket, FileFlags, - HostSocket, HostSocketType, HowToShut, Iovs, IovsMut, MsgHdr, MsgHdrFlags, MsgHdrMut, - RecvFlags, SendFlags, SliceAsLibcIovec, SockAddr, SocketType, UnixAddr, + socketpair, unix_socket, AsUnixSocket, Domain, HostSocket, HostSocketType, Iovs, IovsMut, + RawAddr, SliceAsLibcIovec, UnixAddr, }; pub use self::syscalls::*; mod io_multiplexing; -mod socket; +pub(crate) mod socket; mod syscalls; + +pub use self::syscalls::*; diff --git a/src/libos/src/net/socket/mod.rs b/src/libos/src/net/socket/mod.rs index f57d2fb0..37c72801 100644 --- a/src/libos/src/net/socket/mod.rs +++ b/src/libos/src/net/socket/mod.rs @@ -1,21 +1,15 @@ use super::*; -mod address_family; -mod flags; mod host; -mod iovs; -mod msg; -mod shutdown; -mod socket_address; -mod socket_type; +pub(crate) mod sockopt; mod unix; +pub(crate) mod uring; +pub(crate) mod util; -pub use self::address_family::AddressFamily; -pub use self::flags::{FileFlags, MsgHdrFlags, RecvFlags, SendFlags}; pub use self::host::{HostSocket, HostSocketType}; -pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec}; -pub use self::msg::{mmsghdr, msghdr, msghdr_mut, CMessages, CmsgData, MsgHdr, MsgHdrMut}; -pub use self::shutdown::HowToShut; -pub use self::socket_address::SockAddr; -pub use self::socket_type::SocketType; -pub use self::unix::{socketpair, unix_socket, AsUnixSocket, UnixAddr}; +pub use self::unix::{socketpair, unix_socket, AsUnixSocket}; +pub use self::util::{ + Addr, AnyAddr, CMessages, CSockAddr, CmsgData, Domain, Iovs, IovsMut, Ipv4Addr, Ipv4SocketAddr, + Ipv6SocketAddr, MsgFlags, RawAddr, RecvFlags, SendFlags, Shutdown, SliceAsLibcIovec, + SocketProtocol, Type, UnixAddr, +}; diff --git a/src/libos/src/net/socket/uring/common/common.rs b/src/libos/src/net/socket/uring/common/common.rs new file mode 100644 index 00000000..864ee2cc --- /dev/null +++ b/src/libos/src/net/socket/uring/common/common.rs @@ -0,0 +1,241 @@ +use core::time::Duration; +use std::marker::PhantomData; +use std::sync::atomic::{AtomicBool, Ordering}; + +use super::Timeout; +use io_uring_callback::IoUring; + +use libc::ocall::getsockname as do_getsockname; +use libc::ocall::shutdown as do_shutdown; +use libc::ocall::socket as do_socket; +use libc::ocall::socketpair as do_socketpair; + +use crate::events::Pollee; +use crate::fs::{IoEvents, IoNotifier}; +use crate::net::socket::uring::runtime::Runtime; +use crate::prelude::*; + +/// The common parts of all stream sockets. +pub struct Common { + host_fd: FileDesc, + type_: Type, + nonblocking: AtomicBool, + is_closed: AtomicBool, + pollee: Pollee, + inner: Mutex>, + timeout: Mutex, + errno: Mutex>, + io_uring: Arc, + phantom_data: PhantomData<(A, R)>, +} + +impl Common { + pub fn new(type_: Type, nonblocking: bool, protocol: Option) -> Result { + let domain_c = A::domain() as libc::c_int; + let type_c = type_ as libc::c_int; + let protocol = protocol.unwrap_or(0) as libc::c_int; + let host_fd = try_libc!(do_socket(domain_c, type_c, protocol)) as FileDesc; + let nonblocking = AtomicBool::new(nonblocking); + let is_closed = AtomicBool::new(false); + let pollee = Pollee::new(IoEvents::empty()); + let inner = Mutex::new(Inner::new()); + let timeout = Mutex::new(Timeout::new()); + let io_uring = R::io_uring(); + let errno = Mutex::new(None); + Ok(Self { + host_fd, + type_, + nonblocking, + is_closed, + pollee, + inner, + timeout, + errno, + io_uring, + phantom_data: PhantomData, + }) + } + + pub fn new_pair(sock_type: Type, nonblocking: bool) -> Result<(Self, Self)> { + return_errno!(EINVAL, "Unix is unsupported"); + } + + pub fn with_host_fd(host_fd: FileDesc, type_: Type, nonblocking: bool) -> Self { + let nonblocking = AtomicBool::new(nonblocking); + let is_closed = AtomicBool::new(false); + let pollee = Pollee::new(IoEvents::empty()); + let inner = Mutex::new(Inner::new()); + let timeout = Mutex::new(Timeout::new()); + let io_uring = R::io_uring(); + let errno = Mutex::new(None); + Self { + host_fd, + type_, + nonblocking, + is_closed, + pollee, + inner, + timeout, + errno, + io_uring, + phantom_data: PhantomData, + } + } + + pub fn io_uring(&self) -> Arc { + self.io_uring.clone() + } + + pub fn host_fd(&self) -> FileDesc { + self.host_fd + } + + pub fn type_(&self) -> Type { + self.type_ + } + + pub fn nonblocking(&self) -> bool { + self.nonblocking.load(Ordering::Relaxed) + } + + pub fn set_nonblocking(&self, is_nonblocking: bool) { + self.nonblocking.store(is_nonblocking, Ordering::Relaxed) + } + + pub fn notifier(&self) -> &IoNotifier { + self.pollee.notifier() + } + + pub fn send_timeout(&self) -> Option { + self.timeout.lock().sender_timeout() + } + + pub fn recv_timeout(&self) -> Option { + self.timeout.lock().receiver_timeout() + } + + pub fn set_send_timeout(&self, timeout: Duration) { + self.timeout.lock().set_sender(timeout) + } + + pub fn set_recv_timeout(&self, timeout: Duration) { + self.timeout.lock().set_receiver(timeout) + } + + pub fn is_closed(&self) -> bool { + self.is_closed.load(Ordering::Relaxed) + } + + pub fn set_closed(&self) { + self.is_closed.store(true, Ordering::Relaxed) + } + + pub fn reset_closed(&self) { + self.is_closed.store(false, Ordering::Relaxed) + } + + pub fn pollee(&self) -> &Pollee { + &self.pollee + } + + #[allow(unused)] + pub fn addr(&self) -> Option { + let inner = self.inner.lock(); + inner.addr.clone() + } + + pub fn set_addr(&self, addr: &A) { + let mut inner = self.inner.lock(); + inner.addr = Some(addr.clone()) + } + + pub fn get_addr_from_host(&self) -> Result { + let mut c_addr: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; + let mut c_addr_len = std::mem::size_of::() as u32; + try_libc!(do_getsockname( + self.host_fd as _, + &mut c_addr as *mut libc::sockaddr_storage as *mut _, + &mut c_addr_len as *mut _, + )); + A::from_c_storage(&c_addr, c_addr_len as _) + } + + pub fn peer_addr(&self) -> Option { + let inner = self.inner.lock(); + inner.peer_addr.clone() + } + + pub fn set_peer_addr(&self, peer_addr: &A) { + let mut inner = self.inner.lock(); + inner.peer_addr = Some(peer_addr.clone()); + } + + pub fn reset_peer_addr(&self) { + let mut inner = self.inner.lock(); + inner.peer_addr = None; + } + + // For getsockopt SO_ERROR command + pub fn errno(&self) -> Option { + let mut errno_option = self.errno.lock(); + errno_option.take() + } + + pub fn set_errno(&self, errno: Errno) { + let mut errno_option = self.errno.lock(); + *errno_option = Some(errno); + } + + pub fn host_shutdown(&self, how: Shutdown) -> Result<()> { + trace!("host shutdown: {:?}", how); + match how { + Shutdown::Write => { + try_libc!(do_shutdown(self.host_fd as _, libc::SHUT_WR)); + } + Shutdown::Read => { + try_libc!(do_shutdown(self.host_fd as _, libc::SHUT_RD)); + } + Shutdown::Both => { + try_libc!(do_shutdown(self.host_fd as _, libc::SHUT_RDWR)); + } + } + Ok(()) + } +} + +impl std::fmt::Debug for Common { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Common") + .field("host_fd", &self.host_fd) + .field("type", &self.type_) + .field("nonblocking", &self.nonblocking) + .field("pollee", &self.pollee) + .field("inner", &self.inner.lock()) + .finish() + } +} + +impl Drop for Common { + fn drop(&mut self) { + if let Err(e) = super::do_close(self.host_fd) { + log::error!("do_close failed, host_fd: {}, err: {:?}", self.host_fd, e); + } + + R::disattach_io_uring(self.host_fd as usize, self.io_uring()) + } +} + +#[derive(Debug)] +struct Inner { + addr: Option, + peer_addr: Option, +} + +impl Inner { + pub fn new() -> Self { + Self { + addr: None, + peer_addr: None, + } + } +} diff --git a/src/libos/src/net/socket/uring/common/mod.rs b/src/libos/src/net/socket/uring/common/mod.rs new file mode 100644 index 00000000..1af5cbcd --- /dev/null +++ b/src/libos/src/net/socket/uring/common/mod.rs @@ -0,0 +1,7 @@ +mod common; +mod operation; +mod timeout; + +pub use self::common::Common; +pub use self::operation::{do_bind, do_close, do_connect, do_unlink}; +pub use self::timeout::Timeout; diff --git a/src/libos/src/net/socket/uring/common/operation.rs b/src/libos/src/net/socket/uring/common/operation.rs new file mode 100644 index 00000000..71032bbe --- /dev/null +++ b/src/libos/src/net/socket/uring/common/operation.rs @@ -0,0 +1,44 @@ +use std::ffi::CString; +use std::mem::{self, MaybeUninit}; + +use crate::prelude::*; + +pub fn do_bind(host_fd: FileDesc, addr: &A) -> Result<()> { + let fd = host_fd as i32; + let (c_addr_storage, c_addr_len) = addr.to_c_storage(); + let c_addr_ptr = &c_addr_storage as *const _ as _; + let c_addr_len = c_addr_len as u32; + try_libc!(libc::ocall::bind(fd, c_addr_ptr, c_addr_len)); + Ok(()) +} + +pub fn do_close(host_fd: FileDesc) -> Result<()> { + let fd = host_fd as i32; + try_libc!(libc::ocall::close(fd)); + Ok(()) +} + +pub fn do_unlink(path: &String) -> Result<()> { + let c_string = + CString::new(path.as_bytes()).map_err(|_| errno!(EINVAL, "cstring new failure"))?; + let c_path = c_string.as_c_str().as_ptr(); + try_libc!(libc::ocall::unlink(c_path)); + Ok(()) +} + +pub fn do_connect(host_fd: FileDesc, addr: Option<&A>) -> Result<()> { + let fd = host_fd as i32; + let (c_addr_storage, c_addr_len) = match addr { + Some(addr_inner) => addr_inner.to_c_storage(), + None => { + let mut sockaddr_storage = + unsafe { MaybeUninit::::uninit().assume_init() }; + sockaddr_storage.ss_family = libc::AF_UNSPEC as _; + (sockaddr_storage, mem::size_of::()) + } + }; + let c_addr_ptr = &c_addr_storage as *const _ as _; + let c_addr_len = c_addr_len as u32; + try_libc!(libc::ocall::connect(fd, c_addr_ptr, c_addr_len)); + Ok(()) +} diff --git a/src/libos/src/net/socket/uring/common/timeout.rs b/src/libos/src/net/socket/uring/common/timeout.rs new file mode 100644 index 00000000..6df904a0 --- /dev/null +++ b/src/libos/src/net/socket/uring/common/timeout.rs @@ -0,0 +1,32 @@ +use std::time::Duration; + +#[derive(Clone, Debug)] +pub struct Timeout { + sender: Option, + receiver: Option, +} + +impl Timeout { + pub fn new() -> Self { + Self { + sender: None, + receiver: None, + } + } + + pub fn sender_timeout(&self) -> Option { + self.sender + } + + pub fn receiver_timeout(&self) -> Option { + self.receiver + } + + pub fn set_sender(&mut self, timeout: Duration) { + self.sender = Some(timeout); + } + + pub fn set_receiver(&mut self, timeout: Duration) { + self.receiver = Some(timeout); + } +} diff --git a/src/libos/src/net/socket/uring/datagram/generic.rs b/src/libos/src/net/socket/uring/datagram/generic.rs new file mode 100644 index 00000000..f9b5d3af --- /dev/null +++ b/src/libos/src/net/socket/uring/datagram/generic.rs @@ -0,0 +1,494 @@ +use core::time::Duration; + +use crate::{ + events::{Observer, Poller}, + fs::{IoNotifier, StatusFlags}, + match_ioctl_cmd_mut, + net::socket::MsgFlags, +}; + +use super::*; +use crate::fs::IoEvents as Events; +use crate::fs::{GetIfConf, GetIfReqWithRawCmd, GetReadBufLen, IoctlCmd}; + +pub struct DatagramSocket { + common: Arc>, + state: RwLock, + sender: Arc>, + receiver: Arc>, +} + +impl DatagramSocket { + pub fn new(nonblocking: bool) -> Result { + let common = Arc::new(Common::new(Type::DGRAM, nonblocking, None)?); + let state = RwLock::new(State::new()); + let sender = Sender::new(common.clone()); + let receiver = Receiver::new(common.clone()); + Ok(Self { + common, + state, + sender, + receiver, + }) + } + + pub fn new_pair(nonblocking: bool) -> Result<(Self, Self)> { + let (common1, common2) = Common::new_pair(Type::DGRAM, nonblocking)?; + let socket1 = Self::new_connected(common1); + let socket2 = Self::new_connected(common2); + Ok((socket1, socket2)) + } + + fn new_connected(common: Common) -> Self { + let common = Arc::new(common); + let state = RwLock::new(State::new_connected()); + let sender = Sender::new(common.clone()); + let receiver = Receiver::new(common.clone()); + receiver.initiate_async_recv(); + Self { + common, + state, + sender, + receiver, + } + } + + pub fn domain(&self) -> Domain { + A::domain() + } + + pub fn host_fd(&self) -> FileDesc { + self.common.host_fd() + } + + pub fn status_flags(&self) -> StatusFlags { + // Only support O_NONBLOCK + if self.common.nonblocking() { + StatusFlags::O_NONBLOCK + } else { + StatusFlags::empty() + } + } + + pub fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { + // Only support O_NONBLOCK + let nonblocking = new_flags.is_nonblocking(); + self.common.set_nonblocking(nonblocking); + Ok(()) + } + + /// When creating a datagram socket, you can use `bind` to bind the socket + /// to a address, hence another socket can send data to this address. + /// + /// Binding is divided into explicit and implicit. Invoking `bind` is + /// explicit binding, while invoking `sendto` / `sendmsg` / `connect` + /// will trigger implicit binding. + /// + /// Datagram sockets can only bind once. You should use explicit binding or + /// just implicit binding. The explicit binding will failed if it happens after + /// a implicit binding. + pub fn bind(&self, addr: &A) -> Result<()> { + let mut state = self.state.write().unwrap(); + if state.is_bound() { + return_errno!(EINVAL, "The socket is already bound to an address"); + } + + do_bind(self.host_fd(), addr)?; + + self.common.set_addr(addr); + state.mark_explicit_bind(); + // Start async recv after explicit binding or implicit binding + self.receiver.initiate_async_recv(); + + Ok(()) + } + + /// Datagram sockets provide only connectionless interactions, But datagram sockets + /// can also use connect to associate a socket with a specific address. + /// After connection, any data sent on the socket is automatically addressed to the + /// connected peer, and only data received from that peer is delivered to the user. + /// + /// Unlike stream sockets, datagram sockets can connect multiple times. But the socket + /// can only connect to one peer in the same time; a second connect will change the + /// peer address, and a connect to a address with family AF_UNSPEC will dissolve the + /// association ("disconnect" or "unconnect"). + /// + /// Before connection you can only use `sendto` / `sendmsg` / `recvfrom` / `recvmsg`. + /// Only after connection, you can use `read` / `recv` / `write` / `send`. + /// And you can ignore the address in `sendto` / `sendmsg` if you just want to + /// send data to the connected peer. + /// + /// Ref 1: http://osr507doc.xinuos.com/en/netguide/disockD.connecting_datagrams.html + /// Ref 2: https://www.masterraghu.com/subjects/np/introduction/unix_network_programming_v1.3/ch08lev1sec11.html + pub fn connect(&self, peer_addr: Option<&A>) -> Result<()> { + let mut state = self.state.write().unwrap(); + + // if previous peer.is_default() and peer_addr.is_none() + // is unspec, so the situation exists that both + // !state.is_connected() and peer_addr.is_none() are true. + + if let Some(peer) = peer_addr { + do_connect(self.host_fd(), Some(peer))?; + + self.receiver.reset_shutdown(); + self.sender.reset_shutdown(); + self.common.set_peer_addr(peer); + + if peer.is_default() { + state.mark_disconnected(); + } else { + state.mark_connected(); + } + if !state.is_bound() { + state.mark_implicit_bind(); + // Start async recv after explicit binding or implicit binding + self.receiver.initiate_async_recv(); + } + + // TODO: update binding address in some cases + // For a ipv4 socket bound to 0.0.0.0 (INADDR_ANY), if you do connection + // to 127.0.0.1 (Local IP address), the IP address of the socket will + // change to 127.0.0.1 too. And if connect to non-local IP address, linux + // will assign a address to the socket. + // In both cases, we should update the binding address that we stored. + } else { + do_connect::(self.host_fd(), None)?; + + self.common.reset_peer_addr(); + state.mark_disconnected(); + + // TODO: clear binding in some cases. + // Disconnect will effect the binding address. In Linux, for socket that + // explicit bound to local IP address, disconnect will clear the binding address, + // but leave the port intact. For socket with implicit bound, disconnect will + // clear both the address and port. + } + Ok(()) + } + + // Close the datagram socket, cancel pending iouring requests + pub fn close(&self) -> Result<()> { + self.sender.shutdown(); + self.receiver.shutdown(); + self.common.set_closed(); + self.cancel_requests(); + Ok(()) + } + + /// Shutdown the udp socket. This syscall is very TCP-oriented, but it is also useful for udp socket. + /// Not like tcp, shutdown does nothing on the wire, it only changes shutdown states. + /// The shutdown states block the io-uring request of receiving or sending message. + pub fn shutdown(&self, how: Shutdown) -> Result<()> { + let state = self.state.read().unwrap(); + if !state.is_connected() { + return_errno!(ENOTCONN, "The udp socket is not connected"); + } + drop(state); + match how { + Shutdown::Read => { + self.common.host_shutdown(how)?; + self.receiver.shutdown(); + self.common.pollee().add_events(Events::IN); + } + Shutdown::Write => { + if self.sender.is_empty() { + self.common.host_shutdown(how)?; + } + self.sender.shutdown(); + self.common.pollee().add_events(Events::OUT); + } + Shutdown::Both => { + self.common.host_shutdown(Shutdown::Read)?; + if self.sender.is_empty() { + self.common.host_shutdown(Shutdown::Write)?; + } + self.receiver.shutdown(); + self.sender.shutdown(); + self.common + .pollee() + .add_events(Events::IN | Events::OUT | Events::HUP); + } + } + Ok(()) + } + + pub fn read(&self, buf: &mut [u8]) -> Result { + self.readv(&mut [buf]) + } + + pub fn readv(&self, bufs: &mut [&mut [u8]]) -> Result { + let state = self.state.read().unwrap(); + drop(state); + + self.recvmsg(bufs, RecvFlags::empty(), None) + .map(|(ret, ..)| ret) + } + + /// You can not invoke `recvfrom` directly after creating a datagram socket. + /// That is because `recvfrom` doesn't privide a implicit binding. If you + /// don't do a explicit or implicit binding, the sender doesn't know where + /// to send the data. + pub fn recvmsg( + &self, + bufs: &mut [&mut [u8]], + flags: RecvFlags, + control: Option<&mut [u8]>, + ) -> Result<(usize, Option, MsgFlags, usize)> { + self.receiver.recvmsg(bufs, flags, control) + } + + pub fn write(&self, buf: &[u8]) -> Result { + self.writev(&[buf]) + } + + pub fn writev(&self, bufs: &[&[u8]]) -> Result { + self.sendmsg(bufs, None, SendFlags::empty(), None) + } + + pub fn sendmsg( + &self, + bufs: &[&[u8]], + addr: Option<&A>, + flags: SendFlags, + control: Option<&[u8]>, + ) -> Result { + let state = self.state.read().unwrap(); + if addr.is_none() && !state.is_connected() { + return_errno!(EDESTADDRREQ, "Destination address required"); + } + + drop(state); + let res = if let Some(addr) = addr { + self.sender.sendmsg(bufs, addr, flags, control) + } else { + let peer = self.common.peer_addr(); + if let Some(peer) = peer.as_ref() { + self.sender.sendmsg(bufs, peer, flags, control) + } else { + return_errno!(EDESTADDRREQ, "Destination address required"); + } + }; + + let mut state = self.state.write().unwrap(); + if !state.is_bound() { + state.mark_implicit_bind(); + // Start async recv after explicit binding or implicit binding + self.receiver.initiate_async_recv(); + } + + res + } + + pub fn poll(&self, mask: Events, poller: Option<&Poller>) -> Events { + let pollee = self.common.pollee(); + pollee.poll(mask, poller) + } + + pub fn addr(&self) -> Result { + let common = &self.common; + + // Always get addr from host. + // Because for IP socket, users can specify "0" as port and the kernel should select a usable port for him. + // Thus, when calling getsockname, this should be updated. + let addr = common.get_addr_from_host()?; + common.set_addr(&addr); + Ok(addr) + } + + pub fn notifier(&self) -> &IoNotifier { + let notifier = self.common.notifier(); + notifier + } + + pub fn peer_addr(&self) -> Result { + let state = self.state.read().unwrap(); + if !state.is_connected() { + return_errno!(ENOTCONN, "the socket is not connected"); + } + Ok(self.common.peer_addr().unwrap()) + } + + pub fn errno(&self) -> Option { + self.common.errno() + } + + pub fn ioctl(&self, cmd: &mut dyn IoctlCmd) -> Result<()> { + match_ioctl_cmd_mut!(&mut *cmd, { + cmd: GetSockOptRawCmd => { + cmd.execute(self.host_fd())?; + }, + cmd: SetSockOptRawCmd => { + cmd.execute(self.host_fd())?; + }, + cmd: SetRecvTimeoutCmd => { + self.set_recv_timeout(*cmd.timeout()); + }, + cmd: SetSendTimeoutCmd => { + self.set_send_timeout(*cmd.timeout()); + }, + cmd: GetRecvTimeoutCmd => { + let timeval = timeout_to_timeval(self.recv_timeout()); + cmd.set_output(timeval); + }, + cmd: GetSendTimeoutCmd => { + let timeval = timeout_to_timeval(self.send_timeout()); + cmd.set_output(timeval); + }, + cmd: GetAcceptConnCmd => { + // Datagram doesn't support listen + cmd.set_output(0); + }, + cmd: GetDomainCmd => { + cmd.set_output(self.domain() as _); + }, + cmd: GetErrorCmd => { + let error: i32 = self.errno().map(|err| err as i32).unwrap_or(0); + cmd.set_output(error); + }, + cmd: GetPeerNameCmd => { + let peer = self.peer_addr()?; + cmd.set_output(AddrStorage(peer.to_c_storage())); + }, + cmd: GetTypeCmd => { + cmd.set_output(self.common.type_() as _); + }, + cmd: GetIfReqWithRawCmd => { + cmd.execute(self.host_fd())?; + }, + cmd: GetIfConf => { + cmd.execute(self.host_fd())?; + }, + cmd: GetReadBufLen => { + let read_buf_len = self.receiver.ready_len(); + cmd.set_output(read_buf_len as _); + }, + _ => { + return_errno!(EINVAL, "Not supported yet"); + } + }); + Ok(()) + } + + fn send_timeout(&self) -> Option { + self.common.send_timeout() + } + + fn recv_timeout(&self) -> Option { + self.common.recv_timeout() + } + + fn set_send_timeout(&self, timeout: Duration) { + self.common.set_send_timeout(timeout); + } + + fn set_recv_timeout(&self, timeout: Duration) { + self.common.set_recv_timeout(timeout); + } + + fn cancel_requests(&self) { + self.receiver.cancel_recv_requests(); + self.sender.try_clear_msg_queue_when_close(); + } +} + +impl Drop for DatagramSocket { + fn drop(&mut self) { + self.common.set_closed(); + } +} + +impl std::fmt::Debug for DatagramSocket { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DatagramSocket") + .field("common", &self.common) + .finish() + } +} + +#[derive(Debug)] +struct State { + bind_state: BindState, + is_connected: bool, +} + +impl State { + pub fn new() -> Self { + Self { + bind_state: BindState::Unbound, + is_connected: false, + } + } + + pub fn new_connected() -> Self { + Self { + bind_state: BindState::Unbound, + is_connected: true, + } + } + + pub fn is_bound(&self) -> bool { + self.bind_state.is_bound() + } + + #[allow(dead_code)] + pub fn is_explicit_bound(&self) -> bool { + self.bind_state.is_explicit_bound() + } + + #[allow(dead_code)] + pub fn is_implicit_bound(&self) -> bool { + self.bind_state.is_implicit_bound() + } + + pub fn is_connected(&self) -> bool { + self.is_connected + } + + pub fn mark_explicit_bind(&mut self) { + self.bind_state = BindState::ExplicitBound; + } + + pub fn mark_implicit_bind(&mut self) { + self.bind_state = BindState::ImplicitBound; + } + + pub fn mark_connected(&mut self) { + self.is_connected = true; + } + + pub fn mark_disconnected(&mut self) { + self.is_connected = false; + } +} + +#[derive(Debug)] +enum BindState { + Unbound, + ExplicitBound, + ImplicitBound, +} + +impl BindState { + pub fn is_bound(&self) -> bool { + match self { + Self::Unbound => false, + _ => true, + } + } + + #[allow(dead_code)] + pub fn is_explicit_bound(&self) -> bool { + match self { + Self::ExplicitBound => true, + _ => false, + } + } + + #[allow(dead_code)] + pub fn is_implicit_bound(&self) -> bool { + match self { + Self::ImplicitBound => true, + _ => false, + } + } +} diff --git a/src/libos/src/net/socket/uring/datagram/mod.rs b/src/libos/src/net/socket/uring/datagram/mod.rs new file mode 100644 index 00000000..6ccb9d5c --- /dev/null +++ b/src/libos/src/net/socket/uring/datagram/mod.rs @@ -0,0 +1,20 @@ +//! Datagram sockets. +mod generic; +mod receiver; +mod sender; + +use self::receiver::Receiver; +use self::sender::Sender; +use crate::net::socket::sockopt::*; +use crate::net::socket::uring::common::{do_bind, do_connect, Common}; +use crate::net::socket::uring::runtime::Runtime; +use crate::prelude::*; + +pub use generic::DatagramSocket; + +use crate::net::socket::sockopt::{ + timeout_to_timeval, GetRecvTimeoutCmd, GetSendTimeoutCmd, SetRecvTimeoutCmd, SetSendTimeoutCmd, +}; + +const MAX_BUF_SIZE: usize = 64 * 1024; +const OPTMEM_MAX: usize = 64 * 1024; diff --git a/src/libos/src/net/socket/uring/datagram/receiver.rs b/src/libos/src/net/socket/uring/datagram/receiver.rs new file mode 100644 index 00000000..2ade1ef7 --- /dev/null +++ b/src/libos/src/net/socket/uring/datagram/receiver.rs @@ -0,0 +1,382 @@ +use core::time::Duration; +use std::mem::MaybeUninit; + +use crate::events::Poller; +use crate::net::socket::MsgFlags; +use io_uring_callback::{Fd, IoHandle}; +use sgx_untrusted_alloc::{MaybeUntrusted, UntrustedBox}; + +use crate::fs::IoEvents as Events; +use crate::net::socket::uring::common::Common; +use crate::net::socket::uring::runtime::Runtime; +use crate::prelude::*; + +pub struct Receiver { + common: Arc>, + inner: Mutex, +} + +impl Receiver { + pub fn new(common: Arc>) -> Arc { + let inner = Mutex::new(Inner::new()); + Arc::new(Self { common, inner }) + } + + pub fn recvmsg( + self: &Arc, + bufs: &mut [&mut [u8]], + flags: RecvFlags, + mut control: Option<&mut [u8]>, + ) -> Result<(usize, Option, MsgFlags, usize)> { + let mask = Events::IN; + // Initialize the poller only when needed + let mut poller = None; + let mut timeout = self.common.recv_timeout(); + loop { + // Attempt to recv + let res = self.try_recvmsg(bufs, flags, &mut control); + if !res.has_errno(EAGAIN) { + return res; + } + + // Need more handles for flags not MSG_DONTWAIT + // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT + if self.common.nonblocking() + || flags.contains(RecvFlags::MSG_DONTWAIT) + || flags.contains(RecvFlags::MSG_ERRQUEUE) + { + return_errno!(EAGAIN, "no data are present to be received"); + } + + // Wait for interesting events by polling + if poller.is_none() { + let new_poller = Poller::new(); + self.common.pollee().connect_poller(mask, &new_poller); + poller = Some(new_poller); + } + let events = self.common.pollee().poll(mask, None); + if events.is_empty() { + let ret = poller.as_ref().unwrap().wait_timeout(timeout.as_mut()); + if let Err(e) = ret { + warn!("recv wait errno = {:?}", e.errno()); + match e.errno() { + ETIMEDOUT => { + return_errno!(EAGAIN, "timeout reached") + } + _ => { + return_errno!(e.errno(), "wait error") + } + } + } + } + } + } + + fn try_recvmsg( + self: &Arc, + bufs: &mut [&mut [u8]], + flags: RecvFlags, + control: &mut Option<&mut [u8]>, + ) -> Result<(usize, Option, MsgFlags, usize)> { + let mut inner = self.inner.lock(); + + if !flags.is_empty() && flags.contains(RecvFlags::MSG_OOB | RecvFlags::MSG_CMSG_CLOEXEC) { + // todo!("Support other flags"); + return_errno!(EINVAL, "the socket flags is not supported"); + } + + // Mark the socket as non-readable since Datagram uses single packet + self.common.pollee().del_events(Events::IN); + + let mut recv_bytes = 0; + let mut msg_flags = MsgFlags::empty(); + let recv_addr = inner.get_packet_addr(); + let msg_controllen = inner.control_len.unwrap_or(0); + let user_controllen = control.as_ref().map_or(0, |buf| buf.len()); + + // Copy ancillary data from control buffer + if user_controllen > super::OPTMEM_MAX { + return_errno!(EINVAL, "invalid msg control length"); + } + + if user_controllen < msg_controllen { + msg_flags = msg_flags | MsgFlags::MSG_CTRUNC + } + + if msg_controllen > 0 { + let copied_bytes = msg_controllen.min(user_controllen); + control + .as_mut() + .map(|buf| buf[..copied_bytes].copy_from_slice(&inner.msg_control[..copied_bytes])); + } + + // Copy data from the recv buffer to the bufs + let copied_bytes = inner.try_copy_buf(bufs); + if let Some(copied_bytes) = copied_bytes { + let bufs_len: usize = bufs.iter().map(|buf| buf.len()).sum(); + + // If user provided buffer length is smaller than kernel received datagram length, + // discard the datagram and set MsgFlags::MSG_TRUNC in returned msg_flags. + if bufs_len < inner.recv_len().unwrap() { + // update msg.msg_flags to MSG_TRUNC + msg_flags = msg_flags | MsgFlags::MSG_TRUNC + }; + + // If user provided flags contain MSG_TRUNC, the return received length should be + // kernel receiver buffer length, vice versa should return truly copied bytes length. + recv_bytes = if flags.contains(RecvFlags::MSG_TRUNC) { + inner.recv_len().unwrap() + } else { + copied_bytes + }; + + // When flags contain MSG_PEEK and there is data in socket recv buffer, + // it is unnecessary to send blocking recv request (do_recv) to fetch data + // from iouring buffer, which may flush the data in recv buffer. + // When flags don't contain MSG_PEEK or there is no available data, + // it is time to send blocking request to iouring for notifying events. + if !flags.contains(RecvFlags::MSG_PEEK) { + self.do_recv(&mut inner); + } + return Ok((recv_bytes, recv_addr, msg_flags, msg_controllen)); + } + + // In some situantions of MSG_ERRQUEUE, users only require control buffer but setting iovec length to zero. + if msg_controllen > 0 { + return Ok((recv_bytes, recv_addr, msg_flags, msg_controllen)); + } + + // Handle iouring message error + if let Some(errno) = inner.error { + // Reset error + inner.error = None; + self.common.pollee().del_events(Events::ERR); + return_errno!(errno, "recv failed"); + } + + if inner.is_shutdown { + if self.common.nonblocking() + || flags.contains(RecvFlags::MSG_DONTWAIT) + || flags.contains(RecvFlags::MSG_ERRQUEUE) + { + return_errno!(Errno::EWOULDBLOCK, "the socket recv has been shutdown"); + } else { + return Ok((0, None, msg_flags, 0)); + } + } + + self.do_recv(&mut inner); + return_errno!(EAGAIN, "try recv again"); + } + + fn do_recv(self: &Arc, inner: &mut MutexGuard) { + if inner.io_handle.is_some() || self.common.is_closed() { + return; + } + // Clear recv_len and error + inner.recv_len.take(); + inner.control_len.take(); + inner.error.take(); + + if inner.is_shutdown { + info!("do_recv early return, the socket recv has been shutdown"); + return; + } + + let receiver = self.clone(); + // Init the callback invoked upon the completion of the async recv + let complete_fn = move |retval: i32| { + let mut inner = receiver.inner.lock(); + + // Release the handle to the async recv + inner.io_handle.take(); + + // Handle error + if retval < 0 { + // TODO: Should we filter the error case? Do we have the ability to filter? + // We only filter the normal case now. According to the man page of recvmsg, + // these errors should not happen, since our fd and arguments should always + // be valid unless being attacked. + + // TODO: guard against Iago attack through errno + let errno = Errno::from(-retval as u32); + inner.error = Some(errno); + receiver.common.set_errno(errno); + // TODO: add PRI event if set SO_SELECT_ERR_QUEUE + receiver.common.pollee().add_events(Events::ERR); + return; + } + + // Handle the normal case of a successful read + inner.recv_len = Some(retval as usize); + + let control_len = inner.req.msg.msg_controllen; + inner.control_len = Some(control_len); + + receiver.common.pollee().add_events(Events::IN); + + // We don't do_recv() here, since do_recv() will clear the recv message. + }; + + // Generate the async recv request + let msghdr_ptr = inner.new_recv_req(); + + // Submit the async recv to io_uring + let io_uring = self.common.io_uring(); + let host_fd = Fd(self.common.host_fd() as _); + let handle = unsafe { io_uring.recvmsg(host_fd, msghdr_ptr, 0, complete_fn) }; + inner.io_handle.replace(handle); + } + + pub fn initiate_async_recv(self: &Arc) { + let mut inner = self.inner.lock(); + self.do_recv(&mut inner); + } + + pub fn cancel_recv_requests(&self) { + { + let inner = self.inner.lock(); + if let Some(io_handle) = &inner.io_handle { + let io_uring = self.common.io_uring(); + unsafe { io_uring.cancel(io_handle) }; + } else { + return; + } + } + + // wait for the cancel to complete + let poller = Poller::new(); + let mask = Events::ERR | Events::IN; + self.common.pollee().connect_poller(mask, &poller); + + loop { + let pending_request_exist = { + let inner = self.inner.lock(); + inner.io_handle.is_some() + }; + + if pending_request_exist { + let mut timeout = Some(Duration::from_secs(10)); + let ret = poller.wait_timeout(timeout.as_mut()); + if let Err(e) = ret { + warn!("wait cancel recv request error = {:?}", e.errno()); + continue; + } + } else { + break; + } + } + } + + /// Shutdown udp receiver. + pub fn shutdown(&self) { + let mut inner = self.inner.lock(); + inner.is_shutdown = true; + } + + /// Reset udp receiver shutdown state. + pub fn reset_shutdown(&self) { + let mut inner = self.inner.lock(); + inner.is_shutdown = false; + } + + pub fn ready_len(&self) -> usize { + let inner = self.inner.lock(); + inner.recv_len().unwrap_or(0) + } +} + +struct Inner { + recv_buf: UntrustedBox<[u8]>, + // Datagram sockets in various domains permit zero-length datagrams. + // Hence the recv_len might be 0. + recv_len: Option, + // When the recv_buf content length is greater than user buffer, + // store the offset for the recv_buf for read loop + recv_buf_offset: usize, + msg_control: UntrustedBox<[u8]>, + control_len: Option, + req: UntrustedBox, + io_handle: Option, + error: Option, + is_shutdown: bool, +} + +unsafe impl Send for Inner {} + +impl Inner { + pub fn new() -> Self { + Self { + recv_buf: UntrustedBox::new_uninit_slice(super::MAX_BUF_SIZE), + recv_len: None, + recv_buf_offset: 0, + msg_control: UntrustedBox::new_uninit_slice(super::OPTMEM_MAX), + control_len: None, + req: UntrustedBox::new_uninit(), + io_handle: None, + error: None, + is_shutdown: false, + } + } + + pub fn new_recv_req(&mut self) -> *mut libc::msghdr { + let iovec = libc::iovec { + iov_base: self.recv_buf.as_mut_ptr() as _, + iov_len: self.recv_buf.len(), + }; + + let msghdr_ptr = &raw mut self.req.msg; + + let mut msg: libc::msghdr = unsafe { MaybeUninit::zeroed().assume_init() }; + msg.msg_iov = &raw mut self.req.iovec as _; + msg.msg_iovlen = 1; + msg.msg_name = &raw mut self.req.addr as _; + msg.msg_namelen = std::mem::size_of::() as _; + + msg.msg_control = self.msg_control.as_mut_ptr() as _; + msg.msg_controllen = self.msg_control.len() as _; + + self.req.msg = msg; + self.req.iovec = iovec; + + msghdr_ptr + } + + pub fn try_copy_buf(&self, bufs: &mut [&mut [u8]]) -> Option { + self.recv_len.map(|recv_len| { + let mut copy_len = 0; + for buf in bufs { + let recv_buf = &self.recv_buf[copy_len..recv_len]; + if buf.len() <= recv_buf.len() { + buf.copy_from_slice(&recv_buf[..buf.len()]); + copy_len += buf.len(); + } else { + buf[..recv_buf.len()].copy_from_slice(&recv_buf[..]); + copy_len += recv_buf.len(); + break; + } + } + copy_len + }) + } + + pub fn recv_len(&self) -> Option { + self.recv_len + } + + /// Return the addr of the received packet if udp socket is not connected. + /// Return None if udp socket is connected. + pub fn get_packet_addr(&self) -> Option { + let recv_addr_len = self.req.msg.msg_namelen as usize; + A::from_c_storage(&self.req.addr, recv_addr_len).ok() + } +} + +#[repr(C)] +struct RecvReq { + msg: libc::msghdr, + iovec: libc::iovec, + addr: libc::sockaddr_storage, +} + +unsafe impl MaybeUntrusted for RecvReq {} diff --git a/src/libos/src/net/socket/uring/datagram/sender.rs b/src/libos/src/net/socket/uring/datagram/sender.rs new file mode 100644 index 00000000..a21daa2a --- /dev/null +++ b/src/libos/src/net/socket/uring/datagram/sender.rs @@ -0,0 +1,406 @@ +use core::time::Duration; +use std::ptr::{self}; + +use io_uring_callback::{Fd, IoHandle}; +use libc::c_void; +use sgx_untrusted_alloc::{MaybeUntrusted, UntrustedBox}; +use std::collections::VecDeque; + +use crate::events::Poller; +use crate::fs::IoEvents as Events; +use crate::net::socket::uring::common::Common; +use crate::net::socket::uring::runtime::Runtime; +use crate::prelude::*; +use crate::util::sync::MutexGuard; + +const SENDMSG_QUEUE_LEN: usize = 16; + +pub struct Sender { + common: Arc>, + inner: Mutex, +} + +impl Sender { + pub fn new(common: Arc>) -> Arc { + common.pollee().add_events(Events::OUT); + let inner = Mutex::new(Inner::new()); + Arc::new(Self { common, inner }) + } + + /// Shutdown udp sender. + pub fn shutdown(&self) { + let mut inner = self.inner.lock(); + inner.is_shutdown = ShutdownStatus::PreShutdown; + } + + /// Reset udp sender shutdown state. + pub fn reset_shutdown(&self) { + let mut inner = self.inner.lock(); + inner.is_shutdown = ShutdownStatus::Running; + } + + /// Whether no buffer in sender. + pub fn is_empty(&self) -> bool { + let inner = self.inner.lock(); + inner.msg_queue.is_empty() + } + + // Normally, We will always try to send as long as the kernel send buf is not empty. + // However, if the user calls close, we will wait LINGER time + // and then cancel on-going or new-issued send requests. + pub fn try_clear_msg_queue_when_close(&self) { + let inner = self.inner.lock(); + debug_assert!(inner.is_shutdown()); + if inner.msg_queue.is_empty() { + return; + } + + // Wait for linger time to empty the kernel buffer or cancel subsequent requests. + drop(inner); + const DEFUALT_LINGER_TIME: usize = 10; + let poller = Poller::new(); + let mask = Events::ERR | Events::OUT; + self.common.pollee().connect_poller(mask, &poller); + + loop { + let pending_request_exist = { + let inner = self.inner.lock(); + inner.io_handle.is_some() + }; + + if pending_request_exist { + let mut timeout = Some(Duration::from_secs(DEFUALT_LINGER_TIME as u64)); + let ret = poller.wait_timeout(timeout.as_mut()); + trace!("wait empty send buffer ret = {:?}", ret); + if let Err(_) = ret { + // No complete request to wake. Just cancel the send requests. + let io_uring = self.common.io_uring(); + let inner = self.inner.lock(); + if let Some(io_handle) = &inner.io_handle { + unsafe { io_uring.cancel(io_handle) }; + // Loop again to wait the cancel request to complete + continue; + } else { + // No pending request, just break + break; + } + } + } else { + // There is no pending requests + break; + } + } + } + + pub fn sendmsg( + self: &Arc, + bufs: &[&[u8]], + addr: &A, + flags: SendFlags, + control: Option<&[u8]>, + ) -> Result { + if !flags.is_empty() + && flags.intersects(!(SendFlags::MSG_DONTWAIT | SendFlags::MSG_NOSIGNAL)) + { + error!("Not supported flags: {:?}", flags); + return_errno!(EINVAL, "not supported flags"); + } + let mask = Events::OUT; + // Initialize the poller only when needed + let mut poller = None; + let mut timeout = self.common.send_timeout(); + loop { + // Attempt to write + let res = self.try_sendmsg(bufs, addr, control); + if !res.has_errno(EAGAIN) { + return res; + } + + // Still some buffer contents pending + if self.common.nonblocking() || flags.contains(SendFlags::MSG_DONTWAIT) { + return_errno!(EAGAIN, "try write again"); + } + + // Wait for interesting events by polling + if poller.is_none() { + let new_poller = Poller::new(); + self.common.pollee().connect_poller(mask, &new_poller); + poller = Some(new_poller); + } + + let events = self.common.pollee().poll(mask, None); + if events.is_empty() { + let ret = poller.as_ref().unwrap().wait_timeout(timeout.as_mut()); + if let Err(e) = ret { + warn!("send wait errno = {:?}", e.errno()); + match e.errno() { + ETIMEDOUT => { + return_errno!(EAGAIN, "timeout reached") + } + _ => { + return_errno!(e.errno(), "wait error") + } + } + } + } + } + } + + fn try_sendmsg( + self: &Arc, + bufs: &[&[u8]], + addr: &A, + control: Option<&[u8]>, + ) -> Result { + let mut inner = self.inner.lock(); + if inner.is_shutdown() { + return_errno!(EPIPE, "the write has been shutdown") + } + + if let Some(errno) = inner.error { + // Reset error + inner.error = None; + self.common.pollee().del_events(Events::ERR); + return_errno!(errno, "write failed"); + } + + let buf_len: usize = bufs.iter().map(|buf| buf.len()).sum(); + let mut msg = DataMsg::new(buf_len); + let total_copied = msg.copy_buf(bufs)?; + msg.copy_control(control)?; + + let msghdr_ptr = new_send_req(&mut msg, addr); + + if !inner.msg_queue.push_msg(msg) { + // Msg queue can not push this msg, mark the socket as non-writable + self.common.pollee().del_events(Events::OUT); + return_errno!(EAGAIN, "try write again"); + } + + // Since the send buffer is not empty, try to flush the buffer + if inner.io_handle.is_none() { + self.do_send(&mut inner, msghdr_ptr); + } + Ok(total_copied) + } + + fn do_send(self: &Arc, inner: &mut MutexGuard, msghdr_ptr: *const libc::msghdr) { + debug_assert!(!inner.msg_queue.is_empty()); + debug_assert!(inner.io_handle.is_none()); + let sender = self.clone(); + // Submit the async send to io_uring + let complete_fn = move |retval: i32| { + let mut inner = sender.inner.lock(); + trace!("send request complete with retval: {}", retval); + + // Release the handle to the async recv + inner.io_handle.take(); + + if retval < 0 { + // TODO: add PRI event if set SO_SELECT_ERR_QUEUE + let errno = Errno::from(-retval as u32); + + inner.error = Some(errno); + sender.common.set_errno(errno); + sender.common.pollee().add_events(Events::ERR); + return; + } + + // Need to handle normal case + inner.msg_queue.pop_msg(); + sender.common.pollee().add_events(Events::OUT); + if !inner.msg_queue.is_empty() { + let msghdr_ptr = inner.msg_queue.first_msg_ptr(); + debug_assert!(msghdr_ptr.is_some()); + sender.do_send(&mut inner, msghdr_ptr.unwrap()); + } else if inner.is_shutdown == ShutdownStatus::PreShutdown { + // The buffer is empty and the write side is shutdown by the user. + // We can safely shutdown host file here. + let _ = sender.common.host_shutdown(Shutdown::Write); + inner.is_shutdown = ShutdownStatus::PostShutdown + } + }; + + // Generate the async recv request + let io_uring = self.common.io_uring(); + let host_fd = Fd(self.common.host_fd() as _); + let handle = unsafe { io_uring.sendmsg(host_fd, msghdr_ptr, 0, complete_fn) }; + inner.io_handle.replace(handle); + } +} + +fn new_send_req(dmsg: &mut DataMsg, addr: &A) -> *const libc::msghdr { + let iovec = libc::iovec { + iov_base: dmsg.send_buf.as_ptr() as _, + iov_len: dmsg.send_buf.len(), + }; + + let (control, controllen) = match &dmsg.control { + Some(control) => (control.as_mut_ptr() as *mut c_void, control.len()), + None => (ptr::null_mut(), 0), + }; + + dmsg.req.iovec = iovec; + + dmsg.req.msg.msg_iov = &raw mut dmsg.req.iovec as _; + dmsg.req.msg.msg_iovlen = 1; + + let (c_addr_storage, c_addr_len) = addr.to_c_storage(); + + dmsg.req.addr = c_addr_storage; + dmsg.req.msg.msg_name = &raw mut dmsg.req.addr as _; + dmsg.req.msg.msg_namelen = c_addr_len as _; + dmsg.req.msg.msg_control = control; + dmsg.req.msg.msg_controllen = controllen; + + &mut dmsg.req.msg +} + +pub struct Inner { + io_handle: Option, + error: Option, + is_shutdown: ShutdownStatus, + msg_queue: MsgQueue, +} + +unsafe impl Send for Inner {} + +impl Inner { + pub fn new() -> Self { + Self { + io_handle: None, + error: None, + is_shutdown: ShutdownStatus::Running, + msg_queue: MsgQueue::new(), + } + } + + /// Obtain udp sender shutdown state. + #[inline(always)] + pub fn is_shutdown(&self) -> bool { + self.is_shutdown == ShutdownStatus::PreShutdown + || self.is_shutdown == ShutdownStatus::PostShutdown + } +} + +#[repr(C)] +struct SendReq { + msg: libc::msghdr, + iovec: libc::iovec, + addr: libc::sockaddr_storage, +} + +unsafe impl MaybeUntrusted for SendReq {} + +struct MsgQueue { + queue: VecDeque, + curr_size: usize, +} + +impl MsgQueue { + #[inline(always)] + fn new() -> Self { + Self { + queue: VecDeque::with_capacity(SENDMSG_QUEUE_LEN), + curr_size: 0, + } + } + + #[inline(always)] + fn size(&self) -> usize { + self.curr_size + } + + #[inline(always)] + fn is_empty(&self) -> bool { + self.queue.is_empty() + } + + // Push datagram msg, return true if succeed, + // return false if buffer is full. + #[inline(always)] + fn push_msg(&mut self, msg: DataMsg) -> bool { + let total_len = msg.len() + self.size(); + if total_len <= super::MAX_BUF_SIZE { + self.curr_size = total_len; + self.queue.push_back(msg); + return true; + } + false + } + + #[inline(always)] + fn pop_msg(&mut self) { + if let Some(msg) = self.queue.pop_front() { + self.curr_size = self.size() - msg.len(); + } + } + + #[inline(always)] + fn first_msg_ptr(&self) -> Option<*const libc::msghdr> { + self.queue + .front() + .map(|data_msg| &data_msg.req.msg as *const libc::msghdr) + } +} + +// Datagram msg contents in untrusted region +struct DataMsg { + req: UntrustedBox, + send_buf: UntrustedBox<[u8]>, + control: Option>, +} + +impl DataMsg { + #[inline(always)] + fn new(buf_len: usize) -> Self { + Self { + req: UntrustedBox::::new_uninit(), + send_buf: UntrustedBox::new_uninit_slice(buf_len), + control: None, + } + } + + #[inline(always)] + fn copy_buf(&mut self, bufs: &[&[u8]]) -> Result { + let total_len = self.send_buf.len(); + if total_len > super::MAX_BUF_SIZE { + return_errno!(EMSGSIZE, "the message is too large") + } + // Copy data from the bufs to the send buffer + let mut total_copied = 0; + for buf in bufs { + self.send_buf[total_copied..(total_copied + buf.len())].copy_from_slice(buf); + total_copied += buf.len(); + } + Ok(total_copied) + } + + #[inline(always)] + fn copy_control(&mut self, control: Option<&[u8]>) -> Result { + if let Some(msg_control) = control { + let send_controllen = msg_control.len(); + if send_controllen > super::OPTMEM_MAX { + return_errno!(EINVAL, "invalid msg control length"); + } + let mut send_control_buf = UntrustedBox::new_uninit_slice(send_controllen); + send_control_buf.copy_from_slice(&msg_control[..send_controllen]); + + self.control = Some(send_control_buf); + return Ok(send_controllen); + }; + Ok(0) + } + + #[inline(always)] + fn len(&self) -> usize { + self.send_buf.len() + } +} + +#[derive(Debug, PartialEq)] +enum ShutdownStatus { + Running, // not shutdown + PreShutdown, // start the shutdown process, set by calling shutdown syscall + PostShutdown, // shutdown process is done, set when the buffer is empty +} diff --git a/src/libos/src/net/socket/uring/file_impl.rs b/src/libos/src/net/socket/uring/file_impl.rs new file mode 100644 index 00000000..521744a7 --- /dev/null +++ b/src/libos/src/net/socket/uring/file_impl.rs @@ -0,0 +1,79 @@ +use super::socket_file::SocketFile; +use crate::fs::{ + AccessMode, FileDesc, HostFd, IoEvents, IoNotifier, IoctlCmd, IoctlRawCmd, StatusFlags, +}; +use crate::prelude::*; +use std::{io::SeekFrom, os::unix::raw::off_t}; + +impl File for SocketFile { + fn read(&self, buf: &mut [u8]) -> Result { + self.read(buf) + } + + fn readv(&self, bufs: &mut [&mut [u8]]) -> Result { + self.readv(bufs) + } + + fn write(&self, buf: &[u8]) -> Result { + self.write(buf) + } + + fn writev(&self, bufs: &[&[u8]]) -> Result { + self.writev(bufs) + } + + fn read_at(&self, offset: usize, buf: &mut [u8]) -> Result { + if offset != 0 { + return_errno!(ESPIPE, "a nonzero position is not supported"); + } + self.read(buf) + } + + fn write_at(&self, offset: usize, buf: &[u8]) -> Result { + if offset != 0 { + return_errno!(ESPIPE, "a nonzero position is not supported"); + } + self.write(buf) + } + + fn seek(&self, pos: SeekFrom) -> Result { + return_errno!(ESPIPE, "Socket does not support seek") + } + + fn ioctl(&self, cmd: &mut dyn IoctlCmd) -> Result<()> { + self.ioctl(cmd) + } + + fn notifier(&self) -> Option<&IoNotifier> { + Some(self.notifier()) + } + + fn access_mode(&self) -> Result { + Ok(AccessMode::O_RDWR) + } + + fn status_flags(&self) -> Result { + Ok(self.status_flags()) + } + + fn set_status_flags(&self, new_status_flags: StatusFlags) -> Result<()> { + self.set_status_flags(new_status_flags) + } + + fn poll_new(&self) -> IoEvents { + let mask = IoEvents::all(); + self.poll(mask, None) + } + + fn host_fd(&self) -> Option<&HostFd> { + None + } + + fn update_host_events(&self, ready: &IoEvents, mask: &IoEvents, trigger_notifier: bool) { + unreachable!() + } + + fn as_any(&self) -> &dyn core::any::Any { + self + } +} diff --git a/src/libos/src/net/socket/uring/mod.rs b/src/libos/src/net/socket/uring/mod.rs new file mode 100644 index 00000000..b39bc3c0 --- /dev/null +++ b/src/libos/src/net/socket/uring/mod.rs @@ -0,0 +1,12 @@ +#![feature(stmt_expr_attributes)] +#![feature(new_uninit)] +#![feature(raw_ref_op)] + +pub mod common; +pub mod datagram; +pub mod file_impl; +pub mod runtime; +pub mod socket_file; +pub mod stream; + +pub use self::socket_file::UringSocketType; diff --git a/src/libos/src/net/socket/uring/runtime.rs b/src/libos/src/net/socket/uring/runtime.rs new file mode 100644 index 00000000..7d55143d --- /dev/null +++ b/src/libos/src/net/socket/uring/runtime.rs @@ -0,0 +1,12 @@ +use alloc::sync::Arc; +use io_uring_callback::IoUring; + +/// The runtime support for HostSocket. +/// +/// This trait provides a common interface for user-implemented runtimes +/// that support HostSocket. Currently, the only dependency is a singleton +/// of IoUring instance. +pub trait Runtime: Send + Sync + 'static { + fn io_uring() -> Arc; + fn disattach_io_uring(fd: usize, uring: Arc); +} diff --git a/src/libos/src/net/socket/uring/socket_file.rs b/src/libos/src/net/socket/uring/socket_file.rs new file mode 100644 index 00000000..d7ee0595 --- /dev/null +++ b/src/libos/src/net/socket/uring/socket_file.rs @@ -0,0 +1,446 @@ +use self::impls::{Ipv4Datagram, Ipv6Datagram}; +use crate::events::{Observer, Poller}; +use crate::net::socket::{MsgFlags, SocketProtocol}; + +use self::impls::{Ipv4Stream, Ipv6Stream}; +use crate::fs::{AccessMode, IoEvents, IoNotifier, IoctlCmd, StatusFlags}; +use crate::net::socket::{AnyAddr, Ipv4SocketAddr, Ipv6SocketAddr}; +use crate::prelude::*; + +#[derive(Debug)] +pub struct SocketFile { + socket: AnySocket, +} + +// Apply a function to all variants of AnySocket enum. +macro_rules! apply_fn_on_any_socket { + ($any_socket:expr, |$socket:ident| { $($fn_body:tt)* }) => {{ + let any_socket: &AnySocket = $any_socket; + match any_socket { + AnySocket::Ipv4Stream($socket) => { + $($fn_body)* + } + AnySocket::Ipv6Stream($socket) => { + $($fn_body)* + } + AnySocket::Ipv4Datagram($socket) => { + $($fn_body)* + } + AnySocket::Ipv6Datagram($socket) => { + $($fn_body)* + } + } + }} +} + +pub trait UringSocketType { + fn as_uring_socket(&self) -> Result<&SocketFile>; +} + +impl UringSocketType for FileRef { + fn as_uring_socket(&self) -> Result<&SocketFile> { + self.as_any() + .downcast_ref::() + .ok_or_else(|| errno!(ENOTSOCK, "not a uring socket")) + } +} + +#[derive(Debug)] +enum AnySocket { + Ipv4Stream(Ipv4Stream), + Ipv6Stream(Ipv6Stream), + Ipv4Datagram(Ipv4Datagram), + Ipv6Datagram(Ipv6Datagram), +} + +// Implement the common methods required by FileHandle +impl SocketFile { + pub fn read(&self, buf: &mut [u8]) -> Result { + apply_fn_on_any_socket!(&self.socket, |socket| { socket.read(buf) }) + } + + pub fn readv(&self, bufs: &mut [&mut [u8]]) -> Result { + apply_fn_on_any_socket!(&self.socket, |socket| { socket.readv(bufs) }) + } + + pub fn write(&self, buf: &[u8]) -> Result { + apply_fn_on_any_socket!(&self.socket, |socket| { socket.write(buf) }) + } + + pub fn writev(&self, bufs: &[&[u8]]) -> Result { + apply_fn_on_any_socket!(&self.socket, |socket| { socket.writev(bufs) }) + } + + pub fn access_mode(&self) -> AccessMode { + // We consider all sockets both readable and writable + AccessMode::O_RDWR + } + + pub fn status_flags(&self) -> StatusFlags { + apply_fn_on_any_socket!(&self.socket, |socket| { socket.status_flags() }) + } + + pub fn host_fd_inner(&self) -> FileDesc { + apply_fn_on_any_socket!(&self.socket, |socket| { socket.host_fd() }) + } + + pub fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { + apply_fn_on_any_socket!(&self.socket, |socket| { + socket.set_status_flags(new_flags) + }) + } + + pub fn notifier(&self) -> &IoNotifier { + apply_fn_on_any_socket!(&self.socket, |socket| { socket.notifier() }) + } + + pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { + apply_fn_on_any_socket!(&self.socket, |socket| { socket.poll(mask, poller) }) + } + + pub fn ioctl(&self, cmd: &mut dyn IoctlCmd) -> Result<()> { + apply_fn_on_any_socket!(&self.socket, |socket| { socket.ioctl(cmd) }) + } + + pub fn get_type(&self) -> Type { + match self.socket { + AnySocket::Ipv4Stream(_) | AnySocket::Ipv6Stream(_) => Type::STREAM, + AnySocket::Ipv4Datagram(_) | AnySocket::Ipv6Datagram(_) => Type::DGRAM, + } + } +} + +// Implement socket-specific methods +impl SocketFile { + pub fn new( + domain: Domain, + protocol: SocketProtocol, + socket_type: Type, + nonblocking: bool, + ) -> Result { + match socket_type { + Type::STREAM => { + if protocol != SocketProtocol::IPPROTO_IP && protocol != SocketProtocol::IPPROTO_TCP + { + return_errno!(EPROTONOSUPPORT, "Protocol not supported"); + } + let any_socket = match domain { + Domain::INET => { + let ipv4_stream = Ipv4Stream::new(nonblocking)?; + AnySocket::Ipv4Stream(ipv4_stream) + } + Domain::INET6 => { + let ipv6_stream = Ipv6Stream::new(nonblocking)?; + AnySocket::Ipv6Stream(ipv6_stream) + } + _ => { + panic!() + } + }; + let new_self = Self { socket: any_socket }; + Ok(new_self) + } + Type::DGRAM => { + if protocol != SocketProtocol::IPPROTO_IP && protocol != SocketProtocol::IPPROTO_UDP + { + return_errno!(EPROTONOSUPPORT, "Protocol not supported"); + } + let any_socket = match domain { + Domain::INET => { + let ipv4_datagram = Ipv4Datagram::new(nonblocking)?; + AnySocket::Ipv4Datagram(ipv4_datagram) + } + Domain::INET6 => { + let ipv6_datagram = Ipv6Datagram::new(nonblocking)?; + AnySocket::Ipv6Datagram(ipv6_datagram) + } + _ => { + return_errno!(EINVAL, "not support yet"); + } + }; + let new_self = Self { socket: any_socket }; + Ok(new_self) + } + Type::RAW => { + return_errno!(EINVAL, "RAW socket not supported"); + } + _ => { + return_errno!(ESOCKTNOSUPPORT, "socket type not supported"); + } + } + } + + pub fn domain(&self) -> Domain { + apply_fn_on_any_socket!(&self.socket, |socket| { socket.domain() }) + } + + pub fn is_stream(&self) -> bool { + matches!(&self.socket, AnySocket::Ipv4Stream(_)) + } + + pub fn connect(&self, addr: &AnyAddr) -> Result<()> { + match &self.socket { + AnySocket::Ipv4Stream(ipv4_stream) => { + let ip_addr = addr.to_ipv4()?; + ipv4_stream.connect(ip_addr) + } + AnySocket::Ipv6Stream(ipv6_stream) => { + let ip_addr = addr.to_ipv6()?; + ipv6_stream.connect(ip_addr) + } + AnySocket::Ipv4Datagram(ipv4_datagram) => { + let mut ip_addr = None; + if !addr.is_unspec() { + let ipv4_addr = addr.to_ipv4()?; + ip_addr = Some(ipv4_addr); + } + ipv4_datagram.connect(ip_addr) + } + AnySocket::Ipv6Datagram(ipv6_datagram) => { + let mut ip_addr = None; + if !addr.is_unspec() { + let ipv6_addr = addr.to_ipv6()?; + ip_addr = Some(ipv6_addr); + } + ipv6_datagram.connect(ip_addr) + } + _ => { + return_errno!(EINVAL, "connect is not supported"); + } + } + } + + pub fn bind(&self, addr: &mut AnyAddr) -> Result<()> { + match &self.socket { + AnySocket::Ipv4Stream(ipv4_stream) => { + let ip_addr = addr.to_ipv4()?; + ipv4_stream.bind(ip_addr) + } + AnySocket::Ipv6Stream(ipv6_stream) => { + let ip_addr = addr.to_ipv6()?; + ipv6_stream.bind(ip_addr) + } + AnySocket::Ipv4Datagram(ipv4_datagram) => { + let ip_addr = addr.to_ipv4()?; + ipv4_datagram.bind(ip_addr) + } + + AnySocket::Ipv6Datagram(ipv6_datagram) => { + let ip_addr = addr.to_ipv6()?; + ipv6_datagram.bind(ip_addr) + } + + _ => { + return_errno!(EINVAL, "bind is not supported"); + } + } + } + + pub fn listen(&self, backlog: u32) -> Result<()> { + match &self.socket { + AnySocket::Ipv4Stream(ip_stream) => ip_stream.listen(backlog), + AnySocket::Ipv6Stream(ip_stream) => ip_stream.listen(backlog), + _ => { + return_errno!(EOPNOTSUPP, "The socket is not of a listen supported type"); + } + } + } + + pub fn accept(&self, nonblocking: bool) -> Result { + let accepted_any_socket = match &self.socket { + AnySocket::Ipv4Stream(ipv4_stream) => { + let accepted_ipv4_stream = ipv4_stream.accept(nonblocking)?; + AnySocket::Ipv4Stream(accepted_ipv4_stream) + } + AnySocket::Ipv6Stream(ipv6_stream) => { + let accepted_ipv6_stream = ipv6_stream.accept(nonblocking)?; + AnySocket::Ipv6Stream(accepted_ipv6_stream) + } + _ => { + return_errno!(EOPNOTSUPP, "The socket is not of a accept supported type"); + } + }; + let accepted_socket_file = SocketFile { + socket: accepted_any_socket, + }; + Ok(accepted_socket_file) + } + + pub fn recvfrom(&self, buf: &mut [u8], flags: RecvFlags) -> Result<(usize, Option)> { + let (bytes_recv, addr_recv, _, _) = self.recvmsg(&mut [buf], flags, None)?; + Ok((bytes_recv, addr_recv)) + } + + pub fn recvmsg( + &self, + bufs: &mut [&mut [u8]], + flags: RecvFlags, + control: Option<&mut [u8]>, + ) -> Result<(usize, Option, MsgFlags, usize)> { + // TODO: support msg_flags and msg_control + Ok(match &self.socket { + AnySocket::Ipv4Stream(ipv4_stream) => { + let (bytes_recv, addr_recv, msg_flags) = ipv4_stream.recvmsg(bufs, flags)?; + ( + bytes_recv, + addr_recv.map(|addr| AnyAddr::Ipv4(addr)), + msg_flags, + 0, + ) + } + AnySocket::Ipv6Stream(ipv6_stream) => { + let (bytes_recv, addr_recv, msg_flags) = ipv6_stream.recvmsg(bufs, flags)?; + ( + bytes_recv, + addr_recv.map(|addr| AnyAddr::Ipv6(addr)), + msg_flags, + 0, + ) + } + AnySocket::Ipv4Datagram(ipv4_datagram) => { + let (bytes_recv, addr_recv, msg_flags, msg_controllen) = + ipv4_datagram.recvmsg(bufs, flags, control)?; + ( + bytes_recv, + addr_recv.map(|addr| AnyAddr::Ipv4(addr)), + msg_flags, + msg_controllen, + ) + } + AnySocket::Ipv6Datagram(ipv6_datagram) => { + let (bytes_recv, addr_recv, msg_flags, msg_controllen) = + ipv6_datagram.recvmsg(bufs, flags, control)?; + ( + bytes_recv, + addr_recv.map(|addr| AnyAddr::Ipv6(addr)), + msg_flags, + msg_controllen, + ) + } + _ => { + return_errno!(EINVAL, "recvfrom is not supported"); + } + }) + } + + pub fn sendto(&self, buf: &[u8], addr: Option, flags: SendFlags) -> Result { + self.sendmsg(&[buf], addr, flags, None) + } + + pub fn sendmsg( + &self, + bufs: &[&[u8]], + addr: Option, + flags: SendFlags, + control: Option<&[u8]>, + ) -> Result { + let res = match &self.socket { + AnySocket::Ipv4Stream(ipv4_stream) => ipv4_stream.sendmsg(bufs, flags), + AnySocket::Ipv6Stream(ipv6_stream) => ipv6_stream.sendmsg(bufs, flags), + AnySocket::Ipv4Datagram(ipv4_datagram) => { + let ip_addr = if let Some(addr) = addr.as_ref() { + Some(addr.to_ipv4()?) + } else { + None + }; + ipv4_datagram.sendmsg(bufs, ip_addr, flags, control) + } + AnySocket::Ipv6Datagram(ipv6_datagram) => { + let ip_addr = if let Some(addr) = addr.as_ref() { + Some(addr.to_ipv6()?) + } else { + None + }; + ipv6_datagram.sendmsg(bufs, ip_addr, flags, control) + } + _ => { + return_errno!(EINVAL, "sendmsg is not supported"); + } + }; + if res.has_errno(EPIPE) && !flags.contains(SendFlags::MSG_NOSIGNAL) { + crate::signal::do_tkill(current!().tid(), crate::signal::SIGPIPE.as_u8() as i32); + } + + res + } + + pub fn addr(&self) -> Result { + Ok(match &self.socket { + AnySocket::Ipv4Stream(ipv4_stream) => AnyAddr::Ipv4(ipv4_stream.addr()?), + AnySocket::Ipv6Stream(ipv6_stream) => AnyAddr::Ipv6(ipv6_stream.addr()?), + AnySocket::Ipv4Datagram(ipv4_datagram) => AnyAddr::Ipv4(ipv4_datagram.addr()?), + AnySocket::Ipv6Datagram(ipv6_datagram) => AnyAddr::Ipv6(ipv6_datagram.addr()?), + _ => { + return_errno!(EINVAL, "addr is not supported"); + } + }) + } + + pub fn peer_addr(&self) -> Result { + Ok(match &self.socket { + AnySocket::Ipv4Stream(ipv4_stream) => AnyAddr::Ipv4(ipv4_stream.peer_addr()?), + AnySocket::Ipv6Stream(ipv6_stream) => AnyAddr::Ipv6(ipv6_stream.peer_addr()?), + AnySocket::Ipv4Datagram(ipv4_datagram) => AnyAddr::Ipv4(ipv4_datagram.peer_addr()?), + AnySocket::Ipv6Datagram(ipv6_datagram) => AnyAddr::Ipv6(ipv6_datagram.peer_addr()?), + _ => { + return_errno!(EINVAL, "peer_addr is not supported"); + } + }) + } + + pub fn shutdown(&self, how: Shutdown) -> Result<()> { + match &self.socket { + AnySocket::Ipv4Stream(ipv4_stream) => ipv4_stream.shutdown(how), + AnySocket::Ipv6Stream(ipv6_stream) => ipv6_stream.shutdown(how), + AnySocket::Ipv4Datagram(ipv4_datagram) => ipv4_datagram.shutdown(how), + AnySocket::Ipv6Datagram(ipv6_datagram) => ipv6_datagram.shutdown(how), + _ => { + return_errno!(EINVAL, "shutdown is not supported"); + } + } + } + + pub fn close(&self) -> Result<()> { + match &self.socket { + AnySocket::Ipv4Stream(ipv4_stream) => ipv4_stream.close(), + AnySocket::Ipv6Stream(ipv6_stream) => ipv6_stream.close(), + AnySocket::Ipv4Datagram(ipv4_datagram) => ipv4_datagram.close(), + AnySocket::Ipv6Datagram(ipv6_datagram) => ipv6_datagram.close(), + _ => Ok(()), + } + } +} + +impl Drop for SocketFile { + fn drop(&mut self) { + self.close(); + } +} + +mod impls { + use super::*; + use io_uring_callback::IoUring; + + pub type Ipv4Stream = + crate::net::socket::uring::stream::StreamSocket; + pub type Ipv6Stream = + crate::net::socket::uring::stream::StreamSocket; + + pub type Ipv4Datagram = + crate::net::socket::uring::datagram::DatagramSocket; + pub type Ipv6Datagram = + crate::net::socket::uring::datagram::DatagramSocket; + + pub struct SocketRuntime; + impl crate::net::socket::uring::runtime::Runtime for SocketRuntime { + // Assign an IO-Uring instance for newly created socket + fn io_uring() -> Arc { + crate::io_uring::MULTITON.get_uring() + } + + // Disattach IO-Uring instance with closed socket + fn disattach_io_uring(fd: usize, uring: Arc) { + crate::io_uring::MULTITON.disattach_uring(fd, uring); + } + } +} diff --git a/src/libos/src/net/socket/uring/stream/mod.rs b/src/libos/src/net/socket/uring/stream/mod.rs new file mode 100644 index 00000000..fa4e070c --- /dev/null +++ b/src/libos/src/net/socket/uring/stream/mod.rs @@ -0,0 +1,610 @@ +mod states; + +use core::hint; +use core::sync::atomic::AtomicUsize; +use core::time::Duration; + +use atomic::Ordering; + +use self::states::{ConnectedStream, ConnectingStream, InitStream, ListenerStream}; +use crate::events::Observer; +use crate::fs::{ + GetIfConf, GetIfReqWithRawCmd, GetReadBufLen, IoEvents, IoNotifier, IoctlCmd, SetNonBlocking, + StatusFlags, +}; + +use crate::net::socket::uring::common::Common; +use crate::net::socket::uring::runtime::Runtime; +use crate::prelude::*; + +use crate::events::Poller; +use crate::net::socket::{sockopt::*, MsgFlags}; + +lazy_static! { + pub static ref SEND_BUF_SIZE: AtomicUsize = AtomicUsize::new(2565 * 1024 + 1); // Default Linux send buffer size is 2.5MB. + pub static ref RECV_BUF_SIZE: AtomicUsize = AtomicUsize::new(256 * 1024 + 1); +} + +pub struct StreamSocket { + state: RwLock>, + common: Arc>, +} + +enum State { + // Start state + Init(Arc>), + // Intermediate state + Connect(Arc>), + // Final state 1 + Connected(Arc>), + // Final state 2 + Listen(Arc>), +} + +impl StreamSocket { + pub fn new(nonblocking: bool) -> Result { + let init_stream = InitStream::new(nonblocking)?; + let common = init_stream.common().clone(); + + let fd = common.host_fd(); + debug!("host fd: {}", fd); + + let init_state = State::Init(init_stream); + Ok(Self { + state: RwLock::new(init_state), + common, + }) + } + + pub fn new_pair(nonblocking: bool) -> Result<(Self, Self)> { + let (common1, common2) = Common::new_pair(Type::STREAM, nonblocking)?; + let connected1 = ConnectedStream::new(Arc::new(common1)); + let connected2 = ConnectedStream::new(Arc::new(common2)); + let socket1 = Self::new_connected(connected1); + let socket2 = Self::new_connected(connected2); + Ok((socket1, socket2)) + } + + fn new_connected(connected_stream: Arc>) -> Self { + let common = connected_stream.common().clone(); + let state = RwLock::new(State::Connected(connected_stream)); + Self { state, common } + } + + fn try_switch_to_connected_state( + connecting_stream: &Arc>, + ) -> Option>> { + // Previously, I thought connecting state only exists for non-blocking socket. However, some applications can set non-blocking for + // connect syscall and after the connect returns, set the socket to blocking socket. Thus, this function shouldn't assert the connecting + // stream is non-blocking socket. + if connecting_stream.check_connection() { + let common = connecting_stream.common().clone(); + common.set_peer_addr(connecting_stream.peer_addr()); + Some(ConnectedStream::new(common)) + } else { + None + } + } + + pub fn domain(&self) -> Domain { + A::domain() + } + + pub fn errno(&self) -> Option { + self.common.errno() + } + + pub fn host_fd(&self) -> FileDesc { + let state = self.state.read().unwrap(); + state.common().host_fd() + } + + pub fn status_flags(&self) -> StatusFlags { + // Only support O_NONBLOCK + let state = self.state.read().unwrap(); + if state.common().nonblocking() { + StatusFlags::O_NONBLOCK + } else { + StatusFlags::empty() + } + } + + pub fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { + // Only support O_NONBLOCK + let state = self.state.read().unwrap(); + let nonblocking = new_flags.is_nonblocking(); + state.common().set_nonblocking(nonblocking); + Ok(()) + } + + pub fn bind(&self, addr: &A) -> Result<()> { + let state = self.state.read().unwrap(); + match &*state { + State::Init(init_stream) => init_stream.bind(addr), + _ => { + return_errno!(EINVAL, "cannot bind"); + } + } + } + + pub fn listen(&self, backlog: u32) -> Result<()> { + let mut state = self.state.write().unwrap(); + match &*state { + State::Init(init_stream) => { + let common = init_stream.common().clone(); + let listener = ListenerStream::new(backlog, common)?; + *state = State::Listen(listener); + Ok(()) + } + _ => { + return_errno!(EINVAL, "cannot listen"); + } + } + } + + pub fn connect(&self, peer_addr: &A) -> Result<()> { + // Create the new intermediate state of connecting and save the + // old state of init in case of failure to connect. + let (init_stream, connecting_stream) = { + let mut state = self.state.write().unwrap(); + match &*state { + State::Init(init_stream) => { + let connecting_stream = { + let common = init_stream.common().clone(); + ConnectingStream::new(peer_addr, common)? + }; + let init_stream = init_stream.clone(); + *state = State::Connect(connecting_stream.clone()); + (init_stream, connecting_stream) + } + State::Connect(connecting_stream) => { + if let Some(connected_stream) = + Self::try_switch_to_connected_state(connecting_stream) + { + *state = State::Connected(connected_stream); + return_errno!(EISCONN, "the socket is already connected"); + } else { + // Not connected, keep the connecting state and try connect + let init_stream = + InitStream::new_with_common(connecting_stream.common().clone())?; + (init_stream, connecting_stream.clone()) + } + } + State::Connected(_) => { + return_errno!(EISCONN, "the socket is already connected"); + } + State::Listen(_) => { + return_errno!(EINVAL, "the socket is listening"); + } + } + }; + + let res = connecting_stream.connect(); + + // If success, then the state is switched to connected; otherwise, for blocking socket + // the state is restored to the init state, and for non-blocking socket, the state + // keeps in connecting state. + match &res { + Ok(()) => { + let connected_stream = { + let common = init_stream.common().clone(); + common.set_peer_addr(peer_addr); + ConnectedStream::new(common) + }; + + let mut state = self.state.write().unwrap(); + *state = State::Connected(connected_stream); + } + Err(_) => { + if !connecting_stream.common().nonblocking() { + let mut state = self.state.write().unwrap(); + *state = State::Init(init_stream); + } + } + } + res + } + + pub fn accept(&self, nonblocking: bool) -> Result { + let listener_stream = { + let state = self.state.read().unwrap(); + match &*state { + State::Listen(listener_stream) => listener_stream.clone(), + _ => { + return_errno!(EINVAL, "the socket is not listening"); + } + } + }; + + let connected_stream = listener_stream.accept(nonblocking)?; + + let new_self = Self::new_connected(connected_stream); + Ok(new_self) + } + + pub fn read(&self, buf: &mut [u8]) -> Result { + self.readv(&mut [buf]) + } + + pub fn readv(&self, bufs: &mut [&mut [u8]]) -> Result { + let ret = self.recvmsg(bufs, RecvFlags::empty())?; + Ok(ret.0) + } + + /// Receive messages from connected socket + /// + /// Linux behavior: + /// Unlike datagram socket, `recvfrom` / `recvmsg` of stream socket will + /// ignore the address even if user specified it. + pub fn recvmsg( + &self, + buf: &mut [&mut [u8]], + flags: RecvFlags, + ) -> Result<(usize, Option, MsgFlags)> { + let connected_stream = { + let mut state = self.state.write().unwrap(); + match &*state { + State::Connected(connected_stream) => connected_stream.clone(), + State::Connect(connecting_stream) => { + if let Some(connected_stream) = + Self::try_switch_to_connected_state(connecting_stream) + { + *state = State::Connected(connected_stream.clone()); + connected_stream + } else { + return_errno!(ENOTCONN, "the socket is not connected"); + } + } + _ => { + return_errno!(ENOTCONN, "the socket is not connected"); + } + } + }; + + let recv_len = connected_stream.recvmsg(buf, flags)?; + Ok((recv_len, None, MsgFlags::empty())) + } + + pub fn write(&self, buf: &[u8]) -> Result { + self.writev(&[buf]) + } + + pub fn writev(&self, bufs: &[&[u8]]) -> Result { + self.sendmsg(bufs, SendFlags::empty()) + } + + pub fn sendmsg(&self, bufs: &[&[u8]], flags: SendFlags) -> Result { + let connected_stream = { + let mut state = self.state.write().unwrap(); + match &*state { + State::Connected(connected_stream) => connected_stream.clone(), + State::Connect(connecting_stream) => { + if let Some(connected_stream) = + Self::try_switch_to_connected_state(connecting_stream) + { + *state = State::Connected(connected_stream.clone()); + connected_stream + } else { + return_errno!(ENOTCONN, "the socket is not connected"); + } + } + _ => { + return_errno!(EPIPE, "the socket is not connected"); + } + } + }; + + connected_stream.sendmsg(bufs, flags) + } + + pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { + let state = self.state.read().unwrap(); + let pollee = state.common().pollee(); + pollee.poll(mask, poller) + } + + pub fn addr(&self) -> Result { + let state = self.state.read().unwrap(); + let common = state.common(); + + // Always get addr from host. + // Because for IP socket, users can specify "0" as port and the kernel should select a usable port for him. + // Thus, when calling getsockname, this should be updated. + let addr = common.get_addr_from_host()?; + common.set_addr(&addr); + Ok(addr) + } + + pub fn notifier(&self) -> &IoNotifier { + { + let mut state = self.state.write().unwrap(); + // Try switch to connected state to receive endpoint status + if let State::Connect(connecting_stream) = &*state { + if let Some(connected_stream) = + Self::try_switch_to_connected_state(connecting_stream) + { + *state = State::Connected(connected_stream.clone()); + } + } + // `state` goes out of scope here and the lock is implicitly released. + } + + self.common.notifier() + } + + pub fn peer_addr(&self) -> Result { + let mut state = self.state.write().unwrap(); + match &*state { + State::Connected(connected_stream) => { + Ok(connected_stream.common().peer_addr().unwrap()) + } + State::Connect(connecting_stream) => { + if let Some(connected_stream) = + Self::try_switch_to_connected_state(connecting_stream) + { + *state = State::Connected(connected_stream.clone()); + Ok(connected_stream.common().peer_addr().unwrap()) + } else { + return_errno!(ENOTCONN, "the socket is not connected"); + } + } + _ => return_errno!(ENOTCONN, "the socket is not connected"), + } + } + + pub fn ioctl(&self, cmd: &mut dyn IoctlCmd) -> Result<()> { + let mut state = self.state.write().unwrap(); + match &*state { + State::Connect(connecting_stream) => { + if let Some(connected_stream) = + Self::try_switch_to_connected_state(connecting_stream) + { + *state = State::Connected(connected_stream.clone()); + } + } + _ => {} + } + drop(state); + crate::match_ioctl_cmd_mut!(&mut *cmd, { + cmd: GetSockOptRawCmd => { + cmd.execute(self.host_fd())?; + }, + cmd: SetSockOptRawCmd => { + cmd.execute(self.host_fd())?; + }, + cmd: SetRecvTimeoutCmd => { + self.set_recv_timeout(*cmd.timeout()); + }, + cmd: SetSendTimeoutCmd => { + self.set_send_timeout(*cmd.timeout()); + }, + cmd: GetRecvTimeoutCmd => { + let timeval = timeout_to_timeval(self.recv_timeout()); + cmd.set_output(timeval); + }, + cmd: GetSendTimeoutCmd => { + let timeval = timeout_to_timeval(self.send_timeout()); + cmd.set_output(timeval); + }, + cmd: SetSndBufSizeCmd => { + cmd.update_host(self.host_fd())?; + let buf_size = cmd.buf_size(); + self.set_kernel_send_buf_size(buf_size); + }, + cmd: SetRcvBufSizeCmd => { + cmd.update_host(self.host_fd())?; + let buf_size = cmd.buf_size(); + self.set_kernel_recv_buf_size(buf_size); + }, + cmd: GetSndBufSizeCmd => { + let buf_size = SEND_BUF_SIZE.load(Ordering::Relaxed); + cmd.set_output(buf_size); + }, + cmd: GetRcvBufSizeCmd => { + let buf_size = RECV_BUF_SIZE.load(Ordering::Relaxed); + cmd.set_output(buf_size); + }, + cmd: GetAcceptConnCmd => { + let mut is_listen = false; + let state = self.state.read().unwrap(); + if let State::Listen(_listener_stream) = &*state { + is_listen = true; + } + cmd.set_output(is_listen as _); + }, + cmd: GetDomainCmd => { + cmd.set_output(self.domain() as _); + }, + cmd: GetPeerNameCmd => { + let peer = self.peer_addr()?; + cmd.set_output(AddrStorage(peer.to_c_storage())); + }, + cmd: GetErrorCmd => { + let error: i32 = self.errno().map(|err| err as i32).unwrap_or(0); + cmd.set_output(error); + }, + cmd: GetTypeCmd => { + let state = self.state.read().unwrap(); + cmd.set_output(state.common().type_() as _); + }, + cmd: SetNonBlocking => { + let state = self.state.read().unwrap(); + state.common().set_nonblocking(*cmd.input() != 0); + }, + cmd: GetReadBufLen => { + let state = self.state.read().unwrap(); + if let State::Connected(connected_stream) = &*state { + let read_buf_len = connected_stream.bytes_to_consume(); + cmd.set_output(read_buf_len as _); + } else { + return_errno!(ENOTCONN, "unconnected socket"); + } + }, + cmd: GetIfReqWithRawCmd => { + cmd.execute(self.host_fd())?; + }, + cmd: GetIfConf => { + cmd.execute(self.host_fd())?; + }, + _ => { + return_errno!(EINVAL, "Not supported yet"); + } + }); + Ok(()) + } + + fn set_kernel_send_buf_size(&self, buf_size: usize) { + let state = self.state.read().unwrap(); + match &*state { + State::Init(_) | State::Listen(_) | State::Connect(_) => { + // The kernel buffer is only created when the socket is connected. Just update the static variable. + SEND_BUF_SIZE.store(buf_size, Ordering::Relaxed); + } + State::Connected(connected_stream) => { + connected_stream.try_update_send_buf_size(buf_size); + } + } + } + + fn set_kernel_recv_buf_size(&self, buf_size: usize) { + let state = self.state.read().unwrap(); + match &*state { + State::Init(_) | State::Listen(_) | State::Connect(_) => { + // The kernel buffer is only created when the socket is connected. Just update the static variable. + RECV_BUF_SIZE.store(buf_size, Ordering::Relaxed); + } + State::Connected(connected_stream) => { + connected_stream.try_update_recv_buf_size(buf_size); + } + } + } + + pub fn shutdown(&self, shutdown: Shutdown) -> Result<()> { + let mut state = self.state.write().unwrap(); + match &*state { + State::Listen(listener_stream) => { + // listening socket can be shutdown and then re-use by calling listen again. + listener_stream.shutdown(shutdown)?; + if shutdown.should_shut_read() { + // Cancel pending accept requests. This is necessary because the socket is reusable. + listener_stream.cancel_accept_requests(); + // Set init state + let init_stream = + InitStream::new_with_common(listener_stream.common().clone())?; + let init_state = State::Init(init_stream); + *state = init_state; + Ok(()) + } else { + // shutdown the writer of the listener expect to have no effect + Ok(()) + } + } + State::Connected(connected_stream) => connected_stream.shutdown(shutdown), + State::Connect(connecting_stream) => { + if let Some(connected_stream) = + Self::try_switch_to_connected_state(connecting_stream) + { + connected_stream.shutdown(shutdown)?; + *state = State::Connected(connected_stream); + Ok(()) + } else { + return_errno!(ENOTCONN, "the socket is not connected"); + } + } + _ => { + return_errno!(ENOTCONN, "the socket is not connected"); + } + } + } + + pub fn close(&self) -> Result<()> { + let state = self.state.read().unwrap(); + match &*state { + State::Init(_) => {} + State::Listen(listener_stream) => { + listener_stream.common().set_closed(); + listener_stream.cancel_accept_requests(); + } + State::Connect(connecting_stream) => { + connecting_stream.common().set_closed(); + let need_wait = true; + connecting_stream.cancel_connect_request(need_wait); + } + State::Connected(connected_stream) => { + connected_stream.set_closed(); + connected_stream.cancel_recv_requests(); + connected_stream.try_empty_send_buf_when_close(); + } + } + Ok(()) + } + + fn send_timeout(&self) -> Option { + let state = self.state.read().unwrap(); + state.common().send_timeout() + } + + fn recv_timeout(&self) -> Option { + let state = self.state.read().unwrap(); + state.common().recv_timeout() + } + + fn set_send_timeout(&self, timeout: Duration) { + let state = self.state.read().unwrap(); + state.common().set_send_timeout(timeout); + } + + fn set_recv_timeout(&self, timeout: Duration) { + let state = self.state.read().unwrap(); + state.common().set_recv_timeout(timeout); + } + + /* + pub fn poll_by(&self, mask: Events, mut poller: Option<&mut Poller>) -> Events { + let state = self.state.read(); + match *state { + Init(init_stream) => init_stream.poll_by(mask, poller), + Connect(connect_stream) => connect_stream.poll_by(mask, poller), + Connected(connected_stream) = connected_stream.poll_by(mask, poller), + Listen(listener_stream) = listener_stream.poll_by(mask, poller), + } + } + */ +} + +impl Drop for StreamSocket { + fn drop(&mut self) { + let state = self.state.read().unwrap(); + state.common().set_closed(); + drop(state); + } +} + +impl std::fmt::Debug for State { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let inner: &dyn std::fmt::Debug = match self { + State::Init(inner) => inner as _, + State::Connect(inner) => inner as _, + State::Connected(inner) => inner as _, + State::Listen(inner) => inner as _, + }; + f.debug_tuple("State").field(inner).finish() + } +} + +impl std::fmt::Debug for StreamSocket { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StreamSocket").finish() + } +} + +impl State { + fn common(&self) -> &Common { + match self { + Self::Init(stream) => stream.common(), + Self::Connect(stream) => stream.common(), + Self::Connected(stream) => stream.common(), + Self::Listen(stream) => stream.common(), + } + } +} diff --git a/src/libos/src/net/socket/uring/stream/states/connect.rs b/src/libos/src/net/socket/uring/stream/states/connect.rs new file mode 100644 index 00000000..5a06a7b2 --- /dev/null +++ b/src/libos/src/net/socket/uring/stream/states/connect.rs @@ -0,0 +1,227 @@ +use core::time::Duration; +use std::marker::PhantomData; +use std::sync::atomic::{AtomicBool, Ordering}; + +use io_uring_callback::{Fd, IoHandle}; +use sgx_untrusted_alloc::UntrustedBox; + +use crate::events::Poller; +use crate::fs::IoEvents; +use crate::net::socket::uring::common::Common; +use crate::net::socket::uring::runtime::Runtime; +use crate::prelude::*; + +/// A stream socket that is in its connecting state. +pub struct ConnectingStream { + common: Arc>, + peer_addr: A, + req: Mutex>, + connected: AtomicBool, // Mainly use for nonblocking socket to update status asynchronously +} + +struct ConnectReq { + io_handle: Option, + c_addr: UntrustedBox, + c_addr_len: usize, + errno: Option, + phantom_data: PhantomData, +} + +impl ConnectingStream { + pub fn new(peer_addr: &A, common: Arc>) -> Result> { + let req = Mutex::new(ConnectReq::new(peer_addr)); + let new_self = Self { + common, + peer_addr: peer_addr.clone(), + req, + connected: AtomicBool::new(false), + }; + Ok(Arc::new(new_self)) + } + + /// Connect to the peer address. + pub fn connect(self: &Arc) -> Result<()> { + let pollee = self.common.pollee(); + pollee.reset_events(); + + self.initiate_async_connect(); + + if self.common.nonblocking() { + return_errno!(EINPROGRESS, "non-blocking connect request in progress"); + } + + // Wait for the async connect to complete + let mask = IoEvents::OUT; + let poller = Poller::new(); + pollee.connect_poller(mask, &poller); + let mut timeout = self.common.send_timeout(); + loop { + let events = pollee.poll(mask, None); + if !events.is_empty() { + break; + } + let ret = poller.wait_timeout(timeout.as_mut()); + if let Err(e) = ret { + let errno = e.errno(); + warn!("connect wait errno = {:?}", errno); + match errno { + ETIMEDOUT => { + // Cancel connect request if timeout. No need to wait for cancel to complete. + self.cancel_connect_request(false); + // This error code is same as the connect timeout error code on Linux + return_errno!(EINPROGRESS, "timeout reached") + } + _ => { + return_errno!(e.errno(), "wait error") + } + } + } + } + + // Finish the async connect + let req = self.req.lock(); + if let Some(e) = req.errno { + return_errno!(e, "connect failed"); + } + Ok(()) + } + + fn initiate_async_connect(self: &Arc) { + let io_uring = self.common.io_uring(); + let mut req = self.req.lock(); + // Skip if there is pending request + if req.io_handle.is_some() { + return; + } + + let arc_self = self.clone(); + let callback = move |retval: i32| { + // Guard against Igao attack + assert!(retval <= 0); + debug!("connect request complete with retval: {}", retval); + + let mut req = arc_self.req.lock(); + // Release the handle to the async connect + req.io_handle.take(); + + if retval == 0 { + arc_self.connected.store(true, Ordering::Relaxed); + arc_self.common.pollee().add_events(IoEvents::OUT); + } else { + // Store the errno + let errno = Errno::from(-retval as u32); + req.errno = Some(errno); + drop(req); + arc_self.common.set_errno(errno); + arc_self.connected.store(false, Ordering::Relaxed); + + let events = if errno == ENOTCONN || errno == ECONNRESET || errno == ECONNREFUSED { + IoEvents::HUP | IoEvents::IN | IoEvents::ERR + } else { + IoEvents::ERR + }; + arc_self.common.pollee().add_events(events); + } + }; + + let host_fd = self.common.host_fd() as _; + let c_addr_ptr = req.c_addr.as_ptr(); + let c_addr_len = req.c_addr_len; + let io_handle = unsafe { + io_uring.connect( + Fd(host_fd), + c_addr_ptr as *const libc::sockaddr, + c_addr_len as u32, + callback, + ) + }; + req.io_handle = Some(io_handle); + } + + pub fn cancel_connect_request(&self, need_wait: bool) { + { + let io_uring = self.common.io_uring(); + let req = self.req.lock(); + if let Some(io_handle) = &req.io_handle { + unsafe { io_uring.cancel(io_handle) }; + } else { + return; + } + } + + // Wait for the cancel to complete if needed + if !need_wait { + return; + } + + let poller = Poller::new(); + let mask = IoEvents::ERR | IoEvents::IN; + self.common.pollee().connect_poller(mask, &poller); + + loop { + let pending_request_exist = { + let req = self.req.lock(); + req.io_handle.is_some() + }; + + if pending_request_exist { + let mut timeout = Some(Duration::from_secs(10)); + let ret = poller.wait_timeout(timeout.as_mut()); + if let Err(e) = ret { + warn!("wait cancel connect request error = {:?}", e.errno()); + continue; + } + } else { + break; + } + } + } + + #[allow(dead_code)] + pub fn peer_addr(&self) -> &A { + &self.peer_addr + } + + pub fn common(&self) -> &Arc> { + &self.common + } + + // This can be used in connecting state to check non-blocking connect status. + pub fn check_connection(&self) -> bool { + // It is fine whether the load happens before or after the store operation + self.connected.load(Ordering::Relaxed) + } +} + +impl ConnectReq { + pub fn new(peer_addr: &A) -> Self { + let (c_addr_storage, c_addr_len) = peer_addr.to_c_storage(); + Self { + io_handle: None, + c_addr: UntrustedBox::new(c_addr_storage), + c_addr_len, + errno: None, + phantom_data: PhantomData, + } + } +} + +impl std::fmt::Debug for ConnectingStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ConnectingStream") + .field("common", &self.common) + .field("peer_addr", &self.peer_addr) + .field("req", &*self.req.lock()) + .field("connected", &self.connected) + .finish() + } +} + +impl std::fmt::Debug for ConnectReq { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ConnectReq") + .field("io_handle", &self.io_handle) + .field("errno", &self.errno) + .finish() + } +} diff --git a/src/libos/src/net/socket/uring/stream/states/connected/mod.rs b/src/libos/src/net/socket/uring/stream/states/connected/mod.rs new file mode 100644 index 00000000..3e30adc3 --- /dev/null +++ b/src/libos/src/net/socket/uring/stream/states/connected/mod.rs @@ -0,0 +1,114 @@ +use atomic::Ordering; + +use self::recv::Receiver; +use self::send::Sender; +use crate::fs::IoEvents as Events; +use crate::net::socket::sockopt::SockOptName; +use crate::net::socket::uring::common::Common; +use crate::net::socket::uring::runtime::Runtime; +use crate::prelude::*; + +mod recv; +mod send; + +pub struct ConnectedStream { + common: Arc>, + sender: Sender, + receiver: Receiver, +} + +impl ConnectedStream { + pub fn new(common: Arc>) -> Arc { + common.pollee().reset_events(); + common.pollee().add_events(Events::OUT); + + let fd = common.host_fd(); + + let sender = Sender::new(); + let receiver = Receiver::new(); + let new_self = Arc::new(Self { + common, + sender, + receiver, + }); + + // Start async recv requests right as early as possible to support poll and + // improve performance. If we don't start recv requests early, the poll() + // might block forever when user just invokes poll(Event::In) without read(). + // Once we have recv requests completed, we can have Event::In in the events. + new_self.initiate_async_recv(); + + new_self + } + + pub fn common(&self) -> &Arc> { + &self.common + } + + pub fn shutdown(&self, how: Shutdown) -> Result<()> { + // Do host shutdown + // For shutdown write, don't call host_shutdown until the content in the pending buffer is sent. + // For shutdown read, ignore the pending buffer. + let (shut_write, send_buf_is_empty, shut_read) = ( + how.should_shut_write(), + self.sender.is_empty(), + how.should_shut_read(), + ); + match (shut_write, send_buf_is_empty, shut_read) { + // As long as send buf is empty, just shutdown. + (_, true, _) => self.common.host_shutdown(how)?, + // If not shutdown write, just shutdown. + (false, _, _) => self.common.host_shutdown(how)?, + // If shutdown both but the send buf is not empty, only shutdown read. + (true, false, true) => self.common.host_shutdown(Shutdown::Read)?, + // If shutdown write but the send buf is not empty, don't do shutdown. + (true, false, false) => {} + } + + // Set internal state and trigger events. + if shut_read { + self.receiver.shutdown(); + self.common.pollee().add_events(Events::IN); + } + if shut_write { + self.sender.shutdown(); + self.common.pollee().add_events(Events::OUT); + } + + if shut_read && shut_write { + self.common.pollee().add_events(Events::HUP); + } + + Ok(()) + } + + pub fn set_closed(&self) { + // Mark the sender and receiver to shutdown to prevent submitting new requests. + self.receiver.shutdown(); + self.sender.shutdown(); + + self.common.set_closed(); + } + + // Other methods are implemented in the send and receive modules +} + +impl std::fmt::Debug for ConnectedStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ConnectedStream") + .field("common", &self.common) + .field("sender", &self.sender) + .field("receiver", &self.receiver) + .finish() + } +} + +fn new_msghdr(iovecs_ptr: *mut libc::iovec, iovecs_len: usize) -> libc::msghdr { + use std::mem::MaybeUninit; + // Safety. Setting all fields to zeros is a valid state for msghdr. + let mut msghdr: libc::msghdr = unsafe { MaybeUninit::zeroed().assume_init() }; + msghdr.msg_iov = iovecs_ptr; + msghdr.msg_iovlen = iovecs_len as _; + // We do want to leave all other fields as zeros + msghdr +} diff --git a/src/libos/src/net/socket/uring/stream/states/connected/recv.rs b/src/libos/src/net/socket/uring/stream/states/connected/recv.rs new file mode 100644 index 00000000..3a82cdfe --- /dev/null +++ b/src/libos/src/net/socket/uring/stream/states/connected/recv.rs @@ -0,0 +1,458 @@ +use core::hint; +use core::sync::atomic::AtomicBool; +use core::time::Duration; +use std::mem::MaybeUninit; +use std::ptr::{self}; + +use atomic::Ordering; +use io_uring_callback::{Fd, IoHandle}; +use sgx_untrusted_alloc::{MaybeUntrusted, UntrustedBox}; + +use super::ConnectedStream; +use crate::net::socket::uring::runtime::Runtime; +use crate::net::socket::uring::stream::RECV_BUF_SIZE; +use crate::prelude::*; +use crate::untrusted::UntrustedCircularBuf; +use crate::util::sync::{Mutex, MutexGuard}; + +use crate::events::Poller; +use crate::fs::IoEvents as Events; + +impl ConnectedStream { + pub fn recvmsg(self: &Arc, bufs: &mut [&mut [u8]], flags: RecvFlags) -> Result { + let total_len: usize = bufs.iter().map(|buf| buf.len()).sum(); + if total_len == 0 { + return Ok(0); + } + + let mut total_received = 0; + let mut iov_buffer_index = 0; + let mut iov_buffer_offset = 0; + + let mask = Events::IN; + // Initialize the poller only when needed + let mut poller = None; + let mut timeout = self.common.recv_timeout(); + loop { + // Attempt to read + let res = self.try_recvmsg(bufs, flags, iov_buffer_index, iov_buffer_offset); + + match res { + Ok((received_size, index, offset)) => { + total_received += received_size; + + if !flags.contains(RecvFlags::MSG_WAITALL) || total_received == total_len { + return Ok(total_received); + } else { + // save the index and offset for the next round + iov_buffer_index = index; + iov_buffer_offset = offset; + } + } + Err(e) => { + if e.errno() != EAGAIN { + return Err(e); + } + } + }; + + if self.common.nonblocking() || flags.contains(RecvFlags::MSG_DONTWAIT) { + return_errno!(EAGAIN, "no data are present to be received"); + } + + // Wait for interesting events by polling + if poller.is_none() { + let new_poller = Poller::new(); + self.common.pollee().connect_poller(mask, &new_poller); + poller = Some(new_poller); + } + let events = self.common.pollee().poll(mask, None); + if events.is_empty() { + let ret = poller.as_ref().unwrap().wait_timeout(timeout.as_mut()); + if let Err(e) = ret { + warn!("recv wait errno = {:?}", e.errno()); + // For recv with MSG_WAITALL, return total received bytes if timeout or interrupt + if flags.contains(RecvFlags::MSG_WAITALL) && total_received > 0 { + return Ok(total_received); + } + match e.errno() { + ETIMEDOUT => { + return_errno!(EAGAIN, "timeout reached") + } + _ => { + return_errno!(e.errno(), "wait error") + } + } + } + } + } + } + + fn try_recvmsg( + self: &Arc, + bufs: &mut [&mut [u8]], + flags: RecvFlags, + iov_buffer_index: usize, + iov_buffer_offset: usize, + ) -> Result<(usize, usize, usize)> { + let mut inner = self.receiver.inner.lock(); + + if !flags.is_empty() + && flags.intersects(!(RecvFlags::MSG_DONTWAIT | RecvFlags::MSG_WAITALL)) + { + warn!("Unsupported flags: {:?}", flags); + return_errno!(EINVAL, "flags not supported"); + } + + let res = { + let mut total_consumed = 0; + let mut iov_buffer_index = iov_buffer_index; + let mut iov_buffer_offset = iov_buffer_offset; + + // save the received data from bufs[iov_buffer_index][iov_buffer_offset..] + for (_, buf) in bufs.iter_mut().skip(iov_buffer_index).enumerate() { + let this_consumed = inner.recv_buf.consume(&mut buf[iov_buffer_offset..]); + if this_consumed == 0 { + break; + } + total_consumed += this_consumed; + + // if the buffer is not full, then the try_recvmsg will be used again + // next time, the data will be stored from the offset + if this_consumed < buf[iov_buffer_offset..].len() { + iov_buffer_offset += this_consumed; + break; + } else { + iov_buffer_index += 1; + iov_buffer_offset = 0; + } + } + (total_consumed, iov_buffer_index, iov_buffer_offset) + }; + + if self.receiver.need_update() { + // Only update the recv buf when it is empty and there is no pending recv request + if inner.recv_buf.is_empty() && inner.io_handle.is_none() { + self.receiver.set_need_update(false); + inner.update_buf_size(RECV_BUF_SIZE.load(Ordering::Relaxed)); + } + } + + if inner.end_of_file { + return Ok(res); + } + + if inner.recv_buf.is_empty() { + // Mark the socket as non-readable + self.common.pollee().del_events(Events::IN); + } + + if res.0 > 0 { + self.do_recv(&mut inner); + return Ok(res); + } + + // Only when there are no data available in the recv buffer, shall we check + // the following error conditions. + // + // Case 1: If the read side of the connection has been shutdown... + if inner.is_shutdown { + return_errno!(EPIPE, "read side is shutdown"); + } + // Case 2: If the connenction has been broken... + if let Some(errno) = inner.fatal { + // Reset error + inner.fatal = None; + self.common.pollee().del_events(Events::ERR); + return_errno!(errno, "read failed"); + } + + self.do_recv(&mut inner); + return_errno!(EAGAIN, "try read again"); + } + + fn do_recv(self: &Arc, inner: &mut MutexGuard) { + if inner.recv_buf.is_full() + || inner.is_shutdown + || inner.io_handle.is_some() + || inner.end_of_file + || self.common.is_closed() + { + // Delete ERR events from sender. If io_handle is some, the recv request must be + // pending and the events can't be for the reciever. Just delete this event. + // This can happen when send request is timeout and canceled. + let events = self.common.pollee().poll(Events::IN, None); + if events.contains(Events::ERR) && inner.io_handle.is_some() { + self.common.pollee().del_events(Events::ERR); + } + return; + } + + // Init the callback invoked upon the completion of the async recv + let stream = self.clone(); + let complete_fn = move |retval: i32| { + // let mut inner = stream.receiver.inner.lock().unwrap(); + let mut inner = stream.receiver.inner.lock(); + trace!("recv request complete with retval: {:?}", retval); + + // Release the handle to the async recv + inner.io_handle.take(); + + // Handle error + if retval < 0 { + // TODO: guard against Iago attack through errno + // We should return here, The error may be due to network reasons + // or because the request was cancelled. We don't want to start a + // new request after cancelled a request. + let errno = Errno::from(-retval as u32); + inner.fatal = Some(errno); + stream.common.set_errno(errno); + + let events = if errno == ENOTCONN || errno == ECONNRESET || errno == ECONNREFUSED { + Events::HUP | Events::IN | Events::ERR + } else { + Events::ERR + }; + stream.common.pollee().add_events(events); + + return; + } + // Handle end of file + else if retval == 0 { + inner.end_of_file = true; + stream.common.pollee().add_events(Events::IN); + return; + } + + // Handle the normal case of a successful read + let nbytes = retval as usize; + inner.recv_buf.produce_without_copy(nbytes); + + // Now that we have produced non-zero bytes, the buf must become + // ready to read. + stream.common.pollee().add_events(Events::IN); + + stream.do_recv(&mut inner); + }; + + // Generate the async recv request + let msghdr_ptr = inner.new_recv_req(); + + // Submit the async recv to io_uring + let io_uring = self.common.io_uring(); + let host_fd = Fd(self.common.host_fd() as _); + + let handle = unsafe { io_uring.recvmsg(host_fd, msghdr_ptr, 0, complete_fn) }; + inner.io_handle.replace(handle); + } + + pub(super) fn initiate_async_recv(self: &Arc) { + // trace!("initiate async recv"); + let mut inner = self.receiver.inner.lock(); + self.do_recv(&mut inner); + } + + pub fn cancel_recv_requests(&self) { + { + let inner = self.receiver.inner.lock(); + if let Some(io_handle) = &inner.io_handle { + let io_uring = self.common.io_uring(); + unsafe { io_uring.cancel(io_handle) }; + } else { + return; + } + } + + // wait for the cancel to complete + let poller = Poller::new(); + let mask = Events::ERR | Events::IN; + self.common.pollee().connect_poller(mask, &poller); + + loop { + let pending_request_exist = { + let inner = self.receiver.inner.lock(); + inner.io_handle.is_some() + }; + + if pending_request_exist { + let mut timeout = Some(Duration::from_secs(10)); + let ret = poller.wait_timeout(timeout.as_mut()); + if let Err(e) = ret { + warn!("wait cancel recv request error = {:?}", e.errno()); + continue; + } + } else { + break; + } + } + } + + pub fn bytes_to_consume(self: &Arc) -> usize { + let inner = self.receiver.inner.lock(); + inner.recv_buf.consumable() + } + + // This function will try to update the kernel recv buf size. + // For socket recv, there will always be a pending request in advance. Thus,we can only update the kernel + // buffer when a recv request is done and the kernel buffer is empty. Here, we just set the update flag. + pub fn try_update_recv_buf_size(&self, buf_size: usize) { + let pre_buf_size = RECV_BUF_SIZE.swap(buf_size, Ordering::Relaxed); + if buf_size == pre_buf_size { + return; + } + + self.receiver.set_need_update(true); + } +} + +pub struct Receiver { + inner: Mutex, + need_update: AtomicBool, +} + +impl Receiver { + pub fn new() -> Self { + let inner = Mutex::new(Inner::new()); + let need_update = AtomicBool::new(false); + + Self { inner, need_update } + } + + pub fn shutdown(&self) { + let mut inner = self.inner.lock(); + inner.is_shutdown = true; + } + + pub fn set_need_update(&self, need_update: bool) { + self.need_update.store(need_update, Ordering::Relaxed) + } + + pub fn need_update(&self) -> bool { + self.need_update.load(Ordering::Relaxed) + } +} + +impl std::fmt::Debug for Receiver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Receiver") + .field("inner", &self.inner.lock()) + .finish() + } +} + +struct Inner { + recv_buf: UntrustedCircularBuf, + recv_req: UntrustedBox, + io_handle: Option, + is_shutdown: bool, + end_of_file: bool, + fatal: Option, +} + +// Safety. `RecvReq` does not implement `Send`. But since all pointers in `RecvReq` +// refer to `recv_buf`, we can be sure that it is ok for `RecvReq` to move between +// threads. All other fields in `RecvReq` implement `Send` as well. So the entirety +// of `Inner` is `Send`-safe. +unsafe impl Send for Inner {} + +impl Inner { + pub fn new() -> Self { + Self { + recv_buf: UntrustedCircularBuf::with_capacity(RECV_BUF_SIZE.load(Ordering::Relaxed)), + recv_req: UntrustedBox::new_uninit(), + io_handle: None, + is_shutdown: false, + end_of_file: false, + fatal: None, + } + } + + fn update_buf_size(&mut self, buf_size: usize) { + debug_assert!(self.recv_buf.is_empty() && self.io_handle.is_none()); + let new_recv_buf = UntrustedCircularBuf::with_capacity(buf_size); + self.recv_buf = new_recv_buf; + } + + /// Constructs a new recv request according to the receiver's internal state. + /// + /// The new `RecvReq` will be put into `self.recv_req`, which is a location that is + /// accessible by io_uring. A pointer to the C version of the resulting `RecvReq`, + /// which is `libc::msghdr`, will be returned. + /// + /// The buffer used in the new `RecvReq` is part of `self.recv_buf`. + pub fn new_recv_req(&mut self) -> *mut libc::msghdr { + let (iovecs, iovecs_len) = self.gen_iovecs_from_recv_buf(); + + let msghdr_ptr: *mut libc::msghdr = &mut self.recv_req.msg; + let iovecs_ptr: *mut libc::iovec = &mut self.recv_req.iovecs as *mut _ as _; + + let msg = super::new_msghdr(iovecs_ptr, iovecs_len); + + self.recv_req.msg = msg; + self.recv_req.iovecs = iovecs; + + msghdr_ptr + } + + fn gen_iovecs_from_recv_buf(&mut self) -> ([libc::iovec; 2], usize) { + let mut iovecs_len = 0; + let mut iovecs = unsafe { MaybeUninit::<[libc::iovec; 2]>::uninit().assume_init() }; + self.recv_buf.with_producer_view(|part0, part1| { + debug_assert!(part0.len() > 0); + + iovecs[0] = libc::iovec { + iov_base: part0.as_ptr() as _, + iov_len: part0.len() as _, + }; + + iovecs[1] = if part1.len() > 0 { + iovecs_len = 2; + libc::iovec { + iov_base: part1.as_ptr() as _, + iov_len: part1.len() as _, + } + } else { + iovecs_len = 1; + libc::iovec { + iov_base: ptr::null_mut(), + iov_len: 0, + } + }; + + // Only access the producer's buffer; zero bytes produced for now. + 0 + }); + debug_assert!(iovecs_len > 0); + (iovecs, iovecs_len) + } +} + +impl std::fmt::Debug for Inner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Inner") + .field("recv_buf", &self.recv_buf) + .field("io_handle", &self.io_handle) + .field("is_shutdown", &self.is_shutdown) + .field("end_of_file", &self.end_of_file) + .field("fatal", &self.fatal) + .finish() + } +} + +#[repr(C)] +struct RecvReq { + msg: libc::msghdr, + iovecs: [libc::iovec; 2], +} + +// Safety. RecvReq is a C-style struct. +unsafe impl MaybeUntrusted for RecvReq {} + +// Acquired by `IoUringCell`. +impl Copy for RecvReq {} + +impl Clone for RecvReq { + fn clone(&self) -> Self { + *self + } +} diff --git a/src/libos/src/net/socket/uring/stream/states/connected/send.rs b/src/libos/src/net/socket/uring/stream/states/connected/send.rs new file mode 100644 index 00000000..6b0e8088 --- /dev/null +++ b/src/libos/src/net/socket/uring/stream/states/connected/send.rs @@ -0,0 +1,466 @@ +use core::hint; +use core::sync::atomic::AtomicBool; +use core::time::Duration; +use std::mem::MaybeUninit; +use std::ptr::{self}; + +use atomic::Ordering; +use io_uring_callback::{Fd, IoHandle}; +use log::error; +use sgx_untrusted_alloc::{MaybeUntrusted, UntrustedBox}; + +use super::ConnectedStream; +use crate::net::socket::uring::runtime::Runtime; +use crate::net::socket::uring::stream::SEND_BUF_SIZE; +use crate::prelude::*; +use crate::untrusted::UntrustedCircularBuf; + +use crate::util::sync::{Mutex, MutexGuard}; + +use crate::events::Poller; +use crate::fs::IoEvents as Events; + +impl ConnectedStream { + // We make sure the all the buffer contents are buffered in kernel and then return. + pub fn sendmsg(self: &Arc, bufs: &[&[u8]], flags: SendFlags) -> Result { + let total_len: usize = bufs.iter().map(|buf| buf.len()).sum(); + if total_len == 0 { + return Ok(0); + } + + let mut send_len = 0; + // variables to track the position of async sendmsg. + let mut iov_buf_id = 0; // user buffer id tracker + let mut iov_buf_index = 0; // user buffer index tracker + + let mask = Events::OUT; + // Initialize the poller only when needed + let mut poller = None; + let mut timeout = self.common.send_timeout(); + loop { + // Attempt to write + let res = self.try_sendmsg(bufs, flags, &mut iov_buf_id, &mut iov_buf_index); + if let Ok(len) = res { + send_len += len; + // Sent all or sent partial but it is nonblocking, return bytes sent + if send_len == total_len + || self.common.nonblocking() + || flags.contains(SendFlags::MSG_DONTWAIT) + { + return Ok(send_len); + } + } else if !res.has_errno(EAGAIN) { + return res; + } + + // Still some buffer contents pending + if self.common.nonblocking() || flags.contains(SendFlags::MSG_DONTWAIT) { + return_errno!(EAGAIN, "try write again"); + } + + // Wait for interesting events by polling + if poller.is_none() { + let new_poller = Poller::new(); + self.common.pollee().connect_poller(mask, &new_poller); + poller = Some(new_poller); + } + + let events = self.common.pollee().poll(mask, None); + if events.is_empty() { + let ret = poller.as_ref().unwrap().wait_timeout(timeout.as_mut()); + if let Err(e) = ret { + warn!("send wait errno = {:?}", e.errno()); + match e.errno() { + ETIMEDOUT => { + // Just cancel send requests if timeout + self.cancel_send_requests(); + return_errno!(EAGAIN, "timeout reached") + } + _ => { + return_errno!(e.errno(), "wait error") + } + } + } + } + } + } + + fn try_sendmsg( + self: &Arc, + bufs: &[&[u8]], + flags: SendFlags, + iov_buf_id: &mut usize, + iov_buf_index: &mut usize, + ) -> Result { + let mut inner = self.sender.inner.lock(); + + if !flags.is_empty() + && flags.intersects( + !(SendFlags::MSG_DONTWAIT | SendFlags::MSG_NOSIGNAL | SendFlags::MSG_MORE), + ) + { + error!("Not supported flags: {:?}", flags); + return_errno!(EINVAL, "not supported flags"); + } + + // Check for error condition before write. + // + // Case 1. If the write side of the connection has been shutdown... + if inner.is_shutdown() { + return_errno!(EPIPE, "write side is shutdown"); + } + // Case 2. If the connenction has been broken... + if let Some(errno) = inner.fatal { + // Reset error + inner.fatal = None; + self.common.pollee().del_events(Events::ERR); + return_errno!(errno, "write failed"); + } + + // Copy data from the bufs to the send buffer + // If the send buffer is full, update the user buffer tracker, return error to wait for events + // And once there is free space, continue from the user buffer tracker + let nbytes = { + let mut total_produced = 0; + let last_time_buf_id = iov_buf_id.clone(); + let mut last_time_buf_idx = iov_buf_index.clone(); + for (_i, buf) in bufs.iter().skip(last_time_buf_id).enumerate() { + let i = _i + last_time_buf_id; // After skipping ,the index still starts from 0 + let this_produced = inner.send_buf.produce(&buf[last_time_buf_idx..]); + total_produced += this_produced; + if this_produced < buf[last_time_buf_idx..].len() { + // Send buffer is full. + *iov_buf_id = i; + *iov_buf_index = last_time_buf_idx + this_produced; + break; + } else { + // For next buffer, start from the front + last_time_buf_idx = 0; + } + } + total_produced + }; + + if inner.send_buf.is_full() { + // Mark the socket as non-writable + self.common.pollee().del_events(Events::OUT); + } + + // Since the send buffer is not empty, we can try to flush the buffer + if inner.io_handle.is_none() { + self.do_send(&mut inner); + } + + if nbytes > 0 { + Ok(nbytes) + } else { + return_errno!(EAGAIN, "try write again"); + } + } + + fn do_send(self: &Arc, inner: &mut MutexGuard) { + // This function can also be called even if the socket is set to shutdown by shutdown syscall. This is due to the + // async behaviour that the kernel may return to user before actually issuing the request. We should + // keep sending the request as long as the send buffer is not empty even if the socket is shutdown. + debug_assert!(inner.is_shutdown != ShutdownStatus::PostShutdown); + debug_assert!(!inner.send_buf.is_empty()); + debug_assert!(inner.io_handle.is_none()); + + // Init the callback invoked upon the completion of the async send + let stream = self.clone(); + let complete_fn = move |retval: i32| { + let mut inner = stream.sender.inner.lock(); + + trace!("send request complete with retval: {}", retval); + // Release the handle to the async send + inner.io_handle.take(); + + // Handle error + if retval < 0 { + // TODO: guard against Iago attack through errno + // TODO: should we ignore EINTR and try again? + let errno = Errno::from(-retval as u32); + + inner.fatal = Some(errno); + stream.common.set_errno(errno); + stream.common.pollee().add_events(Events::ERR); + return; + } + assert!(retval != 0); + + // Handle the normal case of a successful write + let nbytes = retval as usize; + inner.send_buf.consume_without_copy(nbytes); + + // Now that we have consume non-zero bytes, the buf must become + // ready to write. + stream.common.pollee().add_events(Events::OUT); + + // Attempt to send again if there are available data in the buf. + if !inner.send_buf.is_empty() { + stream.do_send(&mut inner); + } else if inner.is_shutdown == ShutdownStatus::PreShutdown { + // The buffer is empty and the write side is shutdown by the user. We can safely shutdown host file here. + let _ = stream.common.host_shutdown(Shutdown::Write); + inner.is_shutdown = ShutdownStatus::PostShutdown + } else if stream.sender.need_update() { + // send_buf is empty. We can try to update the send_buf + stream.sender.set_need_update(false); + inner.update_buf_size(SEND_BUF_SIZE.load(Ordering::Relaxed)); + } + }; + + // Generate the async send request + let msghdr_ptr = inner.new_send_req(); + + trace!("send submit request"); + // Submit the async send to io_uring + let io_uring = self.common.io_uring(); + let host_fd = Fd(self.common.host_fd() as _); + let handle = unsafe { io_uring.sendmsg(host_fd, msghdr_ptr, 0, complete_fn) }; + inner.io_handle.replace(handle); + } + + pub fn cancel_send_requests(&self) { + let io_uring = self.common.io_uring(); + let inner = self.sender.inner.lock(); + if let Some(io_handle) = &inner.io_handle { + unsafe { io_uring.cancel(io_handle) }; + } + } + + // This function will try to update the kernel buf size. + // If the kernel buf is currently empty, the size will be updated immediately. + // If the kernel buf is not empty, update the flag in Sender and update the kernel buf after send. + pub fn try_update_send_buf_size(&self, buf_size: usize) { + let pre_buf_size = SEND_BUF_SIZE.swap(buf_size, Ordering::Relaxed); + if pre_buf_size == buf_size { + return; + } + + // Try to acquire the lock. If success, try directly update here. + // If failure, don't wait because there is pending send request. + if let Some(mut inner) = self.sender.inner.try_lock() { + if inner.send_buf.is_empty() && inner.io_handle.is_none() { + inner.update_buf_size(buf_size); + return; + } + } + + // Can't easily aquire lock or the sendbuf is not empty. Update the flag only + self.sender.set_need_update(true); + } + + // Normally, We will always try to send as long as the kernel send buf is not empty. However, if the user calls close, we will wait LINGER time + // and then cancel on-going or new-issued send requests. + pub fn try_empty_send_buf_when_close(&self) { + // let inner = self.sender.inner.lock().unwrap(); + let inner = self.sender.inner.lock(); + debug_assert!(inner.is_shutdown()); + if inner.send_buf.is_empty() { + return; + } + + // Wait for linger time to empty the kernel buffer or cancel subsequent requests. + drop(inner); + const DEFUALT_LINGER_TIME: usize = 10; + let poller = Poller::new(); + let mask = Events::ERR | Events::OUT; + self.common.pollee().connect_poller(mask, &poller); + + loop { + let pending_request_exist = { + // let inner = self.sender.inner.lock().unwrap(); + let inner = self.sender.inner.lock(); + inner.io_handle.is_some() + }; + + if pending_request_exist { + let mut timeout = Some(Duration::from_secs(DEFUALT_LINGER_TIME as u64)); + let ret = poller.wait_timeout(timeout.as_mut()); + trace!("wait empty send buffer ret = {:?}", ret); + if let Err(_) = ret { + // No complete request to wake. Just cancel the send requests. + let io_uring = self.common.io_uring(); + let inner = self.sender.inner.lock(); + if let Some(io_handle) = &inner.io_handle { + unsafe { io_uring.cancel(io_handle) }; + // Loop again to wait the cancel request to complete + continue; + } else { + // No pending request, just break + break; + } + } + } else { + // There is no pending requests + break; + } + } + } +} + +pub struct Sender { + inner: Mutex, + need_update: AtomicBool, +} + +impl Sender { + pub fn new() -> Self { + let inner = Mutex::new(Inner::new()); + let need_update = AtomicBool::new(false); + Self { inner, need_update } + } + + pub fn shutdown(&self) { + let mut inner = self.inner.lock(); + inner.is_shutdown = ShutdownStatus::PreShutdown; + } + + pub fn is_empty(&self) -> bool { + let inner = self.inner.lock(); + inner.send_buf.is_empty() + } + + pub fn set_need_update(&self, need_update: bool) { + self.need_update.store(need_update, Ordering::Relaxed) + } + + pub fn need_update(&self) -> bool { + self.need_update.load(Ordering::Relaxed) + } +} + +impl std::fmt::Debug for Sender { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Sender") + .field("inner", &self.inner.lock()) + .finish() + } +} + +struct Inner { + send_buf: UntrustedCircularBuf, + send_req: UntrustedBox, + io_handle: Option, + is_shutdown: ShutdownStatus, + fatal: Option, +} + +// Safety. `SendReq` does not implement `Send`. But since all pointers in `SengReq` +// refer to `send_buf`, we can be sure that it is ok for `SendReq` to move between +// threads. All other fields in `SendReq` implement `Send` as well. So the entirety +// of `Inner` is `Send`-safe. +unsafe impl Send for Inner {} + +impl Inner { + pub fn new() -> Self { + Self { + send_buf: UntrustedCircularBuf::with_capacity(SEND_BUF_SIZE.load(Ordering::Relaxed)), + send_req: UntrustedBox::new_uninit(), + io_handle: None, + is_shutdown: ShutdownStatus::Running, + fatal: None, + } + } + + fn update_buf_size(&mut self, buf_size: usize) { + debug_assert!(self.send_buf.is_empty() && self.io_handle.is_none()); + let new_send_buf = UntrustedCircularBuf::with_capacity(buf_size); + self.send_buf = new_send_buf; + } + + pub fn is_shutdown(&self) -> bool { + self.is_shutdown == ShutdownStatus::PreShutdown + || self.is_shutdown == ShutdownStatus::PostShutdown + } + + /// Constructs a new send request according to the sender's internal state. + /// + /// The new `SendReq` will be put into `self.send_req`, which is a location that is + /// accessible by io_uring. A pointer to the C version of the resulting `SendReq`, + /// which is `libc::msghdr`, will be returned. + /// + /// The buffer used in the new `SendReq` is part of `self.send_buf`. + pub fn new_send_req(&mut self) -> *mut libc::msghdr { + let (iovecs, iovecs_len) = self.gen_iovecs_from_send_buf(); + + let msghdr_ptr: *mut libc::msghdr = &mut self.send_req.msg; + let iovecs_ptr: *mut libc::iovec = &mut self.send_req.iovecs as *mut _ as _; + + let msg = super::new_msghdr(iovecs_ptr, iovecs_len); + + self.send_req.msg = msg; + self.send_req.iovecs = iovecs; + + msghdr_ptr + } + + fn gen_iovecs_from_send_buf(&mut self) -> ([libc::iovec; 2], usize) { + let mut iovecs_len = 0; + let mut iovecs = unsafe { MaybeUninit::<[libc::iovec; 2]>::uninit().assume_init() }; + self.send_buf.with_consumer_view(|part0, part1| { + debug_assert!(part0.len() > 0); + + iovecs[0] = libc::iovec { + iov_base: part0.as_ptr() as _, + iov_len: part0.len() as _, + }; + + iovecs[1] = if part1.len() > 0 { + iovecs_len = 2; + libc::iovec { + iov_base: part1.as_ptr() as _, + iov_len: part1.len() as _, + } + } else { + iovecs_len = 1; + libc::iovec { + iov_base: ptr::null_mut(), + iov_len: 0, + } + }; + + // Only access the consumer's buffer; zero bytes consumed for now. + 0 + }); + debug_assert!(iovecs_len > 0); + (iovecs, iovecs_len) + } +} + +impl std::fmt::Debug for Inner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Inner") + .field("send_buf", &self.send_buf) + .field("io_handle", &self.io_handle) + .field("is_shutdown", &self.is_shutdown) + .field("fatal", &self.fatal) + .finish() + } +} + +#[repr(C)] +struct SendReq { + msg: libc::msghdr, + iovecs: [libc::iovec; 2], +} + +// Safety. SendReq is a C-style struct. +unsafe impl MaybeUntrusted for SendReq {} + +// Acquired by `IoUringCell`. +impl Copy for SendReq {} + +impl Clone for SendReq { + fn clone(&self) -> Self { + *self + } +} + +#[derive(Debug, PartialEq)] +enum ShutdownStatus { + Running, // not shutdown + PreShutdown, // start the shutdown process, set by calling shutdown syscall + PostShutdown, // shutdown process is done, set when the buffer is empty +} diff --git a/src/libos/src/net/socket/uring/stream/states/init.rs b/src/libos/src/net/socket/uring/stream/states/init.rs new file mode 100644 index 00000000..5edcecbe --- /dev/null +++ b/src/libos/src/net/socket/uring/stream/states/init.rs @@ -0,0 +1,72 @@ +use crate::fs::IoEvents; +use crate::net::socket::uring::common::Common; +use crate::net::socket::uring::runtime::Runtime; +use crate::prelude::*; + +/// A stream socket that is in its initial state. +pub struct InitStream { + common: Arc>, + inner: Mutex, +} + +struct Inner { + has_bound: bool, +} + +impl InitStream { + pub fn new(nonblocking: bool) -> Result> { + let common = Arc::new(Common::new(Type::STREAM, nonblocking, None)?); + common.pollee().add_events(IoEvents::HUP | IoEvents::OUT); + let inner = Mutex::new(Inner::new()); + let new_self = Self { common, inner }; + Ok(Arc::new(new_self)) + } + + pub fn new_with_common(common: Arc>) -> Result> { + let inner = Mutex::new(Inner { + has_bound: common.addr().is_some(), + }); + let new_self = Self { common, inner }; + Ok(Arc::new(new_self)) + } + + pub fn bind(&self, addr: &A) -> Result<()> { + let mut inner = self.inner.lock(); + if inner.has_bound { + return_errno!(EINVAL, "the socket is already bound to an address"); + } + + crate::net::socket::uring::common::do_bind(self.common.host_fd(), addr)?; + + inner.has_bound = true; + self.common.set_addr(addr); + Ok(()) + } + + pub fn common(&self) -> &Arc> { + &self.common + } +} + +impl std::fmt::Debug for InitStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("InitStream") + .field("common", &self.common) + .field("inner", &*self.inner.lock()) + .finish() + } +} + +impl Inner { + pub fn new() -> Self { + Self { has_bound: false } + } +} + +impl std::fmt::Debug for Inner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Inner") + .field("has_bound", &self.has_bound) + .finish() + } +} diff --git a/src/libos/src/net/socket/uring/stream/states/listen.rs b/src/libos/src/net/socket/uring/stream/states/listen.rs new file mode 100644 index 00000000..24a91da0 --- /dev/null +++ b/src/libos/src/net/socket/uring/stream/states/listen.rs @@ -0,0 +1,430 @@ +use core::time::Duration; +use std::collections::VecDeque; +use std::marker::PhantomData; +use std::mem::size_of; + +use io_uring_callback::{Fd, IoHandle}; +use sgx_untrusted_alloc::{MaybeUntrusted, UntrustedBox}; + +use super::ConnectedStream; +use crate::events::Poller; +use crate::fs::IoEvents; +use crate::net::socket::uring::common::{do_close, Common}; +use crate::net::socket::uring::runtime::Runtime; +use crate::prelude::*; +use libc::ocall::shutdown as do_shutdown; + +// We issue the async accept request ahead of time. But with big backlog number, +// there will be too many pending requests, which could be harmful to the system. +const PENDING_ASYNC_ACCEPT_NUM_MAX: usize = 128; + +/// A listener stream, ready to accept incoming connections. +pub struct ListenerStream { + common: Arc>, + inner: Mutex>, +} + +impl ListenerStream { + /// Creates a new listener stream. + pub fn new(backlog: u32, common: Arc>) -> Result> { + // Here we use different variables for the backlog. For the libos, as we will issue async accept request + // ahead of time, and cacel when the socket closes, we set the libos backlog to a reasonable value which + // is no greater than the max value we set to save resources and make it more efficient. For the host, + // we just use the backlog value for maximum connection. + let libos_backlog = std::cmp::min(backlog, PENDING_ASYNC_ACCEPT_NUM_MAX as u32); + let host_backlog = backlog; + + let inner = Inner::new(libos_backlog)?; + Self::do_listen(common.host_fd(), host_backlog)?; + + common.pollee().reset_events(); + let new_self = Arc::new(Self { + common, + inner: Mutex::new(inner), + }); + + // Start async accept requests right as early as possible to improve performance + { + let inner = new_self.inner.lock(); + new_self.initiate_async_accepts(inner); + } + + Ok(new_self) + } + + fn do_listen(host_fd: FileDesc, backlog: u32) -> Result<()> { + try_libc!(libc::ocall::listen(host_fd as _, backlog as _)); + Ok(()) + } + + pub fn accept(self: &Arc, nonblocking: bool) -> Result>> { + let mask = IoEvents::IN; + // Init the poller only when needed + let mut poller = None; + let mut timeout = self.common.recv_timeout(); + loop { + // Attempt to accept + let res = self.try_accept(nonblocking); + if !res.has_errno(EAGAIN) { + return res; + } + + if self.common.nonblocking() { + return_errno!(EAGAIN, "no connections are present to be accepted"); + } + + // Ensure the poller is initialized + if poller.is_none() { + let new_poller = Poller::new(); + self.common.pollee().connect_poller(mask, &new_poller); + poller = Some(new_poller); + } + // Wait for interesting events by polling + + let events = self.common.pollee().poll(mask, None); + if events.is_empty() { + let ret = poller.as_ref().unwrap().wait_timeout(timeout.as_mut()); + if let Err(e) = ret { + warn!("accept wait errno = {:?}", e.errno()); + match e.errno() { + ETIMEDOUT => { + return_errno!(EAGAIN, "timeout reached") + } + _ => { + return_errno!(e.errno(), "wait error") + } + } + } + } + } + } + + pub fn try_accept(self: &Arc, nonblocking: bool) -> Result>> { + let mut inner = self.inner.lock(); + + if let Some(errno) = inner.fatal { + // Reset error + inner.fatal = None; + self.common.pollee().del_events(IoEvents::ERR); + return_errno!(errno, "accept failed"); + } + + let (accepted_fd, accepted_addr) = inner.backlog.pop_completed_req().ok_or_else(|| { + self.common.pollee().del_events(IoEvents::IN); + errno!(EAGAIN, "try accept again") + })?; + + if !inner.backlog.has_completed_reqs() { + self.common.pollee().del_events(IoEvents::IN); + } + + self.initiate_async_accepts(inner); + + let common = { + let common = Arc::new(Common::with_host_fd(accepted_fd, Type::STREAM, nonblocking)); + common.set_peer_addr(&accepted_addr); + common + }; + let accepted_stream = ConnectedStream::new(common); + Ok(accepted_stream) + } + + fn initiate_async_accepts(self: &Arc, mut inner: MutexGuard>) { + let backlog = &mut inner.backlog; + while backlog.has_free_entries() { + backlog.start_new_req(self); + } + } + + pub fn common(&self) -> &Arc> { + &self.common + } + + pub fn cancel_accept_requests(&self) { + { + // Set the listener stream as closed to prevent submitting new request in the callback fn + self.common().set_closed(); + let io_uring = self.common.io_uring(); + let inner = self.inner.lock(); + for entry in inner.backlog.entries.iter() { + if let Entry::Pending { io_handle } = entry { + unsafe { io_uring.cancel(&io_handle) }; + } + } + } + + // wait for all the cancel requests to complete + let poller = Poller::new(); + let mask = IoEvents::ERR | IoEvents::IN; + self.common.pollee().connect_poller(mask, &poller); + + loop { + let pending_entry_exists = { + let inner = self.inner.lock(); + inner + .backlog + .entries + .iter() + .find(|entry| match entry { + Entry::Pending { .. } => true, + _ => false, + }) + .is_some() + }; + + if pending_entry_exists { + let mut timeout = Some(Duration::from_secs(20)); + let ret = poller.wait_timeout(timeout.as_mut()); + if let Err(e) = ret { + warn!("wait cancel accept request error = {:?}", e.errno()); + continue; + } + } else { + // Reset the stream for re-listen + self.common().reset_closed(); + return; + } + } + } + + pub fn shutdown(&self, how: Shutdown) -> Result<()> { + if how == Shutdown::Both { + self.common.host_shutdown(Shutdown::Both)?; + self.common + .pollee() + .add_events(IoEvents::IN | IoEvents::OUT | IoEvents::HUP); + } else if how.should_shut_read() { + self.common.host_shutdown(Shutdown::Read)?; + self.common.pollee().add_events(IoEvents::IN); + } else if how.should_shut_write() { + self.common.host_shutdown(Shutdown::Write)?; + self.common.pollee().add_events(IoEvents::OUT); + } + Ok(()) + } +} + +impl std::fmt::Debug for ListenerStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ListenerStream") + .field("common", &self.common) + .field("inner", &self.inner.lock()) + .finish() + } +} + +/// The mutable, internal state of a listener stream. +struct Inner { + backlog: Backlog, + fatal: Option, +} + +impl Inner { + pub fn new(backlog: u32) -> Result { + Ok(Inner { + backlog: Backlog::with_capacity(backlog as usize)?, + fatal: None, + }) + } +} + +impl std::fmt::Debug for Inner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Inner") + .field("backlog", &self.backlog) + .field("fatal", &self.fatal) + .finish() + } +} + +/// An entry in the backlog. +#[derive(Debug)] +enum Entry { + /// The entry is free to use. + Free, + /// The entry is a pending accept request. + Pending { io_handle: IoHandle }, + /// The entry is a completed accept request. + Completed { host_fd: FileDesc }, +} + +impl Default for Entry { + fn default() -> Self { + Self::Free + } +} + +/// An async io_uring accept request. +#[derive(Copy, Clone)] +#[repr(C)] +struct AcceptReq { + c_addr: libc::sockaddr_storage, + c_addr_len: libc::socklen_t, +} + +// Safety. AcceptReq is a C-style struct with C-style fields. +unsafe impl MaybeUntrusted for AcceptReq {} + +/// A backlog of incoming connections of a listener stream. +/// +/// With backlog, we can start async accept requests, keep track of the pending requests, +/// and maintain the ones that have completed. +struct Backlog { + // The entries in the backlog. + entries: Box<[Entry]>, + // Arguments of the io_uring requests submitted for the entries in the backlog. + reqs: UntrustedBox<[AcceptReq]>, + // The indexes of completed entries. + completed: VecDeque, + // The number of free entries. + num_free: usize, + phantom_data: PhantomData, +} + +impl Backlog { + pub fn with_capacity(capacity: usize) -> Result { + if capacity == 0 { + return_errno!(EINVAL, "capacity cannot be zero"); + } + + let entries = (0..capacity) + .map(|_| Entry::Free) + .collect::>() + .into_boxed_slice(); + let reqs = UntrustedBox::new_uninit_slice(capacity); + let completed = VecDeque::new(); + let num_free = capacity; + let new_self = Self { + entries, + reqs, + completed, + num_free, + phantom_data: PhantomData, + }; + Ok(new_self) + } + + pub fn has_free_entries(&self) -> bool { + self.num_free > 0 + } + + /// Start a new async accept request, turning a free entry into a pending one. + pub fn start_new_req(&mut self, stream: &Arc>) { + if stream.common.is_closed() { + return; + } + debug_assert!(self.has_free_entries()); + + let entry_idx = self + .entries + .iter() + .position(|entry| matches!(entry, Entry::Free)) + .unwrap(); + + let (c_addr_ptr, c_addr_len_ptr) = { + let accept_req = &mut self.reqs[entry_idx]; + accept_req.c_addr_len = size_of::() as _; + + let c_addr_ptr = &mut accept_req.c_addr as *mut _ as _; + let c_addr_len_ptr = &mut accept_req.c_addr_len as _; + (c_addr_ptr, c_addr_len_ptr) + }; + + let callback = { + let stream = stream.clone(); + move |retval: i32| { + let mut inner = stream.inner.lock(); + + trace!("accept request complete with retval: {:?}", retval); + + if retval < 0 { + // Since most errors that may result from the accept syscall are _not fatal_, + // we simply ignore the errno code and try again. + // + // According to the man page, Linux may report the network errors on an + // newly-accepted socket through the accept system call. Thus, we should not + // treat the listener socket as "broken" simply because an error is returned + // from the accept syscall. + // + // TODO: throw fatal errors to the upper layer. + let errno = Errno::from(-retval as u32); + log::error!("Accept error: errno = {}", errno); + + inner.backlog.entries[entry_idx] = Entry::Free; + inner.backlog.num_free += 1; + + // When canceling request, a poller might be waiting for this to return. + inner.fatal = Some(errno); + stream.common.set_errno(errno); + stream.common.pollee().add_events(IoEvents::ERR); + + // After getting the error from the accept system call, we should not start + // the async accept requests again, because this may cause a large number of + // io-uring requests to be retried + return; + } + + let host_fd = retval as FileDesc; + inner.backlog.entries[entry_idx] = Entry::Completed { host_fd }; + inner.backlog.completed.push_back(entry_idx); + + stream.common.pollee().add_events(IoEvents::IN); + + stream.initiate_async_accepts(inner); + } + }; + let io_uring = stream.common.io_uring(); + let fd = stream.common.host_fd() as i32; + let flags = 0; + let io_handle = + unsafe { io_uring.accept(Fd(fd), c_addr_ptr, c_addr_len_ptr, flags, callback) }; + self.entries[entry_idx] = Entry::Pending { io_handle }; + self.num_free -= 1; + } + + pub fn has_completed_reqs(&self) -> bool { + self.completed.len() > 0 + } + + /// Pop a completed async accept request, turing a completed entry into a free one. + pub fn pop_completed_req(&mut self) -> Option<(FileDesc, A)> { + let completed_idx = self.completed.pop_front()?; + let accepted_addr = { + let AcceptReq { c_addr, c_addr_len } = self.reqs[completed_idx].clone(); + A::from_c_storage(&c_addr, c_addr_len as _).unwrap() + }; + let accepted_fd = { + let entry = &mut self.entries[completed_idx]; + let accepted_fd = match entry { + Entry::Completed { host_fd } => *host_fd, + _ => unreachable!("the entry should have been completed"), + }; + self.num_free += 1; + *entry = Entry::Free; + accepted_fd + }; + Some((accepted_fd, accepted_addr)) + } +} + +impl std::fmt::Debug for Backlog { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Backlog") + .field("entries", &self.entries) + .field("completed", &self.completed) + .field("num_free", &self.num_free) + .finish() + } +} + +impl Drop for Backlog { + fn drop(&mut self) { + for entry in self.entries.iter() { + if let Entry::Completed { host_fd } = entry { + if let Err(e) = do_close(*host_fd) { + log::error!("close fd failed, host_fd: {}, err: {}", host_fd, e); + } + } + } + } +} diff --git a/src/libos/src/net/socket/uring/stream/states/mod.rs b/src/libos/src/net/socket/uring/stream/states/mod.rs new file mode 100644 index 00000000..96b6e065 --- /dev/null +++ b/src/libos/src/net/socket/uring/stream/states/mod.rs @@ -0,0 +1,9 @@ +mod connect; +mod connected; +mod init; +mod listen; + +pub use self::connect::ConnectingStream; +pub use self::connected::ConnectedStream; +pub use self::init::InitStream; +pub use self::listen::ListenerStream; diff --git a/src/libos/src/prelude.rs b/src/libos/src/prelude.rs index ee4044d5..653fdfe3 100644 --- a/src/libos/src/prelude.rs +++ b/src/libos/src/prelude.rs @@ -17,8 +17,11 @@ pub use std::sync::{ pub use crate::error::Result; pub use crate::error::*; pub use crate::fs::{File, FileDesc, FileRef}; +pub use crate::net::socket::util::Addr; +pub use crate::net::socket::{Domain, RecvFlags, SendFlags, Shutdown, Type}; pub use crate::process::{pid_t, uid_t}; pub use crate::util::sync::RwLock; +pub use crate::util::sync::{Mutex, MutexGuard}; macro_rules! debug_trace { () => { diff --git a/src/libos/src/syscall/mod.rs b/src/libos/src/syscall/mod.rs index cd3c2340..25cca5fe 100644 --- a/src/libos/src/syscall/mod.rs +++ b/src/libos/src/syscall/mod.rs @@ -44,8 +44,7 @@ use crate::net::{ do_accept, do_accept4, do_bind, do_connect, do_epoll_create, do_epoll_create1, do_epoll_ctl, do_epoll_pwait, do_epoll_wait, do_getpeername, do_getsockname, do_getsockopt, do_listen, do_poll, do_ppoll, do_pselect6, do_recvfrom, do_recvmsg, do_select, do_sendmmsg, do_sendmsg, - do_sendto, do_setsockopt, do_shutdown, do_socket, do_socketpair, mmsghdr, msghdr, msghdr_mut, - sigset_argpack, + do_sendto, do_setsockopt, do_shutdown, do_socket, do_socketpair, mmsghdr, sigset_argpack, }; use crate::process::{ do_arch_prctl, do_clone, do_execve, do_exit, do_exit_group, do_futex, do_get_robust_list, @@ -143,8 +142,8 @@ macro_rules! process_syscall_table_with_callback { (Accept = 43) => do_accept(fd: c_int, addr: *mut libc::sockaddr, addr_len: *mut libc::socklen_t), (Sendto = 44) => do_sendto(fd: c_int, base: *const c_void, len: size_t, flags: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t), (Recvfrom = 45) => do_recvfrom(fd: c_int, base: *mut c_void, len: size_t, flags: c_int, addr: *mut libc::sockaddr, addr_len: *mut libc::socklen_t), - (Sendmsg = 46) => do_sendmsg(fd: c_int, msg_ptr: *const msghdr, flags_c: c_int), - (Recvmsg = 47) => do_recvmsg(fd: c_int, msg_mut_ptr: *mut msghdr_mut, flags_c: c_int), + (Sendmsg = 46) => do_sendmsg(fd: c_int, msg_ptr: *const libc::msghdr, flags_c: c_int), + (Recvmsg = 47) => do_recvmsg(fd: c_int, msg_mut_ptr: *mut libc::msghdr, flags_c: c_int), (Shutdown = 48) => do_shutdown(fd: c_int, how: c_int), (Bind = 49) => do_bind(fd: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t), (Listen = 50) => do_listen(fd: c_int, backlog: c_int), diff --git a/src/pal/Makefile b/src/pal/Makefile index 5e33ba67..cff7e33c 100644 --- a/src/pal/Makefile +++ b/src/pal/Makefile @@ -54,6 +54,7 @@ LINK_FLAGS += -lsgx_quote_ex_sim endif endif +LINK_FLAGS += -L$(PROJECT_DIR)/deps/io-uring/ocalls/target/release/ -lsgx_io_uring_ocalls ALL_BUILD_SUBDIRS := $(sort $(patsubst %/,%,$(dir $(LIBOCCLUM_PAL_SO_REAL) $(EDL_C_OBJS) $(C_OBJS) $(CXX_OBJS) $(VDSO_OBJS)))) .PHONY: all format format-check clean @@ -79,7 +80,8 @@ $(OBJ_DIR)/pal/$(SRC_OBJ)/Enclave_u.c: $(SGX_EDGER8R) ../Enclave.edl $(SGX_EDGER8R) $(SGX_EDGER8R_MODE) --untrusted $(CUR_DIR)/../Enclave.edl \ --search-path $(SGX_SDK)/include \ --search-path $(RUST_SGX_SDK_DIR)/edl/ \ - --search-path $(CRATES_DIR)/vdso-time/ocalls + --search-path $(CRATES_DIR)/vdso-time/ocalls \ + --search-path $(PROJECT_DIR)/deps/io-uring/ocalls @echo "GEN <= $@" $(OBJ_DIR)/pal/$(SRC_OBJ)/%.o: src/%.c