Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions agents/src/audio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import ffmpegInstaller from '@ffmpeg-installer/ffmpeg';
import { AudioFrame } from '@livekit/rtc-node';
import ffmpeg from 'fluent-ffmpeg';
import type { ReadableStream } from 'node:stream/web';
import { log } from './log.js';
import { createStreamChannel } from './stream/stream_channel.js';
import { Chan, ChanClosed } from './stream/chan.js';
import { type AudioBuffer, isFfmpegTeardownError } from './utils.js';

ffmpeg.setFfmpegPath(ffmpegInstaller.path);
Expand Down Expand Up @@ -110,12 +109,12 @@ export class AudioByteStream {
export function audioFramesFromFile(
filePath: string,
options: AudioDecodeOptions = {},
): ReadableStream<AudioFrame> {
): AsyncIterable<AudioFrame> {
const sampleRate = options.sampleRate ?? 48000;
const numChannels = options.numChannels ?? 1;

const audioStream = new AudioByteStream(sampleRate, numChannels);
const channel = createStreamChannel<AudioFrame>();
const chan = new Chan<AudioFrame>();
const logger = log();

// TODO (Brian): decode WAV using a custom decoder instead of FFmpeg
Expand All @@ -139,7 +138,7 @@ export function audioFramesFromFile(
const onClose = () => {
logger.debug('Audio file playback aborted');

channel.close();
chan.close();
if (commandRunning) {
commandRunning = false;
command.kill('SIGKILL');
Expand Down Expand Up @@ -168,17 +167,27 @@ export function audioFramesFromFile(

const frames = audioStream.write(arrayBuffer);
for (const frame of frames) {
channel.write(frame);
try {
chan.sendNowait(frame);
} catch (e) {
if (e instanceof ChanClosed) return;
throw e;
}
}
});

outputStream.on('end', () => {
const frames = audioStream.flush();
for (const frame of frames) {
channel.write(frame);
try {
chan.sendNowait(frame);
} catch (e) {
if (e instanceof ChanClosed) return;
throw e;
}
}
commandRunning = false;
channel.close();
chan.close();
});

outputStream.on('error', (err: Error) => {
Expand All @@ -187,7 +196,7 @@ export function audioFramesFromFile(
onClose();
});

return channel.stream();
return chan;
}

/**
Expand Down
143 changes: 69 additions & 74 deletions agents/src/inference/interruption/http_transport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// SPDX-License-Identifier: Apache-2.0
import type { Throws } from '@livekit/throws-transformer/throws';
import { FetchError, ofetch } from 'ofetch';
import { TransformStream } from 'stream/web';
import { z } from 'zod';
import { APIConnectionError, APIError, APIStatusError, isAPIError } from '../../_exceptions.js';
import { log } from '../../log.js';
Expand Down Expand Up @@ -113,8 +112,12 @@ export interface HttpTransportState {
cache: BoundedCache<number, InterruptionCacheEntry>;
}

export type TransportFn = (
source: AsyncIterable<Int16Array | OverlappingSpeechEvent>,
) => AsyncIterable<OverlappingSpeechEvent>;

/**
* Creates an HTTP transport TransformStream for interruption detection.
* Creates an HTTP transport async generator for interruption detection.
*
* This transport receives Int16Array audio slices and outputs InterruptionEvents.
* Each audio slice triggers an HTTP POST request.
Expand All @@ -128,80 +131,72 @@ export function createHttpTransport(
setState: (partial: Partial<HttpTransportState>) => void,
updateUserSpeakingSpan?: (entry: InterruptionCacheEntry) => void,
getAndResetNumRequests?: () => number,
): TransformStream<Int16Array | OverlappingSpeechEvent, OverlappingSpeechEvent> {
): TransportFn {
const logger = log();

return new TransformStream<Int16Array | OverlappingSpeechEvent, OverlappingSpeechEvent>(
{
async transform(chunk, controller) {
if (!(chunk instanceof Int16Array)) {
controller.enqueue(chunk);
return;
}
return async function* (source) {
for await (const chunk of source) {
if (!(chunk instanceof Int16Array)) {
yield chunk;
continue;
}

const state = getState();
const overlapSpeechStartedAt = state.overlapSpeechStartedAt;
if (overlapSpeechStartedAt === undefined || !state.overlapSpeechStarted) return;

try {
const resp = await predictHTTP(
chunk,
{ threshold: options.threshold, minFrames: options.minFrames },
{
baseUrl: options.baseUrl,
timeout: options.timeout,
maxRetries: options.maxRetries,
token: await createAccessToken(options.apiKey, options.apiSecret),
},
);

const { createdAt, isBargein, probabilities, predictionDurationInS } = resp;
const entry = state.cache.setOrUpdate(
createdAt,
() => new InterruptionCacheEntry({ createdAt }),
{
probabilities,
isInterruption: isBargein,
speechInput: chunk,
totalDurationInS: (performance.now() - createdAt) / 1000,
detectionDelayInS: (Date.now() - overlapSpeechStartedAt) / 1000,
predictionDurationInS,
},
);

if (state.overlapSpeechStarted && entry.isInterruption) {
if (updateUserSpeakingSpan) {
updateUserSpeakingSpan(entry);
}
const event: OverlappingSpeechEvent = {
type: 'overlapping_speech',
detectedAt: Date.now(),
overlapStartedAt: overlapSpeechStartedAt,
isInterruption: entry.isInterruption,
speechInput: entry.speechInput,
probabilities: entry.probabilities,
totalDurationInS: entry.totalDurationInS,
predictionDurationInS: entry.predictionDurationInS,
detectionDelayInS: entry.detectionDelayInS,
probability: entry.probability,
numRequests: getAndResetNumRequests?.() ?? 0,
};
logger.debug(
{
detectionDelayInS: entry.detectionDelayInS,
totalDurationInS: entry.totalDurationInS,
},
'interruption detected',
);
setState({ overlapSpeechStarted: false });
controller.enqueue(event);
}
} catch (err) {
controller.error(err);
const state = getState();
const overlapSpeechStartedAt = state.overlapSpeechStartedAt;
if (overlapSpeechStartedAt === undefined || !state.overlapSpeechStarted) continue;

const resp = await predictHTTP(
chunk,
{ threshold: options.threshold, minFrames: options.minFrames },
{
baseUrl: options.baseUrl,
timeout: options.timeout,
maxRetries: options.maxRetries,
token: await createAccessToken(options.apiKey, options.apiSecret),
},
);

const { createdAt, isBargein, probabilities, predictionDurationInS } = resp;
const entry = state.cache.setOrUpdate(
createdAt,
() => new InterruptionCacheEntry({ createdAt }),
{
probabilities,
isInterruption: isBargein,
speechInput: chunk,
totalDurationInS: (performance.now() - createdAt) / 1000,
detectionDelayInS: (Date.now() - overlapSpeechStartedAt) / 1000,
predictionDurationInS,
},
);

if (state.overlapSpeechStarted && entry.isInterruption) {
if (updateUserSpeakingSpan) {
updateUserSpeakingSpan(entry);
}
},
},
{ highWaterMark: 2 },
{ highWaterMark: 2 },
);
const event: OverlappingSpeechEvent = {
type: 'overlapping_speech',
detectedAt: Date.now(),
overlapStartedAt: overlapSpeechStartedAt,
isInterruption: entry.isInterruption,
speechInput: entry.speechInput,
probabilities: entry.probabilities,
totalDurationInS: entry.totalDurationInS,
predictionDurationInS: entry.predictionDurationInS,
detectionDelayInS: entry.detectionDelayInS,
probability: entry.probability,
numRequests: getAndResetNumRequests?.() ?? 0,
};
logger.debug(
{
detectionDelayInS: entry.detectionDelayInS,
totalDurationInS: entry.totalDurationInS,
},
'interruption detected',
);
setState({ overlapSpeechStarted: false });
yield event;
}
}
};
}
Loading