feat: add tools registry
This commit is contained in:
parent
63dba85b7a
commit
2158b90748
3 changed files with 618 additions and 40 deletions
|
|
@ -1,5 +1,6 @@
|
|||
"""Workflow-related base datamodels and step/task structures."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import register_model_labels, ModelMixin
|
||||
|
|
@ -7,9 +8,12 @@ from modules.shared.attributeUtils import register_model_labels, ModelMixin
|
|||
|
||||
class ActionDocument(BaseModel, ModelMixin):
|
||||
"""Clear document structure for action results"""
|
||||
|
||||
documentName: str = Field(description="Name of the document")
|
||||
documentData: Any = Field(description="Content/data of the document")
|
||||
mimeType: str = Field(description="MIME type of the document")
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ActionDocument",
|
||||
{"en": "Action Document", "fr": "Document d'action"},
|
||||
|
|
@ -31,16 +35,25 @@ class ActionResult(BaseModel, ModelMixin):
|
|||
|
||||
success: bool = Field(description="Whether execution succeeded")
|
||||
error: Optional[str] = Field(None, description="Error message if failed")
|
||||
documents: List[ActionDocument] = Field(default_factory=list, description="Document outputs")
|
||||
resultLabel: Optional[str] = Field(None, description="Label for document routing (set by action handler, not by action methods)")
|
||||
documents: List[ActionDocument] = Field(
|
||||
default_factory=list, description="Document outputs"
|
||||
)
|
||||
resultLabel: Optional[str] = Field(
|
||||
None,
|
||||
description="Label for document routing (set by action handler, not by action methods)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def isSuccess(cls, documents: List[ActionDocument] = None) -> "ActionResult":
|
||||
return cls(success=True, documents=documents or [])
|
||||
|
||||
@classmethod
|
||||
def isFailure(cls, error: str, documents: List[ActionDocument] = None) -> "ActionResult":
|
||||
def isFailure(
|
||||
cls, error: str, documents: List[ActionDocument] = None
|
||||
) -> "ActionResult":
|
||||
return cls(success=False, documents=documents or [], error=error)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ActionResult",
|
||||
{"en": "Action Result", "fr": "Résultat de l'action"},
|
||||
|
|
@ -55,7 +68,9 @@ register_model_labels(
|
|||
|
||||
class ActionSelection(BaseModel, ModelMixin):
|
||||
method: str = Field(description="Method to execute (e.g., web, document, ai)")
|
||||
name: str = Field(description="Action name within the method (e.g., search, extract)")
|
||||
name: str = Field(
|
||||
description="Action name within the method (e.g., search, extract)"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
|
|
@ -69,7 +84,9 @@ register_model_labels(
|
|||
|
||||
|
||||
class ActionParameters(BaseModel, ModelMixin):
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="Parameters to execute the selected action")
|
||||
parameters: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Parameters to execute the selected action"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
|
|
@ -102,8 +119,12 @@ class Observation(BaseModel, ModelMixin):
|
|||
success: bool = Field(description="Action execution success flag")
|
||||
resultLabel: str = Field(description="Deterministic label for produced documents")
|
||||
documentsCount: int = Field(description="Number of produced documents")
|
||||
previews: List[ObservationPreview] = Field(default_factory=list, description="Compact previews of outputs")
|
||||
notes: List[str] = Field(default_factory=list, description="Short notes or key facts")
|
||||
previews: List[ObservationPreview] = Field(
|
||||
default_factory=list, description="Compact previews of outputs"
|
||||
)
|
||||
notes: List[str] = Field(
|
||||
default_factory=list, description="Short notes or key facts"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
|
|
@ -119,7 +140,9 @@ register_model_labels(
|
|||
)
|
||||
|
||||
|
||||
class TaskStatus(str):
|
||||
class TaskStatus(str, Enum):
|
||||
"""Task status enumeration."""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
|
|
@ -142,7 +165,9 @@ register_model_labels(
|
|||
|
||||
class DocumentExchange(BaseModel, ModelMixin):
|
||||
documentsLabel: str = Field(description="Label for the set of documents")
|
||||
documents: List[str] = Field(default_factory=list, description="List of document references")
|
||||
documents: List[str] = Field(
|
||||
default_factory=list, description="List of document references"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
|
|
@ -159,16 +184,28 @@ class TaskAction(BaseModel, ModelMixin):
|
|||
id: str = Field(..., description="Action ID")
|
||||
execMethod: str = Field(..., description="Method to execute")
|
||||
execAction: str = Field(..., description="Action to perform")
|
||||
execParameters: Dict[str, Any] = Field(default_factory=dict, description="Action parameters")
|
||||
execResultLabel: Optional[str] = Field(None, description="Label for the set of result documents")
|
||||
expectedDocumentFormats: Optional[List[Dict[str, str]]] = Field(None, description="Expected document formats (optional)")
|
||||
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
|
||||
execParameters: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Action parameters"
|
||||
)
|
||||
execResultLabel: Optional[str] = Field(
|
||||
None, description="Label for the set of result documents"
|
||||
)
|
||||
expectedDocumentFormats: Optional[List[Dict[str, str]]] = Field(
|
||||
None, description="Expected document formats (optional)"
|
||||
)
|
||||
userMessage: Optional[str] = Field(
|
||||
None, description="User-friendly message in user's language"
|
||||
)
|
||||
status: TaskStatus = Field(default=TaskStatus.PENDING, description="Action status")
|
||||
error: Optional[str] = Field(None, description="Error message if action failed")
|
||||
retryCount: int = Field(default=0, description="Number of retries attempted")
|
||||
retryMax: int = Field(default=3, description="Maximum number of retries")
|
||||
processingTime: Optional[float] = Field(None, description="Processing time in seconds")
|
||||
timestamp: float = Field(..., description="When the action was executed (UTC timestamp in seconds)")
|
||||
processingTime: Optional[float] = Field(
|
||||
None, description="Processing time in seconds"
|
||||
)
|
||||
timestamp: float = Field(
|
||||
..., description="When the action was executed (UTC timestamp in seconds)"
|
||||
)
|
||||
result: Optional[str] = Field(None, description="Result of the action")
|
||||
|
||||
|
||||
|
|
@ -181,7 +218,10 @@ register_model_labels(
|
|||
"execAction": {"en": "Action", "fr": "Action"},
|
||||
"execParameters": {"en": "Parameters", "fr": "Paramètres"},
|
||||
"execResultLabel": {"en": "Result Label", "fr": "Label du résultat"},
|
||||
"expectedDocumentFormats": {"en": "Expected Document Formats", "fr": "Formats de documents attendus"},
|
||||
"expectedDocumentFormats": {
|
||||
"en": "Expected Document Formats",
|
||||
"fr": "Formats de documents attendus",
|
||||
},
|
||||
"userMessage": {"en": "User Message", "fr": "Message utilisateur"},
|
||||
"status": {"en": "Status", "fr": "Statut"},
|
||||
"error": {"en": "Error", "fr": "Erreur"},
|
||||
|
|
@ -221,16 +261,30 @@ class TaskItem(BaseModel, ModelMixin):
|
|||
userInput: str = Field(..., description="User input that triggered the task")
|
||||
status: TaskStatus = Field(default=TaskStatus.PENDING, description="Task status")
|
||||
error: Optional[str] = Field(None, description="Error message if task failed")
|
||||
startedAt: Optional[float] = Field(None, description="When the task started (UTC timestamp in seconds)")
|
||||
finishedAt: Optional[float] = Field(None, description="When the task finished (UTC timestamp in seconds)")
|
||||
actionList: List[TaskAction] = Field(default_factory=list, description="List of actions to execute")
|
||||
startedAt: Optional[float] = Field(
|
||||
None, description="When the task started (UTC timestamp in seconds)"
|
||||
)
|
||||
finishedAt: Optional[float] = Field(
|
||||
None, description="When the task finished (UTC timestamp in seconds)"
|
||||
)
|
||||
actionList: List[TaskAction] = Field(
|
||||
default_factory=list, description="List of actions to execute"
|
||||
)
|
||||
retryCount: int = Field(default=0, description="Number of retries attempted")
|
||||
retryMax: int = Field(default=3, description="Maximum number of retries")
|
||||
rollbackOnFailure: bool = Field(default=True, description="Whether to rollback on failure")
|
||||
dependencies: List[str] = Field(default_factory=list, description="List of task IDs this task depends on")
|
||||
rollbackOnFailure: bool = Field(
|
||||
default=True, description="Whether to rollback on failure"
|
||||
)
|
||||
dependencies: List[str] = Field(
|
||||
default_factory=list, description="List of task IDs this task depends on"
|
||||
)
|
||||
feedback: Optional[str] = Field(None, description="Task feedback message")
|
||||
processingTime: Optional[float] = Field(None, description="Total processing time in seconds")
|
||||
resultLabels: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Map of result labels to their values")
|
||||
processingTime: Optional[float] = Field(
|
||||
None, description="Total processing time in seconds"
|
||||
)
|
||||
resultLabels: Optional[Dict[str, Any]] = Field(
|
||||
default_factory=dict, description="Map of result labels to their values"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
|
|
@ -258,7 +312,9 @@ class TaskStep(BaseModel, ModelMixin):
|
|||
dependencies: Optional[list[str]] = Field(default_factory=list)
|
||||
success_criteria: Optional[list[str]] = Field(default_factory=list)
|
||||
estimated_complexity: Optional[str] = None
|
||||
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
|
||||
userMessage: Optional[str] = Field(
|
||||
None, description="User-friendly message in user's language"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
|
|
@ -269,7 +325,10 @@ register_model_labels(
|
|||
"objective": {"en": "Objective", "fr": "Objectif"},
|
||||
"dependencies": {"en": "Dependencies", "fr": "Dépendances"},
|
||||
"success_criteria": {"en": "Success Criteria", "fr": "Critères de succès"},
|
||||
"estimated_complexity": {"en": "Estimated Complexity", "fr": "Complexité estimée"},
|
||||
"estimated_complexity": {
|
||||
"en": "Estimated Complexity",
|
||||
"fr": "Complexité estimée",
|
||||
},
|
||||
"userMessage": {"en": "User Message", "fr": "Message utilisateur"},
|
||||
},
|
||||
)
|
||||
|
|
@ -278,15 +337,31 @@ register_model_labels(
|
|||
class TaskHandover(BaseModel, ModelMixin):
|
||||
taskId: str = Field(description="Target task ID")
|
||||
sourceTask: Optional[str] = Field(None, description="Source task ID")
|
||||
inputDocuments: List[DocumentExchange] = Field(default_factory=list, description="Available input documents")
|
||||
outputDocuments: List[DocumentExchange] = Field(default_factory=list, description="Produced output documents")
|
||||
inputDocuments: List[DocumentExchange] = Field(
|
||||
default_factory=list, description="Available input documents"
|
||||
)
|
||||
outputDocuments: List[DocumentExchange] = Field(
|
||||
default_factory=list, description="Produced output documents"
|
||||
)
|
||||
context: Dict[str, Any] = Field(default_factory=dict, description="Task context")
|
||||
previousResults: List[str] = Field(default_factory=list, description="Previous result summaries")
|
||||
improvements: List[str] = Field(default_factory=list, description="Improvement suggestions")
|
||||
workflowSummary: Optional[str] = Field(None, description="Summarized workflow context")
|
||||
messageHistory: List[str] = Field(default_factory=list, description="Key message summaries")
|
||||
timestamp: float = Field(..., description="When the handover was created (UTC timestamp in seconds)")
|
||||
handoverType: str = Field(default="task", description="Type of handover: task, phase, or workflow")
|
||||
previousResults: List[str] = Field(
|
||||
default_factory=list, description="Previous result summaries"
|
||||
)
|
||||
improvements: List[str] = Field(
|
||||
default_factory=list, description="Improvement suggestions"
|
||||
)
|
||||
workflowSummary: Optional[str] = Field(
|
||||
None, description="Summarized workflow context"
|
||||
)
|
||||
messageHistory: List[str] = Field(
|
||||
default_factory=list, description="Key message summaries"
|
||||
)
|
||||
timestamp: float = Field(
|
||||
..., description="When the handover was created (UTC timestamp in seconds)"
|
||||
)
|
||||
handoverType: str = Field(
|
||||
default="task", description="Type of handover: task, phase, or workflow"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
|
|
@ -310,7 +385,7 @@ register_model_labels(
|
|||
|
||||
class TaskContext(BaseModel, ModelMixin):
|
||||
task_step: TaskStep
|
||||
workflow: Optional['ChatWorkflow'] = None
|
||||
workflow: Optional["ChatWorkflow"] = None
|
||||
workflow_id: Optional[str] = None
|
||||
available_documents: Optional[str] = "No documents available"
|
||||
available_connections: Optional[list[str]] = Field(default_factory=list)
|
||||
|
|
@ -358,7 +433,9 @@ class ReviewResult(BaseModel, ModelMixin):
|
|||
met_criteria: Optional[list[str]] = Field(default_factory=list)
|
||||
unmet_criteria: Optional[list[str]] = Field(default_factory=list)
|
||||
confidence: Optional[float] = 0.5
|
||||
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
|
||||
userMessage: Optional[str] = Field(
|
||||
None, description="User-friendly message in user's language"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
|
|
@ -381,7 +458,9 @@ register_model_labels(
|
|||
class TaskPlan(BaseModel, ModelMixin):
|
||||
overview: str
|
||||
tasks: list[TaskStep]
|
||||
userMessage: Optional[str] = Field(None, description="Overall user-friendly message for the task plan")
|
||||
userMessage: Optional[str] = Field(
|
||||
None, description="Overall user-friendly message for the task plan"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
|
|
@ -393,7 +472,3 @@ register_model_labels(
|
|||
"userMessage": {"en": "User Message", "fr": "Message utilisateur"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
305
modules/features/chatBot/utils/toolRegistry.py
Normal file
305
modules/features/chatBot/utils/toolRegistry.py
Normal file
|
|
@ -0,0 +1,305 @@
|
|||
"""Tool registry for auto-discovering and managing chatbot tools.
|
||||
|
||||
This module provides a central registry that automatically discovers all tools
|
||||
in the chatbotTools directory structure and provides methods to query them.
|
||||
The registry is built in-memory at startup and does not require a database.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolMetadata:
|
||||
"""Metadata about a discovered chatbot tool.
|
||||
|
||||
Attributes:
|
||||
tool_id: Unique identifier (e.g., 'shared.tavily_search')
|
||||
name: Function name of the tool
|
||||
category: Category of the tool ('shared' or 'customer')
|
||||
description: Tool description from docstring
|
||||
tool_instance: The actual LangChain tool instance
|
||||
module_path: Full Python module path
|
||||
"""
|
||||
|
||||
tool_id: str
|
||||
name: str
|
||||
category: str
|
||||
description: str
|
||||
tool_instance: BaseTool
|
||||
module_path: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a pretty-printed string representation for logging."""
|
||||
return (
|
||||
f"ToolMetadata(\n"
|
||||
f" tool_id='{self.tool_id}',\n"
|
||||
f" name='{self.name}',\n"
|
||||
f" category='{self.category}',\n"
|
||||
f" description='{self.description}',\n"
|
||||
f" module_path='{self.module_path}'\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""Central registry for all chatbot tools.
|
||||
|
||||
This class discovers and catalogs all tools decorated with @tool in the
|
||||
chatbotTools directory structure. Tools are automatically discovered at
|
||||
initialization by scanning the filesystem.
|
||||
|
||||
The registry provides methods to query tools by ID, category, or get all tools.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize an empty tool registry."""
|
||||
self._tools: Dict[str, ToolMetadata] = {}
|
||||
self._initialized: bool = False
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Discover and register all tools from the chatbotTools directory.
|
||||
|
||||
This method scans both sharedTools and customerTools directories,
|
||||
imports all tool*.py modules, and extracts functions decorated with @tool.
|
||||
|
||||
This method is idempotent - calling it multiple times has no effect
|
||||
after the first initialization.
|
||||
"""
|
||||
if self._initialized:
|
||||
logger.debug("Tool registry already initialized, skipping")
|
||||
return
|
||||
|
||||
logger.info("Initializing tool registry...")
|
||||
|
||||
# Get base path to chatbotTools directory
|
||||
base_path = Path(__file__).parent.parent / "chatbotTools"
|
||||
|
||||
if not base_path.exists():
|
||||
logger.warning(f"chatbotTools directory not found at {base_path}")
|
||||
self._initialized = True
|
||||
return
|
||||
|
||||
# Discover tools in each category
|
||||
self._discover_category(
|
||||
category_path=base_path / "sharedTools", category="shared"
|
||||
)
|
||||
self._discover_category(
|
||||
category_path=base_path / "customerTools", category="customer"
|
||||
)
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"Tool registry initialized with {len(self._tools)} tools")
|
||||
|
||||
def _discover_category(self, *, category_path: Path, category: str) -> None:
|
||||
"""Discover all tools in a specific category directory.
|
||||
|
||||
Args:
|
||||
category_path: Path to the category directory (sharedTools or customerTools)
|
||||
category: Category name ('shared' or 'customer')
|
||||
"""
|
||||
if not category_path.exists():
|
||||
logger.warning(f"Category directory not found: {category_path}")
|
||||
return
|
||||
|
||||
logger.debug(f"Discovering tools in category: {category}")
|
||||
|
||||
# Find all tool*.py files (excluding __init__.py)
|
||||
tool_files = [
|
||||
f for f in category_path.glob("tool*.py") if f.name != "__init__.py"
|
||||
]
|
||||
|
||||
for tool_file in tool_files:
|
||||
self._import_and_register_tools(
|
||||
tool_file=tool_file, category=category, category_path=category_path
|
||||
)
|
||||
|
||||
logger.debug(f"Discovered {len(tool_files)} tool files in {category}")
|
||||
|
||||
def _import_and_register_tools(
|
||||
self, *, tool_file: Path, category: str, category_path: Path
|
||||
) -> None:
|
||||
"""Import a tool module and register all discovered tools.
|
||||
|
||||
Args:
|
||||
tool_file: Path to the tool Python file
|
||||
category: Category name ('shared' or 'customer')
|
||||
category_path: Path to the category directory
|
||||
"""
|
||||
# Construct module name
|
||||
module_name = (
|
||||
f"modules.features.chatBot.chatbotTools.{category}Tools.{tool_file.stem}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Import the module
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Find all BaseTool instances in the module
|
||||
tools_found = 0
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if isinstance(obj, BaseTool):
|
||||
self._register_tool(
|
||||
tool_instance=obj,
|
||||
name=name,
|
||||
category=category,
|
||||
module_path=module_name,
|
||||
)
|
||||
tools_found += 1
|
||||
|
||||
if tools_found == 0:
|
||||
logger.warning(f"No tools found in {module_name}")
|
||||
else:
|
||||
logger.debug(f"Loaded {tools_found} tool(s) from {module_name}")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
f"Import error loading tools from {module_name}: {str(e)}. "
|
||||
f"This tool will not be available."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error loading tools from {module_name}: {type(e).__name__}: {str(e)}"
|
||||
)
|
||||
|
||||
def _register_tool(
|
||||
self, *, tool_instance: BaseTool, name: str, category: str, module_path: str
|
||||
) -> None:
|
||||
"""Register a single tool in the registry.
|
||||
|
||||
Args:
|
||||
tool_instance: The LangChain tool instance
|
||||
name: Function name of the tool
|
||||
category: Category name ('shared' or 'customer')
|
||||
module_path: Full Python module path
|
||||
"""
|
||||
tool_id = f"{category}.{name}"
|
||||
|
||||
# Check for duplicate tool IDs
|
||||
if tool_id in self._tools:
|
||||
logger.warning(f"Duplicate tool ID detected: {tool_id}, overwriting")
|
||||
|
||||
metadata = ToolMetadata(
|
||||
tool_id=tool_id,
|
||||
name=name,
|
||||
category=category,
|
||||
description=tool_instance.description or "",
|
||||
tool_instance=tool_instance,
|
||||
module_path=module_path,
|
||||
)
|
||||
|
||||
self._tools[tool_id] = metadata
|
||||
logger.debug(f"Registered tool: {tool_id}")
|
||||
|
||||
def get_all_tools(self) -> List[ToolMetadata]:
|
||||
"""Get all registered tools.
|
||||
|
||||
Returns:
|
||||
List of all tool metadata objects
|
||||
"""
|
||||
return list(self._tools.values())
|
||||
|
||||
def get_tool(self, *, tool_id: str) -> Optional[ToolMetadata]:
|
||||
"""Get a specific tool by its ID.
|
||||
|
||||
Args:
|
||||
tool_id: The tool identifier (e.g., 'shared.tavily_search')
|
||||
|
||||
Returns:
|
||||
Tool metadata if found, None otherwise
|
||||
"""
|
||||
return self._tools.get(tool_id)
|
||||
|
||||
def get_tools_by_category(self, *, category: str) -> List[ToolMetadata]:
|
||||
"""Get all tools in a specific category.
|
||||
|
||||
Args:
|
||||
category: Category name ('shared' or 'customer')
|
||||
|
||||
Returns:
|
||||
List of tool metadata for the specified category
|
||||
"""
|
||||
return [t for t in self._tools.values() if t.category == category]
|
||||
|
||||
def list_tool_ids(self) -> List[str]:
|
||||
"""Get a list of all registered tool IDs.
|
||||
|
||||
Returns:
|
||||
List of tool ID strings
|
||||
"""
|
||||
return list(self._tools.keys())
|
||||
|
||||
def get_tool_instances(self, *, tool_ids: List[str]) -> List[BaseTool]:
|
||||
"""Get actual tool instances for a list of tool IDs.
|
||||
|
||||
This is useful for filtering tools based on user permissions.
|
||||
|
||||
Args:
|
||||
tool_ids: List of tool IDs to retrieve
|
||||
|
||||
Returns:
|
||||
List of BaseTool instances for the specified IDs
|
||||
"""
|
||||
instances = []
|
||||
for tool_id in tool_ids:
|
||||
metadata = self.get_tool(tool_id=tool_id)
|
||||
if metadata:
|
||||
instances.append(metadata.tool_instance)
|
||||
else:
|
||||
logger.warning(f"Tool ID not found in registry: {tool_id}")
|
||||
return instances
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the registry has been initialized.
|
||||
|
||||
Returns:
|
||||
True if initialized, False otherwise
|
||||
"""
|
||||
return self._initialized
|
||||
|
||||
|
||||
# Global registry instance
|
||||
_registry: Optional[ToolRegistry] = None
|
||||
|
||||
|
||||
def get_registry() -> ToolRegistry:
|
||||
"""Get the global tool registry instance.
|
||||
|
||||
This function ensures the registry is initialized on first access.
|
||||
Subsequent calls return the same instance.
|
||||
|
||||
Returns:
|
||||
The global ToolRegistry instance
|
||||
"""
|
||||
global _registry
|
||||
|
||||
if _registry is None:
|
||||
_registry = ToolRegistry()
|
||||
|
||||
if not _registry.is_initialized:
|
||||
_registry.initialize()
|
||||
|
||||
return _registry
|
||||
|
||||
|
||||
def reinitialize_registry() -> ToolRegistry:
|
||||
"""Force reinitialize the tool registry.
|
||||
|
||||
This is useful for testing or when tools are added dynamically.
|
||||
|
||||
Returns:
|
||||
The reinitialized ToolRegistry instance
|
||||
"""
|
||||
global _registry
|
||||
_registry = ToolRegistry()
|
||||
_registry.initialize()
|
||||
return _registry
|
||||
198
tests/features/chatBot/utils/test_toolRegistry.py
Normal file
198
tests/features/chatBot/utils/test_toolRegistry.py
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
"""Pytest tests for the tool registry.
|
||||
|
||||
This module tests that the tool registry correctly discovers and catalogs
|
||||
all tools in the chatbotTools directory.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pytest
|
||||
from modules.features.chatBot.utils.toolRegistry import (
|
||||
ToolMetadata,
|
||||
ToolRegistry,
|
||||
get_registry,
|
||||
reinitialize_registry,
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestToolRegistry:
|
||||
"""Test suite for ToolRegistry class."""
|
||||
|
||||
@pytest.fixture
|
||||
def registry(self) -> ToolRegistry:
|
||||
"""Provide a fresh registry instance for each test."""
|
||||
return reinitialize_registry()
|
||||
|
||||
def test_registry_initialization(self, registry: ToolRegistry) -> None:
|
||||
"""Test that registry initializes correctly."""
|
||||
assert registry.is_initialized
|
||||
assert isinstance(registry._tools, dict)
|
||||
|
||||
def test_get_all_tools(self, registry: ToolRegistry) -> None:
|
||||
"""Test getting all registered tools."""
|
||||
all_tools = registry.get_all_tools()
|
||||
assert isinstance(all_tools, list)
|
||||
assert len(all_tools) > 0
|
||||
assert all(isinstance(tool, ToolMetadata) for tool in all_tools)
|
||||
|
||||
# Log all discovered tools
|
||||
logger.info(f"Found {len(all_tools)} tools in registry:")
|
||||
for tool in all_tools:
|
||||
logger.info(f"\n{tool}")
|
||||
|
||||
def test_tool_metadata_structure(self, registry: ToolRegistry) -> None:
|
||||
"""Test that tool metadata has correct structure."""
|
||||
all_tools = registry.get_all_tools()
|
||||
for tool in all_tools:
|
||||
assert isinstance(tool.tool_id, str)
|
||||
assert isinstance(tool.name, str)
|
||||
assert isinstance(tool.category, str)
|
||||
assert tool.category in ["shared", "customer"]
|
||||
assert isinstance(tool.description, str)
|
||||
assert isinstance(tool.tool_instance, BaseTool)
|
||||
assert isinstance(tool.module_path, str)
|
||||
|
||||
def test_list_tool_ids(self, registry: ToolRegistry) -> None:
|
||||
"""Test listing all tool IDs."""
|
||||
tool_ids = registry.list_tool_ids()
|
||||
assert isinstance(tool_ids, list)
|
||||
assert len(tool_ids) > 0
|
||||
assert all(isinstance(tool_id, str) for tool_id in tool_ids)
|
||||
|
||||
# Check that tool IDs follow expected format
|
||||
for tool_id in tool_ids:
|
||||
assert "." in tool_id
|
||||
category, name = tool_id.split(".", 1)
|
||||
assert category in ["shared", "customer"]
|
||||
|
||||
def test_get_specific_tool(self, registry: ToolRegistry) -> None:
|
||||
"""Test retrieving a specific tool by ID."""
|
||||
# Get all tool IDs first
|
||||
tool_ids = registry.list_tool_ids()
|
||||
if tool_ids:
|
||||
# Test with first available tool
|
||||
test_tool_id = tool_ids[0]
|
||||
tool_metadata = registry.get_tool(tool_id=test_tool_id)
|
||||
|
||||
assert tool_metadata is not None
|
||||
assert isinstance(tool_metadata, ToolMetadata)
|
||||
assert tool_metadata.tool_id == test_tool_id
|
||||
|
||||
def test_get_nonexistent_tool(self, registry: ToolRegistry) -> None:
|
||||
"""Test retrieving a tool that doesn't exist."""
|
||||
tool_metadata = registry.get_tool(tool_id="nonexistent.tool")
|
||||
assert tool_metadata is None
|
||||
|
||||
def test_get_tools_by_category_shared(self, registry: ToolRegistry) -> None:
|
||||
"""Test getting all shared tools."""
|
||||
shared_tools = registry.get_tools_by_category(category="shared")
|
||||
assert isinstance(shared_tools, list)
|
||||
assert all(tool.category == "shared" for tool in shared_tools)
|
||||
|
||||
def test_get_tools_by_category_customer(self, registry: ToolRegistry) -> None:
|
||||
"""Test getting all customer tools."""
|
||||
customer_tools = registry.get_tools_by_category(category="customer")
|
||||
assert isinstance(customer_tools, list)
|
||||
assert all(tool.category == "customer" for tool in customer_tools)
|
||||
|
||||
def test_get_tool_instances(self, registry: ToolRegistry) -> None:
|
||||
"""Test getting tool instances by IDs."""
|
||||
tool_ids = registry.list_tool_ids()
|
||||
if len(tool_ids) >= 2:
|
||||
# Test with first two tools
|
||||
test_ids = tool_ids[:2]
|
||||
instances = registry.get_tool_instances(tool_ids=test_ids)
|
||||
|
||||
assert isinstance(instances, list)
|
||||
assert len(instances) == 2
|
||||
assert all(isinstance(inst, BaseTool) for inst in instances)
|
||||
|
||||
def test_get_tool_instances_with_invalid_id(self, registry: ToolRegistry) -> None:
|
||||
"""Test getting tool instances with some invalid IDs."""
|
||||
tool_ids = registry.list_tool_ids()
|
||||
if tool_ids:
|
||||
# Mix valid and invalid IDs
|
||||
test_ids = [tool_ids[0], "invalid.tool"]
|
||||
instances = registry.get_tool_instances(tool_ids=test_ids)
|
||||
|
||||
# Should only return the valid one
|
||||
assert len(instances) == 1
|
||||
assert isinstance(instances[0], BaseTool)
|
||||
|
||||
def test_global_registry_singleton(self) -> None:
|
||||
"""Test that get_registry returns same instance."""
|
||||
registry1 = get_registry()
|
||||
registry2 = get_registry()
|
||||
assert registry1 is registry2
|
||||
|
||||
def test_reinitialize_registry(self) -> None:
|
||||
"""Test that reinitialize creates new instance."""
|
||||
registry1 = get_registry()
|
||||
registry2 = reinitialize_registry()
|
||||
# Should be different instances after reinitialize
|
||||
assert registry1 is not registry2
|
||||
assert registry2.is_initialized
|
||||
|
||||
|
||||
class TestToolDiscovery:
|
||||
"""Test suite for tool discovery functionality."""
|
||||
|
||||
def test_discovers_at_least_one_tool(self) -> None:
|
||||
"""Test that at least one tool is discovered."""
|
||||
registry = get_registry()
|
||||
tool_ids = registry.list_tool_ids()
|
||||
|
||||
# At least one tool should be successfully loaded
|
||||
assert len(tool_ids) >= 1, "Expected at least one tool to be discovered"
|
||||
|
||||
def test_query_althaus_database_if_available(self) -> None:
|
||||
"""Test query_althaus_database tool if it was successfully loaded."""
|
||||
registry = get_registry()
|
||||
tool = registry.get_tool(tool_id="customer.query_althaus_database")
|
||||
|
||||
if tool is not None:
|
||||
assert tool.name == "query_althaus_database"
|
||||
assert tool.category == "customer"
|
||||
assert "database" in tool.description.lower()
|
||||
else:
|
||||
# Tool may not have loaded due to import errors - log warning
|
||||
import logging
|
||||
|
||||
logging.warning(
|
||||
"customer.query_althaus_database tool not found - "
|
||||
"may have failed to import"
|
||||
)
|
||||
|
||||
def test_tavily_search_if_available(self) -> None:
|
||||
"""Test tavily_search tool if it was successfully loaded."""
|
||||
registry = get_registry()
|
||||
tool = registry.get_tool(tool_id="shared.tavily_search")
|
||||
|
||||
if tool is not None:
|
||||
assert tool.name == "tavily_search"
|
||||
assert tool.category == "shared"
|
||||
assert "search" in tool.description.lower()
|
||||
else:
|
||||
# Tool may not have loaded due to import errors - log warning
|
||||
import logging
|
||||
|
||||
logging.warning(
|
||||
"shared.tavily_search tool not found - may have failed to import"
|
||||
)
|
||||
|
||||
def test_tool_ids_have_correct_format(self) -> None:
|
||||
"""Test that all discovered tool IDs follow the expected format."""
|
||||
registry = get_registry()
|
||||
tool_ids = registry.list_tool_ids()
|
||||
|
||||
for tool_id in tool_ids:
|
||||
# All tool IDs should have format: category.toolname
|
||||
assert "." in tool_id, f"Tool ID {tool_id} missing category separator"
|
||||
category, name = tool_id.split(".", 1)
|
||||
assert category in [
|
||||
"shared",
|
||||
"customer",
|
||||
], f"Tool {tool_id} has invalid category: {category}"
|
||||
assert len(name) > 0, f"Tool {tool_id} has empty name"
|
||||
Loading…
Reference in a new issue