From 34de3acc122f152b29ec977634ee71b0f88c510a Mon Sep 17 00:00:00 2001 From: Dan Gohman Date: Wed, 22 Jan 2025 21:17:28 -0800 Subject: [PATCH] Report the number of received bytes in `recv_uninit` and `recvfrom_uninit`. As discussed in #1159, add a length to the return value of `recv_uninit` and `recvfrom_uninit` to report the original received length, which with `RecvFlags::TRUNC` may differ from the returned buffer length. --- src/net/send_recv/mod.rs | 33 ++++++++++++++++----------------- tests/net/recv_trunc.rs | 15 +++++++++++---- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/net/send_recv/mod.rs b/src/net/send_recv/mod.rs index 2a4d1db35..898231e0b 100644 --- a/src/net/send_recv/mod.rs +++ b/src/net/send_recv/mod.rs @@ -70,26 +70,23 @@ pub fn recv(fd: Fd, buf: &mut [u8], flags: RecvFlags) -> io::Result( fd: Fd, buf: &mut [MaybeUninit], flags: RecvFlags, -) -> io::Result<(&mut [u8], &mut [MaybeUninit])> { +) -> io::Result<(&mut [u8], &mut [MaybeUninit], usize)> { let length = unsafe { backend::net::syscalls::recv(fd.as_fd(), buf.as_mut_ptr().cast::(), buf.len(), flags)? }; // If the `TRUNC` flag is set, the returned `length` may be longer than the // buffer length. - Ok(unsafe { split_init(buf, min(length, buf.len())) }) + let (init, uninit) = unsafe { split_init(buf, min(length, buf.len())) }; + Ok((init, uninit, length)) } /// `send(fd, buf, flags)`—Writes data to a socket. @@ -167,19 +164,21 @@ pub fn recvfrom( /// /// This is equivalent to [`recvfrom`], except that it can read into /// uninitialized memory. It returns the slice that was initialized by this -/// function and the slice that remains uninitialized. -/// -/// Because this interface returns the length via the returned slice, it's -/// unsable to return the untruncated length that would be returned when the -/// `RecvFlags::TRUNC` flag is used. If you need the untruncated length, use -/// [`recvfrom`]. +/// function, the slice that remains uninitialized, the number of bytes +/// received before any truncation due to the `RecvFlags::TRUNC` flag, and +/// the address of the sender if known. #[allow(clippy::type_complexity)] #[inline] pub fn recvfrom_uninit( fd: Fd, buf: &mut [MaybeUninit], flags: RecvFlags, -) -> io::Result<(&mut [u8], &mut [MaybeUninit], Option)> { +) -> io::Result<( + &mut [u8], + &mut [MaybeUninit], + usize, + Option, +)> { let (length, addr) = unsafe { backend::net::syscalls::recvfrom( fd.as_fd(), @@ -192,7 +191,7 @@ pub fn recvfrom_uninit( // If the `TRUNC` flag is set, the returned `length` may be longer than the // buffer length. let (init, uninit) = unsafe { split_init(buf, min(length, buf.len())) }; - Ok((init, uninit, addr)) + Ok((init, uninit, length, addr)) } /// `sendto(fd, buf, flags, addr)`—Writes data to a socket to a specific IP diff --git a/tests/net/recv_trunc.rs b/tests/net/recv_trunc.rs index 3e17c9c93..86932bbca 100644 --- a/tests/net/recv_trunc.rs +++ b/tests/net/recv_trunc.rs @@ -23,8 +23,9 @@ fn net_recv_uninit_trunc() { #[cfg(not(any(apple, solarish, target_os = "netbsd")))] { let mut response = [MaybeUninit::::zeroed(); 5]; - let (init, uninit) = rustix::net::recv_uninit(&receiver, &mut response, RecvFlags::TRUNC) - .expect("recv_uninit"); + let (init, uninit, length) = + rustix::net::recv_uninit(&receiver, &mut response, RecvFlags::TRUNC) + .expect("recv_uninit"); // We used the `TRUNC` flag, so we should have only gotten 5 bytes. assert_eq!(init, b"Hello"); @@ -34,17 +35,23 @@ fn net_recv_uninit_trunc() { let n = rustix::net::sendto_unix(&sender, request, SendFlags::empty(), &name).expect("send"); assert_eq!(n, request.len()); + + // Check the `length`. + assert_eq!(length, 15); } // This time receive it without `TRUNC`. This should fail. let mut response = [MaybeUninit::::zeroed(); 5]; - let (init, uninit) = rustix::net::recv_uninit(&receiver, &mut response, RecvFlags::empty()) - .expect("recv_uninit"); + let (init, uninit, length) = + rustix::net::recv_uninit(&receiver, &mut response, RecvFlags::empty()) + .expect("recv_uninit"); // We didn't use the `TRUNC` flag, so we should have received 15 bytes, // truncated to 5 bytes. assert_eq!(init, b"Hello"); assert!(uninit.is_empty()); + + assert_eq!(length, 5); } /// Test `recvmsg` with the `RecvFlags::Trunc` flag.