diff --git a/Cargo.toml b/Cargo.toml index 4aeafeb2..02466c61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,7 +75,7 @@ widestring = "1.0.2" windows-sys = "0.61.0" thiserror = "2.0.3" smallvec = "1.13.2" -synchrony = "0.1.7" +synchrony = "0.1.8" thin-cell = "0.2.1" slotmap = "1.1.1" crossfire = "3.1.5" diff --git a/compio-driver/src/sys/driver/mod.rs b/compio-driver/src/sys/driver/mod.rs index ed6002d4..224a17a3 100644 --- a/compio-driver/src/sys/driver/mod.rs +++ b/compio-driver/src/sys/driver/mod.rs @@ -21,3 +21,11 @@ cfg_if::cfg_if! { crate::assert_not_impl!(Driver, Send); crate::assert_not_impl!(Driver, Sync); + +/// An operation that can be optimized by making use of the "poll-first" +/// feature. +pub trait PollFirst { + /// Poll first before syscall. This is only meaningful for io-uring. It sets + /// `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` of the SQE. + fn poll_first(&mut self); +} diff --git a/compio-driver/src/sys/extra/mod.rs b/compio-driver/src/sys/extra/mod.rs index 16ffeafe..c9abd5a0 100644 --- a/compio-driver/src/sys/extra/mod.rs +++ b/compio-driver/src/sys/extra/mod.rs @@ -96,8 +96,7 @@ impl Extra { /// /// # Behaviour /// - /// This method must be used only on the flags for any of the `receive` - /// variants supported by `IO_URING`. The driver will try to check whether + /// This method must be used only on `IO_URING`. The driver will try to check whether /// the `IORING_CQE_F_SOCK_NONEMPTY` flag was set by the kernel for the CQE. /// On other platforms, this will always return the [`Unsupported`] error. /// diff --git a/compio-driver/src/sys/op/managed/fallback.rs b/compio-driver/src/sys/op/managed/fallback.rs index b690b013..a31d3d9b 100644 --- a/compio-driver/src/sys/op/managed/fallback.rs +++ b/compio-driver/src/sys/op/managed/fallback.rs @@ -5,7 +5,7 @@ use rustix::net::RecvFlags; use socket2::SockAddr; use crate::{ - AsFd, BufferPool, BufferRef, + AsFd, BufferPool, BufferRef, PollFirst, op::{RecvMsg, TakeBuffer}, sys::op::{Read, ReadAt, Recv, RecvFrom}, }; @@ -69,11 +69,12 @@ impl RecvManaged { op: Recv::new(fd, pool.pop()?.with_capacity(len), flags), }) } +} - /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` - /// of the SQE on the IO_URING driver. - // This method has been added here for the sake of API compatibility. - pub fn poll_first(&mut self) {} +impl PollFirst for RecvManaged { + fn poll_first(&mut self) { + self.op.poll_first(); + } } impl TakeBuffer for RecvManaged { @@ -96,11 +97,12 @@ impl RecvFromManaged { op: RecvFrom::new(fd, pool.pop()?.with_capacity(len), flags), }) } +} - /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` - /// of the SQE on the IO_URING driver. - // This method has been added here for the sake of API compatibility. - pub fn poll_first(&mut self) {} +impl PollFirst for RecvFromManaged { + fn poll_first(&mut self) { + self.op.poll_first(); + } } impl TakeBuffer for RecvFromManaged { @@ -129,11 +131,12 @@ impl RecvMsgManaged { op: RecvMsg::new(fd, [pool.pop()?.with_capacity(len)], control, flags), }) } +} - /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` - /// of the SQE on the IO_URING driver. - // This method has been added here for the sake of API compatibility. - pub fn poll_first(&mut self) {} +impl PollFirst for RecvMsgManaged { + fn poll_first(&mut self) { + self.op.poll_first(); + } } impl TakeBuffer for RecvMsgManaged { diff --git a/compio-driver/src/sys/op/managed/fusion.rs b/compio-driver/src/sys/op/managed/fusion.rs index c511037a..35fa6189 100644 --- a/compio-driver/src/sys/op/managed/fusion.rs +++ b/compio-driver/src/sys/op/managed/fusion.rs @@ -3,7 +3,9 @@ use rustix::net::RecvFlags; use socket2::SockAddr; use super::{fallback, iour}; -use crate::{BufferPool, BufferRef, IourOpCode, OpEntry, OpType, PollOpCode, sys::pal::*}; +use crate::{ + BufferPool, BufferRef, IourOpCode, OpEntry, OpType, PollFirst, PollOpCode, sys::pal::*, +}; macro_rules! mop { (<$($ty:ident: $trait:ident),* $(,)?> $name:ident( $($arg:ident: $arg_t:ty),* $(,)? ) with $pool:ident) => { @@ -131,10 +133,8 @@ mop!( RecvMulti(fd: S, pool: &BufferPool, len: usize, flags: RecvFlags) mop!( RecvFromMulti(fd: S, pool: &BufferPool, flags: RecvFlags) with pool; RecvFromMultiResult); mop!( RecvMsgMulti(fd: S, pool: &BufferPool, control_len: usize, flags: RecvFlags) with pool; RecvMsgMultiResult); -impl RecvManaged { - /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` - /// of the SQE on the IO_URING driver. - pub fn poll_first(&mut self) { +impl PollFirst for RecvManaged { + fn poll_first(&mut self) { match self.inner { RecvManagedInner::Poll(ref mut i) => i.poll_first(), RecvManagedInner::IoUring(ref mut i) => i.poll_first(), @@ -142,10 +142,8 @@ impl RecvManaged { } } -impl RecvFromManaged { - /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` - /// of the SQE on the IO_URING driver. - pub fn poll_first(&mut self) { +impl PollFirst for RecvFromManaged { + fn poll_first(&mut self) { match self.inner { RecvFromManagedInner::Poll(ref mut i) => i.poll_first(), RecvFromManagedInner::IoUring(ref mut i) => i.poll_first(), @@ -153,10 +151,8 @@ impl RecvFromManaged { } } -impl RecvMsgManaged { - /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` - /// of the SQE on the IO_URING driver. - pub fn poll_first(&mut self) { +impl PollFirst for RecvMsgManaged { + fn poll_first(&mut self) { match self.inner { RecvMsgManagedInner::Poll(ref mut i) => i.poll_first(), RecvMsgManagedInner::IoUring(ref mut i) => i.poll_first(), diff --git a/compio-driver/src/sys/op/managed/iour.rs b/compio-driver/src/sys/op/managed/iour.rs index 255c4812..38b2e4c3 100644 --- a/compio-driver/src/sys/op/managed/iour.rs +++ b/compio-driver/src/sys/op/managed/iour.rs @@ -12,7 +12,7 @@ use rustix::net::RecvFlags; use socket2::{SockAddr, SockAddrStorage, socklen_t}; use crate::{ - BufferPool, BufferRef, Extra, IourOpCode as OpCode, OpEntry, + BufferPool, BufferRef, Extra, IourOpCode as OpCode, OpEntry, PollFirst, op::TakeBuffer, sys::pal::{is_kernel_at_least, set_poll_first}, }; @@ -162,10 +162,10 @@ impl RecvManaged { poll_first: false, }) } +} - /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` - /// of the SQE on the IO_URING driver. - pub fn poll_first(&mut self) { +impl PollFirst for RecvManaged { + fn poll_first(&mut self) { self.poll_first = true; } } @@ -250,10 +250,10 @@ impl RecvFromManaged { poll_first: false, }) } +} - /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` - /// of the SQE on the IO_URING driver. - pub fn poll_first(&mut self) { +impl PollFirst for RecvFromManaged { + fn poll_first(&mut self) { self.poll_first = true; } } @@ -330,10 +330,10 @@ impl RecvMsgManaged { control_len: 0, }) } +} - /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` - /// of the SQE on the IO_URING driver. - pub fn poll_first(&mut self) { +impl PollFirst for RecvMsgManaged { + fn poll_first(&mut self) { self.op.poll_first(); } } diff --git a/compio-driver/src/sys/op/socket/iour.rs b/compio-driver/src/sys/op/socket/iour.rs index a2d3be86..7c6ed70e 100644 --- a/compio-driver/src/sys/op/socket/iour.rs +++ b/compio-driver/src/sys/op/socket/iour.rs @@ -94,14 +94,15 @@ unsafe impl OpCode for Accept { type Control = (); fn create_entry(&mut self, _: &mut Self::Control) -> OpEntry { - opcode::Accept::new( + let entry = opcode::Accept::new( Fd(self.fd.as_fd().as_raw_fd()), unsafe { self.buffer.view_as::() }, &raw mut self.addr_len, ) .flags(libc::SOCK_CLOEXEC) - .build() - .into() + .build(); + let entry = set_poll_first(entry, self.poll_first); + entry.into() } unsafe fn set_result(&mut self, _: &mut Self::Control, res: &io::Result, _: &Extra) { diff --git a/compio-driver/src/sys/op/socket/mod.rs b/compio-driver/src/sys/op/socket/mod.rs index cc0ca5b6..3bf87ce7 100644 --- a/compio-driver/src/sys/op/socket/mod.rs +++ b/compio-driver/src/sys/op/socket/mod.rs @@ -15,7 +15,7 @@ mod_use![stub]; use rustix::net::{RecvFlags, SendFlags}; -use crate::sys::prelude::*; +use crate::{PollFirst, sys::prelude::*}; /// Connect to a remote address. pub struct Connect { @@ -261,10 +261,10 @@ impl RecvMsg { poll_first: false, } } +} - /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` - /// of the SQE on the IO_URING driver. - pub fn poll_first(&mut self) { +impl PollFirst for RecvMsg { + fn poll_first(&mut self) { self.poll_first = true; } } @@ -291,10 +291,10 @@ impl Recv { poll_first: false, } } +} - /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` - /// of the SQE on the IO_URING driver. - pub fn poll_first(&mut self) { +impl PollFirst for Recv { + fn poll_first(&mut self) { self.poll_first = true; } } @@ -317,10 +317,10 @@ impl RecvVectored { poll_first: false, } } +} - /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` - /// of the SQE on the IO_URING driver. - pub fn poll_first(&mut self) { +impl PollFirst for RecvVectored { + fn poll_first(&mut self) { self.poll_first = true; } } @@ -359,10 +359,10 @@ impl RecvFromVectored { buffer, } } +} - /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` - /// of the SQE on the IO_URING driver. - pub fn poll_first(&mut self) { +impl PollFirst for RecvFromVectored { + fn poll_first(&mut self) { self.header.poll_first = true; } } @@ -384,10 +384,10 @@ impl RecvFrom { buffer, } } +} - /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` - /// of the SQE on the IO_URING driver. - pub fn poll_first(&mut self) { +impl PollFirst for RecvFrom { + fn poll_first(&mut self) { self.header.poll_first = true; } } diff --git a/compio-driver/src/sys/op/socket/unix.rs b/compio-driver/src/sys/op/socket/unix.rs index d1a9eb53..4ddee2e7 100644 --- a/compio-driver/src/sys/op/socket/unix.rs +++ b/compio-driver/src/sys/op/socket/unix.rs @@ -9,7 +9,7 @@ use rustix::{ }, }; -use crate::sys::op::*; +use crate::{PollFirst, sys::op::*}; impl Accept { pub(crate) fn call(&mut self) -> io::Result { @@ -299,6 +299,7 @@ pub struct Accept { pub(crate) buffer: SockAddrStorage, pub(crate) addr_len: socklen_t, pub(crate) accepted_fd: Option, + pub(crate) poll_first: bool, } impl Accept { @@ -311,10 +312,17 @@ impl Accept { buffer, addr_len, accepted_fd: None, + poll_first: false, } } } +impl PollFirst for Accept { + fn poll_first(&mut self) { + self.poll_first = true; + } +} + impl IntoInner for Accept { type Inner = (Socket2, SockAddr); diff --git a/compio-driver/src/sys/pal/iour/mod.rs b/compio-driver/src/sys/pal/iour/mod.rs index 967ea7e8..a06edbfd 100644 --- a/compio-driver/src/sys/pal/iour/mod.rs +++ b/compio-driver/src/sys/pal/iour/mod.rs @@ -57,7 +57,11 @@ pub fn is_kernel_at_least(v: impl Into) -> bool { } pub(crate) fn set_poll_first(mut entry: Entry, flag: bool) -> Entry { - if flag && is_kernel_at_least((5, 19)) { + let version = match entry.get_opcode() as u8 { + io_uring::opcode::Accept::CODE => (6, 10), + _ => (5, 19), + }; + if flag && is_kernel_at_least(version) { let sqe = &raw mut entry as *mut io_uring_sqe; unsafe { (*sqe).ioprio |= IORING_RECVSEND_POLL_FIRST as u16; diff --git a/compio-net/src/socket/linux.rs b/compio-net/src/socket/linux.rs index 5fc7747d..6a352f4d 100644 --- a/compio-net/src/socket/linux.rs +++ b/compio-net/src/socket/linux.rs @@ -1,24 +1,21 @@ -use std::{ops::Deref, sync::atomic::Ordering}; +use std::sync::atomic::Ordering; +use compio_driver::{Extra, PollFirst}; #[cfg(feature = "sync")] -use synchrony::sync::atomic::AtomicI8; +use synchrony::sync::atomic::AtomicU8; #[cfg(not(feature = "sync"))] -use synchrony::unsync::atomic::AtomicI8; +use synchrony::unsync::atomic::AtomicU8; -// We are not on the IO_URING driver and hence retrieving socket state is -// not supported. -const UNSUPPORTED: i8 = -1; +const RECV_OFFSET: usize = 0; +const ACCEPT_OFFSET: usize = 2; -// The socket was empty after the receive operation. -const EMPTY: i8 = 0; - -// The socket was not-empty after the last receive operation and has more -// data to be read. -const NON_EMPTY: i8 = 1; +const UNSET: u8 = 0; +const EMPTY: u8 = 1; +const NON_EMPTY: u8 = 2; #[derive(Debug)] pub(super) struct SocketState { - state: AtomicI8, + state: AtomicU8, } impl Default for SocketState { @@ -30,39 +27,60 @@ impl Default for SocketState { impl SocketState { pub(super) fn new() -> Self { Self { - state: AtomicI8::new(-1), + state: AtomicU8::new(0), } } - pub(super) fn get(&self) -> Option { - match self.load(Ordering::Relaxed) { - UNSUPPORTED => None, + fn get_bit(&self, offset: usize) -> Option { + let state = self.state.load(Ordering::Relaxed); + match (state >> offset) & 0b11 { + UNSET => None, EMPTY => Some(false), NON_EMPTY => Some(true), _ => unreachable!(), } } - pub(super) fn set(&self, extra: &compio_driver::Extra) { + fn set_bit(&self, offset: usize, value: bool) { + let bits = if value { NON_EMPTY } else { EMPTY } << offset; + self.state + .update(Ordering::Relaxed, Ordering::Relaxed, |state| { + (state & !(0b11 << offset)) | bits + }); + } + + fn set_op(&self, offset: usize, op: &mut impl PollFirst) { + if self.get_bit(offset) == Some(false) { + op.poll_first(); + } + } + + pub(super) fn set_recv(&self, extra: &Extra) { + if let Ok(n) = extra.sock_nonempty() { + self.set_bit(RECV_OFFSET, n); + } + } + + pub(super) fn set_recv_op(&self, op: &mut impl PollFirst) { + self.set_op(RECV_OFFSET, op); + } + + pub(super) fn set_accept(&self, extra: &Extra) { if let Ok(n) = extra.sock_nonempty() { - self.store(n as i8, Ordering::Relaxed); + self.set_bit(ACCEPT_OFFSET, n); } } + + pub(super) fn set_accept_op(&self, op: &mut impl PollFirst) { + self.set_op(ACCEPT_OFFSET, op); + } } impl Clone for SocketState { fn clone(&self) -> Self { let current = self.state.load(Ordering::Relaxed); Self { - state: AtomicI8::new(current), + state: AtomicU8::new(current), } } } - -impl Deref for SocketState { - type Target = AtomicI8; - - fn deref(&self) -> &Self::Target { - &self.state - } -} diff --git a/compio-net/src/socket/mod.rs b/compio-net/src/socket/mod.rs index 53f29a78..aa58d699 100644 --- a/compio-net/src/socket/mod.rs +++ b/compio-net/src/socket/mod.rs @@ -54,22 +54,8 @@ cfg_if::cfg_if! { #[path = "linux.rs"] mod sys; } else { - mod sys { - #[derive(Default, Clone, Debug)] - pub(super) struct SocketState; - - impl SocketState { - pub(super) fn new() -> Self { - SocketState - } - - pub(super) fn get(&self) -> Option { - None - } - - pub(super) fn set(&self, _: &compio_driver::Extra) {} - } - } + #[path = "stub.rs"] + mod sys; } } @@ -164,8 +150,11 @@ impl Socket { #[cfg(unix)] pub async fn accept(&self) -> io::Result<(Self, SockAddr)> { - let op = Accept::new(self.to_shared_fd()); - let (_, op) = buf_try!(@try compio_runtime::submit(op).await); + let mut op = Accept::new(self.to_shared_fd()); + self.state.set_accept_op(&mut op); + let (BufResult(res, op), extra) = compio_runtime::submit(op).with_extra().await; + res?; + self.state.set_accept(&extra); let (accept_sock, addr) = op.into_inner(); let accept_sock = Self::from_socket2(accept_sock)?; Ok((accept_sock, addr)) @@ -246,11 +235,9 @@ impl Socket { pub async fn recv(&self, buffer: B, flags: RecvFlags) -> BufResult { let fd = self.to_shared_fd(); let mut op = Recv::new(fd, buffer, flags); - if self.state.get() == Some(false) { - op.poll_first(); - } + self.state.set_recv_op(&mut op); let (res, extra) = compio_runtime::submit(op).with_extra().await; - self.state.set(&extra); + self.state.set_recv(&extra); let res = res.into_inner(); unsafe { res.map_advanced() } } @@ -262,11 +249,9 @@ impl Socket { ) -> BufResult { let fd = self.to_shared_fd(); let mut op = RecvVectored::new(fd, buffer, flags); - if self.state.get() == Some(false) { - op.poll_first(); - } + self.state.set_recv_op(&mut op); let (res, extra) = compio_runtime::submit(op).with_extra().await; - self.state.set(&extra); + self.state.set_recv(&extra); let res = res.into_inner(); unsafe { res.map_vec_advanced() } } @@ -280,14 +265,12 @@ impl Socket { let (res, extra) = Runtime::with_current(|rt| { let buffer_pool = rt.buffer_pool()?; let mut op = RecvManaged::new(fd, &buffer_pool, len, flags)?; - if self.state.get() == Some(false) { - op.poll_first(); - } + self.state.set_recv_op(&mut op); io::Result::Ok(rt.submit(op).with_extra()) })? .await; - self.state.set(&extra); + self.state.set_recv(&extra); unsafe { res.take_buffer() } } @@ -346,11 +329,9 @@ impl Socket { ) -> BufResult<(usize, Option), T> { let fd = self.to_shared_fd(); let mut op = RecvFrom::new(fd, buffer, flags); - if self.state.get() == Some(false) { - op.poll_first(); - } + self.state.set_recv_op(&mut op); let (res, extra) = compio_runtime::submit(op).with_extra().await; - self.state.set(&extra); + self.state.set_recv(&extra); let res = res.into_inner().map_addr(); unsafe { res.map_advanced() } } @@ -362,11 +343,9 @@ impl Socket { ) -> BufResult<(usize, Option), T> { let fd = self.to_shared_fd(); let mut op = RecvFromVectored::new(fd, buffer, flags); - if self.state.get() == Some(false) { - op.poll_first(); - } + self.state.set_recv_op(&mut op); let (res, extra) = compio_runtime::submit(op).with_extra().await; - self.state.set(&extra); + self.state.set_recv(&extra); let res = res.into_inner().map_addr(); unsafe { res.map_vec_advanced() } } @@ -380,13 +359,11 @@ impl Socket { let (inner, extra) = Runtime::with_current(|rt| { let buffer_pool = rt.buffer_pool()?; let mut op = RecvFromManaged::new(fd, &buffer_pool, len, flags)?; - if self.state.get() == Some(false) { - op.poll_first(); - } + self.state.set_recv_op(&mut op); io::Result::Ok(rt.submit(op).with_extra()) })? .await; - self.state.set(&extra); + self.state.set_recv(&extra); let (len, op) = buf_try!(@try inner); // Kernel returns 0 for the operation, drop the buffer and return Ok(None) if len == 0 { @@ -435,11 +412,9 @@ impl Socket { ) -> BufResult<(usize, usize, Option), (T, C)> { let fd = self.to_shared_fd(); let mut op = RecvMsg::new(fd, buffer, control, flags); - if self.state.get() == Some(false) { - op.poll_first(); - } + self.state.set_recv_op(&mut op); let (res, extra) = compio_runtime::submit(op).with_extra().await; - self.state.set(&extra); + self.state.set_recv(&extra); let res = res.into_inner().map_addr(); unsafe { res.map_vec_advanced() } } @@ -454,13 +429,11 @@ impl Socket { let (inner, extra) = Runtime::with_current(|rt| { let buffer_pool = rt.buffer_pool()?; let mut op = RecvMsgManaged::new(fd, &buffer_pool, len, control, flags)?; - if self.state.get() == Some(false) { - op.poll_first(); - } + self.state.set_recv_op(&mut op); io::Result::Ok(rt.submit(op).with_extra()) })? .await; - self.state.set(&extra); + self.state.set_recv(&extra); let (len, op) = buf_try!(@try inner); // Kernel returns 0 for the operation, drop the buffer and return Ok(None) if len == 0 { diff --git a/compio-net/src/socket/stub.rs b/compio-net/src/socket/stub.rs new file mode 100644 index 00000000..c8628aa8 --- /dev/null +++ b/compio-net/src/socket/stub.rs @@ -0,0 +1,19 @@ +use compio_driver::{Extra, PollFirst}; + +#[derive(Debug, Default, Clone)] +pub(super) struct SocketState; + +#[allow(dead_code)] +impl SocketState { + pub(super) fn new() -> Self { + SocketState + } + + pub(super) fn set_recv(&self, _: &Extra) {} + + pub(super) fn set_recv_op(&self, _: &mut impl PollFirst) {} + + pub(super) fn set_accept(&self, _: &Extra) {} + + pub(super) fn set_accept_op(&self, _: &mut impl PollFirst) {} +}