Use mutex in Pthread test

This commit is contained in:
Tate, Hongliang Tian 2019-04-06 20:09:03 +08:00 committed by Tate Tian
parent 660d0931cd
commit bd82b27762

@ -6,36 +6,76 @@
* Child threads * Child threads
*/ */
#define NTHREADS 4 #define NTHREADS (4)
#define STACK_SIZE (8 * 1024) #define STACK_SIZE (8 * 1024)
static void* thread_func(void* arg) { #define LOCAL_COUNT (100000L)
int* tid = arg; #define EXPECTED_GLOBAL_COUNT (LOCAL_COUNT * NTHREADS)
printf("tid = %d\n", *tid);
struct thread_arg {
int ti;
long local_count;
volatile long* global_count;
pthread_mutex_t* mutex;
};
static void* thread_func(void* _arg) {
struct thread_arg* arg = _arg;
printf("Thread #%d: started\n", arg->ti);
for (long i = 0; i < arg->local_count; i++) {
pthread_mutex_lock(arg->mutex);
(*arg->global_count)++;
pthread_mutex_unlock(arg->mutex);
}
printf("Thread #%d: completed\n", arg->ti);
return NULL; return NULL;
} }
int main(int argc, const char* argv[]) { int main(int argc, const char* argv[]) {
/*
* Multiple threads are to increase a global counter concurrently
*/
volatile long global_count = 0;
pthread_t threads[NTHREADS]; pthread_t threads[NTHREADS];
int thread_data[NTHREADS]; struct thread_arg thread_args[NTHREADS];
/*
printf("Creating %d threads...", NTHREADS); * Protect the counter with a mutex
*/
pthread_mutex_t mutex;
pthread_mutex_init(&mutex, NULL);
/*
* Start the threads
*/
for (int ti = 0; ti < NTHREADS; ti++) { for (int ti = 0; ti < NTHREADS; ti++) {
thread_data[ti] = ti; struct thread_arg* thread_arg = &thread_args[ti];
if (pthread_create(&threads[ti], NULL, thread_func, &thread_data[ti]) < 0) { thread_arg->ti = ti;
thread_arg->local_count = LOCAL_COUNT;
thread_arg->global_count = &global_count;
thread_arg->mutex = &mutex;
if (pthread_create(&threads[ti], NULL, thread_func, thread_arg) < 0) {
printf("ERROR: pthread_create failed (ti = %d)\n", ti); printf("ERROR: pthread_create failed (ti = %d)\n", ti);
return -1; return -1;
} }
} }
printf("done.\n"); /*
* Wait for the threads to finish
printf("Waiting for %d threads to exit...", NTHREADS); */
for (int ti = 0; ti < NTHREADS; ti++) { for (int ti = 0; ti < NTHREADS; ti++) {
if (pthread_join(threads[ti], NULL) < 0) { if (pthread_join(threads[ti], NULL) < 0) {
printf("ERROR: pthread_join failed (ti = %d)\n", ti); printf("ERROR: pthread_join failed (ti = %d)\n", ti);
return -1; return -1;
} }
} }
printf("done.\n"); /*
* Check the correctness of the concurrent counter
*/
if (global_count != EXPECTED_GLOBAL_COUNT) {
printf("ERROR: incorrect global_count (actual = %ld, expected = %ld)\n",
global_count, EXPECTED_GLOBAL_COUNT);
return -1;
}
pthread_mutex_destroy(&mutex);
return 0; return 0;
} }