diff --git a/src/libos/src/process/do_clone.rs b/src/libos/src/process/do_clone.rs index 6387e7e1..957676a9 100644 --- a/src/libos/src/process/do_clone.rs +++ b/src/libos/src/process/do_clone.rs @@ -50,6 +50,7 @@ pub fn do_clone( let files = current.files().clone(); let rlimits = current.rlimits().clone(); let fs = current.fs().clone(); + let name = current.name().clone(); let mut builder = ThreadBuilder::new() .process(current.process().clone()) @@ -57,6 +58,7 @@ pub fn do_clone( .task(task) .fs(fs) .files(files) + .name(name) .rlimits(rlimits); if let Some(ctid) = ctid { builder = builder.clear_ctid(ctid); diff --git a/src/libos/src/process/do_spawn/mod.rs b/src/libos/src/process/do_spawn/mod.rs index aeca3eb3..329a5332 100644 --- a/src/libos/src/process/do_spawn/mod.rs +++ b/src/libos/src/process/do_spawn/mod.rs @@ -6,6 +6,7 @@ use self::exec_loader::{load_exec_file_to_vec, load_file_to_vec}; use super::elf_file::{ElfFile, ElfHeader, ProgramHeader, ProgramHeaderExt}; use super::process::ProcessBuilder; use super::task::Task; +use super::thread::ThreadName; use super::{table, task, ProcessRef, ThreadRef}; use crate::fs::{ CreationFlags, File, FileDesc, FileTable, FsView, HostStdioFds, StdinFile, StdoutFile, @@ -180,6 +181,10 @@ fn new_process( let fs_ref = Arc::new(SgxMutex::new(current_ref.fs().lock().unwrap().clone())); let sched_ref = Arc::new(SgxMutex::new(current_ref.sched().lock().unwrap().clone())); + // Make the default thread name to be the process's corresponding elf file name + let elf_name = elf_path.rsplit('/').collect::>()[0]; + let thread_name = ThreadName::new(elf_name); + ProcessBuilder::new() .vm(vm_ref) .exec_path(&elf_path) @@ -188,6 +193,7 @@ fn new_process( .sched(sched_ref) .fs(fs_ref) .files(files_ref) + .name(thread_name) .build()? }; diff --git a/src/libos/src/process/mod.rs b/src/libos/src/process/mod.rs index a473ddf9..c9f45ff6 100644 --- a/src/libos/src/process/mod.rs +++ b/src/libos/src/process/mod.rs @@ -36,6 +36,7 @@ mod do_getpid; mod do_set_tid_address; mod do_spawn; mod do_wait4; +mod prctl; mod process; mod syscalls; mod term_status; diff --git a/src/libos/src/process/prctl/macros.rs b/src/libos/src/process/prctl/macros.rs new file mode 100644 index 00000000..d408c684 --- /dev/null +++ b/src/libos/src/process/prctl/macros.rs @@ -0,0 +1,36 @@ +// Macros to implement `PrctlCmd` given a list of prctl names, numbers, and argument types. + +// Implement `PrctlNum` and `PrctlCmd`. +#[macro_export] +macro_rules! impl_prctl_nums_and_cmds { + ($( $prctl_name: ident => ( $prctl_num: expr, $($prctl_type_tt: tt)* ) ),+,) => { + $(const $prctl_name:i32 = $prctl_num;)* + + impl_prctl_cmds! { + $( + $prctl_name => ( $($prctl_type_tt)*), + )* + } + } +} + +macro_rules! impl_prctl_cmds { + ($( $prctl_name: ident => ( $($prctl_type_tt: tt)* ) ),+,) => { + #[derive(Debug)] + #[allow(non_camel_case_types)] + pub enum PrctlCmd<'a> { + $( + $prctl_name( get_arg_type!($($prctl_type_tt)*) ), + )* + } + } +} + +macro_rules! get_arg_type { + (()) => { + () + }; + ($($prctl_type_tt: tt)*) => { + $($prctl_type_tt)* + }; +} diff --git a/src/libos/src/process/prctl/mod.rs b/src/libos/src/process/prctl/mod.rs new file mode 100644 index 00000000..c03735e6 --- /dev/null +++ b/src/libos/src/process/prctl/mod.rs @@ -0,0 +1,66 @@ +use alloc::string::String; +use alloc::vec::Vec; +use std::ffi::CString; +use std::os::raw::c_char; + +use super::thread::ThreadName; +use crate::prelude::*; +use crate::util::mem_util::from_user::{check_array, clone_cstring_safely}; + +#[macro_use] +mod macros; + +// Note: +// PrctlCmd has implied lifetime parameter `'a` +impl_prctl_nums_and_cmds! { + // Format: + // prctl_name => (prctl_num, prctl_type_arg, ... + PR_SET_NAME => (15, ThreadName), + // Get thread name + PR_GET_NAME => (16, &'a mut [u8]), +} + +impl<'a> PrctlCmd<'a> { + pub fn from_raw(cmd: i32, arg2: u64, arg3: u64, arg4: u64, arg5: u64) -> Result> { + Ok(match cmd { + PR_SET_NAME => { + check_array(arg2 as *const u8, ThreadName::max_len())?; + let raw_name = + unsafe { std::slice::from_raw_parts(arg2 as *const u8, ThreadName::max_len()) }; + let name = ThreadName::from_slice(raw_name); + PrctlCmd::PR_SET_NAME(name) + } + PR_GET_NAME => { + let buf_checked = { + check_array(arg2 as *mut u8, ThreadName::max_len())?; + unsafe { + std::slice::from_raw_parts_mut(arg2 as *mut u8, ThreadName::max_len()) + } + }; + PrctlCmd::PR_GET_NAME(buf_checked) + } + _ => { + debug!("prctl cmd num: {}", cmd); + return_errno!(EINVAL, "unsupported prctl command"); + } + }) + } +} + +pub fn do_prctl(cmd: PrctlCmd) -> Result { + debug!("prctl: {:?}", cmd); + + let current = current!(); + match cmd { + PrctlCmd::PR_SET_NAME(name) => { + current.set_name(name); + } + PrctlCmd::PR_GET_NAME(c_buf) => { + let name = current.name(); + c_buf.copy_from_slice(name.as_slice()); + } + _ => warn!("Prctl command not supported"), + } + + Ok(0) +} diff --git a/src/libos/src/process/process/builder.rs b/src/libos/src/process/process/builder.rs index e2e25902..2dc446d4 100644 --- a/src/libos/src/process/process/builder.rs +++ b/src/libos/src/process/process/builder.rs @@ -1,5 +1,5 @@ use super::super::task::Task; -use super::super::thread::{ThreadBuilder, ThreadId}; +use super::super::thread::{ThreadBuilder, ThreadId, ThreadName}; use super::super::{ FileTableRef, ForcedExitStatus, FsViewRef, ProcessRef, ProcessVMRef, ResourceLimitsRef, SchedAgentRef, @@ -77,6 +77,10 @@ impl ProcessBuilder { self.thread_builder(|tb| tb.rlimits(rlimits)) } + pub fn name(mut self, name: ThreadName) -> Self { + self.thread_builder(|tb| tb.name(name)) + } + pub fn build(mut self) -> Result { // Process's pid == Main thread's tid let tid = self.tid.take().unwrap_or_else(|| ThreadId::new()); diff --git a/src/libos/src/process/syscalls.rs b/src/libos/src/process/syscalls.rs index 1ca946d2..263220f0 100644 --- a/src/libos/src/process/syscalls.rs +++ b/src/libos/src/process/syscalls.rs @@ -1,13 +1,13 @@ -use std::ptr::NonNull; - use super::do_arch_prctl::ArchPrctlCode; use super::do_clone::CloneFlags; use super::do_futex::{FutexFlags, FutexOp}; use super::do_spawn::FileAction; +use super::prctl::PrctlCmd; use super::process::ProcessFilter; use crate::prelude::*; use crate::time::timespec_t; use crate::util::mem_util::from_user::*; +use std::ptr::NonNull; pub fn do_spawn( child_pid_ptr: *mut u32, @@ -173,6 +173,11 @@ pub fn do_futex( } } +pub fn do_prctl(option: i32, arg2: u64, arg3: u64, arg4: u64, arg5: u64) -> Result { + let prctl_cmd = super::prctl::PrctlCmd::from_raw(option, arg2, arg3, arg4, arg5)?; + super::prctl::do_prctl(prctl_cmd).map(|_| 0) +} + pub fn do_arch_prctl(code: u32, addr: *mut usize) -> Result { let code = ArchPrctlCode::from_u32(code)?; check_mut_ptr(addr)?; diff --git a/src/libos/src/process/thread/builder.rs b/src/libos/src/process/thread/builder.rs index 52ff3934..6cea5863 100644 --- a/src/libos/src/process/thread/builder.rs +++ b/src/libos/src/process/thread/builder.rs @@ -2,7 +2,7 @@ use std::ptr::NonNull; use super::{ FileTableRef, FsViewRef, ProcessRef, ProcessVM, ProcessVMRef, ResourceLimitsRef, SchedAgentRef, - SigQueues, SigSet, Task, Thread, ThreadId, ThreadInner, ThreadRef, + SigQueues, SigSet, Task, Thread, ThreadId, ThreadInner, ThreadName, ThreadRef, }; use crate::prelude::*; use crate::time::ThreadProfiler; @@ -20,6 +20,7 @@ pub struct ThreadBuilder { sched: Option, rlimits: Option, clear_ctid: Option>, + name: Option, } impl ThreadBuilder { @@ -34,6 +35,7 @@ impl ThreadBuilder { sched: None, rlimits: None, clear_ctid: None, + name: None, } } @@ -82,6 +84,11 @@ impl ThreadBuilder { self } + pub fn name(mut self, name: ThreadName) -> Self { + self.name = Some(name); + self + } + pub fn build(self) -> Result { let task = self .task @@ -99,6 +106,7 @@ impl ThreadBuilder { let files = self.files.unwrap_or_default(); let sched = self.sched.unwrap_or_default(); let rlimits = self.rlimits.unwrap_or_default(); + let name = SgxRwLock::new(self.name.unwrap_or_default()); let sig_queues = SgxMutex::new(SigQueues::new()); let sig_mask = SgxRwLock::new(SigSet::new_empty()); let sig_tmp_mask = SgxRwLock::new(SigSet::new_empty()); @@ -120,6 +128,7 @@ impl ThreadBuilder { files, sched, rlimits, + name, sig_queues, sig_mask, sig_tmp_mask, diff --git a/src/libos/src/process/thread/mod.rs b/src/libos/src/process/thread/mod.rs index a5317730..b613b74b 100644 --- a/src/libos/src/process/thread/mod.rs +++ b/src/libos/src/process/thread/mod.rs @@ -12,9 +12,11 @@ use crate::time::ThreadProfiler; pub use self::builder::ThreadBuilder; pub use self::id::ThreadId; +pub use self::name::ThreadName; mod builder; mod id; +mod name; pub struct Thread { // Low-level info @@ -24,6 +26,7 @@ pub struct Thread { // Mutable info clear_ctid: SgxRwLock>>, inner: SgxMutex, + name: SgxRwLock, // Process process: ProcessRef, // Resources @@ -131,6 +134,14 @@ impl Thread { *self.clear_ctid.write().unwrap() = new_clear_ctid; } + pub fn name(&self) -> ThreadName { + self.name.read().unwrap().clone() + } + + pub fn set_name(&self, new_name: ThreadName) { + *self.name.write().unwrap() = new_name; + } + pub(super) fn start(&self, host_tid: pid_t) { self.sched().lock().unwrap().attach(host_tid); self.inner().start(); diff --git a/src/libos/src/process/thread/name.rs b/src/libos/src/process/thread/name.rs new file mode 100644 index 00000000..9baf0b08 --- /dev/null +++ b/src/libos/src/process/thread/name.rs @@ -0,0 +1,52 @@ +use crate::prelude::*; + +use std::ffi::{CStr, CString}; + +// The thread name buffer should allow space for up to 16 bytes, including the terminating null byte. +const THREAD_NAME_MAX_LEN: usize = 16; + +/// A thread name represented in a fixed buffer of 16 bytes. +/// +/// The length is chosen to be consistent with Linux. +#[derive(Debug, Clone, Default)] +pub struct ThreadName { + buf: [u8; THREAD_NAME_MAX_LEN], + len: usize, // including null terminator +} + +impl ThreadName { + /// Construct a thread name from str + pub fn new(name: &str) -> Self { + Self::from_slice(CString::new(name).unwrap().as_bytes_with_nul()) + } + + pub const fn max_len() -> usize { + THREAD_NAME_MAX_LEN + } + + /// Construct a thread name from slice + pub fn from_slice(input: &[u8]) -> Self { + let mut buf = [0; THREAD_NAME_MAX_LEN]; + let mut len = THREAD_NAME_MAX_LEN; + for (i, b) in buf.iter_mut().take(THREAD_NAME_MAX_LEN - 1).enumerate() { + if input[i] == '\0' as u8 { + len = i + 1; + break; + } + *b = input[i]; + } + debug_assert!(buf[THREAD_NAME_MAX_LEN - 1] == 0); + Self { buf, len } + } + + /// Returns a byte slice + pub fn as_slice(&self) -> &[u8] { + &self.buf + } + + /// Converts to a CStr. + pub fn as_c_str(&self) -> &CStr { + // Note: from_bytes_with_nul will fail if slice has more than 1 '\0' at the end + CStr::from_bytes_with_nul(&self.buf[..self.len]).unwrap_or_default() + } +} diff --git a/src/libos/src/syscall/mod.rs b/src/libos/src/syscall/mod.rs index 3a898fbf..b3ee49d9 100644 --- a/src/libos/src/syscall/mod.rs +++ b/src/libos/src/syscall/mod.rs @@ -35,8 +35,8 @@ use crate::net::{ }; use crate::process::{ do_arch_prctl, do_clone, do_exit, do_exit_group, do_futex, do_getegid, do_geteuid, do_getgid, - do_getpgid, do_getpid, do_getppid, do_gettid, do_getuid, do_set_tid_address, do_spawn, - do_wait4, pid_t, FdOp, ThreadStatus, + do_getpgid, do_getpid, do_getppid, do_gettid, do_getuid, do_prctl, do_set_tid_address, + do_spawn, do_wait4, pid_t, FdOp, ThreadStatus, }; use crate::sched::{do_sched_getaffinity, do_sched_setaffinity, do_sched_yield}; use crate::signal::{ @@ -234,7 +234,7 @@ macro_rules! process_syscall_table_with_callback { (ModifyLdt = 154) => handle_unsupported(), (PivotRoot = 155) => handle_unsupported(), (SysCtl = 156) => handle_unsupported(), - (Prctl = 157) => handle_unsupported(), + (Prctl = 157) => do_prctl(option: i32, arg2: u64, arg3: u64, arg4: u64, arg5: u64), (ArchPrctl = 158) => do_arch_prctl(code: u32, addr: *mut usize), (Adjtimex = 159) => handle_unsupported(), (Setrlimit = 160) => handle_unsupported(), diff --git a/test/Makefile b/test/Makefile index 6febb32c..e9f8b682 100644 --- a/test/Makefile +++ b/test/Makefile @@ -18,7 +18,7 @@ TEST_DEPS := client data_sink TESTS ?= env empty hello_world malloc mmap file fs_perms getpid spawn sched pipe time \ truncate readdir mkdir open stat link symlink chmod chown tls pthread uname rlimit \ server server_epoll unix_socket cout hostfs cpuid rdtsc device sleep exit_group \ - ioctl fcntl eventfd emulate_syscall access signal sysinfo + ioctl fcntl eventfd emulate_syscall access signal sysinfo prctl # Benchmarks: need to be compiled and run by bench-% target BENCHES := spawn_and_exit_latency pipe_throughput unix_socket_throughput diff --git a/test/prctl/Makefile b/test/prctl/Makefile new file mode 100644 index 00000000..541855aa --- /dev/null +++ b/test/prctl/Makefile @@ -0,0 +1,5 @@ +include ../test_common.mk + +EXTRA_C_FLAGS := -Wno-stringop-truncation +EXTRA_LINK_FLAGS := +BIN_ARGS := diff --git a/test/prctl/main.c b/test/prctl/main.c new file mode 100644 index 00000000..ec22c90c --- /dev/null +++ b/test/prctl/main.c @@ -0,0 +1,140 @@ +#define _GNU_SOURCE +#include +#include +#include +#include +#include +#include +#include +#include "test.h" + +// ============================================================================ +// Helper function +// ============================================================================ +#define THREAD_NAME_LEN 16 + +extern char *program_invocation_short_name; +#define DEFAULT_NAME program_invocation_short_name // name of this executable + +static const char *LONG_NAME = "A very very long thread name that is over 16 bytes"; +static const char *NORMAL_NAME = "A thread name"; + +static int *test_thread_long_name(void *arg) { + char thread_name[THREAD_NAME_LEN] = {0}; + char correct_name[THREAD_NAME_LEN] = {0}; + + // Thread name can hold up to 16 bytes including null terminator + // Construct the "correct_name" from the "long_name" + strncpy(correct_name, LONG_NAME, THREAD_NAME_LEN - 1); + correct_name[THREAD_NAME_LEN - 1] = '\0'; + + if (prctl(PR_SET_NAME, LONG_NAME) != 0) { + printf("long name test set thread name error\n"); + return (int *) -1; + } + if (prctl(PR_GET_NAME, thread_name) != 0) { + printf("long name test set thread name error\n"); + return (int *) -1; + } + if (!strncmp(thread_name, correct_name, THREAD_NAME_LEN)) { + return NULL; + } else { + printf("test long thread name mismatch\n"); + return (int *) -1; + } +} + +static int *test_thread_normal_name(void *arg) { + char thread_name[THREAD_NAME_LEN] = {0}; + + if (prctl(PR_SET_NAME, NORMAL_NAME) != 0) { + printf("normal name test set thread name error\n"); + return (int *) -1; + }; + if (prctl(PR_GET_NAME, thread_name) != 0) { + printf("normal name test get thread name error\n"); + return (int *) -1; + } + if (!strncmp(thread_name, NORMAL_NAME, strlen(NORMAL_NAME))) { + return NULL; + } else { + printf("test normal thread name mismatch\n"); + return (int *) -1; + } +} + +static int *test_thread_default_name(void *arg) { + char thread_name[THREAD_NAME_LEN] = {0}; + + if (prctl(PR_GET_NAME, thread_name) != 0) { + printf("get thread default name error\n"); + return (int *) -1; + } + + // The DEFAULT_NAME could be longer than THREAD_NAME_LEN and thus will make the last byte + // to be the null-terminator. So we just compare length with "THREAD_NAME_LEN - 1" + if (!strncmp(thread_name, DEFAULT_NAME, THREAD_NAME_LEN - 1)) { + return NULL; + } else { + printf("test default thread name mismatch\n"); + return (int *) -1; + } +} + +// ============================================================================ +// Test cases +// ============================================================================ +static int test_prctl_set_get_long_name(void) { + pthread_t tid; + void *ret; + + if (pthread_create(&tid, NULL, (void *)test_thread_long_name, NULL)) { + THROW_ERROR("create test long name thread failed"); + } + pthread_join(tid, &ret); + if ((int *) ret) { + THROW_ERROR("test long name thread prctl error"); + } + return 0; +} + +static int test_prctl_set_get_normal_name(void) { + pthread_t tid; + void *ret; + + if (pthread_create(&tid, NULL, (void *)test_thread_normal_name, NULL)) { + THROW_ERROR("create test normal name thread failed"); + } + pthread_join(tid, &ret); + if ((int *) ret) { + THROW_ERROR("test normal name thread prctl error"); + } + return 0; +} + +static int test_prctl_get_default_thread_name(void) { + pthread_t tid; + void *ret; + + if (pthread_create(&tid, NULL, (void *)test_thread_default_name, NULL)) { + THROW_ERROR("create test default name thread failed"); + } + pthread_join(tid, &ret); + if ((int *) ret) { + THROW_ERROR("test default name thread prctl error"); + } + return 0; +} + +// ============================================================================ +// Test suite main +// ============================================================================ +static test_case_t test_cases[] = { + TEST_CASE(test_prctl_set_get_long_name), // over 16 bytes + TEST_CASE(test_prctl_set_get_normal_name), + TEST_CASE(test_prctl_get_default_thread_name), +}; + +int main() { + return test_suite_run(test_cases, ARRAY_SIZE(test_cases)); +}