diff --git a/peer/README.md b/peer/README.md index 82631da5c..fac986936 100644 --- a/peer/README.md +++ b/peer/README.md @@ -24,8 +24,8 @@ A quick overview of the major features peer provides are as follows: - Provides a basic concurrent safe Decred peer for handling Decred communications via the peer-to-peer protocol - Full duplex reading and writing of Decred protocol messages - - Automatic handling of the initial handshake process including protocol - version negotiation + - Separate synchronous method for the initial handshake process which includes + protocol version negotiation - Asynchronous message queueing of outbound messages with optional channel for notification when the message is actually sent - Flexible peer configuration @@ -53,12 +53,12 @@ A quick overview of the major features peer provides are as follows: ## Installation and Updating -This package is part of the `github.com/decred/dcrd/peer/v2` module. Use the +This package is part of the `github.com/decred/dcrd/peer/v3` module. Use the standard go tooling for working with modules to incorporate it. ## Examples -* [New Outbound Peer Example](https://pkg.go.dev/github.com/decred/dcrd/peer/v2#example-package-NewOutboundPeer) +* [New Outbound Peer Example](https://pkg.go.dev/github.com/decred/dcrd/peer/v3#example-package-NewOutboundPeer) Demonstrates the basic process for initializing and creating an outbound peer. Peers negotiate by exchanging version and verack messages. For demonstration, a simple handler for the version message is attached to the peer. diff --git a/peer/doc.go b/peer/doc.go index 90b46a39e..ca5703007 100644 --- a/peer/doc.go +++ b/peer/doc.go @@ -1,5 +1,5 @@ // Copyright (c) 2015-2016 The btcsuite developers -// Copyright (c) 2016-2022 The Decred developers +// Copyright (c) 2016-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -61,27 +61,31 @@ This provides high flexibility for things such as connecting via proxies, acting as a proxy, creating bridge peers, choosing whether to listen for inbound peers, etc. -NewOutboundPeer and NewInboundPeer functions must be followed by calling Connect -with a net.Conn instance to the peer. This will start all async I/O goroutines -and initiate the protocol negotiation process. Once finished with the peer call -Disconnect to disconnect from the peer and clean up all resources. -WaitForDisconnect can be used to block until peer disconnection and resource -cleanup has completed. +[NewOutboundPeer] and [NewInboundPeer] must be followed by calling +[Peer.Handshake] on the returned instance to perform the initial protocol +negotiation handshake process and finally [Peer.Start] to start all async I/O +goroutines. + +[Peer.WaitForDisconnect] can be used to block until peer disconnection and +resource cleanup has completed. + +When finished with the peer call [Peer.Disconnect] to close the connection and +clean up all resources. # Callbacks In order to do anything useful with a peer, it is necessary to react to decred -messages. This is accomplished by creating an instance of the MessageListeners -struct with the callbacks to be invoke specified and setting the Listeners field -of the Config struct specified when creating a peer to it. +messages. This is accomplished by creating an instance of the [MessageListeners] +struct with the callbacks to be invoke specified and setting [Config.Listeners] +in the [Config] struct specified when creating a peer. For convenience, a callback hook for all of the currently supported decred messages is exposed which receives the peer instance and the concrete message -type. In addition, a hook for OnRead is provided so even custom messages types -for which this package does not directly provide a hook, as long as they -implement the wire.Message interface, can be used. Finally, the OnWrite hook -is provided, which in conjunction with OnRead, can be used to track server-wide -byte counts. +type. In addition, a [MessageListeners.OnRead] hook is provided so even custom +messages types for which this package does not directly provide a hook, as long +as they implement the wire.Message interface, can be used. Finally, the +[MessageListeners.OnWrite] hook is provided, which in conjunction with +[MessageListeners.OnRead], can be used to track server-wide byte counts. It is often useful to use closures which encapsulate state when specifying the callback handlers. This provides a clean method for accessing that state when @@ -89,52 +93,54 @@ callbacks are invoked. # Queuing Messages and Inventory -The QueueMessage function provides the fundamental means to send messages to the -remote peer. As the name implies, this employs a non-blocking queue. A done -channel which will be notified when the message is actually sent can optionally -be specified. There are certain message types which are better sent using other -functions which provide additional functionality. +The [Peer.QueueMessage] function provides the fundamental means to send messages +to the remote peer. As the name implies, this employs a non-blocking queue. A +done channel which will be notified when the message is actually sent can +optionally be specified. There are certain message types which are better sent +using other functions which provide additional functionality. Of special interest are inventory messages. Rather than manually sending MsgInv -messages via Queuemessage, the inventory vectors should be queued using the -QueueInventory function. It employs batching and trickling along with -intelligent known remote peer inventory detection and avoidance through the use -of a most-recently used algorithm. +messages via [Peer.QueueMessage], the inventory vectors should be queued using +the [Peer.QueueInventory] function. It employs batching and trickling along +with intelligent known remote peer inventory detection and avoidance through the +use of a most-recently used algorithm. # Message Sending Helper Functions -In addition to the bare QueueMessage function previously described, the -PushAddrMsg, PushGetBlocksMsg, and PushGetHeadersMsg functions are provided as a -convenience. While it is of course possible to create and send these messages -manually via QueueMessage, these helper functions provided additional useful -functionality that is typically desired. +In addition to the bare [Peer.QueueMessage] function previously described, the +[Peer.PushAddrMsg], [Peer.PushGetBlocksMsg], and [Peer.PushGetHeadersMsg] +functions are provided as a convenience. While it is of course possible to +create and send these messages manually via [Peer.QueueMessage], these helper +functions provided additional useful functionality that is typically desired. -For example, the PushAddrMsg function automatically limits the addresses to the +For example, [Peer.PushAddrMsg] automatically limits the addresses to the maximum number allowed by the message and randomizes the chosen addresses when there are too many. This allows the caller to simply provide a slice of known addresses, such as that returned by the addrmgr package, without having to worry about the details. -Finally, the PushGetBlocksMsg and PushGetHeadersMsg functions will construct +Finally, [Peer.PushGetBlocksMsg] and [Peer.PushGetHeadersMsg] will construct proper messages using a block locator and ignore back to back duplicate requests. # Peer Statistics -A snapshot of the current peer statistics can be obtained with the StatsSnapshot -function. This includes statistics such as the total number of bytes read and -written, the remote address, user agent, and negotiated protocol version. +A snapshot of the current peer statistics can be obtained with +[Peer.StatsSnapshot]. This includes statistics such as the total number of +bytes read and written, the remote address, user agent, and negotiated protocol +version. # Logging -This package provides extensive logging capabilities through the UseLogger +This package provides extensive logging capabilities through the [UseLogger] function which allows a slog.Logger to be specified. For example, logging at the debug level provides summaries of every message sent and received, and logging at the trace level provides full dumps of parsed messages as well as the raw message bytes using a format similar to hexdump -C. -# Improvement Proposals +# Decred Change Proposals -This package supports all improvement proposals supported by the wire package. +This package supports all Decred Change Proposals (DCPs) supported by the wire +package. */ package peer diff --git a/peer/error.go b/peer/error.go new file mode 100644 index 000000000..526197fe5 --- /dev/null +++ b/peer/error.go @@ -0,0 +1,60 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package peer + +// ErrorKind identifies a kind of error. It has full support for errors.Is +// and errors.As, so the caller can directly check against an error kind +// when determining the reason for an error. +type ErrorKind string + +// These constants are used to identify a specific [Error]. +const ( + // ErrNotVersionMessage indicates the first message received from a remote + // peer is not the required version message. + ErrNotVersionMessage = ErrorKind("ErrNotVersionMessage") + + // ErrSelfConnection indicates a peer attempted to connect to itself. + ErrSelfConnection = ErrorKind("ErrSelfConnection") + + // ErrProtocolVerTooOld indicates a protocol version is older than the + // minimum required version. + ErrProtocolVerTooOld = ErrorKind("ErrProtocolVerTooOld") + + // ErrNotVerAckMessage indicates the second message received from a remote + // peer is not the required verack message. + ErrNotVerAckMessage = ErrorKind("ErrNotVerAckMessage") + + // ErrHandshakeTimeout indicates the initial handshake timed out before + // completing. + ErrHandshakeTimeout = ErrorKind("ErrHandshakeTimeout") +) + +// Error satisfies the error interface and prints human-readable errors. +func (e ErrorKind) Error() string { + return string(e) +} + +// Error identifies an address manager error. It has full support for +// errors.Is and errors.As, so the caller can ascertain the specific reason +// for the error by checking the underlying error. +type Error struct { + Err error + Description string +} + +// Error satisfies the error interface and prints human-readable errors. +func (e Error) Error() string { + return e.Description +} + +// Unwrap returns the underlying wrapped error. +func (e Error) Unwrap() error { + return e.Err +} + +// makeError creates an [Error] given a set of arguments. +func makeError(kind ErrorKind, desc string) Error { + return Error{Err: kind, Description: desc} +} diff --git a/peer/error_test.go b/peer/error_test.go new file mode 100644 index 000000000..db40a41cb --- /dev/null +++ b/peer/error_test.go @@ -0,0 +1,129 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package peer + +import ( + "errors" + "testing" +) + +// TestErrorKindStringer tests the stringized output for the [ErrorKind] type. +func TestErrorKindStringer(t *testing.T) { + t.Parallel() + + tests := []struct { + in ErrorKind + want string + }{ + {ErrNotVersionMessage, "ErrNotVersionMessage"}, + {ErrSelfConnection, "ErrSelfConnection"}, + {ErrProtocolVerTooOld, "ErrProtocolVerTooOld"}, + {ErrNotVerAckMessage, "ErrNotVerAckMessage"}, + {ErrHandshakeTimeout, "ErrHandshakeTimeout"}, + } + + for i, test := range tests { + result := test.in.Error() + if result != test.want { + t.Errorf("#%d: got: %s want: %s", i, result, test.want) + continue + } + } +} + +// TestError tests the error output for the [Error] type. +func TestError(t *testing.T) { + t.Parallel() + + tests := []struct { + in Error + want string + }{{ + Error{Description: "some error"}, + "some error", + }, { + Error{Description: "human-readable error"}, + "human-readable error", + }} + + for i, test := range tests { + result := test.in.Error() + if result != test.want { + t.Errorf("#%d: got: %s want: %s", i, result, test.want) + continue + } + } +} + +// TestErrorKindIsAs ensures both [ErrorKind] and [Error] can be identified as +// being a specific error kind via [errors.Is] and unwrapped via [errors.As]. +func TestErrorKindIsAs(t *testing.T) { + tests := []struct { + name string + err error + target error + wantMatch bool + wantAs ErrorKind + }{{ + name: "ErrNotVersionMessage == ErrNotVersionMessage", + err: ErrNotVersionMessage, + target: ErrNotVersionMessage, + wantMatch: true, + wantAs: ErrNotVersionMessage, + }, { + name: "Error.ErrNotVersionMessage == ErrNotVersionMessage", + err: makeError(ErrNotVersionMessage, ""), + target: ErrNotVersionMessage, + wantMatch: true, + wantAs: ErrNotVersionMessage, + }, { + name: "ErrNotVersionMessage != ErrSelfConnection", + err: ErrNotVersionMessage, + target: ErrSelfConnection, + wantMatch: false, + wantAs: ErrNotVersionMessage, + }, { + name: "Error.ErrNotVersionMessage != ErrSelfConnection", + err: makeError(ErrNotVersionMessage, ""), + target: ErrSelfConnection, + wantMatch: false, + wantAs: ErrNotVersionMessage, + }, { + name: "ErrNotVersionMessage != Error.ErrSelfConnection", + err: ErrNotVersionMessage, + target: makeError(ErrSelfConnection, ""), + wantMatch: false, + wantAs: ErrNotVersionMessage, + }, { + name: "Error.ErrNotVersionMessage != Error.ErrSelfConnection", + err: makeError(ErrNotVersionMessage, ""), + target: makeError(ErrSelfConnection, ""), + wantMatch: false, + wantAs: ErrNotVersionMessage, + }} + + for _, test := range tests { + // Ensure the error matches or not depending on the expected result. + result := errors.Is(test.err, test.target) + if result != test.wantMatch { + t.Errorf("%s: incorrect error identification -- got %v, want %v", + test.name, result, test.wantMatch) + continue + } + + // Ensure the underlying error kind can be unwrapped and is the + // expected kind. + var kind ErrorKind + if !errors.As(test.err, &kind) { + t.Errorf("%s: unable to unwrap to error kind", test.name) + continue + } + if kind != test.wantAs { + t.Errorf("%s: unexpected unwrapped error kind -- got %v, want %v", + test.name, kind, test.wantAs) + continue + } + } +} diff --git a/peer/example_test.go b/peer/example_test.go index bfe5d3f0b..a4cead5ce 100644 --- a/peer/example_test.go +++ b/peer/example_test.go @@ -1,15 +1,17 @@ // Copyright (c) 2015-2016 The btcsuite developers -// Copyright (c) 2016-2021 The Decred developers +// Copyright (c) 2016-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. package peer_test import ( + "context" "fmt" "net" "time" + "github.com/decred/dcrd/crypto/rand" "github.com/decred/dcrd/peer/v4" "github.com/decred/dcrd/wire" ) @@ -39,8 +41,14 @@ func mockRemotePeer(listenAddr string) (net.Listener, error) { } // Create and start the inbound peer. - p := peer.NewInboundPeer(peerCfg) - p.AssociateConnection(conn) + go func() { + p := peer.NewInboundPeer(peerCfg, conn) + if err := p.Handshake(context.Background(), nil); err != nil { + fmt.Printf("inbound handshake error: %v\n", err) + return + } + p.Start() + }() }() return listener, nil @@ -62,45 +70,55 @@ func Example_newOutboundPeer() { } defer listener.Close() + // Establish the connection to the peer address. + conn, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + fmt.Printf("net.Dial: error %v\n", err) + return + } + // Create an outbound peer that is configured to act as a simnet node - // that offers no services and has listeners for the version and verack - // messages. The verack listener is used here to signal the code below - // when the handshake has been finished by signalling a channel. - verack := make(chan struct{}) + // that offers no services and has a listener for the pong message. + // + // Then perform the initial handshake and start the async I/O handling. + // + // The pong listener is used here to signal the code below when it arrives + // in response to an example ping. + pong := make(chan struct{}) peerCfg := &peer.Config{ UserAgentName: "peer", // User agent name to advertise. UserAgentVersion: "1.0.0", // User agent version to advertise. Net: wire.SimNet, Services: 0, Listeners: peer.MessageListeners{ - OnVersion: func(p *peer.Peer, msg *wire.MsgVersion) { - fmt.Println("outbound: received version") - }, - OnVerAck: func(p *peer.Peer, msg *wire.MsgVerAck) { - verack <- struct{}{} + // This uses a simple channel for the purposes of the example, but + // callers will typically find it much more ergonomic to create a + // type that houses additional state and exposes methods for the + // desired listeners. Then the listeners may be set to a concrete + // instance of that type so that they close over the additional + // state. + OnPong: func(p *peer.Peer, msg *wire.MsgPong) { + pong <- struct{}{} }, }, IdleTimeout: time.Second * 120, } - p, err := peer.NewOutboundPeer(peerCfg, listener.Addr()) - if err != nil { - fmt.Printf("NewOutboundPeer: error %v\n", err) + p := peer.NewOutboundPeer(peerCfg, conn.RemoteAddr(), conn) + if err := p.Handshake(context.Background(), nil); err != nil { + fmt.Printf("outbound peer handshake error: %v\n", err) return } + p.Start() - // Establish the connection to the peer address and mark it connected. - conn, err := net.Dial("tcp", p.Addr()) - if err != nil { - fmt.Printf("net.Dial: error %v\n", err) - return - } - p.AssociateConnection(conn) + // Ping the remote peer aysnchronously. + p.QueueMessage(wire.NewMsgPing(rand.Uint64()), nil) - // Wait for the verack message or timeout in case of failure. + // Wait for the pong message or timeout in case of failure. select { - case <-verack: + case <-pong: + fmt.Println("outbound: received pong") case <-time.After(time.Second * 1): - fmt.Printf("Example_peerConnection: verack timeout") + fmt.Printf("Example_newOutboundPeer: pong timeout") } // Disconnect the peer. @@ -108,5 +126,5 @@ func Example_newOutboundPeer() { p.WaitForDisconnect() // Output: - // outbound: received version + // outbound: received pong } diff --git a/peer/peer.go b/peer/peer.go index 75ec208b7..c227f1cf3 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -7,6 +7,7 @@ package peer import ( "bytes" + "context" "errors" "fmt" "hash" @@ -192,12 +193,6 @@ type MessageListeners struct { // OnFeeFilter is invoked when a peer receives a feefilter wire message. OnFeeFilter func(p *Peer, msg *wire.MsgFeeFilter) - // OnVersion is invoked when a peer receives a version wire message. - OnVersion func(p *Peer, msg *wire.MsgVersion) - - // OnVerAck is invoked when a peer receives a verack wire message. - OnVerAck func(p *Peer, msg *wire.MsgVerAck) - // OnSendHeaders is invoked when a peer receives a sendheaders wire // message. OnSendHeaders func(p *Peer, msg *wire.MsgSendHeaders) @@ -424,6 +419,13 @@ type AddrFunc func(remoteAddr *wire.NetAddress) *wire.NetAddress type HostToNetAddrFunc func(host string, port uint16, services wire.ServiceFlag) (*wire.NetAddressV2, error) +// delayedHandshakeMsg stores a message and buffer received before the verack +// during the handshake on old protocol versions. +type delayedHandshakeMsg struct { + msg wire.Message + buf []byte +} + // NOTE: The overall data flow of a peer is split into 3 goroutines. Inbound // messages are read via the inHandler goroutine and generally dispatched to // their own handler. For inbound data-related messages such as blocks, @@ -476,12 +478,16 @@ type Peer struct { id int32 userAgent string remoteServices wire.ServiceFlag - versionKnown bool - handshakeDone bool advertisedProtoVer uint32 // protocol version advertised by remote protocolVersion uint32 // negotiated protocol version sendHeadersPreferred bool // peer sent a sendheaders message - verAckReceived bool + + // These fields are used to delay messages that arrive during the handshake + // on older protocol versions. The associated logic can eventually be + // removed once a consensus upgrade forces a new protocol version as the + // minimum. + delayedHandshakeMsgsMtx sync.Mutex + delayedHandshakeMsgs []delayedHandshakeMsg knownInventory *lru.Set[wire.InvVect] prevGetBlocksMtx sync.Mutex @@ -673,42 +679,6 @@ func (p *Peer) LastPingMicros() int64 { return lastPingMicros } -// VersionKnown returns the whether or not the version of a peer is known -// locally. -// -// This function is safe for concurrent access. -func (p *Peer) VersionKnown() bool { - p.flagsMtx.Lock() - versionKnown := p.versionKnown - p.flagsMtx.Unlock() - - return versionKnown -} - -// HandshakeDone returns whether initial version messages were sent and -// received. -// -// This function is safe for concurrent access. -func (p *Peer) HandshakeDone() bool { - p.flagsMtx.Lock() - handshakeDone := p.handshakeDone - p.flagsMtx.Unlock() - - return handshakeDone -} - -// VerAckReceived returns whether or not a verack message was received by the -// peer. -// -// This function is safe for concurrent access. -func (p *Peer) VerAckReceived() bool { - p.flagsMtx.Lock() - verAckReceived := p.verAckReceived - p.flagsMtx.Unlock() - - return verAckReceived -} - // ProtocolVersion returns the negotiated peer protocol version. // // This function is safe for concurrent access. @@ -1346,264 +1316,279 @@ cleanup: log.Tracef("Peer stall handler done for %s", p) } -// inHandler handles all incoming messages for the peer. It must be run as a -// goroutine. -func (p *Peer) inHandler() { -out: - for atomic.LoadInt32(&p.disconnect) == 0 { - // Read a message and stop the idle timer as soon as the read - // is done. The timer is reset below for the next iteration if - // needed. - rmsg, buf, err := p.readMessage() - if err != nil { - // Only log the error if the local peer is not forcibly - // disconnecting and the remote peer has not disconnected. - if p.shouldHandleReadError(err) { - log.Errorf("Can't read message from %s: %v", p, err) - } +// processInboundMessage processes an inbound message and associated buffer. It +// returns true if the peer should be disconnected. +// +// In addition to checking requirements and updating peer state, it invokes any +// configured message handlers. +func (p *Peer) processInboundMessage(rmsg wire.Message, buf []byte) bool { + switch msg := rmsg.(type) { + case *wire.MsgVersion: + // Limit to one version message per peer. + log.Debugf("Already received 'version' from peer %s -- disconnecting", p) + return true - var nErr net.Error - if errors.As(err, &nErr) && nErr.Timeout() { - log.Warnf("Peer %s no answer for %s -- disconnecting", - p, p.cfg.IdleTimeout) - } + case *wire.MsgVerAck: + // Limit to one verack message per peer. + log.Debugf("Already received 'verack' from peer %s -- disconnecting", p) + return true - break out + case *wire.MsgGetAddr: + if p.cfg.Listeners.OnGetAddr != nil { + p.cfg.Listeners.OnGetAddr(p, msg) } - atomic.StoreInt64(&p.lastRecv, time.Now().Unix()) - select { - case p.stallControl <- stallControlMsg{sccReceiveMessage, rmsg}: - case <-p.quit: - break out + + case *wire.MsgAddr: + if p.cfg.Listeners.OnAddr != nil { + p.cfg.Listeners.OnAddr(p, msg) } - // Handle each supported message type. - select { - case p.stallControl <- stallControlMsg{sccHandlerStart, rmsg}: - case <-p.quit: - break out + case *wire.MsgAddrV2: + if p.cfg.Listeners.OnAddrV2 != nil { + p.cfg.Listeners.OnAddrV2(p, msg) } - switch msg := rmsg.(type) { - case *wire.MsgVersion: - // Limit to one version message per peer. - log.Debugf("Already received 'version' from peer %s -- "+ - "disconnecting", p) - break out - case *wire.MsgVerAck: - // No read lock is necessary because verAckReceived is not written - // to in any other goroutine. - if p.verAckReceived { - log.Infof("Already received 'verack' from peer %s -- "+ - "disconnecting", p) - break out - } - p.flagsMtx.Lock() - p.verAckReceived = true - p.flagsMtx.Unlock() - if p.cfg.Listeners.OnVerAck != nil { - p.cfg.Listeners.OnVerAck(p, msg) - } + case *wire.MsgPing: + p.handlePingMsg(msg) + if p.cfg.Listeners.OnPing != nil { + p.cfg.Listeners.OnPing(p, msg) + } - case *wire.MsgGetAddr: - if p.cfg.Listeners.OnGetAddr != nil { - p.cfg.Listeners.OnGetAddr(p, msg) - } + case *wire.MsgPong: + p.handlePongMsg(msg) + if p.cfg.Listeners.OnPong != nil { + p.cfg.Listeners.OnPong(p, msg) + } - case *wire.MsgAddr: - if p.cfg.Listeners.OnAddr != nil { - p.cfg.Listeners.OnAddr(p, msg) - } + case *wire.MsgMemPool: + if p.cfg.Listeners.OnMemPool != nil { + p.cfg.Listeners.OnMemPool(p, msg) + } - case *wire.MsgAddrV2: - if p.cfg.Listeners.OnAddrV2 != nil { - p.cfg.Listeners.OnAddrV2(p, msg) - } + case *wire.MsgGetMiningState: + if p.cfg.Listeners.OnGetMiningState != nil { + p.cfg.Listeners.OnGetMiningState(p, msg) + } - case *wire.MsgPing: - p.handlePingMsg(msg) - if p.cfg.Listeners.OnPing != nil { - p.cfg.Listeners.OnPing(p, msg) - } + case *wire.MsgMiningState: + if p.cfg.Listeners.OnMiningState != nil { + p.cfg.Listeners.OnMiningState(p, msg) + } - case *wire.MsgPong: - p.handlePongMsg(msg) - if p.cfg.Listeners.OnPong != nil { - p.cfg.Listeners.OnPong(p, msg) - } + case *wire.MsgTx: + if p.cfg.Listeners.OnTx != nil { + p.cfg.Listeners.OnTx(p, msg) + } - case *wire.MsgMemPool: - if p.cfg.Listeners.OnMemPool != nil { - p.cfg.Listeners.OnMemPool(p, msg) - } + case *wire.MsgBlock: + if p.cfg.Listeners.OnBlock != nil { + p.cfg.Listeners.OnBlock(p, msg, buf) + } - case *wire.MsgGetMiningState: - if p.cfg.Listeners.OnGetMiningState != nil { - p.cfg.Listeners.OnGetMiningState(p, msg) - } + case *wire.MsgInv: + if p.cfg.Listeners.OnInv != nil { + p.cfg.Listeners.OnInv(p, msg) + } - case *wire.MsgMiningState: - if p.cfg.Listeners.OnMiningState != nil { - p.cfg.Listeners.OnMiningState(p, msg) - } + case *wire.MsgHeaders: + if p.cfg.Listeners.OnHeaders != nil { + p.cfg.Listeners.OnHeaders(p, msg) + } - case *wire.MsgTx: - if p.cfg.Listeners.OnTx != nil { - p.cfg.Listeners.OnTx(p, msg) - } + case *wire.MsgNotFound: + if p.cfg.Listeners.OnNotFound != nil { + p.cfg.Listeners.OnNotFound(p, msg) + } - case *wire.MsgBlock: - if p.cfg.Listeners.OnBlock != nil { - p.cfg.Listeners.OnBlock(p, msg, buf) - } + case *wire.MsgGetData: + if p.cfg.Listeners.OnGetData != nil { + p.cfg.Listeners.OnGetData(p, msg) + } - case *wire.MsgInv: - if p.cfg.Listeners.OnInv != nil { - p.cfg.Listeners.OnInv(p, msg) - } + case *wire.MsgGetBlocks: + if p.cfg.Listeners.OnGetBlocks != nil { + p.cfg.Listeners.OnGetBlocks(p, msg) + } - case *wire.MsgHeaders: - if p.cfg.Listeners.OnHeaders != nil { - p.cfg.Listeners.OnHeaders(p, msg) - } + case *wire.MsgGetHeaders: + if p.cfg.Listeners.OnGetHeaders != nil { + p.cfg.Listeners.OnGetHeaders(p, msg) + } - case *wire.MsgNotFound: - if p.cfg.Listeners.OnNotFound != nil { - p.cfg.Listeners.OnNotFound(p, msg) - } + case *wire.MsgGetCFilter: + if p.cfg.Listeners.OnGetCFilter != nil { + p.cfg.Listeners.OnGetCFilter(p, msg) + } - case *wire.MsgGetData: - if p.cfg.Listeners.OnGetData != nil { - p.cfg.Listeners.OnGetData(p, msg) - } + case *wire.MsgGetCFHeaders: + if p.cfg.Listeners.OnGetCFHeaders != nil { + p.cfg.Listeners.OnGetCFHeaders(p, msg) + } - case *wire.MsgGetBlocks: - if p.cfg.Listeners.OnGetBlocks != nil { - p.cfg.Listeners.OnGetBlocks(p, msg) - } + case *wire.MsgGetCFTypes: + if p.cfg.Listeners.OnGetCFTypes != nil { + p.cfg.Listeners.OnGetCFTypes(p, msg) + } - case *wire.MsgGetHeaders: - if p.cfg.Listeners.OnGetHeaders != nil { - p.cfg.Listeners.OnGetHeaders(p, msg) - } + case *wire.MsgCFilter: + if p.cfg.Listeners.OnCFilter != nil { + p.cfg.Listeners.OnCFilter(p, msg) + } - case *wire.MsgGetCFilter: - if p.cfg.Listeners.OnGetCFilter != nil { - p.cfg.Listeners.OnGetCFilter(p, msg) - } + case *wire.MsgCFHeaders: + if p.cfg.Listeners.OnCFHeaders != nil { + p.cfg.Listeners.OnCFHeaders(p, msg) + } - case *wire.MsgGetCFHeaders: - if p.cfg.Listeners.OnGetCFHeaders != nil { - p.cfg.Listeners.OnGetCFHeaders(p, msg) - } + case *wire.MsgCFTypes: + if p.cfg.Listeners.OnCFTypes != nil { + p.cfg.Listeners.OnCFTypes(p, msg) + } - case *wire.MsgGetCFTypes: - if p.cfg.Listeners.OnGetCFTypes != nil { - p.cfg.Listeners.OnGetCFTypes(p, msg) - } + case *wire.MsgFeeFilter: + if p.cfg.Listeners.OnFeeFilter != nil { + p.cfg.Listeners.OnFeeFilter(p, msg) + } - case *wire.MsgCFilter: - if p.cfg.Listeners.OnCFilter != nil { - p.cfg.Listeners.OnCFilter(p, msg) - } + case *wire.MsgSendHeaders: + p.flagsMtx.Lock() + p.sendHeadersPreferred = true + p.flagsMtx.Unlock() - case *wire.MsgCFHeaders: - if p.cfg.Listeners.OnCFHeaders != nil { - p.cfg.Listeners.OnCFHeaders(p, msg) - } + if p.cfg.Listeners.OnSendHeaders != nil { + p.cfg.Listeners.OnSendHeaders(p, msg) + } - case *wire.MsgCFTypes: - if p.cfg.Listeners.OnCFTypes != nil { - p.cfg.Listeners.OnCFTypes(p, msg) - } + case *wire.MsgGetCFilterV2: + if p.cfg.Listeners.OnGetCFilterV2 != nil { + p.cfg.Listeners.OnGetCFilterV2(p, msg) + } - case *wire.MsgFeeFilter: - if p.cfg.Listeners.OnFeeFilter != nil { - p.cfg.Listeners.OnFeeFilter(p, msg) - } + case *wire.MsgCFilterV2: + if p.cfg.Listeners.OnCFilterV2 != nil { + p.cfg.Listeners.OnCFilterV2(p, msg) + } - case *wire.MsgSendHeaders: - p.flagsMtx.Lock() - p.sendHeadersPreferred = true - p.flagsMtx.Unlock() + case *wire.MsgGetCFsV2: + if p.cfg.Listeners.OnGetCFiltersV2 != nil { + p.cfg.Listeners.OnGetCFiltersV2(p, msg) + } - if p.cfg.Listeners.OnSendHeaders != nil { - p.cfg.Listeners.OnSendHeaders(p, msg) - } + case *wire.MsgCFiltersV2: + if p.cfg.Listeners.OnCFiltersV2 != nil { + p.cfg.Listeners.OnCFiltersV2(p, msg) + } - case *wire.MsgGetCFilterV2: - if p.cfg.Listeners.OnGetCFilterV2 != nil { - p.cfg.Listeners.OnGetCFilterV2(p, msg) - } + case *wire.MsgGetInitState: + if p.cfg.Listeners.OnGetInitState != nil { + p.cfg.Listeners.OnGetInitState(p, msg) + } - case *wire.MsgCFilterV2: - if p.cfg.Listeners.OnCFilterV2 != nil { - p.cfg.Listeners.OnCFilterV2(p, msg) - } + case *wire.MsgInitState: + if p.cfg.Listeners.OnInitState != nil { + p.cfg.Listeners.OnInitState(p, msg) + } - case *wire.MsgGetCFsV2: - if p.cfg.Listeners.OnGetCFiltersV2 != nil { - p.cfg.Listeners.OnGetCFiltersV2(p, msg) - } + case *wire.MsgMixPairReq: + if p.cfg.Listeners.OnMixPairReq != nil { + p.cfg.Listeners.OnMixPairReq(p, msg) + } - case *wire.MsgCFiltersV2: - if p.cfg.Listeners.OnCFiltersV2 != nil { - p.cfg.Listeners.OnCFiltersV2(p, msg) - } + case *wire.MsgMixKeyExchange: + if p.cfg.Listeners.OnMixKeyExchange != nil { + p.cfg.Listeners.OnMixKeyExchange(p, msg) + } - case *wire.MsgGetInitState: - if p.cfg.Listeners.OnGetInitState != nil { - p.cfg.Listeners.OnGetInitState(p, msg) - } + case *wire.MsgMixCiphertexts: + if p.cfg.Listeners.OnMixCiphertexts != nil { + p.cfg.Listeners.OnMixCiphertexts(p, msg) + } - case *wire.MsgInitState: - if p.cfg.Listeners.OnInitState != nil { - p.cfg.Listeners.OnInitState(p, msg) - } + case *wire.MsgMixSlotReserve: + if p.cfg.Listeners.OnMixSlotReserve != nil { + p.cfg.Listeners.OnMixSlotReserve(p, msg) + } - case *wire.MsgMixPairReq: - if p.cfg.Listeners.OnMixPairReq != nil { - p.cfg.Listeners.OnMixPairReq(p, msg) - } + case *wire.MsgMixDCNet: + if p.cfg.Listeners.OnMixDCNet != nil { + p.cfg.Listeners.OnMixDCNet(p, msg) + } - case *wire.MsgMixKeyExchange: - if p.cfg.Listeners.OnMixKeyExchange != nil { - p.cfg.Listeners.OnMixKeyExchange(p, msg) - } + case *wire.MsgMixConfirm: + if p.cfg.Listeners.OnMixConfirm != nil { + p.cfg.Listeners.OnMixConfirm(p, msg) + } - case *wire.MsgMixCiphertexts: - if p.cfg.Listeners.OnMixCiphertexts != nil { - p.cfg.Listeners.OnMixCiphertexts(p, msg) - } + case *wire.MsgMixFactoredPoly: + if p.cfg.Listeners.OnMixFactoredPoly != nil { + p.cfg.Listeners.OnMixFactoredPoly(p, msg) + } - case *wire.MsgMixSlotReserve: - if p.cfg.Listeners.OnMixSlotReserve != nil { - p.cfg.Listeners.OnMixSlotReserve(p, msg) - } + case *wire.MsgMixSecrets: + if p.cfg.Listeners.OnMixSecrets != nil { + p.cfg.Listeners.OnMixSecrets(p, msg) + } - case *wire.MsgMixDCNet: - if p.cfg.Listeners.OnMixDCNet != nil { - p.cfg.Listeners.OnMixDCNet(p, msg) - } + default: + log.Debugf("Received unhandled message of type %v from %v", + rmsg.Command(), p) + } + + return false +} - case *wire.MsgMixConfirm: - if p.cfg.Listeners.OnMixConfirm != nil { - p.cfg.Listeners.OnMixConfirm(p, msg) +// inHandler handles all incoming messages for the peer. It must be run as a +// goroutine. +func (p *Peer) inHandler() { + // Process messages that arrived out of order during the handshake on older + // protocol versions. + if p.ProtocolVersion() < wire.AddrV2Version { + p.delayedHandshakeMsgsMtx.Lock() + for _, msg := range p.delayedHandshakeMsgs { + if disconn := p.processInboundMessage(msg.msg, msg.buf); disconn { + p.Disconnect() + break } + } + p.delayedHandshakeMsgs = nil + p.delayedHandshakeMsgsMtx.Unlock() + } - case *wire.MsgMixFactoredPoly: - if p.cfg.Listeners.OnMixFactoredPoly != nil { - p.cfg.Listeners.OnMixFactoredPoly(p, msg) +out: + for atomic.LoadInt32(&p.disconnect) == 0 { + // Read a message and stop the idle timer as soon as the read + // is done. The timer is reset below for the next iteration if + // needed. + rmsg, buf, err := p.readMessage() + if err != nil { + // Only log the error if the local peer is not forcibly + // disconnecting and the remote peer has not disconnected. + if p.shouldHandleReadError(err) { + log.Errorf("Can't read message from %s: %v", p, err) } - case *wire.MsgMixSecrets: - if p.cfg.Listeners.OnMixSecrets != nil { - p.cfg.Listeners.OnMixSecrets(p, msg) + var nErr net.Error + if errors.As(err, &nErr) && nErr.Timeout() { + log.Warnf("Peer %s no answer for %s -- disconnecting", + p, p.cfg.IdleTimeout) } - default: - log.Debugf("Received unhandled message of type %v "+ - "from %v", rmsg.Command(), p) + break out + } + atomic.StoreInt64(&p.lastRecv, time.Now().Unix()) + select { + case p.stallControl <- stallControlMsg{sccReceiveMessage, rmsg}: + case <-p.quit: + break out + } + // Handle each supported message type. + select { + case p.stallControl <- stallControlMsg{sccHandlerStart, rmsg}: + case <-p.quit: + break out + } + if disconnect := p.processInboundMessage(rmsg, buf); disconnect { + break out } select { case p.stallControl <- stallControlMsg{sccHandlerDone, rmsg}: @@ -1678,10 +1663,7 @@ out: p.sendQueue <- next case iv := <-p.outputInvChan: - // No handshake? They'll find out soon enough. - if p.HandshakeDone() { - invSendQueue = append(invSendQueue, iv) - } + invSendQueue = append(invSendQueue, iv) case <-trickleTimer.C: // Don't send anything if we're disconnecting or there @@ -1968,25 +1950,32 @@ func (p *Peer) Disconnect() { } } +// OnVersionCallback is an optional callback function that a caller may provide +// to receive the remote version message during the handshake process. See +// [Peer.Handshake] for details. +type OnVersionCallback func(*wire.MsgVersion) error + // readRemoteVersionMsg waits for the next message to arrive from the remote // peer. If the next message is not a version message or the version is not // acceptable then return an error. -func (p *Peer) readRemoteVersionMsg() error { +func (p *Peer) readRemoteVersionMsg(onVersion OnVersionCallback) error { // Read their version message. remoteMsg, _, err := p.readMessage() if err != nil { return err } - // Disconnect clients if the first message is not a version message. + // Disconnect client if the first message is not a version message. msg, ok := remoteMsg.(*wire.MsgVersion) if !ok { - return errors.New("a version message must precede all others") + const str = "a version message must precede all others" + return makeError(ErrNotVersionMessage, str) } // Detect self connections. if !allowSelfConns && sentNonces.Contains(msg.Nonce) { - return errors.New("disconnecting peer connected to self") + const str = "disconnecting peer connected to self" + return makeError(ErrSelfConnection, str) } // Negotiate the protocol version and set the services to what the remote @@ -1994,11 +1983,10 @@ func (p *Peer) readRemoteVersionMsg() error { p.flagsMtx.Lock() p.advertisedProtoVer = uint32(msg.ProtocolVersion) p.protocolVersion = minUint32(p.protocolVersion, p.advertisedProtoVer) - p.versionKnown = true p.remoteServices = msg.Services p.flagsMtx.Unlock() - log.Debugf("Negotiated protocol version %d for peer %s", - p.protocolVersion, p) + log.Debugf("Negotiated protocol version %d for peer %s", p.protocolVersion, + p) // Updating a bunch of stats. p.statsMtx.Lock() @@ -2016,15 +2004,73 @@ func (p *Peer) readRemoteVersionMsg() error { p.flagsMtx.Unlock() // Invoke the callback if specified. - if p.cfg.Listeners.OnVersion != nil { - p.cfg.Listeners.OnVersion(p, msg) + if onVersion != nil { + if err := onVersion(msg); err != nil { + return err + } } // Disconnect clients that have a protocol version that is too old. const reqProtocolVersion = int32(wire.RemoveRejectVersion) if msg.ProtocolVersion < reqProtocolVersion { - return fmt.Errorf("protocol version must be %d or greater", + str := fmt.Sprintf("protocol version must be %d or greater", reqProtocolVersion) + return makeError(ErrProtocolVerTooOld, str) + } + + return nil +} + +// readRemoteVerAckMsgLegacy reads and stores up to 3 messages until the verack +// is received to be processed after the handshake completes. +// +// Unfortunately, older protocol versions sometimes send other messages before +// the verack due to async processes being started before the handshake +// completes. +func (p *Peer) readRemoteVerAckMsgLegacy() error { + var verAckReceived bool + const maxNonVerAcks = 3 + for i := 0; i < maxNonVerAcks; i++ { + msg, buf, err := p.readMessage() + if err != nil { + return err + } + + _, ok := msg.(*wire.MsgVerAck) + if ok { + verAckReceived = true + break + } + + delayedMsg := delayedHandshakeMsg{msg, buf} + p.delayedHandshakeMsgsMtx.Lock() + p.delayedHandshakeMsgs = append(p.delayedHandshakeMsgs, delayedMsg) + p.delayedHandshakeMsgsMtx.Unlock() + } + if !verAckReceived { + str := fmt.Sprintf("the verack message must follow the version "+ + "message within %d messages", maxNonVerAcks) + return makeError(ErrNotVerAckMessage, str) + } + + return nil +} + +// readRemoteVerAckMsg waits for the next message to arrive from the remote +// peer and errors if it is not a verack message. +func (p *Peer) readRemoteVerAckMsg() error { + // Read their verack message. + remoteMsg, _, err := p.readMessage() + if err != nil { + return err + } + + // Disconnect clients if the second message is not a verack message. + _, ok := remoteMsg.(*wire.MsgVerAck) + if !ok { + const str = "the verack message must follow the version message " + + "and precede all others" + return makeError(ErrNotVerAckMessage, str) } return nil @@ -2136,11 +2182,17 @@ func (p *Peer) writeLocalVersionMsg() error { return nil } -// negotiateInboundProtocol waits to receive a version message from the peer -// then sends our version message. If the events do not occur in that order then -// it returns an error. -func (p *Peer) negotiateInboundProtocol() error { - if err := p.readRemoteVersionMsg(); err != nil { +// inboundHandshake waits to receive a version message from the remote peer then +// sends a local version message followed by a verack message to signal the +// remote version message was received and acceptable. Finally, it waits to +// receive the remote verack. No verack is sent if the version is deemed +// unacceptable. +// +// An error is returned when the events do not occur in the described order. +func (p *Peer) inboundHandshake(onVersion OnVersionCallback) error { + // Outbound peers are required to send the version message first, so inbound + // peers must expect to read it first. + if err := p.readRemoteVersionMsg(onVersion); err != nil { return err } @@ -2148,57 +2200,129 @@ func (p *Peer) negotiateInboundProtocol() error { return err } - p.flagsMtx.Lock() - p.handshakeDone = true - p.flagsMtx.Unlock() + if err := p.writeMessage(wire.NewMsgVerAck()); err != nil { + return err + } + + var readRemoteVerAckMsgFn = p.readRemoteVerAckMsg + if p.ProtocolVersion() < wire.AddrV2Version { + readRemoteVerAckMsgFn = p.readRemoteVerAckMsgLegacy + } + if err := readRemoteVerAckMsgFn(); err != nil { + return err + } return nil } -// negotiateOutboundProtocol sends our version message then waits to receive a -// version message from the peer. If the events do not occur in that order then -// it returns an error. -func (p *Peer) negotiateOutboundProtocol() error { +// outboundHandshake sends a local version message then waits to receive a +// version message from the remote peer followed by a verack from the remote +// peer that signals the local version message was received and deemed +// acceptable. Finally, it sends a verack message to signal the remote version +// was received and acceptable. No verack is sent if the version is deemed +// unacceptable. +// +// An error is returned when the events do not occur in the described order. +func (p *Peer) outboundHandshake(onVersion OnVersionCallback) error { + // Outbound peers are required to send the version message first. if err := p.writeLocalVersionMsg(); err != nil { return err } - if err := p.readRemoteVersionMsg(); err != nil { + if err := p.readRemoteVersionMsg(onVersion); err != nil { return err } - p.flagsMtx.Lock() - p.handshakeDone = true - p.flagsMtx.Unlock() + var readRemoteVerAckMsgFn = p.readRemoteVerAckMsg + if p.ProtocolVersion() < wire.AddrV2Version { + readRemoteVerAckMsgFn = p.readRemoteVerAckMsgLegacy + } + if err := readRemoteVerAckMsgFn(); err != nil { + return err + } + + if err := p.writeMessage(wire.NewMsgVerAck()); err != nil { + return err + } return nil } -// start begins processing input and output messages. -func (p *Peer) start() error { - log.Tracef("Starting peer %s", p) +// errHandshakeTimeout indicates the handshake process timed out. +var errHandshakeTimeout = makeError(ErrHandshakeTimeout, + "protocol handshake timeout") - negotiateErr := make(chan error, 1) +// Handshake performs the intitial handshake with a remote peer and returns an +// error if the handshake does not complete for any reason. It blocks until the +// handshake successfully completes or an error occurs. +// +// The peer will be disconnected when a non-nil error is returned. +// +// The caller may optionally provide a callback that will be invoked with the +// version message from the remote peer. Any errors returned from the callback +// will cause the handshake process to fail and will be returned from this +// function. Effectively, it provides the caller with an easy mechanism to +// reject peers based on whatever additional criteria they deem fit. +// +// NOTE: The code in the callback must be careful to avoid using exported +// methods that reference details that are not yet established. Generally +// speaking, anything that is set as a result of the information in the version +// message, with the exception of the negotiated protocol version, must not be +// relied on. +// +// More specifically, the callback must only rely on the following: +// - Anything set in the [NewInboundPeer] or [NewOutboundPeer] constructors +// such as [Peer.Addr] and [Peer.Inbound] +// - [Peer.ProtocolVersion] is the negotiated protocol version +// - [Peer.ID] +// +// On the other hand, the callback must NOT rely on any other methods such as: +// - [Peer.UserAgent] (part of the version message) +// - [Peer.LastBlock] (part of the version message) +// - [Peer.Services] (part of the version message) +// - [Peer.StartingHeight] (derived from version message) +// - [Peer.StatsSnapshot] (contains many of the aforementioned details) +// +// No callback for the verack message is provided because a successful handshake +// guarantees the verack was received. Thus, anything that would be invoked in +// response to verack, can be done when this function returns without error. +// +// This should only be called once when the peer is first connected. +// +// The caller MUST only start the async I/O processing with [Peer.Start] after +// this function returns without error. +func (p *Peer) Handshake(ctx context.Context, onVersion OnVersionCallback) error { + handshakeErr := make(chan error, 1) go func() { if p.inbound { - negotiateErr <- p.negotiateInboundProtocol() + handshakeErr <- p.inboundHandshake(onVersion) } else { - negotiateErr <- p.negotiateOutboundProtocol() + handshakeErr <- p.outboundHandshake(onVersion) } }() // Negotiate the protocol within the specified negotiateTimeout. select { - case err := <-negotiateErr: + case err := <-handshakeErr: if err != nil { p.Disconnect() return err } case <-time.After(negotiateTimeout): p.Disconnect() - return errors.New("protocol negotiation timeout") + return errHandshakeTimeout + case <-ctx.Done(): + p.Disconnect() + return errHandshakeTimeout } - log.Debugf("Connected to %s", p.Addr()) + + return nil +} + +// Start begins processing input and output messages. Callers MUST only call +// this after [Peer.Handshake] completes without error. +func (p *Peer) Start() { + log.Tracef("Starting peer %s", p) // The protocol has been negotiated successfully so start processing input // and output messages. @@ -2206,41 +2330,6 @@ func (p *Peer) start() error { go p.inHandler() go p.queueHandler() go p.outHandler() - - // Send our verack message now that the IO processing machinery has started. - p.QueueMessage(wire.NewMsgVerAck(), nil) - return nil -} - -// AssociateConnection associates the given conn to the peer. -// Calling this function when the peer is already connected will -// have no effect. -func (p *Peer) AssociateConnection(conn net.Conn) { - p.connMtx.Lock() - - // Already connected? - if p.conn != nil { - p.connMtx.Unlock() - return - } - - p.conn = conn - p.connMtx.Unlock() - - p.statsMtx.Lock() - p.timeConnected = time.Now() - p.statsMtx.Unlock() - - if p.inbound { - p.remoteAddr = p.conn.RemoteAddr() - } - - go func(peer *Peer) { - if err := peer.start(); err != nil { - log.Debugf("Cannot start peer %v: %v", peer, err) - peer.Disconnect() - } - }(p) } // WaitForDisconnect waits until the peer has completely disconnected and all @@ -2252,9 +2341,9 @@ func (p *Peer) WaitForDisconnect() { } // newPeerBase returns a new base Decred peer based on the inbound flag. This -// is used by the NewInboundPeer and NewOutboundPeer functions to perform base -// setup needed by both types of peers. -func newPeerBase(cfgOrig *Config, inbound bool) *Peer { +// is used by the [NewInboundPeer] and [NewOutboundPeer] functions to perform +// base setup needed by both types of peers. +func newPeerBase(cfgOrig *Config, conn net.Conn, inbound bool) *Peer { // Copy to avoid mutating the caller and so the caller can't mutate. cfg := *cfgOrig @@ -2278,9 +2367,11 @@ func newPeerBase(cfgOrig *Config, inbound bool) *Peer { p := Peer{ blake256Hasher: blake256.New(), + conn: conn, inbound: inbound, knownInventory: lru.NewSetWithDefaultTTL[wire.InvVect]( maxKnownInventory, maxKnownInventoryTTL), + timeConnected: time.Now(), stallControl: make(chan stallControlMsg, 1), // nonblocking sync outputQueue: make(chan outMsg, outputBufferSize), sendQueue: make(chan outMsg, 1), // nonblocking sync @@ -2296,15 +2387,20 @@ func newPeerBase(cfgOrig *Config, inbound bool) *Peer { return &p } -// NewInboundPeer returns a new inbound Decred peer. Use Start to begin -// processing incoming and outgoing messages. -func NewInboundPeer(cfg *Config) *Peer { - return newPeerBase(cfg, true) +// NewInboundPeer returns a new inbound Decred peer. Use [Peer.Handshake] to +// perform the initial version negotiation and then [Peer.Start] to begin +// processing incoming and outgoing messages when the handshake is successful. +func NewInboundPeer(cfg *Config, conn net.Conn) *Peer { + p := newPeerBase(cfg, conn, true) + p.remoteAddr = p.conn.RemoteAddr() + return p } -// NewOutboundPeer returns a new outbound Decred peer. -func NewOutboundPeer(cfg *Config, addr net.Addr) (*Peer, error) { - p := newPeerBase(cfg, false) +// NewOutboundPeer returns a new outbound Decred peer. Use [Peer.Handshake] to +// perform the initial version negotiation and then [Peer.Start] to begin +// processing incoming and outgoing messages when the handshake is successful. +func NewOutboundPeer(cfg *Config, addr net.Addr, conn net.Conn) *Peer { + p := newPeerBase(cfg, conn, false) p.remoteAddr = addr - return p, nil + return p } diff --git a/peer/peer_test.go b/peer/peer_test.go index 9277c03d9..2daf67d01 100644 --- a/peer/peer_test.go +++ b/peer/peer_test.go @@ -6,10 +6,13 @@ package peer import ( + "context" "encoding/binary" "errors" + "fmt" "io" "net" + "reflect" "strconv" "sync" "testing" @@ -80,26 +83,31 @@ func (m addr) String() string { return m.address } // pipe turns two mock connections into a full-duplex connection similar to // net.Pipe to allow pipe's with (fake) addresses. -func pipe(c1, c2 *conn) (*conn, *conn) { +func pipe(inAddr, outAddr string) (*conn, *conn) { r1, w1 := io.Pipe() r2, w2 := io.Pipe() - c1.WriteCloser = w1 - c2.ReadCloser = r1 - c1.ReadCloser = r2 - c2.WriteCloser = w2 - + c1 := &conn{laddr: inAddr, raddr: outAddr, WriteCloser: w1, ReadCloser: r2} + c2 := &conn{laddr: outAddr, raddr: inAddr, WriteCloser: w2, ReadCloser: r1} return c1, c2 } +// mockPeerConfig returns a base mock peer config to use throughout the tests. +func mockPeerConfig() *Config { + return &Config{ + UserAgentName: "peer", + UserAgentVersion: "1.0", + Net: wire.MainNet, + Services: wire.SFNodeNetwork, + } +} + // peerStats holds the expected peer stats used for testing peer. type peerStats struct { wantUserAgent string wantServices wire.ServiceFlag wantProtocolVersion uint32 wantConnected bool - wantVersionKnown bool - wantVerAckReceived bool wantLastBlock int64 wantStartingHeight int64 wantLastPingTime time.Time @@ -110,136 +118,142 @@ type peerStats struct { wantBytesReceived uint64 } -// testPeer tests the given peer's flags and stats. -func testPeer(t *testing.T, p *Peer, s peerStats) { - if p.UserAgent() != s.wantUserAgent { - t.Errorf("testPeer: wrong UserAgent - got %v, want %v", p.UserAgent(), s.wantUserAgent) - return +// runPeersAsync invokes the [Peer.Start] method on the passed peers in separate +// goroutines and returns a cancelable context and wait group the caller can use +// to shutdown the peers and wait for clean shutdown. +func runPeersAsync(peers ...*Peer) (context.CancelFunc, *sync.WaitGroup) { + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + wg.Add(len(peers)) + for _, peer := range peers { + go func(peer *Peer) { + peer.Start() + select { + case <-ctx.Done(): + peer.Disconnect() + case <-peer.quit: + } + wg.Done() + }(peer) } + return cancel, &wg +} - if p.Services() != s.wantServices { - t.Errorf("testPeer: wrong Services - got %v, want %v", p.Services(), s.wantServices) - return +// runHandshakes performs the initial handshake for all of the passed peers +// as a group in separate goroutines and returns when they all finish. If any +// handshake fails, all of them are stopped and the first non-nil error is +// returned. +func runHandshakes(peers ...*Peer) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + var wg sync.WaitGroup + wg.Add(len(peers)) + var errsMtx sync.Mutex + errs := make([]error, 0, len(peers)) + for _, peer := range peers { + go func(peer *Peer) { + defer wg.Done() + + negotiateCtx, negotiateCancel := context.WithTimeout(ctx, time.Second) + defer negotiateCancel() + + err := peer.Handshake(negotiateCtx, nil) + if err != nil { + cancel() + errsMtx.Lock() + errs = append(errs, err) + errsMtx.Unlock() + } + }(peer) } + wg.Wait() + return errors.Join(errs...) +} + +// testPeerState ensures the flags and state of the provided peer match the +// given stats. +func testPeerState(t *testing.T, p *Peer, s peerStats) { + t.Helper() - if !p.LastPingTime().Equal(s.wantLastPingTime) { - t.Errorf("testPeer: wrong LastPingTime - got %v, want %v", p.LastPingTime(), s.wantLastPingTime) - return + if got := p.UserAgent(); got != s.wantUserAgent { + t.Fatalf("wrong UserAgent - got %v, want %v", got, s.wantUserAgent) } - if p.LastPingNonce() != s.wantLastPingNonce { - t.Errorf("testPeer: wrong LastPingNonce - got %v, want %v", p.LastPingNonce(), s.wantLastPingNonce) - return + if got := p.Services(); got != s.wantServices { + t.Fatalf("wrong Services - got %v, want %v", got, s.wantServices) } - if p.LastPingMicros() != s.wantLastPingMicros { - t.Errorf("testPeer: wrong LastPingMicros - got %v, want %v", p.LastPingMicros(), s.wantLastPingMicros) - return + if got := p.LastPingTime(); !got.Equal(s.wantLastPingTime) { + t.Fatalf("wrong LastPingTime - got %v, want %v", got, s.wantLastPingTime) } - if p.VerAckReceived() != s.wantVerAckReceived { - t.Errorf("testPeer: wrong VerAckReceived - got %v, want %v", p.VerAckReceived(), s.wantVerAckReceived) - return + if got := p.LastPingNonce(); got != s.wantLastPingNonce { + t.Fatalf("wrong LastPingNonce - got %v, want %v", got, s.wantLastPingNonce) } - if p.VersionKnown() != s.wantVersionKnown { - t.Errorf("testPeer: wrong VersionKnown - got %v, want %v", p.VersionKnown(), s.wantVersionKnown) - return + if got := p.LastPingMicros(); got != s.wantLastPingMicros { + t.Fatalf("wrong LastPingMicros - got %v, want %v", got, s.wantLastPingMicros) } - if p.ProtocolVersion() != s.wantProtocolVersion { - t.Errorf("testPeer: wrong ProtocolVersion - got %v, want %v", p.ProtocolVersion(), s.wantProtocolVersion) - return + if got := p.ProtocolVersion(); got != s.wantProtocolVersion { + t.Fatalf("wrong ProtocolVersion - got %v, want %v", got, s.wantProtocolVersion) } - if p.LastBlock() != s.wantLastBlock { - t.Errorf("testPeer: wrong LastBlock - got %v, want %v", p.LastBlock(), s.wantLastBlock) - return + if got := p.LastBlock(); got != s.wantLastBlock { + t.Fatalf("wrong LastBlock - got %v, want %v", got, s.wantLastBlock) } // Allow for a deviation of 1s, as the second may tick when the message is // in transit and the protocol doesn't support any further precision. - if p.TimeOffset() != s.wantTimeOffset && p.TimeOffset() != s.wantTimeOffset-1 { - t.Errorf("testPeer: wrong TimeOffset - got %v, want %v or %v", p.TimeOffset(), + if got := p.TimeOffset(); got != s.wantTimeOffset && got != s.wantTimeOffset-1 { + t.Fatalf("wrong TimeOffset - got %v, want %v or %v", got, s.wantTimeOffset, s.wantTimeOffset-1) - return } - if p.BytesSent() != s.wantBytesSent { - t.Errorf("testPeer: wrong BytesSent - got %v, want %v", p.BytesSent(), s.wantBytesSent) - return + if got := p.BytesSent(); got != s.wantBytesSent { + t.Fatalf("wrong BytesSent - got %v, want %v", got, s.wantBytesSent) } - if p.BytesReceived() != s.wantBytesReceived { - t.Errorf("testPeer: wrong BytesReceived - got %v, want %v", p.BytesReceived(), s.wantBytesReceived) - return + if got := p.BytesReceived(); got != s.wantBytesReceived { + t.Fatalf("wrong BytesReceived - got %v, want %v", got, s.wantBytesReceived) } - if p.StartingHeight() != s.wantStartingHeight { - t.Errorf("testPeer: wrong StartingHeight - got %v, want %v", p.StartingHeight(), s.wantStartingHeight) - return + if got := p.StartingHeight(); got != s.wantStartingHeight { + t.Fatalf("wrong StartingHeight - got %v, want %v", got, s.wantStartingHeight) } - if p.Connected() != s.wantConnected { - t.Errorf("testPeer: wrong Connected - got %v, want %v", p.Connected(), s.wantConnected) - return + if got := p.Connected(); got != s.wantConnected { + t.Fatalf("wrong Connected - got %v, want %v", got, s.wantConnected) } stats := p.StatsSnapshot() - if p.ID() != stats.ID { - t.Errorf("testPeer: wrong ID - got %v, want %v", p.ID(), stats.ID) - return + if got := p.ID(); got != stats.ID { + t.Fatalf("wrong ID - got %v, want %v", got, stats.ID) } - if p.Addr() != stats.Addr { - t.Errorf("testPeer: wrong Addr - got %v, want %v", p.Addr(), stats.Addr) - return + if got := p.Addr(); got != stats.Addr { + t.Fatalf("wrong Addr - got %v, want %v", got, stats.Addr) } - if p.LastSend() != stats.LastSend { - t.Errorf("testPeer: wrong LastSend - got %v, want %v", p.LastSend(), stats.LastSend) - return + if got := p.LastSend(); got != stats.LastSend { + t.Fatalf("wrong LastSend - got %v, want %v", got, stats.LastSend) } - if p.LastRecv() != stats.LastRecv { - t.Errorf("testPeer: wrong LastRecv - got %v, want %v", p.LastRecv(), stats.LastRecv) - return + if got := p.LastRecv(); got != stats.LastRecv { + t.Fatalf("wrong LastRecv - got %v, want %v", got, stats.LastRecv) } } -// TestPeerConnection tests connection between inbound and outbound peers. -func TestPeerConnection(t *testing.T) { - var pause sync.Mutex - verack := make(chan struct{}) - peerCfg := &Config{ - Listeners: MessageListeners{ - OnVerAck: func(p *Peer, msg *wire.MsgVerAck) { - verack <- struct{}{} - }, - OnWrite: func(p *Peer, bytesWritten int, msg wire.Message, - err error) { - if _, ok := msg.(*wire.MsgVerAck); ok { - verack <- struct{}{} - } - pause.Lock() - // Needed to squash empty critical section lint errors. - _ = p - pause.Unlock() - }, - }, - UserAgentName: "peer", - UserAgentVersion: "1.0", - Net: wire.MainNet, - Services: 0, - } +// TestPeerHandshake tests the handshake between inbound and outbound peers. +func TestPeerHandshake(t *testing.T) { + peerCfg := mockPeerConfig() + peerCfg.Services = 0 wantStats := peerStats{ wantUserAgent: wire.DefaultUserAgent + "peer:1.0/", wantServices: 0, wantProtocolVersion: MaxProtocolVersion, wantConnected: true, - wantVersionKnown: true, - wantVerAckReceived: true, wantLastPingTime: time.Time{}, wantLastPingNonce: uint64(0), wantLastPingMicros: int64(0), @@ -249,215 +263,207 @@ func TestPeerConnection(t *testing.T) { } tests := []struct { name string - setup func() (*Peer, *Peer, error) + setup func() (*Peer, *Peer) }{{ "basic handshake", - func() (*Peer, *Peer, error) { - inConn, outConn := pipe( - &conn{raddr: "10.0.0.1:8333"}, - &conn{raddr: "10.0.0.2:8333"}, - ) - inPeer := NewInboundPeer(peerCfg) - inPeer.AssociateConnection(inConn) - - outPeer, err := NewOutboundPeer(peerCfg, outConn.RemoteAddr()) - if err != nil { - return nil, nil, err - } - outPeer.AssociateConnection(outConn) - - for i := 0; i < 4; i++ { - select { - case <-verack: - case <-time.After(time.Second): - return nil, nil, errors.New("verack timeout") - } - } - return inPeer, outPeer, nil + func() (*Peer, *Peer) { + inConn, outConn := pipe("10.0.0.1:9108", "10.0.0.2:9108") + inPeer := NewInboundPeer(peerCfg, inConn) + outPeer := NewOutboundPeer(peerCfg, outConn.RemoteAddr(), outConn) + return inPeer, outPeer }, }, { "socks proxy", - func() (*Peer, *Peer, error) { - inConn, outConn := pipe( - &conn{raddr: "10.0.0.1:8333", proxy: true}, - &conn{raddr: "10.0.0.2:8333"}, - ) - inPeer := NewInboundPeer(peerCfg) - inPeer.AssociateConnection(inConn) - - outPeer, err := NewOutboundPeer(peerCfg, outConn.RemoteAddr()) - if err != nil { - return nil, nil, err - } - outPeer.AssociateConnection(outConn) - - for i := 0; i < 4; i++ { - select { - case <-verack: - case <-time.After(time.Second): - return nil, nil, errors.New("verack timeout") - } - } - return inPeer, outPeer, nil + func() (*Peer, *Peer) { + inConn, outConn := pipe("10.0.0.1:9108", "10.0.0.2:9108") + inConn.proxy = true + inPeer := NewInboundPeer(peerCfg, inConn) + outPeer := NewOutboundPeer(peerCfg, outConn.RemoteAddr(), outConn) + return inPeer, outPeer }, }} t.Logf("Running %d tests", len(tests)) - for i, test := range tests { - inPeer, outPeer, err := test.setup() - if err != nil { - t.Errorf("TestPeerConnection setup #%d: unexpected err %v", i, err) - return + for _, test := range tests { + inPeer, outPeer := test.setup() + if err := runHandshakes(inPeer, outPeer); err != nil { + t.Fatalf("%q: failed to perform handshake: %v", test.name, err) } - - pause.Lock() - testPeer(t, inPeer, wantStats) - testPeer(t, outPeer, wantStats) - pause.Unlock() - - inPeer.Disconnect() - outPeer.Disconnect() - inPeer.WaitForDisconnect() - outPeer.WaitForDisconnect() + testPeerState(t, inPeer, wantStats) + testPeerState(t, outPeer, wantStats) } } -// TestPeerListeners tests that the peer listeners are called as expected. -func TestPeerListeners(t *testing.T) { - verack := make(chan struct{}, 1) - ok := make(chan wire.Message, 20) +// TestPeerHandshakeCallback ensures the handshake callback is invoked properly +// and results in a handshake failure when it returns an error. +func TestPeerHandshakeCallback(t *testing.T) { peerCfg := &Config{ - Listeners: MessageListeners{ - OnGetAddr: func(p *Peer, msg *wire.MsgGetAddr) { - ok <- msg - }, - OnAddr: func(p *Peer, msg *wire.MsgAddr) { - ok <- msg - }, - OnAddrV2: func(p *Peer, msg *wire.MsgAddrV2) { - ok <- msg - }, - OnPing: func(p *Peer, msg *wire.MsgPing) { - ok <- msg - }, - OnPong: func(p *Peer, msg *wire.MsgPong) { - ok <- msg - }, - OnMemPool: func(p *Peer, msg *wire.MsgMemPool) { - ok <- msg - }, - OnTx: func(p *Peer, msg *wire.MsgTx) { - ok <- msg - }, - OnBlock: func(p *Peer, msg *wire.MsgBlock, buf []byte) { - ok <- msg - }, - OnInv: func(p *Peer, msg *wire.MsgInv) { - ok <- msg - }, - OnHeaders: func(p *Peer, msg *wire.MsgHeaders) { - ok <- msg - }, - OnNotFound: func(p *Peer, msg *wire.MsgNotFound) { - ok <- msg - }, - OnGetData: func(p *Peer, msg *wire.MsgGetData) { - ok <- msg - }, - OnGetBlocks: func(p *Peer, msg *wire.MsgGetBlocks) { - ok <- msg - }, - OnGetHeaders: func(p *Peer, msg *wire.MsgGetHeaders) { - ok <- msg - }, - OnGetCFilter: func(p *Peer, msg *wire.MsgGetCFilter) { - ok <- msg - }, - OnGetCFHeaders: func(p *Peer, msg *wire.MsgGetCFHeaders) { - ok <- msg - }, - OnGetCFTypes: func(p *Peer, msg *wire.MsgGetCFTypes) { - ok <- msg - }, - OnCFilter: func(p *Peer, msg *wire.MsgCFilter) { - ok <- msg - }, - OnCFHeaders: func(p *Peer, msg *wire.MsgCFHeaders) { - ok <- msg - }, - OnCFTypes: func(p *Peer, msg *wire.MsgCFTypes) { - ok <- msg - }, - OnFeeFilter: func(p *Peer, msg *wire.MsgFeeFilter) { - ok <- msg - }, - OnVersion: func(p *Peer, msg *wire.MsgVersion) { - ok <- msg - }, - OnVerAck: func(p *Peer, msg *wire.MsgVerAck) { - verack <- struct{}{} - }, - OnSendHeaders: func(p *Peer, msg *wire.MsgSendHeaders) { - ok <- msg - }, - OnGetCFilterV2: func(p *Peer, msg *wire.MsgGetCFilterV2) { - ok <- msg - }, - OnCFilterV2: func(p *Peer, msg *wire.MsgCFilterV2) { - ok <- msg - }, - OnGetInitState: func(p *Peer, msg *wire.MsgGetInitState) { - ok <- msg - }, - OnInitState: func(p *Peer, msg *wire.MsgInitState) { - ok <- msg - }, - OnGetCFiltersV2: func(p *Peer, msg *wire.MsgGetCFsV2) { - ok <- msg - }, - OnCFiltersV2: func(p *Peer, msg *wire.MsgCFiltersV2) { - ok <- msg - }, - }, UserAgentName: "peer", UserAgentVersion: "1.0", Net: wire.MainNet, - Services: wire.SFNodeBloom, + Services: 0, } - inConn, outConn := pipe( - &conn{raddr: "10.0.0.1:8333"}, - &conn{raddr: "10.0.0.2:8333"}, - ) - inPeer := NewInboundPeer(peerCfg) - inPeer.AssociateConnection(inConn) + inConn, outConn := pipe("10.0.0.1:9108", "10.0.0.2:9108") + outPeer := NewOutboundPeer(peerCfg, outConn.RemoteAddr(), outConn) + inPeer := NewInboundPeer(peerCfg, inConn) - peerCfg.Listeners = MessageListeners{ - OnVerAck: func(p *Peer, msg *wire.MsgVerAck) { - verack <- struct{}{} - }, - } - outPeer, err := NewOutboundPeer(peerCfg, outConn.RemoteAddr()) - if err != nil { - t.Errorf("NewOutboundPeer: unexpected err %v\n", err) - return + // Ensure the handshake version callback is invoked. + inErr := make(chan error, 1) + version := make(chan struct{}) + go func() { + ctx := context.Background() + err := inPeer.Handshake(ctx, func(msg *wire.MsgVersion) error { + close(version) + return nil + }) + inErr <- err + }() + if err := runHandshakes(outPeer); err != nil { + t.Fatalf("failed to perform handshake: %v", err) } - outPeer.AssociateConnection(outConn) - for i := 0; i < 2; i++ { - select { - case <-verack: - case <-time.After(time.Second * 1): - t.Error("TestPeerListeners: verack timeout\n") - return + select { + case err := <-inErr: + if err != nil { + t.Fatalf("failed to perform inPeer handshake: %v", err) } + case <-time.After(time.Second * 1): + t.Fatal("inPeer handshake error timeout") } select { - case <-ok: + case <-version: case <-time.After(time.Second * 1): - t.Error("TestPeerListeners: version timeout") - return + t.Fatal("version timeout") + } + + // Ensure returning an error from the handshake version callback results in + // handshake failure. + rejectHandshakeErr := fmt.Errorf("rejected in handshake callback") + outPeer = NewOutboundPeer(peerCfg, outConn.RemoteAddr(), outConn) + inPeer = NewInboundPeer(peerCfg, inConn) + go func() { + ctx := context.Background() + err := inPeer.Handshake(ctx, func(msg *wire.MsgVersion) error { + return rejectHandshakeErr + }) + inErr <- err + }() + if err := runHandshakes(outPeer); !errors.Is(err, io.EOF) { + t.Fatalf("did not receive expected err, got: %v, want: %v", err, io.EOF) } + select { + case err := <-inErr: + if !errors.Is(err, rejectHandshakeErr) { + t.Fatalf("did not receive expected err, got: %v, want: %v", err, + rejectHandshakeErr) + } + case <-time.After(time.Second * 1): + t.Fatal("inPeer handshake error timeout") + } +} + +// TestPeerListeners tests that the peer listeners are called as expected. +func TestPeerListeners(t *testing.T) { + ok := make(chan wire.Message, 20) + peerCfg := mockPeerConfig() + peerCfg.Listeners = MessageListeners{ + OnGetAddr: func(p *Peer, msg *wire.MsgGetAddr) { + ok <- msg + }, + OnAddr: func(p *Peer, msg *wire.MsgAddr) { + ok <- msg + }, + OnAddrV2: func(p *Peer, msg *wire.MsgAddrV2) { + ok <- msg + }, + OnPing: func(p *Peer, msg *wire.MsgPing) { + ok <- msg + }, + OnPong: func(p *Peer, msg *wire.MsgPong) { + ok <- msg + }, + OnMemPool: func(p *Peer, msg *wire.MsgMemPool) { + ok <- msg + }, + OnTx: func(p *Peer, msg *wire.MsgTx) { + ok <- msg + }, + OnBlock: func(p *Peer, msg *wire.MsgBlock, buf []byte) { + ok <- msg + }, + OnInv: func(p *Peer, msg *wire.MsgInv) { + ok <- msg + }, + OnHeaders: func(p *Peer, msg *wire.MsgHeaders) { + ok <- msg + }, + OnNotFound: func(p *Peer, msg *wire.MsgNotFound) { + ok <- msg + }, + OnGetData: func(p *Peer, msg *wire.MsgGetData) { + ok <- msg + }, + OnGetBlocks: func(p *Peer, msg *wire.MsgGetBlocks) { + ok <- msg + }, + OnGetHeaders: func(p *Peer, msg *wire.MsgGetHeaders) { + ok <- msg + }, + OnGetCFilter: func(p *Peer, msg *wire.MsgGetCFilter) { + ok <- msg + }, + OnGetCFHeaders: func(p *Peer, msg *wire.MsgGetCFHeaders) { + ok <- msg + }, + OnGetCFTypes: func(p *Peer, msg *wire.MsgGetCFTypes) { + ok <- msg + }, + OnCFilter: func(p *Peer, msg *wire.MsgCFilter) { + ok <- msg + }, + OnCFHeaders: func(p *Peer, msg *wire.MsgCFHeaders) { + ok <- msg + }, + OnCFTypes: func(p *Peer, msg *wire.MsgCFTypes) { + ok <- msg + }, + OnFeeFilter: func(p *Peer, msg *wire.MsgFeeFilter) { + ok <- msg + }, + OnSendHeaders: func(p *Peer, msg *wire.MsgSendHeaders) { + ok <- msg + }, + OnGetCFilterV2: func(p *Peer, msg *wire.MsgGetCFilterV2) { + ok <- msg + }, + OnCFilterV2: func(p *Peer, msg *wire.MsgCFilterV2) { + ok <- msg + }, + OnGetInitState: func(p *Peer, msg *wire.MsgGetInitState) { + ok <- msg + }, + OnInitState: func(p *Peer, msg *wire.MsgInitState) { + ok <- msg + }, + OnGetCFiltersV2: func(p *Peer, msg *wire.MsgGetCFsV2) { + ok <- msg + }, + OnCFiltersV2: func(p *Peer, msg *wire.MsgCFiltersV2) { + ok <- msg + }, + } + inConn, outConn := pipe("10.0.0.1:9108", "10.0.0.2:9108") + inPeer := NewInboundPeer(peerCfg, inConn) + peerCfg.Listeners = MessageListeners{} + outPeer := NewOutboundPeer(peerCfg, outConn.RemoteAddr(), outConn) + if err := runHandshakes(inPeer, outPeer); err != nil { + t.Fatalf("failed to perform handshake: %v", err) + } + cancel, wg := runPeersAsync(inPeer, outPeer) + defer wg.Wait() + defer cancel() + const pver = wire.ProtocolVersion tests := []struct { listener string @@ -619,65 +625,34 @@ func TestPeerListeners(t *testing.T) { // Queue the test message outPeer.QueueMessage(test.msg, nil) select { - case <-ok: + case got := <-ok: + if reflect.TypeOf(got) != reflect.TypeOf(test.msg) { + t.Fatalf("wrong message type: got %T, want %T", got, test.msg) + } case <-time.After(time.Second * 1): - t.Errorf("TestPeerListeners: %s timeout", test.listener) - return + t.Fatalf("%s timeout", test.listener) } } - inPeer.Disconnect() - outPeer.Disconnect() } // TestOldProtocolVersion ensures that peers with protocol versions older than // the minimum required version are disconnected. func TestOldProtocolVersion(t *testing.T) { - version := make(chan wire.Message, 1) - verack := make(chan struct{}, 1) - peerCfg := &Config{ - ProtocolVersion: wire.RemoveRejectVersion - 1, - Listeners: MessageListeners{ - OnVersion: func(p *Peer, msg *wire.MsgVersion) { - version <- msg - }, - OnVerAck: func(p *Peer, msg *wire.MsgVerAck) { - verack <- struct{}{} - }, - }, - UserAgentName: "peer", - UserAgentVersion: "1.0", - Net: wire.MainNet, - Services: wire.SFNodeNetwork, - } - inConn, outConn := pipe( - &conn{raddr: "10.0.0.1:8333"}, - &conn{raddr: "10.0.0.2:8333"}, - ) - inPeer := NewInboundPeer(peerCfg) - inPeer.AssociateConnection(inConn) - defer inPeer.Disconnect() - + const minVer = wire.RemoveRejectVersion + peerCfg := mockPeerConfig() + peerCfg.ProtocolVersion = minVer - 1 + inConn, outConn := pipe("10.0.0.1:9108", "10.0.0.2:9108") + inPeer := NewInboundPeer(peerCfg, inConn) peerCfg.Listeners = MessageListeners{} - outPeer, err := NewOutboundPeer(peerCfg, outConn.RemoteAddr()) - if err != nil { - t.Errorf("NewOutboundPeer: unexpected err %v", err) - return + outPeer := NewOutboundPeer(peerCfg, outConn.RemoteAddr(), outConn) + wantErr := ErrProtocolVerTooOld + if err := runHandshakes(inPeer, outPeer); !errors.Is(err, wantErr) { + t.Fatalf("unexpected handshake error -- got: %v, want %v", err, wantErr) } - outPeer.AssociateConnection(outConn) - defer outPeer.Disconnect() - select { - case <-version: - case <-time.After(time.Second * 1): - t.Fatal("version timeout") - } - - // Ensure the inbound peer is disconnected and does not receive a verack - // from the outbound side. + // Ensure the inbound peer is disconnected. select { case <-inPeer.quit: - case <-verack: - t.Fatal("unexpected verack from outbound peer") case <-time.After(time.Second * 1): t.Fatal("inbound peer disconnect timeout") } @@ -690,128 +665,67 @@ func TestOldProtocolVersion(t *testing.T) { } } -// TestOutboundPeer tests that the outbound peer works as expected. -func TestOutboundPeer(t *testing.T) { - peerCfg := &Config{ - NewestBlock: func() (*chainhash.Hash, int64, error) { - return nil, 0, errors.New("newest block not found") - }, - UserAgentName: "peer", - UserAgentVersion: "1.0", - Net: wire.MainNet, - Services: 0, - } - - r, w := io.Pipe() - c := &conn{raddr: "10.0.0.1:8333", WriteCloser: w, ReadCloser: r} - - p, err := NewOutboundPeer(peerCfg, c.RemoteAddr()) - if err != nil { - t.Errorf("NewOutboundPeer: unexpected err - %v\n", err) - return - } - - // Test trying to connect twice. - p.AssociateConnection(c) - p.AssociateConnection(c) - +// TestNoNewestBlock ensures peers are disconnected due to an error returned by +// the caller-provided newest block callback. +func TestNoNewestBlock(t *testing.T) { + // Create a pair of peers that connect to each other using a fake conn such + // that the inbound peer has a newest block callback that errors. + errNoNewestBlock := errors.New("newest block not found") + peerCfg := mockPeerConfig() + inConn, outConn := pipe("10.0.0.1:9108", "10.0.0.1:9108") + outPeer := NewOutboundPeer(peerCfg, outConn.RemoteAddr(), outConn) + peerCfg.NewestBlock = func() (*chainhash.Hash, int64, error) { + return nil, 0, errNoNewestBlock + } + inPeer := NewInboundPeer(peerCfg, inConn) + wantErr := errNoNewestBlock + if err := runHandshakes(inPeer, outPeer); !errors.Is(err, wantErr) { + t.Fatalf("unexpected handshake error -- got: %v, want %v", err, wantErr) + } + + // Ensure the inbound peer disconnects due to the error. disconnected := make(chan struct{}) go func() { - p.WaitForDisconnect() + inPeer.WaitForDisconnect() disconnected <- struct{}{} }() select { case <-disconnected: - close(disconnected) case <-time.After(time.Second): - t.Fatal("Peer did not automatically disconnect.") + t.Fatal("peer did not automatically disconnect") } - if p.Connected() { - t.Fatalf("Should not be connected as NewestBlock produces error.") + if inPeer.Connected() { + t.Fatal("inbound peer should not be connected") } - // Test Queue Inv - fakeBlockHash := &chainhash.Hash{0: 0x00, 1: 0x01} - fakeInv := wire.NewInvVect(wire.InvTypeBlock, fakeBlockHash) - - // Should be noops as the peer could not connect. - p.QueueInventory(fakeInv) - p.AddKnownInventory(fakeInv) - p.QueueInventory(fakeInv) - - fakeMsg := wire.NewMsgVerAck() - p.QueueMessage(fakeMsg, nil) - done := make(chan struct{}) - p.QueueMessage(fakeMsg, done) - <-done - p.Disconnect() - - // Test NewestBlock - newestBlock := func() (*chainhash.Hash, int64, error) { - hashStr := "14a0810ac680a3eb3f82edc878cea25ec41d6b790744e5daeef" - hash, err := chainhash.NewHashFromStr(hashStr) - if err != nil { - return nil, 0, err - } - return hash, 234439, nil + // Repeat, but in the other direction so the outbound peer has the error. + inConn, outConn = pipe("10.0.0.1:9108", "10.0.0.1:9108") + outPeer = NewOutboundPeer(peerCfg, outConn.RemoteAddr(), outConn) + peerCfg.NewestBlock = nil + inPeer = NewInboundPeer(peerCfg, inConn) + if err := runHandshakes(inPeer, outPeer); !errors.Is(err, wantErr) { + t.Fatalf("unexpected handshake error -- got: %v, want %v", err, wantErr) } - peerCfg.NewestBlock = newestBlock - r1, w1 := io.Pipe() - c1 := &conn{raddr: "10.0.0.1:8333", WriteCloser: w1, ReadCloser: r1} - p1, err := NewOutboundPeer(peerCfg, c1.RemoteAddr()) - if err != nil { - t.Errorf("NewOutboundPeer: unexpected err - %v\n", err) - return - } - p1.AssociateConnection(c1) - - // Test Queue Inv after connection - p1.QueueInventory(fakeInv) - p1.Disconnect() + // Ensure the outbound peer disconnects due to the error. + disconnected = make(chan struct{}) + go func() { + outPeer.WaitForDisconnect() + disconnected <- struct{}{} + }() - // Test testnet - peerCfg.Net = wire.TestNet3 - peerCfg.Services = wire.SFNodeBloom - r2, w2 := io.Pipe() - c2 := &conn{raddr: "10.0.0.1:8333", WriteCloser: w2, ReadCloser: r2} - p2, err := NewOutboundPeer(peerCfg, c2.RemoteAddr()) - if err != nil { - t.Errorf("NewOutboundPeer: unexpected err - %v\n", err) - return + select { + case <-disconnected: + close(disconnected) + case <-time.After(time.Second): + t.Fatal("peer did not automatically disconnect") } - p2.AssociateConnection(c2) - // Test PushXXX - var addrs []*wire.NetAddress - for i := 0; i < 5; i++ { - na := wire.NetAddress{} - addrs = append(addrs, &na) - } - if _, err := p2.PushAddrMsg(addrs); err != nil { - t.Errorf("PushAddrMsg: unexpected err %v\n", err) - return - } - if err := p2.PushGetBlocksMsg(nil, &chainhash.Hash{}); err != nil { - t.Errorf("PushGetBlocksMsg: unexpected err %v\n", err) - return - } - if err := p2.PushGetHeadersMsg(nil, &chainhash.Hash{}); err != nil { - t.Errorf("PushGetHeadersMsg: unexpected err %v\n", err) - return + if outPeer.Connected() { + t.Fatal("outbound peer should not be connected") } - - // Test Queue Messages - p2.QueueMessage(wire.NewMsgGetAddr(), nil) - p2.QueueMessage(wire.NewMsgPing(1), nil) - p2.QueueMessage(wire.NewMsgMemPool(), nil) - p2.QueueMessage(wire.NewMsgGetData(), nil) - p2.QueueMessage(wire.NewMsgGetHeaders(), nil) - p2.QueueMessage(wire.NewMsgFeeFilter(20000), nil) - - p2.Disconnect() } // TestDuplicateVersionMsg ensures that receiving a version message after one @@ -819,38 +733,16 @@ func TestOutboundPeer(t *testing.T) { func TestDuplicateVersionMsg(t *testing.T) { // Create a pair of peers that are connected to each other using a fake // connection. - verack := make(chan struct{}) - peerCfg := &Config{ - Listeners: MessageListeners{ - OnVerAck: func(p *Peer, msg *wire.MsgVerAck) { - verack <- struct{}{} - }, - }, - UserAgentName: "peer", - UserAgentVersion: "1.0", - Net: wire.MainNet, - Services: 0, - } - inConn, outConn := pipe( - &conn{laddr: "10.0.0.1:9108", raddr: "10.0.0.2:9108"}, - &conn{laddr: "10.0.0.2:9108", raddr: "10.0.0.1:9108"}, - ) - outPeer, err := NewOutboundPeer(peerCfg, outConn.RemoteAddr()) - if err != nil { - t.Fatalf("NewOutboundPeer: unexpected err: %v\n", err) - } - outPeer.AssociateConnection(outConn) - inPeer := NewInboundPeer(peerCfg) - inPeer.AssociateConnection(inConn) - - // Wait for the veracks from the initial protocol version negotiation. - for i := 0; i < 2; i++ { - select { - case <-verack: - case <-time.After(time.Second): - t.Fatal("verack timeout") - } + peerCfg := mockPeerConfig() + inConn, outConn := pipe("10.0.0.1:9108", "10.0.0.1:9108") + outPeer := NewOutboundPeer(peerCfg, outConn.RemoteAddr(), outConn) + inPeer := NewInboundPeer(peerCfg, inConn) + if err := runHandshakes(inPeer, outPeer); err != nil { + t.Fatalf("failed to perform handshake: %v", err) } + cancel, wg := runPeersAsync(inPeer, outPeer) + defer wg.Wait() + defer cancel() // Queue a duplicate version message from the outbound peer and wait until // it is sent. @@ -867,7 +759,7 @@ func TestDuplicateVersionMsg(t *testing.T) { disconnected := make(chan struct{}, 1) go func() { inPeer.WaitForDisconnect() - disconnected <- struct{}{} + close(disconnected) }() select { case <-disconnected: @@ -887,16 +779,19 @@ func TestNetFallback(t *testing.T) { UserAgentVersion: "1.0", Services: 0, } + inConn, outConn := pipe("10.0.0.1:9108", "10.0.0.2:9108") + defer inConn.Close() + defer outConn.Close() // Ensure testnet is used when the network is not specified. - p := NewInboundPeer(&cfg) + p := NewInboundPeer(&cfg, inConn) if p.cfg.Net != wire.TestNet3 { t.Fatalf("default network is %v instead of testnet3", p.cfg.Net) } // Ensure network is set to the explicitly specified value. cfg.Net = wire.SimNet - p = NewInboundPeer(&cfg) + p = NewInboundPeer(&cfg, inConn) if p.cfg.Net != wire.SimNet { t.Fatalf("explicit network is %v instead of %v", p.cfg.Net, wire.SimNet) } @@ -909,13 +804,7 @@ func TestUpdateLastBlockHeight(t *testing.T) { // Create a pair of peers that are connected to each other using a fake // connection and the remote peer starting at height 100. const remotePeerHeight = 100 - verack := make(chan struct{}) peerCfg := Config{ - Listeners: MessageListeners{ - OnVerAck: func(p *Peer, msg *wire.MsgVerAck) { - verack <- struct{}{} - }, - }, UserAgentName: "peer", UserAgentVersion: "1.0", Net: wire.MainNet, @@ -925,26 +814,15 @@ func TestUpdateLastBlockHeight(t *testing.T) { remotePeerCfg.NewestBlock = func() (*chainhash.Hash, int64, error) { return &chainhash.Hash{}, remotePeerHeight, nil } - inConn, outConn := pipe( - &conn{laddr: "10.0.0.1:9108", raddr: "10.0.0.2:9108"}, - &conn{laddr: "10.0.0.2:9108", raddr: "10.0.0.1:9108"}, - ) - localPeer, err := NewOutboundPeer(&peerCfg, outConn.RemoteAddr()) - if err != nil { - t.Fatalf("NewOutboundPeer: unexpected err: %v\n", err) - } - localPeer.AssociateConnection(outConn) - inPeer := NewInboundPeer(&remotePeerCfg) - inPeer.AssociateConnection(inConn) - - // Wait for the veracks from the initial protocol version negotiation. - for i := 0; i < 2; i++ { - select { - case <-verack: - case <-time.After(time.Second): - t.Fatal("verack timeout") - } + inConn, outConn := pipe("10.0.0.1:9108", "10.0.0.2:9108") + localPeer := NewOutboundPeer(&peerCfg, outConn.RemoteAddr(), outConn) + inPeer := NewInboundPeer(&remotePeerCfg, inConn) + if err := runHandshakes(localPeer, inPeer); err != nil { + t.Fatalf("failed to perform handshake: %v", err) } + cancel, wg := runPeersAsync(localPeer, inPeer) + defer wg.Wait() + defer cancel() // Ensure the latest block height starts at the value reported by the remote // peer via its version message. @@ -1013,22 +891,22 @@ func TestPushAddrV2Msg(t *testing.T) { }} // Create a mock connection. - _, outConn := pipe( - &conn{laddr: "10.0.0.1:9108", raddr: "10.0.0.2:9108"}, - &conn{laddr: "10.0.0.2:9108", raddr: "10.0.0.1:9108"}, - ) + inConn, outConn := pipe("10.0.0.1:9108", "10.0.0.2:9108") // Create a peer with the connection. - cfg := &Config{} - peer, err := NewOutboundPeer(cfg, outConn.RemoteAddr()) - if err != nil { - t.Fatalf("NewOutboundPeer: unexpected err: %v", err) + peerCfg := mockPeerConfig() + inPeer := NewInboundPeer(peerCfg, inConn) + outPeer := NewOutboundPeer(peerCfg, outConn.RemoteAddr(), outConn) + if err := runHandshakes(inPeer, outPeer); err != nil { + t.Fatalf("failed to perform handshake: %v", err) } - peer.AssociateConnection(outConn) + cancel, wg := runPeersAsync(inPeer, outPeer) + defer wg.Wait() + defer cancel() for _, test := range tests { // Test the PushAddrV2Msg function. - sent := peer.PushAddrV2Msg(test.addrs) + sent := outPeer.PushAddrV2Msg(test.addrs) // Check the number of addresses sent. if got := len(sent); got != test.wantSentLen { @@ -1036,9 +914,6 @@ func TestPushAddrV2Msg(t *testing.T) { test.wantSentLen, got) } } - - peer.Disconnect() - peer.WaitForDisconnect() } func init() { diff --git a/server.go b/server.go index e85438f60..403919bce 100644 --- a/server.go +++ b/server.go @@ -668,9 +668,6 @@ type serverPeer struct { isWhitelisted bool quit chan struct{} - handshakeDone chan struct{} - closeHandshakeDoneOnce sync.Once - // syncMgrPeer houses the network sync manager peer instance that wraps the // underlying peer similar to the way this server peer itself wraps it. syncMgrPeer *netsync.Peer @@ -727,7 +724,6 @@ func newServerPeer(s *server, remoteAddr *addrmgr.NetAddress, isPersistent bool) persistent: isPersistent, knownAddresses: apbf.NewFilter(maxKnownAddrsPerPeer, knownAddrsFPRate), quit: make(chan struct{}), - handshakeDone: make(chan struct{}), getDataQueue: make(chan []*wire.InvVect, maxConcurrentGetDataReqs), } } @@ -914,7 +910,7 @@ func (sp *serverPeer) serveGetData() { // the peer has disconnected and performs other associated cleanup such as // evicting any remaining orphans sent by the peer and shutting down all // goroutines. -func (sp *serverPeer) Run() { +func (sp *serverPeer) Run(ctx context.Context) { var wg sync.WaitGroup wg.Add(1) go func() { @@ -922,6 +918,13 @@ func (sp *serverPeer) Run() { wg.Done() }() + // Start processing async I/O. + sp.Start() + + // Request all block annoucements via full headers instead of the inv + // message. + sp.QueueMessage(wire.NewMsgSendHeaders(), nil) + // Add valid peer to the server. sp.server.AddPeer(sp) @@ -932,14 +935,11 @@ func (sp *serverPeer) Run() { srvr.DonePeer(sp) srvr.syncManager.OnPeerDisconnected(sp.syncMgrPeer) - if sp.VersionKnown() { - // Evict any remaining mempool orphans that were sent by the peer. - numEvicted := srvr.txMemPool.RemoveOrphansByTag(mempool.Tag(sp.ID())) - if numEvicted > 0 { - srvrLog.Debugf("Evicted %d mempool %s from peer %v (id %d)", - numEvicted, pickNoun(numEvicted, "orphan", "orphans"), sp, - sp.ID()) - } + // Evict any remaining mempool orphans that were sent by the peer. + numEvicted := srvr.txMemPool.RemoveOrphansByTag(mempool.Tag(sp.ID())) + if numEvicted > 0 { + srvrLog.Debugf("Evicted %d mempool %s from peer %v (id %d)", numEvicted, + pickNoun(numEvicted, "orphan", "orphans"), sp, sp.ID()) } // Shutdown remaining peer goroutines. @@ -1178,7 +1178,7 @@ func natfSupported(pver uint32) addrmgr.NetAddressTypeFilter { // OnVersion is invoked when a peer receives a version wire message and is used // to negotiate the protocol version details as well as kick start the // communications. -func (sp *serverPeer) OnVersion(_ *peer.Peer, msg *wire.MsgVersion) { +func (sp *serverPeer) OnVersion(msg *wire.MsgVersion) error { // Update the address manager with the advertised services for outbound // connections in case they have changed. This is not done for inbound // connections to help prevent malicious behavior and is skipped when @@ -1201,11 +1201,8 @@ func (sp *serverPeer) OnVersion(_ *peer.Peer, msg *wire.MsgVersion) { // Reject peers that have a protocol version that is too old. const reqProtocolVersion = int32(wire.RemoveRejectVersion) if msg.ProtocolVersion < reqProtocolVersion { - srvrLog.Debugf("Rejecting peer %s with protocol version %d prior to "+ - "the required version %d", sp, msg.ProtocolVersion, - reqProtocolVersion) - sp.Disconnect() - return + return fmt.Errorf("rejecting protocol version %d prior to the required "+ + "version %d", msg.ProtocolVersion, reqProtocolVersion) } // Maintain a minimum desired number of outbound peers capable of supporting @@ -1231,11 +1228,10 @@ func (sp *serverPeer) OnVersion(_ *peer.Peer, msg *wire.MsgVersion) { needsMoreMixCapable := !hasMinMixCapableOuts && numOutbound+wantMixCapableOutbound >= sp.server.targetOutbound if needsMoreMixCapable { - srvrLog.Debugf("Rejecting outbound peer %s with protocol version "+ + return fmt.Errorf("rejecting outbound peer with protocol version "+ "%d in favor of a peer with minimum version %d (have: %d, "+ - "target: %d)", sp, msg.ProtocolVersion, wire.MixVersion, + "target: %d)", msg.ProtocolVersion, wire.MixVersion, numMixCapableOutbound, wantMixCapableOutbound) - sp.Disconnect() } } @@ -1243,10 +1239,8 @@ func (sp *serverPeer) OnVersion(_ *peer.Peer, msg *wire.MsgVersion) { wantServices := wire.SFNodeNetwork if !isInbound && !hasServices(msg.Services, wantServices) { missingServices := wantServices & ^msg.Services - srvrLog.Debugf("Rejecting peer %s with services %v due to not "+ - "providing desired services %v", sp, msg.Services, missingServices) - sp.Disconnect() - return + return fmt.Errorf("rejecting peer with services %v due to not "+ + "providing desired services %v", msg.Services, missingServices) } // Update the address manager and request known addresses from the @@ -1293,14 +1287,7 @@ func (sp *serverPeer) OnVersion(_ *peer.Peer, msg *wire.MsgVersion) { // Add the remote peer time as a sample for creating an offset against // the local clock to keep the network time in sync. sp.server.timeSource.AddTimeSample(sp.Addr(), msg.Timestamp) -} - -// OnVerAck is invoked when a peer receives a verack wire message. It creates -// and sends a sendheaders message to request all block annoucements are made -// via full headers instead of the inv message. -func (sp *serverPeer) OnVerAck(_ *peer.Peer, msg *wire.MsgVerAck) { - sp.closeHandshakeDoneOnce.Do(func() { close(sp.handshakeDone) }) - sp.QueueMessage(wire.NewMsgSendHeaders(), nil) + return nil } // OnMemPool is invoked when a peer receives a mempool wire message. It creates @@ -2479,8 +2466,6 @@ func newPeerConfig(sp *serverPeer) *peer.Config { return &peer.Config{ Listeners: peer.MessageListeners{ - OnVersion: sp.OnVersion, - OnVerAck: sp.OnVerAck, OnMemPool: sp.OnMemPool, OnGetMiningState: sp.OnGetMiningState, OnMiningState: sp.OnMiningState, @@ -2538,7 +2523,7 @@ func newPeerConfig(sp *serverPeer) *peer.Config { // connection is established. It initializes a new inbound server peer // instance, associates it with the connection, and starts all additional server // peer processing goroutines. -func (s *server) inboundPeerConnected(conn net.Conn) { +func (s *server) inboundPeerConnected(ctx context.Context, conn net.Conn) { remoteNetAddr, err := connToNetAddr(conn) if err != nil { srvrLog.Debugf("Unable to create inbound peer for address %s: %v", @@ -2554,16 +2539,15 @@ func (s *server) inboundPeerConnected(conn net.Conn) { sp := newServerPeer(s, remoteNetAddr, false) sp.isWhitelisted = isWhitelisted(conn.RemoteAddr()) - sp.Peer = peer.NewInboundPeer(newPeerConfig(sp)) - sp.AssociateConnection(conn) - select { - case <-sp.handshakeDone: - case <-time.After(30 * time.Second): - srvrLog.Debugf("Handshake timeout for inbound peer %s", conn.RemoteAddr()) + sp.Peer = peer.NewInboundPeer(newPeerConfig(sp), conn) + if err := sp.Handshake(ctx, sp.OnVersion); err != nil { + srvrLog.Debugf("Failed handshake for inbound peer %s: %v", + remoteNetAddr, err) + conn.Close() return } sp.syncMgrPeer = netsync.NewPeer(sp.Peer) - go sp.Run() + sp.Run(ctx) } // outboundPeerConnected is invoked by the connection manager when a new @@ -2571,41 +2555,35 @@ func (s *server) inboundPeerConnected(conn net.Conn) { // peer instance, associates it with the relevant state such as the connection // request instance and the connection itself, and start all additional server // peer processing goroutines. -func (s *server) outboundPeerConnected(c *connmgr.ConnReq, conn net.Conn) { +func (s *server) outboundPeerConnected(ctx context.Context, c *connmgr.ConnReq, conn net.Conn) { remoteNetAddr, err := connToNetAddr(conn) if err != nil { srvrLog.Debugf("Unable to create outbound peer for address %s: %v", conn.RemoteAddr(), err) conn.Close() + s.connManager.Disconnect(c.ID()) } // Disconnect banned connections. Ideally we would never connect to a // banned peer, but the connection manager is currently unaware of banned // addresses, so this is needed. if disconnected := s.handleBannedConn(conn); disconnected { + s.connManager.Disconnect(c.ID()) return } sp := newServerPeer(s, remoteNetAddr, c.Permanent) - p, err := peer.NewOutboundPeer(newPeerConfig(sp), c.Addr) - if err != nil { - srvrLog.Debugf("Cannot create outbound peer %s: %v", c.Addr, err) - s.connManager.Disconnect(c.ID()) - return - } + p := peer.NewOutboundPeer(newPeerConfig(sp), c.Addr, conn) sp.Peer = p sp.connReq.Store(c) sp.isWhitelisted = isWhitelisted(conn.RemoteAddr()) - sp.AssociateConnection(conn) - select { - case <-sp.handshakeDone: - case <-time.After(30 * time.Second): - srvrLog.Debugf("Handshake timeout from outbound peer %s", c.Addr) + if err := sp.Handshake(ctx, sp.OnVersion); err != nil { + srvrLog.Debugf("Failed handshake for outbound peer %s: %v", c.Addr, err) s.connManager.Disconnect(c.ID()) return } sp.syncMgrPeer = netsync.NewPeer(sp.Peer) - go sp.Run() + sp.Run(ctx) } // peerHandler is used to handle peer operations such as inventory relay and @@ -2709,7 +2687,6 @@ func (s *server) handleAddPeer(sp *serverPeer) bool { na := sp.peerNa.Load() // Add the new peer. - srvrLog.Debugf("New peer %s", sp) if sp.Inbound() { state.inboundPeers[sp.ID()] = sp @@ -2835,10 +2812,8 @@ func (s *server) DonePeer(sp *serverPeer) { list = state.outboundPeers } if _, ok := list[sp.ID()]; ok { - if !sp.Inbound() && sp.VersionKnown() { - state.outboundGroups[sp.remoteAddr.GroupKey()]-- - } if !sp.Inbound() { + state.outboundGroups[sp.remoteAddr.GroupKey()]-- connReq := sp.connReq.Load() if connReq != nil { s.connManager.Disconnect(connReq.ID()) @@ -2854,12 +2829,11 @@ func (s *server) DonePeer(sp *serverPeer) { s.connManager.Disconnect(connReq.ID()) } - // Update the address manager with the last seen time when the peer has - // acknowledged our version and has sent us its version as well. This is - // skipped when running on the simulation and regression test networks since - // they are only intended to connect to specified peers and actively avoid + // Update the address manager with the last seen time. This is skipped when + // running on the simulation and regression test networks since they are + // only intended to connect to specified peers and actively avoid // advertising and connecting to discovered peers. - if !cfg.SimNet && !cfg.RegNet && sp.VerAckReceived() && sp.VersionKnown() { + if !cfg.SimNet && !cfg.RegNet { err := s.addrManager.Connected(sp.remoteAddr) if err != nil { srvrLog.Errorf("Marking address as connected failed: %v", err) @@ -4506,14 +4480,18 @@ func newServer(ctx context.Context, profiler *profileServer, s.targetOutbound = uint32(cfg.MaxPeers) } cmgr, err := connmgr.New(&connmgr.Config{ - Listeners: listeners, - OnAccept: s.inboundPeerConnected, + Listeners: listeners, + OnAccept: func(conn net.Conn) { + s.inboundPeerConnected(ctx, conn) + }, RetryDuration: connectionRetryInterval, TargetOutbound: s.targetOutbound, Dial: s.attemptDcrdDial, Timeout: cfg.DialTimeout, - OnConnection: s.outboundPeerConnected, - GetNewAddress: newAddressFunc, + OnConnection: func(c *connmgr.ConnReq, conn net.Conn) { + s.outboundPeerConnected(ctx, c, conn) + }, + GetNewAddress: newAddressFunc, }) if err != nil { return nil, err