Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 182 additions & 74 deletions camel/societies/workforce/single_agent_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Dict,
List,
Optional,
Tuple,
)

from colorama import Fore
Expand All @@ -42,7 +43,6 @@
WorkflowMemoryManager,
)
from camel.tasks.task import Task, TaskState, is_task_result_insufficient
from camel.utils import consume_response_content_async
from camel.utils.context_utils import ContextUtility

if TYPE_CHECKING:
Expand Down Expand Up @@ -151,6 +151,7 @@ async def return_agent(self, agent: ChatAgent) -> None:
# Only add back to pool if under max size
if len(self._available_agents) < self.max_size:
agent.reset()
agent._execution_context = {}
self._agent_last_used[agent_id] = time.time()
self._available_agents.append(agent)
# Notify one waiting coroutine that an agent is available
Expand Down Expand Up @@ -346,6 +347,154 @@ def _get_workflow_manager(self) -> WorkflowMemoryManager:
)
return self._workflow_manager

@staticmethod
def _normalize_stream_contents(contents: List[str]) -> str:
if not contents:
return ""
if len(contents) == 1:
return contents[0]
is_cumulative = all(
len(contents[index + 1]) > len(contents[index])
and contents[index + 1].startswith(contents[index])
for index in range(len(contents) - 1)
)
if is_cumulative:
return contents[-1]
return "".join(contents)

async def _consume_worker_response(
self,
response: Any,
stream_callback: Optional[
Callable[["ChatAgentResponse"], Optional[Awaitable[None]]]
] = None,
) -> Tuple[Any, str, Optional[TaskResult]]:
if isinstance(response, AsyncStreamingChatAgentResponse):
final_response = None
parsed_result: Optional[TaskResult] = None
contents: List[str] = []

async for chunk in response:
final_response = chunk
if stream_callback:
maybe = stream_callback(chunk)
if asyncio.iscoroutine(maybe):
await maybe
if chunk.msg:
if chunk.msg.content is not None:
contents.append(chunk.msg.content)
if chunk.msg.parsed is not None:
parsed = chunk.msg.parsed
if isinstance(parsed, TaskResult):
parsed_result = parsed
else:
parsed_result = TaskResult.model_validate(parsed)

if final_response is None:
from camel.responses import ChatAgentResponse

final_response = ChatAgentResponse(
msgs=[], terminated=False, info={}
)
return (
final_response,
self._normalize_stream_contents(contents),
parsed_result,
)

content = response.msg.content if response.msg else ""
parsed_result = None
if response.msg and response.msg.parsed is not None:
parsed = response.msg.parsed
if isinstance(parsed, TaskResult):
parsed_result = parsed
else:
parsed_result = TaskResult.model_validate(parsed)
return response, content, parsed_result

def _build_execution_context(
self, task: Task, worker_agent: ChatAgent
) -> Dict[str, Any]:
existing_attempts = (
len(task.additional_info.get("worker_attempts", []))
if task.additional_info
else 0
)
return {
"task_id": task.id,
"parent_task_id": task.parent.id if task.parent else None,
"worker_node_id": self.node_id,
"worker_description": self.description,
"original_worker_id": getattr(
self.worker, "agent_id", self.worker.role_name
),
"borrowed_agent_id": getattr(
worker_agent, "agent_id", worker_agent.role_name
),
"agent_id": getattr(
worker_agent, "agent_id", worker_agent.role_name
),
"attempt": existing_attempts + 1,
}

def _attach_execution_context(
self, task: Task, worker_agent: ChatAgent
) -> Dict[str, Any]:
execution_context = self._build_execution_context(task, worker_agent)
worker_agent._execution_context = execution_context.copy()

if task.additional_info is None:
task.additional_info = {}
task.additional_info["execution_context"] = execution_context.copy()
return execution_context

def _record_worker_attempt(
self,
task: Task,
worker_agent: ChatAgent,
response_content: str,
final_response: Any,
token_usage: Dict[str, Any],
execution_context: Dict[str, Any],
) -> None:
if task.additional_info is None:
task.additional_info = {}

tool_calls = (
final_response.info.get("tool_calls", [])
if hasattr(final_response, "info")
else []
)
tool_calls_summary = str(tool_calls)[:200]
worker_attempt_details = {
"attempt": execution_context["attempt"],
"agent_id": getattr(
worker_agent, "agent_id", worker_agent.role_name
),
"original_worker_id": getattr(
self.worker, "agent_id", self.worker.role_name
),
"worker_node_id": self.node_id,
"timestamp": str(datetime.datetime.now()),
"description": f"Attempt by "
f"{getattr(worker_agent, 'agent_id', worker_agent.role_name)} "
f"(from pool/clone of "
f"{getattr(self.worker, 'agent_id', self.worker.role_name)}) "
f"to process task: {task.content}",
"execution_context": execution_context.copy(),
"response_content_summary": response_content[:200],
"tool_calls_summary": tool_calls_summary,
"tool_call_count": (
len(tool_calls) if isinstance(tool_calls, list) else 0
),
"token_usage": token_usage.copy(),
}

