gateway/modules/aicore/aicorePluginTavily.py

490 lines
20 KiB
Python

"""Tavily web search class.
"""
import logging
import asyncio
from dataclasses import dataclass
from typing import Optional, List
from tavily import AsyncTavilyClient
from modules.shared.configuration import APP_CONFIG
from modules.shared.timezoneUtils import get_utc_timestamp
from modules.aicore.aicoreBase import BaseConnectorAi
from modules.datamodels.datamodelAi import AiModel, ModelCapabilitiesEnum, PriorityEnum, ProcessingModeEnum, OperationTypeEnum
from modules.datamodels.datamodelWeb import (
WebSearchActionResult,
WebSearchActionDocument,
WebSearchDocumentData,
WebSearchResultItem,
WebCrawlActionResult,
WebCrawlActionDocument,
WebCrawlDocumentData,
WebCrawlResultItem,
WebScrapeActionResult,
WebScrapeActionDocument,
WebSearchDocumentData as WebScrapeDocumentData,
WebScrapeResultItem,
)
logger = logging.getLogger(__name__)
@dataclass
class WebSearchResult:
title: str
url: str
raw_content: Optional[str] = None
@dataclass
class WebCrawlResult:
url: str
content: str
class ConnectorWeb(BaseConnectorAi):
"""Tavily web search connector."""
def __init__(self):
super().__init__()
self.client: Optional[AsyncTavilyClient] = None
# Cached settings loaded at initialization time
self.crawlTimeout: int = 30
self.crawlMaxRetries: int = 3
self.crawlRetryDelay: int = 2
# Cached web search constraints (camelCase per project style)
self.webSearchMinResults: int = 1
self.webSearchMaxResults: int = 20
def getConnectorType(self) -> str:
"""Get the connector type identifier."""
return "tavily"
def getModels(self) -> List[AiModel]:
"""Get all available Tavily models."""
return [
AiModel(
name="tavily_search",
displayName="Tavily Search",
connectorType="tavily",
maxTokens=0, # Web search doesn't use tokens
contextLength=0,
costPer1kTokensInput=0.0,
costPer1kTokensOutput=0.0,
speedRating=8,
qualityRating=8,
capabilities=[ModelCapabilitiesEnum.WEB_SEARCH, ModelCapabilitiesEnum.INFORMATION_RETRIEVAL, ModelCapabilitiesEnum.URL_DISCOVERY],
functionCall=self.search,
priority=PriorityEnum.BALANCED,
processingMode=ProcessingModeEnum.BASIC,
operationTypes=[OperationTypeEnum.WEB_RESEARCH],
version="tavily-search",
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived, searchDepth="basic", numRequests=1: numRequests * (1 if searchDepth == "basic" else 2) * 0.008
),
AiModel(
name="tavily_extract",
displayName="Tavily Extract",
connectorType="tavily",
maxTokens=0, # Web extraction doesn't use tokens
contextLength=0,
costPer1kTokensInput=0.0,
costPer1kTokensOutput=0.0,
speedRating=6,
qualityRating=8,
capabilities=[ModelCapabilitiesEnum.WEB_CRAWLING, ModelCapabilitiesEnum.CONTENT_EXTRACTION, ModelCapabilitiesEnum.TEXT_EXTRACTION],
functionCall=self.crawl,
priority=PriorityEnum.BALANCED,
processingMode=ProcessingModeEnum.BASIC,
operationTypes=[OperationTypeEnum.WEB_RESEARCH],
version="tavily-extract",
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived, extractionDepth="basic", numSuccessfulUrls=1: (numSuccessfulUrls / 5) * (1 if extractionDepth == "basic" else 2) * 0.008
),
AiModel(
name="tavily_crawl",
displayName="Tavily Crawl",
connectorType="tavily",
maxTokens=0, # Web crawling doesn't use tokens
contextLength=0,
costPer1kTokensInput=0.0,
costPer1kTokensOutput=0.0,
speedRating=6,
qualityRating=8,
capabilities=[ModelCapabilitiesEnum.WEB_CRAWLING, ModelCapabilitiesEnum.CONTENT_EXTRACTION, ModelCapabilitiesEnum.MAPPING],
functionCall=self.crawl,
priority=PriorityEnum.BALANCED,
processingMode=ProcessingModeEnum.BASIC,
operationTypes=[OperationTypeEnum.WEB_RESEARCH],
version="tavily-crawl",
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived, numPages=10, extractionDepth="basic", withInstructions=False, numSuccessfulExtractions=10: ((numPages / 10) * (2 if withInstructions else 1) + (numSuccessfulExtractions / 5) * (1 if extractionDepth == "basic" else 2)) * 0.008
),
AiModel(
name="tavily_scrape",
displayName="Tavily Scrape",
connectorType="tavily",
maxTokens=0, # Web scraping doesn't use tokens
contextLength=0,
costPer1kTokensInput=0.0,
costPer1kTokensOutput=0.0,
speedRating=6,
qualityRating=8,
capabilities=[ModelCapabilitiesEnum.WEB_SEARCH, ModelCapabilitiesEnum.WEB_CRAWLING, ModelCapabilitiesEnum.CONTENT_EXTRACTION, ModelCapabilitiesEnum.INFORMATION_RETRIEVAL],
functionCall=self.scrape,
priority=PriorityEnum.BALANCED,
processingMode=ProcessingModeEnum.BASIC,
operationTypes=[OperationTypeEnum.WEB_RESEARCH],
version="tavily-search-extract",
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived, searchDepth="basic", numSuccessfulUrls=1, extractionDepth="basic": ((1 if searchDepth == "basic" else 2) + (numSuccessfulUrls / 5) * (1 if extractionDepth == "basic" else 2)) * 0.008
)
]
@classmethod
async def create(cls):
api_key = APP_CONFIG.get("Connector_WebTavily_API_KEY_SECRET")
if not api_key:
raise ValueError("Tavily API key not configured. Please set Connector_WebTavily_API_KEY_SECRET in config.ini")
# Load and cache web crawl related configuration
crawlTimeout = int(APP_CONFIG.get("Web_Crawl_TIMEOUT", "30"))
crawlMaxRetries = int(APP_CONFIG.get("Web_Crawl_MAX_RETRIES", "3"))
crawlRetryDelay = int(APP_CONFIG.get("Web_Crawl_RETRY_DELAY", "2"))
return cls(
client=AsyncTavilyClient(api_key=api_key),
crawlTimeout=crawlTimeout,
crawlMaxRetries=crawlMaxRetries,
crawlRetryDelay=crawlRetryDelay,
webSearchMinResults=int(APP_CONFIG.get("Web_Search_MIN_RESULTS", "1")),
webSearchMaxResults=int(APP_CONFIG.get("Web_Search_MAX_RESULTS", "20")),
)
# Standardized methods returning ActionResults for the interface to consume
async def search(self, request) -> "WebSearchActionResult":
try:
raw_results = await self._search(
query=request.query,
max_results=request.max_results,
search_depth=request.search_depth,
time_range=request.time_range,
topic=request.topic,
include_domains=request.include_domains,
exclude_domains=request.exclude_domains,
language=request.language,
include_answer=request.include_answer,
include_raw_content=request.include_raw_content,
)
except Exception as e:
return WebSearchActionResult(success=False, error=str(e))
result_items = [
WebSearchResultItem(
title=result.title,
url=result.url,
raw_content=getattr(result, 'raw_content', None)
)
for result in raw_results
]
document_data = WebSearchDocumentData(
query=request.query,
results=result_items,
total_count=len(result_items),
)
document = WebSearchActionDocument(
documentName=f"web_search_results_{get_utc_timestamp()}.json",
documentData=document_data,
mimeType="application/json",
)
return WebSearchActionResult(
success=True, documents=[document], resultLabel="web_search_results"
)
async def crawl(self, request) -> "WebCrawlActionResult":
try:
raw_results = await self._crawl(
[str(u) for u in request.urls],
extract_depth=request.extract_depth,
format=request.format,
)
except Exception as e:
return WebCrawlActionResult(success=False, error=str(e))
result_items = [
WebCrawlResultItem(url=result.url, content=result.content)
for result in raw_results
]
document_data = WebCrawlDocumentData(
urls=[str(u) for u in request.urls],
results=result_items,
total_count=len(result_items),
)
document = WebCrawlActionDocument(
documentName=f"web_crawl_results_{get_utc_timestamp()}.json",
documentData=document_data,
mimeType="application/json",
)
return WebCrawlActionResult(
success=True, documents=[document], resultLabel="web_crawl_results"
)
async def scrape(self, request) -> "WebScrapeActionResult":
try:
search_results = await self._search(
query=request.query,
max_results=request.max_results,
search_depth=request.search_depth,
time_range=request.time_range,
topic=request.topic,
include_domains=request.include_domains,
exclude_domains=request.exclude_domains,
language=request.language,
include_answer=request.include_answer,
include_raw_content=request.include_raw_content,
)
except Exception as e:
return WebScrapeActionResult(success=False, error=str(e))
try:
urls = [result.url for result in search_results]
crawl_results = await self._crawl(
urls,
extract_depth=request.extract_depth,
format=request.format,
)
except Exception as e:
return WebScrapeActionResult(success=False, error=str(e))
result_items = [
WebScrapeResultItem(url=result.url, content=result.content)
for result in crawl_results
]
document_data = WebScrapeDocumentData(
query=request.query,
results=result_items,
total_count=len(result_items),
)
document = WebScrapeActionDocument(
documentName=f"web_scrape_results_{get_utc_timestamp()}.json",
documentData=document_data,
mimeType="application/json",
)
return WebScrapeActionResult(
success=True, documents=[document], resultLabel="web_scrape_results"
)
async def _search_urls_raw(self,
*,
query: str,
max_results: int,
search_depth: str | None = None,
time_range: str | None = None,
topic: str | None = None,
include_domains: list[str] | None = None,
exclude_domains: list[str] | None = None,
language: str | None = None,
include_answer: bool | None = None,
include_raw_content: bool | None = None,
) -> list["WebSearchResult"]:
return await self._search(
query=query,
max_results=max_results,
search_depth=search_depth,
time_range=time_range,
topic=topic,
include_domains=include_domains,
exclude_domains=exclude_domains,
language=language,
include_answer=include_answer,
include_raw_content=include_raw_content,
)
async def _crawl_urls_raw(self,
*,
urls: list[str],
extract_depth: str | None = None,
format: str | None = None,
) -> list["WebCrawlResult"]:
return await self._crawl(urls, extract_depth=extract_depth, format=format)
async def _scrape_raw(self,
*,
query: str,
max_results: int,
search_depth: str | None = None,
time_range: str | None = None,
topic: str | None = None,
include_domains: list[str] | None = None,
exclude_domains: list[str] | None = None,
language: str | None = None,
include_answer: bool | None = None,
include_raw_content: bool | None = None,
extract_depth: str | None = None,
format: str | None = None,
) -> list["WebCrawlResult"]:
search_results = await self._search(
query=query,
max_results=max_results,
search_depth=search_depth,
time_range=time_range,
topic=topic,
include_domains=include_domains,
exclude_domains=exclude_domains,
language=language,
include_answer=include_answer,
include_raw_content=include_raw_content,
)
urls = [result.url for result in search_results]
return await self._crawl(urls, extract_depth=extract_depth, format=format)
def _clean_url(self, url: str) -> str:
"""Clean URL by removing extra text that might be appended."""
import re
# Extract just the URL part, removing any extra text after it
url_match = re.match(r'(https?://[^\s,]+)', url)
if url_match:
return url_match.group(1)
return url
async def _search(
self,
query: str,
max_results: int,
search_depth: str | None = None,
time_range: str | None = None,
topic: str | None = None,
include_domains: list[str] | None = None,
exclude_domains: list[str] | None = None,
language: str | None = None,
country: str | None = None,
include_answer: bool | None = None,
include_raw_content: bool | None = None,
) -> list[WebSearchResult]:
"""Calls the Tavily API to perform a web search."""
# Make sure max_results is within the allowed range (use cached values)
minResults = self.webSearchMinResults
maxAllowedResults = self.webSearchMaxResults
if max_results < minResults or max_results > maxAllowedResults:
raise ValueError(f"max_results must be between {minResults} and {maxAllowedResults}")
# Perform actual API call
# Build kwargs only for provided options to avoid API rejections
kwargs: dict = {"query": query, "max_results": max_results}
if search_depth is not None:
kwargs["search_depth"] = search_depth
if time_range is not None:
kwargs["time_range"] = time_range
if topic is not None:
kwargs["topic"] = topic
if include_domains is not None and len(include_domains) > 0:
kwargs["include_domains"] = include_domains
if exclude_domains is not None:
kwargs["exclude_domains"] = exclude_domains
if language is not None:
kwargs["language"] = language
if country is not None:
kwargs["country"] = country
if include_answer is not None:
kwargs["include_answer"] = include_answer
if include_raw_content is not None:
kwargs["include_raw_content"] = include_raw_content
logger.debug(f"Tavily.search kwargs: {kwargs}")
response = await self.client.search(**kwargs)
return [
WebSearchResult(
title=result["title"],
url=self._clean_url(result["url"]),
raw_content=result.get("raw_content")
)
for result in response["results"]
]
async def _crawl(
self,
urls: list,
extract_depth: str | None = None,
format: str | None = None,
) -> list[WebCrawlResult]:
"""Calls the Tavily API to extract text content from URLs with retry logic."""
maxRetries = self.crawlMaxRetries
retryDelay = self.crawlRetryDelay
timeout = self.crawlTimeout
logger.debug(f"Starting crawl of {len(urls)} URLs: {urls}")
logger.debug(f"Crawl settings: extract_depth={extract_depth}, format={format}, timeout={timeout}s")
for attempt in range(maxRetries + 1):
try:
logger.debug(f"Crawl attempt {attempt + 1}/{maxRetries + 1}")
# Use asyncio.wait_for for timeout
# Build kwargs for extract
kwargs_extract: dict = {"urls": urls}
kwargs_extract["extract_depth"] = extract_depth or "advanced"
kwargs_extract["format"] = format or "markdown" # Use markdown to get HTML structure
logger.debug(f"Sending request to Tavily with kwargs: {kwargs_extract}")
response = await asyncio.wait_for(
self.client.extract(**kwargs_extract),
timeout=timeout
)
logger.debug(f"Tavily response received: {list(response.keys())}")
# Debug: Log what Tavily actually returns
if "results" in response and response["results"]:
logger.debug(f"Tavily returned {len(response['results'])} results")
logger.debug(f"First result keys: {list(response['results'][0].keys())}")
logger.debug(f"First result has raw_content: {'raw_content' in response['results'][0]}")
# Log each result
for i, result in enumerate(response["results"]):
logger.debug(f"Result {i+1}: URL={result.get('url', 'N/A')}, content_length={len(result.get('raw_content', result.get('content', '')))}")
else:
logger.warning(f"Tavily returned no results in response: {response}")
results = [
WebCrawlResult(
url=result["url"],
content=result.get("raw_content", result.get("content", "")) # Try raw_content first, fallback to content
)
for result in response["results"]
]
logger.debug(f"Crawl successful: extracted {len(results)} results")
return results
except asyncio.TimeoutError:
logger.warning(f"Crawl attempt {attempt + 1} timed out after {timeout} seconds for URLs: {urls}")
if attempt < maxRetries:
logger.info(f"Retrying in {retryDelay} seconds...")
await asyncio.sleep(retryDelay)
else:
raise Exception(f"Crawl failed after {maxRetries + 1} attempts due to timeout")
except Exception as e:
logger.warning(f"Crawl attempt {attempt + 1} failed for URLs {urls}: {str(e)}")
logger.debug(f"Full error details: {type(e).__name__}: {str(e)}")
# Check if it's a validation error and log more details
if "validation" in str(e).lower():
logger.debug(f"URL validation failed. Checking URL format:")
for i, url in enumerate(urls):
logger.debug(f" URL {i+1}: '{url}' (length: {len(url)})")
# Check for common URL issues
if ' ' in url:
logger.debug(f" WARNING: URL contains spaces!")
if not url.startswith(('http://', 'https://')):
logger.debug(f" WARNING: URL doesn't start with http/https!")
if len(url) > 2000:
logger.debug(f" WARNING: URL is very long ({len(url)} chars)")
if attempt < maxRetries:
logger.info(f"Retrying in {retryDelay} seconds...")
await asyncio.sleep(retryDelay)
else:
raise Exception(f"Crawl failed after {maxRetries + 1} attempts: {str(e)}")