diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index 956710d0a..68db6c83b 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -20,23 +20,23 @@ jobs: - name: Clippy (test_cases guest) run: | cd tests - cargo clippy --locked -p test_cases --features guest -- -D warnings - + IPERF_DURATION=15 cargo clippy --locked -p test_cases --features guest -- -D warnings + - name: Clippy (test_cases host) run: | cd tests - PKG_CONFIG_PATH="$(realpath ../test-prefix/lib64/pkgconfig/)" LD_LIBRARY_PATH="$(realpath ../test-prefix/lib64/)" cargo clippy --locked -p test_cases --features host -- -D warnings - + IPERF_DURATION=15 PKG_CONFIG_PATH="$(realpath ../test-prefix/lib64/pkgconfig/)" LD_LIBRARY_PATH="$(realpath ../test-prefix/lib64/)" cargo clippy --locked -p test_cases --features host -- -D warnings + - name: Clippy (runner) run: | cd tests PKG_CONFIG_PATH="$(realpath ../test-prefix/lib64/pkgconfig/)" LD_LIBRARY_PATH="$(realpath ../test-prefix/lib64/)" cargo clippy --locked -p runner -- -D warnings - + - name: Clippy (guest-agent) run: | cd tests cargo clippy --locked --target x86_64-unknown-linux-musl -p guest-agent -- -D warnings - + - name: Enable KVM group perms run: | echo 'KERNEL=="kvm", GROUP="kvm", MODE="0666", OPTIONS+="static_node=kvm"' | sudo tee /etc/udev/rules.d/99-kvm4all.rules @@ -45,13 +45,13 @@ jobs: sudo usermod -a -G kvm $USER - name: Install additional packages - run: sudo apt-get install -y --no-install-recommends build-essential patchelf pkg-config net-tools - + run: sudo apt-get install -y --no-install-recommends build-essential patchelf pkg-config net-tools iperf3 + - name: Install libkrunfw run: TAG=`curl -sL https://api.github.com/repos/containers/libkrunfw/releases/latest |jq -r .tag_name` && curl -L -o /tmp/libkrunfw-x86_64.tgz https://github.com/containers/libkrunfw/releases/download/$TAG/libkrunfw-x86_64.tgz && mkdir tmp && tar xf /tmp/libkrunfw-x86_64.tgz -C tmp && sudo mv tmp/lib64/* /lib/x86_64-linux-gnu - + - name: Integration tests - run: KRUN_ENOMEM_WORKAROUND=1 KRUN_NO_UNSHARE=1 KRUN_TEST_BASE_DIR=/tmp/libkrun-tests make test TEST_FLAGS="--keep-all --github-summary" + run: KRUN_ENOMEM_WORKAROUND=1 KRUN_NO_UNSHARE=1 KRUN_TEST_BASE_DIR=/tmp/libkrun-tests make test NET=1 IPERF_DURATION=30 TEST_FLAGS="--keep-all --github-summary" - name: Upload test logs if: always() @@ -81,34 +81,34 @@ jobs: - name: Clippy (test_cases guest) run: | cd tests - cargo clippy --locked -p test_cases --features guest -- -D warnings - + IPERF_DURATION=15 cargo clippy --locked -p test_cases --features guest -- -D warnings + - name: Clippy (test_cases host) run: | cd tests - PKG_CONFIG_PATH="$(realpath ../test-prefix/lib64/pkgconfig/)" LD_LIBRARY_PATH="$(realpath ../test-prefix/lib64/)" cargo clippy --locked -p test_cases --features host -- -D warnings - + IPERF_DURATION=15 PKG_CONFIG_PATH="$(realpath ../test-prefix/lib64/pkgconfig/)" LD_LIBRARY_PATH="$(realpath ../test-prefix/lib64/)" cargo clippy --locked -p test_cases --features host -- -D warnings + - name: Clippy (runner) run: | cd tests PKG_CONFIG_PATH="$(realpath ../test-prefix/lib64/pkgconfig/)" LD_LIBRARY_PATH="$(realpath ../test-prefix/lib64/)" cargo clippy --locked -p runner -- -D warnings - + - name: Clippy (guest-agent) run: | cd tests cargo clippy --locked --target aarch64-unknown-linux-musl -p guest-agent -- -D warnings - + - name: Install additional packages - run: sudo apt-get install -y --no-install-recommends build-essential patchelf pkg-config net-tools - + run: sudo apt-get install -y --no-install-recommends build-essential patchelf pkg-config net-tools iperf3 + - name: Install libkrunfw run: TAG=`curl -sL https://api.github.com/repos/containers/libkrunfw/releases/latest |jq -r .tag_name` && curl -L -o /tmp/libkrunfw-aarch64.tgz https://github.com/containers/libkrunfw/releases/download/$TAG/libkrunfw-aarch64.tgz && mkdir tmp && tar xf /tmp/libkrunfw-aarch64.tgz -C tmp && sudo mv tmp/lib64/* /lib/aarch64-linux-gnu - name: Clean up tests directory run: rm -fr /tmp/libkrun-tests - + - name: Integration tests - run: KRUN_ENOMEM_WORKAROUND=1 KRUN_NO_UNSHARE=1 KRUN_TEST_BASE_DIR=/tmp/libkrun-tests make test TEST_FLAGS="--keep-all --github-summary" + run: KRUN_ENOMEM_WORKAROUND=1 KRUN_NO_UNSHARE=1 KRUN_TEST_BASE_DIR=/tmp/libkrun-tests make test NET=1 IPERF_DURATION=30 TEST_FLAGS="--keep-all --github-summary" - name: Upload test logs if: always() diff --git a/Makefile b/Makefile index 732207edc..75509ed70 100644 --- a/Makefile +++ b/Makefile @@ -232,14 +232,21 @@ clean: clean-all: clean clean-sysroot -test-prefix/lib64/libkrun.pc: $(LIBRARY_RELEASE_$(OS)) +test-prefix/$(LIBDIR_$(OS))/libkrun.pc: $(LIBRARY_RELEASE_$(OS)) mkdir -p test-prefix PREFIX="$$(realpath test-prefix)" make install -test-prefix: test-prefix/lib64/libkrun.pc +test-prefix: test-prefix/$(LIBDIR_$(OS))/libkrun.pc TEST ?= all TEST_FLAGS ?= +# Library path variable differs by OS +LIBPATH_VAR_Linux = LD_LIBRARY_PATH +LIBPATH_VAR_Darwin = DYLD_LIBRARY_PATH +# Extra library paths needed for tests (libkrunfw, llvm) +EXTRA_LIBPATH_Linux = +EXTRA_LIBPATH_Darwin = /opt/homebrew/opt/libkrunfw/lib:/opt/homebrew/opt/llvm/lib + test: test-prefix - cd tests; RUST_LOG=trace LD_LIBRARY_PATH="$$(realpath ../test-prefix/lib64/)" PKG_CONFIG_PATH="$$(realpath ../test-prefix/lib64/pkgconfig/)" ./run.sh test --test-case "$(TEST)" $(TEST_FLAGS) + cd tests; RUST_LOG=trace $(LIBPATH_VAR_$(OS))="$$(realpath ../test-prefix/$(LIBDIR_$(OS))/):$(EXTRA_LIBPATH_$(OS)):$${$(LIBPATH_VAR_$(OS))}" PKG_CONFIG_PATH="$$(realpath ../test-prefix/$(LIBDIR_$(OS))/pkgconfig/)" ./run.sh test --test-case "$(TEST)" $(TEST_FLAGS) diff --git a/examples/chroot_vm.c b/examples/chroot_vm.c index 4c25dabd3..ed6878738 100644 --- a/examples/chroot_vm.c +++ b/examples/chroot_vm.c @@ -27,6 +27,7 @@ enum net_mode { NET_MODE_PASST = 0, NET_MODE_TSI, + NET_MODE_TAP, }; static void print_help(char *const name) @@ -38,8 +39,9 @@ static void print_help(char *const name) " --log=PATH Write libkrun log to file or named pipe at PATH\n" " --color-log=PATH Write libkrun log to file or named pipe at PATH, use color\n" " --net=NET_MODE Set network mode\n" - " --passt-socket=PATH Instead of starting passt, connect to passt socket at PATH" - "NET_MODE can be either TSI (default) or PASST\n" + " --passt-socket=PATH Instead of starting passt, connect to passt socket at PATH\n" + " --tap=NAME Use TAP device NAME for networking\n" + "NET_MODE can be TSI (default), PASST, or TAP\n" "\n" "NEWROOT: the root directory of the vm\n" "COMMAND: the command you want to execute in the vm\n" @@ -54,6 +56,7 @@ static const struct option long_options[] = { { "color-log", required_argument, NULL, 'C' }, { "net_mode", required_argument, NULL, 'N' }, { "passt-socket", required_argument, NULL, 'P' }, + { "tap", required_argument, NULL, 'T' }, { NULL, 0, NULL, 0 } }; @@ -63,6 +66,7 @@ struct cmdline { uint32_t log_style; enum net_mode net_mode; char const *passt_socket_path; + char const *tap_name; char const *new_root; char *const *guest_argv; }; @@ -89,6 +93,7 @@ bool parse_cmdline(int argc, char *const argv[], struct cmdline *cmdline) .show_help = false, .net_mode = NET_MODE_TSI, .passt_socket_path = NULL, + .tap_name = NULL, .new_root = NULL, .guest_argv = NULL, .log_target = KRUN_LOG_TARGET_DEFAULT, @@ -116,6 +121,8 @@ bool parse_cmdline(int argc, char *const argv[], struct cmdline *cmdline) cmdline->net_mode = NET_MODE_TSI; } else if(strcasecmp("PASST", optarg) == 0) { cmdline->net_mode = NET_MODE_PASST; + } else if(strcasecmp("TAP", optarg) == 0) { + cmdline->net_mode = NET_MODE_TAP; } else { fprintf(stderr, "Unknown mode %s\n", optarg); return false; @@ -124,6 +131,10 @@ bool parse_cmdline(int argc, char *const argv[], struct cmdline *cmdline) case 'P': cmdline->passt_socket_path = optarg; break; + case 'T': + cmdline->tap_name = optarg; + cmdline->net_mode = NET_MODE_TAP; + break; case '?': return false; default: @@ -268,15 +279,18 @@ int main(int argc, char *const argv[]) return -1; } - // Map port 18000 in the host to 8000 in the guest (if networking uses TSI) - if (cmdline.net_mode == NET_MODE_TSI) { + // Configure networking based on mode + uint8_t mac[] = {0x5a, 0x94, 0xef, 0xe4, 0x0c, 0xee}; + switch (cmdline.net_mode) { + case NET_MODE_TSI: + // Map port 18000 in the host to 8000 in the guest if (err = krun_set_port_map(ctx_id, &port_map[0])) { errno = -err; perror("Error configuring port map"); return -1; } - } else { - uint8_t mac[] = {0x5a, 0x94, 0xef, 0xe4, 0x0c, 0xee}; + break; + case NET_MODE_PASST: if (cmdline.passt_socket_path != NULL) { if (err = krun_add_net_unixstream(ctx_id, cmdline.passt_socket_path, -1, &mac[0], COMPAT_NET_FEATURES, 0)) { errno = -err; @@ -285,17 +299,26 @@ int main(int argc, char *const argv[]) } } else { int passt_fd = start_passt(); - if (passt_fd < 0) { return -1; } - if (err = krun_add_net_unixstream(ctx_id, NULL, passt_fd, &mac[0], COMPAT_NET_FEATURES, 0)) { errno = -err; perror("Error configuring net mode"); return -1; } } + break; + case NET_MODE_TAP: + if (cmdline.tap_name == NULL) { + return -1; + } + if (err = krun_add_net_tap(ctx_id, (char *)cmdline.tap_name, &mac[0], COMPAT_NET_FEATURES, 0)) { + errno = -err; + perror("Error configuring TAP network"); + return -1; + } + break; } // Configure the rlimits that will be set in the guest diff --git a/examples/chroot_vm_test b/examples/chroot_vm_test new file mode 100755 index 000000000..c946c1b8a Binary files /dev/null and b/examples/chroot_vm_test differ diff --git a/examples/test_tap_net.sh b/examples/test_tap_net.sh new file mode 100755 index 000000000..a8fe41c48 --- /dev/null +++ b/examples/test_tap_net.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Test script for virtio-net with tap backend +# This script builds libkrun, compiles chroot_vm, and runs a network test + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +PREFIX="$SCRIPT_DIR/libkrun-prefix" +CHROOT_ROOT="${CHROOT_ROOT:-/home/mhrica/c/my_rootfs2}" +TAP_DEVICE="${TAP_DEVICE:-tap0}" + +echo "=== libkrun tap network test ===" +echo "Project root: $PROJECT_ROOT" +echo "Prefix: $PREFIX" +echo "Chroot root: $CHROOT_ROOT" +echo "TAP device: $TAP_DEVICE" +echo "" + +# Build and install libkrun +cd "$PROJECT_ROOT" + +if [[ "$1" == "--rebuild" ]] || [[ ! -f "$PREFIX/lib64/libkrun.so" ]]; then + echo "=== Building libkrun ===" + make clean 2>/dev/null || true + make PREFIX="$PREFIX" + make install PREFIX="$PREFIX" +fi + +# Compile chroot_vm +echo "=== Compiling chroot_vm_test ===" +cd "$SCRIPT_DIR" +PKG_CONFIG_PATH="$PREFIX/lib64/pkgconfig" \ + gcc -g -o chroot_vm_test chroot_vm.c $(pkg-config --cflags --libs libkrun) + +# Create a named pipe for log output +LOG_PIPE="$SCRIPT_DIR/krun_log_pipe" +rm -f "$LOG_PIPE" +mkfifo "$LOG_PIPE" + +# Start log reader in background +echo "=== Starting log reader ===" +cat "$LOG_PIPE" & +LOG_PID=$! + +# Cleanup on exit +cleanup() { + kill $LOG_PID 2>/dev/null || true + rm -f "$LOG_PIPE" +} +trap cleanup EXIT + +# Run the test +echo "=== Running chroot_vm with tap backend ===" +echo "Command: LD_LIBRARY_PATH=$PREFIX/lib64 ./chroot_vm_test --color-log=$LOG_PIPE --net=TAP --tap=$TAP_DEVICE $CHROOT_ROOT /usr/bin/ping -c 3 8.8.8.8" +echo "" + +LD_LIBRARY_PATH="$PREFIX/lib64" \ + ./chroot_vm_test \ + --color-log="$LOG_PIPE" \ + --net=TAP \ + --tap="$TAP_DEVICE" \ + "$CHROOT_ROOT" \ + /guest_net_test.sh + +echo "" +echo "=== Test complete ===" diff --git a/include/libkrun.h b/include/libkrun.h index 325c59eed..0a7e1ec7b 100644 --- a/include/libkrun.h +++ b/include/libkrun.h @@ -326,6 +326,7 @@ int32_t krun_add_virtiofs2(uint32_t ctx_id, /* Send the VFKIT magic after establishing the connection, as required by gvproxy in vfkit mode. */ #define NET_FLAG_VFKIT 1 << 0 +#define NET_FLAG_INCLUDE_VNET_HEADER 1 << 1 /* TSI (Transparent Socket Impersonation) feature flags for vsock */ #define KRUN_TSI_HIJACK_INET (1 << 0) diff --git a/src/devices/Cargo.toml b/src/devices/Cargo.toml index f78b67ae1..7136b929d 100644 --- a/src/devices/Cargo.toml +++ b/src/devices/Cargo.toml @@ -8,7 +8,8 @@ edition = "2021" tee = [] amd-sev = ["blk", "tee"] tdx = ["blk", "tee"] -net = [] +net = ["batch_queue"] +batch_queue = [] blk = [] efi = ["blk", "net"] gpu = ["rutabaga_gfx", "thiserror", "zerocopy", "krun_display"] @@ -24,7 +25,7 @@ crossbeam-channel = ">=0.5.15" libc = ">=0.2.39" libloading = "0.8" log = "0.4.0" -nix = { version = "0.30.1", features = ["ioctl", "net", "poll", "socket", "fs"] } +nix = { version = "0.30.1", features = ["ioctl", "net", "poll", "socket", "fs", "uio"] } pw = { package = "pipewire", version = "0.8.0", optional = true } rand = "0.9.2" thiserror = { version = "2.0", optional = true } diff --git a/src/devices/src/virtio/batch_queue/iovec_utils.rs b/src/devices/src/virtio/batch_queue/iovec_utils.rs new file mode 100644 index 000000000..040c73260 --- /dev/null +++ b/src/devices/src/virtio/batch_queue/iovec_utils.rs @@ -0,0 +1,122 @@ +// Copyright 2026 Red Hat, Inc. +// SPDX-License-Identifier: Apache-2.0 + +//! Utilities for working with iovec slices. + +use libc::iovec; +use std::io::IoSliceMut; + +/// Calculate total length of iovec slices. +/// Works with both IoSlice and IoSliceMut. +pub fn iovecs_len>(slices: &[T]) -> usize { + slices.iter().map(|s| s.len()).sum() +} + +/// Write data to iovecs, spanning multiple buffers if needed. +pub fn write_to_iovecs(slices: &mut [IoSliceMut], data: &[u8]) -> usize { + let mut written = 0; + for iov in slices.iter_mut() { + let remaining = data.len() - written; + if remaining == 0 { + break; + } + let take = remaining.min(iov.len()); + iov[..take].copy_from_slice(&data[written..written + take]); + written += take; + } + written +} + +/// Advance iovecs in place by `bytes`, removing fully consumed buffers (Vec version). +/// +/// Works with Vec, removing consumed iovecs from the front and +/// adjusting the first remaining iovec's pointer/length as needed. +pub fn advance_iovecs_vec(iovecs: &mut Vec>, bytes: usize) { + let mut remaining = bytes; + while remaining > 0 && !iovecs.is_empty() { + let first_len = iovecs[0].len(); + if first_len <= remaining { + iovecs.remove(0); + remaining -= first_len; + } else { + let ptr = iovecs[0].as_mut_ptr(); + let new_len = first_len - remaining; + // Safety: advancing pointer within same allocation + let new_slice = unsafe { std::slice::from_raw_parts_mut(ptr.add(remaining), new_len) }; + iovecs[0] = IoSliceMut::new(new_slice); + remaining = 0; + } + } +} + +/// Advance IoSlice Vec in place by `bytes`, removing fully consumed buffers. +/// +/// Works with Vec, removing consumed iovecs from the front and +/// adjusting the first remaining iovec's pointer/length as needed. +pub fn advance_tx_iovecs_vec(iovecs: &mut Vec>, bytes: usize) { + let mut remaining = bytes; + while remaining > 0 && !iovecs.is_empty() { + let first_len = iovecs[0].len(); + if first_len <= remaining { + iovecs.remove(0); + remaining -= first_len; + } else { + let ptr = iovecs[0].as_ptr(); + let new_len = first_len - remaining; + // Safety: advancing pointer within same allocation + let new_slice = unsafe { std::slice::from_raw_parts(ptr.add(remaining), new_len) }; + iovecs[0] = std::io::IoSlice::new(new_slice); + remaining = 0; + } + } +} + +/// Advance raw iovecs in place by `bytes`, removing fully consumed buffers. +/// +/// Works directly on `Vec` without going through `IoSliceMut`, avoiding +/// provenance issues when the iovecs originate from read-only memory (e.g., TX). +pub fn advance_raw_iovecs(iovecs: &mut Vec, bytes: usize) { + let mut remaining = bytes; + while remaining > 0 && !iovecs.is_empty() { + let first_len = iovecs[0].iov_len; + if first_len <= remaining { + iovecs.remove(0); + remaining -= first_len; + } else { + // Safety: advancing pointer within same allocation + iovecs[0].iov_base = unsafe { (iovecs[0].iov_base as *mut u8).add(remaining) as _ }; + iovecs[0].iov_len = first_len - remaining; + remaining = 0; + } + } +} + +/// Truncate iovecs in place to max_bytes total, returning the usable slice. +pub fn truncate_iovecs<'a, 'b>( + slices: &'a mut [IoSliceMut<'b>], + max_bytes: usize, +) -> &'a mut [IoSliceMut<'b>] { + let mut total: usize = 0; + for (i, slice) in slices.iter_mut().enumerate() { + let new_total = total.saturating_add(slice.len()); + + if new_total >= max_bytes { + // total <= max_bytes here (otherwise we'd have returned in a previous iteration), + // so this subtraction cannot underflow + let take = max_bytes - total; + // Last iovec is empty so we don't include it in the and + if take == 0 { + return &mut slices[..i]; + } + + let ptr = slice.as_mut_ptr(); + // SAFETY: `take <= len` because we only enter this branch when + // `total + len >= max_bytes`, which means `max_bytes - total <= len`. + // The pointer `ptr` is valid for `len` bytes, so it's valid for `take` bytes. + *slice = IoSliceMut::new(unsafe { std::slice::from_raw_parts_mut(ptr, take) }); + return &mut slices[..=i]; + } + total = new_total; + } + slices +} diff --git a/src/devices/src/virtio/batch_queue/mod.rs b/src/devices/src/virtio/batch_queue/mod.rs new file mode 100644 index 000000000..9d2834804 --- /dev/null +++ b/src/devices/src/virtio/batch_queue/mod.rs @@ -0,0 +1,134 @@ +// Copyright 2026 Red Hat, Inc. +// SPDX-License-Identifier: Apache-2.0 + +//! Batched virtio queue producer/consumer infrastructure. +//! +//! Provides generic queue handling suited for vectored I/O on virtio queues +//! (e.g. sending a whole descriptor chain in a single `writev`, supporting +//! partial writes, partial reads, etc.). +//! +//! The representation trait [`ChainsMemoryRepr`] allows backends to plug in +//! optimised layouts (e.g. `mmsghdr` for `sendmmsg`/`recvmmsg`). + +use std::io::IoSliceMut; + +use libc::iovec; + +use iovec_utils::{advance_raw_iovecs, truncate_iovecs}; + +pub mod iovec_utils; +mod rx_queue_producer; +mod tx_queue_consumer; + +pub use rx_queue_producer::{RxProducerBatch, RxQueueProducer}; +pub use tx_queue_consumer::{TxConsumerBatch, TxQueueConsumer}; + +/// Base trait for descriptor chain memory representation. +/// +/// # Safety +/// +/// - The iovecs stored in the representation point into guest memory owned by +/// the `TxQueueConsumer`/`RxQueueProducer`. The representation must not +/// outlive the consumer/producer — it must stay within the container. +/// - The consumer/producer guarantees that `clear()` is called before the +/// representation is dropped. `clear()` receives external `Meta` (e.g., +/// `Vec` capacity) needed to correctly free owned resources. Implementors +/// must release all owned memory in `clear()`. +pub unsafe trait ChainsMemoryRepr: Sized + Send { + /// User-defined metadata stored alongside each chain (e.g., Vec capacity). + type Meta: Default; + + /// Number of slices in this chain. + fn len(&self) -> usize; + + /// Check if empty. + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Total bytes across all slices. + fn total_bytes(&self) -> usize; + + /// Release owned resources. Always called by the consumer/producer before + /// drop, with the external `Meta` needed for cleanup. + fn clear(&mut self, meta: &mut Self::Meta); +} + +/// Trait for representation types that support advancing (consuming bytes from front). +/// +/// # Safety +/// +/// Implementors must maintain iovec validity after advancing: the remaining +/// slices must still point to valid guest memory with correct lengths. +pub unsafe trait AdvanceBytes: ChainsMemoryRepr { + /// Advance slices by removing consumed bytes from the front. + fn advance(&mut self, bytes: usize); +} + +/// Trait for representation types that know how many bytes were received. +/// +/// Used by batch receive operations to report per-chain byte counts. +pub trait ReceivedLen: ChainsMemoryRepr { + /// Number of bytes received into this chain. + fn received_len(&self) -> usize; +} + +/// Trait for representation types that support truncating (limiting total bytes). +/// +/// # Safety +/// +/// Implementors must maintain iovec validity after truncating: the remaining +/// slices must still point to valid guest memory with correct lengths. +pub unsafe trait TruncateBytes: ChainsMemoryRepr { + /// Truncate slices to limit total bytes to `max_bytes`. + fn truncate_bytes(&mut self, max_bytes: usize); +} + +/// Wrapper around `Vec` that implements `Send`. +/// +/// # Safety +/// The raw pointers in `iovec` point to guest memory managed by the owning +/// `TxQueueConsumer`/`RxQueueProducer`. The memory is pinned and the struct +/// lifetime ensures the pointers remain valid. Transferring to another thread +/// is safe because we transfer ownership of the entire container. +#[derive(Debug, Default)] +#[repr(transparent)] +pub struct IovecVec(pub Vec); + +// Safety: See struct-level documentation +unsafe impl Send for IovecVec {} + +// ChainsMemoryRepr implemented for IovecVec - the default representation type. +// Raw iovec has no lifetime, avoiding the need for fake 'static lifetimes. +unsafe impl ChainsMemoryRepr for IovecVec { + type Meta = (); + + fn len(&self) -> usize { + self.0.len() + } + + fn total_bytes(&self) -> usize { + self.0.iter().map(|s| s.iov_len).sum() + } + + fn clear(&mut self, _meta: &mut ()) { + self.0.clear(); + } +} + +unsafe impl AdvanceBytes for IovecVec { + fn advance(&mut self, bytes: usize) { + advance_raw_iovecs(&mut self.0, bytes); + } +} + +unsafe impl TruncateBytes for IovecVec { + fn truncate_bytes(&mut self, max_bytes: usize) { + // Safety: IoSliceMut is #[repr(transparent)] over iovec. + let slices: &mut [IoSliceMut] = unsafe { + std::slice::from_raw_parts_mut(self.0.as_mut_ptr() as *mut IoSliceMut, self.0.len()) + }; + let keep = truncate_iovecs(slices, max_bytes).len(); + self.0.truncate(keep); + } +} diff --git a/src/devices/src/virtio/batch_queue/rx_queue_producer.rs b/src/devices/src/virtio/batch_queue/rx_queue_producer.rs new file mode 100644 index 000000000..1988c7952 --- /dev/null +++ b/src/devices/src/virtio/batch_queue/rx_queue_producer.rs @@ -0,0 +1,1053 @@ +// Copyright 2026 Red Hat, Inc. +// SPDX-License-Identifier: Apache-2.0 + +//! RX queue producer for batched virtio receive operations. + +use std::io::IoSliceMut; +use std::ops::Range; + +use libc::iovec; +use vm_memory::{GuestMemory, GuestMemoryMmap}; + +use super::super::queue::{DescriptorChain, Queue}; +use super::super::InterruptTransport; +use super::iovec_utils::write_to_iovecs; +use super::{AdvanceBytes, ChainsMemoryRepr, IovecVec, ReceivedLen, TruncateBytes}; + +/// Metadata for a pending descriptor chain. +#[derive(Debug)] +struct ChainMeta { + head_index: u16, + max_bytes: usize, + bytes_used: usize, + finished: bool, + /// User-defined metadata + user_meta: M, +} + +/// RxQueueProducer - owns the RX queue and provides buffers for receiving. +/// +/// Generic over representation type R, allowing different backends to use optimized +/// representations (e.g., mmsghdr for recvmmsg). Default is IovecVec. +/// +/// Pops descriptor chains from the virtio RX queue and provides writable +/// iovecs for receiving data. Unfinished chains are kept pending for the next +/// produce() call; finished chains get add_used() with their byte counts. +/// +/// The iovecs point into guest memory owned by `mem`. This is safe because +/// the struct owns the memory reference and outlives any use of the iovecs. +pub struct RxQueueProducer { + /// The virtio RX queue + queue: Queue, + /// Guest memory reference + mem: GuestMemoryMmap, + /// Interrupt for signaling guest + interrupt: InterruptTransport, + /// Maximum number of chains to keep pending at once. + max_chains: usize, + /// Per-chain representation (type depends on R) + chain_repr: Vec, + /// Metadata for each chain (parallel to chain_repr) + chain_meta: Vec>, +} + +impl RxQueueProducer { + /// Create a new RxQueueProducer with the given queue, memory, and interrupt. + pub fn new(queue: Queue, mem: GuestMemoryMmap, interrupt: InterruptTransport) -> Self { + let max_chains = queue.size as usize * 8; + Self { + queue, + mem, + interrupt, + max_chains, + chain_repr: Vec::new(), + chain_meta: Vec::new(), + } + } + + /// Set the maximum number of chains to keep pending at once. + pub fn set_max_chains(&mut self, max: usize) { + self.max_chains = max; + } + + /// Feed descriptor chains from the queue, converting each into the + /// representation type `R` via a user-supplied callback. + /// + /// The callback receives the chain's writable iovecs and returns an `(R, Meta)` + /// pair. It may mutate the iovecs before building `R` — for example, writing + /// a header and advancing past it so that subsequent I/O starts after the + /// header. Any bytes consumed by the callback are automatically tracked. + /// + /// Returns the number of chains added. + /// + pub fn feed_with_transform(&mut self, mut transform: F) -> usize + where + F: for<'a> FnMut(Vec>) -> (R, R::Meta), + { + let mut added = 0; + + if let Err(e) = self.queue.disable_notification(&self.mem) { + warn!("Failed to disable queue notifications: {e:?}"); + } + 'next_chain: while self.pending_count() < self.max_chains { + let Some(head) = self.queue.pop(&self.mem) else { + // Queue exhausted: re-enable driver kicks. If more descriptors arrived in the + // meantime, loops back to pop them; otherwise break and expect the user to wake + // us up on the next kick. + match self.queue.enable_notification(&self.mem) { + Ok(true) => continue 'next_chain, + Ok(false) => break 'next_chain, + Err(e) => { + error!("Failed to re-enable queue notifications: {e:?}"); + break 'next_chain; + } + } + }; + + let head_index = head.index; + let mut iovecs: Vec> = Vec::new(); + + for desc in head.into_iter().filter(DescriptorChain::is_write_only) { + if let Some(iov) = unsafe { self.desc_to_ioslice_mut(&desc) } { + iovecs.push(iov); + } else { + log::error!("Invalid descriptor: {desc:?}, skipping the chain",); + continue 'next_chain; + } + } + + if iovecs.is_empty() { + log::warn!("Found empty chain, ignoring it"); + continue 'next_chain; + } + + // Compute original chain length before transformation + let max_bytes: usize = iovecs.iter().map(|iov| iov.len()).sum(); + + // Apply transformation (callback takes ownership, returns representation) + let (repr, user_meta) = transform(iovecs); + + // Track bytes already consumed by transform + let bytes_used = max_bytes - repr.total_bytes(); + + self.chain_repr.push(repr); + self.chain_meta.push(ChainMeta { + head_index, + max_bytes, + bytes_used, + finished: false, + user_meta, + }); + added += 1; + } + + added + } + + /// Number of chains pending (not yet sent) + pub fn pending_count(&self) -> usize { + self.chain_meta.len() + } + + /// Check if there are any pending chains + pub fn has_pending(&self) -> bool { + self.pending_count() > 0 + } + + /// Convert a descriptor to a mutable IoSlice pointing into guest memory. + /// + /// Returns None if the descriptor's memory region cannot be found or mapped. + /// + unsafe fn desc_to_ioslice_mut(&self, desc: &DescriptorChain) -> Option> { + let len = desc.len as usize; + let slice = self.mem.get_slice(desc.addr, len).ok()?; + let ptr = slice.ptr_guard_mut().as_ptr(); + + // Safety: We own the GuestMemoryMmap, so the memory is valid for our lifetime. + // The slice points into pinned guest memory that won't move. + let byte_slice = unsafe { std::slice::from_raw_parts_mut(ptr, len) }; + + // Transmute to 'static - safe because we own the memory reference + let static_slice: &mut [u8] = unsafe { std::mem::transmute(byte_slice) }; + + Some(IoSliceMut::new(static_slice)) + } + + /// Produce frames by calling the callback with a batch. + /// + /// The callback receives an `RxProducerBatch` which provides access to chains + /// and methods to mark them as complete. Returns the number of chains finished. + pub fn produce(&mut self, f: F) -> usize + where + F: for<'a> FnOnce(&mut RxProducerBatch<'a, R>), + { + if self.chain_meta.is_empty() { + log::info!("produce: no chains pending, returning 0"); + return 0; + } + + log::info!( + "produce: {} chains pending, calling callback", + self.chain_meta.len() + ); + + let mut batch = RxProducerBatch { + chain_repr: &mut self.chain_repr, + chain_meta: &mut self.chain_meta, + queue: &mut self.queue, + mem: &self.mem, + first_unfinished: 0, + }; + + f(&mut batch); + let finished_count = self.compact(); + + if finished_count > 0 { + self.signal_used_if_needed(); + } + + log::trace!( + "produce: finished_count={} remaining={}", + finished_count, + self.chain_meta.len() + ); + + finished_count + } + + // Remove finished chains in O(n) by swapping unfinished to front, then truncating + // (for producer we don't care about the order of the descriptor chains) + fn compact(&mut self) -> usize { + let mut finished_count = 0; + let mut write = 0; + + for read in 0..self.chain_meta.len() { + if self.chain_meta[read].finished { + self.chain_repr[read].clear(&mut self.chain_meta[read].user_meta); + finished_count += 1; + } else { + if write != read { + self.chain_repr.swap(write, read); + self.chain_meta.swap(write, read); + } + write += 1; + } + } + + self.chain_repr.truncate(write); + self.chain_meta.truncate(write); + + finished_count + } + + /// Signal used queue interrupt if needed. + fn signal_used_if_needed(&mut self) { + match self.queue.needs_notification(&self.mem) { + Ok(true) => { + log::info!("RxQueueProducer: signaling used queue interrupt"); + self.interrupt.signal_used_queue(); + } + Ok(false) => { + log::info!("RxQueueProducer: needs_notification returned false, not signaling"); + } + Err(e) => { + log::error!("RxQueueProducer: needs_notification error: {e}"); + } + } + } +} + +/// Convenience methods for the default representation type (IovecVec). +impl RxQueueProducer { + /// Feed descriptor chains from queue without transformation. + /// + /// This is a convenience method for the common case where no header + /// transformation is needed. + pub fn feed(&mut self) -> usize { + self.feed_with_transform(|iovecs| { + let raw: Vec = unsafe { std::mem::transmute(iovecs) }; + (IovecVec(raw), ()) + }) + } +} + +/// Batch for producing RX chains. +/// +/// Provides access to pending chains and methods to mark them as complete. +/// Panics if you access or finish an already-finished chain. +pub struct RxProducerBatch<'a, R: ChainsMemoryRepr> { + chain_repr: &'a mut [R], + chain_meta: &'a mut [ChainMeta], + queue: &'a mut Queue, + mem: &'a GuestMemoryMmap, + /// Index of first unfinished chain. Chains 0..first_unfinished are finished. + /// For sequential finishing (0, 1, 2...), this advances efficiently. + first_unfinished: usize, +} + +impl RxProducerBatch<'_, R> { + /// Number of chains in the batch. + #[inline] + pub fn len(&self) -> usize { + self.chain_repr.len() + } + + /// Check if the batch is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.chain_repr.is_empty() + } + + /// Check if chain is already finished. + #[inline] + pub fn is_finished(&self, index: usize) -> bool { + self.chain_meta[index].finished + } + + /// Get bytes already produced for chain at index. + #[inline] + pub fn bytes_used(&self, index: usize) -> usize { + self.chain_meta[index].bytes_used + } + + /// Get maximum bytes the chain can hold. + #[inline] + pub fn max_bytes(&self, index: usize) -> usize { + self.chain_meta[index].max_bytes + } + + /// Get reference to the user-defined metadata for chain at index. + #[inline] + pub fn user_meta(&self, index: usize) -> &R::Meta { + &self.chain_meta[index].user_meta + } + + // Get mutable access to the chain at index. + /// + /// # Panics + /// + /// Panics if index is out of bounds or if the chain has already been finished. + pub fn chain_mut(&mut self, index: usize) -> &mut R { + self.assert_not_finished(index); + &mut self.chain_repr[index] + } + + /// Get mutable access to chains in a range (checked). + /// + /// O(1) if chains are being finished sequentially, O(n) otherwise. + /// + /// # Panics + /// + /// Panics if any chain in the range has already been finished. + pub fn chains_mut(&mut self, range: Range) -> &mut [R] { + self.assert_range_not_finished(range.clone()); + &mut self.chain_repr[range] + } + + /// Finish a range of chains, reporting them to the guest. + /// + /// The received byte count should already have been set via + /// [`advance`](Self::advance). To set the byte count and finish in one + /// step, use [`complete`](Self::complete) or [`complete_many`](Self::complete_many). + /// + /// Chains can be finished out-of-order, but sequential finishing + /// (0, 1, 2...) is preferable. + /// + /// O(1) if chains are being finished sequentially, O(n) otherwise. + /// + /// # Panics + /// + /// Panics if any chain in the range has already been finished. + pub fn finish_many(&mut self, range: Range) { + if range.is_empty() { + return; + } + + let range_start = range.start; + let range_end = range.end; + + for i in range { + self.assert_not_finished(i); + let meta = &mut self.chain_meta[i]; + meta.finished = true; + + log::trace!( + "finishing chain index={} head_index={} bytes_used={}", + i, + meta.head_index, + meta.bytes_used + ); + + if let Err(e) = self + .queue + .add_used(self.mem, meta.head_index, meta.bytes_used as u32) + { + log::error!("failed to add_used: {e}"); + } + } + + debug_assert!(range_start >= self.first_unfinished); + if range_start == self.first_unfinished { + // Jump to the end of the range we just verified and finished + self.first_unfinished = range_end; + + // Scan forward in case there were out-of-order finishes sitting ahead of us + while self.first_unfinished < self.chain_meta.len() + && self.chain_meta[self.first_unfinished].finished + { + self.first_unfinished += 1; + } + } + } + + /// Finish a chain, reporting it to the guest. + /// + /// The received byte count should already have been set via + /// [`advance`](Self::advance). To set the byte count and finish in one + /// step, use [`complete`](Self::complete). + /// + /// # Panics + /// + /// Panics if the chain at `index` has already been finished. + pub fn finish(&mut self, index: usize) { + self.finish_many(index..index + 1); + } + + #[track_caller] + #[inline] + fn assert_range_not_finished(&self, range: Range) { + // Fast path: if range starts at or after first_unfinished, all are unfinished + if range.start < self.first_unfinished { + // Slow path: range may include finished chains, check each + for i in range { + self.assert_not_finished(i); + } + } + } + + /// Set the received byte count and finish the chain, reporting it to the guest. + /// + /// This is the primary way to hand a received buffer back to the guest. + /// If the byte count was already set via [`advance`](Self::advance), use + /// [`finish`](Self::finish) instead. + /// + /// See also [`complete_received`](Self::complete_received) when the chain + /// representation knows its own received length. + /// + /// # Panics + /// + /// Panics if the chain at `index` has already been finished. + pub fn complete(&mut self, index: usize, bytes: usize) { + let meta = &mut self.chain_meta[index]; + meta.bytes_used += bytes; + debug_assert!( + meta.bytes_used <= meta.max_bytes, + "complete: bytes_used {} exceeds max_bytes {}", + meta.bytes_used, + meta.max_bytes + ); + self.finish(index); + } + + #[track_caller] + #[inline] + fn assert_not_finished(&self, index: usize) { + assert!( + !self.is_finished(index), + "chain at index {index} already finished", + ); + } +} + +/// Methods for representation types that support advancing (for partial receives). +impl RxProducerBatch<'_, R> { + /// Advance bytes used for chain at index (partial receive). + /// + /// Updates bytes_used and advances the iovecs in place. + /// Chain remains pending for next produce() call. + /// + /// # Panics + /// + /// Panics if the chain at `index` has already been finished. + pub fn advance(&mut self, index: usize, bytes: usize) { + assert!( + !self.chain_meta[index].finished, + "advance: chain at index {} already finished", + index + ); + let meta = &mut self.chain_meta[index]; + meta.bytes_used += bytes; + debug_assert!( + meta.bytes_used <= meta.max_bytes, + "advance: bytes_used {} exceeds max_bytes {}", + meta.bytes_used, + meta.max_bytes + ); + self.chain_repr[index].advance(bytes); + } +} + +/// Methods for representation types that report their own received byte count. +impl RxProducerBatch<'_, R> { + /// Complete a chain, reading the received byte count from the chain's + /// [`ReceivedLen`] implementation and reporting it to the guest. + /// + /// # Panics + /// + /// Panics if the chain at `index` has already been finished. + pub fn complete_received(&mut self, index: usize) { + self.complete_received_many(index..index + 1); + } + + /// Complete a range of chains, reading the received byte count from each + /// chain's [`ReceivedLen`] implementation and reporting them to the guest. + /// + /// + /// # Panics + /// + /// Panics if any chain in the range has already been finished. + pub fn complete_received_many(&mut self, range: Range) { + for i in range.clone() { + self.chain_meta[i].bytes_used += self.chain_repr[i].received_len(); + } + self.finish_many(range); + } +} + +/// Methods for representation types that support truncating (limiting receive size). +impl RxProducerBatch<'_, R> { + /// Truncate chain at index to limit receive to `max_bytes`. + /// + /// This is useful when you know the frame size ahead of time and want to + /// limit how much data can be received into the buffer. + /// + /// # Panics + /// + /// Panics if the chain at `index` has already been finished. + pub fn truncate(&mut self, index: usize, max_bytes: usize) { + assert!( + !self.chain_meta[index].finished, + "truncate: chain at index {} already finished", + index + ); + self.chain_repr[index].truncate_bytes(max_bytes); + } +} + +/// Specialized methods for the default IovecVec representation type. +impl RxProducerBatch<'_, IovecVec> { + /// Get a chain's iovecs as mutable IoSliceMut references. + /// + /// # Panics + /// + /// Panics if index is out of bounds or if the chain has already been finished. + pub fn io_slices_mut(&mut self, index: usize) -> &mut [IoSliceMut<'_>] { + assert!( + !self.chain_meta[index].finished, + "io_slices_mut: chain at index {} already finished", + index + ); + let slice = &mut self.chain_repr[index].0[..]; + // The lifetime is tied to &mut self, ensuring the iovecs remain valid. + unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), slice.len()) } + } + + /// Write data to chain and advance bytes_used (without finishing). + /// + /// Useful for writing headers (e.g., vnet header for RX) before receiving + /// the actual payload. + /// + /// # Errors + /// + /// Returns `Err(())` if the chain doesn't have enough space for all the data. + /// + /// # Panics + /// + /// Panics if the chain at `index` has already been finished. + #[allow(clippy::result_unit_err)] + pub fn write_advance(&mut self, index: usize, data: &[u8]) -> Result<(), ()> { + let written = write_to_iovecs(self.io_slices_mut(index), data); + if written != data.len() { + return Err(()); + } + self.advance(index, written); + Ok(()) + } + + /// Write data to chain and complete it. + /// + /// Writes the data, advances bytes_used, and finishes the chain in one call. + /// + /// # Errors + /// + /// Returns `Err(())` if the chain doesn't have enough space for all the data. + /// + /// # Panics + /// + /// Panics if the chain at `index` has already been finished. + #[allow(clippy::result_unit_err)] + pub fn write_complete(&mut self, index: usize, data: &[u8]) -> Result<(), ()> { + let written = write_to_iovecs(self.io_slices_mut(index), data); + if written != data.len() { + return Err(()); + } + self.complete(index, written); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::io::IoSliceMut; + + use std::cell::Cell; + + use libc::iovec; + + use crate::virtio::batch_queue::iovec_utils::{advance_iovecs_vec, write_to_iovecs}; + use crate::virtio::batch_queue::{ChainsMemoryRepr, IovecVec, ReceivedLen}; + use crate::virtio::test_utils::{create_interrupt, ExpectedUsed, TestSetup}; + + use super::RxQueueProducer; + + /// Helper type alias for tests using default representation + type TestRxProducer = RxQueueProducer; + + /// Helper to convert IoSliceMut to IovecVec (for test callbacks) + fn to_iovec(iovecs: Vec>) -> IovecVec { + IovecVec(unsafe { std::mem::transmute(iovecs) }) + } + + #[test] + fn test_initial_state() { + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + let mut producer: TestRxProducer = + RxQueueProducer::new(queue, setup.mem().clone(), create_interrupt()); + + assert_eq!(producer.pending_count(), 0); + assert_eq!(producer.feed(), 0); + assert_eq!(producer.pending_count(), 0); + assert_eq!(producer.produce(|_batch| {}), 0); + driver.assert_used(&[]); + } + + #[test] + fn test_feed_single_writable_descriptor() { + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + driver.writable(&[1500]); + + let mut producer: TestRxProducer = + RxQueueProducer::new(queue, setup.mem().clone(), create_interrupt()); + + let added = producer.feed(); + + assert_eq!(added, 1); + assert_eq!(producer.pending_count(), 1); + } + + #[test] + fn test_feed_chained_writable_descriptors() { + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + // Chain of 2 writable descriptors + driver.writable(&[512, 1024]); + + let mut producer: TestRxProducer = + RxQueueProducer::new(queue, setup.mem().clone(), create_interrupt()); + + let added = producer.feed(); + + assert_eq!(added, 1); + assert_eq!(producer.pending_count(), 1); + + // Verify buffer structure via produce + producer.produce(|batch| { + assert_eq!(batch.len(), 1); + let chain = batch.io_slices_mut(0); + assert_eq!(chain.len(), 2); + assert_eq!(chain[0].len(), 512); + assert_eq!(chain[1].len(), 1024); + // Don't mark anything as finished + }); + + // We haven't finished anything + driver.assert_used(&[]); + } + + #[test] + fn test_feed_respects_max_frames() { + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + driver + .writable(&[1500]) + .writable(&[1500]) + .writable(&[1500]) + .writable(&[1500]) + .writable(&[1500]); + + let mut producer: TestRxProducer = + RxQueueProducer::new(queue, setup.mem().clone(), create_interrupt()); + producer.set_max_chains(2); + + let added = producer.feed(); + + assert_eq!(added, 2); + assert_eq!(producer.pending_count(), 2); + } + + #[test] + fn test_produce_via_write_bytes() { + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + driver.writable(&[10, 90]).writable(&[100]).writable(&[100]); + + let mut producer: TestRxProducer = + RxQueueProducer::new(queue, setup.mem().clone(), create_interrupt()); + + producer.feed(); + assert_eq!(producer.pending_count(), 3); + + let completed = producer.produce(|batch| { + assert_eq!(batch.max_bytes(0), 100); + batch.write_complete(0, b"Received packet 1").unwrap(); + assert_eq!(batch.bytes_used(0), 17); + assert!(batch.is_finished(0)); + + assert_eq!(batch.max_bytes(1), 100); + batch.write_complete(1, b"Received packet 2").unwrap(); + assert_eq!(batch.bytes_used(1), 17); + assert!(batch.is_finished(1)); + + // Third left unfinished + assert_eq!(batch.max_bytes(2), 100); + assert_eq!(batch.bytes_used(2), 0); + assert!(!batch.is_finished(2)); + }); + + assert_eq!(completed, 2); + assert_eq!(producer.pending_count(), 1); + + // Verify add_used was called with actual bytes written (17), not buffer capacity (1500) + // Also verifies the content written to guest memory + driver.assert_used(&[ + (0, ExpectedUsed::Writable(b"Received packet 1")), + (1, ExpectedUsed::Writable(b"Received packet 2")), + ]); + } + + #[test] + fn test_multiple_produce_cycles() { + // Each chain: 3 descriptors [6, 12, 6] = 24 bytes raw. + // Transform writes "HD" (2 bytes) header then advances past it. + // Usable iovecs after transform: [4, 12, 6] = 22 bytes. + // + // Leftover state per cycle: + // cycle 1 → 2 leftover (1 partial, 1 untouched) + // cycle 2 → 3 leftover (complete the partial, 3 untouched) + // cycle 3 → 1 leftover (complete 2 of 3) + // cycle 4 → 0 leftover (drain everything) + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(32); + + driver + .writable(&[6, 12, 6]) + .writable(&[6, 12, 6]) + .writable(&[6, 12, 6]); + + let mut producer: TestRxProducer = + RxQueueProducer::new(queue, setup.mem().clone(), create_interrupt()); + + let feed_with_hdr = |p: &mut TestRxProducer| { + p.feed_with_transform(|mut iovecs| { + write_to_iovecs(&mut iovecs, b"HD"); + advance_iovecs_vec(&mut iovecs, 2); + (to_iovec(iovecs), ()) + }) + }; + + // ── Cycle 1: feed 3, complete 1, partial 1, leave 1 untouched ─── + assert_eq!(feed_with_hdr(&mut producer), 3); + assert_eq!(producer.pending_count(), 3); + + let completed = producer.produce(|batch| { + // Chain 0: 18-byte write spanning all 3 iovecs, complete + batch.write_complete(0, b"aaaaaaaaaaaaaaaaaa").unwrap(); + + // Chain 1: partial write (4 bytes into first iovec) + let written = write_to_iovecs(batch.io_slices_mut(1), b"bbbb"); + assert_eq!(written, 4); + batch.advance(1, 4); + + // Chain 2: untouched + }); + assert_eq!(completed, 1); + assert_eq!(producer.pending_count(), 2); + + driver.assert_used(&[(0, ExpectedUsed::Writable(b"HDaaaaaaaaaaaaaaaaaa"))]); + + // ── Cycle 2: guest adds 2 buffers, complete the partial ───────── + driver + .writable(&[1, 1, 3, 3, 12, 6]) // 6 descriptors, HD consumes first two → [3, 3, 12, 6] usable + .writable(&[6, 12, 6]); + assert_eq!(feed_with_hdr(&mut producer), 2); + assert_eq!(producer.pending_count(), 4); + + let completed = producer.produce(|batch| { + // Batch[0] (chain 1): continue partial, write 8 more b's + let written = write_to_iovecs(batch.io_slices_mut(0), b"bbbbbbbb"); + assert_eq!(written, 8); + batch.complete(0, 8); + + // Batch[1..3]: untouched (simulating no more packets this cycle) + }); + assert_eq!(completed, 1); + assert_eq!(producer.pending_count(), 3); + + driver.assert_used(&[ + (0, ExpectedUsed::Writable(b"HDaaaaaaaaaaaaaaaaaa")), + (1, ExpectedUsed::Writable(b"HDbbbbbbbbbbbb")), + ]); + + // ── Cycle 3: no new buffers, complete 2 of 3, leave 1 ────────── + let completed = producer.produce(|batch| { + assert_eq!(batch.len(), 3); + batch.write_complete(0, b"cccccccccccc").unwrap(); + batch.write_complete(1, b"dddddd").unwrap(); // spans [3, 3] boundary + // Batch[2]: untouched + }); + assert_eq!(completed, 2); + assert_eq!(producer.pending_count(), 1); + + driver.assert_used(&[ + (0, ExpectedUsed::Writable(b"HDaaaaaaaaaaaaaaaaaa")), + (1, ExpectedUsed::Writable(b"HDbbbbbbbbbbbb")), + (2, ExpectedUsed::Writable(b"HDcccccccccccc")), + (3, ExpectedUsed::Writable(b"HDdddddd")), + ]); + + // ── Cycle 4: guest adds 1 buffer, complete both remaining ─────── + driver.writable(&[6, 12, 6]); + assert_eq!(feed_with_hdr(&mut producer), 1); + assert_eq!(producer.pending_count(), 2); + + let completed = producer.produce(|batch| { + assert_eq!(batch.len(), 2); + batch.write_complete(0, b"eeee").unwrap(); + batch.write_complete(1, b"ffff").unwrap(); + }); + assert_eq!(completed, 2); + assert_eq!(producer.pending_count(), 0); + + // Letter = chain index: a=0, b=1, c=2, d=3, e=4, f=5 + driver.assert_used(&[ + (0, ExpectedUsed::Writable(b"HDaaaaaaaaaaaaaaaaaa")), + (1, ExpectedUsed::Writable(b"HDbbbbbbbbbbbb")), + (2, ExpectedUsed::Writable(b"HDcccccccccccc")), + (3, ExpectedUsed::Writable(b"HDdddddd")), + (4, ExpectedUsed::Writable(b"HDeeee")), + (5, ExpectedUsed::Writable(b"HDffff")), + ]); + } + + #[test] + fn test_out_of_order_completion() { + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + driver + .writable(&[2, 2]) + .writable(&[2, 2]) + .writable(&[2, 2]) + .writable(&[2, 2]); + + let mut producer: TestRxProducer = + RxQueueProducer::new(queue, setup.mem().clone(), create_interrupt()); + + producer.feed(); + assert_eq!(producer.pending_count(), 4); + + // Complete chains 3 and 1 (out of order), leave 0 and 2 pending + let completed = producer.produce(|batch| { + batch.write_complete(3, b"pkt3").unwrap(); + batch.write_complete(1, b"pkt1").unwrap(); + }); + + assert_eq!(completed, 2); + assert_eq!(producer.pending_count(), 2); + + // Used ring reflects completion order (3 then 1) + driver.assert_used(&[ + (3, ExpectedUsed::Writable(b"pkt3")), + (1, ExpectedUsed::Writable(b"pkt1")), + ]); + + // Complete remaining chains, also out of order + let completed = producer.produce(|batch| { + batch.write_complete(1, b"pkt2").unwrap(); + batch.write_complete(0, b"pkt0").unwrap(); + }); + + assert_eq!(completed, 2); + assert_eq!(producer.pending_count(), 0); + + // All 4 chains in used ring in the order they were completed + driver.assert_used(&[ + (3, ExpectedUsed::Writable(b"pkt3")), + (1, ExpectedUsed::Writable(b"pkt1")), + (2, ExpectedUsed::Writable(b"pkt2")), + (0, ExpectedUsed::Writable(b"pkt0")), + ]); + } + + /// Custom representation simulating recvmmsg-style batch receive. + /// Each chain stores iovecs + a filled received_len (like mmsghdr.msg_len). + struct CustomChainRepr { + iovecs: Vec, + received_len: Cell, + } + + impl CustomChainRepr { + /// Writes `data` across the iovec scatter list and sets received_len. + fn simulate_recv(&mut self, data: &[u8]) { + // Safety: IoSliceMut is #[repr(transparent)] over iovec. + let slices: &mut [IoSliceMut] = unsafe { + std::slice::from_raw_parts_mut( + self.iovecs.as_mut_ptr() as *mut IoSliceMut, + self.iovecs.len(), + ) + }; + let written = write_to_iovecs(slices, data); + self.received_len.set(written); + } + } + + unsafe impl ChainsMemoryRepr for CustomChainRepr { + type Meta = u32; // tag to verify metadata works + + fn len(&self) -> usize { + self.iovecs.len() + } + + fn total_bytes(&self) -> usize { + self.iovecs.iter().map(|iov| iov.iov_len).sum() + } + + fn clear(&mut self, _meta: &mut u32) { + self.iovecs.clear(); + self.received_len.set(0); + } + } + + impl ReceivedLen for CustomChainRepr { + fn received_len(&self) -> usize { + self.received_len.get() + } + } + + unsafe impl Send for CustomChainRepr {} + + #[test] + fn test_complete_received_many() { + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + driver + .writable(&[100]) + .writable(&[100]) + .writable(&[100]) + .writable(&[100]); + + let mut producer: RxQueueProducer = + RxQueueProducer::new(queue, setup.mem().clone(), create_interrupt()); + + // Feed with meta tags 10, 20, 30, 40 + let mut tag = 0u32; + let added = producer.feed_with_transform(|iovecs| { + tag += 10; + let raw: Vec = unsafe { std::mem::transmute(iovecs) }; + let repr = CustomChainRepr { + iovecs: raw, + received_len: Cell::new(0), + }; + (repr, tag) + }); + assert_eq!(added, 4); + + // Simulate recvmmsg: kernel writes data + fills received_len on each repr. + let completed = producer.produce(|batch| { + assert_eq!(batch.len(), 4); + + // Verify meta tags round-tripped + assert_eq!(*batch.user_meta(0), 10); + assert_eq!(*batch.user_meta(1), 20); + assert_eq!(*batch.user_meta(2), 30); + assert_eq!(*batch.user_meta(3), 40); + + // Simulate kernel writing data (like recvmmsg would) + batch.chain_mut(0).simulate_recv(b"aaaa"); + batch.chain_mut(1).simulate_recv(b"bbbbbbbb"); + // chain 2: no data yet, leave pending + batch.chain_mut(3).simulate_recv(b"dddddddddddd"); + + // Batch complete first two chains + batch.complete_received_many(0..2); + assert!(batch.is_finished(0)); + assert!(batch.is_finished(1)); + assert_eq!(batch.bytes_used(0), 4); + assert_eq!(batch.bytes_used(1), 8); + + // Single complete for chain 3 + batch.complete_received(3); + assert!(batch.is_finished(3)); + assert_eq!(batch.bytes_used(3), 12); + + // Chain 2 left pending + assert!(!batch.is_finished(2)); + }); + assert_eq!(completed, 3); + assert_eq!(producer.pending_count(), 1); + + driver.assert_used(&[ + (0, ExpectedUsed::Writable(b"aaaa")), + (1, ExpectedUsed::Writable(b"bbbbbbbb")), + (3, ExpectedUsed::Writable(b"dddddddddddd")), + ]); + + // ── Cycle 2: complete the remaining chain ───────────────────────── + let completed = producer.produce(|batch| { + assert_eq!(batch.len(), 1); + // Verify meta survived compaction (chain 2 had tag 30) + assert_eq!(*batch.user_meta(0), 30); + + batch.chain_mut(0).simulate_recv(b"cccccc"); + batch.complete_received(0); + assert_eq!(batch.bytes_used(0), 6); + }); + assert_eq!(completed, 1); + assert_eq!(producer.pending_count(), 0); + + driver.assert_used(&[ + (0, ExpectedUsed::Writable(b"aaaa")), + (1, ExpectedUsed::Writable(b"bbbbbbbb")), + (3, ExpectedUsed::Writable(b"dddddddddddd")), + (2, ExpectedUsed::Writable(b"cccccc")), + ]); + } + + #[test] + #[should_panic(expected = "already finished")] + fn test_double_finish_panics() { + let setup = TestSetup::new(); + let (queue, _driver) = setup.create_queue(16); + _driver.writable(&[100]); + + let mut producer: TestRxProducer = + RxQueueProducer::new(queue, setup.mem().clone(), create_interrupt()); + + producer.feed(); + producer.produce(|batch| { + batch.complete(0, 10); + batch.complete(0, 10); // panic: already finished + }); + } +} diff --git a/src/devices/src/virtio/batch_queue/tx_queue_consumer.rs b/src/devices/src/virtio/batch_queue/tx_queue_consumer.rs new file mode 100644 index 000000000..e295f265c --- /dev/null +++ b/src/devices/src/virtio/batch_queue/tx_queue_consumer.rs @@ -0,0 +1,945 @@ +// Copyright 2026 Red Hat, Inc. +// SPDX-License-Identifier: Apache-2.0 + +//! TX queue consumer for batched virtio transmit operations. + +use std::io::IoSlice; +use std::ops::Range; + +use libc::iovec; +use vm_memory::{GuestMemory, GuestMemoryMmap}; + +use super::super::queue::{DescriptorChain, Queue}; +use super::super::InterruptTransport; +use super::{AdvanceBytes, ChainsMemoryRepr, IovecVec}; + +/// Metadata for a pending descriptor chain. +#[derive(Debug, Clone)] +struct ChainMeta { + head_index: u16, + /// Total bytes in iovecs + max_bytes: usize, + /// Bytes from guest descriptors (for add_used reporting) + guest_len: usize, + /// Bytes sent so far (for partial send tracking) + bytes_used: usize, + finished: bool, + /// User-defined metadata + user_meta: M, +} + +/// TxQueueConsumer - owns the TX queue and manages chain batching. +/// +/// Generic over representation type R, allowing different backends to use optimized +/// representations (e.g., mmsghdr for sendmmsg). Default is IovecVec. +/// +/// The iovecs stored in chain representation point into guest memory owned by `mem`. +/// This is safe because the struct owns the memory reference and outlives any +/// use of the iovecs. +pub struct TxQueueConsumer { + /// The virtio TX queue (owned) + queue: Queue, + /// Guest memory reference + mem: GuestMemoryMmap, + /// Interrupt for signaling guest + interrupt: InterruptTransport, + /// Maximum number of chains to keep pending at once. + max_chains: usize, + /// Per-chain representation (type depends on R) + chain_repr: Vec, + /// Metadata for each chain (parallel to chain_repr) + chain_meta: Vec>, + + /// Number of chains fully sent + sent_chains: usize, +} + +impl TxQueueConsumer { + /// Create a new TxQueueConsumer with the given queue, memory, and interrupt. + pub fn new(queue: Queue, mem: GuestMemoryMmap, interrupt: InterruptTransport) -> Self { + let max_chains = queue.size as usize * 8; + Self { + queue, + mem, + interrupt, + max_chains, + chain_repr: Vec::new(), + chain_meta: Vec::new(), + sent_chains: 0, + } + } + + /// Set the maximum number of chains to keep pending at once. + pub fn set_max_chains(&mut self, max: usize) { + self.max_chains = max; + } + + /// Feed descriptor chains from the queue, converting each into the + /// representation type `R` via a user-supplied callback. + /// + /// The callback receives the chain's readable iovecs and returns an `(R, Meta)` + /// pair. It may mutate the iovecs before building `R` — for example, skipping + /// a header so that subsequent I/O starts after it. Any bytes consumed by + /// the callback are automatically tracked. + /// + /// Returns the number of chains added. + pub fn feed_with_transform(&mut self, mut transform: F) -> usize + where + F: for<'a> FnMut(Vec>) -> (R, R::Meta), + { + let mut added = 0; + + if let Err(e) = self.queue.disable_notification(&self.mem) { + warn!("Failed to disable queue notifications: {e:?}"); + } + 'next_chain: while self.pending_count() < self.max_chains { + let Some(head) = self.queue.pop(&self.mem) else { + // Queue exhausted: re-enable driver kicks. If more descriptors arrived in the + // meantime, loops back to pop them; otherwise break and expect the user to wake + // us up on the next kick. + match self.queue.enable_notification(&self.mem) { + Ok(true) => continue 'next_chain, + Ok(false) => break 'next_chain, + Err(e) => { + error!("Failed to re-enable queue notifications: {e:?}"); + break 'next_chain; + } + } + }; + + let head_index = head.index; + let mut iovecs: Vec> = Vec::new(); + + for desc in head.into_iter().filter(DescriptorChain::is_read_only) { + if let Some(iov) = unsafe { self.desc_to_ioslice(&desc) } { + iovecs.push(iov); + } else { + log::error!("Invalid descriptor: {desc:?}, skipping the chain",); + continue 'next_chain; + } + } + + if iovecs.is_empty() { + warn!("Found empty chain, ignoring it"); + continue 'next_chain; + } + + // Compute original chain length before transformation + let guest_len: usize = iovecs.iter().map(|s| s.len()).sum(); + + // Apply transformation (callback takes ownership, returns representation) + let (repr, user_meta) = transform(iovecs); + + // Compute final length + let max_bytes = repr.total_bytes(); + + // Track bytes already consumed by transform + let bytes_used = max_bytes - repr.total_bytes(); + + self.chain_repr.push(repr); + self.chain_meta.push(ChainMeta { + head_index, + max_bytes, + guest_len, + bytes_used, + finished: false, + user_meta, + }); + added += 1; + } + + added + } + + /// Number of chains pending + pub fn pending_count(&self) -> usize { + self.chain_meta.len() + } + + /// Check if there are any pending chains + pub fn has_pending(&self) -> bool { + self.pending_count() > 0 + } + + /// Consume pending chains using a callback that performs the actual I/O. + /// + /// The callback receives a `TxConsumerBatch` which provides: + /// - `chain(i)` - access to chain iovecs by index (panics if already finished) + /// - `finish(i)` / `finish_many(range)` - mark chains as finished + /// + /// Returns the number of chains finished. Finished chains are removed + /// from the pending list and interrupt is signaled if needed. + pub fn consume(&mut self, f: F) -> usize + where + F: for<'a> FnOnce(&mut TxConsumerBatch<'a, R>), + { + if !self.has_pending() { + return 0; + } + + let finished_count; + { + let pending_storage = &mut self.chain_repr[self.sent_chains..]; + let pending_meta = &mut self.chain_meta[self.sent_chains..]; + + let mut batch = TxConsumerBatch { + chain_repr: pending_storage, + chain_meta: pending_meta, + queue: &mut self.queue, + mem: &self.mem, + first_finished: 0, + }; + + f(&mut batch); + finished_count = batch.first_finished; + } + + // Update sent_chains based on what was finished + self.sent_chains += finished_count; + + if finished_count > 0 { + self.signal_used_if_needed(); + } + + log::trace!( + "consume: finished_count={} remaining={}", + finished_count, + self.chain_meta.len() + ); + + self.compact(); + finished_count + } + + /// Convert a descriptor to an IoSlice pointing into guest memory. + /// + unsafe fn desc_to_ioslice(&self, desc: &DescriptorChain) -> Option> { + let len = desc.len as usize; + let slice = self.mem.get_slice(desc.addr, len).ok()?; + let ptr = slice.ptr_guard_mut().as_ptr(); + + // Safety: We own the GuestMemoryMmap, so the memory is valid for our lifetime. + let byte_slice = unsafe { std::slice::from_raw_parts(ptr, len) }; + Some(IoSlice::new(byte_slice)) + } + + /// Clears the finished chains from the begining. + fn compact(&mut self) { + if self.sent_chains > 0 { + // Clear representation properly (calls R::clear with meta) + for i in 0..self.sent_chains { + self.chain_repr[i].clear(&mut self.chain_meta[i].user_meta); + } + self.chain_repr.drain(..self.sent_chains); + self.chain_meta.drain(..self.sent_chains); + self.sent_chains = 0; + } + } + + /// Signal used queue interrupt if needed. + fn signal_used_if_needed(&mut self) { + match self.queue.needs_notification(&self.mem) { + Ok(true) => self.interrupt.signal_used_queue(), + Ok(false) => {} // No notification needed + Err(e) => { + log::error!("TxQueueConsumer: needs_notification error: {e}"); + } + } + } +} + +impl TxQueueConsumer { + /// Feed descriptor chains from queue without transformation. + /// + /// This is a convenience method for the common case where no header + /// transformation is needed. + pub fn feed(&mut self) -> usize { + self.feed_with_transform(|iovecs| { + let raw: Vec = unsafe { std::mem::transmute(iovecs) }; + (IovecVec(raw), ()) + }) + } +} + +/// Specialized methods for the default IovecVec representation type. +impl TxConsumerBatch<'_, IovecVec> { + /// Get a chain's iovecs as IoSlice references. + /// + /// # Panics + /// + /// Panics if index is out of bounds or if the chain has already been finished. + pub fn io_slices(&self, index: usize) -> &[IoSlice<'_>] { + assert!( + !self.chain_meta[index].finished, + "io_slices: chain at index {} already finished", + index + ); + let slice = &self.chain_repr[index].0[..]; + // iovec and IoSlice have the same memory layout + unsafe { std::slice::from_raw_parts(slice.as_ptr().cast(), slice.len()) } + } +} + +/// Batch for consuming TX chains. +/// +/// Provides access to pending chains and methods to mark them as finished. +/// +/// Panics if you access or finish an already-finished chain. +pub struct TxConsumerBatch<'a, R: ChainsMemoryRepr> { + chain_repr: &'a mut [R], + chain_meta: &'a mut [ChainMeta], + queue: &'a mut Queue, + mem: &'a GuestMemoryMmap, + /// Index of first unfinished chain. Chains 0..first_finished are finished. + /// For sequential finishing, this equals the number of finished chains. + first_finished: usize, +} + +impl TxConsumerBatch<'_, R> { + /// Number of pending chains in this batch. + #[inline] + pub fn len(&self) -> usize { + self.chain_repr.len() + } + + /// Check if the batch is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.chain_repr.is_empty() + } + + /// Check if chain is already finished. + #[inline] + pub fn is_finished(&self, index: usize) -> bool { + self.chain_meta[index].finished + } + + /// Get bytes already consumed for chain at index. + #[inline] + pub fn bytes_used(&self, index: usize) -> usize { + self.chain_meta[index].bytes_used + } + + /// Get maximum bytes the chain can hold. + #[inline] + pub fn max_bytes(&self, index: usize) -> usize { + self.chain_meta[index].max_bytes + } + + /// Get access to a chain at index. + /// + /// # Panics + /// + /// Panics if index is out of bounds or if the chain has already been finished. + pub fn chain(&self, index: usize) -> &R { + self.assert_not_finished(index); + &self.chain_repr[index] + } + + /// Get access to chains in a range (checked). + /// + /// Returns a slice of chain representations for the given range. + /// + /// O(1) if chains are being finished sequentially, O(n) otherwise. + /// + /// # Panics + /// + /// Panics if any chain in the range has already been finished. + pub fn chains(&self, range: Range) -> &[R] { + // Fast path: if range starts at or after first_finished, all are unfinished + if range.start < self.first_finished { + // Slow path: range may include finished chains, check each + for i in range.clone() { + self.assert_not_finished(i); + } + } + &self.chain_repr[range] + } + + /// Get total bytes across all pending (non-finished) chains. + pub fn total_bytes(&self) -> usize { + self.chain_meta + .iter() + .filter(|m| !m.finished) + .map(|m| m.max_bytes) + .sum() + } + + /// Mark chain at index as finished. + /// + /// Calls add_used immediately. Chain will be removed after consume() returns. + /// + /// # Panics + /// + /// Panics if the chain at `index` has already been finished. + pub fn finish(&mut self, index: usize) { + let meta = &mut self.chain_meta[index]; + assert!( + !meta.finished, + "finish: chain at index {} already finished", + index + ); + meta.finished = true; + log::trace!( + "finish: index={} head_index={} guest_len={}", + index, + meta.head_index, + meta.guest_len + ); + if let Err(e) = self + .queue + .add_used(self.mem, meta.head_index, meta.guest_len as u32) + { + log::error!("TxConsumerBatch: failed to add_used: {e}"); + } + + // Update first_finished for sequential finishing optimization + if index == self.first_finished { + while self.first_finished < self.chain_meta.len() + && self.chain_meta[self.first_finished].finished + { + self.first_finished += 1; + } + } + } + + /// Mark a range of chains as finished. + /// + /// # Panics + /// + /// Panics if any chain in the range has already been finished. + pub fn finish_many(&mut self, range: Range) { + for i in range { + self.finish(i); + } + } + + #[track_caller] + fn assert_not_finished(&self, index: usize) { + assert!( + !self.is_finished(index), + "chain at index {index} already finished", + ); + } +} + +/// Methods for representation types that support advancing (for partial sends). +impl TxConsumerBatch<'_, R> { + /// Advance bytes used for chain at index (partial send). + /// + /// Updates bytes_used and advances the iovecs in place. + /// Chain remains pending for next consume() call. + /// + /// # Panics + /// + /// Panics if the chain at `index` has already been finished. + pub fn advance(&mut self, index: usize, bytes: usize) { + assert!( + !self.chain_meta[index].finished, + "advance: chain at index {} already finished", + index + ); + self.chain_meta[index].bytes_used += bytes; + self.chain_repr[index].advance(bytes); + } +} + +#[cfg(test)] +mod tests { + use std::io::IoSlice; + + use crate::virtio::batch_queue::IovecVec; + use crate::virtio::test_utils::{create_interrupt, ExpectedUsed, TestSetup}; + + use super::TxQueueConsumer; + + /// Helper type alias for tests using default representation + type TestTxConsumer = TxQueueConsumer; + + /// Helper to convert IoSlice to IovecVec (for test callbacks) + fn to_iovec(iovecs: Vec>) -> IovecVec { + IovecVec(unsafe { std::mem::transmute(iovecs) }) + } + + #[test] + fn test_new_consumer_is_empty() { + let setup = TestSetup::new(); + let (queue, _driver) = setup.create_queue(16); + let consumer: TestTxConsumer = + TxQueueConsumer::new(queue, setup.mem().clone(), create_interrupt()); + + assert_eq!(consumer.pending_count(), 0); + assert!(!consumer.has_pending()); + } + + #[test] + fn test_feed_single_descriptor() { + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + driver.readable(&[b"Hello, World!"]); + + let mut consumer: TestTxConsumer = + TxQueueConsumer::new(queue, setup.mem().clone(), create_interrupt()); + + let added = consumer.feed(); + + assert_eq!(added, 1); + assert_eq!(consumer.pending_count(), 1); + assert!(consumer.has_pending()); + + // Verify chain content via consume callback + let finished = consumer.consume(|batch| { + assert_eq!(batch.len(), 1); + assert_eq!(batch.io_slices(0).len(), 1); + assert_eq!(&*batch.io_slices(0)[0], b"Hello, World!"); + batch.finish(0); + }); + + assert_eq!(finished, 1); + driver.assert_used(&[(0, ExpectedUsed::Readable(13))]); + } + + #[test] + fn test_feed_chained_descriptors() { + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + // Chain of two descriptors + driver.readable(&[b"First", b"Second"]); + + let mut consumer: TestTxConsumer = + TxQueueConsumer::new(queue, setup.mem().clone(), create_interrupt()); + + let added = consumer.feed(); + + assert_eq!(added, 1); + assert_eq!(consumer.pending_count(), 1); + + let finished = consumer.consume(|batch| { + assert_eq!(batch.io_slices(0).len(), 2); + assert_eq!(&*batch.io_slices(0)[0], b"First"); + assert_eq!(&*batch.io_slices(0)[1], b"Second"); + batch.finish(0); + }); + + assert_eq!(finished, 1); + driver.assert_used(&[(0, ExpectedUsed::Readable(11))]); + } + + #[test] + fn test_feed_multiple_frames() { + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + driver + .readable(&[b"Frame1"]) + .readable(&[b"Frame2"]) + .readable(&[b"Frame3"]); + + let mut consumer: TestTxConsumer = + TxQueueConsumer::new(queue, setup.mem().clone(), create_interrupt()); + + let added = consumer.feed(); + + assert_eq!(added, 3); + assert_eq!(consumer.pending_count(), 3); + + let finished = consumer.consume(|batch| { + assert_eq!(batch.len(), 3); + batch.finish_many(0..3); + }); + + assert_eq!(finished, 3); + driver.assert_used(&[ + (0, ExpectedUsed::Readable(6)), + (1, ExpectedUsed::Readable(6)), + (2, ExpectedUsed::Readable(6)), + ]); + } + + #[test] + fn test_feed_respects_max_chains() { + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + driver + .readable(&[b"F0"]) + .readable(&[b"F1"]) + .readable(&[b"F2"]) + .readable(&[b"F3"]) + .readable(&[b"F4"]); + + let mut consumer: TestTxConsumer = + TxQueueConsumer::new(queue, setup.mem().clone(), create_interrupt()); + consumer.set_max_chains(2); + + let added = consumer.feed(); + assert_eq!(added, 2); + assert_eq!(consumer.pending_count(), 2); + + // Already at limit + let added2 = consumer.feed(); + assert_eq!(added2, 0); + assert_eq!(consumer.pending_count(), 2); + } + + #[test] + fn test_feed_transform_callback() { + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + driver.readable(&[b"TestData12345"]); + + let mut consumer: TestTxConsumer = + TxQueueConsumer::new(queue, setup.mem().clone(), create_interrupt()); + + let added = consumer.feed_with_transform(|mut iovecs| { + // Skip 4 bytes (like skipping vnet header) + if !iovecs.is_empty() && iovecs[0].len() >= 4 { + let first = &iovecs[0]; + let ptr = first.as_ptr(); + let new_len = first.len() - 4; + let new_slice = unsafe { std::slice::from_raw_parts(ptr.add(4), new_len) }; + iovecs[0] = IoSlice::new(new_slice); + } + (to_iovec(iovecs), ()) + }); + + assert_eq!(added, 1); + + consumer.consume(|batch| { + batch.finish(0); + }); + + // Original guest length is 13, not 9 + driver.assert_used(&[(0, ExpectedUsed::Readable(13))]); + } + + #[test] + fn test_consume_and_finish_all() { + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + driver + .readable(&[b"FirstChain"]) + .readable(&[b"SecondChain"]); + + let mut consumer: TestTxConsumer = + TxQueueConsumer::new(queue, setup.mem().clone(), create_interrupt()); + + consumer.feed(); + assert_eq!(consumer.pending_count(), 2); + + let finished = consumer.consume(|batch| { + assert_eq!(batch.total_bytes(), 21); + batch.finish_many(0..batch.len()); + }); + + assert_eq!(finished, 2); + assert_eq!(consumer.pending_count(), 0); + + driver.assert_used(&[ + (0, ExpectedUsed::Readable(10)), + (1, ExpectedUsed::Readable(11)), + ]); + } + + #[test] + fn test_consume_partial() { + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + driver + .readable(&[b"Chain00000"]) + .readable(&[b"Chain11111"]) + .readable(&[b"Chain22222"]); + + let mut consumer: TestTxConsumer = + TxQueueConsumer::new(queue, setup.mem().clone(), create_interrupt()); + + consumer.feed(); + + // Finish only first chain + let finished = consumer.consume(|batch| { + batch.finish(0); + }); + + assert_eq!(finished, 1); + assert_eq!(consumer.pending_count(), 2); + driver.assert_used(&[(0, ExpectedUsed::Readable(10))]); + } + + #[test] + fn test_compact() { + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + driver + .readable(&[b"test"]) + .readable(&[b"test"]) + .readable(&[b"test"]) + .readable(&[b"test"]) + .readable(&[b"test"]); + + let mut consumer: TestTxConsumer = + TxQueueConsumer::new(queue, setup.mem().clone(), create_interrupt()); + + consumer.feed(); + assert_eq!(consumer.pending_count(), 5); + + // Finish 3 chains (compact is called internally) + let finished = consumer.consume(|batch| { + batch.finish_many(0..3); + }); + assert_eq!(finished, 3); + assert_eq!(consumer.pending_count(), 2); + + driver.assert_used(&[ + (0, ExpectedUsed::Readable(4)), + (1, ExpectedUsed::Readable(4)), + (2, ExpectedUsed::Readable(4)), + ]); + } + + #[test] + fn test_empty_queue_returns_zero() { + let setup = TestSetup::new(); + let (queue, _driver) = setup.create_queue(16); + // Don't add any descriptors + + let mut consumer: TestTxConsumer = + TxQueueConsumer::new(queue, setup.mem().clone(), create_interrupt()); + + let added = consumer.feed(); + + assert_eq!(added, 0); + assert_eq!(consumer.pending_count(), 0); + // consume returns 0 when no pending chains + let finished = consumer.consume(|_batch| {}); + assert_eq!(finished, 0); + } + + #[test] + fn test_no_finish_preserves_pending() { + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + driver.readable(&[b"TestData"]); + + let mut consumer: TestTxConsumer = + TxQueueConsumer::new(queue, setup.mem().clone(), create_interrupt()); + + consumer.feed(); + + // Callback doesn't finish anything (simulating EAGAIN/WouldBlock) + let finished = consumer.consume(|_batch| {}); + assert_eq!(finished, 0); + assert_eq!(consumer.pending_count(), 1); + + // Nothing should be in used ring yet + assert_eq!(driver.used_count(), 0); + } + + #[test] + fn test_remove_header_byte_tracking() { + // Guest provides [header (12) | payload (100)]. + // Transform skips header. byte_count = 100 (payload only). + // I/O returns 100 → chain finished. + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + + let mut data = vec![0x48u8; 12]; // header + data.extend(vec![0x50; 100]); // payload + driver.readable(&[&data]); + + let mut consumer: TestTxConsumer = + TxQueueConsumer::new(queue, setup.mem().clone(), create_interrupt()); + + let added = consumer.feed_with_transform(|mut iovecs| { + // Skip 12 bytes from first iovec + if !iovecs.is_empty() && iovecs[0].len() >= 12 { + let first = &iovecs[0]; + let ptr = first.as_ptr(); + let new_len = first.len() - 12; + let new_slice = unsafe { std::slice::from_raw_parts(ptr.add(12), new_len) }; + iovecs[0] = IoSlice::new(new_slice); + } + (to_iovec(iovecs), ()) + }); + assert_eq!(added, 1); + + let finished = consumer.consume(|batch| { + // Sum bytes in chain 0 (should be 100, not 112) + let total: usize = batch.io_slices(0).iter().map(|iov| iov.len()).sum(); + assert_eq!(total, 100); // payload only + batch.finish(0); + }); + + assert_eq!(finished, 1); + assert_eq!(consumer.pending_count(), 0); + + // add_used reports ORIGINAL guest length (112), not transformed (100) + driver.assert_used(&[(0, ExpectedUsed::Readable(112))]); + } + + #[test] + fn test_multi_cycle_partial_writes() { + // Tricky scenario: partial writes across multiple cycles. + // Chain layout after transform: payload only (100 bytes after skipping 12-byte header) + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + + let mut data = vec![0x48u8; 12]; // virtio header (skipped) + data.extend(vec![0x50; 100]); // payload + driver.readable(&[&data]); + + let mut consumer: TestTxConsumer = + TxQueueConsumer::new(queue, setup.mem().clone(), create_interrupt()); + + let added = consumer.feed_with_transform(|mut iovecs| { + if !iovecs.is_empty() && iovecs[0].len() >= 12 { + let first = &iovecs[0]; + let ptr = first.as_ptr(); + let new_len = first.len() - 12; + let new_slice = unsafe { std::slice::from_raw_parts(ptr.add(12), new_len) }; + iovecs[0] = IoSlice::new(new_slice); + } + (to_iovec(iovecs), ()) + }); + assert_eq!(added, 1); + + // Cycle 1: 2 bytes sent (partial) + consumer.consume(|batch| batch.advance(0, 2)); + assert_eq!(consumer.pending_count(), 1); + + // Cycle 2: 50 more bytes (total 52) + consumer.consume(|batch| batch.advance(0, 50)); + assert_eq!(consumer.pending_count(), 1); + + // Cycle 3: remaining 48 bytes - now finished + consumer.consume(|batch| { + batch.advance(0, 48); + batch.finish(0); + }); + assert_eq!(consumer.pending_count(), 0); + + // add_used reports ORIGINAL guest length (112) + driver.assert_used(&[(0, ExpectedUsed::Readable(112))]); + } + + #[test] + fn test_stop_resume_across_compact() { + // Feed 2 chains, partial send, compact, feed more, continue. + // This tests that state is preserved when guest adds more descriptors mid-stream. + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + + // First batch: 2 chains of 30 bytes each + let data = vec![0x50u8; 30]; + driver.readable(&[&data]).readable(&[&data]); + + let mut consumer: TestTxConsumer = + TxQueueConsumer::new(queue, setup.mem().clone(), create_interrupt()); + + consumer.feed(); + assert_eq!(consumer.pending_count(), 2); + + // Finish only first chain, advance partial on second + consumer.consume(|batch| { + batch.finish(0); + batch.advance(1, 15); + }); + assert_eq!(consumer.pending_count(), 1); + + // Only chain 0 in used ring so far + driver.assert_used(&[(0, ExpectedUsed::Readable(30))]); + + // Guest adds more descriptors (simulating queue refill) + driver.readable(&[&data]); // chain 2 + + consumer.feed(); + assert_eq!(consumer.pending_count(), 2); // chain 1 (partial) + chain 2 + + // Finish remaining chains + consumer.consume(|batch| { + batch.finish_many(0..2); + }); + assert_eq!(consumer.pending_count(), 0); + + // All 3 chains, including the one that crossed a compact boundary + driver.assert_used(&[ + (0, ExpectedUsed::Readable(30)), + (1, ExpectedUsed::Readable(30)), + (2, ExpectedUsed::Readable(30)), + ]); + } + + #[test] + fn test_out_of_order_finish() { + let setup = TestSetup::new(); + let (queue, driver) = setup.create_queue(16); + driver + .readable(&[b"pkt0"]) + .readable(&[b"pkt1"]) + .readable(&[b"pkt2"]) + .readable(&[b"pkt3"]); + + let mut consumer: TestTxConsumer = + TxQueueConsumer::new(queue, setup.mem().clone(), create_interrupt()); + + consumer.feed(); + assert_eq!(consumer.pending_count(), 4); + + // Finish chains 3 and 1 (out of order), leave 0 and 2 pending + let finished = consumer.consume(|batch| { + batch.finish(3); + batch.finish(1); + }); + + // first_finished never advanced past 0 (chain 0 not finished), + // so compact doesn't remove anything yet + assert_eq!(finished, 0); + assert_eq!(consumer.pending_count(), 4); + + // Used ring has both entries in finish-call order + driver.assert_used(&[ + (3, ExpectedUsed::ReadableAnyLen), + (1, ExpectedUsed::ReadableAnyLen), + ]); + + // Finish chain 0 — first_finished jumps 0→2 (skipping already-finished 1) + let finished = consumer.consume(|batch| { + batch.finish(0); + }); + + assert_eq!(finished, 2); // compact removes 0 and 1 + assert_eq!(consumer.pending_count(), 2); // chains 2 and 3 remain + + // Finish remaining: chain 2 (index 0) then chain 3 (index 1, already finished) + // Chain 3 was finished in the first cycle but not compacted until now + let finished = consumer.consume(|batch| { + batch.finish(0); // chain 2 + }); + + // first_finished: 0→1, then chain 1 (original 3) already finished → jumps to 2 + assert_eq!(finished, 2); + assert_eq!(consumer.pending_count(), 0); + + // All 4 in used ring in the order finish() was called + driver.assert_used(&[ + (3, ExpectedUsed::ReadableAnyLen), + (1, ExpectedUsed::ReadableAnyLen), + (0, ExpectedUsed::ReadableAnyLen), + (2, ExpectedUsed::ReadableAnyLen), + ]); + } + + #[test] + #[should_panic(expected = "already finished")] + fn test_double_finish_panics() { + let setup = TestSetup::new(); + let (queue, _driver) = setup.create_queue(16); + _driver.readable(&[b"data"]); + + let mut consumer: TestTxConsumer = + TxQueueConsumer::new(queue, setup.mem().clone(), create_interrupt()); + + consumer.feed(); + consumer.consume(|batch| { + batch.finish(0); + batch.finish(0); // panic: already finished + }); + } +} diff --git a/src/devices/src/virtio/mod.rs b/src/devices/src/virtio/mod.rs index 384aef5ac..0eeac2c19 100644 --- a/src/devices/src/virtio/mod.rs +++ b/src/devices/src/virtio/mod.rs @@ -12,6 +12,8 @@ use std::io::Error as IOError; #[cfg(not(feature = "tee"))] pub mod balloon; +#[cfg(feature = "batch_queue")] +pub mod batch_queue; #[allow(dead_code)] #[allow(non_camel_case_types)] pub mod bindings; @@ -36,6 +38,8 @@ mod queue; pub mod rng; #[cfg(feature = "snd")] pub mod snd; +#[cfg(all(feature = "batch_queue", test))] +pub(crate) mod test_utils; pub mod vsock; #[cfg(not(feature = "tee"))] diff --git a/src/devices/src/virtio/net/backend.rs b/src/devices/src/virtio/net/backend.rs index b73b910b9..25f7036b5 100644 --- a/src/devices/src/virtio/net/backend.rs +++ b/src/devices/src/virtio/net/backend.rs @@ -18,29 +18,39 @@ pub enum ConnectError { #[allow(dead_code)] #[derive(Debug)] pub enum ReadError { - /// Nothing was written - NothingRead, - /// Another internal error occurred + /// Backend process not running (EPIPE) + ProcessNotRunning, + /// Internal I/O error Internal(nix::Error), } #[allow(dead_code)] #[derive(Debug)] pub enum WriteError { - /// Nothing was written, you can drop the frame or try to resend it later - NothingWritten, - /// Part of the buffer was written, the write has to be finished using try_finish_write - PartialWrite, - /// Passt doesnt seem to be running (received EPIPE) + /// Backend process not running (EPIPE) ProcessNotRunning, - /// Another internal error occurred + /// Internal I/O error Internal(nix::Error), } +/// Network backend trait. +/// +/// Backends own both the socket and the queue consumers. The send/recv methods +/// operate on internal queues. EAGAIN is not an error - it just means nothing +/// happened this call. pub trait NetBackend { - fn read_frame(&mut self, buf: &mut [u8]) -> Result; - fn write_frame(&mut self, hdr_len: usize, buf: &mut [u8]) -> Result<(), WriteError>; - fn has_unfinished_write(&self) -> bool; - fn try_finish_write(&mut self, hdr_len: usize, buf: &[u8]) -> Result<(), WriteError>; + /// Send pending frames from the TX queue to the network. + /// + /// Pulls frames from internal TxQueueConsumer and sends using batched I/O. + /// EAGAIN returns Ok(()) - pending frames kept for retry. + fn send(&mut self) -> Result<(), WriteError>; + + /// Receive frames from the network into the RX queue. + /// + /// Reads from socket into internal RxQueueProvider. + /// EAGAIN returns Ok(()). + fn recv(&mut self) -> Result<(), ReadError>; + + /// Returns the raw socket fd for epoll registration. fn raw_socket_fd(&self) -> RawFd; } diff --git a/src/devices/src/virtio/net/device.rs b/src/devices/src/virtio/net/device.rs index 9d4b4a1fc..4db4df134 100644 --- a/src/devices/src/virtio/net/device.rs +++ b/src/devices/src/virtio/net/device.rs @@ -20,25 +20,18 @@ use std::cmp; use std::io::Write; use std::os::fd::RawFd; use std::path::PathBuf; -use virtio_bindings::virtio_net::VIRTIO_NET_F_MAC; -use virtio_bindings::virtio_ring::VIRTIO_RING_F_EVENT_IDX; -use vm_memory::{ByteValued, GuestMemoryError, GuestMemoryMmap}; +use virtio_bindings::{virtio_net::VIRTIO_NET_F_MAC, virtio_ring::VIRTIO_RING_F_EVENT_IDX}; +use vm_memory::{ByteValued, GuestMemoryMmap}; const VIRTIO_F_VERSION_1: u32 = 32; -#[derive(Debug)] -pub enum FrontendError { - DescriptorChainTooSmall, - EmptyQueue, - GuestMemory(GuestMemoryError), - QueueError(QueueError), - ReadOnlyDescriptor, -} +// FrontendError removed - no longer used with vectored I/O #[derive(Debug)] pub enum RxError { Backend(ReadError), DeviceError(DeviceError), + QueueError(QueueError), } #[derive(Debug)] @@ -54,6 +47,7 @@ struct VirtioNetConfig { mac: [u8; 6], status: u16, max_virtqueue_pairs: u16, + include_vnet_header: bool, } // Safe because it only has data and has no implicit padding. @@ -88,6 +82,7 @@ impl Net { cfg_backend: VirtioNetBackend, mac: [u8; 6], features: u32, + include_vnet_header: bool, ) -> Result { let avail_features = features as u64 | (1 << VIRTIO_NET_F_MAC) @@ -98,6 +93,7 @@ impl Net { mac, status: 0, max_virtqueue_pairs: 0, + include_vnet_header, }; Ok(Net { @@ -187,6 +183,7 @@ impl VirtioDevice for Net { interrupt.clone(), mem.clone(), self.acked_features, + self.config.include_vnet_header, self.cfg_backend.clone(), ) { Ok(worker) => { diff --git a/src/devices/src/virtio/net/mod.rs b/src/devices/src/virtio/net/mod.rs index d3e7e9739..8b9010c2f 100644 --- a/src/devices/src/virtio/net/mod.rs +++ b/src/devices/src/virtio/net/mod.rs @@ -6,13 +6,20 @@ use virtio_bindings::virtio_net::virtio_net_hdr_v1; use super::QueueConfig; -pub const MAX_BUFFER_SIZE: usize = 65562; +/// Each frame forwarded to a unixstream backend is prepended by a 4 byte "header". +/// It is interpreted as a big-endian u32 integer and is the length of the following ethernet frame. +/// In order to avoid unnecessary allocations and copies, the TX buffer is allocated with extra +/// space to accommodate this header. +const FRAME_HEADER_LEN: usize = 4; +pub const MAX_BUFFER_SIZE: usize = 65562 + FRAME_HEADER_LEN; const QUEUE_SIZE: u16 = 1024; pub const NUM_QUEUES: usize = 2; pub static QUEUE_CONFIG: [QueueConfig; NUM_QUEUES] = [QueueConfig::new(QUEUE_SIZE); NUM_QUEUES]; mod backend; pub mod device; +#[cfg(target_os = "macos")] +mod socket_x; #[cfg(target_os = "linux")] mod tap; mod unixgram; @@ -23,13 +30,10 @@ fn vnet_hdr_len() -> usize { mem::size_of::() } -// This initializes to all 0 the virtio_net_hdr part of a buf and return the length of the header -// https://docs.oasis-open.org/virtio/virtio/v1.1/csprd01/virtio-v1.1-csprd01.html#x1-2050006 -fn write_virtio_net_hdr(buf: &mut [u8]) -> usize { - let len = vnet_hdr_len(); - buf[0..len].fill(0); - len -} +/// Default zeroed virtio_net_hdr_v1 (12 bytes) - used as prefix when receiving from backends +/// that don't include vnet headers (e.g., passt/unixstream) +/// https://docs.oasis-open.org/virtio/virtio/v1.1/csprd01/virtio-v1.1-csprd01.html#x1-2050006 +static DEFAULT_VNET_HDR: [u8; 12] = [0u8; 12]; pub use self::device::Net; #[derive(Debug)] diff --git a/src/devices/src/virtio/net/socket_x.rs b/src/devices/src/virtio/net/socket_x.rs new file mode 100644 index 000000000..3e6aca512 --- /dev/null +++ b/src/devices/src/virtio/net/socket_x.rs @@ -0,0 +1,81 @@ +// macOS-specific batch message syscalls (sendmsg_x/recvmsg_x) +// +// These are private Apple APIs that allow sending/receiving multiple messages +// in a single syscall, similar to Linux's sendmmsg/recvmmsg. +// +// Reference: https://github.com/nirs/vmnet-helper/blob/main/socket_x.h + +#![allow(dead_code)] +#![allow(non_camel_case_types)] + +#[cfg(target_os = "macos")] +pub mod macos { + use libc::{c_int, c_uint, c_void, iovec, socklen_t}; + + /// Extended message header for batch operations. + /// Similar to msghdr but includes msg_datalen for output. + #[repr(C)] + pub struct msghdr_x { + pub msg_name: *mut c_void, + pub msg_namelen: socklen_t, + pub msg_iov: *mut iovec, + pub msg_iovlen: c_int, + pub msg_control: *mut c_void, + pub msg_controllen: socklen_t, + pub msg_flags: c_int, + pub msg_datalen: usize, // out: bytes transferred for this message + } + + impl Default for msghdr_x { + fn default() -> Self { + Self { + msg_name: std::ptr::null_mut(), + msg_namelen: 0, + msg_iov: std::ptr::null_mut(), + msg_iovlen: 0, + msg_control: std::ptr::null_mut(), + msg_controllen: 0, + msg_flags: 0, + msg_datalen: 0, + } + } + } + + extern "C" { + /// Send multiple datagrams in a single syscall. + /// + /// # Arguments + /// * `s` - Socket file descriptor + /// * `msgp` - Pointer to array of msghdr_x structures + /// * `cnt` - Number of messages to send + /// * `flags` - Only MSG_DONTWAIT is supported + /// + /// # Constraints + /// For each msghdr_x: msg_name, msg_namelen, msg_control, msg_controllen, + /// msg_flags, and msg_datalen must all be zero on input. + /// + /// # Returns + /// Number of datagrams sent, or -1 on error. + /// Each msghdr_x.msg_datalen is set to bytes sent for that message. + pub fn sendmsg_x(s: c_int, msgp: *const msghdr_x, cnt: c_uint, flags: c_int) -> isize; + + /// Receive multiple datagrams in a single syscall. + /// + /// # Arguments + /// * `s` - Socket file descriptor + /// * `msgp` - Pointer to array of msghdr_x structures + /// * `cnt` - Maximum number of messages to receive + /// * `flags` - Only MSG_DONTWAIT is supported + /// + /// # Constraints + /// For each msghdr_x: msg_flags must be zero on input. + /// + /// # Returns + /// Number of datagrams received (may be less than cnt), or -1 on error. + /// Each msghdr_x.msg_datalen is set to bytes received for that message. + pub fn recvmsg_x(s: c_int, msgp: *mut msghdr_x, cnt: c_uint, flags: c_int) -> isize; + } +} + +#[cfg(target_os = "macos")] +pub use macos::*; diff --git a/src/devices/src/virtio/net/tap.rs b/src/devices/src/virtio/net/tap.rs index 1c8bde34e..a36367199 100644 --- a/src/devices/src/virtio/net/tap.rs +++ b/src/devices/src/virtio/net/tap.rs @@ -2,18 +2,23 @@ use libc::{ c_char, c_int, ifreq, IFF_NO_PI, IFF_TAP, IFF_VNET_HDR, TUN_F_CSUM, TUN_F_TSO4, TUN_F_TSO6, TUN_F_UFO, }; -use nix::fcntl::{fcntl, open, FcntlArg, OFlag}; +use nix::fcntl::{open, OFlag}; use nix::sys::stat::Mode; -use nix::unistd::{read, write}; +use nix::sys::uio::{readv, writev}; use nix::{ioctl_write_int, ioctl_write_ptr}; -use std::os::fd::{AsRawFd, OwnedFd, RawFd}; +use std::os::fd::{AsFd, AsRawFd, OwnedFd, RawFd}; use std::{io, mem, ptr}; +use utils::fd::SetNonblockingExt; use virtio_bindings::virtio_net::{ VIRTIO_NET_F_GUEST_CSUM, VIRTIO_NET_F_GUEST_TSO4, VIRTIO_NET_F_GUEST_TSO6, VIRTIO_NET_F_GUEST_UFO, }; +use vm_memory::GuestMemoryMmap; use super::backend::{ConnectError, NetBackend, ReadError, WriteError}; +use crate::virtio::batch_queue::{RxQueueProducer, TxQueueConsumer}; +use crate::virtio::queue::Queue; +use crate::virtio::InterruptTransport; ioctl_write_ptr!(tunsetiff, b'T', 202, c_int); ioctl_write_int!(tunsetoffload, b'T', 208); @@ -21,11 +26,20 @@ ioctl_write_ptr!(tunsetvnethdrsz, b'T', 216, c_int); pub struct Tap { fd: OwnedFd, + tx_consumer: TxQueueConsumer, + rx_producer: RxQueueProducer, } impl Tap { /// Create an endpoint using the file descriptor of a tap device - pub fn new(tap_name: String, vnet_features: u64) -> Result { + pub fn new( + tap_name: String, + vnet_features: u64, + tx_queue: Queue, + rx_queue: Queue, + mem: GuestMemoryMmap, + interrupt: InterruptTransport, + ) -> Result { let fd = match open("/dev/net/tun", OFlag::O_RDWR, Mode::empty()) { Ok(fd) => fd, Err(err) => return Err(ConnectError::OpenNetTun(err)), @@ -43,6 +57,8 @@ impl Tap { req.ifr_ifru.ifru_flags = IFF_TAP as i16 | IFF_NO_PI as i16 | IFF_VNET_HDR as i16; + log::info!("Tap::new() fd={} tap={}", fd.as_raw_fd(), tap_name); + let mut offload_flags: u64 = 0; if (vnet_features & (1 << VIRTIO_NET_F_GUEST_CSUM)) != 0 { offload_flags |= TUN_F_CSUM as u64; @@ -62,7 +78,7 @@ impl Tap { return Err(ConnectError::TunSetIff(io::Error::from(err))); } - // TODO(slp): replace hardcoded vnet size with cons + // TODO(slp): replace hardcoded vnet size with const if let Err(err) = tunsetvnethdrsz(fd.as_raw_fd(), &12) { return Err(ConnectError::TunSetVnetHdrSz(io::Error::from(err))); } @@ -72,53 +88,73 @@ impl Tap { } } - match fcntl(&fd, FcntlArg::F_GETFL) { - Ok(flags) => { - if let Err(e) = fcntl( - &fd, - FcntlArg::F_SETFL(OFlag::from_bits_truncate(flags) | OFlag::O_NONBLOCK), - ) { - warn!("error switching to non-blocking: id={fd:?}, err={e}"); - } - } - Err(e) => error!("couldn't obtain fd flags id={fd:?}, err={e}"), - }; + if let Err(e) = fd.set_nonblocking(true) { + log::warn!("Failed to set O_NONBLOCK on tap: {e}"); + } - Ok(Self { fd }) + let tx_consumer = TxQueueConsumer::new(tx_queue, mem.clone(), interrupt.clone()); + let rx_provider = RxQueueProducer::new(rx_queue, mem, interrupt); + + Ok(Self { + fd, + tx_consumer, + rx_producer: rx_provider, + }) } } impl NetBackend for Tap { - /// Try to read a frame from the tap devie. If no bytes are available reports - /// ReadError::NothingRead. - fn read_frame(&mut self, buf: &mut [u8]) -> Result { - let frame_length = match read(&self.fd, buf) { - Ok(f) => f, - #[allow(unreachable_patterns)] - Err(nix::Error::EAGAIN | nix::Error::EWOULDBLOCK) => { - return Err(ReadError::NothingRead) - } - Err(e) => { - return Err(ReadError::Internal(e)); + fn send(&mut self) -> Result<(), WriteError> { + let fd = self.fd.as_fd(); + + self.tx_consumer.feed(); + + // Each descriptor chain is one packet. TAP's writev combines iovecs into + // a single packet, so we can use it directly without flattening. + // One writev syscall per packet. + self.tx_consumer.consume(|batch| { + for i in 0..batch.len() { + let chain = batch.io_slices(i); + if chain.is_empty() { + continue; + } + match writev(fd, chain) { + Ok(_) => batch.finish(i), + Err(nix::errno::Errno::EAGAIN) => break, + Err(e) => { + log::error!("writev to tap failed: {e:?}"); + break; + } + } } - }; - debug!("Read eth frame from tap: {frame_length} bytes"); - Ok(frame_length) - } + }); - /// Try to write a frame to the tap device. - fn write_frame(&mut self, _hdr_len: usize, buf: &mut [u8]) -> Result<(), WriteError> { - let ret = write(&self.fd, buf).map_err(WriteError::Internal)?; - debug!("Written frame size={}, written={}", buf.len(), ret); Ok(()) } - fn has_unfinished_write(&self) -> bool { - false - } + fn recv(&mut self) -> Result<(), ReadError> { + let fd = self.fd.as_fd(); + + self.rx_producer.feed(); + + self.rx_producer.produce(|batch| { + for i in 0..batch.len() { + let iovecs = batch.io_slices_mut(i); + if iovecs.is_empty() { + log::warn!("Chain {i} was empty"); + break; + } + + match readv(fd, iovecs) { + Ok(n) => batch.complete(i, n), + Err(nix::errno::Errno::EAGAIN) => break, + Err(e) => { + log::error!("readv from tap failed: {e:?}"); + } + } + } + }); - fn try_finish_write(&mut self, _hdr_len: usize, _buf: &[u8]) -> Result<(), WriteError> { - // The tap backend doesn't do partial writes. Ok(()) } diff --git a/src/devices/src/virtio/net/unixgram.rs b/src/devices/src/virtio/net/unixgram.rs index 04e230066..9621d9947 100644 --- a/src/devices/src/virtio/net/unixgram.rs +++ b/src/devices/src/virtio/net/unixgram.rs @@ -1,36 +1,178 @@ -use nix::fcntl::{fcntl, FcntlArg, OFlag}; +#[cfg(target_os = "macos")] +use libc::c_int; +use libc::iovec; +#[cfg(target_os = "linux")] +use libc::mmsghdr; use nix::sys::socket::{ - bind, connect, getsockopt, recv, send, setsockopt, socket, sockopt, AddressFamily, MsgFlags, + bind, connect, getsockopt, send, setsockopt, socket, sockopt, AddressFamily, MsgFlags, SockFlag, SockType, UnixAddr, }; -use nix::unistd::unlink; +use std::fs::remove_file; use std::os::fd::{AsRawFd, OwnedFd, RawFd}; use std::path::PathBuf; +use utils::fd::SetNonblockingExt; +use vm_memory::GuestMemoryMmap; use super::backend::{ConnectError, NetBackend, ReadError, WriteError}; -use super::write_virtio_net_hdr; +use crate::virtio::batch_queue::iovec_utils::{advance_tx_iovecs_vec, write_to_iovecs}; +use crate::virtio::batch_queue::{ChainsMemoryRepr, ReceivedLen, RxQueueProducer, TxQueueConsumer}; +use crate::virtio::queue::Queue; +use crate::virtio::InterruptTransport; + +#[cfg(target_os = "macos")] +use super::socket_x::msghdr_x; const VFKIT_MAGIC: [u8; 4] = *b"VFKT"; +// ============================================================================ +// MsgHdr - Chain representation that IS an mmsghdr/msghdr_x +// ============================================================================ + +#[cfg(target_os = "linux")] +type RawMsgHdr = mmsghdr; + +#[cfg(target_os = "macos")] +type RawMsgHdr = msghdr_x; + +/// Chain representation that wraps mmsghdr/msghdr_x. +/// +/// The iovec pointer is stored directly in the header, avoiding allocation +/// of a separate mmsghdr array for sendmmsg/sendmsg_x/recvmmsg/recvmsg_x. +/// +/// For RX, use `received_len()` to get the kernel-filled byte count. +/// +/// # Safety +/// Uses `mem::forget` to transfer iovec Vec ownership into the header. +/// The capacity is stored in `Meta` for proper cleanup via `Vec::from_raw_parts()`. +#[repr(transparent)] +pub struct MsgHdr(RawMsgHdr); + +// Safety: The raw pointer inside points to heap memory that we have exclusive ownership of. +// Transferring to another thread is safe because we transfer ownership of the entire struct. +unsafe impl Send for MsgHdr {} + +unsafe impl ChainsMemoryRepr for MsgHdr { + /// Stores the Vec capacity for cleanup + type Meta = usize; + + fn len(&self) -> usize { + #[cfg(target_os = "linux")] + { + self.0.msg_hdr.msg_iovlen + } + #[cfg(target_os = "macos")] + { + self.0.msg_iovlen as usize + } + } + + fn total_bytes(&self) -> usize { + let (ptr, len) = self.iov_ptr_len(); + if ptr.is_null() { + 0 + } else { + let slices = unsafe { std::slice::from_raw_parts(ptr as *const iovec, len) }; + slices.iter().map(|s| s.iov_len).sum() + } + } + + fn clear(&mut self, capacity: &mut Self::Meta) { + let (ptr, len) = self.iov_ptr_len(); + if !ptr.is_null() { + // Reconstruct Vec to drop it properly + unsafe { + let _: Vec = Vec::from_raw_parts(ptr, len, *capacity); + } + self.set_iov_null(); + *capacity = 0; + } + } +} + +impl MsgHdr { + /// Create MsgHdr from raw iovec pointer and length. + #[inline] + fn from_raw(iov_ptr: *mut iovec, len: usize) -> Self { + #[cfg(target_os = "linux")] + { + let mut hdr: mmsghdr = unsafe { std::mem::zeroed() }; + hdr.msg_hdr.msg_iov = iov_ptr; + hdr.msg_hdr.msg_iovlen = len; + Self(hdr) + } + + #[cfg(target_os = "macos")] + { + Self(msghdr_x { + msg_iov: iov_ptr, + msg_iovlen: len as c_int, + ..Default::default() + }) + } + } + + #[inline] + fn iov_ptr_len(&self) -> (*mut iovec, usize) { + #[cfg(target_os = "linux")] + { + (self.0.msg_hdr.msg_iov, self.0.msg_hdr.msg_iovlen) + } + #[cfg(target_os = "macos")] + { + (self.0.msg_iov, self.0.msg_iovlen as usize) + } + } + + #[inline] + fn set_iov_null(&mut self) { + #[cfg(target_os = "linux")] + { + self.0.msg_hdr.msg_iov = std::ptr::null_mut(); + self.0.msg_hdr.msg_iovlen = 0; + } + #[cfg(target_os = "macos")] + { + self.0.msg_iov = std::ptr::null_mut(); + self.0.msg_iovlen = 0; + } + } +} + +impl ReceivedLen for MsgHdr { + #[cfg(target_os = "linux")] + #[inline] + fn received_len(&self) -> usize { + self.0.msg_len as usize + } + + #[cfg(target_os = "macos")] + #[inline] + fn received_len(&self) -> usize { + self.0.msg_datalen + } +} + pub struct Unixgram { fd: OwnedFd, + include_vnet_header: bool, + tx_consumer: TxQueueConsumer, + rx_producer: RxQueueProducer, } impl Unixgram { /// Create the backend with a pre-established connection to the userspace network proxy. - pub fn new(fd: OwnedFd) -> Self { + pub fn new( + fd: OwnedFd, + include_vnet_header: bool, + tx_queue: Queue, + rx_queue: Queue, + mem: GuestMemoryMmap, + interrupt: InterruptTransport, + ) -> Self { // Ensure the socket is in non-blocking mode. - match fcntl(&fd, FcntlArg::F_GETFL) { - Ok(flags) => match OFlag::from_bits(flags) { - Some(flags) => { - if let Err(e) = fcntl(&fd, FcntlArg::F_SETFL(flags | OFlag::O_NONBLOCK)) { - warn!("error switching to non-blocking: id={fd:?}, err={e}"); - } - } - None => error!("invalid fd flags id={fd:?}"), - }, - Err(e) => error!("couldn't obtain fd flags id={fd:?}, err={e}"), - }; + if let Err(e) = fd.set_nonblocking(true) { + log::error!("Failed to set O_NONBLOCK on unixgram socket: {e}"); + } #[cfg(target_os = "macos")] { @@ -47,11 +189,27 @@ impl Unixgram { }; } - Self { fd } + let tx_consumer = TxQueueConsumer::new(tx_queue, mem.clone(), interrupt.clone()); + let rx_producer = RxQueueProducer::new(rx_queue, mem, interrupt); + + Self { + fd, + include_vnet_header, + tx_consumer, + rx_producer, + } } /// Create the backend opening a connection to the userspace network proxy. - pub fn open(path: PathBuf, send_vfkit_magic: bool) -> Result { + pub fn open( + path: PathBuf, + send_vfkit_magic: bool, + include_vnet_header: bool, + tx_queue: Queue, + rx_queue: Queue, + mem: GuestMemoryMmap, + interrupt: InterruptTransport, + ) -> Result { // We cannot create a non-blocking socket on macOS here. This is done later in new(). let fd = socket( AddressFamily::Unix, @@ -64,7 +222,7 @@ impl Unixgram { let local_addr = UnixAddr::new(&PathBuf::from(format!("{}-krun.sock", path.display()))) .map_err(ConnectError::InvalidAddress)?; if let Some(path) = local_addr.path() { - _ = unlink(path); + _ = remove_file(path); } bind(fd.as_raw_fd(), &local_addr).map_err(ConnectError::Binding)?; @@ -81,7 +239,7 @@ impl Unixgram { log::warn!("Failed to increase SO_SNDBUF (performance may be decreased): {e}"); } if let Err(e) = setsockopt(&fd, sockopt::RcvBuf, &(7 * 1024 * 1024)) { - log::warn!("Failed to increase SO_SNDBUF (performance may be decreased): {e}"); + log::warn!("Failed to increase SO_RCVBUF (performance may be decreased): {e}"); } log::debug!( @@ -90,50 +248,285 @@ impl Unixgram { getsockopt(&fd, sockopt::RcvBuf) ); - Ok(Self::new(fd)) + Ok(Self::new( + fd, + include_vnet_header, + tx_queue, + rx_queue, + mem, + interrupt, + )) } } impl NetBackend for Unixgram { - /// Try to read a frame the proxy. If no bytes are available reports ReadError::NothingRead - fn read_frame(&mut self, buf: &mut [u8]) -> Result { - let hdr_len = write_virtio_net_hdr(buf); - let frame_length = match recv(self.fd.as_raw_fd(), &mut buf[hdr_len..], MsgFlags::empty()) { - Ok(f) => f, - #[allow(unreachable_patterns)] - Err(nix::Error::EAGAIN | nix::Error::EWOULDBLOCK) => { - return Err(ReadError::NothingRead) - } - Err(e) => { - return Err(ReadError::Internal(e)); + fn send(&mut self) -> Result<(), WriteError> { + let skip = if !self.include_vnet_header { + super::vnet_hdr_len() + } else { + 0 + }; + + // Feed frames from queue, skipping vnet header + let fed = self.tx_consumer.feed_with_transform(|mut iovecs| { + let orig_len = iovecs.len(); + let orig_bytes: usize = iovecs.iter().map(|s| s.len()).sum(); + if skip > 0 { + advance_tx_iovecs_vec(&mut iovecs, skip); } + let ptr = iovecs.as_mut_ptr() as *mut iovec; + let len = iovecs.len(); + let cap = iovecs.capacity(); + let total_bytes: usize = unsafe { + std::slice::from_raw_parts(ptr as *const iovec, len) + .iter() + .map(|iov| iov.iov_len) + .sum() + }; + log::info!( + "TX feed: orig_iovecs={} orig_bytes={} after_skip: iovecs={} bytes={} cap={}", + orig_len, + orig_bytes, + len, + total_bytes, + cap + ); + std::mem::forget(iovecs); + (MsgHdr::from_raw(ptr, len), cap) + }); + if fed > 0 { + log::info!( + "TX: fed {} chains, pending={}", + fed, + self.tx_consumer.pending_count() + ); + } + + if !self.tx_consumer.has_pending() { + return Ok(()); + } + + #[cfg(target_os = "linux")] + self.send_linux()?; + + #[cfg(target_os = "macos")] + self.send_macos()?; + + Ok(()) + } + + fn recv(&mut self) -> Result<(), ReadError> { + let vnet_offset = if !self.include_vnet_header { + super::vnet_hdr_len() + } else { + 0 }; - debug!("Read eth frame from proxy: {frame_length} bytes"); - Ok(hdr_len + frame_length) - } - - /// Try to write a frame to the proxy. - fn write_frame(&mut self, hdr_len: usize, buf: &mut [u8]) -> Result<(), WriteError> { - let ret = send(self.fd.as_raw_fd(), &buf[hdr_len..], MsgFlags::empty()) - .map_err(WriteError::Internal)?; - debug!( - "Written frame size={}, written={}", - buf.len() - hdr_len, - ret + log::info!( + "recv: include_vnet_header={} vnet_offset={}", + self.include_vnet_header, + vnet_offset ); + + // Feed chains from queue, writing vnet header and advancing iovecs during feed + let rx_fed = self.rx_producer.feed_with_transform(|mut iovecs| { + let orig_len = iovecs.len(); + let orig_bytes: usize = iovecs.iter().map(|s| s.len()).sum(); + if vnet_offset > 0 { + // Write default vnet header to beginning of buffer + write_to_iovecs(&mut iovecs, &super::DEFAULT_VNET_HDR); + // Advance iovecs past vnet header so receive goes after it + crate::virtio::batch_queue::iovec_utils::advance_iovecs_vec( + &mut iovecs, + vnet_offset, + ); + } + let ptr = iovecs.as_mut_ptr() as *mut iovec; + let len = iovecs.len(); + let cap = iovecs.capacity(); + log::info!( + "RX feed: orig_iovecs={} orig_bytes={} after_vnet: iovecs={} cap={}", + orig_len, + orig_bytes, + len, + cap + ); + std::mem::forget(iovecs); + (MsgHdr::from_raw(ptr, len), cap) + }); + if rx_fed > 0 { + log::info!( + "RX: fed {} chains, pending={}", + rx_fed, + self.rx_producer.pending_count() + ); + } + + #[cfg(target_os = "linux")] + self.recv_linux(); + + #[cfg(target_os = "macos")] + self.recv_macos(); + Ok(()) } - fn has_unfinished_write(&self) -> bool { - false + fn raw_socket_fd(&self) -> RawFd { + self.fd.as_raw_fd() } +} + +#[cfg(target_os = "linux")] +impl Unixgram { + fn send_linux(&mut self) -> Result<(), WriteError> { + let fd = self.fd.as_raw_fd(); + + self.tx_consumer.consume(|batch| { + let len = batch.len(); + let chains = batch.chains(0..len); + let ptr = chains.as_ptr() as *mut mmsghdr; + + let ret = unsafe { libc::sendmmsg(fd, ptr, len as libc::c_uint, libc::MSG_DONTWAIT) }; + + if ret < 0 { + let err = std::io::Error::last_os_error(); + match err.kind() { + std::io::ErrorKind::WouldBlock => {} + _ => { + log::error!("sendmmsg failed: {err}"); + } + } + return; + } + + batch.finish_many(0..ret as usize); + }); - fn try_finish_write(&mut self, _hdr_len: usize, _buf: &[u8]) -> Result<(), WriteError> { - // The unixgram backend doesn't do partial writes. Ok(()) } - fn raw_socket_fd(&self) -> RawFd { - self.fd.as_raw_fd() + fn recv_linux(&mut self) { + let fd = self.fd.as_raw_fd(); + + self.rx_producer.produce(|batch| { + let len = batch.len(); + let ret = { + let storage = batch.chains_mut(0..len); + let ptr = storage.as_mut_ptr() as *mut mmsghdr; + unsafe { + libc::recvmmsg( + fd, + ptr, + len as libc::c_uint, + libc::MSG_DONTWAIT, + std::ptr::null_mut(), + ) + } + }; + + match ret { + n if n > 0 => { + batch.complete_received_many(0..n as usize); + } + 0 => log::warn!("recvmmsg returned 0 (unexpected)"), + _ => { + let err = std::io::Error::last_os_error(); + if err.kind() != std::io::ErrorKind::WouldBlock { + log::error!("recvmmsg failed: {err}"); + } + } + } + }); + } +} + +#[cfg(target_os = "macos")] +impl Unixgram { + fn send_macos(&mut self) -> Result<(), WriteError> { + let fd = self.fd.as_raw_fd(); + + self.tx_consumer.consume(|batch| { + let len = batch.len(); + // Safety: No chains have been completed yet, so 0..len is valid. + let storage = batch.chains(0..len); + let ptr = storage.as_ptr() as *const super::socket_x::msghdr_x; + + // Debug: log each msghdr_x before sending + for i in 0..len { + let hdr = unsafe { &*ptr.add(i) }; + let total: usize = if !hdr.msg_iov.is_null() && hdr.msg_iovlen > 0 { + unsafe { + std::slice::from_raw_parts(hdr.msg_iov, hdr.msg_iovlen as usize) + .iter() + .map(|iov| iov.iov_len) + .sum() + } + } else { + 0 + }; + log::info!( + "sendmsg_x[{}]: iovlen={} total_bytes={} msg_datalen={} msg_flags={} msg_name={:?} msg_control={:?}", + i, hdr.msg_iovlen, total, hdr.msg_datalen, hdr.msg_flags, hdr.msg_name, hdr.msg_control + ); + } + + let ret = unsafe { + super::socket_x::sendmsg_x( + fd, + ptr, + len as libc::c_uint, + libc::MSG_DONTWAIT, + ) + }; + + log::info!("sendmsg_x(fd={}, cnt={}) = {}", fd, len, ret); + + if ret < 0 { + let err = std::io::Error::last_os_error(); + log::info!("sendmsg_x error: {:?} (raw={})", err.kind(), err.raw_os_error().unwrap_or(-1)); + match err.kind() { + std::io::ErrorKind::WouldBlock => {} + _ => { + log::error!("sendmsg_x failed: {err:?}"); + } + } + return; + } + + batch.finish_many(0..ret as usize); + }); + + Ok(()) + } + + fn recv_macos(&mut self) { + let fd = self.fd.as_raw_fd(); + + self.rx_producer.produce(|batch| { + log::info!("recv_macos: {} chains available", batch.len()); + + let len = batch.len(); + let ret = { + let storage = batch.chains_mut(0..len); + let ptr = storage.as_mut_ptr() as *mut super::socket_x::msghdr_x; + unsafe { + super::socket_x::recvmsg_x(fd, ptr, len as libc::c_uint, libc::MSG_DONTWAIT) + } + }; + + log::info!("recvmsg_x(fd={}, cnt={}) = {}", fd, len, ret); + + match ret { + n if n > 0 => { + batch.complete_received_many(0..n as usize); + } + 0 => log::warn!("recvmsg_x returned 0 (unexpected)"), + _ => { + let err = std::io::Error::last_os_error(); + if err.kind() != std::io::ErrorKind::WouldBlock { + log::error!("recvmsg_x failed: {err}"); + } + } + } + }); } } diff --git a/src/devices/src/virtio/net/unixstream.rs b/src/devices/src/virtio/net/unixstream.rs index 023be6b28..0f43004ea 100644 --- a/src/devices/src/virtio/net/unixstream.rs +++ b/src/devices/src/virtio/net/unixstream.rs @@ -1,32 +1,87 @@ use nix::sys::socket::{ - connect, getsockopt, recv, send, setsockopt, socket, sockopt, AddressFamily, MsgFlags, - SockFlag, SockType, UnixAddr, -}; -use std::{ - os::fd::{AsRawFd, OwnedFd, RawFd}, - path::PathBuf, + connect, getsockopt, setsockopt, socket, sockopt, AddressFamily, SockFlag, SockType, UnixAddr, }; +use nix::sys::uio::readv; +use nix::unistd::read; +use std::io::IoSlice; +use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd}; +use std::path::PathBuf; +use utils::fd::SetNonblockingExt; +use vm_memory::GuestMemoryMmap; +use crate::virtio::batch_queue::iovec_utils::{advance_tx_iovecs_vec, iovecs_len, truncate_iovecs}; +use crate::virtio::batch_queue::{IovecVec, RxQueueProducer, TxQueueConsumer}; use crate::virtio::net::backend::ConnectError; +use crate::virtio::queue::Queue; +use crate::virtio::InterruptTransport; use super::backend::{NetBackend, ReadError, WriteError}; -use super::write_virtio_net_hdr; +use super::FRAME_HEADER_LEN; + +/// Helper to convert IoSlice to IovecVec +fn to_iovec(iovecs: Vec>) -> IovecVec { + IovecVec(unsafe { std::mem::transmute::>, Vec>(iovecs) }) +} -/// Each frame the network proxy is prepended by a 4 byte "header". -/// It is interpreted as a big-endian u32 integer and is the length of the following ethernet frame. -const FRAME_HEADER_LEN: usize = 4; +/// Try to read/complete the frame length header. +/// Returns Some(frame_len) when complete, None if incomplete or EAGAIN. +fn try_read_frame_header( + fd: BorrowedFd, + header_buf: &mut [u8; FRAME_HEADER_LEN], + header_pos: &mut usize, + expecting: &mut Option, +) -> Option { + if let Some(len) = *expecting { + return Some(len as usize); + } + + let remaining = &mut header_buf[*header_pos..]; + match read(fd, remaining) { + Ok(n) if n > 0 => { + *header_pos += n; + if *header_pos == FRAME_HEADER_LEN { + let len = u32::from_be_bytes(*header_buf); + *expecting = Some(len); + *header_pos = 0; + Some(len as usize) + } else { + None + } + } + _ => None, + } +} pub struct Unixstream { fd: OwnedFd, - // 0 when a frame length has not been read - expecting_frame_length: u32, - // 0 if last write is fully complete, otherwise the length that was written - last_partial_write_length: usize, + backend_handles_vnet_hdr: bool, + tx_consumer: TxQueueConsumer, + rx_producer: RxQueueProducer, + /// For RX: partial frame length header buffer + rx_header_buf: [u8; FRAME_HEADER_LEN], + /// For RX: bytes read into rx_header_buf so far + rx_header_pos: usize, + /// For RX: expected frame length (None when header not yet complete) + expecting_frame_length: Option, + // TODO: lets have one allocation ptr for the u32 sending length box, and use that for every + // packet where we need to send the length or actually it could even be our expecting_frame_length LOL } impl Unixstream { /// Create the backend with a pre-established connection to the userspace network proxy. - pub fn new(fd: OwnedFd) -> Self { + pub fn new( + fd: OwnedFd, + backend_handles_vnet_hdr: bool, + tx_queue: Queue, + rx_queue: Queue, + mem: GuestMemoryMmap, + interrupt: InterruptTransport, + ) -> Self { + // Set socket to non-blocking mode (critical for epoll-based event loop) + if let Err(e) = fd.set_nonblocking(true) { + log::error!("Failed to set O_NONBLOCK on the socket: {e}"); + } + if let Err(e) = setsockopt(&fd, sockopt::SndBuf, &(16 * 1024 * 1024)) { log::warn!("Failed to increase SO_SNDBUF (performance may be decreased): {e}"); } @@ -37,22 +92,42 @@ impl Unixstream { getsockopt(&fd, sockopt::RcvBuf) ); + let tx_consumer = TxQueueConsumer::new(tx_queue, mem.clone(), interrupt.clone()); + let rx_provider = RxQueueProducer::new(rx_queue, mem, interrupt); + Self { fd, - expecting_frame_length: 0, - last_partial_write_length: 0, + backend_handles_vnet_hdr, + tx_consumer, + rx_producer: rx_provider, + rx_header_buf: [0u8; FRAME_HEADER_LEN], + rx_header_pos: 0, + expecting_frame_length: None, } } /// Create the backend opening a connection to the userspace network proxy. - pub fn open(path: PathBuf) -> Result { - let fd = socket( - AddressFamily::Unix, - SockType::Stream, - SockFlag::empty(), - None, - ) - .map_err(ConnectError::CreateSocket)?; + pub fn open( + path: PathBuf, + include_vnet_header: bool, + tx_queue: Queue, + rx_queue: Queue, + mem: GuestMemoryMmap, + interrupt: InterruptTransport, + ) -> Result { + #[cfg(target_os = "linux")] + let flags = SockFlag::SOCK_NONBLOCK | SockFlag::SOCK_CLOEXEC; + #[cfg(not(target_os = "linux"))] + let flags = SockFlag::empty(); + + let fd = socket(AddressFamily::Unix, SockType::Stream, flags, None) + .map_err(ConnectError::CreateSocket)?; + + // On macOS, set nonblocking after socket creation since SOCK_NONBLOCK isn't available + #[cfg(not(target_os = "linux"))] + fd.set_nonblocking(true).map_err(|e| { + ConnectError::CreateSocket(nix::Error::from_raw(e.raw_os_error().unwrap_or(libc::EIO))) + })?; let peer_addr = UnixAddr::new(&path).map_err(ConnectError::InvalidAddress)?; connect(fd.as_raw_fd(), &peer_addr).map_err(ConnectError::Binding)?; @@ -66,152 +141,124 @@ impl Unixstream { getsockopt(&fd, sockopt::RcvBuf) ); - Ok(Self { + Ok(Self::new( fd, - expecting_frame_length: 0, - last_partial_write_length: 0, - }) + include_vnet_header, + tx_queue, + rx_queue, + mem, + interrupt, + )) } +} - /// Try to read until filling the whole slice. - fn read_loop(&self, buf: &mut [u8], block_until_has_data: bool) -> Result<(), ReadError> { - let mut bytes_read = 0; - #[cfg(target_os = "linux")] - let flags = MsgFlags::MSG_DONTWAIT | MsgFlags::MSG_NOSIGNAL; - #[cfg(target_os = "macos")] - let flags = MsgFlags::MSG_DONTWAIT; - - if !block_until_has_data { - match recv(self.fd.as_raw_fd(), buf, flags) { - Ok(size) => bytes_read += size, - #[allow(unreachable_patterns)] - Err(nix::Error::EAGAIN | nix::Error::EWOULDBLOCK) => { - return Err(ReadError::NothingRead) - } - Err(e) => return Err(ReadError::Internal(e)), - } - } +impl NetBackend for Unixstream { + fn send(&mut self) -> Result<(), WriteError> { + log::trace!("Unixstream::send() called"); + let skip = if !self.backend_handles_vnet_hdr { + super::vnet_hdr_len() + } else { + 0 + }; - #[cfg(target_os = "linux")] - let flags = MsgFlags::MSG_WAITALL | MsgFlags::MSG_NOSIGNAL; - #[cfg(target_os = "macos")] - let flags = MsgFlags::MSG_WAITALL; - - while bytes_read < buf.len() { - match recv(self.fd.as_raw_fd(), &mut buf[bytes_read..], flags) { - #[allow(unreachable_patterns)] - Err(nix::Error::EAGAIN | nix::Error::EWOULDBLOCK) => { - log::warn!("read_loop: unexpected EAGAIN/EWOULDBLOCK on blocking socket"); - continue; - } - Err(e) => return Err(ReadError::Internal(e)), - Ok(size) => { - bytes_read += size; - //log::trace!("proxy recv {}/{}", bytes_read, buf.len()); - } - } + // Feed frames from queue, prepending frame length header + let fed = self.tx_consumer.feed_with_transform(|mut iovecs| { + // Skip vnet header + advance_tx_iovecs_vec(&mut iovecs, skip); + + // Calculate payload length (after vnet skip) + let payload_len = iovecs_len(&iovecs); + + // FIXME: This leaks memory! Need proper header storage in TxQueueConsumer. + // For now, Box::leak the header bytes to get 'static lifetime. + let header = Box::leak(Box::new((payload_len as u32).to_be_bytes())); + iovecs.insert(0, IoSlice::new(header)); + (to_iovec(iovecs), ()) + }); + log::trace!( + "Unixstream::send() fed {} frames, pending={}", + fed, + self.tx_consumer.pending_count() + ); + + if !self.tx_consumer.has_pending() { + return Ok(()); } - Ok(()) - } + let fd = self.fd.as_fd(); - fn write_loop(&mut self, buf: &[u8]) -> Result<(), WriteError> { - let mut bytes_send = 0; + // Chains already have header prepended, just writev each one + self.tx_consumer.consume(|batch| { + for i in 0..batch.len() { + let chain = batch.io_slices(i); + if chain.is_empty() { + continue; + } - #[cfg(target_os = "linux")] - let flags = MsgFlags::MSG_DONTWAIT | MsgFlags::MSG_NOSIGNAL; - #[cfg(target_os = "macos")] - let flags = MsgFlags::MSG_DONTWAIT; - - while bytes_send < buf.len() { - match send(self.fd.as_raw_fd(), &buf[bytes_send..], flags) { - Ok(size) => bytes_send += size, - #[allow(unreachable_patterns)] - Err(nix::Error::EAGAIN | nix::Error::EWOULDBLOCK) => { - if bytes_send == 0 { - return Err(WriteError::NothingWritten); - } else { - log::trace!( - "Wrote {bytes_send} bytes, but socket blocked, will need try_finish_write() to finish" - ); - - self.last_partial_write_length += bytes_send; - return Err(WriteError::PartialWrite); + match nix::sys::uio::writev(fd, chain) { + Ok(_) => batch.finish(i), + Err(nix::errno::Errno::EAGAIN) => break, + Err(e) => { + log::error!("writev to unixstream failed: {e:?}"); + break; } } - Err(nix::Error::EPIPE) => return Err(WriteError::ProcessNotRunning), - Err(e) => return Err(WriteError::Internal(e)), } - } - self.last_partial_write_length = 0; + }); + Ok(()) } -} -impl NetBackend for Unixstream { - /// Try to read a frame from the proxy. If no bytes are available reports ReadError::NothingRead - fn read_frame(&mut self, buf: &mut [u8]) -> Result { - if self.expecting_frame_length == 0 { - self.expecting_frame_length = { - let mut frame_length_buf = [0u8; FRAME_HEADER_LEN]; - self.read_loop(&mut frame_length_buf, false)?; - u32::from_be_bytes(frame_length_buf) - }; - } + fn recv(&mut self) -> Result<(), ReadError> { + let fd = unsafe { BorrowedFd::borrow_raw(self.fd.as_raw_fd()) }; + let vnet_offset = if !self.backend_handles_vnet_hdr { + super::vnet_hdr_len() + } else { + 0 + }; - let hdr_len = write_virtio_net_hdr(buf); - let buf = &mut buf[hdr_len..]; - let frame_length = self.expecting_frame_length as usize; - self.read_loop(&mut buf[..frame_length], false)?; - self.expecting_frame_length = 0; - log::trace!("Read eth frame from network proxy: {frame_length} bytes"); - Ok(hdr_len + frame_length) - } + self.rx_producer.feed(); - /// Try to write a frame to the proxy. - /// (Will mutate and override parts of buf, with a frame header!) - /// - /// * `hdr_len` - specifies the size of any existing headers encapsulating the ethernet frame, - /// (such as vnet header), that can be overwritten. Must be >= FRAME_HEADER_LEN. - /// * `buf` - the buffer to write to the proxy, `buf[..hdr_len]` may be overwritten - /// - /// If this function returns WriteError::PartialWrite, you have to finish the write using - /// try_finish_write. - fn write_frame(&mut self, hdr_len: usize, buf: &mut [u8]) -> Result<(), WriteError> { - if self.last_partial_write_length != 0 { - panic!("Cannot write a frame to the proxy, while a partial write is not resolved."); - } - assert!( - hdr_len >= FRAME_HEADER_LEN, - "Not enough space to write the frame header" - ); - assert!(buf.len() > hdr_len); - let frame_length = buf.len() - hdr_len; + let header_buf = &mut self.rx_header_buf; + let header_pos = &mut self.rx_header_pos; + let expecting = &mut self.expecting_frame_length; - buf[hdr_len - FRAME_HEADER_LEN..hdr_len] - .copy_from_slice(&(frame_length as u32).to_be_bytes()); + self.rx_producer.produce(|batch| { + for i in 0..batch.len() { + // Read frame header + let frame_len = match try_read_frame_header(fd, header_buf, header_pos, expecting) { + Some(len) => len, + None => break, + }; + let total_len = vnet_offset + frame_len; - self.write_loop(&buf[hdr_len - FRAME_HEADER_LEN..])?; - Ok(()) - } + // Write vnet header at start of new frame + if batch.bytes_used(i) == 0 && vnet_offset > 0 { + // Header is small, chain should always have space + let _ = batch.write_advance(i, &super::DEFAULT_VNET_HDR); + } - fn has_unfinished_write(&self) -> bool { - self.last_partial_write_length != 0 - } + // Read payload (truncated to remaining frame bytes) + let remaining = total_len - batch.bytes_used(i); + let iovecs = truncate_iovecs(batch.io_slices_mut(i), remaining); - /// Try to finish a partial write - /// - /// If no partial write is required will do nothing and return Ok(()) - /// - /// * `hdr_len` - must be the same value as passed to write_frame, that caused the partial write - /// * `buf` - must be same buffer that was given to write_frame, that caused the partial write - fn try_finish_write(&mut self, hdr_len: usize, buf: &[u8]) -> Result<(), WriteError> { - if self.last_partial_write_length != 0 { - let already_written = self.last_partial_write_length; - log::trace!("Requested to finish partial write"); - self.write_loop(&buf[hdr_len - FRAME_HEADER_LEN + already_written..])?; - log::debug!("Finished partial write ({already_written}bytes written before)") - } + match readv(fd, iovecs) { + Ok(n) if n > 0 => { + batch.advance(i, n); + if batch.bytes_used(i) >= total_len { + batch.finish(i); + *expecting = None; + } + } + Ok(_) => break, // EOF or 0 bytes + Err(nix::errno::Errno::EAGAIN) => break, + Err(e) => { + log::error!("readv from unixstream failed: {e:?}"); + break; + } + } + } + }); Ok(()) } diff --git a/src/devices/src/virtio/net/worker.rs b/src/devices/src/virtio/net/worker.rs index 222b781a0..0919122fd 100644 --- a/src/devices/src/virtio/net/worker.rs +++ b/src/devices/src/virtio/net/worker.rs @@ -3,85 +3,99 @@ use crate::virtio::net::backend::ConnectError; use crate::virtio::net::tap::Tap; use crate::virtio::net::unixgram::Unixgram; use crate::virtio::net::unixstream::Unixstream; -use crate::virtio::net::{MAX_BUFFER_SIZE, QUEUE_SIZE}; use crate::virtio::{DeviceQueue, InterruptTransport}; use super::backend::{NetBackend, ReadError, WriteError}; -use super::device::{FrontendError, RxError, TxError, VirtioNetBackend}; -use super::vnet_hdr_len; +use super::device::VirtioNetBackend; use std::os::fd::{AsRawFd, FromRawFd, OwnedFd}; +use std::sync::Arc; use std::thread; -use std::{cmp, result}; use utils::epoll::{ControlOperation, Epoll, EpollEvent, EventSet}; -use vm_memory::{Bytes, GuestAddress, GuestMemoryMmap}; +use utils::eventfd::EventFd; +use vm_memory::GuestMemoryMmap; pub struct NetWorker { - rx_q: DeviceQueue, - tx_q: DeviceQueue, - interrupt: InterruptTransport, - - mem: GuestMemoryMmap, + rx_evt: Arc, + tx_evt: Arc, backend: Box, - - rx_frame_buf: [u8; MAX_BUFFER_SIZE], - rx_frame_buf_len: usize, - rx_has_deferred_frame: bool, - - tx_iovec: Vec<(GuestAddress, usize)>, - tx_frame_buf: [u8; MAX_BUFFER_SIZE], - tx_frame_len: usize, } impl NetWorker { + #[allow(clippy::too_many_arguments)] pub fn new( rx_q: DeviceQueue, tx_q: DeviceQueue, interrupt: InterruptTransport, mem: GuestMemoryMmap, _vnet_features: u64, + include_vnet_header: bool, cfg_backend: VirtioNetBackend, ) -> Result { - let backend = match cfg_backend { + let DeviceQueue { + queue: rx_queue, + event: rx_evt, + } = rx_q; + let DeviceQueue { + queue: tx_queue, + event: tx_evt, + } = tx_q; + + let backend: Box = match cfg_backend { VirtioNetBackend::UnixstreamFd(fd) => { - // SAFETY: we need to trust that the library user has configured - // the backend with a healthy file descriptor. let owned_fd = unsafe { OwnedFd::from_raw_fd(fd) }; - Box::new(Unixstream::new(owned_fd)) as Box - } - VirtioNetBackend::UnixstreamPath(path) => { - Box::new(Unixstream::open(path)?) as Box - } + Box::new(Unixstream::new( + owned_fd, + include_vnet_header, + tx_queue, + rx_queue, + mem, + interrupt, + )) + } + VirtioNetBackend::UnixstreamPath(path) => Box::new(Unixstream::open( + path, + include_vnet_header, + tx_queue, + rx_queue, + mem, + interrupt, + )?), VirtioNetBackend::UnixgramFd(fd) => { - // SAFETY: we need to trust that the library user has configured - // the backend with a healthy file descriptor. let owned_fd = unsafe { OwnedFd::from_raw_fd(fd) }; - Box::new(Unixgram::new(owned_fd)) as Box - } - VirtioNetBackend::UnixgramPath(path, vfkit_magic) => { - Box::new(Unixgram::open(path, vfkit_magic)?) as Box - } + Box::new(Unixgram::new( + owned_fd, + include_vnet_header, + tx_queue, + rx_queue, + mem, + interrupt, + )) + } + VirtioNetBackend::UnixgramPath(path, vfkit_magic) => Box::new(Unixgram::open( + path, + vfkit_magic, + include_vnet_header, + tx_queue, + rx_queue, + mem, + interrupt, + )?), #[cfg(target_os = "linux")] - VirtioNetBackend::Tap(tap_name) => { - Box::new(Tap::new(tap_name, _vnet_features)?) as Box - } + VirtioNetBackend::Tap(tap_name) => Box::new(Tap::new( + tap_name, + _vnet_features, + tx_queue, + rx_queue, + mem, + interrupt, + )?), }; Ok(Self { - rx_q, - tx_q, - - mem, + rx_evt, + tx_evt, backend, - interrupt, - - rx_frame_buf: [0u8; MAX_BUFFER_SIZE], - rx_frame_buf_len: 0, - rx_has_deferred_frame: false, - - tx_frame_buf: [0u8; MAX_BUFFER_SIZE], - tx_frame_len: 0, - tx_iovec: Vec::with_capacity(QUEUE_SIZE as usize), }) } @@ -93,30 +107,48 @@ impl NetWorker { } fn work(mut self) { - let virtq_rx_ev_fd = self.rx_q.event.as_raw_fd(); - let virtq_tx_ev_fd = self.tx_q.event.as_raw_fd(); + let virtq_rx_ev_fd = self.rx_evt.as_raw_fd(); + let virtq_tx_ev_fd = self.tx_evt.as_raw_fd(); let backend_socket = self.backend.raw_socket_fd(); let epoll = Epoll::new().unwrap(); - let _ = epoll.ctl( + if let Err(e) = epoll.ctl( ControlOperation::Add, virtq_rx_ev_fd, &EpollEvent::new(EventSet::IN, virtq_rx_ev_fd as u64), - ); - let _ = epoll.ctl( + ) { + log::error!( + "Failed to add rx_ev fd {} to epoll: {:?}", + virtq_rx_ev_fd, + e + ); + } + if let Err(e) = epoll.ctl( ControlOperation::Add, virtq_tx_ev_fd, &EpollEvent::new(EventSet::IN, virtq_tx_ev_fd as u64), - ); - let _ = epoll.ctl( + ) { + log::error!( + "Failed to add tx_ev fd {} to epoll: {:?}", + virtq_tx_ev_fd, + e + ); + } + if let Err(e) = epoll.ctl( ControlOperation::Add, backend_socket, &EpollEvent::new( - EventSet::IN | EventSet::OUT | EventSet::EDGE_TRIGGERED | EventSet::READ_HANG_UP, + EventSet::IN | EventSet::OUT | EventSet::READ_HANG_UP | EventSet::EDGE_TRIGGERED, backend_socket as u64, ), - ); + ) { + log::error!( + "Failed to add backend fd {} to epoll: {:?}", + backend_socket, + e + ); + } loop { let mut epoll_events = vec![EpollEvent::new(EventSet::empty(), 0); 32]; @@ -125,322 +157,82 @@ impl NetWorker { for event in &epoll_events[0..ev_cnt] { let source = event.fd(); let event_set = event.event_set(); - match event_set { - EventSet::IN if source == virtq_rx_ev_fd => { - self.process_rx_queue_event(); - } - EventSet::IN if source == virtq_tx_ev_fd => { - self.process_tx_queue_event(); - } - _ if source == backend_socket => { - if event_set.contains(EventSet::HANG_UP) - || event_set.contains(EventSet::READ_HANG_UP) - { - log::error!("Got {event_set:?} on backend fd, virtio-net will stop working"); - eprintln!("LIBKRUN VIRTIO-NET FATAL: Backend process seems to have quit or crashed! Networking is now disabled!"); - } else { - if event_set.contains(EventSet::IN) { - self.process_backend_socket_readable() - } + log::trace!( + "virtio-net epoll event: fd={} event_set={:?}", + source, + event_set + ); + + if source == virtq_rx_ev_fd && event_set.contains(EventSet::IN) { + log::trace!("virtio-net: rx queue event"); + self.process_rx_queue_event(); + } else if source == virtq_tx_ev_fd && event_set.contains(EventSet::IN) { + log::trace!("virtio-net: tx queue event"); + self.process_tx_queue_event(); + } else if source == backend_socket { + if event_set.contains(EventSet::HANG_UP) + || event_set.contains(EventSet::READ_HANG_UP) + { + log::error!( + "Got {event_set:?} on backend fd, virtio-net will stop working" + ); + eprintln!("LIBKRUN VIRTIO-NET FATAL: Backend process seems to have quit or crashed! Networking is now disabled!"); + } else { + if event_set.contains(EventSet::IN) { + self.process_rx(); + } - if event_set.contains(EventSet::OUT) { - self.process_backend_socket_writeable() - } + if event_set.contains(EventSet::OUT) { + self.process_tx(); } } - _ => { - log::warn!( - "Received unknown event: {event_set:?} from fd: {source:?}" - ); - } + } else { + log::warn!("Received unknown event: {event_set:?} from fd: {source:?}"); } } } Err(e) => { - debug!("vsock: failed to consume muxer epoll event: {e}"); + debug!("virtio-net: failed to consume epoll event: {e}"); } } } } - pub(crate) fn process_rx_queue_event(&mut self) { - if let Err(e) = self.rx_q.event.read() { + fn process_rx_queue_event(&mut self) { + if let Err(e) = self.rx_evt.read() { log::error!("Failed to get rx event from queue: {e:?}"); } - if let Err(e) = self.rx_q.queue.disable_notification(&self.mem) { - error!("error disabling queue notifications: {e:?}"); - } - if let Err(e) = self.process_rx() { - log::error!("Failed to process rx: {e:?} (triggered by queue event)") - }; - if let Err(e) = self.rx_q.queue.enable_notification(&self.mem) { - error!("error disabling queue notifications: {e:?}"); - } - } - - pub(crate) fn process_tx_queue_event(&mut self) { - match self.tx_q.event.read() { - Ok(_) => self.process_tx_loop(), - Err(e) => { - log::error!("Failed to get tx queue event from queue: {e:?}"); - } - } - } - - pub(crate) fn process_backend_socket_readable(&mut self) { - if let Err(e) = self.rx_q.queue.enable_notification(&self.mem) { - error!("error disabling queue notifications: {e:?}"); - } - if let Err(e) = self.process_rx() { - log::error!("Failed to process rx: {e:?} (triggered by backend socket readable)"); - }; - if let Err(e) = self.rx_q.queue.disable_notification(&self.mem) { - error!("error disabling queue notifications: {e:?}"); - } - } - - pub(crate) fn process_backend_socket_writeable(&mut self) { - match self - .backend - .try_finish_write(vnet_hdr_len(), &self.tx_frame_buf[..self.tx_frame_len]) - { - Ok(()) => self.process_tx_loop(), - Err(WriteError::PartialWrite | WriteError::NothingWritten) => {} - Err(e @ WriteError::Internal(_)) => { - log::error!("Failed to finish write: {e:?}"); - } - Err(e @ WriteError::ProcessNotRunning) => { - log::debug!("Failed to finish write: {e:?}"); - } - } - } - - fn process_rx(&mut self) -> result::Result<(), RxError> { - // if we have a deferred frame we try to process it first, - // if that is not possible, we don't continue processing other frames - if self.rx_has_deferred_frame { - if self.write_frame_to_guest() { - self.rx_has_deferred_frame = false; - } else { - return Ok(()); - } - } - - let mut signal_queue = false; - - // Read as many frames as possible. - let result = loop { - match self.read_into_rx_frame_buf_from_backend() { - Ok(()) => { - if self.write_frame_to_guest() { - signal_queue = true; - } else { - self.rx_has_deferred_frame = true; - break Ok(()); - } - } - Err(ReadError::NothingRead) => break Ok(()), - Err(e @ ReadError::Internal(_)) => break Err(RxError::Backend(e)), - } - }; - - // At this point we processed as many Rx frames as possible. - // We have to wake the guest if at least one descriptor chain has been used. - if signal_queue { - self.interrupt - .try_signal_used_queue() - .map_err(RxError::DeviceError)?; - } - - result + self.process_rx(); } - fn process_tx_loop(&mut self) { - loop { - self.tx_q.queue.disable_notification(&self.mem).unwrap(); - - if let Err(e) = self.process_tx() { - log::error!("Failed to process rx: {e:?} (triggered by backend socket readable)"); - }; - - if !self.tx_q.queue.enable_notification(&self.mem).unwrap() { - break; - } + fn process_tx_queue_event(&mut self) { + if let Err(e) = self.tx_evt.read() { + log::error!("Failed to get tx queue event from queue: {e:?}"); } + self.process_tx(); } - fn process_tx(&mut self) -> result::Result<(), TxError> { - let tx_queue = &mut self.tx_q.queue; - - if self.backend.has_unfinished_write() - && self - .backend - .try_finish_write(vnet_hdr_len(), &self.tx_frame_buf[..self.tx_frame_len]) - .is_err() - { - log::trace!("Cannot process tx because of unfinished partial write!"); - return Ok(()); - } - - let mut raise_irq = false; - - while let Some(head) = tx_queue.pop(&self.mem) { - let head_index = head.index; - let mut next_desc = Some(head); - - self.tx_iovec.clear(); - while let Some(desc) = next_desc { - if desc.is_write_only() { - self.tx_iovec.clear(); - break; - } - self.tx_iovec.push((desc.addr, desc.len as usize)); - next_desc = desc.next_descriptor(); - } - - // Copy buffer from across multiple descriptors. - let mut read_count = 0; - for (desc_addr, desc_len) in self.tx_iovec.drain(..) { - let limit = cmp::min(read_count + desc_len, self.tx_frame_buf.len()); - - let read_result = self - .mem - .read_slice(&mut self.tx_frame_buf[read_count..limit], desc_addr); - match read_result { - Ok(()) => { - read_count += limit - read_count; - } - Err(e) => { - log::error!("Failed to read slice: {e:?}"); - read_count = 0; - break; - } - } + fn process_rx(&mut self) { + match self.backend.recv() { + Ok(()) => {} + Err(ReadError::ProcessNotRunning) => { + log::error!("RX error: backend process not running"); } - - self.tx_frame_len = read_count; - match self - .backend - .write_frame(vnet_hdr_len(), &mut self.tx_frame_buf[..read_count]) - { - Ok(()) => { - self.tx_frame_len = 0; - tx_queue - .add_used(&self.mem, head_index, 0) - .map_err(TxError::QueueError)?; - raise_irq = true; - } - Err(WriteError::NothingWritten) => { - tx_queue.undo_pop(); - break; - } - Err(WriteError::PartialWrite) => { - log::trace!("process_tx: partial write"); - /* - This situation should be pretty rare, assuming reasonably sized socket buffers. - We have written only a part of a frame to the backend socket (the socket is full). - - The frame we have read from the guest remains in tx_frame_buf, and will be sent - later. - - Note that we cannot wait for the backend to process our sending frames, because - the backend could be blocked on sending a remainder of a frame to us - us waiting - for backend would cause a deadlock. - */ - tx_queue - .add_used(&self.mem, head_index, 0) - .map_err(TxError::QueueError)?; - raise_irq = true; - break; - } - Err(e @ WriteError::Internal(_) | e @ WriteError::ProcessNotRunning) => { - return Err(TxError::Backend(e)) - } + Err(ReadError::Internal(e)) => { + log::error!("RX error: {e:?}"); } } - - if raise_irq && tx_queue.needs_notification(&self.mem).unwrap() { - self.interrupt - .try_signal_used_queue() - .map_err(TxError::DeviceError)?; - } - - Ok(()) } - // Copies a single frame from `self.rx_frame_buf` into the guest. - fn write_frame_to_guest_impl(&mut self) -> result::Result<(), FrontendError> { - let mut result: std::result::Result<(), FrontendError> = Ok(()); - - let queue = &mut self.rx_q.queue; - let head_descriptor = queue.pop(&self.mem).ok_or(FrontendError::EmptyQueue)?; - let head_index = head_descriptor.index; - - let mut frame_slice = &self.rx_frame_buf[..self.rx_frame_buf_len]; - - let frame_len = frame_slice.len(); - let mut maybe_next_descriptor = Some(head_descriptor); - while let Some(descriptor) = &maybe_next_descriptor { - if frame_slice.is_empty() { - break; + fn process_tx(&mut self) { + match self.backend.send() { + Ok(()) => {} + Err(WriteError::ProcessNotRunning) => { + log::error!("TX error: backend process not running"); } - - if !descriptor.is_write_only() { - result = Err(FrontendError::ReadOnlyDescriptor); - break; + Err(WriteError::Internal(e)) => { + log::error!("TX error: {e:?}"); } - - let len = std::cmp::min(frame_slice.len(), descriptor.len as usize); - match self.mem.write_slice(&frame_slice[..len], descriptor.addr) { - Ok(()) => { - frame_slice = &frame_slice[len..]; - } - Err(e) => { - log::error!("Failed to write slice: {e:?}"); - result = Err(FrontendError::GuestMemory(e)); - break; - } - }; - - maybe_next_descriptor = descriptor.next_descriptor(); } - if result.is_ok() && !frame_slice.is_empty() { - log::warn!("Receiving buffer is too small to hold frame of current size"); - result = Err(FrontendError::DescriptorChainTooSmall); - } - - // Mark the descriptor chain as used. If an error occurred, skip the descriptor chain. - let used_len = if result.is_err() { 0 } else { frame_len as u32 }; - queue - .add_used(&self.mem, head_index, used_len) - .map_err(FrontendError::QueueError)?; - result - } - - // Copies a single frame from `self.rx_frame_buf` into the guest. In case of an error retries - // the operation if possible. Returns true if the operation was successfull. - fn write_frame_to_guest(&mut self) -> bool { - let max_iterations = self.rx_q.queue.actual_size(); - for _ in 0..max_iterations { - match self.write_frame_to_guest_impl() { - Ok(()) => return true, - Err(FrontendError::EmptyQueue) => { - // retry - continue; - } - Err(_) => { - // retry - continue; - } - } - } - - false - } - - /// Fills self.rx_frame_buf with an ethernet frame from backend and prepends virtio_net_hdr to it - fn read_into_rx_frame_buf_from_backend(&mut self) -> result::Result<(), ReadError> { - self.rx_frame_buf_len = self.backend.read_frame(&mut self.rx_frame_buf)?; - Ok(()) } } diff --git a/src/devices/src/virtio/queue.rs b/src/devices/src/virtio/queue.rs index 2fb74289d..f1699678e 100644 --- a/src/devices/src/virtio/queue.rs +++ b/src/devices/src/virtio/queue.rs @@ -219,6 +219,18 @@ pub struct DescriptorChain<'a> { pub next: u16, } +impl<'a> fmt::Debug for DescriptorChain<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DescriptorChain") + .field("index", &self.index) + .field("addr", &format_args!("{:#018x}", self.addr.raw_value())) + .field("len", &self.len) + .field("flags", &format_args!("{:#06x}", self.flags)) + .field("next", &self.next) + .finish_non_exhaustive() + } +} + impl<'a> DescriptorChain<'a> { pub fn checked_new( mem: &GuestMemoryMmap, diff --git a/src/devices/src/virtio/test_utils.rs b/src/devices/src/virtio/test_utils.rs new file mode 100644 index 000000000..783b65038 --- /dev/null +++ b/src/devices/src/virtio/test_utils.rs @@ -0,0 +1,522 @@ +// Copyright 2026 Red Hat, Inc. +// SPDX-License-Identifier: Apache-2.0 + +//! Shared test utilities for TxQueueConsumer and RxQueueProducer tests. + +use std::cell::{Cell, RefCell}; +use std::mem::size_of; + +use vm_memory::{Address, Bytes, GuestAddress, GuestMemoryMmap}; + +use crate::legacy::DummyIrqChip; +use crate::virtio::queue::tests::{VIRTQ_DESC_F_NEXT, VIRTQ_DESC_F_WRITE}; +use crate::virtio::queue::{Descriptor, Queue, VirtqUsedElem}; +use crate::virtio::InterruptTransport; + +const MEM_SIZE: u64 = 0x100000; +/// Per-queue data region size (64 KB). +const DATA_REGION_SIZE: u64 = 0x10000; +/// Data regions start after queue structures. +const DATA_BASE: u64 = 0x10000; + +/// Test setup that owns guest memory and allocates non-overlapping queues. +pub struct TestSetup { + mem: GuestMemoryMmap, + /// Bump allocator for queue structures (low addresses). + next_struct_addr: Cell, + /// Number of queues created (used to partition data regions). + queue_count: Cell, +} + +impl TestSetup { + pub fn new() -> Self { + Self { + mem: GuestMemoryMmap::from_ranges(&[(GuestAddress(0), MEM_SIZE as usize)]).unwrap(), + next_struct_addr: Cell::new(0), + queue_count: Cell::new(0), + } + } + + pub fn mem(&self) -> &GuestMemoryMmap { + &self.mem + } + + /// Allocate `size` bytes at the next `align`-byte boundary. + fn alloc(&self, size: u64, align: u64) -> u64 { + let addr = self.next_struct_addr.get(); + let aligned = (addr + align - 1) & !(align - 1); + self.next_struct_addr.set(aligned + size); + assert!( + self.next_struct_addr.get() <= DATA_BASE, + "queue structures overflow into data area" + ); + aligned + } + + /// Create a queue with the given size and its corresponding driver. + pub fn create_queue(&self, size: u16) -> (Queue, VirtQueueDriver<'_>) { + let n = size as u64; + let ring_overhead = 3 * size_of::() as u64; // flags + idx + event + let desc_table = self.alloc(size_of::() as u64 * n, 16); + let avail_ring = self.alloc(ring_overhead + size_of::() as u64 * n, 2); + let used_ring = self.alloc(ring_overhead + size_of::() as u64 * n, 4); + + let mut queue = Queue::new(size); + queue.size = size; + queue.ready = true; + queue.desc_table = GuestAddress(desc_table); + queue.avail_ring = GuestAddress(avail_ring); + queue.used_ring = GuestAddress(used_ring); + + let idx = self.queue_count.get(); + self.queue_count.set(idx + 1); + let data_addr = DATA_BASE + idx as u64 * DATA_REGION_SIZE; + assert!( + data_addr + DATA_REGION_SIZE <= MEM_SIZE, + "out of data regions" + ); + + let driver = VirtQueueDriver::new(&queue, &self.mem, data_addr); + (queue, driver) + } +} + +/// Create an InterruptTransport for testing +pub fn create_interrupt() -> InterruptTransport { + InterruptTransport::new(DummyIrqChip::new().into(), "test".to_string()).unwrap() +} + +/// A segment within a descriptor chain (address + size + optional expected data) +#[derive(Clone)] +pub struct DescSegment { + /// Guest physical address of this segment + pub addr: u64, + /// Length of this segment + pub len: u32, + /// For readable segments: copy of expected data (None for writable) + pub expected_data: Option>, +} + +/// Information about a built descriptor chain +#[derive(Clone)] +pub struct BuiltChain { + /// Head descriptor index (used in add_used) + pub head_index: u16, + /// Segments in this chain + pub segments: Vec, +} + +impl BuiltChain { + /// Total length of all segments in this chain + pub fn total_len(&self) -> u32 { + self.segments.iter().map(|s| s.len).sum() + } + + /// Check if this chain is readable (TX - has expected data) + pub fn is_readable(&self) -> bool { + self.segments.iter().any(|s| s.expected_data.is_some()) + } + + /// Check if this chain is writable (RX - no expected data) + pub fn is_writable(&self) -> bool { + self.segments.iter().all(|s| s.expected_data.is_none()) + } +} + +/// Expected state for a chain in the used ring. +#[derive(Debug, Clone)] +pub enum ExpectedUsed<'a> { + /// Writable chain - verify content matches exactly + Writable(&'a [u8]), + /// Readable chain - verify wasn't modified, expect this length in used ring + Readable(u32), + /// Readable chain - verify wasn't modified, don't check length + ReadableAnyLen, +} + +/// Simulates the guest driver side of a VirtIO queue for testing. +/// +/// Communicates with the device ONLY through guest memory. +/// Supports incremental descriptor addition during tests. +/// Tracks chain metadata for verification (assert_used_len_exact, etc). +pub struct VirtQueueDriver<'a> { + mem: &'a GuestMemoryMmap, + /// Queue size (max descriptors) + queue_size: u16, + /// Descriptor table address in guest memory + desc_table: GuestAddress, + /// Available ring address in guest memory + avail_ring: GuestAddress, + /// Used ring address in guest memory + used_ring: GuestAddress, + /// Next descriptor table index to use + desc_idx: Cell, + /// Next available ring index (initialized from memory) + avail_idx: Cell, + /// Next memory address for data allocation + next_addr: Cell, + /// Tracked chains for verification + chains: RefCell>, +} + +impl<'a> VirtQueueDriver<'a> { + /// Create a new driver by extracting queue addresses from the Queue struct. + /// + /// The Queue reference is only used to get addresses - it is NOT stored. + /// All communication happens through guest memory. + pub fn new(queue: &Queue, mem: &'a GuestMemoryMmap, data_addr: u64) -> Self { + // Extract addresses from queue (not stored) + let desc_table = queue.desc_table; + let avail_ring = queue.avail_ring; + let used_ring = queue.used_ring; + let queue_size = queue.size; + + // Read current avail_idx from memory to support mid-test construction + let avail_idx_addr = avail_ring.unchecked_add(2); + let current_avail_idx: u16 = mem.read_obj(avail_idx_addr).unwrap_or(0); + + Self { + mem, + queue_size, + desc_table, + avail_ring, + used_ring, + desc_idx: Cell::new(current_avail_idx as usize), // Start after existing descriptors + avail_idx: Cell::new(current_avail_idx), + next_addr: Cell::new(data_addr), + chains: RefCell::new(Vec::new()), + } + } + + // ======================================================================== + // Chain building methods + // ======================================================================== + + /// Add a readable chain (for TX). Each slice in `segments` becomes a descriptor. + /// + /// Simple case (1 descriptor): `driver.readable(&[b"data"])` + /// Chained case: `driver.readable(&[b"header", b"payload"])` + pub fn readable(&self, segments: &[&[u8]]) -> &Self { + assert!( + !segments.is_empty(), + "readable chain must have at least one segment" + ); + let head_idx = self.desc_idx.get() as u16; + let mut chain_segments = Vec::new(); + + for (i, data) in segments.iter().enumerate() { + let addr = self.next_addr.get(); + self.next_addr.set(addr + data.len() as u64); + assert!(self.next_addr.get() <= MEM_SIZE, "out of memory"); + + // Write data to guest memory + self.mem.write(data, GuestAddress(addr)).unwrap(); + + let idx = self.desc_idx.get(); + assert!(idx < self.queue_size as usize, "descriptor table full"); + + let is_last = i == segments.len() - 1; + let flags = if is_last { 0 } else { VIRTQ_DESC_F_NEXT }; + let next = if is_last { 0 } else { (idx + 1) as u16 }; + + // Write descriptor to guest memory + self.write_descriptor(idx, addr, data.len() as u32, flags, next); + self.desc_idx.set(idx + 1); + + chain_segments.push(DescSegment { + addr, + len: data.len() as u32, + expected_data: Some(data.to_vec()), + }); + } + + // Add to available ring + self.add_to_avail_ring(head_idx); + + // Track chain + self.chains.borrow_mut().push(BuiltChain { + head_index: head_idx, + segments: chain_segments, + }); + + self + } + + /// Add a chain with readable prefix and writable suffix. + /// + /// This is used to test that RX handlers correctly skip readable descriptors. + /// Example: `driver.readable_then_writable(&[b"header"], &[1500])` + pub fn readable_then_writable(&self, readable: &[&[u8]], writable: &[u32]) -> &Self { + assert!( + !readable.is_empty() || !writable.is_empty(), + "chain must have at least one segment" + ); + let head_idx = self.desc_idx.get() as u16; + let mut chain_segments = Vec::new(); + let total_segments = readable.len() + writable.len(); + let mut segment_counter = 0; + + // Add readable descriptors + for data in readable.iter() { + let addr = self.next_addr.get(); + self.next_addr.set(addr + data.len() as u64); + assert!(self.next_addr.get() <= MEM_SIZE, "out of memory"); + + self.mem.write(data, GuestAddress(addr)).unwrap(); + + let idx = self.desc_idx.get(); + assert!(idx < self.queue_size as usize, "descriptor table full"); + + segment_counter += 1; + let is_last = segment_counter == total_segments; + let flags = if is_last { 0 } else { VIRTQ_DESC_F_NEXT }; + let next = if is_last { 0 } else { (idx + 1) as u16 }; + + self.write_descriptor(idx, addr, data.len() as u32, flags, next); + self.desc_idx.set(idx + 1); + + chain_segments.push(DescSegment { + addr, + len: data.len() as u32, + expected_data: Some(data.to_vec()), + }); + } + + // Add writable descriptors + for &len in writable.iter() { + let addr = self.next_addr.get(); + self.next_addr.set(addr + len as u64); + assert!(self.next_addr.get() <= MEM_SIZE, "out of memory"); + + let idx = self.desc_idx.get(); + assert!(idx < self.queue_size as usize, "descriptor table full"); + + segment_counter += 1; + let is_last = segment_counter == total_segments; + let flags = VIRTQ_DESC_F_WRITE | if is_last { 0 } else { VIRTQ_DESC_F_NEXT }; + let next = if is_last { 0 } else { (idx + 1) as u16 }; + + self.write_descriptor(idx, addr, len, flags, next); + self.desc_idx.set(idx + 1); + + chain_segments.push(DescSegment { + addr, + len, + expected_data: None, + }); + } + + self.add_to_avail_ring(head_idx); + + self.chains.borrow_mut().push(BuiltChain { + head_index: head_idx, + segments: chain_segments, + }); + + self + } + + /// Add a writable chain (for RX). Each length in `sizes` becomes a descriptor. + /// + /// Simple case (1 descriptor): `driver.writable(&[1500])` + /// Chained case: `driver.writable(&[12, 1500])` (e.g., header + payload) + pub fn writable(&self, sizes: &[u32]) -> &Self { + assert!( + !sizes.is_empty(), + "writable chain must have at least one segment" + ); + let head_idx = self.desc_idx.get() as u16; + let mut chain_segments = Vec::new(); + + for (i, &len) in sizes.iter().enumerate() { + let addr = self.next_addr.get(); + self.next_addr.set(addr + len as u64); + assert!(self.next_addr.get() <= MEM_SIZE, "out of memory"); + + let idx = self.desc_idx.get(); + assert!(idx < self.queue_size as usize, "descriptor table full"); + + let is_last = i == sizes.len() - 1; + let flags = VIRTQ_DESC_F_WRITE | if is_last { 0 } else { VIRTQ_DESC_F_NEXT }; + let next = if is_last { 0 } else { (idx + 1) as u16 }; + + // Write descriptor to guest memory + self.write_descriptor(idx, addr, len, flags, next); + self.desc_idx.set(idx + 1); + + chain_segments.push(DescSegment { + addr, + len, + expected_data: None, + }); + } + + // Add to available ring + self.add_to_avail_ring(head_idx); + + // Track chain + self.chains.borrow_mut().push(BuiltChain { + head_index: head_idx, + segments: chain_segments, + }); + + self + } + + fn write_descriptor(&self, idx: usize, addr: u64, len: u32, flags: u16, next: u16) { + let desc = Descriptor { + addr, + len, + flags, + next, + }; + let desc_addr = self.desc_table.unchecked_add((idx * 16) as u64); + self.mem.write_obj(desc, desc_addr).unwrap(); + } + + fn add_to_avail_ring(&self, desc_idx: u16) { + let avail_idx = self.avail_idx.get(); + + // Write descriptor index to ring[avail_idx] + // Available ring layout: flags(2) + idx(2) + ring[size](2*size) + let ring_entry_addr = self.avail_ring.unchecked_add(4 + (avail_idx as u64) * 2); + self.mem.write_obj(desc_idx, ring_entry_addr).unwrap(); + + // Increment and write avail idx + let new_avail_idx = avail_idx + 1; + self.avail_idx.set(new_avail_idx); + let avail_idx_addr = self.avail_ring.unchecked_add(2); + self.mem.write_obj(new_avail_idx, avail_idx_addr).unwrap(); + } + + // ======================================================================== + // Query methods + // ======================================================================== + + /// Get the used ring entries as (descriptor_id, len) pairs. + pub fn used_entries(&self) -> Vec<(u16, u32)> { + // Used ring layout: flags(2) + idx(2) + ring[size]({id:4, len:4}*size) + let used_idx_addr = self.used_ring.unchecked_add(2); + let used_idx: u16 = self.mem.read_obj(used_idx_addr).unwrap(); + + let mut entries = Vec::new(); + for i in 0..used_idx { + // Each used element is 8 bytes: u32 id, u32 len + let elem_addr = self.used_ring.unchecked_add(4 + (i as u64) * 8); + let id: u32 = self.mem.read_obj(elem_addr).unwrap(); + let len: u32 = self.mem.read_obj(elem_addr.unchecked_add(4)).unwrap(); + entries.push((id as u16, len)); + } + entries + } + + /// Get the number of used ring entries. + pub fn used_count(&self) -> u16 { + let used_idx_addr = self.used_ring.unchecked_add(2); + self.mem.read_obj(used_idx_addr).unwrap() + } + + /// Get the number of chains tracked. + pub fn chain_count(&self) -> usize { + self.chains.borrow().len() + } + + // ======================================================================== + // Verification methods + // ======================================================================== + + /// Assert the used ring matches expected entries. + /// + /// Each entry is `(chain_idx, expected)` where `expected` is: + /// - `Writable(bytes)` - verify writable chain content matches + /// - `Readable(len)` - verify readable chain wasn't modified, check length + /// - `ReadableAnyLen` - verify readable chain wasn't modified, skip length check + #[track_caller] + pub fn assert_used(&self, expected: &[(usize, ExpectedUsed<'_>)]) { + let used = self.used_entries(); + let chains = self.chains.borrow(); + + assert_eq!( + used.len(), + expected.len(), + "used ring count mismatch: expected {}, got {}", + expected.len(), + used.len() + ); + + for (i, (chain_idx, expectation)) in expected.iter().enumerate() { + let chain = &chains[*chain_idx]; + let (actual_id, actual_len) = used[i]; + + // Verify descriptor ID + assert_eq!( + actual_id, chain.head_index, + "used[{}] descriptor id mismatch: expected {} (chain {}), got {}", + i, chain.head_index, chain_idx, actual_id + ); + + match expectation { + ExpectedUsed::Writable(expected_bytes) => { + // Verify length + assert_eq!( + actual_len, + expected_bytes.len() as u32, + "used[{}] length mismatch: expected {}, got {}", + i, + expected_bytes.len(), + actual_len + ); + // Verify content + let full = self.read_chain(chain); + let actual_data = &full[..expected_bytes.len().min(full.len())]; + assert_eq!( + actual_data, *expected_bytes, + "used[{}] content mismatch for chain {}: expected {:?}, got {:?}", + i, chain_idx, expected_bytes, actual_data + ); + } + ExpectedUsed::Readable(expected_len) => { + // Verify readable data wasn't modified + self.assert_chain_unchanged(&chains, *chain_idx); + // Verify length + assert_eq!( + actual_len, *expected_len, + "used[{}] length mismatch: expected {}, got {}", + i, expected_len, actual_len + ); + } + ExpectedUsed::ReadableAnyLen => { + // Verify readable data wasn't modified (skip length check) + self.assert_chain_unchanged(&chains, *chain_idx); + } + } + } + } + + /// Assert a single chain's readable segments weren't modified. + fn assert_chain_unchanged(&self, chains: &[BuiltChain], chain_idx: usize) { + let chain = &chains[chain_idx]; + for (seg_idx, seg) in chain.segments.iter().enumerate() { + if let Some(expected) = &seg.expected_data { + let mut actual = vec![0u8; seg.len as usize]; + self.mem.read(&mut actual, GuestAddress(seg.addr)).unwrap(); + assert_eq!( + &actual, expected, + "chain {} segment {} at addr {:x} was modified: expected {:?}, got {:?}", + chain_idx, seg_idx, seg.addr, expected, actual + ); + } + } + } + + /// Read data from all segments of a chain into a contiguous Vec. + fn read_chain(&self, chain: &BuiltChain) -> Vec { + let mut data = Vec::new(); + for seg in &chain.segments { + let mut buf = vec![0u8; seg.len as usize]; + self.mem.read(&mut buf, GuestAddress(seg.addr)).unwrap(); + data.extend(buf); + } + data + } +} diff --git a/src/libkrun/src/lib.rs b/src/libkrun/src/lib.rs index 0a7117f37..0f269e63f 100644 --- a/src/libkrun/src/lib.rs +++ b/src/libkrun/src/lib.rs @@ -888,6 +888,9 @@ pub unsafe extern "C" fn krun_set_data_disk(ctx_id: u32, c_disk_path: *const c_c #[cfg(feature = "net")] const NET_FLAG_VFKIT: u32 = 1 << 0; +#[cfg(feature = "net")] +const NET_FLAG_INCLUDE_VNET_HEADER: u32 = 1 << 1; + /* Taken from uapi/linux/virtio_net.h */ #[cfg(feature = "net")] const NET_FEATURE_CSUM: u32 = 1 << 0; @@ -965,19 +968,21 @@ pub unsafe extern "C" fn krun_add_net_unixstream( Err(_) => return -libc::EINVAL, }; - /* The unixstream backend doesn't support any flags */ - if flags != 0 { + if (features & !NET_ALL_FEATURES) != 0 { return -libc::EINVAL; } - if (features & !NET_ALL_FEATURES) != 0 { + // Unixstream backends don't support NET_FLAG_VFKIT. + if (flags & !NET_FLAG_INCLUDE_VNET_HEADER) != 0 { return -libc::EINVAL; } + let include_vnet_header: bool = flags & NET_FLAG_INCLUDE_VNET_HEADER != 0; + match CTX_MAP.lock().unwrap().entry(ctx_id) { Entry::Occupied(mut ctx_cfg) => { let cfg = ctx_cfg.get_mut(); - create_virtio_net(cfg, backend, mac, features); + create_virtio_net(cfg, backend, mac, features, include_vnet_header); } Entry::Vacant(_) => return -libc::ENOENT, } @@ -1020,10 +1025,11 @@ pub unsafe extern "C" fn krun_add_net_unixgram( return -libc::EINVAL; } - if (flags & !NET_FLAG_VFKIT) != 0 { + if (flags & !(NET_FLAG_VFKIT | NET_FLAG_INCLUDE_VNET_HEADER)) != 0 { return -libc::EINVAL; } let send_vfkit_magic: bool = flags & NET_FLAG_VFKIT != 0; + let include_vnet_header: bool = flags & NET_FLAG_INCLUDE_VNET_HEADER != 0; let backend = if let Some(path) = path { VirtioNetBackend::UnixgramPath(path, send_vfkit_magic) @@ -1034,7 +1040,7 @@ pub unsafe extern "C" fn krun_add_net_unixgram( match CTX_MAP.lock().unwrap().entry(ctx_id) { Entry::Occupied(mut ctx_cfg) => { let cfg = ctx_cfg.get_mut(); - create_virtio_net(cfg, backend, mac, features); + create_virtio_net(cfg, backend, mac, features, include_vnet_header); } Entry::Vacant(_) => return -libc::ENOENT, } @@ -1083,7 +1089,7 @@ pub unsafe extern "C" fn krun_add_net_tap( match CTX_MAP.lock().unwrap().entry(ctx_id) { Entry::Occupied(mut ctx_cfg) => { let cfg = ctx_cfg.get_mut(); - create_virtio_net(cfg, VirtioNetBackend::Tap(tap_name), mac, features); + create_virtio_net(cfg, VirtioNetBackend::Tap(tap_name), mac, features, true); } Entry::Vacant(_) => return -libc::ENOENT, } @@ -1963,12 +1969,14 @@ fn create_virtio_net( backend: VirtioNetBackend, mac: [u8; 6], features: u32, + include_vnet_header: bool, ) { let network_interface_config = NetworkInterfaceConfig { iface_id: format!("eth{}", ctx_cfg.net_index), backend, mac, features, + include_vnet_header, }; ctx_cfg.net_index += 1; ctx_cfg @@ -2615,7 +2623,7 @@ pub extern "C" fn krun_start_enter(ctx_id: u32) -> i32 { let mac = ctx_cfg .legacy_mac .unwrap_or([0x5a, 0x94, 0xef, 0xe4, 0x0c, 0xee]); - create_virtio_net(&mut ctx_cfg, backend, mac, NET_COMPAT_FEATURES); + create_virtio_net(&mut ctx_cfg, backend, mac, NET_COMPAT_FEATURES, false); } } diff --git a/src/utils/src/fd.rs b/src/utils/src/fd.rs new file mode 100644 index 000000000..14c58e7ba --- /dev/null +++ b/src/utils/src/fd.rs @@ -0,0 +1,51 @@ +// Copyright 2026 Red Hat, Inc. +// SPDX-License-Identifier: Apache-2.0 + +//! File descriptor utilities. + +use std::io; +use std::os::fd::{AsFd, BorrowedFd, RawFd}; + +use nix::fcntl::{fcntl, FcntlArg, OFlag}; + +/// Set non-blocking mode on a file descriptor. +/// +/// If `nonblock` is true, sets O_NONBLOCK. If false, clears it. +pub fn set_nonblocking(fd: impl AsFd, nonblock: bool) -> io::Result<()> { + let fd = fd.as_fd(); + let flags = fcntl(fd, FcntlArg::F_GETFL)?; + let old_flags = OFlag::from_bits_retain(flags); + + let new_flags = if nonblock { + old_flags | OFlag::O_NONBLOCK + } else { + old_flags & !OFlag::O_NONBLOCK + }; + + if new_flags != old_flags { + fcntl(fd, FcntlArg::F_SETFL(new_flags))?; + } + + Ok(()) +} + +/// Set non-blocking mode on a raw file descriptor. +/// +/// The caller must ensure `fd` is a valid file descriptor. +pub fn set_nonblocking_raw(fd: RawFd, nonblock: bool) -> io::Result<()> { + // SAFETY: Caller guarantees fd is valid + let borrowed = unsafe { BorrowedFd::borrow_raw(fd) }; + set_nonblocking(borrowed, nonblock) +} + +/// Extension trait for setting non-blocking mode on file descriptors. +pub trait SetNonblockingExt: AsFd { + /// Set non-blocking mode on this file descriptor. + /// + /// If `nonblock` is true, sets O_NONBLOCK. If false, clears it. + fn set_nonblocking(&self, nonblock: bool) -> io::Result<()> { + set_nonblocking(self.as_fd(), nonblock) + } +} + +impl SetNonblockingExt for T {} diff --git a/src/utils/src/lib.rs b/src/utils/src/lib.rs index f3b22a37b..dadda6480 100644 --- a/src/utils/src/lib.rs +++ b/src/utils/src/lib.rs @@ -6,6 +6,7 @@ pub use vmm_sys_util::{errno, tempdir, tempfile, terminal}; pub use vmm_sys_util::{eventfd, ioctl}; pub mod byte_order; +pub mod fd; #[cfg(target_os = "linux")] pub mod linux; #[cfg(target_os = "linux")] diff --git a/src/vmm/src/vmm_config/net.rs b/src/vmm/src/vmm_config/net.rs index 444692d8f..b55b8f5c7 100644 --- a/src/vmm/src/vmm_config/net.rs +++ b/src/vmm/src/vmm_config/net.rs @@ -18,6 +18,8 @@ pub struct NetworkInterfaceConfig { pub mac: [u8; 6], /// virtio-net features for the network interface. pub features: u32, + /// Whether vnet headers should be sent to and received from the network backend. + pub include_vnet_header: bool, } /// Errors associated with `NetworkInterfaceConfig`. @@ -65,7 +67,13 @@ impl NetBuilder { /// Creates a Net device from a NetworkInterfaceConfig. pub fn create_net(cfg: NetworkInterfaceConfig) -> Result { // Create and return the Net device - Net::new(cfg.iface_id, cfg.backend, cfg.mac, cfg.features) - .map_err(NetworkInterfaceError::CreateNetworkDevice) + Net::new( + cfg.iface_id, + cfg.backend, + cfg.mac, + cfg.features, + cfg.include_vnet_header, + ) + .map_err(NetworkInterfaceError::CreateNetworkDevice) } } diff --git a/tests/Cargo.lock b/tests/Cargo.lock index 157d21c5f..cf6538482 100644 --- a/tests/Cargo.lock +++ b/tests/Cargo.lock @@ -221,6 +221,12 @@ dependencies = [ "either", ] +[[package]] +name = "itoa" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" + [[package]] name = "krun-sys" version = "1.11.1" @@ -410,6 +416,49 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + [[package]] name = "shlex" version = "1.3.0" @@ -451,6 +500,8 @@ dependencies = [ "krun-sys", "macros", "nix", + "serde", + "serde_json", "tempdir", ] @@ -560,3 +611,9 @@ name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/tests/README.md b/tests/README.md index bc61b9da2..2ad2c5699 100644 --- a/tests/README.md +++ b/tests/README.md @@ -1,9 +1,54 @@ # End-to-end tests -The testing framework here allows you to write code to configure libkrun (using the public API) and run some specific code in the guest. +The testing framework here allows you to write code to configure libkrun (using the public API) and run some specific code in the guest. ## Running the tests: The tests can be ran using `make test` (from the main libkrun directory). -You can also run `./run.sh` inside the `test` directory. When using the `./run.sh` script you probably want specify the `PKG_CONFIG_PATH` enviroment variable, otherwise you will be testing the system wide installation of libkrun. +You can also run `./run.sh` inside the `test` directory. When using the `./run.sh` script you probably want specify the `PKG_CONFIG_PATH` enviroment variable, otherwise you will be testing the system wide installation of libkrun. + +## Running on macOS + +### Prerequisites + +1. Install required build tools: + ```bash + brew install lld xz + rustup target add aarch64-unknown-linux-musl + ``` + +2. Install libkrunfw (required for non-EFI builds). Either via homebrew: + ```bash + brew install libkrunfw + ``` + + Or build from source: + ```bash + curl -LO https://github.com/containers/libkrunfw/releases/download/v5.2.0/libkrunfw-prebuilt-aarch64.tgz + tar -xzf libkrunfw-prebuilt-aarch64.tgz + cd libkrunfw + make + sudo make install + ``` + + If installed from source, add `/usr/local/lib` to your library path: + ```bash + export DYLD_LIBRARY_PATH="/usr/local/lib:${DYLD_LIBRARY_PATH}" + ``` + + The test harness automatically handles the library path for homebrew installations. + +### Running tests + +```bash +make test +``` ## Adding tests -To add a test you need to add a new rust module in the `test_cases` directory, implement the required host and guest side methods (see existing tests) and register the test in the `test_cases/src/lib.rs` to be ran. \ No newline at end of file +To add a test you need to add a new rust module in the `test_cases` directory, implement the required host and guest side methods (see existing tests) and register the test in the `test_cases/src/lib.rs` to be ran. + +## Rootfs images + +Some tests (e.g. the iperf3 performance tests) need a full Linux rootfs with extra packages installed. These are built automatically via podman and stored in podman's local image store (tagged as `libkrun-test-`). Podman's layer cache handles rebuild efficiency. + +Container image definitions are registered in the `rootfs_image()` function in `test_cases/src/lib.rs`. Tests refer to images by name only. Tests that need a rootfs will be skipped if podman is not installed. + +To clean up images: `podman rmi $(podman images --filter reference='libkrun-test-*' -q)` \ No newline at end of file diff --git a/tests/create_tap.sh b/tests/create_tap.sh new file mode 100755 index 000000000..ed78d83ca --- /dev/null +++ b/tests/create_tap.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# Create a TAP device for libkrun net-tap testing +# Run with: sudo ./create_tap.sh +# +# This script: +# 1. Creates a persistent TAP device owned by the calling user +# 2. Configures IP address (10.0.0.1/24) +# 3. Sets up NAT/masquerading for internet access from guest + +set -e + +TAP_NAME="${1:-tap0}" +TAP_IP="10.0.0.1" +TAP_NETWORK="10.0.0.0/24" + +if [ "$(id -u)" -ne 0 ]; then + echo "This script must be run with sudo" + exit 1 +fi + +if [ -z "$SUDO_USER" ]; then + echo "Please run with sudo (not as root directly)" + exit 1 +fi + +# Check if tap already exists +if ip link show "$TAP_NAME" &>/dev/null; then + read -p "TAP device '$TAP_NAME' already exists. Delete and recreate? [y/N] " -n 1 -r + echo + if [[ $REPLY =~ ^[Yy]$ ]]; then + echo "Deleting existing $TAP_NAME..." + ip link delete "$TAP_NAME" + else + echo "Aborting." + exit 1 + fi +fi + +echo "Creating TAP device '$TAP_NAME' for user '$SUDO_USER'..." +ip tuntap add dev "$TAP_NAME" mode tap user "$SUDO_USER" vnet_hdr + +echo "Configuring IP address $TAP_IP/24..." +ip addr add "$TAP_IP/24" dev "$TAP_NAME" +ip link set "$TAP_NAME" up + +echo "Enabling IP forwarding..." +echo 1 > /proc/sys/net/ipv4/ip_forward + +# Find the default outgoing interface +DEFAULT_IF=$(ip route show default | awk '/default/ {print $5}' | head -1) +if [ -z "$DEFAULT_IF" ]; then + echo "Warning: Could not determine default interface for masquerading" +else + echo "Setting up NAT/masquerading via $DEFAULT_IF..." + # Remove old rule if exists (ignore errors) + iptables -t nat -D POSTROUTING -s "$TAP_NETWORK" -o "$DEFAULT_IF" -j MASQUERADE 2>/dev/null || true + # Add new rule + iptables -t nat -A POSTROUTING -s "$TAP_NETWORK" -o "$DEFAULT_IF" -j MASQUERADE +fi + +echo "" +echo "Done! TAP device '$TAP_NAME' is ready." +echo "" +echo "Host: $TAP_IP" +echo "Guest: Configure with 10.0.0.2/24, gateway $TAP_IP" +echo "" +echo "To run the test:" +echo " KRUN_NO_UNSHARE=1 LIBKRUN_TAP_NAME=$TAP_NAME make test NET=1 TEST=net-tap" +echo "" +echo "Note: KRUN_NO_UNSHARE=1 is required because the TAP device is in the host" +echo "network namespace, not the test's isolated namespace." diff --git a/tests/guest-agent/src/main.rs b/tests/guest-agent/src/main.rs index 1f9b7965c..668eb9688 100644 --- a/tests/guest-agent/src/main.rs +++ b/tests/guest-agent/src/main.rs @@ -8,7 +8,7 @@ fn run_guest_agent(test_name: &str) -> anyhow::Result<()> { .into_iter() .find(|t| t.name() == test_name) .context("No such test!")?; - let TestCase { test, name: _ } = test_case; + let TestCase { test, .. } = test_case; test.in_guest(); Ok(()) } diff --git a/tests/run.sh b/tests/run.sh index 128b3e546..df3a3c645 100755 --- a/tests/run.sh +++ b/tests/run.sh @@ -6,15 +6,50 @@ set -e +OS=$(uname -s) + # macOS uses the string "arm64" but Rust uses "aarch64" +ARCH=$(uname -m | sed 's/^arm64$/aarch64/') + +GUEST_TARGET="${ARCH}-unknown-linux-musl" + # Run the unit tests first (this tests the testing framework itself not libkrun) -cargo test -p test_cases --features guest +# Only run on Linux - guest code uses Linux-specific ioctls +if [ "$OS" = "Linux" ]; then + cargo test -p test_cases --features guest +fi + +# On macOS, we need to cross-compile for Linux musl +if [ "$OS" = "Darwin" ]; then + SYSROOT="../linux-sysroot" + if [ ! -d "$SYSROOT" ]; then + echo "ERROR: Linux sysroot not found at $SYSROOT" + echo "Run 'make' in the libkrun root directory first to create it." + exit 1 + fi -GUEST_TARGET_ARCH="$(uname -m)-unknown-linux-musl" + export CARGO_TARGET_AARCH64_UNKNOWN_LINUX_MUSL_LINKER="clang" + export CARGO_TARGET_AARCH64_UNKNOWN_LINUX_MUSL_RUSTFLAGS="-C link-arg=-target -C link-arg=aarch64-linux-gnu -C link-arg=-fuse-ld=lld -C link-arg=--sysroot=$SYSROOT -C link-arg=-static" + echo "Cross-compiling guest-agent for $GUEST_TARGET" +fi -cargo build --target=$GUEST_TARGET_ARCH -p guest-agent +cargo build --target=$GUEST_TARGET -p guest-agent cargo build -p runner -export KRUN_TEST_GUEST_AGENT_PATH="target/$GUEST_TARGET_ARCH/debug/guest-agent" +# On macOS, the runner needs entitlements to use Hypervisor.framework +if [ "$OS" = "Darwin" ]; then + codesign --entitlements /dev/stdin --force -s - target/debug/runner <<'EOF' + + + + + com.apple.security.hypervisor + + + +EOF +fi + +export KRUN_TEST_GUEST_AGENT_PATH="target/$GUEST_TARGET/debug/guest-agent" # Build runner args: pass through all arguments RUNNER_ARGS="$*" @@ -24,7 +59,10 @@ if [ -n "${KRUN_TEST_BASE_DIR}" ]; then RUNNER_ARGS="${RUNNER_ARGS} --base-dir ${KRUN_TEST_BASE_DIR}" fi -if [ -z "${KRUN_NO_UNSHARE}" ] && which unshare 2>&1 >/dev/null; then +# Build rootfs images before entering the network namespace (needs internet + podman) +target/debug/runner build-images + +if [ "$OS" != "Darwin" ] && [ -z "${KRUN_NO_UNSHARE}" ] && which unshare 2>&1 >/dev/null; then unshare --user --map-root-user --net -- /bin/sh -c "ifconfig lo 127.0.0.1 && exec target/debug/runner ${RUNNER_ARGS}" else echo "WARNING: Running tests without a network namespace." diff --git a/tests/runner/src/main.rs b/tests/runner/src/main.rs index d3d3a702a..a7a466f82 100644 --- a/tests/runner/src/main.rs +++ b/tests/runner/src/main.rs @@ -8,12 +8,14 @@ use std::panic::catch_unwind; use std::path::{Path, PathBuf}; use std::process::{Command, Stdio}; use tempdir::TempDir; -use test_cases::{test_cases, Test, TestCase, TestSetup}; +use test_cases::{ + rootfs_images, test_cases, Report, ShouldRun, Test, TestCase, TestOutcome, TestSetup, +}; struct TestResult { name: String, - passed: bool, - log_path: PathBuf, + outcome: TestOutcome, + log_path: Option, } fn get_test(name: &str) -> anyhow::Result> { @@ -39,28 +41,39 @@ fn start_vm(test_setup: TestSetup) -> anyhow::Result<()> { } fn run_single_test( - test_case: &str, + test_case: &TestCase, base_dir: &Path, keep_all: bool, max_name_len: usize, ) -> anyhow::Result { + eprint!( + "[{}] {:. outcome, + Err(_) => TestOutcome::Fail, + }; + + match &outcome { + TestOutcome::Pass => { + eprintln!("OK"); + if !keep_all { + let _ = fs::remove_dir_all(&test_dir); + } + } + TestOutcome::Fail => { + eprintln!("FAIL"); + } + TestOutcome::Skip(reason) => { + eprintln!("SKIP ({})", reason); + } + TestOutcome::Report(report) => { + eprintln!("REPORT"); + eprintln!("{:2}", report.text()); } - } else { - eprintln!("FAIL"); } Ok(TestResult { - name: test_case.to_string(), - passed, - log_path, + name: test_case.name.to_string(), + outcome, + log_path: Some(log_path), }) } fn write_github_summary( results: &[TestResult], - num_ok: usize, - num_tests: usize, + num_pass: usize, + num_fail: usize, + num_skip: usize, + num_report: usize, ) -> anyhow::Result<()> { let summary_path = env::var("GITHUB_STEP_SUMMARY") .context("GITHUB_STEP_SUMMARY environment variable not set")?; @@ -106,33 +133,60 @@ fn write_github_summary( .open(&summary_path) .context("Failed to open GITHUB_STEP_SUMMARY")?; - let all_passed = num_ok == num_tests; - let status = if all_passed { "✅" } else { "❌" }; + let num_ran = num_pass + num_fail; + let status = if num_fail == 0 { "✅" } else { "❌" }; + let mut extra = Vec::new(); + if num_skip > 0 { + extra.push(format!("{num_skip} skipped")); + } + if num_report > 0 { + extra.push(format!("{num_report} reports")); + } + let extra_msg = if extra.is_empty() { + String::new() + } else { + format!(" ({})", extra.join(", ")) + }; writeln!( file, - "## {status} Integration Tests ({num_ok}/{num_tests} passed)\n" + "## {status} Integration Tests - {num_pass}/{num_ran} passed{extra_msg}\n" )?; for result in results { - let icon = if result.passed { "✅" } else { "❌" }; - let log_content = fs::read_to_string(&result.log_path).unwrap_or_default(); + let (icon, status_text) = match &result.outcome { + TestOutcome::Pass => ("✅", String::new()), + TestOutcome::Fail => ("❌", String::new()), + TestOutcome::Skip(reason) => ("⏭️", format!(" - {}", reason)), + TestOutcome::Report(_) => ("📊", String::new()), + }; writeln!(file, "
")?; - writeln!(file, "{icon} {}\n", result.name)?; - writeln!(file, "```")?; - // Limit log size to avoid huge summaries (2 MiB limit) - const MAX_LOG_SIZE: usize = 2 * 1024 * 1024; - let truncated = if log_content.len() > MAX_LOG_SIZE { - format!( - "... (truncated, showing last 1 MiB) ...\n{}", - &log_content[log_content.len() - MAX_LOG_SIZE..] - ) - } else { - log_content - }; - writeln!(file, "{truncated}")?; - writeln!(file, "```")?; + writeln!( + file, + "{icon} {}{}\n", + result.name, status_text + )?; + + if let TestOutcome::Report(report) = &result.outcome { + writeln!(file, "{}", report.gh_markdown())?; + } else if let Some(log_path) = &result.log_path { + let log_content = fs::read_to_string(log_path).unwrap_or_default(); + writeln!(file, "```")?; + // Limit log size to avoid huge summaries (2 MiB limit) + const MAX_LOG_SIZE: usize = 2 * 1024 * 1024; + let truncated = if log_content.len() > MAX_LOG_SIZE { + format!( + "... (truncated, showing last 1 MiB) ...\n{}", + &log_content[log_content.len() - MAX_LOG_SIZE..] + ) + } else { + log_content + }; + writeln!(file, "{truncated}")?; + writeln!(file, "```")?; + } + writeln!(file, "
\n")?; } @@ -151,46 +205,78 @@ fn run_tests( fs::create_dir_all(&path).context("Failed to create base directory")?; path } - None => TempDir::new("libkrun-tests") + None => TempDir::new_in("/tmp", "libkrun-tests") .context("Failed to create temp base directory")? .into_path(), }; let mut results: Vec = Vec::new(); + let all_tests = test_cases(); - if test_case == "all" { - let all_tests = test_cases(); - let max_name_len = all_tests.iter().map(|t| t.name.len()).max().unwrap_or(0); - - for TestCase { name, test: _ } in all_tests { - results.push(run_single_test(name, &base_dir, keep_all, max_name_len).context(name)?); - } + let tests_to_run: Vec<_> = if test_case == "all" { + all_tests } else { - let max_name_len = test_case.len(); - results.push( - run_single_test(test_case, &base_dir, keep_all, max_name_len) - .context(test_case.to_string())?, - ); + all_tests + .into_iter() + .filter(|t| t.name == test_case) + .collect() + }; + + if tests_to_run.is_empty() { + anyhow::bail!("No such test: {test_case}"); + } + + let max_name_len = tests_to_run.iter().map(|t| t.name.len()).max().unwrap_or(0); + + for tc in &tests_to_run { + results.push(run_single_test(tc, &base_dir, keep_all, max_name_len).context(tc.name)?); } - let num_tests = results.len(); - let num_ok = results.iter().filter(|r| r.passed).count(); + let num_pass = results + .iter() + .filter(|r| matches!(r.outcome, TestOutcome::Pass)) + .count(); + let num_fail = results + .iter() + .filter(|r| matches!(r.outcome, TestOutcome::Fail)) + .count(); + let num_skip = results + .iter() + .filter(|r| matches!(r.outcome, TestOutcome::Skip(_))) + .count(); + let num_report = results + .iter() + .filter(|r| matches!(r.outcome, TestOutcome::Report(_))) + .count(); + let num_ran = num_pass + num_fail; // Write GitHub Actions summary if requested if github_summary { - write_github_summary(&results, num_ok, num_tests)?; + write_github_summary(&results, num_pass, num_fail, num_skip, num_report)?; } - let num_failures = num_tests - num_ok; - if num_failures > 0 { + let mut extra = Vec::new(); + if num_skip > 0 { + extra.push(format!("{num_skip} skipped")); + } + if num_report > 0 { + extra.push(format!("{num_report} reports")); + } + let extra_msg = if extra.is_empty() { + String::new() + } else { + format!(" ({})", extra.join(", ")) + }; + + if num_fail > 0 { eprintln!("(See test artifacts at: {})", base_dir.display()); - println!("\nFAIL (PASSED {num_ok}/{num_tests})"); + println!("\nFAIL - {num_pass}/{num_ran} passed{extra_msg}"); anyhow::bail!("") } else { if keep_all { eprintln!("(See test artifacts at: {})", base_dir.display()); } - eprintln!("\nOK ({num_ok}/{num_tests} passed)"); + eprintln!("\nOK - {num_pass}/{num_ran} passed{extra_msg}"); } Ok(()) @@ -218,6 +304,8 @@ enum CliCommand { #[arg(long)] tmp_dir: PathBuf, }, + /// Build all registered rootfs images (requires network; run before unshare) + BuildImages, } impl Default for CliCommand { @@ -237,6 +325,19 @@ struct Cli { command: Option, } +fn build_images() -> anyhow::Result<()> { + use test_cases::rootfs; + + for (name, _) in rootfs_images() { + eprint!("Building rootfs image {name}..."); + match rootfs::build_rootfs(name) { + Ok(()) => eprintln!(" done"), + Err(e) => eprintln!(" skipped ({e})"), + } + } + Ok(()) +} + fn main() -> anyhow::Result<()> { let cli = Cli::parse(); let command = cli.command.unwrap_or_default(); @@ -249,5 +350,6 @@ fn main() -> anyhow::Result<()> { keep_all, github_summary, } => run_tests(&test_case, base_dir, keep_all, github_summary), + CliCommand::BuildImages => build_images(), } } diff --git a/tests/test_cases/Cargo.toml b/tests/test_cases/Cargo.toml index 34d646797..5ae646724 100644 --- a/tests/test_cases/Cargo.toml +++ b/tests/test_cases/Cargo.toml @@ -3,7 +3,7 @@ name = "test_cases" edition = "2021" [features] -host = ["krun-sys"] +host = ["krun-sys", "serde", "serde_json"] guest = [] [lib] @@ -12,6 +12,8 @@ name = "test_cases" [dependencies] krun-sys = { path = "../../krun-sys", optional = true } macros = { path = "../macros" } -nix = { version = "0.29.0", features = ["socket"] } +nix = { version = "0.29.0", features = ["socket", "ioctl"] } anyhow = "1.0.95" -tempdir = "0.3.7" \ No newline at end of file +serde = { version = "1", features = ["derive"], optional = true } +serde_json = { version = "1", optional = true } +tempdir = "0.3.7" diff --git a/tests/test_cases/build.rs b/tests/test_cases/build.rs new file mode 100644 index 000000000..08b7d380a --- /dev/null +++ b/tests/test_cases/build.rs @@ -0,0 +1,3 @@ +fn main() { + println!("cargo:rerun-if-env-changed=IPERF_DURATION"); +} diff --git a/tests/test_cases/src/common.rs b/tests/test_cases/src/common.rs index 6a3ee2483..0e84c25d8 100644 --- a/tests/test_cases/src/common.rs +++ b/tests/test_cases/src/common.rs @@ -50,3 +50,34 @@ pub fn setup_fs_and_enter(ctx: u32, test_setup: TestSetup) -> anyhow::Result<()> } unreachable!() } + +/// Like setup_fs_and_enter, but uses an existing rootfs directory (e.g. a Fedora rootfs with +/// extra packages installed). Copies the guest-agent into it and enters the VM. +pub fn setup_existing_rootfs_and_enter( + ctx: u32, + test_setup: TestSetup, + rootfs_dir: &Path, +) -> anyhow::Result<()> { + anyhow::ensure!( + rootfs_dir.is_dir(), + "rootfs directory not found: {}", + rootfs_dir.display() + ); + let path_str = CString::new(rootfs_dir.as_os_str().as_bytes()).context("CString::new")?; + copy_guest_agent(rootfs_dir)?; + unsafe { + krun_call!(krun_set_root(ctx, path_str.as_ptr()))?; + krun_call!(krun_set_workdir(ctx, c"/".as_ptr()))?; + let test_case_cstr = CString::new(test_setup.test_case).context("CString::new")?; + let argv = [test_case_cstr.as_ptr(), null()]; + let envp = [null()]; + krun_call!(krun_set_exec( + ctx, + c"/guest-agent".as_ptr(), + argv.as_ptr(), + envp.as_ptr(), + ))?; + krun_call!(krun_start_enter(ctx))?; + } + unreachable!() +} diff --git a/tests/test_cases/src/lib.rs b/tests/test_cases/src/lib.rs index dfe5211a0..ff81cddb0 100644 --- a/tests/test_cases/src/lib.rs +++ b/tests/test_cases/src/lib.rs @@ -10,9 +10,38 @@ use test_tsi_tcp_guest_connect::TestTsiTcpGuestConnect; mod test_tsi_tcp_guest_listen; use test_tsi_tcp_guest_listen::TestTsiTcpGuestListen; +pub(crate) mod test_net; +use test_net::TestNet; + +mod test_net_perf; +use test_net_perf::TestNetPerf; + mod test_multiport_console; use test_multiport_console::TestMultiportConsole; +pub enum ShouldRun { + Yes, + No(&'static str), +} + +impl ShouldRun { + /// Returns Yes unless on macOS, in which case returns No with the given reason. + pub fn yes_unless_macos(reason: &'static str) -> Self { + if cfg!(target_os = "macos") { + ShouldRun::No(reason) + } else { + ShouldRun::Yes + } + } +} + +pub enum TestOutcome { + Pass, + Fail, + Skip(&'static str), + Report(Box), +} + pub fn test_cases() -> Vec { // Register your test here: vec![ @@ -39,13 +68,96 @@ pub fn test_cases() -> Vec { "tsi-tcp-guest-listen", Box::new(TestTsiTcpGuestListen::new()), ), + TestCase::new("net-passt", Box::new(TestNet::new_passt())), + TestCase::new("net-tap", Box::new(TestNet::new_tap())), + TestCase::new("net-gvproxy", Box::new(TestNet::new_gvproxy())), + TestCase::new("net-vmnet-helper", Box::new(TestNet::new_vmnet_helper())), TestCase::new("multiport-console", Box::new(TestMultiportConsole)), + TestCase::new( + "perf-net-passt-upload", + Box::new(TestNetPerf::new_passt_upload()), + ), + TestCase::new( + "perf-net-passt-download", + Box::new(TestNetPerf::new_passt_download()), + ), + TestCase::new( + "perf-net-tap-upload", + Box::new(TestNetPerf::new_tap_upload()), + ), + TestCase::new( + "perf-net-tap-download", + Box::new(TestNetPerf::new_tap_download()), + ), + TestCase::new( + "perf-net-gvproxy-upload", + Box::new(TestNetPerf::new_gvproxy_upload()), + ), + TestCase::new( + "perf-net-gvproxy-download", + Box::new(TestNetPerf::new_gvproxy_download()), + ), + TestCase::new( + "perf-net-vmnet-helper-upload", + Box::new(TestNetPerf::new_vmnet_helper_upload()), + ), + TestCase::new( + "perf-net-vmnet-helper-download", + Box::new(TestNetPerf::new_vmnet_helper_download()), + ), ] } +/// Registry of container images used by tests. +/// Each entry maps a name to a Containerfile that will be built and cached via podman. +#[host] +pub fn rootfs_images() -> &'static [(&'static str, &'static str)] { + &[( + "fedora-iperf3", + "\ +FROM fedora:43 +RUN dnf install -y iperf3 && dnf clean all +", + )] +} + //////////////////// // Implementation details: ////////////////// + +pub trait ReportImpl { + fn fmt_text(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result; + fn fmt_gh_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result; +} + +pub trait Report: ReportImpl { + fn text(&self) -> ReportText<'_, Self> { + ReportText(self) + } + + fn gh_markdown(&self) -> ReportGhMarkdown<'_, Self> { + ReportGhMarkdown(self) + } +} + +impl Report for T {} + +pub struct ReportText<'a, T: ReportImpl + ?Sized>(pub &'a T); + +impl std::fmt::Display for ReportText<'_, T> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt_text(f) + } +} + +pub struct ReportGhMarkdown<'a, T: ReportImpl + ?Sized>(pub &'a T); + +impl std::fmt::Display for ReportGhMarkdown<'_, T> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt_gh_markdown(f) + } +} + use macros::{guest, host}; #[host] use std::path::PathBuf; @@ -60,6 +172,13 @@ mod common; #[cfg(feature = "host")] mod krun; + +#[cfg(feature = "host")] +pub mod rootfs; + +#[cfg(feature = "guest")] +mod net_config; + mod tcp_tester; #[host] @@ -76,9 +195,18 @@ pub trait Test { fn start_vm(self: Box, test_setup: TestSetup) -> anyhow::Result<()>; /// Checks the output of the (host) process which started the VM - fn check(self: Box, child: Child) { + fn check(self: Box, child: Child) -> TestOutcome { let output = child.wait_with_output().unwrap(); - assert_eq!(String::from_utf8(output.stdout).unwrap(), "OK\n"); + if String::from_utf8(output.stdout).unwrap() == "OK\n" { + TestOutcome::Pass + } else { + TestOutcome::Fail + } + } + + /// Check if this test should run on this platform. + fn should_run(&self) -> ShouldRun { + ShouldRun::Yes } } @@ -100,6 +228,12 @@ impl TestCase { Self { name, test } } + /// Check if this test should run on this platform. + #[host] + pub fn should_run(&self) -> ShouldRun { + self.test.should_run() + } + #[allow(dead_code)] pub fn name(&self) -> &'static str { self.name diff --git a/tests/test_cases/src/net_config.rs b/tests/test_cases/src/net_config.rs new file mode 100644 index 000000000..853f0853e --- /dev/null +++ b/tests/test_cases/src/net_config.rs @@ -0,0 +1,114 @@ +//! Shared network configuration utilities for guest-side network setup +//! +//! This module provides low-level network interface configuration using ioctls, +//! used by virtio-net tests to configure eth0 in the guest. + +use nix::sys::socket::{socket, AddressFamily, SockFlag, SockType}; +use std::os::fd::AsRawFd; + +// Network interface configuration constants +pub const IFNAMSIZ: usize = 16; +pub const IFF_UP: nix::libc::c_short = 0x1; +pub const IFF_RUNNING: nix::libc::c_short = 0x40; + +// ioctl numbers +const SIOCGIFFLAGS: u64 = 0x8913; +const SIOCSIFFLAGS: u64 = 0x8914; +const SIOCSIFADDR: u64 = 0x8916; +const SIOCSIFNETMASK: u64 = 0x891c; + +#[repr(C)] +#[derive(Default)] +pub struct Ifreq { + pub ifr_name: [u8; IFNAMSIZ], + pub ifr_ifru: IfreqIfru, +} + +#[repr(C)] +#[derive(Copy, Clone)] +pub union IfreqIfru { + pub ifru_flags: nix::libc::c_short, + pub ifru_addr: nix::libc::sockaddr, + pub _pad: [u8; 24], +} + +impl Default for IfreqIfru { + fn default() -> Self { + Self { _pad: [0u8; 24] } + } +} + +nix::ioctl_readwrite_bad!(ioctl_siocgifflags, SIOCGIFFLAGS, Ifreq); +nix::ioctl_readwrite_bad!(ioctl_siocsifflags, SIOCSIFFLAGS, Ifreq); +nix::ioctl_write_ptr_bad!(ioctl_siocsifaddr, SIOCSIFADDR, Ifreq); +nix::ioctl_write_ptr_bad!(ioctl_siocsifnetmask, SIOCSIFNETMASK, Ifreq); + +pub fn set_interface_name(ifr: &mut Ifreq, name: &str) { + let bytes = name.as_bytes(); + let len = bytes.len().min(IFNAMSIZ - 1); + ifr.ifr_name[..len].copy_from_slice(&bytes[..len]); + ifr.ifr_name[len] = 0; +} + +pub fn make_sockaddr_in(ip: [u8; 4]) -> nix::libc::sockaddr { + let mut addr: nix::libc::sockaddr_in = unsafe { std::mem::zeroed() }; + addr.sin_family = nix::libc::AF_INET as _; + addr.sin_addr.s_addr = u32::from_ne_bytes(ip); + unsafe { std::mem::transmute(addr) } +} + +/// Configure a network interface with IP address and netmask, and bring it UP +pub fn configure_interface(name: &str, ip: [u8; 4], netmask: [u8; 4]) -> nix::Result<()> { + let sock = socket( + AddressFamily::Inet, + SockType::Datagram, + SockFlag::empty(), + None, + )?; + let fd = sock.as_raw_fd(); + + // Set IP address + let mut ifr = Ifreq::default(); + set_interface_name(&mut ifr, name); + ifr.ifr_ifru.ifru_addr = make_sockaddr_in(ip); + unsafe { ioctl_siocsifaddr(fd, &ifr)? }; + + // Set netmask + let mut ifr = Ifreq::default(); + set_interface_name(&mut ifr, name); + ifr.ifr_ifru.ifru_addr = make_sockaddr_in(netmask); + unsafe { ioctl_siocsifnetmask(fd, &ifr)? }; + + // Bring interface UP + let mut ifr = Ifreq::default(); + set_interface_name(&mut ifr, name); + unsafe { ioctl_siocgifflags(fd, &mut ifr)? }; + unsafe { ifr.ifr_ifru.ifru_flags |= IFF_UP | IFF_RUNNING }; + unsafe { ioctl_siocsifflags(fd, &mut ifr)? }; + + Ok(()) +} + +/// Add a default route via the given gateway +pub fn add_default_route(gateway: [u8; 4]) -> nix::Result<()> { + use nix::libc; + + let sock = socket( + AddressFamily::Inet, + SockType::Datagram, + SockFlag::empty(), + None, + )?; + + let mut rt: libc::rtentry = unsafe { std::mem::zeroed() }; + rt.rt_dst = make_sockaddr_in([0, 0, 0, 0]); + rt.rt_gateway = make_sockaddr_in(gateway); + rt.rt_genmask = make_sockaddr_in([0, 0, 0, 0]); + rt.rt_flags = libc::RTF_UP | libc::RTF_GATEWAY; + + let ret = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCADDRT as _, &rt) }; + if ret < 0 { + return Err(nix::errno::Errno::last()); + } + Ok(()) +} diff --git a/tests/test_cases/src/rootfs.rs b/tests/test_cases/src/rootfs.rs new file mode 100644 index 000000000..701e39a5f --- /dev/null +++ b/tests/test_cases/src/rootfs.rs @@ -0,0 +1,142 @@ +//! Podman-based rootfs provisioning for tests that need a full Linux rootfs. +//! +//! `build_rootfs` builds the podman image and exports a rootfs tarball to +//! `/tmp/libkrun-test-rootfs-cache/`. This runs outside any namespace +//! (via `build-images`). `extract_rootfs` just extracts the cached tarball, +//! so it works inside the `unshare --user --net` namespace without podman. + +use anyhow::{bail, Context}; +use std::fs; +use std::io::Write; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; + +const CACHE_DIR: &str = "/tmp/libkrun-test-rootfs-cache"; + +fn image_tag(name: &str) -> String { + format!("libkrun-test-{name}") +} + +fn tarball_path(name: &str) -> PathBuf { + Path::new(CACHE_DIR).join(format!("{name}.tar")) +} + +fn podman_available() -> bool { + Command::new("podman") + .arg("--version") + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .status() + .map(|s| s.success()) + .unwrap_or(false) +} + +/// Checks whether the rootfs tarball for the given name has been built. +pub fn rootfs_is_built(name: &str) -> bool { + tarball_path(name).exists() +} + +/// Builds the podman image and exports a rootfs tarball to the cache. +/// Must be called outside any namespace (needs podman + network). +pub fn build_rootfs(name: &str) -> anyhow::Result<()> { + if !podman_available() { + bail!("podman not installed"); + } + + let tag = image_tag(name); + let containerfile = crate::rootfs_images() + .iter() + .find(|(n, _)| *n == name) + .unwrap_or_else(|| panic!("unknown rootfs image: {name}")) + .1; + + // Build image (podman layer cache makes this fast when unchanged) + let mut build = Command::new("podman") + .args(["build", "-t", &tag, "-f", "-", "."]) + .stdin(Stdio::piped()) + .stdout(Stdio::null()) + .stderr(Stdio::piped()) + .spawn() + .context("spawning podman build")?; + + build + .stdin + .take() + .unwrap() + .write_all(containerfile.as_bytes()) + .context("writing containerfile to podman stdin")?; + + let output = build + .wait_with_output() + .context("waiting for podman build")?; + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + bail!("podman build failed: {stderr}"); + } + + // Export rootfs tarball to cache + fs::create_dir_all(CACHE_DIR).context("creating rootfs cache directory")?; + + let create_out = Command::new("podman") + .args(["create", &tag]) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .output() + .context("podman create")?; + + if !create_out.status.success() { + let stderr = String::from_utf8_lossy(&create_out.stderr); + bail!("podman create failed: {stderr}"); + } + + let ctr_id = String::from_utf8(create_out.stdout) + .context("container id not utf-8")? + .trim() + .to_string(); + + let tarball = tarball_path(name); + let tar_file = fs::File::create(&tarball).context("creating rootfs tarball file")?; + + let mut export = Command::new("podman") + .args(["export", &ctr_id]) + .stdout(tar_file) + .stderr(Stdio::piped()) + .spawn() + .context("podman export")?; + + let export_status = export.wait().context("waiting for podman export")?; + let _ = Command::new("podman").args(["rm", &ctr_id]).status(); + + if !export_status.success() { + let _ = fs::remove_file(&tarball); + bail!("podman export failed"); + } + + Ok(()) +} + +/// Extracts the cached rootfs tarball into `dest`. +/// The tarball must already exist (call `build_rootfs` first via `build-images`). +pub fn extract_rootfs(name: &str, dest: &Path) -> anyhow::Result<()> { + let tarball = tarball_path(name); + if !tarball.exists() { + bail!("rootfs tarball not found for {name} (run build-images first)"); + } + + fs::create_dir_all(dest).context("creating rootfs destination directory")?; + + let status = Command::new("tar") + .arg("-xf") + .arg(&tarball) + .arg("--no-same-owner") + .arg("-C") + .arg(dest) + .status() + .context("extracting rootfs")?; + + if !status.success() { + bail!("tar extraction failed"); + } + + Ok(()) +} diff --git a/tests/test_cases/src/tcp_tester.rs b/tests/test_cases/src/tcp_tester.rs index c90f12c3b..f518075d1 100644 --- a/tests/test_cases/src/tcp_tester.rs +++ b/tests/test_cases/src/tcp_tester.rs @@ -26,31 +26,15 @@ fn set_timeouts(stream: &mut TcpStream) { .unwrap(); } -fn connect(port: u16) -> TcpStream { - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port); - let mut tries = 0; - loop { - match TcpStream::connect(addr) { - Ok(stream) => return stream, - Err(err) => { - if tries == 5 { - panic!("Couldn't connect to server after 5 attempts: {err}"); - } - tries += 1; - thread::sleep(Duration::from_secs(1)); - } - } - } -} - #[derive(Debug, Copy, Clone)] pub struct TcpTester { + server_ip: Ipv4Addr, port: u16, } impl TcpTester { - pub const fn new(port: u16) -> Self { - Self { port } + pub const fn new(port: u16, server_ip: Ipv4Addr) -> Self { + Self { server_ip, port } } pub fn create_server_socket(&self) -> TcpListener { @@ -66,10 +50,24 @@ impl TcpTester { stream.write_all(b"bye!").unwrap(); // We leak the file descriptor for now, since there is no easy way to close it on libkrun exit mem::forget(listener); + mem::forget(stream); } pub fn run_client(&self) { - let mut stream = connect(self.port); + let addr = SocketAddr::new(IpAddr::V4(self.server_ip), self.port); + let mut tries = 0; + let mut stream = loop { + match TcpStream::connect(addr) { + Ok(stream) => break stream, + Err(err) => { + if tries == 5 { + panic!("Couldn't connect to {addr} after 5 attempts: {err}"); + } + tries += 1; + thread::sleep(Duration::from_secs(1)); + } + } + }; set_timeouts(&mut stream); expect_msg(&mut stream, b"ping!"); expect_wouldblock(&mut stream); diff --git a/tests/test_cases/src/test_net/gvproxy.rs b/tests/test_cases/src/test_net/gvproxy.rs new file mode 100644 index 000000000..10fccc2df --- /dev/null +++ b/tests/test_cases/src/test_net/gvproxy.rs @@ -0,0 +1,127 @@ +//! Gvproxy backend for virtio-net test (macOS only) + +use crate::{krun_call, ShouldRun, TestSetup}; +use krun_sys::{COMPAT_NET_FEATURES, NET_FLAG_VFKIT}; +use nix::libc; +use std::ffi::CString; + +type KrunAddNetUnixgramFn = unsafe extern "C" fn( + ctx_id: u32, + c_path: *const std::ffi::c_char, + fd: i32, + c_mac: *mut u8, + features: u32, + flags: u32, +) -> i32; + +fn get_krun_add_net_unixgram() -> KrunAddNetUnixgramFn { + let symbol = CString::new("krun_add_net_unixgram").unwrap(); + let ptr = unsafe { libc::dlsym(libc::RTLD_DEFAULT, symbol.as_ptr()) }; + assert!(!ptr.is_null(), "krun_add_net_unixgram not found"); + unsafe { std::mem::transmute(ptr) } +} + +fn gvproxy_path() -> Option { + let paths = [ + "/opt/homebrew/Cellar/podman/5.5.1/libexec/podman/gvproxy", + "/opt/homebrew/opt/podman/libexec/podman/gvproxy", + "/usr/libexec/podman/gvproxy", + "/usr/local/libexec/podman/gvproxy", + ]; + for path in paths { + if std::path::Path::new(path).exists() { + return Some(path.to_string()); + } + } + std::process::Command::new("which") + .arg("gvproxy") + .output() + .ok() + .and_then(|o| { + if o.status.success() { + String::from_utf8(o.stdout) + .ok() + .map(|s| s.trim().to_string()) + } else { + None + } + }) +} + +fn start_gvproxy( + socket_path: &str, + log_path: &std::path::Path, +) -> std::io::Result { + use std::process::{Command, Stdio}; + + let _ = Command::new("pkill").arg("-9").arg("gvproxy").status(); + + let gvproxy = gvproxy_path() + .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::NotFound, "gvproxy not found"))?; + + let log_file = std::fs::File::create(log_path)?; + + Command::new(&gvproxy) + .arg("--listen-vfkit") + .arg(format!("unixgram:{}", socket_path)) + .arg("-debug") + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(log_file) + .spawn() +} + +fn wait_for_socket(path: &std::path::Path, timeout_ms: u64) -> bool { + let start = std::time::Instant::now(); + while start.elapsed().as_millis() < timeout_ms as u128 { + if path.exists() { + return true; + } + std::thread::sleep(std::time::Duration::from_millis(50)); + } + false +} + +pub(crate) fn should_run() -> ShouldRun { + #[cfg(not(target_os = "macos"))] + return ShouldRun::No("gvproxy unixgram only supported on macOS"); + + #[cfg(target_os = "macos")] + { + if gvproxy_path().is_none() { + return ShouldRun::No("gvproxy not installed"); + } + ShouldRun::Yes + } +} + +pub(crate) fn setup_backend(ctx: u32, test_setup: &TestSetup) -> anyhow::Result<()> { + let tmp_dir = test_setup + .tmp_dir + .canonicalize() + .unwrap_or_else(|_| test_setup.tmp_dir.clone()); + let socket_path = tmp_dir.join("gvproxy.sock"); + let gvproxy_log = tmp_dir.join("gvproxy.log"); + + let _gvproxy_child = start_gvproxy(socket_path.to_str().unwrap(), &gvproxy_log)?; + + anyhow::ensure!( + wait_for_socket(&socket_path, 5000), + "gvproxy failed to create socket" + ); + + let mut mac: [u8; 6] = [0x5a, 0x94, 0xef, 0xe4, 0x0c, 0xee]; + let c_socket_path = CString::new(socket_path.to_str().unwrap()).unwrap(); + + unsafe { + krun_call!(get_krun_add_net_unixgram()( + ctx, + c_socket_path.as_ptr(), + -1, + mac.as_mut_ptr(), + COMPAT_NET_FEATURES, + NET_FLAG_VFKIT, + ))?; + } + Ok(()) +} diff --git a/tests/test_cases/src/test_net/mod.rs b/tests/test_cases/src/test_net/mod.rs new file mode 100644 index 000000000..551d2cff2 --- /dev/null +++ b/tests/test_cases/src/test_net/mod.rs @@ -0,0 +1,184 @@ +//! Unified virtio-net integration tests +//! +//! All tests follow the same pattern: +//! 1. Host: Start backend + TCP server +//! 2. Guest: Configure eth0 with static IP +//! 3. Guest: Connect to host TCP server + +use crate::tcp_tester::TcpTester; +use macros::{guest, host}; + +#[host] +use crate::{ShouldRun, TestSetup}; + +#[cfg(feature = "host")] +pub(crate) mod gvproxy; +#[cfg(feature = "host")] +pub(crate) mod passt; +#[cfg(feature = "host")] +pub(crate) mod tap; +#[cfg(feature = "host")] +pub(crate) mod vmnet_helper; + +/// Virtio-net test with configurable backend +pub struct TestNet { + #[cfg(feature = "guest")] + guest_ip: [u8; 4], + #[cfg(feature = "guest")] + netmask: [u8; 4], + #[cfg(feature = "guest")] + gateway: Option<[u8; 4]>, + tcp_tester: TcpTester, + #[cfg(feature = "host")] + should_run: fn() -> ShouldRun, + #[cfg(feature = "host")] + setup_backend: fn(u32, &TestSetup) -> anyhow::Result<()>, + #[cfg(feature = "host")] + cleanup: Option, +} + +impl TestNet { + pub fn new_passt() -> Self { + Self { + #[cfg(feature = "guest")] + guest_ip: [169, 254, 2, 1], + #[cfg(feature = "guest")] + netmask: [255, 255, 0, 0], + #[cfg(feature = "guest")] + gateway: None, + tcp_tester: TcpTester::new(9000, [169, 254, 2, 2].into()), + #[cfg(feature = "host")] + should_run: passt::should_run, + #[cfg(feature = "host")] + setup_backend: passt::setup_backend, + #[cfg(feature = "host")] + cleanup: None, + } + } + + pub fn new_tap() -> Self { + Self { + #[cfg(feature = "guest")] + guest_ip: [10, 0, 0, 2], + #[cfg(feature = "guest")] + netmask: [255, 255, 255, 0], + #[cfg(feature = "guest")] + gateway: None, + tcp_tester: TcpTester::new(9001, [10, 0, 0, 1].into()), + #[cfg(feature = "host")] + should_run: tap::should_run, + #[cfg(feature = "host")] + setup_backend: tap::setup_backend, + #[cfg(feature = "host")] + cleanup: Some(tap::cleanup), + } + } + + pub fn new_gvproxy() -> Self { + Self { + #[cfg(feature = "guest")] + guest_ip: [192, 168, 127, 2], + #[cfg(feature = "guest")] + netmask: [255, 255, 255, 0], + #[cfg(feature = "guest")] + gateway: None, + tcp_tester: TcpTester::new(9002, [192, 168, 127, 254].into()), + #[cfg(feature = "host")] + should_run: gvproxy::should_run, + #[cfg(feature = "host")] + setup_backend: gvproxy::setup_backend, + #[cfg(feature = "host")] + cleanup: None, + } + } + + pub fn new_vmnet_helper() -> Self { + Self { + #[cfg(feature = "guest")] + guest_ip: [192, 168, 105, 2], + #[cfg(feature = "guest")] + netmask: [255, 255, 255, 0], + #[cfg(feature = "guest")] + gateway: None, + tcp_tester: TcpTester::new(9003, [192, 168, 105, 1].into()), + #[cfg(feature = "host")] + should_run: vmnet_helper::should_run, + #[cfg(feature = "host")] + setup_backend: vmnet_helper::setup_backend, + #[cfg(feature = "host")] + cleanup: None, + } + } +} + +#[host] +mod host { + use super::*; + use crate::common::setup_fs_and_enter; + use crate::{krun_call, krun_call_u32, Test, TestSetup}; + use krun_sys::*; + use std::thread; + + impl Test for TestNet { + fn should_run(&self) -> ShouldRun { + if unsafe { krun_call_u32!(krun_has_feature(KRUN_FEATURE_NET.into())) }.ok() != Some(1) + { + return ShouldRun::No("libkrun compiled without NET"); + } + (self.should_run)() + } + + fn check(self: Box, child: std::process::Child) -> crate::TestOutcome { + let output = child.wait_with_output().unwrap(); + if let Some(cleanup) = self.cleanup { + cleanup(); + } + if String::from_utf8(output.stdout).unwrap() == "OK\n" { + crate::TestOutcome::Pass + } else { + crate::TestOutcome::Fail + } + } + + fn start_vm(self: Box, test_setup: TestSetup) -> anyhow::Result<()> { + // Start TCP server + let tcp_tester = self.tcp_tester; + let listener = tcp_tester.create_server_socket(); + thread::spawn(move || tcp_tester.run_server(listener)); + + unsafe { + krun_call!(krun_set_log_level(KRUN_LOG_LEVEL_TRACE))?; + let ctx = krun_call_u32!(krun_create_ctx())?; + krun_call!(krun_set_vm_config(ctx, 1, 512))?; + + // Backend-specific setup + (self.setup_backend)(ctx, &test_setup)?; + + setup_fs_and_enter(ctx, test_setup)?; + } + Ok(()) + } + } +} + +#[guest] +mod guest { + use super::*; + use crate::net_config::configure_interface; + use crate::Test; + + impl Test for TestNet { + fn in_guest(self: Box) { + configure_interface("eth0", self.guest_ip, self.netmask) + .expect("Failed to configure eth0"); + + if let Some(gw) = self.gateway { + crate::net_config::add_default_route(gw).expect("Failed to add default route"); + } + + self.tcp_tester.run_client(); + + println!("OK"); + } + } +} diff --git a/tests/test_cases/src/test_net/passt.rs b/tests/test_cases/src/test_net/passt.rs new file mode 100644 index 000000000..97aa62f8a --- /dev/null +++ b/tests/test_cases/src/test_net/passt.rs @@ -0,0 +1,94 @@ +//! Passt backend for virtio-net test + +use crate::{krun_call, ShouldRun, TestSetup}; +use krun_sys::COMPAT_NET_FEATURES; +use nix::libc; +use std::ffi::CString; +use std::os::unix::io::RawFd; + +type KrunAddNetUnixstreamFn = unsafe extern "C" fn( + ctx_id: u32, + c_path: *const std::ffi::c_char, + fd: std::ffi::c_int, + c_mac: *mut u8, + features: u32, + flags: u32, +) -> i32; + +fn get_krun_add_net_unixstream() -> KrunAddNetUnixstreamFn { + let symbol = CString::new("krun_add_net_unixstream").unwrap(); + let ptr = unsafe { libc::dlsym(libc::RTLD_DEFAULT, symbol.as_ptr()) }; + assert!(!ptr.is_null(), "krun_add_net_unixstream not found"); + unsafe { std::mem::transmute(ptr) } +} + +fn passt_available() -> bool { + std::process::Command::new("which") + .arg("passt") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) +} + +fn start_passt() -> std::io::Result { + let mut fds = [0 as libc::c_int; 2]; + if unsafe { libc::socketpair(libc::AF_UNIX, libc::SOCK_STREAM, 0, fds.as_mut_ptr()) } < 0 { + return Err(std::io::Error::last_os_error()); + } + let (parent_fd, child_fd) = (fds[0], fds[1]); + let child_fd_str = child_fd.to_string(); + + let pid = unsafe { libc::fork() }; + if pid < 0 { + return Err(std::io::Error::last_os_error()); + } + + if pid == 0 { + unsafe { libc::close(parent_fd) }; + let passt = CString::new("passt").unwrap(); + let arg_f = CString::new("-f").unwrap(); + let arg_fd = CString::new("--fd").unwrap(); + let arg_fd_val = CString::new(child_fd_str).unwrap(); + unsafe { + libc::execlp( + passt.as_ptr(), + passt.as_ptr(), + arg_f.as_ptr(), + arg_fd.as_ptr(), + arg_fd_val.as_ptr(), + std::ptr::null::(), + ); + } + std::process::exit(1); + } + + unsafe { libc::close(child_fd) }; + Ok(parent_fd) +} + +pub(crate) fn should_run() -> ShouldRun { + if cfg!(target_os = "macos") { + return ShouldRun::No("passt not supported on macOS"); + } + if !passt_available() { + return ShouldRun::No("passt not installed"); + } + ShouldRun::Yes +} + +pub(crate) fn setup_backend(ctx: u32, _test_setup: &TestSetup) -> anyhow::Result<()> { + let passt_fd = start_passt()?; + let mut mac: [u8; 6] = [0x5a, 0x94, 0xef, 0xe4, 0x0c, 0xee]; + + unsafe { + krun_call!(get_krun_add_net_unixstream()( + ctx, + std::ptr::null(), + passt_fd, + mac.as_mut_ptr(), + COMPAT_NET_FEATURES, + 0, + ))?; + } + Ok(()) +} diff --git a/tests/test_cases/src/test_net/tap.rs b/tests/test_cases/src/test_net/tap.rs new file mode 100644 index 000000000..a967cd9be --- /dev/null +++ b/tests/test_cases/src/test_net/tap.rs @@ -0,0 +1,172 @@ +//! TAP backend for virtio-net test + +use crate::{krun_call, ShouldRun, TestSetup}; +use krun_sys::COMPAT_NET_FEATURES; +use nix::libc; +use nix::sys::socket::{socket, AddressFamily, SockFlag, SockType}; +use std::ffi::CString; +use std::fs::OpenOptions; +use std::os::fd::AsRawFd; + +const DEFAULT_TAP_NAME: &str = "tap0"; +const HOST_IP: [u8; 4] = [10, 0, 0, 1]; +const NETMASK: [u8; 4] = [255, 255, 255, 0]; + +type KrunAddNetTapFn = unsafe extern "C" fn( + ctx_id: u32, + c_tap_name: *const std::ffi::c_char, + c_mac: *mut u8, + features: u32, + flags: u32, +) -> i32; + +fn get_krun_add_net_tap() -> KrunAddNetTapFn { + let symbol = CString::new("krun_add_net_tap").unwrap(); + let ptr = unsafe { libc::dlsym(libc::RTLD_DEFAULT, symbol.as_ptr()) }; + assert!(!ptr.is_null(), "krun_add_net_tap not found"); + unsafe { std::mem::transmute(ptr) } +} + +fn interface_exists(name: &str) -> bool { + std::path::Path::new(&format!("/sys/class/net/{}", name)).exists() +} + +// TAP device setup +const TUNSETIFF: libc::c_ulong = 0x400454ca; +const TUNSETPERSIST: libc::c_ulong = 0x400454cb; +const IFF_TAP: libc::c_short = 0x0002; +const IFF_NO_PI: libc::c_short = 0x1000; +const IFF_VNET_HDR: libc::c_short = 0x4000; +const IFNAMSIZ: usize = 16; +const IFF_UP: libc::c_short = 0x1; +const IFF_RUNNING: libc::c_short = 0x40; + +#[repr(C)] +struct Ifreq { + ifr_name: [u8; IFNAMSIZ], + ifr_ifru: IfreqIfru, +} + +#[repr(C)] +#[derive(Copy, Clone)] +union IfreqIfru { + ifru_flags: libc::c_short, + ifru_addr: libc::sockaddr, + _pad: [u8; 24], +} + +nix::ioctl_write_ptr_bad!(ioctl_tunsetiff, TUNSETIFF, Ifreq); +nix::ioctl_write_int_bad!(ioctl_tunsetpersist, TUNSETPERSIST); +nix::ioctl_readwrite_bad!(ioctl_siocsifaddr, 0x8916, Ifreq); +nix::ioctl_readwrite_bad!(ioctl_siocsifnetmask, 0x891c, Ifreq); +nix::ioctl_readwrite_bad!(ioctl_siocgifflags, 0x8913, Ifreq); +nix::ioctl_readwrite_bad!(ioctl_siocsifflags, 0x8914, Ifreq); + +fn set_interface_name(ifr: &mut Ifreq, name: &str) { + let bytes = name.as_bytes(); + let len = bytes.len().min(IFNAMSIZ - 1); + ifr.ifr_name = [0u8; IFNAMSIZ]; + ifr.ifr_name[..len].copy_from_slice(&bytes[..len]); +} + +fn make_sockaddr_in(ip: [u8; 4]) -> libc::sockaddr { + let mut addr: libc::sockaddr_in = unsafe { std::mem::zeroed() }; + addr.sin_family = libc::AF_INET as libc::sa_family_t; + addr.sin_addr.s_addr = u32::from_ne_bytes(ip); + unsafe { std::mem::transmute(addr) } +} + +fn create_tap(name: &str) -> std::io::Result<()> { + let tun = OpenOptions::new() + .read(true) + .write(true) + .open("/dev/net/tun")?; + let mut ifr: Ifreq = unsafe { std::mem::zeroed() }; + set_interface_name(&mut ifr, name); + ifr.ifr_ifru.ifru_flags = IFF_TAP | IFF_NO_PI | IFF_VNET_HDR; + unsafe { ioctl_tunsetiff(tun.as_raw_fd(), &ifr) }.map_err(std::io::Error::other)?; + unsafe { ioctl_tunsetpersist(tun.as_raw_fd(), 1) }.map_err(std::io::Error::other)?; + Ok(()) +} + +fn configure_host_interface(name: &str, ip: [u8; 4], netmask: [u8; 4]) -> nix::Result<()> { + let sock = socket( + AddressFamily::Inet, + SockType::Datagram, + SockFlag::empty(), + None, + )?; + let fd = sock.as_raw_fd(); + + let mut ifr: Ifreq = unsafe { std::mem::zeroed() }; + set_interface_name(&mut ifr, name); + ifr.ifr_ifru.ifru_addr = make_sockaddr_in(ip); + unsafe { ioctl_siocsifaddr(fd, &mut ifr)? }; + + let mut ifr: Ifreq = unsafe { std::mem::zeroed() }; + set_interface_name(&mut ifr, name); + ifr.ifr_ifru.ifru_addr = make_sockaddr_in(netmask); + unsafe { ioctl_siocsifnetmask(fd, &mut ifr)? }; + + let mut ifr: Ifreq = unsafe { std::mem::zeroed() }; + set_interface_name(&mut ifr, name); + unsafe { ioctl_siocgifflags(fd, &mut ifr)? }; + unsafe { ifr.ifr_ifru.ifru_flags |= IFF_UP | IFF_RUNNING }; + unsafe { ioctl_siocsifflags(fd, &mut ifr)? }; + + Ok(()) +} + +pub(crate) fn should_run() -> ShouldRun { + if cfg!(target_os = "macos") { + return ShouldRun::No("TAP not supported on macOS"); + } + if let Ok(tap_name) = std::env::var("LIBKRUN_TAP_NAME") { + if !interface_exists(&tap_name) { + return ShouldRun::No("TAP interface not found"); + } + } else if !std::path::Path::new("/dev/net/tun").exists() { + return ShouldRun::No("/dev/net/tun not available"); + } + ShouldRun::Yes +} + +pub(crate) fn cleanup() { + if let Ok(tun) = OpenOptions::new() + .read(true) + .write(true) + .open("/dev/net/tun") + { + let mut ifr: Ifreq = unsafe { std::mem::zeroed() }; + set_interface_name(&mut ifr, DEFAULT_TAP_NAME); + ifr.ifr_ifru.ifru_flags = IFF_TAP | IFF_NO_PI; + if unsafe { ioctl_tunsetiff(tun.as_raw_fd(), &ifr) }.is_ok() { + let _ = unsafe { ioctl_tunsetpersist(tun.as_raw_fd(), 0) }; + } + } +} + +pub(crate) fn setup_backend(ctx: u32, _test_setup: &TestSetup) -> anyhow::Result<()> { + let tap_name = if let Ok(name) = std::env::var("LIBKRUN_TAP_NAME") { + name + } else { + create_tap(DEFAULT_TAP_NAME)?; + configure_host_interface(DEFAULT_TAP_NAME, HOST_IP, NETMASK) + .map_err(|e| anyhow::anyhow!("Failed to configure TAP: {}", e))?; + DEFAULT_TAP_NAME.to_string() + }; + + let mut mac: [u8; 6] = [0x5a, 0x94, 0xef, 0xe4, 0x0c, 0xee]; + let tap_name_c = CString::new(tap_name).unwrap(); + + unsafe { + krun_call!(get_krun_add_net_tap()( + ctx, + tap_name_c.as_ptr(), + mac.as_mut_ptr(), + COMPAT_NET_FEATURES, + 0, + ))?; + } + Ok(()) +} diff --git a/tests/test_cases/src/test_net/vmnet_helper.rs b/tests/test_cases/src/test_net/vmnet_helper.rs new file mode 100644 index 000000000..a2bb32d0c --- /dev/null +++ b/tests/test_cases/src/test_net/vmnet_helper.rs @@ -0,0 +1,245 @@ +//! vmnet-helper backend for virtio-net test (macOS only) + +use crate::{krun_call, ShouldRun, TestSetup}; +use nix::libc; +use std::ffi::CString; +use std::io::{BufRead, BufReader, Read}; +use std::os::unix::io::FromRawFd; +use std::process::Command; + +type KrunAddNetUnixgramFn = unsafe extern "C" fn( + ctx_id: u32, + c_path: *const std::ffi::c_char, + fd: i32, + c_mac: *mut u8, + features: u32, + flags: u32, +) -> i32; + +fn get_krun_add_net_unixgram() -> KrunAddNetUnixgramFn { + let symbol = CString::new("krun_add_net_unixgram").unwrap(); + let ptr = unsafe { libc::dlsym(libc::RTLD_DEFAULT, symbol.as_ptr()) }; + assert!(!ptr.is_null(), "krun_add_net_unixgram not found"); + unsafe { std::mem::transmute(ptr) } +} + +fn vmnet_helper_path() -> Option { + let paths = [ + "/opt/vmnet-helper/bin/vmnet-helper", + "/opt/homebrew/opt/vmnet-helper/libexec/vmnet-helper", + "/opt/homebrew/bin/vmnet-helper", + "/usr/local/bin/vmnet-helper", + ]; + for path in paths { + if std::path::Path::new(path).exists() { + return Some(path.to_string()); + } + } + Command::new("which") + .arg("vmnet-helper") + .output() + .ok() + .and_then(|o| { + if o.status.success() { + String::from_utf8(o.stdout) + .ok() + .map(|s| s.trim().to_string()) + } else { + None + } + }) +} + +/// Parse a MAC address string like "1e:d4:d1:27:4b:bf" into 6 bytes. +fn parse_mac(s: &str) -> Option<[u8; 6]> { + let parts: Vec<&str> = s.split(':').collect(); + if parts.len() != 6 { + return None; + } + let mut mac = [0u8; 6]; + for (i, part) in parts.iter().enumerate() { + mac[i] = u8::from_str_radix(part, 16).ok()?; + } + Some(mac) +} + +struct VmnetConfig { + fd: i32, + mac: [u8; 6], +} + +/// Start vmnet-helper with `--fd 3`, wait for its JSON config on stdout, +/// and return the fd + MAC address from vmnet. +/// +/// Creates a `SOCK_DGRAM` socketpair, passes one end to vmnet-helper as fd 3 +/// (matching what `vmnet-client` does), and returns the other end for use +/// with `krun_add_net_unixgram`. +fn start_vmnet_helper(log_path: &std::path::Path) -> std::io::Result { + let helper = vmnet_helper_path().ok_or_else(|| { + std::io::Error::new(std::io::ErrorKind::NotFound, "vmnet-helper not found") + })?; + + // Create a SOCK_DGRAM socketpair + let mut fds = [0 as libc::c_int; 2]; + if unsafe { libc::socketpair(libc::AF_UNIX, libc::SOCK_DGRAM, 0, fds.as_mut_ptr()) } < 0 { + return Err(std::io::Error::last_os_error()); + } + let (our_fd, helper_fd) = (fds[0], fds[1]); + + // Create a pipe for reading vmnet-helper's stdout (JSON config) + let mut stdout_fds = [0 as libc::c_int; 2]; + if unsafe { libc::pipe(stdout_fds.as_mut_ptr()) } < 0 { + unsafe { + libc::close(our_fd); + libc::close(helper_fd); + } + return Err(std::io::Error::last_os_error()); + } + let (stdout_read, stdout_write) = (stdout_fds[0], stdout_fds[1]); + + let log_file = std::fs::File::create(log_path)?; + + let pid = unsafe { libc::fork() }; + if pid < 0 { + unsafe { + libc::close(our_fd); + libc::close(helper_fd); + libc::close(stdout_read); + libc::close(stdout_write); + } + return Err(std::io::Error::last_os_error()); + } + + if pid == 0 { + // Child process + unsafe { + libc::close(our_fd); + libc::close(stdout_read); + + // Redirect stdout to our pipe + libc::dup2(stdout_write, 1); + libc::close(stdout_write); + + // Redirect stderr to log file + use std::os::unix::io::AsRawFd; + libc::dup2(log_file.as_raw_fd(), 2); + + // Redirect stdin from /dev/null + let devnull = libc::open(c"/dev/null".as_ptr(), libc::O_RDONLY); + if devnull >= 0 { + libc::dup2(devnull, 0); + libc::close(devnull); + } + + // Place helper_fd at fd 3 + if helper_fd != 3 { + libc::dup2(helper_fd, 3); + libc::close(helper_fd); + } + + let helper_c = CString::new(helper.as_str()).unwrap(); + let arg_fd = CString::new("--fd").unwrap(); + let arg_fd_val = CString::new("3").unwrap(); + libc::execlp( + helper_c.as_ptr(), + helper_c.as_ptr(), + arg_fd.as_ptr(), + arg_fd_val.as_ptr(), + std::ptr::null::(), + ); + libc::_exit(1); + } + } + + // Parent process + unsafe { + libc::close(helper_fd); + libc::close(stdout_write); + } + + // Read the JSON config line from vmnet-helper's stdout. + // vmnet-helper writes a single JSON line then keeps running. + let stdout_file = unsafe { std::fs::File::from_raw_fd(stdout_read) }; + let reader = BufReader::new(stdout_file); + let mut config_line = String::new(); + reader + .take(4096) + .read_line(&mut config_line) + .map_err(|e| std::io::Error::other(format!("failed to read vmnet-helper config: {e}")))?; + + if config_line.is_empty() { + return Err(std::io::Error::other( + "vmnet-helper exited without producing config", + )); + } + + eprintln!("vmnet-helper config: {}", config_line.trim()); + + // Parse the MAC address from the JSON config. + // The JSON looks like: {"vmnet_mac_address":"1e:d4:d1:27:4b:bf",...} + let mac_str = config_line + .split("\"vmnet_mac_address\":\"") + .nth(1) + .and_then(|s| s.split('"').next()) + .ok_or_else(|| std::io::Error::other("vmnet_mac_address not found in config"))?; + + let mac = parse_mac(mac_str) + .ok_or_else(|| std::io::Error::other(format!("invalid MAC address: {mac_str}")))?; + + // Increase socket buffer sizes so libkrun's Unixgram backend (which uses + // the fd path and does NOT set these) can batch frames without drops. + let buf_size: libc::c_int = 7 * 1024 * 1024; + unsafe { + libc::setsockopt( + our_fd, + libc::SOL_SOCKET, + libc::SO_SNDBUF, + &buf_size as *const _ as *const libc::c_void, + std::mem::size_of_val(&buf_size) as libc::socklen_t, + ); + libc::setsockopt( + our_fd, + libc::SOL_SOCKET, + libc::SO_RCVBUF, + &buf_size as *const _ as *const libc::c_void, + std::mem::size_of_val(&buf_size) as libc::socklen_t, + ); + } + + Ok(VmnetConfig { fd: our_fd, mac }) +} + +pub(crate) fn should_run() -> ShouldRun { + #[cfg(not(target_os = "macos"))] + return ShouldRun::No("vmnet-helper only supported on macOS"); + + #[cfg(target_os = "macos")] + { + if vmnet_helper_path().is_none() { + return ShouldRun::No("vmnet-helper not installed"); + } + ShouldRun::Yes + } +} + +pub(crate) fn setup_backend(ctx: u32, test_setup: &TestSetup) -> anyhow::Result<()> { + let tmp_dir = test_setup + .tmp_dir + .canonicalize() + .unwrap_or_else(|_| test_setup.tmp_dir.clone()); + let vmnet_log = tmp_dir.join("vmnet-helper.log"); + + let mut config = start_vmnet_helper(&vmnet_log)?; + + unsafe { + krun_call!(get_krun_add_net_unixgram()( + ctx, + std::ptr::null(), + config.fd, + config.mac.as_mut_ptr(), + 0, // no offloading - vmnet-helper uses raw ethernet frames + 0, // no VFKIT flag + ))?; + } + Ok(()) +} diff --git a/tests/test_cases/src/test_net_perf/mod.rs b/tests/test_cases/src/test_net_perf/mod.rs new file mode 100644 index 000000000..61d4c3152 --- /dev/null +++ b/tests/test_cases/src/test_net_perf/mod.rs @@ -0,0 +1,463 @@ +//! iperf3-based performance tests for virtio-net backends +//! +//! Each test: +//! 1. Host: Start iperf3 server + network backend +//! 2. Guest: Configure eth0, run iperf3 client +//! 3. Host: Parse iperf3 JSON output, produce markdown report +//! +//! Tests are parametrized by backend and direction (upload vs download). + +use macros::{guest, host}; + +#[host] +use crate::{ShouldRun, TestSetup}; + +/// Virtio-net performance test with configurable backend and direction +pub struct TestNetPerf { + #[cfg(feature = "guest")] + guest_ip: [u8; 4], + #[cfg(feature = "guest")] + host_ip: [u8; 4], + #[cfg(feature = "guest")] + netmask: [u8; 4], + port: u16, + /// If true, run iperf3 with -R (reverse: server sends, client receives = download) + reverse: bool, + #[cfg(feature = "host")] + should_run: fn() -> ShouldRun, + #[cfg(feature = "host")] + setup_backend: fn(u32, &TestSetup) -> anyhow::Result<()>, + #[cfg(feature = "host")] + cleanup: Option, +} + +impl TestNetPerf { + pub fn new_passt_upload() -> Self { + Self { + #[cfg(feature = "guest")] + guest_ip: [169, 254, 2, 1], + #[cfg(feature = "guest")] + host_ip: [169, 254, 2, 2], + #[cfg(feature = "guest")] + netmask: [255, 255, 0, 0], + port: 15100, + reverse: false, + #[cfg(feature = "host")] + should_run: crate::test_net::passt::should_run, + #[cfg(feature = "host")] + setup_backend: crate::test_net::passt::setup_backend, + #[cfg(feature = "host")] + cleanup: None, + } + } + + pub fn new_passt_download() -> Self { + Self { + #[cfg(feature = "guest")] + guest_ip: [169, 254, 2, 1], + #[cfg(feature = "guest")] + host_ip: [169, 254, 2, 2], + #[cfg(feature = "guest")] + netmask: [255, 255, 0, 0], + port: 15110, + reverse: true, + #[cfg(feature = "host")] + should_run: crate::test_net::passt::should_run, + #[cfg(feature = "host")] + setup_backend: crate::test_net::passt::setup_backend, + #[cfg(feature = "host")] + cleanup: None, + } + } + + pub fn new_tap_upload() -> Self { + Self { + #[cfg(feature = "guest")] + guest_ip: [10, 0, 0, 2], + #[cfg(feature = "guest")] + host_ip: [10, 0, 0, 1], + #[cfg(feature = "guest")] + netmask: [255, 255, 255, 0], + port: 15101, + reverse: false, + #[cfg(feature = "host")] + should_run: crate::test_net::tap::should_run, + #[cfg(feature = "host")] + setup_backend: crate::test_net::tap::setup_backend, + #[cfg(feature = "host")] + cleanup: Some(crate::test_net::tap::cleanup), + } + } + + pub fn new_tap_download() -> Self { + Self { + #[cfg(feature = "guest")] + guest_ip: [10, 0, 0, 2], + #[cfg(feature = "guest")] + host_ip: [10, 0, 0, 1], + #[cfg(feature = "guest")] + netmask: [255, 255, 255, 0], + port: 15111, + reverse: true, + #[cfg(feature = "host")] + should_run: crate::test_net::tap::should_run, + #[cfg(feature = "host")] + setup_backend: crate::test_net::tap::setup_backend, + #[cfg(feature = "host")] + cleanup: Some(crate::test_net::tap::cleanup), + } + } + + pub fn new_gvproxy_upload() -> Self { + Self { + #[cfg(feature = "guest")] + guest_ip: [192, 168, 127, 2], + #[cfg(feature = "guest")] + host_ip: [192, 168, 127, 254], + #[cfg(feature = "guest")] + netmask: [255, 255, 255, 0], + port: 15102, + reverse: false, + #[cfg(feature = "host")] + should_run: crate::test_net::gvproxy::should_run, + #[cfg(feature = "host")] + setup_backend: crate::test_net::gvproxy::setup_backend, + #[cfg(feature = "host")] + cleanup: None, + } + } + + pub fn new_gvproxy_download() -> Self { + Self { + #[cfg(feature = "guest")] + guest_ip: [192, 168, 127, 2], + #[cfg(feature = "guest")] + host_ip: [192, 168, 127, 254], + #[cfg(feature = "guest")] + netmask: [255, 255, 255, 0], + port: 15112, + reverse: true, + #[cfg(feature = "host")] + should_run: crate::test_net::gvproxy::should_run, + #[cfg(feature = "host")] + setup_backend: crate::test_net::gvproxy::setup_backend, + #[cfg(feature = "host")] + cleanup: None, + } + } + + pub fn new_vmnet_helper_upload() -> Self { + Self { + #[cfg(feature = "guest")] + guest_ip: [192, 168, 105, 2], + #[cfg(feature = "guest")] + host_ip: [192, 168, 105, 1], + #[cfg(feature = "guest")] + netmask: [255, 255, 255, 0], + port: 15103, + reverse: false, + #[cfg(feature = "host")] + should_run: crate::test_net::vmnet_helper::should_run, + #[cfg(feature = "host")] + setup_backend: crate::test_net::vmnet_helper::setup_backend, + #[cfg(feature = "host")] + cleanup: None, + } + } + + pub fn new_vmnet_helper_download() -> Self { + Self { + #[cfg(feature = "guest")] + guest_ip: [192, 168, 105, 2], + #[cfg(feature = "guest")] + host_ip: [192, 168, 105, 1], + #[cfg(feature = "guest")] + netmask: [255, 255, 255, 0], + port: 15113, + reverse: true, + #[cfg(feature = "host")] + should_run: crate::test_net::vmnet_helper::should_run, + #[cfg(feature = "host")] + setup_backend: crate::test_net::vmnet_helper::setup_backend, + #[cfg(feature = "host")] + cleanup: None, + } + } +} + +#[host] +mod host { + use super::*; + use crate::common::setup_existing_rootfs_and_enter; + use crate::rootfs; + use crate::{krun_call, krun_call_u32, Test, TestOutcome, TestSetup}; + use krun_sys::*; + use std::process::{Child, Command, Stdio}; + + const ROOTFS_IMAGE: &str = "fedora-iperf3"; + + fn iperf3_available() -> bool { + Command::new("iperf3") + .arg("--version") + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .status() + .map(|s| s.success()) + .unwrap_or(false) + } + + fn start_iperf_server(port: u16) -> std::io::Result { + Command::new("iperf3") + .arg("-s") + .arg("-p") + .arg(port.to_string()) + .arg("-1") // one-off: exit after first client + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .spawn() + } + + #[derive(serde::Deserialize)] + struct Iperf3Output { + intervals: Vec, + end: Iperf3End, + } + + #[derive(serde::Deserialize)] + struct Iperf3Interval { + sum: Iperf3Sum, + } + + #[derive(serde::Deserialize)] + struct Iperf3End { + sum_sent: Iperf3Sum, + sum_received: Iperf3Sum, + } + + #[derive(serde::Deserialize)] + #[allow(dead_code)] + struct Iperf3Sum { + start: f64, + end: f64, + seconds: f64, + bytes: f64, + bits_per_second: f64, + } + + struct Iperf3Report { + output: Iperf3Output, + reverse: bool, + } + + impl Iperf3Report { + fn label(&self) -> &'static str { + if self.reverse { + "Download (host->guest)" + } else { + "Upload (guest->host)" + } + } + + fn summary(&self) -> &Iperf3Sum { + if self.reverse { + &self.output.end.sum_received + } else { + &self.output.end.sum_sent + } + } + } + + impl crate::ReportImpl for Iperf3Report { + fn fmt_text(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let i = f.width().unwrap_or(0); + writeln!(f, "{:i$}iperf3 — {}\n", "", self.label())?; + writeln!( + f, + "{:i$}{:<9} {:>18} {:>14}", + "", "Interval", "Throughput", "Transferred" + )?; + writeln!(f, "{:i$}{:-<9} {:-<18} {:-<14}", "", "", "", "")?; + for interval in &self.output.intervals { + let s = &interval.sum; + let iv = format!("{:.0}-{:.0}s", s.start, s.end); + writeln!( + f, + "{:i$}{:<9} {:>11.2} Gbit/s {:>10.2} GiB", + "", + iv, + s.bits_per_second / 1_000_000_000.0, + s.bytes / (1024.0 * 1024.0 * 1024.0), + )?; + } + let s = self.summary(); + writeln!(f, "{:i$}{:-<9} {:-<18} {:-<14}", "", "", "", "")?; + write!( + f, + "{:i$}{:<9} {:>11.2} Gbit/s {:>10.2} GiB", + "", + "Total", + s.bits_per_second / 1_000_000_000.0, + s.bytes / (1024.0 * 1024.0 * 1024.0), + ) + } + + fn fmt_gh_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "### iperf3 — {}\n", self.label())?; + writeln!(f, "| Interval | Throughput | Transferred |")?; + writeln!(f, "|----------|-----------|-------------|")?; + for interval in &self.output.intervals { + let s = &interval.sum; + writeln!( + f, + "| {:.0}-{:.0}s | {:.2} Gbit/s | {:.2} GiB |", + s.start, + s.end, + s.bits_per_second / 1_000_000_000.0, + s.bytes / (1024.0 * 1024.0 * 1024.0), + )?; + } + let s = self.summary(); + write!( + f, + "| **Total** | **{:.2} Gbit/s** | **{:.2} GiB** |", + s.bits_per_second / 1_000_000_000.0, + s.bytes / (1024.0 * 1024.0 * 1024.0), + ) + } + } + + impl Test for TestNetPerf { + fn should_run(&self) -> ShouldRun { + if option_env!("IPERF_DURATION").is_none() { + return ShouldRun::No("IPERF_DURATION not set"); + } + if unsafe { krun_call_u32!(krun_has_feature(KRUN_FEATURE_NET.into())) }.ok() != Some(1) + { + return ShouldRun::No("libkrun compiled without NET"); + } + let backend_result = (self.should_run)(); + if let ShouldRun::No(_) = backend_result { + return backend_result; + } + if !iperf3_available() { + return ShouldRun::No("iperf3 not installed on host"); + } + if !rootfs::rootfs_is_built(ROOTFS_IMAGE) { + return ShouldRun::No("rootfs not built (run: make test)"); + } + ShouldRun::Yes + } + + fn start_vm(self: Box, test_setup: TestSetup) -> anyhow::Result<()> { + // Start iperf3 server on host (one-off, exits after first client) + let mut iperf_server = start_iperf_server(self.port)?; + + // Give iperf3 server a moment to start + std::thread::sleep(std::time::Duration::from_millis(200)); + + // Check it's still running + if let Some(status) = iperf_server.try_wait()? { + anyhow::bail!("iperf3 server exited early: {status}"); + } + + unsafe { + krun_call!(krun_set_log_level(KRUN_LOG_LEVEL_TRACE))?; + let ctx = krun_call_u32!(krun_create_ctx())?; + krun_call!(krun_set_vm_config(ctx, 1, 512))?; + + // Backend-specific setup + (self.setup_backend)(ctx, &test_setup)?; + + let rootfs = test_setup.tmp_dir.join("rootfs"); + rootfs::extract_rootfs(ROOTFS_IMAGE, &rootfs)?; + setup_existing_rootfs_and_enter(ctx, test_setup, &rootfs)?; + } + Ok(()) + } + + fn check(self: Box, child: std::process::Child) -> TestOutcome { + let output = child.wait_with_output().unwrap(); + if let Some(cleanup) = self.cleanup { + cleanup(); + } + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + + match serde_json::from_str::(&stdout) { + Ok(iperf_output) => TestOutcome::Report(Box::new(Iperf3Report { + output: iperf_output, + reverse: self.reverse, + })), + Err(_) => TestOutcome::Fail, + } + } + } +} + +#[guest] +mod guest { + use super::*; + use crate::net_config::configure_interface; + use crate::Test; + use std::process::Command; + use std::time::Duration; + + impl Test for TestNetPerf { + fn in_guest(self: Box) { + // Configure eth0 with static IP + configure_interface("eth0", self.guest_ip, self.netmask) + .expect("Failed to configure eth0"); + + let host_ip = format!( + "{}.{}.{}.{}", + self.host_ip[0], self.host_ip[1], self.host_ip[2], self.host_ip[3] + ); + + // Give the network a moment to come up + std::thread::sleep(Duration::from_secs(2)); + + let Some(iperf_duration) = option_env!("IPERF_DURATION") else { + unreachable!() + }; + + // Run iperf3 client with JSON output, retry up to 5 times + let mut last_output = None; + for attempt in 0..5 { + if attempt > 0 { + std::thread::sleep(Duration::from_secs(2)); + } + + let mut cmd = Command::new("/usr/bin/iperf3"); + cmd.arg("-c") + .arg(&host_ip) + .arg("-p") + .arg(self.port.to_string()) + .arg("-t") + .arg(iperf_duration) + .arg("-J"); + + if self.reverse { + cmd.arg("-R"); + } + + let output = cmd.output().expect("Failed to run iperf3"); + + if output.status.success() { + // Print JSON output to stdout (host will read it) + let stdout = String::from_utf8(output.stdout).expect("iperf3 output not UTF-8"); + print!("{}", stdout); + return; + } + + last_output = Some(output); + } + + let output = last_output.unwrap(); + let stderr = String::from_utf8_lossy(&output.stderr); + let stdout = String::from_utf8_lossy(&output.stdout); + panic!( + "iperf3 failed after 5 attempts (exit={}):\nstderr: {}\nstdout: {}", + output.status, stderr, stdout + ); + } + } +} diff --git a/tests/test_cases/src/test_tsi_tcp_guest_connect.rs b/tests/test_cases/src/test_tsi_tcp_guest_connect.rs index 038501b37..1fb6c7424 100644 --- a/tests/test_cases/src/test_tsi_tcp_guest_connect.rs +++ b/tests/test_cases/src/test_tsi_tcp_guest_connect.rs @@ -1,5 +1,6 @@ use crate::tcp_tester::TcpTester; use macros::{guest, host}; +use std::net::Ipv4Addr; const PORT: u16 = 8000; @@ -10,7 +11,7 @@ pub struct TestTsiTcpGuestConnect { impl TestTsiTcpGuestConnect { pub fn new() -> TestTsiTcpGuestConnect { Self { - tcp_tester: TcpTester::new(PORT), + tcp_tester: TcpTester::new(PORT, Ipv4Addr::LOCALHOST), } } } diff --git a/tests/test_cases/src/test_tsi_tcp_guest_listen.rs b/tests/test_cases/src/test_tsi_tcp_guest_listen.rs index 9838ed893..a2f3a1cc8 100644 --- a/tests/test_cases/src/test_tsi_tcp_guest_listen.rs +++ b/tests/test_cases/src/test_tsi_tcp_guest_listen.rs @@ -1,5 +1,6 @@ use crate::tcp_tester::TcpTester; use macros::{guest, host}; +use std::net::Ipv4Addr; const PORT: u16 = 8001; @@ -10,7 +11,7 @@ pub struct TestTsiTcpGuestListen { impl TestTsiTcpGuestListen { pub fn new() -> Self { Self { - tcp_tester: TcpTester::new(PORT), + tcp_tester: TcpTester::new(PORT, Ipv4Addr::LOCALHOST), } } }