143 lines
5.4 KiB
Python
143 lines
5.4 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""Tool registry for the Agent service. Manages tool definitions and dispatch."""
|
|
|
|
import logging
|
|
import time
|
|
from typing import Dict, List, Any, Optional, Callable, Awaitable
|
|
|
|
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
|
|
ToolDefinition, ToolCallRequest, ToolResult
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ToolRegistry:
|
|
"""Registry for agent tools. Handles registration, lookup, and dispatch."""
|
|
|
|
def __init__(self):
|
|
self._tools: Dict[str, ToolDefinition] = {}
|
|
self._handlers: Dict[str, Callable[..., Awaitable[ToolResult]]] = {}
|
|
|
|
def register(self, name: str, handler: Callable[..., Awaitable[ToolResult]],
|
|
description: str = "", parameters: Dict[str, Any] = None,
|
|
readOnly: bool = False, featureType: str = None):
|
|
"""Register a tool with its handler function."""
|
|
if name in self._tools:
|
|
logger.warning(f"Tool '{name}' already registered, overwriting")
|
|
|
|
self._tools[name] = ToolDefinition(
|
|
name=name,
|
|
description=description,
|
|
parameters=parameters or {},
|
|
readOnly=readOnly,
|
|
featureType=featureType
|
|
)
|
|
self._handlers[name] = handler
|
|
logger.debug(f"Registered tool: {name} (readOnly={readOnly})")
|
|
|
|
def registerFromDefinition(self, definition: ToolDefinition,
|
|
handler: Callable[..., Awaitable[ToolResult]]):
|
|
"""Register a tool from a pre-built ToolDefinition."""
|
|
self._tools[definition.name] = definition
|
|
self._handlers[definition.name] = handler
|
|
logger.debug(f"Registered tool: {definition.name} (readOnly={definition.readOnly})")
|
|
|
|
def unregister(self, name: str):
|
|
"""Remove a tool from the registry."""
|
|
self._tools.pop(name, None)
|
|
self._handlers.pop(name, None)
|
|
|
|
def getTools(self, toolSet: str = None, featureType: str = None) -> List[ToolDefinition]:
|
|
"""Get available tools, optionally filtered by toolSet or featureType."""
|
|
tools = list(self._tools.values())
|
|
if featureType:
|
|
tools = [t for t in tools if t.featureType is None or t.featureType == featureType]
|
|
return tools
|
|
|
|
def getToolNames(self) -> List[str]:
|
|
"""Get names of all registered tools."""
|
|
return list(self._tools.keys())
|
|
|
|
def getTool(self, name: str) -> Optional[ToolDefinition]:
|
|
"""Get a single tool definition by name."""
|
|
return self._tools.get(name)
|
|
|
|
def isReadOnly(self, name: str) -> bool:
|
|
"""Check if a tool is marked as readOnly."""
|
|
tool = self._tools.get(name)
|
|
return tool.readOnly if tool else False
|
|
|
|
def isValidTool(self, name: str) -> bool:
|
|
"""Check if a tool name is valid (registered)."""
|
|
return name in self._tools
|
|
|
|
async def dispatch(self, toolCall: ToolCallRequest, context: Dict[str, Any] = None) -> ToolResult:
|
|
"""Execute a tool call and return the result."""
|
|
startTime = time.time()
|
|
|
|
if not self.isValidTool(toolCall.name):
|
|
return ToolResult(
|
|
toolCallId=toolCall.id,
|
|
toolName=toolCall.name,
|
|
success=False,
|
|
error=f"Unknown tool: '{toolCall.name}'. Available: {', '.join(self.getToolNames())}"
|
|
)
|
|
|
|
handler = self._handlers[toolCall.name]
|
|
try:
|
|
result = await handler(toolCall.args, context or {})
|
|
durationMs = int((time.time() - startTime) * 1000)
|
|
|
|
if isinstance(result, ToolResult):
|
|
result.toolCallId = toolCall.id
|
|
result.durationMs = durationMs
|
|
return result
|
|
|
|
return ToolResult(
|
|
toolCallId=toolCall.id,
|
|
toolName=toolCall.name,
|
|
success=True,
|
|
data=str(result),
|
|
durationMs=durationMs
|
|
)
|
|
|
|
except Exception as e:
|
|
durationMs = int((time.time() - startTime) * 1000)
|
|
logger.error(f"Tool '{toolCall.name}' failed: {e}", exc_info=True)
|
|
return ToolResult(
|
|
toolCallId=toolCall.id,
|
|
toolName=toolCall.name,
|
|
success=False,
|
|
error=str(e),
|
|
durationMs=durationMs
|
|
)
|
|
|
|
def formatToolsForPrompt(self) -> str:
|
|
"""Format all tools as text for system prompt (text-based fallback)."""
|
|
parts = []
|
|
for tool in self._tools.values():
|
|
paramStr = ", ".join(
|
|
f"{k}: {v}" for k, v in tool.parameters.items()
|
|
) if tool.parameters else "none"
|
|
parts.append(f"- **{tool.name}**: {tool.description}\n Parameters: {{{paramStr}}}")
|
|
return "\n".join(parts)
|
|
|
|
def formatToolsForFunctionCalling(self) -> List[Dict[str, Any]]:
|
|
"""Format all tools as OpenAI-compatible function definitions for native function calling."""
|
|
functions = []
|
|
for tool in self._tools.values():
|
|
functions.append({
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool.name,
|
|
"description": tool.description,
|
|
"parameters": tool.parameters if tool.parameters else {
|
|
"type": "object",
|
|
"properties": {},
|
|
"required": []
|
|
}
|
|
}
|
|
})
|
|
return functions
|