Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 47 additions & 20 deletions pkg/channels/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ type toolFeedbackMessageCleaner interface {
DismissToolFeedbackMessage(ctx context.Context, chatID string)
}

type toolFeedbackMessageTargetResolver interface {
ToolFeedbackMessageChatID(chatID string, outboundCtx *bus.InboundContext) string
type outboundTargetResolver interface {
ResolveOutboundChatID(chatID string, outboundCtx *bus.InboundContext) string
}

type toolFeedbackMessageContentPreparer interface {
Expand All @@ -155,6 +155,10 @@ func outboundMessageChatID(msg bus.OutboundMessage) string {
return msg.ChatID
}

func resolvedOutboundMessageChatID(ch Channel, msg bus.OutboundMessage) string {
return resolveOutboundChatID(ch, outboundMessageChatID(msg), &msg.Context)
}

func outboundMessageIsToolFeedback(msg bus.OutboundMessage) bool {
if len(msg.Context.Raw) == 0 {
return false
Expand All @@ -178,9 +182,22 @@ func outboundMediaChatID(msg bus.OutboundMediaMessage) string {
return msg.ChatID
}

func trackedToolFeedbackMessageChatID(ch Channel, chatID string, outboundCtx *bus.InboundContext) string {
if resolver, ok := ch.(toolFeedbackMessageTargetResolver); ok {
if resolved := strings.TrimSpace(resolver.ToolFeedbackMessageChatID(chatID, outboundCtx)); resolved != "" {
func resolvedOutboundMediaChatID(ch Channel, msg bus.OutboundMediaMessage) string {
return resolveOutboundChatID(ch, outboundMediaChatID(msg), &msg.Context)
}

func candidateChatIDs(raw, resolved string) []string {
raw = strings.TrimSpace(raw)
resolved = strings.TrimSpace(resolved)
if raw == "" || raw == resolved {
return []string{resolved}
}
return []string{resolved, raw}
}

func resolveOutboundChatID(ch Channel, chatID string, outboundCtx *bus.InboundContext) string {
if resolver, ok := ch.(outboundTargetResolver); ok {
if resolved := strings.TrimSpace(resolver.ResolveOutboundChatID(chatID, outboundCtx)); resolved != "" {
return resolved
}
}
Expand All @@ -193,7 +210,7 @@ func dismissTrackedToolFeedbackMessage(
chatID string,
outboundCtx *bus.InboundContext,
) {
trackedChatID := trackedToolFeedbackMessageChatID(ch, chatID, outboundCtx)
trackedChatID := resolveOutboundChatID(ch, chatID, outboundCtx)
if trackedChatID == "" {
return
}
Expand All @@ -211,7 +228,7 @@ func clearTrackedToolFeedbackMessage(
chatID string,
outboundCtx *bus.InboundContext,
) {
trackedChatID := trackedToolFeedbackMessageChatID(ch, chatID, outboundCtx)
trackedChatID := resolveOutboundChatID(ch, chatID, outboundCtx)
if trackedChatID == "" {
return
}
Expand Down Expand Up @@ -320,18 +337,23 @@ func (m *Manager) RecordReactionUndo(channel, chatID string, undo func()) {
func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMessage, ch Channel) ([]string, bool) {
chatID := outboundMessageChatID(msg)
key := name + ":" + chatID
cleanupChatIDs := candidateChatIDs(chatID, resolvedOutboundMessageChatID(ch, msg))

// 1. Stop typing
if v, loaded := m.typingStops.LoadAndDelete(key); loaded {
if entry, ok := v.(typingEntry); ok {
entry.stop() // idempotent, safe
for _, cleanupChatID := range cleanupChatIDs {
if v, loaded := m.typingStops.LoadAndDelete(name + ":" + cleanupChatID); loaded {
if entry, ok := v.(typingEntry); ok {
entry.stop() // idempotent, safe
}
}
}

// 2. Undo reaction
if v, loaded := m.reactionUndos.LoadAndDelete(key); loaded {
if entry, ok := v.(reactionEntry); ok {
entry.undo() // idempotent, safe
for _, cleanupChatID := range cleanupChatIDs {
if v, loaded := m.reactionUndos.LoadAndDelete(name + ":" + cleanupChatID); loaded {
if entry, ok := v.(reactionEntry); ok {
entry.undo() // idempotent, safe
}
}
}

Expand Down Expand Up @@ -397,7 +419,7 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
content = InitialAnimatedToolFeedbackContent(trackedContent)
}
if err := editor.EditMessage(ctx, chatID, entry.id, content); err == nil {
trackedChatID := trackedToolFeedbackMessageChatID(ch, chatID, &msg.Context)
trackedChatID := resolveOutboundChatID(ch, chatID, &msg.Context)
if tracker, ok := ch.(toolFeedbackMessageTracker); ok && isToolFeedback {
tracker.RecordToolFeedbackMessage(trackedChatID, entry.id, trackedContent)
} else if !isToolFeedback {
Expand All @@ -420,18 +442,23 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.OutboundMediaMessage, ch Channel) {
chatID := outboundMediaChatID(msg)
key := name + ":" + chatID
cleanupChatIDs := candidateChatIDs(chatID, resolvedOutboundMediaChatID(ch, msg))

// 1. Stop typing
if v, loaded := m.typingStops.LoadAndDelete(key); loaded {
if entry, ok := v.(typingEntry); ok {
entry.stop() // idempotent, safe
for _, cleanupChatID := range cleanupChatIDs {
if v, loaded := m.typingStops.LoadAndDelete(name + ":" + cleanupChatID); loaded {
if entry, ok := v.(typingEntry); ok {
entry.stop() // idempotent, safe
}
}
}

// 2. Undo reaction
if v, loaded := m.reactionUndos.LoadAndDelete(key); loaded {
if entry, ok := v.(reactionEntry); ok {
entry.undo() // idempotent, safe
for _, cleanupChatID := range cleanupChatIDs {
if v, loaded := m.reactionUndos.LoadAndDelete(name + ":" + cleanupChatID); loaded {
if entry, ok := v.(reactionEntry); ok {
entry.undo() // idempotent, safe
}
}
}

Expand Down
38 changes: 36 additions & 2 deletions pkg/channels/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (m *mockStreamingChannel) BeginStream(context.Context, string) (Streamer, e
return m.streamer, nil
}

func (m *mockStreamingChannel) ToolFeedbackMessageChatID(
func (m *mockStreamingChannel) ResolveOutboundChatID(
chatID string,
outboundCtx *bus.InboundContext,
) string {
Expand Down Expand Up @@ -948,7 +948,7 @@ func (m *mockDeletingMessageEditor) DeleteMessage(_ context.Context, chatID, mes
return nil
}

func (m *mockResolvedToolFeedbackEditor) ToolFeedbackMessageChatID(
func (m *mockResolvedToolFeedbackEditor) ResolveOutboundChatID(
chatID string,
outboundCtx *bus.InboundContext,
) string {
Expand Down Expand Up @@ -1901,6 +1901,40 @@ func TestPreSend_TypingStopCalled(t *testing.T) {
}
}

func TestPreSend_TypingStopUsesResolvedChatID(t *testing.T) {
m := newTestManager()
var stopCalled bool

ch := &mockResolvedToolFeedbackEditor{
resolveChatIDFn: func(chatID string, outboundCtx *bus.InboundContext) string {
if outboundCtx == nil || outboundCtx.TopicID != "42" {
return chatID
}
return chatID + "/" + outboundCtx.TopicID
},
}

m.RecordTypingStop("test", "123/42", func() {
stopCalled = true
})

msg := testOutboundMessage(bus.OutboundMessage{
Channel: "test",
ChatID: "123",
Content: "hello",
Context: bus.InboundContext{
Channel: "test",
ChatID: "123",
TopicID: "42",
},
})
m.preSend(context.Background(), "test", msg, ch)

if !stopCalled {
t.Fatal("expected typing stop func to be called for resolved topic chat ID")
}
}

func TestPreSend_NoRegisteredState(t *testing.T) {
m := newTestManager()

Expand Down
12 changes: 6 additions & 6 deletions pkg/channels/telegram/telegram.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]
if isToolFeedback {
toolFeedbackContent = fitToolFeedbackForTelegram(msg.Content, useMarkdownV2, 4096)
}
trackedChatID := telegramToolFeedbackChatKey(msg.ChatID, &msg.Context)
trackedChatID := telegramResolvedChatID(msg.ChatID, &msg.Context)
if isToolFeedback {
if msgID, handled, err := c.progress.Update(ctx, trackedChatID, toolFeedbackContent); handled {
if err != nil {
Expand Down Expand Up @@ -553,7 +553,7 @@ func (c *TelegramChannel) FinalizeToolFeedbackMessage(ctx context.Context, msg b
if outboundMessageIsToolFeedback(msg) {
return nil, false
}
return c.finalizeToolFeedbackMessageForChat(ctx, telegramToolFeedbackChatKey(msg.ChatID, &msg.Context), msg)
return c.finalizeToolFeedbackMessageForChat(ctx, telegramResolvedChatID(msg.ChatID, &msg.Context), msg)
}

func (c *TelegramChannel) finalizeToolFeedbackMessageForChat(
Expand Down Expand Up @@ -595,7 +595,7 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe
if !c.IsRunning() {
return nil, channels.ErrNotRunning
}
trackedChatID := telegramToolFeedbackChatKey(msg.ChatID, &msg.Context)
trackedChatID := telegramResolvedChatID(msg.ChatID, &msg.Context)
trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(trackedChatID)

chatID, threadID, err := resolveTelegramOutboundTarget(msg.ChatID, &msg.Context)
Expand Down Expand Up @@ -1122,16 +1122,16 @@ func (c *TelegramChannel) PrepareToolFeedbackMessageContent(content string) stri
return fitToolFeedbackForTelegram(content, c.tgCfg.UseMarkdownV2, 4096)
}

func telegramToolFeedbackChatKey(chatID string, outboundCtx *bus.InboundContext) string {
func telegramResolvedChatID(chatID string, outboundCtx *bus.InboundContext) string {
resolvedChatID, threadID, err := resolveTelegramOutboundTarget(chatID, outboundCtx)
if err != nil || threadID == 0 {
return strings.TrimSpace(chatID)
}
return fmt.Sprintf("%d/%d", resolvedChatID, threadID)
}

func (c *TelegramChannel) ToolFeedbackMessageChatID(chatID string, outboundCtx *bus.InboundContext) string {
return telegramToolFeedbackChatKey(chatID, outboundCtx)
func (c *TelegramChannel) ResolveOutboundChatID(chatID string, outboundCtx *bus.InboundContext) string {
return telegramResolvedChatID(chatID, outboundCtx)
}

// parseTelegramChatID splits "chatID/threadID" into its components.
Expand Down