Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 46 additions & 24 deletions src/providers/llm_history_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import functools
import json
import logging
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, List, Optional, TypeVar, Union
Expand All @@ -24,9 +25,10 @@ class ChatMessage:


ACTION_MAP = {
"emotion": "**** felt: {}.",
"speak": "**** said: {}",
"move": "**** performed this motion: {}.",
"emotion": "{}",
"speak": "{}",
"move": "{}",
"greeting_conversation": "{}",
}


Expand All @@ -39,8 +41,17 @@ def __init__(
self,
config: LLMConfig,
client: Union[openai.AsyncClient, openai.OpenAI],
system_prompt: str = "You are a helpful assistant that summarizes a succession of events and interactions accurately and concisely. You are watching a robot named **** interact with people and the world. Your goal is to help **** remember what the robot felt, saw, and heard, and how the robot responded to those inputs.",
summary_command: str = "\nConsidering the new information, write an updated summary of the situation for ****. Emphasize information that **** needs to know to respond to people and situations in the best possible and most compelling way.",
system_prompt: str = (
"You are a concise assistant that tracks conversation history for a "
"robot named ****. Summarize ONLY what was said: what the user asked "
"and what **** replied. Do NOT elaborate, add analysis, or invent "
"details. Use plain short sentences, not tables or markdown."
),
summary_command: str = (
"\nWrite a brief summary of the conversation so far. List only what "
"the user said and what **** replied. Keep it under 100 words. Do not "
"repeat ****'s previous responses verbatim — just note the topic."
),
):
"""
Initialize the LLMHistoryManager.
Expand Down Expand Up @@ -183,7 +194,10 @@ async def summarize_messages(self, messages: List[ChatMessage]) -> ChatMessage:
return ChatMessage(
role="system", content="Error: Received empty summary from API"
)
return ChatMessage(role="assistant", content=f"Previously, {summary}")
return ChatMessage(
role="assistant",
content=f"[Conversation summary - do not repeat] Previously, {summary}",
)

except asyncio.TimeoutError:
logging.error(f"API request timed out after {timeout} seconds")
Expand Down Expand Up @@ -318,15 +332,15 @@ async def wrapper(self: Any, prompt: str, *args: Any, **kwargs: Any) -> R:
logging.debug(f"LLM Tasking cycle debug tracker: {cycle}")

current_tick = self.io_provider.tick_counter
formatted_inputs = f"{self.agent_name} sensed the following: "
parts = []
for input_type, input_info in self.io_provider.inputs.items():
if input_info.tick == current_tick:
logging.debug(f"LLM: {input_type} (tick #{input_info.tick})")
logging.debug(f"LLM: {input_info}")
formatted_inputs += f"{input_type}. {input_info.input} | "

formatted_inputs = formatted_inputs.replace("..", ".")
formatted_inputs = formatted_inputs.replace(" ", " ")
if input_info.input:
parts.append(input_info.input.strip())
formatted_inputs = (
"User: " + " ".join(parts) if parts else "User: (no input)"
)

inputs = ChatMessage(role="user", content=formatted_inputs)

Expand All @@ -341,20 +355,28 @@ async def wrapper(self: Any, prompt: str, *args: Any, **kwargs: Any) -> R:

if response is not None:

action_message = (
"Given that information, **** took these actions: "
+ (
" | ".join(
ACTION_MAP[action.type.lower()].format(
action.value if action.value else ""
)
for action in response.actions # type: ignore
if action.type.lower() in ACTION_MAP
)
def _extract_text(value: str) -> str:
"""Extract plain text from action value."""
try:
parsed = json.loads(value)
if isinstance(parsed, dict) and "response" in parsed:
return parsed["response"]
except (json.JSONDecodeError, TypeError):
pass
return value

actions_text = " | ".join(
ACTION_MAP[action.type.lower()].format(
_extract_text(action.value) if action.value else ""
)
for action in response.actions # type: ignore
if action.type.lower() in ACTION_MAP
)
action_message = (
f"{self.agent_name}: {actions_text}"
if actions_text
else f"{self.agent_name}: (no response)"
)

action_message = action_message.replace("****", self.agent_name)

self.history_manager.history.append(
ChatMessage(role="assistant", content=action_message)
Expand Down
142 changes: 134 additions & 8 deletions tests/providers/test_llm_history_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import openai
import pytest

from providers.llm_history_manager import ChatMessage, LLMHistoryManager
from providers.llm_history_manager import ACTION_MAP, ChatMessage, LLMHistoryManager


@dataclass
Expand Down Expand Up @@ -57,7 +57,9 @@ async def test_summarize_messages_success(history_manager):
# Test successful summarization
result = await history_manager.summarize_messages(messages)
assert result.role == "assistant"
assert "Previously, This is a test summary" == result.content
assert result.content == (
"[Conversation summary - do not repeat] Previously, This is a test summary"
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -204,9 +206,8 @@ async def process(self, prompt: str, messages: list) -> MagicMock:
# First message should be the inputs message
inputs_msg = history_manager.history[0]
assert inputs_msg.role == "user"
assert "audio_new" in inputs_msg.content
assert inputs_msg.content.startswith("User: ")
assert "User said goodbye" in inputs_msg.content
assert "lidar" in inputs_msg.content
assert "Detected obstacle" in inputs_msg.content

assert "User said hello" not in inputs_msg.content
Expand Down Expand Up @@ -256,7 +257,7 @@ async def process(self, prompt: str, messages: list) -> MagicMock:
# First message should be the inputs message with just the preamble
inputs_msg = history_manager.history[0]
assert inputs_msg.role == "user"
assert "TestBot sensed the following:" in inputs_msg.content
assert inputs_msg.content == "User: (no input)"
# Old inputs should not be included
assert "Old audio" not in inputs_msg.content

Expand Down Expand Up @@ -294,7 +295,6 @@ async def process(self, prompt: str, messages: list) -> MagicMock:

# Verify only tick 0 data in first cycle
first_inputs = history_manager.history[0]
assert "input_tick0" in first_inputs.content
assert "Data at tick 0" in first_inputs.content

# Tick 1: Increment and add new inputs
Expand All @@ -304,7 +304,6 @@ async def process(self, prompt: str, messages: list) -> MagicMock:

# Find the second input message (should be at index 2)
second_inputs = history_manager.history[2]
assert "input_tick1" in second_inputs.content
assert "Data at tick 1" in second_inputs.content
# Should NOT include tick 0 data
assert "Data at tick 0" not in second_inputs.content
Expand All @@ -316,7 +315,6 @@ async def process(self, prompt: str, messages: list) -> MagicMock:

# Find the third input message (should be at index 4)
third_inputs = history_manager.history[4]
assert "input_tick2" in third_inputs.content
assert "Data at tick 2" in third_inputs.content
# Should NOT include previous tick data
assert "Data at tick 0" not in third_inputs.content
Expand Down Expand Up @@ -611,3 +609,131 @@ def test_get_messages_multiple_roles(history_manager):
assert result[0] == {"role": "system", "content": "You are a robot"}
assert result[1] == {"role": "user", "content": "Hello"}
assert result[2] == {"role": "assistant", "content": "Hi there"}


def test_action_map_includes_greeting_conversation():
"""Test that ACTION_MAP includes greeting_conversation."""
assert "greeting_conversation" in ACTION_MAP
assert "emotion" in ACTION_MAP
assert "speak" in ACTION_MAP
assert "move" in ACTION_MAP


def test_action_map_uses_simple_format():
"""Test that ACTION_MAP uses plain '{}' format without preambles."""
for key, fmt in ACTION_MAP.items():
assert fmt == "{}", f"Expected '{{}}' for {key}, got '{fmt}'"


@pytest.mark.asyncio
async def test_update_history_extracts_json_response():
"""Test that action values with JSON-wrapped response field are extracted."""
config = MagicMock()
config.model = "gpt-4o"
config.history_length = 5
config.agent_name = "TestBot"

client = AsyncMock()
history_manager = LLMHistoryManager(config, client)

class MockLLMProvider:
def __init__(self):
self._config = config
self._skip_state_management = False
self.history_manager = history_manager
self.io_provider = history_manager.io_provider
self.agent_name = config.agent_name

@LLMHistoryManager.update_history()
async def process(self, prompt: str, messages: list) -> MagicMock:
response = MagicMock()
response.actions = [
MockAction(
type="greeting_conversation",
value='{"response": "Hello there!"}',
),
]
return response

provider = MockLLMProvider()

provider.io_provider.add_input("audio", "Hi", 1234.0)
await provider.process("test prompt")

assert len(history_manager.history) == 2

action_msg = history_manager.history[1]
assert action_msg.role == "assistant"
# Should extract "Hello there!" from JSON, not show raw JSON
assert "Hello there!" in action_msg.content
assert "TestBot:" in action_msg.content


@pytest.mark.asyncio
async def test_update_history_agent_name_format():
"""Test that action messages use '{agent_name}: {text}' format."""
config = MagicMock()
config.model = "gpt-4o"
config.history_length = 5
config.agent_name = "Bits"

client = AsyncMock()
history_manager = LLMHistoryManager(config, client)

class MockLLMProvider:
def __init__(self):
self._config = config
self._skip_state_management = False
self.history_manager = history_manager
self.io_provider = history_manager.io_provider
self.agent_name = config.agent_name

@LLMHistoryManager.update_history()
async def process(self, prompt: str, messages: list) -> MagicMock:
response = MagicMock()
response.actions = [MockAction(type="speak", value="Welcome to GTC!")]
return response

provider = MockLLMProvider()

provider.io_provider.add_input("audio", "Hello", 1234.0)
await provider.process("test prompt")

action_msg = history_manager.history[1]
assert action_msg.content == "Bits: Welcome to GTC!"


@pytest.mark.asyncio
async def test_update_history_skip_when_history_length_zero():
"""Test that history is skipped entirely when history_length is 0."""
config = MagicMock()
config.model = "gpt-4o"
config.history_length = 0
config.agent_name = "TestBot"

client = AsyncMock()
history_manager = LLMHistoryManager(config, client)

class MockLLMProvider:
def __init__(self):
self._config = config
self._skip_state_management = False
self.history_manager = history_manager
self.io_provider = history_manager.io_provider
self.agent_name = config.agent_name

@LLMHistoryManager.update_history()
async def process(self, prompt: str, messages: list) -> MagicMock:
# messages should be empty list when history_length is 0
assert messages == []
response = MagicMock()
response.actions = [MockAction(type="speak", value="Hello")]
return response

provider = MockLLMProvider()

provider.io_provider.add_input("audio", "Test input", 1234.0)
await provider.process("test prompt")

# History should remain empty
assert len(history_manager.history) == 0
Loading