diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 8a499d340..2153fff81 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -199,8 +199,10 @@ func TestConnectMode(t *testing.T) { // configuration option by waiting until all connections are established and // ensuring they are the only connections made. func TestTargetOutbound(t *testing.T) { - targetOutbound := uint32(10) - connected := make(chan *ConnReq) + const targetOutbound = 10 + var numConnections atomic.Uint32 + hitTargetConns := make(chan struct{}) + extraConns := make(chan *ConnReq) cmgr, err := New(&Config{ TargetOutbound: targetOutbound, Dial: mockDialer, @@ -211,7 +213,14 @@ func TestTargetOutbound(t *testing.T) { }, nil }, OnConnection: func(c *ConnReq, conn net.Conn) { - connected <- c + totalConnections := numConnections.Add(1) + if totalConnections == targetOutbound { + close(hitTargetConns) + return + } + if totalConnections > targetOutbound { + extraConns <- c + } }, }) if err != nil { @@ -220,13 +229,15 @@ func TestTargetOutbound(t *testing.T) { _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Wait for the expected number of target outbound conns to be established. - for i := uint32(0); i < targetOutbound; i++ { - <-connected + select { + case <-hitTargetConns: + case <-time.After(20 * time.Millisecond): + t.Fatal("did not reach target number of conns before timeout") } // Ensure no additional connections are made. select { - case c := <-connected: + case c := <-extraConns: t.Fatalf("target outbound: got unexpected connection - %v", c.Addr) case <-time.After(time.Millisecond * 5): break @@ -241,7 +252,11 @@ func TestTargetOutbound(t *testing.T) { // any address object returned by GetNewAddress will be correctly passed along // to DialAddr to be used for connecting to a host. func TestPassAddrAlongDialAddr(t *testing.T) { - connected := make(chan *ConnReq) + dailedAddr := make(chan net.Addr) + detectDialer := func(ctx context.Context, addr net.Addr) (net.Conn, error) { + dailedAddr <- addr + return nil, errors.New("error") + } // targetAddr will be the specific address we'll use to connect. It _could_ // be carrying more info than a standard (tcp/udp) network address, so it @@ -253,13 +268,10 @@ func TestPassAddrAlongDialAddr(t *testing.T) { cmgr, err := New(&Config{ TargetOutbound: 1, - DialAddr: mockDialerAddr, + DialAddr: detectDialer, GetNewAddress: func() (net.Addr, error) { return targetAddr, nil }, - OnConnection: func(c *ConnReq, conn net.Conn) { - connected <- c - }, }) if err != nil { t.Fatalf("New error: %v", err) @@ -267,8 +279,8 @@ func TestPassAddrAlongDialAddr(t *testing.T) { _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) select { - case c := <-connected: - receivedMock, isMockAddr := c.Addr.(mockAddr) + case addr := <-dailedAddr: + receivedMock, isMockAddr := addr.(mockAddr) if !isMockAddr { t.Fatal("connected to an address that was not a mockAddr") }