diff --git a/config/greeting_local_simplified.json5 b/config/greeting_local_simplified.json5 new file mode 100644 index 000000000..2258597b3 --- /dev/null +++ b/config/greeting_local_simplified.json5 @@ -0,0 +1,126 @@ +{ + // Configuration version + version: "v1.0.3", + + // Mode system configuration for Unitree Go2 + default_mode: "greeting", + allow_manual_switching: true, + mode_memory_enabled: true, + + // Global settings + api_key: "${OM_API_KEY:-openmind_free}", + unitree_ethernet: "${UNITREE_ETHERNET:-enP2p1s0}", + system_governance: "", + cortex_llm: { + type: "QwenLLMSimplified", + config: { + agent_name: "Bits", + history_length: 2, + base_url: "${QWEN_BASE_URL:-http://omr2.local:8860}/v1", + model: "nvidia/nemotron-3-nano", + }, + }, + + knowledge_base: { + knowledge_base_name: "demo", + base_url: "${KB_BASE_URL:-http://localhost:8100}", + }, + + modes: { + approaching: { + display_name: "Approaching Person Mode", + description: "Robot approaches detected humans autonomously.", + system_prompt_base: "You are Bits, a friendly and helpful robotic companion built on a Unitree Go2 platform.", + hertz: 0.001, + agent_inputs: [], + action_execution_mode: "concurrent", + agent_actions: [], + backgrounds: [ + { + type: "ApproachingPerson", + }, + ], + lifecycle_hooks: [], + }, + greeting: { + display_name: "Greeting Conversation Mode", + description: "Robot engages in greeting conversations with users upon approach.", + system_prompt_base: "You are Bits, a friendly robot dog at NVIDIA GTC made by OpenMind. Today is March 9, 2026. Answer in 1-2 short spoken sentences. If relevant info is provided, use it — rephrase in your own words.", + hertz: 0.001, + agent_inputs: [ + { + type: "RivaASRRTSPInput", + config: { + base_url: "ws://localhost:6790", + enable_tts_interrupt: false, + }, + }, + ], + action_execution_mode: "concurrent", + agent_actions: [ + { + name: "greeting_conversation_simplified", + llm_label: "greeting_conversation", + connector: "greeting_conversation_kokoro_simplified", + config: { + model_id: "mlx-community/Kokoro-82M-bf16", + base_url: "${KOKORO_BASE_URL:-http://omr2.local:8880}/v1", + }, + }, + ], + backgrounds: [], + lifecycle_hooks: [ + { + hook_type: "on_startup", + handler_type: "message", + handler_config: { + tts_provider: "kokoro", + message: "Hello! I'm Bits, how can I help you today?", + model_id: "mlx-community/Kokoro-82M-bf16", + base_url: "${KOKORO_BASE_URL:-http://omr2.local:8880}/v1", + }, + }, + { + hook_type: "on_entry", + handler_type: "message", + handler_config: { + tts_provider: "kokoro", + message: "Hey there! What can I help you with?", + model_id: "mlx-community/Kokoro-82M-bf16", + base_url: "${KOKORO_BASE_URL:-http://omr2.local:8880}/v1", + }, + }, + { + hook_type: "on_exit", + handler_type: "function", + handler_config: { + module_name: "greeting_hook", + function: "geeting_end_hook", + tts_provider: "kokoro", + model_id: "mlx-community/Kokoro-82M-bf16", + base_url: "${KOKORO_BASE_URL:-http://omr2.local:8880}/v1", + }, + }, + ], + }, + }, + + transition_rules: [ + { + from_mode: "approaching", + to_mode: "greeting", + transition_type: "context_aware", + context_conditions: { approaching_detected: true }, + priority: 0, + cooldown_seconds: 5.0, + }, + { + from_mode: "greeting", + to_mode: "approaching", + transition_type: "context_aware", + context_conditions: { greeting_conversation_finished: true }, + priority: 0, + cooldown_seconds: 5.0, + }, + ], +} diff --git a/src/actions/greeting_conversation_simplified/connector/greeting_conversation_kokoro_simplified.py b/src/actions/greeting_conversation_simplified/connector/greeting_conversation_kokoro_simplified.py new file mode 100644 index 000000000..7c56c79e2 --- /dev/null +++ b/src/actions/greeting_conversation_simplified/connector/greeting_conversation_kokoro_simplified.py @@ -0,0 +1,262 @@ +import asyncio +import json +import logging +import time +from uuid import uuid4 + +from pydantic import Field + +from actions.base import ActionConfig, ActionConnector +from actions.greeting_conversation_simplified.interface import ( + GreetingConversationSimplifiedInput, +) +from providers.context_provider import ContextProvider +from providers.greeting_conversation_state_provider import ( + ConversationState, + GreetingConversationStateMachineProvider, +) +from providers.kokoro_tts_provider import KokoroTTSProvider +from providers.tts_text_utils import normalize_tts_text +from zenoh_msgs import ( + PersonGreetingStatus, + String, + open_zenoh_session, + prepare_header, +) + + +class SpeakKokoroTTSConfig(ActionConfig): + """ + Configuration for Kokoro TTS connector. + + Parameters + ---------- + base_url : str + Base URL for Kokoro TTS API. + voice_id : str + Kokoro voice ID. + model_id : str + Kokoro model ID. + output_format : str + Kokoro output format. + rate : int + Audio sample rate in Hz. + enable_tts_interrupt : bool + Enable TTS interrupt when ASR detects speech during playback. + silence_rate : int + Number of responses to skip before speaking. + """ + + base_url: str = Field( + default="http://127.0.0.1:8880/v1", + description="Base URL for Kokoro TTS API", + ) + voice_id: str = Field( + default="af_bella", + description="Kokoro voice ID", + ) + model_id: str = Field( + default="kokoro", + description="Kokoro model ID", + ) + output_format: str = Field( + default="pcm", + description="Kokoro output format", + ) + rate: int = Field( + default=24000, + description="Audio sample rate in Hz", + ) + enable_tts_interrupt: bool = Field( + default=False, + description="Enable TTS interrupt when ASR detects speech during playback", + ) + silence_rate: int = Field( + default=0, + description="Number of responses to skip before speaking", + ) + + +class GreetingConversationConnector( + ActionConnector[SpeakKokoroTTSConfig, GreetingConversationSimplifiedInput] +): + """ + Simplified greeting conversation connector with Kokoro TTS. + + Uses a single 'response' field from the LLM and hardcodes + conversation state values for the state machine. + Applies TTS text normalization (e.g. month abbreviation expansion). + """ + + def __init__(self, config: SpeakKokoroTTSConfig): + super().__init__(config) + + self.greeting_state_provider = GreetingConversationStateMachineProvider() + self.greeting_state_provider.start_conversation() + + self.context_provider = ContextProvider() + + # Create Kokoro TTS provider + api_key = getattr(self.config, "api_key", None) + logging.info("Creating Kokoro TTS provider") + self.tts = KokoroTTSProvider( + url=self.config.base_url, + api_key=api_key, + voice_id=self.config.voice_id, + model_id=self.config.model_id, + output_format=self.config.output_format, + rate=self.config.rate, + enable_tts_interrupt=self.config.enable_tts_interrupt, + ) + self.tts.start() + + self.tts_triggered_time = time.time() + self.tts_duration = 0.0 + self.conversation_finished_sent = False + self.pending_finished_update = False + self.delayed_update_task = None + + self.person_greeting_topic = "om/person_greeting" + try: + self.session = open_zenoh_session() + logging.info("Zenoh session opened for PersonGreetingStatus publishing") + except Exception as e: + logging.error(f"Error opening Zenoh session: {e}") + self.session = None + + self.greeting_status = ConversationState.CONVERSING.value + + async def connect( + self, output_interface: GreetingConversationSimplifiedInput + ) -> None: + """ + Process the greeting conversation response. + + Only reads 'response' from the LLM output and hardcodes + conversation state values for the state machine. + """ + logging.info(f"Greeting Response: {output_interface.response}") + + llm_output = { + "conversation_state": ConversationState.CONVERSING.value, + "response": output_interface.response, + "confidence": 0.85, + "speech_clarity": 0.85, + } + + tts_text = normalize_tts_text(output_interface.response) + self.tts.add_pending_message(tts_text) + + # Estimate TTS duration based on text length (~100 words per minute speech rate) + word_count = len(output_interface.response.split()) + self.tts_duration = ( + word_count / 100.0 + ) * 60.0 + 5 # Convert to seconds and add buffer time + self.tts_triggered_time = time.time() + + state_update = self.greeting_state_provider.process_conversation(llm_output) + current_state = state_update.get("current_state", self.greeting_status) + self.greeting_status = current_state + self.publish_countdown_status(self.greeting_status) + + logging.info(f"Greeting Conversation Response: {state_update}") + + if ( + self.greeting_status == ConversationState.FINISHED.value + and not self.conversation_finished_sent + ): + logging.info( + f"Greeting conversation state is FINISHED. " + f"Scheduling context update after TTS completes ({self.tts_duration:.1f}s)." + ) + self.pending_finished_update = True + self.conversation_finished_sent = True + self.delayed_update_task = asyncio.create_task( + self._delayed_context_update((word_count / 150.0) * 60.0) + ) + + async def _delayed_context_update(self, wait_duration: float) -> None: + """Wait for TTS to finish, then update context to indicate conversation finished.""" + try: + logging.info( + f"Waiting {wait_duration:.1f}s for TTS to complete before updating context..." + ) + await asyncio.sleep(wait_duration) + + if self.pending_finished_update: + logging.info( + "TTS completed. Updating context: greeting_conversation_finished = True" + ) + self.context_provider.update_context( + {"greeting_conversation_finished": True} + ) + self.pending_finished_update = False + except Exception as e: + logging.error(f"Error in delayed context update: {e}") + + def tick(self) -> None: + """Periodically update conversation state even without LLM input.""" + logging.info("GreetingConversationConnector tick called") + self.sleep(10) + + if time.time() - self.tts_triggered_time < self.tts_duration: + logging.info( + f"Skipping tick update due to recent TTS activity " + f"(remaining: {self.tts_duration - (time.time() - self.tts_triggered_time):.1f}s)." + ) + return + + state_update = self.greeting_state_provider.update_state_without_llm() + current_state = state_update.get("current_state", self.greeting_status) + self.greeting_status = current_state + self.publish_countdown_status(self.greeting_status) + + if ( + current_state == ConversationState.FINISHED.value + and not self.conversation_finished_sent + ): + logging.info("Greeting conversation has finished (detected in tick).") + self.context_provider.update_context( + {"greeting_conversation_finished": True} + ) + self.conversation_finished_sent = True + + logging.info( + f"State: {current_state}, " + f"Confidence: {state_update.get('confidence', {}).get('overall', 0):.2f}, " + f"Silence: {state_update.get('silence_duration', 0):.1f}s" + ) + + def publish_countdown_status(self, current_state: str) -> None: + """Publish countdown status to Zenoh based on current conversation state.""" + if current_state == ConversationState.CONVERSING.value: + seconds_until_finished = 20 + elif current_state == ConversationState.CONCLUDING.value: + seconds_until_finished = 10 + else: + seconds_until_finished = 0 + + if self.session: + request_id = str(uuid4()) + message_text = json.dumps( + {"seconds_until_finished": seconds_until_finished} + ) + try: + self.session.put( + self.person_greeting_topic, + PersonGreetingStatus( + header=prepare_header(request_id), + request_id=String(data=request_id), + status=PersonGreetingStatus.STATUS.CONVERSATION.value, + message=String(data=message_text), + ).serialize(), + ) + logging.info(f"Published PersonGreetingStatus: {message_text}") + except Exception as e: + logging.error(f"Error publishing PersonGreetingStatus: {e}") + + def stop(self): + """Stop the connector and clean up resources.""" + logging.info("Stopping Greeting Conversation action...") + if self.session: + self.session.close() diff --git a/src/actions/greeting_conversation_simplified/interface.py b/src/actions/greeting_conversation_simplified/interface.py new file mode 100644 index 000000000..c487bd13f --- /dev/null +++ b/src/actions/greeting_conversation_simplified/interface.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass + +from actions.base import Interface + + +@dataclass +class GreetingConversationSimplifiedInput: + """ + Input interface for the simplified GreetingConversation action. + + Parameters + ---------- + response : str + The spoken answer to the user. + """ + + response: str + + +@dataclass +class GreetingConversationSimplified( + Interface[GreetingConversationSimplifiedInput, GreetingConversationSimplifiedInput] +): + """ + Respond to the user. Put your spoken answer in 'response'. + """ + + input: GreetingConversationSimplifiedInput + output: GreetingConversationSimplifiedInput diff --git a/src/fuser/__init__.py b/src/fuser/__init__.py index 33a3a6fb5..72032262d 100644 --- a/src/fuser/__init__.py +++ b/src/fuser/__init__.py @@ -108,11 +108,29 @@ async def fuse( results = await self.knowledge_base.query( query_text, top_k=3, min_score=self.kb_min_score ) - if results: - kb_context = self.knowledge_base.format_context( - results, max_chars=1500 + high = [ + r for r in results if r.score is not None and r.score >= 0.92 + ] + low = [ + r + for r in results + if r.score is not None and 0.75 <= r.score < 0.92 + ] + kb_parts = [] + if high: + kb_parts.append( + self.knowledge_base.format_context(high, max_chars=1500) ) + if low: + kb_parts.append( + "[Potentially relevant, low confidence]\n" + + self.knowledge_base.format_context(low, max_chars=1000) + ) + if kb_parts: + kb_context = "\n".join(kb_parts) logging.info( + f"Knowledge base: {len(high)} high," + f" {len(low)} low confidence" f"Knowledge base: {len(results)} docs passed to LLM" ) else: @@ -126,11 +144,10 @@ async def fuse( if kb_context: inputs_fused += f"\n\nKNOWLEDGE BASE:\n{kb_context}" - # if we provide laws from blockchain, these override the locally stored rules - # the rules are not provided in the system prompt, but as a separate INPUT, - # since they are flowing from the outside world - if "Universal Laws" not in inputs_fused: - system_prompt += "\nLAWS:\n" + self.config.system_governance + # Only include verbose sections if they have content + if self.config.system_governance: + if "Universal Laws" not in inputs_fused: + system_prompt += "\nLAWS:\n" + self.config.system_governance if self.config.system_prompt_examples: system_prompt += "\n\nEXAMPLES:\n" + self.config.system_prompt_examples @@ -145,14 +162,18 @@ async def fuse( if desc: actions_fused += desc + "\n\n" - question_prompt = "What will you do? Actions:" - - # this is the final prompt: - # (1) a (typically) fixed overall system prompt with the agents, name, rules, and examples - # (2) all the inputs (vision, sound, etc.) - # (3) a (typically) fixed list of available actions - # (4) a (typically) fixed system prompt requesting commands to be generated - fused_prompt = f"{system_prompt}\n\nAVAILABLE INPUTS:\n{inputs_fused}\nAVAILABLE ACTIONS:\n\n{actions_fused}\n\n{question_prompt}" + # Build final prompt — skip verbose headers if sections are empty + if actions_fused: + question_prompt = "What will you do? Actions:" + fused_prompt = ( + f"{system_prompt}\n\n" + f"AVAILABLE INPUTS:\n{inputs_fused}\n" + f"AVAILABLE ACTIONS:\n\n{actions_fused}\n\n" + f"{question_prompt}" + ) + else: + question_prompt = "" + fused_prompt = f"{system_prompt}\n\n{inputs_fused}" logging.debug(f"FINAL PROMPT: {fused_prompt}") diff --git a/src/llm/plugins/qwen_llm_simplified.py b/src/llm/plugins/qwen_llm_simplified.py new file mode 100644 index 000000000..60cd28ffa --- /dev/null +++ b/src/llm/plugins/qwen_llm_simplified.py @@ -0,0 +1,262 @@ +""" +Simplified Qwen LLM plugin for local models with limited tool-call support. + +Builds Action objects directly (bypassing convert_function_calls_to_actions) +to preserve field names for single-arg interfaces. Includes a text fallback +for when the model returns plain text instead of tool calls. +""" + +import json +import logging +import re +import time +import typing as T + +import openai +from pydantic import BaseModel, Field + +from llm import LLM, LLMConfig +from llm.output_model import Action, CortexOutputModel +from providers.avatar_llm_state_provider import AvatarLLMState +from providers.llm_history_manager_simplified import LLMHistoryManagerSimplified + +R = T.TypeVar("R", bound=BaseModel) + +_QWEN_TOOL_CALL_RE = re.compile(r"\s*(\{.*?\})\s*", re.DOTALL) + + +def _parse_qwen_tool_calls(text: str) -> list: + """Parse Qwen-style {...} blocks from text.""" + tool_calls = [] + if not isinstance(text, str): + return tool_calls + for i, raw in enumerate(_QWEN_TOOL_CALL_RE.findall(text)): + try: + obj = json.loads(raw) + if name := obj.get("name"): + tool_calls.append( + { + "id": f"call_{i}", + "type": "function", + "function": { + "name": name, + "arguments": json.dumps( + obj.get("arguments", {}), ensure_ascii=False + ), + }, + } + ) + except Exception: + continue + return tool_calls + + +class QwenLLMSimplifiedConfig(LLMConfig): + """ + Configuration for simplified Qwen LLM. + + Parameters + ---------- + base_url : str + Base URL for local Qwen API (default: http://127.0.0.1:8860/v1). + api_key: str + API key for local Qwen API (if required, default: "placeholder"). + model : str + Qwen model name. + enable_reasoning : bool + Enable reasoning mode (default: False). + """ + + base_url: T.Optional[str] = Field( + default="http://127.0.0.1:8860/v1", + description="Base URL for local Qwen API", + ) + api_key: T.Optional[str] = Field( + default="placeholder", + description="API key for local Qwen API (if required)", + ) + model: T.Optional[str] = Field( + default="RedHatAI/Qwen3-30B-A3B-quantized.w4a16", + description="Qwen model name", + ) + enable_reasoning: bool = Field( + default=False, + description="Enable reasoning mode", + ) + + +class QwenLLMSimplified(LLM[R]): + """ + Simplified Qwen LLM that builds Actions directly. + + Bypasses convert_function_calls_to_actions to preserve field names + for single-arg interfaces. Includes a text fallback that wraps + plain text as a {"response": text} action. + """ + + def __init__( + self, + config: QwenLLMSimplifiedConfig, + available_actions: T.Optional[T.List] = None, + ): + super().__init__(config, available_actions) + + self._config: QwenLLMSimplifiedConfig = config + self._base_url = self._config.base_url + self._api_key = self._config.api_key + self._model = self._config.model + self._enable_reasoning = self._config.enable_reasoning + + self._client = openai.AsyncClient( + base_url=self._base_url, + api_key=self._api_key, + ) + + self._extra_body = {"chat_template_kwargs": {"enable_thinking": False}} + self.history_manager = LLMHistoryManagerSimplified(self._config, self._client) + + self._skip_state_management = False + + @AvatarLLMState.trigger_thinking() + @LLMHistoryManagerSimplified.update_history() + async def ask( + self, prompt: str, messages: T.Optional[T.List[T.Dict[str, T.Any]]] = None + ) -> R | None: + """ + Send prompt to local Qwen model and get structured response. + + Parameters + ---------- + prompt : str + The input prompt to send. + messages : list of dict, optional + Conversation history. + + Returns + ------- + R or None + Parsed response with actions, or None if parsing fails. + """ + if messages is None: + messages = [] + try: + logging.info(f"Qwen input: {prompt}") + logging.info(f"Qwen messages: {messages}") + + self.io_provider.llm_start_time = time.time() + self.io_provider.set_llm_prompt(prompt) + + formatted = [ + {"role": m.get("role", "user"), "content": m.get("content", "")} + for m in messages + ] + user_content = prompt if self._enable_reasoning else f"{prompt} /no_think" + formatted.append({"role": "user", "content": user_content}) + + request_params: dict[str, T.Any] = { + "model": self._model, + "messages": formatted, + "timeout": self._config.timeout, + "extra_body": self._extra_body, + } + + if self.function_schemas: + request_params["tools"] = self.function_schemas + request_params["tool_choice"] = "required" + + response = await self._client.chat.completions.create(**request_params) + + if not response.choices: + logging.warning("Qwen API returned empty choices") + return None + + message = response.choices[0].message + self.io_provider.llm_end_time = time.time() + + tool_calls = list(message.tool_calls or []) + if ( + not tool_calls + and isinstance(message.content, str) + and "" in message.content + ): + tool_calls = _parse_qwen_tool_calls(message.content) + + if tool_calls: + logging.info(f"Received {len(tool_calls)} function calls") + logging.info(f"Function calls: {tool_calls}") + + # Build Actions directly, preserving argument field names + # via json.dumps (bypasses convert_function_calls_to_actions + # which loses field names for single-arg calls). + actions = [] + for tc in tool_calls: + fn_name = ( + tc.function.name + if hasattr(tc, "function") + else tc["function"]["name"] + ) + fn_args = ( + tc.function.arguments + if hasattr(tc, "function") + else tc["function"]["arguments"] + ) + # Ensure value is a JSON string with field names preserved + if isinstance(fn_args, str): + try: + args = json.loads(fn_args) + fn_args = json.dumps(args, ensure_ascii=False) + except json.JSONDecodeError: + pass + else: + fn_args = json.dumps(fn_args, ensure_ascii=False) + + # Skip tool calls whose response is a conversation summary + val_lower = fn_args.lower() + if "conversation summary" in val_lower or ( + "previously," in val_lower and "user" in val_lower + ): + logging.info(f"Skipping summary tool call: {fn_args[:200]}") + continue + + actions.append(Action(type=fn_name, value=fn_args)) + if actions: + result = CortexOutputModel(actions=actions) + return T.cast(R, result) + logging.info("All tool calls were summary-style, returning None") + return None + + # Fallback: if model returned text but no tool calls, and there's + # exactly one function schema, wrap the text as a function call + text = message.content + if text and self.function_schemas and len(self.function_schemas) == 1: + # Skip summary-style responses that aren't user-facing + text_lower = text.lower() + is_summary = ( + ("user asked" in text_lower and "replied" in text_lower) + or "conversation summary" in text_lower + or text_lower.lstrip().startswith("previously,") + ) + if is_summary: + logging.info(f"Skipping summary-style response: {text[:200]}") + return None + # Strip XML-style tags the model may wrap around the response + text = re.sub(r"", "", text).strip() + if not text: + return None + + fn_name = self.function_schemas[0]["function"]["name"] + logging.info( + f"No tool calls found, wrapping text as {fn_name}: " f"{text[:200]}" + ) + action = Action( + type=fn_name, + value=json.dumps({"response": text}, ensure_ascii=False), + ) + result = CortexOutputModel(actions=[action]) + return T.cast(R, result) + + logging.warning(f"Qwen returned no tool calls and no usable text: {text}") + return None + except Exception as e: + logging.error(f"Qwen LLM error: {e}") + return None diff --git a/src/providers/llm_history_manager.py b/src/providers/llm_history_manager.py index c3ee98316..5efd63b36 100644 --- a/src/providers/llm_history_manager.py +++ b/src/providers/llm_history_manager.py @@ -384,4 +384,4 @@ async def wrapper(self: Any, prompt: str, *args: Any, **kwargs: Any) -> R: return wrapper - return decorator + return decorator \ No newline at end of file diff --git a/src/providers/llm_history_manager_simplified.py b/src/providers/llm_history_manager_simplified.py new file mode 100644 index 000000000..a2814134b --- /dev/null +++ b/src/providers/llm_history_manager_simplified.py @@ -0,0 +1,174 @@ +""" +Simplified LLM history manager for small local models. + +Changes from the original LLMHistoryManager: +- Simplified ACTION_MAP without verbose preambles (e.g., "**** said: {}") +- Adds "greeting_conversation" to ACTION_MAP +- Conversation-focused input formatting ("User: ..." instead of sensor-style) +- Extracts plain text from JSON-wrapped action values ({"response": text}) +- Prefixes summaries with "[Conversation summary - do not repeat]" +- Concise summarization prompts tuned for small models +""" + +import functools +import json +import logging +from typing import Any, Awaitable, Callable, List, TypeVar + +from llm import LLMConfig + +from .llm_history_manager import ChatMessage, LLMHistoryManager + +R = TypeVar("R") + + +ACTION_MAP_SIMPLIFIED = { + "emotion": "{}", + "speak": "{}", + "move": "{}", + "greeting_conversation": "{}", +} + + +class LLMHistoryManagerSimplified(LLMHistoryManager): + """ + Simplified history manager for conversational small-model use cases. + + Inherits from LLMHistoryManager but overrides: + - summarize_messages: adds "[Conversation summary - do not repeat]" prefix + - update_history: uses "User: ..." format and simplified ACTION_MAP + """ + + def __init__( + self, + config: LLMConfig, + client, + system_prompt: str = ( + "You are a concise assistant that tracks conversation history for a " + "robot named ****. Summarize ONLY what was said: what the user asked " + "and what **** replied. Do NOT elaborate, add analysis, or invent " + "details. Use plain short sentences, not tables or markdown." + ), + summary_command: str = ( + "\nWrite a brief summary of the conversation so far. List only what " + "the user said and what **** replied. Keep it under 100 words. Do not " + "repeat ****'s previous responses verbatim — just note the topic." + ), + ): + super().__init__(config, client, system_prompt, summary_command) + + async def summarize_messages(self, messages: List[ChatMessage]) -> ChatMessage: + """Summarize messages with a 'do not repeat' prefix.""" + result = await super().summarize_messages(messages) + if result.role == "assistant" and not result.content.startswith( + "[Conversation summary" + ): + result = ChatMessage( + role="assistant", + content=f"[Conversation summary - do not repeat] {result.content}", + ) + return result + + @staticmethod + def update_history() -> ( + Callable[[Callable[..., Awaitable[R]]], Callable[..., Awaitable[R]]] + ): + """ + Decorator to manage LLM history with simplified formatting. + + Uses "User: ..." input format and simplified ACTION_MAP. + """ + + def decorator( + func: Callable[..., Awaitable[R]], + ) -> Callable[..., Awaitable[R]]: + @functools.wraps(func) + async def wrapper(self: Any, prompt: str, *args: Any, **kwargs: Any) -> R: + if getattr(self, "_skip_state_management", False): + return await func(self, prompt, *args, **kwargs) + + if self._config.history_length == 0: + response = await func(self, prompt, [], *args, **kwargs) + self.history_manager.frame_index += 1 + return response + + self.agent_name = self._config.agent_name + + cycle = self.history_manager.frame_index + logging.debug(f"LLM Tasking cycle debug tracker: {cycle}") + + current_tick = self.io_provider.tick_counter + parts = [] + for input_type, input_info in self.io_provider.inputs.items(): + if input_info.tick == current_tick: + logging.debug(f"LLM: {input_type} (tick #{input_info.tick})") + if input_info.input: + parts.append(input_info.input.strip()) + formatted_inputs = ( + "User: " + " ".join(parts) if parts else "User: (no input)" + ) + + inputs = ChatMessage(role="user", content=formatted_inputs) + + logging.debug(f"Inputs: {inputs}") + self.history_manager.history.append(inputs) + + messages = self.history_manager.get_messages() + logging.debug(f"messages:\n{messages}") + response = await func(self, prompt, messages, *args, **kwargs) + logging.debug(f"Response to parse:\n{response}") + + if response is not None: + + def _extract_text(value: str) -> str: + """Extract plain text from action value.""" + try: + parsed = json.loads(value) + if isinstance(parsed, dict) and "response" in parsed: + return parsed["response"] + except (json.JSONDecodeError, TypeError): + pass + return value + + actions_text = " | ".join( + ACTION_MAP_SIMPLIFIED[action.type.lower()].format( + _extract_text(action.value) if action.value else "" + ) + for action in response.actions # type: ignore + if action.type.lower() in ACTION_MAP_SIMPLIFIED + ) + action_message = ( + f"{self.agent_name}: {actions_text}" + if actions_text + else f"{self.agent_name}: (no response)" + ) + + self.history_manager.history.append( + ChatMessage(role="assistant", content=action_message) + ) + + if ( + self.history_manager.config.history_length > 0 + and len(self.history_manager.history) + > self.history_manager.config.history_length + ): + await self.history_manager.start_summary_task( + self.history_manager.history + ) + else: + if ( + self.history_manager.history + and self.history_manager.history[-1].role == "user" + ): + logging.warning( + "LLM response failed, removing unpaired user message" + ) + self.history_manager.history.pop() + + self.history_manager.frame_index += 1 + + return response + + return wrapper + + return decorator diff --git a/src/providers/tts_text_utils.py b/src/providers/tts_text_utils.py new file mode 100644 index 000000000..aeae19bcf --- /dev/null +++ b/src/providers/tts_text_utils.py @@ -0,0 +1,51 @@ +import re + +_TTS_CORRECTIONS = [ + # Month abbreviations + (re.compile(r"\bJan\b"), "January"), + (re.compile(r"\bFeb\b"), "February"), + (re.compile(r"\bMar\b"), "March"), + (re.compile(r"\bApr\b"), "April"), + (re.compile(r"\bJun\b"), "June"), + (re.compile(r"\bJul\b"), "July"), + (re.compile(r"\bAug\b"), "August"), + (re.compile(r"\bSep(?:t)?\b"), "September"), + (re.compile(r"\bOct\b"), "October"), + (re.compile(r"\bNov\b"), "November"), + (re.compile(r"\bDec\b"), "December"), + # Address abbreviations + (re.compile(r"\bSt\b\.?"), "Street"), + (re.compile(r"\bAve\b\.?"), "Avenue"), + (re.compile(r"\bBlvd\b\.?"), "Boulevard"), + (re.compile(r"\bDr\b\.?"), "Drive"), + (re.compile(r"\bRd\b\.?"), "Road"), + (re.compile(r"\bLn\b\.?"), "Lane"), + (re.compile(r"\bCt\b\.?"), "Court"), + (re.compile(r"\bPl\b\.?"), "Place"), + (re.compile(r"\bPkwy\b\.?"), "Parkway"), + (re.compile(r"\bHwy\b\.?"), "Highway"), + # Directional abbreviations + (re.compile(r"\bN\b\.?(?=\s+[A-Z])"), "North"), + (re.compile(r"\bS\b\.?(?=\s+[A-Z])"), "South"), + (re.compile(r"\bE\b\.?(?=\s+[A-Z])"), "East"), + (re.compile(r"\bW\b\.?(?=\s+[A-Z])"), "West"), +] + +# Time patterns: "11:00 a.m." -> "11 a.m.", "3:30 p.m." -> "3 30 p.m." +_TIME_ON_HOUR = re.compile(r"\b(\d{1,2}):00\b") +_TIME_WITH_MINUTES = re.compile(r"\b(\d{1,2}):(\d{2})\b") + + +def normalize_tts_text(text: str) -> str: + """Expand abbreviations and reformat times for cleaner TTS output.""" + # Fix times: "11:00" -> "11", "3:30" -> "3 30" + text = _TIME_ON_HOUR.sub(r"\1", text) + text = _TIME_WITH_MINUTES.sub(r"\1 \2", text) + + for pattern, replacement in _TTS_CORRECTIONS: + text = pattern.sub(replacement, text) + + # Strip non-English characters (keep ASCII letters, digits, punctuation, whitespace) + text = re.sub(r"[^\x00-\x7F]+", "", text) + + return text diff --git a/tests/fuser/test_init.py b/tests/fuser/test_init.py index d58a3cc6b..f64a3f2ff 100644 --- a/tests/fuser/test_init.py +++ b/tests/fuser/test_init.py @@ -225,9 +225,11 @@ async def test_fuser_with_knowledge_base_and_voice_input(): mock_kb.query.assert_called_once_with( "What is the capital of France?", top_k=3, min_score=0.0 ) - mock_kb.format_context.assert_called_once_with( - [mock_doc1, mock_doc2], max_chars=1500 - ) + # Two-tier KB filtering: high confidence (>=0.92) and low (0.75-0.92) + # are formatted separately, so format_context is called twice + assert mock_kb.format_context.call_count == 2 + mock_kb.format_context.assert_any_call([mock_doc1], max_chars=1500) + mock_kb.format_context.assert_any_call([mock_doc2], max_chars=1000) assert result is not None assert "KNOWLEDGE BASE:" in result assert "Paris is the capital of France." in result diff --git a/tests/providers/test_llm_history_manager_simplified.py b/tests/providers/test_llm_history_manager_simplified.py new file mode 100644 index 000000000..b69d7e69d --- /dev/null +++ b/tests/providers/test_llm_history_manager_simplified.py @@ -0,0 +1,437 @@ +import asyncio +from dataclasses import dataclass +from unittest.mock import AsyncMock, MagicMock + +import openai +import pytest + +from providers.llm_history_manager import ChatMessage +from providers.llm_history_manager_simplified import ( + ACTION_MAP_SIMPLIFIED, + LLMHistoryManagerSimplified, +) + + +@dataclass +class MockAction: + type: str + value: str + + +@pytest.fixture +def llm_config(): + config = MagicMock() + config.model = "gpt-4o" + config.history_length = 5 + config.agent_name = "Test Robot" + return config + + +@pytest.fixture +def openai_client(): + client = MagicMock(spec=openai.AsyncClient) + + response = MagicMock() + response.choices = [MagicMock()] + response.choices[0].message.content = "This is a test summary" + + chat_mock = MagicMock() + completions_mock = MagicMock() + completions_mock.create = AsyncMock(return_value=response) + chat_mock.completions = completions_mock + client.chat = chat_mock + + return client + + +@pytest.fixture +def history_manager(llm_config, openai_client): + return LLMHistoryManagerSimplified(llm_config, openai_client) + + +@pytest.mark.asyncio +async def test_summarize_messages_adds_prefix(history_manager): + """Test that summarize_messages adds the 'do not repeat' prefix.""" + messages = [ + ChatMessage(role="assistant", content="Previous summary"), + ChatMessage(role="user", content="New input"), + ChatMessage(role="user", content="Action taken"), + ] + + result = await history_manager.summarize_messages(messages) + assert result.role == "assistant" + assert result.content == ( + "[Conversation summary - do not repeat] Previously, This is a test summary" + ) + + +@pytest.mark.asyncio +async def test_summarize_messages_empty(history_manager): + """Test with empty messages.""" + result = await history_manager.summarize_messages([]) + assert result.role == "system" + assert "No history to summarize" == result.content + + +@pytest.mark.asyncio +async def test_summarize_messages_api_error(history_manager): + """Test that API errors are handled gracefully.""" + history_manager.client.chat.completions.create.side_effect = Exception("API Error") + + messages = [ChatMessage(role="user", content="Test")] + result = await history_manager.summarize_messages(messages) + + assert result.role == "system" + assert "Error summarizing state" == result.content + + +@pytest.mark.asyncio +async def test_start_summary_task(history_manager): + """Test that summary task runs and updates messages.""" + messages = [ + ChatMessage(role="assistant", content="Previous summary"), + ChatMessage(role="user", content="New input"), + ChatMessage(role="user", content="Action taken"), + ] + + history_manager.summarize_messages = AsyncMock() + history_manager.summarize_messages.return_value = ChatMessage( + role="assistant", content="New summary" + ) + + await history_manager.start_summary_task(messages) + await asyncio.sleep(0.1) + + assert history_manager._summary_task is not None + await asyncio.sleep(0.1) + + assert len(messages) == 1 + assert messages[0].role == "assistant" + assert "New summary" == messages[0].content + + +def test_action_map_includes_greeting_conversation(): + """Test that ACTION_MAP_SIMPLIFIED includes greeting_conversation.""" + assert "greeting_conversation" in ACTION_MAP_SIMPLIFIED + assert "emotion" in ACTION_MAP_SIMPLIFIED + assert "speak" in ACTION_MAP_SIMPLIFIED + assert "move" in ACTION_MAP_SIMPLIFIED + + +def test_action_map_uses_simple_format(): + """Test that ACTION_MAP_SIMPLIFIED uses plain '{}' format without preambles.""" + for key, fmt in ACTION_MAP_SIMPLIFIED.items(): + assert fmt == "{}", f"Expected '{{}}' for {key}, got '{fmt}'" + + +@pytest.mark.asyncio +async def test_update_history_user_format(): + """Test that inputs are formatted as 'User: ...' instead of sensor-style.""" + config = MagicMock() + config.model = "gpt-4o" + config.history_length = 5 + config.agent_name = "TestBot" + + client = AsyncMock() + history_manager = LLMHistoryManagerSimplified(config, client) + + class MockLLMProvider: + def __init__(self): + self._config = config + self._skip_state_management = False + self.history_manager = history_manager + self.io_provider = history_manager.io_provider + self.agent_name = config.agent_name + + @LLMHistoryManagerSimplified.update_history() + async def process(self, prompt: str, messages: list) -> MagicMock: + response = MagicMock() + response.actions = [ + MockAction(type="speak", value="Hello"), + MockAction(type="emotion", value="happy"), + ] + return response + + provider = MockLLMProvider() + + provider.io_provider.add_input("audio", "User said hello", 1234.0) + provider.io_provider.add_input("vision", "Saw a person", 1235.0) + + provider.io_provider.increment_tick() + + provider.io_provider.add_input("audio_new", "User said goodbye", 1236.0) + provider.io_provider.add_input("lidar", "Detected obstacle", 1237.0) + + await provider.process("test prompt") + + assert len(history_manager.history) == 2 + + inputs_msg = history_manager.history[0] + assert inputs_msg.role == "user" + # Simplified format: "User: ..." without input type names + assert inputs_msg.content.startswith("User: ") + assert "User said goodbye" in inputs_msg.content + assert "Detected obstacle" in inputs_msg.content + # Old tick inputs should not be present + assert "User said hello" not in inputs_msg.content + assert "Saw a person" not in inputs_msg.content + + +@pytest.mark.asyncio +async def test_update_history_no_inputs(): + """Test that when no inputs match current tick, 'User: (no input)' is used.""" + config = MagicMock() + config.model = "gpt-4o" + config.history_length = 5 + config.agent_name = "TestBot" + + client = AsyncMock() + history_manager = LLMHistoryManagerSimplified(config, client) + + class MockLLMProvider: + def __init__(self): + self._config = config + self._skip_state_management = False + self.history_manager = history_manager + self.io_provider = history_manager.io_provider + self.agent_name = config.agent_name + + @LLMHistoryManagerSimplified.update_history() + async def process(self, prompt: str, messages: list) -> MagicMock: + response = MagicMock() + response.actions = [MockAction(type="speak", value="Nothing to report")] + return response + + provider = MockLLMProvider() + + provider.io_provider.add_input("audio", "Old audio", 1234.0) + provider.io_provider.increment_tick() + + await provider.process("test prompt") + + assert len(history_manager.history) == 2 + + inputs_msg = history_manager.history[0] + assert inputs_msg.role == "user" + assert inputs_msg.content == "User: (no input)" + assert "Old audio" not in inputs_msg.content + + +@pytest.mark.asyncio +async def test_update_history_multiple_ticks(): + """Test that inputs are filtered correctly across multiple tick cycles.""" + config = MagicMock() + config.model = "gpt-4o" + config.history_length = 10 + config.agent_name = "MultiTickBot" + + client = AsyncMock() + history_manager = LLMHistoryManagerSimplified(config, client) + + class MockLLMProvider: + def __init__(self): + self._config = config + self._skip_state_management = False + self.history_manager = history_manager + self.io_provider = history_manager.io_provider + self.agent_name = config.agent_name + + @LLMHistoryManagerSimplified.update_history() + async def process(self, prompt: str, messages: list) -> MagicMock: + response = MagicMock() + response.actions = [MockAction(type="speak", value="Response")] + return response + + provider = MockLLMProvider() + + # Tick 0: Add inputs + provider.io_provider.add_input("input_tick0", "Data at tick 0", 1000.0) + await provider.process("prompt") + + first_inputs = history_manager.history[0] + assert "Data at tick 0" in first_inputs.content + + # Tick 1: Increment and add new inputs + provider.io_provider.increment_tick() + provider.io_provider.add_input("input_tick1", "Data at tick 1", 2000.0) + await provider.process("prompt") + + second_inputs = history_manager.history[2] + assert "Data at tick 1" in second_inputs.content + assert "Data at tick 0" not in second_inputs.content + + # Tick 2: Increment and add new inputs + provider.io_provider.increment_tick() + provider.io_provider.add_input("input_tick2", "Data at tick 2", 3000.0) + await provider.process("prompt") + + third_inputs = history_manager.history[4] + assert "Data at tick 2" in third_inputs.content + assert "Data at tick 0" not in third_inputs.content + assert "Data at tick 1" not in third_inputs.content + + +@pytest.mark.asyncio +async def test_update_history_extracts_json_response(): + """Test that action values with JSON-wrapped response field are extracted.""" + config = MagicMock() + config.model = "gpt-4o" + config.history_length = 5 + config.agent_name = "TestBot" + + client = AsyncMock() + history_manager = LLMHistoryManagerSimplified(config, client) + + class MockLLMProvider: + def __init__(self): + self._config = config + self._skip_state_management = False + self.history_manager = history_manager + self.io_provider = history_manager.io_provider + self.agent_name = config.agent_name + + @LLMHistoryManagerSimplified.update_history() + async def process(self, prompt: str, messages: list) -> MagicMock: + response = MagicMock() + response.actions = [ + MockAction( + type="greeting_conversation", + value='{"response": "Hello there!"}', + ), + ] + return response + + provider = MockLLMProvider() + + provider.io_provider.add_input("audio", "Hi", 1234.0) + await provider.process("test prompt") + + assert len(history_manager.history) == 2 + + action_msg = history_manager.history[1] + assert action_msg.role == "assistant" + # Should extract "Hello there!" from JSON, not show raw JSON + assert "Hello there!" in action_msg.content + assert "TestBot:" in action_msg.content + + +@pytest.mark.asyncio +async def test_update_history_agent_name_format(): + """Test that action messages use '{agent_name}: {text}' format.""" + config = MagicMock() + config.model = "gpt-4o" + config.history_length = 5 + config.agent_name = "Bits" + + client = AsyncMock() + history_manager = LLMHistoryManagerSimplified(config, client) + + class MockLLMProvider: + def __init__(self): + self._config = config + self._skip_state_management = False + self.history_manager = history_manager + self.io_provider = history_manager.io_provider + self.agent_name = config.agent_name + + @LLMHistoryManagerSimplified.update_history() + async def process(self, prompt: str, messages: list) -> MagicMock: + response = MagicMock() + response.actions = [MockAction(type="speak", value="Welcome to GTC!")] + return response + + provider = MockLLMProvider() + + provider.io_provider.add_input("audio", "Hello", 1234.0) + await provider.process("test prompt") + + action_msg = history_manager.history[1] + assert action_msg.content == "Bits: Welcome to GTC!" + + +@pytest.mark.asyncio +async def test_update_history_llm_failure_removes_unpaired_message(): + """Test that when LLM returns None, unpaired user message is removed.""" + config = MagicMock() + config.model = "gpt-4o" + config.history_length = 5 + config.agent_name = "TestBot" + + client = AsyncMock() + history_manager = LLMHistoryManagerSimplified(config, client) + + class MockLLMProvider: + def __init__(self): + self._config = config + self._skip_state_management = False + self.history_manager = history_manager + self.io_provider = history_manager.io_provider + self.agent_name = config.agent_name + + @LLMHistoryManagerSimplified.update_history() + async def process(self, prompt: str, messages: list) -> None: + return None + + provider = MockLLMProvider() + + provider.io_provider.add_input("audio", "Test input", 1234.0) + result = await provider.process("test prompt") + + assert result is None + assert len(history_manager.history) == 0 + + +@pytest.mark.asyncio +async def test_update_history_skip_when_history_length_zero(): + """Test that history is skipped entirely when history_length is 0.""" + config = MagicMock() + config.model = "gpt-4o" + config.history_length = 0 + config.agent_name = "TestBot" + + client = AsyncMock() + history_manager = LLMHistoryManagerSimplified(config, client) + + class MockLLMProvider: + def __init__(self): + self._config = config + self._skip_state_management = False + self.history_manager = history_manager + self.io_provider = history_manager.io_provider + self.agent_name = config.agent_name + + @LLMHistoryManagerSimplified.update_history() + async def process(self, prompt: str, messages: list) -> MagicMock: + # messages should be empty list when history_length is 0 + assert messages == [] + response = MagicMock() + response.actions = [MockAction(type="speak", value="Hello")] + return response + + provider = MockLLMProvider() + + provider.io_provider.add_input("audio", "Test input", 1234.0) + await provider.process("test prompt") + + # History should remain empty + assert len(history_manager.history) == 0 + + +def test_get_messages_empty(history_manager): + """Test get_messages returns empty list when no history.""" + result = history_manager.get_messages() + assert result == [] + + +def test_get_messages_multiple(history_manager): + """Test get_messages with multiple messages.""" + history_manager.history.extend( + [ + ChatMessage(role="user", content="User: Hello"), + ChatMessage(role="assistant", content="Test Robot: Hi there"), + ] + ) + result = history_manager.get_messages() + assert len(result) == 2 + assert result[0] == {"role": "user", "content": "User: Hello"} + assert result[1] == {"role": "assistant", "content": "Test Robot: Hi there"}