[libos] Implement async network framework based on IO_Uring
This commit is contained in:
		
							parent
							
								
									9d4dcc2b21
								
							
						
					
					
						commit
						f8be7e7454
					
				
							
								
								
									
										3
									
								
								.gitmodules
									
									
									
									
										vendored
									
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										3
									
								
								.gitmodules
									
									
									
									
										vendored
									
									
								
							| @ -24,3 +24,6 @@ | |||||||
| [submodule "deps/resolv-conf"] | [submodule "deps/resolv-conf"] | ||||||
| 	path = deps/resolv-conf | 	path = deps/resolv-conf | ||||||
| 	url = https://github.com/tailhook/resolv-conf.git | 	url = https://github.com/tailhook/resolv-conf.git | ||||||
|  | [submodule "deps/io-uring"] | ||||||
|  | 	path = deps/io-uring | ||||||
|  | 	url = https://github.com/occlum/io-uring.git | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								Makefile
									
									
									
									
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										2
									
								
								Makefile
									
									
									
									
									
								
							| @ -42,6 +42,7 @@ submodule: githooks init-submodule | |||||||
| 	@cp deps/sefs/sefs-cli/lib/libsefs-cli_sim.so build/lib | 	@cp deps/sefs/sefs-cli/lib/libsefs-cli_sim.so build/lib | ||||||
| 	@cp deps/sefs/sefs-cli/lib/libsefs-cli.signed.so build/lib | 	@cp deps/sefs/sefs-cli/lib/libsefs-cli.signed.so build/lib | ||||||
| 	@cp deps/sefs/sefs-cli/enclave/Enclave.config.xml build/sefs-cli.Enclave.xml | 	@cp deps/sefs/sefs-cli/enclave/Enclave.config.xml build/sefs-cli.Enclave.xml | ||||||
|  | 	@cd deps/io-uring/ocalls && cargo clean && cargo build --release | ||||||
| else | else | ||||||
| submodule: githooks init-submodule | submodule: githooks init-submodule | ||||||
| 	@rm -rf build | 	@rm -rf build | ||||||
| @ -60,6 +61,7 @@ submodule: githooks init-submodule | |||||||
| 	@cp deps/sefs/sefs-cli/lib/libsefs-cli_sim.so build/lib | 	@cp deps/sefs/sefs-cli/lib/libsefs-cli_sim.so build/lib | ||||||
| 	@cp deps/sefs/sefs-cli/lib/libsefs-cli.signed.so build/lib | 	@cp deps/sefs/sefs-cli/lib/libsefs-cli.signed.so build/lib | ||||||
| 	@cp deps/sefs/sefs-cli/enclave/Enclave.config.xml build/sefs-cli.Enclave.xml | 	@cp deps/sefs/sefs-cli/enclave/Enclave.config.xml build/sefs-cli.Enclave.xml | ||||||
|  | 	@cd deps/io-uring/ocalls && cargo clean && cargo build --release | ||||||
| endif | endif | ||||||
| 
 | 
 | ||||||
| init-submodule: | init-submodule: | ||||||
|  | |||||||
							
								
								
									
										1
									
								
								deps/io-uring
									
									
									
									
										vendored
									
									
										Submodule
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										1
									
								
								deps/io-uring
									
									
									
									
										vendored
									
									
										Submodule
									
								
							| @ -0,0 +1 @@ | |||||||
