diff --git a/modules/connectors/connector_tavily.py b/modules/connectors/connector_tavily.py index bcb38e3f..783fea8c 100644 --- a/modules/connectors/connector_tavily.py +++ b/modules/connectors/connector_tavily.py @@ -4,11 +4,16 @@ import logging import os from dataclasses import dataclass from modules.interfaces.interface_web_model import ( + WebCrawlBase, + WebCrawlDocumentData, + WebCrawlRequest, WebSearchBase, WebSearchRequest, WebSearchActionResult, WebSearchActionDocument, WebSearchDocumentData, + WebCrawlActionDocument, + WebCrawlActionResult, ) # from modules.interfaces.interfaceChatModel import ActionResult, ActionDocument @@ -20,7 +25,7 @@ logger = logging.getLogger(__name__) @dataclass -class ConnectorTavily(WebSearchBase): +class ConnectorTavily(WebSearchBase, WebCrawlBase): client: AsyncTavilyClient = None @classmethod @@ -28,7 +33,10 @@ class ConnectorTavily(WebSearchBase): return cls(client=AsyncTavilyClient(api_key=os.getenv("TAVILY_API_KEY"))) async def search_urls(self, request: WebSearchRequest) -> WebSearchActionResult: - """Handles the web search request.""" + """Handles the web search request. + + Takes a query and returns a list of URLs. + """ # Step 1: Search try: search_results = await self._search(request.query, request.max_results) @@ -37,12 +45,28 @@ class ConnectorTavily(WebSearchBase): # Step 2: Build ActionResult try: - result = self._build_action_result(search_results) + result = self._build_search_action_result(search_results) except Exception as e: return WebSearchActionResult(success=False, error=str(e)) return result + async def crawl_urls(self, request: WebCrawlRequest) -> WebCrawlActionResult: + """Crawls the given URLs and returns the extracted text content.""" + # Step 1: Crawl + try: + crawl_results = await self._crawl(request.urls) + except Exception as e: + return WebCrawlActionResult(success=False, error=str(e)) + + # Step 2: Build ActionResult + try: + result = self._build_crawl_action_result(crawl_results) + except Exception as e: + return WebCrawlActionResult(success=False, error=str(e)) + + return result + async def _search(self, query: str, max_results: int) -> WebSearchActionResult: """Calls the Tavily API to perform a web search.""" # Make sure max_results is within the allowed range @@ -56,7 +80,9 @@ class ConnectorTavily(WebSearchBase): return response["results"] - def _build_action_result(self, search_results: list) -> WebSearchActionResult: + def _build_search_action_result( + self, search_results: list + ) -> WebSearchActionResult: """Builds the ActionResult from the search results.""" documents = [] for result in search_results: @@ -75,3 +101,30 @@ class ConnectorTavily(WebSearchBase): return WebSearchActionResult( success=True, documents=documents, resultLabel="web_search_results" ) + + async def _crawl(self, urls: list) -> list[str]: + """Calls the Tavily API to extract text content from URLs.""" + response = await self.client.extract( + urls=urls, extract_depth="advanced", format="text" + ) + return response["results"] + + def _build_crawl_action_result(self, crawl_results: list) -> WebCrawlActionResult: + """Builds the ActionResult from the crawl results.""" + documents = [] + for result in crawl_results: + document_name = f"web_crawl_{get_utc_timestamp()}.txt" + doc_data = WebCrawlDocumentData( + url=result["url"], content=result["raw_content"] + ) + mime_type = "application/json" + doc = WebCrawlActionDocument( + documentName=document_name, + documentData=doc_data, + mimeType=mime_type, + ) + documents.append(doc) + + return WebCrawlActionResult( + success=True, documents=documents, resultLabel="web_crawl_results" + ) diff --git a/modules/interfaces/interface_web_model.py b/modules/interfaces/interface_web_model.py index 8dc01fc8..0a258623 100644 --- a/modules/interfaces/interface_web_model.py +++ b/modules/interfaces/interface_web_model.py @@ -40,10 +40,31 @@ class WebSearchBase(ABC): # list of URLs -> list of extracted HTML content -# TODO + +class WebCrawlRequest(BaseModel): + urls: List[str] + + +class WebCrawlDocumentData(BaseModel): + url: str + content: str + + +class WebCrawlActionDocument(ActionDocument): + documentData: WebCrawlDocumentData = Field( + description="The data extracted from a single crawled URL" + ) + + +class WebCrawlActionResult(ActionResult): + documents: List[WebCrawlActionDocument] = Field(default_factory=list) + + +class WebCrawlBase(ABC): + @abstractmethod + async def crawl_urls(self, request: WebCrawlRequest) -> WebCrawlActionResult: ... + # --- Web query --- -# query -> list of extracted text - -# TODO +# query -> list of extracted text; combines web search and crawl in one step