diff --git a/.changeset/fix-preemptive-tool-race.md b/.changeset/fix-preemptive-tool-race.md new file mode 100644 index 000000000..2f73187f9 --- /dev/null +++ b/.changeset/fix-preemptive-tool-race.md @@ -0,0 +1,5 @@ +--- +"@livekit/agents": patch +--- + +Fix LLM context corruption when user speaks during tool execution diff --git a/agents/src/voice/agent_activity.test.ts b/agents/src/voice/agent_activity.test.ts index 5f03fb164..d2c33cb12 100644 --- a/agents/src/voice/agent_activity.test.ts +++ b/agents/src/voice/agent_activity.test.ts @@ -16,8 +16,10 @@ */ import { Heap } from 'heap-js'; import { describe, expect, it, vi } from 'vitest'; +import { LLM } from '../llm/llm.js'; import { Future } from '../utils.js'; import { AgentActivity } from './agent_activity.js'; +import type { PreemptiveGenerationInfo } from './audio_recognition.js'; import { SpeechHandle } from './speech_handle.js'; // Break circular dependency: agent_activity.ts → agent.js → beta/workflows/task_group.ts @@ -81,7 +83,7 @@ function buildMainTaskRunner() { }, }; - const mainTask = (AgentActivity.prototype as Record).mainTask as ( + const mainTask = (AgentActivity.prototype as unknown as Record).mainTask as ( signal: AbortSignal, ) => Promise; @@ -93,6 +95,70 @@ function buildMainTaskRunner() { }; } +describe('AgentActivity - _toolExecutionInProgress guard', () => { + it('should block preemptive generation when tool execution is in progress', () => { + // onPreemptiveGeneration checks this._toolExecutionInProgress and early-returns. + // We verify the guard by calling the method on a minimal stub where all other + // guards pass but _toolExecutionInProgress is true. + const onPreemptiveGeneration = (AgentActivity.prototype as unknown as Record) + .onPreemptiveGeneration as (info: PreemptiveGenerationInfo) => void; + + const generateReplySpy = vi.fn(); + const fakeActivity = { + agentSession: { sessionOptions: { preemptiveGeneration: true } }, + schedulingPaused: false, + _currentSpeech: undefined, + _toolExecutionInProgress: true, + llm: Object.create(LLM.prototype), + _preemptiveGeneration: undefined, + cancelPreemptiveGeneration: vi.fn(), + generateReply: generateReplySpy, + agent: { chatCtx: { copy: () => ({ copy: () => ({}) }) } }, + tools: {}, + toolChoice: null, + logger: { info: vi.fn(), debug: vi.fn(), warn: vi.fn(), error: vi.fn() }, + }; + + onPreemptiveGeneration.call(fakeActivity, { + newTranscript: 'test transcript', + transcriptConfidence: 1.0, + } as PreemptiveGenerationInfo); + + expect(generateReplySpy).not.toHaveBeenCalled(); + expect(fakeActivity._preemptiveGeneration).toBeUndefined(); + }); + + it('should allow preemptive generation when no tool execution is in progress', () => { + const onPreemptiveGeneration = (AgentActivity.prototype as unknown as Record) + .onPreemptiveGeneration as (info: PreemptiveGenerationInfo) => void; + + const mockSpeechHandle = { id: 'test' }; + const generateReplySpy = vi.fn().mockReturnValue(mockSpeechHandle); + const fakeActivity = { + agentSession: { sessionOptions: { preemptiveGeneration: true } }, + schedulingPaused: false, + _currentSpeech: undefined, + _toolExecutionInProgress: false, + llm: Object.create(LLM.prototype), + _preemptiveGeneration: undefined, + cancelPreemptiveGeneration: vi.fn(), + generateReply: generateReplySpy, + agent: { chatCtx: { copy: () => ({ copy: () => ({}) }) } }, + tools: {}, + toolChoice: null, + logger: { info: vi.fn(), debug: vi.fn(), warn: vi.fn(), error: vi.fn() }, + }; + + onPreemptiveGeneration.call(fakeActivity, { + newTranscript: 'test transcript', + transcriptConfidence: 1.0, + } as PreemptiveGenerationInfo); + + expect(generateReplySpy).toHaveBeenCalledOnce(); + expect(fakeActivity._preemptiveGeneration).toBeDefined(); + }); +}); + describe('AgentActivity - mainTask', () => { it('should recover when speech handle is interrupted after authorization', async () => { const { fakeActivity, mainTask, speechQueue, q_updated } = buildMainTaskRunner(); diff --git a/agents/src/voice/agent_activity.ts b/agents/src/voice/agent_activity.ts index 8b0fa6b26..433e2a34d 100644 --- a/agents/src/voice/agent_activity.ts +++ b/agents/src/voice/agent_activity.ts @@ -129,6 +129,7 @@ export class AgentActivity implements RecognitionHooks { // default to null as None, which maps to the default provider tool choice value private toolChoice: ToolChoice | null = null; + private _toolExecutionInProgress = false; private _preemptiveGeneration?: PreemptiveGeneration; private interruptionDetector?: AdaptiveInterruptionDetector; private isInterruptionDetectionEnabled: boolean; @@ -1031,6 +1032,7 @@ export class AgentActivity implements RecognitionHooks { !this.agentSession.sessionOptions.preemptiveGeneration || this.schedulingPaused || (this._currentSpeech !== undefined && !this._currentSpeech.interrupted) || + this._toolExecutionInProgress || !(this.llm instanceof LLM) ) { return; @@ -1955,233 +1957,241 @@ export class AgentActivity implements RecognitionHooks { onToolExecutionStarted, onToolExecutionCompleted, }); + this._toolExecutionInProgress = true; + try { + await speechHandle.waitIfNotInterrupted(tasks.map((task) => task.result)); - await speechHandle.waitIfNotInterrupted(tasks.map((task) => task.result)); - - if (audioOutput) { - await speechHandle.waitIfNotInterrupted([audioOutput.waitForPlayout()]); - } + if (audioOutput) { + await speechHandle.waitIfNotInterrupted([audioOutput.waitForPlayout()]); + } - const agentStoppedSpeakingAt = Date.now(); - const assistantMetrics: MetricsReport = {}; + const agentStoppedSpeakingAt = Date.now(); + const assistantMetrics: MetricsReport = {}; - if (llmGenData.ttft !== undefined) { - assistantMetrics.llmNodeTtft = llmGenData.ttft; // already in seconds - } - if (ttsGenData?.ttfb !== undefined) { - assistantMetrics.ttsNodeTtfb = ttsGenData.ttfb; // already in seconds - } - if (agentStartedSpeakingAt !== undefined) { - assistantMetrics.startedSpeakingAt = agentStartedSpeakingAt / 1000; // ms -> seconds - assistantMetrics.stoppedSpeakingAt = agentStoppedSpeakingAt / 1000; // ms -> seconds + if (llmGenData.ttft !== undefined) { + assistantMetrics.llmNodeTtft = llmGenData.ttft; // already in seconds + } + if (ttsGenData?.ttfb !== undefined) { + assistantMetrics.ttsNodeTtfb = ttsGenData.ttfb; // already in seconds + } + if (agentStartedSpeakingAt !== undefined) { + assistantMetrics.startedSpeakingAt = agentStartedSpeakingAt / 1000; // ms -> seconds + assistantMetrics.stoppedSpeakingAt = agentStoppedSpeakingAt / 1000; // ms -> seconds - if (userMetrics?.stoppedSpeakingAt !== undefined) { - const e2eLatency = agentStartedSpeakingAt / 1000 - userMetrics.stoppedSpeakingAt; - assistantMetrics.e2eLatency = e2eLatency; - span.setAttribute(traceTypes.ATTR_E2E_LATENCY, e2eLatency); + if (userMetrics?.stoppedSpeakingAt !== undefined) { + const e2eLatency = agentStartedSpeakingAt / 1000 - userMetrics.stoppedSpeakingAt; + assistantMetrics.e2eLatency = e2eLatency; + span.setAttribute(traceTypes.ATTR_E2E_LATENCY, e2eLatency); + } } - } - span.setAttribute(traceTypes.ATTR_SPEECH_INTERRUPTED, speechHandle.interrupted); - let hasSpeechMessage = false; + span.setAttribute(traceTypes.ATTR_SPEECH_INTERRUPTED, speechHandle.interrupted); + let hasSpeechMessage = false; - // add the tools messages that triggers this reply to the chat context - if (toolsMessages) { - for (const msg of toolsMessages) { - msg.createdAt = replyStartedAt; - } - // Only insert FunctionCallOutput items into agent._chatCtx since FunctionCall items - // were already added by onToolExecutionStarted when the tool execution began. - // Inserting function_calls again would create duplicates that break provider APIs - // (e.g. Google's "function response parts != function call parts" error). - const toolCallOutputs = toolsMessages.filter( - (m): m is FunctionCallOutput => m.type === 'function_call_output', - ); - if (toolCallOutputs.length > 0) { - this.agent._chatCtx.insert(toolCallOutputs); - this.agentSession._toolItemsAdded(toolCallOutputs); + // add the tools messages that triggers this reply to the chat context + if (toolsMessages) { + for (const msg of toolsMessages) { + msg.createdAt = replyStartedAt; + } + // Only insert FunctionCallOutput items into agent._chatCtx since FunctionCall items + // were already added by onToolExecutionStarted when the tool execution began. + // Inserting function_calls again would create duplicates that break provider APIs + // (e.g. Google's "function response parts != function call parts" error). + const toolCallOutputs = toolsMessages.filter( + (m): m is FunctionCallOutput => m.type === 'function_call_output', + ); + if (toolCallOutputs.length > 0) { + this.agent._chatCtx.insert(toolCallOutputs); + this.agentSession._toolItemsAdded(toolCallOutputs); + } } - } - if (speechHandle.interrupted) { - this.logger.debug( - { speech_id: speechHandle.id }, - 'Aborting all pipeline reply tasks due to interruption', - ); + if (speechHandle.interrupted) { + this.logger.debug( + { speech_id: speechHandle.id }, + 'Aborting all pipeline reply tasks due to interruption', + ); - // Stop playout ASAP (don't wait for cancellations), otherwise the segment may finish and we - // will correctly (but undesirably) commit a long transcript even though the user said "stop". - if (audioOutput) { - audioOutput.clearBuffer(); - } + // Stop playout ASAP (don't wait for cancellations), otherwise the segment may finish and we + // will correctly (but undesirably) commit a long transcript even though the user said "stop". + if (audioOutput) { + audioOutput.clearBuffer(); + } - replyAbortController.abort(); - await cancelAndWait(tasks, AgentActivity.REPLY_TASK_CANCEL_TIMEOUT); + replyAbortController.abort(); + await cancelAndWait(tasks, AgentActivity.REPLY_TASK_CANCEL_TIMEOUT); - let forwardedText = textOut?.text || ''; + let forwardedText = textOut?.text || ''; - if (audioOutput) { - const playbackEv = await audioOutput.waitForPlayout(); - if (audioOut?.firstFrameFut.done && !audioOut.firstFrameFut.rejected) { - // playback EV is valid only if the first frame was already played - this.logger.info( - { speech_id: speechHandle.id, playbackPositionInS: playbackEv.playbackPosition }, - 'playout interrupted', - ); - if (playbackEv.synchronizedTranscript) { - forwardedText = playbackEv.synchronizedTranscript; + if (audioOutput) { + const playbackEv = await audioOutput.waitForPlayout(); + if (audioOut?.firstFrameFut.done && !audioOut.firstFrameFut.rejected) { + // playback EV is valid only if the first frame was already played + this.logger.info( + { speech_id: speechHandle.id, playbackPositionInS: playbackEv.playbackPosition }, + 'playout interrupted', + ); + if (playbackEv.synchronizedTranscript) { + forwardedText = playbackEv.synchronizedTranscript; + } + } else { + forwardedText = ''; + } + } + + if (forwardedText) { + hasSpeechMessage = true; + const message = ChatMessage.create({ + role: 'assistant', + content: forwardedText, + id: llmGenData.id, + interrupted: true, + createdAt: replyStartedAt, + metrics: assistantMetrics, + }); + chatCtx.insert(message); + this.agent._chatCtx.insert(message); + speechHandle._itemAdded([message]); + this.agentSession._conversationItemAdded(message); + span.setAttribute(traceTypes.ATTR_RESPONSE_TEXT, forwardedText); + } + + if (this.agentSession.agentState === 'speaking') { + this.agentSession._updateAgentState('listening'); + if (this.isInterruptionDetectionEnabled && this.audioRecognition) { + this.audioRecognition.onEndOfAgentSpeech(Date.now()); + this.restoreInterruptionByAudioActivity(); } - } else { - forwardedText = ''; } + + this.logger.info( + { speech_id: speechHandle.id, message: forwardedText }, + 'playout completed with interrupt', + ); + speechHandle._markGenerationDone(); + await executeToolsTask.cancelAndWait(AgentActivity.REPLY_TASK_CANCEL_TIMEOUT); + return; } - if (forwardedText) { + if (textOut && textOut.text) { hasSpeechMessage = true; const message = ChatMessage.create({ role: 'assistant', - content: forwardedText, id: llmGenData.id, - interrupted: true, + interrupted: false, createdAt: replyStartedAt, + content: textOut.text, metrics: assistantMetrics, }); chatCtx.insert(message); this.agent._chatCtx.insert(message); speechHandle._itemAdded([message]); this.agentSession._conversationItemAdded(message); - span.setAttribute(traceTypes.ATTR_RESPONSE_TEXT, forwardedText); + span.setAttribute(traceTypes.ATTR_RESPONSE_TEXT, textOut.text); + this.logger.info( + { speech_id: speechHandle.id, message: textOut.text }, + 'playout completed without interruption', + ); } - if (this.agentSession.agentState === 'speaking') { + if (toolOutput.output.length > 0) { + this.agentSession._updateAgentState('thinking'); + } else if (this.agentSession.agentState === 'speaking') { this.agentSession._updateAgentState('listening'); if (this.isInterruptionDetectionEnabled && this.audioRecognition) { - this.audioRecognition.onEndOfAgentSpeech(Date.now()); - this.restoreInterruptionByAudioActivity(); + { + this.audioRecognition.onEndOfAgentSpeech(Date.now()); + this.restoreInterruptionByAudioActivity(); + } } } - this.logger.info( - { speech_id: speechHandle.id, message: forwardedText }, - 'playout completed with interrupt', - ); + // mark the playout done before waiting for the tool execution speechHandle._markGenerationDone(); - await executeToolsTask.cancelAndWait(AgentActivity.REPLY_TASK_CANCEL_TIMEOUT); - return; - } + await executeToolsTask.result; - if (textOut && textOut.text) { - hasSpeechMessage = true; - const message = ChatMessage.create({ - role: 'assistant', - id: llmGenData.id, - interrupted: false, - createdAt: replyStartedAt, - content: textOut.text, - metrics: assistantMetrics, - }); - chatCtx.insert(message); - this.agent._chatCtx.insert(message); - speechHandle._itemAdded([message]); - this.agentSession._conversationItemAdded(message); - span.setAttribute(traceTypes.ATTR_RESPONSE_TEXT, textOut.text); - this.logger.info( - { speech_id: speechHandle.id, message: textOut.text }, - 'playout completed without interruption', - ); - } + if (toolOutput.output.length === 0) return; - if (toolOutput.output.length > 0) { - this.agentSession._updateAgentState('thinking'); - } else if (this.agentSession.agentState === 'speaking') { - this.agentSession._updateAgentState('listening'); - if (this.isInterruptionDetectionEnabled && this.audioRecognition) { - { - this.audioRecognition.onEndOfAgentSpeech(Date.now()); - this.restoreInterruptionByAudioActivity(); - } + // important: no agent output should be used after this point + const { maxToolSteps } = this.agentSession.sessionOptions; + if (speechHandle.numSteps >= maxToolSteps) { + this.logger.warn( + { speech_id: speechHandle.id, max_tool_steps: maxToolSteps }, + 'maximum number of function calls steps reached', + ); + return; } - } - - // mark the playout done before waiting for the tool execution - speechHandle._markGenerationDone(); - await executeToolsTask.result; - if (toolOutput.output.length === 0) return; + const { + functionToolsExecutedEvent, + shouldGenerateToolReply, + newAgentTask, + ignoreTaskSwitch, + } = this.summarizeToolExecutionOutput(toolOutput, speechHandle); - // important: no agent output should be used after this point - const { maxToolSteps } = this.agentSession.sessionOptions; - if (speechHandle.numSteps >= maxToolSteps) { - this.logger.warn( - { speech_id: speechHandle.id, max_tool_steps: maxToolSteps }, - 'maximum number of function calls steps reached', + this.agentSession.emit( + AgentSessionEventTypes.FunctionToolsExecuted, + functionToolsExecutedEvent, ); - return; - } - - const { functionToolsExecutedEvent, shouldGenerateToolReply, newAgentTask, ignoreTaskSwitch } = - this.summarizeToolExecutionOutput(toolOutput, speechHandle); - - this.agentSession.emit( - AgentSessionEventTypes.FunctionToolsExecuted, - functionToolsExecutedEvent, - ); - - let schedulingPaused = this.schedulingPaused; - if (!ignoreTaskSwitch && newAgentTask !== null) { - this.agentSession.updateAgent(newAgentTask); - schedulingPaused = true; - } - - const toolMessages = [ - ...functionToolsExecutedEvent.functionCalls, - ...functionToolsExecutedEvent.functionCallOutputs, - ] as ChatItem[]; - if (shouldGenerateToolReply) { - chatCtx.insert(toolMessages); - - // Increment step count on SAME handle (parity with Python agent_activity.py L2081) - speechHandle._numSteps += 1; - - // Avoid setting tool_choice to "required" or a specific function when - // passing tool response back to the LLM - const respondToolChoice = - schedulingPaused || modelSettings.toolChoice === 'none' ? 'none' : 'auto'; - // Reuse same speechHandle for tool response (parity with Python agent_activity.py L2122-2140) - const toolResponseTask = this.createSpeechTask({ - taskFn: () => - this.pipelineReplyTask( - speechHandle, - chatCtx, - toolCtx, - { toolChoice: respondToolChoice }, - replyAbortController, - instructions, - undefined, - toolMessages, - hasSpeechMessage ? undefined : userMetrics, - ), - ownedSpeechHandle: speechHandle, - name: 'AgentActivity.pipelineReply', - }); + let schedulingPaused = this.schedulingPaused; + if (!ignoreTaskSwitch && newAgentTask !== null) { + this.agentSession.updateAgent(newAgentTask); + schedulingPaused = true; + } + + const toolMessages = [ + ...functionToolsExecutedEvent.functionCalls, + ...functionToolsExecutedEvent.functionCallOutputs, + ] as ChatItem[]; + if (shouldGenerateToolReply) { + chatCtx.insert(toolMessages); + + // Increment step count on SAME handle (parity with Python agent_activity.py L2081) + speechHandle._numSteps += 1; + + // Avoid setting tool_choice to "required" or a specific function when + // passing tool response back to the LLM + const respondToolChoice = + schedulingPaused || modelSettings.toolChoice === 'none' ? 'none' : 'auto'; + + // Reuse same speechHandle for tool response (parity with Python agent_activity.py L2122-2140) + const toolResponseTask = this.createSpeechTask({ + taskFn: () => + this.pipelineReplyTask( + speechHandle, + chatCtx, + toolCtx, + { toolChoice: respondToolChoice }, + replyAbortController, + instructions, + undefined, + toolMessages, + hasSpeechMessage ? undefined : userMetrics, + ), + ownedSpeechHandle: speechHandle, + name: 'AgentActivity.pipelineReply', + }); - toolResponseTask.result.finally(() => this.onPipelineReplyDone()); + toolResponseTask.result.finally(() => this.onPipelineReplyDone()); - this.scheduleSpeech(speechHandle, SpeechHandle.SPEECH_PRIORITY_NORMAL, true); - } else if (functionToolsExecutedEvent.functionCallOutputs.length > 0) { - for (const msg of toolMessages) { - msg.createdAt = replyStartedAt; - } + this.scheduleSpeech(speechHandle, SpeechHandle.SPEECH_PRIORITY_NORMAL, true); + } else if (functionToolsExecutedEvent.functionCallOutputs.length > 0) { + for (const msg of toolMessages) { + msg.createdAt = replyStartedAt; + } - const toolCallOutputs = toolMessages.filter( - (m): m is FunctionCallOutput => m.type === 'function_call_output', - ); + const toolCallOutputs = toolMessages.filter( + (m): m is FunctionCallOutput => m.type === 'function_call_output', + ); - if (toolCallOutputs.length > 0) { - this.agent._chatCtx.insert(toolCallOutputs); - this.agentSession._toolItemsAdded(toolCallOutputs); + if (toolCallOutputs.length > 0) { + this.agent._chatCtx.insert(toolCallOutputs); + this.agentSession._toolItemsAdded(toolCallOutputs); + } } + } finally { + this._toolExecutionInProgress = false; } }; @@ -2471,201 +2481,209 @@ export class AgentActivity implements RecognitionHooks { onToolExecutionStarted, onToolExecutionCompleted, }); + this._toolExecutionInProgress = true; + try { + await speechHandle.waitIfNotInterrupted(tasks.map((task) => task.result)); - await speechHandle.waitIfNotInterrupted(tasks.map((task) => task.result)); - - // TODO(brian): add tracing span + // TODO(brian): add tracing span - if (audioOutput) { - await speechHandle.waitIfNotInterrupted([audioOutput.waitForPlayout()]); - } + if (audioOutput) { + await speechHandle.waitIfNotInterrupted([audioOutput.waitForPlayout()]); + } - if (speechHandle.interrupted) { - this.logger.debug( - { speech_id: speechHandle.id }, - 'Aborting all realtime generation tasks due to interruption', - ); - replyAbortController.abort(); - await cancelAndWait(tasks, AgentActivity.REPLY_TASK_CANCEL_TIMEOUT); + if (speechHandle.interrupted) { + this.logger.debug( + { speech_id: speechHandle.id }, + 'Aborting all realtime generation tasks due to interruption', + ); + replyAbortController.abort(); + await cancelAndWait(tasks, AgentActivity.REPLY_TASK_CANCEL_TIMEOUT); - if (messageOutputs.length > 0) { - // there should be only one message - const [msgId, textOut, audioOut, msgModalities] = messageOutputs[0]!; - let forwardedText = textOut?.text || ''; + if (messageOutputs.length > 0) { + // there should be only one message + const [msgId, textOut, audioOut, msgModalities] = messageOutputs[0]!; + let forwardedText = textOut?.text || ''; - if (audioOutput) { - audioOutput.clearBuffer(); - const playbackEv = await audioOutput.waitForPlayout(); - let playbackPositionInS = playbackEv.playbackPosition; - if (audioOut?.firstFrameFut.done && !audioOut.firstFrameFut.rejected) { - // playback EV is valid only if the first frame was already played - this.logger.info( - { speech_id: speechHandle.id, playbackPositionInS }, - 'playout interrupted', - ); - if (playbackEv.synchronizedTranscript) { - forwardedText = playbackEv.synchronizedTranscript; + if (audioOutput) { + audioOutput.clearBuffer(); + const playbackEv = await audioOutput.waitForPlayout(); + let playbackPositionInS = playbackEv.playbackPosition; + if (audioOut?.firstFrameFut.done && !audioOut.firstFrameFut.rejected) { + // playback EV is valid only if the first frame was already played + this.logger.info( + { speech_id: speechHandle.id, playbackPositionInS }, + 'playout interrupted', + ); + if (playbackEv.synchronizedTranscript) { + forwardedText = playbackEv.synchronizedTranscript; + } + } else { + forwardedText = ''; + playbackPositionInS = 0; } - } else { - forwardedText = ''; - playbackPositionInS = 0; + + // truncate server-side message + this.realtimeSession.truncate({ + messageId: msgId, + audioEndMs: Math.floor(playbackPositionInS * 1000), + modalities: msgModalities, + audioTranscript: forwardedText, + }); } - // truncate server-side message - this.realtimeSession.truncate({ - messageId: msgId, - audioEndMs: Math.floor(playbackPositionInS * 1000), - modalities: msgModalities, - audioTranscript: forwardedText, - }); + if (forwardedText) { + const message = ChatMessage.create({ + role: 'assistant', + content: forwardedText, + id: msgId, + interrupted: true, + }); + this.agent._chatCtx.insert(message); + speechHandle._itemAdded([message]); + this.agentSession._conversationItemAdded(message); + + // TODO(brian): add tracing span + } + this.logger.info( + { speech_id: speechHandle.id, message: forwardedText }, + 'playout completed with interrupt', + ); } + speechHandle._markGenerationDone(); + await executeToolsTask.cancelAndWait(AgentActivity.REPLY_TASK_CANCEL_TIMEOUT); - if (forwardedText) { - const message = ChatMessage.create({ - role: 'assistant', - content: forwardedText, - id: msgId, - interrupted: true, - }); - this.agent._chatCtx.insert(message); - speechHandle._itemAdded([message]); - this.agentSession._conversationItemAdded(message); + // TODO(brian): close tees + return; + } - // TODO(brian): add tracing span - } - this.logger.info( - { speech_id: speechHandle.id, message: forwardedText }, - 'playout completed with interrupt', - ); + if (messageOutputs.length > 0) { + // there should be only one message + const [msgId, textOut, _, __] = messageOutputs[0]!; + const message = ChatMessage.create({ + role: 'assistant', + content: textOut?.text || '', + id: msgId, + interrupted: false, + }); + this.agent._chatCtx.insert(message); + speechHandle._itemAdded([message]); + this.agentSession._conversationItemAdded(message); // mark the playout done before waiting for the tool execution\ + // TODO(brian): add tracing span } - speechHandle._markGenerationDone(); - await executeToolsTask.cancelAndWait(AgentActivity.REPLY_TASK_CANCEL_TIMEOUT); + // mark the playout done before waiting for the tool execution + speechHandle._markGenerationDone(); // TODO(brian): close tees - return; - } - if (messageOutputs.length > 0) { - // there should be only one message - const [msgId, textOut, _, __] = messageOutputs[0]!; - const message = ChatMessage.create({ - role: 'assistant', - content: textOut?.text || '', - id: msgId, - interrupted: false, - }); - this.agent._chatCtx.insert(message); - speechHandle._itemAdded([message]); - this.agentSession._conversationItemAdded(message); // mark the playout done before waiting for the tool execution\ - // TODO(brian): add tracing span - } + await executeToolsTask.result; - // mark the playout done before waiting for the tool execution - speechHandle._markGenerationDone(); - // TODO(brian): close tees + if (toolOutput.output.length > 0) { + this.agentSession._updateAgentState('thinking'); + } else if (this.agentSession.agentState === 'speaking') { + this.agentSession._updateAgentState('listening'); + } - await executeToolsTask.result; + if (toolOutput.output.length === 0) { + return; + } - if (toolOutput.output.length > 0) { - this.agentSession._updateAgentState('thinking'); - } else if (this.agentSession.agentState === 'speaking') { - this.agentSession._updateAgentState('listening'); - } + // important: no agent ouput should be used after this point + const { maxToolSteps } = this.agentSession.sessionOptions; + if (speechHandle.numSteps >= maxToolSteps) { + this.logger.warn( + { speech_id: speechHandle.id, max_tool_steps: maxToolSteps }, + 'maximum number of function calls steps reached', + ); + return; + } - if (toolOutput.output.length === 0) { - return; - } + const { + functionToolsExecutedEvent, + shouldGenerateToolReply, + newAgentTask, + ignoreTaskSwitch, + } = this.summarizeToolExecutionOutput(toolOutput, speechHandle); - // important: no agent ouput should be used after this point - const { maxToolSteps } = this.agentSession.sessionOptions; - if (speechHandle.numSteps >= maxToolSteps) { - this.logger.warn( - { speech_id: speechHandle.id, max_tool_steps: maxToolSteps }, - 'maximum number of function calls steps reached', + this.agentSession.emit( + AgentSessionEventTypes.FunctionToolsExecuted, + functionToolsExecutedEvent, ); - return; - } - const { functionToolsExecutedEvent, shouldGenerateToolReply, newAgentTask, ignoreTaskSwitch } = - this.summarizeToolExecutionOutput(toolOutput, speechHandle); + let schedulingPaused = this.schedulingPaused; + if (!ignoreTaskSwitch && newAgentTask !== null) { + this.agentSession.updateAgent(newAgentTask); + schedulingPaused = true; + } + + if (functionToolsExecutedEvent.functionCallOutputs.length > 0) { + // wait all speeches played before updating the tool output and generating the response + // most realtime models dont support generating multiple responses at the same time + while (this.currentSpeech || this.speechQueue.size() > 0) { + if ( + this.currentSpeech && + !this.currentSpeech.done() && + this.currentSpeech !== speechHandle + ) { + await this.currentSpeech.waitForPlayout(); + } else { + // Don't block the event loop + await new Promise((resolve) => setImmediate(resolve)); + } + } + const chatCtx = this.realtimeSession.chatCtx.copy(); + chatCtx.items.push(...functionToolsExecutedEvent.functionCallOutputs); - this.agentSession.emit( - AgentSessionEventTypes.FunctionToolsExecuted, - functionToolsExecutedEvent, - ); + this.agentSession._toolItemsAdded( + functionToolsExecutedEvent.functionCallOutputs as FunctionCallOutput[], + ); - let schedulingPaused = this.schedulingPaused; - if (!ignoreTaskSwitch && newAgentTask !== null) { - this.agentSession.updateAgent(newAgentTask); - schedulingPaused = true; - } - - if (functionToolsExecutedEvent.functionCallOutputs.length > 0) { - // wait all speeches played before updating the tool output and generating the response - // most realtime models dont support generating multiple responses at the same time - while (this.currentSpeech || this.speechQueue.size() > 0) { - if ( - this.currentSpeech && - !this.currentSpeech.done() && - this.currentSpeech !== speechHandle - ) { - await this.currentSpeech.waitForPlayout(); - } else { - // Don't block the event loop - await new Promise((resolve) => setImmediate(resolve)); + try { + await this.realtimeSession.updateChatCtx(chatCtx); + } catch (error) { + this.logger.warn( + { error }, + 'failed to update chat context before generating the function calls results', + ); } } - const chatCtx = this.realtimeSession.chatCtx.copy(); - chatCtx.items.push(...functionToolsExecutedEvent.functionCallOutputs); - - this.agentSession._toolItemsAdded( - functionToolsExecutedEvent.functionCallOutputs as FunctionCallOutput[], - ); - try { - await this.realtimeSession.updateChatCtx(chatCtx); - } catch (error) { - this.logger.warn( - { error }, - 'failed to update chat context before generating the function calls results', - ); + // skip realtime reply if not required or auto-generated + if (!shouldGenerateToolReply || this.llm.capabilities.autoToolReplyGeneration) { + return; } - } - - // skip realtime reply if not required or auto-generated - if (!shouldGenerateToolReply || this.llm.capabilities.autoToolReplyGeneration) { - return; - } - this.realtimeSession.interrupt(); - - const replySpeechHandle = SpeechHandle.create({ - allowInterruptions: speechHandle.allowInterruptions, - stepIndex: speechHandle.numSteps + 1, - parent: speechHandle, - }); - this.agentSession.emit( - AgentSessionEventTypes.SpeechCreated, - createSpeechCreatedEvent({ - userInitiated: false, - source: 'tool_response', - speechHandle: replySpeechHandle, - }), - ); + this.realtimeSession.interrupt(); - const toolChoice = schedulingPaused || modelSettings.toolChoice === 'none' ? 'none' : 'auto'; - this.createSpeechTask({ - taskFn: (abortController: AbortController) => - this.realtimeReplyTask({ + const replySpeechHandle = SpeechHandle.create({ + allowInterruptions: speechHandle.allowInterruptions, + stepIndex: speechHandle.numSteps + 1, + parent: speechHandle, + }); + this.agentSession.emit( + AgentSessionEventTypes.SpeechCreated, + createSpeechCreatedEvent({ + userInitiated: false, + source: 'tool_response', speechHandle: replySpeechHandle, - modelSettings: { toolChoice }, - abortController, }), - ownedSpeechHandle: replySpeechHandle, - name: 'AgentActivity.realtime_reply', - }); + ); - this.scheduleSpeech(replySpeechHandle, SpeechHandle.SPEECH_PRIORITY_NORMAL, true); + const toolChoice = schedulingPaused || modelSettings.toolChoice === 'none' ? 'none' : 'auto'; + this.createSpeechTask({ + taskFn: (abortController: AbortController) => + this.realtimeReplyTask({ + speechHandle: replySpeechHandle, + modelSettings: { toolChoice }, + abortController, + }), + ownedSpeechHandle: replySpeechHandle, + name: 'AgentActivity.realtime_reply', + }); + + this.scheduleSpeech(replySpeechHandle, SpeechHandle.SPEECH_PRIORITY_NORMAL, true); + } finally { + this._toolExecutionInProgress = false; + } } private summarizeToolExecutionOutput(toolOutput: ToolOutput, speechHandle: SpeechHandle) {