Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions pkg/agent/agent_outbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,15 @@ func (al *AgentLoop) maybePublishError(ctx context.Context, channel, chatID, ses
return true
}

func (al *AgentLoop) publishResponseOrError(
ctx context.Context,
channel, chatID, sessionKey string,
response string,
err error,
) {
if err != nil {
if !al.maybePublishError(ctx, channel, chatID, sessionKey, err) {
return
}
response = ""
}
al.PublishResponseIfNeeded(ctx, channel, chatID, sessionKey, response)
func (al *AgentLoop) PublishResponseIfNeeded(ctx context.Context, channel, chatID, sessionKey, response string) {
al.publishResponseWithContextIfNeeded(ctx, channel, chatID, sessionKey, response, nil)
}

func (al *AgentLoop) PublishResponseIfNeeded(ctx context.Context, channel, chatID, sessionKey, response string) {
func (al *AgentLoop) publishResponseWithContextIfNeeded(
ctx context.Context,
channel, chatID, sessionKey, response string,
inboundCtx *bus.InboundContext,
) {
if response == "" {
return
}
Expand All @@ -64,18 +57,28 @@ func (al *AgentLoop) PublishResponseIfNeeded(ctx context.Context, channel, chatI
return
}

agent := al.agentForSession(sessionKey)
agentID := ""
if agent != nil {
agentID = agent.ID
}
msg := bus.OutboundMessage{
Context: bus.NewOutboundContext(channel, chatID, ""),
Content: response,
Channel: channel,
ChatID: chatID,
Context: outboundContextFromInbound(inboundCtx, channel, chatID, ""),
AgentID: agentID,
SessionKey: sessionKey,
Content: response,
}
if sessionKey != "" {
msg.ContextUsage = computeContextUsage(al.agentForSession(sessionKey), sessionKey)
msg.ContextUsage = computeContextUsage(agent, sessionKey)
}
al.bus.PublishOutbound(ctx, msg)
logger.InfoCF("agent", "Published outbound response",
map[string]any{
"channel": channel,
"chat_id": chatID,
"topic_id": msg.Context.TopicID,
"content_len": len(response),
})
}
Expand Down
17 changes: 15 additions & 2 deletions pkg/agent/agent_steering.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@ func (al *AgentLoop) processMessageSync(ctx context.Context, msg bus.InboundMess
}

response, err := al.processMessage(ctx, msg)
al.publishResponseOrError(ctx, msg.Channel, msg.ChatID, msg.SessionKey, response, err)
if err != nil {
if !al.maybePublishError(ctx, msg.Channel, msg.ChatID, msg.SessionKey, err) {
return
}
response = ""
}
al.publishResponseWithContextIfNeeded(ctx, msg.Channel, msg.ChatID, msg.SessionKey, response, &msg.Context)
}

func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.InboundMessage) {
Expand Down Expand Up @@ -58,7 +64,14 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb

// Publish final response
if finalResponse != "" {
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, finalResponse)
al.publishResponseWithContextIfNeeded(
ctx,
target.Channel,
target.ChatID,
target.SessionKey,
finalResponse,
&initialMsg.Context,
)
}
}

Expand Down
51 changes: 51 additions & 0 deletions pkg/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2518,6 +2518,57 @@ func TestProcessMessage_UsesRouteSessionKey(t *testing.T) {
}
}

func TestProcessMessageSync_PreservesInboundTopicOnFinalResponse(t *testing.T) {
tmpDir := t.TempDir()

cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}

msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "topic response"}
al := NewAgentLoop(cfg, msgBus, provider)

msg := testInboundMessage(bus.InboundMessage{
Context: bus.InboundContext{
Channel: "telegram",
ChatID: "-1001234567890",
ChatType: "group",
TopicID: "42",
SenderID: "user1",
MessageID: "123",
},
Content: "hello topic",
})

al.processMessageSync(context.Background(), msg)

select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content != "topic response" {
t.Fatalf("outbound content = %q, want topic response", outbound.Content)
}
if outbound.Channel != "telegram" || outbound.ChatID != "-1001234567890" {
t.Fatalf("outbound route = %s/%s, want telegram/-1001234567890", outbound.Channel, outbound.ChatID)
}
if outbound.Context.TopicID != "42" {
t.Fatalf("outbound topic = %q, want 42; context=%+v", outbound.Context.TopicID, outbound.Context)
}
if outbound.Context.MessageID != "123" {
t.Fatalf("outbound context message ID = %q, want 123", outbound.Context.MessageID)
}
case <-time.After(responseTimeout):
t.Fatal("timed out waiting for outbound response")
}
}

func TestProcessMessage_CommandOutcomes(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
Expand Down