335 lines
12 KiB
Python
335 lines
12 KiB
Python
"""Tavily web search class."""
|
|
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass
|
|
from modules.interfaces.interfaceWebModel import (
|
|
WebCrawlBase,
|
|
WebCrawlDocumentData,
|
|
WebCrawlRequest,
|
|
WebCrawlResultItem,
|
|
WebScrapeActionDocument,
|
|
WebScrapeActionResult,
|
|
WebScrapeBase,
|
|
WebScrapeDocumentData,
|
|
WebScrapeRequest,
|
|
WebScrapeResultItem,
|
|
WebSearchBase,
|
|
WebSearchRequest,
|
|
WebSearchActionResult,
|
|
WebSearchActionDocument,
|
|
WebSearchDocumentData,
|
|
WebSearchResultItem,
|
|
WebCrawlActionDocument,
|
|
WebCrawlActionResult,
|
|
get_web_search_min_results,
|
|
get_web_search_max_results,
|
|
)
|
|
|
|
# from modules.interfaces.interfaceChatModel import ActionResult, ActionDocument
|
|
from tavily import AsyncTavilyClient
|
|
from modules.shared.timezoneUtils import get_utc_timestamp
|
|
from modules.shared.configuration import APP_CONFIG
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Cached configuration values are loaded into the connector instance on creation
|
|
|
|
|
|
@dataclass
|
|
class TavilySearchResult:
|
|
title: str
|
|
url: str
|
|
|
|
|
|
@dataclass
|
|
class TavilyCrawlResult:
|
|
url: str
|
|
content: str
|
|
|
|
|
|
@dataclass
|
|
class ConnectorTavily(WebSearchBase, WebCrawlBase, WebScrapeBase):
|
|
client: AsyncTavilyClient = None
|
|
# Cached settings loaded at initialization time
|
|
crawl_timeout: int = 30
|
|
crawl_max_retries: int = 3
|
|
crawl_retry_delay: int = 2
|
|
|
|
@classmethod
|
|
async def create(cls):
|
|
api_key = APP_CONFIG.get("Connector_WebTavily_API_KEY_SECRET")
|
|
if not api_key:
|
|
raise ValueError("Tavily API key not configured. Please set Connector_WebTavily_API_KEY_SECRET in config.ini")
|
|
# Load and cache web crawl related configuration
|
|
crawl_timeout = int(APP_CONFIG.get("Web_Crawl_TIMEOUT", "30"))
|
|
crawl_max_retries = int(APP_CONFIG.get("Web_Crawl_MAX_RETRIES", "3"))
|
|
crawl_retry_delay = int(APP_CONFIG.get("Web_Crawl_RETRY_DELAY", "2"))
|
|
return cls(
|
|
client=AsyncTavilyClient(api_key=api_key),
|
|
crawl_timeout=crawl_timeout,
|
|
crawl_max_retries=crawl_max_retries,
|
|
crawl_retry_delay=crawl_retry_delay,
|
|
)
|
|
|
|
async def search_urls(self, request: WebSearchRequest) -> WebSearchActionResult:
|
|
"""Handles the web search request.
|
|
|
|
Takes a query and returns a list of URLs.
|
|
"""
|
|
# Step 1: Search
|
|
try:
|
|
search_results = await self._search(
|
|
query=request.query,
|
|
max_results=request.max_results,
|
|
search_depth=request.search_depth,
|
|
time_range=request.time_range,
|
|
topic=request.topic,
|
|
include_domains=request.include_domains,
|
|
exclude_domains=request.exclude_domains,
|
|
language=request.language,
|
|
include_answer=request.include_answer,
|
|
include_raw_content=request.include_raw_content,
|
|
)
|
|
except Exception as e:
|
|
return WebSearchActionResult(success=False, error=str(e))
|
|
|
|
# Step 2: Build ActionResult
|
|
try:
|
|
result = self._build_search_action_result(search_results, request.query)
|
|
except Exception as e:
|
|
return WebSearchActionResult(success=False, error=str(e))
|
|
|
|
return result
|
|
|
|
async def crawl_urls(self, request: WebCrawlRequest) -> WebCrawlActionResult:
|
|
"""Crawls the given URLs and returns the extracted text content."""
|
|
# Step 1: Crawl
|
|
try:
|
|
crawl_results = await self._crawl(request.urls)
|
|
except Exception as e:
|
|
return WebCrawlActionResult(success=False, error=str(e))
|
|
|
|
# Step 2: Build ActionResult
|
|
try:
|
|
result = self._build_crawl_action_result(crawl_results, request.urls)
|
|
except Exception as e:
|
|
return WebCrawlActionResult(success=False, error=str(e))
|
|
|
|
return result
|
|
|
|
async def scrape(self, request: WebScrapeRequest) -> WebScrapeActionResult:
|
|
"""Turns a query in a list of urls with extracted content."""
|
|
# Step 1: Search
|
|
try:
|
|
search_results = await self._search(
|
|
query=request.query,
|
|
max_results=request.max_results,
|
|
search_depth=request.search_depth,
|
|
time_range=request.time_range,
|
|
topic=request.topic,
|
|
include_domains=request.include_domains,
|
|
exclude_domains=request.exclude_domains,
|
|
language=request.language,
|
|
include_answer=request.include_answer,
|
|
include_raw_content=request.include_raw_content,
|
|
)
|
|
except Exception as e:
|
|
return WebScrapeActionResult(success=False, error=str(e))
|
|
|
|
# Step 2: Crawl
|
|
try:
|
|
urls = [result.url for result in search_results]
|
|
crawl_results = await self._crawl(
|
|
urls,
|
|
extract_depth=request.extract_depth,
|
|
format=request.format,
|
|
)
|
|
except Exception as e:
|
|
return WebScrapeActionResult(success=False, error=str(e))
|
|
|
|
# Step 3: Build ActionResult
|
|
try:
|
|
result = self._build_scrape_action_result(crawl_results, request.query)
|
|
except Exception as e:
|
|
return WebScrapeActionResult(success=False, error=str(e))
|
|
|
|
return result
|
|
|
|
async def _search(
|
|
self,
|
|
query: str,
|
|
max_results: int,
|
|
search_depth: str | None = None,
|
|
time_range: str | None = None,
|
|
topic: str | None = None,
|
|
include_domains: list[str] | None = None,
|
|
exclude_domains: list[str] | None = None,
|
|
language: str | None = None,
|
|
include_answer: bool | None = None,
|
|
include_raw_content: bool | None = None,
|
|
) -> list[TavilySearchResult]:
|
|
"""Calls the Tavily API to perform a web search."""
|
|
# Make sure max_results is within the allowed range
|
|
min_results = get_web_search_min_results()
|
|
max_allowed_results = get_web_search_max_results()
|
|
if max_results < min_results or max_results > max_allowed_results:
|
|
raise ValueError(f"max_results must be between {min_results} and {max_allowed_results}")
|
|
|
|
# Perform actual API call
|
|
# Build kwargs only for provided options to avoid API rejections
|
|
kwargs: dict = {"query": query, "max_results": max_results}
|
|
if search_depth is not None:
|
|
kwargs["search_depth"] = search_depth
|
|
if time_range is not None:
|
|
kwargs["time_range"] = time_range
|
|
if topic is not None:
|
|
kwargs["topic"] = topic
|
|
if include_domains is not None:
|
|
kwargs["include_domains"] = include_domains
|
|
if exclude_domains is not None:
|
|
kwargs["exclude_domains"] = exclude_domains
|
|
if language is not None:
|
|
kwargs["language"] = language
|
|
if include_answer is not None:
|
|
kwargs["include_answer"] = include_answer
|
|
if include_raw_content is not None:
|
|
kwargs["include_raw_content"] = include_raw_content
|
|
|
|
response = await self.client.search(**kwargs)
|
|
|
|
return [
|
|
TavilySearchResult(title=result["title"], url=result["url"])
|
|
for result in response["results"]
|
|
]
|
|
|
|
def _build_search_action_result(
|
|
self, search_results: list[TavilySearchResult], query: str = ""
|
|
) -> WebSearchActionResult:
|
|
"""Builds the ActionResult from the search results."""
|
|
# Convert to result items
|
|
result_items = [
|
|
WebSearchResultItem(title=result.title, url=result.url)
|
|
for result in search_results
|
|
]
|
|
|
|
# Create document data with all results
|
|
document_data = WebSearchDocumentData(
|
|
query=query, results=result_items, total_count=len(result_items)
|
|
)
|
|
|
|
# Create single document
|
|
document = WebSearchActionDocument(
|
|
documentName=f"web_search_results_{get_utc_timestamp()}.json",
|
|
documentData=document_data,
|
|
mimeType="application/json",
|
|
)
|
|
|
|
return WebSearchActionResult(
|
|
success=True, documents=[document], resultLabel="web_search_results"
|
|
)
|
|
|
|
async def _crawl(
|
|
self,
|
|
urls: list,
|
|
extract_depth: str | None = None,
|
|
format: str | None = None,
|
|
) -> list[TavilyCrawlResult]:
|
|
"""Calls the Tavily API to extract text content from URLs with retry logic."""
|
|
import asyncio
|
|
|
|
max_retries = self.crawl_max_retries
|
|
retry_delay = self.crawl_retry_delay
|
|
timeout = self.crawl_timeout
|
|
|
|
for attempt in range(max_retries + 1):
|
|
try:
|
|
# Use asyncio.wait_for for timeout
|
|
# Build kwargs for extract
|
|
kwargs_extract: dict = {"urls": urls}
|
|
kwargs_extract["extract_depth"] = extract_depth or "advanced"
|
|
kwargs_extract["format"] = format or "text"
|
|
|
|
response = await asyncio.wait_for(
|
|
self.client.extract(**kwargs_extract),
|
|
timeout=timeout
|
|
)
|
|
|
|
return [
|
|
TavilyCrawlResult(url=result["url"], content=result["raw_content"])
|
|
for result in response["results"]
|
|
]
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"Crawl attempt {attempt + 1} timed out after {timeout} seconds")
|
|
if attempt < max_retries:
|
|
logger.info(f"Retrying in {retry_delay} seconds...")
|
|
await asyncio.sleep(retry_delay)
|
|
else:
|
|
raise Exception(f"Crawl failed after {max_retries + 1} attempts due to timeout")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Crawl attempt {attempt + 1} failed: {str(e)}")
|
|
if attempt < max_retries:
|
|
logger.info(f"Retrying in {retry_delay} seconds...")
|
|
await asyncio.sleep(retry_delay)
|
|
else:
|
|
raise Exception(f"Crawl failed after {max_retries + 1} attempts: {str(e)}")
|
|
|
|
def _build_crawl_action_result(
|
|
self, crawl_results: list[TavilyCrawlResult], urls: list[str] = None
|
|
) -> WebCrawlActionResult:
|
|
"""Builds the ActionResult from the crawl results."""
|
|
# Convert to result items
|
|
result_items = [
|
|
WebCrawlResultItem(url=result.url, content=result.content)
|
|
for result in crawl_results
|
|
]
|
|
|
|
# Create document data with all results
|
|
document_data = WebCrawlDocumentData(
|
|
urls=urls or [result.url for result in crawl_results],
|
|
results=result_items,
|
|
total_count=len(result_items),
|
|
)
|
|
|
|
# Create single document
|
|
document = WebCrawlActionDocument(
|
|
documentName=f"web_crawl_results_{get_utc_timestamp()}.json",
|
|
documentData=document_data,
|
|
mimeType="application/json",
|
|
)
|
|
|
|
return WebCrawlActionResult(
|
|
success=True, documents=[document], resultLabel="web_crawl_results"
|
|
)
|
|
|
|
def _build_scrape_action_result(
|
|
self, crawl_results: list[TavilyCrawlResult], query: str = ""
|
|
) -> WebScrapeActionResult:
|
|
"""Builds the ActionResult from the scrape results."""
|
|
# Convert to result items
|
|
result_items = [
|
|
WebScrapeResultItem(url=result.url, content=result.content)
|
|
for result in crawl_results
|
|
]
|
|
|
|
# Create document data with all results
|
|
document_data = WebScrapeDocumentData(
|
|
query=query,
|
|
results=result_items,
|
|
total_count=len(result_items),
|
|
)
|
|
|
|
# Create single document
|
|
document = WebScrapeActionDocument(
|
|
documentName=f"web_scrape_results_{get_utc_timestamp()}.json",
|
|
documentData=document_data,
|
|
mimeType="application/json",
|
|
)
|
|
|
|
return WebScrapeActionResult(
|
|
success=True, documents=[document], resultLabel="web_scrape_results"
|
|
)
|