From 63dba85b7a30927d194f24c13c57f77f5cd79af2 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Fri, 3 Oct 2025 10:29:37 +0200 Subject: [PATCH] feat: mock althaus db query tool --- .../customerTools/toolQueryAlthausDatabase.py | 208 ++++++++++++++++++ 1 file changed, 208 insertions(+) create mode 100644 modules/features/chatBot/chatbotTools/customerTools/toolQueryAlthausDatabase.py diff --git a/modules/features/chatBot/chatbotTools/customerTools/toolQueryAlthausDatabase.py b/modules/features/chatBot/chatbotTools/customerTools/toolQueryAlthausDatabase.py new file mode 100644 index 00000000..72c15f15 --- /dev/null +++ b/modules/features/chatBot/chatbotTools/customerTools/toolQueryAlthausDatabase.py @@ -0,0 +1,208 @@ +"""Althaus Database Query Tool for LangGraph. + +This tool provides database query capabilities for the Althaus database +via an external REST API. Only SELECT queries are allowed. +""" + +import logging +import asyncio +import re +from typing import Annotated +from langchain_core.tools import tool + +logger = logging.getLogger(__name__) + + +async def _mock_api_call(*, sql_query: str) -> dict: + """Mock the external REST API call to Althaus database. + + Args: + sql_query: The SQL SELECT query to execute + + Returns: + A dictionary containing the query results with columns and rows + """ + # Simulate network delay + await asyncio.sleep(0.5) + + # Mock response data based on common query patterns + if "users" in sql_query.lower(): + return { + "columns": ["id", "username", "email", "created_at"], + "rows": [ + [1, "john_doe", "john@example.com", "2024-01-15"], + [2, "jane_smith", "jane@example.com", "2024-02-20"], + [3, "bob_wilson", "bob@example.com", "2024-03-10"], + ], + "row_count": 3, + } + elif "products" in sql_query.lower(): + return { + "columns": ["product_id", "name", "price", "stock"], + "rows": [ + [101, "Widget A", 29.99, 150], + [102, "Widget B", 39.99, 75], + [103, "Widget C", 19.99, 200], + ], + "row_count": 3, + } + elif "orders" in sql_query.lower(): + return { + "columns": ["order_id", "customer_id", "total", "status"], + "rows": [ + [5001, 1, 129.99, "completed"], + [5002, 2, 89.50, "pending"], + [5003, 1, 199.99, "shipped"], + ], + "row_count": 3, + } + else: + # Generic response for other queries + return { + "columns": ["id", "value", "description"], + "rows": [ + [1, "Sample 1", "First sample entry"], + [2, "Sample 2", "Second sample entry"], + ], + "row_count": 2, + } + + +def _validate_select_query(*, sql_query: str) -> tuple[bool, str]: + """Validate that the query is a SELECT statement only. + + Args: + sql_query: The SQL query to validate + + Returns: + A tuple of (is_valid, error_message) + """ + # Remove leading/trailing whitespace and convert to lowercase for checking + normalized_query = sql_query.strip().lower() + + # Check if query starts with SELECT + if not normalized_query.startswith("select"): + return False, "Query must be a SELECT statement" + + # Check for dangerous keywords that should not be in a SELECT query + dangerous_keywords = [ + "insert", + "update", + "delete", + "drop", + "create", + "alter", + "truncate", + "grant", + "revoke", + "exec", + "execute", + ] + + for keyword in dangerous_keywords: + # Use word boundary to match whole words only + if re.search(rf"\b{keyword}\b", normalized_query): + return False, f"Query contains forbidden keyword: {keyword.upper()}" + + return True, "" + + +def _format_results(*, columns: list[str], rows: list[list], row_count: int) -> str: + """Format query results into a readable string. + + Args: + columns: List of column names + rows: List of row data + row_count: Total number of rows + + Returns: + Formatted string representation of the results + """ + if row_count == 0: + return "Query executed successfully but returned no results." + + # Calculate column widths + col_widths = [len(str(col)) for col in columns] + for row in rows: + for i, cell in enumerate(row): + col_widths[i] = max(col_widths[i], len(str(cell))) + + # Build header + header_parts = [] + for col, width in zip(columns, col_widths): + header_parts.append(str(col).ljust(width)) + header = " | ".join(header_parts) + separator = "-" * len(header) + + # Build rows + row_lines = [] + for row in rows: + row_parts = [] + for cell, width in zip(row, col_widths): + row_parts.append(str(cell).ljust(width)) + row_lines.append(" | ".join(row_parts)) + + # Combine all parts + result_parts = [ + f"Query returned {row_count} row(s):\n", + header, + separator, + "\n".join(row_lines), + ] + + return "\n".join(result_parts) + + +@tool +async def query_althaus_database( + sql_query: Annotated[ + str, "The SQL SELECT query to execute against the Althaus database" + ], +) -> str: + """Execute a SELECT query against the Althaus database via REST API. + + Use this tool to query data from the Althaus database. Only SELECT statements + are allowed for security reasons. The query will be forwarded to an external + REST API and the results will be returned in a formatted table. + + Args: + sql_query: The SQL SELECT query to execute (e.g., "SELECT * FROM users WHERE id = 1") + + Returns: + A formatted string containing the query results with columns and rows + """ + try: + # Validate the query + is_valid, error_msg = _validate_select_query(sql_query=sql_query) + if not is_valid: + logger.warning(f"Invalid query attempt: {sql_query[:100]}...") + return f"Error: {error_msg}" + + logger.info(f"Executing Althaus database query: {sql_query[:100]}...") + + # Mock the external REST API call + # In production, this would be replaced with actual REST API call: + # response = await httpx.AsyncClient().post( + # "https://api.althaus.example.com/query", + # json={"query": sql_query}, + # headers={"Authorization": f"Bearer {api_key}"} + # ) + # result = response.json() + + result = await _mock_api_call(sql_query=sql_query) + + # Format and return results + formatted_output = _format_results( + columns=result["columns"], + rows=result["rows"], + row_count=result["row_count"], + ) + + logger.info( + f"Query completed successfully, returned {result['row_count']} row(s)" + ) + return formatted_output + + except Exception as e: + logger.error(f"Error in query_althaus_database tool: {str(e)}") + return f"Error executing query: {str(e)}"