task.additional_info.setdefault("worker_attempts", []).append(
worker_attempt_details
)
task.additional_info["token_usage"] = token_usage.copy()

async def _process_task(
self,
task: Task,
Expand Down Expand Up @@ -375,6 +524,7 @@ async def _process_task(
"""
# Get agent efficiently (from pool or by cloning)
worker_agent = await self._get_worker_agent()
execution_context = self._attach_execution_context(task, worker_agent)
response_content = ""

try:
Expand Down Expand Up @@ -410,13 +560,16 @@ async def _process_task(
)
response = await worker_agent.astep(enhanced_prompt)

# Handle streaming response
# Normalize streamed vs non-streamed content for logging and
# structured parsing.
(
response,
final_response,
response_content,
) = await consume_response_content_async(
_,
) = await self._consume_worker_response(
response, stream_callback
)
response = final_response
task_result = (
self.structured_handler.parse_structured_response(
response_text=response_content,
Expand All @@ -432,46 +585,27 @@ async def _process_task(
response = await worker_agent.astep(
prompt, response_format=TaskResult
)
(
final_response,
response_content,
parsed_result,
) = await self._consume_worker_response(
response, stream_callback
)
response = final_response

# Handle streaming response for native output
if isinstance(response, AsyncStreamingChatAgentResponse):
task_result = None
async for chunk in response:
if stream_callback:
maybe = stream_callback(chunk)
if asyncio.iscoroutine(maybe):
await maybe
if chunk.msg and chunk.msg.parsed:
task_result = chunk.msg.parsed
response_content = chunk.msg.content
# If no parsed result found in streaming, create fallback
if task_result is None:
task_result = TaskResult(
content="Failed to parse streaming response",
failed=True,
)
else:
# Regular ChatAgentResponse
task_result = response.msg.parsed
response_content = (
response.msg.content if response.msg else ""
if parsed_result is None:
task_result = TaskResult(
content="Failed to parse streaming response",
failed=True,
)
else:
task_result = parsed_result

# Get token usage from the response
if isinstance(response, AsyncStreamingChatAgentResponse):
# For streaming responses, get the final response info
final_response = await response
usage_info = final_response.info.get(
"usage"
) or final_response.info.get("token_usage")
else:
final_response = response
usage_info = response.info.get("usage") or response.info.get(
"token_usage"
)
total_tokens = (
usage_info.get("total_tokens", 0) if usage_info else 0
usage_info = response.info.get("usage") or response.info.get(
"token_usage"
)
token_usage = usage_info or {"total_tokens": 0}

# collect conversation from working agent to
# accumulator for workflow memory
Expand Down Expand Up @@ -511,40 +645,14 @@ async def _process_task(
# Return agent to pool or let it be garbage collected
await self._return_worker_agent(worker_agent)

# Populate additional_info with worker attempt details
if task.additional_info is None:
task.additional_info = {}

# Create worker attempt details with descriptive keys
worker_attempt_details = {
"agent_id": getattr(
worker_agent, "agent_id", worker_agent.role_name
),
"original_worker_id": getattr(
self.worker, "agent_id", self.worker.role_name
),
"timestamp": str(datetime.datetime.now()),
"description": f"Attempt by "
f"{getattr(worker_agent, 'agent_id', worker_agent.role_name)} "
f"(from pool/clone of "
f"{getattr(self.worker, 'agent_id', self.worker.role_name)}) "
f"to process task: {task.content}",
"response_content": response_content[:50],
"tool_calls": str(
final_response.info.get("tool_calls")
if isinstance(response, AsyncStreamingChatAgentResponse)
else response.info.get("tool_calls")
)[:50],
"total_tokens": total_tokens,
}

# Store the worker attempt in additional_info
if "worker_attempts" not in task.additional_info:
task.additional_info["worker_attempts"] = []
task.additional_info["worker_attempts"].append(worker_attempt_details)

# Store the actual token usage for this specific task
task.additional_info["token_usage"] = {"total_tokens": total_tokens}
self._record_worker_attempt(
task=task,
worker_agent=worker_agent,
response_content=response_content,
final_response=response,
token_usage=token_usage,
execution_context=execution_context,
)

print(f"======\n{Fore.GREEN}Response from {self}:{Fore.RESET}")
logger.info(f"Response from {self}:")
Expand Down
Loading
Loading