wiki/z-archive/implementation/Chatbot/legacy/sqlitetool.py

134 lines
4.9 KiB
Python

"""Tool that allows the chatbot to interact with a remote SQLite database via API (read-only)."""
import logging
import httpx
from dataclasses import dataclass
from typing import Optional, List, Dict, Any
from langchain_core.tools import tool
from typing import Callable
logger = logging.getLogger(__name__)
@dataclass
class SQLiteTool:
"""Remote SQLite database tool for searching articles via API."""
api_key: str
base_url: str
@classmethod
async def create(cls, *, api_key: str, base_url: str) -> "SQLiteTool":
"""Factory method to create SQLiteTool instance.
Args:
api_key: API key for authentication
base_url: Base URL of the preprocessing query API
Returns:
SQLiteTool instance
"""
return cls(api_key=api_key, base_url=base_url)
def get_tool(self) -> Callable[[str], str]:
"""Get the configured LangChain tool."""
@tool("execute_sql")
async def execute_sql_query(sql_query: str) -> str:
"""Execute a read-only SELECT query on the remote article database.
Only SELECT statements are allowed. No PRAGMA, INSERT, UPDATE, DELETE, or DDL operations permitted.
The database contains one table named "Data" with article information.
Your query must reference this table explicitly (e.g., SELECT * FROM Data WHERE ...).
Results are limited to 50 rows.
Args:
sql_query: SQLite SELECT query to execute (read-only operations only)
Returns:
The result of the query execution or an error message.
"""
logger.info(f"Executing SQL query via API: {sql_query}")
try:
# Check if query is read-only (starts with SELECT)
query_upper = sql_query.strip().upper()
if not query_upper.startswith("SELECT"):
return "Error: Only SELECT queries are allowed. No INSERT, UPDATE, DELETE, or DDL operations permitted."
# Additional safety checks for potentially harmful operations
forbidden_keywords = [
"DROP",
"CREATE",
"ALTER",
"INSERT",
"UPDATE",
"DELETE",
"PRAGMA",
"ATTACH",
"DETACH",
]
if any(keyword in query_upper for keyword in forbidden_keywords):
return "Error: Query contains forbidden keywords. Only SELECT queries are allowed."
# Make API request
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
self.base_url,
json={"query": sql_query},
headers={"X-DB-API-Key": self.api_key},
)
response.raise_for_status()
result = response.json()
# Parse API response
if not result.get("success"):
error_msg = result.get("message", "Unknown error")
return f"Query failed: {error_msg}"
data = result.get("data", [])
row_count = result.get("row_count", 0)
columns = result.get("columns", [])
if row_count == 0:
return "Query executed successfully but returned no results."
# Format results
results = []
for row in data[:50]: # Limit to 50 rows for readability
results.append(str(row))
return (
f"Query executed successfully. Returned {row_count} rows (showing first {min(row_count, 50)}):\n"
+ "\n".join(results)
)
except httpx.HTTPStatusError as e:
return f"API error: HTTP {e.response.status_code} - {e.response.text}"
except httpx.RequestError as e:
return f"Network error: {str(e)}"
except Exception as e:
return f"Error executing query: {str(e)}"
return execute_sql_query
async def execute_query(self, query: str) -> Dict[str, Any]:
"""Execute a raw SQL query via the remote API.
Args:
query: SQL query string
Returns:
Dictionary with query results from the API
"""
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
self.base_url,
json={"query": query},
headers={"X-DB-API-Key": self.api_key},
)
response.raise_for_status()
return response.json()
except Exception as e:
raise Exception(f"Error executing query: {str(e)}")