Add prctl support of PR_SET/GET_NAME options

This commit is contained in:
Hui, Chunyang 2020-06-11 07:41:52 +00:00
parent b86d8ed490
commit 3cd46fd224
14 changed files with 345 additions and 8 deletions

@ -50,6 +50,7 @@ pub fn do_clone(
let files = current.files().clone();
let rlimits = current.rlimits().clone();
let fs = current.fs().clone();
let name = current.name().clone();
let mut builder = ThreadBuilder::new()
.process(current.process().clone())
@ -57,6 +58,7 @@ pub fn do_clone(
.task(task)
.fs(fs)
.files(files)
.name(name)
.rlimits(rlimits);
if let Some(ctid) = ctid {
builder = builder.clear_ctid(ctid);

@ -6,6 +6,7 @@ use self::exec_loader::{load_exec_file_to_vec, load_file_to_vec};
use super::elf_file::{ElfFile, ElfHeader, ProgramHeader, ProgramHeaderExt};
use super::process::ProcessBuilder;
use super::task::Task;
use super::thread::ThreadName;
use super::{table, task, ProcessRef, ThreadRef};
use crate::fs::{
CreationFlags, File, FileDesc, FileTable, FsView, HostStdioFds, StdinFile, StdoutFile,
@ -180,6 +181,10 @@ fn new_process(
let fs_ref = Arc::new(SgxMutex::new(current_ref.fs().lock().unwrap().clone()));
let sched_ref = Arc::new(SgxMutex::new(current_ref.sched().lock().unwrap().clone()));
// Make the default thread name to be the process's corresponding elf file name
let elf_name = elf_path.rsplit('/').collect::<Vec<&str>>()[0];
let thread_name = ThreadName::new(elf_name);
ProcessBuilder::new()
.vm(vm_ref)
.exec_path(&elf_path)
@ -188,6 +193,7 @@ fn new_process(
.sched(sched_ref)
.fs(fs_ref)
.files(files_ref)
.name(thread_name)
.build()?
};

@ -36,6 +36,7 @@ mod do_getpid;
mod do_set_tid_address;
mod do_spawn;
mod do_wait4;
mod prctl;
mod process;
mod syscalls;
mod term_status;

@ -0,0 +1,36 @@
// Macros to implement `PrctlCmd` given a list of prctl names, numbers, and argument types.
// Implement `PrctlNum` and `PrctlCmd`.
#[macro_export]
macro_rules! impl_prctl_nums_and_cmds {
($( $prctl_name: ident => ( $prctl_num: expr, $($prctl_type_tt: tt)* ) ),+,) => {
$(const $prctl_name:i32 = $prctl_num;)*
impl_prctl_cmds! {
$(
$prctl_name => ( $($prctl_type_tt)*),
)*
}
}
}
macro_rules! impl_prctl_cmds {
($( $prctl_name: ident => ( $($prctl_type_tt: tt)* ) ),+,) => {
#[derive(Debug)]
#[allow(non_camel_case_types)]
pub enum PrctlCmd<'a> {
$(
$prctl_name( get_arg_type!($($prctl_type_tt)*) ),
)*
}
}
}
macro_rules! get_arg_type {
(()) => {
()
};
($($prctl_type_tt: tt)*) => {
$($prctl_type_tt)*
};
}

@ -0,0 +1,66 @@
use alloc::string::String;
use alloc::vec::Vec;
use std::ffi::CString;
use std::os::raw::c_char;
use super::thread::ThreadName;
use crate::prelude::*;
use crate::util::mem_util::from_user::{check_array, clone_cstring_safely};
#[macro_use]
mod macros;
// Note:
// PrctlCmd has implied lifetime parameter `'a`
impl_prctl_nums_and_cmds! {
// Format:
// prctl_name => (prctl_num, prctl_type_arg, ...
PR_SET_NAME => (15, ThreadName),
// Get thread name
PR_GET_NAME => (16, &'a mut [u8]),
}
impl<'a> PrctlCmd<'a> {
pub fn from_raw(cmd: i32, arg2: u64, arg3: u64, arg4: u64, arg5: u64) -> Result<PrctlCmd<'a>> {
Ok(match cmd {
PR_SET_NAME => {
check_array(arg2 as *const u8, ThreadName::max_len())?;
let raw_name =
unsafe { std::slice::from_raw_parts(arg2 as *const u8, ThreadName::max_len()) };
let name = ThreadName::from_slice(raw_name);
PrctlCmd::PR_SET_NAME(name)
}
PR_GET_NAME => {
let buf_checked = {
check_array(arg2 as *mut u8, ThreadName::max_len())?;
unsafe {
std::slice::from_raw_parts_mut(arg2 as *mut u8, ThreadName::max_len())
}
};
PrctlCmd::PR_GET_NAME(buf_checked)
}
_ => {
debug!("prctl cmd num: {}", cmd);
return_errno!(EINVAL, "unsupported prctl command");
}
})
}
}
pub fn do_prctl(cmd: PrctlCmd) -> Result<isize> {
debug!("prctl: {:?}", cmd);
let current = current!();
match cmd {
PrctlCmd::PR_SET_NAME(name) => {
current.set_name(name);
}
PrctlCmd::PR_GET_NAME(c_buf) => {
let name = current.name();
c_buf.copy_from_slice(name.as_slice());
}
_ => warn!("Prctl command not supported"),
}
Ok(0)
}

@ -1,5 +1,5 @@
use super::super::task::Task;
use super::super::thread::{ThreadBuilder, ThreadId};
use super::super::thread::{ThreadBuilder, ThreadId, ThreadName};
use super::super::{
FileTableRef, ForcedExitStatus, FsViewRef, ProcessRef, ProcessVMRef, ResourceLimitsRef,
SchedAgentRef,
@ -77,6 +77,10 @@ impl ProcessBuilder {
self.thread_builder(|tb| tb.rlimits(rlimits))
}
pub fn name(mut self, name: ThreadName) -> Self {
self.thread_builder(|tb| tb.name(name))
}
pub fn build(mut self) -> Result<ProcessRef> {
// Process's pid == Main thread's tid
let tid = self.tid.take().unwrap_or_else(|| ThreadId::new());

@ -1,13 +1,13 @@
use std::ptr::NonNull;
use super::do_arch_prctl::ArchPrctlCode;
use super::do_clone::CloneFlags;
use super::do_futex::{FutexFlags, FutexOp};
use super::do_spawn::FileAction;
use super::prctl::PrctlCmd;
use super::process::ProcessFilter;
use crate::prelude::*;
use crate::time::timespec_t;
use crate::util::mem_util::from_user::*;
use std::ptr::NonNull;
pub fn do_spawn(
child_pid_ptr: *mut u32,
@ -173,6 +173,11 @@ pub fn do_futex(
}
}
pub fn do_prctl(option: i32, arg2: u64, arg3: u64, arg4: u64, arg5: u64) -> Result<isize> {
let prctl_cmd = super::prctl::PrctlCmd::from_raw(option, arg2, arg3, arg4, arg5)?;
super::prctl::do_prctl(prctl_cmd).map(|_| 0)
}
pub fn do_arch_prctl(code: u32, addr: *mut usize) -> Result<isize> {
let code = ArchPrctlCode::from_u32(code)?;
check_mut_ptr(addr)?;

@ -2,7 +2,7 @@ use std::ptr::NonNull;
use super::{
FileTableRef, FsViewRef, ProcessRef, ProcessVM, ProcessVMRef, ResourceLimitsRef, SchedAgentRef,
SigQueues, SigSet, Task, Thread, ThreadId, ThreadInner, ThreadRef,
SigQueues, SigSet, Task, Thread, ThreadId, ThreadInner, ThreadName, ThreadRef,
};
use crate::prelude::*;
use crate::time::ThreadProfiler;
@ -20,6 +20,7 @@ pub struct ThreadBuilder {
sched: Option<SchedAgentRef>,
rlimits: Option<ResourceLimitsRef>,
clear_ctid: Option<NonNull<pid_t>>,
name: Option<ThreadName>,
}
impl ThreadBuilder {
@ -34,6 +35,7 @@ impl ThreadBuilder {
sched: None,
rlimits: None,
clear_ctid: None,
name: None,
}
}
@ -82,6 +84,11 @@ impl ThreadBuilder {
self
}
pub fn name(mut self, name: ThreadName) -> Self {
self.name = Some(name);
self
}
pub fn build(self) -> Result<ThreadRef> {
let task = self
.task
@ -99,6 +106,7 @@ impl ThreadBuilder {
let files = self.files.unwrap_or_default();
let sched = self.sched.unwrap_or_default();
let rlimits = self.rlimits.unwrap_or_default();
let name = SgxRwLock::new(self.name.unwrap_or_default());
let sig_queues = SgxMutex::new(SigQueues::new());
let sig_mask = SgxRwLock::new(SigSet::new_empty());
let sig_tmp_mask = SgxRwLock::new(SigSet::new_empty());
@ -120,6 +128,7 @@ impl ThreadBuilder {
files,
sched,
rlimits,
name,
sig_queues,
sig_mask,
sig_tmp_mask,

@ -12,9 +12,11 @@ use crate::time::ThreadProfiler;
pub use self::builder::ThreadBuilder;
pub use self::id::ThreadId;
pub use self::name::ThreadName;
mod builder;
mod id;
mod name;
pub struct Thread {
// Low-level info
@ -24,6 +26,7 @@ pub struct Thread {
// Mutable info
clear_ctid: SgxRwLock<Option<NonNull<pid_t>>>,
inner: SgxMutex<ThreadInner>,
name: SgxRwLock<ThreadName>,
// Process
process: ProcessRef,
// Resources
@ -131,6 +134,14 @@ impl Thread {
*self.clear_ctid.write().unwrap() = new_clear_ctid;
}
pub fn name(&self) -> ThreadName {
self.name.read().unwrap().clone()
}
pub fn set_name(&self, new_name: ThreadName) {
*self.name.write().unwrap() = new_name;
}
pub(super) fn start(&self, host_tid: pid_t) {
self.sched().lock().unwrap().attach(host_tid);
self.inner().start();

@ -0,0 +1,52 @@
use crate::prelude::*;
use std::ffi::{CStr, CString};
// The thread name buffer should allow space for up to 16 bytes, including the terminating null byte.
const THREAD_NAME_MAX_LEN: usize = 16;
/// A thread name represented in a fixed buffer of 16 bytes.
///
/// The length is chosen to be consistent with Linux.
#[derive(Debug, Clone, Default)]
pub struct ThreadName {
buf: [u8; THREAD_NAME_MAX_LEN],
len: usize, // including null terminator
}
impl ThreadName {
/// Construct a thread name from str
pub fn new(name: &str) -> Self {
Self::from_slice(CString::new(name).unwrap().as_bytes_with_nul())
}
pub const fn max_len() -> usize {
THREAD_NAME_MAX_LEN
}
/// Construct a thread name from slice
pub fn from_slice(input: &[u8]) -> Self {
let mut buf = [0; THREAD_NAME_MAX_LEN];
let mut len = THREAD_NAME_MAX_LEN;
for (i, b) in buf.iter_mut().take(THREAD_NAME_MAX_LEN - 1).enumerate() {
if input[i] == '\0' as u8 {
len = i + 1;
break;
}
*b = input[i];
}
debug_assert!(buf[THREAD_NAME_MAX_LEN - 1] == 0);
Self { buf, len }
}
/// Returns a byte slice
pub fn as_slice(&self) -> &[u8] {
&self.buf
}
/// Converts to a CStr.
pub fn as_c_str(&self) -> &CStr {
// Note: from_bytes_with_nul will fail if slice has more than 1 '\0' at the end
CStr::from_bytes_with_nul(&self.buf[..self.len]).unwrap_or_default()
}
}

@ -35,8 +35,8 @@ use crate::net::{
};
use crate::process::{
do_arch_prctl, do_clone, do_exit, do_exit_group, do_futex, do_getegid, do_geteuid, do_getgid,
do_getpgid, do_getpid, do_getppid, do_gettid, do_getuid, do_set_tid_address, do_spawn,
do_wait4, pid_t, FdOp, ThreadStatus,
do_getpgid, do_getpid, do_getppid, do_gettid, do_getuid, do_prctl, do_set_tid_address,
do_spawn, do_wait4, pid_t, FdOp, ThreadStatus,
};
use crate::sched::{do_sched_getaffinity, do_sched_setaffinity, do_sched_yield};
use crate::signal::{
@ -234,7 +234,7 @@ macro_rules! process_syscall_table_with_callback {
(ModifyLdt = 154) => handle_unsupported(),
(PivotRoot = 155) => handle_unsupported(),
(SysCtl = 156) => handle_unsupported(),
(Prctl = 157) => handle_unsupported(),
(Prctl = 157) => do_prctl(option: i32, arg2: u64, arg3: u64, arg4: u64, arg5: u64),
(ArchPrctl = 158) => do_arch_prctl(code: u32, addr: *mut usize),
(Adjtimex = 159) => handle_unsupported(),
(Setrlimit = 160) => handle_unsupported(),

@ -18,7 +18,7 @@ TEST_DEPS := client data_sink
TESTS ?= env empty hello_world malloc mmap file fs_perms getpid spawn sched pipe time \
truncate readdir mkdir open stat link symlink chmod chown tls pthread uname rlimit \
server server_epoll unix_socket cout hostfs cpuid rdtsc device sleep exit_group \
ioctl fcntl eventfd emulate_syscall access signal sysinfo
ioctl fcntl eventfd emulate_syscall access signal sysinfo prctl
# Benchmarks: need to be compiled and run by bench-% target
BENCHES := spawn_and_exit_latency pipe_throughput unix_socket_throughput

5
test/prctl/Makefile Normal file

@ -0,0 +1,5 @@
include ../test_common.mk
EXTRA_C_FLAGS := -Wno-stringop-truncation
EXTRA_LINK_FLAGS :=
BIN_ARGS :=

140
test/prctl/main.c Normal file

@ -0,0 +1,140 @@
#define _GNU_SOURCE
#include <stdlib.h>
#include <unistd.h>
#include <stdio.h>
#include <pthread.h>
#include <string.h>
#include <sys/prctl.h>
#include <errno.h>
#include "test.h"
// ============================================================================
// Helper function
// ============================================================================
#define THREAD_NAME_LEN 16
extern char *program_invocation_short_name;
#define DEFAULT_NAME program_invocation_short_name // name of this executable
static const char *LONG_NAME = "A very very long thread name that is over 16 bytes";
static const char *NORMAL_NAME = "A thread name";
static int *test_thread_long_name(void *arg) {
char thread_name[THREAD_NAME_LEN] = {0};
char correct_name[THREAD_NAME_LEN] = {0};
// Thread name can hold up to 16 bytes including null terminator
// Construct the "correct_name" from the "long_name"
strncpy(correct_name, LONG_NAME, THREAD_NAME_LEN - 1);
correct_name[THREAD_NAME_LEN - 1] = '\0';
if (prctl(PR_SET_NAME, LONG_NAME) != 0) {
printf("long name test set thread name error\n");
return (int *) -1;
}
if (prctl(PR_GET_NAME, thread_name) != 0) {
printf("long name test set thread name error\n");
return (int *) -1;
}
if (!strncmp(thread_name, correct_name, THREAD_NAME_LEN)) {
return NULL;
} else {
printf("test long thread name mismatch\n");
return (int *) -1;
}
}
static int *test_thread_normal_name(void *arg) {
char thread_name[THREAD_NAME_LEN] = {0};
if (prctl(PR_SET_NAME, NORMAL_NAME) != 0) {
printf("normal name test set thread name error\n");
return (int *) -1;
};
if (prctl(PR_GET_NAME, thread_name) != 0) {
printf("normal name test get thread name error\n");
return (int *) -1;
}
if (!strncmp(thread_name, NORMAL_NAME, strlen(NORMAL_NAME))) {
return NULL;
} else {
printf("test normal thread name mismatch\n");
return (int *) -1;
}
}
static int *test_thread_default_name(void *arg) {
char thread_name[THREAD_NAME_LEN] = {0};
if (prctl(PR_GET_NAME, thread_name) != 0) {
printf("get thread default name error\n");
return (int *) -1;
}
// The DEFAULT_NAME could be longer than THREAD_NAME_LEN and thus will make the last byte
// to be the null-terminator. So we just compare length with "THREAD_NAME_LEN - 1"
if (!strncmp(thread_name, DEFAULT_NAME, THREAD_NAME_LEN - 1)) {
return NULL;
} else {
printf("test default thread name mismatch\n");
return (int *) -1;
}
}
// ============================================================================
// Test cases
// ============================================================================
static int test_prctl_set_get_long_name(void) {
pthread_t tid;
void *ret;
if (pthread_create(&tid, NULL, (void *)test_thread_long_name, NULL)) {
THROW_ERROR("create test long name thread failed");
}
pthread_join(tid, &ret);
if ((int *) ret) {
THROW_ERROR("test long name thread prctl error");
}
return 0;
}
static int test_prctl_set_get_normal_name(void) {
pthread_t tid;
void *ret;
if (pthread_create(&tid, NULL, (void *)test_thread_normal_name, NULL)) {
THROW_ERROR("create test normal name thread failed");
}
pthread_join(tid, &ret);
if ((int *) ret) {
THROW_ERROR("test normal name thread prctl error");
}
return 0;
}
static int test_prctl_get_default_thread_name(void) {
pthread_t tid;
void *ret;
if (pthread_create(&tid, NULL, (void *)test_thread_default_name, NULL)) {
THROW_ERROR("create test default name thread failed");
}
pthread_join(tid, &ret);
if ((int *) ret) {
THROW_ERROR("test default name thread prctl error");
}
return 0;
}
// ============================================================================
// Test suite main
// ============================================================================
static test_case_t test_cases[] = {
TEST_CASE(test_prctl_set_get_long_name), // over 16 bytes
TEST_CASE(test_prctl_set_get_normal_name),
TEST_CASE(test_prctl_get_default_thread_name),
};
int main() {
return test_suite_run(test_cases, ARRAY_SIZE(test_cases));
}