Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions src/matilda_brain/tools/builtins/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
"""

import ast
import asyncio
import math
import operator
import os
import shutil
import subprocess
import tempfile
from typing import Any, Callable, Optional

from matilda_brain.tools import tool

from .config import _get_code_timeout, _get_timeout_bounds, _safe_execute
from .config import _get_code_timeout, _get_timeout_bounds, _safe_execute_async

# Allowed math functions and constants
ALLOWED_MATH_NAMES = {
Expand Down Expand Up @@ -160,7 +162,7 @@ def visit_Call(self, node: ast.Call) -> Any:


@tool(category="code", description="Execute Python code safely in a sandboxed environment")
def run_python(code: str, timeout: Optional[int] = None) -> str:
async def run_python(code: str, timeout: Optional[int] = None) -> str:
"""Execute Python code safely.

Args:
Expand All @@ -171,7 +173,7 @@ def run_python(code: str, timeout: Optional[int] = None) -> str:
Output of the code execution or error message
"""

def _run_python_impl(code: str, timeout: Optional[int] = None) -> str:
async def _run_python_impl(code: str, timeout: Optional[int] = None) -> str:
if timeout is None:
timeout = _get_code_timeout()

Expand All @@ -191,34 +193,42 @@ def _run_python_impl(code: str, timeout: Optional[int] = None) -> str:
try:
# Run code in subprocess with timeout
# Try python3 first, then python
python_cmd = (
"python3" if subprocess.run(["which", "python3"], capture_output=True).returncode == 0 else "python"
)
python_cmd = shutil.which("python3") or "python"

result = subprocess.run(
[python_cmd, temp_file],
capture_output=True,
text=True,
timeout=timeout,
check=False,
proc = await asyncio.create_subprocess_exec(
python_cmd,
temp_file,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)

try:
stdout_data, stderr_data = await asyncio.wait_for(proc.communicate(), timeout=timeout)
except asyncio.TimeoutError:
proc.kill()
await proc.communicate()
stdout_data, stderr_data = b"", b"Execution timed out"

stdout = stdout_data.decode()
stderr = stderr_data.decode()

output = []
if result.stdout:
output.append(result.stdout)
if result.stderr:
output.append(f"Errors:\n{result.stderr}")
if stdout:
output.append(stdout)
if stderr:
output.append(f"Errors:\n{stderr}")

if result.returncode != 0:
output.append(f"Exit code: {result.returncode}")
if proc.returncode is not None and proc.returncode != 0:
output.append(f"Exit code: {proc.returncode}")

return "\n".join(output) if output else "Code executed successfully (no output)"

finally:
# Clean up
os.unlink(temp_file)
if os.path.exists(temp_file):
os.unlink(temp_file)

return _safe_execute("run_python", _run_python_impl, code=code, timeout=timeout)
return await _safe_execute_async("run_python", _run_python_impl, code=code, timeout=timeout)


@tool(category="math", description="Perform mathematical calculations safely")
Expand Down
106 changes: 67 additions & 39 deletions src/matilda_brain/tools/builtins/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
used across all built-in tools.
"""

from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Optional, Tuple

from matilda_brain.config.schema import get_config

Expand Down Expand Up @@ -89,53 +89,80 @@ def _get_timeout_bounds() -> tuple:
return (1, 30) # Fallback to constants values


def _sanitize_kwargs(kwargs: Dict[str, Any]) -> Tuple[Dict[str, Any], Optional[str]]:
"""Sanitize keyword arguments."""
sanitized_kwargs = {}
for key, value in kwargs.items():
if key in ["file_path", "path"] and isinstance(value, str):
try:
sanitized_kwargs[key] = str(InputSanitizer.sanitize_path(value))
except ValueError as e:
return {}, f"Error: Invalid path '{value}': {e}"
elif key in ["url"] and isinstance(value, str):
try:
sanitized_kwargs[key] = InputSanitizer.sanitize_url(value)
except ValueError as e:
return {}, f"Error: Invalid URL '{value}': {e}"
elif key in ["query", "code", "expression", "content"] and isinstance(value, str):
try:
# Allow code for these contexts
allow_code = key in ["code", "expression"]
sanitized_kwargs[key] = InputSanitizer.sanitize_string(value, allow_code=allow_code)
except ValueError as e:
return {}, f"Error: Invalid input '{key}': {e}"
else:
sanitized_kwargs[key] = value
return sanitized_kwargs, None


def _handle_error(func_name: str, e: Exception) -> str:
"""Handle exceptions with error recovery system."""
# Classify error and provide helpful message
error_pattern = recovery_system.classify_error(str(e))

# Create user-friendly error message
if error_pattern.error_type.value == "network_error":
return f"Network Error: {error_pattern.message}\n{error_pattern.suggested_action}"
elif error_pattern.error_type.value == "permission_error":
return f"Permission Error: {error_pattern.message}\n{error_pattern.suggested_action}"
elif error_pattern.error_type.value == "resource_error":
return f"Resource Error: {error_pattern.message}\n{error_pattern.suggested_action}"
elif error_pattern.error_type.value == "timeout_error":
return f"Timeout Error: {error_pattern.message}\n{error_pattern.suggested_action}"
elif error_pattern.error_type.value == "validation_error":
return f"Validation Error: {error_pattern.message}\n{error_pattern.suggested_action}"
else:
return f"Error in {func_name}: {str(e)}\n{error_pattern.suggested_action}"


