diff --git a/pkg/agent/agent_init.go b/pkg/agent/agent_init.go index 76f12fa65f..64ec23a020 100644 --- a/pkg/agent/agent_init.go +++ b/pkg/agent/agent_init.go @@ -5,6 +5,7 @@ package agent import ( "context" "fmt" + "strings" "time" "github.com/sipeed/picoclaw/pkg/agent/interfaces" @@ -157,6 +158,7 @@ func registerSharedTools( tools.ToolSessionKey(ctx), tools.ToolSessionScope(ctx), ) + inheritToolTopic(ctx, &outboundCtx, channel, chatID, outboundScope) return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ Context: outboundCtx, AgentID: outboundAgentID, @@ -339,3 +341,29 @@ func registerSharedTools( } } } + +func inheritToolTopic( + ctx context.Context, + outboundCtx *bus.InboundContext, + channel, chatID string, + scope *bus.OutboundScope, +) { + if outboundCtx == nil || strings.TrimSpace(outboundCtx.TopicID) != "" { + return + } + if strings.TrimSpace(channel) != strings.TrimSpace(tools.ToolChannel(ctx)) || + strings.TrimSpace(chatID) != strings.TrimSpace(tools.ToolChatID(ctx)) { + return + } + if scope == nil || scope.Values == nil { + return + } + if topic := strings.TrimPrefix(strings.TrimSpace(scope.Values["topic"]), "topic:"); topic != "" { + outboundCtx.TopicID = topic + return + } + chatScope := strings.TrimSpace(scope.Values["chat"]) + if idx := strings.LastIndex(chatScope, "/"); idx >= 0 && idx+1 < len(chatScope) { + outboundCtx.TopicID = strings.TrimSpace(chatScope[idx+1:]) + } +} diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index a75919912f..5d8f0c51a2 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -1721,6 +1721,40 @@ func (m *messageToolProvider) GetDefaultModel() string { return "message-tool-model" } +type explicitChatMessageToolProvider struct { + calls int +} + +func (m *explicitChatMessageToolProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + if m.calls == 1 { + return &providers.LLMResponse{ + Content: "", + ToolCalls: []providers.ToolCall{{ + ID: "call_message", + Type: "function", + Name: "message", + Arguments: map[string]any{ + "channel": "telegram", + "chat_id": "-1001234567890", + "content": "topic tool message", + }, + }}, + }, nil + } + return &providers.LLMResponse{}, nil +} + +func (m *explicitChatMessageToolProvider) GetDefaultModel() string { + return "message-tool-model" +} + type reasoningVisibleToolProvider struct { filePath string calls int @@ -4444,6 +4478,52 @@ func TestProcessMessage_MessageToolPublishesOutboundWithTurnMetadata(t *testing. } } +func TestProcessMessage_MessageToolInheritsTelegramTopicWithExplicitChatID(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = t.TempDir() + cfg.Agents.Defaults.ModelName = "test-model" + cfg.Agents.Defaults.MaxTokens = 4096 + cfg.Agents.Defaults.MaxToolIterations = 10 + cfg.Session.Dimensions = []string{"chat"} + + msgBus := bus.NewMessageBus() + provider := &explicitChatMessageToolProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "-1001234567890", + ChatType: "group", + TopicID: "6", + SenderID: "user-1", + MessageID: "475", + }, + Content: "send an interim message", + })) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if response == "" { + t.Fatal("expected processMessage() to return a final loop response") + } + + select { + case outbound := <-msgBus.OutboundChan(): + if outbound.Content != "topic tool message" { + t.Fatalf("outbound content = %q, want topic tool message", outbound.Content) + } + if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "-1001234567890" { + t.Fatalf("unexpected message tool outbound context: %+v", outbound.Context) + } + if outbound.Context.TopicID != "6" { + t.Fatalf("outbound topic = %q, want 6; context=%+v scope=%+v", outbound.Context.TopicID, outbound.Context, outbound.Scope) + } + case <-time.After(2 * time.Second): + t.Fatal("expected message tool outbound") + } +} + func TestRun_PicoPublishesAssistantContentDuringToolCallsWithoutFinalDuplicate(t *testing.T) { tmpDir := t.TempDir()