feat: add web search abstraction
This commit is contained in:
parent
bd23b41134
commit
dfd76c7d11
2 changed files with 101 additions and 0 deletions
31
modules/methods/web/web_search/web_search_base.py
Normal file
31
modules/methods/web/web_search/web_search_base.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
"""Base class for web search classes."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from modules.interfaces.interfaceChatModel import ActionDocument, ActionResult
|
||||
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
|
||||
|
||||
class WebSearchRequest(BaseModel):
|
||||
query: str
|
||||
max_results: int
|
||||
|
||||
|
||||
class WebSearchDocumentData(BaseModel):
|
||||
title: str
|
||||
url: str
|
||||
|
||||
|
||||
class WebSearchActionDocument(ActionDocument):
|
||||
documentData: List[WebSearchDocumentData]
|
||||
|
||||
|
||||
class WebSearchActionResult(ActionResult):
|
||||
documents: List[WebSearchActionDocument] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WebSearchBase(ABC):
|
||||
@abstractmethod
|
||||
async def __call__(self, request: WebSearchRequest) -> WebSearchActionResult: ...
|
||||
70
modules/methods/web/web_search/web_search_tavily.py
Normal file
70
modules/methods/web/web_search/web_search_tavily.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
"""Tavily web search class."""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from web_search_base import (
|
||||
WebSearchBase,
|
||||
WebSearchRequest,
|
||||
WebSearchActionResult,
|
||||
WebSearchActionDocument,
|
||||
WebSearchDocumentData,
|
||||
)
|
||||
|
||||
# from modules.interfaces.interfaceChatModel import ActionResult, ActionDocument
|
||||
from tavily import AsyncTavilyClient
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSearchTavily(WebSearchBase):
|
||||
client: AsyncTavilyClient = None
|
||||
|
||||
@classmethod
|
||||
async def create(cls):
|
||||
return cls(client=AsyncTavilyClient(api_key=os.getenv("TAVILY_API_KEY")))
|
||||
|
||||
async def __call__(self, request: WebSearchRequest) -> WebSearchActionResult:
|
||||
"""Handles the web search request."""
|
||||
# Step 1: Search
|
||||
try:
|
||||
search_results = await self._search(request.query, request.max_results)
|
||||
except Exception as e:
|
||||
return WebSearchActionResult(success=False, error=str(e))
|
||||
|
||||
# Step 2: Build ActionResult
|
||||
try:
|
||||
result = self._build_action_result(search_results)
|
||||
except Exception as e:
|
||||
return WebSearchActionResult(success=False, error=str(e))
|
||||
|
||||
return result
|
||||
|
||||
async def _search(self, query: str, max_results: int) -> WebSearchActionResult:
|
||||
"""Calls the Tavily API to perform a web search."""
|
||||
# Make sure max_results is within the allowed range
|
||||
if max_results < 0 or max_results > 20:
|
||||
raise ValueError("max_results must be between 0 and 20")
|
||||
|
||||
# Perform actual API call
|
||||
response = await self.client.search(query=query, max_results=max_results)
|
||||
return response["results"]
|
||||
|
||||
def _build_action_result(self, search_results: list) -> WebSearchActionResult:
|
||||
"""Builds the ActionResult from the search results."""
|
||||
documents = []
|
||||
for result in search_results:
|
||||
document_name = f"web_search_{get_utc_timestamp()}.txt"
|
||||
document_data = WebSearchDocumentData(
|
||||
title=result["title"], url=result["url"]
|
||||
)
|
||||
mime_type = "text/plain"
|
||||
doc = WebSearchActionDocument(
|
||||
documentName=document_name,
|
||||
documentData=document_data,
|
||||
mimeType=mime_type,
|
||||
)
|
||||
documents.append(doc)
|
||||
|
||||
return WebSearchActionResult(
|
||||
success=True, documents=documents, resultLabel="web_search_results"
|
||||
)
|
||||
Loading…
Reference in a new issue