Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
from ten_ai_base.utils import encrypt

Expand All @@ -16,12 +16,29 @@ class AssemblyAIASRConfig(BaseModel):
sample_rate: int = 16000
encoding: str = "pcm_s16le"

# Model selection
# u3-rt-pro enables language_detection by default and supports prompt/vad_threshold
speech_model: Literal[
"universal-streaming-english",
"universal-streaming-multilingual",
"u3-rt-pro",
] = "u3-rt-pro"

# Real-time transcription settings
end_of_turn_confidence_threshold: Optional[float] = 0.4
end_of_turn_confidence_threshold: Optional[float] = None
format_turns: bool = True
keyterms_prompt: Optional[List[str]] = Field(default_factory=list)
min_end_of_turn_silence_when_confident: Optional[int] = 160
max_turn_silence: Optional[int] = 400
# Deprecated: use min_turn_silence instead
min_end_of_turn_silence_when_confident: Optional[int] = None
min_turn_silence: Optional[int] = None
max_turn_silence: Optional[int] = None

# u3-rt-pro specific settings
language_detection: Optional[bool] = None
prompt: Optional[str] = None
vad_threshold: Optional[float] = None
speaker_labels: Optional[bool] = None
max_speakers: Optional[int] = None

# Language settings
language: str = "en-US"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,19 @@ async def start_connection(self) -> None:
assemblyai_config = {
"sample_rate": self.config.sample_rate,
"encoding": self.config.encoding,
"speech_model": self.config.speech_model,
"end_of_turn_confidence_threshold": self.config.end_of_turn_confidence_threshold,
"format_turns": self.config.format_turns,
"keyterms_prompt": self.config.keyterms_prompt,
# min_turn_silence supersedes the deprecated min_end_of_turn_silence_when_confident
"min_turn_silence": self.config.min_turn_silence,
"min_end_of_turn_silence_when_confident": self.config.min_end_of_turn_silence_when_confident,
"max_turn_silence": self.config.max_turn_silence,
"language_detection": self.config.language_detection,
"prompt": self.config.prompt,
"vad_threshold": self.config.vad_threshold,
"speaker_labels": self.config.speaker_labels,
"max_speakers": self.config.max_speakers,
}
self.ten_env.log_info(f"AssemblyAI ASR config: {assemblyai_config}")
self.recognition = AssemblyAIWSRecognition(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,32 @@
"type": "string"
}
},
"speech_model": {
"type": "string"
},
"language_detection": {
"type": "bool"
},
"min_turn_silence": {
"type": "int64"
},
"min_end_of_turn_silence_when_confident": {
"type": "int64"
},
"max_turn_silence": {
"type": "int64"
},
"prompt": {
"type": "string"
},
"vad_threshold": {
"type": "float64"
},
"speaker_labels": {
"type": "bool"
},
"max_speakers": {
"type": "int64"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
"ws_url": "wss://streaming.assemblyai.com/v3/ws",
"sample_rate": 16000,
"encoding": "pcm_s16le",
"speech_model": "u3-rt-pro",
"language": "en-US",
"end_of_turn_confidence_threshold": 0.4,
"format_turns": true,
"keyterms_prompt": [],
"min_end_of_turn_silence_when_confident": 400,
"max_turn_silence": 1280
"keyterms_prompt": []
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,17 @@ def _build_websocket_url(self) -> str:

sample_rate = self.config.get("sample_rate", 16000)
params.append(f"sample_rate={sample_rate}")

encoding = self.config.get("encoding", "pcm_s16le")
if encoding:
params.append(f"encoding={encoding}")

speech_model = self.config.get("speech_model")
if speech_model:
params.append(f"speech_model={speech_model}")

is_u3_rt_pro = speech_model == "u3-rt-pro"

end_of_turn_confidence_threshold = self.config.get(
"end_of_turn_confidence_threshold"
)
Expand All @@ -198,21 +205,51 @@ def _build_websocket_url(self) -> str:

keyterms_prompt = self.config.get("keyterms_prompt", [])
if keyterms_prompt:
keyterms_str = ",".join(keyterms_prompt)
keyterms_str = json.dumps(keyterms_prompt)
params.append(f"keyterms_prompt={keyterms_str}")

min_end_of_turn_silence_when_confident = self.config.get(
"min_end_of_turn_silence_when_confident"
)
if min_end_of_turn_silence_when_confident is not None:
params.append(
f"min_end_of_turn_silence_when_confident={min_end_of_turn_silence_when_confident}"
)
# min_turn_silence replaces the deprecated min_end_of_turn_silence_when_confident.
# If neither is set, omit the param and let the API use its default.
min_turn_silence = self.config.get("min_turn_silence")
if min_turn_silence is None:
min_turn_silence = self.config.get("min_end_of_turn_silence_when_confident")
if min_turn_silence is not None:
params.append(f"min_turn_silence={min_turn_silence}")

max_turn_silence = self.config.get("max_turn_silence")
if max_turn_silence is not None:
params.append(f"max_turn_silence={max_turn_silence}")

# language_detection: defaults True for multilingual or u3-rt-pro
language_detection = self.config.get("language_detection")
if language_detection is None:
if is_u3_rt_pro or (speech_model and "multilingual" in speech_model):
language_detection = True
if language_detection is not None:
params.append(f"language_detection={str(language_detection).lower()}")

# prompt is only supported with u3-rt-pro
prompt = self.config.get("prompt")
if prompt is not None:
if is_u3_rt_pro:
params.append(f"prompt={prompt}")
else:
self.ten_env.log_warn(
"[AssemblyAI] 'prompt' is only supported with u3-rt-pro; ignoring"
)

vad_threshold = self.config.get("vad_threshold")
if vad_threshold is not None:
params.append(f"vad_threshold={vad_threshold}")

speaker_labels = self.config.get("speaker_labels")
if speaker_labels is not None:
params.append(f"speaker_labels={str(speaker_labels).lower()}")

max_speakers = self.config.get("max_speakers")
if max_speakers is not None:
params.append(f"max_speakers={max_speakers}")

self.ten_env.log_info(
f"[AssemblyAI] Building websocket url with params: {params}"
)
Expand All @@ -234,14 +271,19 @@ async def start(self, timeout: int = 10) -> bool:
return True

ws_url = self._build_websocket_url()
headers = {"Authorization": self.api_key}
headers = {
"Authorization": self.api_key,
"User-Agent": "AssemblyAI/1.0 (integration=TEN-Framework)",
}

self.ten_env.log_info(
f"[AssemblyAI] Connecting to AssemblyAI: {ws_url}"
)

self.websocket = await websockets.connect(
ws_url, additional_headers=headers, open_timeout=timeout
ws_url,
additional_headers=headers,
open_timeout=timeout,
)
self._message_task = asyncio.create_task(self._message_handler())
self._consumer_task = asyncio.create_task(self._consume_and_send())
Expand Down Expand Up @@ -315,7 +357,7 @@ async def send_update_configuration(self, config_update: Dict[str, Any]):
return

try:
message = {"type": "updateConfiguration", **config_update}
message = {"type": "UpdateConfiguration", **config_update}
await self.websocket.send(json.dumps(message))
self.ten_env.log_info(
f"[AssemblyAI] Sent configuration update: {config_update}"
Expand All @@ -340,7 +382,7 @@ async def force_endpoint(self):
return

try:
message = {"type": "forceEndpoint"}
message = {"type": "ForceEndpoint"}
await self.websocket.send(json.dumps(message))
self.ten_env.log_info("[AssemblyAI] Sent force endpoint signal")

Expand Down
Loading