diff --git a/src/libos/src/vm/process_vm.rs b/src/libos/src/vm/process_vm.rs index dd8591e3..88418385 100644 --- a/src/libos/src/vm/process_vm.rs +++ b/src/libos/src/vm/process_vm.rs @@ -373,6 +373,12 @@ impl ProcessVM { } pub fn mprotect(&self, addr: usize, size: usize, perms: VMPerms) -> Result<()> { + let size = { + if size == 0 { + return Ok(()); + } + align_up(size, PAGE_SIZE) + }; let protect_range = VMRange::new_with_size(addr, size)?; if !self.process_range.range().is_superset_of(&protect_range) { return_errno!(ENOMEM, "invalid range"); diff --git a/test/mmap/main.c b/test/mmap/main.c index b1a9ff08..ad2463b3 100644 --- a/test/mmap/main.c +++ b/test/mmap/main.c @@ -9,6 +9,7 @@ #include #include #include +#include #include "test.h" // ============================================================================ @@ -1061,6 +1062,35 @@ int test_mprotect_with_invalid_prot() { return 0; } +int test_mprotect_with_non_page_aligned_size() { + int flags = MAP_PRIVATE | MAP_ANONYMOUS; + void *buf = mmap(NULL, PAGE_SIZE * 2, PROT_NONE, flags, -1, 0); + if (buf == MAP_FAILED) { + THROW_ERROR("mmap failed"); + } + + // Use raw syscall interface becase libc wrapper will handle non-page-aligned address + // and will not cause failure. + // Raw mprotect syscall with non-page-aligned address should fail. + int ret = syscall(SYS_mprotect, buf + 10, PAGE_SIZE, PROT_WRITE); + if (ret == 0 || errno != EINVAL) { + THROW_ERROR("mprotect with non-page-aligned address should fail with EINVAL"); + } + + // According to man page of mprotect, this syscall require a page aligned start address, but the size could be any value. + // Raw mprotect syscall with non-page-aligned size should succeed. + ret = syscall(SYS_mprotect, buf, PAGE_SIZE + 100, PROT_WRITE); + if (ret < 0) { + THROW_ERROR("mprotect with non-page-aligned size failed"); + } + + // Mprotect succeeded and the pages are writable. + *(char *)buf = 1; + *(char *)(buf + PAGE_SIZE) = 1; + + return 0; +} + // ============================================================================ // Test suite main // ============================================================================ @@ -1100,6 +1130,7 @@ static test_case_t test_cases[] = { TEST_CASE(test_mprotect_with_zero_len), TEST_CASE(test_mprotect_with_invalid_addr), TEST_CASE(test_mprotect_with_invalid_prot), + TEST_CASE(test_mprotect_with_non_page_aligned_size), }; int main() {