From d4b846c5980d6448b0ef41ea4c6ecbf67a74f22f Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Fri, 29 Aug 2025 15:35:14 +0200 Subject: [PATCH] chore: restructure web search w/ tests --- modules/__init__.py | 0 modules/connectors/connector_tavily.py | 70 +++++++++++++++++++++ modules/interfaces/interface_web_model.py | 49 +++++++++++++++ modules/interfaces/interface_web_objects.py | 24 +++++++ modules/methods/method_web.py | 43 +++++++++++++ pytest.ini | 11 ++++ requirements.txt | 4 ++ tests/__init__.py | 1 + tests/connectors/__init__.py | 0 tests/connectors/test_connector_tavily.py | 39 ++++++++++++ tests/methods/__init__.py | 0 tests/methods/test_method_web.py | 36 +++++++++++ 12 files changed, 277 insertions(+) create mode 100644 modules/__init__.py create mode 100644 modules/connectors/connector_tavily.py create mode 100644 modules/interfaces/interface_web_model.py create mode 100644 modules/interfaces/interface_web_objects.py create mode 100644 modules/methods/method_web.py create mode 100644 pytest.ini create mode 100644 tests/__init__.py create mode 100644 tests/connectors/__init__.py create mode 100644 tests/connectors/test_connector_tavily.py create mode 100644 tests/methods/__init__.py create mode 100644 tests/methods/test_method_web.py diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modules/connectors/connector_tavily.py b/modules/connectors/connector_tavily.py new file mode 100644 index 00000000..4f57fb94 --- /dev/null +++ b/modules/connectors/connector_tavily.py @@ -0,0 +1,70 @@ +"""Tavily web search class.""" + +import os +from dataclasses import dataclass +from modules.interfaces.interface_web_model import ( + WebSearchBase, + WebSearchRequest, + WebSearchActionResult, + WebSearchActionDocument, + WebSearchDocumentData, +) + +# from modules.interfaces.interfaceChatModel import ActionResult, ActionDocument +from tavily import AsyncTavilyClient +from modules.shared.timezoneUtils import get_utc_timestamp + + +@dataclass +class ConnectorTavily(WebSearchBase): + client: AsyncTavilyClient = None + + @classmethod + async def create(cls): + return cls(client=AsyncTavilyClient(api_key=os.getenv("TAVILY_API_KEY"))) + + async def search_urls(self, request: WebSearchRequest) -> WebSearchActionResult: + """Handles the web search request.""" + # Step 1: Search + try: + search_results = await self._search(request.query, request.max_results) + except Exception as e: + return WebSearchActionResult(success=False, error=str(e)) + + # Step 2: Build ActionResult + try: + result = self._build_action_result(search_results) + except Exception as e: + return WebSearchActionResult(success=False, error=str(e)) + + return result + + async def _search(self, query: str, max_results: int) -> WebSearchActionResult: + """Calls the Tavily API to perform a web search.""" + # Make sure max_results is within the allowed range + if max_results < 0 or max_results > 20: + raise ValueError("max_results must be between 0 and 20") + + # Perform actual API call + response = await self.client.search(query=query, max_results=max_results) + return response["results"] + + def _build_action_result(self, search_results: list) -> WebSearchActionResult: + """Builds the ActionResult from the search results.""" + documents = [] + for result in search_results: + document_name = f"web_search_{get_utc_timestamp()}.txt" + document_data = WebSearchDocumentData( + title=result["title"], url=result["url"] + ) + mime_type = "application/json" + doc = WebSearchActionDocument( + documentName=document_name, + documentData=document_data, + mimeType=mime_type, + ) + documents.append(doc) + + return WebSearchActionResult( + success=True, documents=documents, resultLabel="web_search_results" + ) diff --git a/modules/interfaces/interface_web_model.py b/modules/interfaces/interface_web_model.py new file mode 100644 index 00000000..8dc01fc8 --- /dev/null +++ b/modules/interfaces/interface_web_model.py @@ -0,0 +1,49 @@ +"""Base class for web classes.""" + +from abc import ABC, abstractmethod +from modules.interfaces.interfaceChatModel import ActionDocument, ActionResult + + +from pydantic import BaseModel, Field +from typing import List + + +# --- Web search --- + +# query -> list of URLs + + +class WebSearchRequest(BaseModel): + query: str + max_results: int + + +class WebSearchDocumentData(BaseModel): + title: str + url: str + + +class WebSearchActionDocument(ActionDocument): + documentData: WebSearchDocumentData + + +class WebSearchActionResult(ActionResult): + documents: List[WebSearchActionDocument] = Field(default_factory=list) + + +class WebSearchBase(ABC): + @abstractmethod + async def search_urls(self, request: WebSearchRequest) -> WebSearchActionResult: ... + + +# --- Web crawl --- + +# list of URLs -> list of extracted HTML content + +# TODO + +# --- Web query --- + +# query -> list of extracted text + +# TODO diff --git a/modules/interfaces/interface_web_objects.py b/modules/interfaces/interface_web_objects.py new file mode 100644 index 00000000..b38db6a3 --- /dev/null +++ b/modules/interfaces/interface_web_objects.py @@ -0,0 +1,24 @@ +from modules.interfaces.interface_web_model import ( + WebSearchActionResult, + WebSearchRequest, +) + +from dataclasses import dataclass +from modules.connectors.connector_tavily import ConnectorTavily + + +@dataclass +class WebInterface: + connector_tavily: ConnectorTavily = None + + @classmethod + async def create(cls) -> "WebInterface": + connector_tavily = await ConnectorTavily.create() + + return WebInterface(connector_tavily=connector_tavily) + + async def search( + self, web_search_request: WebSearchRequest + ) -> WebSearchActionResult: + # NOTE: Add connectors here + return await self.connector_tavily.search_urls(web_search_request) diff --git a/modules/methods/method_web.py b/modules/methods/method_web.py new file mode 100644 index 00000000..27b82ba5 --- /dev/null +++ b/modules/methods/method_web.py @@ -0,0 +1,43 @@ +import logging +from typing import Any, Dict +from modules.chat.methodBase import MethodBase, action +from modules.interfaces.interfaceChatModel import ActionResult +from modules.interfaces.interface_web_objects import WebInterface +from modules.interfaces.interface_web_model import WebSearchRequest + + +logger = logging.getLogger(__name__) + + +class MethodWeb(MethodBase): + """Web method implementation for web operations.""" + + def __init__(self, serviceCenter: Any): + super().__init__(serviceCenter) + + @action + async def search(self, parameters: Dict[str, Any]) -> ActionResult: + """ + Perform a web search and output a .txt file with a plain list of URLs (one per line). + + Parameters: + query (str): Search query to perform + maxResults (int, optional): Maximum number of results (default: 10) + """ + # TODO: Fix docstrings - do we need that format for parsing? + + try: + # Prepare request data + web_search_request = WebSearchRequest( + query=parameters.get("query"), + max_results=parameters.get("maxResults", 10), + ) + + # Perform request + web_interface = await WebInterface.create() + web_search_result = await web_interface.search(web_search_request) + + return web_search_result + + except Exception as e: + return ActionResult(success=False, error=str(e)) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..b606fa5f --- /dev/null +++ b/pytest.ini @@ -0,0 +1,11 @@ +[pytest] +testpaths = tests +python_paths = . +addopts = -v --tb=short +python_files = test_*.py +python_classes = Test* +python_functions = test_* +log_file = logs/test_logs.log +log_file_level = INFO +log_file_format = %(asctime)s %(levelname)s %(message)s +log_file_date_format = %Y-%m-%d %H:%M:%S \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index e6397aa6..75bd81b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -68,3 +68,7 @@ PyPDF2>=3.0.0 PyMuPDF>=1.20.0 beautifulsoup4>=4.11.0 chardet>=4.0.0 # For encoding detection + +## Testing Dependencies +pytest>=8.0.0 +pytest-asyncio>=0.21.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..4ede8e6d --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# noqa diff --git a/tests/connectors/__init__.py b/tests/connectors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/connectors/test_connector_tavily.py b/tests/connectors/test_connector_tavily.py new file mode 100644 index 00000000..54ea382c --- /dev/null +++ b/tests/connectors/test_connector_tavily.py @@ -0,0 +1,39 @@ +"""Tests for Tavliy web search.""" + +import pytest +import logging + +from modules.interfaces.interfaceChatModel import ActionResult +from modules.interfaces.interface_web_model import WebSearchRequest +from modules.connectors.connector_tavily import ConnectorTavily + +logger = logging.getLogger(__name__) + + +@pytest.mark.asyncio +async def test_tavily_connector_search_test_live_api(): + logger.info("Testing Tavliy connector with live API calls") + + # Test request + request = WebSearchRequest(query="How old is the Earth?", max_results=5) + + # Tavily instance + connector_tavily = await ConnectorTavily.create() + + # Search test + action_result = await connector_tavily.search_urls(request=request) + + # Check results + assert isinstance(action_result, ActionResult) + + logger.info("=" * 20) + logger.info(f"Action result success status: {action_result.success}") + logger.info(f"Action result error: {action_result.error}") + logger.info(f"Action result label: {action_result.resultLabel}") + + logger.info("Documents:") + for doc in action_result.documents: + logger.info("-" * 10) + logger.info(f" - Document Name: {doc.documentName}") + logger.info(f" - Document Mime Type: {doc.mimeType}") + logger.info(f" - Document Data: {doc.documentData}") diff --git a/tests/methods/__init__.py b/tests/methods/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/methods/test_method_web.py b/tests/methods/test_method_web.py new file mode 100644 index 00000000..078d2902 --- /dev/null +++ b/tests/methods/test_method_web.py @@ -0,0 +1,36 @@ +"""Tests for method web.py""" + +import logging + +import pytest +from modules.methods.method_web import MethodWeb + +logger = logging.getLogger(__name__) + + +@pytest.mark.asyncio +async def test_method_web_search_live(): + """Tests method web search with live API calls.""" + + method_web = MethodWeb(serviceCenter=None) + + # Actual request + action_result = await method_web.search( + {"query": "How old is the earth", "maxResults": 5} + ) + + # Evaluate results + assert action_result.success + assert len(action_result.documents) > 0 + + logger.info("=" * 20) + logger.info(f"Action result success status: {action_result.success}") + logger.info(f"Action result error: {action_result.error}") + logger.info(f"Action result label: {action_result.resultLabel}") + + logger.info("Documents:") + for doc in action_result.documents: + logger.info("-" * 10) + logger.info(f" - Document Name: {doc.documentName}") + logger.info(f" - Document Mime Type: {doc.mimeType}") + logger.info(f" - Document Data: {doc.documentData}")