Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class SelectorManagerState(BaseGroupChatManagerState):
"""State for :class:`~autogen_agentchat.teams.SelectorGroupChat` manager."""

previous_speaker: Optional[str] = Field(default=None)
previous_speakers: Optional[List[str]] = Field(default=None)
type: str = Field(default="SelectorManagerState")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@

trace_logger = logging.getLogger(TRACE_LOGGER_NAME)

SyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None]
AsyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | None]]
SyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | list[str] | None]
AsyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | list[str] | None]]
SelectorFuncType = Union[SyncSelectorFunc | AsyncSelectorFunc]

SyncCandidateFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], List[str]]
Expand Down Expand Up @@ -72,6 +72,7 @@ def __init__(
emit_team_events: bool,
model_context: ChatCompletionContext | None,
model_client_streaming: bool = False,
max_concurrent_speakers: int = 1,
) -> None:
super().__init__(
name,
Expand All @@ -88,8 +89,9 @@ def __init__(
)
self._model_client = model_client
self._selector_prompt = selector_prompt
self._previous_speaker: str | None = None
self._previous_speakers: list[str] = []
self._allow_repeated_speaker = allow_repeated_speaker
self._max_concurrent_speakers = max_concurrent_speakers
self._selector_func = selector_func
self._is_selector_func_async = iscoroutinefunction(self._selector_func)
self._max_selector_attempts = max_selector_attempts
Expand All @@ -111,13 +113,14 @@ async def reset(self) -> None:
await self._model_context.clear()
if self._termination_condition is not None:
await self._termination_condition.reset()
self._previous_speaker = None
self._previous_speakers = []

async def save_state(self) -> Mapping[str, Any]:
state = SelectorManagerState(
message_thread=[msg.dump() for msg in self._message_thread],
current_turn=self._current_turn,
previous_speaker=self._previous_speaker,
previous_speaker=self._previous_speakers[0] if self._previous_speakers else None,
previous_speakers=self._previous_speakers if self._previous_speakers else None,
)
return state.model_dump()

Expand All @@ -128,7 +131,12 @@ async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context, [msg for msg in self._message_thread if isinstance(msg, BaseChatMessage)]
)
self._current_turn = selector_state.current_turn
self._previous_speaker = selector_state.previous_speaker
if selector_state.previous_speakers is not None:
self._previous_speakers = selector_state.previous_speakers
elif selector_state.previous_speaker is not None:
self._previous_speakers = [selector_state.previous_speaker]
else:
self._previous_speakers = []