def _safe_execute(func_name: str, func: Callable[..., Any], **kwargs: Any) -> str:
"""Execute a function with error recovery and input sanitization."""
try:
# Sanitize arguments
sanitized_kwargs = {}
for key, value in kwargs.items():
if key in ["file_path", "path"] and isinstance(value, str):
try:
sanitized_kwargs[key] = str(InputSanitizer.sanitize_path(value))
except ValueError as e:
return f"Error: Invalid path '{value}': {e}"
elif key in ["url"] and isinstance(value, str):
try:
sanitized_kwargs[key] = InputSanitizer.sanitize_url(value)
except ValueError as e:
return f"Error: Invalid URL '{value}': {e}"
elif key in ["query", "code", "expression", "content"] and isinstance(value, str):
try:
# Allow code for these contexts
allow_code = key in ["code", "expression"]
sanitized_kwargs[key] = InputSanitizer.sanitize_string(value, allow_code=allow_code)
except ValueError as e:
return f"Error: Invalid input '{key}': {e}"
else:
sanitized_kwargs[key] = value
sanitized_kwargs, error = _sanitize_kwargs(kwargs)
if error:
return error

# Execute with enhanced error handling
result = func(**sanitized_kwargs)
return str(result)

except Exception as e:
# Classify error and provide helpful message
error_pattern = recovery_system.classify_error(str(e))

# Create user-friendly error message
if error_pattern.error_type.value == "network_error":
return f"Network Error: {error_pattern.message}\n{error_pattern.suggested_action}"
elif error_pattern.error_type.value == "permission_error":
return f"Permission Error: {error_pattern.message}\n{error_pattern.suggested_action}"
elif error_pattern.error_type.value == "resource_error":
return f"Resource Error: {error_pattern.message}\n{error_pattern.suggested_action}"
elif error_pattern.error_type.value == "timeout_error":
return f"Timeout Error: {error_pattern.message}\n{error_pattern.suggested_action}"
elif error_pattern.error_type.value == "validation_error":
return f"Validation Error: {error_pattern.message}\n{error_pattern.suggested_action}"
else:
return f"Error in {func_name}: {str(e)}\n{error_pattern.suggested_action}"
return _handle_error(func_name, e)


async def _safe_execute_async(func_name: str, func: Callable[..., Any], **kwargs: Any) -> str:
"""Execute an async function with error recovery and input sanitization."""
try:
sanitized_kwargs, error = _sanitize_kwargs(kwargs)
if error:
return error

# Execute with enhanced error handling
result = await func(**sanitized_kwargs)
return str(result)

except Exception as e:
return _handle_error(func_name, e)


__all__ = [
Expand All @@ -144,5 +171,6 @@ def _safe_execute(func_name: str, func: Callable[..., Any], **kwargs: Any) -> st
"_get_web_timeout",
"_get_timeout_bounds",
"_safe_execute",
"_safe_execute_async",
"recovery_system",
]
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
import time
from contextlib import contextmanager
from pathlib import Path
from unittest.mock import MagicMock

import pytest
from dotenv import load_dotenv

# Mock matilda_transport
sys.modules["matilda_transport"] = MagicMock()

# Add the parent directory and src directory to Python path for imports
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root / "src"))
Expand Down
20 changes: 12 additions & 8 deletions tests/test_tools_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,34 +192,38 @@ class TestCodeExecution:
"""Test code execution tool."""

@pytest.mark.unit
def test_run_python_executes_code_and_captures_output(self):
@pytest.mark.asyncio
async def test_run_python_executes_code_and_captures_output(self):
"""Test successful Python code execution."""
code = "print('Hello, World!')\nprint(2 + 2)"
result = run_python(code)
result = await run_python(code)

assert "Hello, World!" in result
assert "4" in result

@pytest.mark.unit
def test_run_python_error(self):
@pytest.mark.asyncio
async def test_run_python_error(self):
"""Test Python code with error."""
code = "print(undefined_variable)"
result = run_python(code)
result = await run_python(code)

assert "Error" in result.lower() or "NameError" in result

@pytest.mark.unit
def test_run_python_timeout(self):
@pytest.mark.asyncio
async def test_run_python_timeout(self):
"""Test Python code timeout."""
code = "import time\ntime.sleep(10)"
result = run_python(code, timeout=1)
result = await run_python(code, timeout=1)

assert "timed out" in result

@pytest.mark.unit
def test_run_python_empty_code(self):
@pytest.mark.asyncio
async def test_run_python_empty_code(self):
"""Test empty code."""
result = run_python("")
result = await run_python("")
assert "Code cannot be empty" in result


Expand Down