diff --git a/modules/datamodels/datamodelWorkflow.py b/modules/datamodels/datamodelWorkflow.py index 0ff2dcca..686144c3 100644 --- a/modules/datamodels/datamodelWorkflow.py +++ b/modules/datamodels/datamodelWorkflow.py @@ -1,5 +1,6 @@ """Workflow-related base datamodels and step/task structures.""" +from enum import Enum from typing import List, Dict, Any, Optional from pydantic import BaseModel, Field from modules.shared.attributeUtils import register_model_labels, ModelMixin @@ -7,9 +8,12 @@ from modules.shared.attributeUtils import register_model_labels, ModelMixin class ActionDocument(BaseModel, ModelMixin): """Clear document structure for action results""" + documentName: str = Field(description="Name of the document") documentData: Any = Field(description="Content/data of the document") mimeType: str = Field(description="MIME type of the document") + + register_model_labels( "ActionDocument", {"en": "Action Document", "fr": "Document d'action"}, @@ -31,16 +35,25 @@ class ActionResult(BaseModel, ModelMixin): success: bool = Field(description="Whether execution succeeded") error: Optional[str] = Field(None, description="Error message if failed") - documents: List[ActionDocument] = Field(default_factory=list, description="Document outputs") - resultLabel: Optional[str] = Field(None, description="Label for document routing (set by action handler, not by action methods)") + documents: List[ActionDocument] = Field( + default_factory=list, description="Document outputs" + ) + resultLabel: Optional[str] = Field( + None, + description="Label for document routing (set by action handler, not by action methods)", + ) @classmethod def isSuccess(cls, documents: List[ActionDocument] = None) -> "ActionResult": return cls(success=True, documents=documents or []) @classmethod - def isFailure(cls, error: str, documents: List[ActionDocument] = None) -> "ActionResult": + def isFailure( + cls, error: str, documents: List[ActionDocument] = None + ) -> "ActionResult": return cls(success=False, documents=documents or [], error=error) + + register_model_labels( "ActionResult", {"en": "Action Result", "fr": "Résultat de l'action"}, @@ -55,7 +68,9 @@ register_model_labels( class ActionSelection(BaseModel, ModelMixin): method: str = Field(description="Method to execute (e.g., web, document, ai)") - name: str = Field(description="Action name within the method (e.g., search, extract)") + name: str = Field( + description="Action name within the method (e.g., search, extract)" + ) register_model_labels( @@ -69,7 +84,9 @@ register_model_labels( class ActionParameters(BaseModel, ModelMixin): - parameters: Dict[str, Any] = Field(default_factory=dict, description="Parameters to execute the selected action") + parameters: Dict[str, Any] = Field( + default_factory=dict, description="Parameters to execute the selected action" + ) register_model_labels( @@ -102,8 +119,12 @@ class Observation(BaseModel, ModelMixin): success: bool = Field(description="Action execution success flag") resultLabel: str = Field(description="Deterministic label for produced documents") documentsCount: int = Field(description="Number of produced documents") - previews: List[ObservationPreview] = Field(default_factory=list, description="Compact previews of outputs") - notes: List[str] = Field(default_factory=list, description="Short notes or key facts") + previews: List[ObservationPreview] = Field( + default_factory=list, description="Compact previews of outputs" + ) + notes: List[str] = Field( + default_factory=list, description="Short notes or key facts" + ) register_model_labels( @@ -119,7 +140,9 @@ register_model_labels( ) -class TaskStatus(str): +class TaskStatus(str, Enum): + """Task status enumeration.""" + PENDING = "pending" RUNNING = "running" COMPLETED = "completed" @@ -142,7 +165,9 @@ register_model_labels( class DocumentExchange(BaseModel, ModelMixin): documentsLabel: str = Field(description="Label for the set of documents") - documents: List[str] = Field(default_factory=list, description="List of document references") + documents: List[str] = Field( + default_factory=list, description="List of document references" + ) register_model_labels( @@ -159,16 +184,28 @@ class TaskAction(BaseModel, ModelMixin): id: str = Field(..., description="Action ID") execMethod: str = Field(..., description="Method to execute") execAction: str = Field(..., description="Action to perform") - execParameters: Dict[str, Any] = Field(default_factory=dict, description="Action parameters") - execResultLabel: Optional[str] = Field(None, description="Label for the set of result documents") - expectedDocumentFormats: Optional[List[Dict[str, str]]] = Field(None, description="Expected document formats (optional)") - userMessage: Optional[str] = Field(None, description="User-friendly message in user's language") + execParameters: Dict[str, Any] = Field( + default_factory=dict, description="Action parameters" + ) + execResultLabel: Optional[str] = Field( + None, description="Label for the set of result documents" + ) + expectedDocumentFormats: Optional[List[Dict[str, str]]] = Field( + None, description="Expected document formats (optional)" + ) + userMessage: Optional[str] = Field( + None, description="User-friendly message in user's language" + ) status: TaskStatus = Field(default=TaskStatus.PENDING, description="Action status") error: Optional[str] = Field(None, description="Error message if action failed") retryCount: int = Field(default=0, description="Number of retries attempted") retryMax: int = Field(default=3, description="Maximum number of retries") - processingTime: Optional[float] = Field(None, description="Processing time in seconds") - timestamp: float = Field(..., description="When the action was executed (UTC timestamp in seconds)") + processingTime: Optional[float] = Field( + None, description="Processing time in seconds" + ) + timestamp: float = Field( + ..., description="When the action was executed (UTC timestamp in seconds)" + ) result: Optional[str] = Field(None, description="Result of the action") @@ -181,7 +218,10 @@ register_model_labels( "execAction": {"en": "Action", "fr": "Action"}, "execParameters": {"en": "Parameters", "fr": "Paramètres"}, "execResultLabel": {"en": "Result Label", "fr": "Label du résultat"}, - "expectedDocumentFormats": {"en": "Expected Document Formats", "fr": "Formats de documents attendus"}, + "expectedDocumentFormats": { + "en": "Expected Document Formats", + "fr": "Formats de documents attendus", + }, "userMessage": {"en": "User Message", "fr": "Message utilisateur"}, "status": {"en": "Status", "fr": "Statut"}, "error": {"en": "Error", "fr": "Erreur"}, @@ -221,16 +261,30 @@ class TaskItem(BaseModel, ModelMixin): userInput: str = Field(..., description="User input that triggered the task") status: TaskStatus = Field(default=TaskStatus.PENDING, description="Task status") error: Optional[str] = Field(None, description="Error message if task failed") - startedAt: Optional[float] = Field(None, description="When the task started (UTC timestamp in seconds)") - finishedAt: Optional[float] = Field(None, description="When the task finished (UTC timestamp in seconds)") - actionList: List[TaskAction] = Field(default_factory=list, description="List of actions to execute") + startedAt: Optional[float] = Field( + None, description="When the task started (UTC timestamp in seconds)" + ) + finishedAt: Optional[float] = Field( + None, description="When the task finished (UTC timestamp in seconds)" + ) + actionList: List[TaskAction] = Field( + default_factory=list, description="List of actions to execute" + ) retryCount: int = Field(default=0, description="Number of retries attempted") retryMax: int = Field(default=3, description="Maximum number of retries") - rollbackOnFailure: bool = Field(default=True, description="Whether to rollback on failure") - dependencies: List[str] = Field(default_factory=list, description="List of task IDs this task depends on") + rollbackOnFailure: bool = Field( + default=True, description="Whether to rollback on failure" + ) + dependencies: List[str] = Field( + default_factory=list, description="List of task IDs this task depends on" + ) feedback: Optional[str] = Field(None, description="Task feedback message") - processingTime: Optional[float] = Field(None, description="Total processing time in seconds") - resultLabels: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Map of result labels to their values") + processingTime: Optional[float] = Field( + None, description="Total processing time in seconds" + ) + resultLabels: Optional[Dict[str, Any]] = Field( + default_factory=dict, description="Map of result labels to their values" + ) register_model_labels( @@ -258,7 +312,9 @@ class TaskStep(BaseModel, ModelMixin): dependencies: Optional[list[str]] = Field(default_factory=list) success_criteria: Optional[list[str]] = Field(default_factory=list) estimated_complexity: Optional[str] = None - userMessage: Optional[str] = Field(None, description="User-friendly message in user's language") + userMessage: Optional[str] = Field( + None, description="User-friendly message in user's language" + ) register_model_labels( @@ -269,7 +325,10 @@ register_model_labels( "objective": {"en": "Objective", "fr": "Objectif"}, "dependencies": {"en": "Dependencies", "fr": "Dépendances"}, "success_criteria": {"en": "Success Criteria", "fr": "Critères de succès"}, - "estimated_complexity": {"en": "Estimated Complexity", "fr": "Complexité estimée"}, + "estimated_complexity": { + "en": "Estimated Complexity", + "fr": "Complexité estimée", + }, "userMessage": {"en": "User Message", "fr": "Message utilisateur"}, }, ) @@ -278,15 +337,31 @@ register_model_labels( class TaskHandover(BaseModel, ModelMixin): taskId: str = Field(description="Target task ID") sourceTask: Optional[str] = Field(None, description="Source task ID") - inputDocuments: List[DocumentExchange] = Field(default_factory=list, description="Available input documents") - outputDocuments: List[DocumentExchange] = Field(default_factory=list, description="Produced output documents") + inputDocuments: List[DocumentExchange] = Field( + default_factory=list, description="Available input documents" + ) + outputDocuments: List[DocumentExchange] = Field( + default_factory=list, description="Produced output documents" + ) context: Dict[str, Any] = Field(default_factory=dict, description="Task context") - previousResults: List[str] = Field(default_factory=list, description="Previous result summaries") - improvements: List[str] = Field(default_factory=list, description="Improvement suggestions") - workflowSummary: Optional[str] = Field(None, description="Summarized workflow context") - messageHistory: List[str] = Field(default_factory=list, description="Key message summaries") - timestamp: float = Field(..., description="When the handover was created (UTC timestamp in seconds)") - handoverType: str = Field(default="task", description="Type of handover: task, phase, or workflow") + previousResults: List[str] = Field( + default_factory=list, description="Previous result summaries" + ) + improvements: List[str] = Field( + default_factory=list, description="Improvement suggestions" + ) + workflowSummary: Optional[str] = Field( + None, description="Summarized workflow context" + ) + messageHistory: List[str] = Field( + default_factory=list, description="Key message summaries" + ) + timestamp: float = Field( + ..., description="When the handover was created (UTC timestamp in seconds)" + ) + handoverType: str = Field( + default="task", description="Type of handover: task, phase, or workflow" + ) register_model_labels( @@ -310,7 +385,7 @@ register_model_labels( class TaskContext(BaseModel, ModelMixin): task_step: TaskStep - workflow: Optional['ChatWorkflow'] = None + workflow: Optional["ChatWorkflow"] = None workflow_id: Optional[str] = None available_documents: Optional[str] = "No documents available" available_connections: Optional[list[str]] = Field(default_factory=list) @@ -358,7 +433,9 @@ class ReviewResult(BaseModel, ModelMixin): met_criteria: Optional[list[str]] = Field(default_factory=list) unmet_criteria: Optional[list[str]] = Field(default_factory=list) confidence: Optional[float] = 0.5 - userMessage: Optional[str] = Field(None, description="User-friendly message in user's language") + userMessage: Optional[str] = Field( + None, description="User-friendly message in user's language" + ) register_model_labels( @@ -381,7 +458,9 @@ register_model_labels( class TaskPlan(BaseModel, ModelMixin): overview: str tasks: list[TaskStep] - userMessage: Optional[str] = Field(None, description="Overall user-friendly message for the task plan") + userMessage: Optional[str] = Field( + None, description="Overall user-friendly message for the task plan" + ) register_model_labels( @@ -393,7 +472,3 @@ register_model_labels( "userMessage": {"en": "User Message", "fr": "Message utilisateur"}, }, ) - - - - diff --git a/modules/features/chatBot/utils/toolRegistry.py b/modules/features/chatBot/utils/toolRegistry.py new file mode 100644 index 00000000..5f5d14d6 --- /dev/null +++ b/modules/features/chatBot/utils/toolRegistry.py @@ -0,0 +1,305 @@ +"""Tool registry for auto-discovering and managing chatbot tools. + +This module provides a central registry that automatically discovers all tools +in the chatbotTools directory structure and provides methods to query them. +The registry is built in-memory at startup and does not require a database. +""" + +import importlib +import inspect +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional + +from langchain_core.tools import BaseTool + +logger = logging.getLogger(__name__) + + +@dataclass +class ToolMetadata: + """Metadata about a discovered chatbot tool. + + Attributes: + tool_id: Unique identifier (e.g., 'shared.tavily_search') + name: Function name of the tool + category: Category of the tool ('shared' or 'customer') + description: Tool description from docstring + tool_instance: The actual LangChain tool instance + module_path: Full Python module path + """ + + tool_id: str + name: str + category: str + description: str + tool_instance: BaseTool + module_path: str + + def __str__(self) -> str: + """Return a pretty-printed string representation for logging.""" + return ( + f"ToolMetadata(\n" + f" tool_id='{self.tool_id}',\n" + f" name='{self.name}',\n" + f" category='{self.category}',\n" + f" description='{self.description}',\n" + f" module_path='{self.module_path}'\n" + f")" + ) + + +class ToolRegistry: + """Central registry for all chatbot tools. + + This class discovers and catalogs all tools decorated with @tool in the + chatbotTools directory structure. Tools are automatically discovered at + initialization by scanning the filesystem. + + The registry provides methods to query tools by ID, category, or get all tools. + """ + + def __init__(self) -> None: + """Initialize an empty tool registry.""" + self._tools: Dict[str, ToolMetadata] = {} + self._initialized: bool = False + + def initialize(self) -> None: + """Discover and register all tools from the chatbotTools directory. + + This method scans both sharedTools and customerTools directories, + imports all tool*.py modules, and extracts functions decorated with @tool. + + This method is idempotent - calling it multiple times has no effect + after the first initialization. + """ + if self._initialized: + logger.debug("Tool registry already initialized, skipping") + return + + logger.info("Initializing tool registry...") + + # Get base path to chatbotTools directory + base_path = Path(__file__).parent.parent / "chatbotTools" + + if not base_path.exists(): + logger.warning(f"chatbotTools directory not found at {base_path}") + self._initialized = True + return + + # Discover tools in each category + self._discover_category( + category_path=base_path / "sharedTools", category="shared" + ) + self._discover_category( + category_path=base_path / "customerTools", category="customer" + ) + + self._initialized = True + logger.info(f"Tool registry initialized with {len(self._tools)} tools") + + def _discover_category(self, *, category_path: Path, category: str) -> None: + """Discover all tools in a specific category directory. + + Args: + category_path: Path to the category directory (sharedTools or customerTools) + category: Category name ('shared' or 'customer') + """ + if not category_path.exists(): + logger.warning(f"Category directory not found: {category_path}") + return + + logger.debug(f"Discovering tools in category: {category}") + + # Find all tool*.py files (excluding __init__.py) + tool_files = [ + f for f in category_path.glob("tool*.py") if f.name != "__init__.py" + ] + + for tool_file in tool_files: + self._import_and_register_tools( + tool_file=tool_file, category=category, category_path=category_path + ) + + logger.debug(f"Discovered {len(tool_files)} tool files in {category}") + + def _import_and_register_tools( + self, *, tool_file: Path, category: str, category_path: Path + ) -> None: + """Import a tool module and register all discovered tools. + + Args: + tool_file: Path to the tool Python file + category: Category name ('shared' or 'customer') + category_path: Path to the category directory + """ + # Construct module name + module_name = ( + f"modules.features.chatBot.chatbotTools.{category}Tools.{tool_file.stem}" + ) + + try: + # Import the module + module = importlib.import_module(module_name) + + # Find all BaseTool instances in the module + tools_found = 0 + for name, obj in inspect.getmembers(module): + if isinstance(obj, BaseTool): + self._register_tool( + tool_instance=obj, + name=name, + category=category, + module_path=module_name, + ) + tools_found += 1 + + if tools_found == 0: + logger.warning(f"No tools found in {module_name}") + else: + logger.debug(f"Loaded {tools_found} tool(s) from {module_name}") + + except ImportError as e: + logger.error( + f"Import error loading tools from {module_name}: {str(e)}. " + f"This tool will not be available." + ) + except Exception as e: + logger.error( + f"Unexpected error loading tools from {module_name}: {type(e).__name__}: {str(e)}" + ) + + def _register_tool( + self, *, tool_instance: BaseTool, name: str, category: str, module_path: str + ) -> None: + """Register a single tool in the registry. + + Args: + tool_instance: The LangChain tool instance + name: Function name of the tool + category: Category name ('shared' or 'customer') + module_path: Full Python module path + """ + tool_id = f"{category}.{name}" + + # Check for duplicate tool IDs + if tool_id in self._tools: + logger.warning(f"Duplicate tool ID detected: {tool_id}, overwriting") + + metadata = ToolMetadata( + tool_id=tool_id, + name=name, + category=category, + description=tool_instance.description or "", + tool_instance=tool_instance, + module_path=module_path, + ) + + self._tools[tool_id] = metadata + logger.debug(f"Registered tool: {tool_id}") + + def get_all_tools(self) -> List[ToolMetadata]: + """Get all registered tools. + + Returns: + List of all tool metadata objects + """ + return list(self._tools.values()) + + def get_tool(self, *, tool_id: str) -> Optional[ToolMetadata]: + """Get a specific tool by its ID. + + Args: + tool_id: The tool identifier (e.g., 'shared.tavily_search') + + Returns: + Tool metadata if found, None otherwise + """ + return self._tools.get(tool_id) + + def get_tools_by_category(self, *, category: str) -> List[ToolMetadata]: + """Get all tools in a specific category. + + Args: + category: Category name ('shared' or 'customer') + + Returns: + List of tool metadata for the specified category + """ + return [t for t in self._tools.values() if t.category == category] + + def list_tool_ids(self) -> List[str]: + """Get a list of all registered tool IDs. + + Returns: + List of tool ID strings + """ + return list(self._tools.keys()) + + def get_tool_instances(self, *, tool_ids: List[str]) -> List[BaseTool]: + """Get actual tool instances for a list of tool IDs. + + This is useful for filtering tools based on user permissions. + + Args: + tool_ids: List of tool IDs to retrieve + + Returns: + List of BaseTool instances for the specified IDs + """ + instances = [] + for tool_id in tool_ids: + metadata = self.get_tool(tool_id=tool_id) + if metadata: + instances.append(metadata.tool_instance) + else: + logger.warning(f"Tool ID not found in registry: {tool_id}") + return instances + + @property + def is_initialized(self) -> bool: + """Check if the registry has been initialized. + + Returns: + True if initialized, False otherwise + """ + return self._initialized + + +# Global registry instance +_registry: Optional[ToolRegistry] = None + + +def get_registry() -> ToolRegistry: + """Get the global tool registry instance. + + This function ensures the registry is initialized on first access. + Subsequent calls return the same instance. + + Returns: + The global ToolRegistry instance + """ + global _registry + + if _registry is None: + _registry = ToolRegistry() + + if not _registry.is_initialized: + _registry.initialize() + + return _registry + + +def reinitialize_registry() -> ToolRegistry: + """Force reinitialize the tool registry. + + This is useful for testing or when tools are added dynamically. + + Returns: + The reinitialized ToolRegistry instance + """ + global _registry + _registry = ToolRegistry() + _registry.initialize() + return _registry diff --git a/tests/features/chatBot/utils/test_toolRegistry.py b/tests/features/chatBot/utils/test_toolRegistry.py new file mode 100644 index 00000000..219752b7 --- /dev/null +++ b/tests/features/chatBot/utils/test_toolRegistry.py @@ -0,0 +1,198 @@ +"""Pytest tests for the tool registry. + +This module tests that the tool registry correctly discovers and catalogs +all tools in the chatbotTools directory. +""" + +import logging +import pytest +from modules.features.chatBot.utils.toolRegistry import ( + ToolMetadata, + ToolRegistry, + get_registry, + reinitialize_registry, +) +from langchain_core.tools import BaseTool + +logger = logging.getLogger(__name__) + + +class TestToolRegistry: + """Test suite for ToolRegistry class.""" + + @pytest.fixture + def registry(self) -> ToolRegistry: + """Provide a fresh registry instance for each test.""" + return reinitialize_registry() + + def test_registry_initialization(self, registry: ToolRegistry) -> None: + """Test that registry initializes correctly.""" + assert registry.is_initialized + assert isinstance(registry._tools, dict) + + def test_get_all_tools(self, registry: ToolRegistry) -> None: + """Test getting all registered tools.""" + all_tools = registry.get_all_tools() + assert isinstance(all_tools, list) + assert len(all_tools) > 0 + assert all(isinstance(tool, ToolMetadata) for tool in all_tools) + + # Log all discovered tools + logger.info(f"Found {len(all_tools)} tools in registry:") + for tool in all_tools: + logger.info(f"\n{tool}") + + def test_tool_metadata_structure(self, registry: ToolRegistry) -> None: + """Test that tool metadata has correct structure.""" + all_tools = registry.get_all_tools() + for tool in all_tools: + assert isinstance(tool.tool_id, str) + assert isinstance(tool.name, str) + assert isinstance(tool.category, str) + assert tool.category in ["shared", "customer"] + assert isinstance(tool.description, str) + assert isinstance(tool.tool_instance, BaseTool) + assert isinstance(tool.module_path, str) + + def test_list_tool_ids(self, registry: ToolRegistry) -> None: + """Test listing all tool IDs.""" + tool_ids = registry.list_tool_ids() + assert isinstance(tool_ids, list) + assert len(tool_ids) > 0 + assert all(isinstance(tool_id, str) for tool_id in tool_ids) + + # Check that tool IDs follow expected format + for tool_id in tool_ids: + assert "." in tool_id + category, name = tool_id.split(".", 1) + assert category in ["shared", "customer"] + + def test_get_specific_tool(self, registry: ToolRegistry) -> None: + """Test retrieving a specific tool by ID.""" + # Get all tool IDs first + tool_ids = registry.list_tool_ids() + if tool_ids: + # Test with first available tool + test_tool_id = tool_ids[0] + tool_metadata = registry.get_tool(tool_id=test_tool_id) + + assert tool_metadata is not None + assert isinstance(tool_metadata, ToolMetadata) + assert tool_metadata.tool_id == test_tool_id + + def test_get_nonexistent_tool(self, registry: ToolRegistry) -> None: + """Test retrieving a tool that doesn't exist.""" + tool_metadata = registry.get_tool(tool_id="nonexistent.tool") + assert tool_metadata is None + + def test_get_tools_by_category_shared(self, registry: ToolRegistry) -> None: + """Test getting all shared tools.""" + shared_tools = registry.get_tools_by_category(category="shared") + assert isinstance(shared_tools, list) + assert all(tool.category == "shared" for tool in shared_tools) + + def test_get_tools_by_category_customer(self, registry: ToolRegistry) -> None: + """Test getting all customer tools.""" + customer_tools = registry.get_tools_by_category(category="customer") + assert isinstance(customer_tools, list) + assert all(tool.category == "customer" for tool in customer_tools) + + def test_get_tool_instances(self, registry: ToolRegistry) -> None: + """Test getting tool instances by IDs.""" + tool_ids = registry.list_tool_ids() + if len(tool_ids) >= 2: + # Test with first two tools + test_ids = tool_ids[:2] + instances = registry.get_tool_instances(tool_ids=test_ids) + + assert isinstance(instances, list) + assert len(instances) == 2 + assert all(isinstance(inst, BaseTool) for inst in instances) + + def test_get_tool_instances_with_invalid_id(self, registry: ToolRegistry) -> None: + """Test getting tool instances with some invalid IDs.""" + tool_ids = registry.list_tool_ids() + if tool_ids: + # Mix valid and invalid IDs + test_ids = [tool_ids[0], "invalid.tool"] + instances = registry.get_tool_instances(tool_ids=test_ids) + + # Should only return the valid one + assert len(instances) == 1 + assert isinstance(instances[0], BaseTool) + + def test_global_registry_singleton(self) -> None: + """Test that get_registry returns same instance.""" + registry1 = get_registry() + registry2 = get_registry() + assert registry1 is registry2 + + def test_reinitialize_registry(self) -> None: + """Test that reinitialize creates new instance.""" + registry1 = get_registry() + registry2 = reinitialize_registry() + # Should be different instances after reinitialize + assert registry1 is not registry2 + assert registry2.is_initialized + + +class TestToolDiscovery: + """Test suite for tool discovery functionality.""" + + def test_discovers_at_least_one_tool(self) -> None: + """Test that at least one tool is discovered.""" + registry = get_registry() + tool_ids = registry.list_tool_ids() + + # At least one tool should be successfully loaded + assert len(tool_ids) >= 1, "Expected at least one tool to be discovered" + + def test_query_althaus_database_if_available(self) -> None: + """Test query_althaus_database tool if it was successfully loaded.""" + registry = get_registry() + tool = registry.get_tool(tool_id="customer.query_althaus_database") + + if tool is not None: + assert tool.name == "query_althaus_database" + assert tool.category == "customer" + assert "database" in tool.description.lower() + else: + # Tool may not have loaded due to import errors - log warning + import logging + + logging.warning( + "customer.query_althaus_database tool not found - " + "may have failed to import" + ) + + def test_tavily_search_if_available(self) -> None: + """Test tavily_search tool if it was successfully loaded.""" + registry = get_registry() + tool = registry.get_tool(tool_id="shared.tavily_search") + + if tool is not None: + assert tool.name == "tavily_search" + assert tool.category == "shared" + assert "search" in tool.description.lower() + else: + # Tool may not have loaded due to import errors - log warning + import logging + + logging.warning( + "shared.tavily_search tool not found - may have failed to import" + ) + + def test_tool_ids_have_correct_format(self) -> None: + """Test that all discovered tool IDs follow the expected format.""" + registry = get_registry() + tool_ids = registry.list_tool_ids() + + for tool_id in tool_ids: + # All tool IDs should have format: category.toolname + assert "." in tool_id, f"Tool ID {tool_id} missing category separator" + category, name = tool_id.split(".", 1) + assert category in [ + "shared", + "customer", + ], f"Tool {tool_id} has invalid category: {category}" + assert len(name) > 0, f"Tool {tool_id} has empty name"