208 lines
6.3 KiB
Python
208 lines
6.3 KiB
Python
"""Althaus Database Query Tool for LangGraph.
|
|
|
|
This tool provides database query capabilities for the Althaus database
|
|
via an external REST API. Only SELECT queries are allowed.
|
|
"""
|
|
|
|
import logging
|
|
import asyncio
|
|
import re
|
|
from typing import Annotated
|
|
from langchain_core.tools import tool
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def _mock_api_call(*, sql_query: str) -> dict:
|
|
"""Mock the external REST API call to Althaus database.
|
|
|
|
Args:
|
|
sql_query: The SQL SELECT query to execute
|
|
|
|
Returns:
|
|
A dictionary containing the query results with columns and rows
|
|
"""
|
|
# Simulate network delay
|
|
await asyncio.sleep(0.5)
|
|
|
|
# Mock response data based on common query patterns
|
|
if "users" in sql_query.lower():
|
|
return {
|
|
"columns": ["id", "username", "email", "created_at"],
|
|
"rows": [
|
|
[1, "john_doe", "john@example.com", "2024-01-15"],
|
|
[2, "jane_smith", "jane@example.com", "2024-02-20"],
|
|
[3, "bob_wilson", "bob@example.com", "2024-03-10"],
|
|
],
|
|
"row_count": 3,
|
|
}
|
|
elif "products" in sql_query.lower():
|
|
return {
|
|
"columns": ["product_id", "name", "price", "stock"],
|
|
"rows": [
|
|
[101, "Widget A", 29.99, 150],
|
|
[102, "Widget B", 39.99, 75],
|
|
[103, "Widget C", 19.99, 200],
|
|
],
|
|
"row_count": 3,
|
|
}
|
|
elif "orders" in sql_query.lower():
|
|
return {
|
|
"columns": ["order_id", "customer_id", "total", "status"],
|
|
"rows": [
|
|
[5001, 1, 129.99, "completed"],
|
|
[5002, 2, 89.50, "pending"],
|
|
[5003, 1, 199.99, "shipped"],
|
|
],
|
|
"row_count": 3,
|
|
}
|
|
else:
|
|
# Generic response for other queries
|
|
return {
|
|
"columns": ["id", "value", "description"],
|
|
"rows": [
|
|
[1, "Sample 1", "First sample entry"],
|
|
[2, "Sample 2", "Second sample entry"],
|
|
],
|
|
"row_count": 2,
|
|
}
|
|
|
|
|
|
def _validate_select_query(*, sql_query: str) -> tuple[bool, str]:
|
|
"""Validate that the query is a SELECT statement only.
|
|
|
|
Args:
|
|
sql_query: The SQL query to validate
|
|
|
|
Returns:
|
|
A tuple of (is_valid, error_message)
|
|
"""
|
|
# Remove leading/trailing whitespace and convert to lowercase for checking
|
|
normalized_query = sql_query.strip().lower()
|
|
|
|
# Check if query starts with SELECT
|
|
if not normalized_query.startswith("select"):
|
|
return False, "Query must be a SELECT statement"
|
|
|
|
# Check for dangerous keywords that should not be in a SELECT query
|
|
dangerous_keywords = [
|
|
"insert",
|
|
"update",
|
|
"delete",
|
|
"drop",
|
|
"create",
|
|
"alter",
|
|
"truncate",
|
|
"grant",
|
|
"revoke",
|
|
"exec",
|
|
"execute",
|
|
]
|
|
|
|
for keyword in dangerous_keywords:
|
|
# Use word boundary to match whole words only
|
|
if re.search(rf"\b{keyword}\b", normalized_query):
|
|
return False, f"Query contains forbidden keyword: {keyword.upper()}"
|
|
|
|
return True, ""
|
|
|
|
|
|
def _format_results(*, columns: list[str], rows: list[list], row_count: int) -> str:
|
|
"""Format query results into a readable string.
|
|
|
|
Args:
|
|
columns: List of column names
|
|
rows: List of row data
|
|
row_count: Total number of rows
|
|
|
|
Returns:
|
|
Formatted string representation of the results
|
|
"""
|
|
if row_count == 0:
|
|
return "Query executed successfully but returned no results."
|
|
|
|
# Calculate column widths
|
|
col_widths = [len(str(col)) for col in columns]
|
|
for row in rows:
|
|
for i, cell in enumerate(row):
|
|
col_widths[i] = max(col_widths[i], len(str(cell)))
|
|
|
|
# Build header
|
|
header_parts = []
|
|
for col, width in zip(columns, col_widths):
|
|
header_parts.append(str(col).ljust(width))
|
|
header = " | ".join(header_parts)
|
|
separator = "-" * len(header)
|
|
|
|
# Build rows
|
|
row_lines = []
|
|
for row in rows:
|
|
row_parts = []
|
|
for cell, width in zip(row, col_widths):
|
|
row_parts.append(str(cell).ljust(width))
|
|
row_lines.append(" | ".join(row_parts))
|
|
|
|
# Combine all parts
|
|
result_parts = [
|
|
f"Query returned {row_count} row(s):\n",
|
|
header,
|
|
separator,
|
|
"\n".join(row_lines),
|
|
]
|
|
|
|
return "\n".join(result_parts)
|
|
|
|
|
|
@tool
|
|
async def query_althaus_database(
|
|
sql_query: Annotated[
|
|
str, "The SQL SELECT query to execute against the Althaus database"
|
|
],
|
|
) -> str:
|
|
"""Execute a SELECT query against the Althaus database via REST API.
|
|
|
|
Use this tool to query data from the Althaus database. Only SELECT statements
|
|
are allowed for security reasons. The query will be forwarded to an external
|
|
REST API and the results will be returned in a formatted table.
|
|
|
|
Args:
|
|
sql_query: The SQL SELECT query to execute (e.g., "SELECT * FROM users WHERE id = 1")
|
|
|
|
Returns:
|
|
A formatted string containing the query results with columns and rows
|
|
"""
|
|
try:
|
|
# Validate the query
|
|
is_valid, error_msg = _validate_select_query(sql_query=sql_query)
|
|
if not is_valid:
|
|
logger.warning(f"Invalid query attempt: {sql_query[:100]}...")
|
|
return f"Error: {error_msg}"
|
|
|
|
logger.info(f"Executing Althaus database query: {sql_query[:100]}...")
|
|
|
|
# Mock the external REST API call
|
|
# In production, this would be replaced with actual REST API call:
|
|
# response = await httpx.AsyncClient().post(
|
|
# "https://api.althaus.example.com/query",
|
|
# json={"query": sql_query},
|
|
# headers={"Authorization": f"Bearer {api_key}"}
|
|
# )
|
|
# result = response.json()
|
|
|
|
result = await _mock_api_call(sql_query=sql_query)
|
|
|
|
# Format and return results
|
|
formatted_output = _format_results(
|
|
columns=result["columns"],
|
|
rows=result["rows"],
|
|
row_count=result["row_count"],
|
|
)
|
|
|
|
logger.info(
|
|
f"Query completed successfully, returned {result['row_count']} row(s)"
|
|
)
|
|
return formatted_output
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in query_althaus_database tool: {str(e)}")
|
|
return f"Error executing query: {str(e)}"
|