diff --git a/.changeset/tasty-pillows-serve.md b/.changeset/tasty-pillows-serve.md new file mode 100644 index 000000000..d7c1f0d31 --- /dev/null +++ b/.changeset/tasty-pillows-serve.md @@ -0,0 +1,5 @@ +--- +"@livekit/agents": patch +--- + +Reuse STT Pipeline Across Agent Handoff diff --git a/.gitignore b/.gitignore index e2f0e071a..07011cc16 100644 --- a/.gitignore +++ b/.gitignore @@ -203,3 +203,4 @@ examples/src/test_*.ts # OpenTelemetry trace test output .traces/ *.traces.json +.worktrees/ diff --git a/agents/src/utils.ts b/agents/src/utils.ts index a3d828599..414f3ac96 100644 --- a/agents/src/utils.ts +++ b/agents/src/utils.ts @@ -962,6 +962,41 @@ export async function waitForTrackPublication({ } } +/** + * Yields values from a ReadableStream until the stream ends or the signal is aborted. + * Handles reader cleanup and stream-release errors internally. + */ +export async function* readStream( + stream: ReadableStream, + signal?: AbortSignal, +): AsyncGenerator { + const reader = stream.getReader(); + try { + if (signal) { + const abortPromise = waitForAbort(signal); + while (true) { + const result = await Promise.race([reader.read(), abortPromise]); + if (!result) break; + const { done, value } = result; + if (done) break; + yield value; + } + } else { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + yield value; + } + } + } finally { + try { + reader.releaseLock(); + } catch { + // stream cleanup errors are expected (releasing reader, controller closed, etc.) + } + } +} + export async function waitForAbort(signal: AbortSignal) { const abortFuture = new Future(); const handler = () => { diff --git a/agents/src/voice/agent_activity.ts b/agents/src/voice/agent_activity.ts index bedebf3e2..e672b9e3f 100644 --- a/agents/src/voice/agent_activity.ts +++ b/agents/src/voice/agent_activity.ts @@ -61,6 +61,7 @@ import { type EndOfTurnInfo, type PreemptiveGenerationInfo, type RecognitionHooks, + type STTPipeline, } from './audio_recognition.js'; import { AgentSessionEventTypes, @@ -292,19 +293,27 @@ export class AgentActivity implements RecognitionHooks { this.isDefaultInterruptionByAudioActivityEnabled = this.isInterruptionByAudioActivityEnabled; } - async start(): Promise { + async start(options?: { reuseSttPipeline?: STTPipeline }): Promise { const unlock = await this.lock.lock(); try { - await this._startSession({ spanName: 'start_agent_activity', runOnEnter: true }); + await this._startSession({ + spanName: 'start_agent_activity', + runOnEnter: true, + reuseSttPipeline: options?.reuseSttPipeline, + }); } finally { unlock(); } } - async resume(): Promise { + async resume(options?: { reuseSttPipeline?: STTPipeline }): Promise { const unlock = await this.lock.lock(); try { - await this._startSession({ spanName: 'resume_agent_activity', runOnEnter: false }); + await this._startSession({ + spanName: 'resume_agent_activity', + runOnEnter: false, + reuseSttPipeline: options?.reuseSttPipeline, + }); } finally { unlock(); } @@ -313,8 +322,9 @@ export class AgentActivity implements RecognitionHooks { private async _startSession(options: { spanName: 'start_agent_activity' | 'resume_agent_activity'; runOnEnter: boolean; + reuseSttPipeline?: STTPipeline; }): Promise { - const { spanName, runOnEnter } = options; + const { spanName, runOnEnter, reuseSttPipeline } = options; const startSpan = tracer.startSpan({ name: spanName, attributes: { [traceTypes.ATTR_AGENT_LABEL]: this.agent.id }, @@ -415,9 +425,15 @@ export class AgentActivity implements RecognitionHooks { sttProvider: this.getSttProvider(), getLinkedParticipant: () => this.agentSession._roomIO?.linkedParticipant, }); - this.audioRecognition.start(); - this.started = true; + if (reuseSttPipeline) { + this.logger.debug('Reusing STT pipeline from previous activity'); + await this.audioRecognition.start({ sttPipeline: reuseSttPipeline }); + } else { + await this.audioRecognition.start(); + } + + this.started = true; this._resumeSchedulingTask(); if (runOnEnter) { @@ -438,6 +454,30 @@ export class AgentActivity implements RecognitionHooks { startSpan.end(); } + async _detachSttPipelineIfReusable(newActivity: AgentActivity): Promise { + const hasAudioRecognition = !!this.audioRecognition; + const hasSttOld = !!this.stt; + const hasSttNew = !!newActivity.stt; + const sameSttInstance = this.stt === newActivity.stt; + const sameSttNode = + Object.getPrototypeOf(this.agent).sttNode === + Object.getPrototypeOf(newActivity.agent).sttNode; + + if (!hasAudioRecognition || !hasSttOld || !hasSttNew) { + return undefined; + } + + if (!sameSttInstance) { + return undefined; + } + + if (!sameSttNode) { + return undefined; + } + + return await this.audioRecognition!.detachSttPipeline(); + } + get currentSpeech(): SpeechHandle | undefined { return this._currentSpeech; } diff --git a/agents/src/voice/agent_activity_handoff.test.ts b/agents/src/voice/agent_activity_handoff.test.ts new file mode 100644 index 000000000..e27fc507a --- /dev/null +++ b/agents/src/voice/agent_activity_handoff.test.ts @@ -0,0 +1,123 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import type { AudioFrame } from '@livekit/rtc-node'; +import { describe, expect, it, vi } from 'vitest'; +import { type SpeechEvent } from '../stt/stt.js'; +import { Agent } from './agent.js'; +import { AgentActivity } from './agent_activity.js'; + +type FakeActivity = { + agent: Agent; + audioRecognition: { detachSttPipeline: ReturnType } | undefined; + stt: unknown; +}; + +function createFakeActivity(agent: Agent, stt: unknown) { + const detachedPipeline = { id: Symbol('pipeline') }; + const activity = { + agent, + audioRecognition: { + detachSttPipeline: vi.fn(async () => detachedPipeline), + }, + stt, + } as FakeActivity; + + return { activity, detachedPipeline }; +} + +async function detachIfReusable(oldActivity: FakeActivity, newActivity: FakeActivity) { + return await (AgentActivity.prototype as any)._detachSttPipelineIfReusable.call( + oldActivity, + newActivity, + ); +} + +describe('AgentActivity STT handoff reuse eligibility', () => { + it('reuses the pipeline when both activities share the same STT instance and sttNode', async () => { + const sharedStt = { id: 'shared-stt' }; + const oldActivity = createFakeActivity(new Agent({ instructions: 'a' }), sharedStt); + const newActivity = createFakeActivity(new Agent({ instructions: 'b' }), sharedStt); + + const result = await detachIfReusable(oldActivity.activity, newActivity.activity); + + expect(result).toBe(oldActivity.detachedPipeline); + expect(oldActivity.activity.audioRecognition?.detachSttPipeline).toHaveBeenCalledTimes(1); + }); + + it('does not reuse when the STT instances differ', async () => { + const oldActivity = createFakeActivity(new Agent({ instructions: 'a' }), { id: 'stt-a' }); + const newActivity = createFakeActivity(new Agent({ instructions: 'b' }), { id: 'stt-b' }); + + const result = await detachIfReusable(oldActivity.activity, newActivity.activity); + + expect(result).toBeUndefined(); + expect(oldActivity.activity.audioRecognition?.detachSttPipeline).not.toHaveBeenCalled(); + }); + + it('does not reuse when either activity has no STT', async () => { + const sharedStt = { id: 'shared-stt' }; + const oldActivity = createFakeActivity(new Agent({ instructions: 'a' }), undefined); + const newActivity = createFakeActivity(new Agent({ instructions: 'b' }), sharedStt); + + const result = await detachIfReusable(oldActivity.activity, newActivity.activity); + + expect(result).toBeUndefined(); + expect(oldActivity.activity.audioRecognition?.detachSttPipeline).not.toHaveBeenCalled(); + }); + + it('does not reuse when the agents override sttNode differently', async () => { + const sharedStt = { id: 'shared-stt' }; + + class AgentA extends Agent { + async sttNode(_audio: ReadableStream, _modelSettings: any) { + return null as ReadableStream | null; + } + } + + class AgentB extends Agent { + async sttNode(_audio: ReadableStream, _modelSettings: any) { + return null as ReadableStream | null; + } + } + + const oldActivity = createFakeActivity(new AgentA({ instructions: 'a' }), sharedStt); + const newActivity = createFakeActivity(new AgentB({ instructions: 'b' }), sharedStt); + + const result = await detachIfReusable(oldActivity.activity, newActivity.activity); + + expect(result).toBeUndefined(); + expect(oldActivity.activity.audioRecognition?.detachSttPipeline).not.toHaveBeenCalled(); + }); + + it('reuses when the new agent inherits the same sttNode implementation', async () => { + const sharedStt = { id: 'shared-stt' }; + + class AgentA extends Agent { + async sttNode(_audio: ReadableStream, _modelSettings: any) { + return null as ReadableStream | null; + } + } + + class AgentB extends AgentA {} + + const oldActivity = createFakeActivity(new AgentA({ instructions: 'a' }), sharedStt); + const newActivity = createFakeActivity(new AgentB({ instructions: 'b' }), sharedStt); + + const result = await detachIfReusable(oldActivity.activity, newActivity.activity); + + expect(result).toBe(oldActivity.detachedPipeline); + expect(oldActivity.activity.audioRecognition?.detachSttPipeline).toHaveBeenCalledTimes(1); + }); + + it('does not reuse when the old activity has no audioRecognition', async () => { + const sharedStt = { id: 'shared-stt' }; + const oldActivity = createFakeActivity(new Agent({ instructions: 'a' }), sharedStt); + const newActivity = createFakeActivity(new Agent({ instructions: 'b' }), sharedStt); + oldActivity.activity.audioRecognition = undefined; + + const result = await detachIfReusable(oldActivity.activity, newActivity.activity); + + expect(result).toBeUndefined(); + }); +}); diff --git a/agents/src/voice/agent_session.ts b/agents/src/voice/agent_session.ts index 572a3b9fa..3cc11c79a 100644 --- a/agents/src/voice/agent_session.ts +++ b/agents/src/voice/agent_session.ts @@ -40,7 +40,7 @@ import { Task } from '../utils.js'; import type { VAD } from '../vad.js'; import type { Agent } from './agent.js'; import { AgentActivity } from './agent_activity.js'; -import type { _TurnDetector } from './audio_recognition.js'; +import type { STTPipeline, _TurnDetector } from './audio_recognition.js'; import { type AgentEvent, AgentSessionEventTypes, @@ -223,6 +223,7 @@ export class AgentSession< private _input: AgentInput; private _output: AgentOutput; + private closing = false; private closingTask: Promise | null = null; private userAwayTimer: NodeJS.Timeout | null = null; @@ -515,6 +516,7 @@ export class AgentSession< return; } + this.closing = false; this._usageCollector = new ModelUsageCollector(); let ctx: JobContext | undefined = undefined; @@ -760,6 +762,7 @@ export class AgentSession< const runWithContext = async () => { const unlock = await this.activityLock.lock(); let onEnterTask: Task | undefined; + let reusedSttPipeline: STTPipeline | undefined; try { this.agent = agent; @@ -782,6 +785,10 @@ export class AgentSession< this.nextActivity = agent._agentActivity; } + if (prevActivityObj && this.nextActivity && prevActivityObj !== this.nextActivity) { + reusedSttPipeline = await prevActivityObj._detachSttPipelineIfReusable(this.nextActivity); + } + if (prevActivityObj && prevActivityObj !== this.nextActivity) { if (previousActivity === 'pause') { await prevActivityObj.pause({ blockedTasks }); @@ -791,6 +798,18 @@ export class AgentSession< } } + if (this.closing && newActivity === 'start') { + this.logger.warn( + { agentId: this.nextActivity?.agent.id }, + 'Session is closing, skipping start of next activity', + ); + await reusedSttPipeline?.close(); + reusedSttPipeline = undefined; + this.nextActivity = undefined; + this.activity = undefined; + return; + } + this.activity = this.nextActivity; this.nextActivity = undefined; @@ -815,16 +834,22 @@ export class AgentSession< ); if (newActivity === 'start') { - await this.activity!.start(); + await this.activity!.start({ reuseSttPipeline: reusedSttPipeline }); } else { - await this.activity!.resume(); + await this.activity!.resume({ reuseSttPipeline: reusedSttPipeline }); } + reusedSttPipeline = undefined; onEnterTask = this.activity!._onEnterTask; if (this._input.audio) { this.activity!.attachAudioInput(this._input.audio.stream); } + } catch (error) { + // JS safeguard: session cleanup owns the detached pipeline until the next activity + // starts successfully, preventing leaks when handoff fails mid-transition. + await reusedSttPipeline?.close(); + throw error; } finally { unlock(); } @@ -1130,6 +1155,7 @@ export class AgentSession< return; } + this.closing = true; this._cancelUserAwayTimer(); this._onAecWarmupExpired(); this.off(AgentSessionEventTypes.UserInputTranscribed, this._onUserInputTranscribed); diff --git a/agents/src/voice/agent_session_handoff.test.ts b/agents/src/voice/agent_session_handoff.test.ts new file mode 100644 index 000000000..d3e054f3c --- /dev/null +++ b/agents/src/voice/agent_session_handoff.test.ts @@ -0,0 +1,228 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { describe, expect, it, vi } from 'vitest'; +import { ChatContext } from '../llm/chat_context.js'; +import { Agent } from './agent.js'; +import { AgentActivity } from './agent_activity.js'; +import { AgentSession } from './agent_session.js'; + +function createFakeLock() { + return { + lock: vi.fn(async () => () => {}), + }; +} + +function createFakeSession() { + return { + activityLock: createFakeLock(), + rootSpanContext: undefined, + agent: undefined, + activity: undefined, + nextActivity: undefined, + _globalRunState: undefined, + _chatCtx: ChatContext.empty(), + logger: { + debug: vi.fn(), + warn: vi.fn(), + }, + sessionOptions: { + turnHandling: { + interruption: { + enabled: true, + minDuration: 0, + minWords: 0, + }, + endpointing: { + minDelay: 0, + maxDelay: 0, + }, + }, + }, + interruptionDetection: undefined, + turnDetection: undefined, + vad: undefined, + stt: undefined, + llm: undefined, + tts: undefined, + useTtsAlignedTranscript: false, + _input: {}, + } as unknown as AgentSession; +} + +describe('AgentSession STT pipeline handoff', () => { + it('passes a detached STT pipeline into the next resumed activity', async () => { + const pipeline = { + close: vi.fn(async () => {}), + }; + const previousAgent = new Agent({ instructions: 'old' }); + const nextAgent = new Agent({ instructions: 'new' }); + const previousActivity = { + agent: previousAgent, + _detachSttPipelineIfReusable: vi.fn(async () => pipeline), + drain: vi.fn(async () => {}), + close: vi.fn(async () => {}), + pause: vi.fn(async () => {}), + }; + const nextActivity = { + agent: nextAgent, + resume: vi.fn(async () => {}), + start: vi.fn(async () => {}), + attachAudioInput: vi.fn(), + _onEnterTask: undefined, + }; + nextAgent._agentActivity = nextActivity as any; + + const session = createFakeSession(); + (session as any).activity = previousActivity as any; + + await AgentSession.prototype._updateActivity.call(session, nextAgent, { + newActivity: 'resume', + waitOnEnter: false, + }); + + expect(previousActivity._detachSttPipelineIfReusable).toHaveBeenCalledWith(nextActivity); + expect(nextActivity.resume).toHaveBeenCalledWith({ reuseSttPipeline: pipeline }); + expect(pipeline.close).not.toHaveBeenCalled(); + }); + + it('closes the detached pipeline if the next activity fails to start', async () => { + const pipeline = { + close: vi.fn(async () => {}), + }; + const previousAgent = new Agent({ instructions: 'old' }); + const nextAgent = new Agent({ instructions: 'new' }); + const previousActivity = { + agent: previousAgent, + _detachSttPipelineIfReusable: vi.fn(async () => pipeline), + drain: vi.fn(async () => {}), + close: vi.fn(async () => {}), + pause: vi.fn(async () => {}), + }; + const nextActivity = { + agent: nextAgent, + resume: vi.fn(async () => { + throw new Error('resume failed'); + }), + start: vi.fn(async () => {}), + attachAudioInput: vi.fn(), + _onEnterTask: undefined, + }; + nextAgent._agentActivity = nextActivity as any; + + const session = createFakeSession(); + (session as any).activity = previousActivity as any; + + await expect( + AgentSession.prototype._updateActivity.call(session, nextAgent, { + newActivity: 'resume', + waitOnEnter: false, + }), + ).rejects.toThrow('resume failed'); + + expect(pipeline.close).toHaveBeenCalledTimes(1); + }); + + it('does not close the adopted pipeline after the next activity starts successfully', async () => { + const pipeline = { + close: vi.fn(async () => {}), + }; + const previousAgent = new Agent({ instructions: 'old' }); + const nextAgent = new Agent({ instructions: 'new' }); + const previousActivity = { + agent: previousAgent, + _detachSttPipelineIfReusable: vi.fn(async () => pipeline), + drain: vi.fn(async () => {}), + close: vi.fn(async () => {}), + pause: vi.fn(async () => {}), + }; + const nextActivity = { + agent: nextAgent, + resume: vi.fn(async () => {}), + start: vi.fn(async () => {}), + attachAudioInput: vi.fn(() => { + throw new Error('attach failed'); + }), + _onEnterTask: undefined, + }; + nextAgent._agentActivity = nextActivity as any; + + const session = createFakeSession(); + (session as any).activity = previousActivity as any; + (session as any)._input = { audio: { stream: {} } } as any; + + await expect( + AgentSession.prototype._updateActivity.call(session, nextAgent, { + newActivity: 'resume', + waitOnEnter: false, + }), + ).rejects.toThrow('attach failed'); + + expect(nextActivity.resume).toHaveBeenCalledWith({ reuseSttPipeline: pipeline }); + expect(pipeline.close).not.toHaveBeenCalled(); + }); + + it('skips STT detach when the same activity object is reused', async () => { + const agent = new Agent({ instructions: 'same' }); + const activity = { + agent, + _detachSttPipelineIfReusable: vi.fn(async () => undefined), + drain: vi.fn(async () => {}), + close: vi.fn(async () => {}), + pause: vi.fn(async () => {}), + resume: vi.fn(async () => {}), + start: vi.fn(async () => {}), + attachAudioInput: vi.fn(), + _onEnterTask: undefined, + }; + agent._agentActivity = activity as any; + + const session = createFakeSession(); + (session as any).activity = activity as any; + + await AgentSession.prototype._updateActivity.call(session, agent, { + newActivity: 'resume', + waitOnEnter: false, + }); + + expect(activity._detachSttPipelineIfReusable).not.toHaveBeenCalled(); + expect(activity.resume).toHaveBeenCalledWith({ reuseSttPipeline: undefined }); + }); + + it('skips starting a new activity while the session is closing and closes the detached pipeline', async () => { + const pipeline = { + close: vi.fn(async () => {}), + }; + const previousAgent = new Agent({ instructions: 'old' }); + const nextAgent = new Agent({ instructions: 'new' }); + const previousActivity = { + agent: previousAgent, + _detachSttPipelineIfReusable: vi.fn(async () => pipeline), + drain: vi.fn(async () => {}), + close: vi.fn(async () => {}), + pause: vi.fn(async () => {}), + }; + + const startSpy = vi.spyOn(AgentActivity.prototype, 'start').mockResolvedValue(undefined); + + try { + const session = createFakeSession(); + (session as any).activity = previousActivity as any; + (session as any).closing = true; + + await AgentSession.prototype._updateActivity.call(session, nextAgent, { + newActivity: 'start', + waitOnEnter: false, + }); + + expect(previousActivity._detachSttPipelineIfReusable).toHaveBeenCalledTimes(1); + expect(previousActivity.close).toHaveBeenCalledTimes(1); + expect(pipeline.close).toHaveBeenCalledTimes(1); + expect(startSpy).not.toHaveBeenCalled(); + expect((session as any).activity).toBeUndefined(); + expect((session as any).nextActivity).toBeUndefined(); + } finally { + startSpy.mockRestore(); + } + }); +}); diff --git a/agents/src/voice/audio_recognition.ts b/agents/src/voice/audio_recognition.ts index 81a778780..5390fa984 100644 --- a/agents/src/voice/audio_recognition.ts +++ b/agents/src/voice/audio_recognition.ts @@ -1,6 +1,7 @@ // SPDX-FileCopyrightText: 2025 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 +import { Mutex } from '@livekit/mutex'; import type { ParticipantKind } from '@livekit/rtc-node'; import { AudioFrame } from '@livekit/rtc-node'; import { @@ -24,13 +25,13 @@ import { import type { LanguageCode } from '../language.js'; import { type ChatContext } from '../llm/chat_context.js'; import { log } from '../log.js'; -import { DeferredReadableStream, isStreamReaderReleaseError } from '../stream/deferred_stream.js'; +import { DeferredReadableStream } from '../stream/deferred_stream.js'; import { IdentityTransform } from '../stream/identity_transform.js'; import { mergeReadableStreams } from '../stream/merge_readable_streams.js'; import { type StreamChannel, createStreamChannel } from '../stream/stream_channel.js'; import { type SpeechEvent, SpeechEventType } from '../stt/stt.js'; import { traceTypes, tracer } from '../telemetry/index.js'; -import { Task, delay, waitForAbort } from '../utils.js'; +import { Task, cancelAndWait, delay, readStream, waitForAbort } from '../utils.js'; import { type VAD, type VADEvent, VADEventType } from '../vad.js'; import type { TurnDetectionMode } from './agent_session.js'; import type { STTNode } from './io.js'; @@ -69,6 +70,49 @@ export interface RecognitionHooks { retrieveChatCtx: () => ChatContext; } +export class STTPipeline { + static readonly PUMP_TASK_CANCEL_TIMEOUT = 5000; + + private sttNode: STTNode; + private _audioChannel: StreamChannel = createStreamChannel(); + private _eventChannel: StreamChannel = createStreamChannel(); + private _pumpTask: Task; + + constructor(sttNode: STTNode) { + this.sttNode = sttNode; + this._pumpTask = Task.from(({ signal }) => this.sttPump(signal)); + this._pumpTask.addDoneCallback(() => this._eventChannel.close()); + } + + get audioChannel() { + return this._audioChannel; + } + + get eventChannel() { + return this._eventChannel; + } + + private async sttPump(signal: AbortSignal): Promise { + const node = await this.sttNode(this._audioChannel.stream(), {}); + if (node === null) return; + + try { + for await (const value of readStream(node, signal)) { + if (typeof value === 'string') { + throw new Error(`STT node must yield SpeechEvent, got: ${typeof value}`); + } + await this._eventChannel.write(value); + } + } finally { + await node.cancel().catch(() => {}); + } + } + + async close(): Promise { + await cancelAndWait([this._pumpTask], STTPipeline.PUMP_TASK_CANCEL_TIMEOUT); + } +} + export interface _TurnDetector { /** The model name used by this turn detector. */ readonly model: string; @@ -119,6 +163,7 @@ export interface ParticipantLike { export class AudioRecognition { private hooks: RecognitionHooks; private stt?: STTNode; + private sttPipeline?: STTPipeline; private vad?: VAD; private turnDetector?: _TurnDetector; private turnDetectionMode?: TurnDetectionMode; @@ -149,12 +194,15 @@ export class AudioRecognition { private sttInputStream: ReadableStream; private silenceAudioTransform = new IdentityTransform(); private silenceAudioWriter: WritableStreamDefaultWriter; + private sttOwnershipTransferred = false; + private readonly sttLifecycleLock = new Mutex(); // all cancellable tasks private bounceEOUTask?: Task; private commitUserTurnTask?: Task; + private sttForwardTask?: Task; private vadTask?: Task; - private sttTask?: Task; + private sttConsumerTask?: Task; private interruptionTask?: Task; // interruption detection @@ -228,17 +276,14 @@ export class AudioRecognition { this.turnDetectionMode = options.turnDetection; } - async start() { + async start(options?: { sttPipeline?: STTPipeline }) { + this.startSttTasks(options?.sttPipeline); + this.vadTask = Task.from(({ signal }) => this.createVadTask(this.vad, signal)); this.vadTask.result.catch((err) => { this.logger.error(`Error running VAD task: ${err}`); }); - this.sttTask = Task.from(({ signal }) => this.createSttTask(this.stt, signal)); - this.sttTask.result.catch((err) => { - this.logger.error(`Error running STT task: ${err}`); - }); - this.interruptionTask = Task.from(({ signal }) => this.createInterruptionTask(this.interruptionDetection, signal), ); @@ -248,7 +293,8 @@ export class AudioRecognition { } async stop() { - await this.sttTask?.cancelAndWait(); + await this.sttConsumerTask?.cancelAndWait(); + await this.sttForwardTask?.cancelAndWait(); await this.vadTask?.cancelAndWait(); await this.interruptionTask?.cancelAndWait(); } @@ -887,56 +933,61 @@ export class AudioRecognition { }); } - private async createSttTask(stt: STTNode | undefined, signal: AbortSignal) { - if (!stt) return; + private startSttTasks(reusePipeline?: STTPipeline) { + if (!this.stt) return; - this.logger.debug('createSttTask: create stt stream from stt node'); + this.sttPipeline = reusePipeline ?? new STTPipeline(this.stt); - const sttStream = await stt(this.sttInputStream, {}); + this.transcriptBuffer = []; + this.ignoreUserTranscriptUntil = undefined; + this._inputStartedAt = undefined; + this.sttOwnershipTransferred = false; - if (signal.aborted || sttStream === null) return; + const pipeline = this.sttPipeline; - if (sttStream instanceof ReadableStream) { - const reader = sttStream.getReader(); + this.sttForwardTask = Task.from(({ signal }) => this.forwardInputAudioToStt(pipeline, signal)); + this.sttForwardTask.result.catch((err) => { + this.logger.error(`Error forwarding audio to STT pipeline: ${err}`); + }); - signal.addEventListener('abort', async () => { - try { - reader.releaseLock(); - await sttStream?.cancel(); - } catch (e) { - this.logger.debug('createSttTask: error during abort handler:', e); - } - }); + this.sttConsumerTask = Task.from(({ signal }) => this.consumeSttEvents(pipeline, signal)); + this.sttConsumerTask.result.catch((err) => { + this.logger.error(`Error running STT task: ${err}`); + }); + } - try { - while (true) { - if (signal.aborted) break; + private async stopSttTasks() { + await this.sttConsumerTask?.cancelAndWait(); + this.sttConsumerTask = undefined; + await this.sttForwardTask?.cancelAndWait(); + this.sttForwardTask = undefined; + } - const { done, value: ev } = await reader.read(); - if (done) break; + async detachSttPipeline(): Promise { + const unlock = await this.sttLifecycleLock.lock(); + try { + const pipeline = this.sttPipeline; + this.sttPipeline = undefined; + this.sttOwnershipTransferred = pipeline !== undefined; - if (typeof ev === 'string') { - throw new Error('STT node must yield SpeechEvent'); - } else { - await this.onSTTEvent(ev); - } - } - } catch (e) { - if (isStreamReaderReleaseError(e)) { - return; - } - this.logger.error({ error: e }, 'createSttTask: error reading sttStream'); - } finally { - reader.releaseLock(); - try { - await sttStream.cancel(); - } catch (e) { - this.logger.debug( - 'createSttTask: error cancelling sttStream (may already be cancelled):', - e, - ); - } - } + await this.sttConsumerTask?.cancelAndWait(); + this.sttConsumerTask = undefined; + + return pipeline; + } finally { + unlock(); + } + } + + private async forwardInputAudioToStt(pipeline: STTPipeline, signal: AbortSignal) { + for await (const frame of readStream(this.sttInputStream, signal)) { + await pipeline.audioChannel.write(frame); + } + } + + private async consumeSttEvents(pipeline: STTPipeline, signal: AbortSignal) { + for await (const ev of readStream(pipeline.eventChannel.stream(), signal)) { + await this.onSTTEvent(ev); } } @@ -1161,11 +1212,29 @@ export class AudioRecognition { this.finalTranscriptConfidence = []; this.userTurnCommitted = false; - this.sttTask?.cancelAndWait().finally(() => { - this.sttTask = Task.from(({ signal }) => this.createSttTask(this.stt, signal)); - this.sttTask.result.catch((err) => { - this.logger.error(`Error running STT task: ${err}`); - }); + const restartStt = async () => { + const unlock = await this.sttLifecycleLock.lock(); + try { + if (!this.stt || this.sttOwnershipTransferred) { + return; + } + + await this.stopSttTasks(); + await this.sttPipeline?.close(); + this.sttPipeline = undefined; + + if (this.sttOwnershipTransferred) { + return; + } + + this.startSttTasks(); + } finally { + unlock(); + } + }; + + void restartStt().catch((err) => { + this.logger.error(`Error resetting STT task: ${err}`); }); } @@ -1237,7 +1306,13 @@ export class AudioRecognition { this.detachInputAudioStream(); this.silenceAudioWriter.releaseLock(); await this.commitUserTurnTask?.cancelAndWait(); - await this.sttTask?.cancelAndWait(); + await this.stopSttTasks(); + + if (this.sttPipeline) { + await this.sttPipeline.close(); + this.sttPipeline = undefined; + } + await this.vadTask?.cancelAndWait(); await this.bounceEOUTask?.cancelAndWait(); await this.interruptionTask?.cancelAndWait(); diff --git a/agents/src/voice/audio_recognition_handoff.test.ts b/agents/src/voice/audio_recognition_handoff.test.ts new file mode 100644 index 000000000..80a715463 --- /dev/null +++ b/agents/src/voice/audio_recognition_handoff.test.ts @@ -0,0 +1,230 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { ReadableStreamDefaultController } from 'node:stream/web'; +import { describe, expect, it, vi } from 'vitest'; +import { ChatContext } from '../llm/chat_context.js'; +import { initializeLogger } from '../log.js'; +import { type SpeechEvent, SpeechEventType } from '../stt/stt.js'; +import { AudioRecognition, type RecognitionHooks, STTPipeline } from './audio_recognition.js'; +import type { STTNode } from './io.js'; + +function createHooks() { + const hooks: RecognitionHooks = { + onInterruption: vi.fn(), + onStartOfSpeech: vi.fn(), + onVADInferenceDone: vi.fn(), + onEndOfSpeech: vi.fn(), + onInterimTranscript: vi.fn(), + onFinalTranscript: vi.fn(), + onEndOfTurn: vi.fn(async () => true), + onPreemptiveGeneration: vi.fn(), + retrieveChatCtx: () => ChatContext.empty(), + }; + + return hooks; +} + +async function flushTasks() { + await new Promise((resolve) => setTimeout(resolve, 0)); +} + +async function waitFor(check: () => boolean, timeoutMs = 200) { + const startedAt = Date.now(); + while (!check()) { + if (Date.now() - startedAt > timeoutMs) { + throw new Error('timed out waiting for condition'); + } + await flushTasks(); + } +} + +function createRecognition(sttNode: STTNode, hooks = createHooks()) { + return { + hooks, + recognition: new AudioRecognition({ + recognitionHooks: hooks, + stt: sttNode, + minEndpointingDelay: 0, + maxEndpointingDelay: 0, + }), + }; +} + +describe('AudioRecognition STT pipeline handoff', () => { + initializeLogger({ pretty: false, level: 'silent' }); + + it('reuses an injected STT pipeline instead of opening a second STT stream', async () => { + let sttNodeCalls = 0; + + const sttNode: STTNode = async () => { + sttNodeCalls += 1; + return new ReadableStream({ + start() {}, + }); + }; + + const pipeline = new STTPipeline(sttNode); + const { recognition } = createRecognition(sttNode); + + try { + await recognition.start({ sttPipeline: pipeline }); + await waitFor(() => sttNodeCalls === 1); + + expect(sttNodeCalls).toBe(1); + } finally { + await recognition.close(); + await pipeline.close(); + } + }); + + it('detaches the pipeline so a new consumer can receive subsequent STT events', async () => { + let controller: ReadableStreamDefaultController | undefined; + + const sttNode: STTNode = async () => + new ReadableStream({ + start(ctrl) { + controller = ctrl; + }, + }); + + const pipeline = new STTPipeline(sttNode); + const first = createRecognition(sttNode); + const second = createRecognition(sttNode); + + try { + await first.recognition.start({ sttPipeline: pipeline }); + await waitFor(() => controller !== undefined); + + const detachedPipeline = await (first.recognition as any).detachSttPipeline(); + await first.recognition.close(); + + await second.recognition.start({ sttPipeline: detachedPipeline }); + await flushTasks(); + + controller?.enqueue({ + type: SpeechEventType.FINAL_TRANSCRIPT, + alternatives: [{ text: 'reused pipeline', confidence: 0.9 }], + }); + await waitFor(() => second.hooks.onFinalTranscript.mock.calls.length === 1); + + expect(first.hooks.onFinalTranscript).not.toHaveBeenCalled(); + expect(second.hooks.onFinalTranscript).toHaveBeenCalledTimes(1); + } finally { + controller?.close(); + await first.recognition.close(); + await second.recognition.close(); + await pipeline.close(); + } + }); + + it('resets handoff-sensitive STT state when attaching a pipeline', async () => { + const sttNode: STTNode = async () => + new ReadableStream({ + start() {}, + }); + + const pipeline = new STTPipeline(sttNode); + const { recognition } = createRecognition(sttNode); + + (recognition as any).transcriptBuffer = [ + { type: SpeechEventType.FINAL_TRANSCRIPT, alternatives: [{ text: 'stale transcript' }] }, + ]; + (recognition as any).ignoreUserTranscriptUntil = Date.now(); + (recognition as any)._inputStartedAt = Date.now(); + + try { + await recognition.start({ sttPipeline: pipeline }); + + expect((recognition as any).transcriptBuffer).toEqual([]); + expect((recognition as any).ignoreUserTranscriptUntil).toBeUndefined(); + expect((recognition as any)._inputStartedAt).toBeUndefined(); + } finally { + await recognition.close(); + await pipeline.close(); + } + }); + + it('recreates the owned STT pipeline when clearing the user turn', async () => { + let sttNodeCalls = 0; + + const sttNode: STTNode = async () => { + sttNodeCalls += 1; + return new ReadableStream({ + start() {}, + }); + }; + + const { recognition } = createRecognition(sttNode); + + try { + await recognition.start(); + await waitFor(() => sttNodeCalls === 1); + + recognition.clearUserTurn(); + + await waitFor(() => sttNodeCalls === 2); + } finally { + await recognition.close(); + } + }); + + it('keeps an STT pipeline alive across overlapping clearUserTurn calls', async () => { + let sttNodeCalls = 0; + + const sttNode: STTNode = async () => { + sttNodeCalls += 1; + return new ReadableStream({ + start() {}, + }); + }; + + const { recognition } = createRecognition(sttNode); + + try { + await recognition.start(); + await waitFor(() => sttNodeCalls === 1); + + recognition.clearUserTurn(); + recognition.clearUserTurn(); + + await flushTasks(); + await flushTasks(); + await flushTasks(); + + expect((recognition as any).sttPipeline).toBeDefined(); + } finally { + await recognition.close(); + } + }); + + it('does not recreate a new pipeline after ownership was detached for handoff', async () => { + let sttNodeCalls = 0; + + const sttNode: STTNode = async () => { + sttNodeCalls += 1; + return new ReadableStream({ + start() {}, + }); + }; + + const { recognition } = createRecognition(sttNode); + + try { + await recognition.start(); + await waitFor(() => sttNodeCalls === 1); + + await (recognition as any).detachSttPipeline(); + recognition.clearUserTurn(); + + await flushTasks(); + await flushTasks(); + await flushTasks(); + + expect(sttNodeCalls).toBe(1); + expect((recognition as any).sttPipeline).toBeUndefined(); + } finally { + await recognition.close(); + } + }); +}); diff --git a/examples/src/restaurant_agent.ts b/examples/src/restaurant_agent.ts index 34d357378..dd3d233bb 100644 --- a/examples/src/restaurant_agent.ts +++ b/examples/src/restaurant_agent.ts @@ -7,38 +7,20 @@ import { ServerOptions, cli, defineAgent, + inference, llm, voice, } from '@livekit/agents'; -import * as deepgram from '@livekit/agents-plugin-deepgram'; -import * as elevenlabs from '@livekit/agents-plugin-elevenlabs'; -import * as livekit from '@livekit/agents-plugin-livekit'; -import * as openai from '@livekit/agents-plugin-openai'; import * as silero from '@livekit/agents-plugin-silero'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; +// Ref: python examples/voice_agents/restaurant_agent.py - 29-34 lines const voices = { - greeter: { - id: '9BWtsMINqrJLrRacOk9x', // Aria - calm, professional female voice - name: 'Aria', - category: 'premade', - }, - reservation: { - id: 'EXAVITQu4vr4xnSDxMaL', // Sarah - warm, reassuring professional tone - name: 'Sarah', - category: 'premade', - }, - takeaway: { - id: 'CwhRBWXzGAHq8TQ4Fs17', // Roger - confident middle-aged male - name: 'Roger', - category: 'premade', - }, - checkout: { - id: '5Q0t7uMcjvnagumLfvZi', // Paul - authoritative middle-aged male - name: 'Paul', - category: 'premade', - }, + greeter: 'e07c00bc-4134-4eae-9ea4-1a55fb45746b', + reservation: '9626c31c-bec5-4cca-baa8-f8ba9e84c8bc', + takeaway: '5ee9feff-1265-424a-9d7f-8e4d431a12c7', + checkout: 'a167e0f3-df7e-4d52-a9c3-f949145efdab', }; type UserData = { @@ -189,8 +171,8 @@ function createGreeterAgent(menu: string) { const greeter = new BaseAgent({ name: 'greeter', instructions: `You are a friendly restaurant receptionist. The menu is: ${menu}\nYour jobs are to greet the caller and understand if they want to make a reservation or order takeaway. Guide them to the right agent using tools.`, - // TODO(brian): support parallel tool calls - tts: new elevenlabs.TTS({ voice: voices.greeter }), + llm: new inference.LLM({ model: 'openai/gpt-4.1-mini' }), + tts: new inference.TTS({ model: 'cartesia/sonic-3', voice: voices.greeter }), tools: { toReservation: llm.tool({ description: `Called when user wants to make or update a reservation. @@ -225,7 +207,7 @@ function createReservationAgent() { const reservation = new BaseAgent({ name: 'reservation', instructions: `You are a reservation agent at a restaurant. Your jobs are to ask for the reservation time, then customer's name, and phone number. Then confirm the reservation details with the customer.`, - tts: new elevenlabs.TTS({ voice: voices.reservation }), + tts: new inference.TTS({ model: 'cartesia/sonic-3', voice: voices.reservation }), tools: { updateName, updatePhone, @@ -267,7 +249,7 @@ function createTakeawayAgent(menu: string) { const takeaway = new BaseAgent({ name: 'takeaway', instructions: `Your are a takeaway agent that takes orders from the customer. Our menu is: ${menu}\nClarify special requests and confirm the order with the customer.`, - tts: new elevenlabs.TTS({ voice: voices.takeaway }), + tts: new inference.TTS({ model: 'cartesia/sonic-3', voice: voices.takeaway }), tools: { toGreeter, updateOrder: llm.tool({ @@ -303,7 +285,7 @@ function createCheckoutAgent(menu: string) { const checkout = new BaseAgent({ name: 'checkout', instructions: `You are a checkout agent at a restaurant. The menu is: ${menu}\nYour are responsible for confirming the expense of the order and then collecting customer's name, phone number and credit card information, including the card number, expiry date, and CVV step by step.`, - tts: new elevenlabs.TTS({ voice: voices.checkout }), + tts: new inference.TTS({ model: 'cartesia/sonic-3', voice: voices.checkout }), tools: { updateName, updatePhone, @@ -383,12 +365,11 @@ export default defineAgent({ const vad = ctx.proc.userData.vad! as silero.VAD; const session = new voice.AgentSession({ vad, - stt: new deepgram.STT(), - tts: new elevenlabs.TTS(), - llm: new openai.LLM(), + stt: new inference.STT({ model: 'deepgram/nova-3' }), + llm: new inference.LLM({ model: 'openai/gpt-4.1-mini' }), + tts: new inference.TTS({ model: 'cartesia/sonic-3' }), // to use realtime model, replace the stt, llm, tts and vad with the following - // llm: new openai.realtime.RealtimeModel(), - turnDetection: new livekit.turnDetector.EnglishModel(), + // llm: new openai.realtime.RealtimeModel({ voice: 'alloy' }), userData, voiceOptions: { maxToolSteps: 5,