diff --git a/agents/src/audio.ts b/agents/src/audio.ts index 4d39abb80..a7ee0e06a 100644 --- a/agents/src/audio.ts +++ b/agents/src/audio.ts @@ -4,9 +4,8 @@ import ffmpegInstaller from '@ffmpeg-installer/ffmpeg'; import { AudioFrame } from '@livekit/rtc-node'; import ffmpeg from 'fluent-ffmpeg'; -import type { ReadableStream } from 'node:stream/web'; import { log } from './log.js'; -import { createStreamChannel } from './stream/stream_channel.js'; +import { Chan, ChanClosed } from './stream/chan.js'; import { type AudioBuffer, isFfmpegTeardownError } from './utils.js'; ffmpeg.setFfmpegPath(ffmpegInstaller.path); @@ -110,12 +109,12 @@ export class AudioByteStream { export function audioFramesFromFile( filePath: string, options: AudioDecodeOptions = {}, -): ReadableStream { +): AsyncIterable { const sampleRate = options.sampleRate ?? 48000; const numChannels = options.numChannels ?? 1; const audioStream = new AudioByteStream(sampleRate, numChannels); - const channel = createStreamChannel(); + const chan = new Chan(); const logger = log(); // TODO (Brian): decode WAV using a custom decoder instead of FFmpeg @@ -139,7 +138,7 @@ export function audioFramesFromFile( const onClose = () => { logger.debug('Audio file playback aborted'); - channel.close(); + chan.close(); if (commandRunning) { commandRunning = false; command.kill('SIGKILL'); @@ -168,17 +167,27 @@ export function audioFramesFromFile( const frames = audioStream.write(arrayBuffer); for (const frame of frames) { - channel.write(frame); + try { + chan.sendNowait(frame); + } catch (e) { + if (e instanceof ChanClosed) return; + throw e; + } } }); outputStream.on('end', () => { const frames = audioStream.flush(); for (const frame of frames) { - channel.write(frame); + try { + chan.sendNowait(frame); + } catch (e) { + if (e instanceof ChanClosed) return; + throw e; + } } commandRunning = false; - channel.close(); + chan.close(); }); outputStream.on('error', (err: Error) => { @@ -187,7 +196,7 @@ export function audioFramesFromFile( onClose(); }); - return channel.stream(); + return chan; } /** diff --git a/agents/src/inference/interruption/http_transport.ts b/agents/src/inference/interruption/http_transport.ts index b28a9a6dc..e248d3db9 100644 --- a/agents/src/inference/interruption/http_transport.ts +++ b/agents/src/inference/interruption/http_transport.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import type { Throws } from '@livekit/throws-transformer/throws'; import { FetchError, ofetch } from 'ofetch'; -import { TransformStream } from 'stream/web'; import { z } from 'zod'; import { APIConnectionError, APIError, APIStatusError, isAPIError } from '../../_exceptions.js'; import { log } from '../../log.js'; @@ -113,8 +112,12 @@ export interface HttpTransportState { cache: BoundedCache; } +export type TransportFn = ( + source: AsyncIterable, +) => AsyncIterable; + /** - * Creates an HTTP transport TransformStream for interruption detection. + * Creates an HTTP transport async generator for interruption detection. * * This transport receives Int16Array audio slices and outputs InterruptionEvents. * Each audio slice triggers an HTTP POST request. @@ -128,80 +131,72 @@ export function createHttpTransport( setState: (partial: Partial) => void, updateUserSpeakingSpan?: (entry: InterruptionCacheEntry) => void, getAndResetNumRequests?: () => number, -): TransformStream { +): TransportFn { const logger = log(); - return new TransformStream( - { - async transform(chunk, controller) { - if (!(chunk instanceof Int16Array)) { - controller.enqueue(chunk); - return; - } + return async function* (source) { + for await (const chunk of source) { + if (!(chunk instanceof Int16Array)) { + yield chunk; + continue; + } - const state = getState(); - const overlapSpeechStartedAt = state.overlapSpeechStartedAt; - if (overlapSpeechStartedAt === undefined || !state.overlapSpeechStarted) return; - - try { - const resp = await predictHTTP( - chunk, - { threshold: options.threshold, minFrames: options.minFrames }, - { - baseUrl: options.baseUrl, - timeout: options.timeout, - maxRetries: options.maxRetries, - token: await createAccessToken(options.apiKey, options.apiSecret), - }, - ); - - const { createdAt, isBargein, probabilities, predictionDurationInS } = resp; - const entry = state.cache.setOrUpdate( - createdAt, - () => new InterruptionCacheEntry({ createdAt }), - { - probabilities, - isInterruption: isBargein, - speechInput: chunk, - totalDurationInS: (performance.now() - createdAt) / 1000, - detectionDelayInS: (Date.now() - overlapSpeechStartedAt) / 1000, - predictionDurationInS, - }, - ); - - if (state.overlapSpeechStarted && entry.isInterruption) { - if (updateUserSpeakingSpan) { - updateUserSpeakingSpan(entry); - } - const event: OverlappingSpeechEvent = { - type: 'overlapping_speech', - detectedAt: Date.now(), - overlapStartedAt: overlapSpeechStartedAt, - isInterruption: entry.isInterruption, - speechInput: entry.speechInput, - probabilities: entry.probabilities, - totalDurationInS: entry.totalDurationInS, - predictionDurationInS: entry.predictionDurationInS, - detectionDelayInS: entry.detectionDelayInS, - probability: entry.probability, - numRequests: getAndResetNumRequests?.() ?? 0, - }; - logger.debug( - { - detectionDelayInS: entry.detectionDelayInS, - totalDurationInS: entry.totalDurationInS, - }, - 'interruption detected', - ); - setState({ overlapSpeechStarted: false }); - controller.enqueue(event); - } - } catch (err) { - controller.error(err); + const state = getState(); + const overlapSpeechStartedAt = state.overlapSpeechStartedAt; + if (overlapSpeechStartedAt === undefined || !state.overlapSpeechStarted) continue; + + const resp = await predictHTTP( + chunk, + { threshold: options.threshold, minFrames: options.minFrames }, + { + baseUrl: options.baseUrl, + timeout: options.timeout, + maxRetries: options.maxRetries, + token: await createAccessToken(options.apiKey, options.apiSecret), + }, + ); + + const { createdAt, isBargein, probabilities, predictionDurationInS } = resp; + const entry = state.cache.setOrUpdate( + createdAt, + () => new InterruptionCacheEntry({ createdAt }), + { + probabilities, + isInterruption: isBargein, + speechInput: chunk, + totalDurationInS: (performance.now() - createdAt) / 1000, + detectionDelayInS: (Date.now() - overlapSpeechStartedAt) / 1000, + predictionDurationInS, + }, + ); + + if (state.overlapSpeechStarted && entry.isInterruption) { + if (updateUserSpeakingSpan) { + updateUserSpeakingSpan(entry); } - }, - }, - { highWaterMark: 2 }, - { highWaterMark: 2 }, - ); + const event: OverlappingSpeechEvent = { + type: 'overlapping_speech', + detectedAt: Date.now(), + overlapStartedAt: overlapSpeechStartedAt, + isInterruption: entry.isInterruption, + speechInput: entry.speechInput, + probabilities: entry.probabilities, + totalDurationInS: entry.totalDurationInS, + predictionDurationInS: entry.predictionDurationInS, + detectionDelayInS: entry.detectionDelayInS, + probability: entry.probability, + numRequests: getAndResetNumRequests?.() ?? 0, + }; + logger.debug( + { + detectionDelayInS: entry.detectionDelayInS, + totalDurationInS: entry.totalDurationInS, + }, + 'interruption detected', + ); + setState({ overlapSpeechStarted: false }); + yield event; + } + } + }; } diff --git a/agents/src/inference/interruption/interruption_stream.ts b/agents/src/inference/interruption/interruption_stream.ts index df6162aae..646711d0b 100644 --- a/agents/src/inference/interruption/interruption_stream.ts +++ b/agents/src/inference/interruption/interruption_stream.ts @@ -3,13 +3,11 @@ // SPDX-License-Identifier: Apache-2.0 import { AudioFrame, AudioResampler } from '@livekit/rtc-node'; import type { Span } from '@opentelemetry/api'; -import { type ReadableStream, TransformStream } from 'stream/web'; import { log } from '../../log.js'; import type { InterruptionMetrics } from '../../metrics/base.js'; -import { type StreamChannel, createStreamChannel } from '../../stream/stream_channel.js'; +import { Chan } from '../../stream/chan.js'; import { traceTypes } from '../../telemetry/index.js'; import { FRAMES_PER_SECOND, apiConnectDefaults } from './defaults.js'; -import type { InterruptionDetectionError } from './errors.js'; import { createHttpTransport } from './http_transport.js'; import { InterruptionCacheEntry } from './interruption_cache_entry.js'; import type { AdaptiveInterruptionDetector } from './interruption_detector.js'; @@ -76,9 +74,9 @@ function updateUserSpeakingSpan(span: Span, entry: InterruptionCacheEntry) { } export class InterruptionStreamBase { - private inputStream: StreamChannel; + private inputChan: Chan; - private eventStream: ReadableStream; + private eventStream: AsyncIterable; private resampler?: AudioResampler; @@ -112,10 +110,7 @@ export class InterruptionStreamBase { }; constructor(model: AdaptiveInterruptionDetector, apiOptions: Partial) { - this.inputStream = createStreamChannel< - InterruptionSentinel | AudioFrame, - InterruptionDetectionError - >(); + this.inputChan = new Chan(); this.model = model; this.options = { ...model.options }; @@ -133,7 +128,7 @@ export class InterruptionStreamBase { maxRetries: this.apiOptions.maxRetries, }; - this.eventStream = this.setupTransform(); + this.eventStream = this.setupPipeline(); } /** @@ -158,7 +153,9 @@ export class InterruptionStreamBase { } } - private setupTransform(): ReadableStream { + private setupPipeline(): AsyncIterable { + // eslint-disable-next-line @typescript-eslint/no-this-alias + const self = this; let agentSpeechStarted = false; let startIdx = 0; let accumulatedSamples = 0; @@ -198,119 +195,115 @@ export class InterruptionStreamBase { return n; }; - // First transform: process input frames/sentinels and output audio slices or events - const audioTransformer = new TransformStream< - InterruptionSentinel | AudioFrame, - Int16Array | OverlappingSpeechEvent - >( - { - transform: (chunk, controller) => { - if (chunk instanceof AudioFrame) { - if (!agentSpeechStarted) { - return; - } - if (this.options.sampleRate !== chunk.sampleRate) { - controller.error('the sample rate of the input frames must be consistent'); - this.logger.error('the sample rate of the input frames must be consistent'); - return; - } - const result = writeToInferenceS16Data( - chunk, - startIdx, - inferenceS16Data, - this.options.maxAudioDurationInS, + // Audio transform: process input frames/sentinels and output audio slices or events + const audioTransform = async function* ( + source: AsyncIterable, + ): AsyncIterable { + for await (const chunk of source) { + if (chunk instanceof AudioFrame) { + if (!agentSpeechStarted) { + continue; + } + if (self.options.sampleRate !== chunk.sampleRate) { + self.logger.error('the sample rate of the input frames must be consistent'); + throw new Error('the sample rate of the input frames must be consistent'); + } + const result = writeToInferenceS16Data( + chunk, + startIdx, + inferenceS16Data, + self.options.maxAudioDurationInS, + ); + startIdx = result.startIdx; + accumulatedSamples += result.samplesWritten; + + if ( + accumulatedSamples >= + Math.floor(self.options.detectionIntervalInS * self.options.sampleRate) && + overlapSpeechStarted + ) { + const audioSlice = inferenceS16Data.slice(0, startIdx); + accumulatedSamples = 0; + yield audioSlice; + } + } else if (chunk.type === 'agent-speech-started') { + self.logger.debug('agent speech started'); + agentSpeechStarted = true; + overlapSpeechStarted = false; + self.overlapSpeechStartedAt = undefined; + accumulatedSamples = 0; + overlapCount = 0; + startIdx = 0; + self.numRequests = 0; + cache.clear(); + } else if (chunk.type === 'agent-speech-ended') { + self.logger.debug('agent speech ended'); + agentSpeechStarted = false; + overlapSpeechStarted = false; + self.overlapSpeechStartedAt = undefined; + accumulatedSamples = 0; + overlapCount = 0; + startIdx = 0; + self.numRequests = 0; + cache.clear(); + } else if (chunk.type === 'overlap-speech-started' && agentSpeechStarted) { + self.overlapSpeechStartedAt = chunk.startedAt; + self.userSpeakingSpan = chunk.userSpeakingSpan; + self.logger.debug('overlap speech started, starting interruption inference'); + overlapSpeechStarted = true; + accumulatedSamples = 0; + overlapCount += 1; + if (overlapCount <= 1) { + const keepSize = + Math.round((chunk.speechDuration / 1000) * self.options.sampleRate) + + Math.round(self.options.audioPrefixDurationInS * self.options.sampleRate); + const shiftCount = Math.max(0, startIdx - keepSize); + inferenceS16Data.copyWithin(0, shiftCount, startIdx); + startIdx -= shiftCount; + } + cache.clear(); + } else if (chunk.type === 'overlap-speech-ended') { + self.logger.debug('overlap speech ended'); + if (overlapSpeechStarted) { + self.userSpeakingSpan = undefined; + let latestEntry = cache.pop( + (entry) => entry.totalDurationInS !== undefined && entry.totalDurationInS > 0, ); - startIdx = result.startIdx; - accumulatedSamples += result.samplesWritten; - - if ( - accumulatedSamples >= - Math.floor(this.options.detectionIntervalInS * this.options.sampleRate) && - overlapSpeechStarted - ) { - const audioSlice = inferenceS16Data.slice(0, startIdx); - accumulatedSamples = 0; - controller.enqueue(audioSlice); + if (!latestEntry) { + self.logger.debug('no request made for overlap speech'); + latestEntry = InterruptionCacheEntry.default(); } - } else if (chunk.type === 'agent-speech-started') { - this.logger.debug('agent speech started'); - agentSpeechStarted = true; - overlapSpeechStarted = false; - this.overlapSpeechStartedAt = undefined; - accumulatedSamples = 0; - overlapCount = 0; - startIdx = 0; - this.numRequests = 0; - cache.clear(); - } else if (chunk.type === 'agent-speech-ended') { - this.logger.debug('agent speech ended'); - agentSpeechStarted = false; + const e = latestEntry ?? InterruptionCacheEntry.default(); + const event: OverlappingSpeechEvent = { + type: 'overlapping_speech', + detectedAt: chunk.endedAt, + isInterruption: false, + overlapStartedAt: self.overlapSpeechStartedAt, + speechInput: e.speechInput, + probabilities: e.probabilities, + totalDurationInS: e.totalDurationInS, + detectionDelayInS: e.detectionDelayInS, + predictionDurationInS: e.predictionDurationInS, + probability: e.probability, + numRequests: getAndResetNumRequests(), + }; + yield event; overlapSpeechStarted = false; - this.overlapSpeechStartedAt = undefined; accumulatedSamples = 0; - overlapCount = 0; - startIdx = 0; - this.numRequests = 0; - cache.clear(); - } else if (chunk.type === 'overlap-speech-started' && agentSpeechStarted) { - this.overlapSpeechStartedAt = chunk.startedAt; - this.userSpeakingSpan = chunk.userSpeakingSpan; - this.logger.debug('overlap speech started, starting interruption inference'); - overlapSpeechStarted = true; - accumulatedSamples = 0; - overlapCount += 1; - if (overlapCount <= 1) { - const keepSize = - Math.round((chunk.speechDuration / 1000) * this.options.sampleRate) + - Math.round(this.options.audioPrefixDurationInS * this.options.sampleRate); - const shiftCount = Math.max(0, startIdx - keepSize); - inferenceS16Data.copyWithin(0, shiftCount, startIdx); - startIdx -= shiftCount; - } - cache.clear(); - } else if (chunk.type === 'overlap-speech-ended') { - this.logger.debug('overlap speech ended'); - if (overlapSpeechStarted) { - this.userSpeakingSpan = undefined; - let latestEntry = cache.pop( - (entry) => entry.totalDurationInS !== undefined && entry.totalDurationInS > 0, - ); - if (!latestEntry) { - this.logger.debug('no request made for overlap speech'); - latestEntry = InterruptionCacheEntry.default(); - } - const e = latestEntry ?? InterruptionCacheEntry.default(); - const event: OverlappingSpeechEvent = { - type: 'overlapping_speech', - detectedAt: chunk.endedAt, - isInterruption: false, - overlapStartedAt: this.overlapSpeechStartedAt, - speechInput: e.speechInput, - probabilities: e.probabilities, - totalDurationInS: e.totalDurationInS, - detectionDelayInS: e.detectionDelayInS, - predictionDurationInS: e.predictionDurationInS, - probability: e.probability, - numRequests: getAndResetNumRequests(), - }; - controller.enqueue(event); - overlapSpeechStarted = false; - accumulatedSamples = 0; - } - this.overlapSpeechStartedAt = undefined; - } else if (chunk.type === 'flush') { - // no-op } - }, - }, - { highWaterMark: 32 }, - { highWaterMark: 32 }, - ); + self.overlapSpeechStartedAt = undefined; + } else if (chunk.type === 'flush') { + // no-op + } + } + }; - // Second transform: transport layer (HTTP or WebSocket based on useProxy) + // Transport layer (HTTP or WebSocket based on useProxy) const transportOptions = this.transportOptions; - let transport: TransformStream; + let transportFn: ( + source: AsyncIterable, + ) => AsyncIterable; if (this.options.useProxy) { const wsResult = createWsTransport( transportOptions, @@ -320,10 +313,10 @@ export class InterruptionStreamBase { onRequestSent, getAndResetNumRequests, ); - transport = wsResult.transport; + transportFn = wsResult.transport; this.wsReconnect = wsResult.reconnect; } else { - transport = createHttpTransport( + transportFn = createHttpTransport( transportOptions, getState, setState, @@ -332,40 +325,39 @@ export class InterruptionStreamBase { ); } - const eventEmitter = new TransformStream({ - transform: (chunk, controller) => { - this.model.emit('overlapping_speech', chunk); + // Event emitter: emit model events and metrics for each overlapping speech event + const eventEmit = async function* ( + source: AsyncIterable, + ): AsyncIterable { + for await (const event of source) { + self.model.emit('overlapping_speech', event); const metrics: InterruptionMetrics = { type: 'interruption_metrics', - timestamp: chunk.detectedAt, - totalDuration: chunk.totalDurationInS * 1000, - predictionDuration: chunk.predictionDurationInS * 1000, - detectionDelay: chunk.detectionDelayInS * 1000, - numInterruptions: chunk.isInterruption ? 1 : 0, - numBackchannels: chunk.isInterruption ? 0 : 1, - numRequests: chunk.numRequests, + timestamp: event.detectedAt, + totalDuration: event.totalDurationInS * 1000, + predictionDuration: event.predictionDurationInS * 1000, + detectionDelay: event.detectionDelayInS * 1000, + numInterruptions: event.isInterruption ? 1 : 0, + numBackchannels: event.isInterruption ? 0 : 1, + numRequests: event.numRequests, metadata: { - modelProvider: this.model.provider, - modelName: this.model.model, + modelProvider: self.model.provider, + modelName: self.model.model, }, }; - this.model.emit('metrics_collected', metrics); + self.model.emit('metrics_collected', metrics); - controller.enqueue(chunk); - }, - }); + yield event; + } + }; - // Pipeline: input -> audioTransformer -> transport -> eventEmitter -> eventStream - return this.inputStream - .stream() - .pipeThrough(audioTransformer) - .pipeThrough(transport) - .pipeThrough(eventEmitter); + // Pipeline: inputChan -> audioTransform -> transport -> eventEmit + return eventEmit(transportFn(audioTransform(this.inputChan))); } private ensureInputNotEnded() { - if (this.inputStream.closed) { + if (this.inputChan.closed) { throw new Error('input stream is closed'); } } @@ -381,39 +373,39 @@ export class InterruptionStreamBase { return this.resampler; } - stream(): ReadableStream { + stream(): AsyncIterable { return this.eventStream; } async pushFrame(frame: InterruptionSentinel | AudioFrame): Promise { this.ensureStreamsNotEnded(); if (!(frame instanceof AudioFrame)) { - return this.inputStream.write(frame); + await this.inputChan.send(frame); } else if (this.options.sampleRate !== frame.sampleRate) { const resampler = this.getResamplerFor(frame.sampleRate); if (resampler.inputRate !== frame.sampleRate) { throw new Error('the sample rate of the input frames must be consistent'); } for (const resampledFrame of resampler.push(frame)) { - await this.inputStream.write(resampledFrame); + await this.inputChan.send(resampledFrame); } } else { - await this.inputStream.write(frame); + await this.inputChan.send(frame); } } async flush(): Promise { this.ensureStreamsNotEnded(); - await this.inputStream.write(InterruptionStreamSentinel.flush()); + await this.inputChan.send(InterruptionStreamSentinel.flush()); } async endInput(): Promise { await this.flush(); - await this.inputStream.close(); + this.inputChan.close(); } async close(): Promise { - if (!this.inputStream.closed) await this.inputStream.close(); + if (!this.inputChan.closed) this.inputChan.close(); this.model.removeStream(this); } } diff --git a/agents/src/inference/interruption/ws_transport.ts b/agents/src/inference/interruption/ws_transport.ts index 8a7316e73..2a3b34d4e 100644 --- a/agents/src/inference/interruption/ws_transport.ts +++ b/agents/src/inference/interruption/ws_transport.ts @@ -2,13 +2,14 @@ // // SPDX-License-Identifier: Apache-2.0 import type { Throws } from '@livekit/throws-transformer/throws'; -import { TransformStream } from 'stream/web'; import WebSocket from 'ws'; import { z } from 'zod'; import { APIConnectionError, APIStatusError, APITimeoutError } from '../../_exceptions.js'; import { log } from '../../log.js'; +import { Chan } from '../../stream/chan.js'; import TypedPromise from '../../typed_promise.js'; import { createAccessToken } from '../utils.js'; +import type { TransportFn } from './http_transport.js'; import { InterruptionCacheEntry } from './interruption_cache_entry.js'; import type { OverlappingSpeechEvent } from './types.js'; import type { BoundedCache } from './utils.js'; @@ -121,7 +122,7 @@ async function connectWebSocket( } export interface WsTransportResult { - transport: TransformStream; + transport: TransportFn; reconnect: () => Promise; } @@ -142,7 +143,8 @@ export function createWsTransport( ): WsTransportResult { const logger = log(); let ws: WebSocket | null = null; - let outputController: TransformStreamDefaultController | null = null; + let outputChan: Chan | null = null; + let transportError: unknown = null; function setupMessageHandler(socket: WebSocket): void { socket.on('message', (data: WebSocket.Data) => { @@ -155,9 +157,8 @@ export function createWsTransport( }); socket.on('error', (err: Error) => { - outputController?.error( - new APIConnectionError({ message: `WebSocket error: ${err.message}` }), - ); + transportError = new APIConnectionError({ message: `WebSocket error: ${err.message}` }); + outputChan?.close(); }); socket.on('close', (code: number, reason: Buffer) => { @@ -247,7 +248,11 @@ export function createWsTransport( numRequests: getAndResetNumRequests?.() ?? 0, }; - outputController?.enqueue(event); + try { + outputChan?.sendNowait(event); + } catch { + // Chan closed + } setState({ overlapSpeechStarted: false }); } break; @@ -292,12 +297,11 @@ export function createWsTransport( break; case MSG_ERROR: - outputController?.error( - new APIStatusError({ - message: `LiveKit Adaptive Interruption error: ${message.message}`, - options: { statusCode: message.code ?? -1 }, - }), - ); + transportError = new APIStatusError({ + message: `LiveKit Adaptive Interruption error: ${message.message}`, + options: { statusCode: message.code ?? -1 }, + }); + outputChan?.close(); break; } } @@ -358,59 +362,70 @@ export function createWsTransport( close(); } - const transport = new TransformStream< - Int16Array | OverlappingSpeechEvent, - OverlappingSpeechEvent - >( - { - async start(controller) { - outputController = controller; - await ensureConnection().catch((e) => { - controller.error(e); - }); - }, + const transport: TransportFn = async function* (source) { + outputChan = new Chan(); + transportError = null; - transform(chunk, controller) { - if (!(chunk instanceof Int16Array)) { - controller.enqueue(chunk); - return; - } + await ensureConnection(); + + // Pump source in background: consume input, send audio to WS, passthrough events + const pump = (async () => { + try { + for await (const chunk of source) { + if (!(chunk instanceof Int16Array)) { + try { + outputChan!.sendNowait(chunk); + } catch { + break; + } + continue; + } + + // Only forwards buffered audio while overlap speech is actively on. + const state = getState(); + if (!state.overlapSpeechStartedAt || !state.overlapSpeechStarted) continue; - // Only forwards buffered audio while overlap speech is actively on. - const state = getState(); - if (!state.overlapSpeechStartedAt || !state.overlapSpeechStarted) return; - - if (options.timeout > 0) { - const now = performance.now(); - for (const [, entry] of state.cache.entries()) { - if (entry.totalDurationInS !== 0) continue; - if (now - entry.createdAt > options.timeout) { - controller.error( - new APIStatusError({ + if (options.timeout > 0) { + const now = performance.now(); + for (const [, entry] of state.cache.entries()) { + if (entry.totalDurationInS !== 0) continue; + if (now - entry.createdAt > options.timeout) { + transportError = new APIStatusError({ message: `interruption inference timed out after ${((now - entry.createdAt) / 1000).toFixed(1)}s (ws)`, options: { statusCode: 408, retryable: false }, - }), - ); - return; + }); + outputChan!.close(); + return; + } + break; } - break; } - } - try { - sendAudioData(chunk); - } catch (err) { - controller.error(err); + try { + sendAudioData(chunk); + } catch (err) { + transportError = err; + outputChan!.close(); + return; + } } - }, - - flush() { + } finally { close(); - }, - }, - { highWaterMark: 2 }, - { highWaterMark: 2 }, - ); + outputChan!.close(); + } + })(); + + try { + for await (const event of outputChan) { + yield event; + } + if (transportError) { + throw transportError; + } + } finally { + await pump; + } + }; return { transport, reconnect }; } diff --git a/agents/src/inference/stt.ts b/agents/src/inference/stt.ts index 48688318f..b93a70093 100644 --- a/agents/src/inference/stt.ts +++ b/agents/src/inference/stt.ts @@ -7,7 +7,7 @@ import { APIError, APIStatusError } from '../_exceptions.js'; import { AudioByteStream } from '../audio.js'; import { type LanguageCode, areLanguagesEquivalent, normalizeLanguage } from '../language.js'; import { log } from '../log.js'; -import { createStreamChannel } from '../stream/stream_channel.js'; +import { Chan, ChanClosed } from '../stream/chan.js'; import { STT as BaseSTT, SpeechStream as BaseSpeechStream, @@ -390,7 +390,7 @@ export class SpeechStream extends BaseSpeechStream { let closing = false; let finalReceived = false; - const eventChannel = createStreamChannel(); + const eventChannel = new Chan(); const resourceCleanup = () => { if (closing) return; @@ -411,7 +411,11 @@ export class SpeechStream extends BaseSpeechStream { ws.on('message', (data) => { const json = JSON.parse(data.toString()) as SttServerEvent; - eventChannel.write(json); + try { + eventChannel.sendNowait(json); + } catch { + // Chan closed + } }); ws.on('error', (e) => { @@ -489,20 +493,15 @@ export class SpeechStream extends BaseSpeechStream { }; const recv = async (signal: AbortSignal) => { - const serverEventStream = eventChannel.stream(); - const reader = serverEventStream.getReader(); - try { - while (!this.closed && !signal.aborted) { - const result = await reader.read(); - if (signal.aborted) return; - if (result.done) return; + for await (const value of eventChannel) { + if (this.closed || signal.aborted) return; // Parse and validate with Zod schema - const parseResult = await sttServerEventSchema.safeParseAsync(result.value); + const parseResult = await sttServerEventSchema.safeParseAsync(value); if (!parseResult.success) { this.#logger.warn( - { error: parseResult.error, rawData: result.value }, + { error: parseResult.error, rawData: value }, 'Failed to parse STT server event', ); continue; @@ -530,13 +529,8 @@ export class SpeechStream extends BaseSpeechStream { throw new APIError(`LiveKit STT returned error: ${JSON.stringify(event)}`); } } - } finally { - reader.releaseLock(); - try { - await serverEventStream.cancel(); - } catch (e) { - this.#logger.debug('Error cancelling serverEventStream (may already be cancelled):', e); - } + } catch (e) { + if (!(e instanceof ChanClosed)) throw e; } }; diff --git a/agents/src/inference/tts.ts b/agents/src/inference/tts.ts index 3b91041f1..a2b0111cc 100644 --- a/agents/src/inference/tts.ts +++ b/agents/src/inference/tts.ts @@ -8,7 +8,7 @@ import { AudioByteStream } from '../audio.js'; import { ConnectionPool } from '../connection_pool.js'; import { type LanguageCode, normalizeLanguage } from '../language.js'; import { log } from '../log.js'; -import { createStreamChannel } from '../stream/stream_channel.js'; +import { Chan, ChanClosed } from '../stream/chan.js'; import { basic as tokenizeBasic } from '../tokenize/index.js'; import type { ChunkedStream } from '../tts/index.js'; import { SynthesizeStream as BaseSynthesizeStream, TTS as BaseTTS } from '../tts/index.js'; @@ -434,7 +434,7 @@ export class SynthesizeStream extends BaseSynthesizeSt let lastFrame: AudioFrame | undefined; const sendTokenizerStream = new tokenizeBasic.SentenceTokenizer().stream(); - const eventChannel = createStreamChannel(); + const eventChannel = new Chan(); const requestId = shortuuid('tts_request_'); const inputSentEvent = new Event(); @@ -445,8 +445,7 @@ export class SynthesizeStream extends BaseSynthesizeSt if (closing) return; closing = true; sendTokenizerStream.close(); - // close() returns a promise; don't leak it - await eventChannel.close(); + eventChannel.close(); }; const sendClientEvent = async (event: TtsClientEvent, ws: WebSocket, signal: AbortSignal) => { @@ -510,13 +509,13 @@ export class SynthesizeStream extends BaseSynthesizeSt try { const eventJson = JSON.parse(data.toString()) as Record; const validatedEvent = ttsServerEventSchema.parse(eventJson); - // writer.write returns a promise; avoid unhandled rejections if stream is closed - void eventChannel.write(validatedEvent).catch((error) => { - this.#logger.debug( - { error }, - 'Failed writing TTS event to stream channel (likely closed)', - ); - }); + try { + eventChannel.sendNowait(validatedEvent); + } catch (error) { + if (!(error instanceof ChanClosed)) { + this.#logger.debug({ error }, 'Failed writing TTS event to channel (likely closed)'); + } + } } catch (e) { this.#logger.error({ error: e }, 'Error parsing WebSocket message'); } @@ -589,15 +588,14 @@ export class SynthesizeStream extends BaseSynthesizeSt const recvTimeoutMs = this.connOptions.timeoutMs; const bstream = new AudioByteStream(this.opts.sampleRate, NUM_CHANNELS); - const serverEventStream = eventChannel.stream(); - const reader = serverEventStream.getReader(); + const iter = eventChannel[Symbol.asyncIterator](); try { await inputSentEvent.wait(); while (!this.closed && !signal.aborted) { const result = await waitUntilTimeout( - reader.read(), + iter.next(), recvTimeoutMs, () => new APITimeoutError({ message: 'TTS recv idle timeout' }), ); @@ -647,6 +645,7 @@ export class SynthesizeStream extends BaseSynthesizeSt } } } catch (e) { + if (e instanceof ChanClosed) return; if (e instanceof APITimeoutError) { this.#logger.warn('TTS recv task timed out waiting for server message'); await resourceCleanup(); @@ -654,13 +653,6 @@ export class SynthesizeStream extends BaseSynthesizeSt return; } throw e; - } finally { - reader.releaseLock(); - try { - await serverEventStream.cancel(); - } catch (e) { - this.#logger.debug('Error cancelling serverEventStream (may already be cancelled):', e); - } } }; diff --git a/agents/src/llm/realtime.ts b/agents/src/llm/realtime.ts index 864e25d2d..4c351a15e 100644 --- a/agents/src/llm/realtime.ts +++ b/agents/src/llm/realtime.ts @@ -3,8 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 import type { AudioFrame } from '@livekit/rtc-node'; import { EventEmitter } from 'events'; -import type { ReadableStream } from 'node:stream/web'; -import { DeferredReadableStream } from '../stream/deferred_stream.js'; +import { Chan, ChanClosed } from '../stream/chan.js'; import { Task } from '../utils.js'; import type { TimedString } from '../voice/io.js'; import type { ChatContext, FunctionCall } from './chat_context.js'; @@ -21,14 +20,14 @@ export interface MessageGeneration { /** * Text stream that may contain plain strings or TimedString objects with timestamps. */ - textStream: ReadableStream; - audioStream: ReadableStream; + textStream: AsyncIterable; + audioStream: AsyncIterable; modalities?: Promise<('text' | 'audio')[]>; } export interface GenerationCreatedEvent { - messageStream: ReadableStream; - functionStream: ReadableStream; + messageStream: AsyncIterable; + functionStream: AsyncIterable; userInitiated: boolean; /** Response ID for correlating metrics with spans */ responseId?: string; @@ -84,7 +83,8 @@ export abstract class RealtimeModel { export abstract class RealtimeSession extends EventEmitter { protected _realtimeModel: RealtimeModel; - private deferredInputStream = new DeferredReadableStream(); + private inputChan = new Chan(); + private _pumpAbort: AbortController | null = null; private _mainTask: Task; constructor(realtimeModel: RealtimeModel) { @@ -156,17 +156,32 @@ export abstract class RealtimeSession extends EventEmitter { } private async _mainTaskImpl(signal: AbortSignal): Promise { - const reader = this.deferredInputStream.stream.getReader(); - while (true) { - const { done, value } = await reader.read(); - if (done || signal.aborted) { + for await (const value of this.inputChan) { + if (signal.aborted) { break; } this.pushAudio(value); } } - setInputAudioStream(audioStream: ReadableStream): void { - this.deferredInputStream.setSource(audioStream); + setInputAudioStream(audioStream: AsyncIterable): void { + this._pumpAbort?.abort(); + const abort = new AbortController(); + this._pumpAbort = abort; + (async () => { + try { + for await (const frame of audioStream) { + if (abort.signal.aborted) break; + try { + this.inputChan.sendNowait(frame); + } catch (e) { + if (e instanceof ChanClosed) break; + throw e; + } + } + } catch { + // Source errors are silently consumed + } + })(); } } diff --git a/agents/src/stream/adapters.ts b/agents/src/stream/adapters.ts new file mode 100644 index 000000000..42aeb2c86 --- /dev/null +++ b/agents/src/stream/adapters.ts @@ -0,0 +1,168 @@ +// SPDX-FileCopyrightText: 2025 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { ReadableStream } from 'node:stream/web'; +import { IdleTimeoutError } from '../utils.js'; +import { Chan, ChanClosed } from './chan.js'; + +/** + * Convert a ReadableStream into an AsyncIterable backed by a Chan. + * + * This is an adapter for interop with external APIs (e.g., AudioStream from rtc-node) + * that still expose ReadableStream. The returned AsyncIterable can be used with + * `for await...of` and integrates cleanly with the Chan-based architecture. + * + * @param stream - The ReadableStream to convert + * @param signal - Optional AbortSignal to stop reading early + * @returns An AsyncIterable that yields all values from the stream + */ +export function fromReadableStream( + stream: ReadableStream, + signal?: AbortSignal, +): AsyncIterable { + const ch = new Chan(); + + // Pump the ReadableStream into the channel in the background + (async () => { + const reader = stream.getReader(); + try { + while (true) { + if (signal?.aborted) break; + const { done, value } = await reader.read(); + if (done) break; + try { + ch.sendNowait(value); + } catch (e) { + if (e instanceof ChanClosed) break; + throw e; + } + } + } catch { + // Stream errors are silently consumed; the channel will close + } finally { + reader.releaseLock(); + ch.close(); + } + })(); + + return signal ? ch.iter(signal) : ch; +} + +/** + * Convert an AsyncIterable into a ReadableStream. + * + * This is an adapter for interop with APIs that require ReadableStream + * (e.g., external libraries, WebRTC tracks). It consumes the async iterable + * and enqueues each value into the ReadableStream. + * + * @param iterable - The AsyncIterable to convert + * @param signal - Optional AbortSignal to stop iteration early + * @returns A ReadableStream that yields all values from the iterable + */ +export function toReadableStream( + iterable: AsyncIterable, + signal?: AbortSignal, +): ReadableStream { + return new ReadableStream({ + async start(controller) { + try { + for await (const value of iterable) { + if (signal?.aborted) break; + controller.enqueue(value); + } + controller.close(); + } catch (e) { + controller.error(e); + } + }, + }); +} + +/** + * Merge multiple AsyncIterables into a single AsyncIterable. + * + * All sources are consumed concurrently. Values are yielded in the order + * they arrive (interleaved). The output closes when all sources are exhausted. + * + * @param sources - The AsyncIterables to merge + * @returns A single AsyncIterable yielding values from all sources + */ +export function mergeAsyncIterables(...sources: AsyncIterable[]): AsyncIterable { + const ch = new Chan(); + + let remaining = sources.length; + if (remaining === 0) { + ch.close(); + return ch; + } + + for (const source of sources) { + (async () => { + try { + for await (const value of source) { + try { + ch.sendNowait(value); + } catch (e) { + if (e instanceof ChanClosed) return; + throw e; + } + } + } catch { + // Source errors are silently consumed + } finally { + remaining--; + if (remaining === 0) { + ch.close(); + } + } + })(); + } + + return ch; +} + +/** + * Wrap an AsyncIterable with an idle timeout on each `.next()` call. + * + * If the source does not yield a value within `timeoutMs` milliseconds, + * the iteration throws {@link IdleTimeoutError}. The timer resets after + * every successfully received value. + * + * @param source - The AsyncIterable to wrap + * @param timeoutMs - Maximum idle time between values in milliseconds + * @returns An AsyncIterable that throws IdleTimeoutError on stall + */ +export async function* withIdleTimeout( + source: AsyncIterable, + timeoutMs: number, +): AsyncGenerator { + const iter = source[Symbol.asyncIterator](); + let timedOut = false; + try { + while (true) { + let timer: ReturnType | undefined; + const result = await Promise.race([ + iter.next(), + new Promise((_, reject) => { + timer = setTimeout(() => reject(new IdleTimeoutError()), timeoutMs); + }), + ]).finally(() => clearTimeout(timer)); + if (result.done) break; + yield result.value; + } + } catch (e) { + if (e instanceof IdleTimeoutError) { + timedOut = true; + } + throw e; + } finally { + if (!timedOut) { + // Only attempt cleanup if we didn't time out — a timed-out iterator + // may be stuck on a never-resolving promise, so calling return() would hang. + await iter.return?.(undefined); + } else { + // Fire-and-forget: allow GC to clean up the stalled iterator + iter.return?.(undefined)?.catch?.(() => {}); + } + } +} diff --git a/agents/src/stream/index.ts b/agents/src/stream/index.ts index 076764fb3..a6ee0ed34 100644 --- a/agents/src/stream/index.ts +++ b/agents/src/stream/index.ts @@ -1,8 +1,14 @@ // SPDX-FileCopyrightText: 2024 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 +export { + fromReadableStream, + mergeAsyncIterables, + toReadableStream, + withIdleTimeout, +} from './adapters.js'; export { Chan, ChanClosed, ChanEmpty, ChanFull } from './chan.js'; -export { DeferredReadableStream } from './deferred_stream.js'; +export { DeferredReadableStream, isStreamReaderReleaseError } from './deferred_stream.js'; export { IdentityTransform } from './identity_transform.js'; export { mergeReadableStreams } from './merge_readable_streams.js'; export { MultiInputStream } from './multi_input_stream.js'; diff --git a/agents/src/stt/stt.ts b/agents/src/stt/stt.ts index cbe774630..53dc93086 100644 --- a/agents/src/stt/stt.ts +++ b/agents/src/stt/stt.ts @@ -4,13 +4,12 @@ import { type AudioFrame, AudioResampler } from '@livekit/rtc-node'; import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; import { EventEmitter } from 'node:events'; -import type { ReadableStream } from 'node:stream/web'; import { APIConnectionError, APIError } from '../_exceptions.js'; import { calculateAudioDurationSeconds } from '../audio.js'; import type { LanguageCode } from '../language.js'; import { log } from '../log.js'; import type { STTMetrics } from '../metrics/base.js'; -import { DeferredReadableStream } from '../stream/deferred_stream.js'; +import { Chan, ChanClosed } from '../stream/chan.js'; import { type APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS, intervalForRetry } from '../types.js'; import type { AudioBuffer } from '../utils.js'; import { AsyncIterableQueue, delay, startSoon, toError } from '../utils.js'; @@ -222,7 +221,8 @@ export abstract class SpeechStream implements AsyncIterableIterator abstract label: string; protected closed = false; #stt: STT; - private deferredInputStream: DeferredReadableStream; + private inputChan = new Chan(); + private _pumpAbort: AbortController | null = null; private logger = log(); private _connOptions: APIConnectOptions; private _startTimeOffset: number = 0; @@ -236,7 +236,6 @@ export abstract class SpeechStream implements AsyncIterableIterator ) { this.#stt = stt; this._connOptions = connectionOptions; - this.deferredInputStream = new DeferredReadableStream(); this.neededSampleRate = sampleRate; this.monitorMetrics(); this.pumpInput(); @@ -304,20 +303,12 @@ export abstract class SpeechStream implements AsyncIterableIterator } protected async pumpInput() { - // TODO(AJS-35): Implement STT with webstreams API - const inputStream = this.deferredInputStream.stream; - const reader = inputStream.getReader(); - try { - while (true) { - const { done, value } = await reader.read(); - if (done) break; + for await (const value of this.inputChan) { this.pushFrame(value); } } catch (error) { this.logger.error('Error in STTStream mainTask:', error); - } finally { - reader.releaseLock(); } } @@ -375,12 +366,30 @@ export abstract class SpeechStream implements AsyncIterableIterator this._startTimeOffset = value; } - updateInputStream(audioStream: ReadableStream) { - this.deferredInputStream.setSource(audioStream); + updateInputStream(audioStream: AsyncIterable) { + this._pumpAbort?.abort(); + const abort = new AbortController(); + this._pumpAbort = abort; + (async () => { + try { + for await (const frame of audioStream) { + if (abort.signal.aborted) break; + try { + this.inputChan.sendNowait(frame); + } catch (e) { + if (e instanceof ChanClosed) break; + throw e; + } + } + } catch { + // Source errors are silently consumed + } + })(); } detachInputStream() { - this.deferredInputStream.detachSource(); + this._pumpAbort?.abort(); + this._pumpAbort = null; } /** Push an audio frame to the STT */ diff --git a/agents/src/tts/tts.ts b/agents/src/tts/tts.ts index ab0477144..83bbc964d 100644 --- a/agents/src/tts/tts.ts +++ b/agents/src/tts/tts.ts @@ -5,11 +5,10 @@ import type { AudioFrame } from '@livekit/rtc-node'; import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; import type { Span } from '@opentelemetry/api'; import { EventEmitter } from 'node:events'; -import type { ReadableStream } from 'node:stream/web'; import { APIConnectionError, APIError } from '../_exceptions.js'; import { log } from '../log.js'; import type { TTSMetrics } from '../metrics/base.js'; -import { DeferredReadableStream } from '../stream/deferred_stream.js'; +import { Chan, ChanClosed } from '../stream/chan.js'; import { recordException, traceTypes, tracer } from '../telemetry/index.js'; import { type APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS, intervalForRetry } from '../types.js'; import { AsyncIterableQueue, delay, mergeFrames, startSoon, toError } from '../utils.js'; @@ -171,9 +170,8 @@ export abstract class SynthesizeStream protected connOptions: APIConnectOptions; protected abortController = new AbortController(); - private deferredInputStream: DeferredReadableStream< - string | typeof SynthesizeStream.FLUSH_SENTINEL - >; + private inputChan = new Chan(); + private _pumpAbort: AbortController | null = null; private logger = log(); abstract label: string; @@ -189,12 +187,11 @@ export abstract class SynthesizeStream constructor(tts: TTS, connOptions: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS) { this.#tts = tts; this.connOptions = connOptions; - this.deferredInputStream = new DeferredReadableStream(); this.pumpInput(); this.abortController.signal.addEventListener('abort', () => { - this.deferredInputStream.detachSource(); - // TODO (AJS-36) clean this up when we refactor with streams + this._pumpAbort?.abort(); + this.inputChan.close(); if (!this.input.closed) this.input.close(); if (!this.output.closed) this.output.close(); this.closed = true; @@ -277,31 +274,18 @@ export abstract class SynthesizeStream }); } - // NOTE(AJS-37): The implementation below uses an AsyncIterableQueue (`this.input`) - // bridged from a DeferredReadableStream (`this.deferredInputStream`) rather than - // consuming the stream directly. - // - // A full refactor to native Web Streams was considered but is currently deferred. - // The primary reason is to maintain architectural parity with the Python SDK, - // which is a key design goal for the project. This ensures a consistent developer - // experience across both platforms. - // - // For more context, see the discussion in GitHub issue # 844. protected async pumpInput() { - const reader = this.deferredInputStream.stream.getReader(); try { - while (true) { - const { done, value } = await reader.read(); - if (done || value === SynthesizeStream.FLUSH_SENTINEL) { + for await (const value of this.inputChan) { + if (value === SynthesizeStream.FLUSH_SENTINEL) { break; } this.pushText(value); } this.endInput(); } catch (error) { - this.logger.error(error, 'Error reading deferred input stream'); + this.logger.error(error, 'Error reading input channel'); } finally { - reader.releaseLock(); // Ensure output is closed when the stream ends if (!this.#monitorMetricsTask) { // No text was received, close the output directly @@ -391,8 +375,32 @@ export abstract class SynthesizeStream protected abstract run(): Promise; - updateInputStream(text: ReadableStream) { - this.deferredInputStream.setSource(text); + updateInputStream(text: AsyncIterable) { + this._pumpAbort?.abort(); + const abort = new AbortController(); + this._pumpAbort = abort; + (async () => { + try { + for await (const value of text) { + if (abort.signal.aborted) break; + try { + this.inputChan.sendNowait(value); + } catch (e) { + if (e instanceof ChanClosed) break; + throw e; + } + } + } catch { + // Source errors are silently consumed + } finally { + // Only close the channel if this pump was NOT replaced by a new one. + // If abort fired because updateInputStream was called again, the new + // pump owns the channel and we must not close it. + if (!abort.signal.aborted) { + this.inputChan.close(); + } + } + })(); } /** Push a string of text to the TTS */ diff --git a/agents/src/utils.test.ts b/agents/src/utils.test.ts index a44678d08..4f533f534 100644 --- a/agents/src/utils.test.ts +++ b/agents/src/utils.test.ts @@ -2,7 +2,6 @@ // // SPDX-License-Identifier: Apache-2.0 import { AudioFrame } from '@livekit/rtc-node'; -import { ReadableStream } from 'node:stream/web'; import { describe, expect, it } from 'vitest'; import { initializeLogger } from '../src/log.js'; import { Event, Task, TaskResult, delay, isPending, resampleStream } from '../src/utils.js'; @@ -637,17 +636,10 @@ describe('utils', () => { return new AudioFrame(data, sampleRate, channels, samples); }; - const streamToArray = async (stream: ReadableStream): Promise => { - const reader = stream.getReader(); + const iterableToArray = async (source: AsyncIterable): Promise => { const chunks: AudioFrame[] = []; - try { - while (true) { - const { done, value } = await reader.read(); - if (done) break; - chunks.push(value); - } - } finally { - reader.releaseLock(); + for await (const value of source) { + chunks.push(value); } return chunks; }; @@ -657,15 +649,12 @@ describe('utils', () => { const outputRate = 16000; const inputFrame = createAudioFrame(inputRate, 960); // 20ms at 48kHz - const inputStream = new ReadableStream({ - start(controller) { - controller.enqueue(inputFrame); - controller.close(); - }, - }); + const inputStream = (async function* () { + yield inputFrame; + })(); const outputStream = resampleStream({ stream: inputStream, outputRate }); - const outputFrames = await streamToArray(outputStream); + const outputFrames = await iterableToArray(outputStream); expect(outputFrames.length).toBeGreaterThan(0); @@ -679,15 +668,12 @@ describe('utils', () => { const sampleRate = 44100; const inputFrame = createAudioFrame(sampleRate, 1024); - const inputStream = new ReadableStream({ - start(controller) { - controller.enqueue(inputFrame); - controller.close(); - }, - }); + const inputStream = (async function* () { + yield inputFrame; + })(); const outputStream = resampleStream({ stream: inputStream, outputRate: sampleRate }); - const outputFrames = await streamToArray(outputStream); + const outputFrames = await iterableToArray(outputStream); expect(outputFrames.length).toBeGreaterThan(0); @@ -703,16 +689,13 @@ describe('utils', () => { const frame1 = createAudioFrame(inputRate, 640); const frame2 = createAudioFrame(inputRate, 640); - const inputStream = new ReadableStream({ - start(controller) { - controller.enqueue(frame1); - controller.enqueue(frame2); - controller.close(); - }, - }); + const inputStream = (async function* () { + yield frame1; + yield frame2; + })(); const outputStream = resampleStream({ stream: inputStream, outputRate }); - const outputFrames = await streamToArray(outputStream); + const outputFrames = await iterableToArray(outputStream); expect(outputFrames.length).toBeGreaterThan(0); @@ -723,14 +706,12 @@ describe('utils', () => { }); it('should handle empty stream', async () => { - const inputStream = new ReadableStream({ - start(controller) { - controller.close(); - }, - }); + const inputStream = (async function* (): AsyncIterable { + // empty stream + })(); const outputStream = resampleStream({ stream: inputStream, outputRate: 44100 }); - const outputFrames = await streamToArray(outputStream); + const outputFrames = await iterableToArray(outputStream); expect(outputFrames).toEqual([]); }); diff --git a/agents/src/utils.ts b/agents/src/utils.ts index 2dc5d4ee1..663e1e28d 100644 --- a/agents/src/utils.ts +++ b/agents/src/utils.ts @@ -12,8 +12,6 @@ import { AudioFrame, AudioResampler, RoomEvent } from '@livekit/rtc-node'; import type { Throws } from '@livekit/throws-transformer/throws'; import { AsyncLocalStorage } from 'node:async_hooks'; import { EventEmitter, once } from 'node:events'; -import type { ReadableStream } from 'node:stream/web'; -import { TransformStream, type TransformStreamDefaultController } from 'node:stream/web'; import { v4 as uuidv4 } from 'uuid'; import { log } from './log.js'; @@ -701,32 +699,33 @@ export function resampleStream({ stream, outputRate, }: { - stream: ReadableStream; + stream: AsyncIterable; outputRate: number; -}): ReadableStream { - let resampler: AudioResampler | null = null; - const transformStream = new TransformStream({ - transform(chunk: AudioFrame, controller: TransformStreamDefaultController) { - if (chunk.samplesPerChannel === 0) { - controller.enqueue(chunk); - return; - } - if (!resampler) { - resampler = new AudioResampler(chunk.sampleRate, outputRate); - } - for (const frame of resampler.push(chunk)) { - controller.enqueue(frame); +}): AsyncIterable { + return (async function* () { + let resampler: AudioResampler | null = null; + try { + for await (const chunk of stream) { + if (chunk.samplesPerChannel === 0) { + yield chunk; + continue; + } + if (!resampler) { + resampler = new AudioResampler(chunk.sampleRate, outputRate); + } + for (const frame of resampler.push(chunk)) { + yield frame; + } } - }, - flush(controller) { if (resampler) { for (const frame of resampler.flush()) { - controller.enqueue(frame); + yield frame; } } - }, - }); - return stream.pipeThrough(transformStream); + } finally { + // resampler cleanup happens via GC + } + })(); } export class InvalidErrorType extends Error { diff --git a/agents/src/vad.ts b/agents/src/vad.ts index 422c2654c..e85f87524 100644 --- a/agents/src/vad.ts +++ b/agents/src/vad.ts @@ -4,15 +4,10 @@ import type { AudioFrame } from '@livekit/rtc-node'; import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; import { EventEmitter } from 'node:events'; -import type { - ReadableStream, - ReadableStreamDefaultReader, - WritableStreamDefaultWriter, -} from 'node:stream/web'; import { log } from './log.js'; import type { VADMetrics } from './metrics/base.js'; -import { DeferredReadableStream } from './stream/deferred_stream.js'; -import { IdentityTransform } from './stream/identity_transform.js'; +import { Chan, ChanClosed } from './stream/chan.js'; +import { tee } from './stream/tee.js'; export enum VADEventType { START_OF_SPEECH, @@ -89,67 +84,38 @@ export abstract class VAD extends (EventEmitter as new () => TypedEmitter { protected static readonly FLUSH_SENTINEL = Symbol('FLUSH_SENTINEL'); - protected input = new IdentityTransform(); - protected output = new IdentityTransform(); - protected inputWriter: WritableStreamDefaultWriter; - protected inputReader: ReadableStreamDefaultReader; - protected outputWriter: WritableStreamDefaultWriter; - protected outputReader: ReadableStreamDefaultReader; + protected inputChan = new Chan(); + protected outputChan = new Chan(); protected closed = false; protected inputClosed = false; protected vad: VAD; protected lastActivityTime = BigInt(0); protected logger; - protected deferredInputStream: DeferredReadableStream; + private _pumpAbort: AbortController | null = null; + + private outputTee: ReturnType> | null = null; + private outputIter: AsyncIterableIterator | null = null; + private metricsIter: AsyncIterableIterator | null = null; - private metricsStream: ReadableStream; constructor(vad: VAD) { this.logger = log(); this.vad = vad; - this.deferredInputStream = new DeferredReadableStream(); - - this.inputWriter = this.input.writable.getWriter(); - this.inputReader = this.input.readable.getReader(); - this.outputWriter = this.output.writable.getWriter(); - const [outputStream, metricsStream] = this.output.readable.tee(); - this.metricsStream = metricsStream; - this.outputReader = outputStream.getReader(); + // Tee the output channel into two iterators: one for consumer, one for metrics + this.outputTee = tee(this.outputChan, 2); + this.outputIter = this.outputTee.get(0)[Symbol.asyncIterator](); + this.metricsIter = this.outputTee.get(1)[Symbol.asyncIterator](); - this.pumpDeferredStream(); this.monitorMetrics(); } - /** - * Reads from the deferred input stream and forwards chunks to the input writer. - * - * Note: we can't just do this.deferredInputStream.stream.pipeTo(this.input.writable) - * because the inputWriter locks the this.input.writable stream. All writes must go through - * the inputWriter. - */ - private async pumpDeferredStream() { - const reader = this.deferredInputStream.stream.getReader(); - try { - while (true) { - const { done, value } = await reader.read(); - if (done) break; - await this.inputWriter.write(value); - } - } catch (e) { - this.logger.error(`Error pumping deferred stream: ${e}`); - throw e; - } finally { - reader.releaseLock(); - } - } - protected async monitorMetrics() { let inferenceDurationTotalMs = 0; let inferenceCount = 0; - const metricsReader = this.metricsStream.getReader(); + if (!this.metricsIter) return; while (true) { - const { done, value } = await metricsReader.read(); + const { done, value } = await this.metricsIter.next(); if (done) { break; } @@ -184,8 +150,8 @@ export abstract class VADStream implements AsyncIterableIterator { } /** - * Safely send a VAD event to the output stream, handling writer release errors during shutdown. - * @returns true if the event was sent, false if the stream is closing + * Safely send a VAD event to the output channel, handling close errors during shutdown. + * @returns true if the event was sent, false if the channel is closing * @throws Error if an unexpected error occurs */ protected sendVADEvent(event: VADEvent): boolean { @@ -194,19 +160,38 @@ export abstract class VADStream implements AsyncIterableIterator { } try { - this.outputWriter.write(event); + this.outputChan.sendNowait(event); return true; } catch (e) { + if (e instanceof ChanClosed) return false; throw e; } } - updateInputStream(audioStream: ReadableStream) { - this.deferredInputStream.setSource(audioStream); + updateInputStream(audioStream: AsyncIterable) { + this._pumpAbort?.abort(); + const abort = new AbortController(); + this._pumpAbort = abort; + (async () => { + try { + for await (const frame of audioStream) { + if (abort.signal.aborted) break; + try { + this.inputChan.sendNowait(frame); + } catch (e) { + if (e instanceof ChanClosed) break; + throw e; + } + } + } catch { + // Source errors are silently consumed + } + })(); } detachInputStream() { - this.deferredInputStream.detachSource(); + this._pumpAbort?.abort(); + this._pumpAbort = null; } /** @deprecated Use `updateInputStream` instead */ @@ -218,7 +203,12 @@ export abstract class VADStream implements AsyncIterableIterator { if (this.closed) { throw new Error('Stream is closed'); } - this.inputWriter.write(frame); + try { + this.inputChan.sendNowait(frame); + } catch (e) { + if (e instanceof ChanClosed) return; + throw e; + } } flush() { @@ -228,7 +218,12 @@ export abstract class VADStream implements AsyncIterableIterator { if (this.closed) { throw new Error('Stream is closed'); } - this.inputWriter.write(VADStream.FLUSH_SENTINEL); + try { + this.inputChan.sendNowait(VADStream.FLUSH_SENTINEL); + } catch (e) { + if (e instanceof ChanClosed) return; + throw e; + } } endInput() { @@ -239,22 +234,23 @@ export abstract class VADStream implements AsyncIterableIterator { throw new Error('Stream is closed'); } this.inputClosed = true; - this.input.writable.close(); + this.inputChan.close(); } async next(): Promise> { - return this.outputReader.read().then(({ done, value }) => { - if (done) { - return { done: true, value: undefined }; - } - return { done: false, value }; - }); + if (!this.outputIter) { + return { done: true, value: undefined }; + } + return this.outputIter.next(); } close() { - this.outputWriter.releaseLock(); - this.outputReader.cancel(); - this.output.writable.close(); + this._pumpAbort?.abort(); + this.inputChan.close(); + this.outputChan.close(); + if (this.outputTee) { + this.outputTee.aclose(); + } this.closed = true; } diff --git a/agents/src/voice/agent.ts b/agents/src/voice/agent.ts index 3f83aee32..f96a4db15 100644 --- a/agents/src/voice/agent.ts +++ b/agents/src/voice/agent.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import type { AudioFrame } from '@livekit/rtc-node'; import { AsyncLocalStorage } from 'node:async_hooks'; -import { ReadableStream } from 'node:stream/web'; import { LLM as InferenceLLM, STT as InferenceSTT, @@ -272,18 +271,18 @@ export class Agent { async onExit(): Promise {} async transcriptionNode( - text: ReadableStream, + text: AsyncIterable, modelSettings: ModelSettings, - ): Promise | null> { + ): Promise | null> { return Agent.default.transcriptionNode(this, text, modelSettings); } async onUserTurnCompleted(_chatCtx: ChatContext, _newMessage: ChatMessage): Promise {} async sttNode( - audio: ReadableStream, + audio: AsyncIterable, modelSettings: ModelSettings, - ): Promise | null> { + ): Promise | null> { return Agent.default.sttNode(this, audio, modelSettings); } @@ -291,21 +290,21 @@ export class Agent { chatCtx: ChatContext, toolCtx: ToolContext, modelSettings: ModelSettings, - ): Promise | null> { + ): Promise | null> { return Agent.default.llmNode(this, chatCtx, toolCtx, modelSettings); } async ttsNode( - text: ReadableStream, + text: AsyncIterable, modelSettings: ModelSettings, - ): Promise | null> { + ): Promise | null> { return Agent.default.ttsNode(this, text, modelSettings); } async realtimeAudioOutputNode( - audio: ReadableStream, + audio: AsyncIterable, modelSettings: ModelSettings, - ): Promise | null> { + ): Promise | null> { return Agent.default.realtimeAudioOutputNode(this, audio, modelSettings); } @@ -341,9 +340,9 @@ export class Agent { static default = { async sttNode( agent: Agent, - audio: ReadableStream, + audio: AsyncIterable, _modelSettings: ModelSettings, - ): Promise | null> { + ): Promise | null> { const activity = agent.getActivityOrThrow(); if (!activity.stt) { throw new Error('sttNode called but no STT node is available'); @@ -383,22 +382,15 @@ export class Agent { stream.close(); }; - return new ReadableStream({ - async start(controller) { - try { - for await (const event of stream) { - controller.enqueue(event); - } - controller.close(); - } finally { - // Always clean up the STT stream, whether it ends naturally or is cancelled - cleanup(); + return (async function* () { + try { + for await (const event of stream) { + yield event; } - }, - cancel() { + } finally { cleanup(); - }, - }); + } + })(); }, async llmNode( @@ -406,7 +398,7 @@ export class Agent { chatCtx: ChatContext, toolCtx: ToolContext, modelSettings: ModelSettings, - ): Promise | null> { + ): Promise | null> { const activity = agent.getActivityOrThrow(); if (!activity.llm) { throw new Error('llmNode called but no LLM node is available'); @@ -437,28 +429,22 @@ export class Agent { stream.close(); }; - return new ReadableStream({ - async start(controller) { - try { - for await (const chunk of stream) { - controller.enqueue(chunk); - } - controller.close(); - } finally { - cleanup(); + return (async function* () { + try { + for await (const chunk of stream) { + yield chunk; } - }, - cancel() { + } finally { cleanup(); - }, - }); + } + })(); }, async ttsNode( agent: Agent, - text: ReadableStream, + text: AsyncIterable, _modelSettings: ModelSettings, - ): Promise | null> { + ): Promise | null> { const activity = agent.getActivityOrThrow(); if (!activity.tts) { throw new Error('ttsNode called but no TTS node is available'); @@ -481,43 +467,37 @@ export class Agent { stream.close(); }; - return new ReadableStream({ - async start(controller) { - try { - for await (const chunk of stream) { - if (chunk === SynthesizeStream.END_OF_STREAM) { - break; - } - // Attach timed transcripts to frame.userdata - if (chunk.timedTranscripts && chunk.timedTranscripts.length > 0) { - chunk.frame.userdata[USERDATA_TIMED_TRANSCRIPT] = chunk.timedTranscripts; - } - controller.enqueue(chunk.frame); + return (async function* () { + try { + for await (const chunk of stream) { + if (chunk === SynthesizeStream.END_OF_STREAM) { + break; } - controller.close(); - } finally { - cleanup(); + // Attach timed transcripts to frame.userdata + if (chunk.timedTranscripts && chunk.timedTranscripts.length > 0) { + chunk.frame.userdata[USERDATA_TIMED_TRANSCRIPT] = chunk.timedTranscripts; + } + yield chunk.frame; } - }, - cancel() { + } finally { cleanup(); - }, - }); + } + })(); }, async transcriptionNode( - agent: Agent, - text: ReadableStream, + _agent: Agent, + text: AsyncIterable, _modelSettings: ModelSettings, - ): Promise | null> { + ): Promise | null> { return text; }, async realtimeAudioOutputNode( _agent: Agent, - audio: ReadableStream, + audio: AsyncIterable, _modelSettings: ModelSettings, - ): Promise | null> { + ): Promise | null> { return audio; }, }; diff --git a/agents/src/voice/agent_activity.ts b/agents/src/voice/agent_activity.ts index bedebf3e2..f0b3736bb 100644 --- a/agents/src/voice/agent_activity.ts +++ b/agents/src/voice/agent_activity.ts @@ -7,7 +7,6 @@ import type { Span } from '@opentelemetry/api'; import { ROOT_CONTEXT, context as otelContext, trace } from '@opentelemetry/api'; import { Heap } from 'heap-js'; import { AsyncLocalStorage } from 'node:async_hooks'; -import { ReadableStream, TransformStream } from 'node:stream/web'; import type { InterruptionDetectionError } from '../inference/interruption/errors.js'; import { AdaptiveInterruptionDetector } from '../inference/interruption/interruption_detector.js'; import type { OverlappingSpeechEvent } from '../inference/interruption/types.js'; @@ -40,7 +39,8 @@ import type { TTSMetrics, VADMetrics, } from '../metrics/base.js'; -import { MultiInputStream } from '../stream/multi_input_stream.js'; +import { Chan, ChanClosed } from '../stream/chan.js'; +import { tee } from '../stream/tee.js'; import { STT, type STTError, type SpeechEvent } from '../stt/stt.js'; import { recordRealtimeMetrics, traceTypes, tracer } from '../telemetry/index.js'; import { splitWords } from '../tokenize/basic/word.js'; @@ -125,8 +125,8 @@ export class AgentActivity implements RecognitionHooks { private q_updated: Future; private speechTasks: Set> = new Set(); private lock = new Mutex(); - private audioStream = new MultiInputStream(); - private audioStreamId?: string; + private audioChan = new Chan(); + private _audioPumpAbort: AbortController | null = null; // default to null as None, which maps to the default provider tool choice value private toolChoice: ToolChoice | null = null; @@ -587,51 +587,58 @@ export class AgentActivity implements RecognitionHooks { } } - attachAudioInput(audioStream: ReadableStream): void { - void this.audioStream.close(); - this.audioStream = new MultiInputStream(); + attachAudioInput(audioStream: AsyncIterable): void { + // Close previous pump and channel, create fresh ones + this._audioPumpAbort?.abort(); + this.audioChan.close(); + this.audioChan = new Chan(); - // Filter is applied on this.audioStream.stream (downstream of MultiInputStream) rather - // than on the source audioStream via pipeThrough. pipeThrough locks its source stream, so - // if it were applied directly on audioStream, that lock would survive MultiInputStream.close() - // and make audioStream permanently locked for subsequent attachAudioInput calls (e.g. handoff). - const aecWarmupAudioFilter = new TransformStream({ - transform: (frame, controller) => { + // AEC warmup filter as an async generator + const filteredStream = async function* (this: AgentActivity) { + for await (const frame of this.audioChan) { const shouldDiscardForAecWarmup = this.agentSession.agentState === 'speaking' && this.agentSession._aecWarmupRemaining > 0; if (!shouldDiscardForAecWarmup) { - controller.enqueue(frame); + yield frame; } - }, - }); + } + }.call(this); - this.audioStreamId = this.audioStream.addInputStream(audioStream); + // Pump source into audioChan + const abort = new AbortController(); + this._audioPumpAbort = abort; + (async () => { + try { + for await (const frame of audioStream) { + if (abort.signal.aborted) break; + try { + this.audioChan.sendNowait(frame); + } catch (e) { + if (e instanceof ChanClosed) break; + throw e; + } + } + } catch { + // Source errors silently consumed + } + })(); if (this.realtimeSession && this.audioRecognition) { - const [realtimeAudioStream, recognitionAudioStream] = this.audioStream.stream - .pipeThrough(aecWarmupAudioFilter) - .tee(); - this.realtimeSession.setInputAudioStream(realtimeAudioStream); - this.audioRecognition.setInputAudioStream(recognitionAudioStream); + const teed = tee(filteredStream, 2); + this.realtimeSession.setInputAudioStream(teed.get(0)); + this.audioRecognition.setInputAudioStream(teed.get(1)); } else if (this.realtimeSession) { - this.realtimeSession.setInputAudioStream( - this.audioStream.stream.pipeThrough(aecWarmupAudioFilter), - ); + this.realtimeSession.setInputAudioStream(filteredStream); } else if (this.audioRecognition) { - this.audioRecognition.setInputAudioStream( - this.audioStream.stream.pipeThrough(aecWarmupAudioFilter), - ); + this.audioRecognition.setInputAudioStream(filteredStream); } } detachAudioInput(): void { - if (this.audioStreamId === undefined) { - return; - } - - void this.audioStream.close(); - this.audioStream = new MultiInputStream(); - this.audioStreamId = undefined; + this._audioPumpAbort?.abort(); + this._audioPumpAbort = null; + this.audioChan.close(); + this.audioChan = new Chan(); } commitUserTurn( @@ -657,9 +664,9 @@ export class AgentActivity implements RecognitionHooks { } say( - text: string | ReadableStream, + text: string | AsyncIterable, options?: { - audio?: ReadableStream; + audio?: AsyncIterable; allowInterruptions?: boolean; addToChatCtx?: boolean; }, @@ -1597,11 +1604,11 @@ export class AgentActivity implements RecognitionHooks { private async ttsTask( speechHandle: SpeechHandle, - text: string | ReadableStream, + text: string | AsyncIterable, addToChatCtx: boolean, modelSettings: ModelSettings, replyAbortController: AbortController, - audio?: ReadableStream | null, + audio?: AsyncIterable | null, ): Promise { speechHandle._agentTurnContext = otelContext.active(); @@ -1621,19 +1628,18 @@ export class AgentActivity implements RecognitionHooks { return; } - let baseStream: ReadableStream; - if (text instanceof ReadableStream) { - baseStream = text; + let baseIterable: AsyncIterable; + if (typeof text === 'string') { + baseIterable = (async function* () { + yield text; + })(); } else { - baseStream = new ReadableStream({ - start(controller) { - controller.enqueue(text); - controller.close(); - }, - }); + baseIterable = text; } - const [textSource, audioSource] = baseStream.tee(); + const teed = tee(baseIterable, 2); + const textSource = teed.get(0); + const audioSource = teed.get(1); const tasks: Array> = []; @@ -1833,12 +1839,13 @@ export class AgentActivity implements RecognitionHooks { let ttsTask: Task | null = null; let ttsGenData: _TTSGenerationData | null = null; - let llmOutput: ReadableStream; + let llmOutput: AsyncIterable; if (audioOutput) { // Only tee the stream when we need TTS - const [ttsTextInput, textOutput] = llmGenData.textStream.tee(); - llmOutput = textOutput; + const llmTee = tee(llmGenData.textStream, 2); + const ttsTextInput = llmTee.get(0); + llmOutput = llmTee.get(1); [ttsTask, ttsGenData] = performTTSInference( (...args) => this.agent.ttsNode(...args), ttsTextInput, @@ -1877,7 +1884,7 @@ export class AgentActivity implements RecognitionHooks { const replyStartedAt = Date.now(); // Determine the transcription input source - let transcriptionInput: ReadableStream = llmOutput; + let transcriptionInput: AsyncIterable = llmOutput; // Check if we should use TTS aligned transcripts if (this.useTtsAlignedTranscript && this.tts?.capabilities.alignedTranscript && ttsGenData) { @@ -2332,8 +2339,8 @@ export class AgentActivity implements RecognitionHooks { } const msgModalities = msg.modalities ? await msg.modalities : undefined; - let ttsTextInput: ReadableStream | null = null; - let trTextInput: ReadableStream; + let ttsTextInput: AsyncIterable | null = null; + let trTextInput: AsyncIterable; if (msgModalities && !msgModalities.includes('audio') && this.tts) { if (this.llm instanceof RealtimeModel && this.llm.capabilities.audioOutput) { @@ -2341,9 +2348,9 @@ export class AgentActivity implements RecognitionHooks { 'text response received from realtime API, falling back to use a TTS model.', ); } - const [_ttsTextInput, _trTextInput] = msg.textStream.tee(); - ttsTextInput = _ttsTextInput; - trTextInput = _trTextInput; + const msgTee = tee(msg.textStream, 2); + ttsTextInput = msgTee.get(0); + trTextInput = msgTee.get(1); } else { trTextInput = msg.textStream; } @@ -2362,7 +2369,7 @@ export class AgentActivity implements RecognitionHooks { let audioOut: _AudioOut | null = null; if (audioOutput) { - let realtimeAudioResult: ReadableStream | null = null; + let realtimeAudioResult: AsyncIterable | null = null; if (ttsTextInput) { const [ttsTask, ttsGenData] = performTTSInference( @@ -2430,25 +2437,25 @@ export class AgentActivity implements RecognitionHooks { ), ]; - const [toolCallStream, toolCallStreamForTracing] = ev.functionStream.tee(); + const toolTee = tee(ev.functionStream, 2); + const toolCallStream = toolTee.get(0); + const toolCallStreamForTracing = toolTee.get(1); // TODO(brian): append to tracing tees const toolCalls: FunctionCall[] = []; const readToolStreamTask = async ( controller: AbortController, - stream: ReadableStream, + stream: AsyncIterable, ) => { - const reader = stream.getReader(); try { - while (!controller.signal.aborted) { - const { done, value } = await reader.read(); - if (done) break; + for await (const value of stream) { + if (controller.signal.aborted) break; this.logger.debug({ tool_call: value }, 'received tool call from the realtime API'); toolCalls.push(value); } - } finally { - reader.releaseLock(); + } catch { + // Stream closed or source error } }; diff --git a/agents/src/voice/agent_session.ts b/agents/src/voice/agent_session.ts index 572a3b9fa..06447a59a 100644 --- a/agents/src/voice/agent_session.ts +++ b/agents/src/voice/agent_session.ts @@ -7,7 +7,6 @@ import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; import type { Context, Span } from '@opentelemetry/api'; import { ROOT_CONTEXT, context as otelContext, trace } from '@opentelemetry/api'; import { EventEmitter } from 'node:events'; -import type { ReadableStream } from 'node:stream/web'; import type { z } from 'zod'; import { LLM as InferenceLLM, @@ -601,9 +600,9 @@ export class AgentSession< } say( - text: string | ReadableStream, + text: string | AsyncIterable, options?: { - audio?: ReadableStream; + audio?: AsyncIterable; allowInterruptions?: boolean; addToChatCtx?: boolean; }, diff --git a/agents/src/voice/audio_recognition.ts b/agents/src/voice/audio_recognition.ts index 81a778780..15b7b2d0c 100644 --- a/agents/src/voice/audio_recognition.ts +++ b/agents/src/voice/audio_recognition.ts @@ -10,8 +10,6 @@ import { context as otelContext, trace, } from '@opentelemetry/api'; -import type { WritableStreamDefaultWriter } from 'node:stream/web'; -import { ReadableStream } from 'node:stream/web'; import { isAPIError } from '../_exceptions.js'; import { apiConnectDefaults, intervalForRetry } from '../inference/interruption/defaults.js'; import { InterruptionDetectionError } from '../inference/interruption/errors.js'; @@ -24,10 +22,9 @@ import { import type { LanguageCode } from '../language.js'; import { type ChatContext } from '../llm/chat_context.js'; import { log } from '../log.js'; -import { DeferredReadableStream, isStreamReaderReleaseError } from '../stream/deferred_stream.js'; -import { IdentityTransform } from '../stream/identity_transform.js'; -import { mergeReadableStreams } from '../stream/merge_readable_streams.js'; -import { type StreamChannel, createStreamChannel } from '../stream/stream_channel.js'; +import { mergeAsyncIterables } from '../stream/adapters.js'; +import { Chan, ChanClosed } from '../stream/chan.js'; +import { tee } from '../stream/tee.js'; import { type SpeechEvent, SpeechEventType } from '../stt/stt.js'; import { traceTypes, tracer } from '../telemetry/index.js'; import { Task, delay, waitForAbort } from '../utils.js'; @@ -130,7 +127,8 @@ export class AudioRecognition { private sttProvider?: string; private getLinkedParticipant?: () => ParticipantLike | undefined; - private deferredInputStream: DeferredReadableStream; + private inputChan = new Chan(); + private _pumpAbort: AbortController | null = null; private logger = log(); private lastFinalTranscriptTime = 0; private audioTranscript = ''; @@ -145,10 +143,9 @@ export class AudioRecognition { private userTurnSpan?: Span; - private vadInputStream: ReadableStream; - private sttInputStream: ReadableStream; - private silenceAudioTransform = new IdentityTransform(); - private silenceAudioWriter: WritableStreamDefaultWriter; + private vadInputStream: AsyncIterable; + private sttInputStream: AsyncIterable; + private silenceAudioChan = new Chan(); // all cancellable tasks private bounceEOUTask?: Task; @@ -164,7 +161,7 @@ export class AudioRecognition { private transcriptBuffer: SpeechEvent[]; private isInterruptionEnabled: boolean; private isAgentSpeaking: boolean; - private interruptionStreamChannel?: StreamChannel; + private interruptionChan?: Chan; private closed = false; constructor(opts: AudioRecognitionOptions) { @@ -181,31 +178,36 @@ export class AudioRecognition { this.sttProvider = opts.sttProvider; this.getLinkedParticipant = opts.getLinkedParticipant; - this.deferredInputStream = new DeferredReadableStream(); this.interruptionDetection = opts.interruptionDetection; this.transcriptBuffer = []; this.isInterruptionEnabled = !!(opts.interruptionDetection && opts.vad); this.isAgentSpeaking = false; if (opts.interruptionDetection) { - const [vadInputStream, teedInput] = this.deferredInputStream.stream.tee(); - const [inputStream, sttInputStream] = teedInput.tee(); - this.vadInputStream = vadInputStream; - this.sttInputStream = mergeReadableStreams( - sttInputStream, - this.silenceAudioTransform.readable, - ); - this.interruptionStreamChannel = createStreamChannel(); - this.interruptionStreamChannel.addStreamInput(inputStream); + const teed = tee(this.inputChan, 3); + this.vadInputStream = teed.get(0); + this.sttInputStream = mergeAsyncIterables(teed.get(1), this.silenceAudioChan); + this.interruptionChan = new Chan(); + // Pump teed[2] into interruptionChan + (async () => { + try { + for await (const frame of teed.get(2)) { + try { + this.interruptionChan!.sendNowait(frame); + } catch (e) { + if (e instanceof ChanClosed) break; + throw e; + } + } + } catch { + // Source errors silently consumed + } + })(); } else { - const [vadInputStream, sttInputStream] = this.deferredInputStream.stream.tee(); - this.vadInputStream = vadInputStream; - this.sttInputStream = mergeReadableStreams( - sttInputStream, - this.silenceAudioTransform.readable, - ); + const teed = tee(this.inputChan, 2); + this.vadInputStream = teed.get(0); + this.sttInputStream = mergeAsyncIterables(teed.get(1), this.silenceAudioChan); } - this.silenceAudioWriter = this.silenceAudioTransform.writable.getWriter(); } /** @@ -258,8 +260,8 @@ export class AudioRecognition { this.interruptionDetection = undefined; await this.interruptionTask?.cancelAndWait(); this.interruptionTask = undefined; - await this.interruptionStreamChannel?.close(); - this.interruptionStreamChannel = undefined; + this.interruptionChan?.close(); + this.interruptionChan = undefined; } async onStartOfAgentSpeech() { @@ -437,15 +439,12 @@ export class AudioRecognition { private async trySendInterruptionSentinel( frame: AudioFrame | InterruptionSentinel, ): Promise { - if ( - this.isInterruptionEnabled && - this.interruptionStreamChannel && - !this.interruptionStreamChannel.closed - ) { + if (this.isInterruptionEnabled && this.interruptionChan) { try { - await this.interruptionStreamChannel.write(frame); + this.interruptionChan.sendNowait(frame); return true; } catch (e: unknown) { + if (e instanceof ChanClosed) return false; this.logger.warn( `could not forward interruption sentinel: ${e instanceof Error ? e.message : String(e)}`, ); @@ -896,47 +895,18 @@ export class AudioRecognition { if (signal.aborted || sttStream === null) return; - if (sttStream instanceof ReadableStream) { - const reader = sttStream.getReader(); - - signal.addEventListener('abort', async () => { - try { - reader.releaseLock(); - await sttStream?.cancel(); - } catch (e) { - this.logger.debug('createSttTask: error during abort handler:', e); - } - }); - - try { - while (true) { - if (signal.aborted) break; - - const { done, value: ev } = await reader.read(); - if (done) break; + try { + for await (const ev of sttStream) { + if (signal.aborted) break; - if (typeof ev === 'string') { - throw new Error('STT node must yield SpeechEvent'); - } else { - await this.onSTTEvent(ev); - } - } - } catch (e) { - if (isStreamReaderReleaseError(e)) { - return; - } - this.logger.error({ error: e }, 'createSttTask: error reading sttStream'); - } finally { - reader.releaseLock(); - try { - await sttStream.cancel(); - } catch (e) { - this.logger.debug( - 'createSttTask: error cancelling sttStream (may already be cancelled):', - e, - ); + if (typeof ev === 'string') { + throw new Error('STT node must yield SpeechEvent'); + } else { + await this.onSTTEvent(ev); } } + } catch (e) { + this.logger.error({ error: e }, 'createSttTask: error reading sttStream'); } } @@ -1020,19 +990,17 @@ export class AudioRecognition { interruptionDetection: AdaptiveInterruptionDetector | undefined, signal: AbortSignal, ) { - if (!interruptionDetection || !this.interruptionStreamChannel) return; + if (!interruptionDetection || !this.interruptionChan) return; let numRetries = 0; const maxRetries = apiConnectDefaults.maxRetries; while (!signal.aborted) { const stream = interruptionDetection.createStream(); - const eventReader = stream.stream().getReader(); const cleanup = async () => { try { signal.removeEventListener('abort', cleanup); - eventReader.releaseLock(); await stream.close(); } catch (e) { this.logger.debug('createInterruptionTask: error during cleanup:', e); @@ -1052,16 +1020,10 @@ export class AudioRecognition { } forwardTask = (async () => { - const inputReader = this.interruptionStreamChannel!.stream().getReader(); - const abortPromise = waitForAbort(signal); - + if (!this.interruptionChan) return; try { - while (!signal.aborted) { - const res = await Promise.race([inputReader.read(), abortPromise]); - if (!res) break; - - const { value, done } = res; - if (done) break; + for await (const value of this.interruptionChan) { + if (signal.aborted) break; if (value instanceof AudioFrame) { const frameDurationMs = (value.samplesPerChannel / value.sampleRate) * 1000; @@ -1072,18 +1034,13 @@ export class AudioRecognition { await stream.pushFrame(value); } - } finally { - inputReader.releaseLock(); + } catch { + // Channel closed or source error } })(); - const abortPromise = waitForAbort(signal); - - while (!signal.aborted) { - const res = await Promise.race([eventReader.read(), abortPromise]); - if (!res) break; - const { done, value: ev } = res; - if (done) break; + for await (const ev of stream.stream()) { + if (signal.aborted) break; this.onOverlapSpeechEvent(ev); } break; @@ -1146,12 +1103,30 @@ export class AudioRecognition { this.logger.debug('Interruption task closed'); } - setInputAudioStream(audioStream: ReadableStream) { - this.deferredInputStream.setSource(audioStream); + setInputAudioStream(audioStream: AsyncIterable) { + this._pumpAbort?.abort(); + const abort = new AbortController(); + this._pumpAbort = abort; + (async () => { + try { + for await (const frame of audioStream) { + if (abort.signal.aborted) break; + try { + this.inputChan.sendNowait(frame); + } catch (e) { + if (e instanceof ChanClosed) break; + throw e; + } + } + } catch { + // Source errors are silently consumed + } + })(); } detachInputAudioStream() { - this.deferredInputStream.detachSource(); + this._pumpAbort?.abort(); + this._pumpAbort = null; } clearUserTurn() { @@ -1196,7 +1171,11 @@ export class AudioRecognition { const numSamples = Math.floor(this.sampleRate * 0.5); const silence = new Int16Array(numSamples * 2); const silenceFrame = new AudioFrame(silence, this.sampleRate, 1, numSamples); - this.silenceAudioWriter.write(silenceFrame); + try { + this.silenceAudioChan.sendNowait(silenceFrame); + } catch (e) { + if (!(e instanceof ChanClosed)) throw e; + } } // wait for the final transcript to be available @@ -1235,13 +1214,14 @@ export class AudioRecognition { async close() { this.closed = true; this.detachInputAudioStream(); - this.silenceAudioWriter.releaseLock(); + this.silenceAudioChan.close(); + this.inputChan.close(); + this.interruptionChan?.close(); await this.commitUserTurnTask?.cancelAndWait(); await this.sttTask?.cancelAndWait(); await this.vadTask?.cancelAndWait(); await this.bounceEOUTask?.cancelAndWait(); await this.interruptionTask?.cancelAndWait(); - await this.interruptionStreamChannel?.close(); } private _endUserTurnSpan({ diff --git a/agents/src/voice/audio_recognition_span.test.ts b/agents/src/voice/audio_recognition_span.test.ts index cfe92a821..ae6543b8b 100644 --- a/agents/src/voice/audio_recognition_span.test.ts +++ b/agents/src/voice/audio_recognition_span.test.ts @@ -9,7 +9,6 @@ import { SimpleSpanProcessor, } from '@opentelemetry/sdk-trace-base'; import { NodeTracerProvider } from '@opentelemetry/sdk-trace-node'; -import { ReadableStream } from 'node:stream/web'; import { describe, expect, it, vi } from 'vitest'; import { ChatContext } from '../llm/chat_context.js'; import { initializeLogger } from '../log.js'; @@ -131,13 +130,9 @@ describe('AudioRecognition user_turn span parity', () => { { type: SpeechEventType.END_OF_SPEECH }, ]; - const sttNode: STTNode = async () => - new ReadableStream({ - start(controller) { - for (const ev of sttEvents) controller.enqueue(ev); - controller.close(); - }, - }); + const sttNode: STTNode = async function* () { + for (const ev of sttEvents) yield ev; + }; const ar = new AudioRecognition({ recognitionHooks: hooks, @@ -240,13 +235,9 @@ describe('AudioRecognition user_turn span parity', () => { }, ]; - const sttNode: STTNode = async () => - new ReadableStream({ - start(controller) { - for (const ev of sttEvents) controller.enqueue(ev); - controller.close(); - }, - }); + const sttNode: STTNode = async function* () { + for (const ev of sttEvents) yield ev; + }; const ar = new AudioRecognition({ recognitionHooks: hooks, diff --git a/agents/src/voice/generation.ts b/agents/src/voice/generation.ts index d2a7e32a9..efd3094fd 100644 --- a/agents/src/voice/generation.ts +++ b/agents/src/voice/generation.ts @@ -5,7 +5,6 @@ import type { AudioFrame } from '@livekit/rtc-node'; import { AudioResampler } from '@livekit/rtc-node'; import type { Span } from '@opentelemetry/api'; import { context as otelContext } from '@opentelemetry/api'; -import type { ReadableStream, ReadableStreamDefaultReader } from 'stream/web'; import { type ChatContext, ChatMessage, @@ -22,7 +21,8 @@ import { } from '../llm/tool_context.js'; import { isZodSchema, parseZodSchema } from '../llm/zod-utils.js'; import { log } from '../log.js'; -import { IdentityTransform } from '../stream/identity_transform.js'; +import { withIdleTimeout } from '../stream/adapters.js'; +import { Chan, ChanClosed } from '../stream/chan.js'; import { traceTypes, tracer } from '../telemetry/index.js'; import { USERDATA_TIMED_TRANSCRIPT } from '../types.js'; import { @@ -64,8 +64,8 @@ export class _LLMGenerationData { ttft?: number; constructor( - public readonly textStream: ReadableStream, - public readonly toolCallStream: ReadableStream, + public readonly textStream: AsyncIterable, + public readonly toolCallStream: AsyncIterable, ) { this.id = shortuuid('item_'); this.generatedToolCalls = []; @@ -78,11 +78,11 @@ export class _LLMGenerationData { */ export interface _TTSGenerationData { /** Audio frame stream from TTS */ - audioStream: ReadableStream; + audioStream: AsyncIterable; /** * Future that resolves to a stream of timed transcripts, or null if TTS doesn't support it. */ - timedTextsFut: Future | null>; + timedTextsFut: Future | null>; /** Time to first byte (set when first audio frame is received) */ ttfb?: number; } @@ -430,12 +430,10 @@ export function performLLMInference( model?: string, provider?: string, ): [Task, _LLMGenerationData] { - const textStream = new IdentityTransform(); - const toolCallStream = new IdentityTransform(); + const textChan = new Chan(); + const toolCallChan = new Chan(); - const textWriter = textStream.writable.getWriter(); - const toolCallWriter = toolCallStream.writable.getWriter(); - const data = new _LLMGenerationData(textStream.readable, toolCallStream.readable); + const data = new _LLMGenerationData(textChan, toolCallChan); const _performLLMInferenceImpl = async (signal: AbortSignal, span: Span) => { span.setAttribute( @@ -451,32 +449,22 @@ export function performLLMInference( span.setAttribute(traceTypes.ATTR_GEN_AI_PROVIDER_NAME, provider); } - let llmStreamReader: ReadableStreamDefaultReader | null = null; - let llmStream: ReadableStream | null = null; + let llmStream: AsyncIterable | null = null; const startTime = performance.now() / 1000; // Convert to seconds let firstTokenReceived = false; try { llmStream = await node(chatCtx, toolCtx, modelSettings); if (llmStream === null) { - await textWriter.close(); + textChan.close(); return; } - const abortPromise = waitForAbort(signal); - // TODO(brian): add support for dynamic tools - llmStreamReader = llmStream.getReader(); - while (true) { + for await (const chunk of llmStream) { if (signal.aborted) break; - const result = await Promise.race([llmStreamReader.read(), abortPromise]); - if (result === undefined) break; - - const { done, value: chunk } = result; - if (done) break; - if (!firstTokenReceived) { firstTokenReceived = true; data.ttft = performance.now() / 1000 - startTime; @@ -484,7 +472,12 @@ export function performLLMInference( if (typeof chunk === 'string') { data.generatedText += chunk; - await textWriter.write(chunk); + try { + textChan.sendNowait(chunk); + } catch (e) { + if (e instanceof ChanClosed) break; + throw e; + } // TODO(shubhra): better way to check?? } else { if (chunk.delta === undefined) { @@ -505,13 +498,23 @@ export function performLLMInference( }); data.generatedToolCalls.push(toolCall); - await toolCallWriter.write(toolCall); + try { + toolCallChan.sendNowait(toolCall); + } catch (e) { + if (e instanceof ChanClosed) break; + throw e; + } } } if (chunk.delta.content) { data.generatedText += chunk.delta.content; - await textWriter.write(chunk.delta.content); + try { + textChan.sendNowait(chunk.delta.content); + } catch (e) { + if (e instanceof ChanClosed) break; + throw e; + } } } @@ -530,10 +533,8 @@ export function performLLMInference( } throw error; } finally { - llmStreamReader?.releaseLock(); - await llmStream?.cancel(); - await textWriter.close(); - await toolCallWriter.close(); + textChan.close(); + toolCallChan.close(); } }; @@ -554,47 +555,42 @@ export function performLLMInference( export function performTTSInference( node: TTSNode, - text: ReadableStream, + text: AsyncIterable, modelSettings: ModelSettings, controller: AbortController, model?: string, provider?: string, ): [Task, _TTSGenerationData] { const logger = log(); - const audioStream = new IdentityTransform(); - const outputWriter = audioStream.writable.getWriter(); - const audioOutputStream = audioStream.readable; + const audioChan = new Chan(); + const timedTextsChan = new Chan(); - const timedTextsFut = new Future | null>(); - const timedTextsStream = new IdentityTransform(); - const timedTextsWriter = timedTextsStream.writable.getWriter(); + const timedTextsFut = new Future | null>(); - // Transform stream to extract text from TimedString objects - const textOnlyStream = new IdentityTransform(); - const textOnlyWriter = textOnlyStream.writable.getWriter(); + // Transform iterable to extract text from TimedString objects + const textOnlyChan = new Chan(); (async () => { - const reader = text.getReader(); try { - while (true) { - const { done, value } = await reader.read(); - if (done) { - break; - } + for await (const value of text) { const textValue = typeof value === 'string' ? value : value.text; - await textOnlyWriter.write(textValue); + try { + textOnlyChan.sendNowait(textValue); + } catch (e) { + if (e instanceof ChanClosed) break; + throw e; + } } - await textOnlyWriter.close(); - } catch (e) { - await textOnlyWriter.abort(e as Error); + } catch { + // Source errors are silently consumed } finally { - reader.releaseLock(); + textOnlyChan.close(); } })(); let ttfb: number | undefined; const genData: _TTSGenerationData = { - audioStream: audioOutputStream, + audioStream: audioChan, timedTextsFut, ttfb: undefined, }; @@ -607,47 +603,36 @@ export function performTTSInference( span.setAttribute(traceTypes.ATTR_GEN_AI_PROVIDER_NAME, provider); } - let ttsStreamReader: ReadableStreamDefaultReader | null = null; - let ttsStream: ReadableStream | null = null; + let ttsStream: AsyncIterable | null = null; let pushedDuration = 0; const startTime = performance.now() / 1000; // Convert to seconds let firstByteReceived = false; try { - ttsStream = await node(textOnlyStream.readable, modelSettings); + ttsStream = await node(textOnlyChan, modelSettings); if (ttsStream === null) { timedTextsFut.resolve(null); - await outputWriter.close(); - await timedTextsWriter.close(); + audioChan.close(); + timedTextsChan.close(); return; } // This is critical: the future must be resolved with the channel/stream before the loop // so that agent_activity can start reading while we write if (!timedTextsFut.done) { - timedTextsFut.resolve(timedTextsStream.readable); + timedTextsFut.resolve(timedTextsChan); } - ttsStreamReader = ttsStream.getReader(); - // In Python, perform_tts_inference has a while loop processing multiple input segments // (separated by FlushSentinel), with pushed_duration accumulating across segments. // JS currently only does single inference, so initialPushedDuration is always 0. // TODO: Add FlushSentinel + multi-segment loop const initialPushedDuration = pushedDuration; - while (true) { + for await (const frame of withIdleTimeout(ttsStream, TTS_READ_IDLE_TIMEOUT_MS)) { if (signal.aborted) { break; } - const { done, value: frame } = await waitUntilTimeout( - ttsStreamReader.read(), - TTS_READ_IDLE_TIMEOUT_MS, - ); - if (done) { - break; - } - if (!firstByteReceived) { firstByteReceived = true; ttfb = performance.now() / 1000 - startTime; @@ -655,8 +640,13 @@ export function performTTSInference( span.setAttribute(traceTypes.ATTR_RESPONSE_TTFB, ttfb); } - // Write the audio frame to the output stream - await outputWriter.write(frame); + // Write the audio frame to the output channel + try { + audioChan.sendNowait(frame); + } catch (e) { + if (e instanceof ChanClosed) break; + throw e; + } const timedTranscripts = frame.userdata[USERDATA_TIMED_TRANSCRIPT] as | TimedString[] @@ -677,7 +667,12 @@ export function performTTSInference( confidence: timedText.confidence, startTimeOffset: timedText.startTimeOffset, }); - await timedTextsWriter.write(adjustedTimedText); + try { + timedTextsChan.sendNowait(adjustedTimedText); + } catch (e) { + if (e instanceof ChanClosed) break; + throw e; + } } } @@ -696,10 +691,8 @@ export function performTTSInference( if (!timedTextsFut.done) { timedTextsFut.resolve(null); } - ttsStreamReader?.releaseLock(); - await ttsStream?.cancel(); - await outputWriter.close(); - await timedTextsWriter.close(); + audioChan.close(); + timedTextsChan.close(); } }; @@ -724,19 +717,16 @@ export interface _TextOut { } async function forwardText( - source: ReadableStream, + source: AsyncIterable, out: _TextOut, signal: AbortSignal, textOutput: TextOutput | null, ): Promise { - const reader = source.getReader(); try { - while (true) { + for await (const delta of source) { if (signal.aborted) { break; } - const { done, value: delta } = await reader.read(); - if (done) break; const deltaIsTimedString = isTimedString(delta); const textDelta = deltaIsTimedString ? delta.text : delta; @@ -754,12 +744,11 @@ async function forwardText( if (textOutput !== null) { textOutput.flush(); } - reader?.releaseLock(); } } export function performTextForwarding( - source: ReadableStream, + source: AsyncIterable, controller: AbortController, textOutput: TextOutput | null, ): [Task, _TextOut] { @@ -783,17 +772,14 @@ export interface _AudioOut { } async function forwardAudio( - ttsStream: ReadableStream, + ttsStream: AsyncIterable, audioOutput: AudioOutput, out: _AudioOut, signal?: AbortSignal, ): Promise { const logger = log(); - const reader = ttsStream.getReader(); let resampler: AudioResampler | null = null; - const FORWARD_AUDIO_IDLE_TIMEOUT_MS = 10_000; - const onPlaybackStarted = (ev: { createdAt: number }) => { if (!out.firstFrameFut.done) { out.firstFrameFut.resolve(ev.createdAt); @@ -804,17 +790,11 @@ async function forwardAudio( audioOutput.on(AudioOutput.EVENT_PLAYBACK_STARTED, onPlaybackStarted); audioOutput.resume(); - while (true) { + for await (const frame of withIdleTimeout(ttsStream, TTS_READ_IDLE_TIMEOUT_MS)) { if (signal?.aborted) { break; } - const { done, value: frame } = await waitUntilTimeout( - reader.read(), - FORWARD_AUDIO_IDLE_TIMEOUT_MS, - ); - if (done) break; - out.audio.push(frame); if ( @@ -853,13 +833,12 @@ async function forwardAudio( out.firstFrameFut.reject(new Error('audio forwarding cancelled before playback started')); } - reader?.releaseLock(); audioOutput.flush(); } } export function performAudioForwarding( - ttsStream: ReadableStream, + ttsStream: AsyncIterable, audioOutput: AudioOutput, controller: AbortController, ): [Task, _AudioOut] { @@ -892,7 +871,7 @@ export function performToolExecutions({ speechHandle: SpeechHandle; toolCtx: ToolContext; toolChoice?: ToolChoice; - toolCallStream: ReadableStream; + toolCallStream: AsyncIterable; onToolExecutionStarted?: (toolCall: FunctionCall) => void; onToolExecutionCompleted?: (toolExecutionOutput: ToolExecutionOutput) => void; controller: AbortController; @@ -910,13 +889,10 @@ export function performToolExecutions({ const executeToolsTask = async (controller: AbortController) => { const signal = controller.signal; - const reader = toolCallStream.getReader(); const tasks: Task[] = []; - while (!signal.aborted) { - const { done, value: toolCall } = await reader.read(); + for await (const toolCall of toolCallStream) { if (signal.aborted) break; - if (done) break; if (toolChoice === 'none') { logger.error( diff --git a/agents/src/voice/generation_tools.test.ts b/agents/src/voice/generation_tools.test.ts index d53e12196..d0e222858 100644 --- a/agents/src/voice/generation_tools.test.ts +++ b/agents/src/voice/generation_tools.test.ts @@ -1,7 +1,6 @@ // SPDX-FileCopyrightText: 2025 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import { ReadableStream as NodeReadableStream } from 'stream/web'; import { describe, expect, it } from 'vitest'; import { z } from 'zod'; import { FunctionCall, tool } from '../llm/index.js'; @@ -10,38 +9,25 @@ import type { Task } from '../utils.js'; import { cancelAndWait, delay } from '../utils.js'; import { type _TextOut, performTextForwarding, performToolExecutions } from './generation.js'; -function createStringStream(chunks: string[], delayMs: number = 0): NodeReadableStream { - return new NodeReadableStream({ - async start(controller) { - for (const c of chunks) { - if (delayMs > 0) { - await delay(delayMs); - } - controller.enqueue(c); - } - controller.close(); - }, - }); +async function* createStringStream(chunks: string[], delayMs: number = 0): AsyncIterable { + for (const c of chunks) { + if (delayMs > 0) { + await delay(delayMs); + } + yield c; + } } -function createFunctionCallStream(fc: FunctionCall): NodeReadableStream { - return new NodeReadableStream({ - start(controller) { - controller.enqueue(fc); - controller.close(); - }, - }); +async function* createFunctionCallStream(fc: FunctionCall): AsyncIterable { + yield fc; } -function createFunctionCallStreamFromArray(fcs: FunctionCall[]): NodeReadableStream { - return new NodeReadableStream({ - start(controller) { - for (const fc of fcs) { - controller.enqueue(fc); - } - controller.close(); - }, - }); +async function* createFunctionCallStreamFromArray( + fcs: FunctionCall[], +): AsyncIterable { + for (const fc of fcs) { + yield fc; + } } describe('Generation + Tool Execution', () => { diff --git a/agents/src/voice/generation_tts_timeout.test.ts b/agents/src/voice/generation_tts_timeout.test.ts index 65fb68246..98681713d 100644 --- a/agents/src/voice/generation_tts_timeout.test.ts +++ b/agents/src/voice/generation_tts_timeout.test.ts @@ -2,7 +2,6 @@ // // SPDX-License-Identifier: Apache-2.0 import { AudioFrame } from '@livekit/rtc-node'; -import { ReadableStream } from 'stream/web'; import { describe, expect, it, vi } from 'vitest'; import { initializeLogger } from '../log.js'; import { performAudioForwarding, performTTSInference } from './generation.js'; @@ -36,12 +35,12 @@ describe('TTS stream idle timeout', () => { initializeLogger({ pretty: false, level: 'silent' }); it('forwardAudio completes when TTS stream stalls after producing frames', async () => { - const stalledStream = new ReadableStream({ - start(controller) { - controller.enqueue(createSilentFrame()); - controller.enqueue(createSilentFrame()); - }, - }); + const stalledStream = (async function* () { + yield createSilentFrame(); + yield createSilentFrame(); + // stall: never close + await new Promise(() => {}); + })(); const audioOutput = new MockAudioOutput(); const controller = new AbortController(); @@ -61,14 +60,11 @@ describe('TTS stream idle timeout', () => { }, 10_000); it('forwardAudio completes normally when TTS stream closes properly', async () => { - const normalStream = new ReadableStream({ - start(controller) { - controller.enqueue(createSilentFrame()); - controller.enqueue(createSilentFrame()); - controller.enqueue(createSilentFrame()); - controller.close(); - }, - }); + const normalStream = (async function* () { + yield createSilentFrame(); + yield createSilentFrame(); + yield createSilentFrame(); + })(); const audioOutput = new MockAudioOutput(); const controller = new AbortController(); @@ -82,19 +78,16 @@ describe('TTS stream idle timeout', () => { }); it('performTTSInference completes when TTS node returns stalled stream', async () => { - const stalledTtsStream = new ReadableStream({ - start(controller) { - controller.enqueue(createSilentFrame()); - }, - }); + const stalledTtsStream = (async function* () { + yield createSilentFrame(); + // stall: never close + await new Promise(() => {}); + })(); const ttsNode = async () => stalledTtsStream; - const textInput = new ReadableStream({ - start(controller) { - controller.enqueue('Hello world'); - controller.close(); - }, - }); + const textInput = (async function* () { + yield 'Hello world'; + })(); const controller = new AbortController(); const [task, genData] = performTTSInference(ttsNode, textInput, {}, controller); diff --git a/agents/src/voice/io.ts b/agents/src/voice/io.ts index ff5d8a8b1..8b8055ff3 100644 --- a/agents/src/voice/io.ts +++ b/agents/src/voice/io.ts @@ -3,31 +3,30 @@ // SPDX-License-Identifier: Apache-2.0 import type { AudioFrame } from '@livekit/rtc-node'; import { EventEmitter } from 'node:events'; -import type { ReadableStream } from 'node:stream/web'; import type { ChatContext } from '../llm/chat_context.js'; import type { ChatChunk } from '../llm/llm.js'; import type { ToolContext } from '../llm/tool_context.js'; import { log } from '../log.js'; -import { MultiInputStream } from '../stream/multi_input_stream.js'; +import { Chan, ChanClosed } from '../stream/chan.js'; import type { SpeechEvent } from '../stt/stt.js'; import { Future } from '../utils.js'; import type { ModelSettings } from './agent.js'; export type STTNode = ( - audio: ReadableStream, + audio: AsyncIterable, modelSettings: ModelSettings, -) => Promise | null>; +) => Promise | null>; export type LLMNode = ( chatCtx: ChatContext, toolCtx: ToolContext, modelSettings: ModelSettings, -) => Promise | null>; +) => Promise | null>; export type TTSNode = ( - text: ReadableStream, + text: AsyncIterable, modelSettings: ModelSettings, -) => Promise | null>; +) => Promise | null>; /** * Symbol used to identify TimedString objects. @@ -84,14 +83,43 @@ export interface AudioOutputCapabilities { } export abstract class AudioInput { - protected multiStream: MultiInputStream = new MultiInputStream(); + protected inputChan: Chan = new Chan(); + protected _pumpAbort: AbortController | null = null; - get stream(): ReadableStream { - return this.multiStream.stream; + get stream(): AsyncIterable { + return this.inputChan; + } + + /** + * Add an input source. Values from the source are pumped into the internal channel. + * Returns an id that can be passed to removeInputStream. + */ + addInputStream(source: AsyncIterable): string { + const id = `input-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`; + this._pumpAbort?.abort(); + const abort = new AbortController(); + this._pumpAbort = abort; + (async () => { + try { + for await (const frame of source) { + if (abort.signal.aborted) break; + try { + this.inputChan.sendNowait(frame); + } catch (e) { + if (e instanceof ChanClosed) break; + throw e; + } + } + } catch { + // Source errors are silently consumed + } + })(); + return id; } async close(): Promise { - await this.multiStream.close(); + this._pumpAbort?.abort(); + this.inputChan.close(); } onAttached(): void {} diff --git a/agents/src/voice/recorder_io/recorder_io.ts b/agents/src/voice/recorder_io/recorder_io.ts index 8f5987a55..4f69c0cca 100644 --- a/agents/src/voice/recorder_io/recorder_io.ts +++ b/agents/src/voice/recorder_io/recorder_io.ts @@ -8,11 +8,8 @@ import ffmpeg from 'fluent-ffmpeg'; import fs from 'node:fs'; import path from 'node:path'; import { PassThrough } from 'node:stream'; -import type { ReadableStream } from 'node:stream/web'; -import { TransformStream } from 'node:stream/web'; import { log } from '../../log.js'; -import { isStreamReaderReleaseError } from '../../stream/deferred_stream.js'; -import { type StreamChannel, createStreamChannel } from '../../stream/stream_channel.js'; +import { Chan, ChanClosed } from '../../stream/chan.js'; import { Future, Task, cancelAndWait, delay, isFfmpegTeardownError } from '../../utils.js'; import type { AgentSession } from '../agent_session.js'; import { AudioInput, AudioOutput, type PlaybackFinishedEvent } from '../io.js'; @@ -37,8 +34,8 @@ export class RecorderIO { private inRecord?: RecorderAudioInput; private outRecord?: RecorderAudioOutput; - private inChan: StreamChannel = createStreamChannel(); - private outChan: StreamChannel = createStreamChannel(); + private inChan: Chan = new Chan(); + private outChan: Chan = new Chan(); private session: AgentSession; private sampleRate: number; @@ -101,8 +98,8 @@ export class RecorderIO { try { if (!this.started) return; - await this.inChan.close(); - await this.outChan.close(); + this.inChan.close(); + this.outChan.close(); await this.closeFuture.await; await cancelAndWait([this.forwardTask!, this.encodeTask!]); await this.inRecord?.close(); @@ -125,8 +122,12 @@ export class RecorderIO { private writeCb(buf: AudioFrame[]): void { const inputBuf = this.inRecord!.takeBuf(this.outRecord?._lastSpeechEndTime); - this.inChan.write(inputBuf); - this.outChan.write(buf); + try { + this.inChan.sendNowait(inputBuf); + this.outChan.sendNowait(buf); + } catch (e) { + if (!(e instanceof ChanClosed)) throw e; + } } get recording(): boolean { @@ -171,12 +172,18 @@ export class RecorderIO { // Flush input buffer const inputBuf = this.inRecord!.takeBuf(this.outRecord!._lastSpeechEndTime); - this.inChan - .write(inputBuf) - .catch((err) => this.logger.error({ err }, 'Error writing RecorderIO input buffer')); - this.outChan - .write([]) - .catch((err) => this.logger.error({ err }, 'Error writing RecorderIO output buffer')); + try { + this.inChan.sendNowait(inputBuf); + } catch (err) { + if (!(err instanceof ChanClosed)) + this.logger.error({ err }, 'Error writing RecorderIO input buffer'); + } + try { + this.outChan.sendNowait([]); + } catch (err) { + if (!(err instanceof ChanClosed)) + this.logger.error({ err }, 'Error writing RecorderIO output buffer'); + } } } @@ -314,12 +321,12 @@ export class RecorderIO { private async encode(): Promise { if (!this._outputPath) return; - const inReader = this.inChan.stream().getReader(); - const outReader = this.outChan.stream().getReader(); + const inIter = this.inChan[Symbol.asyncIterator](); + const outIter = this.outChan[Symbol.asyncIterator](); try { while (true) { - const [inResult, outResult] = await Promise.all([inReader.read(), outReader.read()]); + const [inResult, outResult] = await Promise.all([inIter.next(), outIter.next()]); if (inResult.done || outResult.done) { break; @@ -348,11 +355,10 @@ export class RecorderIO { await this.ffmpegPromise; } } catch (err) { - this.logger.error({ err }, 'Error in encode task'); + if (!(err instanceof ChanClosed)) { + this.logger.error({ err }, 'Error in encode task'); + } } finally { - inReader.releaseLock(); - outReader.releaseLock(); - if (!this.closeFuture.done) { this.closeFuture.resolve(); } @@ -374,7 +380,7 @@ class RecorderAudioInput extends AudioInput { this.source = source; // Set up the intercepting stream - this.multiStream.addInputStream(this.createInterceptingStream()); + this.addInputStream(this.createInterceptingStream()); } /** @@ -430,59 +436,27 @@ class RecorderAudioInput extends AudioInput { } /** - * Creates a stream that intercepts frames from the source, + * Creates an async iterable that intercepts frames from the source, * accumulates them when recording, and passes them through unchanged. */ - private createInterceptingStream(): ReadableStream { - const sourceStream = this.source.stream; - const reader = sourceStream.getReader(); - - const transform = new TransformStream({ - transform: (frame, controller) => { - // Accumulate frames when recording is active - if (this.recorderIO.recording) { - if (this._startedWallTime === undefined) { - this._startedWallTime = Date.now(); - } - this.accFrames.push(frame); - } - - controller.enqueue(frame); - }, - }); - - const pump = async () => { - const writer = transform.writable.getWriter(); - let sourceError: unknown; - + private createInterceptingStream(): AsyncIterable { + const self = this; + return (async function* () { try { - while (true) { - const { done, value } = await reader.read(); - if (done) break; - await writer.write(value); - } - } catch (e) { - if (isStreamReaderReleaseError(e)) return; - sourceError = e; - } finally { - if (sourceError) { - writer.abort(sourceError); - return; - } - - writer.releaseLock(); - - try { - await transform.writable.close(); - } catch { - // ignore "WritableStream is closed" errors + for await (const frame of self.source.stream) { + // Accumulate frames when recording is active + if (self.recorderIO.recording) { + if (self._startedWallTime === undefined) { + self._startedWallTime = Date.now(); + } + self.accFrames.push(frame); + } + yield frame; } + } catch { + // Source errors silently consumed } - }; - - pump(); - - return transform.readable; + })(); } onAttached(): void { diff --git a/agents/src/voice/room_io/_input.ts b/agents/src/voice/room_io/_input.ts index 6ede89e2f..9ba19eb47 100644 --- a/agents/src/voice/room_io/_input.ts +++ b/agents/src/voice/room_io/_input.ts @@ -13,8 +13,8 @@ import { RoomEvent, TrackSource, } from '@livekit/rtc-node'; -import type { ReadableStream } from 'node:stream/web'; import { log } from '../../log.js'; +import { fromReadableStream } from '../../stream/adapters.js'; import { resampleStream } from '../../utils.js'; import { AudioInput } from '../io.js'; @@ -125,7 +125,7 @@ export class ParticipantAudioInputStream extends AudioInput { private closeStream() { if (this.currentInputId) { - void this.multiStream.removeInputStream(this.currentInputId); + this._pumpAbort?.abort(); this.currentInputId = null; } @@ -147,7 +147,7 @@ export class ParticipantAudioInputStream extends AudioInput { } this.closeStream(); this.publication = publication; - this.currentInputId = this.multiStream.addInputStream( + this.currentInputId = this.addInputStream( resampleStream({ stream: this.createStream(track), outputRate: this.sampleRate, @@ -174,13 +174,15 @@ export class ParticipantAudioInputStream extends AudioInput { } }; - private createStream(track: RemoteTrack): ReadableStream { - return new AudioStream(track, { - sampleRate: this.sampleRate, - numChannels: this.numChannels, - noiseCancellation: this.frameProcessor || this.noiseCancellation, - // TODO(AJS-269): resolve compatibility issue with node-sdk to remove the forced type casting - }) as unknown as ReadableStream; + private createStream(track: RemoteTrack): AsyncIterable { + return fromReadableStream( + new AudioStream(track, { + sampleRate: this.sampleRate, + numChannels: this.numChannels, + noiseCancellation: this.frameProcessor || this.noiseCancellation, + // TODO(AJS-269): resolve compatibility issue with node-sdk to remove the forced type casting + }) as unknown as import('node:stream/web').ReadableStream, + ); } override async close() { diff --git a/agents/src/voice/transcription/synchronizer.ts b/agents/src/voice/transcription/synchronizer.ts index 75b3de4ba..06cf75f71 100644 --- a/agents/src/voice/transcription/synchronizer.ts +++ b/agents/src/voice/transcription/synchronizer.ts @@ -2,9 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 import type { AudioFrame } from '@livekit/rtc-node'; -import type { ReadableStream, WritableStreamDefaultWriter } from 'node:stream/web'; import { log } from '../../log.js'; -import { IdentityTransform } from '../../stream/identity_transform.js'; +import { Chan, ChanClosed } from '../../stream/chan.js'; import type { SentenceStream, SentenceTokenizer } from '../../tokenize/index.js'; import { basic } from '../../tokenize/index.js'; import { Future, Task, delay } from '../../utils.js'; @@ -143,8 +142,7 @@ class SegmentSynchronizerImpl { private textData: TextData; private audioData: AudioData; private speed: number; - private outputStream: IdentityTransform; - private outputStreamWriter: WritableStreamDefaultWriter; + private outputChan: Chan; private captureTask: Promise; private startWallTime?: number; @@ -171,12 +169,11 @@ class SegmentSynchronizerImpl { done: false, annotatedRate: null, }; - this.outputStream = new IdentityTransform(); - this.outputStreamWriter = this.outputStream.writable.getWriter(); + this.outputChan = new Chan(); this.mainTask() .then(() => { - this.outputStreamWriter.close(); + this.outputChan.close(); }) .catch((error) => { this.logger.error({ error }, 'mainTask SegmentSynchronizerImpl'); @@ -200,8 +197,8 @@ class SegmentSynchronizerImpl { return this.textData.pushedText.length > this.textData.forwardedText.length; } - get readable(): ReadableStream { - return this.outputStream.readable; + get readable(): AsyncIterable { + return this.outputChan; } pushAudio(frame: AudioFrame) { @@ -305,15 +302,13 @@ class SegmentSynchronizerImpl { // Don't use a for-await loop here, because exiting the loop will close the writer in the // outputStream, which will cause an error in the mainTask.then method. // NOTE: forwardedText is updated in mainTask, NOT here - const reader = this.outputStream.readable.getReader(); - while (true) { - const { done, value: text } = await reader.read(); - if (done) { - break; + try { + for await (const text of this.outputChan) { + await this.nextInChain.captureText(text); } - await this.nextInChain.captureText(text); + } catch (e) { + if (!(e instanceof ChanClosed)) throw e; } - reader.releaseLock(); this.nextInChain.flush(); } @@ -342,7 +337,7 @@ class SegmentSynchronizerImpl { } if (this.playbackCompleted) { - this.outputStreamWriter.write(sentence.slice(textCursor, endPos)); + this.outputChan.sendNowait(sentence.slice(textCursor, endPos)); textCursor = endPos; continue; } @@ -379,7 +374,7 @@ class SegmentSynchronizerImpl { await this.sleepIfNotClosed(delayTime / 2); const forwardedWord = sentence.slice(textCursor, endPos); - this.outputStreamWriter.write(forwardedWord); + this.outputChan.sendNowait(forwardedWord); await this.sleepIfNotClosed(delayTime / 2); @@ -390,7 +385,7 @@ class SegmentSynchronizerImpl { if (textCursor < sentence.length) { const remaining = sentence.slice(textCursor); - this.outputStreamWriter.write(remaining); + this.outputChan.sendNowait(remaining); } } } diff --git a/plugins/silero/src/vad.ts b/plugins/silero/src/vad.ts index 41a7aa31e..5d768972a 100644 --- a/plugins/silero/src/vad.ts +++ b/plugins/silero/src/vad.ts @@ -157,15 +157,13 @@ export class VADStream extends baseStream { // used to avoid drift when the sampleRate ratio is not an integer let inputCopyRemainingFrac = 0.0; - while (!this.closed) { - const { done, value: frame } = await this.inputReader.read(); - if (done) { - break; - } + for await (const inputItem of this.inputChan) { + if (this.closed) break; - if (typeof frame === 'symbol') { + if (typeof inputItem === 'symbol') { continue; // ignore flush sentinel for now } + const frame = inputItem; if (!this.#inputSampleRate || !this.#speechBuffer) { this.#inputSampleRate = frame.sampleRate;