579 lines
23 KiB
Python
579 lines
23 KiB
Python
"""Tavily web search class.
|
|
"""
|
|
|
|
import logging
|
|
import asyncio
|
|
import re
|
|
from dataclasses import dataclass
|
|
from typing import Optional, List
|
|
from tavily import AsyncTavilyClient
|
|
from modules.shared.configuration import APP_CONFIG
|
|
from modules.shared.timezoneUtils import get_utc_timestamp
|
|
from modules.aicore.aicoreBase import BaseConnectorAi
|
|
from modules.datamodels.datamodelAi import AiModel, PriorityEnum, ProcessingModeEnum, OperationTypeEnum, AiModelResponse, createOperationTypeRatings
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class WebSearchResult:
|
|
title: str
|
|
url: str
|
|
raw_content: Optional[str] = None
|
|
|
|
@dataclass
|
|
class WebCrawlResult:
|
|
url: str
|
|
content: str
|
|
|
|
@dataclass
|
|
class WebResearchRequest:
|
|
"""Ultra-simplified web research request"""
|
|
user_prompt: str
|
|
urls: Optional[List[str]] = None
|
|
max_results: int = 5
|
|
max_pages: int = 10
|
|
search_depth: str = "basic"
|
|
extract_depth: str = "advanced"
|
|
format: str = "markdown"
|
|
country: Optional[str] = None
|
|
time_range: Optional[str] = None
|
|
topic: Optional[str] = None
|
|
language: Optional[str] = None
|
|
|
|
@dataclass
|
|
class WebResearchResult:
|
|
"""Ultra-simplified web research result - just success/error + documents"""
|
|
success: bool = True
|
|
error: Optional[str] = None
|
|
documents: List[dict] = None # Simple dict instead of ActionDocument
|
|
|
|
def __post_init__(self):
|
|
if self.documents is None:
|
|
self.documents = []
|
|
|
|
class ConnectorWeb(BaseConnectorAi):
|
|
"""Tavily web search connector."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.client: Optional[AsyncTavilyClient] = None
|
|
# Cached settings loaded at initialization time
|
|
self.crawlTimeout: int = 30
|
|
self.crawlMaxRetries: int = 3
|
|
self.crawlRetryDelay: int = 2
|
|
# Cached web search constraints (camelCase per project style)
|
|
self.webSearchMinResults: int = 1
|
|
self.webSearchMaxResults: int = 20
|
|
|
|
def getConnectorType(self) -> str:
|
|
"""Get the connector type identifier."""
|
|
return "tavily"
|
|
|
|
def _extractUrlsFromPrompt(self, prompt: str) -> List[str]:
|
|
"""Extract URLs from a text prompt using regex."""
|
|
if not prompt:
|
|
return []
|
|
|
|
# URL regex pattern - matches http/https URLs
|
|
url_pattern = r'https?://(?:[-\w.])+(?:[:\d]+)?(?:/(?:[\w/_.])*(?:\?(?:[\w&=%.])*)?(?:#(?:[\w.])*)?)?'
|
|
urls = re.findall(url_pattern, prompt)
|
|
|
|
# Remove duplicates while preserving order
|
|
seen = set()
|
|
unique_urls = []
|
|
for url in urls:
|
|
if url not in seen:
|
|
seen.add(url)
|
|
unique_urls.append(url)
|
|
|
|
return unique_urls
|
|
|
|
def getModels(self) -> List[AiModel]:
|
|
"""Get all available Tavily models."""
|
|
return [
|
|
AiModel(
|
|
name="tavily-search",
|
|
displayName="Tavily Search",
|
|
connectorType="tavily",
|
|
apiUrl="https://api.tavily.com/search",
|
|
temperature=0.0, # Web search doesn't use temperature
|
|
maxTokens=0, # Web search doesn't use tokens
|
|
contextLength=0,
|
|
costPer1kTokensInput=0.0,
|
|
costPer1kTokensOutput=0.0,
|
|
speedRating=9, # Very fast for URL discovery
|
|
qualityRating=9, # Excellent URL discovery quality
|
|
# capabilities removed (not used in business logic)
|
|
functionCall=self.search,
|
|
priority=PriorityEnum.BALANCED,
|
|
processingMode=ProcessingModeEnum.BASIC,
|
|
operationTypes=createOperationTypeRatings(
|
|
(OperationTypeEnum.WEB_SEARCH, 10),
|
|
(OperationTypeEnum.WEB_RESEARCH, 3),
|
|
(OperationTypeEnum.WEB_CRAWL, 2),
|
|
(OperationTypeEnum.WEB_NEWS, 3),
|
|
(OperationTypeEnum.WEB_QUESTIONS, 2)
|
|
),
|
|
version="tavily-search",
|
|
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived, searchDepth="basic", numRequests=1: numRequests * (1 if searchDepth == "basic" else 2) * 0.008
|
|
),
|
|
AiModel(
|
|
name="tavily-extract",
|
|
displayName="Tavily Extract",
|
|
connectorType="tavily",
|
|
apiUrl="https://api.tavily.com/extract",
|
|
temperature=0.0, # Web crawling doesn't use temperature
|
|
maxTokens=0, # Web crawling doesn't use tokens
|
|
contextLength=0,
|
|
costPer1kTokensInput=0.0,
|
|
costPer1kTokensOutput=0.0,
|
|
speedRating=7, # Good for content extraction
|
|
qualityRating=9, # Excellent content extraction quality
|
|
# capabilities removed (not used in business logic)
|
|
functionCall=self.crawl,
|
|
priority=PriorityEnum.BALANCED,
|
|
processingMode=ProcessingModeEnum.BASIC,
|
|
operationTypes=createOperationTypeRatings(
|
|
(OperationTypeEnum.WEB_RESEARCH, 3),
|
|
(OperationTypeEnum.WEB_CRAWL, 10),
|
|
(OperationTypeEnum.WEB_NEWS, 3),
|
|
(OperationTypeEnum.WEB_QUESTIONS, 2)
|
|
),
|
|
version="tavily-extract",
|
|
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived, numPages=10, extractionDepth="basic", withInstructions=False, numSuccessfulExtractions=10: ((numPages / 10) * (2 if withInstructions else 1) + (numSuccessfulExtractions / 5) * (1 if extractionDepth == "basic" else 2)) * 0.008
|
|
),
|
|
AiModel(
|
|
name="tavily-search-extract",
|
|
displayName="Tavily Search & Extract",
|
|
connectorType="tavily",
|
|
apiUrl="https://api.tavily.com/search",
|
|
temperature=0.0, # Web scraping doesn't use temperature
|
|
maxTokens=0, # Web scraping doesn't use tokens
|
|
contextLength=0,
|
|
costPer1kTokensInput=0.0,
|
|
costPer1kTokensOutput=0.0,
|
|
speedRating=7, # Good for combined search+extract
|
|
qualityRating=8, # Good quality for structured data
|
|
# capabilities removed (not used in business logic)
|
|
functionCall=self.scrape,
|
|
priority=PriorityEnum.BALANCED,
|
|
processingMode=ProcessingModeEnum.BASIC,
|
|
operationTypes=createOperationTypeRatings(
|
|
(OperationTypeEnum.WEB_RESEARCH, 8),
|
|
(OperationTypeEnum.WEB_SEARCH, 6),
|
|
(OperationTypeEnum.WEB_CRAWL, 6),
|
|
(OperationTypeEnum.WEB_NEWS, 5),
|
|
(OperationTypeEnum.WEB_QUESTIONS, 5)
|
|
),
|
|
version="tavily-search-extract",
|
|
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived, searchDepth="basic", numSuccessfulUrls=1, extractionDepth="basic": ((1 if searchDepth == "basic" else 2) + (numSuccessfulUrls / 5) * (1 if extractionDepth == "basic" else 2)) * 0.008
|
|
)
|
|
]
|
|
|
|
@classmethod
|
|
async def create(cls):
|
|
api_key = APP_CONFIG.get("Connector_AiTavily_API_SECRET")
|
|
if not api_key:
|
|
raise ValueError("Tavily API key not configured. Please set Connector_AiTavily_API_SECRET in config.ini")
|
|
# Load and cache web crawl related configuration
|
|
crawlTimeout = int(APP_CONFIG.get("Web_Crawl_TIMEOUT", "30"))
|
|
crawlMaxRetries = int(APP_CONFIG.get("Web_Crawl_MAX_RETRIES", "3"))
|
|
crawlRetryDelay = int(APP_CONFIG.get("Web_Crawl_RETRY_DELAY", "2"))
|
|
return cls(
|
|
client=AsyncTavilyClient(api_key=api_key),
|
|
crawlTimeout=crawlTimeout,
|
|
crawlMaxRetries=crawlMaxRetries,
|
|
crawlRetryDelay=crawlRetryDelay,
|
|
webSearchMinResults=int(APP_CONFIG.get("Web_Search_MIN_RESULTS", "1")),
|
|
webSearchMaxResults=int(APP_CONFIG.get("Web_Search_MAX_RESULTS", "20")),
|
|
)
|
|
|
|
# Standardized method using AiModelCall/AiModelResponse pattern
|
|
|
|
async def search(self, modelCall) -> "AiModelResponse":
|
|
"""Search using standardized AiModelCall/AiModelResponse pattern"""
|
|
try:
|
|
# Extract parameters from modelCall
|
|
query = modelCall.messages[0]["content"] if modelCall.messages else ""
|
|
options = modelCall.options
|
|
|
|
raw_results = await self._search(
|
|
query=query,
|
|
max_results=options.get("max_results", 5),
|
|
search_depth=options.get("search_depth"),
|
|
time_range=options.get("time_range"),
|
|
topic=options.get("topic"),
|
|
include_domains=options.get("include_domains"),
|
|
exclude_domains=options.get("exclude_domains"),
|
|
language=options.get("language"),
|
|
include_answer=options.get("include_answer"),
|
|
include_raw_content=options.get("include_raw_content"),
|
|
)
|
|
|
|
# Convert to JSON string
|
|
results_json = {
|
|
"query": query,
|
|
"results": [
|
|
{
|
|
"title": result.title,
|
|
"url": result.url,
|
|
"content": getattr(result, 'raw_content', None)
|
|
}
|
|
for result in raw_results
|
|
],
|
|
"total_count": len(raw_results)
|
|
}
|
|
|
|
import json
|
|
content = json.dumps(results_json, indent=2)
|
|
|
|
return AiModelResponse(
|
|
content=content,
|
|
success=True,
|
|
metadata={
|
|
"total_count": len(raw_results),
|
|
"search_depth": options.get("search_depth", "basic")
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
return AiModelResponse(
|
|
content="",
|
|
success=False,
|
|
error=str(e)
|
|
)
|
|
|
|
async def crawl(self, modelCall) -> "AiModelResponse":
|
|
"""Crawl using standardized AiModelCall/AiModelResponse pattern"""
|
|
try:
|
|
# Extract parameters from modelCall
|
|
options = modelCall.options
|
|
urls = options.get("urls", [])
|
|
|
|
# If no URLs provided, try to extract URLs from the prompt
|
|
if not urls and modelCall.messages:
|
|
prompt = modelCall.messages[0]["content"] if modelCall.messages else ""
|
|
urls = self._extractUrlsFromPrompt(prompt)
|
|
|
|
if not urls:
|
|
return AiModelResponse(
|
|
content="No URLs provided for crawling",
|
|
success=False,
|
|
error="No URLs found in options or prompt"
|
|
)
|
|
|
|
raw_results = await self._crawl(
|
|
urls,
|
|
extract_depth=options.get("extract_depth"),
|
|
format=options.get("format"),
|
|
)
|
|
|
|
# Convert to JSON string
|
|
results_json = {
|
|
"urls": urls,
|
|
"results": [
|
|
{
|
|
"url": result.url,
|
|
"content": result.content
|
|
}
|
|
for result in raw_results
|
|
],
|
|
"total_count": len(raw_results)
|
|
}
|
|
|
|
import json
|
|
content = json.dumps(results_json, indent=2)
|
|
|
|
return AiModelResponse(
|
|
content=content,
|
|
success=True,
|
|
metadata={
|
|
"total_count": len(raw_results),
|
|
"extract_depth": options.get("extract_depth", "basic")
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
return AiModelResponse(
|
|
content="",
|
|
success=False,
|
|
error=str(e)
|
|
)
|
|
|
|
async def scrape(self, modelCall) -> "AiModelResponse":
|
|
"""Scrape using standardized AiModelCall/AiModelResponse pattern"""
|
|
try:
|
|
# Extract parameters from modelCall
|
|
query = modelCall.messages[0]["content"] if modelCall.messages else ""
|
|
options = modelCall.options
|
|
|
|
search_results = await self._search(
|
|
query=query,
|
|
max_results=options.get("max_results", 5),
|
|
search_depth=options.get("search_depth"),
|
|
time_range=options.get("time_range"),
|
|
topic=options.get("topic"),
|
|
include_domains=options.get("include_domains"),
|
|
exclude_domains=options.get("exclude_domains"),
|
|
language=options.get("language"),
|
|
include_answer=options.get("include_answer"),
|
|
include_raw_content=options.get("include_raw_content"),
|
|
)
|
|
|
|
urls = [result.url for result in search_results]
|
|
crawl_results = await self._crawl(
|
|
urls,
|
|
extract_depth=options.get("extract_depth"),
|
|
format=options.get("format"),
|
|
)
|
|
|
|
# Convert to JSON string
|
|
results_json = {
|
|
"query": query,
|
|
"results": [
|
|
{
|
|
"url": result.url,
|
|
"content": result.content
|
|
}
|
|
for result in crawl_results
|
|
],
|
|
"total_count": len(crawl_results)
|
|
}
|
|
|
|
import json
|
|
content = json.dumps(results_json, indent=2)
|
|
|
|
return AiModelResponse(
|
|
content=content,
|
|
success=True,
|
|
metadata={
|
|
"total_count": len(crawl_results),
|
|
"search_depth": options.get("search_depth", "basic"),
|
|
"extract_depth": options.get("extract_depth", "basic")
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
return AiModelResponse(
|
|
content="",
|
|
success=False,
|
|
error=str(e)
|
|
)
|
|
|
|
# Helper Functions
|
|
|
|
async def _search_urls_raw(self,
|
|
*,
|
|
query: str,
|
|
max_results: int,
|
|
search_depth: str | None = None,
|
|
time_range: str | None = None,
|
|
topic: str | None = None,
|
|
include_domains: list[str] | None = None,
|
|
exclude_domains: list[str] | None = None,
|
|
language: str | None = None,
|
|
include_answer: bool | None = None,
|
|
include_raw_content: bool | None = None,
|
|
) -> list["WebSearchResult"]:
|
|
return await self._search(
|
|
query=query,
|
|
max_results=max_results,
|
|
search_depth=search_depth,
|
|
time_range=time_range,
|
|
topic=topic,
|
|
include_domains=include_domains,
|
|
exclude_domains=exclude_domains,
|
|
language=language,
|
|
include_answer=include_answer,
|
|
include_raw_content=include_raw_content,
|
|
)
|
|
|
|
async def _crawl_urls_raw(self,
|
|
*,
|
|
urls: list[str],
|
|
extract_depth: str | None = None,
|
|
format: str | None = None,
|
|
) -> list["WebCrawlResult"]:
|
|
return await self._crawl(urls, extract_depth=extract_depth, format=format)
|
|
|
|
async def _scrape_raw(self,
|
|
*,
|
|
query: str,
|
|
max_results: int,
|
|
search_depth: str | None = None,
|
|
time_range: str | None = None,
|
|
topic: str | None = None,
|
|
include_domains: list[str] | None = None,
|
|
exclude_domains: list[str] | None = None,
|
|
language: str | None = None,
|
|
include_answer: bool | None = None,
|
|
include_raw_content: bool | None = None,
|
|
extract_depth: str | None = None,
|
|
format: str | None = None,
|
|
) -> list["WebCrawlResult"]:
|
|
search_results = await self._search(
|
|
query=query,
|
|
max_results=max_results,
|
|
search_depth=search_depth,
|
|
time_range=time_range,
|
|
topic=topic,
|
|
include_domains=include_domains,
|
|
exclude_domains=exclude_domains,
|
|
language=language,
|
|
include_answer=include_answer,
|
|
include_raw_content=include_raw_content,
|
|
)
|
|
urls = [result.url for result in search_results]
|
|
return await self._crawl(urls, extract_depth=extract_depth, format=format)
|
|
|
|
def _clean_url(self, url: str) -> str:
|
|
"""Clean URL by removing extra text that might be appended."""
|
|
import re
|
|
# Extract just the URL part, removing any extra text after it
|
|
url_match = re.match(r'(https?://[^\s,]+)', url)
|
|
if url_match:
|
|
return url_match.group(1)
|
|
return url
|
|
|
|
async def _search(
|
|
self,
|
|
query: str,
|
|
max_results: int,
|
|
search_depth: str | None = None,
|
|
time_range: str | None = None,
|
|
topic: str | None = None,
|
|
include_domains: list[str] | None = None,
|
|
exclude_domains: list[str] | None = None,
|
|
language: str | None = None,
|
|
country: str | None = None,
|
|
include_answer: bool | None = None,
|
|
include_raw_content: bool | None = None,
|
|
) -> list[WebSearchResult]:
|
|
"""Calls the Tavily API to perform a web search."""
|
|
# Make sure max_results is within the allowed range (use cached values)
|
|
minResults = self.webSearchMinResults
|
|
maxAllowedResults = self.webSearchMaxResults
|
|
if max_results < minResults or max_results > maxAllowedResults:
|
|
raise ValueError(f"max_results must be between {minResults} and {maxAllowedResults}")
|
|
|
|
# Perform actual API call
|
|
# Build kwargs only for provided options to avoid API rejections
|
|
kwargs: dict = {"query": query, "max_results": max_results}
|
|
if search_depth is not None:
|
|
kwargs["search_depth"] = search_depth
|
|
if time_range is not None:
|
|
kwargs["time_range"] = time_range
|
|
if topic is not None:
|
|
kwargs["topic"] = topic
|
|
if include_domains is not None and len(include_domains) > 0:
|
|
kwargs["include_domains"] = include_domains
|
|
if exclude_domains is not None:
|
|
kwargs["exclude_domains"] = exclude_domains
|
|
if language is not None:
|
|
kwargs["language"] = language
|
|
if country is not None:
|
|
kwargs["country"] = country
|
|
if include_answer is not None:
|
|
kwargs["include_answer"] = include_answer
|
|
if include_raw_content is not None:
|
|
kwargs["include_raw_content"] = include_raw_content
|
|
|
|
logger.debug(f"Tavily.search kwargs: {kwargs}")
|
|
response = await self.client.search(**kwargs)
|
|
|
|
return [
|
|
WebSearchResult(
|
|
title=result["title"],
|
|
url=self._clean_url(result["url"]),
|
|
raw_content=result.get("raw_content")
|
|
)
|
|
for result in response["results"]
|
|
]
|
|
|
|
async def _crawl(
|
|
self,
|
|
urls: list,
|
|
extract_depth: str | None = None,
|
|
format: str | None = None,
|
|
) -> list[WebCrawlResult]:
|
|
"""Calls the Tavily API to extract text content from URLs with retry logic."""
|
|
maxRetries = self.crawlMaxRetries
|
|
retryDelay = self.crawlRetryDelay
|
|
timeout = self.crawlTimeout
|
|
|
|
logger.debug(f"Starting crawl of {len(urls)} URLs: {urls}")
|
|
logger.debug(f"Crawl settings: extract_depth={extract_depth}, format={format}, timeout={timeout}s")
|
|
|
|
for attempt in range(maxRetries + 1):
|
|
try:
|
|
logger.debug(f"Crawl attempt {attempt + 1}/{maxRetries + 1}")
|
|
|
|
# Use asyncio.wait_for for timeout
|
|
# Build kwargs for extract
|
|
kwargs_extract: dict = {"urls": urls}
|
|
kwargs_extract["extract_depth"] = extract_depth or "advanced"
|
|
kwargs_extract["format"] = format or "markdown" # Use markdown to get HTML structure
|
|
|
|
logger.debug(f"Sending request to Tavily with kwargs: {kwargs_extract}")
|
|
|
|
response = await asyncio.wait_for(
|
|
self.client.extract(**kwargs_extract),
|
|
timeout=timeout
|
|
)
|
|
|
|
logger.debug(f"Tavily response received: {list(response.keys())}")
|
|
|
|
# Debug: Log what Tavily actually returns
|
|
if "results" in response and response["results"]:
|
|
logger.debug(f"Tavily returned {len(response['results'])} results")
|
|
logger.debug(f"First result keys: {list(response['results'][0].keys())}")
|
|
logger.debug(f"First result has raw_content: {'raw_content' in response['results'][0]}")
|
|
|
|
# Log each result
|
|
for i, result in enumerate(response["results"]):
|
|
logger.debug(f"Result {i+1}: URL={result.get('url', 'N/A')}, content_length={len(result.get('raw_content', result.get('content', '')))}")
|
|
else:
|
|
logger.warning(f"Tavily returned no results in response: {response}")
|
|
|
|
results = [
|
|
WebCrawlResult(
|
|
url=result["url"],
|
|
content=result.get("raw_content", result.get("content", "")) # Try raw_content first, fallback to content
|
|
)
|
|
for result in response["results"]
|
|
]
|
|
|
|
logger.debug(f"Crawl successful: extracted {len(results)} results")
|
|
return results
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"Crawl attempt {attempt + 1} timed out after {timeout} seconds for URLs: {urls}")
|
|
if attempt < maxRetries:
|
|
logger.info(f"Retrying in {retryDelay} seconds...")
|
|
await asyncio.sleep(retryDelay)
|
|
else:
|
|
raise Exception(f"Crawl failed after {maxRetries + 1} attempts due to timeout")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Crawl attempt {attempt + 1} failed for URLs {urls}: {str(e)}")
|
|
logger.debug(f"Full error details: {type(e).__name__}: {str(e)}")
|
|
|
|
# Check if it's a validation error and log more details
|
|
if "validation" in str(e).lower():
|
|
logger.debug(f"URL validation failed. Checking URL format:")
|
|
for i, url in enumerate(urls):
|
|
logger.debug(f" URL {i+1}: '{url}' (length: {len(url)})")
|
|
# Check for common URL issues
|
|
if ' ' in url:
|
|
logger.debug(f" WARNING: URL contains spaces!")
|
|
if not url.startswith(('http://', 'https://')):
|
|
logger.debug(f" WARNING: URL doesn't start with http/https!")
|
|
if len(url) > 2000:
|
|
logger.debug(f" WARNING: URL is very long ({len(url)} chars)")
|
|
|
|
if attempt < maxRetries:
|
|
logger.info(f"Retrying in {retryDelay} seconds...")
|
|
await asyncio.sleep(retryDelay)
|
|
else:
|
|
raise Exception(f"Crawl failed after {maxRetries + 1} attempts: {str(e)}")
|