157 lines
6.3 KiB
Python
157 lines
6.3 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
|
|
"""
|
|
Query Database action for Chatbot operations.
|
|
Executes SQL queries via the preprocessor connector.
|
|
"""
|
|
|
|
import logging
|
|
import json
|
|
import time
|
|
from typing import Dict, Any
|
|
from modules.workflows.methods.methodBase import action
|
|
from modules.datamodels.datamodelChat import ActionResult, ActionDocument
|
|
from modules.connectors.connectorPreprocessor import PreprocessorConnector
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@action
|
|
async def queryDatabase(self, parameters: Dict[str, Any]) -> ActionResult:
|
|
"""
|
|
Execute a SQL query via the preprocessor connector.
|
|
|
|
Parameters:
|
|
- sqlQuery (str, required): SQL SELECT query to execute. Can also be extracted from analysis_result document if provided in documentList.
|
|
"""
|
|
try:
|
|
# Init progress logger
|
|
workflowId = self.services.workflow.id if self.services.workflow else f"no-workflow-{int(time.time())}"
|
|
operationId = f"chatbot_query_db_{workflowId}_{int(time.time())}"
|
|
|
|
# Start progress tracking
|
|
parentOperationId = parameters.get('parentOperationId')
|
|
self.services.chat.progressLogStart(
|
|
operationId,
|
|
"Database Query",
|
|
"Executing SQL Query",
|
|
"Preprocessing API",
|
|
parentOperationId=parentOperationId
|
|
)
|
|
|
|
# Get SQL query from parameters or extract from documentList
|
|
sqlQuery = parameters.get("sqlQuery")
|
|
|
|
# If sqlQuery not provided, try to extract from documentList (analysis_result)
|
|
if not sqlQuery:
|
|
documentListParam = parameters.get("documentList")
|
|
if documentListParam:
|
|
# Get documents from previous task
|
|
from modules.datamodels.datamodelDocref import DocumentReferenceList
|
|
if isinstance(documentListParam, str):
|
|
docList = DocumentReferenceList.from_string_list([documentListParam])
|
|
elif isinstance(documentListParam, list):
|
|
docList = DocumentReferenceList.from_string_list(documentListParam)
|
|
else:
|
|
docList = documentListParam
|
|
|
|
# Get documents from workflow
|
|
documents = self.services.chat.getChatDocumentsFromDocumentList(docList)
|
|
|
|
# Try to extract SQL query from JSON document
|
|
for doc in documents:
|
|
try:
|
|
# ChatDocument objects have fileId - get file data from database
|
|
if hasattr(doc, 'fileId') and doc.fileId:
|
|
# Get file data from database
|
|
fileData = self.services.interfaceDbComponent.getFileData(doc.fileId)
|
|
if fileData:
|
|
# Decode bytes if needed
|
|
if isinstance(fileData, bytes):
|
|
docData = fileData.decode('utf-8')
|
|
else:
|
|
docData = str(fileData)
|
|
|
|
# Try to parse as JSON
|
|
analysisData = json.loads(docData)
|
|
sqlQuery = analysisData.get("sqlQuery")
|
|
|
|
if sqlQuery:
|
|
logger.info(f"Extracted SQL query from analysis_result document: {sqlQuery[:100]}...")
|
|
break
|
|
except (json.JSONDecodeError, AttributeError, KeyError, TypeError) as e:
|
|
logger.debug(f"Could not parse document as JSON: {e}")
|
|
continue
|
|
|
|
if not sqlQuery:
|
|
return ActionResult.isFailure(error="SQL query is required. Provide sqlQuery parameter or analysis_result document with sqlQuery field.")
|
|
|
|
# Update progress
|
|
self.services.chat.progressLogUpdate(operationId, 0.3, "Validating query")
|
|
|
|
# Validate: only SELECT queries allowed
|
|
sqlNormalized = sqlQuery.strip().upper()
|
|
if not sqlNormalized.startswith("SELECT"):
|
|
return ActionResult.isFailure(error="Only SELECT queries are allowed")
|
|
forbiddenKeywords = ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "TRUNCATE", "EXEC", "EXECUTE"]
|
|
for kw in forbiddenKeywords:
|
|
if f" {kw} " in f" {sqlNormalized} " or sqlNormalized.startswith(f"{kw} "):
|
|
return ActionResult.isFailure(error=f"Forbidden SQL keyword detected: {kw}")
|
|
|
|
# Initialize connector
|
|
connector = PreprocessorConnector()
|
|
|
|
# Update progress
|
|
self.services.chat.progressLogUpdate(operationId, 0.5, "Executing query")
|
|
|
|
try:
|
|
result = await connector.executeQuery(sqlQuery)
|
|
except Exception:
|
|
await connector.close()
|
|
raise
|
|
|
|
# Update progress
|
|
self.services.chat.progressLogUpdate(operationId, 0.8, "Formatting results")
|
|
|
|
# Generate meaningful filename
|
|
meaningful_name = self._generateMeaningfulFileName(
|
|
base_name="database_query",
|
|
extension="txt",
|
|
action_name="queryDatabase"
|
|
)
|
|
|
|
# Create validation metadata
|
|
validationMetadata = self._createValidationMetadata(
|
|
"queryDatabase",
|
|
sqlQuery=sqlQuery[:200] if len(sqlQuery) > 200 else sqlQuery, # Truncate for metadata
|
|
resultLength=len(result)
|
|
)
|
|
|
|
# Create action document
|
|
document = ActionDocument(
|
|
documentName=meaningful_name,
|
|
documentData=result,
|
|
mimeType="text/plain",
|
|
validationMetadata=validationMetadata
|
|
)
|
|
|
|
# Complete progress tracking
|
|
self.services.chat.progressLogFinish(operationId, True)
|
|
|
|
# Close connector
|
|
await connector.close()
|
|
|
|
return ActionResult.isSuccess(documents=[document])
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error executing database query: {str(e)}")
|
|
|
|
try:
|
|
self.services.chat.progressLogFinish(operationId, False)
|
|
except Exception:
|
|
pass
|
|
|
|
return ActionResult.isFailure(
|
|
error=str(e)
|
|
)
|
|
|