Refactor the futex implementation

1. Use multiple futex buckets to reduce lock contention on futex data
strcutures
2. Add FUTEX_REQUEUE support
3. Add the condition variable test case
This commit is contained in:
LI Qing 2019-11-18 02:47:12 +00:00 committed by Tate, Hongliang Tian
parent b91566d486
commit 4ee3396152
4 changed files with 329 additions and 86 deletions

@ -1,4 +1,7 @@
use super::*; use super::*;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::intrinsics::atomic_load;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
/// `FutexOp`, `FutexFlags`, and `futex_op_and_flags_from_u32` are helper types and /// `FutexOp`, `FutexFlags`, and `futex_op_and_flags_from_u32` are helper types and
@ -66,29 +69,123 @@ pub fn futex_op_and_flags_from_u32(bits: u32) -> Result<(FutexOp, FutexFlags)> {
/// Do futex wait /// Do futex wait
pub fn futex_wait(futex_addr: *const i32, futex_val: i32) -> Result<()> { pub fn futex_wait(futex_addr: *const i32, futex_val: i32) -> Result<()> {
// Get and lock the futex bucket
let futex_key = FutexKey::new(futex_addr); let futex_key = FutexKey::new(futex_addr);
let futex_item = FUTEX_TABLE.lock().unwrap().get_or_new_item(futex_key); let (_, futex_bucket_ref) = FUTEX_BUCKETS.get_bucket(futex_key);
let mut futex_bucket = futex_bucket_ref.lock().unwrap();
futex_item.wait(futex_val); // Check the futex value
if futex_key.load_val() != futex_val {
return_errno!(EAGAIN, "futex value does not match")
}
// Why we first lock the bucket then check the futex value?
//
// CPU 0 <Waiter> CPU 1 <Waker>
// (user mode)
// val = *futex_addr;
// syscall(FUTEX_WAIT);
// (kernel mode)
// futex_wait(futex_addr, val) {
// bucket = get_bucket();
// actual_val = *futex_addr;
// (user mode)
// *futex_addr = new_val;
// syscall(FUTEX_WAKE);
// (kernel mode)
// futex_wake(futex_addr) {
// bucket = get_bucket();
// lock(bucket);
// bucket.dequeue_and_wake_items()
// unlock(bucket);
// return;
// }
// if actual_val == val {
// lock(bucket);
// bucket.enqueue_item();
// unlock(bucket);
// wait();
// }
// }
// If the waiter on CPU 0 does not lock the bucket before check the futex velue,
// it cannot find the transition of futex value from val to new_val and enqueue
// to the bucket, which will cause the waiter to wait forever.
FUTEX_TABLE.lock().unwrap().put_item(futex_item); let futex_item = FutexItem::new(futex_key);
Ok(()) futex_bucket.enqueue_item(futex_item.clone());
// Must make sure that no locks are holded by this thread before wait
drop(futex_bucket);
futex_item.wait()
} }
/// Do futex wake /// Do futex wake
pub fn futex_wake(futex_addr: *const i32, max_count: usize) -> Result<usize> { pub fn futex_wake(futex_addr: *const i32, max_count: usize) -> Result<usize> {
// Get and lock the futex bucket
let futex_key = FutexKey::new(futex_addr); let futex_key = FutexKey::new(futex_addr);
let futex_item = FUTEX_TABLE.lock().unwrap().get_item(futex_key)?; let (_, futex_bucket_ref) = FUTEX_BUCKETS.get_bucket(futex_key);
let count = futex_item.wake(max_count); let mut futex_bucket = futex_bucket_ref.lock().unwrap();
FUTEX_TABLE.lock().unwrap().put_item(futex_item);
// Dequeue and wake up the items in the bucket
let count = futex_bucket.dequeue_and_wake_items(futex_key, max_count);
Ok(count) Ok(count)
} }
lazy_static! { /// Do futex requeue
static ref FUTEX_TABLE: SgxMutex<FutexTable> = { SgxMutex::new(FutexTable::new()) }; pub fn futex_requeue(
futex_addr: *const i32,
max_nwakes: usize,
max_nrequeues: usize,
futex_new_addr: *const i32,
) -> Result<usize> {
if futex_new_addr == futex_addr {
return futex_wake(futex_addr, max_nwakes);
}
let futex_key = FutexKey::new(futex_addr);
let futex_new_key = FutexKey::new(futex_new_addr);
let (bucket_idx, futex_bucket_ref) = FUTEX_BUCKETS.get_bucket(futex_key);
let (new_bucket_idx, futex_new_bucket_ref) = FUTEX_BUCKETS.get_bucket(futex_new_key);
let nwakes = {
if bucket_idx != new_bucket_idx {
let (mut futex_bucket, mut futex_new_bucket) = {
if bucket_idx < new_bucket_idx {
let mut futex_bucket = futex_bucket_ref.lock().unwrap();
let mut futex_new_bucket = futex_new_bucket_ref.lock().unwrap();
(futex_bucket, futex_new_bucket)
} else {
// bucket_idx > new_bucket_idx
let mut futex_new_bucket = futex_new_bucket_ref.lock().unwrap();
let mut futex_bucket = futex_bucket_ref.lock().unwrap();
(futex_bucket, futex_new_bucket)
}
};
let nwakes = futex_bucket.dequeue_and_wake_items(futex_key, max_nwakes);
futex_bucket.requeue_items_to_another_bucket(
futex_key,
&mut futex_new_bucket,
futex_new_key,
max_nrequeues,
);
nwakes
} else {
// bucket_idx == new_bucket_idx
let mut futex_bucket = futex_bucket_ref.lock().unwrap();
let nwakes = futex_bucket.dequeue_and_wake_items(futex_key, max_nwakes);
futex_bucket.update_item_keys(futex_key, futex_new_key, max_nrequeues);
nwakes
}
};
Ok(nwakes)
} }
#[derive(PartialEq, Eq, Hash, Copy, Clone)] // Make sure futex bucket count is the power of 2
const BUCKET_COUNT: usize = 1 << 8;
const BUCKET_MASK: usize = BUCKET_COUNT - 1;
lazy_static! {
static ref FUTEX_BUCKETS: FutexBucketVec = { FutexBucketVec::new(BUCKET_COUNT) };
}
#[derive(PartialEq, Copy, Clone)]
struct FutexKey(usize); struct FutexKey(usize);
impl FutexKey { impl FutexKey {
@ -97,97 +194,131 @@ impl FutexKey {
} }
pub fn load_val(&self) -> i32 { pub fn load_val(&self) -> i32 {
unsafe { *(self.0 as *const i32) } unsafe { atomic_load(self.0 as *const i32) }
}
pub fn addr(&self) -> usize {
self.0
} }
} }
#[derive(Clone)]
struct FutexItem { struct FutexItem {
key: FutexKey, key: FutexKey,
queue: SgxMutex<VecDeque<WaiterRef>>, waiter: WaiterRef,
} }
impl FutexItem { impl FutexItem {
pub fn new(key: FutexKey) -> FutexItem { pub fn new(key: FutexKey) -> FutexItem {
FutexItem { FutexItem {
key: key, key: key,
queue: SgxMutex::new(VecDeque::new()), waiter: Arc::new(Waiter::new()),
} }
} }
pub fn wake(&self, max_count: usize) -> usize { pub fn wake(&self) {
let mut queue = self.queue.lock().unwrap(); self.waiter.wake()
let mut count = 0;
while count < max_count {
let waiter = {
let waiter_option = queue.pop_front();
if waiter_option.is_none() {
break;
} }
waiter_option.unwrap()
}; pub fn wait(&self) -> Result<()> {
waiter.wake(); self.waiter.wait()
}
}
struct FutexBucket {
queue: VecDeque<FutexItem>,
}
type FutexBucketRef = Arc<SgxMutex<FutexBucket>>;
impl FutexBucket {
pub fn new() -> FutexBucket {
FutexBucket {
queue: VecDeque::new(),
}
}
pub fn enqueue_item(&mut self, item: FutexItem) {
self.queue.push_back(item);
}
pub fn dequeue_and_wake_items(&mut self, key: FutexKey, max_count: usize) -> usize {
let mut count = 0;
let mut idx = 0;
while count < max_count && idx < self.queue.len() {
if key == self.queue[idx].key {
if let Some(item) = self.queue.swap_remove_back(idx) {
item.wake();
count += 1; count += 1;
} }
} else {
idx += 1;
}
}
count count
} }
pub fn wait(&self, futex_val: i32) -> () { pub fn update_item_keys(&mut self, key: FutexKey, new_key: FutexKey, max_count: usize) -> () {
let mut queue = self.queue.lock().unwrap(); let mut count = 0;
if self.key.load_val() != futex_val { for item in self.queue.iter_mut() {
return; if count == max_count {
break;
}
if (*item).key == key {
(*item).key = new_key;
count += 1;
}
}
} }
let waiter = Arc::new(Waiter::new()); pub fn requeue_items_to_another_bucket(
queue.push_back(waiter.clone()); &mut self,
drop(queue); key: FutexKey,
another: &mut Self,
// Must make sure that no locks are holded by this thread before sleep new_key: FutexKey,
waiter.wait(); max_nrequeues: usize,
) -> () {
let mut count = 0;
let mut idx = 0;
while count < max_nrequeues && idx < self.queue.len() {
if key == self.queue[idx].key {
if let Some(mut item) = self.queue.swap_remove_back(idx) {
item.key = new_key;
another.enqueue_item(item);
count += 1;
}
} else {
idx += 1;
}
}
} }
} }
type FutexItemRef = Arc<FutexItem>; struct FutexBucketVec {
vec: Vec<FutexBucketRef>,
struct FutexTable {
table: HashMap<FutexKey, FutexItemRef>,
} }
impl FutexTable { impl FutexBucketVec {
pub fn new() -> FutexTable { pub fn new(size: usize) -> FutexBucketVec {
FutexTable { let mut buckets = FutexBucketVec {
table: HashMap::new(), vec: Vec::with_capacity(size),
};
for idx in 0..size {
let bucket = Arc::new(SgxMutex::new(FutexBucket::new()));
buckets.vec.push(bucket);
} }
buckets
} }
pub fn get_or_new_item(&mut self, key: FutexKey) -> FutexItemRef { pub fn get_bucket(&self, key: FutexKey) -> (usize, FutexBucketRef) {
let table = &mut self.table; let idx = BUCKET_MASK & {
let item = table // The addr is the multiples of 4, so we ignore the last 2 bits
.entry(key) let addr = key.addr() >> 2;
.or_insert_with(|| Arc::new(FutexItem::new(key))); let mut s = DefaultHasher::new();
item.clone() addr.hash(&mut s);
} s.finish() as usize
};
pub fn get_item(&mut self, key: FutexKey) -> Result<FutexItemRef> { (idx, self.vec[idx].clone())
let table = &mut self.table;
table
.get_mut(&key)
.map(|item| item.clone())
.ok_or_else(|| errno!(ENOENT, "futex key cannot be found"))
}
pub fn put_item(&mut self, item: FutexItemRef) {
let table = &mut self.table;
// If there are only two references, one is the given argument, the
// other in the table, then it is time to release the futex item.
// This is because we are holding the lock of futex table and the
// reference count cannot be possibly increased by other threads.
if Arc::strong_count(&item) == 2 {
// Release the last but one reference
let key = item.key;
drop(item);
// Release the last reference
table.remove(&key);
}
} }
} }
@ -196,6 +327,7 @@ struct Waiter {
thread: *const c_void, thread: *const c_void,
is_woken: AtomicBool, is_woken: AtomicBool,
} }
type WaiterRef = Arc<Waiter>; type WaiterRef = Arc<Waiter>;
impl Waiter { impl Waiter {
@ -206,16 +338,22 @@ impl Waiter {
} }
} }
pub fn wait(&self) { pub fn wait(&self) -> Result<()> {
while self.is_woken.load(Ordering::SeqCst) != true { let current = unsafe { sgx_thread_get_self() };
if current != self.thread {
return Ok(());
}
while self.is_woken.load(Ordering::SeqCst) == false {
wait_event(self.thread); wait_event(self.thread);
} }
Ok(())
} }
pub fn wake(&self) { pub fn wake(&self) {
self.is_woken.store(true, Ordering::SeqCst); if self.is_woken.fetch_or(true, Ordering::SeqCst) == false {
set_event(self.thread); set_event(self.thread);
} }
}
} }
unsafe impl Send for Waiter {} unsafe impl Send for Waiter {}

