Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 112 additions & 25 deletions camel/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2026 @ CAMEL-AI.org. All Rights Reserved. =========
import copy
import re
import textwrap
from typing import Type
from typing import Any, Dict, List, Optional, Type

from pydantic import BaseModel

Expand Down Expand Up @@ -64,39 +65,125 @@ def extract_thinking_from_content(
return content, reasoning_content


def try_modify_message_with_format(
message: OpenAIMessage,
def with_response_format_system_message(
messages: List[OpenAIMessage],
response_format: Type[BaseModel] | None,
) -> None:
r"""Modifies the content of the message to include the instruction of using
the response format.
) -> List[OpenAIMessage]:
r"""Return a request-scoped copy with format instructions in system.

The message will not be modified in the following cases:
- response_format is None
- message content is not a string
- message role is assistant

Args:
response_format (Type[BaseModel] | None): The Pydantic model class.
message (OpenAIMessage): The message to be modified.
The JSON-format instruction is treated as runtime policy rather than user
content, so it is merged into the first text system message when present,
or prepended as a new system message otherwise.
"""
if response_format is None:
return

if not isinstance(message["content"], str):
return
return messages

request_messages = copy.deepcopy(messages)
instruction = _format_instruction_for_response_format(response_format)

for message in request_messages:
content = message.get("content")
if message.get("role") == "system" and isinstance(content, str):
content = content.rstrip()
message["content"] = (
f"{content}\n\n{instruction}" if content else instruction
)
return request_messages

request_messages.insert(
0,
{
"role": "system",
"content": instruction,
},
)
return request_messages

if message["role"] == "assistant":
return

def _format_instruction_for_response_format(
response_format: Type[BaseModel],
) -> str:
r"""Build the text instruction used for prompt-based structured output."""
json_schema = response_format.model_json_schema()
updated_prompt = textwrap.dedent(
return textwrap.dedent(
f"""\
{message["content"]}

Please generate a JSON response adhering to the following JSON schema:
{json_schema}
Make sure the JSON response is valid and matches the EXACT structure defined in the schema. Your result should ONLY be a valid json object, WITHOUT ANY OTHER TEXT OR COMMENTS.
""" # noqa: E501
)
message["content"] = updated_prompt
).strip()


def pydantic_to_json_schema_response_format(
response_format: Type[BaseModel],
) -> Dict[str, Any]:
r"""Convert a Pydantic model class to a ``json_schema`` response_format
dict suitable for ``chat.completions.create()``.

The returned dict has the shape::

{
"type": "json_schema",
"json_schema": {
"name": "<ModelClassName>",
"schema": { ... }
}
}

Args:
response_format (Type[BaseModel]): The Pydantic model class.

Returns:
Dict[str, Any]: The response_format dict for the API call.
"""
schema = response_format.model_json_schema()
_enforce_object_additional_properties_false(schema)
return {
"type": "json_schema",
"json_schema": {
"name": response_format.__name__,
"schema": schema,
},
}


def _enforce_object_additional_properties_false(schema: Any) -> None:
r"""Recursively enforce strict object schemas.

OpenAI-compatible structured-output backends frequently reject object
schemas that omit ``additionalProperties``. Mirror the stricter OpenAI
Responses handling so the json_schema fallback remains usable for nested
Pydantic models.
"""
if isinstance(schema, dict):
if (
schema.get("type") == "object"
and "additionalProperties" not in schema
):
schema["additionalProperties"] = False

for value in schema.values():
_enforce_object_additional_properties_false(value)
elif isinstance(schema, list):
for item in schema:
_enforce_object_additional_properties_false(item)


def parse_json_response_to_pydantic(
content: Optional[str],
response_format: Type[BaseModel],
) -> Optional[BaseModel]:
r"""Parse a JSON string returned by the model into a Pydantic instance.

Args:
content (Optional[str]): The raw JSON string from the model response.
response_format (Type[BaseModel]): The Pydantic model class to
validate against.

Returns:
Optional[BaseModel]: The validated Pydantic instance, or ``None``
if *content* is ``None`` or empty.
"""
if not content:
return None
return response_format.model_validate_json(content)
12 changes: 6 additions & 6 deletions camel/models/azure_openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class AzureOpenAIModel(BaseModelBackend):
creating a new one. Useful for RL frameworks like AReaL or rLLM
that provide Azure OpenAI-compatible clients. The client should
implement the AzureOpenAI client interface with
`.chat.completions.create()` and `.beta.chat.completions.parse()`
`.chat.completions.create()` and `.chat.completions.parse()`
methods. (default: :obj:`None`)
async_client (Optional[Any], optional): A custom asynchronous
AzureOpenAI client instance. If provided, this client will be
Expand Down Expand Up @@ -380,7 +380,7 @@ def _request_parse(
request_config.pop("stream", None)

return self._call_client(
self._client.beta.chat.completions.parse,
self._client.chat.completions.parse,
messages=messages,
model=str(self.model_type),
**request_config,
Expand All @@ -399,7 +399,7 @@ async def _arequest_parse(
request_config.pop("stream", None)

return await self._acall_client(
self._async_client.beta.chat.completions.parse,
self._async_client.chat.completions.parse,
messages=messages,
model=str(self.model_type),
**request_config,
Expand All @@ -413,7 +413,7 @@ def _request_stream_parse(
) -> ChatCompletionStreamManager[BaseModel]:
r"""Request streaming structured output parsing.

Note: This uses OpenAI's beta streaming API for structured outputs.
Note: This uses OpenAI's streaming API for structured outputs.
"""
request_config = self._prepare_request_config(tools)
# Remove stream from config as it's handled by the stream method
Expand All @@ -436,14 +436,14 @@ async def _arequest_stream_parse(
) -> AsyncChatCompletionStreamManager[BaseModel]:
r"""Request async streaming structured output parsing.

Note: This uses OpenAI's beta streaming API for structured outputs.
Note: This uses OpenAI's streaming API for structured outputs.
"""
request_config = self._prepare_request_config(tools)
# Remove stream from config as it's handled by the stream method
request_config.pop("stream", None)

# Use the beta streaming API for structured outputs
return self._call_client(
return await self._acall_client(
self._async_client.beta.chat.completions.stream,
messages=messages,
model=str(self.model_type),
Expand Down
5 changes: 4 additions & 1 deletion camel/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,10 @@ async def _acall_client(self, call: Any, *args: Any, **kwargs: Any) -> Any:
)
if self._should_sync_request_log(normalized_kwargs):
self._sync_request_log_with_client_kwargs(normalized_kwargs)
return await call(*args, **kwargs)
result = call(*args, **kwargs)
if inspect.isawaitable(result):
return await result
return result

def _should_sync_request_log(self, kwargs: Dict[str, Any]) -> bool:
r"""Check whether a client call should sync request logs."""
Expand Down
25 changes: 16 additions & 9 deletions camel/models/cohere_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from camel.configs import CohereConfig
from camel.messages import OpenAIMessage
from camel.models import BaseModelBackend
from camel.models._utils import try_modify_message_with_format
from camel.models._utils import with_response_format_system_message
from camel.types import ChatCompletion, ModelType
from camel.utils import (
BaseTokenCounter,
Expand Down Expand Up @@ -273,10 +273,11 @@ def _prepare_request(
messages: List[OpenAIMessage],
response_format: Optional[Type[BaseModel]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
) -> tuple[List[OpenAIMessage], Dict[str, Any]]:
import copy

request_config = copy.deepcopy(self.model_config_dict)
request_messages = messages
# Remove strict from each tool's function parameters since Cohere does
# not support them
if tools:
Expand All @@ -285,10 +286,16 @@ def _prepare_request(
function_dict.pop("strict", None)
request_config["tools"] = tools
elif response_format:
try_modify_message_with_format(messages[-1], response_format)
request_config["response_format"] = {"type": "json_object"}
request_messages = with_response_format_system_message(
messages, response_format
)
schema = response_format.model_json_schema()
request_config["response_format"] = {
"type": "json_object",
"schema": schema,
}

return request_config
return request_messages, request_config

@observe(as_type="generation")
def _run(
Expand Down Expand Up @@ -317,11 +324,11 @@ def _run(

from cohere.core.api_error import ApiError

request_config = self._prepare_request(
request_messages, request_config = self._prepare_request(
messages, response_format, tools
)

cohere_messages = self._to_cohere_chatmessage(messages)
cohere_messages = self._to_cohere_chatmessage(request_messages)

try:
response = self._call_client(
Expand Down Expand Up @@ -387,11 +394,11 @@ async def _arun(

from cohere.core.api_error import ApiError

request_config = self._prepare_request(
request_messages, request_config = self._prepare_request(
messages, response_format, tools
)

cohere_messages = self._to_cohere_chatmessage(messages)
cohere_messages = self._to_cohere_chatmessage(request_messages)

try:
response = await self._acall_client(
Expand Down
Loading
Loading