diff --git a/src/libos/src/net/socket_file/recv.rs b/src/libos/src/net/socket_file/recv.rs index 58ffa869..33f3437b 100644 --- a/src/libos/src/net/socket_file/recv.rs +++ b/src/libos/src/net/socket_file/recv.rs @@ -82,7 +82,7 @@ impl SocketFile { let msg_control = msg_control as *mut c_void; let mut msg_controllen_recvd = 0; // Flags - let flags = flags.bits(); + let raw_flags = flags.bits(); let mut msg_flags_recvd = 0; // Do OCall @@ -100,7 +100,7 @@ impl SocketFile { msg_controllen, &mut msg_controllen_recvd as *mut usize, &mut msg_flags_recvd as *mut i32, - flags, + raw_flags, ); assert!(status == sgx_status_t::SGX_SUCCESS); @@ -109,6 +109,8 @@ impl SocketFile { retval }); + let flags_recvd = MsgHdrFlags::from_bits(msg_flags_recvd).unwrap(); + // Check values returned from outside the enclave let bytes_recvd = { // Guarantted by try_libc! @@ -117,14 +119,19 @@ impl SocketFile { // Check bytes_recvd returned from outside the enclave let max_bytes_recvd = data.iter().map(|x| x.len()).sum(); - assert!(retval <= max_bytes_recvd); + + // For MSG_TRUNC recvmsg returns the real length of the packet or datagram, + // even when it was longer than the passed buffer. + if flags.contains(RecvFlags::MSG_TRUNC) && retval > max_bytes_recvd { + assert!(flags_recvd.contains(MsgHdrFlags::MSG_TRUNC)); + } else { + assert!(retval <= max_bytes_recvd); + } retval }; let msg_namelen_recvd = msg_namelen_recvd as usize; assert!(msg_namelen_recvd <= msg_namelen); assert!(msg_controllen_recvd <= msg_controllen); - let flags_recvd = MsgHdrFlags::from_bits(msg_flags_recvd).unwrap(); - Ok(( bytes_recvd, msg_namelen_recvd,