feat: mock althaus db query tool
This commit is contained in:
parent
98b258ae53
commit
63dba85b7a
1 changed files with 208 additions and 0 deletions
|
|
@ -0,0 +1,208 @@
|
|||
"""Althaus Database Query Tool for LangGraph.
|
||||
|
||||
This tool provides database query capabilities for the Althaus database
|
||||
via an external REST API. Only SELECT queries are allowed.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Annotated
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _mock_api_call(*, sql_query: str) -> dict:
|
||||
"""Mock the external REST API call to Althaus database.
|
||||
|
||||
Args:
|
||||
sql_query: The SQL SELECT query to execute
|
||||
|
||||
Returns:
|
||||
A dictionary containing the query results with columns and rows
|
||||
"""
|
||||
# Simulate network delay
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Mock response data based on common query patterns
|
||||
if "users" in sql_query.lower():
|
||||
return {
|
||||
"columns": ["id", "username", "email", "created_at"],
|
||||
"rows": [
|
||||
[1, "john_doe", "john@example.com", "2024-01-15"],
|
||||
[2, "jane_smith", "jane@example.com", "2024-02-20"],
|
||||
[3, "bob_wilson", "bob@example.com", "2024-03-10"],
|
||||
],
|
||||
"row_count": 3,
|
||||
}
|
||||
elif "products" in sql_query.lower():
|
||||
return {
|
||||
"columns": ["product_id", "name", "price", "stock"],
|
||||
"rows": [
|
||||
[101, "Widget A", 29.99, 150],
|
||||
[102, "Widget B", 39.99, 75],
|
||||
[103, "Widget C", 19.99, 200],
|
||||
],
|
||||
"row_count": 3,
|
||||
}
|
||||
elif "orders" in sql_query.lower():
|
||||
return {
|
||||
"columns": ["order_id", "customer_id", "total", "status"],
|
||||
"rows": [
|
||||
[5001, 1, 129.99, "completed"],
|
||||
[5002, 2, 89.50, "pending"],
|
||||
[5003, 1, 199.99, "shipped"],
|
||||
],
|
||||
"row_count": 3,
|
||||
}
|
||||
else:
|
||||
# Generic response for other queries
|
||||
return {
|
||||
"columns": ["id", "value", "description"],
|
||||
"rows": [
|
||||
[1, "Sample 1", "First sample entry"],
|
||||
[2, "Sample 2", "Second sample entry"],
|
||||
],
|
||||
"row_count": 2,
|
||||
}
|
||||
|
||||
|
||||
def _validate_select_query(*, sql_query: str) -> tuple[bool, str]:
|
||||
"""Validate that the query is a SELECT statement only.
|
||||
|
||||
Args:
|
||||
sql_query: The SQL query to validate
|
||||
|
||||
Returns:
|
||||
A tuple of (is_valid, error_message)
|
||||
"""
|
||||
# Remove leading/trailing whitespace and convert to lowercase for checking
|
||||
normalized_query = sql_query.strip().lower()
|
||||
|
||||
# Check if query starts with SELECT
|
||||
if not normalized_query.startswith("select"):
|
||||
return False, "Query must be a SELECT statement"
|
||||
|
||||
# Check for dangerous keywords that should not be in a SELECT query
|
||||
dangerous_keywords = [
|
||||
"insert",
|
||||
"update",
|
||||
"delete",
|
||||
"drop",
|
||||
"create",
|
||||
"alter",
|
||||
"truncate",
|
||||
"grant",
|
||||
"revoke",
|
||||
"exec",
|
||||
"execute",
|
||||
]
|
||||
|
||||
for keyword in dangerous_keywords:
|
||||
# Use word boundary to match whole words only
|
||||
if re.search(rf"\b{keyword}\b", normalized_query):
|
||||
return False, f"Query contains forbidden keyword: {keyword.upper()}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def _format_results(*, columns: list[str], rows: list[list], row_count: int) -> str:
|
||||
"""Format query results into a readable string.
|
||||
|
||||
Args:
|
||||
columns: List of column names
|
||||
rows: List of row data
|
||||
row_count: Total number of rows
|
||||
|
||||
Returns:
|
||||
Formatted string representation of the results
|
||||
"""
|
||||
if row_count == 0:
|
||||
return "Query executed successfully but returned no results."
|
||||
|
||||
# Calculate column widths
|
||||
col_widths = [len(str(col)) for col in columns]
|
||||
for row in rows:
|
||||
for i, cell in enumerate(row):
|
||||
col_widths[i] = max(col_widths[i], len(str(cell)))
|
||||
|
||||
# Build header
|
||||
header_parts = []
|
||||
for col, width in zip(columns, col_widths):
|
||||
header_parts.append(str(col).ljust(width))
|
||||
header = " | ".join(header_parts)
|
||||
separator = "-" * len(header)
|
||||
|
||||
# Build rows
|
||||
row_lines = []
|
||||
for row in rows:
|
||||
row_parts = []
|
||||
for cell, width in zip(row, col_widths):
|
||||
row_parts.append(str(cell).ljust(width))
|
||||
row_lines.append(" | ".join(row_parts))
|
||||
|
||||
# Combine all parts
|
||||
result_parts = [
|
||||
f"Query returned {row_count} row(s):\n",
|
||||
header,
|
||||
separator,
|
||||
"\n".join(row_lines),
|
||||
]
|
||||
|
||||
return "\n".join(result_parts)
|
||||
|
||||
|
||||
@tool
|
||||
async def query_althaus_database(
|
||||
sql_query: Annotated[
|
||||
str, "The SQL SELECT query to execute against the Althaus database"
|
||||
],
|
||||
) -> str:
|
||||
"""Execute a SELECT query against the Althaus database via REST API.
|
||||
|
||||
Use this tool to query data from the Althaus database. Only SELECT statements
|
||||
are allowed for security reasons. The query will be forwarded to an external
|
||||
REST API and the results will be returned in a formatted table.
|
||||
|
||||
Args:
|
||||
sql_query: The SQL SELECT query to execute (e.g., "SELECT * FROM users WHERE id = 1")
|
||||
|
||||
Returns:
|
||||
A formatted string containing the query results with columns and rows
|
||||
"""
|
||||
try:
|
||||
# Validate the query
|
||||
is_valid, error_msg = _validate_select_query(sql_query=sql_query)
|
||||
if not is_valid:
|
||||
logger.warning(f"Invalid query attempt: {sql_query[:100]}...")
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
logger.info(f"Executing Althaus database query: {sql_query[:100]}...")
|
||||
|
||||
# Mock the external REST API call
|
||||
# In production, this would be replaced with actual REST API call:
|
||||
# response = await httpx.AsyncClient().post(
|
||||
# "https://api.althaus.example.com/query",
|
||||
# json={"query": sql_query},
|
||||
# headers={"Authorization": f"Bearer {api_key}"}
|
||||
# )
|
||||
# result = response.json()
|
||||
|
||||
result = await _mock_api_call(sql_query=sql_query)
|
||||
|
||||
# Format and return results
|
||||
formatted_output = _format_results(
|
||||
columns=result["columns"],
|
||||
rows=result["rows"],
|
||||
row_count=result["row_count"],
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Query completed successfully, returned {result['row_count']} row(s)"
|
||||
)
|
||||
return formatted_output
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in query_althaus_database tool: {str(e)}")
|
||||
return f"Error executing query: {str(e)}"
|
||||
Loading…
Reference in a new issue