Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 12 additions & 19 deletions Sources/DNSServer/DNSServer+Handle.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,27 @@ import NIOCore
import NIOPosix

extension DNSServer {
/// Handles the DNS request.
/// - Parameters:
/// - outbound: The NIOAsyncChannelOutboundWriter for which to respond.
/// - packet: The request packet.
/// Handles a UDP DNS request and writes the response back to the sender.
func handle(
outbound: NIOAsyncChannelOutboundWriter<AddressedEnvelope<ByteBuffer>>,
packet: inout AddressedEnvelope<ByteBuffer>
) async throws {
let chunkSize = 512
var data = Data()
let data = Data(packet.data.readableBytesView)

self.log?.debug("reading data")
while packet.data.readableBytes > 0 {
if let chunk = packet.data.readBytes(length: min(chunkSize, packet.data.readableBytes)) {
data.append(contentsOf: chunk)
}
}
self.log?.debug("sending response for request")
let responseData = try await self.processRaw(data: data)
let rData = ByteBuffer(bytes: responseData)
try? await outbound.write(AddressedEnvelope(remoteAddress: packet.remoteAddress, data: rData))
self.log?.debug("processing done")
}

/// Deserializes a raw DNS query, runs it through the handler chain, and returns
/// the serialized response bytes. Used by both UDP and TCP transports.
func processRaw(data: Data) async throws -> Data {
self.log?.debug("deserializing message")
let query = try Message(deserialize: data)
self.log?.debug("processing query: \(query.questions)")

// always send response
let responseData: Data
do {
self.log?.debug("awaiting processing")
Expand Down Expand Up @@ -76,11 +74,6 @@ extension DNSServer {
responseData = try response.serialize()
}

self.log?.debug("sending response for \(query.id)")
let rData = ByteBuffer(bytes: responseData)
try? await outbound.write(AddressedEnvelope(remoteAddress: packet.remoteAddress, data: rData))

self.log?.debug("processing done")

return responseData
}
}
110 changes: 110 additions & 0 deletions Sources/DNSServer/DNSServer+TCPHandle.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
//===----------------------------------------------------------------------===//
// Copyright © 2025-2026 Apple Inc. and the container project authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//===----------------------------------------------------------------------===//

import Foundation
import NIOCore
import NIOPosix
import Synchronization

/// Tracks the activity timestamps for a specific TCP connection.
/// Used to enforce idle timeouts and disconnect inactive clients to prevent resource exhaustion.
private actor ConnectionActivity {
private var lastSeen: ContinuousClock.Instant

init() {
lastSeen = ContinuousClock.now
}

func ping() {
lastSeen = ContinuousClock.now
}

func idle(after duration: Duration) -> Bool {
ContinuousClock.now - lastSeen >= duration
}
}

extension DNSServer {
/// Handles a single active TCP connection from an inbound client.
///
/// This method manages the lifecycle of the connection, reading length-prefixed DNS queries
/// iteratively and executing the underlying `processRaw` logic for each query concurrently
/// using Swift Concurrency. It enforces strict idle timeouts to prevent stale clients
/// from holding connections open indefinitely.
///
/// - Parameter channel: The connected asynchronous TCP channel containing the message buffer streams.
func handleTCP(channel: NIOAsyncChannel<ByteBuffer, ByteBuffer>) async {
do {
try await channel.executeThenClose { inbound, outbound in
let activity = ConnectionActivity()
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
let pollInterval = min(Duration.seconds(1), self.tcpIdleTimeout)
while true {
try await Task.sleep(for: pollInterval)
if await activity.idle(after: self.tcpIdleTimeout) { break }
}
self.log?.debug("TCP DNS: idle timeout, closing connection")
throw CancellationError()
}

group.addTask {
var buffer = ByteBuffer()
for try await chunk in inbound {
buffer.writeImmutableBuffer(chunk)

while buffer.readableBytes >= 2 {
guard
let msgLen = buffer.getInteger(
at: buffer.readerIndex, as: UInt16.self
)
else { break }

guard msgLen > 0, msgLen <= Self.maxTCPMessageSize else {
self.log?.error(
"TCP DNS: unexpected frame size \(msgLen) bytes, closing connection")
return
}

let needed = 2 + Int(msgLen)
guard buffer.readableBytes >= needed else { break }

buffer.moveReaderIndex(forwardBy: 2)
guard let msgSlice = buffer.readSlice(length: Int(msgLen)) else { break }
let msgData = Data(msgSlice.readableBytesView)

let responseData = try await self.processRaw(data: msgData)

await activity.ping()

var out = ByteBuffer()
out.writeInteger(UInt16(responseData.count))
out.writeBytes(responseData)
try await outbound.write(out)
}
}
}

try await group.next()
group.cancelAll()
}
}
} catch is CancellationError {
} catch {
log?.warning("TCP DNS: connection error: \(error)")
}
}
}
115 changes: 107 additions & 8 deletions Sources/DNSServer/DNSServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,50 @@ import Foundation
import Logging
import NIOCore
import NIOPosix
import Synchronization

