diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py b/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py index ecc7b5f7cae7..16fe7bee1e85 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py @@ -51,6 +51,7 @@ class SelectorManagerState(BaseGroupChatManagerState): """State for :class:`~autogen_agentchat.teams.SelectorGroupChat` manager.""" previous_speaker: Optional[str] = Field(default=None) + previous_speakers: Optional[List[str]] = Field(default=None) type: str = Field(default="SelectorManagerState") diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 480dc6b71641..b35d4713570e 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -38,8 +38,8 @@ trace_logger = logging.getLogger(TRACE_LOGGER_NAME) -SyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None] -AsyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | None]] +SyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | list[str] | None] +AsyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | list[str] | None]] SelectorFuncType = Union[SyncSelectorFunc | AsyncSelectorFunc] SyncCandidateFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], List[str]] @@ -72,6 +72,7 @@ def __init__( emit_team_events: bool, model_context: ChatCompletionContext | None, model_client_streaming: bool = False, + max_concurrent_speakers: int = 1, ) -> None: super().__init__( name, @@ -88,8 +89,9 @@ def __init__( ) self._model_client = model_client self._selector_prompt = selector_prompt - self._previous_speaker: str | None = None + self._previous_speakers: list[str] = [] self._allow_repeated_speaker = allow_repeated_speaker + self._max_concurrent_speakers = max_concurrent_speakers self._selector_func = selector_func self._is_selector_func_async = iscoroutinefunction(self._selector_func) self._max_selector_attempts = max_selector_attempts @@ -111,13 +113,14 @@ async def reset(self) -> None: await self._model_context.clear() if self._termination_condition is not None: await self._termination_condition.reset() - self._previous_speaker = None + self._previous_speakers = [] async def save_state(self) -> Mapping[str, Any]: state = SelectorManagerState( message_thread=[msg.dump() for msg in self._message_thread], current_turn=self._current_turn, - previous_speaker=self._previous_speaker, + previous_speaker=self._previous_speakers[0] if self._previous_speakers else None, + previous_speakers=self._previous_speakers if self._previous_speakers else None, ) return state.model_dump() @@ -128,7 +131,12 @@ async def load_state(self, state: Mapping[str, Any]) -> None: self._model_context, [msg for msg in self._message_thread if isinstance(msg, BaseChatMessage)] ) self._current_turn = selector_state.current_turn - self._previous_speaker = selector_state.previous_speaker + if selector_state.previous_speakers is not None: + self._previous_speakers = selector_state.previous_speakers + elif selector_state.previous_speaker is not None: + self._previous_speakers = [selector_state.previous_speaker] + else: + self._previous_speakers = [] @staticmethod async def _add_messages_to_context( @@ -150,12 +158,13 @@ async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseCh await self._add_messages_to_context(self._model_context, base_chat_messages) async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str: - """Selects the next speaker in a group chat using a ChatCompletion client, - with the selector function as override if it returns a speaker name. + """Selects the next speaker(s) in a group chat using a ChatCompletion client, + with the selector function as override if it returns speaker name(s). .. note:: - This method always returns a single speaker name. + When ``max_concurrent_speakers`` is 1 (default), this method returns a single speaker. + When ``max_concurrent_speakers`` > 1, it may return multiple speakers to run concurrently. A key assumption is that the agent type is the same as the topic type, which we use as the agent name. """ @@ -168,13 +177,17 @@ async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage sync_selector_func = cast(SyncSelectorFunc, self._selector_func) speaker = sync_selector_func(thread) if speaker is not None: - if speaker not in self._participant_names: - raise ValueError( - f"Selector function returned an invalid speaker name: {speaker}. " - f"Expected one of: {self._participant_names}." - ) + # Normalize to list. + speakers = [speaker] if isinstance(speaker, str) else speaker + for s in speakers: + if s not in self._participant_names: + raise ValueError( + f"Selector function returned an invalid speaker name: {s}. " + f"Expected one of: {self._participant_names}." + ) # Skip the model based selection. - return [speaker] + self._previous_speakers = speakers + return speakers # Use the candidate function to filter participants if provided if self._candidate_func is not None: @@ -192,27 +205,35 @@ async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage f"Expected one of: {self._participant_names}." ) else: - # Construct the candidate agent list to be selected from, skip the previous speaker if not allowed. - if self._previous_speaker is not None and not self._allow_repeated_speaker: - participants = [p for p in self._participant_names if p != self._previous_speaker] + # Construct the candidate agent list to be selected from, skip previous speakers if not allowed. + if self._previous_speakers and not self._allow_repeated_speaker: + participants = [p for p in self._participant_names if p not in self._previous_speakers] else: participants = list(self._participant_names) assert len(participants) > 0 # Construct agent roles. - # Each agent sould appear on a single line. + # Each agent should appear on a single line. roles = "" for topic_type, description in zip(self._participant_names, self._participant_descriptions, strict=True): roles += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n" roles = roles.strip() - # Select the next speaker. + # Select the next speaker(s). if len(participants) > 1: - agent_name = await self._select_speaker(roles, participants, self._max_selector_attempts) + if self._max_concurrent_speakers > 1: + agent_names = await self._select_multiple_speakers( + roles, participants, self._max_selector_attempts + ) + self._previous_speakers = agent_names + trace_logger.debug(f"Selected speakers: {agent_names}") + return agent_names + else: + agent_name = await self._select_speaker(roles, participants, self._max_selector_attempts) else: agent_name = participants[0] - self._previous_speaker = agent_name + self._previous_speakers = [agent_name] trace_logger.debug(f"Selected speaker: {agent_name}") return [agent_name] @@ -286,10 +307,10 @@ async def _select_speaker(self, roles: str, participants: List[str], max_attempt agent_name = list(mentions.keys())[0] if ( not self._allow_repeated_speaker - and self._previous_speaker is not None - and agent_name == self._previous_speaker + and self._previous_speakers + and agent_name in self._previous_speakers ): - trace_logger.debug(f"Model selected the previous speaker: {agent_name} (attempt {num_attempts})") + trace_logger.debug(f"Model selected a previous speaker: {agent_name} (attempt {num_attempts})") feedback = ( f"Repeated speaker is not allowed, please select a different name from: {str(participants)}." ) @@ -299,14 +320,101 @@ async def _select_speaker(self, roles: str, participants: List[str], max_attempt trace_logger.debug(f"Model selected a valid name: {agent_name} (attempt {num_attempts})") return agent_name - if self._previous_speaker is not None: - trace_logger.warning(f"Model failed to select a speaker after {max_attempts}, using the previous speaker.") - return self._previous_speaker + if self._previous_speakers: + trace_logger.warning(f"Model failed to select a speaker after {max_attempts}, using a previous speaker.") + return self._previous_speakers[0] trace_logger.warning( f"Model failed to select a speaker after {max_attempts} and there was no previous speaker, using the first participant." ) return participants[0] + async def _select_multiple_speakers(self, roles: str, participants: List[str], max_attempts: int) -> List[str]: + """Select multiple speakers concurrently using the model.""" + model_context_messages = await self._model_context.get_messages() + model_context_history = self.construct_message_history(model_context_messages) + + concurrent_prompt = ( + f"You are in a role play game. The following roles are available:\n{roles}.\n" + f"Read the following conversation. Then select up to {self._max_concurrent_speakers} roles " + f"from {str(participants)} to play next. These roles will respond concurrently.\n" + f"Return the selected role names separated by commas.\n\n" + f"{model_context_history}\n\n" + f"Read the above conversation. Then select up to {self._max_concurrent_speakers} roles " + f"from {str(participants)} to respond concurrently. Return only the role names, separated by commas." + ) + + select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage] + if ModelFamily.is_openai(self._model_client.model_info["family"]): + select_speaker_messages = [SystemMessage(content=concurrent_prompt)] + else: + select_speaker_messages = [UserMessage(content=concurrent_prompt, source="user")] + + num_attempts = 0 + while num_attempts < max_attempts: + num_attempts += 1 + if self._model_client_streaming: + chunk: CreateResult | str = "" + async for _chunk in self._model_client.create_stream(messages=select_speaker_messages): + chunk = _chunk + if self._emit_team_events: + if isinstance(chunk, str): + await self._output_message_queue.put( + ModelClientStreamingChunkEvent(content=cast(str, _chunk), source=self._name) + ) + else: + assert isinstance(chunk, CreateResult) + assert isinstance(chunk.content, str) + await self._output_message_queue.put( + SelectorEvent(content=chunk.content, source=self._name) + ) + assert isinstance(chunk, CreateResult) + response = chunk + else: + response = await self._model_client.create(messages=select_speaker_messages) + assert isinstance(response.content, str) + select_speaker_messages.append(AssistantMessage(content=response.content, source="selector")) + mentions = self._mentioned_agents(response.content, self._participant_names) + if len(mentions) == 0: + trace_logger.debug(f"Model failed to select valid names: {response.content} (attempt {num_attempts})") + feedback = f"No valid name was mentioned. Please select from: {str(participants)}." + select_speaker_messages.append(UserMessage(content=feedback, source="user")) + continue + + # Filter to only valid candidates. + selected = [name for name in mentions if name in participants] + if not selected: + trace_logger.debug(f"Model selected non-candidate names: {list(mentions.keys())} (attempt {num_attempts})") + feedback = f"Please select from the candidates: {str(participants)}." + select_speaker_messages.append(UserMessage(content=feedback, source="user")) + continue + + # Filter out previous speakers if not allowed. + if not self._allow_repeated_speaker and self._previous_speakers: + selected = [s for s in selected if s not in self._previous_speakers] + if not selected: + trace_logger.debug(f"Model selected only previous speakers (attempt {num_attempts})") + feedback = ( + f"Repeated speakers are not allowed, please select different names from: {str(participants)}." + ) + select_speaker_messages.append(UserMessage(content=feedback, source="user")) + continue + + # Limit to max concurrent speakers. + selected = selected[: self._max_concurrent_speakers] + trace_logger.debug(f"Model selected valid names: {selected} (attempt {num_attempts})") + return selected + + # Fallback: use first participant. + if self._previous_speakers: + trace_logger.warning( + f"Model failed to select speakers after {max_attempts}, using a previous speaker." + ) + return [self._previous_speakers[0]] + trace_logger.warning( + f"Model failed to select speakers after {max_attempts}, using the first participant." + ) + return [participants[0]] + def _mentioned_agents(self, message_content: str, agent_names: List[str]) -> Dict[str, int]: """Counts the number of times each agent is mentioned in the provided message content. Agent names will match under any of the following conditions (all case-sensitive): @@ -357,6 +465,7 @@ class SelectorGroupChatConfig(BaseModel): emit_team_events: bool = False model_client_streaming: bool = False model_context: ComponentModel | None = None + max_concurrent_speakers: int = 1 class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): @@ -397,10 +506,11 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): max_selector_attempts (int, optional): The maximum number of attempts to select a speaker using the model. Defaults to 3. If the model fails to select a speaker after the maximum number of attempts, the previous speaker will be used if available, otherwise the first participant will be used. - selector_func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None], Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | None]], optional): A custom selector - function that takes the conversation history and returns the name of the next speaker. - If provided, this function will be used to override the model to select the next speaker. + selector_func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | list[str] | None], Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | list[str] | None]], optional): A custom selector + function that takes the conversation history and returns the name of the next speaker(s). + If provided, this function will be used to override the model to select the next speaker(s). If the function returns None, the model will be used to select the next speaker. + The function may return a list of speaker names to enable concurrent responses. NOTE: `selector_func` is not serializable and will be ignored during serialization and deserialization process. candidate_func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], List[str]], Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[List[str]]], optional): A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker @@ -411,6 +521,10 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): Make sure your custom message types are subclasses of :class:`~autogen_agentchat.messages.BaseAgentEvent` or :class:`~autogen_agentchat.messages.BaseChatMessage`. emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False. model_client_streaming (bool, optional): Whether to use streaming for the model client. (This is useful for reasoning models like QwQ). Defaults to False. + max_concurrent_speakers (int, optional): The maximum number of speakers that can respond concurrently in a single turn. + Defaults to 1 (sequential speaker selection). When set to a value greater than 1, the model or selector function + may select multiple speakers to respond concurrently. All selected speakers must respond before the next + selection round begins. model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. Messages stored in model context will be used for speaker selection. The initial messages will be cleared when the team is reset. @@ -620,7 +734,10 @@ def __init__( emit_team_events: bool = False, model_client_streaming: bool = False, model_context: ChatCompletionContext | None = None, + max_concurrent_speakers: int = 1, ): + if max_concurrent_speakers < 1: + raise ValueError("max_concurrent_speakers must be at least 1.") super().__init__( name=name or self.DEFAULT_NAME, description=description or self.DEFAULT_DESCRIPTION, @@ -644,6 +761,7 @@ def __init__( self._candidate_func = candidate_func self._model_client_streaming = model_client_streaming self._model_context = model_context + self._max_concurrent_speakers = max_concurrent_speakers def _create_group_chat_manager_factory( self, @@ -678,6 +796,7 @@ def _create_group_chat_manager_factory( self._emit_team_events, self._model_context, self._model_client_streaming, + self._max_concurrent_speakers, ) def _to_config(self) -> SelectorGroupChatConfig: @@ -695,6 +814,7 @@ def _to_config(self) -> SelectorGroupChatConfig: emit_team_events=self._emit_team_events, model_client_streaming=self._model_client_streaming, model_context=self._model_context.dump_component() if self._model_context else None, + max_concurrent_speakers=self._max_concurrent_speakers, ) @classmethod @@ -727,4 +847,5 @@ def _from_config(cls, config: SelectorGroupChatConfig) -> Self: emit_team_events=config.emit_team_events, model_client_streaming=config.model_client_streaming, model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None, + max_concurrent_speakers=config.max_concurrent_speakers, ) diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 3ded2e0c2e60..d01a99834f96 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -995,7 +995,7 @@ async def test_selector_group_chat_state(task: TaskType, runtime: AgentRuntime | SelectorGroupChatManager, # pyright: ignore ) # pyright: ignore assert manager_1._message_thread == manager_2._message_thread # pyright: ignore - assert manager_1._previous_speaker == manager_2._previous_speaker # pyright: ignore + assert manager_1._previous_speakers == manager_2._previous_speakers # pyright: ignore @pytest.mark.asyncio @@ -1239,6 +1239,74 @@ def _candidate_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> Lis ) +@pytest.mark.asyncio +async def test_selector_group_chat_concurrent_speakers_with_selector_func(runtime: AgentRuntime | None) -> None: + """Test concurrent speakers using a selector function that returns multiple speaker names.""" + model_client = ReplayChatCompletionClient(["agent1"]) + agent1 = _EchoAgent("agent1", description="echo agent 1") + agent2 = _EchoAgent("agent2", description="echo agent 2") + agent3 = _EchoAgent("agent3", description="echo agent 3") + + def _select_concurrent(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> list[str] | str | None: + if len(messages) <= 1: + # First turn (only task message in thread): select agent1 and agent2 concurrently. + return ["agent1", "agent2"] + # After concurrent responses, select agent3 alone. + return "agent3" + + termination = MaxMessageTermination(5) + team = SelectorGroupChat( + participants=[agent1, agent2, agent3], + model_client=model_client, + selector_func=_select_concurrent, + termination_condition=termination, + runtime=runtime, + ) + result = await team.run(task="concurrent task") + # Messages: task, agent1, agent2 (concurrent), then agent3, then again based on selector. + assert len(result.messages) >= 4 + assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].content == "concurrent task" + # agent1 and agent2 should both appear in the first round (order may vary). + first_round_sources = {result.messages[1].source, result.messages[2].source} + assert first_round_sources == {"agent1", "agent2"} + # Next speaker should be agent3. + assert result.messages[3].source == "agent3" + + +@pytest.mark.asyncio +async def test_selector_group_chat_concurrent_speakers_with_model(runtime: AgentRuntime | None) -> None: + """Test concurrent speakers using model-based selection with max_concurrent_speakers > 1.""" + # Model returns two names in first selection, then one name in subsequent selections. + model_client = ReplayChatCompletionClient( + ["agent1, agent2", "agent3", "agent1"], + ) + agent1 = _EchoAgent("agent1", description="echo agent 1") + agent2 = _EchoAgent("agent2", description="echo agent 2") + agent3 = _EchoAgent("agent3", description="echo agent 3") + + # MaxMessageTermination(3): task(1) + agent2_delta(2) + agent3_delta(3) = terminates + # Result will have 4 messages: task, agent1, agent2, agent3 + termination = MaxMessageTermination(3) + team = SelectorGroupChat( + participants=[agent1, agent2, agent3], + model_client=model_client, + termination_condition=termination, + runtime=runtime, + max_concurrent_speakers=2, + allow_repeated_speaker=True, + ) + result = await team.run(task="concurrent task") + assert len(result.messages) == 4 + assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].content == "concurrent task" + # First round: agent1 and agent2 concurrently. + first_round_sources = {result.messages[1].source, result.messages[2].source} + assert first_round_sources == {"agent1", "agent2"} + # Second round: agent3. + assert result.messages[3].source == "agent3" + + class _HandOffAgent(BaseChatAgent): def __init__(self, name: str, description: str, next_agent: str) -> None: super().__init__(name, description)