157 lines
5.5 KiB
Python
157 lines
5.5 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""Tool registry and dispatcher for the CodeEditor agent loop.
|
|
Defines available tools and executes them against the file context manager."""
|
|
|
|
import logging
|
|
import time
|
|
import fnmatch
|
|
from typing import Dict, Any, List
|
|
|
|
from modules.features.codeeditor.datamodelCodeeditor import ToolResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
TOOL_DEFINITIONS = [
|
|
{
|
|
"name": "read_file",
|
|
"description": "Read the full content of a single file by its fileId.",
|
|
"parameters": {"fileId": "string (required)"}
|
|
},
|
|
{
|
|
"name": "list_files",
|
|
"description": "List all available text files with metadata (name, size, mimeType). Optionally filter by glob pattern.",
|
|
"parameters": {"filter": "string (optional, glob pattern e.g. '*.py')"}
|
|
},
|
|
{
|
|
"name": "search_files",
|
|
"description": "Search all file contents for a text query. Returns matching lines with file name and line number.",
|
|
"parameters": {"query": "string (required)", "fileType": "string (optional, extension e.g. 'py')"}
|
|
},
|
|
]
|
|
|
|
|
|
async def dispatch(toolName: str, toolArgs: Dict[str, Any], dbManagement) -> ToolResult:
|
|
"""Execute a tool and return the result."""
|
|
startTime = time.time()
|
|
try:
|
|
if toolName == "read_file":
|
|
result = await _toolReadFile(toolArgs, dbManagement)
|
|
elif toolName == "list_files":
|
|
result = _toolListFiles(toolArgs, dbManagement)
|
|
elif toolName == "search_files":
|
|
result = await _toolSearchFiles(toolArgs, dbManagement)
|
|
else:
|
|
result = f"Unknown tool: {toolName}"
|
|
return ToolResult(toolName=toolName, result=result, success=False,
|
|
executionTime=time.time() - startTime)
|
|
|
|
return ToolResult(toolName=toolName, result=result, success=True,
|
|
executionTime=time.time() - startTime)
|
|
except Exception as e:
|
|
logger.error(f"Tool {toolName} failed: {e}", exc_info=True)
|
|
return ToolResult(toolName=toolName, result=f"Error: {str(e)}", success=False,
|
|
executionTime=time.time() - startTime)
|
|
|
|
|
|
async def _toolReadFile(args: Dict[str, Any], dbManagement) -> str:
|
|
"""Read a single file's content."""
|
|
fileId = args.get("fileId", "")
|
|
if not fileId:
|
|
return "Error: fileId is required"
|
|
|
|
fileItem = dbManagement.getFile(fileId)
|
|
if not fileItem:
|
|
return f"Error: File {fileId} not found"
|
|
|
|
fileData = dbManagement.getFileData(fileId)
|
|
if not fileData:
|
|
return f"Error: No data for file {fileId}"
|
|
|
|
try:
|
|
content = fileData.decode("utf-8")
|
|
except UnicodeDecodeError:
|
|
return f"Error: File {fileItem.fileName} is not valid UTF-8"
|
|
|
|
lines = content.split("\n")
|
|
numbered = "\n".join([f"{i + 1}|{line}" for i, line in enumerate(lines)])
|
|
return f"--- FILE: {fileItem.fileName} (id: {fileId}) ---\n{numbered}\n--- END FILE ---"
|
|
|
|
|
|
def _toolListFiles(args: Dict[str, Any], dbManagement) -> str:
|
|
"""List all text files, optionally filtered by glob pattern."""
|
|
from modules.features.codeeditor.datamodelCodeeditor import isTextFile
|
|
|
|
filterPattern = args.get("filter", "")
|
|
allFiles = dbManagement.getAllFiles()
|
|
if not allFiles:
|
|
return "No files found."
|
|
|
|
lines = []
|
|
for f in allFiles:
|
|
if not isTextFile(f.mimeType, f.fileName):
|
|
continue
|
|
if filterPattern and not fnmatch.fnmatch(f.fileName, filterPattern):
|
|
continue
|
|
lines.append(f"- {f.fileName} (id: {f.id}, size: {f.fileSize}B, type: {f.mimeType})")
|
|
|
|
if not lines:
|
|
return "No matching text files found."
|
|
return f"Available files ({len(lines)}):\n" + "\n".join(lines)
|
|
|
|
|
|
async def _toolSearchFiles(args: Dict[str, Any], dbManagement) -> str:
|
|
"""Search file contents for a query string."""
|
|
from modules.features.codeeditor.datamodelCodeeditor import isTextFile
|
|
|
|
query = args.get("query", "")
|
|
if not query:
|
|
return "Error: query is required"
|
|
|
|
fileType = args.get("fileType", "")
|
|
allFiles = dbManagement.getAllFiles()
|
|
if not allFiles:
|
|
return "No files to search."
|
|
|
|
hits = []
|
|
maxHits = 50
|
|
queryLower = query.lower()
|
|
|
|
for f in allFiles:
|
|
if not isTextFile(f.mimeType, f.fileName):
|
|
continue
|
|
if fileType and not f.fileName.endswith(f".{fileType}"):
|
|
continue
|
|
|
|
fileData = dbManagement.getFileData(f.id)
|
|
if not fileData:
|
|
continue
|
|
|
|
try:
|
|
content = fileData.decode("utf-8")
|
|
except UnicodeDecodeError:
|
|
continue
|
|
|
|
for lineNum, line in enumerate(content.split("\n"), 1):
|
|
if queryLower in line.lower():
|
|
hits.append(f"{f.fileName}:{lineNum}: {line.strip()}")
|
|
if len(hits) >= maxHits:
|
|
break
|
|
if len(hits) >= maxHits:
|
|
break
|
|
|
|
if not hits:
|
|
return f"No matches found for '{query}'."
|
|
result = f"Search results for '{query}' ({len(hits)} matches):\n" + "\n".join(hits)
|
|
if len(hits) >= maxHits:
|
|
result += f"\n... (truncated at {maxHits} matches)"
|
|
return result
|
|
|
|
|
|
def formatToolDefinitions() -> str:
|
|
"""Format tool definitions for inclusion in the system prompt."""
|
|
parts = []
|
|
for tool in TOOL_DEFINITIONS:
|
|
params = ", ".join([f"{k}: {v}" for k, v in tool["parameters"].items()])
|
|
parts.append(f"- **{tool['name']}**: {tool['description']}\n Parameters: {{{params}}}")
|
|
return "\n".join(parts)
|