gateway/modules/features/chatBot/chatbotTools/customerTools/toolQueryAlthausDatabase.py
2025-10-03 10:29:37 +02:00

208 lines
6.3 KiB
Python

"""Althaus Database Query Tool for LangGraph.
This tool provides database query capabilities for the Althaus database
via an external REST API. Only SELECT queries are allowed.
"""
import logging
import asyncio
import re
from typing import Annotated
from langchain_core.tools import tool
logger = logging.getLogger(__name__)
async def _mock_api_call(*, sql_query: str) -> dict:
"""Mock the external REST API call to Althaus database.
Args:
sql_query: The SQL SELECT query to execute
Returns:
A dictionary containing the query results with columns and rows
"""
# Simulate network delay
await asyncio.sleep(0.5)
# Mock response data based on common query patterns
if "users" in sql_query.lower():
return {
"columns": ["id", "username", "email", "created_at"],
"rows": [
[1, "john_doe", "john@example.com", "2024-01-15"],
[2, "jane_smith", "jane@example.com", "2024-02-20"],
[3, "bob_wilson", "bob@example.com", "2024-03-10"],
],
"row_count": 3,
}
elif "products" in sql_query.lower():
return {
"columns": ["product_id", "name", "price", "stock"],
"rows": [
[101, "Widget A", 29.99, 150],
[102, "Widget B", 39.99, 75],
[103, "Widget C", 19.99, 200],
],
"row_count": 3,
}
elif "orders" in sql_query.lower():
return {
"columns": ["order_id", "customer_id", "total", "status"],
"rows": [
[5001, 1, 129.99, "completed"],
[5002, 2, 89.50, "pending"],
[5003, 1, 199.99, "shipped"],
],
"row_count": 3,
}
else:
# Generic response for other queries
return {
"columns": ["id", "value", "description"],
"rows": [
[1, "Sample 1", "First sample entry"],
[2, "Sample 2", "Second sample entry"],
],
"row_count": 2,
}
def _validate_select_query(*, sql_query: str) -> tuple[bool, str]:
"""Validate that the query is a SELECT statement only.
Args:
sql_query: The SQL query to validate
Returns:
A tuple of (is_valid, error_message)
"""
# Remove leading/trailing whitespace and convert to lowercase for checking
normalized_query = sql_query.strip().lower()
# Check if query starts with SELECT
if not normalized_query.startswith("select"):
return False, "Query must be a SELECT statement"
# Check for dangerous keywords that should not be in a SELECT query
dangerous_keywords = [
"insert",
"update",
"delete",
"drop",
"create",
"alter",
"truncate",
"grant",
"revoke",
"exec",
"execute",
]
for keyword in dangerous_keywords:
# Use word boundary to match whole words only
if re.search(rf"\b{keyword}\b", normalized_query):
return False, f"Query contains forbidden keyword: {keyword.upper()}"
return True, ""
def _format_results(*, columns: list[str], rows: list[list], row_count: int) -> str:
"""Format query results into a readable string.
Args:
columns: List of column names
rows: List of row data
row_count: Total number of rows
Returns:
Formatted string representation of the results
"""
if row_count == 0:
return "Query executed successfully but returned no results."
# Calculate column widths
col_widths = [len(str(col)) for col in columns]
for row in rows:
for i, cell in enumerate(row):
col_widths[i] = max(col_widths[i], len(str(cell)))
# Build header
header_parts = []
for col, width in zip(columns, col_widths):
header_parts.append(str(col).ljust(width))
header = " | ".join(header_parts)
separator = "-" * len(header)
# Build rows
row_lines = []
for row in rows:
row_parts = []
for cell, width in zip(row, col_widths):
row_parts.append(str(cell).ljust(width))
row_lines.append(" | ".join(row_parts))
# Combine all parts
result_parts = [
f"Query returned {row_count} row(s):\n",
header,
separator,
"\n".join(row_lines),
]
return "\n".join(result_parts)
@tool
async def query_althaus_database(
sql_query: Annotated[
str, "The SQL SELECT query to execute against the Althaus database"
],
) -> str:
"""Execute a SELECT query against the Althaus database via REST API.
Use this tool to query data from the Althaus database. Only SELECT statements
are allowed for security reasons. The query will be forwarded to an external
REST API and the results will be returned in a formatted table.
Args:
sql_query: The SQL SELECT query to execute (e.g., "SELECT * FROM users WHERE id = 1")
Returns:
A formatted string containing the query results with columns and rows
"""
try:
# Validate the query
is_valid, error_msg = _validate_select_query(sql_query=sql_query)
if not is_valid:
logger.warning(f"Invalid query attempt: {sql_query[:100]}...")
return f"Error: {error_msg}"
logger.info(f"Executing Althaus database query: {sql_query[:100]}...")
# Mock the external REST API call
# In production, this would be replaced with actual REST API call:
# response = await httpx.AsyncClient().post(
# "https://api.althaus.example.com/query",
# json={"query": sql_query},
# headers={"Authorization": f"Bearer {api_key}"}
# )
# result = response.json()
result = await _mock_api_call(sql_query=sql_query)
# Format and return results
formatted_output = _format_results(
columns=result["columns"],
rows=result["rows"],
row_count=result["row_count"],
)
logger.info(
f"Query completed successfully, returned {result['row_count']} row(s)"
)
return formatted_output
except Exception as e:
logger.error(f"Error in query_althaus_database tool: {str(e)}")
return f"Error executing query: {str(e)}"