gateway/modules/aicore/aicorePluginTavily.py

579 lines
23 KiB
Python

"""Tavily web search class.
"""
import logging
import asyncio
import re
from dataclasses import dataclass
from typing import Optional, List
from tavily import AsyncTavilyClient
from modules.shared.configuration import APP_CONFIG
from modules.shared.timezoneUtils import get_utc_timestamp
from modules.aicore.aicoreBase import BaseConnectorAi
from modules.datamodels.datamodelAi import AiModel, PriorityEnum, ProcessingModeEnum, OperationTypeEnum, AiModelResponse, createOperationTypeRatings
logger = logging.getLogger(__name__)
@dataclass
class WebSearchResult:
title: str
url: str
raw_content: Optional[str] = None
@dataclass
class WebCrawlResult:
url: str
content: str
@dataclass
class WebResearchRequest:
"""Ultra-simplified web research request"""
user_prompt: str
urls: Optional[List[str]] = None
max_results: int = 5
max_pages: int = 10
search_depth: str = "basic"
extract_depth: str = "advanced"
format: str = "markdown"
country: Optional[str] = None
time_range: Optional[str] = None
topic: Optional[str] = None
language: Optional[str] = None
@dataclass
class WebResearchResult:
"""Ultra-simplified web research result - just success/error + documents"""
success: bool = True
error: Optional[str] = None
documents: List[dict] = None # Simple dict instead of ActionDocument
def __post_init__(self):
if self.documents is None:
self.documents = []
class ConnectorWeb(BaseConnectorAi):
"""Tavily web search connector."""
def __init__(self):
super().__init__()
self.client: Optional[AsyncTavilyClient] = None
# Cached settings loaded at initialization time
self.crawlTimeout: int = 30
self.crawlMaxRetries: int = 3
self.crawlRetryDelay: int = 2
# Cached web search constraints (camelCase per project style)
self.webSearchMinResults: int = 1
self.webSearchMaxResults: int = 20
def getConnectorType(self) -> str:
"""Get the connector type identifier."""
return "tavily"
def _extractUrlsFromPrompt(self, prompt: str) -> List[str]:
"""Extract URLs from a text prompt using regex."""
if not prompt:
return []
# URL regex pattern - matches http/https URLs
url_pattern = r'https?://(?:[-\w.])+(?:[:\d]+)?(?:/(?:[\w/_.])*(?:\?(?:[\w&=%.])*)?(?:#(?:[\w.])*)?)?'
urls = re.findall(url_pattern, prompt)
# Remove duplicates while preserving order
seen = set()
unique_urls = []
for url in urls:
if url not in seen:
seen.add(url)
unique_urls.append(url)
return unique_urls
def getModels(self) -> List[AiModel]:
"""Get all available Tavily models."""
return [
AiModel(
name="tavily-search",
displayName="Tavily Search",
connectorType="tavily",
apiUrl="https://api.tavily.com/search",
temperature=0.0, # Web search doesn't use temperature
maxTokens=0, # Web search doesn't use tokens
contextLength=0,
costPer1kTokensInput=0.0,
costPer1kTokensOutput=0.0,
speedRating=9, # Very fast for URL discovery
qualityRating=9, # Excellent URL discovery quality
# capabilities removed (not used in business logic)
functionCall=self.search,
priority=PriorityEnum.BALANCED,
processingMode=ProcessingModeEnum.BASIC,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.WEB_SEARCH, 10),
(OperationTypeEnum.WEB_RESEARCH, 3),
(OperationTypeEnum.WEB_CRAWL, 2),
(OperationTypeEnum.WEB_NEWS, 3),
(OperationTypeEnum.WEB_QUESTIONS, 2)
),
version="tavily-search",
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived, searchDepth="basic", numRequests=1: numRequests * (1 if searchDepth == "basic" else 2) * 0.008
),
AiModel(
name="tavily-extract",
displayName="Tavily Extract",
connectorType="tavily",
apiUrl="https://api.tavily.com/extract",
temperature=0.0, # Web crawling doesn't use temperature
maxTokens=0, # Web crawling doesn't use tokens
contextLength=0,
costPer1kTokensInput=0.0,
costPer1kTokensOutput=0.0,
speedRating=7, # Good for content extraction
qualityRating=9, # Excellent content extraction quality
# capabilities removed (not used in business logic)
functionCall=self.crawl,
priority=PriorityEnum.BALANCED,
processingMode=ProcessingModeEnum.BASIC,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.WEB_RESEARCH, 3),
(OperationTypeEnum.WEB_CRAWL, 10),
(OperationTypeEnum.WEB_NEWS, 3),
(OperationTypeEnum.WEB_QUESTIONS, 2)
),
version="tavily-extract",
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived, numPages=10, extractionDepth="basic", withInstructions=False, numSuccessfulExtractions=10: ((numPages / 10) * (2 if withInstructions else 1) + (numSuccessfulExtractions / 5) * (1 if extractionDepth == "basic" else 2)) * 0.008
),
AiModel(
name="tavily-search-extract",
displayName="Tavily Search & Extract",
connectorType="tavily",
apiUrl="https://api.tavily.com/search",
temperature=0.0, # Web scraping doesn't use temperature
maxTokens=0, # Web scraping doesn't use tokens
contextLength=0,
costPer1kTokensInput=0.0,
costPer1kTokensOutput=0.0,
speedRating=7, # Good for combined search+extract
qualityRating=8, # Good quality for structured data
# capabilities removed (not used in business logic)
functionCall=self.scrape,
priority=PriorityEnum.BALANCED,
processingMode=ProcessingModeEnum.BASIC,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.WEB_RESEARCH, 8),
(OperationTypeEnum.WEB_SEARCH, 6),
(OperationTypeEnum.WEB_CRAWL, 6),
(OperationTypeEnum.WEB_NEWS, 5),
(OperationTypeEnum.WEB_QUESTIONS, 5)
),
version="tavily-search-extract",
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived, searchDepth="basic", numSuccessfulUrls=1, extractionDepth="basic": ((1 if searchDepth == "basic" else 2) + (numSuccessfulUrls / 5) * (1 if extractionDepth == "basic" else 2)) * 0.008
)
]
@classmethod
async def create(cls):
api_key = APP_CONFIG.get("Connector_AiTavily_API_SECRET")
if not api_key:
raise ValueError("Tavily API key not configured. Please set Connector_AiTavily_API_SECRET in config.ini")
# Load and cache web crawl related configuration
crawlTimeout = int(APP_CONFIG.get("Web_Crawl_TIMEOUT", "30"))
crawlMaxRetries = int(APP_CONFIG.get("Web_Crawl_MAX_RETRIES", "3"))
crawlRetryDelay = int(APP_CONFIG.get("Web_Crawl_RETRY_DELAY", "2"))
return cls(
client=AsyncTavilyClient(api_key=api_key),
crawlTimeout=crawlTimeout,
crawlMaxRetries=crawlMaxRetries,
crawlRetryDelay=crawlRetryDelay,
webSearchMinResults=int(APP_CONFIG.get("Web_Search_MIN_RESULTS", "1")),
webSearchMaxResults=int(APP_CONFIG.get("Web_Search_MAX_RESULTS", "20")),
)
# Standardized method using AiModelCall/AiModelResponse pattern
async def search(self, modelCall) -> "AiModelResponse":
"""Search using standardized AiModelCall/AiModelResponse pattern"""
try:
# Extract parameters from modelCall
query = modelCall.messages[0]["content"] if modelCall.messages else ""
options = modelCall.options
raw_results = await self._search(
query=query,
max_results=options.get("max_results", 5),
search_depth=options.get("search_depth"),
time_range=options.get("time_range"),
topic=options.get("topic"),
include_domains=options.get("include_domains"),
exclude_domains=options.get("exclude_domains"),
language=options.get("language"),
include_answer=options.get("include_answer"),
include_raw_content=options.get("include_raw_content"),
)
# Convert to JSON string
results_json = {
"query": query,
"results": [
{
"title": result.title,
"url": result.url,
"content": getattr(result, 'raw_content', None)
}
for result in raw_results
],
"total_count": len(raw_results)
}
import json
content = json.dumps(results_json, indent=2)
return AiModelResponse(
content=content,
success=True,
metadata={
"total_count": len(raw_results),
"search_depth": options.get("search_depth", "basic")
}
)
except Exception as e:
return AiModelResponse(
content="",
success=False,
error=str(e)
)
async def crawl(self, modelCall) -> "AiModelResponse":
"""Crawl using standardized AiModelCall/AiModelResponse pattern"""
try:
# Extract parameters from modelCall
options = modelCall.options
urls = options.get("urls", [])
# If no URLs provided, try to extract URLs from the prompt
if not urls and modelCall.messages:
prompt = modelCall.messages[0]["content"] if modelCall.messages else ""
urls = self._extractUrlsFromPrompt(prompt)
if not urls:
return AiModelResponse(
content="No URLs provided for crawling",
success=False,
error="No URLs found in options or prompt"
)
raw_results = await self._crawl(
urls,
extract_depth=options.get("extract_depth"),
format=options.get("format"),
)
# Convert to JSON string
results_json = {
"urls": urls,
"results": [
{
"url": result.url,
"content": result.content
}
for result in raw_results
],
"total_count": len(raw_results)
}
import json
content = json.dumps(results_json, indent=2)
return AiModelResponse(
content=content,
success=True,
metadata={
"total_count": len(raw_results),
"extract_depth": options.get("extract_depth", "basic")
}
)
except Exception as e:
return AiModelResponse(
content="",
success=False,
error=str(e)
)
async def scrape(self, modelCall) -> "AiModelResponse":
"""Scrape using standardized AiModelCall/AiModelResponse pattern"""
try:
# Extract parameters from modelCall
query = modelCall.messages[0]["content"] if modelCall.messages else ""
options = modelCall.options
search_results = await self._search(
query=query,
max_results=options.get("max_results", 5),
search_depth=options.get("search_depth"),
time_range=options.get("time_range"),
topic=options.get("topic"),
include_domains=options.get("include_domains"),
exclude_domains=options.get("exclude_domains"),
language=options.get("language"),
include_answer=options.get("include_answer"),
include_raw_content=options.get("include_raw_content"),
)
urls = [result.url for result in search_results]
crawl_results = await self._crawl(
urls,
extract_depth=options.get("extract_depth"),
format=options.get("format"),
)
# Convert to JSON string
results_json = {
"query": query,
"results": [
{
"url": result.url,
"content": result.content
}
for result in crawl_results
],
"total_count": len(crawl_results)
}
import json
content = json.dumps(results_json, indent=2)
return AiModelResponse(
content=content,
success=True,
metadata={
"total_count": len(crawl_results),
"search_depth": options.get("search_depth", "basic"),
"extract_depth": options.get("extract_depth", "basic")
}
)
except Exception as e:
return AiModelResponse(
content="",
success=False,
error=str(e)
)
# Helper Functions
async def _search_urls_raw(self,
*,
query: str,
max_results: int,
search_depth: str | None = None,
time_range: str | None = None,
topic: str | None = None,
include_domains: list[str] | None = None,
exclude_domains: list[str] | None = None,
language: str | None = None,
include_answer: bool | None = None,
include_raw_content: bool | None = None,
) -> list["WebSearchResult"]:
return await self._search(
query=query,
max_results=max_results,
search_depth=search_depth,
time_range=time_range,
topic=topic,
include_domains=include_domains,
exclude_domains=exclude_domains,
language=language,
include_answer=include_answer,
include_raw_content=include_raw_content,
)
async def _crawl_urls_raw(self,
*,
urls: list[str],
extract_depth: str | None = None,
format: str | None = None,
) -> list["WebCrawlResult"]:
return await self._crawl(urls, extract_depth=extract_depth, format=format)
async def _scrape_raw(self,
*,
query: str,
max_results: int,
search_depth: str | None = None,
time_range: str | None = None,
topic: str | None = None,
include_domains: list[str] | None = None,
exclude_domains: list[str] | None = None,
language: str | None = None,
include_answer: bool | None = None,
include_raw_content: bool | None = None,
extract_depth: str | None = None,
format: str | None = None,
) -> list["WebCrawlResult"]:
search_results = await self._search(
query=query,
max_results=max_results,
search_depth=search_depth,
time_range=time_range,
topic=topic,
include_domains=include_domains,
exclude_domains=exclude_domains,
language=language,
include_answer=include_answer,
include_raw_content=include_raw_content,
)
urls = [result.url for result in search_results]
return await self._crawl(urls, extract_depth=extract_depth, format=format)
def _clean_url(self, url: str) -> str:
"""Clean URL by removing extra text that might be appended."""
import re
# Extract just the URL part, removing any extra text after it
url_match = re.match(r'(https?://[^\s,]+)', url)
if url_match:
return url_match.group(1)
return url
async def _search(
self,
query: str,
max_results: int,
search_depth: str | None = None,
time_range: str | None = None,
topic: str | None = None,
include_domains: list[str] | None = None,
exclude_domains: list[str] | None = None,
language: str | None = None,
country: str | None = None,
include_answer: bool | None = None,
include_raw_content: bool | None = None,
) -> list[WebSearchResult]:
"""Calls the Tavily API to perform a web search."""
# Make sure max_results is within the allowed range (use cached values)
minResults = self.webSearchMinResults
maxAllowedResults = self.webSearchMaxResults
if max_results < minResults or max_results > maxAllowedResults:
raise ValueError(f"max_results must be between {minResults} and {maxAllowedResults}")
# Perform actual API call
# Build kwargs only for provided options to avoid API rejections
kwargs: dict = {"query": query, "max_results": max_results}
if search_depth is not None:
kwargs["search_depth"] = search_depth
if time_range is not None:
kwargs["time_range"] = time_range
if topic is not None:
kwargs["topic"] = topic
if include_domains is not None and len(include_domains) > 0:
kwargs["include_domains"] = include_domains
if exclude_domains is not None:
kwargs["exclude_domains"] = exclude_domains
if language is not None:
kwargs["language"] = language
if country is not None:
kwargs["country"] = country
if include_answer is not None:
kwargs["include_answer"] = include_answer
if include_raw_content is not None:
kwargs["include_raw_content"] = include_raw_content
logger.debug(f"Tavily.search kwargs: {kwargs}")
response = await self.client.search(**kwargs)
return [
WebSearchResult(
title=result["title"],
url=self._clean_url(result["url"]),
raw_content=result.get("raw_content")
)
for result in response["results"]
]
async def _crawl(
self,
urls: list,
extract_depth: str | None = None,
format: str | None = None,
) -> list[WebCrawlResult]:
"""Calls the Tavily API to extract text content from URLs with retry logic."""
maxRetries = self.crawlMaxRetries
retryDelay = self.crawlRetryDelay
timeout = self.crawlTimeout
logger.debug(f"Starting crawl of {len(urls)} URLs: {urls}")
logger.debug(f"Crawl settings: extract_depth={extract_depth}, format={format}, timeout={timeout}s")
for attempt in range(maxRetries + 1):
try:
logger.debug(f"Crawl attempt {attempt + 1}/{maxRetries + 1}")
# Use asyncio.wait_for for timeout
# Build kwargs for extract
kwargs_extract: dict = {"urls": urls}
kwargs_extract["extract_depth"] = extract_depth or "advanced"
kwargs_extract["format"] = format or "markdown" # Use markdown to get HTML structure
logger.debug(f"Sending request to Tavily with kwargs: {kwargs_extract}")
response = await asyncio.wait_for(
self.client.extract(**kwargs_extract),
timeout=timeout
)
logger.debug(f"Tavily response received: {list(response.keys())}")
# Debug: Log what Tavily actually returns
if "results" in response and response["results"]:
logger.debug(f"Tavily returned {len(response['results'])} results")
logger.debug(f"First result keys: {list(response['results'][0].keys())}")
logger.debug(f"First result has raw_content: {'raw_content' in response['results'][0]}")
# Log each result
for i, result in enumerate(response["results"]):
logger.debug(f"Result {i+1}: URL={result.get('url', 'N/A')}, content_length={len(result.get('raw_content', result.get('content', '')))}")
else:
logger.warning(f"Tavily returned no results in response: {response}")
results = [
WebCrawlResult(
url=result["url"],
content=result.get("raw_content", result.get("content", "")) # Try raw_content first, fallback to content
)
for result in response["results"]
]
logger.debug(f"Crawl successful: extracted {len(results)} results")
return results
except asyncio.TimeoutError:
logger.warning(f"Crawl attempt {attempt + 1} timed out after {timeout} seconds for URLs: {urls}")
if attempt < maxRetries:
logger.info(f"Retrying in {retryDelay} seconds...")
await asyncio.sleep(retryDelay)
else:
raise Exception(f"Crawl failed after {maxRetries + 1} attempts due to timeout")
except Exception as e:
logger.warning(f"Crawl attempt {attempt + 1} failed for URLs {urls}: {str(e)}")
logger.debug(f"Full error details: {type(e).__name__}: {str(e)}")
# Check if it's a validation error and log more details
if "validation" in str(e).lower():
logger.debug(f"URL validation failed. Checking URL format:")
for i, url in enumerate(urls):
logger.debug(f" URL {i+1}: '{url}' (length: {len(url)})")
# Check for common URL issues
if ' ' in url:
logger.debug(f" WARNING: URL contains spaces!")
if not url.startswith(('http://', 'https://')):
logger.debug(f" WARNING: URL doesn't start with http/https!")
if len(url) > 2000:
logger.debug(f" WARNING: URL is very long ({len(url)} chars)")
if attempt < maxRetries:
logger.info(f"Retrying in {retryDelay} seconds...")
await asyncio.sleep(retryDelay)
else:
raise Exception(f"Crawl failed after {maxRetries + 1} attempts: {str(e)}")