feat: add web crawl connector; interface
This commit is contained in:
parent
181f55359b
commit
31177063de
2 changed files with 82 additions and 8 deletions
|
|
@ -4,11 +4,16 @@ import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from modules.interfaces.interface_web_model import (
|
from modules.interfaces.interface_web_model import (
|
||||||
|
WebCrawlBase,
|
||||||
|
WebCrawlDocumentData,
|
||||||
|
WebCrawlRequest,
|
||||||
WebSearchBase,
|
WebSearchBase,
|
||||||
WebSearchRequest,
|
WebSearchRequest,
|
||||||
WebSearchActionResult,
|
WebSearchActionResult,
|
||||||
WebSearchActionDocument,
|
WebSearchActionDocument,
|
||||||
WebSearchDocumentData,
|
WebSearchDocumentData,
|
||||||
|
WebCrawlActionDocument,
|
||||||
|
WebCrawlActionResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
# from modules.interfaces.interfaceChatModel import ActionResult, ActionDocument
|
# from modules.interfaces.interfaceChatModel import ActionResult, ActionDocument
|
||||||
|
|
@ -20,7 +25,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConnectorTavily(WebSearchBase):
|
class ConnectorTavily(WebSearchBase, WebCrawlBase):
|
||||||
client: AsyncTavilyClient = None
|
client: AsyncTavilyClient = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -28,7 +33,10 @@ class ConnectorTavily(WebSearchBase):
|
||||||
return cls(client=AsyncTavilyClient(api_key=os.getenv("TAVILY_API_KEY")))
|
return cls(client=AsyncTavilyClient(api_key=os.getenv("TAVILY_API_KEY")))
|
||||||
|
|
||||||
async def search_urls(self, request: WebSearchRequest) -> WebSearchActionResult:
|
async def search_urls(self, request: WebSearchRequest) -> WebSearchActionResult:
|
||||||
"""Handles the web search request."""
|
"""Handles the web search request.
|
||||||
|
|
||||||
|
Takes a query and returns a list of URLs.
|
||||||
|
"""
|
||||||
# Step 1: Search
|
# Step 1: Search
|
||||||
try:
|
try:
|
||||||
search_results = await self._search(request.query, request.max_results)
|
search_results = await self._search(request.query, request.max_results)
|
||||||
|
|
@ -37,12 +45,28 @@ class ConnectorTavily(WebSearchBase):
|
||||||
|
|
||||||
# Step 2: Build ActionResult
|
# Step 2: Build ActionResult
|
||||||
try:
|
try:
|
||||||
result = self._build_action_result(search_results)
|
result = self._build_search_action_result(search_results)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return WebSearchActionResult(success=False, error=str(e))
|
return WebSearchActionResult(success=False, error=str(e))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def crawl_urls(self, request: WebCrawlRequest) -> WebCrawlActionResult:
|
||||||
|
"""Crawls the given URLs and returns the extracted text content."""
|
||||||
|
# Step 1: Crawl
|
||||||
|
try:
|
||||||
|
crawl_results = await self._crawl(request.urls)
|
||||||
|
except Exception as e:
|
||||||
|
return WebCrawlActionResult(success=False, error=str(e))
|
||||||
|
|
||||||
|
# Step 2: Build ActionResult
|
||||||
|
try:
|
||||||
|
result = self._build_crawl_action_result(crawl_results)
|
||||||
|
except Exception as e:
|
||||||
|
return WebCrawlActionResult(success=False, error=str(e))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def _search(self, query: str, max_results: int) -> WebSearchActionResult:
|
async def _search(self, query: str, max_results: int) -> WebSearchActionResult:
|
||||||
"""Calls the Tavily API to perform a web search."""
|
"""Calls the Tavily API to perform a web search."""
|
||||||
# Make sure max_results is within the allowed range
|
# Make sure max_results is within the allowed range
|
||||||
|
|
@ -56,7 +80,9 @@ class ConnectorTavily(WebSearchBase):
|
||||||
|
|
||||||
return response["results"]
|
return response["results"]
|
||||||
|
|
||||||
def _build_action_result(self, search_results: list) -> WebSearchActionResult:
|
def _build_search_action_result(
|
||||||
|
self, search_results: list
|
||||||
|
) -> WebSearchActionResult:
|
||||||
"""Builds the ActionResult from the search results."""
|
"""Builds the ActionResult from the search results."""
|
||||||
documents = []
|
documents = []
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
|
|
@ -75,3 +101,30 @@ class ConnectorTavily(WebSearchBase):
|
||||||
return WebSearchActionResult(
|
return WebSearchActionResult(
|
||||||
success=True, documents=documents, resultLabel="web_search_results"
|
success=True, documents=documents, resultLabel="web_search_results"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _crawl(self, urls: list) -> list[str]:
|
||||||
|
"""Calls the Tavily API to extract text content from URLs."""
|
||||||
|
response = await self.client.extract(
|
||||||
|
urls=urls, extract_depth="advanced", format="text"
|
||||||
|
)
|
||||||
|
return response["results"]
|
||||||
|
|
||||||
|
def _build_crawl_action_result(self, crawl_results: list) -> WebCrawlActionResult:
|
||||||
|
"""Builds the ActionResult from the crawl results."""
|
||||||
|
documents = []
|
||||||
|
for result in crawl_results:
|
||||||
|
document_name = f"web_crawl_{get_utc_timestamp()}.txt"
|
||||||
|
doc_data = WebCrawlDocumentData(
|
||||||
|
url=result["url"], content=result["raw_content"]
|
||||||
|
)
|
||||||
|
mime_type = "application/json"
|
||||||
|
doc = WebCrawlActionDocument(
|
||||||
|
documentName=document_name,
|
||||||
|
documentData=doc_data,
|
||||||
|
mimeType=mime_type,
|
||||||
|
)
|
||||||
|
documents.append(doc)
|
||||||
|
|
||||||
|
return WebCrawlActionResult(
|
||||||
|
success=True, documents=documents, resultLabel="web_crawl_results"
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -40,10 +40,31 @@ class WebSearchBase(ABC):
|
||||||
|
|
||||||
# list of URLs -> list of extracted HTML content
|
# list of URLs -> list of extracted HTML content
|
||||||
|
|
||||||
# TODO
|
|
||||||
|
class WebCrawlRequest(BaseModel):
|
||||||
|
urls: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class WebCrawlDocumentData(BaseModel):
|
||||||
|
url: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class WebCrawlActionDocument(ActionDocument):
|
||||||
|
documentData: WebCrawlDocumentData = Field(
|
||||||
|
description="The data extracted from a single crawled URL"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WebCrawlActionResult(ActionResult):
|
||||||
|
documents: List[WebCrawlActionDocument] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class WebCrawlBase(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def crawl_urls(self, request: WebCrawlRequest) -> WebCrawlActionResult: ...
|
||||||
|
|
||||||
|
|
||||||
# --- Web query ---
|
# --- Web query ---
|
||||||
|
|
||||||
# query -> list of extracted text
|
# query -> list of extracted text; combines web search and crawl in one step
|
||||||
|
|
||||||
# TODO
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue