feat: add tools registry
This commit is contained in:
parent
63dba85b7a
commit
2158b90748
3 changed files with 618 additions and 40 deletions
|
|
@ -1,5 +1,6 @@
|
||||||
"""Workflow-related base datamodels and step/task structures."""
|
"""Workflow-related base datamodels and step/task structures."""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from modules.shared.attributeUtils import register_model_labels, ModelMixin
|
from modules.shared.attributeUtils import register_model_labels, ModelMixin
|
||||||
|
|
@ -7,9 +8,12 @@ from modules.shared.attributeUtils import register_model_labels, ModelMixin
|
||||||
|
|
||||||
class ActionDocument(BaseModel, ModelMixin):
|
class ActionDocument(BaseModel, ModelMixin):
|
||||||
"""Clear document structure for action results"""
|
"""Clear document structure for action results"""
|
||||||
|
|
||||||
documentName: str = Field(description="Name of the document")
|
documentName: str = Field(description="Name of the document")
|
||||||
documentData: Any = Field(description="Content/data of the document")
|
documentData: Any = Field(description="Content/data of the document")
|
||||||
mimeType: str = Field(description="MIME type of the document")
|
mimeType: str = Field(description="MIME type of the document")
|
||||||
|
|
||||||
|
|
||||||
register_model_labels(
|
register_model_labels(
|
||||||
"ActionDocument",
|
"ActionDocument",
|
||||||
{"en": "Action Document", "fr": "Document d'action"},
|
{"en": "Action Document", "fr": "Document d'action"},
|
||||||
|
|
@ -31,16 +35,25 @@ class ActionResult(BaseModel, ModelMixin):
|
||||||
|
|
||||||
success: bool = Field(description="Whether execution succeeded")
|
success: bool = Field(description="Whether execution succeeded")
|
||||||
error: Optional[str] = Field(None, description="Error message if failed")
|
error: Optional[str] = Field(None, description="Error message if failed")
|
||||||
documents: List[ActionDocument] = Field(default_factory=list, description="Document outputs")
|
documents: List[ActionDocument] = Field(
|
||||||
resultLabel: Optional[str] = Field(None, description="Label for document routing (set by action handler, not by action methods)")
|
default_factory=list, description="Document outputs"
|
||||||
|
)
|
||||||
|
resultLabel: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="Label for document routing (set by action handler, not by action methods)",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def isSuccess(cls, documents: List[ActionDocument] = None) -> "ActionResult":
|
def isSuccess(cls, documents: List[ActionDocument] = None) -> "ActionResult":
|
||||||
return cls(success=True, documents=documents or [])
|
return cls(success=True, documents=documents or [])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def isFailure(cls, error: str, documents: List[ActionDocument] = None) -> "ActionResult":
|
def isFailure(
|
||||||
|
cls, error: str, documents: List[ActionDocument] = None
|
||||||
|
) -> "ActionResult":
|
||||||
return cls(success=False, documents=documents or [], error=error)
|
return cls(success=False, documents=documents or [], error=error)
|
||||||
|
|
||||||
|
|
||||||
register_model_labels(
|
register_model_labels(
|
||||||
"ActionResult",
|
"ActionResult",
|
||||||
{"en": "Action Result", "fr": "Résultat de l'action"},
|
{"en": "Action Result", "fr": "Résultat de l'action"},
|
||||||
|
|
@ -55,7 +68,9 @@ register_model_labels(
|
||||||
|
|
||||||
class ActionSelection(BaseModel, ModelMixin):
|
class ActionSelection(BaseModel, ModelMixin):
|
||||||
method: str = Field(description="Method to execute (e.g., web, document, ai)")
|
method: str = Field(description="Method to execute (e.g., web, document, ai)")
|
||||||
name: str = Field(description="Action name within the method (e.g., search, extract)")
|
name: str = Field(
|
||||||
|
description="Action name within the method (e.g., search, extract)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_labels(
|
register_model_labels(
|
||||||
|
|
@ -69,7 +84,9 @@ register_model_labels(
|
||||||
|
|
||||||
|
|
||||||
class ActionParameters(BaseModel, ModelMixin):
|
class ActionParameters(BaseModel, ModelMixin):
|
||||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="Parameters to execute the selected action")
|
parameters: Dict[str, Any] = Field(
|
||||||
|
default_factory=dict, description="Parameters to execute the selected action"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_labels(
|
register_model_labels(
|
||||||
|
|
@ -102,8 +119,12 @@ class Observation(BaseModel, ModelMixin):
|
||||||
success: bool = Field(description="Action execution success flag")
|
success: bool = Field(description="Action execution success flag")
|
||||||
resultLabel: str = Field(description="Deterministic label for produced documents")
|
resultLabel: str = Field(description="Deterministic label for produced documents")
|
||||||
documentsCount: int = Field(description="Number of produced documents")
|
documentsCount: int = Field(description="Number of produced documents")
|
||||||
previews: List[ObservationPreview] = Field(default_factory=list, description="Compact previews of outputs")
|
previews: List[ObservationPreview] = Field(
|
||||||
notes: List[str] = Field(default_factory=list, description="Short notes or key facts")
|
default_factory=list, description="Compact previews of outputs"
|
||||||
|
)
|
||||||
|
notes: List[str] = Field(
|
||||||
|
default_factory=list, description="Short notes or key facts"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_labels(
|
register_model_labels(
|
||||||
|
|
@ -119,7 +140,9 @@ register_model_labels(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(str):
|
class TaskStatus(str, Enum):
|
||||||
|
"""Task status enumeration."""
|
||||||
|
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
RUNNING = "running"
|
RUNNING = "running"
|
||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
|
|
@ -142,7 +165,9 @@ register_model_labels(
|
||||||
|
|
||||||
class DocumentExchange(BaseModel, ModelMixin):
|
class DocumentExchange(BaseModel, ModelMixin):
|
||||||
documentsLabel: str = Field(description="Label for the set of documents")
|
documentsLabel: str = Field(description="Label for the set of documents")
|
||||||
documents: List[str] = Field(default_factory=list, description="List of document references")
|
documents: List[str] = Field(
|
||||||
|
default_factory=list, description="List of document references"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_labels(
|
register_model_labels(
|
||||||
|
|
@ -159,16 +184,28 @@ class TaskAction(BaseModel, ModelMixin):
|
||||||
id: str = Field(..., description="Action ID")
|
id: str = Field(..., description="Action ID")
|
||||||
execMethod: str = Field(..., description="Method to execute")
|
execMethod: str = Field(..., description="Method to execute")
|
||||||
execAction: str = Field(..., description="Action to perform")
|
execAction: str = Field(..., description="Action to perform")
|
||||||
execParameters: Dict[str, Any] = Field(default_factory=dict, description="Action parameters")
|
execParameters: Dict[str, Any] = Field(
|
||||||
execResultLabel: Optional[str] = Field(None, description="Label for the set of result documents")
|
default_factory=dict, description="Action parameters"
|
||||||
expectedDocumentFormats: Optional[List[Dict[str, str]]] = Field(None, description="Expected document formats (optional)")
|
)
|
||||||
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
|
execResultLabel: Optional[str] = Field(
|
||||||
|
None, description="Label for the set of result documents"
|
||||||
|
)
|
||||||
|
expectedDocumentFormats: Optional[List[Dict[str, str]]] = Field(
|
||||||
|
None, description="Expected document formats (optional)"
|
||||||
|
)
|
||||||
|
userMessage: Optional[str] = Field(
|
||||||
|
None, description="User-friendly message in user's language"
|
||||||
|
)
|
||||||
status: TaskStatus = Field(default=TaskStatus.PENDING, description="Action status")
|
status: TaskStatus = Field(default=TaskStatus.PENDING, description="Action status")
|
||||||
error: Optional[str] = Field(None, description="Error message if action failed")
|
error: Optional[str] = Field(None, description="Error message if action failed")
|
||||||
retryCount: int = Field(default=0, description="Number of retries attempted")
|
retryCount: int = Field(default=0, description="Number of retries attempted")
|
||||||
retryMax: int = Field(default=3, description="Maximum number of retries")
|
retryMax: int = Field(default=3, description="Maximum number of retries")
|
||||||
processingTime: Optional[float] = Field(None, description="Processing time in seconds")
|
processingTime: Optional[float] = Field(
|
||||||
timestamp: float = Field(..., description="When the action was executed (UTC timestamp in seconds)")
|
None, description="Processing time in seconds"
|
||||||
|
)
|
||||||
|
timestamp: float = Field(
|
||||||
|
..., description="When the action was executed (UTC timestamp in seconds)"
|
||||||
|
)
|
||||||
result: Optional[str] = Field(None, description="Result of the action")
|
result: Optional[str] = Field(None, description="Result of the action")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -181,7 +218,10 @@ register_model_labels(
|
||||||
"execAction": {"en": "Action", "fr": "Action"},
|
"execAction": {"en": "Action", "fr": "Action"},
|
||||||
"execParameters": {"en": "Parameters", "fr": "Paramètres"},
|
"execParameters": {"en": "Parameters", "fr": "Paramètres"},
|
||||||
"execResultLabel": {"en": "Result Label", "fr": "Label du résultat"},
|
"execResultLabel": {"en": "Result Label", "fr": "Label du résultat"},
|
||||||
"expectedDocumentFormats": {"en": "Expected Document Formats", "fr": "Formats de documents attendus"},
|
"expectedDocumentFormats": {
|
||||||
|
"en": "Expected Document Formats",
|
||||||
|
"fr": "Formats de documents attendus",
|
||||||
|
},
|
||||||
"userMessage": {"en": "User Message", "fr": "Message utilisateur"},
|
"userMessage": {"en": "User Message", "fr": "Message utilisateur"},
|
||||||
"status": {"en": "Status", "fr": "Statut"},
|
"status": {"en": "Status", "fr": "Statut"},
|
||||||
"error": {"en": "Error", "fr": "Erreur"},
|
"error": {"en": "Error", "fr": "Erreur"},
|
||||||
|
|
@ -221,16 +261,30 @@ class TaskItem(BaseModel, ModelMixin):
|
||||||
userInput: str = Field(..., description="User input that triggered the task")
|
userInput: str = Field(..., description="User input that triggered the task")
|
||||||
status: TaskStatus = Field(default=TaskStatus.PENDING, description="Task status")
|
status: TaskStatus = Field(default=TaskStatus.PENDING, description="Task status")
|
||||||
error: Optional[str] = Field(None, description="Error message if task failed")
|
error: Optional[str] = Field(None, description="Error message if task failed")
|
||||||
startedAt: Optional[float] = Field(None, description="When the task started (UTC timestamp in seconds)")
|
startedAt: Optional[float] = Field(
|
||||||
finishedAt: Optional[float] = Field(None, description="When the task finished (UTC timestamp in seconds)")
|
None, description="When the task started (UTC timestamp in seconds)"
|
||||||
actionList: List[TaskAction] = Field(default_factory=list, description="List of actions to execute")
|
)
|
||||||
|
finishedAt: Optional[float] = Field(
|
||||||
|
None, description="When the task finished (UTC timestamp in seconds)"
|
||||||
|
)
|
||||||
|
actionList: List[TaskAction] = Field(
|
||||||
|
default_factory=list, description="List of actions to execute"
|
||||||
|
)
|
||||||
retryCount: int = Field(default=0, description="Number of retries attempted")
|
retryCount: int = Field(default=0, description="Number of retries attempted")
|
||||||
retryMax: int = Field(default=3, description="Maximum number of retries")
|
retryMax: int = Field(default=3, description="Maximum number of retries")
|
||||||
rollbackOnFailure: bool = Field(default=True, description="Whether to rollback on failure")
|
rollbackOnFailure: bool = Field(
|
||||||
dependencies: List[str] = Field(default_factory=list, description="List of task IDs this task depends on")
|
default=True, description="Whether to rollback on failure"
|
||||||
|
)
|
||||||
|
dependencies: List[str] = Field(
|
||||||
|
default_factory=list, description="List of task IDs this task depends on"
|
||||||
|
)
|
||||||
feedback: Optional[str] = Field(None, description="Task feedback message")
|
feedback: Optional[str] = Field(None, description="Task feedback message")
|
||||||
processingTime: Optional[float] = Field(None, description="Total processing time in seconds")
|
processingTime: Optional[float] = Field(
|
||||||
resultLabels: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Map of result labels to their values")
|
None, description="Total processing time in seconds"
|
||||||
|
)
|
||||||
|
resultLabels: Optional[Dict[str, Any]] = Field(
|
||||||
|
default_factory=dict, description="Map of result labels to their values"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_labels(
|
register_model_labels(
|
||||||
|
|
@ -258,7 +312,9 @@ class TaskStep(BaseModel, ModelMixin):
|
||||||
dependencies: Optional[list[str]] = Field(default_factory=list)
|
dependencies: Optional[list[str]] = Field(default_factory=list)
|
||||||
success_criteria: Optional[list[str]] = Field(default_factory=list)
|
success_criteria: Optional[list[str]] = Field(default_factory=list)
|
||||||
estimated_complexity: Optional[str] = None
|
estimated_complexity: Optional[str] = None
|
||||||
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
|
userMessage: Optional[str] = Field(
|
||||||
|
None, description="User-friendly message in user's language"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_labels(
|
register_model_labels(
|
||||||
|
|
@ -269,7 +325,10 @@ register_model_labels(
|
||||||
"objective": {"en": "Objective", "fr": "Objectif"},
|
"objective": {"en": "Objective", "fr": "Objectif"},
|
||||||
"dependencies": {"en": "Dependencies", "fr": "Dépendances"},
|
"dependencies": {"en": "Dependencies", "fr": "Dépendances"},
|
||||||
"success_criteria": {"en": "Success Criteria", "fr": "Critères de succès"},
|
"success_criteria": {"en": "Success Criteria", "fr": "Critères de succès"},
|
||||||
"estimated_complexity": {"en": "Estimated Complexity", "fr": "Complexité estimée"},
|
"estimated_complexity": {
|
||||||
|
"en": "Estimated Complexity",
|
||||||
|
"fr": "Complexité estimée",
|
||||||
|
},
|
||||||
"userMessage": {"en": "User Message", "fr": "Message utilisateur"},
|
"userMessage": {"en": "User Message", "fr": "Message utilisateur"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -278,15 +337,31 @@ register_model_labels(
|
||||||
class TaskHandover(BaseModel, ModelMixin):
|
class TaskHandover(BaseModel, ModelMixin):
|
||||||
taskId: str = Field(description="Target task ID")
|
taskId: str = Field(description="Target task ID")
|
||||||
sourceTask: Optional[str] = Field(None, description="Source task ID")
|
sourceTask: Optional[str] = Field(None, description="Source task ID")
|
||||||
inputDocuments: List[DocumentExchange] = Field(default_factory=list, description="Available input documents")
|
inputDocuments: List[DocumentExchange] = Field(
|
||||||
outputDocuments: List[DocumentExchange] = Field(default_factory=list, description="Produced output documents")
|
default_factory=list, description="Available input documents"
|
||||||
|
)
|
||||||
|
outputDocuments: List[DocumentExchange] = Field(
|
||||||
|
default_factory=list, description="Produced output documents"
|
||||||
|
)
|
||||||
context: Dict[str, Any] = Field(default_factory=dict, description="Task context")
|
context: Dict[str, Any] = Field(default_factory=dict, description="Task context")
|
||||||
previousResults: List[str] = Field(default_factory=list, description="Previous result summaries")
|
previousResults: List[str] = Field(
|
||||||
improvements: List[str] = Field(default_factory=list, description="Improvement suggestions")
|
default_factory=list, description="Previous result summaries"
|
||||||
workflowSummary: Optional[str] = Field(None, description="Summarized workflow context")
|
)
|
||||||
messageHistory: List[str] = Field(default_factory=list, description="Key message summaries")
|
improvements: List[str] = Field(
|
||||||
timestamp: float = Field(..., description="When the handover was created (UTC timestamp in seconds)")
|
default_factory=list, description="Improvement suggestions"
|
||||||
handoverType: str = Field(default="task", description="Type of handover: task, phase, or workflow")
|
)
|
||||||
|
workflowSummary: Optional[str] = Field(
|
||||||
|
None, description="Summarized workflow context"
|
||||||
|
)
|
||||||
|
messageHistory: List[str] = Field(
|
||||||
|
default_factory=list, description="Key message summaries"
|
||||||
|
)
|
||||||
|
timestamp: float = Field(
|
||||||
|
..., description="When the handover was created (UTC timestamp in seconds)"
|
||||||
|
)
|
||||||
|
handoverType: str = Field(
|
||||||
|
default="task", description="Type of handover: task, phase, or workflow"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_labels(
|
register_model_labels(
|
||||||
|
|
@ -310,7 +385,7 @@ register_model_labels(
|
||||||
|
|
||||||
class TaskContext(BaseModel, ModelMixin):
|
class TaskContext(BaseModel, ModelMixin):
|
||||||
task_step: TaskStep
|
task_step: TaskStep
|
||||||
workflow: Optional['ChatWorkflow'] = None
|
workflow: Optional["ChatWorkflow"] = None
|
||||||
workflow_id: Optional[str] = None
|
workflow_id: Optional[str] = None
|
||||||
available_documents: Optional[str] = "No documents available"
|
available_documents: Optional[str] = "No documents available"
|
||||||
available_connections: Optional[list[str]] = Field(default_factory=list)
|
available_connections: Optional[list[str]] = Field(default_factory=list)
|
||||||
|
|
@ -358,7 +433,9 @@ class ReviewResult(BaseModel, ModelMixin):
|
||||||
met_criteria: Optional[list[str]] = Field(default_factory=list)
|
met_criteria: Optional[list[str]] = Field(default_factory=list)
|
||||||
unmet_criteria: Optional[list[str]] = Field(default_factory=list)
|
unmet_criteria: Optional[list[str]] = Field(default_factory=list)
|
||||||
confidence: Optional[float] = 0.5
|
confidence: Optional[float] = 0.5
|
||||||
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
|
userMessage: Optional[str] = Field(
|
||||||
|
None, description="User-friendly message in user's language"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_labels(
|
register_model_labels(
|
||||||
|
|
@ -381,7 +458,9 @@ register_model_labels(
|
||||||
class TaskPlan(BaseModel, ModelMixin):
|
class TaskPlan(BaseModel, ModelMixin):
|
||||||
overview: str
|
overview: str
|
||||||
tasks: list[TaskStep]
|
tasks: list[TaskStep]
|
||||||
userMessage: Optional[str] = Field(None, description="Overall user-friendly message for the task plan")
|
userMessage: Optional[str] = Field(
|
||||||
|
None, description="Overall user-friendly message for the task plan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_labels(
|
register_model_labels(
|
||||||
|
|
@ -393,7 +472,3 @@ register_model_labels(
|
||||||
"userMessage": {"en": "User Message", "fr": "Message utilisateur"},
|
"userMessage": {"en": "User Message", "fr": "Message utilisateur"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
305
modules/features/chatBot/utils/toolRegistry.py
Normal file
305
modules/features/chatBot/utils/toolRegistry.py
Normal file
|
|
@ -0,0 +1,305 @@
|
||||||
|
"""Tool registry for auto-discovering and managing chatbot tools.
|
||||||
|
|
||||||
|
This module provides a central registry that automatically discovers all tools
|
||||||
|
in the chatbotTools directory structure and provides methods to query them.
|
||||||
|
The registry is built in-memory at startup and does not require a database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolMetadata:
|
||||||
|
"""Metadata about a discovered chatbot tool.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
tool_id: Unique identifier (e.g., 'shared.tavily_search')
|
||||||
|
name: Function name of the tool
|
||||||
|
category: Category of the tool ('shared' or 'customer')
|
||||||
|
description: Tool description from docstring
|
||||||
|
tool_instance: The actual LangChain tool instance
|
||||||
|
module_path: Full Python module path
|
||||||
|
"""
|
||||||
|
|
||||||
|
tool_id: str
|
||||||
|
name: str
|
||||||
|
category: str
|
||||||
|
description: str
|
||||||
|
tool_instance: BaseTool
|
||||||
|
module_path: str
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""Return a pretty-printed string representation for logging."""
|
||||||
|
return (
|
||||||
|
f"ToolMetadata(\n"
|
||||||
|
f" tool_id='{self.tool_id}',\n"
|
||||||
|
f" name='{self.name}',\n"
|
||||||
|
f" category='{self.category}',\n"
|
||||||
|
f" description='{self.description}',\n"
|
||||||
|
f" module_path='{self.module_path}'\n"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRegistry:
|
||||||
|
"""Central registry for all chatbot tools.
|
||||||
|
|
||||||
|
This class discovers and catalogs all tools decorated with @tool in the
|
||||||
|
chatbotTools directory structure. Tools are automatically discovered at
|
||||||
|
initialization by scanning the filesystem.
|
||||||
|
|
||||||
|
The registry provides methods to query tools by ID, category, or get all tools.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize an empty tool registry."""
|
||||||
|
self._tools: Dict[str, ToolMetadata] = {}
|
||||||
|
self._initialized: bool = False
|
||||||
|
|
||||||
|
def initialize(self) -> None:
|
||||||
|
"""Discover and register all tools from the chatbotTools directory.
|
||||||
|
|
||||||
|
This method scans both sharedTools and customerTools directories,
|
||||||
|
imports all tool*.py modules, and extracts functions decorated with @tool.
|
||||||
|
|
||||||
|
This method is idempotent - calling it multiple times has no effect
|
||||||
|
after the first initialization.
|
||||||
|
"""
|
||||||
|
if self._initialized:
|
||||||
|
logger.debug("Tool registry already initialized, skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Initializing tool registry...")
|
||||||
|
|
||||||
|
# Get base path to chatbotTools directory
|
||||||
|
base_path = Path(__file__).parent.parent / "chatbotTools"
|
||||||
|
|
||||||
|
if not base_path.exists():
|
||||||
|
logger.warning(f"chatbotTools directory not found at {base_path}")
|
||||||
|
self._initialized = True
|
||||||
|
return
|
||||||
|
|
||||||
|
# Discover tools in each category
|
||||||
|
self._discover_category(
|
||||||
|
category_path=base_path / "sharedTools", category="shared"
|
||||||
|
)
|
||||||
|
self._discover_category(
|
||||||
|
category_path=base_path / "customerTools", category="customer"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
logger.info(f"Tool registry initialized with {len(self._tools)} tools")
|
||||||
|
|
||||||
|
def _discover_category(self, *, category_path: Path, category: str) -> None:
|
||||||
|
"""Discover all tools in a specific category directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category_path: Path to the category directory (sharedTools or customerTools)
|
||||||
|
category: Category name ('shared' or 'customer')
|
||||||
|
"""
|
||||||
|
if not category_path.exists():
|
||||||
|
logger.warning(f"Category directory not found: {category_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(f"Discovering tools in category: {category}")
|
||||||
|
|
||||||
|
# Find all tool*.py files (excluding __init__.py)
|
||||||
|
tool_files = [
|
||||||
|
f for f in category_path.glob("tool*.py") if f.name != "__init__.py"
|
||||||
|
]
|
||||||
|
|
||||||
|
for tool_file in tool_files:
|
||||||
|
self._import_and_register_tools(
|
||||||
|
tool_file=tool_file, category=category, category_path=category_path
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Discovered {len(tool_files)} tool files in {category}")
|
||||||
|
|
||||||
|
def _import_and_register_tools(
|
||||||
|
self, *, tool_file: Path, category: str, category_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""Import a tool module and register all discovered tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_file: Path to the tool Python file
|
||||||
|
category: Category name ('shared' or 'customer')
|
||||||
|
category_path: Path to the category directory
|
||||||
|
"""
|
||||||
|
# Construct module name
|
||||||
|
module_name = (
|
||||||
|
f"modules.features.chatBot.chatbotTools.{category}Tools.{tool_file.stem}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Import the module
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# Find all BaseTool instances in the module
|
||||||
|
tools_found = 0
|
||||||
|
for name, obj in inspect.getmembers(module):
|
||||||
|
if isinstance(obj, BaseTool):
|
||||||
|
self._register_tool(
|
||||||
|
tool_instance=obj,
|
||||||
|
name=name,
|
||||||
|
category=category,
|
||||||
|
module_path=module_name,
|
||||||
|
)
|
||||||
|
tools_found += 1
|
||||||
|
|
||||||
|
if tools_found == 0:
|
||||||
|
logger.warning(f"No tools found in {module_name}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Loaded {tools_found} tool(s) from {module_name}")
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(
|
||||||
|
f"Import error loading tools from {module_name}: {str(e)}. "
|
||||||
|
f"This tool will not be available."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected error loading tools from {module_name}: {type(e).__name__}: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _register_tool(
|
||||||
|
self, *, tool_instance: BaseTool, name: str, category: str, module_path: str
|
||||||
|
) -> None:
|
||||||
|
"""Register a single tool in the registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_instance: The LangChain tool instance
|
||||||
|
name: Function name of the tool
|
||||||
|
category: Category name ('shared' or 'customer')
|
||||||
|
module_path: Full Python module path
|
||||||
|
"""
|
||||||
|
tool_id = f"{category}.{name}"
|
||||||
|
|
||||||
|
# Check for duplicate tool IDs
|
||||||
|
if tool_id in self._tools:
|
||||||
|
logger.warning(f"Duplicate tool ID detected: {tool_id}, overwriting")
|
||||||
|
|
||||||
|
metadata = ToolMetadata(
|
||||||
|
tool_id=tool_id,
|
||||||
|
name=name,
|
||||||
|
category=category,
|
||||||
|
description=tool_instance.description or "",
|
||||||
|
tool_instance=tool_instance,
|
||||||
|
module_path=module_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._tools[tool_id] = metadata
|
||||||
|
logger.debug(f"Registered tool: {tool_id}")
|
||||||
|
|
||||||
|
def get_all_tools(self) -> List[ToolMetadata]:
|
||||||
|
"""Get all registered tools.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of all tool metadata objects
|
||||||
|
"""
|
||||||
|
return list(self._tools.values())
|
||||||
|
|
||||||
|
def get_tool(self, *, tool_id: str) -> Optional[ToolMetadata]:
|
||||||
|
"""Get a specific tool by its ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_id: The tool identifier (e.g., 'shared.tavily_search')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool metadata if found, None otherwise
|
||||||
|
"""
|
||||||
|
return self._tools.get(tool_id)
|
||||||
|
|
||||||
|
def get_tools_by_category(self, *, category: str) -> List[ToolMetadata]:
|
||||||
|
"""Get all tools in a specific category.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: Category name ('shared' or 'customer')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool metadata for the specified category
|
||||||
|
"""
|
||||||
|
return [t for t in self._tools.values() if t.category == category]
|
||||||
|
|
||||||
|
def list_tool_ids(self) -> List[str]:
|
||||||
|
"""Get a list of all registered tool IDs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool ID strings
|
||||||
|
"""
|
||||||
|
return list(self._tools.keys())
|
||||||
|
|
||||||
|
def get_tool_instances(self, *, tool_ids: List[str]) -> List[BaseTool]:
|
||||||
|
"""Get actual tool instances for a list of tool IDs.
|
||||||
|
|
||||||
|
This is useful for filtering tools based on user permissions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_ids: List of tool IDs to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of BaseTool instances for the specified IDs
|
||||||
|
"""
|
||||||
|
instances = []
|
||||||
|
for tool_id in tool_ids:
|
||||||
|
metadata = self.get_tool(tool_id=tool_id)
|
||||||
|
if metadata:
|
||||||
|
instances.append(metadata.tool_instance)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Tool ID not found in registry: {tool_id}")
|
||||||
|
return instances
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_initialized(self) -> bool:
|
||||||
|
"""Check if the registry has been initialized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if initialized, False otherwise
|
||||||
|
"""
|
||||||
|
return self._initialized
|
||||||
|
|
||||||
|
|
||||||
|
# Global registry instance
|
||||||
|
_registry: Optional[ToolRegistry] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_registry() -> ToolRegistry:
|
||||||
|
"""Get the global tool registry instance.
|
||||||
|
|
||||||
|
This function ensures the registry is initialized on first access.
|
||||||
|
Subsequent calls return the same instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The global ToolRegistry instance
|
||||||
|
"""
|
||||||
|
global _registry
|
||||||
|
|
||||||
|
if _registry is None:
|
||||||
|
_registry = ToolRegistry()
|
||||||
|
|
||||||
|
if not _registry.is_initialized:
|
||||||
|
_registry.initialize()
|
||||||
|
|
||||||
|
return _registry
|
||||||
|
|
||||||
|
|
||||||
|
def reinitialize_registry() -> ToolRegistry:
|
||||||
|
"""Force reinitialize the tool registry.
|
||||||
|
|
||||||
|
This is useful for testing or when tools are added dynamically.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The reinitialized ToolRegistry instance
|
||||||
|
"""
|
||||||
|
global _registry
|
||||||
|
_registry = ToolRegistry()
|
||||||
|
_registry.initialize()
|
||||||
|
return _registry
|
||||||
198
tests/features/chatBot/utils/test_toolRegistry.py
Normal file
198
tests/features/chatBot/utils/test_toolRegistry.py
Normal file
|
|
@ -0,0 +1,198 @@
|
||||||
|
"""Pytest tests for the tool registry.
|
||||||
|
|
||||||
|
This module tests that the tool registry correctly discovers and catalogs
|
||||||
|
all tools in the chatbotTools directory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import pytest
|
||||||
|
from modules.features.chatBot.utils.toolRegistry import (
|
||||||
|
ToolMetadata,
|
||||||
|
ToolRegistry,
|
||||||
|
get_registry,
|
||||||
|
reinitialize_registry,
|
||||||
|
)
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolRegistry:
|
||||||
|
"""Test suite for ToolRegistry class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def registry(self) -> ToolRegistry:
|
||||||
|
"""Provide a fresh registry instance for each test."""
|
||||||
|
return reinitialize_registry()
|
||||||
|
|
||||||
|
def test_registry_initialization(self, registry: ToolRegistry) -> None:
|
||||||
|
"""Test that registry initializes correctly."""
|
||||||
|
assert registry.is_initialized
|
||||||
|
assert isinstance(registry._tools, dict)
|
||||||
|
|
||||||
|
def test_get_all_tools(self, registry: ToolRegistry) -> None:
|
||||||
|
"""Test getting all registered tools."""
|
||||||
|
all_tools = registry.get_all_tools()
|
||||||
|
assert isinstance(all_tools, list)
|
||||||
|
assert len(all_tools) > 0
|
||||||
|
assert all(isinstance(tool, ToolMetadata) for tool in all_tools)
|
||||||
|
|
||||||
|
# Log all discovered tools
|
||||||
|
logger.info(f"Found {len(all_tools)} tools in registry:")
|
||||||
|
for tool in all_tools:
|
||||||
|
logger.info(f"\n{tool}")
|
||||||
|
|
||||||
|
def test_tool_metadata_structure(self, registry: ToolRegistry) -> None:
|
||||||
|
"""Test that tool metadata has correct structure."""
|
||||||
|
all_tools = registry.get_all_tools()
|
||||||
|
for tool in all_tools:
|
||||||
|
assert isinstance(tool.tool_id, str)
|
||||||
|
assert isinstance(tool.name, str)
|
||||||
|
assert isinstance(tool.category, str)
|
||||||
|
assert tool.category in ["shared", "customer"]
|
||||||
|
assert isinstance(tool.description, str)
|
||||||
|
assert isinstance(tool.tool_instance, BaseTool)
|
||||||
|
assert isinstance(tool.module_path, str)
|
||||||
|
|
||||||
|
def test_list_tool_ids(self, registry: ToolRegistry) -> None:
|
||||||
|
"""Test listing all tool IDs."""
|
||||||
|
tool_ids = registry.list_tool_ids()
|
||||||
|
assert isinstance(tool_ids, list)
|
||||||
|
assert len(tool_ids) > 0
|
||||||
|
assert all(isinstance(tool_id, str) for tool_id in tool_ids)
|
||||||
|
|
||||||
|
# Check that tool IDs follow expected format
|
||||||
|
for tool_id in tool_ids:
|
||||||
|
assert "." in tool_id
|
||||||
|
category, name = tool_id.split(".", 1)
|
||||||
|
assert category in ["shared", "customer"]
|
||||||
|
|
||||||
|
def test_get_specific_tool(self, registry: ToolRegistry) -> None:
|
||||||
|
"""Test retrieving a specific tool by ID."""
|
||||||
|
# Get all tool IDs first
|
||||||
|
tool_ids = registry.list_tool_ids()
|
||||||
|
if tool_ids:
|
||||||
|
# Test with first available tool
|
||||||
|
test_tool_id = tool_ids[0]
|
||||||
|
tool_metadata = registry.get_tool(tool_id=test_tool_id)
|
||||||
|
|
||||||
|
assert tool_metadata is not None
|
||||||
|
assert isinstance(tool_metadata, ToolMetadata)
|
||||||
|
assert tool_metadata.tool_id == test_tool_id
|
||||||
|
|
||||||
|
def test_get_nonexistent_tool(self, registry: ToolRegistry) -> None:
|
||||||
|
"""Test retrieving a tool that doesn't exist."""
|
||||||
|
tool_metadata = registry.get_tool(tool_id="nonexistent.tool")
|
||||||
|
assert tool_metadata is None
|
||||||
|
|
||||||
|
def test_get_tools_by_category_shared(self, registry: ToolRegistry) -> None:
|
||||||
|
"""Test getting all shared tools."""
|
||||||
|
shared_tools = registry.get_tools_by_category(category="shared")
|
||||||
|
assert isinstance(shared_tools, list)
|
||||||
|
assert all(tool.category == "shared" for tool in shared_tools)
|
||||||
|
|
||||||
|
def test_get_tools_by_category_customer(self, registry: ToolRegistry) -> None:
|
||||||
|
"""Test getting all customer tools."""
|
||||||
|
customer_tools = registry.get_tools_by_category(category="customer")
|
||||||
|
assert isinstance(customer_tools, list)
|
||||||
|
assert all(tool.category == "customer" for tool in customer_tools)
|
||||||
|
|
||||||
|
def test_get_tool_instances(self, registry: ToolRegistry) -> None:
|
||||||
|
"""Test getting tool instances by IDs."""
|
||||||
|
tool_ids = registry.list_tool_ids()
|
||||||
|
if len(tool_ids) >= 2:
|
||||||
|
# Test with first two tools
|
||||||
|
test_ids = tool_ids[:2]
|
||||||
|
instances = registry.get_tool_instances(tool_ids=test_ids)
|
||||||
|
|
||||||
|
assert isinstance(instances, list)
|
||||||
|
assert len(instances) == 2
|
||||||
|
assert all(isinstance(inst, BaseTool) for inst in instances)
|
||||||
|
|
||||||
|
def test_get_tool_instances_with_invalid_id(self, registry: ToolRegistry) -> None:
|
||||||
|
"""Test getting tool instances with some invalid IDs."""
|
||||||
|
tool_ids = registry.list_tool_ids()
|
||||||
|
if tool_ids:
|
||||||
|
# Mix valid and invalid IDs
|
||||||
|
test_ids = [tool_ids[0], "invalid.tool"]
|
||||||
|
instances = registry.get_tool_instances(tool_ids=test_ids)
|
||||||
|
|
||||||
|
# Should only return the valid one
|
||||||
|
assert len(instances) == 1
|
||||||
|
assert isinstance(instances[0], BaseTool)
|
||||||
|
|
||||||
|
def test_global_registry_singleton(self) -> None:
|
||||||
|
"""Test that get_registry returns same instance."""
|
||||||
|
registry1 = get_registry()
|
||||||
|
registry2 = get_registry()
|
||||||
|
assert registry1 is registry2
|
||||||
|
|
||||||
|
def test_reinitialize_registry(self) -> None:
|
||||||
|
"""Test that reinitialize creates new instance."""
|
||||||
|
registry1 = get_registry()
|
||||||
|
registry2 = reinitialize_registry()
|
||||||
|
# Should be different instances after reinitialize
|
||||||
|
assert registry1 is not registry2
|
||||||
|
assert registry2.is_initialized
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolDiscovery:
|
||||||
|
"""Test suite for tool discovery functionality."""
|
||||||
|
|
||||||
|
def test_discovers_at_least_one_tool(self) -> None:
|
||||||
|
"""Test that at least one tool is discovered."""
|
||||||
|
registry = get_registry()
|
||||||
|
tool_ids = registry.list_tool_ids()
|
||||||
|
|
||||||
|
# At least one tool should be successfully loaded
|
||||||
|
assert len(tool_ids) >= 1, "Expected at least one tool to be discovered"
|
||||||
|
|
||||||
|
def test_query_althaus_database_if_available(self) -> None:
|
||||||
|
"""Test query_althaus_database tool if it was successfully loaded."""
|
||||||
|
registry = get_registry()
|
||||||
|
tool = registry.get_tool(tool_id="customer.query_althaus_database")
|
||||||
|
|
||||||
|
if tool is not None:
|
||||||
|
assert tool.name == "query_althaus_database"
|
||||||
|
assert tool.category == "customer"
|
||||||
|
assert "database" in tool.description.lower()
|
||||||
|
else:
|
||||||
|
# Tool may not have loaded due to import errors - log warning
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.warning(
|
||||||
|
"customer.query_althaus_database tool not found - "
|
||||||
|
"may have failed to import"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_tavily_search_if_available(self) -> None:
|
||||||
|
"""Test tavily_search tool if it was successfully loaded."""
|
||||||
|
registry = get_registry()
|
||||||
|
tool = registry.get_tool(tool_id="shared.tavily_search")
|
||||||
|
|
||||||
|
if tool is not None:
|
||||||
|
assert tool.name == "tavily_search"
|
||||||
|
assert tool.category == "shared"
|
||||||
|
assert "search" in tool.description.lower()
|
||||||
|
else:
|
||||||
|
# Tool may not have loaded due to import errors - log warning
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.warning(
|
||||||
|
"shared.tavily_search tool not found - may have failed to import"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_tool_ids_have_correct_format(self) -> None:
|
||||||
|
"""Test that all discovered tool IDs follow the expected format."""
|
||||||
|
registry = get_registry()
|
||||||
|
tool_ids = registry.list_tool_ids()
|
||||||
|
|
||||||
|
for tool_id in tool_ids:
|
||||||
|
# All tool IDs should have format: category.toolname
|
||||||
|
assert "." in tool_id, f"Tool ID {tool_id} missing category separator"
|
||||||
|
category, name = tool_id.split(".", 1)
|
||||||
|
assert category in [
|
||||||
|
"shared",
|
||||||
|
"customer",
|
||||||
|
], f"Tool {tool_id} has invalid category: {category}"
|
||||||
|
assert len(name) > 0, f"Tool {tool_id} has empty name"
|
||||||
Loading…
Reference in a new issue