134 lines
4.9 KiB
Python
134 lines
4.9 KiB
Python
"""Tool that allows the chatbot to interact with a remote SQLite database via API (read-only)."""
|
|
|
|
import logging
|
|
import httpx
|
|
from dataclasses import dataclass
|
|
from typing import Optional, List, Dict, Any
|
|
from langchain_core.tools import tool
|
|
from typing import Callable
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class SQLiteTool:
|
|
"""Remote SQLite database tool for searching articles via API."""
|
|
|
|
api_key: str
|
|
base_url: str
|
|
|
|
@classmethod
|
|
async def create(cls, *, api_key: str, base_url: str) -> "SQLiteTool":
|
|
"""Factory method to create SQLiteTool instance.
|
|
|
|
Args:
|
|
api_key: API key for authentication
|
|
base_url: Base URL of the preprocessing query API
|
|
|
|
Returns:
|
|
SQLiteTool instance
|
|
"""
|
|
return cls(api_key=api_key, base_url=base_url)
|
|
|
|
def get_tool(self) -> Callable[[str], str]:
|
|
"""Get the configured LangChain tool."""
|
|
|
|
@tool("execute_sql")
|
|
async def execute_sql_query(sql_query: str) -> str:
|
|
"""Execute a read-only SELECT query on the remote article database.
|
|
|
|
Only SELECT statements are allowed. No PRAGMA, INSERT, UPDATE, DELETE, or DDL operations permitted.
|
|
The database contains one table named "Data" with article information.
|
|
Your query must reference this table explicitly (e.g., SELECT * FROM Data WHERE ...).
|
|
Results are limited to 50 rows.
|
|
|
|
Args:
|
|
sql_query: SQLite SELECT query to execute (read-only operations only)
|
|
|
|
Returns:
|
|
The result of the query execution or an error message.
|
|
"""
|
|
logger.info(f"Executing SQL query via API: {sql_query}")
|
|
try:
|
|
# Check if query is read-only (starts with SELECT)
|
|
query_upper = sql_query.strip().upper()
|
|
if not query_upper.startswith("SELECT"):
|
|
return "Error: Only SELECT queries are allowed. No INSERT, UPDATE, DELETE, or DDL operations permitted."
|
|
|
|
# Additional safety checks for potentially harmful operations
|
|
forbidden_keywords = [
|
|
"DROP",
|
|
"CREATE",
|
|
"ALTER",
|
|
"INSERT",
|
|
"UPDATE",
|
|
"DELETE",
|
|
"PRAGMA",
|
|
"ATTACH",
|
|
"DETACH",
|
|
]
|
|
if any(keyword in query_upper for keyword in forbidden_keywords):
|
|
return "Error: Query contains forbidden keywords. Only SELECT queries are allowed."
|
|
|
|
# Make API request
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
response = await client.post(
|
|
self.base_url,
|
|
json={"query": sql_query},
|
|
headers={"X-DB-API-Key": self.api_key},
|
|
)
|
|
response.raise_for_status()
|
|
result = response.json()
|
|
|
|
# Parse API response
|
|
if not result.get("success"):
|
|
error_msg = result.get("message", "Unknown error")
|
|
return f"Query failed: {error_msg}"
|
|
|
|
data = result.get("data", [])
|
|
row_count = result.get("row_count", 0)
|
|
columns = result.get("columns", [])
|
|
|
|
if row_count == 0:
|
|
return "Query executed successfully but returned no results."
|
|
|
|
# Format results
|
|
results = []
|
|
for row in data[:50]: # Limit to 50 rows for readability
|
|
results.append(str(row))
|
|
|
|
return (
|
|
f"Query executed successfully. Returned {row_count} rows (showing first {min(row_count, 50)}):\n"
|
|
+ "\n".join(results)
|
|
)
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
return f"API error: HTTP {e.response.status_code} - {e.response.text}"
|
|
except httpx.RequestError as e:
|
|
return f"Network error: {str(e)}"
|
|
except Exception as e:
|
|
return f"Error executing query: {str(e)}"
|
|
|
|
return execute_sql_query
|
|
|
|
async def execute_query(self, query: str) -> Dict[str, Any]:
|
|
"""Execute a raw SQL query via the remote API.
|
|
|
|
Args:
|
|
query: SQL query string
|
|
|
|
Returns:
|
|
Dictionary with query results from the API
|
|
"""
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
response = await client.post(
|
|
self.base_url,
|
|
json={"query": query},
|
|
headers={"X-DB-API-Key": self.api_key},
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
except Exception as e:
|
|
raise Exception(f"Error executing query: {str(e)}")
|