198 lines
7.7 KiB
Python
198 lines
7.7 KiB
Python
"""Pytest tests for the tool registry.
|
|
|
|
This module tests that the tool registry correctly discovers and catalogs
|
|
all tools in the chatbotTools directory.
|
|
"""
|
|
|
|
import logging
|
|
import pytest
|
|
from modules.features.chatBot.utils.toolRegistry import (
|
|
ToolMetadata,
|
|
ToolRegistry,
|
|
get_registry,
|
|
reinitialize_registry,
|
|
)
|
|
from langchain_core.tools import BaseTool
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TestToolRegistry:
|
|
"""Test suite for ToolRegistry class."""
|
|
|
|
@pytest.fixture
|
|
def registry(self) -> ToolRegistry:
|
|
"""Provide a fresh registry instance for each test."""
|
|
return reinitialize_registry()
|
|
|
|
def test_registry_initialization(self, registry: ToolRegistry) -> None:
|
|
"""Test that registry initializes correctly."""
|
|
assert registry.is_initialized
|
|
assert isinstance(registry._tools, dict)
|
|
|
|
def test_get_all_tools(self, registry: ToolRegistry) -> None:
|
|
"""Test getting all registered tools."""
|
|
all_tools = registry.get_all_tools()
|
|
assert isinstance(all_tools, list)
|
|
assert len(all_tools) > 0
|
|
assert all(isinstance(tool, ToolMetadata) for tool in all_tools)
|
|
|
|
# Log all discovered tools
|
|
logger.info(f"Found {len(all_tools)} tools in registry:")
|
|
for tool in all_tools:
|
|
logger.info(f"\n{tool}")
|
|
|
|
def test_tool_metadata_structure(self, registry: ToolRegistry) -> None:
|
|
"""Test that tool metadata has correct structure."""
|
|
all_tools = registry.get_all_tools()
|
|
for tool in all_tools:
|
|
assert isinstance(tool.tool_id, str)
|
|
assert isinstance(tool.name, str)
|
|
assert isinstance(tool.category, str)
|
|
assert tool.category in ["shared", "customer"]
|
|
assert isinstance(tool.description, str)
|
|
assert isinstance(tool.tool_instance, BaseTool)
|
|
assert isinstance(tool.module_path, str)
|
|
|
|
def test_list_tool_ids(self, registry: ToolRegistry) -> None:
|
|
"""Test listing all tool IDs."""
|
|
tool_ids = registry.list_tool_ids()
|
|
assert isinstance(tool_ids, list)
|
|
assert len(tool_ids) > 0
|
|
assert all(isinstance(tool_id, str) for tool_id in tool_ids)
|
|
|
|
# Check that tool IDs follow expected format
|
|
for tool_id in tool_ids:
|
|
assert "." in tool_id
|
|
category, name = tool_id.split(".", 1)
|
|
assert category in ["shared", "customer"]
|
|
|
|
def test_get_specific_tool(self, registry: ToolRegistry) -> None:
|
|
"""Test retrieving a specific tool by ID."""
|
|
# Get all tool IDs first
|
|
tool_ids = registry.list_tool_ids()
|
|
if tool_ids:
|
|
# Test with first available tool
|
|
test_tool_id = tool_ids[0]
|
|
tool_metadata = registry.get_tool(tool_id=test_tool_id)
|
|
|
|
assert tool_metadata is not None
|
|
assert isinstance(tool_metadata, ToolMetadata)
|
|
assert tool_metadata.tool_id == test_tool_id
|
|
|
|
def test_get_nonexistent_tool(self, registry: ToolRegistry) -> None:
|
|
"""Test retrieving a tool that doesn't exist."""
|
|
tool_metadata = registry.get_tool(tool_id="nonexistent.tool")
|
|
assert tool_metadata is None
|
|
|
|
def test_get_tools_by_category_shared(self, registry: ToolRegistry) -> None:
|
|
"""Test getting all shared tools."""
|
|
shared_tools = registry.get_tools_by_category(category="shared")
|
|
assert isinstance(shared_tools, list)
|
|
assert all(tool.category == "shared" for tool in shared_tools)
|
|
|
|
def test_get_tools_by_category_customer(self, registry: ToolRegistry) -> None:
|
|
"""Test getting all customer tools."""
|
|
customer_tools = registry.get_tools_by_category(category="customer")
|
|
assert isinstance(customer_tools, list)
|
|
assert all(tool.category == "customer" for tool in customer_tools)
|
|
|
|
def test_get_tool_instances(self, registry: ToolRegistry) -> None:
|
|
"""Test getting tool instances by IDs."""
|
|
tool_ids = registry.list_tool_ids()
|
|
if len(tool_ids) >= 2:
|
|
# Test with first two tools
|
|
test_ids = tool_ids[:2]
|
|
instances = registry.get_tool_instances(tool_ids=test_ids)
|
|
|
|
assert isinstance(instances, list)
|
|
assert len(instances) == 2
|
|
assert all(isinstance(inst, BaseTool) for inst in instances)
|
|
|
|
def test_get_tool_instances_with_invalid_id(self, registry: ToolRegistry) -> None:
|
|
"""Test getting tool instances with some invalid IDs."""
|
|
tool_ids = registry.list_tool_ids()
|
|
if tool_ids:
|
|
# Mix valid and invalid IDs
|
|
test_ids = [tool_ids[0], "invalid.tool"]
|
|
instances = registry.get_tool_instances(tool_ids=test_ids)
|
|
|
|
# Should only return the valid one
|
|
assert len(instances) == 1
|
|
assert isinstance(instances[0], BaseTool)
|
|
|
|
def test_global_registry_singleton(self) -> None:
|
|
"""Test that get_registry returns same instance."""
|
|
registry1 = get_registry()
|
|
registry2 = get_registry()
|
|
assert registry1 is registry2
|
|
|
|
def test_reinitialize_registry(self) -> None:
|
|
"""Test that reinitialize creates new instance."""
|
|
registry1 = get_registry()
|
|
registry2 = reinitialize_registry()
|
|
# Should be different instances after reinitialize
|
|
assert registry1 is not registry2
|
|
assert registry2.is_initialized
|
|
|
|
|
|
class TestToolDiscovery:
|
|
"""Test suite for tool discovery functionality."""
|
|
|
|
def test_discovers_at_least_one_tool(self) -> None:
|
|
"""Test that at least one tool is discovered."""
|
|
registry = get_registry()
|
|
tool_ids = registry.list_tool_ids()
|
|
|
|
# At least one tool should be successfully loaded
|
|
assert len(tool_ids) >= 1, "Expected at least one tool to be discovered"
|
|
|
|
def test_query_althaus_database_if_available(self) -> None:
|
|
"""Test query_althaus_database tool if it was successfully loaded."""
|
|
registry = get_registry()
|
|
tool = registry.get_tool(tool_id="customer.query_althaus_database")
|
|
|
|
if tool is not None:
|
|
assert tool.name == "query_althaus_database"
|
|
assert tool.category == "customer"
|
|
assert "database" in tool.description.lower()
|
|
else:
|
|
# Tool may not have loaded due to import errors - log warning
|
|
import logging
|
|
|
|
logging.warning(
|
|
"customer.query_althaus_database tool not found - "
|
|
"may have failed to import"
|
|
)
|
|
|
|
def test_tavily_search_if_available(self) -> None:
|
|
"""Test tavily_search tool if it was successfully loaded."""
|
|
registry = get_registry()
|
|
tool = registry.get_tool(tool_id="shared.tavily_search")
|
|
|
|
if tool is not None:
|
|
assert tool.name == "tavily_search"
|
|
assert tool.category == "shared"
|
|
assert "search" in tool.description.lower()
|
|
else:
|
|
# Tool may not have loaded due to import errors - log warning
|
|
import logging
|
|
|
|
logging.warning(
|
|
"shared.tavily_search tool not found - may have failed to import"
|
|
)
|
|
|
|
def test_tool_ids_have_correct_format(self) -> None:
|
|
"""Test that all discovered tool IDs follow the expected format."""
|
|
registry = get_registry()
|
|
tool_ids = registry.list_tool_ids()
|
|
|
|
for tool_id in tool_ids:
|
|
# All tool IDs should have format: category.toolname
|
|
assert "." in tool_id, f"Tool ID {tool_id} missing category separator"
|
|
category, name = tool_id.split(".", 1)
|
|
assert category in [
|
|
"shared",
|
|
"customer",
|
|
], f"Tool {tool_id} has invalid category: {category}"
|
|
assert len(name) > 0, f"Tool {tool_id} has empty name"
|