diff --git a/conn.go b/conn.go index 5ea804c..bf10477 100644 --- a/conn.go +++ b/conn.go @@ -1,6 +1,7 @@ package netlink import ( + "iter" "math/rand" "sync" "sync/atomic" @@ -53,6 +54,7 @@ type Socket interface { Send(m Message) error SendMessages(m []Message) error Receive() ([]Message, error) + ReceiveIter() iter.Seq2[Message, error] } // Dial dials a connection to netlink, using the specified netlink family. @@ -232,81 +234,118 @@ func (c *Conn) Receive() ([]Message, error) { return c.lockedReceive() } +// ReceiveIter returns an iterator which can be used to receive messages from +// netlink. Just like Receive, multi-part messages are handled transparently and +// netlink errors are returned as errors from the iterator. +// +// If the iteration is stopped before all messages have been read and the +// response is multi-part, the remaining messages will be discarded. +func (c *Conn) ReceiveIter() iter.Seq2[Message, error] { + return func(yield func(Message, error) bool) { + c.mu.Lock() + defer c.mu.Unlock() + + for msg, err := range c.lockedReceiveIter() { + if err != nil { + c.debug(func(d *debugger) { + d.debugf(1, "recv: err: %v", err) + }) + yield(Message{}, err) + return + } + + c.debug(func(d *debugger) { + d.debugf(1, "recv: %+v", msg) + }) + if !yield(msg, nil) { + return + } + } + } +} + // lockedReceive implements Receive, but must be called with c.mu acquired for reading. // We rely on the kernel to deal with concurrent reads and writes to the netlink // socket itself. func (c *Conn) lockedReceive() ([]Message, error) { - msgs, err := c.receive() - if err != nil { - c.debug(func(d *debugger) { - d.debugf(1, "recv: err: %v", err) - }) - - return nil, err - } + var msgs []Message - c.debug(func(d *debugger) { - for _, m := range msgs { - d.debugf(1, "recv: %+v", m) + for m, err := range c.lockedReceiveIter() { + if err != nil { + c.debug(func(d *debugger) { + d.debugf(1, "recv: err: %v", err) + }) + return nil, err } - }) - // When using nltest, it's possible for zero messages to be returned by receive. - if len(msgs) == 0 { - return msgs, nil - } + c.debug(func(d *debugger) { + d.debugf(1, "recv: %+v", m) + }) - // Trim the final message with multi-part done indicator if - // present. - if m := msgs[len(msgs)-1]; m.Header.Flags&Multi != 0 && m.Header.Type == Done { - return msgs[:len(msgs)-1], nil + msgs = append(msgs, m) } return msgs, nil } -// receive is the internal implementation of Conn.Receive, which can be called -// recursively to handle multi-part messages. -func (c *Conn) receive() ([]Message, error) { - // NB: All non-nil errors returned from this function *must* be of type - // OpError in order to maintain the appropriate contract with callers of - // this package. - // - // This contract also applies to functions called within this function, - // such as checkMessage. - - var res []Message - for { - msgs, err := c.sock.Receive() - if err != nil { - return nil, newOpError("receive", err) - } - - // If this message is multi-part, we will need to continue looping to - // drain all the messages from the socket. - var multi bool - - for _, m := range msgs { - if err := checkMessage(m); err != nil { - return nil, err +// lockedReceiveIter returns an iterator which can be used to receive messages +// from netlink, but must be called with c.mu acquired for the duration of the +// iteration. +func (c *Conn) lockedReceiveIter() iter.Seq2[Message, error] { + return func(yield func(Message, error) bool) { + // NB: All non-nil errors returned from this function *must* be of type + // OpError in order to maintain the appropriate contract with callers of + // this package. + // + // This contract also applies to functions called within this function, + // such as checkMessage. + + var more, stopped bool + // send is a helper function to prevent yielding messages after the user + // has stopped iterating + var send = func(m Message, err error) { + if stopped { + return } - - // Does this message indicate a multi-part message? - if m.Header.Flags&Multi == 0 { - // No, check the next messages. - continue + if !yield(m, err) { + stopped = true } - - // Does this message indicate the last message in a series of - // multi-part messages from a single read? - multi = m.Header.Type != Done } - res = append(res, msgs...) + for { + for m, err := range c.sock.ReceiveIter() { + if err != nil { + send(Message{}, newOpError("receive", err)) + return + } + + if err := checkMessage(m); err != nil { + send(Message{}, err) + return + } + + // Exit early if we encounter a multi-part done message. + // This should be safe to do since messages of type Done should always + // be the last message in a datagram. + if m.Header.Type == Done && m.Header.Flags&Multi != 0 { + return + } + + if m.Header.Flags&Multi != 0 { + more = true + } + + send(m, nil) + if stopped && !more { + // The user has stopped iterating and there are no more messages + // to read. + return + } + } - if !multi { - // No more messages coming. - return res, nil + if !more { + return + } } } } diff --git a/conn_linux.go b/conn_linux.go index 5541071..93e3ace 100644 --- a/conn_linux.go +++ b/conn_linux.go @@ -5,6 +5,7 @@ package netlink import ( "context" + "iter" "os" "syscall" "time" @@ -121,42 +122,64 @@ func (c *conn) Send(m Message) error { // Receive receives one or more Messages from netlink. func (c *conn) Receive() ([]Message, error) { - // Peek at the buffer to see how many bytes are available. - // - // TODO(mdlayher): deal with OOB message data if available, such as - // when PacketInfo ConnOption is true. - n, _, _, _, err := c.s.Recvmsg(context.Background(), nil, nil, unix.MSG_PEEK|unix.MSG_TRUNC) - if err != nil { - return nil, err - } - // Request buffer for the expected size. - b := make([]byte, n) - - // Read out all available messages - n, _, flags, _, err := c.s.Recvmsg(context.Background(), b, nil, 0) - if err != nil { - return nil, err - } - - if flags&unix.MSG_TRUNC != 0 { - // Our buffer was too small to read the entire message, - // this should not happen since we peeked above, but if it does, - // return an error. - return nil, unix.ENOSPC - } - var msgs []Message - for msg, err := range parseMessagesIter(b[:nlmsgAlign(n)]) { + for msg, err := range c.ReceiveIter() { if err != nil { return nil, err } - msgs = append(msgs, msg) } return msgs, nil } +// getMsgBufferSize peeks at the upcoming message to determine the size of the +// buffer needed to read it. +func (c *conn) getMsgBufferSize() (int, error) { + n, _, _, _, err := c.s.Recvmsg(context.Background(), nil, nil, unix.MSG_PEEK|unix.MSG_TRUNC) + return n, err +} + +// ReceiveIter returns an iterator over Messages received from netlink. +func (c *conn) ReceiveIter() iter.Seq2[Message, error] { + return func(yield func(Message, error) bool) { + n, err := c.getMsgBufferSize() + if err != nil { + yield(Message{}, err) + return + } + b := make([]byte, n) + + // Read out all available messages + // TODO(mdlayher): deal with OOB message data if available, such as + // when PacketInfo ConnOption is true. + n, _, flags, _, err := c.s.Recvmsg(context.Background(), b, nil, 0) + if err != nil { + yield(Message{}, err) + return + } + + if flags&unix.MSG_TRUNC != 0 { + // Our buffer was too small to read the entire message, + // this should not happen since we peeked above, but if it does, + // return an error. + yield(Message{}, unix.ENOSPC) + return + } + + for msg, err := range parseMessagesIter(b[:nlmsgAlign(n)]) { + if err != nil { + yield(Message{}, err) + return + } + + if !yield(msg, nil) { + return + } + } + } +} + // Close closes the connection. func (c *conn) Close() error { return c.s.Close() } diff --git a/conn_linux_integration_test.go b/conn_linux_integration_test.go index fb52474..6aeb6da 100644 --- a/conn_linux_integration_test.go +++ b/conn_linux_integration_test.go @@ -415,6 +415,128 @@ func TestIntegrationConnConcurrentSerializeReceive(t *testing.T) { } } +// TestIntegrationConnConcurrentSerializeReceive verifies that concurrent calls +// to ReceiveIter are serialized correctly, and that a concurrent ReceiveIter +// call cannot steal multipart message fragments mid-Receive. +func TestIntegrationConnConcurrentSerializeReceiveIter(t *testing.T) { + t.Parallel() + + c, err := netlink.Dial(unix.NETLINK_GENERIC, nil) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer c.Close() + + const ( + GENL_ID_CTRL = 0x10 + CTRL_CMD_GETFAMILY = 0x03 + workers = 2 + iterations = 100 + ) + + // Request a dump to trigger a multipart response, which will require multiple + // recvmsg calls on the socket. + req := netlink.Message{ + Header: netlink.Header{ + Type: GENL_ID_CTRL, + Flags: netlink.Request | netlink.Dump, + }, + Data: []byte{CTRL_CMD_GETFAMILY, 1, 0, 0}, + } + + msgs, err := c.Execute(req) + if err != nil { + t.Fatalf("failed to execute request: %v", err) + } + want := len(msgs) + + for range iterations { + if _, err := c.Send(req); err != nil { + t.Fatalf("failed to send request: %v", err) + } + + var wg sync.WaitGroup + wg.Add(workers) + + for w := range workers { + // Each worker will try to receive the entire multipart message, but only + // one should succeed and the other should time out. + go func(worker int) { + defer wg.Done() + + if err := c.SetReadDeadline(time.Now().Add(10 * time.Millisecond)); err != nil { + panicf("failed to set deadline: %v", err) + } + + var msgs []netlink.Message + for m, err := range c.ReceiveIter() { + if errors.Is(err, os.ErrDeadlineExceeded) { + // Timed out, which means we likely had a deadlock in Receive. + // This is expected if the other worker consumed the entire + // multipart message + return + } + if err != nil { + panicf("failed to receive: %v", err) + } + msgs = append(msgs, m) + } + + if diff := cmp.Diff(want, len(msgs)); diff != "" { + panicf("unexpected message count in worker %d (-want +got):\n%s", worker, diff) + } + }(w) + } + + wg.Wait() + } +} + +func TestReceiveIter(t *testing.T) { + t.Parallel() + c, err := netlink.Dial(unix.NETLINK_GENERIC, nil) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer c.Close() + + const ( + GENL_ID_CTRL = 0x10 + CTRL_CMD_GETFAMILY = 0x03 + ) + + // Request a dump to trigger a multipart response, which will require multiple + // recvmsg calls on the socket. + req := netlink.Message{ + Header: netlink.Header{ + Type: GENL_ID_CTRL, + Flags: netlink.Request | netlink.Dump, + }, + Data: []byte{CTRL_CMD_GETFAMILY, 1, 0, 0}, + } + + want, err := c.Execute(req) + if err != nil { + t.Fatalf("failed to execute request: %v", err) + } + var got []netlink.Message + + if _, err := c.Send(req); err != nil { + t.Fatalf("failed to send request: %v", err) + } + for m, err := range c.ReceiveIter() { + if err != nil { + t.Fatalf("failed to receive message: %v", err) + } + m.Header.Sequence -= 1 + got = append(got, m) + } + + if diff := cmp.Diff(want, got); diff != "" { + t.Fatalf("unexpected messages (-want +got):\n%s", diff) + } +} + func TestIntegrationConnSetBuffersSyscallConn(t *testing.T) { tests := []struct { name string diff --git a/conn_others.go b/conn_others.go index 4c5e739..5eb24c0 100644 --- a/conn_others.go +++ b/conn_others.go @@ -5,6 +5,7 @@ package netlink import ( "fmt" + "iter" "runtime" ) @@ -28,3 +29,8 @@ func (c *conn) Send(_ Message) error { return errUnimplemented } func (c *conn) SendMessages(_ []Message) error { return errUnimplemented } func (c *conn) Receive() ([]Message, error) { return nil, errUnimplemented } func (c *conn) Close() error { return errUnimplemented } +func (c *conn) ReceiveIter() iter.Seq2[Message, error] { + return func(yield func(Message, error) bool) { + yield(Message{}, errUnimplemented) + } +} diff --git a/conn_test.go b/conn_test.go index 7fff133..62e086f 100644 --- a/conn_test.go +++ b/conn_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/mdlayher/netlink" "github.com/mdlayher/netlink/nltest" ) @@ -286,3 +287,56 @@ func TestConnSyscallConnUnsupported(t *testing.T) { t.Fatalf("unexpected error: %v", err) } } + +func TestConnReceiveIterMultipaEarlyExit(t *testing.T) { + msgs := []netlink.Message{ + { + Data: []byte{0x00, 0x00, 0x00, 0x01}, + }, + { + Data: []byte{0x00, 0x00, 0x00, 0x02}, + }, + { + Data: []byte{0x00, 0x00, 0x00, 0x03}, + }, + { + Data: []byte{0x00, 0x00, 0x00, 0x04}, + }, + {}, + } + + responded := false + c := nltest.Dial(func(_ []netlink.Message) ([]netlink.Message, error) { + if responded { + return nil, io.EOF + } + responded = true + return nltest.Multipart(msgs) + }) + defer c.Close() + + // Send a message to trigger the multipart response. + if _, err := c.Send(netlink.Message{}); err != nil { + t.Fatalf("failed to send request: %v", err) + } + + for msg, err := range c.ReceiveIter() { + if err != nil { + t.Fatalf("failed to receive messages: %v", err) + } + if diff := cmp.Diff(msgs[0], msg); diff != "" { + t.Fatalf("unexpected message received (-want +got):\n%s", diff) + } + break + } + + got, err := c.Receive() + if err != nil { + t.Fatalf("failed to receive messages: %v", err) + } + + // Early exit should have drained the buffer. + if diff := cmp.Diff([]netlink.Message(nil), got); diff != "" { + t.Fatalf("unexpected messages after early exit from multipart response (-want +got):\n%s", diff) + } +} diff --git a/nltest/nltest.go b/nltest/nltest.go index b79dc40..904f722 100644 --- a/nltest/nltest.go +++ b/nltest/nltest.go @@ -4,6 +4,7 @@ package nltest import ( "fmt" "io" + "iter" "os" "github.com/mdlayher/netlink" @@ -149,57 +150,103 @@ func (c *socket) Send(m netlink.Message) error { } func (c *socket) Receive() ([]netlink.Message, error) { - // No messages set by Send means that we are emulating a - // multicast response or an error occurred. - if len(c.msgs) == 0 { - switch c.err { - case nil: - // No error, simulate multicast, but also return EOF to simulate - // no replies if needed. - msgs, err := c.fn(nil) - if err == io.EOF { - err = nil + var msgs []netlink.Message + for msg, err := range c.ReceiveIter() { + if err != nil { + return nil, err + } + msgs = append(msgs, msg) + } + + return msgs, nil +} + +func (c *socket) ReceiveIter() iter.Seq2[netlink.Message, error] { + return func(yield func(netlink.Message, error) bool) { + // No messages set by Send means that we are emulating a + // multicast response or an error occurred. + if len(c.msgs) == 0 { + switch c.err { + case nil: + // No error, simulate multicast, but also return EOF to simulate + // no replies if needed. + msgs, err := c.fn(nil) + if err == io.EOF { + err = nil + return + } + + if err != nil { + yield(netlink.Message{}, err) + return + } + + for _, m := range msgs { + if !yield(m, nil) { + return + } + } + return + case io.EOF: + // EOF, simulate no replies in multi-part message. + return + } + + // If the error is a system call error, wrap it in os.NewSyscallError + // to simulate what the Linux netlink.Conn does. + if isSyscallError(c.err) { + err := c.err + c.err = nil + yield(netlink.Message{}, os.NewSyscallError("recvmsg", err)) + return } - return msgs, err - case io.EOF: - // EOF, simulate no replies in multi-part message. - return nil, nil + // Some generic error occurred and should be passed to the caller. + err := c.err + c.err = nil + yield(netlink.Message{}, err) + return } - // If the error is a system call error, wrap it in os.NewSyscallError - // to simulate what the Linux netlink.Conn does. - if isSyscallError(c.err) { - return nil, os.NewSyscallError("recvmsg", c.err) + // Detect multi-part messages. + var multi bool + for _, m := range c.msgs { + if m.Header.Flags&netlink.Multi != 0 && m.Header.Type != netlink.Done { + multi = true + } } - // Some generic error occurred and should be passed to the caller. - return nil, c.err - } + // When a multi-part message is detected, the messages are returned in + // batches of half the total messages, so that multiple calls to Receive or + // ReceiveIter from netlink.Conn are needed to drain all messages. + if multi { + batchSize := (len(c.msgs) + 1) / 2 + batch := c.msgs[:batchSize] + c.msgs = c.msgs[batchSize:] + + for _, m := range batch { + if !yield(m, nil) { + return + } + } - // Detect multi-part messages. - var multi bool - for _, m := range c.msgs { - if m.Header.Flags&netlink.Multi != 0 && m.Header.Type != netlink.Done { - multi = true + return } - } - // When a multi-part message is detected, return messages in batches, so that - // multiple calls to Receive from netlink.Conn are needed to receive all - // messages. - if multi { - batchSize := (len(c.msgs) + 1) / 2 - ret := c.msgs[:batchSize] - c.msgs = c.msgs[batchSize:] - - return ret, c.err - } + msgs, err := c.msgs, c.err + c.msgs, c.err = nil, nil - msgs, err := c.msgs, c.err - c.msgs, c.err = nil, nil + if err != nil { + yield(netlink.Message{}, err) + return + } - return msgs, err + for _, m := range msgs { + if !yield(m, nil) { + return + } + } + } } func panicf(format string, a ...interface{}) { diff --git a/nltest/nltest_test.go b/nltest/nltest_test.go index f2821aa..aa80e54 100644 --- a/nltest/nltest_test.go +++ b/nltest/nltest_test.go @@ -179,6 +179,87 @@ func TestConnReceiveMultipart(t *testing.T) { } } +func TestConnReceiveIterMultipart(t *testing.T) { + msgs := []netlink.Message{ + { + Data: []byte{0x00, 0x00, 0x00, 0x01}, + Header: netlink.Header{ + Flags: netlink.Multi, + }, + }, + { + Data: []byte{0x00, 0x00, 0x00, 0x02}, + Header: netlink.Header{ + Flags: netlink.Multi, + }, + }, + { + Data: []byte{0x00, 0x00, 0x00, 0x03}, + Header: netlink.Header{ + Flags: netlink.Multi, + }, + }, + { + Data: []byte{0x00, 0x00, 0x00, 0x04}, + Header: netlink.Header{ + Flags: netlink.Multi, + }, + }, + { + Header: netlink.Header{ + Type: netlink.Done, + Flags: netlink.Multi, + }, + }, + } + + responded := false + c := nltest.Dial(func(_ []netlink.Message) ([]netlink.Message, error) { + // This is necessary so that nltest does not assume that the subsequent call + // to Receive is a multicast response. This would cause it to rerun this + // callback and return the same messages again, instead of simulating no + // more messages coming. + // TODO: Is there a better way to handle mutlicast responses from nltest? + if !responded { + responded = true + return nltest.Multipart(msgs) + } + return nil, io.EOF + }) + defer c.Close() + + // Send an empty request to trigger the multipart response. + if _, err := c.Send(netlink.Message{}); err != nil { + t.Fatalf("failed to send request: %v", err) + } + + var got []netlink.Message + for msg, err := range c.ReceiveIter() { + if err != nil { + t.Fatalf("failed to receive messages: %v", err) + } + got = append(got, msg) + } + + // Expect all messages but the one with the Done type. + want := msgs[:len(msgs)-1] + if diff := cmp.Diff(want, got); diff != "" { + t.Fatalf("unexpected multipart messages (-want +got):\n%s", diff) + } + + // Any subsequent call to Receive should return no messages, since they're + // all drained from the previous call. + got, err := c.Receive() + if err != nil { + t.Fatalf("failed to receive messages: %v", err) + } + + want = nil + if diff := cmp.Diff(want, got); diff != "" { + t.Fatalf("unexpected messages after multipart response (-want +got):\n%s", diff) + } +} + func TestConnExecuteOK(t *testing.T) { req := netlink.Message{ Header: netlink.Header{