@staticmethod
async def _add_messages_to_context(
Expand All @@ -150,12 +158,13 @@ async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseCh
await self._add_messages_to_context(self._model_context, base_chat_messages)

async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
"""Selects the next speaker in a group chat using a ChatCompletion client,
with the selector function as override if it returns a speaker name.
"""Selects the next speaker(s) in a group chat using a ChatCompletion client,
with the selector function as override if it returns speaker name(s).

.. note::

This method always returns a single speaker name.
When ``max_concurrent_speakers`` is 1 (default), this method returns a single speaker.
When ``max_concurrent_speakers`` > 1, it may return multiple speakers to run concurrently.

A key assumption is that the agent type is the same as the topic type, which we use as the agent name.
"""
Expand All @@ -168,13 +177,17 @@ async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage
sync_selector_func = cast(SyncSelectorFunc, self._selector_func)
speaker = sync_selector_func(thread)
if speaker is not None:
if speaker not in self._participant_names:
raise ValueError(
f"Selector function returned an invalid speaker name: {speaker}. "
f"Expected one of: {self._participant_names}."
)
# Normalize to list.
speakers = [speaker] if isinstance(speaker, str) else speaker
for s in speakers:
if s not in self._participant_names:
raise ValueError(
f"Selector function returned an invalid speaker name: {s}. "
f"Expected one of: {self._participant_names}."
)
# Skip the model based selection.
return [speaker]
self._previous_speakers = speakers
return speakers

# Use the candidate function to filter participants if provided
if self._candidate_func is not None:
Expand All @@ -192,27 +205,35 @@ async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage
f"Expected one of: {self._participant_names}."
)
else:
# Construct the candidate agent list to be selected from, skip the previous speaker if not allowed.
if self._previous_speaker is not None and not self._allow_repeated_speaker:
participants = [p for p in self._participant_names if p != self._previous_speaker]
# Construct the candidate agent list to be selected from, skip previous speakers if not allowed.
if self._previous_speakers and not self._allow_repeated_speaker:
participants = [p for p in self._participant_names if p not in self._previous_speakers]
else:
participants = list(self._participant_names)

assert len(participants) > 0

# Construct agent roles.
# Each agent sould appear on a single line.
# Each agent should appear on a single line.
roles = ""
for topic_type, description in zip(self._participant_names, self._participant_descriptions, strict=True):
roles += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n"
roles = roles.strip()

# Select the next speaker.
# Select the next speaker(s).
if len(participants) > 1:
agent_name = await self._select_speaker(roles, participants, self._max_selector_attempts)
if self._max_concurrent_speakers > 1:
agent_names = await self._select_multiple_speakers(
roles, participants, self._max_selector_attempts
)
self._previous_speakers = agent_names
trace_logger.debug(f"Selected speakers: {agent_names}")
return agent_names
else:
agent_name = await self._select_speaker(roles, participants, self._max_selector_attempts)
else:
agent_name = participants[0]
self._previous_speaker = agent_name
self._previous_speakers = [agent_name]
trace_logger.debug(f"Selected speaker: {agent_name}")
return [agent_name]

Expand Down Expand Up @@ -286,10 +307,10 @@ async def _select_speaker(self, roles: str, participants: List[str], max_attempt
agent_name = list(mentions.keys())[0]
if (
not self._allow_repeated_speaker
and self._previous_speaker is not None
and agent_name == self._previous_speaker
and self._previous_speakers
and agent_name in self._previous_speakers
):
trace_logger.debug(f"Model selected the previous speaker: {agent_name} (attempt {num_attempts})")
trace_logger.debug(f"Model selected a previous speaker: {agent_name} (attempt {num_attempts})")
feedback = (
f"Repeated speaker is not allowed, please select a different name from: {str(participants)}."
)
Expand All @@ -299,14 +320,101 @@ async def _select_speaker(self, roles: str, participants: List[str], max_attempt
trace_logger.debug(f"Model selected a valid name: {agent_name} (attempt {num_attempts})")
return agent_name

if self._previous_speaker is not None:
trace_logger.warning(f"Model failed to select a speaker after {max_attempts}, using the previous speaker.")
return self._previous_speaker
if self._previous_speakers:
trace_logger.warning(f"Model failed to select a speaker after {max_attempts}, using a previous speaker.")
return self._previous_speakers[0]
trace_logger.warning(
f"Model failed to select a speaker after {max_attempts} and there was no previous speaker, using the first participant."
)
return participants[0]

async def _select_multiple_speakers(self, roles: str, participants: List[str], max_attempts: int) -> List[str]:
"""Select multiple speakers concurrently using the model."""
model_context_messages = await self._model_context.get_messages()
model_context_history = self.construct_message_history(model_context_messages)

concurrent_prompt = (
f"You are in a role play game. The following roles are available:\n{roles}.\n"
f"Read the following conversation. Then select up to {self._max_concurrent_speakers} roles "
f"from {str(participants)} to play next. These roles will respond concurrently.\n"
f"Return the selected role names separated by commas.\n\n"
f"{model_context_history}\n\n"
f"Read the above conversation. Then select up to {self._max_concurrent_speakers} roles "
f"from {str(participants)} to respond concurrently. Return only the role names, separated by commas."
)

select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage]
if ModelFamily.is_openai(self._model_client.model_info["family"]):
select_speaker_messages = [SystemMessage(content=concurrent_prompt)]
else:
select_speaker_messages = [UserMessage(content=concurrent_prompt, source="user")]

num_attempts = 0
while num_attempts < max_attempts:
num_attempts += 1
if self._model_client_streaming:
chunk: CreateResult | str = ""
async for _chunk in self._model_client.create_stream(messages=select_speaker_messages):
chunk = _chunk
if self._emit_team_events:
if isinstance(chunk, str):
await self._output_message_queue.put(
ModelClientStreamingChunkEvent(content=cast(str, _chunk), source=self._name)
)
else:
assert isinstance(chunk, CreateResult)
assert isinstance(chunk.content, str)
await self._output_message_queue.put(
SelectorEvent(content=chunk.content, source=self._name)
)
assert isinstance(chunk, CreateResult)
response = chunk
else:
response = await self._model_client.create(messages=select_speaker_messages)
assert isinstance(response.content, str)
select_speaker_messages.append(AssistantMessage(content=response.content, source="selector"))
mentions = self._mentioned_agents(response.content, self._participant_names)
if len(mentions) == 0:
trace_logger.debug(f"Model failed to select valid names: {response.content} (attempt {num_attempts})")
feedback = f"No valid name was mentioned. Please select from: {str(participants)}."
select_speaker_messages.append(UserMessage(content=feedback, source="user"))
continue

# Filter to only valid candidates.
selected = [name for name in mentions if name in participants]
if not selected:
trace_logger.debug(f"Model selected non-candidate names: {list(mentions.keys())} (attempt {num_attempts})")
feedback = f"Please select from the candidates: {str(participants)}."
select_speaker_messages.append(UserMessage(content=feedback, source="user"))
continue

# Filter out previous speakers if not allowed.
if not self._allow_repeated_speaker and self._previous_speakers:
selected = [s for s in selected if s not in self._previous_speakers]
if not selected:
trace_logger.debug(f"Model selected only previous speakers (attempt {num_attempts})")
feedback = (
f"Repeated speakers are not allowed, please select different names from: {str(participants)}."
)
select_speaker_messages.append(UserMessage(content=feedback, source="user"))
continue

# Limit to max concurrent speakers.
selected = selected[: self._max_concurrent_speakers]
trace_logger.debug(f"Model selected valid names: {selected} (attempt {num_attempts})")
return selected

# Fallback: use first participant.
if self._previous_speakers:
trace_logger.warning(
f"Model failed to select speakers after {max_attempts}, using a previous speaker."
)
return [self._previous_speakers[0]]
trace_logger.warning(
f"Model failed to select speakers after {max_attempts}, using the first participant."
)
return [participants[0]]

def _mentioned_agents(self, message_content: str, agent_names: List[str]) -> Dict[str, int]:
"""Counts the number of times each agent is mentioned in the provided message content.
Agent names will match under any of the following conditions (all case-sensitive):
Expand Down Expand Up @@ -357,6 +465,7 @@ class SelectorGroupChatConfig(BaseModel):
emit_team_events: bool = False
model_client_streaming: bool = False
model_context: ComponentModel | None = None
max_concurrent_speakers: int = 1


class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
Expand Down Expand Up @@ -397,10 +506,11 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
max_selector_attempts (int, optional): The maximum number of attempts to select a speaker using the model. Defaults to 3.
If the model fails to select a speaker after the maximum number of attempts, the previous speaker will be used if available,
otherwise the first participant will be used.
selector_func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None], Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | None]], optional): A custom selector
function that takes the conversation history and returns the name of the next speaker.
If provided, this function will be used to override the model to select the next speaker.
selector_func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | list[str] | None], Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | list[str] | None]], optional): A custom selector
function that takes the conversation history and returns the name of the next speaker(s).
If provided, this function will be used to override the model to select the next speaker(s).
If the function returns None, the model will be used to select the next speaker.
The function may return a list of speaker names to enable concurrent responses.
NOTE: `selector_func` is not serializable and will be ignored during serialization and deserialization process.
candidate_func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], List[str]], Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[List[str]]], optional):
A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker
Expand All @@ -411,6 +521,10 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
Make sure your custom message types are subclasses of :class:`~autogen_agentchat.messages.BaseAgentEvent` or :class:`~autogen_agentchat.messages.BaseChatMessage`.
emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False.
model_client_streaming (bool, optional): Whether to use streaming for the model client. (This is useful for reasoning models like QwQ). Defaults to False.
max_concurrent_speakers (int, optional): The maximum number of speakers that can respond concurrently in a single turn.
Defaults to 1 (sequential speaker selection). When set to a value greater than 1, the model or selector function
may select multiple speakers to respond concurrently. All selected speakers must respond before the next
selection round begins.
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving
:class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. Messages stored in model context will be used for speaker selection. The initial messages will be cleared when the team is reset.

