diff --git a/Sources/DNSServer/DNSServer+Handle.swift b/Sources/DNSServer/DNSServer+Handle.swift index e258fb482..899e1dd85 100644 --- a/Sources/DNSServer/DNSServer+Handle.swift +++ b/Sources/DNSServer/DNSServer+Handle.swift @@ -19,29 +19,27 @@ import NIOCore import NIOPosix extension DNSServer { - /// Handles the DNS request. - /// - Parameters: - /// - outbound: The NIOAsyncChannelOutboundWriter for which to respond. - /// - packet: The request packet. + /// Handles a UDP DNS request and writes the response back to the sender. func handle( outbound: NIOAsyncChannelOutboundWriter>, packet: inout AddressedEnvelope ) async throws { - let chunkSize = 512 - var data = Data() + let data = Data(packet.data.readableBytesView) - self.log?.debug("reading data") - while packet.data.readableBytes > 0 { - if let chunk = packet.data.readBytes(length: min(chunkSize, packet.data.readableBytes)) { - data.append(contentsOf: chunk) - } - } + self.log?.debug("sending response for request") + let responseData = try await self.processRaw(data: data) + let rData = ByteBuffer(bytes: responseData) + try? await outbound.write(AddressedEnvelope(remoteAddress: packet.remoteAddress, data: rData)) + self.log?.debug("processing done") + } + /// Deserializes a raw DNS query, runs it through the handler chain, and returns + /// the serialized response bytes. Used by both UDP and TCP transports. + func processRaw(data: Data) async throws -> Data { self.log?.debug("deserializing message") let query = try Message(deserialize: data) self.log?.debug("processing query: \(query.questions)") - // always send response let responseData: Data do { self.log?.debug("awaiting processing") @@ -76,11 +74,6 @@ extension DNSServer { responseData = try response.serialize() } - self.log?.debug("sending response for \(query.id)") - let rData = ByteBuffer(bytes: responseData) - try? await outbound.write(AddressedEnvelope(remoteAddress: packet.remoteAddress, data: rData)) - - self.log?.debug("processing done") - + return responseData } } diff --git a/Sources/DNSServer/DNSServer+TCPHandle.swift b/Sources/DNSServer/DNSServer+TCPHandle.swift new file mode 100644 index 000000000..5a10c98be --- /dev/null +++ b/Sources/DNSServer/DNSServer+TCPHandle.swift @@ -0,0 +1,110 @@ +//===----------------------------------------------------------------------===// +// Copyright © 2025-2026 Apple Inc. and the container project authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// + +import Foundation +import NIOCore +import NIOPosix +import Synchronization + +/// Tracks the activity timestamps for a specific TCP connection. +/// Used to enforce idle timeouts and disconnect inactive clients to prevent resource exhaustion. +private actor ConnectionActivity { + private var lastSeen: ContinuousClock.Instant + + init() { + lastSeen = ContinuousClock.now + } + + func ping() { + lastSeen = ContinuousClock.now + } + + func idle(after duration: Duration) -> Bool { + ContinuousClock.now - lastSeen >= duration + } +} + +extension DNSServer { + /// Handles a single active TCP connection from an inbound client. + /// + /// This method manages the lifecycle of the connection, reading length-prefixed DNS queries + /// iteratively and executing the underlying `processRaw` logic for each query concurrently + /// using Swift Concurrency. It enforces strict idle timeouts to prevent stale clients + /// from holding connections open indefinitely. + /// + /// - Parameter channel: The connected asynchronous TCP channel containing the message buffer streams. + func handleTCP(channel: NIOAsyncChannel) async { + do { + try await channel.executeThenClose { inbound, outbound in + let activity = ConnectionActivity() + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + let pollInterval = min(Duration.seconds(1), self.tcpIdleTimeout) + while true { + try await Task.sleep(for: pollInterval) + if await activity.idle(after: self.tcpIdleTimeout) { break } + } + self.log?.debug("TCP DNS: idle timeout, closing connection") + throw CancellationError() + } + + group.addTask { + var buffer = ByteBuffer() + for try await chunk in inbound { + buffer.writeImmutableBuffer(chunk) + + while buffer.readableBytes >= 2 { + guard + let msgLen = buffer.getInteger( + at: buffer.readerIndex, as: UInt16.self + ) + else { break } + + guard msgLen > 0, msgLen <= Self.maxTCPMessageSize else { + self.log?.error( + "TCP DNS: unexpected frame size \(msgLen) bytes, closing connection") + return + } + + let needed = 2 + Int(msgLen) + guard buffer.readableBytes >= needed else { break } + + buffer.moveReaderIndex(forwardBy: 2) + guard let msgSlice = buffer.readSlice(length: Int(msgLen)) else { break } + let msgData = Data(msgSlice.readableBytesView) + + let responseData = try await self.processRaw(data: msgData) + + await activity.ping() + + var out = ByteBuffer() + out.writeInteger(UInt16(responseData.count)) + out.writeBytes(responseData) + try await outbound.write(out) + } + } + } + + try await group.next() + group.cancelAll() + } + } + } catch is CancellationError { + } catch { + log?.warning("TCP DNS: connection error: \(error)") + } + } +} diff --git a/Sources/DNSServer/DNSServer.swift b/Sources/DNSServer/DNSServer.swift index d38e55459..c386dd720 100644 --- a/Sources/DNSServer/DNSServer.swift +++ b/Sources/DNSServer/DNSServer.swift @@ -18,25 +18,50 @@ import Foundation import Logging import NIOCore import NIOPosix +import Synchronization -/// Provides a DNS server. -/// - Parameters: -/// - host: The host address on which to listen. -/// - port: The port for the server to listen. -public struct DNSServer { +private final class ConnectionCounter: Sendable { + private let storage: Mutex = .init(0) + + func tryIncrement(limit: Int) -> Bool { + storage.withLock { count in + guard count < limit else { return false } + count += 1 + return true + } + } + + func decrement() { + storage.withLock { $0 -= 1 } + } +} + +public struct DNSServer: @unchecked Sendable { public var handler: DNSHandler let log: Logger? + static let maxConcurrentConnections = 128 + + static let maxTCPMessageSize: UInt16 = 4096 + + let tcpIdleTimeout: Duration + + private let connections: ConnectionCounter + public init( handler: DNSHandler, - log: Logger? = nil + log: Logger? = nil, + tcpIdleTimeout: Duration = .seconds(30) ) { self.handler = handler self.log = log + self.tcpIdleTimeout = tcpIdleTimeout + self.connections = ConnectionCounter() } + // MARK: - UDP + public func run(host: String, port: Int) async throws { - // TODO: TCP server let srv = try await DatagramBootstrap(group: NIOSingletons.posixEventLoopGroup) .channelOption(.socketOption(.so_reuseaddr), value: 1) .bind(host: host, port: port) @@ -59,7 +84,6 @@ public struct DNSServer { } public func run(socketPath: String) async throws { - // TODO: TCP server let srv = try await DatagramBootstrap(group: NIOSingletons.posixEventLoopGroup) .bind(unixDomainSocketPath: socketPath, cleanupExistingSocketFile: true) .flatMapThrowing { channel in @@ -82,5 +106,80 @@ public struct DNSServer { } } + // MARK: - TCP + + public func runTCP(host: String, port: Int) async throws { + let server = try await ServerBootstrap(group: NIOSingletons.posixEventLoopGroup) + .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) + .bind( + host: host, + port: port + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: .init( + inboundType: ByteBuffer.self, + outboundType: ByteBuffer.self + ) + ) + } + } + + try await server.executeThenClose { inbound in + try await withThrowingDiscardingTaskGroup { group in + for try await child in inbound { + guard connections.tryIncrement(limit: Self.maxConcurrentConnections) else { + log?.warning( + "TCP DNS: connection limit (\(Self.maxConcurrentConnections)) reached, dropping connection") + try? await child.channel.close() + continue + } + + group.addTask { + defer { self.connections.decrement() } + await self.handleTCP(channel: child) + } + } + } + } + } + + public func runTCP(socketPath: String) async throws { + try? FileManager.default.removeItem(atPath: socketPath) + + let address = try SocketAddress(unixDomainSocketPath: socketPath) + let server = try await ServerBootstrap(group: NIOSingletons.posixEventLoopGroup) + .bind(to: address) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: .init( + inboundType: ByteBuffer.self, + outboundType: ByteBuffer.self + ) + ) + } + } + + try await server.executeThenClose { inbound in + try await withThrowingDiscardingTaskGroup { group in + for try await child in inbound { + guard connections.tryIncrement(limit: Self.maxConcurrentConnections) else { + log?.warning( + "TCP DNS: connection limit (\(Self.maxConcurrentConnections)) reached, dropping connection") + try? await child.channel.close() + continue + } + + group.addTask { + defer { self.connections.decrement() } + await self.handleTCP(channel: child) + } + } + } + } + } + public func stop() async throws {} } diff --git a/Sources/Helpers/APIServer/APIServer+Start.swift b/Sources/Helpers/APIServer/APIServer+Start.swift index b4df3b09f..3c71d5d07 100644 --- a/Sources/Helpers/APIServer/APIServer+Start.swift +++ b/Sources/Helpers/APIServer/APIServer+Start.swift @@ -91,71 +91,43 @@ extension APIServer { $0[$1.key.rawValue] = $1.value }), log: log) - await withTaskGroup(of: Result.self) { group in + try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { log.info("starting XPC server") - do { - try await server.listen() - return .success(()) - } catch { - return .failure(error) - } + try await server.listen() } - // start up host table DNS + // start up host table DNS (UDP and TCP) + let hostsResolver = ContainerDNSHandler(networkService: networkService) + let nxDomainResolver = NxDomainResolver() + let compositeResolver = CompositeResolver(handlers: [hostsResolver, nxDomainResolver]) + let hostsQueryValidator = StandardQueryValidator(handler: compositeResolver) + let dnsServer: DNSServer = DNSServer(handler: hostsQueryValidator, log: log) + group.addTask { - let hostsResolver = ContainerDNSHandler(networkService: networkService) - let nxDomainResolver = NxDomainResolver() - let compositeResolver = CompositeResolver(handlers: [hostsResolver, nxDomainResolver]) - let hostsQueryValidator = StandardQueryValidator(handler: compositeResolver) - let dnsServer: DNSServer = DNSServer(handler: hostsQueryValidator, log: log) log.info( - "starting DNS resolver for container hostnames", + "starting DNS resolver for container hostnames (UDP)", metadata: [ "host": "\(Self.listenAddress)", "port": "\(Self.dnsPort)", ] ) - do { - try await dnsServer.run(host: Self.listenAddress, port: Self.dnsPort) - return .success(()) - } catch { - return .failure(error) - } - + try await dnsServer.run(host: Self.listenAddress, port: Self.dnsPort) } - // start up realhost DNS group.addTask { - do { - let localhostResolver = LocalhostDNSHandler(log: log) - await localhostResolver.monitorResolvers() - - let nxDomainResolver = NxDomainResolver() - let compositeResolver = CompositeResolver(handlers: [localhostResolver, nxDomainResolver]) - let hostsQueryValidator = StandardQueryValidator(handler: compositeResolver) - let dnsServer: DNSServer = DNSServer(handler: hostsQueryValidator, log: log) - log.info( - "starting DNS resolver for localhost", - metadata: [ - "host": "\(Self.listenAddress)", - "port": "\(Self.localhostDNSPort)", - ] - ) - try await dnsServer.run(host: Self.listenAddress, port: Self.localhostDNSPort) - return .success(()) - } catch { - return .failure(error) - } + log.info( + "starting DNS resolver for container hostnames (TCP)", + metadata: [ + "host": "\(Self.listenAddress)", + "port": "\(Self.dnsPort)", + ] + ) + try await dnsServer.runTCP(host: Self.listenAddress, port: Self.dnsPort) } - for await result in group { - switch result { - case .success(): - continue - case .failure(let error): - log.error("API server task failed: \(error)") - } + for try await _ in group { + continue } } } catch { diff --git a/Sources/Helpers/APIServer/ContainerDNSHandler.swift b/Sources/Helpers/APIServer/ContainerDNSHandler.swift index 4fcf8a2c4..d63df836c 100644 --- a/Sources/Helpers/APIServer/ContainerDNSHandler.swift +++ b/Sources/Helpers/APIServer/ContainerDNSHandler.swift @@ -18,7 +18,6 @@ import ContainerAPIService import DNS import DNSServer -/// Handler that uses table lookup to resolve hostnames. struct ContainerDNSHandler: DNSHandler { private let networkService: NetworksService private let ttl: UInt32 @@ -37,10 +36,6 @@ struct ContainerDNSHandler: DNSHandler { case ResourceRecordType.host6: let result = try await answerHost6(question: question) if result.record == nil && result.hostnameExists { - // Return NODATA (noError with empty answers) when hostname exists but has no IPv6. - // This is required because musl libc has issues when A record exists but AAAA returns NXDOMAIN. - // musl treats NXDOMAIN on AAAA as "domain doesn't exist" and fails DNS resolution entirely. - // NODATA correctly indicates "no IPv6 address available, but domain exists". return Message( id: query.id, type: .response, @@ -91,7 +86,11 @@ struct ContainerDNSHandler: DNSHandler { } private func answerHost(question: Question) async throws -> ResourceRecord? { - guard let ipAllocation = try await networkService.lookup(hostname: question.name) else { + var hostname = question.name + if hostname.hasSuffix(".") { + hostname.removeLast() + } + guard let ipAllocation = try await networkService.lookup(hostname: hostname) else { return nil } let ipv4 = ipAllocation.ipv4Address.address.description @@ -103,7 +102,11 @@ struct ContainerDNSHandler: DNSHandler { } private func answerHost6(question: Question) async throws -> (record: ResourceRecord?, hostnameExists: Bool) { - guard let ipAllocation = try await networkService.lookup(hostname: question.name) else { + var hostname = question.name + if hostname.hasSuffix(".") { + hostname.removeLast() + } + guard let ipAllocation = try await networkService.lookup(hostname: hostname) else { return (nil, false) } guard let ipv6Address = ipAllocation.ipv6Address else { diff --git a/Tests/DNSServerTests/TCPHandleTest.swift b/Tests/DNSServerTests/TCPHandleTest.swift new file mode 100644 index 000000000..efb8f8374 --- /dev/null +++ b/Tests/DNSServerTests/TCPHandleTest.swift @@ -0,0 +1,344 @@ +//===----------------------------------------------------------------------===// +// Copyright © 2025-2026 Apple Inc. and the container project authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// + +import DNS +import Foundation +import NIOCore +import NIOPosix +import Testing + +@testable import DNSServer + +struct ProcessRawTest { + @Test func testProcessRawReturnsSerializedResponse() async throws { + let handler = HostTableResolver(hosts4: ["myhost.": IPv4("10.0.0.1")!]) + let server = DNSServer(handler: StandardQueryValidator(handler: handler)) + + let query = Message( + id: 99, + type: .query, + questions: [Question(name: "myhost.", type: .host)] + ) + let responseData = try await server.processRaw(data: query.serialize()) + let response = try Message(deserialize: responseData) + + #expect(99 == response.id) + #expect(.response == response.type) + #expect(.noError == response.returnCode) + #expect(1 == response.answers.count) + let record = response.answers.first as? HostRecord + #expect(IPv4("10.0.0.1") == record?.ip) + } + + @Test func testProcessRawReturnsNxdomainForUnknownHost() async throws { + let handler = HostTableResolver(hosts4: ["known.": IPv4("10.0.0.1")!]) + let composite = CompositeResolver(handlers: [handler, NxDomainResolver()]) + let server = DNSServer(handler: StandardQueryValidator(handler: composite)) + + let query = Message( + id: 55, + type: .query, + questions: [Question(name: "unknown.", type: .host)] + ) + let responseData = try await server.processRaw(data: query.serialize()) + let response = try Message(deserialize: responseData) + + #expect(55 == response.id) + #expect(.nonExistentDomain == response.returnCode) + #expect(0 == response.answers.count) + } +} + +struct TCPHandleTest { + private func makeTCPFrame(hostname: String, id: UInt16 = 1) throws -> ByteBuffer { + let bytes = try Message( + id: id, + type: .query, + questions: [Question(name: hostname, type: .host)] + ).serialize() + var buf = ByteBuffer() + buf.writeInteger(UInt16(bytes.count)) + buf.writeBytes(bytes) + return buf + } + + @Test func testTCPRoundTrip() async throws { + let handler = HostTableResolver(hosts4: ["tcp-test.": IPv4("5.6.7.8")!]) + let server = DNSServer( + handler: StandardQueryValidator(handler: handler), + tcpIdleTimeout: .seconds(2) + ) + + let listener = try await ServerBootstrap(group: NIOSingletons.posixEventLoopGroup) + .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) + .bind( + host: "127.0.0.1", + port: 0 + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: .init( + inboundType: ByteBuffer.self, + outboundType: ByteBuffer.self + ) + ) + } + } + let port = listener.channel.localAddress!.port! + + async let serverDone: Void = listener.executeThenClose { inbound in + for try await child in inbound { + await server.handleTCP(channel: child) + break + } + } + + let client = try await ClientBootstrap(group: NIOSingletons.posixEventLoopGroup) + .connect( + host: "127.0.0.1", + port: port + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: .init( + inboundType: ByteBuffer.self, + outboundType: ByteBuffer.self + ) + ) + } + } + + var response: Message? + try await client.executeThenClose { inbound, outbound in + try await outbound.write(try makeTCPFrame(hostname: "tcp-test.", id: 77)) + for try await var chunk in inbound { + // Safe in-process: the entire response will arrive in one chunk + guard let len = chunk.readInteger(as: UInt16.self) else { break } + guard let slice = chunk.readSlice(length: Int(len)) else { break } + response = try Message(deserialize: Data(slice.readableBytesView)) + break + } + } + + try? await listener.channel.close() + try? await serverDone + + #expect(77 == response?.id) + #expect(.noError == response?.returnCode) + #expect(1 == response?.answers.count) + let record = response?.answers.first as? HostRecord + #expect(IPv4("5.6.7.8") == record?.ip) + } + + @Test func testTCPPipelinedQueries() async throws { + let handler = HostTableResolver(hosts4: [ + "first.": IPv4("1.1.1.1")!, + "second.": IPv4("2.2.2.2")!, + ]) + let server = DNSServer( + handler: StandardQueryValidator(handler: handler), + tcpIdleTimeout: .seconds(2) + ) + + let listener = try await ServerBootstrap(group: NIOSingletons.posixEventLoopGroup) + .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) + .bind( + host: "127.0.0.1", + port: 0 + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: .init( + inboundType: ByteBuffer.self, + outboundType: ByteBuffer.self + ) + ) + } + } + let port = listener.channel.localAddress!.port! + + async let serverDone: Void = listener.executeThenClose { inbound in + for try await child in inbound { + await server.handleTCP(channel: child) + break + } + } + + let client = try await ClientBootstrap(group: NIOSingletons.posixEventLoopGroup) + .connect( + host: "127.0.0.1", + port: port + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: .init( + inboundType: ByteBuffer.self, + outboundType: ByteBuffer.self + ) + ) + } + } + + var responses: [Message] = [] + try await client.executeThenClose { inbound, outbound in + var combined = try makeTCPFrame(hostname: "first.", id: 10) + combined.writeImmutableBuffer(try makeTCPFrame(hostname: "second.", id: 20)) + try await outbound.write(combined) + + var accumulator = ByteBuffer() + for try await chunk in inbound { + accumulator.writeImmutableBuffer(chunk) + while accumulator.readableBytes >= 2 { + guard let len = accumulator.getInteger(at: accumulator.readerIndex, as: UInt16.self) else { break } + guard accumulator.readableBytes >= 2 + Int(len) else { break } + accumulator.moveReaderIndex(forwardBy: 2) + guard let slice = accumulator.readSlice(length: Int(len)) else { break } + responses.append(try Message(deserialize: Data(slice.readableBytesView))) + } + if responses.count == 2 { break } + } + } + + try? await listener.channel.close() + try? await serverDone + + #expect(2 == responses.count) + #expect(10 == responses[0].id) + #expect(20 == responses[1].id) + let a1 = responses[0].answers.first as? HostRecord + let a2 = responses[1].answers.first as? HostRecord + #expect(IPv4("1.1.1.1") == a1?.ip) + #expect(IPv4("2.2.2.2") == a2?.ip) + } + + @Test func testTCPDropsOversizedFrame() async throws { + let handler = HostTableResolver(hosts4: ["oversize.": IPv4("1.1.1.1")!]) + let server = DNSServer( + handler: StandardQueryValidator(handler: handler), + tcpIdleTimeout: .seconds(2) + ) + + let listener = try await ServerBootstrap(group: NIOSingletons.posixEventLoopGroup) + .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) + .bind(host: "127.0.0.1", port: 0) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: .init(inboundType: ByteBuffer.self, outboundType: ByteBuffer.self) + ) + } + } + let port = listener.channel.localAddress!.port! + + async let serverDone: Void = listener.executeThenClose { inbound in + for try await child in inbound { + await server.handleTCP(channel: child) + break + } + } + + let client = try await ClientBootstrap(group: NIOSingletons.posixEventLoopGroup) + .connect(host: "127.0.0.1", port: port) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: .init(inboundType: ByteBuffer.self, outboundType: ByteBuffer.self) + ) + } + } + + var receivedChunks = 0 + do { + try await client.executeThenClose { inbound, outbound in + var buf = ByteBuffer() + buf.writeInteger(UInt16(DNSServer.maxTCPMessageSize + 1)) + // The server inspects only the 2-byte length prefix before closing. + // The payload bytes that follow are never read. + buf.writeBytes([UInt8](repeating: 0, count: 10)) + try await outbound.write(buf) + + for try await _ in inbound { + receivedChunks += 1 + } + } + } catch { + print("testTCPDropsOversizedFrame: connection closed with error (expected): \(error)") + } + + try? await listener.channel.close() + try? await serverDone + + #expect(receivedChunks == 0, "Expected server to drop connection without responding to oversized frame") + } + + @Test func testTCPIdleTimeoutDropsConnection() async throws { + let handler = HostTableResolver(hosts4: ["idle.": IPv4("1.1.1.1")!]) + let server = DNSServer( + handler: StandardQueryValidator(handler: handler), + tcpIdleTimeout: .milliseconds(100) + ) + + let listener = try await ServerBootstrap(group: NIOSingletons.posixEventLoopGroup) + .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) + .bind(host: "127.0.0.1", port: 0) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: .init(inboundType: ByteBuffer.self, outboundType: ByteBuffer.self) + ) + } + } + let port = listener.channel.localAddress!.port! + + async let serverDone: Void = listener.executeThenClose { inbound in + for try await child in inbound { + await server.handleTCP(channel: child) + break + } + } + + let client = try await ClientBootstrap(group: NIOSingletons.posixEventLoopGroup) + .connect(host: "127.0.0.1", port: port) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: .init(inboundType: ByteBuffer.self, outboundType: ByteBuffer.self) + ) + } + } + + var receivedChunks = 0 + do { + try await client.executeThenClose { inbound, outbound in + try await Task.sleep(for: .milliseconds(300)) + for try await _ in inbound { + receivedChunks += 1 + } + } + } catch { + print("testTCPIdleTimeoutDropsConnection: connection closed with error (expected): \(error)") + } + + try? await listener.channel.close() + try? await serverDone + + #expect(receivedChunks == 0, "Expected server to drop connection due to idle timeout") + } +}