chore: restructure web search w/ tests
This commit is contained in:
parent
b37cd502cd
commit
d4b846c598
12 changed files with 277 additions and 0 deletions
0
modules/__init__.py
Normal file
0
modules/__init__.py
Normal file
70
modules/connectors/connector_tavily.py
Normal file
70
modules/connectors/connector_tavily.py
Normal file
|
|
@ -0,0 +1,70 @@
|
||||||
|
"""Tavily web search class."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from modules.interfaces.interface_web_model import (
|
||||||
|
WebSearchBase,
|
||||||
|
WebSearchRequest,
|
||||||
|
WebSearchActionResult,
|
||||||
|
WebSearchActionDocument,
|
||||||
|
WebSearchDocumentData,
|
||||||
|
)
|
||||||
|
|
||||||
|
# from modules.interfaces.interfaceChatModel import ActionResult, ActionDocument
|
||||||
|
from tavily import AsyncTavilyClient
|
||||||
|
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConnectorTavily(WebSearchBase):
|
||||||
|
client: AsyncTavilyClient = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(cls):
|
||||||
|
return cls(client=AsyncTavilyClient(api_key=os.getenv("TAVILY_API_KEY")))
|
||||||
|
|
||||||
|
async def search_urls(self, request: WebSearchRequest) -> WebSearchActionResult:
|
||||||
|
"""Handles the web search request."""
|
||||||
|
# Step 1: Search
|
||||||
|
try:
|
||||||
|
search_results = await self._search(request.query, request.max_results)
|
||||||
|
except Exception as e:
|
||||||
|
return WebSearchActionResult(success=False, error=str(e))
|
||||||
|
|
||||||
|
# Step 2: Build ActionResult
|
||||||
|
try:
|
||||||
|
result = self._build_action_result(search_results)
|
||||||
|
except Exception as e:
|
||||||
|
return WebSearchActionResult(success=False, error=str(e))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _search(self, query: str, max_results: int) -> WebSearchActionResult:
|
||||||
|
"""Calls the Tavily API to perform a web search."""
|
||||||
|
# Make sure max_results is within the allowed range
|
||||||
|
if max_results < 0 or max_results > 20:
|
||||||
|
raise ValueError("max_results must be between 0 and 20")
|
||||||
|
|
||||||
|
# Perform actual API call
|
||||||
|
response = await self.client.search(query=query, max_results=max_results)
|
||||||
|
return response["results"]
|
||||||
|
|
||||||
|
def _build_action_result(self, search_results: list) -> WebSearchActionResult:
|
||||||
|
"""Builds the ActionResult from the search results."""
|
||||||
|
documents = []
|
||||||
|
for result in search_results:
|
||||||
|
document_name = f"web_search_{get_utc_timestamp()}.txt"
|
||||||
|
document_data = WebSearchDocumentData(
|
||||||
|
title=result["title"], url=result["url"]
|
||||||
|
)
|
||||||
|
mime_type = "application/json"
|
||||||
|
doc = WebSearchActionDocument(
|
||||||
|
documentName=document_name,
|
||||||
|
documentData=document_data,
|
||||||
|
mimeType=mime_type,
|
||||||
|
)
|
||||||
|
documents.append(doc)
|
||||||
|
|
||||||
|
return WebSearchActionResult(
|
||||||
|
success=True, documents=documents, resultLabel="web_search_results"
|
||||||
|
)
|
||||||
49
modules/interfaces/interface_web_model.py
Normal file
49
modules/interfaces/interface_web_model.py
Normal file
|
|
@ -0,0 +1,49 @@
|
||||||
|
"""Base class for web classes."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from modules.interfaces.interfaceChatModel import ActionDocument, ActionResult
|
||||||
|
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
# --- Web search ---
|
||||||
|
|
||||||
|
# query -> list of URLs
|
||||||
|
|
||||||
|
|
||||||
|
class WebSearchRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
max_results: int
|
||||||
|
|
||||||
|
|
||||||
|
class WebSearchDocumentData(BaseModel):
|
||||||
|
title: str
|
||||||
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
class WebSearchActionDocument(ActionDocument):
|
||||||
|
documentData: WebSearchDocumentData
|
||||||
|
|
||||||
|
|
||||||
|
class WebSearchActionResult(ActionResult):
|
||||||
|
documents: List[WebSearchActionDocument] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class WebSearchBase(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def search_urls(self, request: WebSearchRequest) -> WebSearchActionResult: ...
|
||||||
|
|
||||||
|
|
||||||
|
# --- Web crawl ---
|
||||||
|
|
||||||
|
# list of URLs -> list of extracted HTML content
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
|
||||||
|
# --- Web query ---
|
||||||
|
|
||||||
|
# query -> list of extracted text
|
||||||
|
|
||||||
|
# TODO
|
||||||
24
modules/interfaces/interface_web_objects.py
Normal file
24
modules/interfaces/interface_web_objects.py
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
from modules.interfaces.interface_web_model import (
|
||||||
|
WebSearchActionResult,
|
||||||
|
WebSearchRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from modules.connectors.connector_tavily import ConnectorTavily
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WebInterface:
|
||||||
|
connector_tavily: ConnectorTavily = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(cls) -> "WebInterface":
|
||||||
|
connector_tavily = await ConnectorTavily.create()
|
||||||
|
|
||||||
|
return WebInterface(connector_tavily=connector_tavily)
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self, web_search_request: WebSearchRequest
|
||||||
|
) -> WebSearchActionResult:
|
||||||
|
# NOTE: Add connectors here
|
||||||
|
return await self.connector_tavily.search_urls(web_search_request)
|
||||||
43
modules/methods/method_web.py
Normal file
43
modules/methods/method_web.py
Normal file
|
|
@ -0,0 +1,43 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict
|
||||||
|
from modules.chat.methodBase import MethodBase, action
|
||||||
|
from modules.interfaces.interfaceChatModel import ActionResult
|
||||||
|
from modules.interfaces.interface_web_objects import WebInterface
|
||||||
|
from modules.interfaces.interface_web_model import WebSearchRequest
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MethodWeb(MethodBase):
|
||||||
|
"""Web method implementation for web operations."""
|
||||||
|
|
||||||
|
def __init__(self, serviceCenter: Any):
|
||||||
|
super().__init__(serviceCenter)
|
||||||
|
|
||||||
|
@action
|
||||||
|
async def search(self, parameters: Dict[str, Any]) -> ActionResult:
|
||||||
|
"""
|
||||||
|
Perform a web search and output a .txt file with a plain list of URLs (one per line).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
query (str): Search query to perform
|
||||||
|
maxResults (int, optional): Maximum number of results (default: 10)
|
||||||
|
"""
|
||||||
|
# TODO: Fix docstrings - do we need that format for parsing?
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare request data
|
||||||
|
web_search_request = WebSearchRequest(
|
||||||
|
query=parameters.get("query"),
|
||||||
|
max_results=parameters.get("maxResults", 10),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform request
|
||||||
|
web_interface = await WebInterface.create()
|
||||||
|
web_search_result = await web_interface.search(web_search_request)
|
||||||
|
|
||||||
|
return web_search_result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ActionResult(success=False, error=str(e))
|
||||||
11
pytest.ini
Normal file
11
pytest.ini
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
[pytest]
|
||||||
|
testpaths = tests
|
||||||
|
python_paths = .
|
||||||
|
addopts = -v --tb=short
|
||||||
|
python_files = test_*.py
|
||||||
|
python_classes = Test*
|
||||||
|
python_functions = test_*
|
||||||
|
log_file = logs/test_logs.log
|
||||||
|
log_file_level = INFO
|
||||||
|
log_file_format = %(asctime)s %(levelname)s %(message)s
|
||||||
|
log_file_date_format = %Y-%m-%d %H:%M:%S
|
||||||
|
|
@ -68,3 +68,7 @@ PyPDF2>=3.0.0
|
||||||
PyMuPDF>=1.20.0
|
PyMuPDF>=1.20.0
|
||||||
beautifulsoup4>=4.11.0
|
beautifulsoup4>=4.11.0
|
||||||
chardet>=4.0.0 # For encoding detection
|
chardet>=4.0.0 # For encoding detection
|
||||||
|
|
||||||
|
## Testing Dependencies
|
||||||
|
pytest>=8.0.0
|
||||||
|
pytest-asyncio>=0.21.0
|
||||||
|
|
|
||||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
# noqa
|
||||||
0
tests/connectors/__init__.py
Normal file
0
tests/connectors/__init__.py
Normal file
39
tests/connectors/test_connector_tavily.py
Normal file
39
tests/connectors/test_connector_tavily.py
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
"""Tests for Tavliy web search."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from modules.interfaces.interfaceChatModel import ActionResult
|
||||||
|
from modules.interfaces.interface_web_model import WebSearchRequest
|
||||||
|
from modules.connectors.connector_tavily import ConnectorTavily
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tavily_connector_search_test_live_api():
|
||||||
|
logger.info("Testing Tavliy connector with live API calls")
|
||||||
|
|
||||||
|
# Test request
|
||||||
|
request = WebSearchRequest(query="How old is the Earth?", max_results=5)
|
||||||
|
|
||||||
|
# Tavily instance
|
||||||
|
connector_tavily = await ConnectorTavily.create()
|
||||||
|
|
||||||
|
# Search test
|
||||||
|
action_result = await connector_tavily.search_urls(request=request)
|
||||||
|
|
||||||
|
# Check results
|
||||||
|
assert isinstance(action_result, ActionResult)
|
||||||
|
|
||||||
|
logger.info("=" * 20)
|
||||||
|
logger.info(f"Action result success status: {action_result.success}")
|
||||||
|
logger.info(f"Action result error: {action_result.error}")
|
||||||
|
logger.info(f"Action result label: {action_result.resultLabel}")
|
||||||
|
|
||||||
|
logger.info("Documents:")
|
||||||
|
for doc in action_result.documents:
|
||||||
|
logger.info("-" * 10)
|
||||||
|
logger.info(f" - Document Name: {doc.documentName}")
|
||||||
|
logger.info(f" - Document Mime Type: {doc.mimeType}")
|
||||||
|
logger.info(f" - Document Data: {doc.documentData}")
|
||||||
0
tests/methods/__init__.py
Normal file
0
tests/methods/__init__.py
Normal file
36
tests/methods/test_method_web.py
Normal file
36
tests/methods/test_method_web.py
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
"""Tests for method web.py"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from modules.methods.method_web import MethodWeb
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_method_web_search_live():
|
||||||
|
"""Tests method web search with live API calls."""
|
||||||
|
|
||||||
|
method_web = MethodWeb(serviceCenter=None)
|
||||||
|
|
||||||
|
# Actual request
|
||||||
|
action_result = await method_web.search(
|
||||||
|
{"query": "How old is the earth", "maxResults": 5}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate results
|
||||||
|
assert action_result.success
|
||||||
|
assert len(action_result.documents) > 0
|
||||||
|
|
||||||
|
logger.info("=" * 20)
|
||||||
|
logger.info(f"Action result success status: {action_result.success}")
|
||||||
|
logger.info(f"Action result error: {action_result.error}")
|
||||||
|
logger.info(f"Action result label: {action_result.resultLabel}")
|
||||||
|
|
||||||
|
logger.info("Documents:")
|
||||||
|
for doc in action_result.documents:
|
||||||
|
logger.info("-" * 10)
|
||||||
|
logger.info(f" - Document Name: {doc.documentName}")
|
||||||
|
logger.info(f" - Document Mime Type: {doc.mimeType}")
|
||||||
|
logger.info(f" - Document Data: {doc.documentData}")
|
||||||
Loading…
Reference in a new issue