@ -1,6 +1,8 @@
pub use self::arch_prctl::{do_arch_prctl, ArchPrctlCode}; pub use self::arch_prctl::{do_arch_prctl, ArchPrctlCode};
pub use self::exit::{do_exit, do_wait4, ChildProcessFilter}; pub use self::exit::{do_exit, do_wait4, ChildProcessFilter};
pub use self::futex::{futex_op_and_flags_from_u32, futex_wait, futex_wake, FutexFlags, FutexOp}; pub use self::futex::{
futex_op_and_flags_from_u32, futex_requeue, futex_wait, futex_wake, FutexFlags, FutexOp,
};
pub use self::process::{Status, IDLE_PROCESS}; pub use self::process::{Status, IDLE_PROCESS};
pub use self::process_table::get; pub use self::process_table::get;
pub use self::sched::{do_sched_getaffinity, do_sched_setaffinity, do_sched_yield, CpuSet}; pub use self::sched::{do_sched_getaffinity, do_sched_setaffinity, do_sched_yield, CpuSet};

@ -43,8 +43,15 @@ pub extern "C" fn dispatch_syscall(
arg5: isize, arg5: isize,
) -> isize { ) -> isize {
debug!( debug!(
"syscall {}: {:#x}, {:#x}, {:#x}, {:#x}, {:#x}, {:#x}", "syscall tid:{}, num:{}: {:#x}, {:#x}, {:#x}, {:#x}, {:#x}, {:#x}",
num, arg0, arg1, arg2, arg3, arg4, arg5 process::do_gettid(),
num,
arg0,
arg1,
arg2,
arg3,
arg4,
arg5
); );
#[cfg(feature = "syscall_timing")] #[cfg(feature = "syscall_timing")]
let time_start = { let time_start = {
@ -178,6 +185,8 @@ pub extern "C" fn dispatch_syscall(
arg0 as *const i32, arg0 as *const i32,
arg1 as u32, arg1 as u32,
arg2 as i32, arg2 as i32,
arg3 as i32,
arg4 as *const i32,
// TODO: accept other optional arguments // TODO: accept other optional arguments
), ),
SYS_ARCH_PRCTL => do_arch_prctl(arg0 as u32, arg1 as *mut usize), SYS_ARCH_PRCTL => do_arch_prctl(arg0 as u32, arg1 as *mut usize),
@ -467,20 +476,36 @@ pub fn do_clone(
Ok(child_pid as isize) Ok(child_pid as isize)
} }
pub fn do_futex(futex_addr: *const i32, futex_op: u32, futex_val: i32) -> Result<isize> { pub fn do_futex(
futex_addr: *const i32,
futex_op: u32,
futex_val: i32,
timeout: i32,
futex_new_addr: *const i32,
) -> Result<isize> {
check_ptr(futex_addr)?; check_ptr(futex_addr)?;
let (futex_op, futex_flags) = process::futex_op_and_flags_from_u32(futex_op)?; let (futex_op, futex_flags) = process::futex_op_and_flags_from_u32(futex_op)?;
let get_futex_val = |val| -> Result<usize> {
if val < 0 {
return_errno!(EINVAL, "the futex val must not be negative");
}
Ok(val as usize)
};
match futex_op { match futex_op {
FutexOp::FUTEX_WAIT => process::futex_wait(futex_addr, futex_val).map(|_| 0), FutexOp::FUTEX_WAIT => process::futex_wait(futex_addr, futex_val).map(|_| 0),
FutexOp::FUTEX_WAKE => { FutexOp::FUTEX_WAKE => {
let max_count = { let max_count = get_futex_val(futex_val)?;
if futex_val < 0 {
return_errno!(EINVAL, "the count must not be negative");
}
futex_val as usize
};
process::futex_wake(futex_addr, max_count).map(|count| count as isize) process::futex_wake(futex_addr, max_count).map(|count| count as isize)
} }
FutexOp::FUTEX_REQUEUE => {
check_ptr(futex_new_addr)?;
let max_nwakes = get_futex_val(futex_val)?;
let max_nrequeues = get_futex_val(timeout)?;
process::futex_requeue(futex_addr, max_nwakes, max_nrequeues, futex_new_addr)
.map(|nwakes| nwakes as isize)
}
_ => return_errno!(ENOSYS, "the futex operation is not supported"), _ => return_errno!(ENOSYS, "the futex operation is not supported"),
} }
} }

@ -85,12 +85,90 @@ static int test_mutex_with_concurrent_counter(void) {
return 0; return 0;
} }
// ============================================================================
// The test case of waiting condition variable
// ============================================================================
#define WAIT_ROUND (100000)
struct thread_cond_arg {
int ti;
volatile unsigned int* val;
volatile int* exit_thread_count;
pthread_cond_t* cond_val;
pthread_mutex_t* mutex;
};
static void* thread_cond_wait(void* _arg) {
struct thread_cond_arg* arg = _arg;
printf("Thread #%d: start to wait on condition variable.\n", arg->ti);
for (unsigned int i = 0; i < WAIT_ROUND; ++i) {
pthread_mutex_lock(arg->mutex);
while (*(arg->val) == 0) {
pthread_cond_wait(arg->cond_val, arg->mutex);
}
pthread_mutex_unlock(arg->mutex);
}
(*arg->exit_thread_count)++;
printf("Thread #%d: exited.\n", arg->ti);
return NULL;
}
static int test_mutex_with_cond_wait(void) {
volatile unsigned int val = 0;
volatile int exit_thread_count = 0;
pthread_t threads[NTHREADS];
struct thread_cond_arg thread_args[NTHREADS];
pthread_cond_t cond_val = PTHREAD_COND_INITIALIZER;
pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER;
/*
* Start the threads waiting on the condition variable
*/
for (int ti = 0; ti < NTHREADS; ti++) {
struct thread_cond_arg* thread_arg = &thread_args[ti];
thread_arg->ti = ti;
thread_arg->val = &val;
thread_arg->exit_thread_count = &exit_thread_count;
thread_arg->cond_val = &cond_val;
thread_arg->mutex = &mutex;
if (pthread_create(&threads[ti], NULL, thread_cond_wait, thread_arg) < 0) {
printf("ERROR: pthread_create failed (ti = %d)\n", ti);
return -1;
}
}
/*
* Unblock all threads currently waiting on the condition variable
*/
while (exit_thread_count < NTHREADS) {
pthread_mutex_lock(&mutex);
val = 1;
pthread_cond_broadcast(&cond_val);
pthread_mutex_unlock(&mutex);
pthread_mutex_lock(&mutex);
val = 0;
pthread_mutex_unlock(&mutex);
}
/*
* Wait for the threads to finish
*/
for (int ti = 0; ti < NTHREADS; ti++) {
if (pthread_join(threads[ti], NULL) < 0) {
printf("ERROR: pthread_join failed (ti = %d)\n", ti);
return -1;
}
}
return 0;
}
// ============================================================================ // ============================================================================
// Test suite main // Test suite main
// ============================================================================ // ============================================================================
static test_case_t test_cases[] = { static test_case_t test_cases[] = {
TEST_CASE(test_mutex_with_concurrent_counter) TEST_CASE(test_mutex_with_concurrent_counter),
TEST_CASE(test_mutex_with_cond_wait),
}; };
int main() { int main() {