diff --git a/gbn/config.go b/gbn/config.go index 1c5b68ec..a10ac9bc 100644 --- a/gbn/config.go +++ b/gbn/config.go @@ -78,6 +78,27 @@ func WithBoostPercent(boostPercent float32) TimeoutOptions { } } +// WithDynamicPongTimeout enables dynamic pong timeout based on observed RTT. +// When enabled, the pong timeout is computed as max(basePongTime, +// pongMultiplier * smoothedRTT), capped at maxPongTime. This is useful for +// relay-based transports where the round-trip time through the relay can vary +// significantly based on network conditions. +func WithDynamicPongTimeout(pongMultiplier int, + maxPongTime time.Duration) TimeoutOptions { + + return func(manager *TimeoutManager) { + manager.dynamicPongTime = true + + if pongMultiplier > 0 { + manager.pongMultiplier = pongMultiplier + } + + if maxPongTime > 0 { + manager.maxPongTime = maxPongTime + } + } +} + // config holds the configuration values for an instance of GoBackNConn. type config struct { // n is the window size. The sender can send a maximum of n packets diff --git a/gbn/gbn_conn.go b/gbn/gbn_conn.go index 7ff3e06e..8e4f57e5 100644 --- a/gbn/gbn_conn.go +++ b/gbn/gbn_conn.go @@ -424,8 +424,11 @@ func (g *GoBackNConn) sendPacketsForever() error { default: } - // Start the pong timer. - g.pongTicker.Reset() + // Start the pong timer. We use ResetWithInterval + // to pick up any dynamic pong timeout changes + // based on observed RTT. + pongTime := g.timeoutManager.GetPongTime() + g.pongTicker.ResetWithInterval(pongTime) g.pongTicker.Resume() // Also reset the ping timer. diff --git a/gbn/gbn_conn_test.go b/gbn/gbn_conn_test.go index 70b36853..79f14422 100644 --- a/gbn/gbn_conn_test.go +++ b/gbn/gbn_conn_test.go @@ -1180,6 +1180,265 @@ func TestPayloadSplitting(t *testing.T) { require.True(t, bytes.Equal(msg, payload1)) } +// TestBackwardsCompatMixedTimeouts ensures that a new client with increased +// ping/pong timeouts and dynamic pong can communicate with an old server using +// the original static timeout values. GBN timeouts are configured independently +// on each side and are not negotiated, so mixed versions should be compatible. +func TestBackwardsCompatMixedTimeouts(t *testing.T) { + s1Chan := make(chan []byte, 10) + s2Chan := make(chan []byte, 10) + + s1Read := func(ctx context.Context) ([]byte, error) { + select { + case val := <-s1Chan: + return val, nil + case <-ctx.Done(): + } + return nil, nil + } + + s1Write := func(ctx context.Context, b []byte) error { + select { + case s1Chan <- b: + return nil + case <-ctx.Done(): + } + return nil + } + + s2Read := func(ctx context.Context) ([]byte, error) { + select { + case val := <-s2Chan: + return val, nil + case <-ctx.Done(): + } + return nil, nil + } + + s2Write := func(ctx context.Context, b []byte) error { + select { + case s2Chan <- b: + return nil + case <-ctx.Done(): + } + return nil + } + + // Old server timeouts (pre-fix values). + oldServerOpts := []Option{ + WithTimeoutOptions( + WithKeepalivePing( + 5*time.Second, 3*time.Second, + ), + ), + } + + // New client timeouts (post-fix values with dynamic pong). + newClientOpts := []Option{ + WithTimeoutOptions( + WithKeepalivePing( + 10*time.Second, 5*time.Second, + ), + WithDynamicPongTimeout(3, 15*time.Second), + ), + } + + ctx := context.Background() + + var ( + server *GoBackNConn + wg sync.WaitGroup + srvErr error + ) + + wg.Add(1) + go func() { + defer wg.Done() + + server, srvErr = NewServerConn( + ctx, s1Write, s2Read, oldServerOpts..., + ) + }() + + time.Sleep(200 * time.Millisecond) + + client, err := NewClientConn( + ctx, 2, s2Write, s1Read, newClientOpts..., + ) + require.NoError(t, err) + + wg.Wait() + require.NoError(t, srvErr) + + defer func() { + client.Close() + server.Close() + }() + + // Verify bidirectional communication works with mixed timeouts. + payload1 := []byte("new client -> old server") + payload2 := []byte("old server -> new client") + + sendErrCh := make(chan error, 1) + go func() { + sendErrCh <- server.Send(payload2) + }() + + err = client.Send(payload1) + require.NoError(t, err) + + msg, err := server.Recv() + require.NoError(t, err) + require.True(t, bytes.Equal(msg, payload1)) + + msg, err = client.Recv() + require.NoError(t, err) + require.True(t, bytes.Equal(msg, payload2)) + + require.NoError(t, <-sendErrCh) + + // Send multiple messages to exercise the RTT tracking that feeds the + // dynamic pong timeout. + for i := 0; i < 5; i++ { + payload := []byte("round trip " + string(rune('0'+i))) + + sendErrCh := make(chan error, 1) + go func() { + sendErrCh <- server.Send(payload) + }() + + err = client.Send(payload) + require.NoError(t, err) + + msg, err = server.Recv() + require.NoError(t, err) + require.True(t, bytes.Equal(msg, payload)) + + msg, err = client.Recv() + require.NoError(t, err) + require.True(t, bytes.Equal(msg, payload)) + + require.NoError(t, <-sendErrCh) + } +} + +// TestBackwardsCompatOldClientNewServer tests the reverse direction: an old +// client with the original timeout values connecting to a new server with +// increased timeouts and dynamic pong. +func TestBackwardsCompatOldClientNewServer(t *testing.T) { + s1Chan := make(chan []byte, 10) + s2Chan := make(chan []byte, 10) + + s1Read := func(ctx context.Context) ([]byte, error) { + select { + case val := <-s1Chan: + return val, nil + case <-ctx.Done(): + } + return nil, nil + } + + s1Write := func(ctx context.Context, b []byte) error { + select { + case s1Chan <- b: + return nil + case <-ctx.Done(): + } + return nil + } + + s2Read := func(ctx context.Context) ([]byte, error) { + select { + case val := <-s2Chan: + return val, nil + case <-ctx.Done(): + } + return nil, nil + } + + s2Write := func(ctx context.Context, b []byte) error { + select { + case s2Chan <- b: + return nil + case <-ctx.Done(): + } + return nil + } + + // New server timeouts (post-fix values with dynamic pong). + newServerOpts := []Option{ + WithTimeoutOptions( + WithKeepalivePing( + 8*time.Second, 5*time.Second, + ), + WithDynamicPongTimeout(3, 15*time.Second), + ), + } + + // Old client timeouts (pre-fix values). + oldClientOpts := []Option{ + WithTimeoutOptions( + WithKeepalivePing( + 7*time.Second, 3*time.Second, + ), + ), + } + + ctx := context.Background() + + var ( + server *GoBackNConn + wg sync.WaitGroup + srvErr error + ) + + wg.Add(1) + go func() { + defer wg.Done() + + server, srvErr = NewServerConn( + ctx, s1Write, s2Read, newServerOpts..., + ) + }() + + time.Sleep(200 * time.Millisecond) + + client, err := NewClientConn( + ctx, 2, s2Write, s1Read, oldClientOpts..., + ) + require.NoError(t, err) + + wg.Wait() + require.NoError(t, srvErr) + + defer func() { + client.Close() + server.Close() + }() + + // Verify bidirectional communication works. + payload1 := []byte("old client -> new server") + payload2 := []byte("new server -> old client") + + sendErrCh := make(chan error, 1) + go func() { + sendErrCh <- server.Send(payload2) + }() + + err = client.Send(payload1) + require.NoError(t, err) + + msg, err := server.Recv() + require.NoError(t, err) + require.True(t, bytes.Equal(msg, payload1)) + + msg, err = client.Recv() + require.NoError(t, err) + require.True(t, bytes.Equal(msg, payload2)) + + require.NoError(t, <-sendErrCh) +} + func setUpClientServerConns(t *testing.T, n uint8, cRead, sRead func(ctx context.Context) ([]byte, error), cWrite, sWrite func(ctx context.Context, b []byte) error, diff --git a/gbn/messages.go b/gbn/messages.go index 8ba495e7..a1db19e8 100644 --- a/gbn/messages.go +++ b/gbn/messages.go @@ -159,7 +159,7 @@ func Deserialize(b []byte) (Message, error) { switch b[0] { case DATA: - if len(b) < 3 { + if len(b) < 4 { return nil, io.EOF } return &PacketData{ diff --git a/gbn/rapid_test.go b/gbn/rapid_test.go new file mode 100644 index 00000000..59b8585b --- /dev/null +++ b/gbn/rapid_test.go @@ -0,0 +1,532 @@ +package gbn + +import ( + "math" + "testing" + "time" + + "github.com/stretchr/testify/require" + "pgregory.net/rapid" +) + +// drawDuration draws a random time.Duration in the given range using rapid. +func drawDuration(t *rapid.T, min, max time.Duration, + label string) time.Duration { + + return time.Duration( + rapid.Int64Range(int64(min), int64(max)).Draw(t, label), + ) +} + +// TestRapidMessageRoundTrip checks that for any valid GBN message, the +// serialize-then-deserialize round trip is lossless. +func TestRapidMessageRoundTrip(t *testing.T) { + t.Parallel() + + rapid.Check(t, func(t *rapid.T) { + msg := genMessage(t) + + serialized, err := msg.Serialize() + require.NoError(t, err) + + deserialized, err := Deserialize(serialized) + require.NoError(t, err) + + require.Equal(t, msg, deserialized) + }) +} + +// TestRapidDeserializeNeverPanics checks that Deserialize never panics on +// arbitrary input bytes, only returning errors or valid messages. +func TestRapidDeserializeNeverPanics(t *testing.T) { + t.Parallel() + + rapid.Check(t, func(t *rapid.T) { + data := rapid.SliceOf(rapid.Byte()).Draw(t, "data") + + // This must not panic regardless of input. + msg, err := Deserialize(data) + if err != nil { + return + } + + // If deserialization succeeded, the message should be + // re-serializable. + _, err = msg.Serialize() + require.NoError(t, err) + }) +} + +// TestRapidContainsSequenceProperties verifies key invariants of the modular +// arithmetic sequence containment check used by the GBN queue. +func TestRapidContainsSequenceProperties(t *testing.T) { + t.Parallel() + + // Property: base is never contained when base == top (empty queue). + rapid.Check(t, func(t *rapid.T) { + base := rapid.Uint8().Draw(t, "base") + require.False(t, containsSequence(base, base, base)) + }) + + // Property: base is always contained when queue is non-empty. + rapid.Check(t, func(t *rapid.T) { + base := rapid.Uint8().Draw(t, "base") + top := rapid.Uint8().Draw(t, "top") + if base == top { + t.Skip() + } + require.True(t, containsSequence(base, top, base)) + }) + + // Property: top is never contained (half-open interval [base, top)). + rapid.Check(t, func(t *rapid.T) { + base := rapid.Uint8().Draw(t, "base") + top := rapid.Uint8().Draw(t, "top") + if base == top { + t.Skip() + } + require.False(t, containsSequence(base, top, top)) + }) + + // Property: (top-1) mod 256 is always contained when non-empty. + rapid.Check(t, func(t *rapid.T) { + base := rapid.Uint8().Draw(t, "base") + top := rapid.Uint8().Draw(t, "top") + if base == top { + t.Skip() + } + lastValid := top - 1 // uint8 wraps naturally + require.True(t, containsSequence(base, top, lastValid)) + }) +} + +// TestRapidContainsSequenceBruteForce validates containsSequence against a +// simple reference implementation that enumerates the half-open interval. +func TestRapidContainsSequenceBruteForce(t *testing.T) { + t.Parallel() + + rapid.Check(t, func(t *rapid.T) { + base := rapid.Uint8().Draw(t, "base") + top := rapid.Uint8().Draw(t, "top") + seq := rapid.Uint8().Draw(t, "seq") + + expected := refContains(base, top, seq) + got := containsSequence(base, top, seq) + require.Equal(t, expected, got, + "base=%d top=%d seq=%d", base, top, seq) + }) +} + +// refContains is a reference implementation of sequence containment that +// enumerates the half-open interval [base, top) in uint8 space, wrapping +// at 256 to match the production containsSequence behavior. +func refContains(base, top, seq uint8) bool { + if base == top { + return false + } + cur := base + for { + if cur == seq { + return true + } + cur++ // uint8 wraps at 256, matching containsSequence. + if cur == top { + return false + } + } +} + +// TestRapidDynamicPongTimeoutProperties verifies the key invariants of the +// dynamic pong timeout computation. +func TestRapidDynamicPongTimeoutProperties(t *testing.T) { + t.Parallel() + + // Property: pong timeout is always >= basePongTime when pingTime >= + // basePong (i.e., the pingTime cap doesn't interfere). + rapid.Check(t, func(t *rapid.T) { + basePong := drawDuration( + t, time.Millisecond, 10*time.Second, "basePong", + ) + pingTime := drawDuration( + t, basePong, 60*time.Second, "pingTime", + ) + multiplier := rapid.IntRange(1, 10).Draw(t, "mult") + maxPong := drawDuration( + t, basePong, 60*time.Second, "maxPong", + ) + rtt := drawDuration( + t, time.Millisecond, 20*time.Second, "rtt", + ) + + tm := NewTimeOutManager( + nil, + WithKeepalivePing(pingTime, basePong), + WithDynamicPongTimeout(multiplier, maxPong), + ) + + // Inject RTT. + tm.mu.Lock() + tm.smoothedRTT = rtt + tm.rttInitialized = true + tm.mu.Unlock() + + pong := tm.GetPongTime() + + require.GreaterOrEqual(t, int64(pong), int64(basePong), + "pong=%v must be >= basePong=%v (rtt=%v, ping=%v)", + pong, basePong, rtt, pingTime) + }) + + // Property: pong timeout is always <= maxPongTime when max is set. + rapid.Check(t, func(t *rapid.T) { + basePong := drawDuration( + t, time.Millisecond, 5*time.Second, "basePong", + ) + multiplier := rapid.IntRange(1, 10).Draw(t, "mult") + maxPong := drawDuration( + t, basePong, 60*time.Second, "maxPong", + ) + rtt := drawDuration( + t, time.Millisecond, 20*time.Second, "rtt", + ) + + tm := NewTimeOutManager( + nil, + WithKeepalivePing(10*time.Second, basePong), + WithDynamicPongTimeout(multiplier, maxPong), + ) + + tm.mu.Lock() + tm.smoothedRTT = rtt + tm.rttInitialized = true + tm.mu.Unlock() + + pong := tm.GetPongTime() + + require.LessOrEqual(t, int64(pong), int64(maxPong), + "pong=%v must be <= maxPong=%v (rtt=%v, mult=%d)", + pong, maxPong, rtt, multiplier) + }) + + // Property: pong timeout never exceeds ping interval. + rapid.Check(t, func(t *rapid.T) { + basePong := drawDuration( + t, time.Millisecond, 5*time.Second, "basePong", + ) + pingTime := drawDuration( + t, time.Millisecond, 30*time.Second, "pingTime", + ) + multiplier := rapid.IntRange(1, 10).Draw(t, "mult") + maxPong := drawDuration( + t, basePong, 60*time.Second, "maxPong", + ) + rtt := drawDuration( + t, time.Millisecond, 20*time.Second, "rtt", + ) + + tm := NewTimeOutManager( + nil, + WithKeepalivePing(pingTime, basePong), + WithDynamicPongTimeout(multiplier, maxPong), + ) + + tm.mu.Lock() + tm.smoothedRTT = rtt + tm.rttInitialized = true + tm.mu.Unlock() + + pong := tm.GetPongTime() + + require.LessOrEqual(t, int64(pong), int64(pingTime), + "pong=%v must be <= pingTime=%v (rtt=%v, mult=%d, "+ + "base=%v, max=%v)", + pong, pingTime, rtt, multiplier, basePong, maxPong) + }) + + // Property: with zero RTT, pong falls back to static base. + rapid.Check(t, func(t *rapid.T) { + basePong := drawDuration( + t, time.Millisecond, 30*time.Second, "basePong", + ) + + tm := NewTimeOutManager( + nil, + WithKeepalivePing(10*time.Second, basePong), + WithDynamicPongTimeout(3, 15*time.Second), + ) + + pong := tm.GetPongTime() + require.Equal(t, basePong, pong) + }) + + // Property: without dynamic enabled, pong is always the static base. + rapid.Check(t, func(t *rapid.T) { + basePong := drawDuration( + t, time.Millisecond, 30*time.Second, "basePong", + ) + rtt := drawDuration( + t, time.Millisecond, 20*time.Second, "rtt", + ) + + tm := NewTimeOutManager( + nil, + WithKeepalivePing(10*time.Second, basePong), + ) + + // Even with RTT injected, static mode ignores it. + tm.mu.Lock() + tm.smoothedRTT = rtt + tm.rttInitialized = true + tm.mu.Unlock() + + pong := tm.GetPongTime() + require.Equal(t, basePong, pong) + }) +} + +// TestRapidEWMAProperties verifies EWMA smoothed RTT invariants under random +// sample sequences. +func TestRapidEWMAProperties(t *testing.T) { + t.Parallel() + + // Property: smoothedRTT is always bounded by [min(samples), max(samples)] + // after at least one sample. + rapid.Check(t, func(t *rapid.T) { + nSamples := rapid.IntRange(1, 50).Draw(t, "nSamples") + + tm := NewTimeOutManager(nil) + + var minSample, maxSample time.Duration + + tm.mu.Lock() + for i := 0; i < nSamples; i++ { + sample := drawDuration( + t, time.Millisecond, 30*time.Second, + "sample", + ) + + if i == 0 { + minSample = sample + maxSample = sample + } else { + if sample < minSample { + minSample = sample + } + if sample > maxSample { + maxSample = sample + } + } + + tm.updateSmoothedRTT(sample) + } + + smoothed := tm.smoothedRTT + tm.mu.Unlock() + + require.GreaterOrEqual(t, int64(smoothed), int64(minSample), + "smoothedRTT=%v must be >= min sample=%v", + smoothed, minSample) + require.LessOrEqual(t, int64(smoothed), int64(maxSample), + "smoothedRTT=%v must be <= max sample=%v", + smoothed, maxSample) + }) + + // Property: constant samples converge to that constant. + rapid.Check(t, func(t *rapid.T) { + constant := drawDuration( + t, time.Millisecond, 30*time.Second, "constant", + ) + + tm := NewTimeOutManager(nil) + + tm.mu.Lock() + for i := 0; i < 100; i++ { + tm.updateSmoothedRTT(constant) + } + smoothed := tm.smoothedRTT + tm.mu.Unlock() + + require.Equal(t, constant, smoothed, + "100 identical samples should converge exactly") + }) +} + +// TestRapidQueueSizeInvariants uses a state machine approach to verify that +// the GBN queue's size never exceeds the window and that sequence numbers +// remain consistent after a series of add/ACK/NACK operations. +func TestRapidQueueSizeInvariants(t *testing.T) { + t.Parallel() + + rapid.Check(t, func(t *rapid.T) { + // Use a small window to increase the chance of wrapping. + n := rapid.Uint8Range(2, 8).Draw(t, "n") + s := n + 1 + + tm := NewTimeOutManager(nil) + + q := newQueue(&queueCfg{ + s: s, + log: log, + sendPkt: func(packet *PacketData) error { + return nil + }, + }, tm) + defer q.stop() + + // Track how many packets we've added and ACK'd to verify + // the queue size invariant. + added := 0 + acked := 0 + + numOps := rapid.IntRange(10, 100).Draw(t, "numOps") + + for i := 0; i < numOps; i++ { + currentSize := int(q.size()) + maxSize := int(n) + + if currentSize < maxSize { + op := rapid.IntRange(0, 2).Draw(t, "op") + + switch op { + case 0: + // Add a packet. + q.addPacket(&PacketData{ + Payload: []byte{byte(i)}, + }) + added++ + + case 1: + // ACK the base if queue is non-empty. + if currentSize > 0 { + q.baseMtx.RLock() + base := q.sequenceBase + q.baseMtx.RUnlock() + + if q.processACK(base) { + acked++ + } + } + + case 2: + // NACK some sequence. + if currentSize > 0 { + q.baseMtx.RLock() + base := q.sequenceBase + q.baseMtx.RUnlock() + + q.processNACK(base) + } + } + } else { + // Queue full, must ACK to make room. + q.baseMtx.RLock() + base := q.sequenceBase + q.baseMtx.RUnlock() + + if q.processACK(base) { + acked++ + } + } + + // Invariant: queue size must never exceed the + // window size n. + size := q.size() + require.LessOrEqual(t, size, n, + "queue size %d exceeds window %d after "+ + "op %d (added=%d, acked=%d)", + size, n, i, added, acked) + + // Invariant: queue size should equal + // (added - acked) mod s. + expectedSize := uint8((added - acked) % int(s)) + require.Equal(t, expectedSize, size, + "size mismatch: expected %d got %d "+ + "(added=%d acked=%d s=%d)", + expectedSize, size, added, acked, s) + } + }) +} + +// TestRapidTimeoutBoosterProperties verifies that the timeout booster always +// produces values >= the original timeout, and that reset returns to base. +func TestRapidTimeoutBoosterProperties(t *testing.T) { + t.Parallel() + + rapid.Check(t, func(t *rapid.T) { + originalTimeout := drawDuration( + t, time.Millisecond, 30*time.Second, + "originalTimeout", + ) + + boostPct := float32( + rapid.Float64Range(0.01, 2.0).Draw(t, "boostPct"), + ) + + booster := NewTimeoutBooster(originalTimeout, boostPct, false) + + numBoosts := rapid.IntRange(0, 20).Draw(t, "numBoosts") + for i := 0; i < numBoosts; i++ { + booster.Boost() + } + + current := booster.GetCurrentTimeout() + + // Property: boosted timeout is always >= original. + require.GreaterOrEqual(t, int64(current), + int64(originalTimeout), + "boosted timeout %v < original %v after %d boosts", + current, originalTimeout, numBoosts) + + // Property: after reset, timeout returns to the new base. + newBase := drawDuration( + t, time.Millisecond, 30*time.Second, "newBase", + ) + + booster.Reset(newBase) + require.Equal(t, newBase, booster.GetCurrentTimeout()) + }) +} + +// TestRapidPingPongZeroDisablesPing verifies that a zero ping time effectively +// disables keepalive by returning MaxInt64. +func TestRapidPingPongZeroDisablesPing(t *testing.T) { + t.Parallel() + + rapid.Check(t, func(t *rapid.T) { + tm := NewTimeOutManager(nil) + + pingTime := tm.GetPingTime() + require.Equal(t, time.Duration(math.MaxInt64), pingTime) + }) +} + +// genMessage generates a random valid GBN message. +func genMessage(t *rapid.T) Message { + msgType := rapid.IntRange(0, 5).Draw(t, "msgType") + + switch msgType { + case 0: + return &PacketData{ + Seq: rapid.Uint8().Draw(t, "seq"), + FinalChunk: rapid.Bool().Draw(t, "finalChunk"), + IsPing: rapid.Bool().Draw(t, "isPing"), + Payload: rapid.SliceOf(rapid.Byte()).Draw(t, "payload"), + } + case 1: + return &PacketACK{ + Seq: rapid.Uint8().Draw(t, "seq"), + } + case 2: + return &PacketNACK{ + Seq: rapid.Uint8().Draw(t, "seq"), + } + case 3: + return &PacketSYN{ + N: rapid.Uint8().Draw(t, "n"), + } + case 4: + return &PacketFIN{} + default: + return &PacketSYNACK{} + } +} diff --git a/gbn/testdata/rapid/TestRapidDeserializeNeverPanics/TestRapidDeserializeNeverPanics-20260305151902-96036.fail b/gbn/testdata/rapid/TestRapidDeserializeNeverPanics/TestRapidDeserializeNeverPanics-20260305151902-96036.fail new file mode 100644 index 00000000..0aaa37d9 --- /dev/null +++ b/gbn/testdata/rapid/TestRapidDeserializeNeverPanics/TestRapidDeserializeNeverPanics-20260305151902-96036.fail @@ -0,0 +1,13 @@ +# 2026/03/05 15:19:02.025148 [TestRapidDeserializeNeverPanics] [rapid] draw data: []byte{0x2, 0x0, 0x0} +# +v0.4.8#2188469523130279869 +0x5555555555555 +0x38e38e38e38e4 +0x2 +0x5555555555555 +0x0 +0x0 +0x5555555555555 +0x0 +0x0 +0x0 \ No newline at end of file diff --git a/gbn/timeout_manager.go b/gbn/timeout_manager.go index cac380e2..53220dd2 100644 --- a/gbn/timeout_manager.go +++ b/gbn/timeout_manager.go @@ -18,6 +18,14 @@ const ( defaultBoostPercent = 0.5 DefaultSendTimeout = math.MaxInt64 DefaultRecvTimeout = math.MaxInt64 + + // defaultPongMultiplier is the default multiplier applied to the + // observed RTT to compute the dynamic pong timeout. + defaultPongMultiplier = 3 + + // defaultMaxPongTime is the default upper bound for the dynamic pong + // timeout. + defaultMaxPongTime = 15 * time.Second ) // TimeoutBooster is used to boost a timeout by a given percentage value. @@ -191,11 +199,37 @@ type TimeoutManager struct { // counterparty if we've received no packet. pingTime time.Duration - // pongTime represents how long we will wait for the expect a pong - // response after we've sent a ping. If no response is received within - // the time limit, we will close the connection. + // pongTime represents the base pong timeout, i.e. the minimum time we + // will wait for a pong response after we've sent a ping. If no + // response is received within the time limit, we will close the + // connection. When dynamic pong timeout is enabled, the actual pong + // timeout may be larger than this value based on observed RTT. pongTime time.Duration + // dynamicPongTime indicates whether the pong timeout should be + // dynamically adjusted based on the observed RTT of the connection. + dynamicPongTime bool + + // pongMultiplier is the multiplier applied to the observed RTT when + // computing the dynamic pong timeout. A value of 3 means the pong + // timeout will be at least 3x the observed RTT. + pongMultiplier int + + // maxPongTime is the upper bound for the dynamic pong timeout. + maxPongTime time.Duration + + // smoothedRTT stores the exponentially weighted moving average of + // observed round-trip times. This is used to dynamically compute the + // pong timeout when dynamic pong time is enabled. Using an EWMA + // rather than a single sample prevents an unlucky low measurement + // from making the pong timeout too aggressive. + smoothedRTT time.Duration + + // rttInitialized indicates whether smoothedRTT has received its + // first sample. Before the first sample, GetPongTime falls back to + // the static base pong time. + rttInitialized bool + // responseCounter represents the current number of corresponding // responses received since last updating the resend timeout. responseCounter int @@ -242,6 +276,8 @@ func NewTimeOutManager(logger btclog.Logger, sendTimeout: DefaultSendTimeout, sentTimes: make(map[uint8]time.Time), timeoutUpdateFrequency: defaultTimeoutUpdateFrequency, + pongMultiplier: defaultPongMultiplier, + maxPongTime: defaultMaxPongTime, } for _, opt := range timeoutOpts { @@ -362,6 +398,7 @@ func (m *TimeoutManager) Received(msg Message) { m.latestSentSYNTimeMu.Unlock() + m.updateSmoothedRTT(responseTime) m.updateResendTimeoutUnsafe(responseTime) case *PacketACK: @@ -378,6 +415,12 @@ func (m *TimeoutManager) Received(msg Message) { m.sentTimesMu.Unlock() + responseTime := receivedAt.Sub(sentTime) + + // Always update the smoothed RTT on every ACK so the + // dynamic pong timeout has a stable, up-to-date signal. + m.updateSmoothedRTT(responseTime) + m.responseCounter++ reachedFrequency := m.responseCounter% @@ -390,7 +433,7 @@ func (m *TimeoutManager) Received(msg Message) { if !m.hasSetDynamicTimeout || reachedFrequency { m.responseCounter = 0 - m.updateResendTimeoutUnsafe(receivedAt.Sub(sentTime)) + m.updateResendTimeoutUnsafe(responseTime) } } } @@ -415,7 +458,8 @@ func (m *TimeoutManager) updateResendTimeoutUnsafe(responseTime time.Duration) { multipliedTimeout = minimumResendTimeout } - m.log.Debugf("Updating resendTimeout to %v", multipliedTimeout) + m.log.Debugf("Updating resendTimeout to %v (smoothedRTT=%v)", + multipliedTimeout, m.smoothedRTT) m.resendTimeout = multipliedTimeout @@ -483,8 +527,11 @@ func (m *TimeoutManager) GetPingTime() time.Duration { } // GetPongTime returns the pong timeout, representing how long we will wait for -// the expect a pong response after we've sent a ping. If no response is -// received within the time limit, we will close the connection. +// a pong response after we've sent a ping. If no response is received within +// the time limit, we will close the connection. When dynamic pong timeout is +// enabled and we have observed RTT data, the timeout is computed as +// max(basePongTime, pongMultiplier * smoothedRTT), capped at maxPongTime and +// pingTime. func (m *TimeoutManager) GetPongTime() time.Duration { m.mu.RLock() defer m.mu.RUnlock() @@ -493,7 +540,67 @@ func (m *TimeoutManager) GetPongTime() time.Duration { return time.Duration(math.MaxInt64) } - return m.pongTime + // If dynamic pong time is not enabled or we have no RTT data yet, + // return the static base pong time. + if !m.dynamicPongTime || !m.rttInitialized { + return m.pongTime + } + + // Compute the dynamic pong timeout as pongMultiplier * smoothedRTT. + dynamicPong := time.Duration(m.pongMultiplier) * m.smoothedRTT + + // Use the base pong time as a floor. + if dynamicPong < m.pongTime { + dynamicPong = m.pongTime + } + + // Cap at the maximum pong time. + if m.maxPongTime > 0 && dynamicPong > m.maxPongTime { + dynamicPong = m.maxPongTime + } + + // Ensure pong timeout never exceeds ping interval, otherwise the + // next ping would fire and reset the pong timer before it expires, + // preventing the connection from ever timing out. + if m.pingTime > 0 && dynamicPong > m.pingTime { + dynamicPong = m.pingTime + } + + return dynamicPong +} + +// GetSmoothedRTT returns the EWMA-smoothed round-trip time. +func (m *TimeoutManager) GetSmoothedRTT() time.Duration { + m.mu.RLock() + defer m.mu.RUnlock() + + return m.smoothedRTT +} + +// GetLatestRTT returns the EWMA-smoothed round-trip time. +// +// Deprecated: Use GetSmoothedRTT instead. +func (m *TimeoutManager) GetLatestRTT() time.Duration { + return m.GetSmoothedRTT() +} + +// updateSmoothedRTT updates the EWMA-smoothed RTT with a new sample. The +// first sample seeds the EWMA directly; subsequent samples blend in with +// alpha = 0.25. +// +// NOTE: The TimeoutManager mu must be held when calling this function. +func (m *TimeoutManager) updateSmoothedRTT(rtt time.Duration) { + const ewmaAlpha = 0.25 + + if !m.rttInitialized { + m.smoothedRTT = rtt + m.rttInitialized = true + } else { + m.smoothedRTT = time.Duration( + ewmaAlpha*float64(rtt) + + (1-ewmaAlpha)*float64(m.smoothedRTT), + ) + } } // SetSendTimeout sets the send timeout. diff --git a/gbn/timeout_manager_test.go b/gbn/timeout_manager_test.go index f200dd9b..bc4c8ec9 100644 --- a/gbn/timeout_manager_test.go +++ b/gbn/timeout_manager_test.go @@ -348,6 +348,310 @@ func TestStaticTimeout(t *testing.T) { require.Equal(t, staticTimeout, resendTimeout) } +// TestDynamicPongTimeout ensures that the pong timeout is dynamically adjusted +// based on the EWMA-smoothed RTT when dynamic pong timeout is enabled. +func TestDynamicPongTimeout(t *testing.T) { + t.Parallel() + + basePong := 500 * time.Millisecond + maxPong := 10 * time.Second + pongMultiplier := 3 + + // Use a large ping time so it doesn't cap the pong values under test. + pingTime := 30 * time.Second + + // Create a timeout manager with dynamic pong timeout enabled. + tm := NewTimeOutManager( + nil, + WithKeepalivePing(pingTime, basePong), + WithDynamicPongTimeout(pongMultiplier, maxPong), + ) + + // Initially, with no RTT data, the pong time should equal the base. + require.Equal(t, basePong, tm.GetPongTime()) + + // Simulate a SYN exchange with a 200ms RTT. This is the first sample + // so the EWMA seeds directly: smoothedRTT = 200ms. + // Dynamic pong = max(basePong, 3 * 200ms) = 600ms. + synMsg := &PacketSYN{N: 20} + sendAndReceiveWithDuration( + t, tm, 200*time.Millisecond, synMsg, synMsg, false, + ) + + pongTime := tm.GetPongTime() + expectedPong := time.Duration(pongMultiplier) * 200 * time.Millisecond + + // Allow some tolerance for timing jitter. + require.InDelta( + t, float64(expectedPong), float64(pongTime), + float64(100*time.Millisecond), + ) + + // Verify the pong time is above the base. + require.GreaterOrEqual(t, pongTime, basePong) + + // Now simulate a very fast RTT (50ms). The EWMA blends: + // smoothedRTT = 0.25*50 + 0.75*200 = 162.5ms. + // Dynamic pong = 3 * 162.5ms = 487.5ms ~ basePong. With timing + // jitter the smoothed RTT may be slightly above the theoretical + // value, so use a tolerance check. + sendAndReceiveWithDuration( + t, tm, 50*time.Millisecond, synMsg, synMsg, false, + ) + + pongTime = tm.GetPongTime() + require.InDelta( + t, float64(basePong), float64(pongTime), + float64(100*time.Millisecond), + ) + + // Directly inject a high smoothed RTT to verify the max cap without + // sleeping through many iterations. 5s smoothedRTT * 3 = 15s which + // exceeds maxPong (10s), so pong should be capped at maxPong. + tm.mu.Lock() + tm.smoothedRTT = 5 * time.Second + tm.mu.Unlock() + + pongTime = tm.GetPongTime() + require.Equal(t, maxPong, pongTime) +} + +// TestDynamicPongTimeoutDisabled ensures that the pong timeout is static when +// dynamic pong timeout is not enabled. +func TestDynamicPongTimeoutDisabled(t *testing.T) { + t.Parallel() + + basePong := 3 * time.Second + + // Create a timeout manager without dynamic pong timeout. + tm := NewTimeOutManager( + nil, + WithKeepalivePing(time.Second, basePong), + ) + + // The pong time should always be the base, regardless of RTT. + require.Equal(t, basePong, tm.GetPongTime()) + + // Simulate a SYN exchange with a high RTT. + synMsg := &PacketSYN{N: 20} + sendAndReceiveWithDuration( + t, tm, 2*time.Second, synMsg, synMsg, false, + ) + + // Pong time should still be static. + require.Equal(t, basePong, tm.GetPongTime()) +} + +// TestDefaultPongMultiplierAndMaxPongTime verifies that a TimeoutManager +// created without WithDynamicPongTimeout still has the default pongMultiplier +// and maxPongTime values set. This is a regression test for the constructor +// initialization fix: if dynamic mode were later enabled on such a manager +// (e.g. by a new code path), the defaults must produce sensible pong timeouts +// rather than zero-value degradation. +func TestDefaultPongMultiplierAndMaxPongTime(t *testing.T) { + t.Parallel() + + basePong := 500 * time.Millisecond + + // Create without WithDynamicPongTimeout — the constructor should + // still initialize pongMultiplier and maxPongTime to defaults. + tm := NewTimeOutManager( + nil, + WithKeepalivePing(30*time.Second, basePong), + ) + + // Manually enable dynamic mode to test the defaults take effect. + tm.mu.Lock() + tm.dynamicPongTime = true + tm.smoothedRTT = 2 * time.Second + tm.rttInitialized = true + tm.mu.Unlock() + + pongTime := tm.GetPongTime() + + // With defaults (multiplier=3, max=15s): 3 * 2s = 6s. + expectedPong := time.Duration(defaultPongMultiplier) * 2 * time.Second + require.Equal(t, expectedPong, pongTime, + "default pongMultiplier should produce correct dynamic pong") + + // With a very high RTT, should be capped at defaultMaxPongTime. + tm.mu.Lock() + tm.smoothedRTT = 10 * time.Second + tm.mu.Unlock() + + pongTime = tm.GetPongTime() + require.Equal(t, defaultMaxPongTime, pongTime, + "default maxPongTime should cap the dynamic pong") +} + +// TestDynamicPongTimeoutWithDataPackets ensures the dynamic pong timeout +// updates correctly when RTT is measured from data packet ACKs. +func TestDynamicPongTimeoutWithDataPackets(t *testing.T) { + t.Parallel() + + basePong := 500 * time.Millisecond + pongMultiplier := 3 + + tm := NewTimeOutManager( + nil, + WithKeepalivePing(time.Second, basePong), + WithDynamicPongTimeout(pongMultiplier, 15*time.Second), + WithTimeoutUpdateFrequency(1), + ) + + // Send a data packet and receive the ACK with ~300ms RTT. + msg := &PacketData{Seq: 1} + response := &PacketACK{Seq: 1} + + sendAndReceiveWithDuration( + t, tm, 300*time.Millisecond, msg, response, false, + ) + + // Dynamic pong should be ~900ms (3 * 300ms). + pongTime := tm.GetPongTime() + expectedPong := time.Duration(pongMultiplier) * 300 * time.Millisecond + + require.InDelta( + t, float64(expectedPong), float64(pongTime), + float64(100*time.Millisecond), + ) +} + +// TestDynamicPongCappedByPingTime verifies that the dynamic pong timeout never +// exceeds the ping interval, even when the RTT-based computation would produce +// a larger value. +func TestDynamicPongCappedByPingTime(t *testing.T) { + t.Parallel() + + basePong := 500 * time.Millisecond + pingTime := 2 * time.Second + maxPong := 30 * time.Second + pongMultiplier := 3 + + tm := NewTimeOutManager( + nil, + WithKeepalivePing(pingTime, basePong), + WithDynamicPongTimeout(pongMultiplier, maxPong), + ) + + // Inject a high smoothed RTT: 3 * 1s = 3s > pingTime (2s). + tm.mu.Lock() + tm.smoothedRTT = time.Second + tm.rttInitialized = true + tm.mu.Unlock() + + pongTime := tm.GetPongTime() + require.Equal(t, pingTime, pongTime, + "pong should be capped at pingTime") + + // Even with an extremely high RTT, pong must not exceed pingTime. + tm.mu.Lock() + tm.smoothedRTT = 10 * time.Second + tm.mu.Unlock() + + pongTime = tm.GetPongTime() + require.Equal(t, pingTime, pongTime, + "pong must never exceed pingTime regardless of RTT") + + // When the RTT-based value is below pingTime, it should be used. + tm.mu.Lock() + tm.smoothedRTT = 200 * time.Millisecond + tm.mu.Unlock() + + pongTime = tm.GetPongTime() + expectedPong := time.Duration(pongMultiplier) * 200 * time.Millisecond + require.Equal(t, expectedPong, pongTime, + "pong should use RTT-based value when below pingTime") +} + +// TestEWMASmoothing verifies that the EWMA-smoothed RTT converges correctly +// and is resistant to single-sample outliers. +func TestEWMASmoothing(t *testing.T) { + t.Parallel() + + tm := NewTimeOutManager( + nil, + WithTimeoutUpdateFrequency(1), + WithKeepalivePing(30*time.Second, 100*time.Millisecond), + WithDynamicPongTimeout(3, 30*time.Second), + ) + + // The first sample seeds the EWMA directly. + synMsg := &PacketSYN{N: 20} + sendAndReceiveWithDuration( + t, tm, 200*time.Millisecond, synMsg, synMsg, false, + ) + + rtt := tm.GetSmoothedRTT() + require.InDelta( + t, float64(200*time.Millisecond), float64(rtt), + float64(50*time.Millisecond), + "first sample should seed EWMA directly", + ) + + // Feed 10 stable samples at 200ms. The EWMA should stay near 200ms. + for i := 0; i < 10; i++ { + sendAndReceiveWithDuration( + t, tm, 200*time.Millisecond, synMsg, synMsg, false, + ) + } + + stableRTT := tm.GetSmoothedRTT() + require.InDelta( + t, float64(200*time.Millisecond), float64(stableRTT), + float64(50*time.Millisecond), + "EWMA should converge near stable RTT", + ) + + // Now inject a single outlier (50ms). The EWMA should NOT drop + // dramatically — it should resist the outlier due to smoothing. + sendAndReceiveWithDuration( + t, tm, 50*time.Millisecond, synMsg, synMsg, false, + ) + + afterOutlier := tm.GetSmoothedRTT() + + // EWMA with alpha=0.25: new = 0.25*50 + 0.75*~200 = ~162ms. + // It should still be well above the outlier value. + require.Greater(t, int64(afterOutlier), int64(100*time.Millisecond), + "EWMA should resist single low outlier") + require.Less(t, int64(afterOutlier), int64(stableRTT), + "EWMA should move slightly toward outlier") + + // Inject a single high outlier (2s). Should move up but not jump to 2s. + sendAndReceiveWithDuration( + t, tm, 2*time.Second, synMsg, synMsg, false, + ) + + afterHighOutlier := tm.GetSmoothedRTT() + require.Less(t, int64(afterHighOutlier), int64(time.Second), + "EWMA should resist single high outlier") + require.Greater(t, int64(afterHighOutlier), int64(afterOutlier), + "EWMA should move toward high outlier") +} + +// TestGetLatestRTT ensures GetLatestRTT returns the most recently measured RTT. +func TestGetLatestRTT(t *testing.T) { + t.Parallel() + + tm := NewTimeOutManager(nil, WithTimeoutUpdateFrequency(1)) + + // Initially zero. + require.Equal(t, time.Duration(0), tm.GetLatestRTT()) + + // After a SYN exchange, should reflect the response time. + synMsg := &PacketSYN{N: 20} + sendAndReceiveWithDuration( + t, tm, time.Second, synMsg, synMsg, false, + ) + + rtt := tm.GetLatestRTT() + require.InDelta( + t, float64(time.Second), float64(rtt), + float64(100*time.Millisecond), + ) +} + // sendAndReceive simulates that a SYN message has been sent for the passed the // timeout manager, and then waits for one second before a simulating the SYN // response. While waiting, the function asserts that the resend timeout hasn't diff --git a/mailbox/client_conn.go b/mailbox/client_conn.go index 56015341..57e2d739 100644 --- a/mailbox/client_conn.go +++ b/mailbox/client_conn.go @@ -66,23 +66,34 @@ const ( // set up the clients send stream cipher box. gbnHandshakeTimeout = 2000 * time.Millisecond - // gbnClientPingTimeout is the time after with the client will send the + // gbnClientPingTimeout is the time after which the client will send the // server a ping message if it has not received any packets from the // server. The client will close the connection if it then does not // receive an acknowledgement of the ping from the server. - gbnClientPingTimeout = 7 * time.Second + gbnClientPingTimeout = 10 * time.Second - // gbnServerTimeout is the time after with the server will send the - // client a ping message if it has not received any packets from the - // client. The server will close the connection if it then does not + // gbnServerPingTimeout is the time after which the server will send + // the client a ping message if it has not received any packets from + // the client. The server will close the connection if it then does not // receive an acknowledgement of the ping from the client. This timeout // is slightly shorter than the gbnClientPingTimeout to prevent both // sides from unnecessarily sending pings simultaneously. - gbnServerPingTimeout = 5 * time.Second + gbnServerPingTimeout = 8 * time.Second - // gbnPongTimout is the time after sending the pong message that we will - // timeout if we do not receive any message from our peer. - gbnPongTimeout = 3 * time.Second + // gbnPongTimeout is the base time after sending a ping that we will + // timeout if we do not receive any message from our peer. This serves + // as a floor for the dynamic pong timeout which adjusts based on + // observed RTT. + gbnPongTimeout = 5 * time.Second + + // gbnPongMultiplier is the multiplier applied to the observed RTT + // when computing the dynamic pong timeout. A value of 3 means the + // pong timeout will be at least 3x the observed round-trip time. + gbnPongMultiplier = 3 + + // gbnMaxPongTimeout is the upper bound for the dynamic pong timeout. + // Even with very high RTT, the pong timeout will not exceed this. + gbnMaxPongTimeout = 15 * time.Second // gbnBoostPercent is the percentage value that the resend and handshake // timeout will be boosted any time we need to resend a packet due to @@ -191,6 +202,9 @@ func NewClientConn(ctx context.Context, sid [64]byte, serverHost string, gbnClientPingTimeout, gbnPongTimeout, ), gbn.WithBoostPercent(gbnBoostPercent), + gbn.WithDynamicPongTimeout( + gbnPongMultiplier, gbnMaxPongTimeout, + ), ), gbn.WithOnFIN(func() { // We force the connection to set a new status after diff --git a/mailbox/server_conn.go b/mailbox/server_conn.go index 0c765d70..2ab11588 100644 --- a/mailbox/server_conn.go +++ b/mailbox/server_conn.go @@ -90,6 +90,10 @@ func NewServerConn(ctx context.Context, serverHost string, gbnServerPingTimeout, gbnPongTimeout, ), gbn.WithBoostPercent(gbnBoostPercent), + gbn.WithDynamicPongTimeout( + gbnPongMultiplier, + gbnMaxPongTimeout, + ), ), }, status: ServerStatusNotConnected,