diff --git a/src/Enclave.edl b/src/Enclave.edl index ca08ec0e..c1ada329 100644 --- a/src/Enclave.edl +++ b/src/Enclave.edl @@ -79,6 +79,24 @@ enclave { * EPERM - No permission to send the signal or to the process. */ public int occlum_ecall_kill(int pid, int sig); + + + /* + * Broadcast interrupts to LibOS threads. + * + * Interrupts are sent to whichever LibOS threads that have pending + * events (e.g., signals). By interrupting the execution of these + * threads, the LibOS is given a chance to handle these events despite + * of the specific user workloads. + * + * @retval On success, return a non-negative value, which is the number + * of LibOS threads to which interrupts are sent. On error, return + * -errno. + * + * The possible values of errno are + * EAGAIN - The LibOS is not initialized. + */ + public int occlum_ecall_broadcast_interrupts(void); }; untrusted { @@ -100,7 +118,10 @@ enclave { void occlum_ocall_clock_getres(clockid_t clockid, [out] struct timespec* res); void occlum_ocall_rdtsc([out] uint32_t* low, [out] uint32_t* high); - void occlum_ocall_nanosleep([in] const struct timespec* req); + int occlum_ocall_nanosleep( + [in] const struct timespec* req, + [out] struct timespec* rem + ) propagate_errno; void occlum_ocall_sync(void); @@ -181,5 +202,7 @@ enclave { [in, out, size=len] void *arg, size_t len ) propagate_errno; + + int occlum_ocall_tkill(int tid, int signum) propagate_errno; }; }; diff --git a/src/libos/src/entry.rs b/src/libos/src/entry.rs index ce8d9870..120dd92a 100644 --- a/src/libos/src/entry.rs +++ b/src/libos/src/entry.rs @@ -6,6 +6,7 @@ use std::sync::Once; use super::*; use crate::exception::*; use crate::fs::HostStdioFds; +use crate::interrupt; use crate::process::ProcessFilter; use crate::signal::SigNum; use crate::time::up_time::init; @@ -65,6 +66,7 @@ pub extern "C" fn occlum_ecall_init(log_level: *const c_char, instance_dir: *con // Register exception handlers (support cpuid & rdtsc for now) register_exception_handlers(); + unsafe { let dir_str: &str = CStr::from_ptr(instance_dir).to_str().unwrap(); INSTANCE_DIR.push_str(dir_str); @@ -72,6 +74,8 @@ pub extern "C" fn occlum_ecall_init(log_level: *const c_char, instance_dir: *con ENCLAVE_PATH.push_str("/build/lib/libocclum-libos.signed.so"); } + interrupt::init(); + HAS_INIT.store(true, Ordering::SeqCst); // Init boot up time stamp here. @@ -156,6 +160,25 @@ pub extern "C" fn occlum_ecall_kill(pid: i32, sig: i32) -> i32 { .unwrap_or(ecall_errno!(EFAULT)) } +#[no_mangle] +pub extern "C" fn occlum_ecall_broadcast_interrupts() -> i32 { + if HAS_INIT.load(Ordering::SeqCst) == false { + return ecall_errno!(EAGAIN); + } + + let _ = unsafe { backtrace::enable_backtrace(&ENCLAVE_PATH, PrintFormat::Short) }; + panic::catch_unwind(|| { + backtrace::__rust_begin_short_backtrace(|| match interrupt::broadcast_interrupts() { + Ok(count) => count as i32, + Err(e) => { + eprintln!("failed to broadcast interrupts: {}", e.backtrace()); + ecall_errno!(e.errno()) + } + }) + }) + .unwrap_or(ecall_errno!(EFAULT)) +} + fn parse_log_level(level_chars: *const c_char) -> Result { const DEFAULT_LEVEL: LevelFilter = LevelFilter::Off; diff --git a/src/libos/src/interrupt/mod.rs b/src/libos/src/interrupt/mod.rs new file mode 100644 index 00000000..7e3bd374 --- /dev/null +++ b/src/libos/src/interrupt/mod.rs @@ -0,0 +1,97 @@ +use crate::prelude::*; +use crate::process::ThreadRef; +use crate::syscall::{CpuContext, SyscallNum}; + +pub use self::sgx::sgx_interrupt_info_t; + +mod sgx; + +pub fn init() { + unsafe { + let status = sgx::sgx_interrupt_init(handle_interrupt); + assert!(status == sgx_status_t::SGX_SUCCESS); + } +} + +extern "C" fn handle_interrupt(info: *mut sgx_interrupt_info_t) -> i32 { + extern "C" { + fn __occlum_syscall_c_abi(num: u32, info: *mut sgx_interrupt_info_t) -> u32; + } + unsafe { __occlum_syscall_c_abi(SyscallNum::HandleInterrupt as u32, info) }; + unreachable!(); +} + +pub fn do_handle_interrupt( + info: *mut sgx_interrupt_info_t, + cpu_context: *mut CpuContext, +) -> Result { + let info = unsafe { &*info }; + let context = unsafe { &mut *cpu_context }; + // The cpu context is overriden so that it is as if the syscall is called from where the + // interrupt happened + *context = CpuContext::from_sgx(&info.cpu_context); + Ok(0) +} + +/// Broadcast interrupts to threads by sending POSIX signals. +pub fn broadcast_interrupts() -> Result { + let should_interrupt_thread = |thread: &&ThreadRef| -> bool { + // TODO: check Thread::sig_mask to reduce false positives + thread.process().is_forced_to_exit() + || !thread.sig_queues().lock().unwrap().empty() + || !thread.process().sig_queues().lock().unwrap().empty() + }; + + let num_signaled_threads = crate::process::table::get_all_threads() + .iter() + .filter(should_interrupt_thread) + .map(|thread| { + let host_tid = { + let sched = thread.sched().lock().unwrap(); + match sched.host_tid() { + None => return false, + Some(host_tid) => host_tid, + } + }; + let signum = 64; // real-time signal 64 is used to notify interrupts + let is_signaled = unsafe { + let mut retval = 0; + let status = occlum_ocall_tkill(&mut retval, host_tid, signum); + assert!(status == sgx_status_t::SGX_SUCCESS); + if retval == 0 { + true + } else { + false + } + }; + is_signaled + }) + .filter(|&is_signaled| is_signaled) + .count(); + Ok(num_signaled_threads) +} + +extern "C" { + fn occlum_ocall_tkill(retval: &mut i32, host_tid: pid_t, signum: i32) -> sgx_status_t; +} + +pub fn enable_current_thread() { + // Interruptible range + let (addr, size) = { + let thread = current!(); + let vm = thread.vm(); + let range = vm.get_process_range(); + (range.start(), range.size()) + }; + unsafe { + let status = sgx::sgx_interrupt_enable(addr, size); + assert!(status == sgx_status_t::SGX_SUCCESS); + } +} + +pub fn disable_current_thread() { + unsafe { + let status = sgx::sgx_interrupt_disable(); + assert!(status == sgx_status_t::SGX_SUCCESS); + } +} diff --git a/src/libos/src/interrupt/sgx.rs b/src/libos/src/interrupt/sgx.rs new file mode 100644 index 00000000..f84993d8 --- /dev/null +++ b/src/libos/src/interrupt/sgx.rs @@ -0,0 +1,17 @@ +use crate::prelude::*; + +#[repr(C)] +#[derive(Default, Clone, Copy)] +#[allow(non_camel_case_types)] +pub struct sgx_interrupt_info_t { + pub cpu_context: sgx_cpu_context_t, +} + +#[allow(non_camel_case_types)] +pub type sgx_interrupt_handler_t = extern "C" fn(info: *mut sgx_interrupt_info_t) -> int32_t; + +extern "C" { + pub fn sgx_interrupt_init(handler: sgx_interrupt_handler_t) -> sgx_status_t; + pub fn sgx_interrupt_enable(code_addr: usize, code_size: usize) -> sgx_status_t; + pub fn sgx_interrupt_disable() -> sgx_status_t; +} diff --git a/src/libos/src/lib.rs b/src/libos/src/lib.rs index bc51de48..174aff72 100644 --- a/src/libos/src/lib.rs +++ b/src/libos/src/lib.rs @@ -58,6 +58,7 @@ mod config; mod entry; mod exception; mod fs; +mod interrupt; mod misc; mod net; mod process; diff --git a/src/libos/src/process/do_futex.rs b/src/libos/src/process/do_futex.rs index 15f08095..93c19083 100644 --- a/src/libos/src/process/do_futex.rs +++ b/src/libos/src/process/do_futex.rs @@ -125,7 +125,7 @@ pub fn futex_wait( // Must make sure that no locks are holded by this thread before wait drop(futex_bucket); - futex_item.wait_timeout(timeout) + futex_item.wait(timeout) } /// Do futex wake @@ -230,19 +230,14 @@ impl FutexItem { self.waiter.wake() } - pub fn wait_timeout(&self, timeout: &Option) -> Result<()> { - match timeout { - None => self.waiter.wait(), - Some(ts) => { - if let Err(e) = self.waiter.wait_timeout(&ts) { - let (_, futex_bucket_ref) = FUTEX_BUCKETS.get_bucket(self.key); - let mut futex_bucket = futex_bucket_ref.lock().unwrap(); - futex_bucket.dequeue_item(self); - return_errno!(e.errno(), "futex wait with timeout error"); - } - Ok(()) - } + pub fn wait(&self, timeout: &Option) -> Result<()> { + if let Err(e) = self.waiter.wait_timeout(&timeout) { + let (_, futex_bucket_ref) = FUTEX_BUCKETS.get_bucket(self.key); + let mut futex_bucket = futex_bucket_ref.lock().unwrap(); + futex_bucket.dequeue_item(self); + return_errno!(e.errno(), "futex wait timeout or interrupted"); } + Ok(()) } } @@ -367,18 +362,7 @@ impl Waiter { } } - pub fn wait(&self) -> Result<()> { - let current = unsafe { sgx_thread_get_self() }; - if current != self.thread { - return Ok(()); - } - while self.is_woken.load(Ordering::SeqCst) == false { - wait_event(self.thread); - } - Ok(()) - } - - pub fn wait_timeout(&self, timeout: ×pec_t) -> Result<()> { + pub fn wait_timeout(&self, timeout: &Option) -> Result<()> { let current = unsafe { sgx_thread_get_self() }; if current != self.thread { return Ok(()); @@ -386,8 +370,6 @@ impl Waiter { while self.is_woken.load(Ordering::SeqCst) == false { if let Err(e) = wait_event_timeout(self.thread, timeout) { self.is_woken.store(true, Ordering::SeqCst); - // Do sanity check here, only possible errnos here are ETIMEDOUT, EAGAIN and EINTR - debug_assert!(e.errno() == ETIMEDOUT || e.errno() == EAGAIN || e.errno() == EINTR); return_errno!(e.errno(), "wait_timeout error"); } } @@ -410,34 +392,31 @@ impl PartialEq for Waiter { unsafe impl Send for Waiter {} unsafe impl Sync for Waiter {} -fn wait_event(thread: *const c_void) { - let mut ret: c_int = 0; - let mut sgx_ret: c_int = 0; - unsafe { - sgx_ret = sgx_thread_wait_untrusted_event_ocall(&mut ret as *mut c_int, thread); - } - if ret != 0 || sgx_ret != 0 { - panic!("ERROR: sgx_thread_wait_untrusted_event_ocall failed"); - } -} - -fn wait_event_timeout(thread: *const c_void, timeout: ×pec_t) -> Result<()> { +fn wait_event_timeout(thread: *const c_void, timeout: &Option) -> Result<()> { let mut ret: c_int = 0; let mut sgx_ret: c_int = 0; + let timeout_ptr = timeout + .as_ref() + .map(|timeout_ref| timeout_ref as *const _) + .unwrap_or(0 as *const _); let mut errno: c_int = 0; unsafe { sgx_ret = sgx_thread_wait_untrusted_event_timeout_ocall( &mut ret as *mut c_int, thread, - timeout.sec(), - timeout.nsec(), + timeout_ptr, &mut errno as *mut c_int, ); - } - if ret != 0 || sgx_ret != 0 { - panic!("ERROR: sgx_thread_wait_untrusted_event_timeout_ocall failed"); + assert!(sgx_ret == 0); + assert!(ret == 0); } if errno != 0 { + // Do sanity check here, only possible errnos here are ETIMEDOUT, EAGAIN and EINTR + assert!( + (timeout.is_some() && errno == Errno::ETIMEDOUT as i32) + || errno == Errno::EINTR as i32 + || errno == Errno::EAGAIN as i32 + ); return_errno!( Errno::from(errno as u32), "sgx_thread_wait_untrusted_event_timeout_ocall error" @@ -460,14 +439,10 @@ fn set_event(thread: *const c_void) { extern "C" { fn sgx_thread_get_self() -> *const c_void; - /* Go outside and wait on my untrusted event */ - fn sgx_thread_wait_untrusted_event_ocall(ret: *mut c_int, self_thread: *const c_void) -> c_int; - fn sgx_thread_wait_untrusted_event_timeout_ocall( ret: *mut c_int, self_thread: *const c_void, - sec: c_long, - nsec: c_long, + ts: *const timespec_t, errno: *mut c_int, ) -> c_int; diff --git a/src/libos/src/process/table.rs b/src/libos/src/process/table.rs index 21a3ebe3..3284c3ff 100644 --- a/src/libos/src/process/table.rs +++ b/src/libos/src/process/table.rs @@ -14,6 +14,15 @@ pub fn get_all_processes() -> Vec { .collect() } +pub fn get_all_threads() -> Vec { + THREAD_TABLE + .lock() + .unwrap() + .iter() + .map(|(_, proc_ref)| proc_ref.clone()) + .collect() +} + pub(super) fn add_process(process: ProcessRef) -> Result<()> { PROCESS_TABLE.lock().unwrap().add(process.pid(), process) } diff --git a/src/libos/src/process/task/exec.rs b/src/libos/src/process/task/exec.rs index 178ca48e..9d8023fd 100644 --- a/src/libos/src/process/task/exec.rs +++ b/src/libos/src/process/task/exec.rs @@ -1,5 +1,6 @@ use super::super::{current, TermStatus, ThreadRef}; use super::Task; +use crate::interrupt; use crate::prelude::*; /// Enqueue a new thread so that it can be executed later. @@ -39,11 +40,15 @@ pub fn exec(libos_tid: pid_t, host_tid: pid_t) -> Result { // Enable current::get() from now on current::set(this_thread.clone()); + interrupt::enable_current_thread(); + unsafe { // task may only be modified by this function; so no lock is needed do_exec_task(this_thread.task() as *const Task as *mut Task); } + interrupt::disable_current_thread(); + let term_status = this_thread.inner().term_status().unwrap(); match term_status { TermStatus::Exited(status) => { diff --git a/src/libos/src/sched/sched_agent.rs b/src/libos/src/sched/sched_agent.rs index 6d0d006c..a5541fea 100644 --- a/src/libos/src/sched/sched_agent.rs +++ b/src/libos/src/sched/sched_agent.rs @@ -76,6 +76,13 @@ impl SchedAgent { Self { inner } } + pub fn host_tid(&self) -> Option { + match self.inner() { + Inner::Detached { .. } => None, + Inner::Attached { host_tid, .. } => Some(*host_tid), + } + } + pub fn affinity(&self) -> &CpuSet { match self.inner() { Inner::Detached { affinity } => affinity.as_ref(), diff --git a/src/libos/src/signal/sig_action.rs b/src/libos/src/signal/sig_action.rs index 556d2e4c..68b8c7fd 100644 --- a/src/libos/src/signal/sig_action.rs +++ b/src/libos/src/signal/sig_action.rs @@ -83,6 +83,9 @@ impl SigActionFlags { pub fn from_u32(bits: u32) -> Result { let flags = Self::from_bits(bits).ok_or_else(|| errno!(EINVAL, "invalid sigaction flags"))?; + if flags.contains(SigActionFlags::SA_RESTART) { + warn!("SA_RESTART is not supported"); + } Ok(flags) } diff --git a/src/libos/src/syscall/mod.rs b/src/libos/src/syscall/mod.rs index 2e8ae4f5..d4de72e9 100644 --- a/src/libos/src/syscall/mod.rs +++ b/src/libos/src/syscall/mod.rs @@ -26,6 +26,7 @@ use crate::fs::{ do_sync, do_truncate, do_unlink, do_write, do_writev, iovec_t, File, FileDesc, FileRef, HostStdioFds, Stat, }; +use crate::interrupt::{do_handle_interrupt, sgx_interrupt_info_t}; use crate::misc::{resource_t, rlimit_t, sysinfo_t, utsname_t}; use crate::net::{ do_accept, do_accept4, do_bind, do_connect, do_epoll_create, do_epoll_create1, do_epoll_ctl, @@ -407,8 +408,8 @@ macro_rules! process_syscall_table_with_callback { // Occlum-specific system calls (Spawn = 360) => do_spawn(child_pid_ptr: *mut u32, path: *const i8, argv: *const *const i8, envp: *const *const i8, fdop_list: *const FdOp), - // Exception handling (HandleException = 361) => do_handle_exception(info: *mut sgx_exception_info_t, context: *mut CpuContext), + (HandleInterrupt = 362) => do_handle_interrupt(info: *mut sgx_interrupt_info_t, context: *mut CpuContext), } }; } @@ -622,7 +623,10 @@ fn do_syscall(user_context: &mut CpuContext) { } else if syscall_num == SyscallNum::HandleException { // syscall.args[0] == info syscall.args[1] = user_context as *mut _ as isize; - } else if syscall_num == SyscallNum::Sigaltstack { + } else if syscall.num == SyscallNum::HandleInterrupt { + // syscall.args[0] == info + syscall.args[1] = user_context as *mut _ as isize; + } else if syscall.num == SyscallNum::Sigaltstack { // syscall.args[0] == new_ss // syscall.args[1] == old_ss syscall.args[2] = user_context as *const _ as isize; @@ -686,13 +690,16 @@ fn do_syscall(user_context: &mut CpuContext) { trace!("Retval = {:?}", retval); // Put the return value into user_context.rax, except for syscalls that may - // modify user_context directly. Currently, there are two such syscalls: - // SigReturn and HandleException. + // modify user_context directly. Currently, there are three such syscalls: + // SigReturn, HandleException, and HandleInterrupt. // // Sigreturn restores `user_context` to the state when the last signal // handler is executed. So in the case of sigreturn, `user_context` should // be kept intact. - if num != SyscallNum::RtSigreturn as u32 && num != SyscallNum::HandleException as u32 { + if num != SyscallNum::RtSigreturn as u32 + && num != SyscallNum::HandleException as u32 + && num != SyscallNum::HandleInterrupt as u32 + { user_context.rax = retval as u64; } @@ -807,13 +814,17 @@ fn do_clock_getres(clockid: clockid_t, res_u: *mut timespec_t) -> Result // TODO: handle remainder fn do_nanosleep(req_u: *const timespec_t, rem_u: *mut timespec_t) -> Result { - check_ptr(req_u)?; - if !rem_u.is_null() { + let req = { + check_ptr(req_u)?; + timespec_t::from_raw_ptr(req_u)? + }; + let rem = if !rem_u.is_null() { check_mut_ptr(rem_u)?; - } - - let req = timespec_t::from_raw_ptr(req_u)?; - time::do_nanosleep(&req)?; + Some(unsafe { &mut *rem_u }) + } else { + None + }; + time::do_nanosleep(&req, rem)?; Ok(0) } diff --git a/src/libos/src/time/mod.rs b/src/libos/src/time/mod.rs index c5161f5e..92e815a9 100644 --- a/src/libos/src/time/mod.rs +++ b/src/libos/src/time/mod.rs @@ -160,12 +160,27 @@ pub fn do_clock_getres(clockid: ClockID) -> Result { Ok(res) } -pub fn do_nanosleep(req: ×pec_t) -> Result<()> { +pub fn do_nanosleep(req: ×pec_t, rem: Option<&mut timespec_t>) -> Result<()> { extern "C" { - fn occlum_ocall_nanosleep(req: *const timespec_t) -> sgx_status_t; + fn occlum_ocall_nanosleep( + ret: *mut i32, + req: *const timespec_t, + rem: *mut timespec_t, + ) -> sgx_status_t; } unsafe { - occlum_ocall_nanosleep(req as *const timespec_t); + let mut ret = 0; + let mut u_rem: timespec_t = timespec_t { sec: 0, nsec: 0 }; + let sgx_status = occlum_ocall_nanosleep(&mut ret, req, &mut u_rem); + assert!(sgx_status == sgx_status_t::SGX_SUCCESS); + assert!(ret == 0 || libc::errno() == Errno::EINTR as i32); + if ret != 0 { + assert!(u_rem.as_duration() <= req.as_duration()); + if let Some(rem) = rem { + *rem = u_rem; + } + return_errno!(EINTR, "sleep interrupted"); + } } Ok(()) } diff --git a/src/pal/src/ocalls/signal.c b/src/pal/src/ocalls/signal.c new file mode 100644 index 00000000..69f08072 --- /dev/null +++ b/src/pal/src/ocalls/signal.c @@ -0,0 +1,7 @@ +#include "ocalls.h" + +int occlum_ocall_tkill(int tid, int signum) { + int tgid = getpid(); + int ret = tgkill(tgid, tid, signum); + return ret; +} diff --git a/src/pal/src/ocalls/spawn.c b/src/pal/src/ocalls/spawn.c index b4656537..7b1c0a8f 100644 --- a/src/pal/src/ocalls/spawn.c +++ b/src/pal/src/ocalls/spawn.c @@ -1,6 +1,7 @@ #include #include #include "ocalls.h" +#include "../pal_thread_counter.h" typedef struct { sgx_enclave_id_t enclave_id; @@ -17,11 +18,13 @@ void *exec_libos_thread(void *_thread_data) { host_tid); if (status != SGX_SUCCESS) { const char *sgx_err = pal_get_sgx_error_msg(status); - PAL_ERROR("Failed to enter the enclave to execute a LibOS thread: %s", sgx_err); + PAL_ERROR("Failed to enter the enclave to execute a LibOS thread (host tid = %d): %s", + host_tid, sgx_err); exit(EXIT_FAILURE); } free(thread_data); + pal_thread_counter_dec(); return NULL; } @@ -34,13 +37,15 @@ int occlum_ocall_exec_thread_async(int libos_tid) { thread_data->enclave_id = pal_get_enclave_id(); thread_data->libos_tid = libos_tid; + pal_thread_counter_inc(); if ((ret = pthread_create(&thread, NULL, exec_libos_thread, thread_data)) < 0) { + pal_thread_counter_dec(); free(thread_data); return -1; } pthread_detach(thread); - // Note: thread_data is freed just before the thread exits + // Note: thread_data is freed and thread counter is decreased just before the thread exits return 0; } diff --git a/src/pal/src/ocalls/time.c b/src/pal/src/ocalls/time.c index c8e852e8..7adb7fca 100644 --- a/src/pal/src/ocalls/time.c +++ b/src/pal/src/ocalls/time.c @@ -14,8 +14,8 @@ void occlum_ocall_clock_getres(int clockid, struct timespec *res) { clock_getres(clockid, res); } -void occlum_ocall_nanosleep(const struct timespec *req) { - nanosleep(req, NULL); +int occlum_ocall_nanosleep(const struct timespec *req, struct timespec *rem) { + return nanosleep(req, rem); } int occlum_ocall_thread_getcpuclock(struct timespec *tp) { diff --git a/src/pal/src/pal_api.c b/src/pal/src/pal_api.c index c4d358a6..11f6f483 100644 --- a/src/pal/src/pal_api.c +++ b/src/pal/src/pal_api.c @@ -2,8 +2,10 @@ #include "Enclave_u.h" #include "pal_enclave.h" #include "pal_error.h" +#include "pal_interrupt_thread.h" #include "pal_log.h" #include "pal_syscall.h" +#include "pal_thread_counter.h" #include "errno2str.h" int occlum_pal_get_version(void) { @@ -11,8 +13,6 @@ int occlum_pal_get_version(void) { } int occlum_pal_init(const struct occlum_pal_attr *attr) { - errno = 0; - if (attr == NULL) { errno = EINVAL; return -1; @@ -40,20 +40,30 @@ int occlum_pal_init(const struct occlum_pal_attr *attr) { if (ecall_status != SGX_SUCCESS) { const char *sgx_err = pal_get_sgx_error_msg(ecall_status); PAL_ERROR("Failed to do ECall: %s", sgx_err); - return -1; + goto on_destroy_enclave; } if (ecall_ret < 0) { errno = -ecall_ret; PAL_ERROR("occlum_ecall_init returns %s", errno2str(errno)); - return -1; + goto on_destroy_enclave; } + + if (pal_interrupt_thread_start() < 0) { + PAL_ERROR("Failed to start the interrupt thread: %s", errno2str(errno)); + goto on_destroy_enclave; + } + return 0; +on_destroy_enclave: + if (pal_destroy_enclave() < 0) { + PAL_WARN("Cannot destroy the enclave"); + } + return -1; } int occlum_pal_create_process(struct occlum_pal_create_process_args *args) { int ecall_ret = 0; // libos_tid - errno = 0; if (args->path == NULL || args->argv == NULL || args->pid == NULL) { errno = EINVAL; return -1; @@ -99,8 +109,10 @@ int occlum_pal_exec(struct occlum_pal_exec_args *args) { return -1; } + pal_thread_counter_inc(); sgx_status_t ecall_status = occlum_ecall_exec_thread(eid, &ecall_ret, args->pid, host_tid); + pal_thread_counter_dec(); if (ecall_status != SGX_SUCCESS) { const char *sgx_err = pal_get_sgx_error_msg(ecall_status); PAL_ERROR("Failed to do ECall: %s", sgx_err); @@ -113,12 +125,11 @@ int occlum_pal_exec(struct occlum_pal_exec_args *args) { } *args->exit_value = ecall_ret; + return 0; } int occlum_pal_kill(int pid, int sig) { - errno = 0; - sgx_enclave_id_t eid = pal_get_enclave_id(); if (eid == SGX_INVALID_ENCLAVE_ID) { errno = ENOENT; @@ -143,8 +154,6 @@ int occlum_pal_kill(int pid, int sig) { } int occlum_pal_destroy(void) { - errno = 0; - sgx_enclave_id_t eid = pal_get_enclave_id(); if (eid == SGX_INVALID_ENCLAVE_ID) { PAL_ERROR("Enclave is not initialized yet."); @@ -152,10 +161,17 @@ int occlum_pal_destroy(void) { return -1; } - if (pal_destroy_enclave() < 0) { - return -1; + int ret = 0; + + if (pal_interrupt_thread_stop() < 0) { + ret = -1; + PAL_WARN("Cannot stop the interrupt thread: %s", errno2str(errno)); } - return 0; + if (pal_destroy_enclave() < 0) { + ret = -1; + PAL_WARN("Cannot destroy the enclave"); + } + return ret; } int pal_get_version(void) __attribute__((weak, alias ("occlum_pal_get_version"))); diff --git a/src/pal/src/pal_enclave.c b/src/pal/src/pal_enclave.c index 5dc72273..b117c1a0 100644 --- a/src/pal/src/pal_enclave.c +++ b/src/pal/src/pal_enclave.c @@ -130,12 +130,8 @@ int pal_init_enclave(const char *instance_dir) { } int pal_destroy_enclave(void) { - // TODO: destroy the enclave gracefully - // We cannot destroy the enclave gracefully since we may still have - // running threads that are using the enclave at this point, which blocks - // sgx_destory_enclave call. We need to implement exit_group syscall and - // handle signal and exceptions properly. - //sgx_destroy_enclave(global_eid); + sgx_destroy_enclave(global_eid); + global_eid = SGX_INVALID_ENCLAVE_ID; return 0; } diff --git a/src/pal/src/pal_interrupt_thread.c b/src/pal/src/pal_interrupt_thread.c new file mode 100644 index 00000000..1202df71 --- /dev/null +++ b/src/pal/src/pal_interrupt_thread.c @@ -0,0 +1,81 @@ +#include +#include "Enclave_u.h" +#include "pal_enclave.h" +#include "pal_error.h" +#include "pal_interrupt_thread.h" +#include "pal_log.h" +#include "pal_syscall.h" +#include "pal_thread_counter.h" +#include "errno2str.h" + +#define MS (1000*1000L) // 1ms = 1,000,000ns + +static pthread_t thread; +static int is_running = 0; + +static void *thread_func(void *_data) { + sgx_enclave_id_t eid = pal_get_enclave_id(); + + int counter = 0; + do { + int num_broadcast_threads = 0; + sgx_status_t ecall_status = occlum_ecall_broadcast_interrupts(eid, + &num_broadcast_threads); + if (ecall_status != SGX_SUCCESS) { + const char *sgx_err = pal_get_sgx_error_msg(ecall_status); + PAL_ERROR("Failed to do ECall: occlum_ecall_broadcast_interrupts: %s", sgx_err); + exit(EXIT_FAILURE); + } + if (ecall_status == SGX_SUCCESS && num_broadcast_threads < 0) { + int errno_ = -num_broadcast_threads; + PAL_ERROR("Unexpcted error from cclum_ecall_broadcast_interrupts: %s", errno2str(errno_)); + exit(EXIT_FAILURE); + } + + struct timespec timeout = { .tv_sec = 0, .tv_nsec = 25 * MS }; + counter = pal_thread_counter_wait_zero(&timeout); + } while (counter > 0); + + return NULL; +} + +int pal_interrupt_thread_start(void) { + if (is_running) { + errno = EEXIST; + PAL_ERROR("The interrupt thread is already running: %s", errno2str(errno)); + return -1; + } + + is_running = 1; + pal_thread_counter_inc(); + + int ret = 0; + if ((ret = pthread_create(&thread, NULL, thread_func, NULL))) { + is_running = 0; + pal_thread_counter_dec(); + + errno = ret; + PAL_ERROR("Failed to start the interrupt thread: %s", errno2str(errno)); + return -1; + } + return 0; +} + +int pal_interrupt_thread_stop(void) { + if (!is_running) { + errno = ENOENT; + return -1; + } + + is_running = 0; + pal_thread_counter_dec(); + + int ret = 0; + if ((ret = pthread_join(thread, NULL))) { + errno = ret; + PAL_ERROR("Failed to free the interrupt thread: %s", errno2str(errno)); + return -1; + } + + return 0; +} diff --git a/src/pal/src/pal_interrupt_thread.h b/src/pal/src/pal_interrupt_thread.h new file mode 100644 index 00000000..0a2dccf5 --- /dev/null +++ b/src/pal/src/pal_interrupt_thread.h @@ -0,0 +1,8 @@ +#ifndef __PAL_INTERRUPT_H__ +#define __PAL_INTERRUPT_H__ + +int pal_interrupt_thread_start(void); + +int pal_interrupt_thread_stop(void); + +#endif /* __PAL_INTERRUPT_H__ */ diff --git a/src/pal/src/pal_syscall.h b/src/pal/src/pal_syscall.h index c29624fb..13aa145b 100644 --- a/src/pal/src/pal_syscall.h +++ b/src/pal/src/pal_syscall.h @@ -2,9 +2,15 @@ #define __PAL_SYSCALL_H__ #define _GNU_SOURCE +#include +#include +#include #include #include -#define gettid() syscall(__NR_gettid) +#define gettid() ((pid_t)syscall(__NR_gettid)) +#define tgkill(tgid, tid, signum) ((int)syscall(__NR_tgkill, (tgid), (tid), (signum))); +#define futex_wait(addr, val, timeout) ((int)syscall(__NR_futex, (addr), FUTEX_WAIT, (val), (timeout))) +#define futex_wake(addr) ((int)syscall(__NR_futex, (addr), FUTEX_WAKE, 1)) #endif /* __PAL_SYSCALL_H__ */ diff --git a/src/pal/src/pal_thread_counter.c b/src/pal/src/pal_thread_counter.c new file mode 100644 index 00000000..96756891 --- /dev/null +++ b/src/pal/src/pal_thread_counter.c @@ -0,0 +1,31 @@ +#include +#include "pal_syscall.h" +#include "pal_thread_counter.h" + +volatile int pal_thread_counter = 0; + +void pal_thread_counter_inc(void) { + __atomic_add_fetch(&pal_thread_counter, 1, __ATOMIC_SEQ_CST); +} + +void pal_thread_counter_dec(void) { + int val = __atomic_sub_fetch(&pal_thread_counter, 1, __ATOMIC_SEQ_CST); + assert(val >= 0); + + (void)futex_wake(&pal_thread_counter); +} + +int pal_thread_counter_get(void) { + return __atomic_load_n(&pal_thread_counter, __ATOMIC_SEQ_CST); +} + +int pal_thread_counter_wait_zero(const struct timespec *timeout) { + int old_val = pal_thread_counter_get(); + if (old_val == 0) { return 0; } + + (void)futex_wait(&pal_thread_counter, old_val, timeout); + + int new_val = pal_thread_counter_get(); + return new_val; +} + diff --git a/src/pal/src/pal_thread_counter.h b/src/pal/src/pal_thread_counter.h new file mode 100644 index 00000000..19b4830a --- /dev/null +++ b/src/pal/src/pal_thread_counter.h @@ -0,0 +1,20 @@ +#ifndef __PAL_THREAD_COUNTER_H__ +#define __PAL_THREAD_COUNTER_H__ + +#include + +// An atomic counter for threads + +// Increase the counter atomically +void pal_thread_counter_inc(void); + +// Decrease the counter atomically. Don't try to decrease the value below zero. +void pal_thread_counter_dec(void); + +// Get the value of the counter atomically +int pal_thread_counter_get(void); + +// Wait for counter to be zero until a timeout +int pal_thread_counter_wait_zero(const struct timespec *timeout); + +#endif /* __PAL_THREAD_COUNTER_H__ */ diff --git a/test/exit_group/main.c b/test/exit_group/main.c index fb541685..4402d601 100644 --- a/test/exit_group/main.c +++ b/test/exit_group/main.c @@ -1,6 +1,7 @@ #define _GNU_SOURCE #include #include +#include #include #include #include @@ -17,28 +18,23 @@ // Type 1: a busy loop thread static void *busyloop_thread_func(void *_) { - while (1) { - // By calling getpid, we give the LibOS a chance to force the thread - // to terminate if exit_group is called by any thread in a thread group - getpid(); - } + while (1) { } return NULL; } // Type 2: a sleeping thread -//static void* sleeping_thread_func(void* _) { -// unsigned int a_year_in_sec = 365 * 24 * 60 * 60; -// sleep(a_year_in_sec); -// return NULL; -//} +static void *sleeping_thread_func(void *_) { + unsigned int a_year_in_sec = 365 * 24 * 60 * 60; + sleep(a_year_in_sec); + return NULL; +} -// Type 3: a thead that keeps waiting on a futex -//static void* futex_wait_thread_func(void* _) { -// // Wait on a futex forever -// int my_private_futex = 0; -// syscall(SYS_futex, &my_private_futex, FUTEX_WAIT, my_private_futex); -// return NULL; -//} +// Type 3: a thead that waits on a futex FOREVER +static void *futex_wait_thread_func(void *_) { + int my_private_futex = 0; + syscall(SYS_futex, &my_private_futex, FUTEX_WAIT, my_private_futex, NULL); + return NULL; +} // exit_group syscall should terminate all threads in a thread group. int test_exit_group_to_force_threads_terminate(void) { @@ -48,22 +44,20 @@ int test_exit_group_to_force_threads_terminate(void) { printf("ERROR: pthread_create failed\n"); return -1; } - - // Disable below two test cases, needs interrupt support - // pthread_t sleeping_thread; - // if (pthread_create(&sleeping_thread, NULL, sleeping_thread_func, NULL) < 0) { - // printf("ERROR: pthread_create failed\n"); - // return -1; - // } - // pthread_t futex_wait_thread; - // if (pthread_create(&futex_wait_thread, NULL, futex_wait_thread_func, NULL) < 0) { - // printf("ERROR: pthread_create failed\n"); - // return -1; - // } + pthread_t sleeping_thread; + if (pthread_create(&sleeping_thread, NULL, sleeping_thread_func, NULL) < 0) { + printf("ERROR: pthread_create failed\n"); + return -1; + } + pthread_t futex_wait_thread; + if (pthread_create(&futex_wait_thread, NULL, futex_wait_thread_func, NULL) < 0) { + printf("ERROR: pthread_create failed\n"); + return -1; + } // Sleep for a while to make sure all three threads are running - useconds_t _200ms = 200 * 1000; - usleep(_200ms); + useconds_t half_second = 500 * 1000; // in us + usleep(half_second); // exit_group syscall will be called eventually by libc's exit, after the // main function returns. If Occlum can terminate normally, this means