Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 96 additions & 57 deletions conn.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package netlink

import (
"iter"
"math/rand"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -53,6 +54,7 @@ type Socket interface {
Send(m Message) error
SendMessages(m []Message) error
Receive() ([]Message, error)
ReceiveIter() iter.Seq2[Message, error]
}

// Dial dials a connection to netlink, using the specified netlink family.
Expand Down Expand Up @@ -232,81 +234,118 @@ func (c *Conn) Receive() ([]Message, error) {
return c.lockedReceive()
}

// ReceiveIter returns an iterator which can be used to receive messages from
// netlink. Just like Receive, multi-part messages are handled transparently and
// netlink errors are returned as errors from the iterator.
//
// If the iteration is stopped before all messages have been read and the
// response is multi-part, the remaining messages will be discarded.
func (c *Conn) ReceiveIter() iter.Seq2[Message, error] {
return func(yield func(Message, error) bool) {
c.mu.Lock()
defer c.mu.Unlock()

for msg, err := range c.lockedReceiveIter() {
if err != nil {
c.debug(func(d *debugger) {
d.debugf(1, "recv: err: %v", err)
})
yield(Message{}, err)
return
}

c.debug(func(d *debugger) {
d.debugf(1, "recv: %+v", msg)
})
if !yield(msg, nil) {
return
}
}
}
}

// lockedReceive implements Receive, but must be called with c.mu acquired for reading.
// We rely on the kernel to deal with concurrent reads and writes to the netlink
// socket itself.
func (c *Conn) lockedReceive() ([]Message, error) {
msgs, err := c.receive()
if err != nil {
c.debug(func(d *debugger) {
d.debugf(1, "recv: err: %v", err)
})

return nil, err
}
var msgs []Message

c.debug(func(d *debugger) {
for _, m := range msgs {
d.debugf(1, "recv: %+v", m)
for m, err := range c.lockedReceiveIter() {
if err != nil {
c.debug(func(d *debugger) {
d.debugf(1, "recv: err: %v", err)
})
return nil, err
}
})

// When using nltest, it's possible for zero messages to be returned by receive.
if len(msgs) == 0 {
return msgs, nil
}
c.debug(func(d *debugger) {
d.debugf(1, "recv: %+v", m)
})

// Trim the final message with multi-part done indicator if
// present.
if m := msgs[len(msgs)-1]; m.Header.Flags&Multi != 0 && m.Header.Type == Done {
return msgs[:len(msgs)-1], nil
msgs = append(msgs, m)
}

return msgs, nil
}

// receive is the internal implementation of Conn.Receive, which can be called
// recursively to handle multi-part messages.
func (c *Conn) receive() ([]Message, error) {
// NB: All non-nil errors returned from this function *must* be of type
// OpError in order to maintain the appropriate contract with callers of
// this package.
//
// This contract also applies to functions called within this function,
// such as checkMessage.

var res []Message
for {
msgs, err := c.sock.Receive()
if err != nil {
return nil, newOpError("receive", err)
}

// If this message is multi-part, we will need to continue looping to
// drain all the messages from the socket.
var multi bool

for _, m := range msgs {
if err := checkMessage(m); err != nil {
return nil, err
// lockedReceiveIter returns an iterator which can be used to receive messages
// from netlink, but must be called with c.mu acquired for the duration of the
// iteration.
func (c *Conn) lockedReceiveIter() iter.Seq2[Message, error] {
return func(yield func(Message, error) bool) {
// NB: All non-nil errors returned from this function *must* be of type
// OpError in order to maintain the appropriate contract with callers of
// this package.
//
// This contract also applies to functions called within this function,
// such as checkMessage.

var more, stopped bool
// send is a helper function to prevent yielding messages after the user
// has stopped iterating
var send = func(m Message, err error) {
if stopped {
return
}

// Does this message indicate a multi-part message?
if m.Header.Flags&Multi == 0 {
// No, check the next messages.
continue
if !yield(m, err) {
stopped = true
}

// Does this message indicate the last message in a series of
// multi-part messages from a single read?
multi = m.Header.Type != Done
}

res = append(res, msgs...)
for {
for m, err := range c.sock.ReceiveIter() {
if err != nil {
send(Message{}, newOpError("receive", err))
return
}

if err := checkMessage(m); err != nil {
send(Message{}, err)
return
}

// Exit early if we encounter a multi-part done message.
// This should be safe to do since messages of type Done should always
// be the last message in a datagram.
if m.Header.Type == Done && m.Header.Flags&Multi != 0 {
return
}

if m.Header.Flags&Multi != 0 {
more = true
}

send(m, nil)
if stopped && !more {
// The user has stopped iterating and there are no more messages
// to read.
return
}
}

if !multi {
// No more messages coming.
return res, nil
if !more {
return
}
}
}
}
Expand Down
75 changes: 49 additions & 26 deletions conn_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package netlink

import (
"context"
"iter"
"os"
"syscall"
"time"
Expand Down Expand Up @@ -121,42 +122,64 @@ func (c *conn) Send(m Message) error {

// Receive receives one or more Messages from netlink.
func (c *conn) Receive() ([]Message, error) {
// Peek at the buffer to see how many bytes are available.
//
// TODO(mdlayher): deal with OOB message data if available, such as
// when PacketInfo ConnOption is true.
n, _, _, _, err := c.s.Recvmsg(context.Background(), nil, nil, unix.MSG_PEEK|unix.MSG_TRUNC)
if err != nil {
return nil, err
}
// Request buffer for the expected size.
b := make([]byte, n)

// Read out all available messages
n, _, flags, _, err := c.s.Recvmsg(context.Background(), b, nil, 0)
if err != nil {
return nil, err
}

if flags&unix.MSG_TRUNC != 0 {
// Our buffer was too small to read the entire message,
// this should not happen since we peeked above, but if it does,
// return an error.
return nil, unix.ENOSPC
}

var msgs []Message
for msg, err := range parseMessagesIter(b[:nlmsgAlign(n)]) {
for msg, err := range c.ReceiveIter() {
if err != nil {
return nil, err
}

msgs = append(msgs, msg)
}

return msgs, nil
}

// getMsgBufferSize peeks at the upcoming message to determine the size of the
// buffer needed to read it.
func (c *conn) getMsgBufferSize() (int, error) {
n, _, _, _, err := c.s.Recvmsg(context.Background(), nil, nil, unix.MSG_PEEK|unix.MSG_TRUNC)
return n, err
}

// ReceiveIter returns an iterator over Messages received from netlink.
func (c *conn) ReceiveIter() iter.Seq2[Message, error] {
return func(yield func(Message, error) bool) {
n, err := c.getMsgBufferSize()
if err != nil {
yield(Message{}, err)
return
}
b := make([]byte, n)

// Read out all available messages
// TODO(mdlayher): deal with OOB message data if available, such as
// when PacketInfo ConnOption is true.
n, _, flags, _, err := c.s.Recvmsg(context.Background(), b, nil, 0)
if err != nil {
yield(Message{}, err)
return
}

if flags&unix.MSG_TRUNC != 0 {
// Our buffer was too small to read the entire message,
// this should not happen since we peeked above, but if it does,
// return an error.
yield(Message{}, unix.ENOSPC)
return
}

for msg, err := range parseMessagesIter(b[:nlmsgAlign(n)]) {
if err != nil {
yield(Message{}, err)
return
}

if !yield(msg, nil) {
return
}
}
}
}

// Close closes the connection.
func (c *conn) Close() error { return c.s.Close() }

Expand Down
Loading
Loading