diff --git a/modules/methods/web/web_search/web_search_base.py b/modules/methods/web/web_search/web_search_base.py new file mode 100644 index 00000000..d655bfd3 --- /dev/null +++ b/modules/methods/web/web_search/web_search_base.py @@ -0,0 +1,31 @@ +"""Base class for web search classes.""" + +from abc import ABC, abstractmethod +from modules.interfaces.interfaceChatModel import ActionDocument, ActionResult + + +from pydantic import BaseModel, Field +from typing import List + + +class WebSearchRequest(BaseModel): + query: str + max_results: int + + +class WebSearchDocumentData(BaseModel): + title: str + url: str + + +class WebSearchActionDocument(ActionDocument): + documentData: List[WebSearchDocumentData] + + +class WebSearchActionResult(ActionResult): + documents: List[WebSearchActionDocument] = Field(default_factory=list) + + +class WebSearchBase(ABC): + @abstractmethod + async def __call__(self, request: WebSearchRequest) -> WebSearchActionResult: ... diff --git a/modules/methods/web/web_search/web_search_tavily.py b/modules/methods/web/web_search/web_search_tavily.py new file mode 100644 index 00000000..dcbea35c --- /dev/null +++ b/modules/methods/web/web_search/web_search_tavily.py @@ -0,0 +1,70 @@ +"""Tavily web search class.""" + +import os +from dataclasses import dataclass +from web_search_base import ( + WebSearchBase, + WebSearchRequest, + WebSearchActionResult, + WebSearchActionDocument, + WebSearchDocumentData, +) + +# from modules.interfaces.interfaceChatModel import ActionResult, ActionDocument +from tavily import AsyncTavilyClient +from modules.shared.timezoneUtils import get_utc_timestamp + + +@dataclass +class WebSearchTavily(WebSearchBase): + client: AsyncTavilyClient = None + + @classmethod + async def create(cls): + return cls(client=AsyncTavilyClient(api_key=os.getenv("TAVILY_API_KEY"))) + + async def __call__(self, request: WebSearchRequest) -> WebSearchActionResult: + """Handles the web search request.""" + # Step 1: Search + try: + search_results = await self._search(request.query, request.max_results) + except Exception as e: + return WebSearchActionResult(success=False, error=str(e)) + + # Step 2: Build ActionResult + try: + result = self._build_action_result(search_results) + except Exception as e: + return WebSearchActionResult(success=False, error=str(e)) + + return result + + async def _search(self, query: str, max_results: int) -> WebSearchActionResult: + """Calls the Tavily API to perform a web search.""" + # Make sure max_results is within the allowed range + if max_results < 0 or max_results > 20: + raise ValueError("max_results must be between 0 and 20") + + # Perform actual API call + response = await self.client.search(query=query, max_results=max_results) + return response["results"] + + def _build_action_result(self, search_results: list) -> WebSearchActionResult: + """Builds the ActionResult from the search results.""" + documents = [] + for result in search_results: + document_name = f"web_search_{get_utc_timestamp()}.txt" + document_data = WebSearchDocumentData( + title=result["title"], url=result["url"] + ) + mime_type = "text/plain" + doc = WebSearchActionDocument( + documentName=document_name, + documentData=document_data, + mimeType=mime_type, + ) + documents.append(doc) + + return WebSearchActionResult( + success=True, documents=documents, resultLabel="web_search_results" + )