102 lines
3.5 KiB
Python
102 lines
3.5 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""Sandboxed code execution for the AI agent executeCode tool."""
|
|
|
|
import logging
|
|
import sys
|
|
import io
|
|
import traceback
|
|
from typing import Dict, Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_PYTHON_ALLOWED_MODULES = {
|
|
"math", "statistics", "json", "csv", "re", "datetime",
|
|
"collections", "itertools", "functools", "decimal", "fractions",
|
|
"random", "string", "textwrap", "operator", "copy",
|
|
}
|
|
|
|
_PYTHON_BLOCKED_BUILTINS = {
|
|
"open", "exec", "eval", "compile", "__import__", "globals", "locals",
|
|
"getattr", "setattr", "delattr", "breakpoint", "exit", "quit",
|
|
"input", "memoryview", "type",
|
|
}
|
|
|
|
_MAX_EXECUTION_TIME_S = 30
|
|
_MAX_OUTPUT_CHARS = 50000
|
|
|
|
|
|
def _safeImport(name, *args, **kwargs):
|
|
"""Restricted import that only allows whitelisted modules."""
|
|
if name not in _PYTHON_ALLOWED_MODULES:
|
|
raise ImportError(f"Module '{name}' is not allowed. Permitted: {', '.join(sorted(_PYTHON_ALLOWED_MODULES))}")
|
|
return __builtins__["__import__"](name, *args, **kwargs) if isinstance(__builtins__, dict) else __import__(name, *args, **kwargs)
|
|
|
|
|
|
def _buildRestrictedGlobals() -> Dict[str, Any]:
|
|
"""Build a restricted globals dict for exec()."""
|
|
import builtins
|
|
safeBuiltins = {}
|
|
for name in dir(builtins):
|
|
if name.startswith("_"):
|
|
continue
|
|
if name in _PYTHON_BLOCKED_BUILTINS:
|
|
continue
|
|
safeBuiltins[name] = getattr(builtins, name)
|
|
|
|
safeBuiltins["__import__"] = _safeImport
|
|
safeBuiltins["__name__"] = "__sandbox__"
|
|
safeBuiltins["__builtins__"] = safeBuiltins
|
|
|
|
for modName in _PYTHON_ALLOWED_MODULES:
|
|
try:
|
|
safeBuiltins[modName] = __import__(modName)
|
|
except ImportError:
|
|
pass
|
|
|
|
return {"__builtins__": safeBuiltins}
|
|
|
|
|
|
async def executePython(code: str) -> Dict[str, Any]:
|
|
"""Execute Python code in a restricted sandbox. Returns {success, output, error}."""
|
|
import asyncio
|
|
|
|
def _run():
|
|
restrictedGlobals = _buildRestrictedGlobals()
|
|
capturedOutput = io.StringIO()
|
|
oldStdout = sys.stdout
|
|
oldStderr = sys.stderr
|
|
|
|
try:
|
|
sys.stdout = capturedOutput
|
|
sys.stderr = capturedOutput
|
|
|
|
# Do not use signal.SIGALRM here: _run executes inside a thread-pool worker
|
|
# (asyncio.run_in_executor). signal.signal only works on the main thread.
|
|
# Wall-clock limit is enforced by asyncio.wait_for around run_in_executor.
|
|
|
|
exec(compile(code, "<sandbox>", "exec"), restrictedGlobals)
|
|
|
|
output = capturedOutput.getvalue()
|
|
if len(output) > _MAX_OUTPUT_CHARS:
|
|
output = output[:_MAX_OUTPUT_CHARS] + f"\n... (truncated at {_MAX_OUTPUT_CHARS} chars)"
|
|
return {"success": True, "output": output}
|
|
|
|
except TimeoutError:
|
|
return {"success": False, "error": f"Execution timed out after {_MAX_EXECUTION_TIME_S}s"}
|
|
except Exception as e:
|
|
tb = traceback.format_exc()
|
|
return {"success": False, "error": f"{type(e).__name__}: {e}", "traceback": tb}
|
|
finally:
|
|
sys.stdout = oldStdout
|
|
sys.stderr = oldStderr
|
|
|
|
loop = asyncio.get_event_loop()
|
|
try:
|
|
result = await asyncio.wait_for(
|
|
loop.run_in_executor(None, _run),
|
|
timeout=float(_MAX_EXECUTION_TIME_S) + 5.0,
|
|
)
|
|
return result
|
|
except asyncio.TimeoutError:
|
|
return {"success": False, "error": f"Execution timed out after {_MAX_EXECUTION_TIME_S}s"}
|