Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions internal/connmgr/connmanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,10 @@ func TestConnectMode(t *testing.T) {
// configuration option by waiting until all connections are established and
// ensuring they are the only connections made.
func TestTargetOutbound(t *testing.T) {
targetOutbound := uint32(10)
connected := make(chan *ConnReq)
const targetOutbound = 10
var numConnections atomic.Uint32
hitTargetConns := make(chan struct{})
extraConns := make(chan *ConnReq)
cmgr, err := New(&Config{
TargetOutbound: targetOutbound,
Dial: mockDialer,
Expand All @@ -211,7 +213,14 @@ func TestTargetOutbound(t *testing.T) {
}, nil
},
OnConnection: func(c *ConnReq, conn net.Conn) {
connected <- c
totalConnections := numConnections.Add(1)
if totalConnections == targetOutbound {
close(hitTargetConns)
return
}
if totalConnections > targetOutbound {
extraConns <- c
}
},
})
if err != nil {
Expand All @@ -220,13 +229,15 @@ func TestTargetOutbound(t *testing.T) {
_, shutdown, wg := runConnMgrAsync(context.Background(), cmgr)

// Wait for the expected number of target outbound conns to be established.
for i := uint32(0); i < targetOutbound; i++ {
<-connected
select {
case <-hitTargetConns:
case <-time.After(20 * time.Millisecond):
t.Fatal("did not reach target number of conns before timeout")
}

// Ensure no additional connections are made.
select {
case c := <-connected:
case c := <-extraConns:
t.Fatalf("target outbound: got unexpected connection - %v", c.Addr)
case <-time.After(time.Millisecond * 5):
break
Expand All @@ -241,7 +252,11 @@ func TestTargetOutbound(t *testing.T) {
// any address object returned by GetNewAddress will be correctly passed along
// to DialAddr to be used for connecting to a host.
func TestPassAddrAlongDialAddr(t *testing.T) {
connected := make(chan *ConnReq)
dailedAddr := make(chan net.Addr)
detectDialer := func(ctx context.Context, addr net.Addr) (net.Conn, error) {
dailedAddr <- addr
return nil, errors.New("error")
}

// targetAddr will be the specific address we'll use to connect. It _could_
// be carrying more info than a standard (tcp/udp) network address, so it
Expand All @@ -253,22 +268,19 @@ func TestPassAddrAlongDialAddr(t *testing.T) {

cmgr, err := New(&Config{
TargetOutbound: 1,
DialAddr: mockDialerAddr,
DialAddr: detectDialer,
GetNewAddress: func() (net.Addr, error) {
return targetAddr, nil
},
OnConnection: func(c *ConnReq, conn net.Conn) {
connected <- c
},
})
if err != nil {
t.Fatalf("New error: %v", err)
}
_, shutdown, wg := runConnMgrAsync(context.Background(), cmgr)

select {
case c := <-connected:
receivedMock, isMockAddr := c.Addr.(mockAddr)
case addr := <-dailedAddr:
receivedMock, isMockAddr := addr.(mockAddr)
if !isMockAddr {
t.Fatal("connected to an address that was not a mockAddr")
}
Expand Down