/// Provides a DNS server.
/// - Parameters:
/// - host: The host address on which to listen.
/// - port: The port for the server to listen.
public struct DNSServer {
private final class ConnectionCounter: Sendable {
private let storage: Mutex<Int> = .init(0)

func tryIncrement(limit: Int) -> Bool {
storage.withLock { count in
guard count < limit else { return false }
count += 1
return true
}
}

func decrement() {
storage.withLock { $0 -= 1 }
}
}

public struct DNSServer: @unchecked Sendable {
public var handler: DNSHandler
let log: Logger?

static let maxConcurrentConnections = 128

static let maxTCPMessageSize: UInt16 = 4096

let tcpIdleTimeout: Duration

private let connections: ConnectionCounter

public init(
handler: DNSHandler,
log: Logger? = nil
log: Logger? = nil,
tcpIdleTimeout: Duration = .seconds(30)
) {
self.handler = handler
self.log = log
self.tcpIdleTimeout = tcpIdleTimeout
self.connections = ConnectionCounter()
}

// MARK: - UDP

public func run(host: String, port: Int) async throws {
// TODO: TCP server
let srv = try await DatagramBootstrap(group: NIOSingletons.posixEventLoopGroup)
.channelOption(.socketOption(.so_reuseaddr), value: 1)
.bind(host: host, port: port)
Expand All @@ -59,7 +84,6 @@ public struct DNSServer {
}

public func run(socketPath: String) async throws {
// TODO: TCP server
let srv = try await DatagramBootstrap(group: NIOSingletons.posixEventLoopGroup)
.bind(unixDomainSocketPath: socketPath, cleanupExistingSocketFile: true)
.flatMapThrowing { channel in
Expand All @@ -82,5 +106,80 @@ public struct DNSServer {
}
}

// MARK: - TCP

public func runTCP(host: String, port: Int) async throws {
let server = try await ServerBootstrap(group: NIOSingletons.posixEventLoopGroup)
.serverChannelOption(.socketOption(.so_reuseaddr), value: 1)
.bind(
host: host,
port: port
) { channel in
channel.eventLoop.makeCompletedFuture {
try NIOAsyncChannel(
wrappingChannelSynchronously: channel,
configuration: .init(
inboundType: ByteBuffer.self,
outboundType: ByteBuffer.self
)
)
}
}

try await server.executeThenClose { inbound in
try await withThrowingDiscardingTaskGroup { group in
for try await child in inbound {
guard connections.tryIncrement(limit: Self.maxConcurrentConnections) else {
log?.warning(
"TCP DNS: connection limit (\(Self.maxConcurrentConnections)) reached, dropping connection")
try? await child.channel.close()
continue
}

group.addTask {
defer { self.connections.decrement() }
await self.handleTCP(channel: child)
}
}
}
}
}

public func runTCP(socketPath: String) async throws {
try? FileManager.default.removeItem(atPath: socketPath)

let address = try SocketAddress(unixDomainSocketPath: socketPath)
let server = try await ServerBootstrap(group: NIOSingletons.posixEventLoopGroup)
.bind(to: address) { channel in
channel.eventLoop.makeCompletedFuture {
try NIOAsyncChannel(
wrappingChannelSynchronously: channel,
configuration: .init(
inboundType: ByteBuffer.self,
outboundType: ByteBuffer.self
)
)
}
}

try await server.executeThenClose { inbound in
try await withThrowingDiscardingTaskGroup { group in
for try await child in inbound {
guard connections.tryIncrement(limit: Self.maxConcurrentConnections) else {
log?.warning(
"TCP DNS: connection limit (\(Self.maxConcurrentConnections)) reached, dropping connection")
try? await child.channel.close()
continue
}

group.addTask {
defer { self.connections.decrement() }
await self.handleTCP(channel: child)
}
}
}
}
}

public func stop() async throws {}
}
Loading