gateway/modules/connectors/connector_tavily.py
2025-09-01 09:49:57 +02:00

130 lines
4.4 KiB
Python

"""Tavily web search class."""
import logging
import os
from dataclasses import dataclass
from modules.interfaces.interface_web_model import (
WebCrawlBase,
WebCrawlDocumentData,
WebCrawlRequest,
WebSearchBase,
WebSearchRequest,
WebSearchActionResult,
WebSearchActionDocument,
WebSearchDocumentData,
WebCrawlActionDocument,
WebCrawlActionResult,
)
# from modules.interfaces.interfaceChatModel import ActionResult, ActionDocument
from tavily import AsyncTavilyClient
from modules.shared.timezoneUtils import get_utc_timestamp
logger = logging.getLogger(__name__)
@dataclass
class ConnectorTavily(WebSearchBase, WebCrawlBase):
client: AsyncTavilyClient = None
@classmethod
async def create(cls):
return cls(client=AsyncTavilyClient(api_key=os.getenv("TAVILY_API_KEY")))
async def search_urls(self, request: WebSearchRequest) -> WebSearchActionResult:
"""Handles the web search request.
Takes a query and returns a list of URLs.
"""
# Step 1: Search
try:
search_results = await self._search(request.query, request.max_results)
except Exception as e:
return WebSearchActionResult(success=False, error=str(e))
# Step 2: Build ActionResult
try:
result = self._build_search_action_result(search_results)
except Exception as e:
return WebSearchActionResult(success=False, error=str(e))
return result
async def crawl_urls(self, request: WebCrawlRequest) -> WebCrawlActionResult:
"""Crawls the given URLs and returns the extracted text content."""
# Step 1: Crawl
try:
crawl_results = await self._crawl(request.urls)
except Exception as e:
return WebCrawlActionResult(success=False, error=str(e))
# Step 2: Build ActionResult
try:
result = self._build_crawl_action_result(crawl_results)
except Exception as e:
return WebCrawlActionResult(success=False, error=str(e))
return result
async def _search(self, query: str, max_results: int) -> WebSearchActionResult:
"""Calls the Tavily API to perform a web search."""
# Make sure max_results is within the allowed range
if max_results < 0 or max_results > 20:
raise ValueError("max_results must be between 0 and 20")
# Perform actual API call
response = await self.client.search(query=query, max_results=max_results)
logger.info(f"Tavily API response:\n{response}")
return response["results"]
def _build_search_action_result(
self, search_results: list
) -> WebSearchActionResult:
"""Builds the ActionResult from the search results."""
documents = []
for result in search_results:
document_name = f"web_search_{get_utc_timestamp()}.txt"
document_data = WebSearchDocumentData(
title=result["title"], url=result["url"]
)
mime_type = "application/json"
doc = WebSearchActionDocument(
documentName=document_name,
documentData=document_data,
mimeType=mime_type,
)
documents.append(doc)
return WebSearchActionResult(
success=True, documents=documents, resultLabel="web_search_results"
)
async def _crawl(self, urls: list) -> list[str]:
"""Calls the Tavily API to extract text content from URLs."""
response = await self.client.extract(
urls=urls, extract_depth="advanced", format="text"
)
return response["results"]
def _build_crawl_action_result(self, crawl_results: list) -> WebCrawlActionResult:
"""Builds the ActionResult from the crawl results."""
documents = []
for result in crawl_results:
document_name = f"web_crawl_{get_utc_timestamp()}.txt"
doc_data = WebCrawlDocumentData(
url=result["url"], content=result["raw_content"]
)
mime_type = "application/json"
doc = WebCrawlActionDocument(
documentName=document_name,
documentData=doc_data,
mimeType=mime_type,
)
documents.append(doc)
return WebCrawlActionResult(
success=True, documents=documents, resultLabel="web_crawl_results"
)