diff --git a/.github/workflows/qemu.yml b/.github/workflows/qemu.yml index c18d4d1f3a..9266bb91a4 100644 --- a/.github/workflows/qemu.yml +++ b/.github/workflows/qemu.yml @@ -98,6 +98,7 @@ jobs: --enable-kvm \ --enable-slirp \ --enable-strip \ + --enable-vhost-kernel \ --static \ --disable-docs \ --disable-user \ @@ -117,6 +118,12 @@ jobs: - name: Install TPM 2.0 Reference Implementation build dependencies run: sudo apt install -y build-essential cmake pkg-config + - name: Install netcat used by test-in-svsm for vsock tests + run: sudo apt install -y ncat + + - name: Set up vhost-vsock permissions + run: sudo chmod 666 /dev/vhost-vsock + - name: Build test run: make bin/coconut-test-qemu.igvm diff --git a/Cargo.lock b/Cargo.lock index 41ff9d31d0..84368f2b0f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -650,6 +650,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "embedded-io" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d" + [[package]] name = "enumn" version = "0.1.14" @@ -2293,6 +2299,7 @@ name = "virtio-drivers" version = "0.7.5" dependencies = [ "bitflags", + "embedded-io", "enumn", "log", "zerocopy", diff --git a/Makefile b/Makefile index 3ebbb317fc..d206e5a777 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,7 @@ SVSM_ARGS += --features ${FEATURES} XBUILD_ARGS += -f ${FEATURES} endif -FEATURES_TEST ?= vtpm,virtio-drivers,block +FEATURES_TEST ?= vtpm,virtio-drivers,block,vsock SVSM_ARGS_TEST += --no-default-features ifneq ($(FEATURES_TEST),) SVSM_ARGS_TEST += --features ${FEATURES_TEST} diff --git a/kernel/Cargo.toml b/kernel/Cargo.toml index 8a0f97fa15..d9902391db 100644 --- a/kernel/Cargo.toml +++ b/kernel/Cargo.toml @@ -84,6 +84,7 @@ verus = ["verus_all", "verify_proof/noverify", "verify_external/noverify"] noverify = [] virtio-drivers = ["dep:virtio-drivers"] block = [] +vsock = [] [dev-dependencies] sha2 = { workspace = true, features = ["force-soft"] } diff --git a/kernel/src/error.rs b/kernel/src/error.rs index e4aed14e77..a5583d4f5a 100644 --- a/kernel/src/error.rs +++ b/kernel/src/error.rs @@ -35,6 +35,8 @@ use crate::tdx::TdxError; use crate::utils::immut_after_init::ImmutAfterInitError; #[cfg(feature = "virtio-drivers")] use crate::virtio::VirtioError; +#[cfg(feature = "vsock")] +use crate::vsock::VsockError; use elf::ElfError; use syscall::SysCallError; @@ -143,6 +145,9 @@ pub enum SvsmError { TeeAttestation(AttestationError), /// Errors related to ImmutAfterInitCell ImmutAfterInit(ImmutAfterInitError), + /// Errors related to vsock. + #[cfg(feature = "vsock")] + Vsock(VsockError), } impl From for SvsmError { @@ -183,6 +188,13 @@ impl From for SvsmError { } } +#[cfg(feature = "vsock")] +impl From for SvsmError { + fn from(err: VsockError) -> Self { + Self::Vsock(err) + } +} + impl From for SysCallError { fn from(err: SvsmError) -> Self { match err { diff --git a/kernel/src/lib.rs b/kernel/src/lib.rs index 328e67bccf..fa29d11cf5 100644 --- a/kernel/src/lib.rs +++ b/kernel/src/lib.rs @@ -49,6 +49,8 @@ pub mod utils; #[cfg(feature = "virtio-drivers")] pub mod virtio; pub mod vmm; +#[cfg(feature = "vsock")] +pub mod vsock; #[cfg(all(feature = "vtpm", not(test)))] pub mod vtpm; diff --git a/kernel/src/svsm.rs b/kernel/src/svsm.rs index d45bb63233..c090fc20e9 100755 --- a/kernel/src/svsm.rs +++ b/kernel/src/svsm.rs @@ -67,7 +67,7 @@ use svsm::svsm_paging::invalidate_early_boot_memory; use svsm::task::{KernelThreadStartInfo, schedule_init, start_kernel_task}; use svsm::types::PAGE_SIZE; use svsm::utils::{MemoryRegion, ScopedRef, round_to_pages}; -#[cfg(all(feature = "virtio-drivers", feature = "block"))] +#[cfg(all(feature = "virtio-drivers", any(feature = "block", feature = "vsock")))] use svsm::virtio::probe_mmio_slots; #[cfg(all(feature = "vtpm", not(test)))] use svsm::vtpm::vtpm_init; @@ -217,14 +217,21 @@ fn mapping_info_init(launch_info: &KernelLaunchInfo) { /// Returns an error when a virtio device is found but its driver initialization fails. #[cfg(feature = "virtio-drivers")] fn initialize_virtio_mmio(_boot_params: &BootParams<'_>) -> Result<(), SvsmError> { + #[cfg(any(feature = "block", feature = "vsock"))] + let mut slots = probe_mmio_slots(_boot_params); + #[cfg(feature = "block")] { use svsm::block::virtio_blk::initialize_block; - - let mut slots = probe_mmio_slots(_boot_params); initialize_block(&mut slots)?; } + #[cfg(feature = "vsock")] + { + use svsm::vsock::virtio_vsock::initialize_vsock; + initialize_vsock(&mut slots)?; + } + Ok(()) } diff --git a/kernel/src/testing.rs b/kernel/src/testing.rs index c2420b489a..106e064209 100644 --- a/kernel/src/testing.rs +++ b/kernel/src/testing.rs @@ -42,6 +42,8 @@ pub enum IORequest { GetLaunchMeasurement = 0x01, /// Virtio-blk tests: Get Sha256 hash of the svsm state disk image GetStateImageSha256 = 0x02, + /// Virtio-vsock tests: Ask host to start a vsock server + StartVsockServer = 0x03, } /// Return the serial port to communicate with the host for a given request diff --git a/kernel/src/vsock/api.rs b/kernel/src/vsock/api.rs new file mode 100644 index 0000000000..ecd2ab4c13 --- /dev/null +++ b/kernel/src/vsock/api.rs @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (c) 2025 Red Hat, Inc. +// +// Author: Luigi Leonardi + +use crate::error::SvsmError; + +pub trait VsockTransport: Sync + Send { + /// Establishes a connection to a remote vsock endpoint. + /// + /// This method initiates a connection to the specified remote CID and port + /// using the provided local port. The call blocks until the connection is + /// established or fails. + /// + /// # Parameters + /// + /// * `remote_cid` - The CID of the remote endpoint to connect to + /// * `local_port` - The local port to use for this connection + /// * `remote_port` - The remote port to connect to + /// + /// # Returns + /// + /// * `Ok()` if the connection was successfully established + /// * `Err(SvsmError)` if the connection failed + fn connect(&self, remote_cid: u32, local_port: u32, remote_port: u32) -> Result<(), SvsmError>; + + /// Sends data over an established vsock connection. + /// + /// Transmits the contents of the provided buffer to the remote endpoint. + /// The connection must have been previously established via `connect()`. + /// + /// # Parameters + /// + /// * `remote_cid` - The CID of the remote endpoint + /// * `local_port` - The local port of the connection + /// * `remote_port` - The remote port of the connection + /// * `buffer` - The data to send + /// + /// # Returns + /// + /// * `Ok(usize)` - The number of bytes successfully sent + /// * `Err(SvsmError)` if the send operation failed + fn send( + &self, + remote_cid: u32, + local_port: u32, + remote_port: u32, + buffer: &[u8], + ) -> Result; + + /// Receives data from an established vsock connection. + /// + /// Reads data from the remote endpoint into the provided buffer. This method + /// blocks until all data is available or an error occurs, in such case + /// returns all the received bytes, if any. + /// The connection must have been previously established via `connect()`. + /// + /// # Parameters + /// + /// * `remote_cid` - The CID of the remote endpoint + /// * `local_port` - The local port of the connection + /// * `remote_port` - The remote port of the connection + /// * `buffer` - The buffer to receive data into + /// + /// # Returns + /// + /// * `Ok(usize)` - The number of bytes successfully received + /// * `Err(SvsmError)` if the receive operation failed + fn recv( + &self, + remote_cid: u32, + local_port: u32, + remote_port: u32, + buffer: &mut [u8], + ) -> Result; + + /// Shuts down a vsock connection. + /// + /// Initiates a graceful shutdown of the connection telling the peer that we won't + /// send or receive any more data. + /// + /// # Parameters + /// + /// * `remote_cid` - The CID of the remote endpoint + /// * `local_port` - The local port of the connection + /// * `remote_port` - The remote port of the connection + /// * `force` - Forcibly terminates the connection, without waiting for peer confirm + /// + /// # Returns + /// + /// * `Ok()` if the shutdown was successful + /// * `Err(SvsmError)` if the shutdown failed + fn shutdown( + &self, + remote_cid: u32, + local_port: u32, + remote_port: u32, + force: bool, + ) -> Result<(), SvsmError>; + + /// Returns whether the given local port is currently in use. + /// + /// # Returns + /// + /// * `Ok(true)` - The port is in use + /// * `Ok(false)` - The port is free + /// * `Err(SvsmError)` if the check could not be performed + fn is_local_port_used(&self, port: u32) -> Result; +} diff --git a/kernel/src/vsock/error.rs b/kernel/src/vsock/error.rs new file mode 100644 index 0000000000..764a6c024f --- /dev/null +++ b/kernel/src/vsock/error.rs @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (c) 2025 Red Hat, Inc. +// +// Author: Luigi Leonardi + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum VsockError { + /// A connection already exists. + ConnectionExists, + /// The device is not connected to any peer. + NotConnected, + /// The peer socket has shutdown. + PeerSocketShutdown, + /// The local socket has been shutdown. + SocketShutdown, + /// No local ports are available. + NoPortsAvailable, + /// Generic error for socket operations on a vsock device. + DriverError, +} diff --git a/kernel/src/vsock/mod.rs b/kernel/src/vsock/mod.rs new file mode 100644 index 0000000000..6c59a28278 --- /dev/null +++ b/kernel/src/vsock/mod.rs @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (c) 2025 Red Hat, Inc. +// +// Author: Luigi Leonardi + +pub mod api; +pub mod error; +pub mod stream; +#[cfg(feature = "virtio-drivers")] +pub mod virtio_vsock; + +pub use error::VsockError; +/// Well-known CID for the host. +pub const VMADDR_CID_HOST: u32 = 2; +pub const VMADDR_PORT_ANY: u32 = u32::MAX; + +extern crate alloc; +use crate::{ + error::SvsmError, utils::immut_after_init::ImmutAfterInitCell, vsock::api::VsockTransport, +}; +use alloc::boxed::Box; +use core::ops::Deref; +use core::sync::atomic::{AtomicU32, Ordering}; + +// Currently only one vsock device is supported. +static VSOCK_DEVICE: ImmutAfterInitCell = ImmutAfterInitCell::uninit(); +// Ports below 1024 are reserved +const VSOCK_MIN_PORT: u32 = 1024; +// Number of maximum retries to get a local free port +const MAX_RETRIES: u32 = 5; + +struct VsockDriver { + first_free_port: AtomicU32, + transport: Box, +} + +impl VsockDriver { + /// Returns a free local port number for a new connection. + /// + /// Returns [`VsockError::NoPortsAvailable`] if all + /// ports are already in use. + fn get_first_free_port(&self) -> Result { + for _ in 0..MAX_RETRIES { + let candidate_port = + self.first_free_port + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |port| { + if port >= VMADDR_PORT_ANY - 1 { + Some(VSOCK_MIN_PORT) + } else { + Some(port + 1) + } + }); + + // The closure always returns Some, so this never fails. + let candidate_port = candidate_port.unwrap(); + + if !self.is_local_port_used(candidate_port)? { + return Ok(candidate_port); + } + } + + Err(SvsmError::Vsock(VsockError::NoPortsAvailable)) + } +} + +impl Deref for VsockDriver { + type Target = dyn VsockTransport; + + fn deref(&self) -> &Self::Target { + self.transport.as_ref() + } +} diff --git a/kernel/src/vsock/stream.rs b/kernel/src/vsock/stream.rs new file mode 100644 index 0000000000..0c3a03c5f4 --- /dev/null +++ b/kernel/src/vsock/stream.rs @@ -0,0 +1,301 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (c) 2025 Red Hat, Inc. +// +// Author: Luigi Leonardi + +use crate::{ + error::SvsmError, + io::{Read, Write}, + vsock::{VSOCK_DEVICE, VsockError}, +}; + +#[derive(Debug, Eq, PartialEq)] +enum VsockStreamStatus { + Connected, + Closed, +} + +/// A vsock stream for communication between a virtual machine +/// and its host. +/// +/// `VsockStream` provides a TCP-like socket interface over the VSOCK transport, +/// which is designed for communication between a guest VM and its host. +/// It implements the [`Read`] and [`Write`] traits for I/O operations. +/// +/// # Examples +/// +/// ```no_run +/// use crate::svsm::io::{Read, Write}; +/// use crate::svsm::vsock::{VMADDR_CID_HOST, stream::VsockStream}; +/// use svsm::error; +/// +/// // Connect to host on port 12345 +/// let mut stream = VsockStream::connect(12345, VMADDR_CID_HOST)?; +/// +/// // Write data +/// let data = b"Hello, host!"; +/// stream.write(data)?; +/// +/// // Read response +/// let mut buffer = [0u8; 10]; +/// let n = stream.read(&mut buffer)?; +/// +/// // Explicitly shut down the connection +/// stream.shutdown()?; +/// # Ok::<(), error::SvsmError>(()) +/// ``` +/// +/// # Connection Lifecycle +/// +/// - A stream is created in the `Connected` state via [`connect()`](Self::connect). +/// - It can be explicitly closed using [`shutdown()`](Self::shutdown). +/// - When dropped, the stream automatically performs a force shutdown if still connected. +#[derive(Debug)] +pub struct VsockStream { + local_port: u32, + remote_port: u32, + remote_cid: u32, + status: VsockStreamStatus, +} + +impl VsockStream { + /// Establishes a VSOCK connection to a remote endpoint. + /// + /// Creates a new VSOCK stream and connects to the specified remote port and CID + /// The local port is automatically assigned from available ports. + /// + /// # Arguments + /// + /// * `remote_port` - The port number on the remote endpoint to connect to. + /// * `remote_cid` - The CID of the remote endpoint. + /// + /// # Returns + /// + /// Returns a connected `VsockStream` on success, or an error if: + /// - The VSOCK device is not available (`VsockError::DriverError`) + /// - No free local ports are available + /// - The connection fails + pub fn connect(remote_port: u32, remote_cid: u32) -> Result { + let local_port = VSOCK_DEVICE.get_first_free_port()?; + VSOCK_DEVICE.connect(remote_cid, local_port, remote_port)?; + + Ok(Self { + local_port, + remote_port, + remote_cid, + status: VsockStreamStatus::Connected, + }) + } + + /// Gracefully shuts down the VSOCK connection. + /// + /// Closes the connection and transitions the stream to the `Closed` state. + /// After calling this method, any subsequent read or write operations will fail. + /// + /// # Returns + /// + /// Returns `Ok()` on successful shutdown, or an error if: + /// - The VSOCK device is not available (`VsockError::DriverError`) + /// - The connection is already shutdown (`VsockError::SocketShutdown`) + /// - The shutdown operation fails + pub fn shutdown(&mut self) -> Result<(), SvsmError> { + if self.status == VsockStreamStatus::Closed { + return Err(SvsmError::Vsock(VsockError::SocketShutdown)); + } + + VSOCK_DEVICE.shutdown(self.remote_cid, self.local_port, self.remote_port, false)?; + self.status = VsockStreamStatus::Closed; + Ok(()) + } +} + +impl Read for VsockStream { + type Err = SvsmError; + + /// Perform a blocking read from the VSOCK stream into the provided buffer. + /// + /// # Arguments + /// + /// * `buf` - The buffer to read data into. + /// + /// # Returns + /// + /// Returns the number of bytes read on success, or 0 if the peer shut + /// the connection down. + /// + /// # Errors + /// + /// Returns an error if: + /// - The VSOCK device is not available (`VsockError::DriverError`) + /// - The stream has been shut down (`VsockError::SocketShutdown`) + fn read(&mut self, buf: &mut [u8]) -> Result { + if self.status == VsockStreamStatus::Closed { + return Err(SvsmError::Vsock(VsockError::SocketShutdown)); + } + + match VSOCK_DEVICE.recv(self.remote_cid, self.local_port, self.remote_port, buf) { + Ok(value) => Ok(value), + Err(SvsmError::Vsock(VsockError::NotConnected)) => Ok(0), + Err(SvsmError::Vsock(VsockError::SocketShutdown)) => Ok(0), + Err(e) => Err(e), + } + } +} + +impl Write for VsockStream { + type Err = SvsmError; + + /// Writes data from the provided buffer to the VSOCK stream. + /// + /// # Arguments + /// + /// * `buf` - The buffer containing data to write. + /// + /// # Returns + /// + /// Returns the number of bytes written on success, or an error if: + /// - The VSOCK device is not available (`VsockError::DriverError`) + /// - The send operation fails + fn write(&mut self, buf: &[u8]) -> Result { + if self.status == VsockStreamStatus::Closed { + return Err(SvsmError::Vsock(VsockError::SocketShutdown)); + } + + VSOCK_DEVICE.send(self.remote_cid, self.local_port, self.remote_port, buf) + } +} + +impl Drop for VsockStream { + fn drop(&mut self) { + if self.status == VsockStreamStatus::Closed || VSOCK_DEVICE.try_get_inner().is_err() { + return; + } + + let _ = VSOCK_DEVICE.shutdown(self.remote_cid, self.local_port, self.remote_port, true); + } +} + +#[cfg(all(test, test_in_svsm))] +mod tests { + use crate::{testutils::has_test_iorequests, vsock::VMADDR_CID_HOST}; + use zerocopy::IntoBytes; + + use super::*; + + fn start_vsock_server_host() -> u32 { + use crate::serial::Terminal; + use crate::testing::{IORequest, svsm_test_io}; + + let mut sp = svsm_test_io().unwrap(); + + sp.put_byte(IORequest::StartVsockServer as u8); + + let mut vsock_port: u32 = 0; + + sp.read(vsock_port.as_mut_bytes()) + .expect("unable to get the vsock port"); + + vsock_port + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_virtio_vsock_double_connect() { + if !has_test_iorequests() { + return; + } + + let remote_port = start_vsock_server_host(); + + let mut stream = + VsockStream::connect(remote_port, VMADDR_CID_HOST).expect("connection failed"); + + VsockStream::connect(remote_port, VMADDR_CID_HOST) + .expect_err("The second connection operation was expected to fail, but it succeeded."); + + stream.shutdown().expect("shutdown failed"); + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_virtio_vsock_write() { + if !has_test_iorequests() { + return; + } + + let remote_port = start_vsock_server_host(); + + let mut stream = + VsockStream::connect(remote_port, VMADDR_CID_HOST).expect("connection failed"); + + let buffer: &[u8] = b"Hello world!"; + + let n_bytes = stream.write(buffer).expect("write failed"); + assert_eq!(n_bytes, buffer.len(), "Sent less bytes than requested"); + + stream.shutdown().expect("shutdown failed"); + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_virtio_vsock_read() { + if !has_test_iorequests() { + return; + } + + let remote_port = start_vsock_server_host(); + + let mut stream = + VsockStream::connect(remote_port, VMADDR_CID_HOST).expect("connection failed"); + + let mut buffer: [u8; 11] = [0; 11]; + let n_bytes = stream.read(&mut buffer).expect("read failed"); + assert_eq!(n_bytes, buffer.len(), "Received less bytes than requested"); + + let string = core::str::from_utf8(&buffer).unwrap(); + log::info!("received: {string:?}"); + assert_eq!(string, "hello_world", "Received wrong message"); + + stream.shutdown().expect("shutdown failed"); + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_virtio_vsock_read_shutdown() { + if !has_test_iorequests() { + return; + } + + let remote_port = start_vsock_server_host(); + + let mut stream = + VsockStream::connect(remote_port, VMADDR_CID_HOST).expect("connection failed"); + + stream.shutdown().expect("shutdown failed"); + + let mut buffer: [u8; 11] = [0; 11]; + stream + .read(&mut buffer) + .expect_err("The read operation was expected to fail, but it succeeded"); + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_virtio_vsock_write_shutdown() { + if !has_test_iorequests() { + return; + } + + let remote_port = start_vsock_server_host(); + + let mut stream = + VsockStream::connect(remote_port, VMADDR_CID_HOST).expect("connection failed"); + + stream.shutdown().expect("shutdown failed"); + + stream + .write(b"hello world") + .expect_err("The write operation was expected to fail, but it succeeded"); + } +} diff --git a/kernel/src/vsock/virtio_vsock.rs b/kernel/src/vsock/virtio_vsock.rs new file mode 100644 index 0000000000..531a048509 --- /dev/null +++ b/kernel/src/vsock/virtio_vsock.rs @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (c) 2025 Red Hat, Inc. +// +// Author: Luigi Leonardi + +use crate::error::SvsmError; +use crate::locking::SpinLock; +use crate::mm::GlobalRangeGuard; +use crate::virtio::VirtioError; +use crate::virtio::hal::SvsmHal; +use crate::virtio::mmio::{MmioSlot, MmioSlots}; +use crate::vsock::VSOCK_MIN_PORT; +use crate::vsock::VsockDriver; +use crate::vsock::api::VsockTransport; +use crate::vsock::{VSOCK_DEVICE, VsockError}; + +extern crate alloc; +use alloc::boxed::Box; +use core::sync::atomic::AtomicU32; + +use virtio_drivers::device::socket::VsockConnectionManager; +use virtio_drivers::device::socket::{VirtIOSocket, VsockAddr}; +use virtio_drivers::transport::DeviceType::Socket; +use virtio_drivers::transport::mmio::MmioTransport; + +pub struct VirtIOVsockDriver { + device: SpinLock>>, + _mmio_space: GlobalRangeGuard, +} + +impl core::fmt::Debug for VirtIOVsockDriver { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("VirtIOVsockDriver").finish() + } +} + +/// Initializes the global vsock device subsystem with a VirtIO vsock driver. +/// +/// This function searches for a virtio-vsock device in the MMIO slots list. +/// If discovered, the first virtio-vsock device will be initialized and +/// registered as the global vsock device. +/// **Only one vsock device is supported** +/// +/// # Arguments +/// +/// * `slots` - The virtio MMIO slots list +/// +/// # Returns +/// +/// * Returns Ok() if: +/// * The driver is correctly initialized +/// * No virtio-vsock devices are found +/// * Returns an error if: +/// * The driver initialization fails +/// * The global vsock device has already been initialized +pub fn initialize_vsock(slots: &mut MmioSlots) -> Result<(), SvsmError> { + let Some(slot) = slots.pop_slot(Socket) else { + return Ok(()); + }; + + let transport = VirtIOVsockDriver::new(slot)?; + + let driver = VsockDriver { + first_free_port: AtomicU32::new(VSOCK_MIN_PORT), + transport, + }; + + VSOCK_DEVICE.init(driver)?; + + Ok(()) +} + +impl VirtIOVsockDriver { + pub fn new(slot: MmioSlot) -> Result, SvsmError> { + let vsk = VirtIOSocket::new(slot.transport).map_err(|_| VirtioError::InvalidDevice)?; + let mgr = VsockConnectionManager::new(vsk); + + Ok(Box::new(VirtIOVsockDriver { + device: SpinLock::new(mgr), + _mmio_space: slot.mmio_range, + })) + } +} + +impl VsockTransport for VirtIOVsockDriver { + fn connect(&self, remote_cid: u32, local_port: u32, remote_port: u32) -> Result<(), SvsmError> { + let server_address = VsockAddr { + cid: remote_cid as u64, + port: remote_port, + }; + + self.device + .locked_do(|dev| dev.connect(server_address, local_port)) + .map_err(VsockError::from)?; + + loop { + // This global lock on the device is acquired and released on each iteration to + // allow some interleaving. In this way different processes can take this lock and + // perform some actions without having to wait for the connection to be fully + // established. + let mut dev = self.device.lock(); + + // For the connection to be established we need to wait for a `Connected` event. + // Unfortunately, because there could be multiple vsock streams open, the received + // event might not be related to this specific connection. For this reason, we wait + // for a generic event and then check the status of the connection in every iteration. + dev.wait_for_event().map_err(VsockError::from)?; + let status = dev + .is_connection_established(server_address, local_port) + .map_err(VsockError::from)?; + + if status { + return Ok(()); + } + } + } + + fn recv( + &self, + remote_cid: u32, + local_port: u32, + remote_port: u32, + buffer: &mut [u8], + ) -> Result { + let mut total_received: usize = 0; + let server_address = VsockAddr { + cid: remote_cid as u64, + port: remote_port, + }; + + loop { + // This global lock is acquired and released on each iteration to allow interleaving: + // In this way different processes can take this lock and perform some actions without + // having to wait for all the bytes to be received. + let mut dev = self.device.lock(); + + let received = match dev.recv(server_address, local_port, &mut buffer[total_received..]) + { + Ok(value) => value, + Err(error) => { + if total_received > 0 { + return Ok(total_received); + } else { + return Err(SvsmError::Vsock(VsockError::from(error))); + } + } + }; + log::debug!("[vsock] received: {received}"); + + total_received += received; + + let result = dev.update_credit(server_address, local_port); + if result.is_err() || total_received == buffer.len() { + break; + } + + // If we reach here, it means that we didn't receive all the requested bytes. + // So we need to block and wait for a `Received` event, that indicates that some + // more bytes are available to read. Because there could be multiple vsock streams + // open, the received event might not be related to this specific connection. For + // this reason we wait for a generic event. + dev.wait_for_event().map_err(VsockError::from)?; + } + + Ok(total_received) + } + + fn send( + &self, + remote_cid: u32, + local_port: u32, + remote_port: u32, + buffer: &[u8], + ) -> Result { + let mut dev = self.device.lock(); + + let server_address = VsockAddr { + cid: remote_cid as u64, + port: remote_port, + }; + + dev.send(server_address, local_port, buffer) + .map_err(VsockError::from)?; + Ok(buffer.len()) + } + + fn shutdown( + &self, + remote_cid: u32, + local_port: u32, + remote_port: u32, + force: bool, + ) -> Result<(), SvsmError> { + let mut dev = self.device.lock(); + + let server_address = VsockAddr { + cid: remote_cid as u64, + port: remote_port, + }; + + if force { + dev.force_close(server_address, local_port) + .map_err(VsockError::from)?; + } else { + dev.shutdown(server_address, local_port) + .map_err(VsockError::from)?; + } + + Ok(()) + } + + fn is_local_port_used(&self, port: u32) -> Result { + let dev = self.device.lock(); + + Ok(dev.is_local_port_used(port)) + } +} + +impl From for VsockError { + fn from(e: virtio_drivers::Error) -> Self { + use virtio_drivers::Error::SocketDeviceError; + use virtio_drivers::device::socket::SocketError; + + match e { + SocketDeviceError(SocketError::ConnectionExists) => VsockError::ConnectionExists, + SocketDeviceError(SocketError::NotConnected) => VsockError::NotConnected, + SocketDeviceError(SocketError::PeerSocketShutdown) => VsockError::PeerSocketShutdown, + _ => VsockError::DriverError, + } + } +} diff --git a/scripts/launch_guest.sh b/scripts/launch_guest.sh index 8a070d8e71..591bc6c1ec 100755 --- a/scripts/launch_guest.sh +++ b/scripts/launch_guest.sh @@ -31,7 +31,8 @@ IGVM_OBJ="" SNAPSHOT="on" STATE_DEVICE="" -STATE_ENABLE="" +VSOCK_DEVICE="" +VIRTIO=0 while [[ $# -gt 0 ]]; do case $1 in @@ -51,8 +52,7 @@ while [[ $# -gt 0 ]]; do shift ;; --state) - STATE_ENABLE="x-svsm-virtio-mmio=on" - STATE_DEVICE+="-global virtio-mmio.force-legacy=false " + VIRTIO=1 STATE_DEVICE+="-drive file=$2,format=raw,if=none,id=svsm_storage,cache=none " STATE_DEVICE+="-device virtio-blk-device,drive=svsm_storage " shift @@ -87,6 +87,12 @@ while [[ $# -gt 0 ]]; do shift shift ;; + --vsock) + VIRTIO=1 + VSOCK_DEVICE="-device vhost-vsock-device,guest-cid=$2 " + shift + shift + ;; --) shift break @@ -102,6 +108,13 @@ while [[ $# -gt 0 ]]; do esac done +VIRTIO_ENABLE="" +VIRTIO_CONFIG="" +if [ "$VIRTIO" -eq 1 ]; then + VIRTIO_ENABLE="x-svsm-virtio-mmio=on" + VIRTIO_CONFIG="-global virtio-mmio.force-legacy=false " +fi + # Split the QEMU version number so we can specify the correct parameters QEMU_VERSION=$($QEMU --version | grep -Po '(?<=version )[^ ]+') QEMU_MAJOR=${QEMU_VERSION%%.*} @@ -170,7 +183,7 @@ fi $SUDO_CMD \ "$QEMU" \ -cpu $CPU \ - -machine $MACHINE,$STATE_ENABLE \ + -machine $MACHINE,$VIRTIO_ENABLE \ -object $MEMORY \ $IGVM_OBJ \ $SNP_GUEST \ @@ -187,5 +200,7 @@ $SUDO_CMD \ $COM4_SERIAL \ $QEMU_EXIT_DEVICE \ $QEMU_TEST_IO_DEVICE \ + $VIRTIO_CONFIG \ $STATE_DEVICE \ + $VSOCK_DEVICE \ "$@" diff --git a/scripts/test-in-svsm.sh b/scripts/test-in-svsm.sh index 5f6e78eca2..bc832008e3 100755 --- a/scripts/test-in-svsm.sh +++ b/scripts/test-in-svsm.sh @@ -8,6 +8,8 @@ set -e SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +VSOCK_PORT=12345 +VSOCK_CID=10 test_io(){ PIPE_IN=$1 @@ -28,6 +30,13 @@ test_io(){ "02") sha256sum "$TEST_DIR/svsm_state.raw" | cut -f 1 -d ' ' | xxd -p -r > "$PIPE_IN" ;; + "03") + # virtio-vsock in svsm does not handle half duplex connections. + echo -n "hello_world" | ncat --no-shutdown -l --vsock -p $VSOCK_PORT & + sleep 1 + # write port number as an unsigned int and not as ascii + python3 -c "import sys,struct; sys.stdout.buffer.write(struct.pack('I', $VSOCK_PORT))" > $PIPE_IN + ;; "") # skip EOF ;; @@ -44,9 +53,6 @@ mkfifo $TEST_DIR/pipe.out # Create a raw disk image (512kB in size) for virtio-blk tests containing random data dd if=/dev/urandom of="$TEST_DIR/svsm_state.raw" bs=512 count=1024 -test_io $TEST_DIR/pipe.in $TEST_DIR/pipe.out & -TEST_IO_PID=$! - LAUNCH_GUEST_ARGS="" while [[ $# -gt 0 ]]; do @@ -55,6 +61,16 @@ while [[ $# -gt 0 ]]; do LAUNCH_GUEST_ARGS+="--nocc " shift ;; + --vsock-cid) + VSOCK_CID="$2" + shift + shift + ;; + --vsock-port) + VSOCK_PORT="$2" + shift + shift + ;; --) shift break @@ -66,9 +82,12 @@ while [[ $# -gt 0 ]]; do esac done +test_io $TEST_DIR/pipe.in $TEST_DIR/pipe.out & +TEST_IO_PID=$! $SCRIPT_DIR/launch_guest.sh --igvm $SCRIPT_DIR/../bin/coconut-test-qemu.igvm \ --state "$TEST_DIR/svsm_state.raw" \ + --vsock "$VSOCK_CID" \ --unit-tests $TEST_DIR/pipe \ $LAUNCH_GUEST_ARGS "$@" || svsm_exit_code=$? diff --git a/virtio-drivers/Cargo.toml b/virtio-drivers/Cargo.toml index b59654ae3b..91fe6270c8 100644 --- a/virtio-drivers/Cargo.toml +++ b/virtio-drivers/Cargo.toml @@ -18,10 +18,12 @@ log = { workspace = true } bitflags = { workspace = true } enumn = "0.1.14" zerocopy = { workspace = true, features = ["derive"] } +embedded-io = { version = "0.6.1", optional = true } [features] -default = ["alloc"] +default = ["alloc", "embedded-io"] alloc = ["zerocopy/alloc"] +embedded-io = ["dep:embedded-io"] [lints] workspace = true diff --git a/virtio-drivers/src/device/mod.rs b/virtio-drivers/src/device/mod.rs index 2d517353d7..76da9b782f 100644 --- a/virtio-drivers/src/device/mod.rs +++ b/virtio-drivers/src/device/mod.rs @@ -3,4 +3,6 @@ //! Drivers for specific VirtIO devices. pub mod blk; +pub mod socket; + pub(crate) mod common; diff --git a/virtio-drivers/src/device/socket/connectionmanager.rs b/virtio-drivers/src/device/socket/connectionmanager.rs new file mode 100644 index 0000000000..1299e794e0 --- /dev/null +++ b/virtio-drivers/src/device/socket/connectionmanager.rs @@ -0,0 +1,841 @@ +// SPDX-License-Identifier: MIT + +use super::{ + DEFAULT_RX_BUFFER_SIZE, DisconnectReason, SocketError, VirtIOSocket, VsockEvent, + VsockEventType, protocol::VsockAddr, vsock::ConnectionInfo, +}; +use crate::{Hal, Result, transport::Transport}; +use alloc::{boxed::Box, vec::Vec}; +use core::cmp::min; +use core::convert::TryInto; +use core::hint::spin_loop; +use log::debug; +use zerocopy::FromZeros; + +const DEFAULT_PER_CONNECTION_BUFFER_CAPACITY: u32 = 1024; + +/// A higher level interface for VirtIO socket (vsock) devices. +/// +/// This keeps track of multiple vsock connections. +/// +/// `RX_BUFFER_SIZE` is the size in bytes of each buffer used in the RX virtqueue. This must be +/// bigger than `size_of::()`. +/// +/// # Example +/// +/// ``` +/// # use virtio_drivers::{Error, Hal}; +/// # use virtio_drivers::transport::Transport; +/// use virtio_drivers::device::socket::{VirtIOSocket, VsockAddr, VsockConnectionManager}; +/// +/// # fn example(transport: T) -> Result<(), Error> { +/// let mut socket = VsockConnectionManager::new(VirtIOSocket::::new(transport)?); +/// +/// // Start a thread to call `socket.poll()` and handle events. +/// +/// let remote_address = VsockAddr { cid: 2, port: 42 }; +/// let local_port = 1234; +/// socket.connect(remote_address, local_port)?; +/// +/// // Wait until `socket.poll()` returns an event indicating that the socket is connected. +/// +/// socket.send(remote_address, local_port, "Hello world".as_bytes())?; +/// +/// socket.shutdown(remote_address, local_port)?; +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug)] +pub struct VsockConnectionManager< + H: Hal, + T: Transport, + const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE, +> { + driver: VirtIOSocket, + per_connection_buffer_capacity: u32, + connections: Vec, + listening_ports: Vec, +} + +#[derive(Debug)] +struct Connection { + info: ConnectionInfo, + buffer: RingBuffer, + established: bool, + /// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is + /// still data in the buffer. + peer_requested_shutdown: bool, +} + +impl Connection { + fn new(peer: VsockAddr, local_port: u32, buffer_capacity: u32) -> Self { + let mut info = ConnectionInfo::new(peer, local_port); + info.buf_alloc = buffer_capacity; + Self { + info, + buffer: RingBuffer::new(buffer_capacity.try_into().unwrap()), + established: false, + peer_requested_shutdown: false, + } + } +} + +impl + VsockConnectionManager +{ + /// Construct a new connection manager wrapping the given low-level VirtIO socket driver. + pub fn new(driver: VirtIOSocket) -> Self { + Self::new_with_capacity(driver, DEFAULT_PER_CONNECTION_BUFFER_CAPACITY) + } + + /// Construct a new connection manager wrapping the given low-level VirtIO socket driver, with + /// the given per-connection buffer capacity. + pub fn new_with_capacity( + driver: VirtIOSocket, + per_connection_buffer_capacity: u32, + ) -> Self { + Self { + driver, + connections: Vec::new(), + listening_ports: Vec::new(), + per_connection_buffer_capacity, + } + } + + /// Returns the CID which has been assigned to this guest. + pub fn guest_cid(&self) -> u64 { + self.driver.guest_cid() + } + + /// Returns true if the given local port is currently in use. + pub fn is_local_port_used(&self, port: u32) -> bool { + if self.listening_ports.contains(&port) { + return true; + } + + self.connections + .iter() + .any(|connection| connection.info.src_port == port) + } + + /// Returns true if a connection has been established, false otherwise + pub fn is_connection_established( + &mut self, + destination: VsockAddr, + src_port: u32, + ) -> Result { + let (_, connection) = get_connection(&mut self.connections, destination, src_port)?; + + Ok(connection.established) + } + + /// Allows incoming connections on the given port number. + pub fn listen(&mut self, port: u32) { + if !self.listening_ports.contains(&port) { + self.listening_ports.push(port); + } + } + + /// Stops allowing incoming connections on the given port number. + pub fn unlisten(&mut self, port: u32) { + self.listening_ports.retain(|p| *p != port); + } + + /// Sends a request to connect to the given destination. + /// + /// This returns as soon as the request is sent; you should wait until `poll` returns a + /// `VsockEventType::Connected` event indicating that the peer has accepted the connection + /// before sending data. + pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result { + if self.connections.iter().any(|connection| { + connection.info.dst == destination && connection.info.src_port == src_port + }) { + return Err(SocketError::ConnectionExists.into()); + } + + let new_connection = + Connection::new(destination, src_port, self.per_connection_buffer_capacity); + + self.driver.connect(&new_connection.info)?; + debug!("Connection requested: {:?}", new_connection.info); + self.connections.push(new_connection); + Ok(()) + } + + /// Sends the buffer to the destination. + pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result { + let (_, connection) = get_connection(&mut self.connections, destination, src_port)?; + if connection.peer_requested_shutdown { + return Err(SocketError::NotConnected.into()); + } + + self.driver.send(buffer, &mut connection.info) + } + + /// Polls the vsock device to receive data or other updates. + pub fn poll(&mut self) -> Result> { + let guest_cid = self.driver.guest_cid(); + let connections = &mut self.connections; + let per_connection_buffer_capacity = self.per_connection_buffer_capacity; + + let result = self.driver.poll(|event, body| { + let connection = get_connection_for_event(connections, &event, guest_cid); + + // Skip events which don't match any connection we know about, unless they are a + // connection request. + let connection = if let Some((_, connection)) = connection { + connection + } else if let VsockEventType::ConnectionRequest = event.event_type { + // If the requested connection already exists or the CID isn't ours, ignore it. + if connection.is_some() || event.destination.cid != guest_cid { + return Ok(None); + } + // Add the new connection to our list, at least for now. It will be removed again + // below if we weren't listening on the port. + connections.push(Connection::new( + event.source, + event.destination.port, + per_connection_buffer_capacity, + )); + connections.last_mut().unwrap() + } else { + return Ok(None); + }; + + // Update stored connection info. + connection.info.update_for_event(&event); + + if let VsockEventType::Received { length } = event.event_type { + // Copy to buffer + if !connection.buffer.add(body) { + return Err(SocketError::OutputBufferTooShort(length).into()); + } + } + + Ok(Some(event)) + })?; + + let Some(event) = result else { + return Ok(None); + }; + + // The connection must exist because we found it above in the callback. + let (connection_index, connection) = + get_connection_for_event(connections, &event, guest_cid).unwrap(); + + match event.event_type { + VsockEventType::ConnectionRequest => { + if self.listening_ports.contains(&event.destination.port) { + self.driver.accept(&connection.info)?; + connection.established = true; + } else { + // Reject the connection request and remove it from our list. + self.driver.force_close(&connection.info)?; + self.connections.swap_remove(connection_index); + + // No need to pass the request on to the client, as we've already rejected it. + return Ok(None); + } + } + VsockEventType::Connected => { + connection.established = true; + } + VsockEventType::Disconnected { reason } => { + // Wait until client reads all data before removing connection. + if connection.buffer.is_empty() { + if reason == DisconnectReason::Shutdown { + self.driver.force_close(&connection.info)?; + } + self.connections.swap_remove(connection_index); + } else { + connection.peer_requested_shutdown = true; + } + } + VsockEventType::Received { .. } => { + // Already copied the buffer in the callback above. + } + VsockEventType::CreditRequest => { + // If the peer requested credit, send an update. + self.driver.credit_update(&connection.info)?; + // No need to pass the request on to the client, we've already handled it. + return Ok(None); + } + VsockEventType::CreditUpdate => {} + } + + Ok(Some(event)) + } + + /// Reads data received from the given connection. + pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result { + let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?; + + // Copy from ring buffer + let bytes_read = connection.buffer.drain(buffer); + + connection.info.done_forwarding(bytes_read); + + // If buffer is now empty and the peer requested shutdown, finish shutting down the + // connection. + if connection.peer_requested_shutdown && connection.buffer.is_empty() { + self.driver.force_close(&connection.info)?; + self.connections.swap_remove(connection_index); + } + + Ok(bytes_read) + } + + /// Returns the number of bytes in the receive buffer available to be read by `recv`. + /// + /// When the available bytes is 0, it indicates that the receive buffer is empty and does not + /// contain any data. + pub fn recv_buffer_available_bytes(&mut self, peer: VsockAddr, src_port: u32) -> Result { + let (_, connection) = get_connection(&mut self.connections, peer, src_port)?; + Ok(connection.buffer.used()) + } + + /// Sends a credit update to the given peer. + pub fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result { + let (_, connection) = get_connection(&mut self.connections, peer, src_port)?; + if connection.peer_requested_shutdown { + return Err(SocketError::NotConnected.into()); + } + + self.driver.credit_update(&connection.info) + } + + /// Blocks until we get some event from the vsock device. + pub fn wait_for_event(&mut self) -> Result { + loop { + if let Some(event) = self.poll()? { + return Ok(event); + } else { + spin_loop(); + } + } + } + + /// Requests to shut down the connection cleanly, telling the peer that we won't send or receive + /// any more data. + /// + /// This returns as soon as the request is sent; you should wait until `poll` returns a + /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the + /// shutdown. + pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result { + let (_, connection) = get_connection(&mut self.connections, destination, src_port)?; + + self.driver.shutdown(&connection.info) + } + + /// Forcibly closes the connection without waiting for the peer. + pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result { + let (index, connection) = get_connection(&mut self.connections, destination, src_port)?; + + self.driver.force_close(&connection.info)?; + + self.connections.swap_remove(index); + Ok(()) + } +} + +/// Returns the connection from the given list matching the given peer address and local port, and +/// its index. +/// +/// Returns `Err(SocketError::NotConnected)` if there is no matching connection in the list. +fn get_connection( + connections: &mut [Connection], + peer: VsockAddr, + local_port: u32, +) -> core::result::Result<(usize, &mut Connection), SocketError> { + connections + .iter_mut() + .enumerate() + .find(|(_, connection)| { + connection.info.dst == peer && connection.info.src_port == local_port + }) + .ok_or(SocketError::NotConnected) +} + +/// Returns the connection from the given list matching the event, if any, and its index. +fn get_connection_for_event<'a>( + connections: &'a mut [Connection], + event: &VsockEvent, + local_cid: u64, +) -> Option<(usize, &'a mut Connection)> { + connections + .iter_mut() + .enumerate() + .find(|(_, connection)| event.matches_connection(&connection.info, local_cid)) +} + +#[derive(Debug)] +struct RingBuffer { + buffer: Box<[u8]>, + /// The number of bytes currently in the buffer. + used: usize, + /// The index of the first used byte in the buffer. + start: usize, +} + +impl RingBuffer { + pub fn new(capacity: usize) -> Self { + Self { + buffer: FromZeros::new_box_zeroed_with_elems(capacity).unwrap(), + used: 0, + start: 0, + } + } + + /// Returns the number of bytes currently used in the buffer. + pub fn used(&self) -> usize { + self.used + } + + /// Returns true iff there are currently no bytes in the buffer. + pub fn is_empty(&self) -> bool { + self.used == 0 + } + + /// Returns the number of bytes currently free in the buffer. + pub fn free(&self) -> usize { + self.buffer.len() - self.used + } + + /// Adds the given bytes to the buffer if there is enough capacity for them all. + /// + /// Returns true if they were added, or false if they were not. + pub fn add(&mut self, bytes: &[u8]) -> bool { + if bytes.len() > self.free() { + return false; + } + + // The index of the first available position in the buffer. + let first_available = (self.start + self.used) % self.buffer.len(); + // The number of bytes to copy from `bytes` to `buffer` between `first_available` and + // `buffer.len()`. + let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available); + self.buffer[first_available..first_available + copy_length_before_wraparound] + .copy_from_slice(&bytes[0..copy_length_before_wraparound]); + if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) { + self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound); + } + self.used += bytes.len(); + + true + } + + /// Reads and removes as many bytes as possible from the buffer, up to the length of the given + /// buffer. + pub fn drain(&mut self, out: &mut [u8]) -> usize { + let bytes_read = min(self.used, out.len()); + + // The number of bytes to copy out between `start` and the end of the buffer. + let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start); + // The number of bytes to copy out from the beginning of the buffer after wrapping around. + let read_after_wraparound = bytes_read + .checked_sub(read_before_wraparound) + .unwrap_or_default(); + + out[0..read_before_wraparound] + .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]); + out[read_before_wraparound..bytes_read] + .copy_from_slice(&self.buffer[0..read_after_wraparound]); + + self.used -= bytes_read; + self.start = (self.start + bytes_read) % self.buffer.len(); + + bytes_read + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + device::socket::{ + protocol::{ + SocketType, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, + }, + vsock::{QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX, VsockBufferStatus}, + }, + hal::fake::FakeHal, + transport::{ + DeviceType, + fake::{FakeTransport, QueueStatus, State}, + }, + volatile::ReadOnly, + }; + use alloc::{sync::Arc, vec}; + use core::{mem::size_of, ptr::NonNull}; + use std::{sync::Mutex, thread}; + use zerocopy::{FromBytes, IntoBytes}; + + #[test] + fn send_recv() { + let host_cid = 2; + let guest_cid = 66; + let host_port = 1234; + let guest_port = 4321; + let host_address = VsockAddr { + cid: host_cid, + port: host_port, + }; + let hello_from_guest = "Hello from guest"; + let hello_from_host = "Hello from host"; + + let mut config_space = VirtioVsockConfig { + guest_cid_low: ReadOnly::new(66), + guest_cid_high: ReadOnly::new(0), + }; + let state = Arc::new(Mutex::new(State { + queues: vec![ + QueueStatus::default(), + QueueStatus::default(), + QueueStatus::default(), + ], + ..Default::default() + })); + let transport = FakeTransport { + device_type: DeviceType::Socket, + max_queue_size: 32, + device_features: 0, + config_space: NonNull::from(&mut config_space), + state: state.clone(), + }; + let mut socket = VsockConnectionManager::new( + VirtIOSocket::>::new(transport).unwrap(), + ); + + // Start a thread to simulate the device. + let handle = thread::spawn(move || { + // Wait for connection request. + State::wait_until_queue_notified(&state, TX_QUEUE_IDX); + assert_eq!( + VirtioVsockHdr::read_from_bytes( + state + .lock() + .unwrap() + .read_from_queue::(TX_QUEUE_IDX) + .as_slice() + ) + .unwrap(), + VirtioVsockHdr { + op: VirtioVsockOp::Request.into(), + src_cid: guest_cid.into(), + dst_cid: host_cid.into(), + src_port: guest_port.into(), + dst_port: host_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 1024.into(), + fwd_cnt: 0.into(), + } + ); + + // Accept connection and give the peer enough credit to send the message. + state.lock().unwrap().write_to_queue::( + RX_QUEUE_IDX, + VirtioVsockHdr { + op: VirtioVsockOp::Response.into(), + src_cid: host_cid.into(), + dst_cid: guest_cid.into(), + src_port: host_port.into(), + dst_port: guest_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 50.into(), + fwd_cnt: 0.into(), + } + .as_bytes(), + ); + + // Expect the guest to send some data. + State::wait_until_queue_notified(&state, TX_QUEUE_IDX); + let request = state + .lock() + .unwrap() + .read_from_queue::(TX_QUEUE_IDX); + assert_eq!( + request.len(), + size_of::() + hello_from_guest.len() + ); + assert_eq!( + VirtioVsockHdr::read_from_prefix(request.as_slice()) + .unwrap() + .0, + VirtioVsockHdr { + op: VirtioVsockOp::Rw.into(), + src_cid: guest_cid.into(), + dst_cid: host_cid.into(), + src_port: guest_port.into(), + dst_port: host_port.into(), + len: (hello_from_guest.len() as u32).into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 1024.into(), + fwd_cnt: 0.into(), + } + ); + assert_eq!( + &request[size_of::()..], + hello_from_guest.as_bytes() + ); + + println!("Host sending"); + + // Send a response. + let mut response = vec![0; size_of::() + hello_from_host.len()]; + VirtioVsockHdr { + op: VirtioVsockOp::Rw.into(), + src_cid: host_cid.into(), + dst_cid: guest_cid.into(), + src_port: host_port.into(), + dst_port: guest_port.into(), + len: (hello_from_host.len() as u32).into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 50.into(), + fwd_cnt: (hello_from_guest.len() as u32).into(), + } + .write_to_prefix(response.as_mut_slice()) + .unwrap(); + response[size_of::()..].copy_from_slice(hello_from_host.as_bytes()); + state + .lock() + .unwrap() + .write_to_queue::(RX_QUEUE_IDX, &response); + + // Expect a shutdown. + State::wait_until_queue_notified(&state, TX_QUEUE_IDX); + assert_eq!( + VirtioVsockHdr::read_from_bytes( + state + .lock() + .unwrap() + .read_from_queue::(TX_QUEUE_IDX) + .as_slice() + ) + .unwrap(), + VirtioVsockHdr { + op: VirtioVsockOp::Shutdown.into(), + src_cid: guest_cid.into(), + dst_cid: host_cid.into(), + src_port: guest_port.into(), + dst_port: host_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: (StreamShutdown::SEND | StreamShutdown::RECEIVE).into(), + buf_alloc: 1024.into(), + fwd_cnt: (hello_from_host.len() as u32).into(), + } + ); + }); + + socket.connect(host_address, guest_port).unwrap(); + assert_eq!( + socket.wait_for_event().unwrap(), + VsockEvent { + source: host_address, + destination: VsockAddr { + cid: guest_cid, + port: guest_port, + }, + event_type: VsockEventType::Connected, + buffer_status: VsockBufferStatus { + buffer_allocation: 50, + forward_count: 0, + }, + } + ); + println!("Guest sending"); + socket + .send(host_address, guest_port, "Hello from guest".as_bytes()) + .unwrap(); + println!("Guest waiting to receive."); + assert_eq!( + socket.wait_for_event().unwrap(), + VsockEvent { + source: host_address, + destination: VsockAddr { + cid: guest_cid, + port: guest_port, + }, + event_type: VsockEventType::Received { + length: hello_from_host.len() + }, + buffer_status: VsockBufferStatus { + buffer_allocation: 50, + forward_count: hello_from_guest.len() as u32, + }, + } + ); + println!("Guest getting received data."); + let mut buffer = [0u8; 64]; + assert_eq!( + socket.recv(host_address, guest_port, &mut buffer).unwrap(), + hello_from_host.len() + ); + assert_eq!( + &buffer[0..hello_from_host.len()], + hello_from_host.as_bytes() + ); + socket.shutdown(host_address, guest_port).unwrap(); + + handle.join().unwrap(); + } + + #[test] + fn incoming_connection() { + let host_cid = 2; + let guest_cid = 66; + let host_port = 1234; + let guest_port = 4321; + let wrong_guest_port = 4444; + let host_address = VsockAddr { + cid: host_cid, + port: host_port, + }; + + let mut config_space = VirtioVsockConfig { + guest_cid_low: ReadOnly::new(66), + guest_cid_high: ReadOnly::new(0), + }; + let state = Arc::new(Mutex::new(State { + queues: vec![ + QueueStatus::default(), + QueueStatus::default(), + QueueStatus::default(), + ], + ..Default::default() + })); + let transport = FakeTransport { + device_type: DeviceType::Socket, + max_queue_size: 32, + device_features: 0, + config_space: NonNull::from(&mut config_space), + state: state.clone(), + }; + let mut socket = VsockConnectionManager::new( + VirtIOSocket::>::new(transport).unwrap(), + ); + + socket.listen(guest_port); + + // Start a thread to simulate the device. + let handle = thread::spawn(move || { + // Send a connection request for a port the guest isn't listening on. + println!("Host sending connection request to wrong port"); + state.lock().unwrap().write_to_queue::( + RX_QUEUE_IDX, + VirtioVsockHdr { + op: VirtioVsockOp::Request.into(), + src_cid: host_cid.into(), + dst_cid: guest_cid.into(), + src_port: host_port.into(), + dst_port: wrong_guest_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 50.into(), + fwd_cnt: 0.into(), + } + .as_bytes(), + ); + + // Expect a rejection. + println!("Host waiting for rejection"); + State::wait_until_queue_notified(&state, TX_QUEUE_IDX); + assert_eq!( + VirtioVsockHdr::read_from_bytes( + state + .lock() + .unwrap() + .read_from_queue::(TX_QUEUE_IDX) + .as_slice() + ) + .unwrap(), + VirtioVsockHdr { + op: VirtioVsockOp::Rst.into(), + src_cid: guest_cid.into(), + dst_cid: host_cid.into(), + src_port: wrong_guest_port.into(), + dst_port: host_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 1024.into(), + fwd_cnt: 0.into(), + } + ); + + // Send a connection request for a port the guest is listening on. + println!("Host sending connection request to right port"); + state.lock().unwrap().write_to_queue::( + RX_QUEUE_IDX, + VirtioVsockHdr { + op: VirtioVsockOp::Request.into(), + src_cid: host_cid.into(), + dst_cid: guest_cid.into(), + src_port: host_port.into(), + dst_port: guest_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 50.into(), + fwd_cnt: 0.into(), + } + .as_bytes(), + ); + + // Expect a response. + println!("Host waiting for response"); + State::wait_until_queue_notified(&state, TX_QUEUE_IDX); + assert_eq!( + VirtioVsockHdr::read_from_bytes( + state + .lock() + .unwrap() + .read_from_queue::(TX_QUEUE_IDX) + .as_slice() + ) + .unwrap(), + VirtioVsockHdr { + op: VirtioVsockOp::Response.into(), + src_cid: guest_cid.into(), + dst_cid: host_cid.into(), + src_port: guest_port.into(), + dst_port: host_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 1024.into(), + fwd_cnt: 0.into(), + } + ); + + println!("Host finished"); + }); + + // Expect an incoming connection. + println!("Guest expecting incoming connection."); + assert_eq!( + socket.wait_for_event().unwrap(), + VsockEvent { + source: host_address, + destination: VsockAddr { + cid: guest_cid, + port: guest_port, + }, + event_type: VsockEventType::ConnectionRequest, + buffer_status: VsockBufferStatus { + buffer_allocation: 50, + forward_count: 0, + }, + } + ); + + handle.join().unwrap(); + } +} diff --git a/virtio-drivers/src/device/socket/error.rs b/virtio-drivers/src/device/socket/error.rs new file mode 100644 index 0000000000..81040e941b --- /dev/null +++ b/virtio-drivers/src/device/socket/error.rs @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT + +//! This module contain the error from the VirtIO socket driver. + +use core::{fmt, result}; + +/// The error type of VirtIO socket driver. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum SocketError { + /// There is an existing connection. + ConnectionExists, + /// The device is not connected to any peer. + NotConnected, + /// Peer socket is shutdown. + PeerSocketShutdown, + /// The given buffer is shorter than expected. + BufferTooShort, + /// The given buffer for output is shorter than expected. + OutputBufferTooShort(usize), + /// The given buffer has exceeded the maximum buffer size. + BufferTooLong(usize, usize), + /// Unknown operation. + UnknownOperation(u16), + /// Invalid operation, + InvalidOperation, + /// Invalid number. + InvalidNumber, + /// Unexpected data in packet. + UnexpectedDataInPacket, + /// Peer has insufficient buffer space, try again later. + InsufficientBufferSpaceInPeer, + /// Recycled a wrong buffer. + RecycledWrongBuffer, +} + +impl fmt::Display for SocketError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::ConnectionExists => write!( + f, + "There is an existing connection. Please close the current connection before attempting to connect again." + ), + Self::NotConnected => write!( + f, + "The device is not connected to any peer. Please connect it to a peer first." + ), + Self::PeerSocketShutdown => write!(f, "The peer socket is shutdown."), + Self::BufferTooShort => write!(f, "The given buffer is shorter than expected"), + Self::BufferTooLong(actual, max) => { + write!( + f, + "The given buffer length '{actual}' has exceeded the maximum allowed buffer length '{max}'" + ) + } + Self::OutputBufferTooShort(expected) => { + write!( + f, + "The given output buffer is too short. '{expected}' bytes is needed for the output buffer." + ) + } + Self::UnknownOperation(op) => { + write!(f, "The operation code '{op}' is unknown") + } + Self::InvalidOperation => write!(f, "Invalid operation"), + Self::InvalidNumber => write!(f, "Invalid number"), + Self::UnexpectedDataInPacket => write!(f, "No data is expected in the packet"), + Self::InsufficientBufferSpaceInPeer => { + write!(f, "Peer has insufficient buffer space, try again later") + } + Self::RecycledWrongBuffer => write!(f, "Recycled a wrong buffer"), + } + } +} + +pub type Result = result::Result; diff --git a/virtio-drivers/src/device/socket/mod.rs b/virtio-drivers/src/device/socket/mod.rs new file mode 100644 index 0000000000..0bd24d8dfc --- /dev/null +++ b/virtio-drivers/src/device/socket/mod.rs @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT + +//! Driver for VirtIO socket devices. +//! +//! To use the driver, you should first create a [`VirtIOSocket`] instance with your VirtIO +//! transport, and then create a [`VsockConnectionManager`] wrapping it to keep track of +//! connections. If you want to manage connections yourself you can use the `VirtIOSocket` directly +//! for a lower-level interface. +//! +//! See [`VsockConnectionManager`] for a usage example. + +#[cfg(feature = "alloc")] +mod connectionmanager; +mod error; +mod protocol; +#[cfg(feature = "alloc")] +mod vsock; + +#[cfg(feature = "alloc")] +pub use connectionmanager::VsockConnectionManager; +pub use error::SocketError; +pub use protocol::{StreamShutdown, VMADDR_CID_HOST, VsockAddr}; +#[cfg(feature = "alloc")] +pub use vsock::{ConnectionInfo, DisconnectReason, VirtIOSocket, VsockEvent, VsockEventType}; + +/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than +/// `size_of::()`. +const DEFAULT_RX_BUFFER_SIZE: usize = 512; diff --git a/virtio-drivers/src/device/socket/protocol.rs b/virtio-drivers/src/device/socket/protocol.rs new file mode 100644 index 0000000000..7c476f3fe2 --- /dev/null +++ b/virtio-drivers/src/device/socket/protocol.rs @@ -0,0 +1,235 @@ +// SPDX-License-Identifier: MIT + +//! This module defines the socket device protocol according to the virtio spec v1.1 5.10 Socket Device + +use super::error::{self, SocketError}; +use crate::volatile::ReadOnly; +use bitflags::bitflags; +use core::{ + convert::{TryFrom, TryInto}, + fmt, +}; +use zerocopy::{ + FromBytes, Immutable, IntoBytes, KnownLayout, + byteorder::{LittleEndian, U16, U32, U64}, +}; + +/// Well-known CID for the host. +pub const VMADDR_CID_HOST: u64 = 2; + +/// Currently only stream sockets are supported. type is 1 for stream socket types. +#[derive(Copy, Clone, Debug)] +#[repr(u16)] +pub enum SocketType { + /// Stream sockets provide in-order, guaranteed, connection-oriented delivery without message boundaries. + Stream = 1, + /// seqpacket socket type introduced in virtio-v1.2. + SeqPacket = 2, +} + +impl From for U16 { + fn from(socket_type: SocketType) -> Self { + (socket_type as u16).into() + } +} + +/// VirtioVsockConfig is the vsock device configuration space. +#[repr(C)] +pub struct VirtioVsockConfig { + /// The guest_cid field contains the guest’s context ID, which uniquely identifies + /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed. + /// + /// According to virtio spec v1.1 2.4.1 Driver Requirements: Device Configuration Space, + /// drivers MUST NOT assume reads from fields greater than 32 bits wide are atomic. + /// So we need to split the u64 guest_cid into two parts. + pub guest_cid_low: ReadOnly, + pub guest_cid_high: ReadOnly, +} + +/// The message header for data packets sent on the tx/rx queues +#[repr(C, packed)] +#[derive(Clone, Copy, Debug, Eq, FromBytes, Immutable, IntoBytes, KnownLayout, PartialEq)] +pub struct VirtioVsockHdr { + pub src_cid: U64, + pub dst_cid: U64, + pub src_port: U32, + pub dst_port: U32, + pub len: U32, + pub socket_type: U16, + pub op: U16, + pub flags: U32, + /// Total receive buffer space for this socket. This includes both free and in-use buffers. + pub buf_alloc: U32, + /// Free-running bytes received counter. + pub fwd_cnt: U32, +} + +impl Default for VirtioVsockHdr { + fn default() -> Self { + Self { + src_cid: 0.into(), + dst_cid: 0.into(), + src_port: 0.into(), + dst_port: 0.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + op: 0.into(), + flags: 0.into(), + buf_alloc: 0.into(), + fwd_cnt: 0.into(), + } + } +} + +impl VirtioVsockHdr { + /// Returns the length of the data. + pub fn len(&self) -> u32 { + u32::from(self.len) + } + + pub fn op(&self) -> error::Result { + self.op.try_into() + } + + pub fn source(&self) -> VsockAddr { + VsockAddr { + cid: self.src_cid.get(), + port: self.src_port.get(), + } + } + + pub fn destination(&self) -> VsockAddr { + VsockAddr { + cid: self.dst_cid.get(), + port: self.dst_port.get(), + } + } + + pub fn check_data_is_empty(&self) -> error::Result<()> { + if self.len() == 0 { + Ok(()) + } else { + Err(SocketError::UnexpectedDataInPacket) + } + } +} + +/// Socket address. +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] +pub struct VsockAddr { + /// Context Identifier. + pub cid: u64, + /// Port number. + pub port: u32, +} + +/// An event sent to the event queue +#[derive(Copy, Clone, Debug, Default, IntoBytes, FromBytes, Immutable, KnownLayout)] +#[repr(C)] +pub struct VirtioVsockEvent { + // ID from the virtio_vsock_event_id struct in the virtio spec + pub id: U32, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +#[repr(u16)] +pub enum VirtioVsockOp { + Invalid = 0, + + /* Connect operations */ + Request = 1, + Response = 2, + Rst = 3, + Shutdown = 4, + + /* To send payload */ + Rw = 5, + + /* Tell the peer our credit info */ + CreditUpdate = 6, + /* Request the peer to send the credit info to us */ + CreditRequest = 7, +} + +impl From for U16 { + fn from(op: VirtioVsockOp) -> Self { + (op as u16).into() + } +} + +impl TryFrom> for VirtioVsockOp { + type Error = SocketError; + + fn try_from(v: U16) -> Result { + let op = match u16::from(v) { + 0 => Self::Invalid, + 1 => Self::Request, + 2 => Self::Response, + 3 => Self::Rst, + 4 => Self::Shutdown, + 5 => Self::Rw, + 6 => Self::CreditUpdate, + 7 => Self::CreditRequest, + _ => return Err(SocketError::UnknownOperation(v.into())), + }; + Ok(op) + } +} + +impl fmt::Debug for VirtioVsockOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Invalid => write!(f, "VIRTIO_VSOCK_OP_INVALID"), + Self::Request => write!(f, "VIRTIO_VSOCK_OP_REQUEST"), + Self::Response => write!(f, "VIRTIO_VSOCK_OP_RESPONSE"), + Self::Rst => write!(f, "VIRTIO_VSOCK_OP_RST"), + Self::Shutdown => write!(f, "VIRTIO_VSOCK_OP_SHUTDOWN"), + Self::Rw => write!(f, "VIRTIO_VSOCK_OP_RW"), + Self::CreditUpdate => write!(f, "VIRTIO_VSOCK_OP_CREDIT_UPDATE"), + Self::CreditRequest => write!(f, "VIRTIO_VSOCK_OP_CREDIT_REQUEST"), + } + } +} + +bitflags! { + #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] + pub(crate) struct Feature: u64 { + /// stream socket type is supported. + const STREAM = 1 << 0; + /// seqpacket socket type is supported. + const SEQ_PACKET = 1 << 1; + + // device independent + const NOTIFY_ON_EMPTY = 1 << 24; // legacy + const ANY_LAYOUT = 1 << 27; // legacy + const RING_INDIRECT_DESC = 1 << 28; + const RING_EVENT_IDX = 1 << 29; + const UNUSED = 1 << 30; // legacy + const VERSION_1 = 1 << 32; // detect legacy + + // since virtio v1.1 + const ACCESS_PLATFORM = 1 << 33; + const RING_PACKED = 1 << 34; + const IN_ORDER = 1 << 35; + const ORDER_PLATFORM = 1 << 36; + const SR_IOV = 1 << 37; + const NOTIFICATION_DATA = 1 << 38; + } +} + +bitflags! { + /// Flags sent with a shutdown request to hint that the peer won't send or receive more data. + #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] + pub struct StreamShutdown: u32 { + /// The peer will not receive any more data. + const RECEIVE = 1 << 0; + /// The peer will not send any more data. + const SEND = 1 << 1; + } +} + +impl From for U32 { + fn from(flags: StreamShutdown) -> Self { + flags.bits().into() + } +} diff --git a/virtio-drivers/src/device/socket/vsock.rs b/virtio-drivers/src/device/socket/vsock.rs new file mode 100644 index 0000000000..a1f57ce781 --- /dev/null +++ b/virtio-drivers/src/device/socket/vsock.rs @@ -0,0 +1,505 @@ +// SPDX-License-Identifier: MIT + +//! Driver for VirtIO socket devices. +#![deny(unsafe_op_in_unsafe_fn)] + +use super::DEFAULT_RX_BUFFER_SIZE; +use super::error::SocketError; +use super::protocol::{ + Feature, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr, +}; +use crate::Result; +use crate::hal::Hal; +use crate::queue::{VirtQueue, owning::OwningQueue}; +use crate::transport::Transport; +use crate::volatile::volread; +use core::mem::size_of; +use log::debug; +use zerocopy::{FromBytes, IntoBytes}; + +pub(crate) const RX_QUEUE_IDX: u16 = 0; +pub(crate) const TX_QUEUE_IDX: u16 = 1; +const EVENT_QUEUE_IDX: u16 = 2; + +pub(crate) const QUEUE_SIZE: usize = 8; +const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX.union(Feature::RING_INDIRECT_DESC); + +/// Information about a particular vsock connection. +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct ConnectionInfo { + /// The address of the peer. + pub dst: VsockAddr, + /// The local port number associated with the connection. + pub src_port: u32, + /// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in + /// bytes it has allocated for packet bodies. + peer_buf_alloc: u32, + /// The last `fwd_cnt` value the peer sent to us, indicating how many bytes of packet bodies it + /// has finished processing. + peer_fwd_cnt: u32, + /// The number of bytes of packet bodies which we have sent to the peer. + tx_cnt: u32, + /// The number of bytes of buffer space we have allocated to receive packet bodies from the + /// peer. + pub buf_alloc: u32, + /// The number of bytes of packet bodies which we have received from the peer and handled. + fwd_cnt: u32, + /// Whether we have recently requested credit from the peer. + /// + /// This is set to true when we send a `VIRTIO_VSOCK_OP_CREDIT_REQUEST`, and false when we + /// receive a `VIRTIO_VSOCK_OP_CREDIT_UPDATE`. + has_pending_credit_request: bool, +} + +impl ConnectionInfo { + /// Creates a new `ConnectionInfo` for the given peer address and local port, and default values + /// for everything else. + pub fn new(destination: VsockAddr, src_port: u32) -> Self { + Self { + dst: destination, + src_port, + ..Default::default() + } + } + + /// Updates this connection info with the peer buffer allocation and forwarded count from the + /// given event. + pub fn update_for_event(&mut self, event: &VsockEvent) { + self.peer_buf_alloc = event.buffer_status.buffer_allocation; + self.peer_fwd_cnt = event.buffer_status.forward_count; + + if let VsockEventType::CreditUpdate = event.event_type { + self.has_pending_credit_request = false; + } + } + + /// Increases the forwarded count recorded for this connection by the given number of bytes. + /// + /// This should be called once received data has been passed to the client, so there is buffer + /// space available for more. + pub fn done_forwarding(&mut self, length: usize) { + self.fwd_cnt += length as u32; + } + + /// Returns the number of bytes of RX buffer space the peer has available to receive packet body + /// data from us. + fn peer_free(&self) -> u32 { + self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt) + } + + fn new_header(&self, src_cid: u64) -> VirtioVsockHdr { + VirtioVsockHdr { + src_cid: src_cid.into(), + dst_cid: self.dst.cid.into(), + src_port: self.src_port.into(), + dst_port: self.dst.port.into(), + buf_alloc: self.buf_alloc.into(), + fwd_cnt: self.fwd_cnt.into(), + ..Default::default() + } + } +} + +/// An event received from a VirtIO socket device. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct VsockEvent { + /// The source of the event, i.e. the peer who sent it. + pub source: VsockAddr, + /// The destination of the event, i.e. the CID and port on our side. + pub destination: VsockAddr, + /// The peer's buffer status for the connection. + pub buffer_status: VsockBufferStatus, + /// The type of event. + pub event_type: VsockEventType, +} + +impl VsockEvent { + /// Returns whether the event matches the given connection. + pub fn matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool { + self.source == connection_info.dst + && self.destination.cid == guest_cid + && self.destination.port == connection_info.src_port + } + + fn from_header(header: &VirtioVsockHdr) -> Result { + let op = header.op()?; + let buffer_status = VsockBufferStatus { + buffer_allocation: header.buf_alloc.into(), + forward_count: header.fwd_cnt.into(), + }; + let source = header.source(); + let destination = header.destination(); + + let event_type = match op { + VirtioVsockOp::Request => { + header.check_data_is_empty()?; + VsockEventType::ConnectionRequest + } + VirtioVsockOp::Response => { + header.check_data_is_empty()?; + VsockEventType::Connected + } + VirtioVsockOp::CreditUpdate => { + header.check_data_is_empty()?; + VsockEventType::CreditUpdate + } + VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => { + header.check_data_is_empty()?; + debug!("Disconnected from the peer"); + let reason = if op == VirtioVsockOp::Rst { + DisconnectReason::Reset + } else { + DisconnectReason::Shutdown + }; + VsockEventType::Disconnected { reason } + } + VirtioVsockOp::Rw => VsockEventType::Received { + length: header.len() as usize, + }, + VirtioVsockOp::CreditRequest => { + header.check_data_is_empty()?; + VsockEventType::CreditRequest + } + VirtioVsockOp::Invalid => return Err(SocketError::InvalidOperation.into()), + }; + + Ok(VsockEvent { + source, + destination, + buffer_status, + event_type, + }) + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct VsockBufferStatus { + pub buffer_allocation: u32, + pub forward_count: u32, +} + +/// The reason why a vsock connection was closed. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum DisconnectReason { + /// The peer has either closed the connection in response to our shutdown request, or forcibly + /// closed it of its own accord. + Reset, + /// The peer asked to shut down the connection. + Shutdown, +} + +/// Details of the type of an event received from a VirtIO socket. +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum VsockEventType { + /// The peer requests to establish a connection with us. + ConnectionRequest, + /// The connection was successfully established. + Connected, + /// The connection was closed. + Disconnected { + /// The reason for the disconnection. + reason: DisconnectReason, + }, + /// Data was received on the connection. + Received { + /// The length of the data in bytes. + length: usize, + }, + /// The peer requests us to send a credit update. + CreditRequest, + /// The peer just sent us a credit update with nothing else. + CreditUpdate, +} + +/// Low-level driver for a VirtIO socket device. +/// +/// You probably want to use [`VsockConnectionManager`](super::VsockConnectionManager) rather than +/// using this directly. +/// +/// `RX_BUFFER_SIZE` is the size in bytes of each buffer used in the RX virtqueue. This must be +/// bigger than `size_of::()`. +#[derive(Debug)] +pub struct VirtIOSocket +{ + transport: T, + /// Virtqueue to receive packets. + rx: OwningQueue, + tx: VirtQueue, + /// Virtqueue to receive events from the device. + event: VirtQueue, + /// The guest_cid field contains the guest’s context ID, which uniquely identifies + /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed. + guest_cid: u64, +} + +impl Drop + for VirtIOSocket +{ + fn drop(&mut self) { + // Clear any pointers pointing to DMA regions, so the device doesn't try to access them + // after they have been freed. + self.transport.queue_unset(RX_QUEUE_IDX); + self.transport.queue_unset(TX_QUEUE_IDX); + self.transport.queue_unset(EVENT_QUEUE_IDX); + } +} + +impl VirtIOSocket { + /// Create a new VirtIO Vsock driver. + pub fn new(mut transport: T) -> Result { + assert!(RX_BUFFER_SIZE > size_of::()); + + let negotiated_features = transport.begin_init(SUPPORTED_FEATURES); + + let config = transport.config_space::()?; + debug!("config: {config:?}"); + // SAFETY: Safe because config is a valid pointer to the device configuration space. + let guest_cid = unsafe { + volread!(H, config, guest_cid_low) as u64 + | (volread!(H, config, guest_cid_high) as u64) << 32 + }; + debug!("guest cid: {guest_cid:?}"); + + let rx = VirtQueue::new( + &mut transport, + RX_QUEUE_IDX, + negotiated_features.contains(Feature::RING_INDIRECT_DESC), + negotiated_features.contains(Feature::RING_EVENT_IDX), + )?; + let tx = VirtQueue::new( + &mut transport, + TX_QUEUE_IDX, + negotiated_features.contains(Feature::RING_INDIRECT_DESC), + negotiated_features.contains(Feature::RING_EVENT_IDX), + )?; + let event = VirtQueue::new( + &mut transport, + EVENT_QUEUE_IDX, + negotiated_features.contains(Feature::RING_INDIRECT_DESC), + negotiated_features.contains(Feature::RING_EVENT_IDX), + )?; + + let rx = OwningQueue::new(rx)?; + + transport.finish_init(); + if rx.should_notify() { + transport.notify(RX_QUEUE_IDX); + } + + Ok(Self { + transport, + rx, + tx, + event, + guest_cid, + }) + } + + /// Returns the CID which has been assigned to this guest. + pub fn guest_cid(&self) -> u64 { + self.guest_cid + } + + /// Sends a request to connect to the given destination. + /// + /// This returns as soon as the request is sent; you should wait until `poll` returns a + /// `VsockEventType::Connected` event indicating that the peer has accepted the connection + /// before sending data. + pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result { + let header = VirtioVsockHdr { + op: VirtioVsockOp::Request.into(), + ..connection_info.new_header(self.guest_cid) + }; + // Sends a header only packet to the TX queue to connect the device to the listening socket + // at the given destination. + self.send_packet_to_tx_queue(&header, &[]) + } + + /// Accepts the given connection from a peer. + pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result { + let header = VirtioVsockHdr { + op: VirtioVsockOp::Response.into(), + ..connection_info.new_header(self.guest_cid) + }; + self.send_packet_to_tx_queue(&header, &[]) + } + + /// Requests the peer to send us a credit update for the given connection. + fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result { + let header = VirtioVsockHdr { + op: VirtioVsockOp::CreditRequest.into(), + ..connection_info.new_header(self.guest_cid) + }; + self.send_packet_to_tx_queue(&header, &[]) + } + + /// Sends the buffer to the destination. + pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result { + self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?; + + let len = buffer.len() as u32; + let header = VirtioVsockHdr { + op: VirtioVsockOp::Rw.into(), + len: len.into(), + ..connection_info.new_header(self.guest_cid) + }; + connection_info.tx_cnt += len; + self.send_packet_to_tx_queue(&header, buffer) + } + + fn check_peer_buffer_is_sufficient( + &mut self, + connection_info: &mut ConnectionInfo, + buffer_len: usize, + ) -> Result { + if connection_info.peer_free() as usize >= buffer_len { + Ok(()) + } else { + // Request an update of the cached peer credit, if we haven't already done so, and tell + // the caller to try again later. + if !connection_info.has_pending_credit_request { + self.request_credit(connection_info)?; + connection_info.has_pending_credit_request = true; + } + Err(SocketError::InsufficientBufferSpaceInPeer.into()) + } + } + + /// Tells the peer how much buffer space we have to receive data. + pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result { + let header = VirtioVsockHdr { + op: VirtioVsockOp::CreditUpdate.into(), + ..connection_info.new_header(self.guest_cid) + }; + self.send_packet_to_tx_queue(&header, &[]) + } + + /// Polls the RX virtqueue for the next event, and calls the given handler function to handle + /// it. + pub fn poll( + &mut self, + handler: impl FnOnce(VsockEvent, &[u8]) -> Result>, + ) -> Result> { + self.rx.poll(&mut self.transport, |buffer| { + let (header, body) = read_header_and_body(buffer)?; + VsockEvent::from_header(&header).and_then(|event| handler(event, body)) + }) + } + + /// Requests to shut down the connection cleanly, sending hints about whether we will send or + /// receive more data. + /// + /// This returns as soon as the request is sent; you should wait until `poll` returns a + /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the + /// shutdown. + pub fn shutdown_with_hints( + &mut self, + connection_info: &ConnectionInfo, + hints: StreamShutdown, + ) -> Result { + let header = VirtioVsockHdr { + op: VirtioVsockOp::Shutdown.into(), + flags: hints.into(), + ..connection_info.new_header(self.guest_cid) + }; + self.send_packet_to_tx_queue(&header, &[]) + } + + /// Requests to shut down the connection cleanly, telling the peer that we won't send or receive + /// any more data. + /// + /// This returns as soon as the request is sent; you should wait until `poll` returns a + /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the + /// shutdown. + pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result { + self.shutdown_with_hints( + connection_info, + StreamShutdown::SEND | StreamShutdown::RECEIVE, + ) + } + + /// Forcibly closes the connection without waiting for the peer. + pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result { + let header = VirtioVsockHdr { + op: VirtioVsockOp::Rst.into(), + ..connection_info.new_header(self.guest_cid) + }; + self.send_packet_to_tx_queue(&header, &[])?; + Ok(()) + } + + fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result { + let _len = if buffer.is_empty() { + self.tx + .add_notify_wait_pop(&[header.as_bytes()], &mut [], &mut self.transport)? + } else { + self.tx.add_notify_wait_pop( + &[header.as_bytes(), buffer], + &mut [], + &mut self.transport, + )? + }; + Ok(()) + } +} + +fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> { + // This could fail if the device returns a buffer used length shorter than the header size. + let header = VirtioVsockHdr::read_from_prefix(buffer) + .map_err(|_| SocketError::BufferTooShort)? + .0; + let body_length = header.len() as usize; + + // This could fail if the device returns an unreasonably long body length. + let data_end = size_of::() + .checked_add(body_length) + .ok_or(SocketError::InvalidNumber)?; + // This could fail if the device returns a body length longer than buffer used length it + // returned. + let data = buffer + .get(size_of::()..data_end) + .ok_or(SocketError::BufferTooShort)?; + Ok((header, data)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + hal::fake::FakeHal, + transport::{ + DeviceType, + fake::{FakeTransport, QueueStatus, State}, + }, + volatile::ReadOnly, + }; + use alloc::{sync::Arc, vec}; + use core::ptr::NonNull; + use std::sync::Mutex; + + #[test] + fn config() { + let mut config_space = VirtioVsockConfig { + guest_cid_low: ReadOnly::new(66), + guest_cid_high: ReadOnly::new(0), + }; + let state = Arc::new(Mutex::new(State { + queues: vec![ + QueueStatus::default(), + QueueStatus::default(), + QueueStatus::default(), + ], + ..Default::default() + })); + let transport = FakeTransport { + device_type: DeviceType::Socket, + max_queue_size: 32, + device_features: 0, + config_space: NonNull::from(&mut config_space), + state: state.clone(), + }; + let socket = + VirtIOSocket::>::new(transport).unwrap(); + assert_eq!(socket.guest_cid(), 0x00_0000_0042); + } +} diff --git a/virtio-drivers/src/embedded_io.rs b/virtio-drivers/src/embedded_io.rs new file mode 100644 index 0000000000..2d8f99a8e5 --- /dev/null +++ b/virtio-drivers/src/embedded_io.rs @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT + +//! Implementation of `embedded-io::Error' trait for `Error`. + +use crate::{Error, device::socket::SocketError}; +use embedded_io::ErrorKind; + +impl embedded_io::Error for Error { + fn kind(&self) -> ErrorKind { + match self { + Error::InvalidParam => ErrorKind::InvalidInput, + Error::DmaError => ErrorKind::OutOfMemory, + Error::Unsupported => ErrorKind::Unsupported, + Error::SocketDeviceError(e) => match e { + &SocketError::ConnectionExists => ErrorKind::AddrInUse, + SocketError::NotConnected => ErrorKind::NotConnected, + SocketError::PeerSocketShutdown => ErrorKind::ConnectionAborted, + SocketError::BufferTooShort => ErrorKind::InvalidInput, + SocketError::OutputBufferTooShort(_) => ErrorKind::InvalidInput, + SocketError::BufferTooLong(_, _) => ErrorKind::InvalidInput, + SocketError::InsufficientBufferSpaceInPeer => ErrorKind::WriteZero, + SocketError::UnknownOperation(_) + | SocketError::InvalidOperation + | SocketError::InvalidNumber + | SocketError::UnexpectedDataInPacket + | SocketError::RecycledWrongBuffer => ErrorKind::Other, + }, + Error::QueueFull + | Error::NotReady + | Error::WrongToken + | Error::AlreadyUsed + | Error::IoError + | Error::ConfigSpaceTooSmall + | Error::ConfigSpaceMissing => ErrorKind::Other, + } + } +} diff --git a/virtio-drivers/src/lib.rs b/virtio-drivers/src/lib.rs index 1382cfeb43..ac9ca8ebe0 100644 --- a/virtio-drivers/src/lib.rs +++ b/virtio-drivers/src/lib.rs @@ -21,6 +21,8 @@ extern crate alloc; pub mod device; +#[cfg(feature = "embedded-io")] +mod embedded_io; mod hal; mod queue; pub mod transport; @@ -62,6 +64,8 @@ pub enum Error { ConfigSpaceTooSmall, /// The device doesn't have any config space, but the driver expects some. ConfigSpaceMissing, + /// Error from the socket device. + SocketDeviceError(device::socket::SocketError), } #[cfg(feature = "alloc")] @@ -95,10 +99,17 @@ impl Display for Error { "The device doesn't have any config space, but the driver expects some" ) } + Self::SocketDeviceError(e) => write!(f, "Error from the socket device: {e:?}"), } } } +impl From for Error { + fn from(e: device::socket::SocketError) -> Self { + Self::SocketDeviceError(e) + } +} + /// Align `size` up to a page. fn align_up(size: usize) -> usize { (size + PAGE_SIZE) & !(PAGE_SIZE - 1)