Expand Down Expand Up @@ -620,7 +734,10 @@ def __init__(
emit_team_events: bool = False,
model_client_streaming: bool = False,
model_context: ChatCompletionContext | None = None,
max_concurrent_speakers: int = 1,
):
if max_concurrent_speakers < 1:
raise ValueError("max_concurrent_speakers must be at least 1.")
super().__init__(
name=name or self.DEFAULT_NAME,
description=description or self.DEFAULT_DESCRIPTION,
Expand All @@ -644,6 +761,7 @@ def __init__(
self._candidate_func = candidate_func
self._model_client_streaming = model_client_streaming
self._model_context = model_context
self._max_concurrent_speakers = max_concurrent_speakers

def _create_group_chat_manager_factory(
self,
Expand Down Expand Up @@ -678,6 +796,7 @@ def _create_group_chat_manager_factory(
self._emit_team_events,
self._model_context,
self._model_client_streaming,
self._max_concurrent_speakers,
)

def _to_config(self) -> SelectorGroupChatConfig:
Expand All @@ -695,6 +814,7 @@ def _to_config(self) -> SelectorGroupChatConfig:
emit_team_events=self._emit_team_events,
model_client_streaming=self._model_client_streaming,
model_context=self._model_context.dump_component() if self._model_context else None,
max_concurrent_speakers=self._max_concurrent_speakers,
)

@classmethod
Expand Down Expand Up @@ -727,4 +847,5 @@ def _from_config(cls, config: SelectorGroupChatConfig) -> Self:
emit_team_events=config.emit_team_events,
model_client_streaming=config.model_client_streaming,
model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None,
max_concurrent_speakers=config.max_concurrent_speakers,
)
Loading