diff --git a/command.go b/command.go index d17081f56..e253607a2 100644 --- a/command.go +++ b/command.go @@ -7370,6 +7370,7 @@ type MonitorCmd struct { baseCmd ch chan string status MonitorStatus + closed bool mu sync.Mutex } @@ -7382,6 +7383,7 @@ func newMonitorCmd(ctx context.Context, ch chan string) *MonitorCmd { }, ch: ch, status: monitorStatusIdle, + closed: false, mu: sync.Mutex{}, } } @@ -7411,23 +7413,58 @@ func (cmd *MonitorCmd) readReply(rd *proto.Reader) error { func (cmd *MonitorCmd) readMonitor(rd *proto.Reader, cancel context.CancelFunc) error { for { + // Check if context is done first + select { + case <-cmd.ctx.Done(): + cmd.closeChannel() + cancel() + return cmd.ctx.Err() + default: + } + cmd.mu.Lock() st := cmd.status - pk, _ := rd.Peek(1) cmd.mu.Unlock() - if len(pk) != 0 && st == monitorStatusStart { - cmd.mu.Lock() - line, err := rd.ReadString() - cmd.mu.Unlock() - if err != nil { - return err - } - cmd.ch <- line - } + if st == monitorStatusStop { + cmd.closeChannel() cancel() break } + + if st == monitorStatusStart { + cmd.mu.Lock() + pk, peekErr := rd.Peek(1) + cmd.mu.Unlock() + + if peekErr != nil { + // Check if it's a timeout error - if so, ignore and continue + if isTimeout, _ := isTimeoutError(peekErr); isTimeout { + continue + } + // For non-timeout errors, close channel and return + cmd.closeChannel() + cancel() + return peekErr + } + + if len(pk) != 0 { + cmd.mu.Lock() + line, err := rd.ReadString() + cmd.mu.Unlock() + if err != nil { + // Check if it's a timeout error - if so, ignore and continue + if isTimeout, _ := isTimeoutError(err); isTimeout { + continue + } + // For non-timeout errors, close channel and return + cmd.closeChannel() + cancel() + return err + } + cmd.ch <- line + } + } } return nil } @@ -7444,6 +7481,16 @@ func (cmd *MonitorCmd) Stop() { cmd.status = monitorStatusStop } +// closeChannel safely closes the channel if it hasn't been closed yet. +func (cmd *MonitorCmd) closeChannel() { + cmd.mu.Lock() + defer cmd.mu.Unlock() + if !cmd.closed { + close(cmd.ch) + cmd.closed = true + } +} + type VectorScoreSliceCmd struct { baseCmd diff --git a/monitor_test.go b/monitor_test.go index ebb784853..ac56a42f0 100644 --- a/monitor_test.go +++ b/monitor_test.go @@ -107,3 +107,136 @@ func TestMonitorCommand(t *testing.T) { func containsSubstring(s, substr string) bool { return strings.Contains(s, substr) } + +func TestMonitorWithTimeout(t *testing.T) { + if os.Getenv("RUN_MONITOR_TEST") != "true" { + t.Skip("Skipping Monitor command test. Set RUN_MONITOR_TEST=true to run it.") + } + + ctx := context.TODO() + // Create a client with a very short ReadTimeout (100ms) + client := redis.NewClient(&redis.Options{ + Addr: redisPort, + ReadTimeout: 100 * time.Millisecond, + }) + if err := client.FlushDB(ctx).Err(); err != nil { + t.Fatalf("FlushDB failed: %v", err) + } + + defer func() { + if err := client.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + }() + + ress := make(chan string, 10) + // Create a separate client for executing commands + commandClient := redis.NewClient(&redis.Options{Addr: redisPort}) + defer commandClient.Close() + + mn := client.Monitor(ctx, ress) + mn.Start() + + // Wait for the Redis server to be in monitoring mode. + time.Sleep(100 * time.Millisecond) + + // Wait longer than ReadTimeout to ensure timeouts occur + t.Log("Waiting for timeout to occur...") + time.Sleep(300 * time.Millisecond) + + // Execute commands after timeout should have occurred + t.Log("Executing commands after timeout...") + commandClient.Set(ctx, "key1", "value1", 0) + commandClient.Set(ctx, "key2", "value2", 0) + + // Give some time for messages to arrive + time.Sleep(100 * time.Millisecond) + + // Try to read messages - should still work despite timeouts + var lst []string + timeout := time.After(2 * time.Second) + for i := 0; i < 3; i++ { + select { + case s := <-ress: + lst = append(lst, s) + t.Logf("Received message %d: %s", i, s) + case <-timeout: + t.Fatalf("Timed out waiting for messages. Got %d messages so far", len(lst)) + } + } + + // Stop monitoring + mn.Stop() + + // Verify we got at least the OK message and the SET commands + if len(lst) < 3 { + t.Errorf("Expected at least 3 messages, got %d", len(lst)) + } + + found := false + for _, msg := range lst { + if containsSubstring(msg, `"set" "key1" "value1"`) { + found = true + break + } + } + if !found { + t.Errorf("Expected to find 'set key1 value1' in messages, got: %v", lst) + } +} + +func TestMonitorWithContextCancellation(t *testing.T) { + if os.Getenv("RUN_MONITOR_TEST") != "true" { + t.Skip("Skipping Monitor command test. Set RUN_MONITOR_TEST=true to run it.") + } + + ctx, cancel := context.WithCancel(context.Background()) + client := redis.NewClient(&redis.Options{Addr: redisPort}) + if err := client.FlushDB(ctx).Err(); err != nil { + t.Fatalf("FlushDB failed: %v", err) + } + + defer func() { + if err := client.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + }() + + ress := make(chan string, 10) + mn := client.Monitor(ctx, ress) + mn.Start() + + // Wait for the Redis server to be in monitoring mode. + time.Sleep(100 * time.Millisecond) + + // Execute a command + client.Set(ctx, "test", "value", 0) + + // Wait a bit for the message + time.Sleep(100 * time.Millisecond) + + // Cancel the context + cancel() + + // Wait a bit for cleanup + time.Sleep(100 * time.Millisecond) + + // Try to read from channel - should eventually close + timeout := time.After(2 * time.Second) + channelClosed := false + for !channelClosed { + select { + case _, ok := <-ress: + if !ok { + channelClosed = true + t.Log("Channel was properly closed") + } + case <-timeout: + break + } + } + + if !channelClosed { + t.Log("Note: Channel may not close immediately, but should not block forever") + } +}