|  | Subproject commit c654c4925bb0b013d3eec736015f8ac4888722be | ||||||
| @ -7,6 +7,8 @@ enclave { | |||||||
|     from "sgx_net.edl" import *; |     from "sgx_net.edl" import *; | ||||||
|     from "sgx_occlum_utils.edl" import *; |     from "sgx_occlum_utils.edl" import *; | ||||||
|     from "sgx_vdso_time_ocalls.edl" import *; |     from "sgx_vdso_time_ocalls.edl" import *; | ||||||
|  |     from "sgx_thread.edl" import *; | ||||||
|  |     from "sgx_io_uring_ocalls.edl" import *; | ||||||
| 
 | 
 | ||||||
|     include "sgx_quote.h" |     include "sgx_quote.h" | ||||||
|     include "occlum_edl_types.h" |     include "occlum_edl_types.h" | ||||||
|  | |||||||
							
								
								
									
										186
									
								
								src/libos/Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										186
									
								
								src/libos/Cargo.lock
									
									
									
										generated
									
									
									
								
							| @ -10,16 +10,21 @@ dependencies = [ | |||||||
|  "atomic", |  "atomic", | ||||||
|  "bitflags", |  "bitflags", | ||||||
|  "bitvec 1.0.1", |  "bitvec 1.0.1", | ||||||
|  |  "byteorder", | ||||||
|  "ctor", |  "ctor", | ||||||
|  "derive_builder", |  "derive_builder", | ||||||
|  |  "downcast-rs", | ||||||
|  "errno", |  "errno", | ||||||
|  "goblin", |  "goblin", | ||||||
|  "intrusive-collections", |  "intrusive-collections", | ||||||
|  |  "io-uring-callback", | ||||||
|  "itertools", |  "itertools", | ||||||
|  |  "keyable-arc", | ||||||
|  "lazy_static", |  "lazy_static", | ||||||
|  "log", |  "log", | ||||||
|  "memoffset 0.6.5", |  "memoffset 0.6.5", | ||||||
|  "modular-bitfield", |  "modular-bitfield", | ||||||
|  |  "num_enum", | ||||||
|  "rcore-fs", |  "rcore-fs", | ||||||
|  "rcore-fs-devfs", |  "rcore-fs-devfs", | ||||||
|  "rcore-fs-mountfs", |  "rcore-fs-mountfs", | ||||||
| @ -32,6 +37,7 @@ dependencies = [ | |||||||
|  "scroll", |  "scroll", | ||||||
|  "serde", |  "serde", | ||||||
|  "serde_json", |  "serde_json", | ||||||
|  |  "sgx-untrusted-alloc", | ||||||
|  "sgx_cov", |  "sgx_cov", | ||||||
|  "sgx_tcrypto", |  "sgx_tcrypto", | ||||||
|  "sgx_trts", |  "sgx_trts", | ||||||
| @ -112,6 +118,12 @@ dependencies = [ | |||||||
|  "wyz", |  "wyz", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "byteorder" | ||||||
|  | version = "1.5.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "cc" | name = "cc" | ||||||
| version = "1.0.73" | version = "1.0.73" | ||||||
| @ -203,6 +215,12 @@ dependencies = [ | |||||||
|  "syn", |  "syn", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "downcast-rs" | ||||||
|  | version = "1.2.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "9ea835d29036a4087793836fa931b08837ad5e957da9e23886b29586fb9b6650" | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "either" | name = "either" | ||||||
| version = "1.8.0" | version = "1.8.0" | ||||||
| @ -237,6 +255,67 @@ version = "2.0.0" | |||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
| checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" | checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "futures" | ||||||
|  | version = "0.3.28" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" | ||||||
|  | dependencies = [ | ||||||
|  |  "futures-channel", | ||||||
|  |  "futures-core", | ||||||
|  |  "futures-io", | ||||||
|  |  "futures-sink", | ||||||
|  |  "futures-task", | ||||||
|  |  "futures-util", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "futures-channel" | ||||||
|  | version = "0.3.28" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" | ||||||
|  | dependencies = [ | ||||||
|  |  "futures-core", | ||||||
|  |  "futures-sink", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "futures-core" | ||||||
|  | version = "0.3.28" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" | ||||||
|  | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "futures-io" | ||||||
|  | version = "0.3.28" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" | ||||||
|  | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "futures-sink" | ||||||
|  | version = "0.3.28" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" | ||||||
|  | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "futures-task" | ||||||
|  | version = "0.3.28" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" | ||||||
|  | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "futures-util" | ||||||
|  | version = "0.3.28" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" | ||||||
|  | dependencies = [ | ||||||
|  |  "futures-core", | ||||||
|  |  "futures-sink", | ||||||
|  |  "futures-task", | ||||||
|  |  "pin-project-lite", | ||||||
|  |  "pin-utils", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "goblin" | name = "goblin" | ||||||
| version = "0.5.4" | version = "0.5.4" | ||||||
| @ -267,6 +346,36 @@ dependencies = [ | |||||||
|  "memoffset 0.5.6", |  "memoffset 0.5.6", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "io-uring" | ||||||
|  | version = "0.5.9" | ||||||
|  | dependencies = [ | ||||||
|  |  "bitflags", | ||||||
|  |  "libc", | ||||||
|  |  "sgx_libc", | ||||||
|  |  "sgx_trts", | ||||||
|  |  "sgx_tstd", | ||||||
|  |  "sgx_types", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "io-uring-callback" | ||||||
|  | version = "0.1.0" | ||||||
|  | dependencies = [ | ||||||
|  |  "atomic", | ||||||
|  |  "cfg-if", | ||||||
|  |  "futures", | ||||||
|  |  "io-uring", | ||||||
|  |  "lazy_static", | ||||||
|  |  "libc", | ||||||
|  |  "lock_api", | ||||||
|  |  "log", | ||||||
|  |  "sgx_libc", | ||||||
|  |  "sgx_tstd", | ||||||
|  |  "slab", | ||||||
|  |  "spin 0.7.1", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "itertools" | name = "itertools" | ||||||
| version = "0.10.3" | version = "0.10.3" | ||||||
| @ -283,6 +392,10 @@ dependencies = [ | |||||||
|  "sgx_tstd", |  "sgx_tstd", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "keyable-arc" | ||||||
|  | version = "0.1.0" | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "lazy_static" | name = "lazy_static" | ||||||
| version = "1.4.0" | version = "1.4.0" | ||||||
| @ -298,6 +411,15 @@ version = "0.2.132" | |||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
| checksum = "8371e4e5341c3a96db127eb2465ac681ced4c433e01dd0e938adbef26ba93ba5" | checksum = "8371e4e5341c3a96db127eb2465ac681ced4c433e01dd0e938adbef26ba93ba5" | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "lock_api" | ||||||
|  | version = "0.4.2" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "dd96ffd135b2fd7b973ac026d28085defbe8983df057ced3eb4f2130b0831312" | ||||||
|  | dependencies = [ | ||||||
|  |  "scopeguard", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "log" | name = "log" | ||||||
| version = "0.4.17" | version = "0.4.17" | ||||||
| @ -346,6 +468,38 @@ dependencies = [ | |||||||
|  "syn", |  "syn", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "num_enum" | ||||||
|  | version = "0.5.11" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "1f646caf906c20226733ed5b1374287eb97e3c2a5c227ce668c1f2ce20ae57c9" | ||||||
|  | dependencies = [ | ||||||
|  |  "num_enum_derive", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "num_enum_derive" | ||||||
|  | version = "0.5.11" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "dcbff9bc912032c62bf65ef1d5aea88983b420f4f839db1e9b0c281a25c9c799" | ||||||
|  | dependencies = [ | ||||||
|  |  "proc-macro2", | ||||||
|  |  "quote", | ||||||
|  |  "syn", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "pin-project-lite" | ||||||
|  | version = "0.2.13" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" | ||||||
|  | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "pin-utils" | ||||||
|  | version = "0.1.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "plain" | name = "plain" | ||||||
| version = "0.2.3" | version = "0.2.3" | ||||||
| @ -601,6 +755,12 @@ version = "1.0.11" | |||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
| checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09" | checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09" | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "scopeguard" | ||||||
|  | version = "1.2.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "scroll" | name = "scroll" | ||||||
| version = "0.11.0" | version = "0.11.0" | ||||||
| @ -648,6 +808,23 @@ dependencies = [ | |||||||
|  "sgx_tstd", |  "sgx_tstd", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "sgx-untrusted-alloc" | ||||||
|  | version = "0.1.0" | ||||||
|  | dependencies = [ | ||||||
|  |  "cfg-if", | ||||||
|  |  "errno", | ||||||
|  |  "intrusive-collections", | ||||||
|  |  "lazy_static", | ||||||
|  |  "libc", | ||||||
|  |  "log", | ||||||
|  |  "sgx_libc", | ||||||
|  |  "sgx_trts", | ||||||
|  |  "sgx_tstd", | ||||||
|  |  "sgx_types", | ||||||
|  |  "spin 0.7.1", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "sgx_alloc" | name = "sgx_alloc" | ||||||
| version = "1.1.6" | version = "1.1.6" | ||||||
| @ -753,6 +930,15 @@ dependencies = [ | |||||||
|  "sgx_build_helper", |  "sgx_build_helper", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "slab" | ||||||
|  | version = "0.4.9" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" | ||||||
|  | dependencies = [ | ||||||
|  |  "autocfg 1.1.0", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "spin" | name = "spin" | ||||||
| version = "0.5.2" | version = "0.5.2" | ||||||
|  | |||||||
| @ -7,6 +7,7 @@ use super::*; | |||||||
| use crate::exception::*; | use crate::exception::*; | ||||||
| use crate::fs::HostStdioFds; | use crate::fs::HostStdioFds; | ||||||
| use crate::interrupt; | use crate::interrupt; | ||||||
|  | use crate::io_uring::ENABLE_URING; | ||||||
| use crate::process::idle_reap_zombie_children; | use crate::process::idle_reap_zombie_children; | ||||||
| use crate::process::{ProcessFilter, SpawnAttr}; | use crate::process::{ProcessFilter, SpawnAttr}; | ||||||
| use crate::signal::SigNum; | use crate::signal::SigNum; | ||||||
| @ -101,11 +102,14 @@ pub extern "C" fn occlum_ecall_init( | |||||||
| 
 | 
 | ||||||
|         vm::init_user_space(); |         vm::init_user_space(); | ||||||
| 
 | 
 | ||||||
|  |         if ENABLE_URING.load(Ordering::Relaxed) { | ||||||
|  |             crate::io_uring::MULTITON.poll_completions(); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         // Register exception handlers (support cpuid & rdtsc for now)
 |         // Register exception handlers (support cpuid & rdtsc for now)
 | ||||||
|         register_exception_handlers(); |         register_exception_handlers(); | ||||||
| 
 | 
 | ||||||
|         HAS_INIT.store(true, Ordering::Release); |         HAS_INIT.store(true, Ordering::Release); | ||||||
| 
 |  | ||||||
|         // Enable global backtrace
 |         // Enable global backtrace
 | ||||||
|         unsafe { backtrace::enable_backtrace(&ENCLAVE_PATH, PrintFormat::Short) }; |         unsafe { backtrace::enable_backtrace(&ENCLAVE_PATH, PrintFormat::Short) }; | ||||||
| 
 | 
 | ||||||
|  | |||||||
							
								
								
									
										169
									
								
								src/libos/src/io_uring.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										169
									
								
								src/libos/src/io_uring.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,169 @@ | |||||||
|  | use core::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize}; | ||||||
|  | use std::{collections::HashMap, thread::current}; | ||||||
|  | 
 | ||||||
|  | use crate::util::sync::Mutex; | ||||||
|  | use alloc::{sync::Arc, vec::Vec}; | ||||||
|  | use atomic::Ordering; | ||||||
|  | use io_uring_callback::{Builder, IoUring}; | ||||||
|  | use keyable_arc::KeyableArc; | ||||||
|  | 
 | ||||||
|  | use crate::config::LIBOS_CONFIG; | ||||||
|  | 
 | ||||||
|  | // The number of sockets to reach the network bandwidth threshold of one io_uring instance
 | ||||||
|  | const SOCKET_THRESHOLD_PER_URING: u32 = 1; | ||||||
|  | 
 | ||||||
|  | lazy_static::lazy_static! { | ||||||
|  |     pub static ref MULTITON: UringSet = { | ||||||
|  |         let uring_set = UringSet::new(); | ||||||
|  |         uring_set | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     pub static ref ENABLE_URING: AtomicBool = AtomicBool::new(LIBOS_CONFIG.feature.io_uring > 0); | ||||||
|  | 
 | ||||||
|  |     // Four uring instances are sufficient to reach the network bandwidth threshold of host kernel.
 | ||||||
|  |     pub static ref URING_LIMIT: AtomicUsize = { | ||||||
|  |         let uring_limit = LIBOS_CONFIG.feature.io_uring; | ||||||
|  |         assert!(uring_limit <= 16, "io_uring limit must not exceed 16"); | ||||||
|  |         AtomicUsize::new(uring_limit as usize) | ||||||
|  |     }; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Clone, Copy, Default)] | ||||||
|  | struct UringState { | ||||||
|  |     registered_num: u32, | ||||||
|  |     is_enable_poll: bool, // CQE polling thread
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl UringState { | ||||||
|  |     fn register_one_socket(&mut self) { | ||||||
|  |         self.registered_num += 1; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn unregister_one_socket(&mut self) { | ||||||
|  |         self.registered_num -= 1; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn enable_poll(&mut self, uring: Arc<IoUring>) { | ||||||
|  |         if !self.is_enable_poll { | ||||||
|  |             self.is_enable_poll = true; | ||||||
|  |             std::thread::spawn(move || loop { | ||||||
|  |                 let min_complete = 1; | ||||||
|  |                 let polling_retries = 10000; | ||||||
|  |                 uring.poll_completions(min_complete, polling_retries); | ||||||
|  |             }); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub struct UringSet { | ||||||
|  |     urings: Mutex<HashMap<KeyableArc<IoUring>, UringState>>, | ||||||
|  |     running_uring_num: AtomicU32, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl UringSet { | ||||||
|  |     pub fn new() -> Self { | ||||||
|  |         let urings = Mutex::new(HashMap::new()); | ||||||
|  |         let running_uring_num = AtomicU32::new(0); | ||||||
|  |         Self { | ||||||
|  |             urings, | ||||||
|  |             running_uring_num, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn poll_completions(&self) { | ||||||
|  |         let mut guard = self.urings.lock(); | ||||||
|  |         let uring_limit = URING_LIMIT.load(Ordering::Relaxed) as u32; | ||||||
|  | 
 | ||||||
|  |         for _ in 0..uring_limit { | ||||||
|  |             let uring: KeyableArc<IoUring> = Arc::new( | ||||||
|  |                 Builder::new() | ||||||
|  |                     .setup_sqpoll(500 /* ms */) | ||||||
|  |                     .build(256) | ||||||
|  |                     .unwrap(), | ||||||
|  |             ) | ||||||
|  |             .into(); | ||||||
|  |             let mut state = UringState::default(); | ||||||
|  |             state.enable_poll(uring.clone().into()); | ||||||
|  | 
 | ||||||
|  |             guard.insert(uring.clone(), state); | ||||||
|  |             self.running_uring_num.fetch_add(1, Ordering::Relaxed); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn get_uring(&self) -> Arc<IoUring> { | ||||||
|  |         let mut map = self.urings.lock(); | ||||||
|  |         let running_uring_num = self.running_uring_num.load(Ordering::Relaxed); | ||||||
|  |         let uring_limit = URING_LIMIT.load(Ordering::Relaxed) as u32; | ||||||
|  |         assert!(running_uring_num <= uring_limit); | ||||||
|  | 
 | ||||||
|  |         let init_stage = running_uring_num < uring_limit; | ||||||
|  | 
 | ||||||
|  |         // Construct an io_uring instance and initiate a polling thread
 | ||||||
|  |         if init_stage { | ||||||
|  |             let should_build_uring = { | ||||||
|  |                 // Sum registered socket
 | ||||||
|  |                 let total_socket_num = map | ||||||
|  |                     .values() | ||||||
|  |                     .fold(0, |acc, state| acc + state.registered_num) | ||||||
|  |                     + 1; | ||||||
|  |                 // Determine the number of available io_uring
 | ||||||
|  |                 let uring_num = (total_socket_num / SOCKET_THRESHOLD_PER_URING) + 1; | ||||||
|  |                 let existed_uring_num = self.running_uring_num.load(Ordering::Relaxed); | ||||||
|  |                 assert!(existed_uring_num <= uring_num); | ||||||
|  |                 existed_uring_num < uring_num | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             if should_build_uring { | ||||||
|  |                 let uring: KeyableArc<IoUring> = Arc::new( | ||||||
|  |                     Builder::new() | ||||||
|  |                         .setup_sqpoll(500 /* ms */) | ||||||
|  |                         .build(256) | ||||||
|  |                         .unwrap(), | ||||||
|  |                 ) | ||||||
|  |                 .into(); | ||||||
|  |                 let mut state = UringState::default(); | ||||||
|  |                 state.register_one_socket(); | ||||||
|  |                 state.enable_poll(uring.clone().into()); | ||||||
|  | 
 | ||||||
|  |                 map.insert(uring.clone(), state); | ||||||
|  |                 self.running_uring_num.fetch_add(1, Ordering::Relaxed); | ||||||
|  |                 return uring.into(); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Link the file to the io_uring instance with the least load.
 | ||||||
|  |         let (mut uring, mut state) = map | ||||||
|  |             .iter_mut() | ||||||
|  |             .min_by_key(|(_, &mut state)| state.registered_num) | ||||||
|  |             .unwrap(); | ||||||
|  | 
 | ||||||
|  |         // Re-select io_uring instance with least task load
 | ||||||
|  |         if !init_stage { | ||||||
|  |             let min_registered_num = state.registered_num; | ||||||
|  |             (uring, state) = map | ||||||
|  |                 .iter_mut() | ||||||
|  |                 .filter(|(_, state)| state.registered_num == min_registered_num) | ||||||
|  |                 .min_by_key(|(uring, _)| uring.task_load()) | ||||||
|  |                 .unwrap(); | ||||||
|  |         } else { | ||||||
|  |             // At the initial stage, without constructing additional io_uring instances,
 | ||||||
|  |             // there exists a singular io_uring which has the minimum number of registered sockets.
 | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Update io_uring instance states
 | ||||||
|  |         state.register_one_socket(); | ||||||
|  |         assert!(state.is_enable_poll); | ||||||
|  | 
 | ||||||
|  |         uring.clone().into() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn disattach_uring(&self, fd: usize, uring: Arc<IoUring>) { | ||||||
|  |         let uring: KeyableArc<IoUring> = uring.into(); | ||||||
|  |         let mut map = self.urings.lock(); | ||||||
|  |         let mut state = map.get_mut(&uring).unwrap(); | ||||||
|  |         state.unregister_one_socket(); | ||||||
|  |         drop(map); | ||||||
|  | 
 | ||||||
|  |         uring.disattach_fd(fd); | ||||||
|  |     } | ||||||
|  | } | ||||||
| @ -28,6 +28,8 @@ | |||||||
| #![feature(is_some_and)] | #![feature(is_some_and)] | ||||||
| // for edmm_api macro
 | // for edmm_api macro
 | ||||||
| #![feature(linkage)] | #![feature(linkage)] | ||||||
|  | #![feature(new_uninit)] | ||||||
|  | #![feature(raw_ref_op)] | ||||||
| 
 | 
 | ||||||
| #[macro_use] | #[macro_use] | ||||||
| extern crate alloc; | extern crate alloc; | ||||||
| @ -66,7 +68,6 @@ extern crate intrusive_collections; | |||||||
| extern crate itertools; | extern crate itertools; | ||||||
| extern crate modular_bitfield; | extern crate modular_bitfield; | ||||||
| extern crate resolv_conf; | extern crate resolv_conf; | ||||||
| extern crate vdso_time; |  | ||||||
| 
 | 
 | ||||||
| use sgx_trts::libc; | use sgx_trts::libc; | ||||||
| use sgx_types::*; | use sgx_types::*; | ||||||
| @ -82,15 +83,18 @@ mod prelude; | |||||||
| #[macro_use] | #[macro_use] | ||||||
| mod error; | mod error; | ||||||
| 
 | 
 | ||||||
|  | #[macro_use] | ||||||
|  | mod net; | ||||||
|  | 
 | ||||||
| mod config; | mod config; | ||||||
| mod entry; | mod entry; | ||||||
| mod events; | mod events; | ||||||
| mod exception; | mod exception; | ||||||
| mod fs; | mod fs; | ||||||
| mod interrupt; | mod interrupt; | ||||||
|  | mod io_uring; | ||||||
| mod ipc; | mod ipc; | ||||||
| mod misc; | mod misc; | ||||||
| mod net; |  | ||||||
| mod process; | mod process; | ||||||
| mod sched; | mod sched; | ||||||
| mod signal; | mod signal; | ||||||
|  | |||||||
| @ -7,12 +7,13 @@ pub use self::io_multiplexing::{ | |||||||
|     PollEventFlags, PollFd, THREAD_NOTIFIERS, |     PollEventFlags, PollFd, THREAD_NOTIFIERS, | ||||||
| }; | }; | ||||||
| pub use self::socket::{ | pub use self::socket::{ | ||||||
|     mmsghdr, msghdr, msghdr_mut, socketpair, unix_socket, AddressFamily, AsUnixSocket, FileFlags, |     socketpair, unix_socket, AsUnixSocket, Domain, HostSocket, HostSocketType, Iovs, IovsMut, | ||||||
|     HostSocket, HostSocketType, HowToShut, Iovs, IovsMut, MsgHdr, MsgHdrFlags, MsgHdrMut, |     RawAddr, SliceAsLibcIovec, UnixAddr, | ||||||
|     RecvFlags, SendFlags, SliceAsLibcIovec, SockAddr, SocketType, UnixAddr, |  | ||||||
| }; | }; | ||||||
| pub use self::syscalls::*; | pub use self::syscalls::*; | ||||||
| 
 | 
 | ||||||
| mod io_multiplexing; | mod io_multiplexing; | ||||||
| mod socket; | pub(crate) mod socket; | ||||||
| mod syscalls; | mod syscalls; | ||||||
|  | 
 | ||||||
|  | pub use self::syscalls::*; | ||||||
|  | |||||||
| @ -1,21 +1,15 @@ | |||||||
| use super::*; | use super::*; | ||||||
| 
 | 
 | ||||||
| mod address_family; |  | ||||||
| mod flags; |  | ||||||
| mod host; | mod host; | ||||||
| mod iovs; | pub(crate) mod sockopt; | ||||||
| mod msg; |  | ||||||
| mod shutdown; |  | ||||||
| mod socket_address; |  | ||||||
| mod socket_type; |  | ||||||
| mod unix; | mod unix; | ||||||
|  | pub(crate) mod uring; | ||||||
|  | pub(crate) mod util; | ||||||
| 
 | 
 | ||||||
| pub use self::address_family::AddressFamily; |  | ||||||
| pub use self::flags::{FileFlags, MsgHdrFlags, RecvFlags, SendFlags}; |  | ||||||
| pub use self::host::{HostSocket, HostSocketType}; | pub use self::host::{HostSocket, HostSocketType}; | ||||||
| pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec}; | pub use self::unix::{socketpair, unix_socket, AsUnixSocket}; | ||||||
| pub use self::msg::{mmsghdr, msghdr, msghdr_mut, CMessages, CmsgData, MsgHdr, MsgHdrMut}; | pub use self::util::{ | ||||||
| pub use self::shutdown::HowToShut; |     Addr, AnyAddr, CMessages, CSockAddr, CmsgData, Domain, Iovs, IovsMut, Ipv4Addr, Ipv4SocketAddr, | ||||||
| pub use self::socket_address::SockAddr; |     Ipv6SocketAddr, MsgFlags, RawAddr, RecvFlags, SendFlags, Shutdown, SliceAsLibcIovec, | ||||||
| pub use self::socket_type::SocketType; |     SocketProtocol, Type, UnixAddr, | ||||||
| pub use self::unix::{socketpair, unix_socket, AsUnixSocket, UnixAddr}; | }; | ||||||
|  | |||||||
							
								
								
									
										241
									
								
								src/libos/src/net/socket/uring/common/common.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										241
									
								
								src/libos/src/net/socket/uring/common/common.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,241 @@ | |||||||
|  | use core::time::Duration; | ||||||
|  | use std::marker::PhantomData; | ||||||
|  | use std::sync::atomic::{AtomicBool, Ordering}; | ||||||
|  | 
 | ||||||
|  | use super::Timeout; | ||||||
|  | use io_uring_callback::IoUring; | ||||||
|  | 
 | ||||||
|  | use libc::ocall::getsockname as do_getsockname; | ||||||
|  | use libc::ocall::shutdown as do_shutdown; | ||||||
|  | use libc::ocall::socket as do_socket; | ||||||
|  | use libc::ocall::socketpair as do_socketpair; | ||||||
|  | 
 | ||||||
|  | use crate::events::Pollee; | ||||||
|  | use crate::fs::{IoEvents, IoNotifier}; | ||||||
|  | use crate::net::socket::uring::runtime::Runtime; | ||||||
|  | use crate::prelude::*; | ||||||
|  | 
 | ||||||
|  | /// The common parts of all stream sockets.
 | ||||||
|  | pub struct Common<A: Addr + 'static, R: Runtime> { | ||||||
|  |     host_fd: FileDesc, | ||||||
|  |     type_: Type, | ||||||
|  |     nonblocking: AtomicBool, | ||||||
|  |     is_closed: AtomicBool, | ||||||
|  |     pollee: Pollee, | ||||||
|  |     inner: Mutex<Inner<A>>, | ||||||
|  |     timeout: Mutex<Timeout>, | ||||||
|  |     errno: Mutex<Option<Errno>>, | ||||||
|  |     io_uring: Arc<IoUring>, | ||||||
|  |     phantom_data: PhantomData<(A, R)>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static, R: Runtime> Common<A, R> { | ||||||
|  |     pub fn new(type_: Type, nonblocking: bool, protocol: Option<i32>) -> Result<Self> { | ||||||
|  |         let domain_c = A::domain() as libc::c_int; | ||||||
|  |         let type_c = type_ as libc::c_int; | ||||||
|  |         let protocol = protocol.unwrap_or(0) as libc::c_int; | ||||||
|  |         let host_fd = try_libc!(do_socket(domain_c, type_c, protocol)) as FileDesc; | ||||||
|  |         let nonblocking = AtomicBool::new(nonblocking); | ||||||
|  |         let is_closed = AtomicBool::new(false); | ||||||
|  |         let pollee = Pollee::new(IoEvents::empty()); | ||||||
|  |         let inner = Mutex::new(Inner::new()); | ||||||
|  |         let timeout = Mutex::new(Timeout::new()); | ||||||
|  |         let io_uring = R::io_uring(); | ||||||
|  |         let errno = Mutex::new(None); | ||||||
|  |         Ok(Self { | ||||||
|  |             host_fd, | ||||||
|  |             type_, | ||||||
|  |             nonblocking, | ||||||
|  |             is_closed, | ||||||
|  |             pollee, | ||||||
|  |             inner, | ||||||
|  |             timeout, | ||||||
|  |             errno, | ||||||
|  |             io_uring, | ||||||
|  |             phantom_data: PhantomData, | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn new_pair(sock_type: Type, nonblocking: bool) -> Result<(Self, Self)> { | ||||||
|  |         return_errno!(EINVAL, "Unix is unsupported"); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn with_host_fd(host_fd: FileDesc, type_: Type, nonblocking: bool) -> Self { | ||||||
|  |         let nonblocking = AtomicBool::new(nonblocking); | ||||||
|  |         let is_closed = AtomicBool::new(false); | ||||||
|  |         let pollee = Pollee::new(IoEvents::empty()); | ||||||
|  |         let inner = Mutex::new(Inner::new()); | ||||||
|  |         let timeout = Mutex::new(Timeout::new()); | ||||||
|  |         let io_uring = R::io_uring(); | ||||||
|  |         let errno = Mutex::new(None); | ||||||
|  |         Self { | ||||||
|  |             host_fd, | ||||||
|  |             type_, | ||||||
|  |             nonblocking, | ||||||
|  |             is_closed, | ||||||
|  |             pollee, | ||||||
|  |             inner, | ||||||
|  |             timeout, | ||||||
|  |             errno, | ||||||
|  |             io_uring, | ||||||
|  |             phantom_data: PhantomData, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn io_uring(&self) -> Arc<IoUring> { | ||||||
|  |         self.io_uring.clone() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn host_fd(&self) -> FileDesc { | ||||||
|  |         self.host_fd | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn type_(&self) -> Type { | ||||||
|  |         self.type_ | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn nonblocking(&self) -> bool { | ||||||
|  |         self.nonblocking.load(Ordering::Relaxed) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn set_nonblocking(&self, is_nonblocking: bool) { | ||||||
|  |         self.nonblocking.store(is_nonblocking, Ordering::Relaxed) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn notifier(&self) -> &IoNotifier { | ||||||
|  |         self.pollee.notifier() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn send_timeout(&self) -> Option<Duration> { | ||||||
|  |         self.timeout.lock().sender_timeout() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn recv_timeout(&self) -> Option<Duration> { | ||||||
|  |         self.timeout.lock().receiver_timeout() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn set_send_timeout(&self, timeout: Duration) { | ||||||
|  |         self.timeout.lock().set_sender(timeout) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn set_recv_timeout(&self, timeout: Duration) { | ||||||
|  |         self.timeout.lock().set_receiver(timeout) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn is_closed(&self) -> bool { | ||||||
|  |         self.is_closed.load(Ordering::Relaxed) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn set_closed(&self) { | ||||||
|  |         self.is_closed.store(true, Ordering::Relaxed) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn reset_closed(&self) { | ||||||
|  |         self.is_closed.store(false, Ordering::Relaxed) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn pollee(&self) -> &Pollee { | ||||||
|  |         &self.pollee | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[allow(unused)] | ||||||
|  |     pub fn addr(&self) -> Option<A> { | ||||||
|  |         let inner = self.inner.lock(); | ||||||
|  |         inner.addr.clone() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn set_addr(&self, addr: &A) { | ||||||
|  |         let mut inner = self.inner.lock(); | ||||||
|  |         inner.addr = Some(addr.clone()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn get_addr_from_host(&self) -> Result<A> { | ||||||
|  |         let mut c_addr: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; | ||||||
|  |         let mut c_addr_len = std::mem::size_of::<libc::sockaddr_storage>() as u32; | ||||||
|  |         try_libc!(do_getsockname( | ||||||
|  |             self.host_fd as _, | ||||||
|  |             &mut c_addr as *mut libc::sockaddr_storage as *mut _, | ||||||
|  |             &mut c_addr_len as *mut _, | ||||||
|  |         )); | ||||||
|  |         A::from_c_storage(&c_addr, c_addr_len as _) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn peer_addr(&self) -> Option<A> { | ||||||
|  |         let inner = self.inner.lock(); | ||||||
|  |         inner.peer_addr.clone() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn set_peer_addr(&self, peer_addr: &A) { | ||||||
|  |         let mut inner = self.inner.lock(); | ||||||
|  |         inner.peer_addr = Some(peer_addr.clone()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn reset_peer_addr(&self) { | ||||||
|  |         let mut inner = self.inner.lock(); | ||||||
|  |         inner.peer_addr = None; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // For getsockopt SO_ERROR command
 | ||||||
|  |     pub fn errno(&self) -> Option<Errno> { | ||||||
|  |         let mut errno_option = self.errno.lock(); | ||||||
|  |         errno_option.take() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn set_errno(&self, errno: Errno) { | ||||||
|  |         let mut errno_option = self.errno.lock(); | ||||||
|  |         *errno_option = Some(errno); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn host_shutdown(&self, how: Shutdown) -> Result<()> { | ||||||
|  |         trace!("host shutdown: {:?}", how); | ||||||
|  |         match how { | ||||||
|  |             Shutdown::Write => { | ||||||
|  |                 try_libc!(do_shutdown(self.host_fd as _, libc::SHUT_WR)); | ||||||
|  |             } | ||||||
|  |             Shutdown::Read => { | ||||||
|  |                 try_libc!(do_shutdown(self.host_fd as _, libc::SHUT_RD)); | ||||||
|  |             } | ||||||
|  |             Shutdown::Both => { | ||||||
|  |                 try_libc!(do_shutdown(self.host_fd as _, libc::SHUT_RDWR)); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static, R: Runtime> std::fmt::Debug for Common<A, R> { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         f.debug_struct("Common") | ||||||
|  |             .field("host_fd", &self.host_fd) | ||||||
|  |             .field("type", &self.type_) | ||||||
|  |             .field("nonblocking", &self.nonblocking) | ||||||
|  |             .field("pollee", &self.pollee) | ||||||
|  |             .field("inner", &self.inner.lock()) | ||||||
|  |             .finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static, R: Runtime> Drop for Common<A, R> { | ||||||
|  |     fn drop(&mut self) { | ||||||
|  |         if let Err(e) = super::do_close(self.host_fd) { | ||||||
|  |             log::error!("do_close failed, host_fd: {}, err: {:?}", self.host_fd, e); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         R::disattach_io_uring(self.host_fd as usize, self.io_uring()) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Debug)] | ||||||
|  | struct Inner<A: Addr + 'static> { | ||||||
|  |     addr: Option<A>, | ||||||
|  |     peer_addr: Option<A>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static> Inner<A> { | ||||||
|  |     pub fn new() -> Self { | ||||||
|  |         Self { | ||||||
|  |             addr: None, | ||||||
|  |             peer_addr: None, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										7
									
								
								src/libos/src/net/socket/uring/common/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										7
									
								
								src/libos/src/net/socket/uring/common/mod.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,7 @@ | |||||||
|  | mod common; | ||||||
|  | mod operation; | ||||||
|  | mod timeout; | ||||||
|  | 
 | ||||||
|  | pub use self::common::Common; | ||||||
|  | pub use self::operation::{do_bind, do_close, do_connect, do_unlink}; | ||||||
|  | pub use self::timeout::Timeout; | ||||||
							
								
								
									
										44
									
								
								src/libos/src/net/socket/uring/common/operation.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										44
									
								
								src/libos/src/net/socket/uring/common/operation.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,44 @@ | |||||||
|  | use std::ffi::CString; | ||||||
|  | use std::mem::{self, MaybeUninit}; | ||||||
|  | 
 | ||||||
|  | use crate::prelude::*; | ||||||
|  | 
 | ||||||
|  | pub fn do_bind<A: Addr>(host_fd: FileDesc, addr: &A) -> Result<()> { | ||||||
|  |     let fd = host_fd as i32; | ||||||
|  |     let (c_addr_storage, c_addr_len) = addr.to_c_storage(); | ||||||
|  |     let c_addr_ptr = &c_addr_storage as *const _ as _; | ||||||
|  |     let c_addr_len = c_addr_len as u32; | ||||||
|  |     try_libc!(libc::ocall::bind(fd, c_addr_ptr, c_addr_len)); | ||||||
|  |     Ok(()) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub fn do_close(host_fd: FileDesc) -> Result<()> { | ||||||
|  |     let fd = host_fd as i32; | ||||||
|  |     try_libc!(libc::ocall::close(fd)); | ||||||
|  |     Ok(()) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub fn do_unlink(path: &String) -> Result<()> { | ||||||
|  |     let c_string = | ||||||
|  |         CString::new(path.as_bytes()).map_err(|_| errno!(EINVAL, "cstring new failure"))?; | ||||||
|  |     let c_path = c_string.as_c_str().as_ptr(); | ||||||
|  |     try_libc!(libc::ocall::unlink(c_path)); | ||||||
|  |     Ok(()) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub fn do_connect<A: Addr>(host_fd: FileDesc, addr: Option<&A>) -> Result<()> { | ||||||
|  |     let fd = host_fd as i32; | ||||||
|  |     let (c_addr_storage, c_addr_len) = match addr { | ||||||
|  |         Some(addr_inner) => addr_inner.to_c_storage(), | ||||||
|  |         None => { | ||||||
|  |             let mut sockaddr_storage = | ||||||
|  |                 unsafe { MaybeUninit::<libc::sockaddr_storage>::uninit().assume_init() }; | ||||||
|  |             sockaddr_storage.ss_family = libc::AF_UNSPEC as _; | ||||||
|  |             (sockaddr_storage, mem::size_of::<libc::sa_family_t>()) | ||||||
|  |         } | ||||||
|  |     }; | ||||||
|  |     let c_addr_ptr = &c_addr_storage as *const _ as _; | ||||||
|  |     let c_addr_len = c_addr_len as u32; | ||||||
|  |     try_libc!(libc::ocall::connect(fd, c_addr_ptr, c_addr_len)); | ||||||
|  |     Ok(()) | ||||||
|  | } | ||||||
							
								
								
									
										32
									
								
								src/libos/src/net/socket/uring/common/timeout.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										32
									
								
								src/libos/src/net/socket/uring/common/timeout.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,32 @@ | |||||||
|  | use std::time::Duration; | ||||||
|  | 
 | ||||||
|  | #[derive(Clone, Debug)] | ||||||
|  | pub struct Timeout { | ||||||
|  |     sender: Option<Duration>, | ||||||
|  |     receiver: Option<Duration>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Timeout { | ||||||
|  |     pub fn new() -> Self { | ||||||
|  |         Self { | ||||||
|  |             sender: None, | ||||||
|  |             receiver: None, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn sender_timeout(&self) -> Option<Duration> { | ||||||
|  |         self.sender | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn receiver_timeout(&self) -> Option<Duration> { | ||||||
|  |         self.receiver | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn set_sender(&mut self, timeout: Duration) { | ||||||
|  |         self.sender = Some(timeout); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn set_receiver(&mut self, timeout: Duration) { | ||||||
|  |         self.receiver = Some(timeout); | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										494
									
								
								src/libos/src/net/socket/uring/datagram/generic.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										494
									
								
								src/libos/src/net/socket/uring/datagram/generic.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,494 @@ | |||||||
|  | use core::time::Duration; | ||||||
|  | 
 | ||||||
|  | use crate::{ | ||||||
|  |     events::{Observer, Poller}, | ||||||
|  |     fs::{IoNotifier, StatusFlags}, | ||||||
|  |     match_ioctl_cmd_mut, | ||||||
|  |     net::socket::MsgFlags, | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | use super::*; | ||||||
|  | use crate::fs::IoEvents as Events; | ||||||
|  | use crate::fs::{GetIfConf, GetIfReqWithRawCmd, GetReadBufLen, IoctlCmd}; | ||||||
|  | 
 | ||||||
|  | pub struct DatagramSocket<A: Addr + 'static, R: Runtime> { | ||||||
|  |     common: Arc<Common<A, R>>, | ||||||
|  |     state: RwLock<State>, | ||||||
|  |     sender: Arc<Sender<A, R>>, | ||||||
|  |     receiver: Arc<Receiver<A, R>>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr, R: Runtime> DatagramSocket<A, R> { | ||||||
|  |     pub fn new(nonblocking: bool) -> Result<Self> { | ||||||
|  |         let common = Arc::new(Common::new(Type::DGRAM, nonblocking, None)?); | ||||||
|  |         let state = RwLock::new(State::new()); | ||||||
|  |         let sender = Sender::new(common.clone()); | ||||||
|  |         let receiver = Receiver::new(common.clone()); | ||||||
|  |         Ok(Self { | ||||||
|  |             common, | ||||||
|  |             state, | ||||||
|  |             sender, | ||||||
|  |             receiver, | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn new_pair(nonblocking: bool) -> Result<(Self, Self)> { | ||||||
|  |         let (common1, common2) = Common::new_pair(Type::DGRAM, nonblocking)?; | ||||||
|  |         let socket1 = Self::new_connected(common1); | ||||||
|  |         let socket2 = Self::new_connected(common2); | ||||||
|  |         Ok((socket1, socket2)) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn new_connected(common: Common<A, R>) -> Self { | ||||||
|  |         let common = Arc::new(common); | ||||||
|  |         let state = RwLock::new(State::new_connected()); | ||||||
|  |         let sender = Sender::new(common.clone()); | ||||||
|  |         let receiver = Receiver::new(common.clone()); | ||||||
|  |         receiver.initiate_async_recv(); | ||||||
|  |         Self { | ||||||
|  |             common, | ||||||
|  |             state, | ||||||
|  |             sender, | ||||||
|  |             receiver, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn domain(&self) -> Domain { | ||||||
|  |         A::domain() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn host_fd(&self) -> FileDesc { | ||||||
|  |         self.common.host_fd() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn status_flags(&self) -> StatusFlags { | ||||||
|  |         // Only support O_NONBLOCK
 | ||||||
|  |         if self.common.nonblocking() { | ||||||
|  |             StatusFlags::O_NONBLOCK | ||||||
|  |         } else { | ||||||
|  |             StatusFlags::empty() | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { | ||||||
|  |         // Only support O_NONBLOCK
 | ||||||
|  |         let nonblocking = new_flags.is_nonblocking(); | ||||||
|  |         self.common.set_nonblocking(nonblocking); | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// When creating a datagram socket, you can use `bind` to bind the socket
 | ||||||
|  |     /// to a address, hence another socket can send data to this address.
 | ||||||
|  |     ///
 | ||||||
|  |     /// Binding is divided into explicit and implicit. Invoking `bind` is
 | ||||||
|  |     /// explicit binding, while invoking `sendto` / `sendmsg` / `connect`
 | ||||||
|  |     /// will trigger implicit binding.
 | ||||||
|  |     ///
 | ||||||
|  |     /// Datagram sockets can only bind once. You should use explicit binding or
 | ||||||
|  |     /// just implicit binding. The explicit binding will failed if it happens after
 | ||||||
|  |     /// a implicit binding.
 | ||||||
|  |     pub fn bind(&self, addr: &A) -> Result<()> { | ||||||
|  |         let mut state = self.state.write().unwrap(); | ||||||
|  |         if state.is_bound() { | ||||||
|  |             return_errno!(EINVAL, "The socket is already bound to an address"); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         do_bind(self.host_fd(), addr)?; | ||||||
|  | 
 | ||||||
|  |         self.common.set_addr(addr); | ||||||
|  |         state.mark_explicit_bind(); | ||||||
|  |         // Start async recv after explicit binding or implicit binding
 | ||||||
|  |         self.receiver.initiate_async_recv(); | ||||||
|  | 
 | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// Datagram sockets provide only connectionless interactions, But datagram sockets
 | ||||||
|  |     /// can also use connect to associate a socket with a specific address.
 | ||||||
|  |     /// After connection, any data sent on the socket is automatically addressed to the
 | ||||||
|  |     /// connected peer, and only data received from that peer is delivered to the user.
 | ||||||
|  |     ///
 | ||||||
|  |     /// Unlike stream sockets, datagram sockets can connect multiple times. But the socket
 | ||||||
|  |     /// can only connect to one peer in the same time; a second connect will change the
 | ||||||
|  |     /// peer address, and a connect to a address with family AF_UNSPEC will dissolve the
 | ||||||
|  |     /// association ("disconnect" or "unconnect").
 | ||||||
|  |     ///
 | ||||||
|  |     /// Before connection you can only use `sendto` / `sendmsg` / `recvfrom` / `recvmsg`.
 | ||||||
|  |     /// Only after connection, you can use `read` / `recv` / `write` / `send`.
 | ||||||
|  |     /// And you can ignore the address in `sendto` / `sendmsg` if you just want to
 | ||||||
|  |     /// send data to the connected peer.
 | ||||||
|  |     ///
 | ||||||
|  |     /// Ref 1: http://osr507doc.xinuos.com/en/netguide/disockD.connecting_datagrams.html
 | ||||||
|  |     /// Ref 2: https://www.masterraghu.com/subjects/np/introduction/unix_network_programming_v1.3/ch08lev1sec11.html
 | ||||||
|  |     pub fn connect(&self, peer_addr: Option<&A>) -> Result<()> { | ||||||
|  |         let mut state = self.state.write().unwrap(); | ||||||
|  | 
 | ||||||
|  |         // if previous peer.is_default() and peer_addr.is_none()
 | ||||||
|  |         // is unspec, so the situation exists that both
 | ||||||
|  |         // !state.is_connected() and peer_addr.is_none() are true.
 | ||||||
|  | 
 | ||||||
|  |         if let Some(peer) = peer_addr { | ||||||
|  |             do_connect(self.host_fd(), Some(peer))?; | ||||||
|  | 
 | ||||||
|  |             self.receiver.reset_shutdown(); | ||||||
|  |             self.sender.reset_shutdown(); | ||||||
|  |             self.common.set_peer_addr(peer); | ||||||
|  | 
 | ||||||
|  |             if peer.is_default() { | ||||||
|  |                 state.mark_disconnected(); | ||||||
|  |             } else { | ||||||
|  |                 state.mark_connected(); | ||||||
|  |             } | ||||||
|  |             if !state.is_bound() { | ||||||
|  |                 state.mark_implicit_bind(); | ||||||
|  |                 // Start async recv after explicit binding or implicit binding
 | ||||||
|  |                 self.receiver.initiate_async_recv(); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |         // TODO: update binding address in some cases
 | ||||||
|  |         // For a ipv4 socket bound to 0.0.0.0 (INADDR_ANY), if you do connection
 | ||||||
|  |         // to 127.0.0.1 (Local IP address), the IP address of the socket will
 | ||||||
|  |         // change to 127.0.0.1 too. And if connect to non-local IP address, linux
 | ||||||
|  |         // will assign a address to the socket.
 | ||||||
|  |         // In both cases, we should update the binding address that we stored.
 | ||||||
|  |         } else { | ||||||
|  |             do_connect::<A>(self.host_fd(), None)?; | ||||||
|  | 
 | ||||||
|  |             self.common.reset_peer_addr(); | ||||||
|  |             state.mark_disconnected(); | ||||||
|  | 
 | ||||||
|  |             // TODO: clear binding in some cases.
 | ||||||
|  |             // Disconnect will effect the binding address. In Linux, for socket that
 | ||||||
|  |             // explicit bound to local IP address, disconnect will clear the binding address,
 | ||||||
|  |             // but leave the port intact. For socket with implicit bound, disconnect will
 | ||||||
|  |             // clear both the address and port.
 | ||||||
|  |         } | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Close the datagram socket, cancel pending iouring requests
 | ||||||
|  |     pub fn close(&self) -> Result<()> { | ||||||
|  |         self.sender.shutdown(); | ||||||
|  |         self.receiver.shutdown(); | ||||||
|  |         self.common.set_closed(); | ||||||
|  |         self.cancel_requests(); | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// Shutdown the udp socket. This syscall is very TCP-oriented, but it is also useful for udp socket.
 | ||||||
|  |     /// Not like tcp, shutdown does nothing on the wire, it only changes shutdown states.
 | ||||||
|  |     /// The shutdown states block the io-uring request of receiving or sending message.
 | ||||||
|  |     pub fn shutdown(&self, how: Shutdown) -> Result<()> { | ||||||
|  |         let state = self.state.read().unwrap(); | ||||||
|  |         if !state.is_connected() { | ||||||
|  |             return_errno!(ENOTCONN, "The udp socket is not connected"); | ||||||
|  |         } | ||||||
|  |         drop(state); | ||||||
|  |         match how { | ||||||
|  |             Shutdown::Read => { | ||||||
|  |                 self.common.host_shutdown(how)?; | ||||||
|  |                 self.receiver.shutdown(); | ||||||
|  |                 self.common.pollee().add_events(Events::IN); | ||||||
|  |             } | ||||||
|  |             Shutdown::Write => { | ||||||
|  |                 if self.sender.is_empty() { | ||||||
|  |                     self.common.host_shutdown(how)?; | ||||||
|  |                 } | ||||||
|  |                 self.sender.shutdown(); | ||||||
|  |                 self.common.pollee().add_events(Events::OUT); | ||||||
|  |             } | ||||||
|  |             Shutdown::Both => { | ||||||
|  |                 self.common.host_shutdown(Shutdown::Read)?; | ||||||
|  |                 if self.sender.is_empty() { | ||||||
|  |                     self.common.host_shutdown(Shutdown::Write)?; | ||||||
|  |                 } | ||||||
|  |                 self.receiver.shutdown(); | ||||||
|  |                 self.sender.shutdown(); | ||||||
|  |                 self.common | ||||||
|  |                     .pollee() | ||||||
|  |                     .add_events(Events::IN | Events::OUT | Events::HUP); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn read(&self, buf: &mut [u8]) -> Result<usize> { | ||||||
|  |         self.readv(&mut [buf]) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn readv(&self, bufs: &mut [&mut [u8]]) -> Result<usize> { | ||||||
|  |         let state = self.state.read().unwrap(); | ||||||
|  |         drop(state); | ||||||
|  | 
 | ||||||
|  |         self.recvmsg(bufs, RecvFlags::empty(), None) | ||||||
|  |             .map(|(ret, ..)| ret) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// You can not invoke `recvfrom` directly after creating a datagram socket.
 | ||||||
|  |     /// That is because `recvfrom` doesn't privide a implicit binding. If you
 | ||||||
|  |     /// don't do a explicit or implicit binding, the sender doesn't know where
 | ||||||
|  |     /// to send the data.
 | ||||||
|  |     pub fn recvmsg( | ||||||
|  |         &self, | ||||||
|  |         bufs: &mut [&mut [u8]], | ||||||
|  |         flags: RecvFlags, | ||||||
|  |         control: Option<&mut [u8]>, | ||||||
|  |     ) -> Result<(usize, Option<A>, MsgFlags, usize)> { | ||||||
|  |         self.receiver.recvmsg(bufs, flags, control) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn write(&self, buf: &[u8]) -> Result<usize> { | ||||||
|  |         self.writev(&[buf]) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn writev(&self, bufs: &[&[u8]]) -> Result<usize> { | ||||||
|  |         self.sendmsg(bufs, None, SendFlags::empty(), None) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn sendmsg( | ||||||
|  |         &self, | ||||||
|  |         bufs: &[&[u8]], | ||||||
|  |         addr: Option<&A>, | ||||||
|  |         flags: SendFlags, | ||||||
|  |         control: Option<&[u8]>, | ||||||
|  |     ) -> Result<usize> { | ||||||
|  |         let state = self.state.read().unwrap(); | ||||||
|  |         if addr.is_none() && !state.is_connected() { | ||||||
|  |             return_errno!(EDESTADDRREQ, "Destination address required"); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         drop(state); | ||||||
|  |         let res = if let Some(addr) = addr { | ||||||
|  |             self.sender.sendmsg(bufs, addr, flags, control) | ||||||
|  |         } else { | ||||||
|  |             let peer = self.common.peer_addr(); | ||||||
|  |             if let Some(peer) = peer.as_ref() { | ||||||
|  |                 self.sender.sendmsg(bufs, peer, flags, control) | ||||||
|  |             } else { | ||||||
|  |                 return_errno!(EDESTADDRREQ, "Destination address required"); | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         let mut state = self.state.write().unwrap(); | ||||||
|  |         if !state.is_bound() { | ||||||
|  |             state.mark_implicit_bind(); | ||||||
|  |             // Start async recv after explicit binding or implicit binding
 | ||||||
|  |             self.receiver.initiate_async_recv(); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         res | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn poll(&self, mask: Events, poller: Option<&Poller>) -> Events { | ||||||
|  |         let pollee = self.common.pollee(); | ||||||
|  |         pollee.poll(mask, poller) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn addr(&self) -> Result<A> { | ||||||
|  |         let common = &self.common; | ||||||
|  | 
 | ||||||
|  |         // Always get addr from host.
 | ||||||
|  |         // Because for IP socket, users can specify "0" as port and the kernel should select a usable port for him.
 | ||||||
|  |         // Thus, when calling getsockname, this should be updated.
 | ||||||
|  |         let addr = common.get_addr_from_host()?; | ||||||
|  |         common.set_addr(&addr); | ||||||
|  |         Ok(addr) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn notifier(&self) -> &IoNotifier { | ||||||
|  |         let notifier = self.common.notifier(); | ||||||
|  |         notifier | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn peer_addr(&self) -> Result<A> { | ||||||
|  |         let state = self.state.read().unwrap(); | ||||||
|  |         if !state.is_connected() { | ||||||
|  |             return_errno!(ENOTCONN, "the socket is not connected"); | ||||||
|  |         } | ||||||
|  |         Ok(self.common.peer_addr().unwrap()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn errno(&self) -> Option<Errno> { | ||||||
|  |         self.common.errno() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn ioctl(&self, cmd: &mut dyn IoctlCmd) -> Result<()> { | ||||||
|  |         match_ioctl_cmd_mut!(&mut *cmd, { | ||||||
|  |             cmd: GetSockOptRawCmd => { | ||||||
|  |                 cmd.execute(self.host_fd())?; | ||||||
|  |             }, | ||||||
|  |             cmd: SetSockOptRawCmd => { | ||||||
|  |                 cmd.execute(self.host_fd())?; | ||||||
|  |             }, | ||||||
|  |             cmd: SetRecvTimeoutCmd => { | ||||||
|  |                 self.set_recv_timeout(*cmd.timeout()); | ||||||
|  |             }, | ||||||
|  |             cmd: SetSendTimeoutCmd => { | ||||||
|  |                 self.set_send_timeout(*cmd.timeout()); | ||||||
|  |             }, | ||||||
|  |             cmd: GetRecvTimeoutCmd => { | ||||||
|  |                 let timeval = timeout_to_timeval(self.recv_timeout()); | ||||||
|  |                 cmd.set_output(timeval); | ||||||
|  |             }, | ||||||
|  |             cmd: GetSendTimeoutCmd => { | ||||||
|  |                 let timeval = timeout_to_timeval(self.send_timeout()); | ||||||
|  |                 cmd.set_output(timeval); | ||||||
|  |             }, | ||||||
|  |             cmd: GetAcceptConnCmd => { | ||||||
|  |                 // Datagram doesn't support listen
 | ||||||
|  |                 cmd.set_output(0); | ||||||
|  |             }, | ||||||
|  |             cmd: GetDomainCmd => { | ||||||
|  |                 cmd.set_output(self.domain() as _); | ||||||
|  |             }, | ||||||
|  |             cmd: GetErrorCmd => { | ||||||
|  |                 let error: i32 = self.errno().map(|err| err as i32).unwrap_or(0); | ||||||
|  |                 cmd.set_output(error); | ||||||
|  |             }, | ||||||
|  |             cmd: GetPeerNameCmd => { | ||||||
|  |                 let peer = self.peer_addr()?; | ||||||
|  |                 cmd.set_output(AddrStorage(peer.to_c_storage())); | ||||||
|  |             }, | ||||||
|  |             cmd: GetTypeCmd => { | ||||||
|  |                 cmd.set_output(self.common.type_() as _); | ||||||
|  |             }, | ||||||
|  |             cmd: GetIfReqWithRawCmd => { | ||||||
|  |                 cmd.execute(self.host_fd())?; | ||||||
|  |             }, | ||||||
|  |             cmd: GetIfConf => { | ||||||
|  |                 cmd.execute(self.host_fd())?; | ||||||
|  |             }, | ||||||
|  |             cmd: GetReadBufLen => { | ||||||
|  |                 let read_buf_len = self.receiver.ready_len(); | ||||||
|  |                 cmd.set_output(read_buf_len as _); | ||||||
|  |             }, | ||||||
|  |             _ => { | ||||||
|  |                 return_errno!(EINVAL, "Not supported yet"); | ||||||
|  |             } | ||||||
|  |         }); | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn send_timeout(&self) -> Option<Duration> { | ||||||
|  |         self.common.send_timeout() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn recv_timeout(&self) -> Option<Duration> { | ||||||
|  |         self.common.recv_timeout() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn set_send_timeout(&self, timeout: Duration) { | ||||||
|  |         self.common.set_send_timeout(timeout); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn set_recv_timeout(&self, timeout: Duration) { | ||||||
|  |         self.common.set_recv_timeout(timeout); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn cancel_requests(&self) { | ||||||
|  |         self.receiver.cancel_recv_requests(); | ||||||
|  |         self.sender.try_clear_msg_queue_when_close(); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static, R: Runtime> Drop for DatagramSocket<A, R> { | ||||||
|  |     fn drop(&mut self) { | ||||||
|  |         self.common.set_closed(); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static, R: Runtime> std::fmt::Debug for DatagramSocket<A, R> { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         f.debug_struct("DatagramSocket") | ||||||
|  |             .field("common", &self.common) | ||||||
|  |             .finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Debug)] | ||||||
|  | struct State { | ||||||
|  |     bind_state: BindState, | ||||||
|  |     is_connected: bool, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl State { | ||||||
|  |     pub fn new() -> Self { | ||||||
|  |         Self { | ||||||
|  |             bind_state: BindState::Unbound, | ||||||
|  |             is_connected: false, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn new_connected() -> Self { | ||||||
|  |         Self { | ||||||
|  |             bind_state: BindState::Unbound, | ||||||
|  |             is_connected: true, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn is_bound(&self) -> bool { | ||||||
|  |         self.bind_state.is_bound() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[allow(dead_code)] | ||||||
|  |     pub fn is_explicit_bound(&self) -> bool { | ||||||
|  |         self.bind_state.is_explicit_bound() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[allow(dead_code)] | ||||||
|  |     pub fn is_implicit_bound(&self) -> bool { | ||||||
|  |         self.bind_state.is_implicit_bound() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn is_connected(&self) -> bool { | ||||||
|  |         self.is_connected | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn mark_explicit_bind(&mut self) { | ||||||
|  |         self.bind_state = BindState::ExplicitBound; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn mark_implicit_bind(&mut self) { | ||||||
|  |         self.bind_state = BindState::ImplicitBound; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn mark_connected(&mut self) { | ||||||
|  |         self.is_connected = true; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn mark_disconnected(&mut self) { | ||||||
|  |         self.is_connected = false; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Debug)] | ||||||
|  | enum BindState { | ||||||
|  |     Unbound, | ||||||
|  |     ExplicitBound, | ||||||
|  |     ImplicitBound, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl BindState { | ||||||
|  |     pub fn is_bound(&self) -> bool { | ||||||
|  |         match self { | ||||||
|  |             Self::Unbound => false, | ||||||
|  |             _ => true, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[allow(dead_code)] | ||||||
|  |     pub fn is_explicit_bound(&self) -> bool { | ||||||
|  |         match self { | ||||||
|  |             Self::ExplicitBound => true, | ||||||
|  |             _ => false, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[allow(dead_code)] | ||||||
|  |     pub fn is_implicit_bound(&self) -> bool { | ||||||
|  |         match self { | ||||||
|  |             Self::ImplicitBound => true, | ||||||
|  |             _ => false, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										20
									
								
								src/libos/src/net/socket/uring/datagram/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										20
									
								
								src/libos/src/net/socket/uring/datagram/mod.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,20 @@ | |||||||
|  | //! Datagram sockets.
 | ||||||
|  | mod generic; | ||||||
|  | mod receiver; | ||||||
|  | mod sender; | ||||||
|  | 
 | ||||||
|  | use self::receiver::Receiver; | ||||||
|  | use self::sender::Sender; | ||||||
|  | use crate::net::socket::sockopt::*; | ||||||
|  | use crate::net::socket::uring::common::{do_bind, do_connect, Common}; | ||||||
|  | use crate::net::socket::uring::runtime::Runtime; | ||||||
|  | use crate::prelude::*; | ||||||
|  | 
 | ||||||
|  | pub use generic::DatagramSocket; | ||||||
|  | 
 | ||||||
|  | use crate::net::socket::sockopt::{ | ||||||
|  |     timeout_to_timeval, GetRecvTimeoutCmd, GetSendTimeoutCmd, SetRecvTimeoutCmd, SetSendTimeoutCmd, | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | const MAX_BUF_SIZE: usize = 64 * 1024; | ||||||
|  | const OPTMEM_MAX: usize = 64 * 1024; | ||||||
							
								
								
									
										382
									
								
								src/libos/src/net/socket/uring/datagram/receiver.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										382
									
								
								src/libos/src/net/socket/uring/datagram/receiver.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,382 @@ | |||||||
|  | use core::time::Duration; | ||||||
|  | use std::mem::MaybeUninit; | ||||||
|  | 
 | ||||||
|  | use crate::events::Poller; | ||||||
|  | use crate::net::socket::MsgFlags; | ||||||
|  | use io_uring_callback::{Fd, IoHandle}; | ||||||
|  | use sgx_untrusted_alloc::{MaybeUntrusted, UntrustedBox}; | ||||||
|  | 
 | ||||||
|  | use crate::fs::IoEvents as Events; | ||||||
|  | use crate::net::socket::uring::common::Common; | ||||||
|  | use crate::net::socket::uring::runtime::Runtime; | ||||||
|  | use crate::prelude::*; | ||||||
|  | 
 | ||||||
|  | pub struct Receiver<A: Addr + 'static, R: Runtime> { | ||||||
|  |     common: Arc<Common<A, R>>, | ||||||
|  |     inner: Mutex<Inner>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr, R: Runtime> Receiver<A, R> { | ||||||
|  |     pub fn new(common: Arc<Common<A, R>>) -> Arc<Self> { | ||||||
|  |         let inner = Mutex::new(Inner::new()); | ||||||
|  |         Arc::new(Self { common, inner }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn recvmsg( | ||||||
|  |         self: &Arc<Self>, | ||||||
|  |         bufs: &mut [&mut [u8]], | ||||||
|  |         flags: RecvFlags, | ||||||
|  |         mut control: Option<&mut [u8]>, | ||||||
|  |     ) -> Result<(usize, Option<A>, MsgFlags, usize)> { | ||||||
|  |         let mask = Events::IN; | ||||||
|  |         // Initialize the poller only when needed
 | ||||||
|  |         let mut poller = None; | ||||||
|  |         let mut timeout = self.common.recv_timeout(); | ||||||
|  |         loop { | ||||||
|  |             // Attempt to recv
 | ||||||
|  |             let res = self.try_recvmsg(bufs, flags, &mut control); | ||||||
|  |             if !res.has_errno(EAGAIN) { | ||||||
|  |                 return res; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Need more handles for flags not MSG_DONTWAIT
 | ||||||
|  |             // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT
 | ||||||
|  |             if self.common.nonblocking() | ||||||
|  |                 || flags.contains(RecvFlags::MSG_DONTWAIT) | ||||||
|  |                 || flags.contains(RecvFlags::MSG_ERRQUEUE) | ||||||
|  |             { | ||||||
|  |                 return_errno!(EAGAIN, "no data are present to be received"); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Wait for interesting events by polling
 | ||||||
|  |             if poller.is_none() { | ||||||
|  |                 let new_poller = Poller::new(); | ||||||
|  |                 self.common.pollee().connect_poller(mask, &new_poller); | ||||||
|  |                 poller = Some(new_poller); | ||||||
|  |             } | ||||||
|  |             let events = self.common.pollee().poll(mask, None); | ||||||
|  |             if events.is_empty() { | ||||||
|  |                 let ret = poller.as_ref().unwrap().wait_timeout(timeout.as_mut()); | ||||||
|  |                 if let Err(e) = ret { | ||||||
|  |                     warn!("recv wait errno = {:?}", e.errno()); | ||||||
|  |                     match e.errno() { | ||||||
|  |                         ETIMEDOUT => { | ||||||
|  |                             return_errno!(EAGAIN, "timeout reached") | ||||||
|  |                         } | ||||||
|  |                         _ => { | ||||||
|  |                             return_errno!(e.errno(), "wait error") | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn try_recvmsg( | ||||||
|  |         self: &Arc<Self>, | ||||||
|  |         bufs: &mut [&mut [u8]], | ||||||
|  |         flags: RecvFlags, | ||||||
|  |         control: &mut Option<&mut [u8]>, | ||||||
|  |     ) -> Result<(usize, Option<A>, MsgFlags, usize)> { | ||||||
|  |         let mut inner = self.inner.lock(); | ||||||
|  | 
 | ||||||
|  |         if !flags.is_empty() && flags.contains(RecvFlags::MSG_OOB | RecvFlags::MSG_CMSG_CLOEXEC) { | ||||||
|  |             // todo!("Support other flags");
 | ||||||
|  |             return_errno!(EINVAL, "the socket flags is not supported"); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Mark the socket as non-readable since Datagram uses single packet
 | ||||||
|  |         self.common.pollee().del_events(Events::IN); | ||||||
|  | 
 | ||||||
|  |         let mut recv_bytes = 0; | ||||||
|  |         let mut msg_flags = MsgFlags::empty(); | ||||||
|  |         let recv_addr = inner.get_packet_addr(); | ||||||
|  |         let msg_controllen = inner.control_len.unwrap_or(0); | ||||||
|  |         let user_controllen = control.as_ref().map_or(0, |buf| buf.len()); | ||||||
|  | 
 | ||||||
|  |         // Copy ancillary data from control buffer
 | ||||||
|  |         if user_controllen > super::OPTMEM_MAX { | ||||||
|  |             return_errno!(EINVAL, "invalid msg control length"); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if user_controllen < msg_controllen { | ||||||
|  |             msg_flags = msg_flags | MsgFlags::MSG_CTRUNC | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if msg_controllen > 0 { | ||||||
|  |             let copied_bytes = msg_controllen.min(user_controllen); | ||||||
|  |             control | ||||||
|  |                 .as_mut() | ||||||
|  |                 .map(|buf| buf[..copied_bytes].copy_from_slice(&inner.msg_control[..copied_bytes])); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Copy data from the recv buffer to the bufs
 | ||||||
|  |         let copied_bytes = inner.try_copy_buf(bufs); | ||||||
|  |         if let Some(copied_bytes) = copied_bytes { | ||||||
|  |             let bufs_len: usize = bufs.iter().map(|buf| buf.len()).sum(); | ||||||
|  | 
 | ||||||
|  |             // If user provided buffer length is smaller than kernel received datagram length,
 | ||||||
|  |             // discard the datagram and set MsgFlags::MSG_TRUNC in returned msg_flags.
 | ||||||
|  |             if bufs_len < inner.recv_len().unwrap() { | ||||||
|  |                 // update msg.msg_flags to MSG_TRUNC
 | ||||||
|  |                 msg_flags = msg_flags | MsgFlags::MSG_TRUNC | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             // If user provided flags contain MSG_TRUNC, the return received length should be
 | ||||||
|  |             // kernel receiver buffer length, vice versa should return truly copied bytes length.
 | ||||||
|  |             recv_bytes = if flags.contains(RecvFlags::MSG_TRUNC) { | ||||||
|  |                 inner.recv_len().unwrap() | ||||||
|  |             } else { | ||||||
|  |                 copied_bytes | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             // When flags contain MSG_PEEK and there is data in socket recv buffer,
 | ||||||
|  |             // it is unnecessary to send blocking recv request (do_recv) to fetch data
 | ||||||
|  |             // from iouring buffer, which may flush the data in recv buffer.
 | ||||||
|  |             // When flags don't contain MSG_PEEK or there is no available data,
 | ||||||
|  |             // it is time to send blocking request to iouring for notifying events.
 | ||||||
|  |             if !flags.contains(RecvFlags::MSG_PEEK) { | ||||||
|  |                 self.do_recv(&mut inner); | ||||||
|  |             } | ||||||
|  |             return Ok((recv_bytes, recv_addr, msg_flags, msg_controllen)); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // In some situantions of MSG_ERRQUEUE, users only require control buffer but setting iovec length to zero.
 | ||||||
|  |         if msg_controllen > 0 { | ||||||
|  |             return Ok((recv_bytes, recv_addr, msg_flags, msg_controllen)); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Handle iouring message error
 | ||||||
|  |         if let Some(errno) = inner.error { | ||||||
|  |             // Reset error
 | ||||||
|  |             inner.error = None; | ||||||
|  |             self.common.pollee().del_events(Events::ERR); | ||||||
|  |             return_errno!(errno, "recv failed"); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if inner.is_shutdown { | ||||||
|  |             if self.common.nonblocking() | ||||||
|  |                 || flags.contains(RecvFlags::MSG_DONTWAIT) | ||||||
|  |                 || flags.contains(RecvFlags::MSG_ERRQUEUE) | ||||||
|  |             { | ||||||
|  |                 return_errno!(Errno::EWOULDBLOCK, "the socket recv has been shutdown"); | ||||||
|  |             } else { | ||||||
|  |                 return Ok((0, None, msg_flags, 0)); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         self.do_recv(&mut inner); | ||||||
|  |         return_errno!(EAGAIN, "try recv again"); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn do_recv(self: &Arc<Self>, inner: &mut MutexGuard<Inner>) { | ||||||
|  |         if inner.io_handle.is_some() || self.common.is_closed() { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |         // Clear recv_len and error
 | ||||||
|  |         inner.recv_len.take(); | ||||||
|  |         inner.control_len.take(); | ||||||
|  |         inner.error.take(); | ||||||
|  | 
 | ||||||
|  |         if inner.is_shutdown { | ||||||
|  |             info!("do_recv early return, the socket recv has been shutdown"); | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         let receiver = self.clone(); | ||||||
|  |         // Init the callback invoked upon the completion of the async recv
 | ||||||
|  |         let complete_fn = move |retval: i32| { | ||||||
|  |             let mut inner = receiver.inner.lock(); | ||||||
|  | 
 | ||||||
|  |             // Release the handle to the async recv
 | ||||||
|  |             inner.io_handle.take(); | ||||||
|  | 
 | ||||||
|  |             // Handle error
 | ||||||
|  |             if retval < 0 { | ||||||
|  |                 // TODO: Should we filter the error case? Do we have the ability to filter?
 | ||||||
|  |                 // We only filter the normal case now. According to the man page of recvmsg,
 | ||||||
|  |                 // these errors should not happen, since our fd and arguments should always
 | ||||||
|  |                 // be valid unless being attacked.
 | ||||||
|  | 
 | ||||||
|  |                 // TODO: guard against Iago attack through errno
 | ||||||
|  |                 let errno = Errno::from(-retval as u32); | ||||||
|  |                 inner.error = Some(errno); | ||||||
|  |                 receiver.common.set_errno(errno); | ||||||
|  |                 // TODO: add PRI event if set SO_SELECT_ERR_QUEUE
 | ||||||
|  |                 receiver.common.pollee().add_events(Events::ERR); | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Handle the normal case of a successful read
 | ||||||
|  |             inner.recv_len = Some(retval as usize); | ||||||
|  | 
 | ||||||
|  |             let control_len = inner.req.msg.msg_controllen; | ||||||
|  |             inner.control_len = Some(control_len); | ||||||
|  | 
 | ||||||
|  |             receiver.common.pollee().add_events(Events::IN); | ||||||
|  | 
 | ||||||
|  |             // We don't do_recv() here, since do_recv() will clear the recv message.
 | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         // Generate the async recv request
 | ||||||
|  |         let msghdr_ptr = inner.new_recv_req(); | ||||||
|  | 
 | ||||||
|  |         // Submit the async recv to io_uring
 | ||||||
|  |         let io_uring = self.common.io_uring(); | ||||||
|  |         let host_fd = Fd(self.common.host_fd() as _); | ||||||
|  |         let handle = unsafe { io_uring.recvmsg(host_fd, msghdr_ptr, 0, complete_fn) }; | ||||||
|  |         inner.io_handle.replace(handle); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn initiate_async_recv(self: &Arc<Self>) { | ||||||
|  |         let mut inner = self.inner.lock(); | ||||||
|  |         self.do_recv(&mut inner); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn cancel_recv_requests(&self) { | ||||||
|  |         { | ||||||
|  |             let inner = self.inner.lock(); | ||||||
|  |             if let Some(io_handle) = &inner.io_handle { | ||||||
|  |                 let io_uring = self.common.io_uring(); | ||||||
|  |                 unsafe { io_uring.cancel(io_handle) }; | ||||||
|  |             } else { | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // wait for the cancel to complete
 | ||||||
|  |         let poller = Poller::new(); | ||||||
|  |         let mask = Events::ERR | Events::IN; | ||||||
|  |         self.common.pollee().connect_poller(mask, &poller); | ||||||
|  | 
 | ||||||
|  |         loop { | ||||||
|  |             let pending_request_exist = { | ||||||
|  |                 let inner = self.inner.lock(); | ||||||
|  |                 inner.io_handle.is_some() | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             if pending_request_exist { | ||||||
|  |                 let mut timeout = Some(Duration::from_secs(10)); | ||||||
|  |                 let ret = poller.wait_timeout(timeout.as_mut()); | ||||||
|  |                 if let Err(e) = ret { | ||||||
|  |                     warn!("wait cancel recv request error = {:?}", e.errno()); | ||||||
|  |                     continue; | ||||||
|  |                 } | ||||||
|  |             } else { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// Shutdown udp receiver.
 | ||||||
|  |     pub fn shutdown(&self) { | ||||||
|  |         let mut inner = self.inner.lock(); | ||||||
|  |         inner.is_shutdown = true; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// Reset udp receiver shutdown state.
 | ||||||
|  |     pub fn reset_shutdown(&self) { | ||||||
|  |         let mut inner = self.inner.lock(); | ||||||
|  |         inner.is_shutdown = false; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn ready_len(&self) -> usize { | ||||||
|  |         let inner = self.inner.lock(); | ||||||
|  |         inner.recv_len().unwrap_or(0) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | struct Inner { | ||||||
|  |     recv_buf: UntrustedBox<[u8]>, | ||||||
|  |     // Datagram sockets in various domains permit zero-length datagrams.
 | ||||||
|  |     // Hence the recv_len might be 0.
 | ||||||
|  |     recv_len: Option<usize>, | ||||||
|  |     // When the recv_buf content length is greater than user buffer,
 | ||||||
|  |     // store the offset for the recv_buf for read loop
 | ||||||
|  |     recv_buf_offset: usize, | ||||||
|  |     msg_control: UntrustedBox<[u8]>, | ||||||
|  |     control_len: Option<usize>, | ||||||
|  |     req: UntrustedBox<RecvReq>, | ||||||
|  |     io_handle: Option<IoHandle>, | ||||||
|  |     error: Option<Errno>, | ||||||
|  |     is_shutdown: bool, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | unsafe impl Send for Inner {} | ||||||
|  | 
 | ||||||
|  | impl Inner { | ||||||
|  |     pub fn new() -> Self { | ||||||
|  |         Self { | ||||||
|  |             recv_buf: UntrustedBox::new_uninit_slice(super::MAX_BUF_SIZE), | ||||||
|  |             recv_len: None, | ||||||
|  |             recv_buf_offset: 0, | ||||||
|  |             msg_control: UntrustedBox::new_uninit_slice(super::OPTMEM_MAX), | ||||||
|  |             control_len: None, | ||||||
|  |             req: UntrustedBox::new_uninit(), | ||||||
|  |             io_handle: None, | ||||||
|  |             error: None, | ||||||
|  |             is_shutdown: false, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn new_recv_req(&mut self) -> *mut libc::msghdr { | ||||||
|  |         let iovec = libc::iovec { | ||||||
|  |             iov_base: self.recv_buf.as_mut_ptr() as _, | ||||||
|  |             iov_len: self.recv_buf.len(), | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         let msghdr_ptr = &raw mut self.req.msg; | ||||||
|  | 
 | ||||||
|  |         let mut msg: libc::msghdr = unsafe { MaybeUninit::zeroed().assume_init() }; | ||||||
|  |         msg.msg_iov = &raw mut self.req.iovec as _; | ||||||
|  |         msg.msg_iovlen = 1; | ||||||
|  |         msg.msg_name = &raw mut self.req.addr as _; | ||||||
|  |         msg.msg_namelen = std::mem::size_of::<libc::sockaddr_storage>() as _; | ||||||
|  | 
 | ||||||
|  |         msg.msg_control = self.msg_control.as_mut_ptr() as _; | ||||||
|  |         msg.msg_controllen = self.msg_control.len() as _; | ||||||
|  | 
 | ||||||
|  |         self.req.msg = msg; | ||||||
|  |         self.req.iovec = iovec; | ||||||
|  | 
 | ||||||
|  |         msghdr_ptr | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn try_copy_buf(&self, bufs: &mut [&mut [u8]]) -> Option<usize> { | ||||||
|  |         self.recv_len.map(|recv_len| { | ||||||
|  |             let mut copy_len = 0; | ||||||
|  |             for buf in bufs { | ||||||
|  |                 let recv_buf = &self.recv_buf[copy_len..recv_len]; | ||||||
|  |                 if buf.len() <= recv_buf.len() { | ||||||
|  |                     buf.copy_from_slice(&recv_buf[..buf.len()]); | ||||||
|  |                     copy_len += buf.len(); | ||||||
|  |                 } else { | ||||||
|  |                     buf[..recv_buf.len()].copy_from_slice(&recv_buf[..]); | ||||||
|  |                     copy_len += recv_buf.len(); | ||||||
|  |                     break; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             copy_len | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn recv_len(&self) -> Option<usize> { | ||||||
|  |         self.recv_len | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// Return the addr of the received packet if udp socket is not connected.
 | ||||||
|  |     /// Return None if udp socket is connected.
 | ||||||
|  |     pub fn get_packet_addr<A: Addr>(&self) -> Option<A> { | ||||||
|  |         let recv_addr_len = self.req.msg.msg_namelen as usize; | ||||||
|  |         A::from_c_storage(&self.req.addr, recv_addr_len).ok() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[repr(C)] | ||||||
|  | struct RecvReq { | ||||||
|  |     msg: libc::msghdr, | ||||||
|  |     iovec: libc::iovec, | ||||||
|  |     addr: libc::sockaddr_storage, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | unsafe impl MaybeUntrusted for RecvReq {} | ||||||
							
								
								
									
										406
									
								
								src/libos/src/net/socket/uring/datagram/sender.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										406
									
								
								src/libos/src/net/socket/uring/datagram/sender.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,406 @@ | |||||||
|  | use core::time::Duration; | ||||||
|  | use std::ptr::{self}; | ||||||
|  | 
 | ||||||
|  | use io_uring_callback::{Fd, IoHandle}; | ||||||
|  | use libc::c_void; | ||||||
|  | use sgx_untrusted_alloc::{MaybeUntrusted, UntrustedBox}; | ||||||
|  | use std::collections::VecDeque; | ||||||
|  | 
 | ||||||
|  | use crate::events::Poller; | ||||||
|  | use crate::fs::IoEvents as Events; | ||||||
|  | use crate::net::socket::uring::common::Common; | ||||||
|  | use crate::net::socket::uring::runtime::Runtime; | ||||||
|  | use crate::prelude::*; | ||||||
|  | use crate::util::sync::MutexGuard; | ||||||
|  | 
 | ||||||
|  | const SENDMSG_QUEUE_LEN: usize = 16; | ||||||
|  | 
 | ||||||
|  | pub struct Sender<A: Addr + 'static, R: Runtime> { | ||||||
|  |     common: Arc<Common<A, R>>, | ||||||
|  |     inner: Mutex<Inner>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr, R: Runtime> Sender<A, R> { | ||||||
|  |     pub fn new(common: Arc<Common<A, R>>) -> Arc<Self> { | ||||||
|  |         common.pollee().add_events(Events::OUT); | ||||||
|  |         let inner = Mutex::new(Inner::new()); | ||||||
|  |         Arc::new(Self { common, inner }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// Shutdown udp sender.
 | ||||||
|  |     pub fn shutdown(&self) { | ||||||
|  |         let mut inner = self.inner.lock(); | ||||||
|  |         inner.is_shutdown = ShutdownStatus::PreShutdown; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// Reset udp sender shutdown state.
 | ||||||
|  |     pub fn reset_shutdown(&self) { | ||||||
|  |         let mut inner = self.inner.lock(); | ||||||
|  |         inner.is_shutdown = ShutdownStatus::Running; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// Whether no buffer in sender.
 | ||||||
|  |     pub fn is_empty(&self) -> bool { | ||||||
|  |         let inner = self.inner.lock(); | ||||||
|  |         inner.msg_queue.is_empty() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Normally, We will always try to send as long as the kernel send buf is not empty.
 | ||||||
|  |     // However, if the user calls close, we will wait LINGER time
 | ||||||
|  |     // and then cancel on-going or new-issued send requests.
 | ||||||
|  |     pub fn try_clear_msg_queue_when_close(&self) { | ||||||
|  |         let inner = self.inner.lock(); | ||||||
|  |         debug_assert!(inner.is_shutdown()); | ||||||
|  |         if inner.msg_queue.is_empty() { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Wait for linger time to empty the kernel buffer or cancel subsequent requests.
 | ||||||
|  |         drop(inner); | ||||||
|  |         const DEFUALT_LINGER_TIME: usize = 10; | ||||||
|  |         let poller = Poller::new(); | ||||||
|  |         let mask = Events::ERR | Events::OUT; | ||||||
|  |         self.common.pollee().connect_poller(mask, &poller); | ||||||
|  | 
 | ||||||
|  |         loop { | ||||||
|  |             let pending_request_exist = { | ||||||
|  |                 let inner = self.inner.lock(); | ||||||
|  |                 inner.io_handle.is_some() | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             if pending_request_exist { | ||||||
|  |                 let mut timeout = Some(Duration::from_secs(DEFUALT_LINGER_TIME as u64)); | ||||||
|  |                 let ret = poller.wait_timeout(timeout.as_mut()); | ||||||
|  |                 trace!("wait empty send buffer ret = {:?}", ret); | ||||||
|  |                 if let Err(_) = ret { | ||||||
|  |                     // No complete request to wake. Just cancel the send requests.
 | ||||||
|  |                     let io_uring = self.common.io_uring(); | ||||||
|  |                     let inner = self.inner.lock(); | ||||||
|  |                     if let Some(io_handle) = &inner.io_handle { | ||||||
|  |                         unsafe { io_uring.cancel(io_handle) }; | ||||||
|  |                         // Loop again to wait the cancel request to complete
 | ||||||
|  |                         continue; | ||||||
|  |                     } else { | ||||||
|  |                         // No pending request, just break
 | ||||||
|  |                         break; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } else { | ||||||
|  |                 // There is no pending requests
 | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn sendmsg( | ||||||
|  |         self: &Arc<Self>, | ||||||
|  |         bufs: &[&[u8]], | ||||||
|  |         addr: &A, | ||||||
|  |         flags: SendFlags, | ||||||
|  |         control: Option<&[u8]>, | ||||||
|  |     ) -> Result<usize> { | ||||||
|  |         if !flags.is_empty() | ||||||
|  |             && flags.intersects(!(SendFlags::MSG_DONTWAIT | SendFlags::MSG_NOSIGNAL)) | ||||||
|  |         { | ||||||
|  |             error!("Not supported flags: {:?}", flags); | ||||||
|  |             return_errno!(EINVAL, "not supported flags"); | ||||||
|  |         } | ||||||
|  |         let mask = Events::OUT; | ||||||
|  |         // Initialize the poller only when needed
 | ||||||
|  |         let mut poller = None; | ||||||
|  |         let mut timeout = self.common.send_timeout(); | ||||||
|  |         loop { | ||||||
|  |             // Attempt to write
 | ||||||
|  |             let res = self.try_sendmsg(bufs, addr, control); | ||||||
|  |             if !res.has_errno(EAGAIN) { | ||||||
|  |                 return res; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Still some buffer contents pending
 | ||||||
|  |             if self.common.nonblocking() || flags.contains(SendFlags::MSG_DONTWAIT) { | ||||||
|  |                 return_errno!(EAGAIN, "try write again"); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Wait for interesting events by polling
 | ||||||
|  |             if poller.is_none() { | ||||||
|  |                 let new_poller = Poller::new(); | ||||||
|  |                 self.common.pollee().connect_poller(mask, &new_poller); | ||||||
|  |                 poller = Some(new_poller); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             let events = self.common.pollee().poll(mask, None); | ||||||
|  |             if events.is_empty() { | ||||||
|  |                 let ret = poller.as_ref().unwrap().wait_timeout(timeout.as_mut()); | ||||||
|  |                 if let Err(e) = ret { | ||||||
|  |                     warn!("send wait errno = {:?}", e.errno()); | ||||||
|  |                     match e.errno() { | ||||||
|  |                         ETIMEDOUT => { | ||||||
|  |                             return_errno!(EAGAIN, "timeout reached") | ||||||
|  |                         } | ||||||
|  |                         _ => { | ||||||
|  |                             return_errno!(e.errno(), "wait error") | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn try_sendmsg( | ||||||
|  |         self: &Arc<Self>, | ||||||
|  |         bufs: &[&[u8]], | ||||||
|  |         addr: &A, | ||||||
|  |         control: Option<&[u8]>, | ||||||
|  |     ) -> Result<usize> { | ||||||
|  |         let mut inner = self.inner.lock(); | ||||||
|  |         if inner.is_shutdown() { | ||||||
|  |             return_errno!(EPIPE, "the write has been shutdown") | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if let Some(errno) = inner.error { | ||||||
|  |             // Reset error
 | ||||||
|  |             inner.error = None; | ||||||
|  |             self.common.pollee().del_events(Events::ERR); | ||||||
|  |             return_errno!(errno, "write failed"); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         let buf_len: usize = bufs.iter().map(|buf| buf.len()).sum(); | ||||||
|  |         let mut msg = DataMsg::new(buf_len); | ||||||
|  |         let total_copied = msg.copy_buf(bufs)?; | ||||||
|  |         msg.copy_control(control)?; | ||||||
|  | 
 | ||||||
|  |         let msghdr_ptr = new_send_req(&mut msg, addr); | ||||||
|  | 
 | ||||||
|  |         if !inner.msg_queue.push_msg(msg) { | ||||||
|  |             // Msg queue can not push this msg, mark the socket as non-writable
 | ||||||
|  |             self.common.pollee().del_events(Events::OUT); | ||||||
|  |             return_errno!(EAGAIN, "try write again"); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Since the send buffer is not empty, try to flush the buffer
 | ||||||
|  |         if inner.io_handle.is_none() { | ||||||
|  |             self.do_send(&mut inner, msghdr_ptr); | ||||||
|  |         } | ||||||
|  |         Ok(total_copied) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn do_send(self: &Arc<Self>, inner: &mut MutexGuard<Inner>, msghdr_ptr: *const libc::msghdr) { | ||||||
|  |         debug_assert!(!inner.msg_queue.is_empty()); | ||||||
|  |         debug_assert!(inner.io_handle.is_none()); | ||||||
|  |         let sender = self.clone(); | ||||||
|  |         // Submit the async send to io_uring
 | ||||||
|  |         let complete_fn = move |retval: i32| { | ||||||
|  |             let mut inner = sender.inner.lock(); | ||||||
|  |             trace!("send request complete with retval: {}", retval); | ||||||
|  | 
 | ||||||
|  |             // Release the handle to the async recv
 | ||||||
|  |             inner.io_handle.take(); | ||||||
|  | 
 | ||||||
|  |             if retval < 0 { | ||||||
|  |                 // TODO: add PRI event if set SO_SELECT_ERR_QUEUE
 | ||||||
|  |                 let errno = Errno::from(-retval as u32); | ||||||
|  | 
 | ||||||
|  |                 inner.error = Some(errno); | ||||||
|  |                 sender.common.set_errno(errno); | ||||||
|  |                 sender.common.pollee().add_events(Events::ERR); | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Need to handle normal case
 | ||||||
|  |             inner.msg_queue.pop_msg(); | ||||||
|  |             sender.common.pollee().add_events(Events::OUT); | ||||||
|  |             if !inner.msg_queue.is_empty() { | ||||||
|  |                 let msghdr_ptr = inner.msg_queue.first_msg_ptr(); | ||||||
|  |                 debug_assert!(msghdr_ptr.is_some()); | ||||||
|  |                 sender.do_send(&mut inner, msghdr_ptr.unwrap()); | ||||||
|  |             } else if inner.is_shutdown == ShutdownStatus::PreShutdown { | ||||||
|  |                 // The buffer is empty and the write side is shutdown by the user.
 | ||||||
|  |                 // We can safely shutdown host file here.
 | ||||||
|  |                 let _ = sender.common.host_shutdown(Shutdown::Write); | ||||||
|  |                 inner.is_shutdown = ShutdownStatus::PostShutdown | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         // Generate the async recv request
 | ||||||
|  |         let io_uring = self.common.io_uring(); | ||||||
|  |         let host_fd = Fd(self.common.host_fd() as _); | ||||||
|  |         let handle = unsafe { io_uring.sendmsg(host_fd, msghdr_ptr, 0, complete_fn) }; | ||||||
|  |         inner.io_handle.replace(handle); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | fn new_send_req<A: Addr>(dmsg: &mut DataMsg, addr: &A) -> *const libc::msghdr { | ||||||
|  |     let iovec = libc::iovec { | ||||||
|  |         iov_base: dmsg.send_buf.as_ptr() as _, | ||||||
|  |         iov_len: dmsg.send_buf.len(), | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     let (control, controllen) = match &dmsg.control { | ||||||
|  |         Some(control) => (control.as_mut_ptr() as *mut c_void, control.len()), | ||||||
|  |         None => (ptr::null_mut(), 0), | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     dmsg.req.iovec = iovec; | ||||||
|  | 
 | ||||||
|  |     dmsg.req.msg.msg_iov = &raw mut dmsg.req.iovec as _; | ||||||
|  |     dmsg.req.msg.msg_iovlen = 1; | ||||||
|  | 
 | ||||||
|  |     let (c_addr_storage, c_addr_len) = addr.to_c_storage(); | ||||||
|  | 
 | ||||||
|  |     dmsg.req.addr = c_addr_storage; | ||||||
|  |     dmsg.req.msg.msg_name = &raw mut dmsg.req.addr as _; | ||||||
|  |     dmsg.req.msg.msg_namelen = c_addr_len as _; | ||||||
|  |     dmsg.req.msg.msg_control = control; | ||||||
|  |     dmsg.req.msg.msg_controllen = controllen; | ||||||
|  | 
 | ||||||
|  |     &mut dmsg.req.msg | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub struct Inner { | ||||||
|  |     io_handle: Option<IoHandle>, | ||||||
|  |     error: Option<Errno>, | ||||||
|  |     is_shutdown: ShutdownStatus, | ||||||
|  |     msg_queue: MsgQueue, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | unsafe impl Send for Inner {} | ||||||
|  | 
 | ||||||
|  | impl Inner { | ||||||
|  |     pub fn new() -> Self { | ||||||
|  |         Self { | ||||||
|  |             io_handle: None, | ||||||
|  |             error: None, | ||||||
|  |             is_shutdown: ShutdownStatus::Running, | ||||||
|  |             msg_queue: MsgQueue::new(), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// Obtain udp sender shutdown state.
 | ||||||
|  |     #[inline(always)] | ||||||
|  |     pub fn is_shutdown(&self) -> bool { | ||||||
|  |         self.is_shutdown == ShutdownStatus::PreShutdown | ||||||
|  |             || self.is_shutdown == ShutdownStatus::PostShutdown | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[repr(C)] | ||||||
|  | struct SendReq { | ||||||
|  |     msg: libc::msghdr, | ||||||
|  |     iovec: libc::iovec, | ||||||
|  |     addr: libc::sockaddr_storage, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | unsafe impl MaybeUntrusted for SendReq {} | ||||||
|  | 
 | ||||||
|  | struct MsgQueue { | ||||||
|  |     queue: VecDeque<DataMsg>, | ||||||
|  |     curr_size: usize, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl MsgQueue { | ||||||
|  |     #[inline(always)] | ||||||
|  |     fn new() -> Self { | ||||||
|  |         Self { | ||||||
|  |             queue: VecDeque::with_capacity(SENDMSG_QUEUE_LEN), | ||||||
|  |             curr_size: 0, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[inline(always)] | ||||||
|  |     fn size(&self) -> usize { | ||||||
|  |         self.curr_size | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[inline(always)] | ||||||
|  |     fn is_empty(&self) -> bool { | ||||||
|  |         self.queue.is_empty() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Push datagram msg, return true if succeed,
 | ||||||
|  |     // return false if buffer is full.
 | ||||||
|  |     #[inline(always)] | ||||||
|  |     fn push_msg(&mut self, msg: DataMsg) -> bool { | ||||||
|  |         let total_len = msg.len() + self.size(); | ||||||
|  |         if total_len <= super::MAX_BUF_SIZE { | ||||||
|  |             self.curr_size = total_len; | ||||||
|  |             self.queue.push_back(msg); | ||||||
|  |             return true; | ||||||
|  |         } | ||||||
|  |         false | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[inline(always)] | ||||||
|  |     fn pop_msg(&mut self) { | ||||||
|  |         if let Some(msg) = self.queue.pop_front() { | ||||||
|  |             self.curr_size = self.size() - msg.len(); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[inline(always)] | ||||||
|  |     fn first_msg_ptr(&self) -> Option<*const libc::msghdr> { | ||||||
|  |         self.queue | ||||||
|  |             .front() | ||||||
|  |             .map(|data_msg| &data_msg.req.msg as *const libc::msghdr) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Datagram msg contents in untrusted region
 | ||||||
|  | struct DataMsg { | ||||||
|  |     req: UntrustedBox<SendReq>, | ||||||
|  |     send_buf: UntrustedBox<[u8]>, | ||||||
|  |     control: Option<UntrustedBox<[u8]>>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl DataMsg { | ||||||
|  |     #[inline(always)] | ||||||
|  |     fn new(buf_len: usize) -> Self { | ||||||
|  |         Self { | ||||||
|  |             req: UntrustedBox::<SendReq>::new_uninit(), | ||||||
|  |             send_buf: UntrustedBox::new_uninit_slice(buf_len), | ||||||
|  |             control: None, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[inline(always)] | ||||||
|  |     fn copy_buf(&mut self, bufs: &[&[u8]]) -> Result<usize> { | ||||||
|  |         let total_len = self.send_buf.len(); | ||||||
|  |         if total_len > super::MAX_BUF_SIZE { | ||||||
|  |             return_errno!(EMSGSIZE, "the message is too large") | ||||||
|  |         } | ||||||
|  |         // Copy data from the bufs to the send buffer
 | ||||||
|  |         let mut total_copied = 0; | ||||||
|  |         for buf in bufs { | ||||||
|  |             self.send_buf[total_copied..(total_copied + buf.len())].copy_from_slice(buf); | ||||||
|  |             total_copied += buf.len(); | ||||||
|  |         } | ||||||
|  |         Ok(total_copied) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[inline(always)] | ||||||
|  |     fn copy_control(&mut self, control: Option<&[u8]>) -> Result<usize> { | ||||||
|  |         if let Some(msg_control) = control { | ||||||
|  |             let send_controllen = msg_control.len(); | ||||||
|  |             if send_controllen > super::OPTMEM_MAX { | ||||||
|  |                 return_errno!(EINVAL, "invalid msg control length"); | ||||||
|  |             } | ||||||
|  |             let mut send_control_buf = UntrustedBox::new_uninit_slice(send_controllen); | ||||||
|  |             send_control_buf.copy_from_slice(&msg_control[..send_controllen]); | ||||||
|  | 
 | ||||||
|  |             self.control = Some(send_control_buf); | ||||||
|  |             return Ok(send_controllen); | ||||||
|  |         }; | ||||||
|  |         Ok(0) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[inline(always)] | ||||||
|  |     fn len(&self) -> usize { | ||||||
|  |         self.send_buf.len() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Debug, PartialEq)] | ||||||
|  | enum ShutdownStatus { | ||||||
|  |     Running,      // not shutdown
 | ||||||
|  |     PreShutdown,  // start the shutdown process, set by calling shutdown syscall
 | ||||||
|  |     PostShutdown, // shutdown process is done, set when the buffer is empty
 | ||||||
|  | } | ||||||
							
								
								
									
										79
									
								
								src/libos/src/net/socket/uring/file_impl.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										79
									
								
								src/libos/src/net/socket/uring/file_impl.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,79 @@ | |||||||
|  | use super::socket_file::SocketFile; | ||||||
|  | use crate::fs::{ | ||||||
|  |     AccessMode, FileDesc, HostFd, IoEvents, IoNotifier, IoctlCmd, IoctlRawCmd, StatusFlags, | ||||||
|  | }; | ||||||
|  | use crate::prelude::*; | ||||||
|  | use std::{io::SeekFrom, os::unix::raw::off_t}; | ||||||
|  | 
 | ||||||
|  | impl File for SocketFile { | ||||||
|  |     fn read(&self, buf: &mut [u8]) -> Result<usize> { | ||||||
|  |         self.read(buf) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn readv(&self, bufs: &mut [&mut [u8]]) -> Result<usize> { | ||||||
|  |         self.readv(bufs) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn write(&self, buf: &[u8]) -> Result<usize> { | ||||||
|  |         self.write(buf) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn writev(&self, bufs: &[&[u8]]) -> Result<usize> { | ||||||
|  |         self.writev(bufs) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn read_at(&self, offset: usize, buf: &mut [u8]) -> Result<usize> { | ||||||
|  |         if offset != 0 { | ||||||
|  |             return_errno!(ESPIPE, "a nonzero position is not supported"); | ||||||
|  |         } | ||||||
|  |         self.read(buf) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn write_at(&self, offset: usize, buf: &[u8]) -> Result<usize> { | ||||||
|  |         if offset != 0 { | ||||||
|  |             return_errno!(ESPIPE, "a nonzero position is not supported"); | ||||||
|  |         } | ||||||
|  |         self.write(buf) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn seek(&self, pos: SeekFrom) -> Result<off_t> { | ||||||
|  |         return_errno!(ESPIPE, "Socket does not support seek") | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn ioctl(&self, cmd: &mut dyn IoctlCmd) -> Result<()> { | ||||||
|  |         self.ioctl(cmd) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn notifier(&self) -> Option<&IoNotifier> { | ||||||
|  |         Some(self.notifier()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn access_mode(&self) -> Result<AccessMode> { | ||||||
|  |         Ok(AccessMode::O_RDWR) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn status_flags(&self) -> Result<StatusFlags> { | ||||||
|  |         Ok(self.status_flags()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn set_status_flags(&self, new_status_flags: StatusFlags) -> Result<()> { | ||||||
|  |         self.set_status_flags(new_status_flags) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn poll_new(&self) -> IoEvents { | ||||||
|  |         let mask = IoEvents::all(); | ||||||
|  |         self.poll(mask, None) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn host_fd(&self) -> Option<&HostFd> { | ||||||
|  |         None | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn update_host_events(&self, ready: &IoEvents, mask: &IoEvents, trigger_notifier: bool) { | ||||||
|  |         unreachable!() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn as_any(&self) -> &dyn core::any::Any { | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										12
									
								
								src/libos/src/net/socket/uring/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										12
									
								
								src/libos/src/net/socket/uring/mod.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,12 @@ | |||||||
|  | #![feature(stmt_expr_attributes)] | ||||||
|  | #![feature(new_uninit)] | ||||||
|  | #![feature(raw_ref_op)] | ||||||
|  | 
 | ||||||
|  | pub mod common; | ||||||
|  | pub mod datagram; | ||||||
|  | pub mod file_impl; | ||||||
|  | pub mod runtime; | ||||||
|  | pub mod socket_file; | ||||||
|  | pub mod stream; | ||||||
|  | 
 | ||||||
|  | pub use self::socket_file::UringSocketType; | ||||||
							
								
								
									
										12
									
								
								src/libos/src/net/socket/uring/runtime.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										12
									
								
								src/libos/src/net/socket/uring/runtime.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,12 @@ | |||||||
|  | use alloc::sync::Arc; | ||||||
|  | use io_uring_callback::IoUring; | ||||||
|  | 
 | ||||||
|  | /// The runtime support for HostSocket.
 | ||||||
|  | ///
 | ||||||
|  | /// This trait provides a common interface for user-implemented runtimes
 | ||||||
|  | /// that support HostSocket. Currently, the only dependency is a singleton
 | ||||||
|  | /// of IoUring instance.
 | ||||||
|  | pub trait Runtime: Send + Sync + 'static { | ||||||
|  |     fn io_uring() -> Arc<IoUring>; | ||||||
|  |     fn disattach_io_uring(fd: usize, uring: Arc<IoUring>); | ||||||
|  | } | ||||||
							
								
								
									
										446
									
								
								src/libos/src/net/socket/uring/socket_file.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										446
									
								
								src/libos/src/net/socket/uring/socket_file.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,446 @@ | |||||||
|  | use self::impls::{Ipv4Datagram, Ipv6Datagram}; | ||||||
|  | use crate::events::{Observer, Poller}; | ||||||
|  | use crate::net::socket::{MsgFlags, SocketProtocol}; | ||||||
|  | 
 | ||||||
|  | use self::impls::{Ipv4Stream, Ipv6Stream}; | ||||||
|  | use crate::fs::{AccessMode, IoEvents, IoNotifier, IoctlCmd, StatusFlags}; | ||||||
|  | use crate::net::socket::{AnyAddr, Ipv4SocketAddr, Ipv6SocketAddr}; | ||||||
|  | use crate::prelude::*; | ||||||
|  | 
 | ||||||
|  | #[derive(Debug)] | ||||||
|  | pub struct SocketFile { | ||||||
|  |     socket: AnySocket, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Apply a function to all variants of AnySocket enum.
 | ||||||
|  | macro_rules! apply_fn_on_any_socket { | ||||||
|  |     ($any_socket:expr, |$socket:ident| { $($fn_body:tt)* }) => {{ | ||||||
|  |         let any_socket: &AnySocket = $any_socket; | ||||||
|  |         match any_socket { | ||||||
|  |             AnySocket::Ipv4Stream($socket) => { | ||||||
|  |                 $($fn_body)* | ||||||
|  |             } | ||||||
|  |             AnySocket::Ipv6Stream($socket) => { | ||||||
|  |                 $($fn_body)* | ||||||
|  |             } | ||||||
|  |             AnySocket::Ipv4Datagram($socket) => { | ||||||
|  |                 $($fn_body)* | ||||||
|  |             } | ||||||
|  |             AnySocket::Ipv6Datagram($socket) => { | ||||||
|  |                 $($fn_body)* | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     }} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub trait UringSocketType { | ||||||
|  |     fn as_uring_socket(&self) -> Result<&SocketFile>; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl UringSocketType for FileRef { | ||||||
|  |     fn as_uring_socket(&self) -> Result<&SocketFile> { | ||||||
|  |         self.as_any() | ||||||
|  |             .downcast_ref::<SocketFile>() | ||||||
|  |             .ok_or_else(|| errno!(ENOTSOCK, "not a uring socket")) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Debug)] | ||||||
|  | enum AnySocket { | ||||||
|  |     Ipv4Stream(Ipv4Stream), | ||||||
|  |     Ipv6Stream(Ipv6Stream), | ||||||
|  |     Ipv4Datagram(Ipv4Datagram), | ||||||
|  |     Ipv6Datagram(Ipv6Datagram), | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Implement the common methods required by FileHandle
 | ||||||
|  | impl SocketFile { | ||||||
|  |     pub fn read(&self, buf: &mut [u8]) -> Result<usize> { | ||||||
|  |         apply_fn_on_any_socket!(&self.socket, |socket| { socket.read(buf) }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn readv(&self, bufs: &mut [&mut [u8]]) -> Result<usize> { | ||||||
|  |         apply_fn_on_any_socket!(&self.socket, |socket| { socket.readv(bufs) }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn write(&self, buf: &[u8]) -> Result<usize> { | ||||||
|  |         apply_fn_on_any_socket!(&self.socket, |socket| { socket.write(buf) }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn writev(&self, bufs: &[&[u8]]) -> Result<usize> { | ||||||
|  |         apply_fn_on_any_socket!(&self.socket, |socket| { socket.writev(bufs) }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn access_mode(&self) -> AccessMode { | ||||||
|  |         // We consider all sockets both readable and writable
 | ||||||
|  |         AccessMode::O_RDWR | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn status_flags(&self) -> StatusFlags { | ||||||
|  |         apply_fn_on_any_socket!(&self.socket, |socket| { socket.status_flags() }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn host_fd_inner(&self) -> FileDesc { | ||||||
|  |         apply_fn_on_any_socket!(&self.socket, |socket| { socket.host_fd() }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { | ||||||
|  |         apply_fn_on_any_socket!(&self.socket, |socket| { | ||||||
|  |             socket.set_status_flags(new_flags) | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn notifier(&self) -> &IoNotifier { | ||||||
|  |         apply_fn_on_any_socket!(&self.socket, |socket| { socket.notifier() }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { | ||||||
|  |         apply_fn_on_any_socket!(&self.socket, |socket| { socket.poll(mask, poller) }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn ioctl(&self, cmd: &mut dyn IoctlCmd) -> Result<()> { | ||||||
|  |         apply_fn_on_any_socket!(&self.socket, |socket| { socket.ioctl(cmd) }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn get_type(&self) -> Type { | ||||||
|  |         match self.socket { | ||||||
|  |             AnySocket::Ipv4Stream(_) | AnySocket::Ipv6Stream(_) => Type::STREAM, | ||||||
|  |             AnySocket::Ipv4Datagram(_) | AnySocket::Ipv6Datagram(_) => Type::DGRAM, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Implement socket-specific methods
 | ||||||
|  | impl SocketFile { | ||||||
|  |     pub fn new( | ||||||
|  |         domain: Domain, | ||||||
|  |         protocol: SocketProtocol, | ||||||
|  |         socket_type: Type, | ||||||
|  |         nonblocking: bool, | ||||||
|  |     ) -> Result<Self> { | ||||||
|  |         match socket_type { | ||||||
|  |             Type::STREAM => { | ||||||
|  |                 if protocol != SocketProtocol::IPPROTO_IP && protocol != SocketProtocol::IPPROTO_TCP | ||||||
|  |                 { | ||||||
|  |                     return_errno!(EPROTONOSUPPORT, "Protocol not supported"); | ||||||
|  |                 } | ||||||
|  |                 let any_socket = match domain { | ||||||
|  |                     Domain::INET => { | ||||||
|  |                         let ipv4_stream = Ipv4Stream::new(nonblocking)?; | ||||||
|  |                         AnySocket::Ipv4Stream(ipv4_stream) | ||||||
|  |                     } | ||||||
|  |                     Domain::INET6 => { | ||||||
|  |                         let ipv6_stream = Ipv6Stream::new(nonblocking)?; | ||||||
|  |                         AnySocket::Ipv6Stream(ipv6_stream) | ||||||
|  |                     } | ||||||
|  |                     _ => { | ||||||
|  |                         panic!() | ||||||
|  |                     } | ||||||
|  |                 }; | ||||||
|  |                 let new_self = Self { socket: any_socket }; | ||||||
|  |                 Ok(new_self) | ||||||
|  |             } | ||||||
|  |             Type::DGRAM => { | ||||||
|  |                 if protocol != SocketProtocol::IPPROTO_IP && protocol != SocketProtocol::IPPROTO_UDP | ||||||
|  |                 { | ||||||
|  |                     return_errno!(EPROTONOSUPPORT, "Protocol not supported"); | ||||||
|  |                 } | ||||||
|  |                 let any_socket = match domain { | ||||||
|  |                     Domain::INET => { | ||||||
|  |                         let ipv4_datagram = Ipv4Datagram::new(nonblocking)?; | ||||||
|  |                         AnySocket::Ipv4Datagram(ipv4_datagram) | ||||||
|  |                     } | ||||||
|  |                     Domain::INET6 => { | ||||||
|  |                         let ipv6_datagram = Ipv6Datagram::new(nonblocking)?; | ||||||
|  |                         AnySocket::Ipv6Datagram(ipv6_datagram) | ||||||
|  |                     } | ||||||
|  |                     _ => { | ||||||
|  |                         return_errno!(EINVAL, "not support yet"); | ||||||
|  |                     } | ||||||
|  |                 }; | ||||||
|  |                 let new_self = Self { socket: any_socket }; | ||||||
|  |                 Ok(new_self) | ||||||
|  |             } | ||||||
|  |             Type::RAW => { | ||||||
|  |                 return_errno!(EINVAL, "RAW socket not supported"); | ||||||
|  |             } | ||||||
|  |             _ => { | ||||||
|  |                 return_errno!(ESOCKTNOSUPPORT, "socket type not supported"); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn domain(&self) -> Domain { | ||||||
|  |         apply_fn_on_any_socket!(&self.socket, |socket| { socket.domain() }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn is_stream(&self) -> bool { | ||||||
|  |         matches!(&self.socket, AnySocket::Ipv4Stream(_)) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn connect(&self, addr: &AnyAddr) -> Result<()> { | ||||||
|  |         match &self.socket { | ||||||
|  |             AnySocket::Ipv4Stream(ipv4_stream) => { | ||||||
|  |                 let ip_addr = addr.to_ipv4()?; | ||||||
|  |                 ipv4_stream.connect(ip_addr) | ||||||
|  |             } | ||||||
|  |             AnySocket::Ipv6Stream(ipv6_stream) => { | ||||||
|  |                 let ip_addr = addr.to_ipv6()?; | ||||||
|  |                 ipv6_stream.connect(ip_addr) | ||||||
|  |             } | ||||||
|  |             AnySocket::Ipv4Datagram(ipv4_datagram) => { | ||||||
|  |                 let mut ip_addr = None; | ||||||
|  |                 if !addr.is_unspec() { | ||||||
|  |                     let ipv4_addr = addr.to_ipv4()?; | ||||||
|  |                     ip_addr = Some(ipv4_addr); | ||||||
|  |                 } | ||||||
|  |                 ipv4_datagram.connect(ip_addr) | ||||||
|  |             } | ||||||
|  |             AnySocket::Ipv6Datagram(ipv6_datagram) => { | ||||||
|  |                 let mut ip_addr = None; | ||||||
|  |                 if !addr.is_unspec() { | ||||||
|  |                     let ipv6_addr = addr.to_ipv6()?; | ||||||
|  |                     ip_addr = Some(ipv6_addr); | ||||||
|  |                 } | ||||||
|  |                 ipv6_datagram.connect(ip_addr) | ||||||
|  |             } | ||||||
|  |             _ => { | ||||||
|  |                 return_errno!(EINVAL, "connect is not supported"); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn bind(&self, addr: &mut AnyAddr) -> Result<()> { | ||||||
|  |         match &self.socket { | ||||||
|  |             AnySocket::Ipv4Stream(ipv4_stream) => { | ||||||
|  |                 let ip_addr = addr.to_ipv4()?; | ||||||
|  |                 ipv4_stream.bind(ip_addr) | ||||||
|  |             } | ||||||
|  |             AnySocket::Ipv6Stream(ipv6_stream) => { | ||||||
|  |                 let ip_addr = addr.to_ipv6()?; | ||||||
|  |                 ipv6_stream.bind(ip_addr) | ||||||
|  |             } | ||||||
|  |             AnySocket::Ipv4Datagram(ipv4_datagram) => { | ||||||
|  |                 let ip_addr = addr.to_ipv4()?; | ||||||
|  |                 ipv4_datagram.bind(ip_addr) | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             AnySocket::Ipv6Datagram(ipv6_datagram) => { | ||||||
|  |                 let ip_addr = addr.to_ipv6()?; | ||||||
|  |                 ipv6_datagram.bind(ip_addr) | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             _ => { | ||||||
|  |                 return_errno!(EINVAL, "bind is not supported"); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn listen(&self, backlog: u32) -> Result<()> { | ||||||
|  |         match &self.socket { | ||||||
|  |             AnySocket::Ipv4Stream(ip_stream) => ip_stream.listen(backlog), | ||||||
|  |             AnySocket::Ipv6Stream(ip_stream) => ip_stream.listen(backlog), | ||||||
|  |             _ => { | ||||||
|  |                 return_errno!(EOPNOTSUPP, "The socket is not of a listen supported type"); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn accept(&self, nonblocking: bool) -> Result<Self> { | ||||||
|  |         let accepted_any_socket = match &self.socket { | ||||||
|  |             AnySocket::Ipv4Stream(ipv4_stream) => { | ||||||
|  |                 let accepted_ipv4_stream = ipv4_stream.accept(nonblocking)?; | ||||||
|  |                 AnySocket::Ipv4Stream(accepted_ipv4_stream) | ||||||
|  |             } | ||||||
|  |             AnySocket::Ipv6Stream(ipv6_stream) => { | ||||||
|  |                 let accepted_ipv6_stream = ipv6_stream.accept(nonblocking)?; | ||||||
|  |                 AnySocket::Ipv6Stream(accepted_ipv6_stream) | ||||||
|  |             } | ||||||
|  |             _ => { | ||||||
|  |                 return_errno!(EOPNOTSUPP, "The socket is not of a accept supported type"); | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  |         let accepted_socket_file = SocketFile { | ||||||
|  |             socket: accepted_any_socket, | ||||||
|  |         }; | ||||||
|  |         Ok(accepted_socket_file) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn recvfrom(&self, buf: &mut [u8], flags: RecvFlags) -> Result<(usize, Option<AnyAddr>)> { | ||||||
|  |         let (bytes_recv, addr_recv, _, _) = self.recvmsg(&mut [buf], flags, None)?; | ||||||
|  |         Ok((bytes_recv, addr_recv)) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn recvmsg( | ||||||
|  |         &self, | ||||||
|  |         bufs: &mut [&mut [u8]], | ||||||
|  |         flags: RecvFlags, | ||||||
|  |         control: Option<&mut [u8]>, | ||||||
|  |     ) -> Result<(usize, Option<AnyAddr>, MsgFlags, usize)> { | ||||||
|  |         // TODO: support msg_flags and msg_control
 | ||||||
|  |         Ok(match &self.socket { | ||||||
|  |             AnySocket::Ipv4Stream(ipv4_stream) => { | ||||||
|  |                 let (bytes_recv, addr_recv, msg_flags) = ipv4_stream.recvmsg(bufs, flags)?; | ||||||
|  |                 ( | ||||||
|  |                     bytes_recv, | ||||||
|  |                     addr_recv.map(|addr| AnyAddr::Ipv4(addr)), | ||||||
|  |                     msg_flags, | ||||||
|  |                     0, | ||||||
|  |                 ) | ||||||
|  |             } | ||||||
|  |             AnySocket::Ipv6Stream(ipv6_stream) => { | ||||||
|  |                 let (bytes_recv, addr_recv, msg_flags) = ipv6_stream.recvmsg(bufs, flags)?; | ||||||
|  |                 ( | ||||||
|  |                     bytes_recv, | ||||||
|  |                     addr_recv.map(|addr| AnyAddr::Ipv6(addr)), | ||||||
|  |                     msg_flags, | ||||||
|  |                     0, | ||||||
|  |                 ) | ||||||
|  |             } | ||||||
|  |             AnySocket::Ipv4Datagram(ipv4_datagram) => { | ||||||
|  |                 let (bytes_recv, addr_recv, msg_flags, msg_controllen) = | ||||||
|  |                     ipv4_datagram.recvmsg(bufs, flags, control)?; | ||||||
|  |                 ( | ||||||
|  |                     bytes_recv, | ||||||
|  |                     addr_recv.map(|addr| AnyAddr::Ipv4(addr)), | ||||||
|  |                     msg_flags, | ||||||
|  |                     msg_controllen, | ||||||
|  |                 ) | ||||||
|  |             } | ||||||
|  |             AnySocket::Ipv6Datagram(ipv6_datagram) => { | ||||||
|  |                 let (bytes_recv, addr_recv, msg_flags, msg_controllen) = | ||||||
|  |                     ipv6_datagram.recvmsg(bufs, flags, control)?; | ||||||
|  |                 ( | ||||||
|  |                     bytes_recv, | ||||||
|  |                     addr_recv.map(|addr| AnyAddr::Ipv6(addr)), | ||||||
|  |                     msg_flags, | ||||||
|  |                     msg_controllen, | ||||||
|  |                 ) | ||||||
|  |             } | ||||||
|  |             _ => { | ||||||
|  |                 return_errno!(EINVAL, "recvfrom is not supported"); | ||||||
|  |             } | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn sendto(&self, buf: &[u8], addr: Option<AnyAddr>, flags: SendFlags) -> Result<usize> { | ||||||
|  |         self.sendmsg(&[buf], addr, flags, None) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn sendmsg( | ||||||
|  |         &self, | ||||||
|  |         bufs: &[&[u8]], | ||||||
|  |         addr: Option<AnyAddr>, | ||||||
|  |         flags: SendFlags, | ||||||
|  |         control: Option<&[u8]>, | ||||||
|  |     ) -> Result<usize> { | ||||||
|  |         let res = match &self.socket { | ||||||
|  |             AnySocket::Ipv4Stream(ipv4_stream) => ipv4_stream.sendmsg(bufs, flags), | ||||||
|  |             AnySocket::Ipv6Stream(ipv6_stream) => ipv6_stream.sendmsg(bufs, flags), | ||||||
|  |             AnySocket::Ipv4Datagram(ipv4_datagram) => { | ||||||
|  |                 let ip_addr = if let Some(addr) = addr.as_ref() { | ||||||
|  |                     Some(addr.to_ipv4()?) | ||||||
|  |                 } else { | ||||||
|  |                     None | ||||||
|  |                 }; | ||||||
|  |                 ipv4_datagram.sendmsg(bufs, ip_addr, flags, control) | ||||||
|  |             } | ||||||
|  |             AnySocket::Ipv6Datagram(ipv6_datagram) => { | ||||||
|  |                 let ip_addr = if let Some(addr) = addr.as_ref() { | ||||||
|  |                     Some(addr.to_ipv6()?) | ||||||
|  |                 } else { | ||||||
|  |                     None | ||||||
|  |                 }; | ||||||
|  |                 ipv6_datagram.sendmsg(bufs, ip_addr, flags, control) | ||||||
|  |             } | ||||||
|  |             _ => { | ||||||
|  |                 return_errno!(EINVAL, "sendmsg is not supported"); | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  |         if res.has_errno(EPIPE) && !flags.contains(SendFlags::MSG_NOSIGNAL) { | ||||||
|  |             crate::signal::do_tkill(current!().tid(), crate::signal::SIGPIPE.as_u8() as i32); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         res | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn addr(&self) -> Result<AnyAddr> { | ||||||
|  |         Ok(match &self.socket { | ||||||
|  |             AnySocket::Ipv4Stream(ipv4_stream) => AnyAddr::Ipv4(ipv4_stream.addr()?), | ||||||
|  |             AnySocket::Ipv6Stream(ipv6_stream) => AnyAddr::Ipv6(ipv6_stream.addr()?), | ||||||
|  |             AnySocket::Ipv4Datagram(ipv4_datagram) => AnyAddr::Ipv4(ipv4_datagram.addr()?), | ||||||
|  |             AnySocket::Ipv6Datagram(ipv6_datagram) => AnyAddr::Ipv6(ipv6_datagram.addr()?), | ||||||
|  |             _ => { | ||||||
|  |                 return_errno!(EINVAL, "addr is not supported"); | ||||||
|  |             } | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn peer_addr(&self) -> Result<AnyAddr> { | ||||||
|  |         Ok(match &self.socket { | ||||||
|  |             AnySocket::Ipv4Stream(ipv4_stream) => AnyAddr::Ipv4(ipv4_stream.peer_addr()?), | ||||||
|  |             AnySocket::Ipv6Stream(ipv6_stream) => AnyAddr::Ipv6(ipv6_stream.peer_addr()?), | ||||||
|  |             AnySocket::Ipv4Datagram(ipv4_datagram) => AnyAddr::Ipv4(ipv4_datagram.peer_addr()?), | ||||||
|  |             AnySocket::Ipv6Datagram(ipv6_datagram) => AnyAddr::Ipv6(ipv6_datagram.peer_addr()?), | ||||||
|  |             _ => { | ||||||
|  |                 return_errno!(EINVAL, "peer_addr is not supported"); | ||||||
|  |             } | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn shutdown(&self, how: Shutdown) -> Result<()> { | ||||||
|  |         match &self.socket { | ||||||
|  |             AnySocket::Ipv4Stream(ipv4_stream) => ipv4_stream.shutdown(how), | ||||||
|  |             AnySocket::Ipv6Stream(ipv6_stream) => ipv6_stream.shutdown(how), | ||||||
|  |             AnySocket::Ipv4Datagram(ipv4_datagram) => ipv4_datagram.shutdown(how), | ||||||
|  |             AnySocket::Ipv6Datagram(ipv6_datagram) => ipv6_datagram.shutdown(how), | ||||||
|  |             _ => { | ||||||
|  |                 return_errno!(EINVAL, "shutdown is not supported"); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn close(&self) -> Result<()> { | ||||||
|  |         match &self.socket { | ||||||
|  |             AnySocket::Ipv4Stream(ipv4_stream) => ipv4_stream.close(), | ||||||
|  |             AnySocket::Ipv6Stream(ipv6_stream) => ipv6_stream.close(), | ||||||
|  |             AnySocket::Ipv4Datagram(ipv4_datagram) => ipv4_datagram.close(), | ||||||
|  |             AnySocket::Ipv6Datagram(ipv6_datagram) => ipv6_datagram.close(), | ||||||
|  |             _ => Ok(()), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Drop for SocketFile { | ||||||
|  |     fn drop(&mut self) { | ||||||
|  |         self.close(); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | mod impls { | ||||||
|  |     use super::*; | ||||||
|  |     use io_uring_callback::IoUring; | ||||||
|  | 
 | ||||||
|  |     pub type Ipv4Stream = | ||||||
|  |         crate::net::socket::uring::stream::StreamSocket<Ipv4SocketAddr, SocketRuntime>; | ||||||
|  |     pub type Ipv6Stream = | ||||||
|  |         crate::net::socket::uring::stream::StreamSocket<Ipv6SocketAddr, SocketRuntime>; | ||||||
|  | 
 | ||||||
|  |     pub type Ipv4Datagram = | ||||||
|  |         crate::net::socket::uring::datagram::DatagramSocket<Ipv4SocketAddr, SocketRuntime>; | ||||||
|  |     pub type Ipv6Datagram = | ||||||
|  |         crate::net::socket::uring::datagram::DatagramSocket<Ipv6SocketAddr, SocketRuntime>; | ||||||
|  | 
 | ||||||
|  |     pub struct SocketRuntime; | ||||||
|  |     impl crate::net::socket::uring::runtime::Runtime for SocketRuntime { | ||||||
|  |         // Assign an IO-Uring instance for newly created socket
 | ||||||
|  |         fn io_uring() -> Arc<IoUring> { | ||||||
|  |             crate::io_uring::MULTITON.get_uring() | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Disattach IO-Uring instance with closed socket
 | ||||||
|  |         fn disattach_io_uring(fd: usize, uring: Arc<IoUring>) { | ||||||
|  |             crate::io_uring::MULTITON.disattach_uring(fd, uring); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										610
									
								
								src/libos/src/net/socket/uring/stream/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										610
									
								
								src/libos/src/net/socket/uring/stream/mod.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,610 @@ | |||||||
|  | mod states; | ||||||
|  | 
 | ||||||
|  | use core::hint; | ||||||
|  | use core::sync::atomic::AtomicUsize; | ||||||
|  | use core::time::Duration; | ||||||
|  | 
 | ||||||
|  | use atomic::Ordering; | ||||||
|  | 
 | ||||||
|  | use self::states::{ConnectedStream, ConnectingStream, InitStream, ListenerStream}; | ||||||
|  | use crate::events::Observer; | ||||||
|  | use crate::fs::{ | ||||||
|  |     GetIfConf, GetIfReqWithRawCmd, GetReadBufLen, IoEvents, IoNotifier, IoctlCmd, SetNonBlocking, | ||||||
|  |     StatusFlags, | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | use crate::net::socket::uring::common::Common; | ||||||
|  | use crate::net::socket::uring::runtime::Runtime; | ||||||
|  | use crate::prelude::*; | ||||||
|  | 
 | ||||||
|  | use crate::events::Poller; | ||||||
|  | use crate::net::socket::{sockopt::*, MsgFlags}; | ||||||
|  | 
 | ||||||
|  | lazy_static! { | ||||||
|  |     pub static ref SEND_BUF_SIZE: AtomicUsize = AtomicUsize::new(2565 * 1024 + 1); // Default Linux send buffer size is 2.5MB.
 | ||||||
|  |     pub static ref RECV_BUF_SIZE: AtomicUsize = AtomicUsize::new(256 * 1024 + 1); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub struct StreamSocket<A: Addr + 'static, R: Runtime> { | ||||||
|  |     state: RwLock<State<A, R>>, | ||||||
|  |     common: Arc<Common<A, R>>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | enum State<A: Addr + 'static, R: Runtime> { | ||||||
|  |     // Start state
 | ||||||
|  |     Init(Arc<InitStream<A, R>>), | ||||||
|  |     // Intermediate state
 | ||||||
|  |     Connect(Arc<ConnectingStream<A, R>>), | ||||||
|  |     // Final state 1
 | ||||||
|  |     Connected(Arc<ConnectedStream<A, R>>), | ||||||
|  |     // Final state 2
 | ||||||
|  |     Listen(Arc<ListenerStream<A, R>>), | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr, R: Runtime> StreamSocket<A, R> { | ||||||
|  |     pub fn new(nonblocking: bool) -> Result<Self> { | ||||||
|  |         let init_stream = InitStream::new(nonblocking)?; | ||||||
|  |         let common = init_stream.common().clone(); | ||||||
|  | 
 | ||||||
|  |         let fd = common.host_fd(); | ||||||
|  |         debug!("host fd: {}", fd); | ||||||
|  | 
 | ||||||
|  |         let init_state = State::Init(init_stream); | ||||||
|  |         Ok(Self { | ||||||
|  |             state: RwLock::new(init_state), | ||||||
|  |             common, | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn new_pair(nonblocking: bool) -> Result<(Self, Self)> { | ||||||
|  |         let (common1, common2) = Common::new_pair(Type::STREAM, nonblocking)?; | ||||||
|  |         let connected1 = ConnectedStream::new(Arc::new(common1)); | ||||||
|  |         let connected2 = ConnectedStream::new(Arc::new(common2)); | ||||||
|  |         let socket1 = Self::new_connected(connected1); | ||||||
|  |         let socket2 = Self::new_connected(connected2); | ||||||
|  |         Ok((socket1, socket2)) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn new_connected(connected_stream: Arc<ConnectedStream<A, R>>) -> Self { | ||||||
|  |         let common = connected_stream.common().clone(); | ||||||
|  |         let state = RwLock::new(State::Connected(connected_stream)); | ||||||
|  |         Self { state, common } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn try_switch_to_connected_state( | ||||||
|  |         connecting_stream: &Arc<ConnectingStream<A, R>>, | ||||||
|  |     ) -> Option<Arc<ConnectedStream<A, R>>> { | ||||||
|  |         // Previously, I thought connecting state only exists for non-blocking socket. However, some applications can set non-blocking for
 | ||||||
|  |         // connect syscall and after the connect returns, set the socket to blocking socket. Thus, this function shouldn't assert the connecting
 | ||||||
|  |         // stream is non-blocking socket.
 | ||||||
|  |         if connecting_stream.check_connection() { | ||||||
|  |             let common = connecting_stream.common().clone(); | ||||||
|  |             common.set_peer_addr(connecting_stream.peer_addr()); | ||||||
|  |             Some(ConnectedStream::new(common)) | ||||||
|  |         } else { | ||||||
|  |             None | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn domain(&self) -> Domain { | ||||||
|  |         A::domain() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn errno(&self) -> Option<Errno> { | ||||||
|  |         self.common.errno() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn host_fd(&self) -> FileDesc { | ||||||
|  |         let state = self.state.read().unwrap(); | ||||||
|  |         state.common().host_fd() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn status_flags(&self) -> StatusFlags { | ||||||
|  |         // Only support O_NONBLOCK
 | ||||||
|  |         let state = self.state.read().unwrap(); | ||||||
|  |         if state.common().nonblocking() { | ||||||
|  |             StatusFlags::O_NONBLOCK | ||||||
|  |         } else { | ||||||
|  |             StatusFlags::empty() | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { | ||||||
|  |         // Only support O_NONBLOCK
 | ||||||
|  |         let state = self.state.read().unwrap(); | ||||||
|  |         let nonblocking = new_flags.is_nonblocking(); | ||||||
|  |         state.common().set_nonblocking(nonblocking); | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn bind(&self, addr: &A) -> Result<()> { | ||||||
|  |         let state = self.state.read().unwrap(); | ||||||
|  |         match &*state { | ||||||
|  |             State::Init(init_stream) => init_stream.bind(addr), | ||||||
|  |             _ => { | ||||||
|  |                 return_errno!(EINVAL, "cannot bind"); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn listen(&self, backlog: u32) -> Result<()> { | ||||||
|  |         let mut state = self.state.write().unwrap(); | ||||||
|  |         match &*state { | ||||||
|  |             State::Init(init_stream) => { | ||||||
|  |                 let common = init_stream.common().clone(); | ||||||
|  |                 let listener = ListenerStream::new(backlog, common)?; | ||||||
|  |                 *state = State::Listen(listener); | ||||||
|  |                 Ok(()) | ||||||
|  |             } | ||||||
|  |             _ => { | ||||||
|  |                 return_errno!(EINVAL, "cannot listen"); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn connect(&self, peer_addr: &A) -> Result<()> { | ||||||
|  |         // Create the new intermediate state of connecting and save the
 | ||||||
|  |         // old state of init in case of failure to connect.
 | ||||||
|  |         let (init_stream, connecting_stream) = { | ||||||
|  |             let mut state = self.state.write().unwrap(); | ||||||
|  |             match &*state { | ||||||
|  |                 State::Init(init_stream) => { | ||||||
|  |                     let connecting_stream = { | ||||||
|  |                         let common = init_stream.common().clone(); | ||||||
|  |                         ConnectingStream::new(peer_addr, common)? | ||||||
|  |                     }; | ||||||
|  |                     let init_stream = init_stream.clone(); | ||||||
|  |                     *state = State::Connect(connecting_stream.clone()); | ||||||
|  |                     (init_stream, connecting_stream) | ||||||
|  |                 } | ||||||
|  |                 State::Connect(connecting_stream) => { | ||||||
|  |                     if let Some(connected_stream) = | ||||||
|  |                         Self::try_switch_to_connected_state(connecting_stream) | ||||||
|  |                     { | ||||||
|  |                         *state = State::Connected(connected_stream); | ||||||
|  |                         return_errno!(EISCONN, "the socket is already connected"); | ||||||
|  |                     } else { | ||||||
|  |                         // Not connected, keep the connecting state and try connect
 | ||||||
|  |                         let init_stream = | ||||||
|  |                             InitStream::new_with_common(connecting_stream.common().clone())?; | ||||||
|  |                         (init_stream, connecting_stream.clone()) | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |                 State::Connected(_) => { | ||||||
|  |                     return_errno!(EISCONN, "the socket is already connected"); | ||||||
|  |                 } | ||||||
|  |                 State::Listen(_) => { | ||||||
|  |                     return_errno!(EINVAL, "the socket is listening"); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         let res = connecting_stream.connect(); | ||||||
|  | 
 | ||||||
|  |         // If success, then the state is switched to connected; otherwise, for blocking socket
 | ||||||
|  |         // the state is restored to the init state, and for non-blocking socket, the state
 | ||||||
|  |         // keeps in connecting state.
 | ||||||
|  |         match &res { | ||||||
|  |             Ok(()) => { | ||||||
|  |                 let connected_stream = { | ||||||
|  |                     let common = init_stream.common().clone(); | ||||||
|  |                     common.set_peer_addr(peer_addr); | ||||||
|  |                     ConnectedStream::new(common) | ||||||
|  |                 }; | ||||||
|  | 
 | ||||||
|  |                 let mut state = self.state.write().unwrap(); | ||||||
|  |                 *state = State::Connected(connected_stream); | ||||||
|  |             } | ||||||
|  |             Err(_) => { | ||||||
|  |                 if !connecting_stream.common().nonblocking() { | ||||||
|  |                     let mut state = self.state.write().unwrap(); | ||||||
|  |                     *state = State::Init(init_stream); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         res | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn accept(&self, nonblocking: bool) -> Result<Self> { | ||||||
|  |         let listener_stream = { | ||||||
|  |             let state = self.state.read().unwrap(); | ||||||
|  |             match &*state { | ||||||
|  |                 State::Listen(listener_stream) => listener_stream.clone(), | ||||||
|  |                 _ => { | ||||||
|  |                     return_errno!(EINVAL, "the socket is not listening"); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         let connected_stream = listener_stream.accept(nonblocking)?; | ||||||
|  | 
 | ||||||
|  |         let new_self = Self::new_connected(connected_stream); | ||||||
|  |         Ok(new_self) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn read(&self, buf: &mut [u8]) -> Result<usize> { | ||||||
|  |         self.readv(&mut [buf]) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn readv(&self, bufs: &mut [&mut [u8]]) -> Result<usize> { | ||||||
|  |         let ret = self.recvmsg(bufs, RecvFlags::empty())?; | ||||||
|  |         Ok(ret.0) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// Receive messages from connected socket
 | ||||||
|  |     ///
 | ||||||
|  |     /// Linux behavior:
 | ||||||
|  |     /// Unlike datagram socket, `recvfrom` / `recvmsg` of stream socket will
 | ||||||
|  |     /// ignore the address even if user specified it.
 | ||||||
|  |     pub fn recvmsg( | ||||||
|  |         &self, | ||||||
|  |         buf: &mut [&mut [u8]], | ||||||
|  |         flags: RecvFlags, | ||||||
|  |     ) -> Result<(usize, Option<A>, MsgFlags)> { | ||||||
|  |         let connected_stream = { | ||||||
|  |             let mut state = self.state.write().unwrap(); | ||||||
|  |             match &*state { | ||||||
|  |                 State::Connected(connected_stream) => connected_stream.clone(), | ||||||
|  |                 State::Connect(connecting_stream) => { | ||||||
|  |                     if let Some(connected_stream) = | ||||||
|  |                         Self::try_switch_to_connected_state(connecting_stream) | ||||||
|  |                     { | ||||||
|  |                         *state = State::Connected(connected_stream.clone()); | ||||||
|  |                         connected_stream | ||||||
|  |                     } else { | ||||||
|  |                         return_errno!(ENOTCONN, "the socket is not connected"); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |                 _ => { | ||||||
|  |                     return_errno!(ENOTCONN, "the socket is not connected"); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         let recv_len = connected_stream.recvmsg(buf, flags)?; | ||||||
|  |         Ok((recv_len, None, MsgFlags::empty())) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn write(&self, buf: &[u8]) -> Result<usize> { | ||||||
|  |         self.writev(&[buf]) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn writev(&self, bufs: &[&[u8]]) -> Result<usize> { | ||||||
|  |         self.sendmsg(bufs, SendFlags::empty()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn sendmsg(&self, bufs: &[&[u8]], flags: SendFlags) -> Result<usize> { | ||||||
|  |         let connected_stream = { | ||||||
|  |             let mut state = self.state.write().unwrap(); | ||||||
|  |             match &*state { | ||||||
|  |                 State::Connected(connected_stream) => connected_stream.clone(), | ||||||
|  |                 State::Connect(connecting_stream) => { | ||||||
|  |                     if let Some(connected_stream) = | ||||||
|  |                         Self::try_switch_to_connected_state(connecting_stream) | ||||||
|  |                     { | ||||||
|  |                         *state = State::Connected(connected_stream.clone()); | ||||||
|  |                         connected_stream | ||||||
|  |                     } else { | ||||||
|  |                         return_errno!(ENOTCONN, "the socket is not connected"); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |                 _ => { | ||||||
|  |                     return_errno!(EPIPE, "the socket is not connected"); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         connected_stream.sendmsg(bufs, flags) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { | ||||||
|  |         let state = self.state.read().unwrap(); | ||||||
|  |         let pollee = state.common().pollee(); | ||||||
|  |         pollee.poll(mask, poller) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn addr(&self) -> Result<A> { | ||||||
|  |         let state = self.state.read().unwrap(); | ||||||
|  |         let common = state.common(); | ||||||
|  | 
 | ||||||
|  |         // Always get addr from host.
 | ||||||
|  |         // Because for IP socket, users can specify "0" as port and the kernel should select a usable port for him.
 | ||||||
|  |         // Thus, when calling getsockname, this should be updated.
 | ||||||
|  |         let addr = common.get_addr_from_host()?; | ||||||
|  |         common.set_addr(&addr); | ||||||
|  |         Ok(addr) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn notifier(&self) -> &IoNotifier { | ||||||
|  |         { | ||||||
|  |             let mut state = self.state.write().unwrap(); | ||||||
|  |             // Try switch to connected state to receive endpoint status
 | ||||||
|  |             if let State::Connect(connecting_stream) = &*state { | ||||||
|  |                 if let Some(connected_stream) = | ||||||
|  |                     Self::try_switch_to_connected_state(connecting_stream) | ||||||
|  |                 { | ||||||
|  |                     *state = State::Connected(connected_stream.clone()); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             // `state` goes out of scope here and the lock is implicitly released.
 | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         self.common.notifier() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn peer_addr(&self) -> Result<A> { | ||||||
|  |         let mut state = self.state.write().unwrap(); | ||||||
|  |         match &*state { | ||||||
|  |             State::Connected(connected_stream) => { | ||||||
|  |                 Ok(connected_stream.common().peer_addr().unwrap()) | ||||||
|  |             } | ||||||
|  |             State::Connect(connecting_stream) => { | ||||||
|  |                 if let Some(connected_stream) = | ||||||
|  |                     Self::try_switch_to_connected_state(connecting_stream) | ||||||
|  |                 { | ||||||
|  |                     *state = State::Connected(connected_stream.clone()); | ||||||
|  |                     Ok(connected_stream.common().peer_addr().unwrap()) | ||||||
|  |                 } else { | ||||||
|  |                     return_errno!(ENOTCONN, "the socket is not connected"); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             _ => return_errno!(ENOTCONN, "the socket is not connected"), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn ioctl(&self, cmd: &mut dyn IoctlCmd) -> Result<()> { | ||||||
|  |         let mut state = self.state.write().unwrap(); | ||||||
|  |         match &*state { | ||||||
|  |             State::Connect(connecting_stream) => { | ||||||
|  |                 if let Some(connected_stream) = | ||||||
|  |                     Self::try_switch_to_connected_state(connecting_stream) | ||||||
|  |                 { | ||||||
|  |                     *state = State::Connected(connected_stream.clone()); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             _ => {} | ||||||
|  |         } | ||||||
|  |         drop(state); | ||||||
|  |         crate::match_ioctl_cmd_mut!(&mut *cmd, { | ||||||
|  |             cmd: GetSockOptRawCmd => { | ||||||
|  |                 cmd.execute(self.host_fd())?; | ||||||
|  |             }, | ||||||
|  |             cmd: SetSockOptRawCmd => { | ||||||
|  |                 cmd.execute(self.host_fd())?; | ||||||
|  |             }, | ||||||
|  |             cmd: SetRecvTimeoutCmd => { | ||||||
|  |                 self.set_recv_timeout(*cmd.timeout()); | ||||||
|  |             }, | ||||||
|  |             cmd: SetSendTimeoutCmd => { | ||||||
|  |                 self.set_send_timeout(*cmd.timeout()); | ||||||
|  |             }, | ||||||
|  |             cmd: GetRecvTimeoutCmd => { | ||||||
|  |                 let timeval = timeout_to_timeval(self.recv_timeout()); | ||||||
|  |                 cmd.set_output(timeval); | ||||||
|  |             }, | ||||||
|  |             cmd: GetSendTimeoutCmd => { | ||||||
|  |                 let timeval = timeout_to_timeval(self.send_timeout()); | ||||||
|  |                 cmd.set_output(timeval); | ||||||
|  |             }, | ||||||
|  |             cmd: SetSndBufSizeCmd => { | ||||||
|  |                 cmd.update_host(self.host_fd())?; | ||||||
|  |                 let buf_size = cmd.buf_size(); | ||||||
|  |                 self.set_kernel_send_buf_size(buf_size); | ||||||
|  |             }, | ||||||
|  |             cmd: SetRcvBufSizeCmd => { | ||||||
|  |                 cmd.update_host(self.host_fd())?; | ||||||
|  |                 let buf_size = cmd.buf_size(); | ||||||
|  |                 self.set_kernel_recv_buf_size(buf_size); | ||||||
|  |             }, | ||||||
|  |             cmd: GetSndBufSizeCmd => { | ||||||
|  |                 let buf_size = SEND_BUF_SIZE.load(Ordering::Relaxed); | ||||||
|  |                 cmd.set_output(buf_size); | ||||||
|  |             }, | ||||||
|  |             cmd: GetRcvBufSizeCmd => { | ||||||
|  |                 let buf_size = RECV_BUF_SIZE.load(Ordering::Relaxed); | ||||||
|  |                 cmd.set_output(buf_size); | ||||||
|  |             }, | ||||||
|  |             cmd: GetAcceptConnCmd => { | ||||||
|  |                 let mut is_listen = false; | ||||||
|  |                 let state = self.state.read().unwrap(); | ||||||
|  |                 if let State::Listen(_listener_stream) = &*state { | ||||||
|  |                     is_listen = true; | ||||||
|  |                 } | ||||||
|  |                 cmd.set_output(is_listen as _); | ||||||
|  |             }, | ||||||
|  |             cmd: GetDomainCmd => { | ||||||
|  |                 cmd.set_output(self.domain() as _); | ||||||
|  |             }, | ||||||
|  |             cmd: GetPeerNameCmd => { | ||||||
|  |                 let peer = self.peer_addr()?; | ||||||
|  |                 cmd.set_output(AddrStorage(peer.to_c_storage())); | ||||||
|  |             }, | ||||||
|  |             cmd: GetErrorCmd => { | ||||||
|  |                 let error: i32 = self.errno().map(|err| err as i32).unwrap_or(0); | ||||||
|  |                 cmd.set_output(error); | ||||||
|  |             }, | ||||||
|  |             cmd: GetTypeCmd => { | ||||||
|  |                 let state = self.state.read().unwrap(); | ||||||
|  |                 cmd.set_output(state.common().type_() as _); | ||||||
|  |             }, | ||||||
|  |             cmd: SetNonBlocking => { | ||||||
|  |                 let state = self.state.read().unwrap(); | ||||||
|  |                 state.common().set_nonblocking(*cmd.input() != 0); | ||||||
|  |             }, | ||||||
|  |             cmd: GetReadBufLen => { | ||||||
|  |                 let state = self.state.read().unwrap(); | ||||||
|  |                 if let State::Connected(connected_stream) = &*state { | ||||||
|  |                     let read_buf_len = connected_stream.bytes_to_consume(); | ||||||
|  |                     cmd.set_output(read_buf_len as _); | ||||||
|  |                 } else { | ||||||
|  |                     return_errno!(ENOTCONN, "unconnected socket"); | ||||||
|  |                 } | ||||||
|  |             }, | ||||||
|  |             cmd: GetIfReqWithRawCmd => { | ||||||
|  |                 cmd.execute(self.host_fd())?; | ||||||
|  |             }, | ||||||
|  |             cmd: GetIfConf => { | ||||||
|  |                 cmd.execute(self.host_fd())?; | ||||||
|  |             }, | ||||||
|  |             _ => { | ||||||
|  |                 return_errno!(EINVAL, "Not supported yet"); | ||||||
|  |             } | ||||||
|  |         }); | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn set_kernel_send_buf_size(&self, buf_size: usize) { | ||||||
|  |         let state = self.state.read().unwrap(); | ||||||
|  |         match &*state { | ||||||
|  |             State::Init(_) | State::Listen(_) | State::Connect(_) => { | ||||||
|  |                 // The kernel buffer is only created when the socket is connected. Just update the static variable.
 | ||||||
|  |                 SEND_BUF_SIZE.store(buf_size, Ordering::Relaxed); | ||||||
|  |             } | ||||||
|  |             State::Connected(connected_stream) => { | ||||||
|  |                 connected_stream.try_update_send_buf_size(buf_size); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn set_kernel_recv_buf_size(&self, buf_size: usize) { | ||||||
|  |         let state = self.state.read().unwrap(); | ||||||
|  |         match &*state { | ||||||
|  |             State::Init(_) | State::Listen(_) | State::Connect(_) => { | ||||||
|  |                 // The kernel buffer is only created when the socket is connected. Just update the static variable.
 | ||||||
|  |                 RECV_BUF_SIZE.store(buf_size, Ordering::Relaxed); | ||||||
|  |             } | ||||||
|  |             State::Connected(connected_stream) => { | ||||||
|  |                 connected_stream.try_update_recv_buf_size(buf_size); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn shutdown(&self, shutdown: Shutdown) -> Result<()> { | ||||||
|  |         let mut state = self.state.write().unwrap(); | ||||||
|  |         match &*state { | ||||||
|  |             State::Listen(listener_stream) => { | ||||||
|  |                 // listening socket can be shutdown and then re-use by calling listen again.
 | ||||||
|  |                 listener_stream.shutdown(shutdown)?; | ||||||
|  |                 if shutdown.should_shut_read() { | ||||||
|  |                     // Cancel pending accept requests. This is necessary because the socket is reusable.
 | ||||||
|  |                     listener_stream.cancel_accept_requests(); | ||||||
|  |                     // Set init state
 | ||||||
|  |                     let init_stream = | ||||||
|  |                         InitStream::new_with_common(listener_stream.common().clone())?; | ||||||
|  |                     let init_state = State::Init(init_stream); | ||||||
|  |                     *state = init_state; | ||||||
|  |                     Ok(()) | ||||||
|  |                 } else { | ||||||
|  |                     // shutdown the writer of the listener expect to have no effect
 | ||||||
|  |                     Ok(()) | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             State::Connected(connected_stream) => connected_stream.shutdown(shutdown), | ||||||
|  |             State::Connect(connecting_stream) => { | ||||||
|  |                 if let Some(connected_stream) = | ||||||
|  |                     Self::try_switch_to_connected_state(connecting_stream) | ||||||
|  |                 { | ||||||
|  |                     connected_stream.shutdown(shutdown)?; | ||||||
|  |                     *state = State::Connected(connected_stream); | ||||||
|  |                     Ok(()) | ||||||
|  |                 } else { | ||||||
|  |                     return_errno!(ENOTCONN, "the socket is not connected"); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             _ => { | ||||||
|  |                 return_errno!(ENOTCONN, "the socket is not connected"); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn close(&self) -> Result<()> { | ||||||
|  |         let state = self.state.read().unwrap(); | ||||||
|  |         match &*state { | ||||||
|  |             State::Init(_) => {} | ||||||
|  |             State::Listen(listener_stream) => { | ||||||
|  |                 listener_stream.common().set_closed(); | ||||||
|  |                 listener_stream.cancel_accept_requests(); | ||||||
|  |             } | ||||||
|  |             State::Connect(connecting_stream) => { | ||||||
|  |                 connecting_stream.common().set_closed(); | ||||||
|  |                 let need_wait = true; | ||||||
|  |                 connecting_stream.cancel_connect_request(need_wait); | ||||||
|  |             } | ||||||
|  |             State::Connected(connected_stream) => { | ||||||
|  |                 connected_stream.set_closed(); | ||||||
|  |                 connected_stream.cancel_recv_requests(); | ||||||
|  |                 connected_stream.try_empty_send_buf_when_close(); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn send_timeout(&self) -> Option<Duration> { | ||||||
|  |         let state = self.state.read().unwrap(); | ||||||
|  |         state.common().send_timeout() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn recv_timeout(&self) -> Option<Duration> { | ||||||
|  |         let state = self.state.read().unwrap(); | ||||||
|  |         state.common().recv_timeout() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn set_send_timeout(&self, timeout: Duration) { | ||||||
|  |         let state = self.state.read().unwrap(); | ||||||
|  |         state.common().set_send_timeout(timeout); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn set_recv_timeout(&self, timeout: Duration) { | ||||||
|  |         let state = self.state.read().unwrap(); | ||||||
|  |         state.common().set_recv_timeout(timeout); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /* | ||||||
|  |         pub fn poll_by(&self, mask: Events, mut poller: Option<&mut Poller>) -> Events { | ||||||
|  |             let state = self.state.read(); | ||||||
|  |             match *state { | ||||||
|  |                 Init(init_stream) => init_stream.poll_by(mask, poller), | ||||||
|  |                 Connect(connect_stream) => connect_stream.poll_by(mask, poller), | ||||||
|  |                 Connected(connected_stream) = connected_stream.poll_by(mask, poller), | ||||||
|  |                 Listen(listener_stream) = listener_stream.poll_by(mask, poller), | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     */ | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static, R: Runtime> Drop for StreamSocket<A, R> { | ||||||
|  |     fn drop(&mut self) { | ||||||
|  |         let state = self.state.read().unwrap(); | ||||||
|  |         state.common().set_closed(); | ||||||
|  |         drop(state); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static, R: Runtime> std::fmt::Debug for State<A, R> { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         let inner: &dyn std::fmt::Debug = match self { | ||||||
|  |             State::Init(inner) => inner as _, | ||||||
|  |             State::Connect(inner) => inner as _, | ||||||
|  |             State::Connected(inner) => inner as _, | ||||||
|  |             State::Listen(inner) => inner as _, | ||||||
|  |         }; | ||||||
|  |         f.debug_tuple("State").field(inner).finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static, R: Runtime> std::fmt::Debug for StreamSocket<A, R> { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         f.debug_struct("StreamSocket").finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static, R: Runtime> State<A, R> { | ||||||
|  |     fn common(&self) -> &Common<A, R> { | ||||||
|  |         match self { | ||||||
|  |             Self::Init(stream) => stream.common(), | ||||||
|  |             Self::Connect(stream) => stream.common(), | ||||||
|  |             Self::Connected(stream) => stream.common(), | ||||||
|  |             Self::Listen(stream) => stream.common(), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										227
									
								
								src/libos/src/net/socket/uring/stream/states/connect.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										227
									
								
								src/libos/src/net/socket/uring/stream/states/connect.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,227 @@ | |||||||
|  | use core::time::Duration; | ||||||
|  | use std::marker::PhantomData; | ||||||
|  | use std::sync::atomic::{AtomicBool, Ordering}; | ||||||
|  | 
 | ||||||
|  | use io_uring_callback::{Fd, IoHandle}; | ||||||
|  | use sgx_untrusted_alloc::UntrustedBox; | ||||||
|  | 
 | ||||||
|  | use crate::events::Poller; | ||||||
|  | use crate::fs::IoEvents; | ||||||
|  | use crate::net::socket::uring::common::Common; | ||||||
|  | use crate::net::socket::uring::runtime::Runtime; | ||||||
|  | use crate::prelude::*; | ||||||
|  | 
 | ||||||
|  | /// A stream socket that is in its connecting state.
 | ||||||
|  | pub struct ConnectingStream<A: Addr + 'static, R: Runtime> { | ||||||
|  |     common: Arc<Common<A, R>>, | ||||||
|  |     peer_addr: A, | ||||||
|  |     req: Mutex<ConnectReq<A>>, | ||||||
|  |     connected: AtomicBool, // Mainly use for nonblocking socket to update status asynchronously
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | struct ConnectReq<A: Addr> { | ||||||
|  |     io_handle: Option<IoHandle>, | ||||||
|  |     c_addr: UntrustedBox<libc::sockaddr_storage>, | ||||||
|  |     c_addr_len: usize, | ||||||
|  |     errno: Option<Errno>, | ||||||
|  |     phantom_data: PhantomData<A>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static, R: Runtime> ConnectingStream<A, R> { | ||||||
|  |     pub fn new(peer_addr: &A, common: Arc<Common<A, R>>) -> Result<Arc<Self>> { | ||||||
|  |         let req = Mutex::new(ConnectReq::new(peer_addr)); | ||||||
|  |         let new_self = Self { | ||||||
|  |             common, | ||||||
|  |             peer_addr: peer_addr.clone(), | ||||||
|  |             req, | ||||||
|  |             connected: AtomicBool::new(false), | ||||||
|  |         }; | ||||||
|  |         Ok(Arc::new(new_self)) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// Connect to the peer address.
 | ||||||
|  |     pub fn connect(self: &Arc<Self>) -> Result<()> { | ||||||
|  |         let pollee = self.common.pollee(); | ||||||
|  |         pollee.reset_events(); | ||||||
|  | 
 | ||||||
|  |         self.initiate_async_connect(); | ||||||
|  | 
 | ||||||
|  |         if self.common.nonblocking() { | ||||||
|  |             return_errno!(EINPROGRESS, "non-blocking connect request in progress"); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Wait for the async connect to complete
 | ||||||
|  |         let mask = IoEvents::OUT; | ||||||
|  |         let poller = Poller::new(); | ||||||
|  |         pollee.connect_poller(mask, &poller); | ||||||
|  |         let mut timeout = self.common.send_timeout(); | ||||||
|  |         loop { | ||||||
|  |             let events = pollee.poll(mask, None); | ||||||
|  |             if !events.is_empty() { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |             let ret = poller.wait_timeout(timeout.as_mut()); | ||||||
|  |             if let Err(e) = ret { | ||||||
|  |                 let errno = e.errno(); | ||||||
|  |                 warn!("connect wait errno = {:?}", errno); | ||||||
|  |                 match errno { | ||||||
|  |                     ETIMEDOUT => { | ||||||
|  |                         // Cancel connect request if timeout. No need to wait for cancel to complete.
 | ||||||
|  |                         self.cancel_connect_request(false); | ||||||
|  |                         // This error code is same as the connect timeout error code on Linux
 | ||||||
|  |                         return_errno!(EINPROGRESS, "timeout reached") | ||||||
|  |                     } | ||||||
|  |                     _ => { | ||||||
|  |                         return_errno!(e.errno(), "wait error") | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Finish the async connect
 | ||||||
|  |         let req = self.req.lock(); | ||||||
|  |         if let Some(e) = req.errno { | ||||||
|  |             return_errno!(e, "connect failed"); | ||||||
|  |         } | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn initiate_async_connect(self: &Arc<Self>) { | ||||||
|  |         let io_uring = self.common.io_uring(); | ||||||
|  |         let mut req = self.req.lock(); | ||||||
|  |         // Skip if there is pending request
 | ||||||
|  |         if req.io_handle.is_some() { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         let arc_self = self.clone(); | ||||||
|  |         let callback = move |retval: i32| { | ||||||
|  |             // Guard against Igao attack
 | ||||||
|  |             assert!(retval <= 0); | ||||||
|  |             debug!("connect request complete with retval: {}", retval); | ||||||
|  | 
 | ||||||
|  |             let mut req = arc_self.req.lock(); | ||||||
|  |             // Release the handle to the async connect
 | ||||||
|  |             req.io_handle.take(); | ||||||
|  | 
 | ||||||
|  |             if retval == 0 { | ||||||
|  |                 arc_self.connected.store(true, Ordering::Relaxed); | ||||||
|  |                 arc_self.common.pollee().add_events(IoEvents::OUT); | ||||||
|  |             } else { | ||||||
|  |                 // Store the errno
 | ||||||
|  |                 let errno = Errno::from(-retval as u32); | ||||||
|  |                 req.errno = Some(errno); | ||||||
|  |                 drop(req); | ||||||
|  |                 arc_self.common.set_errno(errno); | ||||||
|  |                 arc_self.connected.store(false, Ordering::Relaxed); | ||||||
|  | 
 | ||||||
|  |                 let events = if errno == ENOTCONN || errno == ECONNRESET || errno == ECONNREFUSED { | ||||||
|  |                     IoEvents::HUP | IoEvents::IN | IoEvents::ERR | ||||||
|  |                 } else { | ||||||
|  |                     IoEvents::ERR | ||||||
|  |                 }; | ||||||
|  |                 arc_self.common.pollee().add_events(events); | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         let host_fd = self.common.host_fd() as _; | ||||||
|  |         let c_addr_ptr = req.c_addr.as_ptr(); | ||||||
|  |         let c_addr_len = req.c_addr_len; | ||||||
|  |         let io_handle = unsafe { | ||||||
|  |             io_uring.connect( | ||||||
|  |                 Fd(host_fd), | ||||||
|  |                 c_addr_ptr as *const libc::sockaddr, | ||||||
|  |                 c_addr_len as u32, | ||||||
|  |                 callback, | ||||||
|  |             ) | ||||||
|  |         }; | ||||||
|  |         req.io_handle = Some(io_handle); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn cancel_connect_request(&self, need_wait: bool) { | ||||||
|  |         { | ||||||
|  |             let io_uring = self.common.io_uring(); | ||||||
|  |             let req = self.req.lock(); | ||||||
|  |             if let Some(io_handle) = &req.io_handle { | ||||||
|  |                 unsafe { io_uring.cancel(io_handle) }; | ||||||
|  |             } else { | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Wait for the cancel to complete if needed
 | ||||||
|  |         if !need_wait { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         let poller = Poller::new(); | ||||||
|  |         let mask = IoEvents::ERR | IoEvents::IN; | ||||||
|  |         self.common.pollee().connect_poller(mask, &poller); | ||||||
|  | 
 | ||||||
|  |         loop { | ||||||
|  |             let pending_request_exist = { | ||||||
|  |                 let req = self.req.lock(); | ||||||
|  |                 req.io_handle.is_some() | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             if pending_request_exist { | ||||||
|  |                 let mut timeout = Some(Duration::from_secs(10)); | ||||||
|  |                 let ret = poller.wait_timeout(timeout.as_mut()); | ||||||
|  |                 if let Err(e) = ret { | ||||||
|  |                     warn!("wait cancel connect request error = {:?}", e.errno()); | ||||||
|  |                     continue; | ||||||
|  |                 } | ||||||
|  |             } else { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[allow(dead_code)] | ||||||
|  |     pub fn peer_addr(&self) -> &A { | ||||||
|  |         &self.peer_addr | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn common(&self) -> &Arc<Common<A, R>> { | ||||||
|  |         &self.common | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // This can be used in connecting state to check non-blocking connect status.
 | ||||||
|  |     pub fn check_connection(&self) -> bool { | ||||||
|  |         // It is fine whether the load happens before or after the store operation
 | ||||||
|  |         self.connected.load(Ordering::Relaxed) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr> ConnectReq<A> { | ||||||
|  |     pub fn new(peer_addr: &A) -> Self { | ||||||
|  |         let (c_addr_storage, c_addr_len) = peer_addr.to_c_storage(); | ||||||
|  |         Self { | ||||||
|  |             io_handle: None, | ||||||
|  |             c_addr: UntrustedBox::new(c_addr_storage), | ||||||
|  |             c_addr_len, | ||||||
|  |             errno: None, | ||||||
|  |             phantom_data: PhantomData, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr, R: Runtime> std::fmt::Debug for ConnectingStream<A, R> { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         f.debug_struct("ConnectingStream") | ||||||
|  |             .field("common", &self.common) | ||||||
|  |             .field("peer_addr", &self.peer_addr) | ||||||
|  |             .field("req", &*self.req.lock()) | ||||||
|  |             .field("connected", &self.connected) | ||||||
|  |             .finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr> std::fmt::Debug for ConnectReq<A> { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         f.debug_struct("ConnectReq") | ||||||
|  |             .field("io_handle", &self.io_handle) | ||||||
|  |             .field("errno", &self.errno) | ||||||
|  |             .finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										114
									
								
								src/libos/src/net/socket/uring/stream/states/connected/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										114
									
								
								src/libos/src/net/socket/uring/stream/states/connected/mod.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,114 @@ | |||||||
|  | use atomic::Ordering; | ||||||
|  | 
 | ||||||
|  | use self::recv::Receiver; | ||||||
|  | use self::send::Sender; | ||||||
|  | use crate::fs::IoEvents as Events; | ||||||
|  | use crate::net::socket::sockopt::SockOptName; | ||||||
|  | use crate::net::socket::uring::common::Common; | ||||||
|  | use crate::net::socket::uring::runtime::Runtime; | ||||||
|  | use crate::prelude::*; | ||||||
|  | 
 | ||||||
|  | mod recv; | ||||||
|  | mod send; | ||||||
|  | 
 | ||||||
|  | pub struct ConnectedStream<A: Addr + 'static, R: Runtime> { | ||||||
|  |     common: Arc<Common<A, R>>, | ||||||
|  |     sender: Sender, | ||||||
|  |     receiver: Receiver, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static, R: Runtime> ConnectedStream<A, R> { | ||||||
|  |     pub fn new(common: Arc<Common<A, R>>) -> Arc<Self> { | ||||||
|  |         common.pollee().reset_events(); | ||||||
|  |         common.pollee().add_events(Events::OUT); | ||||||
|  | 
 | ||||||
|  |         let fd = common.host_fd(); | ||||||
|  | 
 | ||||||
|  |         let sender = Sender::new(); | ||||||
|  |         let receiver = Receiver::new(); | ||||||
|  |         let new_self = Arc::new(Self { | ||||||
|  |             common, | ||||||
|  |             sender, | ||||||
|  |             receiver, | ||||||
|  |         }); | ||||||
|  | 
 | ||||||
|  |         // Start async recv requests right as early as possible to support poll and
 | ||||||
|  |         // improve performance. If we don't start recv requests early, the poll()
 | ||||||
|  |         // might block forever when user just invokes poll(Event::In) without read().
 | ||||||
|  |         // Once we have recv requests completed, we can have Event::In in the events.
 | ||||||
|  |         new_self.initiate_async_recv(); | ||||||
|  | 
 | ||||||
|  |         new_self | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn common(&self) -> &Arc<Common<A, R>> { | ||||||
|  |         &self.common | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn shutdown(&self, how: Shutdown) -> Result<()> { | ||||||
|  |         // Do host shutdown
 | ||||||
|  |         // For shutdown write, don't call host_shutdown until the content in the pending buffer is sent.
 | ||||||
|  |         // For shutdown read, ignore the pending buffer.
 | ||||||
|  |         let (shut_write, send_buf_is_empty, shut_read) = ( | ||||||
|  |             how.should_shut_write(), | ||||||
|  |             self.sender.is_empty(), | ||||||
|  |             how.should_shut_read(), | ||||||
|  |         ); | ||||||
|  |         match (shut_write, send_buf_is_empty, shut_read) { | ||||||
|  |             // As long as send buf is empty, just shutdown.
 | ||||||
|  |             (_, true, _) => self.common.host_shutdown(how)?, | ||||||
|  |             // If not shutdown write, just shutdown.
 | ||||||
|  |             (false, _, _) => self.common.host_shutdown(how)?, | ||||||
|  |             // If shutdown both but the send buf is not empty, only shutdown read.
 | ||||||
|  |             (true, false, true) => self.common.host_shutdown(Shutdown::Read)?, | ||||||
|  |             // If shutdown write but the send buf is not empty, don't do shutdown.
 | ||||||
|  |             (true, false, false) => {} | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Set internal state and trigger events.
 | ||||||
|  |         if shut_read { | ||||||
|  |             self.receiver.shutdown(); | ||||||
|  |             self.common.pollee().add_events(Events::IN); | ||||||
|  |         } | ||||||
|  |         if shut_write { | ||||||
|  |             self.sender.shutdown(); | ||||||
|  |             self.common.pollee().add_events(Events::OUT); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if shut_read && shut_write { | ||||||
|  |             self.common.pollee().add_events(Events::HUP); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn set_closed(&self) { | ||||||
|  |         // Mark the sender and receiver to shutdown to prevent submitting new requests.
 | ||||||
|  |         self.receiver.shutdown(); | ||||||
|  |         self.sender.shutdown(); | ||||||
|  | 
 | ||||||
|  |         self.common.set_closed(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Other methods are implemented in the send and receive modules
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static, R: Runtime> std::fmt::Debug for ConnectedStream<A, R> { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         f.debug_struct("ConnectedStream") | ||||||
|  |             .field("common", &self.common) | ||||||
|  |             .field("sender", &self.sender) | ||||||
|  |             .field("receiver", &self.receiver) | ||||||
|  |             .finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | fn new_msghdr(iovecs_ptr: *mut libc::iovec, iovecs_len: usize) -> libc::msghdr { | ||||||
|  |     use std::mem::MaybeUninit; | ||||||
|  |     // Safety. Setting all fields to zeros is a valid state for msghdr.
 | ||||||
|  |     let mut msghdr: libc::msghdr = unsafe { MaybeUninit::zeroed().assume_init() }; | ||||||
|  |     msghdr.msg_iov = iovecs_ptr; | ||||||
|  |     msghdr.msg_iovlen = iovecs_len as _; | ||||||
|  |     // We do want to leave all other fields as zeros
 | ||||||
|  |     msghdr | ||||||
|  | } | ||||||
							
								
								
									
										458
									
								
								src/libos/src/net/socket/uring/stream/states/connected/recv.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										458
									
								
								src/libos/src/net/socket/uring/stream/states/connected/recv.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,458 @@ | |||||||
|  | use core::hint; | ||||||
|  | use core::sync::atomic::AtomicBool; | ||||||
|  | use core::time::Duration; | ||||||
|  | use std::mem::MaybeUninit; | ||||||
|  | use std::ptr::{self}; | ||||||
|  | 
 | ||||||
|  | use atomic::Ordering; | ||||||
|  | use io_uring_callback::{Fd, IoHandle}; | ||||||
|  | use sgx_untrusted_alloc::{MaybeUntrusted, UntrustedBox}; | ||||||
|  | 
 | ||||||
|  | use super::ConnectedStream; | ||||||
|  | use crate::net::socket::uring::runtime::Runtime; | ||||||
|  | use crate::net::socket::uring::stream::RECV_BUF_SIZE; | ||||||
|  | use crate::prelude::*; | ||||||
|  | use crate::untrusted::UntrustedCircularBuf; | ||||||
|  | use crate::util::sync::{Mutex, MutexGuard}; | ||||||
|  | 
 | ||||||
|  | use crate::events::Poller; | ||||||
|  | use crate::fs::IoEvents as Events; | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static, R: Runtime> ConnectedStream<A, R> { | ||||||
|  |     pub fn recvmsg(self: &Arc<Self>, bufs: &mut [&mut [u8]], flags: RecvFlags) -> Result<usize> { | ||||||
|  |         let total_len: usize = bufs.iter().map(|buf| buf.len()).sum(); | ||||||
|  |         if total_len == 0 { | ||||||
|  |             return Ok(0); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         let mut total_received = 0; | ||||||
|  |         let mut iov_buffer_index = 0; | ||||||
|  |         let mut iov_buffer_offset = 0; | ||||||
|  | 
 | ||||||
|  |         let mask = Events::IN; | ||||||
|  |         // Initialize the poller only when needed
 | ||||||
|  |         let mut poller = None; | ||||||
|  |         let mut timeout = self.common.recv_timeout(); | ||||||
|  |         loop { | ||||||
|  |             // Attempt to read
 | ||||||
|  |             let res = self.try_recvmsg(bufs, flags, iov_buffer_index, iov_buffer_offset); | ||||||
|  | 
 | ||||||
|  |             match res { | ||||||
|  |                 Ok((received_size, index, offset)) => { | ||||||
|  |                     total_received += received_size; | ||||||
|  | 
 | ||||||
|  |                     if !flags.contains(RecvFlags::MSG_WAITALL) || total_received == total_len { | ||||||
|  |                         return Ok(total_received); | ||||||
|  |                     } else { | ||||||
|  |                         // save the index and offset for the next round
 | ||||||
|  |                         iov_buffer_index = index; | ||||||
|  |                         iov_buffer_offset = offset; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |                 Err(e) => { | ||||||
|  |                     if e.errno() != EAGAIN { | ||||||
|  |                         return Err(e); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             if self.common.nonblocking() || flags.contains(RecvFlags::MSG_DONTWAIT) { | ||||||
|  |                 return_errno!(EAGAIN, "no data are present to be received"); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Wait for interesting events by polling
 | ||||||
|  |             if poller.is_none() { | ||||||
|  |                 let new_poller = Poller::new(); | ||||||
|  |                 self.common.pollee().connect_poller(mask, &new_poller); | ||||||
|  |                 poller = Some(new_poller); | ||||||
|  |             } | ||||||
|  |             let events = self.common.pollee().poll(mask, None); | ||||||
|  |             if events.is_empty() { | ||||||
|  |                 let ret = poller.as_ref().unwrap().wait_timeout(timeout.as_mut()); | ||||||
|  |                 if let Err(e) = ret { | ||||||
|  |                     warn!("recv wait errno = {:?}", e.errno()); | ||||||
|  |                     // For recv with MSG_WAITALL, return total received bytes if timeout or interrupt
 | ||||||
|  |                     if flags.contains(RecvFlags::MSG_WAITALL) && total_received > 0 { | ||||||
|  |                         return Ok(total_received); | ||||||
|  |                     } | ||||||
|  |                     match e.errno() { | ||||||
|  |                         ETIMEDOUT => { | ||||||
|  |                             return_errno!(EAGAIN, "timeout reached") | ||||||
|  |                         } | ||||||
|  |                         _ => { | ||||||
|  |                             return_errno!(e.errno(), "wait error") | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn try_recvmsg( | ||||||
|  |         self: &Arc<Self>, | ||||||
|  |         bufs: &mut [&mut [u8]], | ||||||
|  |         flags: RecvFlags, | ||||||
|  |         iov_buffer_index: usize, | ||||||
|  |         iov_buffer_offset: usize, | ||||||
|  |     ) -> Result<(usize, usize, usize)> { | ||||||
|  |         let mut inner = self.receiver.inner.lock(); | ||||||
|  | 
 | ||||||
|  |         if !flags.is_empty() | ||||||
|  |             && flags.intersects(!(RecvFlags::MSG_DONTWAIT | RecvFlags::MSG_WAITALL)) | ||||||
|  |         { | ||||||
|  |             warn!("Unsupported flags: {:?}", flags); | ||||||
|  |             return_errno!(EINVAL, "flags not supported"); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         let res = { | ||||||
|  |             let mut total_consumed = 0; | ||||||
|  |             let mut iov_buffer_index = iov_buffer_index; | ||||||
|  |             let mut iov_buffer_offset = iov_buffer_offset; | ||||||
|  | 
 | ||||||
|  |             // save the received data from bufs[iov_buffer_index][iov_buffer_offset..]
 | ||||||
|  |             for (_, buf) in bufs.iter_mut().skip(iov_buffer_index).enumerate() { | ||||||
|  |                 let this_consumed = inner.recv_buf.consume(&mut buf[iov_buffer_offset..]); | ||||||
|  |                 if this_consumed == 0 { | ||||||
|  |                     break; | ||||||
|  |                 } | ||||||
|  |                 total_consumed += this_consumed; | ||||||
|  | 
 | ||||||
|  |                 // if the buffer is not full, then the try_recvmsg will be used again
 | ||||||
|  |                 // next time, the data will be stored from the offset
 | ||||||
|  |                 if this_consumed < buf[iov_buffer_offset..].len() { | ||||||
|  |                     iov_buffer_offset += this_consumed; | ||||||
|  |                     break; | ||||||
|  |                 } else { | ||||||
|  |                     iov_buffer_index += 1; | ||||||
|  |                     iov_buffer_offset = 0; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             (total_consumed, iov_buffer_index, iov_buffer_offset) | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         if self.receiver.need_update() { | ||||||
|  |             // Only update the recv buf when it is empty and there is no pending recv request
 | ||||||
|  |             if inner.recv_buf.is_empty() && inner.io_handle.is_none() { | ||||||
|  |                 self.receiver.set_need_update(false); | ||||||
|  |                 inner.update_buf_size(RECV_BUF_SIZE.load(Ordering::Relaxed)); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if inner.end_of_file { | ||||||
|  |             return Ok(res); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if inner.recv_buf.is_empty() { | ||||||
|  |             // Mark the socket as non-readable
 | ||||||
|  |             self.common.pollee().del_events(Events::IN); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if res.0 > 0 { | ||||||
|  |             self.do_recv(&mut inner); | ||||||
|  |             return Ok(res); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Only when there are no data available in the recv buffer, shall we check
 | ||||||
|  |         // the following error conditions.
 | ||||||
|  |         //
 | ||||||
|  |         // Case 1: If the read side of the connection has been shutdown...
 | ||||||
|  |         if inner.is_shutdown { | ||||||
|  |             return_errno!(EPIPE, "read side is shutdown"); | ||||||
|  |         } | ||||||
|  |         // Case 2: If the connenction has been broken...
 | ||||||
|  |         if let Some(errno) = inner.fatal { | ||||||
|  |             // Reset error
 | ||||||
|  |             inner.fatal = None; | ||||||
|  |             self.common.pollee().del_events(Events::ERR); | ||||||
|  |             return_errno!(errno, "read failed"); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         self.do_recv(&mut inner); | ||||||
|  |         return_errno!(EAGAIN, "try read again"); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn do_recv(self: &Arc<Self>, inner: &mut MutexGuard<Inner>) { | ||||||
|  |         if inner.recv_buf.is_full() | ||||||
|  |             || inner.is_shutdown | ||||||
|  |             || inner.io_handle.is_some() | ||||||
|  |             || inner.end_of_file | ||||||
|  |             || self.common.is_closed() | ||||||
|  |         { | ||||||
|  |             // Delete ERR events from sender. If io_handle is some, the recv request must be
 | ||||||
|  |             // pending and the events can't be for the reciever. Just delete this event.
 | ||||||
|  |             // This can happen when send request is timeout and canceled.
 | ||||||
|  |             let events = self.common.pollee().poll(Events::IN, None); | ||||||
|  |             if events.contains(Events::ERR) && inner.io_handle.is_some() { | ||||||
|  |                 self.common.pollee().del_events(Events::ERR); | ||||||
|  |             } | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Init the callback invoked upon the completion of the async recv
 | ||||||
|  |         let stream = self.clone(); | ||||||
|  |         let complete_fn = move |retval: i32| { | ||||||
|  |             // let mut inner = stream.receiver.inner.lock().unwrap();
 | ||||||
|  |             let mut inner = stream.receiver.inner.lock(); | ||||||
|  |             trace!("recv request complete with retval: {:?}", retval); | ||||||
|  | 
 | ||||||
|  |             // Release the handle to the async recv
 | ||||||
|  |             inner.io_handle.take(); | ||||||
|  | 
 | ||||||
|  |             // Handle error
 | ||||||
|  |             if retval < 0 { | ||||||
|  |                 // TODO: guard against Iago attack through errno
 | ||||||
|  |                 // We should return here, The error may be due to network reasons
 | ||||||
|  |                 // or because the request was cancelled. We don't want to start a
 | ||||||
|  |                 // new request after cancelled a request.
 | ||||||
|  |                 let errno = Errno::from(-retval as u32); | ||||||
|  |                 inner.fatal = Some(errno); | ||||||
|  |                 stream.common.set_errno(errno); | ||||||
|  | 
 | ||||||
|  |                 let events = if errno == ENOTCONN || errno == ECONNRESET || errno == ECONNREFUSED { | ||||||
|  |                     Events::HUP | Events::IN | Events::ERR | ||||||
|  |                 } else { | ||||||
|  |                     Events::ERR | ||||||
|  |                 }; | ||||||
|  |                 stream.common.pollee().add_events(events); | ||||||
|  | 
 | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|  |             // Handle end of file
 | ||||||
|  |             else if retval == 0 { | ||||||
|  |                 inner.end_of_file = true; | ||||||
|  |                 stream.common.pollee().add_events(Events::IN); | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Handle the normal case of a successful read
 | ||||||
|  |             let nbytes = retval as usize; | ||||||
|  |             inner.recv_buf.produce_without_copy(nbytes); | ||||||
|  | 
 | ||||||
|  |             // Now that we have produced non-zero bytes, the buf must become
 | ||||||
|  |             // ready to read.
 | ||||||
|  |             stream.common.pollee().add_events(Events::IN); | ||||||
|  | 
 | ||||||
|  |             stream.do_recv(&mut inner); | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         // Generate the async recv request
 | ||||||
|  |         let msghdr_ptr = inner.new_recv_req(); | ||||||
|  | 
 | ||||||
|  |         // Submit the async recv to io_uring
 | ||||||
|  |         let io_uring = self.common.io_uring(); | ||||||
|  |         let host_fd = Fd(self.common.host_fd() as _); | ||||||
|  | 
 | ||||||
|  |         let handle = unsafe { io_uring.recvmsg(host_fd, msghdr_ptr, 0, complete_fn) }; | ||||||
|  |         inner.io_handle.replace(handle); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub(super) fn initiate_async_recv(self: &Arc<Self>) { | ||||||
|  |         // trace!("initiate async recv");
 | ||||||
|  |         let mut inner = self.receiver.inner.lock(); | ||||||
|  |         self.do_recv(&mut inner); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn cancel_recv_requests(&self) { | ||||||
|  |         { | ||||||
|  |             let inner = self.receiver.inner.lock(); | ||||||
|  |             if let Some(io_handle) = &inner.io_handle { | ||||||
|  |                 let io_uring = self.common.io_uring(); | ||||||
|  |                 unsafe { io_uring.cancel(io_handle) }; | ||||||
|  |             } else { | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // wait for the cancel to complete
 | ||||||
|  |         let poller = Poller::new(); | ||||||
|  |         let mask = Events::ERR | Events::IN; | ||||||
|  |         self.common.pollee().connect_poller(mask, &poller); | ||||||
|  | 
 | ||||||
|  |         loop { | ||||||
|  |             let pending_request_exist = { | ||||||
|  |                 let inner = self.receiver.inner.lock(); | ||||||
|  |                 inner.io_handle.is_some() | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             if pending_request_exist { | ||||||
|  |                 let mut timeout = Some(Duration::from_secs(10)); | ||||||
|  |                 let ret = poller.wait_timeout(timeout.as_mut()); | ||||||
|  |                 if let Err(e) = ret { | ||||||
|  |                     warn!("wait cancel recv request error = {:?}", e.errno()); | ||||||
|  |                     continue; | ||||||
|  |                 } | ||||||
|  |             } else { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn bytes_to_consume(self: &Arc<Self>) -> usize { | ||||||
|  |         let inner = self.receiver.inner.lock(); | ||||||
|  |         inner.recv_buf.consumable() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // This function will try to update the kernel recv buf size.
 | ||||||
|  |     // For socket recv, there will always be a pending request in advance. Thus,we can only update the kernel
 | ||||||
|  |     // buffer when a recv request is done and the kernel buffer is empty. Here, we just set the update flag.
 | ||||||
|  |     pub fn try_update_recv_buf_size(&self, buf_size: usize) { | ||||||
|  |         let pre_buf_size = RECV_BUF_SIZE.swap(buf_size, Ordering::Relaxed); | ||||||
|  |         if buf_size == pre_buf_size { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         self.receiver.set_need_update(true); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub struct Receiver { | ||||||
|  |     inner: Mutex<Inner>, | ||||||
|  |     need_update: AtomicBool, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Receiver { | ||||||
|  |     pub fn new() -> Self { | ||||||
|  |         let inner = Mutex::new(Inner::new()); | ||||||
|  |         let need_update = AtomicBool::new(false); | ||||||
|  | 
 | ||||||
|  |         Self { inner, need_update } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn shutdown(&self) { | ||||||
|  |         let mut inner = self.inner.lock(); | ||||||
|  |         inner.is_shutdown = true; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn set_need_update(&self, need_update: bool) { | ||||||
|  |         self.need_update.store(need_update, Ordering::Relaxed) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn need_update(&self) -> bool { | ||||||
|  |         self.need_update.load(Ordering::Relaxed) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl std::fmt::Debug for Receiver { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         f.debug_struct("Receiver") | ||||||
|  |             .field("inner", &self.inner.lock()) | ||||||
|  |             .finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | struct Inner { | ||||||
|  |     recv_buf: UntrustedCircularBuf, | ||||||
|  |     recv_req: UntrustedBox<RecvReq>, | ||||||
|  |     io_handle: Option<IoHandle>, | ||||||
|  |     is_shutdown: bool, | ||||||
|  |     end_of_file: bool, | ||||||
|  |     fatal: Option<Errno>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Safety. `RecvReq` does not implement `Send`. But since all pointers in `RecvReq`
 | ||||||
|  | // refer to `recv_buf`, we can be sure that it is ok for `RecvReq` to move between
 | ||||||
|  | // threads. All other fields in `RecvReq` implement `Send` as well. So the entirety
 | ||||||
|  | // of `Inner` is `Send`-safe.
 | ||||||
|  | unsafe impl Send for Inner {} | ||||||
|  | 
 | ||||||
|  | impl Inner { | ||||||
|  |     pub fn new() -> Self { | ||||||
|  |         Self { | ||||||
|  |             recv_buf: UntrustedCircularBuf::with_capacity(RECV_BUF_SIZE.load(Ordering::Relaxed)), | ||||||
|  |             recv_req: UntrustedBox::new_uninit(), | ||||||
|  |             io_handle: None, | ||||||
|  |             is_shutdown: false, | ||||||
|  |             end_of_file: false, | ||||||
|  |             fatal: None, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn update_buf_size(&mut self, buf_size: usize) { | ||||||
|  |         debug_assert!(self.recv_buf.is_empty() && self.io_handle.is_none()); | ||||||
|  |         let new_recv_buf = UntrustedCircularBuf::with_capacity(buf_size); | ||||||
|  |         self.recv_buf = new_recv_buf; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// Constructs a new recv request according to the receiver's internal state.
 | ||||||
|  |     ///
 | ||||||
|  |     /// The new `RecvReq` will be put into `self.recv_req`, which is a location that is
 | ||||||
|  |     /// accessible by io_uring. A pointer to the C version of the resulting `RecvReq`,
 | ||||||
|  |     /// which is `libc::msghdr`, will be returned.
 | ||||||
|  |     ///
 | ||||||
|  |     /// The buffer used in the new `RecvReq` is part of `self.recv_buf`.
 | ||||||
|  |     pub fn new_recv_req(&mut self) -> *mut libc::msghdr { | ||||||
|  |         let (iovecs, iovecs_len) = self.gen_iovecs_from_recv_buf(); | ||||||
|  | 
 | ||||||
|  |         let msghdr_ptr: *mut libc::msghdr = &mut self.recv_req.msg; | ||||||
|  |         let iovecs_ptr: *mut libc::iovec = &mut self.recv_req.iovecs as *mut _ as _; | ||||||
|  | 
 | ||||||
|  |         let msg = super::new_msghdr(iovecs_ptr, iovecs_len); | ||||||
|  | 
 | ||||||
|  |         self.recv_req.msg = msg; | ||||||
|  |         self.recv_req.iovecs = iovecs; | ||||||
|  | 
 | ||||||
|  |         msghdr_ptr | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn gen_iovecs_from_recv_buf(&mut self) -> ([libc::iovec; 2], usize) { | ||||||
|  |         let mut iovecs_len = 0; | ||||||
|  |         let mut iovecs = unsafe { MaybeUninit::<[libc::iovec; 2]>::uninit().assume_init() }; | ||||||
|  |         self.recv_buf.with_producer_view(|part0, part1| { | ||||||
|  |             debug_assert!(part0.len() > 0); | ||||||
|  | 
 | ||||||
|  |             iovecs[0] = libc::iovec { | ||||||
|  |                 iov_base: part0.as_ptr() as _, | ||||||
|  |                 iov_len: part0.len() as _, | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             iovecs[1] = if part1.len() > 0 { | ||||||
|  |                 iovecs_len = 2; | ||||||
|  |                 libc::iovec { | ||||||
|  |                     iov_base: part1.as_ptr() as _, | ||||||
|  |                     iov_len: part1.len() as _, | ||||||
|  |                 } | ||||||
|  |             } else { | ||||||
|  |                 iovecs_len = 1; | ||||||
|  |                 libc::iovec { | ||||||
|  |                     iov_base: ptr::null_mut(), | ||||||
|  |                     iov_len: 0, | ||||||
|  |                 } | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             // Only access the producer's buffer; zero bytes produced for now.
 | ||||||
|  |             0 | ||||||
|  |         }); | ||||||
|  |         debug_assert!(iovecs_len > 0); | ||||||
|  |         (iovecs, iovecs_len) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl std::fmt::Debug for Inner { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         f.debug_struct("Inner") | ||||||
|  |             .field("recv_buf", &self.recv_buf) | ||||||
|  |             .field("io_handle", &self.io_handle) | ||||||
|  |             .field("is_shutdown", &self.is_shutdown) | ||||||
|  |             .field("end_of_file", &self.end_of_file) | ||||||
|  |             .field("fatal", &self.fatal) | ||||||
|  |             .finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[repr(C)] | ||||||
|  | struct RecvReq { | ||||||
|  |     msg: libc::msghdr, | ||||||
|  |     iovecs: [libc::iovec; 2], | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Safety. RecvReq is a C-style struct.
 | ||||||
|  | unsafe impl MaybeUntrusted for RecvReq {} | ||||||
|  | 
 | ||||||
|  | // Acquired by `IoUringCell<T: Copy>`.
 | ||||||
|  | impl Copy for RecvReq {} | ||||||
|  | 
 | ||||||
|  | impl Clone for RecvReq { | ||||||
|  |     fn clone(&self) -> Self { | ||||||
|  |         *self | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										466
									
								
								src/libos/src/net/socket/uring/stream/states/connected/send.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										466
									
								
								src/libos/src/net/socket/uring/stream/states/connected/send.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,466 @@ | |||||||
|  | use core::hint; | ||||||
|  | use core::sync::atomic::AtomicBool; | ||||||
|  | use core::time::Duration; | ||||||
|  | use std::mem::MaybeUninit; | ||||||
|  | use std::ptr::{self}; | ||||||
|  | 
 | ||||||
|  | use atomic::Ordering; | ||||||
|  | use io_uring_callback::{Fd, IoHandle}; | ||||||
|  | use log::error; | ||||||
|  | use sgx_untrusted_alloc::{MaybeUntrusted, UntrustedBox}; | ||||||
|  | 
 | ||||||
|  | use super::ConnectedStream; | ||||||
|  | use crate::net::socket::uring::runtime::Runtime; | ||||||
|  | use crate::net::socket::uring::stream::SEND_BUF_SIZE; | ||||||
|  | use crate::prelude::*; | ||||||
|  | use crate::untrusted::UntrustedCircularBuf; | ||||||
|  | 
 | ||||||
|  | use crate::util::sync::{Mutex, MutexGuard}; | ||||||
|  | 
 | ||||||
|  | use crate::events::Poller; | ||||||
|  | use crate::fs::IoEvents as Events; | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static, R: Runtime> ConnectedStream<A, R> { | ||||||
|  |     // We make sure the all the buffer contents are buffered in kernel and then return.
 | ||||||
|  |     pub fn sendmsg(self: &Arc<Self>, bufs: &[&[u8]], flags: SendFlags) -> Result<usize> { | ||||||
|  |         let total_len: usize = bufs.iter().map(|buf| buf.len()).sum(); | ||||||
|  |         if total_len == 0 { | ||||||
|  |             return Ok(0); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         let mut send_len = 0; | ||||||
|  |         // variables to track the position of async sendmsg.
 | ||||||
|  |         let mut iov_buf_id = 0; // user buffer id tracker
 | ||||||
|  |         let mut iov_buf_index = 0; // user buffer index tracker
 | ||||||
|  | 
 | ||||||
|  |         let mask = Events::OUT; | ||||||
|  |         // Initialize the poller only when needed
 | ||||||
|  |         let mut poller = None; | ||||||
|  |         let mut timeout = self.common.send_timeout(); | ||||||
|  |         loop { | ||||||
|  |             // Attempt to write
 | ||||||
|  |             let res = self.try_sendmsg(bufs, flags, &mut iov_buf_id, &mut iov_buf_index); | ||||||
|  |             if let Ok(len) = res { | ||||||
|  |                 send_len += len; | ||||||
|  |                 // Sent all or sent partial but it is nonblocking, return bytes sent
 | ||||||
|  |                 if send_len == total_len | ||||||
|  |                     || self.common.nonblocking() | ||||||
|  |                     || flags.contains(SendFlags::MSG_DONTWAIT) | ||||||
|  |                 { | ||||||
|  |                     return Ok(send_len); | ||||||
|  |                 } | ||||||
|  |             } else if !res.has_errno(EAGAIN) { | ||||||
|  |                 return res; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Still some buffer contents pending
 | ||||||
|  |             if self.common.nonblocking() || flags.contains(SendFlags::MSG_DONTWAIT) { | ||||||
|  |                 return_errno!(EAGAIN, "try write again"); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Wait for interesting events by polling
 | ||||||
|  |             if poller.is_none() { | ||||||
|  |                 let new_poller = Poller::new(); | ||||||
|  |                 self.common.pollee().connect_poller(mask, &new_poller); | ||||||
|  |                 poller = Some(new_poller); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             let events = self.common.pollee().poll(mask, None); | ||||||
|  |             if events.is_empty() { | ||||||
|  |                 let ret = poller.as_ref().unwrap().wait_timeout(timeout.as_mut()); | ||||||
|  |                 if let Err(e) = ret { | ||||||
|  |                     warn!("send wait errno = {:?}", e.errno()); | ||||||
|  |                     match e.errno() { | ||||||
|  |                         ETIMEDOUT => { | ||||||
|  |                             // Just cancel send requests if timeout
 | ||||||
|  |                             self.cancel_send_requests(); | ||||||
|  |                             return_errno!(EAGAIN, "timeout reached") | ||||||
|  |                         } | ||||||
|  |                         _ => { | ||||||
|  |                             return_errno!(e.errno(), "wait error") | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn try_sendmsg( | ||||||
|  |         self: &Arc<Self>, | ||||||
|  |         bufs: &[&[u8]], | ||||||
|  |         flags: SendFlags, | ||||||
|  |         iov_buf_id: &mut usize, | ||||||
|  |         iov_buf_index: &mut usize, | ||||||
|  |     ) -> Result<usize> { | ||||||
|  |         let mut inner = self.sender.inner.lock(); | ||||||
|  | 
 | ||||||
|  |         if !flags.is_empty() | ||||||
|  |             && flags.intersects( | ||||||
|  |                 !(SendFlags::MSG_DONTWAIT | SendFlags::MSG_NOSIGNAL | SendFlags::MSG_MORE), | ||||||
|  |             ) | ||||||
|  |         { | ||||||
|  |             error!("Not supported flags: {:?}", flags); | ||||||
|  |             return_errno!(EINVAL, "not supported flags"); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Check for error condition before write.
 | ||||||
|  |         //
 | ||||||
|  |         // Case 1. If the write side of the connection has been shutdown...
 | ||||||
|  |         if inner.is_shutdown() { | ||||||
|  |             return_errno!(EPIPE, "write side is shutdown"); | ||||||
|  |         } | ||||||
|  |         // Case 2. If the connenction has been broken...
 | ||||||
|  |         if let Some(errno) = inner.fatal { | ||||||
|  |             // Reset error
 | ||||||
|  |             inner.fatal = None; | ||||||
|  |             self.common.pollee().del_events(Events::ERR); | ||||||
|  |             return_errno!(errno, "write failed"); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Copy data from the bufs to the send buffer
 | ||||||
|  |         // If the send buffer is full, update the user buffer tracker, return error to wait for events
 | ||||||
|  |         // And once there is free space, continue from the user buffer tracker
 | ||||||
|  |         let nbytes = { | ||||||
|  |             let mut total_produced = 0; | ||||||
|  |             let last_time_buf_id = iov_buf_id.clone(); | ||||||
|  |             let mut last_time_buf_idx = iov_buf_index.clone(); | ||||||
|  |             for (_i, buf) in bufs.iter().skip(last_time_buf_id).enumerate() { | ||||||
|  |                 let i = _i + last_time_buf_id; // After skipping ,the index still starts from 0
 | ||||||
|  |                 let this_produced = inner.send_buf.produce(&buf[last_time_buf_idx..]); | ||||||
|  |                 total_produced += this_produced; | ||||||
|  |                 if this_produced < buf[last_time_buf_idx..].len() { | ||||||
|  |                     // Send buffer is full.
 | ||||||
|  |                     *iov_buf_id = i; | ||||||
|  |                     *iov_buf_index = last_time_buf_idx + this_produced; | ||||||
|  |                     break; | ||||||
|  |                 } else { | ||||||
|  |                     // For next buffer, start from the front
 | ||||||
|  |                     last_time_buf_idx = 0; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             total_produced | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         if inner.send_buf.is_full() { | ||||||
|  |             // Mark the socket as non-writable
 | ||||||
|  |             self.common.pollee().del_events(Events::OUT); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Since the send buffer is not empty, we can try to flush the buffer
 | ||||||
|  |         if inner.io_handle.is_none() { | ||||||
|  |             self.do_send(&mut inner); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if nbytes > 0 { | ||||||
|  |             Ok(nbytes) | ||||||
|  |         } else { | ||||||
|  |             return_errno!(EAGAIN, "try write again"); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn do_send(self: &Arc<Self>, inner: &mut MutexGuard<Inner>) { | ||||||
|  |         // This function can also be called even if the socket is set to shutdown by shutdown syscall. This is due to the
 | ||||||
|  |         // async behaviour that the kernel may return to user before actually issuing the request. We should
 | ||||||
|  |         // keep sending the request as long as the send buffer is not empty even if the socket is shutdown.
 | ||||||
|  |         debug_assert!(inner.is_shutdown != ShutdownStatus::PostShutdown); | ||||||
|  |         debug_assert!(!inner.send_buf.is_empty()); | ||||||
|  |         debug_assert!(inner.io_handle.is_none()); | ||||||
|  | 
 | ||||||
|  |         // Init the callback invoked upon the completion of the async send
 | ||||||
|  |         let stream = self.clone(); | ||||||
|  |         let complete_fn = move |retval: i32| { | ||||||
|  |             let mut inner = stream.sender.inner.lock(); | ||||||
|  | 
 | ||||||
|  |             trace!("send request complete with retval: {}", retval); | ||||||
|  |             // Release the handle to the async send
 | ||||||
|  |             inner.io_handle.take(); | ||||||
|  | 
 | ||||||
|  |             // Handle error
 | ||||||
|  |             if retval < 0 { | ||||||
|  |                 // TODO: guard against Iago attack through errno
 | ||||||
|  |                 // TODO: should we ignore EINTR and try again?
 | ||||||
|  |                 let errno = Errno::from(-retval as u32); | ||||||
|  | 
 | ||||||
|  |                 inner.fatal = Some(errno); | ||||||
|  |                 stream.common.set_errno(errno); | ||||||
|  |                 stream.common.pollee().add_events(Events::ERR); | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|  |             assert!(retval != 0); | ||||||
|  | 
 | ||||||
|  |             // Handle the normal case of a successful write
 | ||||||
|  |             let nbytes = retval as usize; | ||||||
|  |             inner.send_buf.consume_without_copy(nbytes); | ||||||
|  | 
 | ||||||
|  |             // Now that we have consume non-zero bytes, the buf must become
 | ||||||
|  |             // ready to write.
 | ||||||
|  |             stream.common.pollee().add_events(Events::OUT); | ||||||
|  | 
 | ||||||
|  |             // Attempt to send again if there are available data in the buf.
 | ||||||
|  |             if !inner.send_buf.is_empty() { | ||||||
|  |                 stream.do_send(&mut inner); | ||||||
|  |             } else if inner.is_shutdown == ShutdownStatus::PreShutdown { | ||||||
|  |                 // The buffer is empty and the write side is shutdown by the user. We can safely shutdown host file here.
 | ||||||
|  |                 let _ = stream.common.host_shutdown(Shutdown::Write); | ||||||
|  |                 inner.is_shutdown = ShutdownStatus::PostShutdown | ||||||
|  |             } else if stream.sender.need_update() { | ||||||
|  |                 // send_buf is empty. We can try to update the send_buf
 | ||||||
|  |                 stream.sender.set_need_update(false); | ||||||
|  |                 inner.update_buf_size(SEND_BUF_SIZE.load(Ordering::Relaxed)); | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         // Generate the async send request
 | ||||||
|  |         let msghdr_ptr = inner.new_send_req(); | ||||||
|  | 
 | ||||||
|  |         trace!("send submit request"); | ||||||
|  |         // Submit the async send to io_uring
 | ||||||
|  |         let io_uring = self.common.io_uring(); | ||||||
|  |         let host_fd = Fd(self.common.host_fd() as _); | ||||||
|  |         let handle = unsafe { io_uring.sendmsg(host_fd, msghdr_ptr, 0, complete_fn) }; | ||||||
|  |         inner.io_handle.replace(handle); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn cancel_send_requests(&self) { | ||||||
|  |         let io_uring = self.common.io_uring(); | ||||||
|  |         let inner = self.sender.inner.lock(); | ||||||
|  |         if let Some(io_handle) = &inner.io_handle { | ||||||
|  |             unsafe { io_uring.cancel(io_handle) }; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // This function will try to update the kernel buf size.
 | ||||||
|  |     // If the kernel buf is currently empty, the size will be updated immediately.
 | ||||||
|  |     // If the kernel buf is not empty, update the flag in Sender and update the kernel buf after send.
 | ||||||
|  |     pub fn try_update_send_buf_size(&self, buf_size: usize) { | ||||||
|  |         let pre_buf_size = SEND_BUF_SIZE.swap(buf_size, Ordering::Relaxed); | ||||||
|  |         if pre_buf_size == buf_size { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Try to acquire the lock. If success, try directly update here.
 | ||||||
|  |         // If failure, don't wait because there is pending send request.
 | ||||||
|  |         if let Some(mut inner) = self.sender.inner.try_lock() { | ||||||
|  |             if inner.send_buf.is_empty() && inner.io_handle.is_none() { | ||||||
|  |                 inner.update_buf_size(buf_size); | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Can't easily aquire lock or the sendbuf is not empty. Update the flag only
 | ||||||
|  |         self.sender.set_need_update(true); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Normally, We will always try to send as long as the kernel send buf is not empty. However, if the user calls close, we will wait LINGER time
 | ||||||
|  |     // and then cancel on-going or new-issued send requests.
 | ||||||
|  |     pub fn try_empty_send_buf_when_close(&self) { | ||||||
|  |         // let inner = self.sender.inner.lock().unwrap();
 | ||||||
|  |         let inner = self.sender.inner.lock(); | ||||||
|  |         debug_assert!(inner.is_shutdown()); | ||||||
|  |         if inner.send_buf.is_empty() { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Wait for linger time to empty the kernel buffer or cancel subsequent requests.
 | ||||||
|  |         drop(inner); | ||||||
|  |         const DEFUALT_LINGER_TIME: usize = 10; | ||||||
|  |         let poller = Poller::new(); | ||||||
|  |         let mask = Events::ERR | Events::OUT; | ||||||
|  |         self.common.pollee().connect_poller(mask, &poller); | ||||||
|  | 
 | ||||||
|  |         loop { | ||||||
|  |             let pending_request_exist = { | ||||||
|  |                 // let inner = self.sender.inner.lock().unwrap();
 | ||||||
|  |                 let inner = self.sender.inner.lock(); | ||||||
|  |                 inner.io_handle.is_some() | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             if pending_request_exist { | ||||||
|  |                 let mut timeout = Some(Duration::from_secs(DEFUALT_LINGER_TIME as u64)); | ||||||
|  |                 let ret = poller.wait_timeout(timeout.as_mut()); | ||||||
|  |                 trace!("wait empty send buffer ret = {:?}", ret); | ||||||
|  |                 if let Err(_) = ret { | ||||||
|  |                     // No complete request to wake. Just cancel the send requests.
 | ||||||
|  |                     let io_uring = self.common.io_uring(); | ||||||
|  |                     let inner = self.sender.inner.lock(); | ||||||
|  |                     if let Some(io_handle) = &inner.io_handle { | ||||||
|  |                         unsafe { io_uring.cancel(io_handle) }; | ||||||
|  |                         // Loop again to wait the cancel request to complete
 | ||||||
|  |                         continue; | ||||||
|  |                     } else { | ||||||
|  |                         // No pending request, just break
 | ||||||
|  |                         break; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } else { | ||||||
|  |                 // There is no pending requests
 | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub struct Sender { | ||||||
|  |     inner: Mutex<Inner>, | ||||||
|  |     need_update: AtomicBool, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Sender { | ||||||
|  |     pub fn new() -> Self { | ||||||
|  |         let inner = Mutex::new(Inner::new()); | ||||||
|  |         let need_update = AtomicBool::new(false); | ||||||
|  |         Self { inner, need_update } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn shutdown(&self) { | ||||||
|  |         let mut inner = self.inner.lock(); | ||||||
|  |         inner.is_shutdown = ShutdownStatus::PreShutdown; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn is_empty(&self) -> bool { | ||||||
|  |         let inner = self.inner.lock(); | ||||||
|  |         inner.send_buf.is_empty() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn set_need_update(&self, need_update: bool) { | ||||||
|  |         self.need_update.store(need_update, Ordering::Relaxed) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn need_update(&self) -> bool { | ||||||
|  |         self.need_update.load(Ordering::Relaxed) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl std::fmt::Debug for Sender { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         f.debug_struct("Sender") | ||||||
|  |             .field("inner", &self.inner.lock()) | ||||||
|  |             .finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | struct Inner { | ||||||
|  |     send_buf: UntrustedCircularBuf, | ||||||
|  |     send_req: UntrustedBox<SendReq>, | ||||||
|  |     io_handle: Option<IoHandle>, | ||||||
|  |     is_shutdown: ShutdownStatus, | ||||||
|  |     fatal: Option<Errno>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Safety. `SendReq` does not implement `Send`. But since all pointers in `SengReq`
 | ||||||
|  | // refer to `send_buf`, we can be sure that it is ok for `SendReq` to move between
 | ||||||
|  | // threads. All other fields in `SendReq` implement `Send` as well. So the entirety
 | ||||||
|  | // of `Inner` is `Send`-safe.
 | ||||||
|  | unsafe impl Send for Inner {} | ||||||
|  | 
 | ||||||
|  | impl Inner { | ||||||
|  |     pub fn new() -> Self { | ||||||
|  |         Self { | ||||||
|  |             send_buf: UntrustedCircularBuf::with_capacity(SEND_BUF_SIZE.load(Ordering::Relaxed)), | ||||||
|  |             send_req: UntrustedBox::new_uninit(), | ||||||
|  |             io_handle: None, | ||||||
|  |             is_shutdown: ShutdownStatus::Running, | ||||||
|  |             fatal: None, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn update_buf_size(&mut self, buf_size: usize) { | ||||||
|  |         debug_assert!(self.send_buf.is_empty() && self.io_handle.is_none()); | ||||||
|  |         let new_send_buf = UntrustedCircularBuf::with_capacity(buf_size); | ||||||
|  |         self.send_buf = new_send_buf; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn is_shutdown(&self) -> bool { | ||||||
|  |         self.is_shutdown == ShutdownStatus::PreShutdown | ||||||
|  |             || self.is_shutdown == ShutdownStatus::PostShutdown | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// Constructs a new send request according to the sender's internal state.
 | ||||||
|  |     ///
 | ||||||
|  |     /// The new `SendReq` will be put into `self.send_req`, which is a location that is
 | ||||||
|  |     /// accessible by io_uring. A pointer to the C version of the resulting `SendReq`,
 | ||||||
|  |     /// which is `libc::msghdr`, will be returned.
 | ||||||
|  |     ///
 | ||||||
|  |     /// The buffer used in the new `SendReq` is part of `self.send_buf`.
 | ||||||
|  |     pub fn new_send_req(&mut self) -> *mut libc::msghdr { | ||||||
|  |         let (iovecs, iovecs_len) = self.gen_iovecs_from_send_buf(); | ||||||
|  | 
 | ||||||
|  |         let msghdr_ptr: *mut libc::msghdr = &mut self.send_req.msg; | ||||||
|  |         let iovecs_ptr: *mut libc::iovec = &mut self.send_req.iovecs as *mut _ as _; | ||||||
|  | 
 | ||||||
|  |         let msg = super::new_msghdr(iovecs_ptr, iovecs_len); | ||||||
|  | 
 | ||||||
|  |         self.send_req.msg = msg; | ||||||
|  |         self.send_req.iovecs = iovecs; | ||||||
|  | 
 | ||||||
|  |         msghdr_ptr | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn gen_iovecs_from_send_buf(&mut self) -> ([libc::iovec; 2], usize) { | ||||||
|  |         let mut iovecs_len = 0; | ||||||
|  |         let mut iovecs = unsafe { MaybeUninit::<[libc::iovec; 2]>::uninit().assume_init() }; | ||||||
|  |         self.send_buf.with_consumer_view(|part0, part1| { | ||||||
|  |             debug_assert!(part0.len() > 0); | ||||||
|  | 
 | ||||||
|  |             iovecs[0] = libc::iovec { | ||||||
|  |                 iov_base: part0.as_ptr() as _, | ||||||
|  |                 iov_len: part0.len() as _, | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             iovecs[1] = if part1.len() > 0 { | ||||||
|  |                 iovecs_len = 2; | ||||||
|  |                 libc::iovec { | ||||||
|  |                     iov_base: part1.as_ptr() as _, | ||||||
|  |                     iov_len: part1.len() as _, | ||||||
|  |                 } | ||||||
|  |             } else { | ||||||
|  |                 iovecs_len = 1; | ||||||
|  |                 libc::iovec { | ||||||
|  |                     iov_base: ptr::null_mut(), | ||||||
|  |                     iov_len: 0, | ||||||
|  |                 } | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             // Only access the consumer's buffer; zero bytes consumed for now.
 | ||||||
|  |             0 | ||||||
|  |         }); | ||||||
|  |         debug_assert!(iovecs_len > 0); | ||||||
|  |         (iovecs, iovecs_len) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl std::fmt::Debug for Inner { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         f.debug_struct("Inner") | ||||||
|  |             .field("send_buf", &self.send_buf) | ||||||
|  |             .field("io_handle", &self.io_handle) | ||||||
|  |             .field("is_shutdown", &self.is_shutdown) | ||||||
|  |             .field("fatal", &self.fatal) | ||||||
|  |             .finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[repr(C)] | ||||||
|  | struct SendReq { | ||||||
|  |     msg: libc::msghdr, | ||||||
|  |     iovecs: [libc::iovec; 2], | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Safety. SendReq is a C-style struct.
 | ||||||
|  | unsafe impl MaybeUntrusted for SendReq {} | ||||||
|  | 
 | ||||||
|  | // Acquired by `IoUringCell<T: Copy>`.
 | ||||||
|  | impl Copy for SendReq {} | ||||||
|  | 
 | ||||||
|  | impl Clone for SendReq { | ||||||
|  |     fn clone(&self) -> Self { | ||||||
|  |         *self | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Debug, PartialEq)] | ||||||
|  | enum ShutdownStatus { | ||||||
|  |     Running,      // not shutdown
 | ||||||
|  |     PreShutdown,  // start the shutdown process, set by calling shutdown syscall
 | ||||||
|  |     PostShutdown, // shutdown process is done, set when the buffer is empty
 | ||||||
|  | } | ||||||
							
								
								
									
										72
									
								
								src/libos/src/net/socket/uring/stream/states/init.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										72
									
								
								src/libos/src/net/socket/uring/stream/states/init.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,72 @@ | |||||||
|  | use crate::fs::IoEvents; | ||||||
|  | use crate::net::socket::uring::common::Common; | ||||||
|  | use crate::net::socket::uring::runtime::Runtime; | ||||||
|  | use crate::prelude::*; | ||||||
|  | 
 | ||||||
|  | /// A stream socket that is in its initial state.
 | ||||||
|  | pub struct InitStream<A: Addr + 'static, R: Runtime> { | ||||||
|  |     common: Arc<Common<A, R>>, | ||||||
|  |     inner: Mutex<Inner>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | struct Inner { | ||||||
|  |     has_bound: bool, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static, R: Runtime> InitStream<A, R> { | ||||||
|  |     pub fn new(nonblocking: bool) -> Result<Arc<Self>> { | ||||||
|  |         let common = Arc::new(Common::new(Type::STREAM, nonblocking, None)?); | ||||||
|  |         common.pollee().add_events(IoEvents::HUP | IoEvents::OUT); | ||||||
|  |         let inner = Mutex::new(Inner::new()); | ||||||
|  |         let new_self = Self { common, inner }; | ||||||
|  |         Ok(Arc::new(new_self)) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn new_with_common(common: Arc<Common<A, R>>) -> Result<Arc<Self>> { | ||||||
|  |         let inner = Mutex::new(Inner { | ||||||
|  |             has_bound: common.addr().is_some(), | ||||||
|  |         }); | ||||||
|  |         let new_self = Self { common, inner }; | ||||||
|  |         Ok(Arc::new(new_self)) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn bind(&self, addr: &A) -> Result<()> { | ||||||
|  |         let mut inner = self.inner.lock(); | ||||||
|  |         if inner.has_bound { | ||||||
|  |             return_errno!(EINVAL, "the socket is already bound to an address"); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         crate::net::socket::uring::common::do_bind(self.common.host_fd(), addr)?; | ||||||
|  | 
 | ||||||
|  |         inner.has_bound = true; | ||||||
|  |         self.common.set_addr(addr); | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn common(&self) -> &Arc<Common<A, R>> { | ||||||
|  |         &self.common | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static, R: Runtime> std::fmt::Debug for InitStream<A, R> { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         f.debug_struct("InitStream") | ||||||
|  |             .field("common", &self.common) | ||||||
|  |             .field("inner", &*self.inner.lock()) | ||||||
|  |             .finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Inner { | ||||||
|  |     pub fn new() -> Self { | ||||||
|  |         Self { has_bound: false } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl std::fmt::Debug for Inner { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         f.debug_struct("Inner") | ||||||
|  |             .field("has_bound", &self.has_bound) | ||||||
|  |             .finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										430
									
								
								src/libos/src/net/socket/uring/stream/states/listen.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										430
									
								
								src/libos/src/net/socket/uring/stream/states/listen.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,430 @@ | |||||||
|  | use core::time::Duration; | ||||||
|  | use std::collections::VecDeque; | ||||||
|  | use std::marker::PhantomData; | ||||||
|  | use std::mem::size_of; | ||||||
|  | 
 | ||||||
|  | use io_uring_callback::{Fd, IoHandle}; | ||||||
|  | use sgx_untrusted_alloc::{MaybeUntrusted, UntrustedBox}; | ||||||
|  | 
 | ||||||
|  | use super::ConnectedStream; | ||||||
|  | use crate::events::Poller; | ||||||
|  | use crate::fs::IoEvents; | ||||||
|  | use crate::net::socket::uring::common::{do_close, Common}; | ||||||
|  | use crate::net::socket::uring::runtime::Runtime; | ||||||
|  | use crate::prelude::*; | ||||||
|  | use libc::ocall::shutdown as do_shutdown; | ||||||
|  | 
 | ||||||
|  | // We issue the async accept request ahead of time. But with big backlog number,
 | ||||||
|  | // there will be too many pending requests, which could be harmful to the system.
 | ||||||
|  | const PENDING_ASYNC_ACCEPT_NUM_MAX: usize = 128; | ||||||
|  | 
 | ||||||
|  | /// A listener stream, ready to accept incoming connections.
 | ||||||
|  | pub struct ListenerStream<A: Addr + 'static, R: Runtime> { | ||||||
|  |     common: Arc<Common<A, R>>, | ||||||
|  |     inner: Mutex<Inner<A>>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static, R: Runtime> ListenerStream<A, R> { | ||||||
|  |     /// Creates a new listener stream.
 | ||||||
|  |     pub fn new(backlog: u32, common: Arc<Common<A, R>>) -> Result<Arc<Self>> { | ||||||
|  |         // Here we use different variables for the backlog. For the libos, as we will issue async accept request
 | ||||||
|  |         // ahead of time, and cacel when the socket closes, we set the libos backlog to a reasonable value which
 | ||||||
|  |         // is no greater than the max value we set to save resources and make it more efficient. For the host,
 | ||||||
|  |         // we just use the backlog value for maximum connection.
 | ||||||
|  |         let libos_backlog = std::cmp::min(backlog, PENDING_ASYNC_ACCEPT_NUM_MAX as u32); | ||||||
|  |         let host_backlog = backlog; | ||||||
|  | 
 | ||||||
|  |         let inner = Inner::new(libos_backlog)?; | ||||||
|  |         Self::do_listen(common.host_fd(), host_backlog)?; | ||||||
|  | 
 | ||||||
|  |         common.pollee().reset_events(); | ||||||
|  |         let new_self = Arc::new(Self { | ||||||
|  |             common, | ||||||
|  |             inner: Mutex::new(inner), | ||||||
|  |         }); | ||||||
|  | 
 | ||||||
|  |         // Start async accept requests right as early as possible to improve performance
 | ||||||
|  |         { | ||||||
|  |             let inner = new_self.inner.lock(); | ||||||
|  |             new_self.initiate_async_accepts(inner); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         Ok(new_self) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn do_listen(host_fd: FileDesc, backlog: u32) -> Result<()> { | ||||||
|  |         try_libc!(libc::ocall::listen(host_fd as _, backlog as _)); | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn accept(self: &Arc<Self>, nonblocking: bool) -> Result<Arc<ConnectedStream<A, R>>> { | ||||||
|  |         let mask = IoEvents::IN; | ||||||
|  |         // Init the poller only when needed
 | ||||||
|  |         let mut poller = None; | ||||||
|  |         let mut timeout = self.common.recv_timeout(); | ||||||
|  |         loop { | ||||||
|  |             // Attempt to accept
 | ||||||
|  |             let res = self.try_accept(nonblocking); | ||||||
|  |             if !res.has_errno(EAGAIN) { | ||||||
|  |                 return res; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             if self.common.nonblocking() { | ||||||
|  |                 return_errno!(EAGAIN, "no connections are present to be accepted"); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Ensure the poller is initialized
 | ||||||
|  |             if poller.is_none() { | ||||||
|  |                 let new_poller = Poller::new(); | ||||||
|  |                 self.common.pollee().connect_poller(mask, &new_poller); | ||||||
|  |                 poller = Some(new_poller); | ||||||
|  |             } | ||||||
|  |             // Wait for interesting events by polling
 | ||||||
|  | 
 | ||||||
|  |             let events = self.common.pollee().poll(mask, None); | ||||||
|  |             if events.is_empty() { | ||||||
|  |                 let ret = poller.as_ref().unwrap().wait_timeout(timeout.as_mut()); | ||||||
|  |                 if let Err(e) = ret { | ||||||
|  |                     warn!("accept wait errno = {:?}", e.errno()); | ||||||
|  |                     match e.errno() { | ||||||
|  |                         ETIMEDOUT => { | ||||||
|  |                             return_errno!(EAGAIN, "timeout reached") | ||||||
|  |                         } | ||||||
|  |                         _ => { | ||||||
|  |                             return_errno!(e.errno(), "wait error") | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn try_accept(self: &Arc<Self>, nonblocking: bool) -> Result<Arc<ConnectedStream<A, R>>> { | ||||||
|  |         let mut inner = self.inner.lock(); | ||||||
|  | 
 | ||||||
|  |         if let Some(errno) = inner.fatal { | ||||||
|  |             // Reset error
 | ||||||
|  |             inner.fatal = None; | ||||||
|  |             self.common.pollee().del_events(IoEvents::ERR); | ||||||
|  |             return_errno!(errno, "accept failed"); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         let (accepted_fd, accepted_addr) = inner.backlog.pop_completed_req().ok_or_else(|| { | ||||||
|  |             self.common.pollee().del_events(IoEvents::IN); | ||||||
|  |             errno!(EAGAIN, "try accept again") | ||||||
|  |         })?; | ||||||
|  | 
 | ||||||
|  |         if !inner.backlog.has_completed_reqs() { | ||||||
|  |             self.common.pollee().del_events(IoEvents::IN); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         self.initiate_async_accepts(inner); | ||||||
|  | 
 | ||||||
|  |         let common = { | ||||||
|  |             let common = Arc::new(Common::with_host_fd(accepted_fd, Type::STREAM, nonblocking)); | ||||||
|  |             common.set_peer_addr(&accepted_addr); | ||||||
|  |             common | ||||||
|  |         }; | ||||||
|  |         let accepted_stream = ConnectedStream::new(common); | ||||||
|  |         Ok(accepted_stream) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn initiate_async_accepts(self: &Arc<Self>, mut inner: MutexGuard<Inner<A>>) { | ||||||
|  |         let backlog = &mut inner.backlog; | ||||||
|  |         while backlog.has_free_entries() { | ||||||
|  |             backlog.start_new_req(self); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn common(&self) -> &Arc<Common<A, R>> { | ||||||
|  |         &self.common | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn cancel_accept_requests(&self) { | ||||||
|  |         { | ||||||
|  |             // Set the listener stream as closed to prevent submitting new request in the callback fn
 | ||||||
|  |             self.common().set_closed(); | ||||||
|  |             let io_uring = self.common.io_uring(); | ||||||
|  |             let inner = self.inner.lock(); | ||||||
|  |             for entry in inner.backlog.entries.iter() { | ||||||
|  |                 if let Entry::Pending { io_handle } = entry { | ||||||
|  |                     unsafe { io_uring.cancel(&io_handle) }; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // wait for all the cancel requests to complete
 | ||||||
|  |         let poller = Poller::new(); | ||||||
|  |         let mask = IoEvents::ERR | IoEvents::IN; | ||||||
|  |         self.common.pollee().connect_poller(mask, &poller); | ||||||
|  | 
 | ||||||
|  |         loop { | ||||||
|  |             let pending_entry_exists = { | ||||||
|  |                 let inner = self.inner.lock(); | ||||||
|  |                 inner | ||||||
|  |                     .backlog | ||||||
|  |                     .entries | ||||||
|  |                     .iter() | ||||||
|  |                     .find(|entry| match entry { | ||||||
|  |                         Entry::Pending { .. } => true, | ||||||
|  |                         _ => false, | ||||||
|  |                     }) | ||||||
|  |                     .is_some() | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             if pending_entry_exists { | ||||||
|  |                 let mut timeout = Some(Duration::from_secs(20)); | ||||||
|  |                 let ret = poller.wait_timeout(timeout.as_mut()); | ||||||
|  |                 if let Err(e) = ret { | ||||||
|  |                     warn!("wait cancel accept request error = {:?}", e.errno()); | ||||||
|  |                     continue; | ||||||
|  |                 } | ||||||
|  |             } else { | ||||||
|  |                 // Reset the stream for re-listen
 | ||||||
|  |                 self.common().reset_closed(); | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn shutdown(&self, how: Shutdown) -> Result<()> { | ||||||
|  |         if how == Shutdown::Both { | ||||||
|  |             self.common.host_shutdown(Shutdown::Both)?; | ||||||
|  |             self.common | ||||||
|  |                 .pollee() | ||||||
|  |                 .add_events(IoEvents::IN | IoEvents::OUT | IoEvents::HUP); | ||||||
|  |         } else if how.should_shut_read() { | ||||||
|  |             self.common.host_shutdown(Shutdown::Read)?; | ||||||
|  |             self.common.pollee().add_events(IoEvents::IN); | ||||||
|  |         } else if how.should_shut_write() { | ||||||
|  |             self.common.host_shutdown(Shutdown::Write)?; | ||||||
|  |             self.common.pollee().add_events(IoEvents::OUT); | ||||||
|  |         } | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static, R: Runtime> std::fmt::Debug for ListenerStream<A, R> { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         f.debug_struct("ListenerStream") | ||||||
|  |             .field("common", &self.common) | ||||||
|  |             .field("inner", &self.inner.lock()) | ||||||
|  |             .finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /// The mutable, internal state of a listener stream.
 | ||||||
|  | struct Inner<A: Addr> { | ||||||
|  |     backlog: Backlog<A>, | ||||||
|  |     fatal: Option<Errno>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr> Inner<A> { | ||||||
|  |     pub fn new(backlog: u32) -> Result<Self> { | ||||||
|  |         Ok(Inner { | ||||||
|  |             backlog: Backlog::with_capacity(backlog as usize)?, | ||||||
|  |             fatal: None, | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static> std::fmt::Debug for Inner<A> { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         f.debug_struct("Inner") | ||||||
|  |             .field("backlog", &self.backlog) | ||||||
|  |             .field("fatal", &self.fatal) | ||||||
|  |             .finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /// An entry in the backlog.
 | ||||||
|  | #[derive(Debug)] | ||||||
|  | enum Entry { | ||||||
|  |     /// The entry is free to use.
 | ||||||
|  |     Free, | ||||||
|  |     /// The entry is a pending accept request.
 | ||||||
|  |     Pending { io_handle: IoHandle }, | ||||||
|  |     /// The entry is a completed accept request.
 | ||||||
|  |     Completed { host_fd: FileDesc }, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Default for Entry { | ||||||
|  |     fn default() -> Self { | ||||||
|  |         Self::Free | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /// An async io_uring accept request.
 | ||||||
|  | #[derive(Copy, Clone)] | ||||||
|  | #[repr(C)] | ||||||
|  | struct AcceptReq { | ||||||
|  |     c_addr: libc::sockaddr_storage, | ||||||
|  |     c_addr_len: libc::socklen_t, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Safety. AcceptReq is a C-style struct with C-style fields.
 | ||||||
|  | unsafe impl MaybeUntrusted for AcceptReq {} | ||||||
|  | 
 | ||||||
|  | /// A backlog of incoming connections of a listener stream.
 | ||||||
|  | ///
 | ||||||
|  | /// With backlog, we can start async accept requests, keep track of the pending requests,
 | ||||||
|  | /// and maintain the ones that have completed.
 | ||||||
|  | struct Backlog<A: Addr> { | ||||||
|  |     // The entries in the backlog.
 | ||||||
|  |     entries: Box<[Entry]>, | ||||||
|  |     // Arguments of the io_uring requests submitted for the entries in the backlog.
 | ||||||
|  |     reqs: UntrustedBox<[AcceptReq]>, | ||||||
|  |     // The indexes of completed entries.
 | ||||||
|  |     completed: VecDeque<usize>, | ||||||
|  |     // The number of free entries.
 | ||||||
|  |     num_free: usize, | ||||||
|  |     phantom_data: PhantomData<A>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr> Backlog<A> { | ||||||
|  |     pub fn with_capacity(capacity: usize) -> Result<Self> { | ||||||
|  |         if capacity == 0 { | ||||||
|  |             return_errno!(EINVAL, "capacity cannot be zero"); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         let entries = (0..capacity) | ||||||
|  |             .map(|_| Entry::Free) | ||||||
|  |             .collect::<Vec<Entry>>() | ||||||
|  |             .into_boxed_slice(); | ||||||
|  |         let reqs = UntrustedBox::new_uninit_slice(capacity); | ||||||
|  |         let completed = VecDeque::new(); | ||||||
|  |         let num_free = capacity; | ||||||
|  |         let new_self = Self { | ||||||
|  |             entries, | ||||||
|  |             reqs, | ||||||
|  |             completed, | ||||||
|  |             num_free, | ||||||
|  |             phantom_data: PhantomData, | ||||||
|  |         }; | ||||||
|  |         Ok(new_self) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn has_free_entries(&self) -> bool { | ||||||
|  |         self.num_free > 0 | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// Start a new async accept request, turning a free entry into a pending one.
 | ||||||
|  |     pub fn start_new_req<R: Runtime>(&mut self, stream: &Arc<ListenerStream<A, R>>) { | ||||||
|  |         if stream.common.is_closed() { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |         debug_assert!(self.has_free_entries()); | ||||||
|  | 
 | ||||||
|  |         let entry_idx = self | ||||||
|  |             .entries | ||||||
|  |             .iter() | ||||||
|  |             .position(|entry| matches!(entry, Entry::Free)) | ||||||
|  |             .unwrap(); | ||||||
|  | 
 | ||||||
|  |         let (c_addr_ptr, c_addr_len_ptr) = { | ||||||
|  |             let accept_req = &mut self.reqs[entry_idx]; | ||||||
|  |             accept_req.c_addr_len = size_of::<libc::sockaddr_storage>() as _; | ||||||
|  | 
 | ||||||
|  |             let c_addr_ptr = &mut accept_req.c_addr as *mut _ as _; | ||||||
|  |             let c_addr_len_ptr = &mut accept_req.c_addr_len as _; | ||||||
|  |             (c_addr_ptr, c_addr_len_ptr) | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         let callback = { | ||||||
|  |             let stream = stream.clone(); | ||||||
|  |             move |retval: i32| { | ||||||
|  |                 let mut inner = stream.inner.lock(); | ||||||
|  | 
 | ||||||
|  |                 trace!("accept request complete with retval: {:?}", retval); | ||||||
|  | 
 | ||||||
|  |                 if retval < 0 { | ||||||
|  |                     // Since most errors that may result from the accept syscall are _not fatal_,
 | ||||||
|  |                     // we simply ignore the errno code and try again.
 | ||||||
|  |                     //
 | ||||||
|  |                     // According to the man page, Linux may report the network errors on an
 | ||||||
|  |                     // newly-accepted socket through the accept system call. Thus, we should not
 | ||||||
|  |                     // treat the listener socket as "broken" simply because an error is returned
 | ||||||
|  |                     // from the accept syscall.
 | ||||||
|  |                     //
 | ||||||
|  |                     // TODO: throw fatal errors to the upper layer.
 | ||||||
|  |                     let errno = Errno::from(-retval as u32); | ||||||
|  |                     log::error!("Accept error: errno = {}", errno); | ||||||
|  | 
 | ||||||
|  |                     inner.backlog.entries[entry_idx] = Entry::Free; | ||||||
|  |                     inner.backlog.num_free += 1; | ||||||
|  | 
 | ||||||
|  |                     // When canceling request, a poller might be waiting for this to return.
 | ||||||
|  |                     inner.fatal = Some(errno); | ||||||
|  |                     stream.common.set_errno(errno); | ||||||
|  |                     stream.common.pollee().add_events(IoEvents::ERR); | ||||||
|  | 
 | ||||||
|  |                     // After getting the error from the accept system call, we should not start
 | ||||||
|  |                     // the async accept requests again, because this may cause a large number of
 | ||||||
|  |                     // io-uring requests to be retried
 | ||||||
|  |                     return; | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 let host_fd = retval as FileDesc; | ||||||
|  |                 inner.backlog.entries[entry_idx] = Entry::Completed { host_fd }; | ||||||
|  |                 inner.backlog.completed.push_back(entry_idx); | ||||||
|  | 
 | ||||||
|  |                 stream.common.pollee().add_events(IoEvents::IN); | ||||||
|  | 
 | ||||||
|  |                 stream.initiate_async_accepts(inner); | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  |         let io_uring = stream.common.io_uring(); | ||||||
|  |         let fd = stream.common.host_fd() as i32; | ||||||
|  |         let flags = 0; | ||||||
|  |         let io_handle = | ||||||
|  |             unsafe { io_uring.accept(Fd(fd), c_addr_ptr, c_addr_len_ptr, flags, callback) }; | ||||||
|  |         self.entries[entry_idx] = Entry::Pending { io_handle }; | ||||||
|  |         self.num_free -= 1; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn has_completed_reqs(&self) -> bool { | ||||||
|  |         self.completed.len() > 0 | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// Pop a completed async accept request, turing a completed entry into a free one.
 | ||||||
|  |     pub fn pop_completed_req(&mut self) -> Option<(FileDesc, A)> { | ||||||
|  |         let completed_idx = self.completed.pop_front()?; | ||||||
|  |         let accepted_addr = { | ||||||
|  |             let AcceptReq { c_addr, c_addr_len } = self.reqs[completed_idx].clone(); | ||||||
|  |             A::from_c_storage(&c_addr, c_addr_len as _).unwrap() | ||||||
|  |         }; | ||||||
|  |         let accepted_fd = { | ||||||
|  |             let entry = &mut self.entries[completed_idx]; | ||||||
|  |             let accepted_fd = match entry { | ||||||
|  |                 Entry::Completed { host_fd } => *host_fd, | ||||||
|  |                 _ => unreachable!("the entry should have been completed"), | ||||||
|  |             }; | ||||||
|  |             self.num_free += 1; | ||||||
|  |             *entry = Entry::Free; | ||||||
|  |             accepted_fd | ||||||
|  |         }; | ||||||
|  |         Some((accepted_fd, accepted_addr)) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr + 'static> std::fmt::Debug for Backlog<A> { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         f.debug_struct("Backlog") | ||||||
|  |             .field("entries", &self.entries) | ||||||
|  |             .field("completed", &self.completed) | ||||||
|  |             .field("num_free", &self.num_free) | ||||||
|  |             .finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<A: Addr> Drop for Backlog<A> { | ||||||
|  |     fn drop(&mut self) { | ||||||
|  |         for entry in self.entries.iter() { | ||||||
|  |             if let Entry::Completed { host_fd } = entry { | ||||||
|  |                 if let Err(e) = do_close(*host_fd) { | ||||||
|  |                     log::error!("close fd failed, host_fd: {}, err: {}", host_fd, e); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										9
									
								
								src/libos/src/net/socket/uring/stream/states/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										9
									
								
								src/libos/src/net/socket/uring/stream/states/mod.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,9 @@ | |||||||
|  | mod connect; | ||||||
|  | mod connected; | ||||||
|  | mod init; | ||||||
|  | mod listen; | ||||||
|  | 
 | ||||||
|  | pub use self::connect::ConnectingStream; | ||||||
|  | pub use self::connected::ConnectedStream; | ||||||
|  | pub use self::init::InitStream; | ||||||
|  | pub use self::listen::ListenerStream; | ||||||
| @ -17,8 +17,11 @@ pub use std::sync::{ | |||||||
| pub use crate::error::Result; | pub use crate::error::Result; | ||||||
| pub use crate::error::*; | pub use crate::error::*; | ||||||
| pub use crate::fs::{File, FileDesc, FileRef}; | pub use crate::fs::{File, FileDesc, FileRef}; | ||||||
|  | pub use crate::net::socket::util::Addr; | ||||||
|  | pub use crate::net::socket::{Domain, RecvFlags, SendFlags, Shutdown, Type}; | ||||||
| pub use crate::process::{pid_t, uid_t}; | pub use crate::process::{pid_t, uid_t}; | ||||||
| pub use crate::util::sync::RwLock; | pub use crate::util::sync::RwLock; | ||||||
|  | pub use crate::util::sync::{Mutex, MutexGuard}; | ||||||
| 
 | 
 | ||||||
| macro_rules! debug_trace { | macro_rules! debug_trace { | ||||||
|     () => { |     () => { | ||||||
|  | |||||||
| @ -44,8 +44,7 @@ use crate::net::{ | |||||||
|     do_accept, do_accept4, do_bind, do_connect, do_epoll_create, do_epoll_create1, do_epoll_ctl, |     do_accept, do_accept4, do_bind, do_connect, do_epoll_create, do_epoll_create1, do_epoll_ctl, | ||||||
|     do_epoll_pwait, do_epoll_wait, do_getpeername, do_getsockname, do_getsockopt, do_listen, |     do_epoll_pwait, do_epoll_wait, do_getpeername, do_getsockname, do_getsockopt, do_listen, | ||||||
|     do_poll, do_ppoll, do_pselect6, do_recvfrom, do_recvmsg, do_select, do_sendmmsg, do_sendmsg, |     do_poll, do_ppoll, do_pselect6, do_recvfrom, do_recvmsg, do_select, do_sendmmsg, do_sendmsg, | ||||||
|     do_sendto, do_setsockopt, do_shutdown, do_socket, do_socketpair, mmsghdr, msghdr, msghdr_mut, |     do_sendto, do_setsockopt, do_shutdown, do_socket, do_socketpair, mmsghdr, sigset_argpack, | ||||||
|     sigset_argpack, |  | ||||||
| }; | }; | ||||||
| use crate::process::{ | use crate::process::{ | ||||||
|     do_arch_prctl, do_clone, do_execve, do_exit, do_exit_group, do_futex, do_get_robust_list, |     do_arch_prctl, do_clone, do_execve, do_exit, do_exit_group, do_futex, do_get_robust_list, | ||||||
| @ -143,8 +142,8 @@ macro_rules! process_syscall_table_with_callback { | |||||||
|             (Accept = 43) => do_accept(fd: c_int, addr: *mut libc::sockaddr, addr_len: *mut libc::socklen_t), |             (Accept = 43) => do_accept(fd: c_int, addr: *mut libc::sockaddr, addr_len: *mut libc::socklen_t), | ||||||
|             (Sendto = 44) => do_sendto(fd: c_int, base: *const c_void, len: size_t, flags: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t), |             (Sendto = 44) => do_sendto(fd: c_int, base: *const c_void, len: size_t, flags: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t), | ||||||
|             (Recvfrom = 45) => do_recvfrom(fd: c_int, base: *mut c_void, len: size_t, flags: c_int, addr: *mut libc::sockaddr, addr_len: *mut libc::socklen_t), |             (Recvfrom = 45) => do_recvfrom(fd: c_int, base: *mut c_void, len: size_t, flags: c_int, addr: *mut libc::sockaddr, addr_len: *mut libc::socklen_t), | ||||||
|             (Sendmsg = 46) => do_sendmsg(fd: c_int, msg_ptr: *const msghdr, flags_c: c_int), |             (Sendmsg = 46) => do_sendmsg(fd: c_int, msg_ptr: *const libc::msghdr, flags_c: c_int), | ||||||
|             (Recvmsg = 47) => do_recvmsg(fd: c_int, msg_mut_ptr: *mut msghdr_mut, flags_c: c_int), |             (Recvmsg = 47) => do_recvmsg(fd: c_int, msg_mut_ptr: *mut libc::msghdr, flags_c: c_int), | ||||||
|             (Shutdown = 48) => do_shutdown(fd: c_int, how: c_int), |             (Shutdown = 48) => do_shutdown(fd: c_int, how: c_int), | ||||||
|             (Bind = 49) => do_bind(fd: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t), |             (Bind = 49) => do_bind(fd: c_int, addr: *const libc::sockaddr, addr_len: libc::socklen_t), | ||||||
|             (Listen = 50) => do_listen(fd: c_int, backlog: c_int), |             (Listen = 50) => do_listen(fd: c_int, backlog: c_int), | ||||||
|  | |||||||
| @ -54,6 +54,7 @@ LINK_FLAGS += -lsgx_quote_ex_sim | |||||||
| endif | endif | ||||||
| endif | endif | ||||||
| 
 | 
 | ||||||
|  | LINK_FLAGS += -L$(PROJECT_DIR)/deps/io-uring/ocalls/target/release/ -lsgx_io_uring_ocalls | ||||||
| ALL_BUILD_SUBDIRS := $(sort $(patsubst %/,%,$(dir $(LIBOCCLUM_PAL_SO_REAL) $(EDL_C_OBJS) $(C_OBJS) $(CXX_OBJS) $(VDSO_OBJS)))) | ALL_BUILD_SUBDIRS := $(sort $(patsubst %/,%,$(dir $(LIBOCCLUM_PAL_SO_REAL) $(EDL_C_OBJS) $(C_OBJS) $(CXX_OBJS) $(VDSO_OBJS)))) | ||||||
| 
 | 
 | ||||||
| .PHONY: all format format-check clean | .PHONY: all format format-check clean | ||||||
| @ -79,7 +80,8 @@ $(OBJ_DIR)/pal/$(SRC_OBJ)/Enclave_u.c: $(SGX_EDGER8R) ../Enclave.edl | |||||||
| 		$(SGX_EDGER8R) $(SGX_EDGER8R_MODE) --untrusted $(CUR_DIR)/../Enclave.edl \
 | 		$(SGX_EDGER8R) $(SGX_EDGER8R_MODE) --untrusted $(CUR_DIR)/../Enclave.edl \
 | ||||||
| 		--search-path $(SGX_SDK)/include \
 | 		--search-path $(SGX_SDK)/include \
 | ||||||
| 		--search-path $(RUST_SGX_SDK_DIR)/edl/ \
 | 		--search-path $(RUST_SGX_SDK_DIR)/edl/ \
 | ||||||
| 		--search-path $(CRATES_DIR)/vdso-time/ocalls | 		--search-path $(CRATES_DIR)/vdso-time/ocalls \
 | ||||||
|  | 		--search-path $(PROJECT_DIR)/deps/io-uring/ocalls | ||||||
| 	@echo "GEN <= $@" | 	@echo "GEN <= $@" | ||||||
| 
 | 
 | ||||||
| $(OBJ_DIR)/pal/$(SRC_OBJ)/%.o: src/%.c | $(OBJ_DIR)/pal/$(SRC_OBJ)/%.o: src/%.c | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user