"""Althaus Database Query Tool for LangGraph. This tool provides database query capabilities for the Althaus database via an external REST API. Only SELECT queries are allowed. """ import logging import asyncio import re from typing import Annotated from langchain_core.tools import tool logger = logging.getLogger(__name__) async def _mock_api_call(*, sql_query: str) -> dict: """Mock the external REST API call to Althaus database. Args: sql_query: The SQL SELECT query to execute Returns: A dictionary containing the query results with columns and rows """ # Simulate network delay await asyncio.sleep(0.5) # Mock response data based on common query patterns if "users" in sql_query.lower(): return { "columns": ["id", "username", "email", "created_at"], "rows": [ [1, "john_doe", "john@example.com", "2024-01-15"], [2, "jane_smith", "jane@example.com", "2024-02-20"], [3, "bob_wilson", "bob@example.com", "2024-03-10"], ], "row_count": 3, } elif "products" in sql_query.lower(): return { "columns": ["product_id", "name", "price", "stock"], "rows": [ [101, "Widget A", 29.99, 150], [102, "Widget B", 39.99, 75], [103, "Widget C", 19.99, 200], ], "row_count": 3, } elif "orders" in sql_query.lower(): return { "columns": ["order_id", "customer_id", "total", "status"], "rows": [ [5001, 1, 129.99, "completed"], [5002, 2, 89.50, "pending"], [5003, 1, 199.99, "shipped"], ], "row_count": 3, } else: # Generic response for other queries return { "columns": ["id", "value", "description"], "rows": [ [1, "Sample 1", "First sample entry"], [2, "Sample 2", "Second sample entry"], ], "row_count": 2, } def _validate_select_query(*, sql_query: str) -> tuple[bool, str]: """Validate that the query is a SELECT statement only. Args: sql_query: The SQL query to validate Returns: A tuple of (is_valid, error_message) """ # Remove leading/trailing whitespace and convert to lowercase for checking normalized_query = sql_query.strip().lower() # Check if query starts with SELECT if not normalized_query.startswith("select"): return False, "Query must be a SELECT statement" # Check for dangerous keywords that should not be in a SELECT query dangerous_keywords = [ "insert", "update", "delete", "drop", "create", "alter", "truncate", "grant", "revoke", "exec", "execute", ] for keyword in dangerous_keywords: # Use word boundary to match whole words only if re.search(rf"\b{keyword}\b", normalized_query): return False, f"Query contains forbidden keyword: {keyword.upper()}" return True, "" def _format_results(*, columns: list[str], rows: list[list], row_count: int) -> str: """Format query results into a readable string. Args: columns: List of column names rows: List of row data row_count: Total number of rows Returns: Formatted string representation of the results """ if row_count == 0: return "Query executed successfully but returned no results." # Calculate column widths col_widths = [len(str(col)) for col in columns] for row in rows: for i, cell in enumerate(row): col_widths[i] = max(col_widths[i], len(str(cell))) # Build header header_parts = [] for col, width in zip(columns, col_widths): header_parts.append(str(col).ljust(width)) header = " | ".join(header_parts) separator = "-" * len(header) # Build rows row_lines = [] for row in rows: row_parts = [] for cell, width in zip(row, col_widths): row_parts.append(str(cell).ljust(width)) row_lines.append(" | ".join(row_parts)) # Combine all parts result_parts = [ f"Query returned {row_count} row(s):\n", header, separator, "\n".join(row_lines), ] return "\n".join(result_parts) @tool async def query_althaus_database( sql_query: Annotated[ str, "The SQL SELECT query to execute against the Althaus database" ], ) -> str: """Execute a SELECT query against the Althaus database via REST API. Use this tool to query data from the Althaus database. Only SELECT statements are allowed for security reasons. The query will be forwarded to an external REST API and the results will be returned in a formatted table. Args: sql_query: The SQL SELECT query to execute (e.g., "SELECT * FROM users WHERE id = 1") Returns: A formatted string containing the query results with columns and rows """ try: # Validate the query is_valid, error_msg = _validate_select_query(sql_query=sql_query) if not is_valid: logger.warning(f"Invalid query attempt: {sql_query[:100]}...") return f"Error: {error_msg}" logger.info(f"Executing Althaus database query: {sql_query[:100]}...") # Mock the external REST API call # In production, this would be replaced with actual REST API call: # response = await httpx.AsyncClient().post( # "https://api.althaus.example.com/query", # json={"query": sql_query}, # headers={"Authorization": f"Bearer {api_key}"} # ) # result = response.json() result = await _mock_api_call(sql_query=sql_query) # Format and return results formatted_output = _format_results( columns=result["columns"], rows=result["rows"], row_count=result["row_count"], ) logger.info( f"Query completed successfully, returned {result['row_count']} row(s)" ) return formatted_output except Exception as e: logger.error(f"Error in query_althaus_database tool: {str(e)}") return f"Error executing query: {str(e)}"