base for integration test
This commit is contained in:
parent
fe27f51ebb
commit
ecf23255d2
32 changed files with 1095 additions and 976 deletions
23
app.py
23
app.py
|
|
@ -13,19 +13,20 @@ import pathlib
|
|||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
def initLogging():
|
||||
"""Initialize logging with configuration from APP_CONFIG"""
|
||||
# Get log level from config (default to INFO if not found)
|
||||
logLevelName = APP_CONFIG.get("APP_LOGGING_LOG_LEVEL", "WARNING")
|
||||
logLevel = getattr(logging, logLevelName)
|
||||
|
||||
# Create formatters
|
||||
# Create formatters - using single line format
|
||||
consoleFormatter = logging.Formatter(
|
||||
fmt=APP_CONFIG.get("APP_LOGGING_FORMAT", "%(asctime)s - %(levelname)s - %(name)s - %(message)s"),
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S")
|
||||
)
|
||||
|
||||
# File formatter with more detailed error information
|
||||
# File formatter with more detailed error information but still single line
|
||||
fileFormatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s - %(pathname)s:%(lineno)d\n%(funcName)s\n%(exc_info)s",
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s - %(pathname)s:%(lineno)d - %(funcName)s",
|
||||
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S")
|
||||
)
|
||||
|
||||
|
|
@ -66,17 +67,22 @@ def initLogging():
|
|||
# Configure the root logger
|
||||
logging.basicConfig(
|
||||
level=logLevel,
|
||||
format=APP_CONFIG.get("APP_LOGGING_FORMAT", "%(asctime)s - %(levelname)s - %(name)s - %(message)s"),
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S"),
|
||||
handlers=handlers
|
||||
handlers=handlers,
|
||||
force=True # Force reconfiguration of the root logger
|
||||
)
|
||||
|
||||
|
||||
# Silence noisy third-party libraries - use the same level as the root logger
|
||||
noisyLoggers = ["httpx", "urllib3", "asyncio", "fastapi.security.oauth2"]
|
||||
for loggerName in noisyLoggers:
|
||||
logging.getLogger(loggerName).setLevel(logLevel)
|
||||
|
||||
# Log the current logging configuration
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Logging initialized with level {logLevelName}")
|
||||
logger.info(f"Log file: {logFile if APP_CONFIG.get('APP_LOGGING_FILE_ENABLED', True) else 'disabled'}")
|
||||
logger.info(f"Console logging: {'enabled' if APP_CONFIG.get('APP_LOGGING_CONSOLE_ENABLED', True) else 'disabled'}")
|
||||
|
||||
# Initialize logging
|
||||
initLogging()
|
||||
|
|
@ -143,6 +149,9 @@ app.include_router(promptRouter)
|
|||
from modules.routes.routeWorkflows import router as workflowRouter
|
||||
app.include_router(workflowRouter)
|
||||
|
||||
from modules.routes.routeSecurityLocal import router as localRouter
|
||||
app.include_router(localRouter)
|
||||
|
||||
from modules.routes.routeSecurityMsft import router as msftRouter
|
||||
app.include_router(msftRouter)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ APP_API_URL = http://localhost:8000
|
|||
# Database Configuration for Application
|
||||
DB_APP_HOST=D:/Temp/_powerondb
|
||||
DB_APP_DATABASE=app
|
||||
DB_APPY_USER=dev_user
|
||||
DB_APP_USER=dev_user
|
||||
DB_APP_PASSWORD_SECRET=dev_password
|
||||
|
||||
# Database Configuration Chat
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ APP_API_URL = https://gateway.poweron-center.net
|
|||
# Database Configuration Application
|
||||
DB_APP_HOST=/home/_powerondb
|
||||
DB_APP_DATABASE=app
|
||||
DB_APPY_USER=dev_user
|
||||
DB_APP_USER=dev_user
|
||||
DB_APP_PASSWORD_SECRET=dev_password
|
||||
|
||||
# Database Configuration Chat
|
||||
|
|
|
|||
|
|
@ -7,8 +7,10 @@ import logging
|
|||
from typing import Dict, Any, List
|
||||
import json
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from modules.workflow.agentBase import AgentBase
|
||||
from modules.interfaces.serviceChatModel import Task, ChatDocument, ChatContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -36,21 +38,21 @@ class AgentCoach(AgentBase):
|
|||
"""Set external dependencies for the agent."""
|
||||
self.setService(serviceBase)
|
||||
|
||||
async def processTask(self, task: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def processTask(self, task: Task) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a task by directly using AI to provide answers or content based on extracted data.
|
||||
|
||||
Args:
|
||||
task: Task dictionary with prompt, inputDocuments, outputSpecifications
|
||||
task: Task object with prompt, inputDocuments, outputSpecifications
|
||||
|
||||
Returns:
|
||||
Dictionary with feedback and documents
|
||||
"""
|
||||
try:
|
||||
# Extract task information
|
||||
prompt = task.get("prompt", "")
|
||||
inputDocuments = task.get("inputDocuments", [])
|
||||
outputSpecs = task.get("outputSpecifications", [])
|
||||
prompt = task.prompt
|
||||
inputDocuments = task.filesInput
|
||||
outputSpecs = task.filesOutput
|
||||
|
||||
# Check AI service
|
||||
if not self.service or not self.service.base:
|
||||
|
|
@ -113,7 +115,7 @@ class AgentCoach(AgentBase):
|
|||
"documents": []
|
||||
}
|
||||
|
||||
def _collectExtractedData(self, documents: List[Dict[str, Any]]) -> str:
|
||||
def _collectExtractedData(self, documents: List[ChatDocument]) -> str:
|
||||
"""
|
||||
Collect extracted data from input documents.
|
||||
|
||||
|
|
@ -126,16 +128,16 @@ class AgentCoach(AgentBase):
|
|||
contextParts = []
|
||||
|
||||
for doc in documents:
|
||||
docName = doc.get("name", "unnamed")
|
||||
if doc.get("ext"):
|
||||
docName = f"{docName}.{doc.get('ext')}"
|
||||
docName = doc.name
|
||||
if doc.ext:
|
||||
docName = f"{docName}.{doc.ext}"
|
||||
|
||||
contextParts.append(f"\n\n--- {docName} ---\n")
|
||||
|
||||
# Process contents, focusing on dataExtracted field
|
||||
for content in doc.get("contents", []):
|
||||
if content.get("dataExtracted"):
|
||||
contextParts.append(content.get("dataExtracted", ""))
|
||||
for content in doc.contents:
|
||||
if content.data:
|
||||
contextParts.append(content.data)
|
||||
|
||||
return "\n".join(contextParts)
|
||||
|
||||
|
|
@ -208,7 +210,7 @@ class AgentCoach(AgentBase):
|
|||
}
|
||||
|
||||
async def _generateDocument(self, prompt: str, context: str, outputLabel: str,
|
||||
outputFormat: str, description: str, taskUnderstanding: Dict) -> Dict[str, Any]:
|
||||
outputFormat: str, description: str, taskUnderstanding: Dict) -> ChatDocument:
|
||||
"""
|
||||
Generate a document based on the request and extracted data.
|
||||
|
||||
|
|
@ -221,7 +223,7 @@ class AgentCoach(AgentBase):
|
|||
taskUnderstanding: Task understanding from analysis
|
||||
|
||||
Returns:
|
||||
Document object
|
||||
ChatDocument object
|
||||
"""
|
||||
# Determine content type based on format
|
||||
contentType = self._getContentType(outputFormat)
|
||||
|
|
@ -248,52 +250,52 @@ class AgentCoach(AgentBase):
|
|||
4. Be comprehensive but focused
|
||||
5. Include appropriate formatting, structure, and organization
|
||||
|
||||
Your response should be in valid {outputFormat} format without explanations or markdown formatting around it.
|
||||
Only return the content. No explanations or additional text.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Build system prompt based on format
|
||||
systemPrompt = f"You create {outputFormat} format content based on requests and extracted data. Provide only the content in valid {outputFormat} format."
|
||||
|
||||
# Generate content with AI
|
||||
# Get content from AI
|
||||
content = await self.service.base.callAi([
|
||||
{"role": "system", "content": systemPrompt},
|
||||
{"role": "system", "content": f"You are a content generation expert. Create content in {outputFormat} format."},
|
||||
{"role": "user", "content": generationPrompt}
|
||||
])
|
||||
|
||||
# Process content based on format
|
||||
if outputFormat in ["json", "csv"]:
|
||||
# For structured formats, extract from code blocks if present
|
||||
# Extract content from code blocks if present
|
||||
content = self._extractFromCodeBlocks(content, outputFormat)
|
||||
|
||||
# Validate JSON if needed
|
||||
if outputFormat == "json":
|
||||
try:
|
||||
json.loads(content)
|
||||
except:
|
||||
logger.warning("Invalid JSON generated, attempting to fix")
|
||||
# Try to extract just the JSON portion
|
||||
jsonStart = content.find('{')
|
||||
jsonEnd = content.rfind('}') + 1
|
||||
if jsonStart >= 0 and jsonEnd > jsonStart:
|
||||
content = content[jsonStart:jsonEnd]
|
||||
|
||||
# Ensure proper structure for markdown/HTML
|
||||
if outputFormat in ["md", "markdown"] and not content.strip().startswith("#"):
|
||||
title = "Response"
|
||||
content = f"# {title}\n\n{content}"
|
||||
elif outputFormat == "html" and not "<html" in content.lower():
|
||||
title = "Response"
|
||||
content = f"<html><head><title>{title}</title></head><body><h1>{title}</h1>{content}</body></html>"
|
||||
|
||||
return self.formatAgentDocumentOutput(outputLabel, content, contentType)
|
||||
# Create document object
|
||||
return ChatDocument(
|
||||
id=str(uuid.uuid4()),
|
||||
name=outputLabel.split('.')[0],
|
||||
ext=outputFormat,
|
||||
data=content,
|
||||
contents=[
|
||||
ChatContent(
|
||||
name="main",
|
||||
data=content,
|
||||
summary=description,
|
||||
metadata={"format": outputFormat}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating document: {str(e)}")
|
||||
|
||||
# Create error document
|
||||
errorContent = self._createErrorContent(str(e), outputFormat)
|
||||
return self.formatAgentDocumentOutput(outputLabel, errorContent, contentType)
|
||||
return ChatDocument(
|
||||
id=str(uuid.uuid4()),
|
||||
name=outputLabel.split('.')[0],
|
||||
ext=outputFormat,
|
||||
data=errorContent,
|
||||
contents=[
|
||||
ChatContent(
|
||||
name="error",
|
||||
data=errorContent,
|
||||
summary="Error generating content",
|
||||
metadata={"format": outputFormat, "error": str(e)}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def _getContentType(self, outputFormat: str) -> str:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Provides code generation, execution, and improvement capabilities.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from typing import Dict, Any, List, Tuple, Optional
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
|
@ -14,9 +14,11 @@ import shutil
|
|||
import venv
|
||||
import importlib.util
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from modules.workflow.agentBase import AgentBase
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.interfaces.serviceChatModel import Task, ChatDocument, ChatContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -46,7 +48,7 @@ class AgentCoder(AgentBase):
|
|||
"""Set external dependencies for the agent."""
|
||||
self.setService(serviceBase)
|
||||
|
||||
async def processTask(self, task: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def processTask(self, task: Task) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a task and perform code development/execution.
|
||||
First checks if the task can be completed without code execution,
|
||||
|
|
@ -54,15 +56,15 @@ class AgentCoder(AgentBase):
|
|||
Enhanced to ensure all generated documents are included in output.
|
||||
|
||||
Args:
|
||||
task: Task dictionary with prompt, inputDocuments, outputSpecifications
|
||||
task: Task object with prompt, inputDocuments, outputSpecifications
|
||||
|
||||
Returns:
|
||||
Dictionary with feedback and documents
|
||||
"""
|
||||
# 1. Extract task information
|
||||
prompt = task.get("prompt", "")
|
||||
inputDocuments = task.get("inputDocuments", [])
|
||||
outputSpecs = task.get("outputSpecifications", [])
|
||||
prompt = task.prompt
|
||||
inputDocuments = task.filesInput
|
||||
outputSpecs = task.filesOutput
|
||||
|
||||
# Check if AI service is available
|
||||
if not self.service or not self.service.base:
|
||||
|
|
@ -79,39 +81,39 @@ class AgentCoder(AgentBase):
|
|||
|
||||
for doc in inputDocuments:
|
||||
# Create proper filename from name and ext
|
||||
filename = f"{doc.get('name')}.{doc.get('ext')}" if doc.get('ext') else doc.get('name')
|
||||
filename = f"{doc.name}.{doc.ext}" if doc.ext else doc.name
|
||||
|
||||
# Add main document data to documentData if it exists
|
||||
docData = doc.get('data', '')
|
||||
docData = doc.data
|
||||
if docData:
|
||||
isBase64 = True # Assume base64 encoded for document data
|
||||
documentData.append([filename, docData, isBase64])
|
||||
|
||||
# Process contents for different uses
|
||||
if doc.get('contents'):
|
||||
for content in doc.get('contents', []):
|
||||
contentName = content.get('name', 'unnamed')
|
||||
if doc.contents:
|
||||
for content in doc.contents:
|
||||
contentName = content.name
|
||||
|
||||
# For AI-extracted data (quick completion)
|
||||
if content.get('dataExtracted'):
|
||||
if content.data:
|
||||
contentExtraction.append({
|
||||
"filename": filename,
|
||||
"contentName": contentName,
|
||||
"contentData": content.get('dataExtracted', ''),
|
||||
"contentType": content.get('contentType', ''),
|
||||
"summary": content.get('summary', '')
|
||||
"contentData": content.data,
|
||||
"contentType": content.contentType,
|
||||
"summary": content.summary
|
||||
})
|
||||
|
||||
# For raw content data
|
||||
if content.get('data'):
|
||||
rawData = content.get('data', '')
|
||||
isBase64 = content.get('metadata', {}).get('base64Encoded', False)
|
||||
if content.data:
|
||||
rawData = content.data
|
||||
isBase64 = content.metadata.get('base64Encoded', False) if content.metadata else False
|
||||
contentData.append({
|
||||
"filename": filename,
|
||||
"contentName": contentName,
|
||||
"data": rawData,
|
||||
"isBase64": isBase64,
|
||||
"contentType": content.get('contentType', '')
|
||||
"contentType": content.contentType
|
||||
})
|
||||
|
||||
# Also add to documentData for code execution if not already added
|
||||
|
|
@ -239,7 +241,7 @@ class AgentCoder(AgentBase):
|
|||
# Override the base64Encoded flag with the value from the result
|
||||
# This is needed since formatAgentDocumentOutput might determine a different value
|
||||
if isinstance(base64Encoded, bool):
|
||||
doc["base64Encoded"] = base64Encoded
|
||||
doc.base64Encoded = base64Encoded
|
||||
|
||||
documents.append(doc)
|
||||
createdOutputs.add(finalLabel)
|
||||
|
|
@ -248,36 +250,14 @@ class AgentCoder(AgentBase):
|
|||
# Not properly structured - log warning
|
||||
logger.warning(f"Skipping improperly formatted result for '{label}'. Results must include 'content' field.")
|
||||
else:
|
||||
# No result dictionary found
|
||||
logger.warning("No valid result dictionary found or it's not properly formatted")
|
||||
|
||||
# If no valid documents were created from the result dictionary but we have output specifications
|
||||
if len(documents) <= 2 and outputSpecs: # Only code.py and history.json exist
|
||||
logger.warning("No valid documents created from result dictionary, using execution output for specifications")
|
||||
# Default to execution output
|
||||
output = executionResult.get("output", "")
|
||||
for spec in outputSpecs:
|
||||
label = spec.get("label", "output.txt")
|
||||
# Create basic document from output
|
||||
doc = self.formatAgentDocumentOutput(label, output, "text/plain")
|
||||
# Handle non-dictionary results
|
||||
logger.warning("Execution result is not a dictionary. Creating a single output document.")
|
||||
doc = self.formatAgentDocumentOutput("result.txt", str(resultData), "text/plain")
|
||||
documents.append(doc)
|
||||
logger.info(f"Created document from output specification: {label}")
|
||||
|
||||
if retryCount > 0:
|
||||
feedback = f"Code executed successfully after {retryCount + 1} attempts. Generated {len(documents) - 2} output files."
|
||||
else:
|
||||
feedback = f"Code executed successfully. Generated {len(documents) - 2} output files."
|
||||
else:
|
||||
# Execution failed
|
||||
error = executionResult.get("error", "Unknown error")
|
||||
documents.append(self.formatAgentDocumentOutput("execution_error.txt", f"Error executing code:\n\n{error}", "text/plain"))
|
||||
if retryCount > 0:
|
||||
feedback = f"Error during code execution after {retryCount + 1} attempts: {error}"
|
||||
else:
|
||||
feedback = f"Error during code execution: {error}"
|
||||
|
||||
# 8. Return results
|
||||
return {
|
||||
"feedback": feedback,
|
||||
"feedback": "Code execution completed successfully." if executionResult.get("success", False) else f"Code execution failed: {executionResult.get('error', 'Unknown error')}",
|
||||
"documents": documents
|
||||
}
|
||||
|
||||
|
|
@ -393,7 +373,7 @@ Return ONLY Python code without explanations or markdown.
|
|||
return None, []
|
||||
|
||||
|
||||
async def _checkQuickCompletion(self, prompt: str, contentExtraction: List[Dict], outputSpecs: List[Dict]) -> Dict:
|
||||
async def _checkQuickCompletion(self, prompt: str, contentExtraction: List[ChatDocument], outputSpecs: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Check if the task can be completed without writing and executing code.
|
||||
|
||||
|
|
@ -411,7 +391,7 @@ Return ONLY Python code without explanations or markdown.
|
|||
|
||||
# Create a prompt for the AI to check if this can be completed directly
|
||||
specsJson = json.dumps(outputSpecs)
|
||||
dataJson = json.dumps(contentExtraction)
|
||||
dataJson = json.dumps([doc.dict() for doc in contentExtraction])
|
||||
|
||||
checkPrompt = f"""
|
||||
Analyze this task and determine if it can be completed directly without writing code.
|
||||
|
|
@ -478,7 +458,7 @@ Only return valid JSON. Your entire response must be parseable as JSON.
|
|||
# Default to requiring code execution
|
||||
return None
|
||||
|
||||
async def _generateCode(self, prompt: str, outputSpecs: List[Dict[str, Any]] = None) -> Tuple[str, List[str]]:
|
||||
async def _generateCode(self, prompt: str, outputSpecs: List[ChatDocument] = None) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
Generate Python code from a prompt with the inputFiles placeholder.
|
||||
Enhanced to emphasize proper result output handling with correct document structure.
|
||||
|
|
@ -1019,6 +999,38 @@ Return ONLY Python code without explanations or markdown.
|
|||
cleanedCode = '\n'.join(lines[startIndex:endIndex])
|
||||
return cleanedCode.strip()
|
||||
|
||||
def formatAgentDocumentOutput(self, filename: str, content: str, contentType: str) -> ChatDocument:
|
||||
"""
|
||||
Format a document for agent output.
|
||||
|
||||
Args:
|
||||
filename: Output filename
|
||||
content: Document content
|
||||
contentType: MIME type of the content
|
||||
|
||||
Returns:
|
||||
ChatDocument object
|
||||
"""
|
||||
# Split filename into name and extension
|
||||
name, ext = os.path.splitext(filename)
|
||||
if ext.startswith('.'):
|
||||
ext = ext[1:]
|
||||
|
||||
# Create document object
|
||||
return ChatDocument(
|
||||
id=str(uuid.uuid4()),
|
||||
name=name,
|
||||
ext=ext,
|
||||
data=content,
|
||||
contents=[
|
||||
ChatContent(
|
||||
name="main",
|
||||
data=content,
|
||||
summary=f"Generated {filename}",
|
||||
metadata={"contentType": contentType}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Factory function for the Coder agent
|
||||
def getAgentCoder():
|
||||
|
|
|
|||
|
|
@ -5,8 +5,12 @@ Handles email-related tasks using Microsoft Graph API.
|
|||
|
||||
import logging
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
import uuid
|
||||
import os
|
||||
|
||||
from modules.workflow.agentBase import AgentBase
|
||||
from modules.interfaces.serviceChatModel import Task, ChatDocument, ChatContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -30,7 +34,7 @@ class AgentEmail(AgentBase):
|
|||
"""Set external dependencies for the agent."""
|
||||
self.serviceBase = serviceBase
|
||||
|
||||
async def processTask(self, task: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def processTask(self, task: Task) -> Dict[str, Any]:
|
||||
"""
|
||||
Process an email-related task.
|
||||
|
||||
|
|
@ -48,9 +52,9 @@ class AgentEmail(AgentBase):
|
|||
"""
|
||||
try:
|
||||
# Extract task information
|
||||
prompt = task.get("prompt", "")
|
||||
inputDocuments = task.get("inputDocuments", [])
|
||||
outputSpecs = task.get("outputSpecifications", [])
|
||||
prompt = task.prompt
|
||||
inputDocuments = task.filesInput
|
||||
outputSpecs = task.filesOutput
|
||||
|
||||
# Check AI service
|
||||
if not self.service.base:
|
||||
|
|
@ -148,26 +152,39 @@ class AgentEmail(AgentBase):
|
|||
"documents": []
|
||||
}
|
||||
|
||||
def _createFrontendAuthTriggerDocument(self) -> Dict[str, Any]:
|
||||
def _createFrontendAuthTriggerDocument(self) -> ChatDocument:
|
||||
"""Create a document that triggers Microsoft authentication in the frontend."""
|
||||
return {
|
||||
"name": "microsoft_auth",
|
||||
"ext": "html",
|
||||
"mimeType": "text/html",
|
||||
"data": """
|
||||
return ChatDocument(
|
||||
id=str(uuid.uuid4()),
|
||||
name="microsoft_auth",
|
||||
ext="html",
|
||||
data="""
|
||||
<div>
|
||||
<h2>Microsoft Authentication Required</h2>
|
||||
<p>Please click the button below to authenticate with Microsoft:</p>
|
||||
<button onclick="window.location.href='/api/auth/microsoft'">Authenticate with Microsoft</button>
|
||||
</div>
|
||||
""",
|
||||
"base64Encoded": False,
|
||||
"metadata": {
|
||||
contents=[
|
||||
ChatContent(
|
||||
name="main",
|
||||
data="""
|
||||
<div>
|
||||
<h2>Microsoft Authentication Required</h2>
|
||||
<p>Please click the button below to authenticate with Microsoft:</p>
|
||||
<button onclick="window.location.href='/api/auth/microsoft'">Authenticate with Microsoft</button>
|
||||
</div>
|
||||
""",
|
||||
summary="Microsoft authentication trigger page",
|
||||
metadata={
|
||||
"contentType": "text/html",
|
||||
"isText": True
|
||||
}
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def _processInputDocuments(self, input_docs: List[Dict[str, Any]]) -> tuple:
|
||||
def _processInputDocuments(self, input_docs: List[ChatDocument]) -> Tuple[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Process input documents to extract content and prepare attachments.
|
||||
|
||||
|
|
@ -181,22 +198,22 @@ class AgentEmail(AgentBase):
|
|||
attachments = []
|
||||
|
||||
for doc in input_docs:
|
||||
docName = doc.get("name", "unnamed")
|
||||
if doc.get("ext"):
|
||||
docName = f"{docName}.{doc.get('ext')}"
|
||||
docName = doc.name
|
||||
if doc.ext:
|
||||
docName = f"{docName}.{doc.ext}"
|
||||
|
||||
# Add document name to contents
|
||||
documentContents.append(f"\n\n--- {docName} ---\n")
|
||||
|
||||
# Process document data directly
|
||||
if doc.get("data"):
|
||||
if doc.data:
|
||||
# Add to attachments with proper metadata
|
||||
attachments.append({
|
||||
"name": docName,
|
||||
"document": {
|
||||
"data": doc["data"],
|
||||
"mimeType": doc.get("mimeType", "application/octet-stream"),
|
||||
"base64Encoded": doc.get("base64Encoded", False)
|
||||
"data": doc.data,
|
||||
"mimeType": doc.contents[0].metadata.get("contentType", "application/octet-stream") if doc.contents else "application/octet-stream",
|
||||
"base64Encoded": doc.contents[0].metadata.get("base64Encoded", False) if doc.contents else False
|
||||
}
|
||||
})
|
||||
documentContents.append(f"Document attached: {docName}")
|
||||
|
|
@ -205,6 +222,39 @@ class AgentEmail(AgentBase):
|
|||
|
||||
return "\n".join(documentContents), attachments
|
||||
|
||||
def formatAgentDocumentOutput(self, filename: str, content: str, contentType: str) -> ChatDocument:
|
||||
"""
|
||||
Format a document for agent output.
|
||||
|
||||
Args:
|
||||
filename: Output filename
|
||||
content: Document content
|
||||
contentType: MIME type of the content
|
||||
|
||||
Returns:
|
||||
ChatDocument object
|
||||
"""
|
||||
# Split filename into name and extension
|
||||
name, ext = os.path.splitext(filename)
|
||||
if ext.startswith('.'):
|
||||
ext = ext[1:]
|
||||
|
||||
# Create document object
|
||||
return ChatDocument(
|
||||
id=str(uuid.uuid4()),
|
||||
name=name,
|
||||
ext=ext,
|
||||
data=content,
|
||||
contents=[
|
||||
ChatContent(
|
||||
name="main",
|
||||
data=content,
|
||||
summary=f"Generated {filename}",
|
||||
metadata={"contentType": contentType}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def _generateEmailTemplate(self, prompt: str, documentContents: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate email template using AI.
|
||||
|
|
|
|||
|
|
@ -409,11 +409,17 @@ class DatabaseConnector:
|
|||
with open(recordPath, 'w', encoding='utf-8') as f:
|
||||
json.dump(recordData, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Save metadata
|
||||
# Update metadata with new record ID
|
||||
if recordData["id"] not in metadata["recordIds"]:
|
||||
metadata["recordIds"].append(recordData["id"])
|
||||
metadata["recordIds"].sort()
|
||||
|
||||
# Save updated metadata
|
||||
if not self._saveTableMetadata(table, metadata):
|
||||
raise ValueError(f"Error saving metadata for table {table}")
|
||||
|
||||
# Update cache safely
|
||||
# Update both caches
|
||||
self._tableMetadataCache[table] = metadata
|
||||
if table in self._tablesCache:
|
||||
if isinstance(self._tablesCache[table], list):
|
||||
self._tablesCache[table].append(recordData)
|
||||
|
|
@ -526,9 +532,7 @@ class DatabaseConnector:
|
|||
"""Returns the initial ID for a table."""
|
||||
systemData = self._loadSystemTable()
|
||||
initialId = systemData.get(table)
|
||||
logger.debug(f"Database '{self.dbDatabase}': Table: {systemData}, Initial ID for table '{table}' is {initialId}")
|
||||
if initialId is None:
|
||||
logger.debug(f"No initial ID found for table {table}")
|
||||
logger.debug(f"Initial ID for table '{table}': {initialId}")
|
||||
return initialId
|
||||
|
||||
def getAllInitialIds(self) -> Dict[str, str]:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Access control for the Application.
|
|||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
from modules.interfaces.serviceAppModel import UserPrivilege, Session
|
||||
from modules.interfaces.serviceAppModel import UserPrivilege, Session, User
|
||||
|
||||
class AppAccess:
|
||||
"""
|
||||
|
|
@ -12,12 +12,12 @@ class AppAccess:
|
|||
Handles user access management and permission checks.
|
||||
"""
|
||||
|
||||
def __init__(self, currentUser: Dict[str, Any], db):
|
||||
def __init__(self, currentUser: User, db):
|
||||
"""Initialize with user context."""
|
||||
self.currentUser = currentUser
|
||||
self.mandateId = currentUser.get("mandateId")
|
||||
self.userId = currentUser.get("id")
|
||||
self.privilege = currentUser.get("privilege", UserPrivilege.USER)
|
||||
self.userId = currentUser.id
|
||||
self.mandateId = currentUser.mandateId
|
||||
self.privilege = currentUser.privilege
|
||||
|
||||
if not self.mandateId or not self.userId:
|
||||
raise ValueError("Invalid user context: mandateId and userId are required")
|
||||
|
|
|
|||
|
|
@ -37,12 +37,12 @@ class GatewayInterface:
|
|||
Manages users and mandates.
|
||||
"""
|
||||
|
||||
def __init__(self, currentUser: Dict[str, Any] = None):
|
||||
def __init__(self, currentUser: Optional[User] = None):
|
||||
"""Initializes the Gateway Interface."""
|
||||
# Initialize variables
|
||||
self.currentUser = currentUser
|
||||
self.userId = currentUser.get("id") if currentUser else None
|
||||
self.mandateId = currentUser.get("mandateId") if currentUser else None
|
||||
self.currentUser = currentUser # Store User object directly
|
||||
self.userId = currentUser.id if currentUser else None
|
||||
self.mandateId = currentUser.mandateId if currentUser else None
|
||||
self.access = None # Will be set when user context is provided
|
||||
|
||||
# Initialize database
|
||||
|
|
@ -55,24 +55,24 @@ class GatewayInterface:
|
|||
if currentUser:
|
||||
self.setUserContext(currentUser)
|
||||
|
||||
def setUserContext(self, currentUser: Dict[str, Any]):
|
||||
def setUserContext(self, currentUser: User):
|
||||
"""Sets the user context for the interface."""
|
||||
if not currentUser:
|
||||
logger.info("Initializing interface without user context")
|
||||
return
|
||||
|
||||
self.currentUser = currentUser
|
||||
self.userId = currentUser.get("id")
|
||||
self.mandateId = currentUser.get("mandateId")
|
||||
self.currentUser = currentUser # Store User object directly
|
||||
self.userId = currentUser.id
|
||||
self.mandateId = currentUser.mandateId
|
||||
|
||||
if not self.userId or not self.mandateId:
|
||||
raise ValueError("Invalid user context: id and mandateId are required")
|
||||
|
||||
# Add language settings
|
||||
self.userLanguage = currentUser.get("language", "en") # Default user language
|
||||
self.userLanguage = currentUser.language # Default user language
|
||||
|
||||
# Initialize access control with user context
|
||||
self.access = AppAccess(self.currentUser, self.db)
|
||||
self.access = AppAccess(self.currentUser, self.db) # Convert to dict only when needed
|
||||
|
||||
logger.debug(f"User context set: userId={self.userId}, mandateId={self.mandateId}")
|
||||
|
||||
|
|
@ -115,7 +115,7 @@ class GatewayInterface:
|
|||
name="Root",
|
||||
language="en"
|
||||
)
|
||||
createdMandate = self.db.recordCreate("mandates", rootMandate.model_dump())
|
||||
createdMandate = self.db.recordCreate("mandates", rootMandate.to_dict())
|
||||
logger.info(f"Root mandate created with ID {createdMandate['id']}")
|
||||
|
||||
# Register the initial ID
|
||||
|
|
@ -142,7 +142,7 @@ class GatewayInterface:
|
|||
hashedPassword=self._getPasswordHash("The 1st Poweron Admin"), # Use a secure password in production!
|
||||
connections=[]
|
||||
)
|
||||
createdUser = self.db.recordCreate("users", adminUser.model_dump())
|
||||
createdUser = self.db.recordCreate("users", adminUser.to_dict())
|
||||
logger.info(f"Admin user created with ID {createdUser['id']}")
|
||||
|
||||
# Register the initial ID
|
||||
|
|
@ -219,10 +219,10 @@ class GatewayInterface:
|
|||
return None
|
||||
|
||||
# Find user by username
|
||||
for user in users:
|
||||
if user.get("username") == username:
|
||||
for user_dict in users:
|
||||
if user_dict.get("username") == username:
|
||||
logger.info(f"Found user with username {username}")
|
||||
return User.from_dict(user)
|
||||
return User.from_dict(user_dict)
|
||||
|
||||
logger.info(f"No user found with username {username}")
|
||||
return None
|
||||
|
|
@ -270,7 +270,7 @@ class GatewayInterface:
|
|||
user.connections.append(connection)
|
||||
|
||||
# Update user record
|
||||
self.db.recordModify("users", userId, {"connections": [c.model_dump() for c in user.connections]})
|
||||
self.db.recordModify("users", userId, {"connections": [c.to_dict() for c in user.connections]})
|
||||
|
||||
return connection
|
||||
|
||||
|
|
@ -290,7 +290,7 @@ class GatewayInterface:
|
|||
user.connections = [c for c in user.connections if c.id != connectionId]
|
||||
|
||||
# Update user record
|
||||
self.db.recordModify("users", userId, {"connections": [c.model_dump() for c in user.connections]})
|
||||
self.db.recordModify("users", userId, {"connections": [c.to_dict() for c in user.connections]})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing user connection: {str(e)}")
|
||||
|
|
@ -317,8 +317,11 @@ class GatewayInterface:
|
|||
raise ValueError("User does not have local authentication enabled")
|
||||
|
||||
# Get the full user record with password hash for verification
|
||||
userWithPassword = UserInDB(**self.db.getRecordset("users", recordFilter={"id": user.id})[0])
|
||||
if not self._verifyPassword(password, userWithPassword.hashedPassword):
|
||||
userRecord = self.db.getRecordset("users", recordFilter={"id": user.id})[0]
|
||||
if not userRecord.get("hashedPassword"):
|
||||
raise ValueError("User has no password set")
|
||||
|
||||
if not self._verifyPassword(password, userRecord["hashedPassword"]):
|
||||
raise ValueError("Invalid password")
|
||||
|
||||
return user
|
||||
|
|
@ -331,6 +334,18 @@ class GatewayInterface:
|
|||
externalEmail: str = None) -> User:
|
||||
"""Create a new user with optional external connection"""
|
||||
try:
|
||||
# Ensure username is a string
|
||||
username = str(username).strip()
|
||||
|
||||
# Validate password for local authentication
|
||||
if authenticationAuthority == AuthAuthority.LOCAL:
|
||||
if not password:
|
||||
raise ValueError("Password is required for local authentication")
|
||||
if not isinstance(password, str):
|
||||
raise ValueError("Password must be a string")
|
||||
if not password.strip():
|
||||
raise ValueError("Password cannot be empty")
|
||||
|
||||
# Create user data using UserInDB model
|
||||
userData = UserInDB(
|
||||
username=username,
|
||||
|
|
@ -365,9 +380,11 @@ class GatewayInterface:
|
|||
if not createdUser or len(createdUser) == 0:
|
||||
raise ValueError("Failed to retrieve created user")
|
||||
|
||||
# Clear users table from cache
|
||||
# Clear both table and metadata caches
|
||||
if hasattr(self.db, '_tablesCache') and "users" in self.db._tablesCache:
|
||||
del self.db._tablesCache["users"]
|
||||
if hasattr(self.db, '_tableMetadataCache') and "users" in self.db._tableMetadataCache:
|
||||
del self.db._tableMetadataCache["users"]
|
||||
|
||||
return User.from_dict(createdUser[0])
|
||||
|
||||
|
|
@ -387,7 +404,7 @@ class GatewayInterface:
|
|||
raise ValueError(f"User {userId} not found")
|
||||
|
||||
# Update user data using model
|
||||
updatedData = user.model_dump()
|
||||
updatedData = user.to_dict()
|
||||
updatedData.update(updateData)
|
||||
updatedUser = User.from_dict(updatedData)
|
||||
|
||||
|
|
@ -488,7 +505,7 @@ class GatewayInterface:
|
|||
raise ValueError(f"Mandate {mandateId} not found")
|
||||
|
||||
# Update mandate data using model
|
||||
updatedData = mandate.model_dump()
|
||||
updatedData = mandate.to_dict()
|
||||
updatedData.update(updateData)
|
||||
updatedMandate = Mandate.from_dict(updatedData)
|
||||
|
||||
|
|
@ -529,20 +546,65 @@ class GatewayInterface:
|
|||
logger.error(f"Error deleting mandate: {str(e)}")
|
||||
raise ValueError(f"Failed to delete mandate: {str(e)}")
|
||||
|
||||
def _getInitialUser(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get the initial user record directly from database without access control."""
|
||||
try:
|
||||
initialUserId = self.db.getInitialId("users")
|
||||
if not initialUserId:
|
||||
return None
|
||||
|
||||
users = self.db.getRecordset("users", recordFilter={"id": initialUserId})
|
||||
return users[0] if users else None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting initial user: {str(e)}")
|
||||
return None
|
||||
|
||||
def checkUsernameAvailability(self, checkData: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Checks if a username is available for registration."""
|
||||
try:
|
||||
username = checkData.get("username")
|
||||
authenticationAuthority = checkData.get("authenticationAuthority", "local")
|
||||
|
||||
if not username:
|
||||
return {
|
||||
"available": False,
|
||||
"message": "Username is required"
|
||||
}
|
||||
|
||||
# Get user by username
|
||||
user = self.getUserByUsername(username)
|
||||
|
||||
# Check if user exists (User model instance)
|
||||
if user is not None:
|
||||
return {
|
||||
"available": False,
|
||||
"message": "Username is already taken"
|
||||
}
|
||||
|
||||
return {
|
||||
"available": True,
|
||||
"message": "Username is available"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking username availability: {str(e)}")
|
||||
return {
|
||||
"available": False,
|
||||
"message": f"Error checking username availability: {str(e)}"
|
||||
}
|
||||
|
||||
# Public Methods
|
||||
|
||||
def getInterface(currentUser: Dict[str, Any]) -> GatewayInterface:
|
||||
def getInterface(currentUser: User) -> GatewayInterface:
|
||||
"""
|
||||
Returns a GatewayInterface instance for the current user.
|
||||
Handles initialization of database and records.
|
||||
"""
|
||||
mandateId = currentUser.get("mandateId")
|
||||
userId = currentUser.get("id")
|
||||
if not mandateId or not userId:
|
||||
raise ValueError("Invalid user context: mandateId and id are required")
|
||||
if not currentUser:
|
||||
raise ValueError("Invalid user context: user is required")
|
||||
|
||||
# Create context key
|
||||
contextKey = f"{mandateId}_{userId}"
|
||||
contextKey = f"{currentUser.mandateId}_{currentUser.id}"
|
||||
|
||||
# Create new instance if not exists
|
||||
if contextKey not in _gatewayInterfaces:
|
||||
|
|
@ -550,24 +612,27 @@ def getInterface(currentUser: Dict[str, Any]) -> GatewayInterface:
|
|||
|
||||
return _gatewayInterfaces[contextKey]
|
||||
|
||||
def getRootUser() -> Dict[str, Any]:
|
||||
def getRootUser() -> User:
|
||||
"""
|
||||
Returns the root user from the database.
|
||||
This is the user with the initial ID in the users table.
|
||||
"""
|
||||
try:
|
||||
readInterface = getInterface()
|
||||
# Get the initial user ID
|
||||
initialUserId = readInterface.db.getInitialId("users")
|
||||
# Create a temporary interface without user context
|
||||
tempInterface = GatewayInterface()
|
||||
|
||||
# Get the initial user directly
|
||||
initialUserId = tempInterface.db.getInitialId("users")
|
||||
if not initialUserId:
|
||||
raise ValueError("No initial user ID found in database")
|
||||
|
||||
# Get the user record
|
||||
users = readInterface.db.getRecordset("users", recordFilter={"id": initialUserId})
|
||||
users = tempInterface.db.getRecordset("users", recordFilter={"id": initialUserId})
|
||||
if not users:
|
||||
raise ValueError(f"Root user with ID {initialUserId} not found in database")
|
||||
raise ValueError("Initial user not found in database")
|
||||
|
||||
# Convert to User model and return the model instance
|
||||
return User.from_dict(users[0])
|
||||
|
||||
return users[0]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting root user: {str(e)}")
|
||||
raise ValueError(f"Failed to get root user: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -10,6 +10,17 @@ from enum import Enum
|
|||
|
||||
from modules.shared.attributeUtils import Label, BaseModelWithUI
|
||||
|
||||
class AttributeDefinition(BaseModel):
|
||||
"""Definition of an attribute for UI forms"""
|
||||
name: str = Field(..., description="Name of the attribute")
|
||||
label: str = Field(..., description="Display label for the attribute")
|
||||
type: str = Field(..., description="Type of the attribute (string, number, boolean, etc.)")
|
||||
required: bool = Field(default=False, description="Whether the attribute is required")
|
||||
placeholder: Optional[str] = Field(None, description="Placeholder text for the input")
|
||||
editable: bool = Field(default=True, description="Whether the attribute can be edited")
|
||||
visible: bool = Field(default=True, description="Whether the attribute should be visible in forms")
|
||||
order: int = Field(default=0, description="Order in which to display the attribute")
|
||||
|
||||
class AuthAuthority(str, Enum):
|
||||
"""Authentication authorities"""
|
||||
LOCAL = "local"
|
||||
|
|
@ -151,7 +162,7 @@ class User(BaseModelWithUI):
|
|||
disabled: bool = Field(default=False, description="Indicates whether the user is disabled")
|
||||
privilege: UserPrivilege = Field(default=UserPrivilege.USER, description="Permission level")
|
||||
authenticationAuthority: AuthAuthority = Field(default=AuthAuthority.LOCAL, description="Primary authentication authority")
|
||||
mandateId: str = Field(description="ID of the mandate this user belongs to")
|
||||
mandateId: Optional[str] = Field(None, description="ID of the mandate this user belongs to")
|
||||
connections: List[UserConnection] = Field(default_factory=list, description="List of external service connections")
|
||||
|
||||
label: Label = Field(
|
||||
|
|
|
|||
|
|
@ -25,28 +25,3 @@ class LocalToken(BaseModel):
|
|||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_at: float
|
||||
|
||||
# Token management functions
|
||||
def saveToken(interface, tokenType: str, tokenData: dict) -> bool:
|
||||
"""Save token data for a specific service"""
|
||||
try:
|
||||
return interface.saveToken(f"tokens{tokenType}", tokenData)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving {tokenType} token: {str(e)}")
|
||||
return False
|
||||
|
||||
def getToken(interface, tokenType: str) -> Optional[dict]:
|
||||
"""Get token data for a specific service"""
|
||||
try:
|
||||
return interface.getToken(f"tokens{tokenType}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting {tokenType} token: {str(e)}")
|
||||
return None
|
||||
|
||||
def deleteToken(interface, tokenType: str) -> bool:
|
||||
"""Delete token data for a specific service"""
|
||||
try:
|
||||
return interface.deleteToken(f"tokens{tokenType}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting {tokenType} token: {str(e)}")
|
||||
return False
|
||||
|
|
@ -16,8 +16,9 @@ from modules.interfaces.serviceChatAccess import ChatAccess
|
|||
from modules.interfaces.serviceChatModel import (
|
||||
ChatContent, ChatDocument, ChatStat, ChatMessage,
|
||||
ChatLog, ChatWorkflow, Agent, AgentResponse,
|
||||
TaskItem, TaskPlan, UserInputRequest
|
||||
Task, TaskPlan, UserInputRequest
|
||||
)
|
||||
from modules.interfaces.serviceAppModel import User
|
||||
|
||||
# DYNAMIC PART: Connectors to the Interface
|
||||
from modules.connectors.connectorDbJson import DatabaseConnector
|
||||
|
|
@ -57,43 +58,41 @@ class ChatInterface:
|
|||
Uses the JSON connector for data access with added language support.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, currentUser: Optional[User] = None):
|
||||
"""Initializes the Chat Interface."""
|
||||
# Initialize variables
|
||||
self.currentUser = currentUser # Store User object directly
|
||||
self.userId = currentUser.id if currentUser else None
|
||||
self.mandateId = currentUser.mandateId if currentUser else None
|
||||
self.access = None # Will be set when user context is provided
|
||||
|
||||
# Initialize database
|
||||
self._initializeDatabase()
|
||||
|
||||
# Initialize standard records if needed
|
||||
self._initRecords()
|
||||
# Set user context if provided
|
||||
if currentUser:
|
||||
self.setUserContext(currentUser)
|
||||
|
||||
# Initialize variables
|
||||
self.currentUser = None
|
||||
self.userId = None
|
||||
self.mandateId = None
|
||||
self.access = None # Will be set when user context is provided
|
||||
self.aiService = None # Will be set when user context is provided
|
||||
|
||||
def setUserContext(self, currentUser: Dict[str, Any]):
|
||||
def setUserContext(self, currentUser: User):
|
||||
"""Sets the user context for the interface."""
|
||||
if not currentUser:
|
||||
logger.info("Initializing interface without user context")
|
||||
return
|
||||
|
||||
self.currentUser = currentUser
|
||||
self.userId = currentUser.get("id")
|
||||
self.mandateId = currentUser.get("mandateId")
|
||||
if not self.userId:
|
||||
raise ValueError("Invalid user context: id is required")
|
||||
self.currentUser = currentUser # Store User object directly
|
||||
self.userId = currentUser.id
|
||||
self.mandateId = currentUser.mandateId
|
||||
|
||||
if not self.userId or not self.mandateId:
|
||||
raise ValueError("Invalid user context: id and mandateId are required")
|
||||
|
||||
# Add language settings
|
||||
self.userLanguage = currentUser.get("language", "en") # Default user language
|
||||
self.userLanguage = currentUser.language # Default user language
|
||||
|
||||
# Initialize access control with user context
|
||||
self.access = ChatAccess(self.currentUser, self.db)
|
||||
self.access = ChatAccess(self.currentUser, self.db) # Convert to dict only when needed
|
||||
|
||||
# Initialize AI service
|
||||
self.aiService = ChatService()
|
||||
|
||||
logger.debug(f"User context set: userId={self.userId}")
|
||||
logger.debug(f"User context set: userId={self.userId}, mandateId={self.mandateId}")
|
||||
|
||||
def _initializeDatabase(self):
|
||||
"""Initializes the database connection."""
|
||||
|
|
@ -353,7 +352,7 @@ class ChatInterface:
|
|||
|
||||
# Workflow Messages
|
||||
|
||||
def getWorkflowMessages(self, workflowId: str) -> List[Dict[str, Any]]:
|
||||
def getWorkflowMessages(self, workflowId: str) -> List[ChatMessage]:
|
||||
"""Returns messages for a workflow if user has access to the workflow."""
|
||||
# Check workflow access first
|
||||
workflow = self.getWorkflow(workflowId)
|
||||
|
|
@ -362,7 +361,7 @@ class ChatInterface:
|
|||
|
||||
# Get messages for this workflow
|
||||
messages = self.db.getRecordset("workflowMessages", recordFilter={"workflowId": workflowId})
|
||||
return messages # No further filtering needed since workflow access is already checked
|
||||
return [ChatMessage(**msg) for msg in messages]
|
||||
|
||||
def createWorkflowMessage(self, messageData: Dict[str, Any]) -> ChatMessage:
|
||||
"""Creates a message for a workflow if user has access."""
|
||||
|
|
@ -639,7 +638,7 @@ class ChatInterface:
|
|||
|
||||
# Workflow Logs
|
||||
|
||||
def getWorkflowLogs(self, workflowId: str) -> List[Dict[str, Any]]:
|
||||
def getWorkflowLogs(self, workflowId: str) -> List[ChatLog]:
|
||||
"""Returns logs for a workflow if user has access to the workflow."""
|
||||
# Check workflow access first
|
||||
workflow = self.getWorkflow(workflowId)
|
||||
|
|
@ -647,7 +646,7 @@ class ChatInterface:
|
|||
return []
|
||||
|
||||
# Get logs for this workflow
|
||||
return self.db.getRecordset("workflowLogs", recordFilter={"workflowId": workflowId})
|
||||
return [ChatLog(**log) for log in self.db.getRecordset("workflowLogs", recordFilter={"workflowId": workflowId})]
|
||||
|
||||
def createWorkflowLog(self, logData: Dict[str, Any]) -> ChatLog:
|
||||
"""Creates a log entry for a workflow if user has access."""
|
||||
|
|
@ -695,17 +694,17 @@ class ChatInterface:
|
|||
return None
|
||||
|
||||
# Create log in database
|
||||
createdLog = self.db.recordCreate("workflowLogs", log_model.model_dump())
|
||||
createdLog = self.db.recordCreate("workflowLogs", log_model.to_dict())
|
||||
|
||||
# Return validated ChatLog instance
|
||||
return ChatLog(**createdLog)
|
||||
|
||||
# Workflow Management
|
||||
|
||||
def saveWorkflowState(self, workflow: Dict[str, Any], saveMessages: bool = True, saveLogs: bool = True) -> bool:
|
||||
def saveWorkflowState(self, workflow: ChatWorkflow, saveMessages: bool = True, saveLogs: bool = True) -> bool:
|
||||
"""Saves workflow state if user has access."""
|
||||
try:
|
||||
workflowId = workflow.get("id")
|
||||
workflowId = workflow.id
|
||||
if not workflowId:
|
||||
return False
|
||||
|
||||
|
|
@ -722,12 +721,12 @@ class ChatInterface:
|
|||
# Extract only the database-relevant workflow fields
|
||||
workflowDbData = {
|
||||
"id": workflowId,
|
||||
"mandateId": workflow.get("mandateId", self.currentUser.get("mandateId")),
|
||||
"name": workflow.get("name", f"Workflow {workflowId}"),
|
||||
"status": workflow.get("status", "completed"),
|
||||
"startedAt": workflow.get("startedAt", self._getCurrentTimestamp()),
|
||||
"lastActivity": workflow.get("lastActivity", self._getCurrentTimestamp()),
|
||||
"dataStats": workflow.get("dataStats", {})
|
||||
"mandateId": workflow.mandateId,
|
||||
"name": workflow.name,
|
||||
"status": workflow.status,
|
||||
"startedAt": workflow.startedAt,
|
||||
"lastActivity": workflow.lastActivity,
|
||||
"dataStats": workflow.stats.dict() if workflow.stats else {}
|
||||
}
|
||||
|
||||
# Check if workflow already exists
|
||||
|
|
@ -802,7 +801,7 @@ class ChatInterface:
|
|||
logger.error(f"Error saving workflow state: {str(e)}")
|
||||
return False
|
||||
|
||||
def loadWorkflowState(self, workflowId: str) -> Optional[Dict[str, Any]]:
|
||||
def loadWorkflowState(self, workflowId: str) -> Optional[ChatWorkflow]:
|
||||
"""Loads workflow state if user has access."""
|
||||
try:
|
||||
# Check workflow access
|
||||
|
|
@ -852,21 +851,19 @@ class ChatInterface:
|
|||
return None
|
||||
|
||||
|
||||
def getInterface(currentUser: Dict[str, Any] = None) -> 'ChatInterface':
|
||||
def getInterface(currentUser: Optional[User] = None) -> 'ChatInterface':
|
||||
"""
|
||||
Returns a ChatInterface instance.
|
||||
If currentUser is provided, initializes with user context.
|
||||
Otherwise, returns an instance with only database access.
|
||||
Returns a ChatInterface instance for the current user.
|
||||
Handles initialization of database and records.
|
||||
"""
|
||||
if not currentUser:
|
||||
raise ValueError("Invalid user context: user is required")
|
||||
|
||||
# Create context key
|
||||
contextKey = f"{currentUser.mandateId}_{currentUser.id}"
|
||||
|
||||
# Create new instance if not exists
|
||||
if "default" not in _chatInterfaces:
|
||||
_chatInterfaces["default"] = ChatInterface()
|
||||
if contextKey not in _chatInterfaces:
|
||||
_chatInterfaces[contextKey] = ChatInterface(currentUser)
|
||||
|
||||
interface = _chatInterfaces["default"]
|
||||
|
||||
if currentUser:
|
||||
interface.setUserContext(currentUser)
|
||||
else:
|
||||
logger.info("Returning interface without user context")
|
||||
|
||||
return interface
|
||||
return _chatInterfaces[contextKey]
|
||||
|
|
@ -67,6 +67,23 @@ class ChatMessage(BaseModelWithUI):
|
|||
stats: Optional[ChatStat] = Field(None, description="Statistics for this message")
|
||||
success: Optional[bool] = Field(None, description="Whether the message processing was successful")
|
||||
|
||||
class Task(BaseModelWithUI):
|
||||
"""Data model for a task"""
|
||||
id: str = Field(description="Primary key")
|
||||
workflowId: str = Field(description="Foreign key to workflow")
|
||||
agentName: str = Field(description="Name of the agent assigned to this task")
|
||||
status: str = Field(description="Current status of the task")
|
||||
progress: float = Field(description="Task progress (0-100)")
|
||||
prompt: str = Field(description="Prompt for the task")
|
||||
userLanguage: str = Field(description="User's preferred language")
|
||||
filesInput: List[str] = Field(default_factory=list, description="Input files")
|
||||
filesOutput: List[str] = Field(default_factory=list, description="Output files")
|
||||
result: Optional[ChatMessage] = Field(None, description="Task result message")
|
||||
error: Optional[str] = Field(None, description="Error message if failed")
|
||||
startedAt: str = Field(description="When the task started")
|
||||
finishedAt: Optional[str] = Field(None, description="When the task finished")
|
||||
performance: Optional[Dict[str, Any]] = Field(None, description="Performance metrics")
|
||||
|
||||
class ChatWorkflow(BaseModelWithUI):
|
||||
"""Data model for a chat workflow"""
|
||||
id: str = Field(description="Primary key")
|
||||
|
|
@ -79,7 +96,7 @@ class ChatWorkflow(BaseModelWithUI):
|
|||
logs: List[ChatLog] = Field(default_factory=list, description="Workflow logs")
|
||||
messages: List[ChatMessage] = Field(default_factory=list, description="Messages in the workflow")
|
||||
stats: Optional[ChatStat] = Field(None, description="Workflow statistics")
|
||||
tasks: List['Task'] = Field(default_factory=list, description="List of tasks in the workflow")
|
||||
tasks: List[Task] = Field(default_factory=list, description="List of tasks in the workflow")
|
||||
|
||||
label: Label = Field(
|
||||
default=Label(default="Chat Workflow", translations={"en": "Chat Workflow", "fr": "Flux de travail de chat"}),
|
||||
|
|
@ -117,23 +134,6 @@ class AgentResponse(BaseModelWithUI):
|
|||
performance: Dict[str, Any] = Field(default_factory=dict, description="Performance metrics")
|
||||
progress: float = Field(description="Task progress (0-100)")
|
||||
|
||||
class Task(BaseModelWithUI):
|
||||
"""Data model for a task"""
|
||||
id: str = Field(description="Primary key")
|
||||
workflowId: str = Field(description="Foreign key to workflow")
|
||||
agentName: str = Field(description="Name of the agent assigned to this task")
|
||||
status: str = Field(description="Current status of the task")
|
||||
progress: float = Field(description="Task progress (0-100)")
|
||||
prompt: str = Field(description="Prompt for the task")
|
||||
userLanguage: str = Field(description="User's preferred language")
|
||||
filesInput: List[str] = Field(default_factory=list, description="Input files")
|
||||
filesOutput: List[str] = Field(default_factory=list, description="Output files")
|
||||
result: Optional[ChatMessage] = Field(None, description="Task result message")
|
||||
error: Optional[str] = Field(None, description="Error message if failed")
|
||||
startedAt: str = Field(description="When the task started")
|
||||
finishedAt: Optional[str] = Field(None, description="When the task finished")
|
||||
performance: Optional[Dict[str, Any]] = Field(None, description="Performance metrics")
|
||||
|
||||
class TaskPlan(BaseModelWithUI):
|
||||
"""Data model for a task plan"""
|
||||
fileList: List[str] = Field(default_factory=list, description="List of files")
|
||||
|
|
@ -145,4 +145,14 @@ class UserInputRequest(BaseModelWithUI):
|
|||
"""Data model for a user input request"""
|
||||
prompt: str = Field(description="Prompt for the user")
|
||||
listFileId: List[int] = Field(default_factory=list, description="List of file IDs")
|
||||
userLanguage: str = Field(description="User's preferred language")
|
||||
userLanguage: str = Field(default="en", description="User's preferred language")
|
||||
|
||||
class AgentProfile(BaseModel):
|
||||
"""Model for agent profile information."""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
capabilities: List[str] = Field(default_factory=list)
|
||||
isAvailable: bool = True
|
||||
lastActive: Optional[datetime] = None
|
||||
stats: Optional[Dict[str, Any]] = None
|
||||
|
|
@ -3,7 +3,12 @@ Access control module for Management interface.
|
|||
Handles user access management and permission checks.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from modules.interfaces.serviceAppModel import User
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ManagementAccess:
|
||||
"""
|
||||
|
|
@ -11,15 +16,12 @@ class ManagementAccess:
|
|||
Handles user access management and permission checks.
|
||||
"""
|
||||
|
||||
def __init__(self, currentUser: Dict[str, Any], db):
|
||||
def __init__(self, currentUser: User, db):
|
||||
"""Initialize with user context."""
|
||||
self.currentUser = currentUser
|
||||
self.mandateId = currentUser.get("mandateId")
|
||||
self.userId = currentUser.get("id")
|
||||
|
||||
if not self.mandateId or not self.userId:
|
||||
raise ValueError("Invalid user context: mandateId and userId are required")
|
||||
|
||||
self.userId = currentUser.id
|
||||
self.mandateId = currentUser.mandateId
|
||||
self.privilege = currentUser.privilege
|
||||
self.db = db
|
||||
|
||||
def uam(self, table: str, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
|
|
@ -34,8 +36,8 @@ class ManagementAccess:
|
|||
Returns:
|
||||
Filtered recordset with access control attributes
|
||||
"""
|
||||
userPrivilege = self.currentUser.get("privilege", "user")
|
||||
print("DEBUG: User privilege:", userPrivilege, self.currentUser.get("username"),self.currentUser.get("email"))
|
||||
userPrivilege = self.privilege
|
||||
logger.debug(f"User privilege: {userPrivilege}, username: {self.currentUser.username}, email: {self.currentUser.email}")
|
||||
filtered_records = []
|
||||
|
||||
# Apply filtering based on privilege
|
||||
|
|
@ -98,7 +100,7 @@ class ManagementAccess:
|
|||
Returns:
|
||||
Boolean indicating permission
|
||||
"""
|
||||
userPrivilege = self.currentUser.get("privilege", "user")
|
||||
userPrivilege = self.privilege
|
||||
|
||||
# System admins can modify anything
|
||||
if userPrivilege == "sysadmin":
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class ServiceManagement:
|
|||
logger.info("Initializing interface without user context")
|
||||
return
|
||||
|
||||
self.currentUser = currentUser
|
||||
self.currentUser = currentUser # Store User object directly
|
||||
self.userId = currentUser.id
|
||||
|
||||
if not self.userId:
|
||||
|
|
@ -249,40 +249,32 @@ class ServiceManagement:
|
|||
filteredPrompts = self._uam("prompts", prompts)
|
||||
return Prompt.from_dict(filteredPrompts[0]) if filteredPrompts else None
|
||||
|
||||
def createPrompt(self, content: str, name: str) -> Prompt:
|
||||
def createPrompt(self, promptData: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Creates a new prompt if user has permission."""
|
||||
if not self._canModify("prompts"):
|
||||
raise PermissionError("No permission to create prompts")
|
||||
|
||||
promptData = Prompt(
|
||||
content=content,
|
||||
name=name,
|
||||
createdAt=self._getCurrentTimestamp()
|
||||
)
|
||||
|
||||
# Create prompt record
|
||||
createdRecord = self.db.recordCreate("prompts", promptData.to_dict())
|
||||
return Prompt.from_dict(createdRecord)
|
||||
if not createdRecord or not createdRecord.get("id"):
|
||||
raise ValueError("Failed to create prompt record")
|
||||
|
||||
def updatePrompt(self, promptId: str, content: str = None, name: str = None) -> Prompt:
|
||||
return createdRecord
|
||||
|
||||
def updatePrompt(self, promptId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Updates a prompt if user has access."""
|
||||
# Check if the prompt exists and user has access
|
||||
try:
|
||||
# Get prompt
|
||||
prompt = self.getPrompt(promptId)
|
||||
if not prompt:
|
||||
raise ValueError(f"Prompt {promptId} not found")
|
||||
|
||||
if not self._canModify("prompts", promptId):
|
||||
raise PermissionError(f"No permission to update prompt {promptId}")
|
||||
|
||||
# Update prompt data using model
|
||||
updatedData = prompt.model_dump()
|
||||
if content is not None:
|
||||
updatedData["content"] = content
|
||||
if name is not None:
|
||||
updatedData["name"] = name
|
||||
|
||||
updatedData = prompt.to_dict()
|
||||
updatedData.update(updateData)
|
||||
updatedPrompt = Prompt.from_dict(updatedData)
|
||||
|
||||
# Update prompt
|
||||
# Update prompt record
|
||||
self.db.recordModify("prompts", promptId, updatedPrompt.to_dict())
|
||||
|
||||
# Get updated prompt
|
||||
|
|
@ -292,6 +284,10 @@ class ServiceManagement:
|
|||
|
||||
return updatedPrompt
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating prompt: {str(e)}")
|
||||
raise ValueError(f"Failed to update prompt: {str(e)}")
|
||||
|
||||
def deletePrompt(self, promptId: str) -> bool:
|
||||
"""Deletes a prompt if user has access."""
|
||||
# Check if the prompt exists and user has access
|
||||
|
|
|
|||
|
|
@ -1,13 +1,15 @@
|
|||
from fastapi import APIRouter, Response, Depends
|
||||
from fastapi import APIRouter, Response, Depends, Request
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path as FilePath
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, List
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.security.auth import limiter, getCurrentUser
|
||||
from modules.interfaces.serviceAppModel import User
|
||||
|
||||
router = APIRouter(
|
||||
prefix="",
|
||||
|
|
@ -33,7 +35,7 @@ router.mount("/static", StaticFiles(directory=str(staticFolder), html=True), nam
|
|||
|
||||
@router.get("/")
|
||||
@limiter.limit("30/minute")
|
||||
async def root():
|
||||
async def root(request: Request) -> Dict[str, str]:
|
||||
"""API status endpoint"""
|
||||
return {
|
||||
"status": "online",
|
||||
|
|
@ -43,7 +45,7 @@ async def root():
|
|||
|
||||
@router.get("/api/environment")
|
||||
@limiter.limit("30/minute")
|
||||
async def get_environment(currentUser: Dict[str, Any] = Depends(getCurrentUser)):
|
||||
async def get_environment(request: Request, currentUser: Dict[str, Any] = Depends(getCurrentUser)) -> Dict[str, str]:
|
||||
"""Get environment configuration for frontend"""
|
||||
return {
|
||||
"apiBaseUrl": APP_CONFIG.get("APP_API_URL", ""),
|
||||
|
|
@ -54,10 +56,10 @@ async def get_environment(currentUser: Dict[str, Any] = Depends(getCurrentUser))
|
|||
|
||||
@router.options("/{fullPath:path}")
|
||||
@limiter.limit("60/minute")
|
||||
async def options_route(fullPath: str):
|
||||
async def options_route(request: Request, fullPath: str) -> Response:
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.get("/favicon.ico")
|
||||
@limiter.limit("30/minute")
|
||||
async def favicon():
|
||||
async def favicon(request: Request) -> FileResponse:
|
||||
return FileResponse(str(staticFolder / "favicon.ico"), media_type="image/x-icon")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from fastapi import APIRouter, HTTPException, Depends, Path, Response
|
||||
from fastapi import APIRouter, HTTPException, Depends, Path, Response, Request
|
||||
from typing import List, Dict, Any
|
||||
from fastapi import status
|
||||
import inspect
|
||||
|
|
@ -11,7 +11,7 @@ import logging
|
|||
from modules.security.auth import limiter, getCurrentUser
|
||||
|
||||
# Import the attribute definition and helper functions
|
||||
from modules.interfaces.serviceAppModel import AttributeDefinition
|
||||
from modules.interfaces.serviceAppModel import AttributeDefinition, User
|
||||
from modules.shared.attributeUtils import getModelClasses
|
||||
|
||||
# Configure logger
|
||||
|
|
@ -50,9 +50,9 @@ router = APIRouter(
|
|||
@router.get("/{entityType}", response_model=AttributeResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_entity_attributes(
|
||||
entityType: str = Path(..., description="Type of entity (e.g. prompt)"),
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
request: Request,
|
||||
entityType: str = Path(..., description="Type of entity (e.g. prompt)")
|
||||
) -> AttributeResponse:
|
||||
"""
|
||||
Retrieves the attribute definitions for a specific entity.
|
||||
This can be used for dynamic form generation.
|
||||
|
|
@ -63,9 +63,6 @@ async def get_entity_attributes(
|
|||
Returns:
|
||||
- A list of attribute definitions that can be used to generate forms
|
||||
"""
|
||||
# Determine preferred language of the user
|
||||
userLanguage = currentUser.get("language", "en")
|
||||
|
||||
# Get model classes dynamically
|
||||
modelClasses = getModelClasses()
|
||||
|
||||
|
|
@ -80,13 +77,21 @@ async def get_entity_attributes(
|
|||
modelClass = modelClasses[entityType]
|
||||
attributes = modelClass.getModelAttributeDefinitions()
|
||||
|
||||
# Return only visible attributes
|
||||
return AttributeResponse(attributes=[attr for attr in attributes if attr.visible])
|
||||
# Convert dictionary attributes to AttributeDefinition objects
|
||||
attribute_definitions = []
|
||||
for attr in attributes:
|
||||
if isinstance(attr, dict) and attr.get('visible', True):
|
||||
attribute_definitions.append(AttributeDefinition(**attr))
|
||||
elif hasattr(attr, 'visible') and attr.visible:
|
||||
attribute_definitions.append(attr)
|
||||
|
||||
return AttributeResponse(attributes=attribute_definitions)
|
||||
|
||||
@router.options("/{entityType}")
|
||||
@limiter.limit("60/minute")
|
||||
async def options_entity_attributes(
|
||||
request: Request,
|
||||
entityType: str = Path(..., description="Type of entity (e.g. prompt)")
|
||||
):
|
||||
) -> Response:
|
||||
"""Handle OPTIONS request for CORS preflight"""
|
||||
return Response(status_code=200)
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
from fastapi import APIRouter, HTTPException, Depends, File, UploadFile, Form, Path, Request, status, Query, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import APIRouter, HTTPException, Depends, File, UploadFile, Form, Path, Request, status, Query, Response, Body
|
||||
from fastapi.responses import JSONResponse, FileResponse
|
||||
from typing import List, Dict, Any, Optional
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
|
@ -11,8 +11,9 @@ from modules.security.auth import limiter, getCurrentUser
|
|||
|
||||
# Import interfaces
|
||||
import modules.interfaces.serviceManagementClass as serviceManagementClass
|
||||
from modules.interfaces.serviceManagementModel import FileItem, getModelAttributeDefinitions
|
||||
from modules.interfaces.serviceAppModel import AttributeDefinition
|
||||
from modules.interfaces.serviceManagementModel import FileItem
|
||||
from modules.shared.attributeUtils import getModelAttributeDefinitions
|
||||
from modules.interfaces.serviceAppModel import AttributeDefinition, User
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -33,10 +34,13 @@ router = APIRouter(
|
|||
}
|
||||
)
|
||||
|
||||
@router.get("", response_model=List[FileItem])
|
||||
@router.get("/list", response_model=List[FileItem])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_files(currentUser: Dict[str, Any] = Depends(getCurrentUser)):
|
||||
"""Get all available files"""
|
||||
async def get_files(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[FileItem]:
|
||||
"""Get all files"""
|
||||
try:
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
|
||||
|
|
@ -44,19 +48,20 @@ async def get_files(currentUser: Dict[str, Any] = Depends(getCurrentUser)):
|
|||
files = managementInterface.getAllFiles()
|
||||
return [FileItem(**file) for file in files]
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving files: {str(e)}")
|
||||
logger.error(f"Error getting files: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error retrieving files: {str(e)}"
|
||||
detail=f"Failed to get files: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/upload", status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("10/minute")
|
||||
async def upload_file(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
workflowId: Optional[str] = Form(None),
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> JSONResponse:
|
||||
"""Upload a file"""
|
||||
try:
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
|
|
@ -82,7 +87,10 @@ async def upload_file(
|
|||
fileMeta["workflowId"] = workflowId
|
||||
|
||||
# Successful response
|
||||
return fileMeta
|
||||
return JSONResponse({
|
||||
"message": "File uploaded successfully",
|
||||
"file": fileMeta
|
||||
})
|
||||
|
||||
except serviceManagementClass.FileStorageError as e:
|
||||
logger.error(f"Error during file upload (storage): {str(e)}")
|
||||
|
|
@ -97,29 +105,27 @@ async def upload_file(
|
|||
detail=f"Error during file upload: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/{fileId}")
|
||||
@router.get("/{fileId}", response_model=FileItem)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_file(
|
||||
fileId: str,
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
"""Returns a file by its ID for download"""
|
||||
request: Request,
|
||||
fileId: str = Path(..., description="ID of the file"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> FileItem:
|
||||
"""Get a file"""
|
||||
try:
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
|
||||
# Get file via LucyDOM interface from the database
|
||||
fileData = managementInterface.downloadFile(fileId)
|
||||
|
||||
# Return file
|
||||
headers = {
|
||||
"Content-Disposition": f'attachment; filename="{fileData["name"]}"'
|
||||
}
|
||||
return Response(
|
||||
content=fileData["content"],
|
||||
media_type=fileData["contentType"],
|
||||
headers=headers
|
||||
fileData = managementInterface.getFile(fileId)
|
||||
if not fileData:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"File with ID {fileId} not found"
|
||||
)
|
||||
|
||||
return FileItem(**fileData)
|
||||
|
||||
except serviceManagementClass.FileNotFoundError as e:
|
||||
logger.warning(f"File not found: {str(e)}")
|
||||
raise HTTPException(
|
||||
|
|
@ -145,53 +151,62 @@ async def get_file(
|
|||
detail=f"Error retrieving file: {str(e)}"
|
||||
)
|
||||
|
||||
@router.put("/{file_id}", response_model=FileItem)
|
||||
@router.put("/{fileId}", response_model=FileItem)
|
||||
@limiter.limit("10/minute")
|
||||
async def update_file(
|
||||
file_id: str,
|
||||
file_data: FileItem,
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
"""
|
||||
Update file metadata
|
||||
"""
|
||||
request: Request,
|
||||
fileId: str = Path(..., description="ID of the file to update"),
|
||||
file_info: Dict[str, Any] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> FileItem:
|
||||
"""Update file info"""
|
||||
try:
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
|
||||
# Get the file from the database
|
||||
file = managementInterface.getFile(file_id)
|
||||
file = managementInterface.getFile(fileId)
|
||||
if not file:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"File with ID {fileId} not found"
|
||||
)
|
||||
|
||||
# Check if user has access to the file
|
||||
if file.get("userId", 0) != currentUser.get("id", 0):
|
||||
raise HTTPException(status_code=403, detail="Not authorized to update this file")
|
||||
|
||||
# Convert FileItem to dict for interface
|
||||
update_data = file_data.model_dump()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not authorized to update this file"
|
||||
)
|
||||
|
||||
# Update the file
|
||||
result = managementInterface.updateFile(file_id, update_data)
|
||||
result = managementInterface.updateFile(fileId, file_info)
|
||||
if not result:
|
||||
raise HTTPException(status_code=500, detail="Failed to update file")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update file"
|
||||
)
|
||||
|
||||
# Get updated file and convert to FileItem
|
||||
updatedFile = managementInterface.getFile(file_id)
|
||||
updatedFile = managementInterface.getFile(fileId)
|
||||
return FileItem(**updatedFile)
|
||||
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating file: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
@router.delete("/{fileId}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_file(
|
||||
request: Request,
|
||||
fileId: str,
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
"""Deletes a file by its ID from the database"""
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> JSONResponse:
|
||||
"""Delete a file"""
|
||||
try:
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
|
||||
|
|
@ -199,7 +214,9 @@ async def delete_file(
|
|||
managementInterface.deleteFile(fileId)
|
||||
|
||||
# Return successful deletion without content (204 No Content)
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
return JSONResponse({
|
||||
"message": "File deleted successfully"
|
||||
})
|
||||
|
||||
except serviceManagementClass.FileNotFoundError as e:
|
||||
logger.warning(f"File not found: {str(e)}")
|
||||
|
|
@ -229,8 +246,9 @@ async def delete_file(
|
|||
@router.get("/stats", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_file_stats(
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Returns statistics about the stored files"""
|
||||
try:
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
|
|
@ -266,8 +284,9 @@ async def get_file_stats(
|
|||
@router.get("/attributes", response_model=List[AttributeDefinition])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_file_attributes(
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[AttributeDefinition]:
|
||||
"""
|
||||
Retrieves the attribute definitions for files.
|
||||
This can be used for dynamic form generation.
|
||||
|
|
|
|||
|
|
@ -1,3 +1,8 @@
|
|||
"""
|
||||
Mandate routes for the backend API.
|
||||
Implements the endpoints for mandate management.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import status
|
||||
|
|
@ -8,10 +13,10 @@ from modules.security.auth import limiter, getCurrentUser
|
|||
|
||||
# Import interfaces
|
||||
import modules.interfaces.serviceManagementClass as serviceManagementClass
|
||||
from modules.interfaces.serviceManagementModel import Mandate, getModelAttributeDefinitions
|
||||
from modules.shared.attributeUtils import getModelAttributeDefinitions
|
||||
|
||||
# Import the model classes
|
||||
from modules.interfaces.serviceAppModel import AttributeDefinition
|
||||
from modules.interfaces.serviceAppModel import AttributeDefinition, Mandate, User
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -26,13 +31,17 @@ router = APIRouter(
|
|||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
|
||||
@router.get("/", response_model=List[Dict[str, Any]], tags=["Mandates"])
|
||||
@router.get("/", response_model=List[Mandate])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_mandates(currentUser: Dict[str, Any] = Depends(getCurrentUser)):
|
||||
async def get_mandates(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[Mandate]:
|
||||
"""Get all mandates"""
|
||||
try:
|
||||
appInterface = serviceManagementClass.getInterface(currentUser)
|
||||
return appInterface.getMandates()
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
mandates = managementInterface.getMandates()
|
||||
return [Mandate.from_dict(mandate) for mandate in mandates]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting mandates: {str(e)}")
|
||||
raise HTTPException(
|
||||
|
|
@ -40,24 +49,25 @@ async def get_mandates(currentUser: Dict[str, Any] = Depends(getCurrentUser)):
|
|||
detail=f"Failed to get mandates: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/{mandateId}", response_model=Dict[str, Any], tags=["Mandates"])
|
||||
@router.get("/{mandateId}", response_model=Mandate)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_mandate(
|
||||
mandateId: str,
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
request: Request,
|
||||
mandateId: str = Path(..., description="ID of the mandate"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Mandate:
|
||||
"""Get a specific mandate by ID"""
|
||||
try:
|
||||
appInterface = serviceManagementClass.getInterface(currentUser)
|
||||
mandate = appInterface.getMandateById(mandateId)
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
mandate = managementInterface.getMandate(mandateId)
|
||||
|
||||
if not mandate:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Mandate {mandateId} not found"
|
||||
detail=f"Mandate with ID {mandateId} not found"
|
||||
)
|
||||
|
||||
return mandate
|
||||
return Mandate.from_dict(mandate)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -67,31 +77,30 @@ async def get_mandate(
|
|||
detail=f"Failed to get mandate: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/", response_model=Mandate, tags=["Mandates"])
|
||||
@router.post("/", response_model=Mandate)
|
||||
@limiter.limit("10/minute")
|
||||
async def create_mandate(
|
||||
mandateData: Mandate,
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
request: Request,
|
||||
mandateData: Mandate = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Mandate:
|
||||
"""Create a new mandate"""
|
||||
try:
|
||||
appInterface = serviceManagementClass.getInterface(currentUser)
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
|
||||
try:
|
||||
createdMandate = appInterface.createMandate(mandateData)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
# Convert Mandate to dict for interface
|
||||
mandate_data = mandateData.to_dict()
|
||||
|
||||
if not createdMandate:
|
||||
# Create mandate
|
||||
newMandate = managementInterface.createMandate(mandate_data)
|
||||
|
||||
if not newMandate:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create mandate"
|
||||
)
|
||||
|
||||
return createdMandate
|
||||
return Mandate.from_dict(newMandate)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -101,33 +110,31 @@ async def create_mandate(
|
|||
detail=f"Failed to create mandate: {str(e)}"
|
||||
)
|
||||
|
||||
@router.put("/{mandateId}", response_model=Mandate, tags=["Mandates"])
|
||||
@router.put("/{mandateId}", response_model=Mandate)
|
||||
@limiter.limit("10/minute")
|
||||
async def update_mandate(
|
||||
mandateId: str,
|
||||
mandateData: Mandate,
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
request: Request,
|
||||
mandateId: str = Path(..., description="ID of the mandate to update"),
|
||||
mandateData: Mandate = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Mandate:
|
||||
"""Update an existing mandate"""
|
||||
try:
|
||||
appInterface = serviceManagementClass.getInterface(currentUser)
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
|
||||
# Check if mandate exists
|
||||
existingMandate = appInterface.getMandateById(mandateId)
|
||||
existingMandate = managementInterface.getMandate(mandateId)
|
||||
if not existingMandate:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Mandate {mandateId} not found"
|
||||
detail=f"Mandate with ID {mandateId} not found"
|
||||
)
|
||||
|
||||
# Update mandate data
|
||||
try:
|
||||
updatedMandate = appInterface.updateMandate(mandateId, mandateData)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
# Convert Mandate to dict for interface
|
||||
update_data = mandateData.to_dict()
|
||||
|
||||
# Update mandate
|
||||
updatedMandate = managementInterface.updateMandate(mandateId, update_data)
|
||||
|
||||
if not updatedMandate:
|
||||
raise HTTPException(
|
||||
|
|
@ -135,7 +142,7 @@ async def update_mandate(
|
|||
detail="Failed to update mandate"
|
||||
)
|
||||
|
||||
return updatedMandate
|
||||
return Mandate.from_dict(updatedMandate)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -145,12 +152,13 @@ async def update_mandate(
|
|||
detail=f"Failed to update mandate: {str(e)}"
|
||||
)
|
||||
|
||||
@router.delete("/{mandateId}", response_model=Dict[str, Any], tags=["Mandates"])
|
||||
@router.delete("/{mandateId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_mandate(
|
||||
mandateId: str,
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
request: Request,
|
||||
mandateId: str = Path(..., description="ID of the mandate to delete"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Delete a mandate"""
|
||||
try:
|
||||
appInterface = serviceManagementClass.getInterface(currentUser)
|
||||
|
|
@ -185,8 +193,9 @@ async def delete_mandate(
|
|||
@router.get("/attributes", response_model=List[AttributeDefinition])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_mandate_attributes(
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[AttributeDefinition]:
|
||||
"""
|
||||
Retrieves the attribute definitions for mandates.
|
||||
This can be used for dynamic form generation.
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from modules.security.auth import limiter, getCurrentUser
|
|||
# Import interfaces
|
||||
import modules.interfaces.serviceManagementClass as serviceManagementClass
|
||||
from modules.interfaces.serviceManagementModel import Prompt
|
||||
from modules.interfaces.serviceAppModel import AttributeDefinition
|
||||
from modules.interfaces.serviceAppModel import AttributeDefinition, User
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -25,24 +25,26 @@ router = APIRouter(
|
|||
@router.get("", response_model=List[Prompt])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_prompts(
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[Prompt]:
|
||||
"""Get all prompts"""
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
prompts = managementInterface.getAllPrompts()
|
||||
return [Prompt(**prompt) for prompt in prompts]
|
||||
return [Prompt.from_dict(prompt) for prompt in prompts]
|
||||
|
||||
@router.post("", response_model=Prompt)
|
||||
@limiter.limit("10/minute")
|
||||
async def create_prompt(
|
||||
request: Request,
|
||||
prompt: Prompt,
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Prompt:
|
||||
"""Create a new prompt"""
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
|
||||
# Convert Prompt to dict for interface
|
||||
prompt_data = prompt.model_dump()
|
||||
prompt_data = prompt.to_dict()
|
||||
|
||||
# Create prompt
|
||||
newPrompt = managementInterface.createPrompt(prompt_data)
|
||||
|
|
@ -51,14 +53,15 @@ async def create_prompt(
|
|||
if "createdAt" in Prompt.getModelAttributeDefinitions() and hasattr(newPrompt, "createdAt"):
|
||||
newPrompt["createdAt"] = datetime.now().isoformat()
|
||||
|
||||
return Prompt(**newPrompt)
|
||||
return Prompt.from_dict(newPrompt)
|
||||
|
||||
@router.get("/{promptId}", response_model=Prompt)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_prompt(
|
||||
request: Request,
|
||||
promptId: str = Path(..., description="ID of the prompt"),
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Prompt:
|
||||
"""Get a specific prompt"""
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
|
||||
|
|
@ -70,15 +73,16 @@ async def get_prompt(
|
|||
detail=f"Prompt with ID {promptId} not found"
|
||||
)
|
||||
|
||||
return Prompt(**prompt)
|
||||
return Prompt.from_dict(prompt)
|
||||
|
||||
@router.put("/{promptId}", response_model=Prompt)
|
||||
@limiter.limit("10/minute")
|
||||
async def update_prompt(
|
||||
request: Request,
|
||||
promptId: str = Path(..., description="ID of the prompt to update"),
|
||||
promptData: Prompt = Body(...),
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Prompt:
|
||||
"""Update an existing prompt"""
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
|
||||
|
|
@ -91,7 +95,7 @@ async def update_prompt(
|
|||
)
|
||||
|
||||
# Convert Prompt to dict for interface
|
||||
update_data = promptData.model_dump()
|
||||
update_data = promptData.to_dict()
|
||||
|
||||
# Update prompt
|
||||
updatedPrompt = managementInterface.updatePrompt(promptId, update_data)
|
||||
|
|
@ -102,14 +106,15 @@ async def update_prompt(
|
|||
detail="Error updating the prompt"
|
||||
)
|
||||
|
||||
return Prompt(**updatedPrompt)
|
||||
return Prompt.from_dict(updatedPrompt)
|
||||
|
||||
@router.delete("/{promptId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_prompt(
|
||||
request: Request,
|
||||
promptId: str = Path(..., description="ID of the prompt to delete"),
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Delete a prompt"""
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
|
||||
|
|
@ -133,8 +138,9 @@ async def delete_prompt(
|
|||
@router.get("/attributes", response_model=List[AttributeDefinition])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_prompt_attributes(
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[AttributeDefinition]:
|
||||
"""
|
||||
Retrieves the attribute definitions for prompts.
|
||||
This can be used for dynamic form generation.
|
||||
|
|
|
|||
|
|
@ -1,3 +1,8 @@
|
|||
"""
|
||||
User routes for the backend API.
|
||||
Implements the endpoints for user management.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import status
|
||||
|
|
@ -13,8 +18,8 @@ import modules.interfaces.serviceManagementClass as serviceManagementClass
|
|||
from modules.security.auth import getCurrentUser, limiter, getCurrentUser
|
||||
|
||||
# Import the attribute definition and helper functions
|
||||
from modules.interfaces.serviceManagementModel import User, AttributeDefinition, getModelAttributeDefinitions
|
||||
from modules.interfaces.serviceAppModel import AttributeDefinition as ServiceAppAttributeDefinition
|
||||
from modules.interfaces.serviceAppModel import User, AttributeDefinition as ServiceAppAttributeDefinition
|
||||
from modules.shared.attributeUtils import getModelAttributeDefinitions
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -25,12 +30,17 @@ router = APIRouter(
|
|||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
|
||||
@router.get("/", response_model=List[Dict[str, Any]], tags=["Users"])
|
||||
async def get_users(currentUser: Dict[str, Any] = Depends(getCurrentUser)):
|
||||
@router.get("/", response_model=List[User])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_users(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[User]:
|
||||
"""Get all users in the current mandate"""
|
||||
try:
|
||||
appInterface = serviceManagementClass.getInterface(currentUser)
|
||||
return appInterface.getUsers()
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
users = managementInterface.getUsers()
|
||||
return [User.from_dict(user) for user in users]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting users: {str(e)}")
|
||||
raise HTTPException(
|
||||
|
|
@ -38,23 +48,25 @@ async def get_users(currentUser: Dict[str, Any] = Depends(getCurrentUser)):
|
|||
detail=f"Failed to get users: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/{userId}", response_model=Dict[str, Any], tags=["Users"])
|
||||
@router.get("/{userId}", response_model=User)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_user(
|
||||
userId: str,
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
request: Request,
|
||||
userId: str = Path(..., description="ID of the user"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> User:
|
||||
"""Get a specific user by ID"""
|
||||
try:
|
||||
appInterface = serviceManagementClass.getInterface(currentUser)
|
||||
user = appInterface.getUserById(userId)
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
user = managementInterface.getUser(userId)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"User {userId} not found"
|
||||
detail=f"User with ID {userId} not found"
|
||||
)
|
||||
|
||||
return user
|
||||
return User.from_dict(user)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -64,98 +76,68 @@ async def get_user(
|
|||
detail=f"Failed to get user: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/", response_model=User, tags=["Users"])
|
||||
@router.post("", response_model=User)
|
||||
@limiter.limit("10/minute")
|
||||
async def create_user(
|
||||
userData: User,
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
request: Request,
|
||||
user: User,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> User:
|
||||
"""Create a new user"""
|
||||
try:
|
||||
# Get interface for user creation
|
||||
appInterface = serviceManagementClass.getInterface(currentUser)
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
|
||||
try:
|
||||
# Convert User model to dict and pass to createUser
|
||||
createdUser = appInterface.createUser(
|
||||
username=userData.username,
|
||||
email=userData.email,
|
||||
fullName=userData.fullName,
|
||||
language=userData.language,
|
||||
disabled=userData.disabled,
|
||||
privilege=userData.privilege,
|
||||
authenticationAuthority=userData.authenticationAuthority
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
# Convert User to dict for interface
|
||||
user_data = user.to_dict()
|
||||
|
||||
if not createdUser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create user"
|
||||
)
|
||||
# Create user
|
||||
newUser = managementInterface.createUser(user_data)
|
||||
|
||||
return createdUser
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating user: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to create user: {str(e)}"
|
||||
)
|
||||
# Set current time for createdAt if it exists in the model
|
||||
if "createdAt" in User.getModelAttributeDefinitions() and hasattr(newUser, "createdAt"):
|
||||
newUser["createdAt"] = datetime.now().isoformat()
|
||||
|
||||
@router.put("/{userId}", response_model=User, tags=["Users"])
|
||||
return User.from_dict(newUser)
|
||||
|
||||
@router.put("/{userId}", response_model=User)
|
||||
@limiter.limit("10/minute")
|
||||
async def update_user(
|
||||
userId: str,
|
||||
userData: User,
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
request: Request,
|
||||
userId: str = Path(..., description="ID of the user to update"),
|
||||
userData: User = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> User:
|
||||
"""Update an existing user"""
|
||||
try:
|
||||
# Get interface for user updates
|
||||
appInterface = serviceManagementClass.getInterface(currentUser)
|
||||
managementInterface = serviceManagementClass.getInterface(currentUser)
|
||||
|
||||
# Check if user exists
|
||||
existingUser = appInterface.getUserById(userId)
|
||||
# Check if the user exists
|
||||
existingUser = managementInterface.getUser(userId)
|
||||
if not existingUser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"User {userId} not found"
|
||||
detail=f"User with ID {userId} not found"
|
||||
)
|
||||
|
||||
# Update user data
|
||||
try:
|
||||
updatedUser = appInterface.updateUser(userId, userData)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
# Convert User to dict for interface
|
||||
update_data = userData.to_dict()
|
||||
|
||||
# Update user
|
||||
updatedUser = managementInterface.updateUser(userId, update_data)
|
||||
|
||||
if not updatedUser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update user"
|
||||
detail="Error updating the user"
|
||||
)
|
||||
|
||||
return updatedUser
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating user {userId}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to update user: {str(e)}"
|
||||
)
|
||||
return User.from_dict(updatedUser)
|
||||
|
||||
@router.delete("/{userId}", response_model=Dict[str, Any], tags=["Users"])
|
||||
@router.delete("/{userId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_user(
|
||||
userId: str,
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
request: Request,
|
||||
userId: str = Path(..., description="ID of the user to delete"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Delete a user"""
|
||||
try:
|
||||
appInterface = serviceManagementClass.getInterface(currentUser)
|
||||
|
|
@ -176,8 +158,9 @@ async def delete_user(
|
|||
@router.get("/attributes", response_model=List[ServiceAppAttributeDefinition])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_user_attributes(
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[ServiceAppAttributeDefinition]:
|
||||
"""
|
||||
Retrieves the attribute definitions for users.
|
||||
This can be used for dynamic form generation.
|
||||
|
|
|
|||
|
|
@ -10,12 +10,12 @@ from typing import Dict, Any, Optional
|
|||
from datetime import datetime, timedelta
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
from google.auth.transport.requests import Request
|
||||
from google.auth.transport.requests import Request as GoogleRequest
|
||||
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.interfaces.serviceAppClass import getInterface
|
||||
from modules.interfaces.serviceAppModel import AuthAuthority
|
||||
from modules.interfaces.serviceAppTokens import GoogleToken, saveToken
|
||||
from modules.interfaces.serviceAppClass import getInterface, getRootInterface
|
||||
from modules.interfaces.serviceAppModel import AuthAuthority, User
|
||||
from modules.interfaces.serviceAppTokens import GoogleToken
|
||||
from modules.security.auth import getCurrentUser, limiter
|
||||
|
||||
# Configure logger
|
||||
|
|
@ -46,7 +46,7 @@ SCOPES = [
|
|||
|
||||
@router.get("/login")
|
||||
@limiter.limit("5/minute")
|
||||
async def login():
|
||||
async def login(request: Request) -> RedirectResponse:
|
||||
"""Initiate Google login"""
|
||||
try:
|
||||
# Create OAuth flow
|
||||
|
|
@ -79,7 +79,7 @@ async def login():
|
|||
)
|
||||
|
||||
@router.get("/auth/callback")
|
||||
async def auth_callback(code: str, request: Request):
|
||||
async def auth_callback(code: str, request: Request) -> HTMLResponse:
|
||||
"""Handle Google OAuth callback"""
|
||||
try:
|
||||
# Create OAuth flow
|
||||
|
|
@ -111,7 +111,7 @@ async def auth_callback(code: str, request: Request):
|
|||
|
||||
# Save token data
|
||||
appInterface = getInterface()
|
||||
saveToken(appInterface, "Google", token_data)
|
||||
appInterface.saveToken("Google", token_data)
|
||||
|
||||
# Return success page with token data
|
||||
return HTMLResponse(
|
||||
|
|
@ -141,24 +141,36 @@ async def auth_callback(code: str, request: Request):
|
|||
detail=f"Authentication failed: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/logout")
|
||||
@router.get("/me", response_model=User)
|
||||
@limiter.limit("30/minute")
|
||||
async def logout(currentUser: Dict[str, Any] = Depends(getCurrentUser)):
|
||||
"""Logout from Google"""
|
||||
async def get_current_user(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> User:
|
||||
"""Get current user information"""
|
||||
try:
|
||||
# Get user interface
|
||||
appInterface = getInterface()
|
||||
|
||||
# Revoke all sessions for the user
|
||||
appInterface.revokeAllUserSessions(currentUser.get("id"))
|
||||
|
||||
return JSONResponse({
|
||||
"message": "Successfully logged out from Google"
|
||||
})
|
||||
return currentUser
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current user: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get current user: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/logout")
|
||||
@limiter.limit("10/minute")
|
||||
async def logout(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Logout current user"""
|
||||
try:
|
||||
appInterface = getInterface(currentUser)
|
||||
appInterface.logout()
|
||||
return {"message": "Logged out successfully"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Logout failed: {str(e)}"
|
||||
detail=f"Failed to logout: {str(e)}"
|
||||
)
|
||||
|
|
@ -2,18 +2,19 @@
|
|||
Routes for local security and authentication.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status, Depends, Request
|
||||
from fastapi import APIRouter, HTTPException, status, Depends, Request, Response, Body
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import JSONResponse, HTMLResponse, RedirectResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Import auth modules
|
||||
from modules.security.auth import createAccessToken, getCurrentUser, limiter
|
||||
from modules.interfaces.serviceAppClass import getInterface
|
||||
from modules.interfaces.serviceAppModel import User, AuthAuthority
|
||||
from modules.interfaces.serviceAppTokens import LocalToken, saveToken
|
||||
from modules.interfaces.serviceAppClass import getInterface, getRootInterface
|
||||
from modules.interfaces.serviceAppModel import User, UserInDB, AuthAuthority, UserPrivilege
|
||||
from modules.interfaces.serviceAppTokens import LocalToken
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -36,7 +37,7 @@ router = APIRouter(
|
|||
async def login(
|
||||
request: Request,
|
||||
formData: OAuth2PasswordRequestForm = Depends(),
|
||||
):
|
||||
) -> Dict[str, Any]:
|
||||
"""Get access token for local user authentication"""
|
||||
try:
|
||||
# Validate CSRF token
|
||||
|
|
@ -47,11 +48,22 @@ async def login(
|
|||
detail="CSRF token missing"
|
||||
)
|
||||
|
||||
# Get gateway interface
|
||||
appInterface = getInterface()
|
||||
# Get gateway interface with root privileges for authentication
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Get default mandate ID
|
||||
defaultMandateId = rootInterface.getInitialId("mandates")
|
||||
if not defaultMandateId:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="No default mandate found"
|
||||
)
|
||||
|
||||
# Set the mandate ID on the interface
|
||||
rootInterface.mandateId = defaultMandateId
|
||||
|
||||
# Authenticate user
|
||||
user = appInterface.authenticateLocalUser(
|
||||
user = rootInterface.authenticateLocalUser(
|
||||
username=formData.username,
|
||||
password=formData.password
|
||||
)
|
||||
|
|
@ -79,13 +91,16 @@ async def login(
|
|||
detail="Failed to create access token"
|
||||
)
|
||||
|
||||
# Get user-specific interface for token operations
|
||||
userInterface = getInterface(user)
|
||||
|
||||
# Save token data
|
||||
token_data = {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_at": expires_at.timestamp()
|
||||
}
|
||||
saveToken(appInterface, "Local", token_data)
|
||||
userInterface.saveToken("Local", token_data)
|
||||
|
||||
# Create response data
|
||||
response_data = {
|
||||
|
|
@ -115,11 +130,16 @@ async def login(
|
|||
)
|
||||
|
||||
@router.post("/register", response_model=User)
|
||||
async def register_user(userData: User):
|
||||
@limiter.limit("10/minute")
|
||||
async def register_user(
|
||||
request: Request,
|
||||
userData: User = Body(...),
|
||||
password: str = Body(..., embed=True)
|
||||
) -> User:
|
||||
"""Register a new local user."""
|
||||
try:
|
||||
# Get gateway interface
|
||||
appInterface = getInterface()
|
||||
# Get gateway interface with root privileges since this is a public endpoint
|
||||
appInterface = getRootInterface()
|
||||
|
||||
# Get default mandate ID
|
||||
defaultMandateId = appInterface.getInitialId("mandates")
|
||||
|
|
@ -129,22 +149,28 @@ async def register_user(userData: User):
|
|||
detail="No default mandate found"
|
||||
)
|
||||
|
||||
# Create user with default mandate
|
||||
user = appInterface.createUser(
|
||||
# Set the mandate ID on the interface
|
||||
appInterface.mandateId = defaultMandateId
|
||||
|
||||
# Create user with individual parameters
|
||||
newUser = appInterface.createUser(
|
||||
username=userData.username,
|
||||
password=userData.password,
|
||||
password=password, # Pass the plain text password - createUser will hash it
|
||||
email=userData.email,
|
||||
mandateId=defaultMandateId, # Use default mandate instead of userData.mandateId
|
||||
authenticationAuthority=AuthAuthority.LOCAL
|
||||
fullName=userData.fullName,
|
||||
language=userData.language,
|
||||
disabled=userData.disabled,
|
||||
privilege=userData.privilege,
|
||||
authenticationAuthority=userData.authenticationAuthority
|
||||
)
|
||||
|
||||
if not user:
|
||||
if not newUser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to register user"
|
||||
)
|
||||
|
||||
return user
|
||||
return newUser
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -158,22 +184,32 @@ async def register_user(userData: User):
|
|||
detail=f"Failed to register user: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/me", response_model=Dict[str, Any])
|
||||
@router.get("/me", response_model=User)
|
||||
@limiter.limit("30/minute")
|
||||
async def read_user_me(currentUser: Dict[str, Any] = Depends(getCurrentUser)):
|
||||
"""Get current user information"""
|
||||
async def read_user_me(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> User:
|
||||
"""Get current user info"""
|
||||
try:
|
||||
return currentUser
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user me: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get current user: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/logout")
|
||||
@limiter.limit("30/minute")
|
||||
async def logout(currentUser: Dict[str, Any] = Depends(getCurrentUser)):
|
||||
async def logout(request: Request, currentUser: User = Depends(getCurrentUser)) -> JSONResponse:
|
||||
"""Logout from local authentication"""
|
||||
try:
|
||||
# Get user interface
|
||||
appInterface = getInterface()
|
||||
# Get user interface with current user context
|
||||
appInterface = getInterface(currentUser)
|
||||
|
||||
# Revoke all sessions for the user
|
||||
appInterface.revokeAllUserSessions(currentUser.get("id"))
|
||||
appInterface.revokeAllUserSessions(currentUser.id)
|
||||
|
||||
return JSONResponse({
|
||||
"message": "Successfully logged out"
|
||||
|
|
@ -186,15 +222,31 @@ async def logout(currentUser: Dict[str, Any] = Depends(getCurrentUser)):
|
|||
detail=f"Logout failed: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/available", response_model=Dict[str, Any])
|
||||
@router.get("/available")
|
||||
@limiter.limit("10/minute")
|
||||
async def check_username_availability(
|
||||
request: Request,
|
||||
username: str,
|
||||
authenticationAuthority: str = "local"
|
||||
):
|
||||
"""Check if a username is available for registration"""
|
||||
) -> Dict[str, Any]:
|
||||
"""Check if a username is available for registration."""
|
||||
try:
|
||||
interfaceRoot = getInterface()
|
||||
return interfaceRoot.checkUsernameAvailability(username, authenticationAuthority)
|
||||
# Get root interface
|
||||
appInterface = getRootInterface()
|
||||
|
||||
# Use the interface's method to check availability
|
||||
result = appInterface.checkUsernameAvailability({
|
||||
"username": username,
|
||||
"authenticationAuthority": authenticationAuthority
|
||||
})
|
||||
|
||||
return {
|
||||
"username": username,
|
||||
"authenticationAuthority": authenticationAuthority,
|
||||
"available": result["available"],
|
||||
"message": result["message"]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking username availability: {str(e)}")
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
Routes for Microsoft authentication.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, Response, status, Depends
|
||||
from fastapi import APIRouter, HTTPException, Request, Response, status, Depends, Body
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
||||
import logging
|
||||
import json
|
||||
|
|
@ -11,9 +11,9 @@ from datetime import datetime, timedelta
|
|||
import msal
|
||||
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.interfaces.serviceAppClass import getInterface
|
||||
from modules.interfaces.serviceAppModel import AuthAuthority
|
||||
from modules.interfaces.serviceAppTokens import MsftToken, saveToken
|
||||
from modules.interfaces.serviceAppClass import getInterface, getRootInterface
|
||||
from modules.interfaces.serviceAppModel import AuthAuthority, User
|
||||
from modules.interfaces.serviceAppTokens import MsftToken
|
||||
from modules.security.auth import getCurrentUser, limiter
|
||||
|
||||
# Configure logger
|
||||
|
|
@ -42,7 +42,7 @@ SCOPES = ["Mail.ReadWrite", "User.Read"]
|
|||
|
||||
@router.get("/login")
|
||||
@limiter.limit("5/minute")
|
||||
async def login():
|
||||
async def login(request: Request) -> RedirectResponse:
|
||||
"""Initiate Microsoft login"""
|
||||
try:
|
||||
# Create MSAL app
|
||||
|
|
@ -68,7 +68,7 @@ async def login():
|
|||
)
|
||||
|
||||
@router.get("/auth/callback")
|
||||
async def auth_callback(code: str, request: Request):
|
||||
async def auth_callback(code: str, request: Request) -> HTMLResponse:
|
||||
"""Handle Microsoft OAuth callback"""
|
||||
try:
|
||||
# Create MSAL app
|
||||
|
|
@ -101,7 +101,7 @@ async def auth_callback(code: str, request: Request):
|
|||
|
||||
# Save token data
|
||||
appInterface = getInterface()
|
||||
saveToken(appInterface, "Msft", token_data)
|
||||
appInterface.saveToken("Msft", token_data)
|
||||
|
||||
# Return success page with token data
|
||||
return HTMLResponse(
|
||||
|
|
@ -131,24 +131,36 @@ async def auth_callback(code: str, request: Request):
|
|||
detail=f"Authentication failed: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/logout")
|
||||
@router.get("/me", response_model=User)
|
||||
@limiter.limit("30/minute")
|
||||
async def logout(currentUser: Dict[str, Any] = Depends(getCurrentUser)):
|
||||
"""Logout from Microsoft"""
|
||||
async def get_current_user(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> User:
|
||||
"""Get current user information"""
|
||||
try:
|
||||
# Get user interface
|
||||
appInterface = getInterface()
|
||||
|
||||
# Revoke all sessions for the user
|
||||
appInterface.revokeAllUserSessions(currentUser.get("id"))
|
||||
|
||||
return JSONResponse({
|
||||
"message": "Successfully logged out from Microsoft"
|
||||
})
|
||||
return currentUser
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current user: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get current user: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/logout")
|
||||
@limiter.limit("10/minute")
|
||||
async def logout(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Logout current user"""
|
||||
try:
|
||||
appInterface = getInterface(currentUser)
|
||||
appInterface.logout()
|
||||
return {"message": "Logged out successfully"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Logout failed: {str(e)}"
|
||||
detail=f"Failed to logout: {str(e)}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,14 +7,16 @@ import os
|
|||
import json
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Response, status
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Response, status, Request
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Import auth modules
|
||||
from modules.security.auth import limiter, getCurrentUser
|
||||
|
||||
# Import interfaces
|
||||
import modules.interfaces.serviceChatClass as serviceChatClass
|
||||
from modules.interfaces.serviceChatClass import getInterface
|
||||
|
||||
# Import workflow manager
|
||||
from modules.workflow.workflowManager import getWorkflowManager
|
||||
|
|
@ -26,10 +28,10 @@ from modules.interfaces.serviceChatModel import (
|
|||
ChatLog,
|
||||
ChatStat,
|
||||
ChatDocument,
|
||||
UserInputRequest,
|
||||
Workflow,
|
||||
getModelAttributeDefinitions
|
||||
UserInputRequest
|
||||
)
|
||||
from modules.shared.attributeUtils import getModelAttributeDefinitions
|
||||
from modules.interfaces.serviceAppModel import User
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -58,34 +60,32 @@ def createServiceContainer(currentUser: Dict[str, Any]):
|
|||
return service
|
||||
|
||||
# API Endpoint for getting all workflows
|
||||
@router.get("", response_model=List[ChatWorkflow])
|
||||
@router.get("/list", response_model=List[ChatWorkflow])
|
||||
@limiter.limit("30/minute")
|
||||
async def list_workflows(
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[ChatWorkflow]:
|
||||
"""List all workflows for the current user."""
|
||||
try:
|
||||
# Get service container
|
||||
service = createServiceContainer(currentUser)
|
||||
|
||||
# Retrieve workflows for the user
|
||||
workflows = service.base.getWorkflowsByUser(currentUser["id"])
|
||||
return [ChatWorkflow(**workflow) for workflow in workflows]
|
||||
appInterface = getInterface(currentUser)
|
||||
return appInterface.getAllWorkflows()
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing workflows: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error listing workflows: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error listing workflows: {str(e)}"
|
||||
detail=f"Failed to list workflows: {str(e)}"
|
||||
)
|
||||
|
||||
# State 1: Workflow Initialization endpoint
|
||||
@router.post("/start", response_model=ChatWorkflow)
|
||||
@limiter.limit("10/minute")
|
||||
async def start_workflow(
|
||||
request: Request,
|
||||
workflowId: Optional[str] = Query(None, description="Optional ID of the workflow to continue"),
|
||||
userInput: UserInputRequest = Body(...),
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> ChatWorkflow:
|
||||
"""
|
||||
Starts a new workflow or continues an existing one.
|
||||
Corresponds to State 1 in the state machine documentation.
|
||||
|
|
@ -100,19 +100,23 @@ async def start_workflow(
|
|||
# Start or continue workflow
|
||||
workflow = await workflowManager.workflowStart(userInput, workflowId)
|
||||
|
||||
return workflow
|
||||
return ChatWorkflow(**workflow)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in start_workflow: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
# State 8: Workflow Stopped endpoint
|
||||
@router.post("/{workflowId}/stop", response_model=ChatWorkflow)
|
||||
@limiter.limit("10/minute")
|
||||
async def stop_workflow(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow to stop"),
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> ChatWorkflow:
|
||||
"""Stops a running workflow."""
|
||||
try:
|
||||
# Get service container
|
||||
|
|
@ -124,19 +128,23 @@ async def stop_workflow(
|
|||
# Stop workflow
|
||||
workflow = await workflowManager.workflowStop(workflowId)
|
||||
|
||||
return workflow
|
||||
return ChatWorkflow(**workflow)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stop_workflow: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
# State 11: Workflow Reset/Deletion endpoint
|
||||
@router.delete("/{workflowId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_workflow(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow to delete"),
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Deletes a workflow and its associated data."""
|
||||
try:
|
||||
# Get service container
|
||||
|
|
@ -183,9 +191,10 @@ async def delete_workflow(
|
|||
@router.get("/{workflowId}/status", response_model=ChatWorkflow)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_workflow_status(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> ChatWorkflow:
|
||||
"""Get the current status of a workflow."""
|
||||
try:
|
||||
# Get service container
|
||||
|
|
@ -213,10 +222,11 @@ async def get_workflow_status(
|
|||
@router.get("/{workflowId}/logs", response_model=List[ChatLog])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_workflow_logs(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
logId: Optional[str] = Query(None, description="Optional log ID to get only newer logs"),
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[ChatLog]:
|
||||
"""Get logs for a workflow with support for selective data transfer."""
|
||||
try:
|
||||
# Get service container
|
||||
|
|
@ -255,10 +265,11 @@ async def get_workflow_logs(
|
|||
@router.get("/{workflowId}/messages", response_model=List[ChatMessage])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_workflow_messages(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
messageId: Optional[str] = Query(None, description="Optional message ID to get only newer messages"),
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[ChatMessage]:
|
||||
"""Get messages for a workflow with support for selective data transfer."""
|
||||
try:
|
||||
# Get service container
|
||||
|
|
@ -277,20 +288,12 @@ async def get_workflow_messages(
|
|||
|
||||
# Apply selective data transfer if messageId is provided
|
||||
if messageId:
|
||||
# Find the index of the specified message based on messageIds array
|
||||
messageIds = workflow.get("messageIds", [])
|
||||
if messageId in messageIds:
|
||||
messageIndex = messageIds.index(messageId)
|
||||
# Return messages from this index onwards based on the messageIds order
|
||||
filteredMessages = []
|
||||
for msgId in messageIds[messageIndex:]:
|
||||
message = next((msg for msg in allMessages if msg.get("id") == msgId), None)
|
||||
if message:
|
||||
filteredMessages.append(message)
|
||||
return [ChatMessage(**msg) for msg in filteredMessages]
|
||||
# Find the index of the message with the given ID
|
||||
messageIndex = next((i for i, msg in enumerate(allMessages) if msg.get("id") == messageId), -1)
|
||||
if messageIndex >= 0:
|
||||
# Return only messages after the specified message
|
||||
return [ChatMessage(**msg) for msg in allMessages[messageIndex + 1:]]
|
||||
|
||||
# Sort messages by sequenceNo
|
||||
allMessages.sort(key=lambda x: x.get("sequenceNo", 0))
|
||||
return [ChatMessage(**msg) for msg in allMessages]
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -306,16 +309,17 @@ async def get_workflow_messages(
|
|||
@router.delete("/{workflowId}/messages/{messageId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_workflow_message(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
messageId: str = Path(..., description="ID of the message to delete"),
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Delete a message from a workflow."""
|
||||
try:
|
||||
# Get service container
|
||||
service = createServiceContainer(currentUser)
|
||||
|
||||
# Verify workflow exists and belongs to user
|
||||
# Verify workflow exists
|
||||
workflow = service.base.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
|
|
@ -355,17 +359,18 @@ async def delete_workflow_message(
|
|||
@router.delete("/{workflowId}/messages/{messageId}/files/{fileId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_file_from_message(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
messageId: str = Path(..., description="ID of the message"),
|
||||
fileId: str = Path(..., description="ID of the file to delete"),
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Delete a file reference from a message in a workflow."""
|
||||
try:
|
||||
# Get service container
|
||||
service = createServiceContainer(currentUser)
|
||||
|
||||
# Verify workflow exists and belongs to user
|
||||
# Verify workflow exists
|
||||
workflow = service.base.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
|
|
@ -402,89 +407,24 @@ async def delete_file_from_message(
|
|||
@router.get("/files/{fileId}/preview", response_model=ChatDocument)
|
||||
@limiter.limit("30/minute")
|
||||
async def preview_file(
|
||||
request: Request,
|
||||
fileId: str = Path(..., description="ID of the file to preview"),
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
"""Get file metadata and a preview of the file content."""
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> ChatDocument:
|
||||
"""Preview a file's content."""
|
||||
try:
|
||||
# Get service container
|
||||
service = createServiceContainer(currentUser)
|
||||
|
||||
# Get file metadata
|
||||
file = service.base.getFile(fileId)
|
||||
if not file:
|
||||
# Get file document
|
||||
document = service.base.getFileDocument(fileId)
|
||||
if not document:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"File with ID {fileId} not found"
|
||||
)
|
||||
|
||||
# Get file data (limited for preview)
|
||||
fileData = service.base.getFileData(fileId)
|
||||
if fileData is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"File data not found for file ID {fileId}"
|
||||
)
|
||||
|
||||
# For text-based files, return a preview of the content
|
||||
mimeType = file.get("mimeType", "application/octet-stream")
|
||||
isText = mimeType.startswith("text/") or mimeType in [
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript"
|
||||
]
|
||||
|
||||
previewData = None
|
||||
|
||||
# Get base64Encoded flag from database
|
||||
fileDataEntries = service.base.db.getRecordset("fileData", recordFilter={"id": fileId})
|
||||
if fileDataEntries and "base64Encoded" in fileDataEntries[0]:
|
||||
# Use the flag from the database
|
||||
base64Encoded = fileDataEntries[0]["base64Encoded"]
|
||||
else:
|
||||
# Determine based on file type (fallback for older data)
|
||||
base64Encoded = not isText
|
||||
|
||||
if isText:
|
||||
# Convert to string without trim for preview
|
||||
if isinstance(fileData, bytes):
|
||||
try:
|
||||
filePreview = fileData.decode('utf-8')
|
||||
previewData = filePreview
|
||||
except UnicodeDecodeError:
|
||||
# Try other encodings
|
||||
for encoding in ['latin-1', 'cp1252', 'iso-8859-1']:
|
||||
try:
|
||||
filePreview = fileData.decode(encoding)
|
||||
previewData = filePreview
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
# For images, return base64 encoded data
|
||||
if mimeType.startswith("image/"):
|
||||
import base64
|
||||
previewData = base64.b64encode(fileData).decode('utf-8')
|
||||
base64Encoded = True
|
||||
|
||||
# Create ChatDocument instance
|
||||
return ChatDocument(
|
||||
id=fileId,
|
||||
fileId=fileId,
|
||||
fileName=file.get("name"),
|
||||
fileSize=file.get("size"),
|
||||
mimeType=mimeType,
|
||||
contents=[{
|
||||
"sequenceNr": 1,
|
||||
"name": file.get("name"),
|
||||
"mimeType": mimeType,
|
||||
"data": previewData,
|
||||
"metadata": {
|
||||
"base64Encoded": base64Encoded,
|
||||
"isPreviewable": isText or mimeType.startswith("image/")
|
||||
}
|
||||
}]
|
||||
)
|
||||
return ChatDocument(**document)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -497,9 +437,10 @@ async def preview_file(
|
|||
@router.get("/files/{fileId}/download")
|
||||
@limiter.limit("30/minute")
|
||||
async def download_file(
|
||||
request: Request,
|
||||
fileId: str = Path(..., description="ID of the file to download"),
|
||||
currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
):
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Response:
|
||||
"""Download a file."""
|
||||
try:
|
||||
# Get service container
|
||||
|
|
@ -529,3 +470,54 @@ async def download_file(
|
|||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error downloading file: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/workflows", response_model=List[ChatWorkflow])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_workflows(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[ChatWorkflow]:
|
||||
"""Get all workflows for current user"""
|
||||
try:
|
||||
# Get workflow interface with current user context
|
||||
workflowInterface = getInterface(currentUser)
|
||||
|
||||
# Get workflows
|
||||
workflows = workflowInterface.getWorkflows()
|
||||
return workflows
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflows: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get workflows: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/workflows/{workflow_id}", response_model=ChatWorkflow)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_workflow(
|
||||
request: Request,
|
||||
workflow_id: str,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> ChatWorkflow:
|
||||
"""Get workflow by ID"""
|
||||
try:
|
||||
# Get workflow interface with current user context
|
||||
workflowInterface = getInterface(currentUser)
|
||||
|
||||
# Get workflow
|
||||
workflow = workflowInterface.getWorkflow(workflow_id)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Workflow not found"
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflow: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get workflow: {str(e)}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from slowapi.util import get_remote_address
|
|||
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.interfaces.serviceAppClass import getRootInterface
|
||||
from modules.interfaces.serviceAppModel import Session, AuthEvent, UserPrivilege
|
||||
from modules.interfaces.serviceAppModel import Session, AuthEvent, UserPrivilege, User
|
||||
|
||||
# Get Config Data
|
||||
SECRET_KEY = APP_CONFIG.get("APP_JWT_SECRET_SECRET")
|
||||
|
|
@ -72,7 +72,7 @@ def createRefreshToken(data: dict) -> Tuple[str, datetime]:
|
|||
|
||||
return encodedJwt, expire
|
||||
|
||||
def _getUserBase(token: str = Depends(oauth2Scheme)) -> Dict[str, Any]:
|
||||
def _getUserBase(token: str = Depends(oauth2Scheme)) -> User:
|
||||
"""
|
||||
Extracts and validates the current user from the JWT token.
|
||||
|
||||
|
|
@ -80,7 +80,7 @@ def _getUserBase(token: str = Depends(oauth2Scheme)) -> Dict[str, Any]:
|
|||
token: JWT Token from the Authorization header
|
||||
|
||||
Returns:
|
||||
User data
|
||||
User model instance
|
||||
|
||||
Raises:
|
||||
HTTPException: For invalid token or user
|
||||
|
|
@ -122,20 +122,20 @@ def _getUserBase(token: str = Depends(oauth2Scheme)) -> Dict[str, Any]:
|
|||
logger.warning(f"User {username} not found")
|
||||
raise credentialsException
|
||||
|
||||
if user.get("disabled", False):
|
||||
if user.disabled:
|
||||
logger.warning(f"User {username} is disabled")
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User is disabled")
|
||||
|
||||
# Ensure the user has the correct context
|
||||
if str(user.get("mandateId")) != str(mandateId) or str(user.get("id")) != str(userId):
|
||||
logger.error(f"User context mismatch: token(mandateId={mandateId}, userId={userId}) vs user(mandateId={user.get('mandateId')}, id={user.get('id')})")
|
||||
if str(user.mandateId) != str(mandateId) or str(user.id) != str(userId):
|
||||
logger.error(f"User context mismatch: token(mandateId={mandateId}, userId={userId}) vs user(mandateId={user.mandateId}, id={user.id})")
|
||||
raise credentialsException
|
||||
|
||||
return user
|
||||
|
||||
def getCurrentUser(currentUser: Dict[str, Any] = Depends(_getUserBase)) -> Dict[str, Any]:
|
||||
def getCurrentUser(currentUser: User = Depends(_getUserBase)) -> User:
|
||||
"""Get current active user with additional validation."""
|
||||
if currentUser.get("disabled", False):
|
||||
if currentUser.disabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User is disabled"
|
||||
|
|
@ -155,7 +155,17 @@ def createUserSession(userId: str, tokenId: str, request: Request) -> Session:
|
|||
)
|
||||
|
||||
# Save session to database
|
||||
appInterface.db.recordCreate("sessions", session.model_dump())
|
||||
appInterface.db.recordCreate("sessions", session.to_dict())
|
||||
|
||||
# Log auth event
|
||||
event = AuthEvent(
|
||||
userId=userId,
|
||||
eventType="login",
|
||||
details={"method": "local"},
|
||||
ipAddress=request.client.host if request.client else None,
|
||||
userAgent=request.headers.get("user-agent")
|
||||
)
|
||||
appInterface.db.recordCreate("auth_events", event.to_dict())
|
||||
|
||||
return session
|
||||
|
||||
|
|
@ -172,7 +182,7 @@ def logAuthEvent(userId: str, eventType: str, details: Dict[str, Any], request:
|
|||
)
|
||||
|
||||
# Save event to database
|
||||
appInterface.db.recordCreate("auth_events", event.model_dump())
|
||||
appInterface.db.recordCreate("auth_events", event.to_dict())
|
||||
|
||||
def validateSession(sessionId: str) -> bool:
|
||||
"""Validate a user session."""
|
||||
|
|
|
|||
|
|
@ -21,7 +21,10 @@ class BaseModelWithUI(BaseModel):
|
|||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary with proper validation"""
|
||||
return self.model_dump()
|
||||
# Handle both Pydantic v1 and v2
|
||||
if hasattr(self, 'model_dump'):
|
||||
return self.model_dump() # Pydantic v2
|
||||
return self.dict() # Pydantic v1
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'BaseModelWithUI':
|
||||
|
|
@ -29,23 +32,47 @@ class BaseModelWithUI(BaseModel):
|
|||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def getModelAttributeDefinitions(cls) -> Dict[str, Any]:
|
||||
def getModelAttributeDefinitions(cls) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get attribute definitions for this model class.
|
||||
Override this method in model classes to provide custom attribute definitions.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary of attribute definitions
|
||||
List[Dict[str, Any]]: List of attribute definitions
|
||||
"""
|
||||
return {
|
||||
name: {
|
||||
attributes = []
|
||||
|
||||
# Handle both Pydantic v1 and v2
|
||||
if hasattr(cls, 'model_fields'): # Pydantic v2
|
||||
fields = cls.model_fields
|
||||
for name, field in fields.items():
|
||||
attributes.append({
|
||||
"name": name,
|
||||
"type": field.annotation.__name__ if hasattr(field.annotation, "__name__") else str(field.annotation),
|
||||
"required": field.is_required() if hasattr(field, "is_required") else True,
|
||||
"description": field.description if hasattr(field, "description") else "",
|
||||
"label": cls.fieldLabels.get(name, Label(default=name)).getLabel() if hasattr(cls, "fieldLabels") else name
|
||||
}
|
||||
for name, field in cls.model_fields.items()
|
||||
}
|
||||
"label": cls.fieldLabels.get(name, Label(default=name)).getLabel() if hasattr(cls, "fieldLabels") else name,
|
||||
"placeholder": f"Please enter {name}",
|
||||
"editable": True,
|
||||
"visible": True,
|
||||
"order": len(attributes)
|
||||
})
|
||||
else: # Pydantic v1
|
||||
fields = cls.__fields__
|
||||
for name, field in fields.items():
|
||||
attributes.append({
|
||||
"name": name,
|
||||
"type": field.type_.__name__ if hasattr(field.type_, "__name__") else str(field.type_),
|
||||
"required": field.required,
|
||||
"description": field.field_info.description if hasattr(field.field_info, "description") else "",
|
||||
"label": cls.fieldLabels.get(name, Label(default=name)).getLabel() if hasattr(cls, "fieldLabels") else name,
|
||||
"placeholder": f"Please enter {name}",
|
||||
"editable": True,
|
||||
"visible": True,
|
||||
"order": len(attributes)
|
||||
})
|
||||
|
||||
return attributes
|
||||
|
||||
def getModelAttributes(modelClass):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -10,7 +10,14 @@ import logging
|
|||
from typing import Any, Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
# Set up logging
|
||||
# Set up basic logging for configuration loading
|
||||
logging.basicConfig(
|
||||
level=logging.WARNING,
|
||||
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
handlers=[logging.StreamHandler()]
|
||||
)
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Configuration:
|
||||
|
|
@ -34,11 +41,8 @@ class Configuration:
|
|||
|
||||
def _loadConfig(self):
|
||||
"""Load configuration from config.ini file in flattened format"""
|
||||
# Find config.ini file (look in current directory and parent directory)
|
||||
configPath = Path('config.ini')
|
||||
if not configPath.exists():
|
||||
# Try in parent directory
|
||||
configPath = Path('../config.ini')
|
||||
# Find config.ini file in the gateway directory
|
||||
configPath = Path(__file__).parent.parent.parent / 'config.ini'
|
||||
if not configPath.exists():
|
||||
logger.warning(f"Configuration file not found at {configPath.absolute()}")
|
||||
return
|
||||
|
|
@ -75,11 +79,8 @@ class Configuration:
|
|||
|
||||
def _loadEnv(self):
|
||||
"""Load environment variables from .env file"""
|
||||
# Find .env file (look in current directory and parent directory)
|
||||
envPath = Path('.env')
|
||||
if not envPath.exists():
|
||||
# Try in parent directory
|
||||
envPath = Path('../.env')
|
||||
# Find .env file in the gateway directory
|
||||
envPath = Path(__file__).parent.parent.parent / 'env_dev.env'
|
||||
if not envPath.exists():
|
||||
logger.warning(f"Environment file not found at {envPath.absolute()}")
|
||||
return
|
||||
|
|
|
|||
|
|
@ -192,7 +192,7 @@ class AgentManager:
|
|||
performance={},
|
||||
progress=0.0
|
||||
),
|
||||
Task(**{**task.model_dump(), "status": "failed", "error": error_msg})
|
||||
Task(**{**task.to_dict(), "status": "failed", "error": error_msg})
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,155 +0,0 @@
|
|||
"""
|
||||
Agent Registry Module for managing and initializing agents.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import importlib
|
||||
from typing import Dict, Any, List, Optional
|
||||
from modules.workflow.agentBase import AgentBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AgentRegistry:
|
||||
"""Central registry for all available agents in the system."""
|
||||
|
||||
_instance = None
|
||||
|
||||
@classmethod
|
||||
def getInstance(cls):
|
||||
"""Return a singleton instance of the agent registry."""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the agent registry."""
|
||||
if AgentRegistry._instance is not None:
|
||||
raise RuntimeError("Singleton instance already exists - use getInstance()")
|
||||
|
||||
self.agents: Dict[str, AgentBase] = {}
|
||||
self._loadAgents()
|
||||
|
||||
def initialize(self, service=None):
|
||||
"""Initialize or update the registry with workflow manager and service references."""
|
||||
if service:
|
||||
# Validate required interfaces
|
||||
required_interfaces = ['base', 'msft', 'google']
|
||||
missing_interfaces = []
|
||||
for interface in required_interfaces:
|
||||
if not hasattr(service, interface):
|
||||
missing_interfaces.append(interface)
|
||||
|
||||
if missing_interfaces:
|
||||
logger.warning(f"Service container missing required interfaces: {', '.join(missing_interfaces)}")
|
||||
return False
|
||||
|
||||
# Initialize agents with service
|
||||
for agent in self.agents.values():
|
||||
if service and hasattr(agent, 'setService'):
|
||||
agent.setService(service)
|
||||
|
||||
return True
|
||||
|
||||
def _loadAgents(self):
|
||||
"""Load all available agents from modules."""
|
||||
logger.info("Loading agent modules...")
|
||||
|
||||
# List of agent modules to load
|
||||
agentModules = []
|
||||
agentDir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "agents")
|
||||
|
||||
# Search the directory for agent modules
|
||||
for filename in os.listdir(agentDir):
|
||||
if filename.startswith("agent") and filename.endswith(".py"):
|
||||
agentModules.append(filename[0:-3]) # Remove .py extension
|
||||
|
||||
if not agentModules:
|
||||
logger.warning("No agent modules found")
|
||||
return
|
||||
|
||||
logger.info(f"{len(agentModules)} agent modules found")
|
||||
|
||||
# Load each agent module
|
||||
for moduleName in agentModules:
|
||||
try:
|
||||
# Import the module
|
||||
module = importlib.import_module(f"modules.agents.{moduleName}")
|
||||
|
||||
# Look for agent class or get_*_agent function
|
||||
agentName = moduleName.split("agent")[-1]
|
||||
className = f"Agent{agentName}"
|
||||
getterName = f"getAgent{agentName}"
|
||||
|
||||
agent = None
|
||||
|
||||
# Try to get the agent via the get*Agent function
|
||||
if hasattr(module, getterName):
|
||||
getterFunc = getattr(module, getterName)
|
||||
agent = getterFunc()
|
||||
logger.info(f"Agent '{agent.name}' loaded via {getterName}()")
|
||||
|
||||
# Alternatively, try to instantiate the agent directly
|
||||
elif hasattr(module, className):
|
||||
agentClass = getattr(module, className)
|
||||
agent = agentClass()
|
||||
logger.info(f"Agent '{agent.name}' directly instantiated")
|
||||
|
||||
if agent:
|
||||
# Register the agent
|
||||
self.registerAgent(agent)
|
||||
else:
|
||||
logger.warning(f"No agent class or getter function found in module {moduleName}")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Module {moduleName} could not be imported: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading agent from module {moduleName}: {e}")
|
||||
|
||||
def registerAgent(self, agent):
|
||||
"""
|
||||
Register an agent in the registry.
|
||||
|
||||
Args:
|
||||
agent: The agent to register
|
||||
"""
|
||||
agentId = getattr(agent, 'name', "unknown_agent")
|
||||
self.agents[agentId] = agent
|
||||
logger.debug(f"Agent '{agent.name}' registered")
|
||||
|
||||
def getAgent(self, agentIdentifier: str):
|
||||
"""
|
||||
Return an agent instance
|
||||
Args:
|
||||
agentIdentifier: ID or type of the desired agent
|
||||
Returns:
|
||||
Agent instance or None if not found
|
||||
"""
|
||||
if agentIdentifier in self.agents:
|
||||
return self.agents[agentIdentifier]
|
||||
logger.error(f"Agent with identifier '{agentIdentifier}' not found")
|
||||
return None
|
||||
|
||||
def getAllAgents(self) -> Dict[str, AgentBase]:
|
||||
"""
|
||||
Get all registered agents.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping agent names to agent instances
|
||||
"""
|
||||
return self.agents.copy()
|
||||
|
||||
def getAgentInfos(self) -> List[Dict[str, Any]]:
|
||||
"""Return information about all registered agents."""
|
||||
agentInfos = []
|
||||
seenAgents = set()
|
||||
for agent in self.agents.values():
|
||||
if agent not in seenAgents:
|
||||
agentInfos.append(agent.getAgentInfo())
|
||||
seenAgents.add(agent)
|
||||
return agentInfos
|
||||
|
||||
|
||||
# Singleton factory for the agent registry
|
||||
def getAgentRegistry():
|
||||
return AgentRegistry.getInstance()
|
||||
|
|
@ -20,7 +20,7 @@ from modules.workflow.taskManager import getTaskManager
|
|||
from modules.workflow.documentManager import getDocumentManager
|
||||
from modules.interfaces.serviceChatModel import (
|
||||
UserInputRequest, ChatWorkflow, ChatMessage, ChatLog,
|
||||
ChatDocument, ChatStat, Workflow, Task, AgentResponse
|
||||
ChatDocument, ChatStat, Task, AgentResponse, AgentProfile
|
||||
)
|
||||
|
||||
# Configure logger
|
||||
|
|
@ -360,8 +360,8 @@ class WorkflowManager:
|
|||
self.service.functions.updateWorkflow(workflow.id, {
|
||||
"status": workflow.status,
|
||||
"lastActivity": workflow.lastActivity,
|
||||
"stats": workflow.stats.model_dump(),
|
||||
"messages": [msg.model_dump() for msg in workflow.messages]
|
||||
"stats": workflow.stats.to_dict(),
|
||||
"messages": [msg.to_dict() for msg in workflow.messages]
|
||||
})
|
||||
|
||||
return workflow
|
||||
|
|
@ -380,7 +380,7 @@ class WorkflowManager:
|
|||
self.service.functions.updateWorkflow(workflow.id, {
|
||||
"status": "failed",
|
||||
"lastActivity": workflow.lastActivity,
|
||||
"stats": workflow.stats.model_dump()
|
||||
"stats": workflow.stats.to_dict()
|
||||
})
|
||||
|
||||
self.logAdd(workflow, f"Workflow failed: {str(e)}", level="error", progress=100)
|
||||
|
|
@ -421,7 +421,7 @@ class WorkflowManager:
|
|||
)
|
||||
|
||||
# Save to database - only the workflow metadata
|
||||
workflowDb = workflow.model_dump()
|
||||
workflowDb = workflow.to_dict()
|
||||
self.service.functions.createWorkflow(workflowDb)
|
||||
|
||||
self.logAdd(workflow, GLOBAL_WORKFLOW_LABELS["workflowStatusMessages"]["init"], level="info", progress=0)
|
||||
|
|
@ -456,7 +456,7 @@ class WorkflowManager:
|
|||
"status": workflow.status,
|
||||
"lastActivity": workflow.lastActivity,
|
||||
"currentRound": workflow.currentRound,
|
||||
"stats": workflow.stats.model_dump() # Include updated stats
|
||||
"stats": workflow.stats.to_dict() # Include updated stats
|
||||
}
|
||||
self.service.functions.updateWorkflow(workflowId, workflowUpdate)
|
||||
|
||||
|
|
@ -623,7 +623,7 @@ JSON_OUTPUT = {{
|
|||
logger.debug(f"PROJECT MANAGER Planning answer: {projectManagerOutput}")
|
||||
return self.parseJsonResponse(projectManagerOutput)
|
||||
|
||||
async def agentProcessing(self, task: Dict[str, Any], workflow: ChatWorkflow) -> List[Dict[str, Any]]:
|
||||
async def agentProcessing(self, task: Task, workflow: ChatWorkflow) -> List[ChatDocument]:
|
||||
"""
|
||||
Process a single agent task from the workflow (State 5: Agent Execution).
|
||||
Uses the new Task and AgentResponse models.
|
||||
|
|
@ -660,7 +660,7 @@ JSON_OUTPUT = {{
|
|||
|
||||
# Update in database
|
||||
self.service.functions.updateWorkflow(workflow.id, {
|
||||
"stats": workflow.stats.model_dump()
|
||||
"stats": workflow.stats.to_dict()
|
||||
})
|
||||
|
||||
# Log the agent response
|
||||
|
|
@ -685,7 +685,7 @@ JSON_OUTPUT = {{
|
|||
self.logAdd(workflow, errorMsg, level="error")
|
||||
return []
|
||||
|
||||
async def generateFinalMessage(self, objUserResponse: str, objFinalDocuments: List[str], objResults: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
async def generateFinalMessage(self, objUserResponse: str, objFinalDocuments: List[str], objResults: List[Dict[str, Any]]) -> ChatMessage:
|
||||
"""
|
||||
Creates the final response message with review of promised and delivered documents (State 6: Final Response Generation).
|
||||
|
||||
|
|
@ -857,7 +857,7 @@ filesDelivered = {self.parseJson2text(matchingDocuments)}
|
|||
|
||||
# Update workflow in database
|
||||
self.service.functions.updateWorkflow(workflow.id, {
|
||||
"messages": [msg.model_dump() for msg in workflow.messages]
|
||||
"messages": [msg.to_dict() for msg in workflow.messages]
|
||||
})
|
||||
|
||||
return messageObject
|
||||
|
|
@ -931,7 +931,7 @@ filesDelivered = {self.parseJson2text(matchingDocuments)}
|
|||
workflow.stats.tokensUsed += tokensUsed
|
||||
|
||||
# Create ChatMessage object
|
||||
chatMessage = ChatMessage(**message.model_dump())
|
||||
chatMessage = ChatMessage(**message.to_dict())
|
||||
|
||||
# Add message to workflow
|
||||
workflow.messages.append(chatMessage)
|
||||
|
|
@ -947,13 +947,13 @@ filesDelivered = {self.parseJson2text(matchingDocuments)}
|
|||
workflow.lastActivity = currentTime
|
||||
|
||||
# Save to database - first the message itself
|
||||
self.service.functions.createWorkflowMessage(chatMessage.model_dump())
|
||||
self.service.functions.createWorkflowMessage(chatMessage.to_dict())
|
||||
|
||||
# Then save the workflow with updated references and statistics
|
||||
workflowUpdate = {
|
||||
"lastActivity": currentTime,
|
||||
"messageIds": workflow.messageIds,
|
||||
"stats": workflow.stats.model_dump() # Include updated statistics
|
||||
"stats": workflow.stats.to_dict() # Include updated statistics
|
||||
}
|
||||
self.service.functions.updateWorkflow(workflow.id, workflowUpdate)
|
||||
|
||||
|
|
@ -1018,7 +1018,7 @@ filesDelivered = {self.parseJson2text(matchingDocuments)}
|
|||
logger.log(logLevel, f"[Workflow {workflow.id}] {message}")
|
||||
|
||||
# Save to database
|
||||
self.service.functions.saveWorkflowLog(workflow.id, logEntry.model_dump())
|
||||
self.service.functions.saveWorkflowLog(workflow.id, logEntry.to_dict())
|
||||
|
||||
return logId
|
||||
|
||||
|
|
@ -1109,7 +1109,7 @@ filesDelivered = {self.parseJson2text(matchingDocuments)}
|
|||
|
||||
return fileIds
|
||||
|
||||
def getAvailableDocuments(self, workflow: ChatWorkflow, messageUser: ChatMessage) -> List[Dict[str, Any]]:
|
||||
def getAvailableDocuments(self, workflow: ChatWorkflow, messageUser: ChatMessage) -> List[ChatDocument]:
|
||||
"""
|
||||
Determines all currently available documents from user input and already generated documents.
|
||||
|
||||
|
|
@ -1171,7 +1171,7 @@ filesDelivered = {self.parseJson2text(matchingDocuments)}
|
|||
logger.info(f"Available documents: {len(availableDocs)}")
|
||||
return availableDocs
|
||||
|
||||
def agentProfiles(self) -> List[Dict[str, Any]]:
|
||||
def agentProfiles(self) -> List[AgentProfile]:
|
||||
"""
|
||||
Gets information about all available agents.
|
||||
|
||||
|
|
@ -1247,7 +1247,7 @@ filesDelivered = {self.parseJson2text(matchingDocuments)}
|
|||
"userLanguage": "en"
|
||||
}
|
||||
|
||||
def _createWorkflowData(self, workflow: ChatWorkflow) -> Dict[str, Any]:
|
||||
def _createWorkflowData(self, workflow: ChatWorkflow) -> ChatWorkflow:
|
||||
"""Creates a workflow data structure."""
|
||||
return {
|
||||
"mandateId": self.functions.mandateId,
|
||||
|
|
@ -1256,7 +1256,7 @@ filesDelivered = {self.parseJson2text(matchingDocuments)}
|
|||
"status": workflow.status,
|
||||
"startedAt": workflow.startedAt,
|
||||
"lastActivity": workflow.lastActivity,
|
||||
"stats": workflow.stats.model_dump()
|
||||
"stats": workflow.stats.to_dict()
|
||||
}
|
||||
|
||||
def _checkFileAccess(self, fileId: int) -> bool:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,13 @@
|
|||
....................... TASKS
|
||||
|
||||
Agents and Manager:
|
||||
- To adapt prompts to match document handling, done by agents
|
||||
- agents to use service object and to work stepwise:
|
||||
1. to extract document content with prompts
|
||||
2. to run ai propmt with integrated content-data in the prompt, including document reference (name, id)
|
||||
3. to analyse success and to give back instruction to task manager
|
||||
4. task manager to add a task based on agents result and feedback
|
||||
- document extraction to have error handling for big documents. if document too large, then to get content in pieces - depending on document type
|
||||
|
||||
Walkthroughs:
|
||||
- register
|
||||
|
|
@ -8,6 +16,9 @@ Walkthroughs:
|
|||
- management pages
|
||||
- workflow
|
||||
|
||||
Install a Test environment with same prod_env
|
||||
- add CORS url names to prod_env
|
||||
-
|
||||
|
||||
----------------------- OPEN
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue