refactory complete to review and test
This commit is contained in:
parent
e6ca2bad17
commit
628cca0ed4
26 changed files with 1981 additions and 1179 deletions
3
app.py
3
app.py
|
|
@ -119,3 +119,6 @@ app.include_router(workflowRouter)
|
|||
|
||||
from modules.routes.routeMsft import router as msftRouter
|
||||
app.include_router(msftRouter)
|
||||
|
||||
from modules.routes.routeGoogle import router as googleRouter
|
||||
app.include_router(googleRouter)
|
||||
|
|
|
|||
|
|
@ -50,3 +50,8 @@ Agent_Coder_EXECUTION_RETRY = 5
|
|||
Service_MSFT_CLIENT_ID = c7e7112d-61dc-4f3a-8cd3-08cc4cd7504c
|
||||
Service_MSFT_CLIENT_SECRET = Kxf8Q~2lJIteZ~JaI32kMf1lfaWKATqxXiNiFbzV
|
||||
Service_MSFT_TENANT_ID = common
|
||||
|
||||
# Google Service configuration
|
||||
Service_GOOGLE_CLIENT_ID = your-google-client-id
|
||||
Service_GOOGLE_CLIENT_SECRET = your-google-client-secret
|
||||
Service_GOOGLE_REDIRECT_URI = http://localhost:8000/api/google/auth/callback
|
||||
|
|
|
|||
|
|
@ -40,5 +40,6 @@ APP_LOGGING_FILE_ENABLED = True
|
|||
APP_LOGGING_ROTATION_SIZE = 10485760
|
||||
APP_LOGGING_BACKUP_COUNT = 5
|
||||
|
||||
# Agent Mail
|
||||
# Service Redirects
|
||||
Service_MSFT_REDIRECT_URI = http://localhost:8000/api/msft/auth/callback
|
||||
Service_GOOGLE_REDIRECT_URI = http://localhost:8000/api/google/auth/callback
|
||||
|
|
@ -40,5 +40,6 @@ APP_LOGGING_FILE_ENABLED = True
|
|||
APP_LOGGING_ROTATION_SIZE = 10485760
|
||||
APP_LOGGING_BACKUP_COUNT = 5
|
||||
|
||||
# Service MSFT
|
||||
# Service Redirects
|
||||
Service_MSFT_REDIRECT_URI = https://gateway.poweron-center.net/api/msft/auth/callback
|
||||
Service_GOOGLE_REDIRECT_URI = http://gateway.poweron-center.net/api/google/auth/callback
|
||||
|
|
|
|||
|
|
@ -7,12 +7,15 @@ import logging
|
|||
import json
|
||||
import io
|
||||
import base64
|
||||
from typing import Dict, Any, List
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
from modules.workflow.agentBase import AgentBase
|
||||
from modules.interfaces.lucydomModel import ChatContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -634,97 +637,33 @@ class AgentAnalyst(AgentBase):
|
|||
return self.formatAgentDocumentOutput(outputLabel, imgData, f"image/{formatType}")
|
||||
|
||||
async def _createDataDocument(self, datasets: Dict, prompt: str, outputLabel: str,
|
||||
analysisPlan: Dict, description: str) -> Dict:
|
||||
analysisPlan: Dict, description: str) -> ChatContent:
|
||||
"""
|
||||
Create a data document (e.g., CSV, JSON) based on analysis.
|
||||
Create a data document (CSV, JSON, Excel) from analysis results.
|
||||
|
||||
Args:
|
||||
datasets: Dictionary of datasets
|
||||
prompt: Original task prompt
|
||||
outputLabel: Output filename
|
||||
analysisPlan: Analysis plan from AI
|
||||
analysisPlan: Analysis plan
|
||||
description: Output description
|
||||
|
||||
Returns:
|
||||
Data document
|
||||
ChatContent object
|
||||
"""
|
||||
# Determine format from filename
|
||||
formatType = outputLabel.split('.')[-1].lower()
|
||||
|
||||
# If no datasets available, return error message
|
||||
if not datasets:
|
||||
return {
|
||||
"label": outputLabel,
|
||||
"content": f"No data available for processing into {formatType} format.",
|
||||
"metadata": {
|
||||
"contentType": "text/plain"
|
||||
}
|
||||
}
|
||||
|
||||
# Generate data processing instructions
|
||||
dataPrompt = f"""
|
||||
Create Python code to process datasets and generate a {formatType} file for:
|
||||
|
||||
TASK: {prompt}
|
||||
|
||||
OUTPUT REQUIREMENTS:
|
||||
- Format: {formatType}
|
||||
- Filename: {outputLabel}
|
||||
- Description: {description}
|
||||
|
||||
ANALYSIS CONTEXT:
|
||||
{json.dumps(analysisPlan, indent=2)}
|
||||
|
||||
AVAILABLE DATASETS:
|
||||
"""
|
||||
|
||||
# Add dataset info
|
||||
for name, df in datasets.items():
|
||||
dataPrompt += f"\nDataset '{name}':\n"
|
||||
dataPrompt += f"- Shape: {df.shape}\n"
|
||||
dataPrompt += f"- Columns: {df.columns.tolist()}\n"
|
||||
dataPrompt += f"- Sample data: {df.head(3).to_dict(orient='records')}\n"
|
||||
|
||||
dataPrompt += """
|
||||
Generate Python code that:
|
||||
1. Processes the available dataset(s)
|
||||
2. Performs necessary transformations, aggregations, or calculations
|
||||
3. Outputs the result in the requested format
|
||||
4. Returns the content as a string variable named 'result'
|
||||
|
||||
Return ONLY executable Python code, no explanations or markdown.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Get data processing code from AI
|
||||
dataCode = await self.service.base.callAi([
|
||||
{"role": "system", "content": "You are a data processing expert. Provide only executable Python code."},
|
||||
{"role": "user", "content": dataPrompt}
|
||||
], produceUserAnswer = True)
|
||||
# Determine format from filename
|
||||
formatType = outputLabel.split('.')[-1].lower() if '.' in outputLabel else "csv"
|
||||
|
||||
# Clean code
|
||||
dataCode = dataCode.replace("```python", "").replace("```", "").strip()
|
||||
|
||||
# Setup execution environment
|
||||
localVars = {"pd": pd, "np": __import__('numpy'), "io": io}
|
||||
|
||||
# Add datasets to local variables
|
||||
for name, df in datasets.items():
|
||||
# Create a sanitized variable name
|
||||
varName = ''.join(c if c.isalnum() else '_' for c in name)
|
||||
localVars[varName] = df
|
||||
|
||||
# Also add with standard names for simpler code
|
||||
if "df" not in localVars:
|
||||
localVars["df"] = df
|
||||
elif "df2" not in localVars:
|
||||
localVars["df2"] = df
|
||||
|
||||
# Execute the code
|
||||
exec(dataCode, globals(), localVars)
|
||||
|
||||
# Get the result
|
||||
result = localVars.get("result", "No output was generated.")
|
||||
# Process data based on format
|
||||
if formatType == "csv":
|
||||
result = self._convertToCsv(datasets)
|
||||
elif formatType == "json":
|
||||
result = json.dumps(datasets, indent=2)
|
||||
elif formatType == "xlsx":
|
||||
result = self._convertToExcel(datasets)
|
||||
else:
|
||||
result = str(datasets)
|
||||
|
||||
# Determine content type
|
||||
contentType = "text/csv" if formatType == "csv" else \
|
||||
|
|
@ -734,58 +673,41 @@ class AgentAnalyst(AgentBase):
|
|||
|
||||
return self.formatAgentDocumentOutput(outputLabel, result, contentType)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating data document: {str(e)}", exc_info=True)
|
||||
|
||||
return {
|
||||
"label": outputLabel,
|
||||
"content": f"Error generating {formatType} document: {str(e)}",
|
||||
"metadata": {
|
||||
"contentType": "text/plain"
|
||||
}
|
||||
}
|
||||
errorContent = f"Error generating {formatType} document: {str(e)}"
|
||||
return self.formatAgentDocumentOutput(outputLabel, errorContent, "text/plain")
|
||||
|
||||
async def _createTextDocument(self, datasets: Dict, context: str, prompt: str,
|
||||
outputLabel: str, formatType: str,
|
||||
analysisPlan: Dict, description: str) -> Dict:
|
||||
analysisPlan: Dict, description: str) -> ChatContent:
|
||||
"""
|
||||
Create a text document (report, analysis, etc.) based on analysis.
|
||||
Create a text document (markdown, HTML, text) from analysis results.
|
||||
|
||||
Args:
|
||||
datasets: Dictionary of datasets
|
||||
context: Document context text
|
||||
context: Document context
|
||||
prompt: Original task prompt
|
||||
outputLabel: Output filename
|
||||
formatType: Output format type
|
||||
analysisPlan: Analysis plan from AI
|
||||
formatType: Output format
|
||||
analysisPlan: Analysis plan
|
||||
description: Output description
|
||||
|
||||
Returns:
|
||||
Text document
|
||||
ChatContent object
|
||||
"""
|
||||
# Create dataset summaries
|
||||
try:
|
||||
# Generate dataset summaries
|
||||
datasetSummaries = []
|
||||
for name, df in datasets.items():
|
||||
summary = f"Dataset: {name}\n"
|
||||
summary += f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n"
|
||||
summary += f"- Columns: {', '.join(df.columns.tolist())}\n"
|
||||
|
||||
# Basic statistics for numeric columns
|
||||
numericCols = df.select_dtypes(include=['number']).columns
|
||||
if len(numericCols) > 0:
|
||||
summary += "- Numeric Columns Stats:\n"
|
||||
for col in numericCols[:3]: # Limit to first 3
|
||||
stats = df[col].describe()
|
||||
summary += f" - {col}: min={stats['min']:.2f}, max={stats['max']:.2f}, mean={stats['mean']:.2f}\n"
|
||||
|
||||
summary = f"\nDataset: {name}\n"
|
||||
summary += f"Shape: {df.shape}\n"
|
||||
summary += f"Columns: {', '.join(df.columns)}\n"
|
||||
if not df.empty:
|
||||
summary += f"Sample data:\n{df.head(3).to_string()}\n"
|
||||
datasetSummaries.append(summary)
|
||||
|
||||
# Determine content type based on format
|
||||
contentType = "text/markdown" if formatType in ["md", "markdown"] else \
|
||||
"text/html" if formatType == "html" else \
|
||||
"text/plain"
|
||||
|
||||
# Generate analysis prompt
|
||||
analysisPrompt = f"""
|
||||
Create a detailed {formatType} document for:
|
||||
|
|
@ -816,7 +738,6 @@ class AgentAnalyst(AgentBase):
|
|||
Your response should be the complete document content in the specified format.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Get document content from AI
|
||||
documentContent = await self.service.base.callAi([
|
||||
{"role": "system", "content": f"You are a data analysis expert creating a {formatType} document."},
|
||||
|
|
@ -829,6 +750,11 @@ class AgentAnalyst(AgentBase):
|
|||
elif formatType == "html" and not "<html" in documentContent.lower():
|
||||
documentContent = f"<html><body>{documentContent}</body></html>"
|
||||
|
||||
# Determine content type
|
||||
contentType = "text/markdown" if formatType in ["md", "markdown"] else \
|
||||
"text/html" if formatType == "html" else \
|
||||
"text/plain"
|
||||
|
||||
return self.formatAgentDocumentOutput(outputLabel, documentContent, contentType)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -842,13 +768,7 @@ class AgentAnalyst(AgentBase):
|
|||
else:
|
||||
content = f"Error in Analysis\n\nThere was an error generating the analysis: {str(e)}"
|
||||
|
||||
return {
|
||||
"label": outputLabel,
|
||||
"content": content,
|
||||
"metadata": {
|
||||
"contentType": contentType
|
||||
}
|
||||
}
|
||||
return self.formatAgentDocumentOutput(outputLabel, content, contentType)
|
||||
|
||||
def _getImageBase64(self, formatType: str = 'png') -> str:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -8,8 +8,10 @@ from typing import Dict, Any, List
|
|||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
import os
|
||||
|
||||
from modules.workflow.agentBase import AgentBase
|
||||
from modules.interfaces.lucydomModel import ChatContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -292,7 +294,7 @@ class AgentDocumentation(AgentBase):
|
|||
}
|
||||
|
||||
async def _createDocumentMultiStep(self, prompt: str, context: str, outputLabel: str,
|
||||
outputDescription: str, documentationPlan: Dict) -> Dict:
|
||||
outputDescription: str, documentationPlan: Dict) -> ChatContent:
|
||||
"""
|
||||
Create a document using a multi-step approach with separate AI calls for each section.
|
||||
|
||||
|
|
@ -304,8 +306,9 @@ class AgentDocumentation(AgentBase):
|
|||
documentationPlan: Documentation plan from AI
|
||||
|
||||
Returns:
|
||||
Document object
|
||||
ChatContent object
|
||||
"""
|
||||
try:
|
||||
# Determine format from filename
|
||||
formatType = outputLabel.split('.')[-1].lower() if '.' in outputLabel else "md"
|
||||
|
||||
|
|
@ -332,56 +335,8 @@ class AgentDocumentation(AgentBase):
|
|||
|
||||
# Get the detailed structure
|
||||
detailedStructure = documentationPlan.get("detailedStructure", [])
|
||||
if not detailedStructure:
|
||||
# Fallback structure if none provided
|
||||
detailedStructure = [
|
||||
{
|
||||
"title": "Introduction (Default)",
|
||||
"keyPoints": ["Purpose", "Scope"],
|
||||
"importance": "high"
|
||||
},
|
||||
{
|
||||
"title": "Main Content (Default)",
|
||||
"keyPoints": ["Core Information"],
|
||||
"importance": "high"
|
||||
},
|
||||
{
|
||||
"title": "Conclusion (Default)",
|
||||
"keyPoints": ["Summary", "Next Steps"],
|
||||
"importance": "medium"
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
# Step 1: Generate document introduction
|
||||
introPrompt = f"""
|
||||
Create the introduction for a {documentType} titled "{title}".
|
||||
|
||||
DOCUMENT OVERVIEW:
|
||||
- Type: {documentType}
|
||||
- Audience: {audience}
|
||||
- Tone: {tone}
|
||||
- Key Topics: {', '.join(keyTopics)}
|
||||
- Format: {formatType}
|
||||
|
||||
TASK CONTEXT: {prompt}
|
||||
|
||||
This introduction should:
|
||||
1. Clearly state the purpose and scope of the document
|
||||
2. Provide context and background information
|
||||
3. Outline what the reader will find in the document
|
||||
4. Set the appropriate tone for the {audience} audience
|
||||
|
||||
The introduction should be professional and engaging, but short and precise, formatted according to {formatType} standards. do not add details, which are not requested by the Task Context.
|
||||
"""
|
||||
|
||||
introduction = await self.service.base.callAi([
|
||||
{"role": "system", "content": f"You are a documentation expert creating an introduction in {formatType} format."},
|
||||
{"role": "user", "content": introPrompt}
|
||||
], produceUserAnswer = True)
|
||||
|
||||
# Step 2: Generate executive summary (if applicable)
|
||||
if documentType in ["report", "whitepaper", "case study"]:
|
||||
# Step 1: Generate executive summary
|
||||
summaryPrompt = f"""
|
||||
Create an executive summary for a {documentType} titled "{title}".
|
||||
|
||||
|
|
@ -392,67 +347,82 @@ The introduction should be professional and engaging, but short and precise, for
|
|||
|
||||
TASK CONTEXT: {prompt}
|
||||
|
||||
This executive summary should:
|
||||
1. Provide a concise overview of the entire document
|
||||
2. Highlight key findings, recommendations, or conclusions
|
||||
3. Be suitable for executives or busy readers who may only read this section
|
||||
4. Be professionally formatted according to {formatType} standards
|
||||
The executive summary should:
|
||||
1. Provide a concise overview of the document's purpose
|
||||
2. Highlight key points and findings
|
||||
3. Be clear and engaging for the target audience
|
||||
4. Set expectations for the document's content
|
||||
|
||||
Keep the summary focused and impactful, approximately 200-300 words.
|
||||
Keep the summary brief but comprehensive.
|
||||
"""
|
||||
|
||||
executiveSummary = await self.service.base.callAi([
|
||||
{"role": "system", "content": f"You are a documentation expert creating an executive summary in {formatType} format."},
|
||||
{"role": "user", "content": summaryPrompt}
|
||||
], produceUserAnswer = True)
|
||||
else:
|
||||
executiveSummary = ""
|
||||
|
||||
# Step 3: Generate each section
|
||||
# Step 2: Generate introduction
|
||||
introPrompt = f"""
|
||||
Create an introduction for a {documentType} titled "{title}".
|
||||
|
||||
DOCUMENT OVERVIEW:
|
||||
- Type: {documentType}
|
||||
- Audience: {audience}
|
||||
- Key Topics: {', '.join(keyTopics)}
|
||||
|
||||
TASK CONTEXT: {prompt}
|
||||
|
||||
The introduction should:
|
||||
1. Set the context and purpose of the document
|
||||
2. Outline the scope and objectives
|
||||
3. Preview the main topics to be covered
|
||||
4. Engage the reader's interest
|
||||
|
||||
Format the introduction according to {formatType} standards.
|
||||
"""
|
||||
|
||||
introduction = await self.service.base.callAi([
|
||||
{"role": "system", "content": f"You are a documentation expert creating an introduction in {formatType} format."},
|
||||
{"role": "user", "content": introPrompt}
|
||||
], produceUserAnswer = True)
|
||||
|
||||
# Step 3: Generate main sections
|
||||
sections = []
|
||||
|
||||
for section in detailedStructure:
|
||||
sectionTitle = section.get("title", "Section")
|
||||
keyPoints = section.get("keyPoints", [])
|
||||
subsections = section.get("subsections", [])
|
||||
importance = section.get("importance", "medium")
|
||||
|
||||
# Adjust depth based on importance
|
||||
detailLevel = "high" if importance == "high" else "medium"
|
||||
estimatedLength = section.get("estimatedLength", "medium")
|
||||
|
||||
sectionPrompt = f"""
|
||||
Create the "{sectionTitle}" section for a {documentType} titled "{title}".
|
||||
Create the {sectionTitle} section for a {documentType} titled "{title}".
|
||||
|
||||
SECTION DETAILS:
|
||||
- Title: {sectionTitle}
|
||||
- Key Points to Cover: {', '.join(keyPoints)}
|
||||
- Key Points: {', '.join(keyPoints)}
|
||||
- Subsections: {', '.join(subsections)}
|
||||
- Detail Level: {detailLevel}
|
||||
- Importance: {importance}
|
||||
- Estimated Length: {estimatedLength}
|
||||
|
||||
DOCUMENT CONTEXT:
|
||||
- Type: {documentType}
|
||||
- Audience: {audience}
|
||||
- Tone: {tone}
|
||||
- Format: {formatType}
|
||||
- Key Topics: {', '.join(keyTopics)}
|
||||
|
||||
TASK CONTEXT: {prompt}
|
||||
|
||||
AVAILABLE INFORMATION:
|
||||
{context[:500]}... (truncated)
|
||||
The section should:
|
||||
1. Cover all key points thoroughly
|
||||
2. Include relevant subsections
|
||||
3. Maintain appropriate depth based on importance
|
||||
4. Follow the document's tone and style
|
||||
|
||||
This section should:
|
||||
1. Be comprehensive and well-structured
|
||||
2. Cover all the key points listed
|
||||
3. Include the specified subsections with appropriate headings
|
||||
4. Maintain a {tone} tone suitable for the {audience} audience
|
||||
5. Be properly formatted according to {formatType} standards
|
||||
6. Include specific examples, data, or evidence where appropriate
|
||||
|
||||
Be thorough in your coverage of this section, providing substantive content focussing on the Task content.
|
||||
Format the section according to {formatType} standards.
|
||||
"""
|
||||
|
||||
sectionContent = await self.service.base.callAi([
|
||||
{"role": "system", "content": f"You are a documentation expert creating detailed content for the {sectionTitle} section."},
|
||||
{"role": "system", "content": f"You are a documentation expert creating a section in {formatType} format."},
|
||||
{"role": "user", "content": sectionPrompt}
|
||||
], produceUserAnswer = True)
|
||||
|
||||
|
|
@ -548,13 +518,7 @@ Be thorough in your coverage of this section, providing substantive content focu
|
|||
else:
|
||||
content = f"Error in Documentation\n\nThere was an error generating the documentation: {str(e)}"
|
||||
|
||||
return {
|
||||
"label": outputLabel,
|
||||
"content": content,
|
||||
"metadata": {
|
||||
"contentType": contentType
|
||||
}
|
||||
}
|
||||
return self.formatAgentDocumentOutput(outputLabel, content, contentType)
|
||||
|
||||
|
||||
# Factory function for the Documentation agent
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from passlib.context import CryptContext
|
|||
from modules.connectors.connectorDbJson import DatabaseConnector
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.interfaces.gatewayAccess import GatewayAccess
|
||||
from modules.interfaces.gatewayModel import User, Mandate, UserInDB
|
||||
from modules.interfaces.gatewayModel import User, Mandate, UserInDB, UserConnection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -22,6 +22,9 @@ logger = logging.getLogger(__name__)
|
|||
# Singleton factory for GatewayInterface instances per context
|
||||
_gatewayInterfaces = {}
|
||||
|
||||
# Root interface instance
|
||||
_rootGatewayInterface = None
|
||||
|
||||
# Password-Hashing
|
||||
pwdContext = CryptContext(schemes=["argon2"], deprecated="auto")
|
||||
|
||||
|
|
@ -32,18 +35,22 @@ class GatewayInterface:
|
|||
Manages users and mandates.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, currentUser: Dict[str, Any] = None):
|
||||
"""Initializes the Gateway Interface."""
|
||||
# Initialize variables
|
||||
self.currentUser = currentUser
|
||||
self.userId = currentUser.get("id") if currentUser else None
|
||||
self.access = None # Will be set when user context is provided
|
||||
|
||||
# Initialize database
|
||||
self._initializeDatabase()
|
||||
|
||||
# Initialize standard records if needed
|
||||
self._initRecords()
|
||||
|
||||
# Initialize variables
|
||||
self.currentUser = None
|
||||
self.userId = None
|
||||
self.access = None # Will be set when user context is provided
|
||||
# Set user context if provided
|
||||
if currentUser:
|
||||
self.setUserContext(currentUser)
|
||||
|
||||
def setUserContext(self, currentUser: Dict[str, Any]):
|
||||
"""Sets the user context for the interface."""
|
||||
|
|
@ -338,10 +345,117 @@ class GatewayInterface:
|
|||
|
||||
return User(**user)
|
||||
|
||||
def addUserConnection(self, userId: str, authority: str, externalId: str, externalUsername: str, externalEmail: Optional[str] = None) -> UserConnection:
|
||||
"""Add a new connection to an external service for a user"""
|
||||
try:
|
||||
# Get user
|
||||
user = self.getUser(userId)
|
||||
if not user:
|
||||
raise ValueError(f"User {userId} not found")
|
||||
|
||||
# Check if connection already exists
|
||||
for conn in user.connections:
|
||||
if conn.authority == authority and conn.externalId == externalId:
|
||||
raise ValueError(f"Connection to {authority} already exists for user {userId}")
|
||||
|
||||
# Create new connection
|
||||
connection = UserConnection(
|
||||
authority=authority,
|
||||
externalId=externalId,
|
||||
externalUsername=externalUsername,
|
||||
externalEmail=externalEmail
|
||||
)
|
||||
|
||||
# Add connection to user
|
||||
user.connections.append(connection)
|
||||
|
||||
# Update user record
|
||||
self.db.recordModify("users", userId, {"connections": [c.model_dump() for c in user.connections]})
|
||||
|
||||
return connection
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding user connection: {str(e)}")
|
||||
raise ValueError(f"Failed to add user connection: {str(e)}")
|
||||
|
||||
def removeUserConnection(self, userId: str, connectionId: str) -> None:
|
||||
"""Remove a connection to an external service for a user"""
|
||||
try:
|
||||
# Get user
|
||||
user = self.getUser(userId)
|
||||
if not user:
|
||||
raise ValueError(f"User {userId} not found")
|
||||
|
||||
# Find and remove connection
|
||||
user.connections = [c for c in user.connections if c.id != connectionId]
|
||||
|
||||
# Update user record
|
||||
self.db.recordModify("users", userId, {"connections": [c.model_dump() for c in user.connections]})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing user connection: {str(e)}")
|
||||
raise ValueError(f"Failed to remove user connection: {str(e)}")
|
||||
|
||||
def authenticateUser(self, username: str, password: str = None, authority: str = "local", external_token: str = None) -> Optional[User]:
|
||||
"""Authenticates a user by username and password or external authority."""
|
||||
# Clear the users table from cache and reload it
|
||||
if "users" in self.db._tablesCache:
|
||||
del self.db._tablesCache["users"]
|
||||
|
||||
# Get user by username
|
||||
user = self.getUserByUsername(username)
|
||||
|
||||
if not user:
|
||||
raise ValueError("User not found")
|
||||
|
||||
# Check if the user is disabled
|
||||
if user.disabled:
|
||||
raise ValueError("User is disabled")
|
||||
|
||||
# Handle authentication based on authority
|
||||
if authority == "local":
|
||||
if not password:
|
||||
raise ValueError("Password is required for local authentication")
|
||||
# Get the full user record with password hash for verification
|
||||
userWithPassword = UserInDB(**self.db.getRecordset("users", recordFilter={"id": user.id})[0])
|
||||
if not self._verifyPassword(password, userWithPassword.hashedPassword):
|
||||
raise ValueError("Invalid password")
|
||||
elif authority in ["microsoft", "google"]: # Support for multiple external auth providers
|
||||
# Verify that the user has the correct authentication authority
|
||||
if user.authenticationAuthority != authority:
|
||||
raise ValueError(f"User does not have {authority} authentication enabled")
|
||||
|
||||
# Verify that the user has a valid connection for this authority
|
||||
if not any(conn.authority == authority for conn in user.connections):
|
||||
raise ValueError(f"User does not have a valid {authority} connection")
|
||||
|
||||
# Verify the external token
|
||||
if not external_token:
|
||||
raise ValueError(f"External token is required for {authority} authentication")
|
||||
|
||||
# Get the appropriate auth service
|
||||
if authority == "microsoft":
|
||||
from .msftInterface import getInterface as getMsftInterface
|
||||
auth_service = getMsftInterface({"_mandateId": user._mandateId, "id": user.id})
|
||||
elif authority == "google":
|
||||
from .googleInterface import getInterface as getGoogleInterface
|
||||
auth_service = getGoogleInterface({"_mandateId": user._mandateId, "id": user.id})
|
||||
else:
|
||||
raise ValueError(f"Unsupported authentication authority: {authority}")
|
||||
|
||||
# Verify the token
|
||||
if not auth_service.verifyToken(external_token):
|
||||
raise ValueError(f"Invalid or expired {authority} token")
|
||||
else:
|
||||
raise ValueError(f"Unknown authentication authority: {authority}")
|
||||
|
||||
return user
|
||||
|
||||
def createUser(self, username: str, password: str = None, email: str = None, fullName: str = None,
|
||||
language: str = "en", disabled: bool = False,
|
||||
privilege: str = "user", authenticationAuthority: str = "local") -> User:
|
||||
"""Create a new user"""
|
||||
privilege: str = "user", authenticationAuthority: str = "local",
|
||||
externalId: str = None, externalUsername: str = None, externalEmail: str = None) -> User:
|
||||
"""Create a new user with optional external connection"""
|
||||
try:
|
||||
# Validate username
|
||||
if not username:
|
||||
|
|
@ -369,7 +483,8 @@ class GatewayInterface:
|
|||
disabled=disabled,
|
||||
privilege=privilege,
|
||||
authenticationAuthority=authenticationAuthority,
|
||||
hashedPassword=self._getPasswordHash(password) if authenticationAuthority == "local" else None
|
||||
hashedPassword=self._getPasswordHash(password) if authenticationAuthority == "local" else None,
|
||||
connections=[]
|
||||
)
|
||||
|
||||
# Create user record
|
||||
|
|
@ -377,6 +492,16 @@ class GatewayInterface:
|
|||
if not createdRecord or not createdRecord.get("id"):
|
||||
raise ValueError("Failed to create user record")
|
||||
|
||||
# Add external connection if provided
|
||||
if externalId and externalUsername:
|
||||
self.addUserConnection(
|
||||
createdRecord["id"],
|
||||
authenticationAuthority,
|
||||
externalId,
|
||||
externalUsername,
|
||||
externalEmail
|
||||
)
|
||||
|
||||
# Get created user using the returned ID
|
||||
createdUser = self.db.getRecordset("users", recordFilter={"id": createdRecord["id"]})
|
||||
if not createdUser or len(createdUser) == 0:
|
||||
|
|
@ -398,41 +523,6 @@ class GatewayInterface:
|
|||
logger.error(f"Unexpected error creating user: {str(e)}")
|
||||
raise ValueError(f"Failed to create user: {str(e)}")
|
||||
|
||||
def authenticateUser(self, username: str, password: str = None) -> Optional[User]:
|
||||
"""Authenticates a user by username and password."""
|
||||
# Clear the users table from cache and reload it
|
||||
if "users" in self.db._tablesCache:
|
||||
del self.db._tablesCache["users"]
|
||||
|
||||
# Get user by username
|
||||
user = self.getUserByUsername(username)
|
||||
|
||||
if not user:
|
||||
raise ValueError("Benutzer nicht gefunden")
|
||||
|
||||
# Check if the user is disabled
|
||||
if user.disabled:
|
||||
raise ValueError("Benutzer ist deaktiviert")
|
||||
|
||||
# Handle authentication based on authority
|
||||
auth_authority = user.authenticationAuthority
|
||||
|
||||
if auth_authority == "local":
|
||||
if not password:
|
||||
raise ValueError("Passwort ist erforderlich")
|
||||
# Get the full user record with password hash for verification
|
||||
userWithPassword = UserInDB(**self.db.getRecordset("users", recordFilter={"id": user.id})[0])
|
||||
if not self._verifyPassword(password, userWithPassword.hashedPassword):
|
||||
raise ValueError("Falsches Passwort")
|
||||
elif auth_authority == "microsoft":
|
||||
# For Microsoft users, we don't verify the password here
|
||||
# The authentication is handled by the Microsoft OAuth flow
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unbekannte Authentifizierungsmethode: {auth_authority}")
|
||||
|
||||
return user
|
||||
|
||||
def updateUser(self, userId: str, userData: Dict[str, Any]) -> User:
|
||||
"""Updates a user if current user has permission."""
|
||||
# Check if the user exists and current user has access
|
||||
|
|
@ -525,21 +615,83 @@ class GatewayInterface:
|
|||
|
||||
return success
|
||||
|
||||
def getInterface(currentUser: Dict[str, Any] = None) -> 'GatewayInterface':
|
||||
def setupLocalAuth(self, userId: str, password: str) -> User:
|
||||
"""Set up local authentication for a user who registered with Microsoft"""
|
||||
try:
|
||||
# Get user
|
||||
user = self.getUser(userId)
|
||||
if not user:
|
||||
raise ValueError(f"User {userId} not found")
|
||||
|
||||
# Validate password
|
||||
if not password:
|
||||
raise ValueError("Password is required")
|
||||
if len(password) < 8:
|
||||
raise ValueError("Password must be at least 8 characters long")
|
||||
|
||||
# Update user with local password
|
||||
userData = {
|
||||
"hashedPassword": self._getPasswordHash(password),
|
||||
"authenticationAuthority": "local" # Change to local auth
|
||||
}
|
||||
|
||||
return self.updateUser(userId, userData)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting up local authentication: {str(e)}")
|
||||
raise ValueError(f"Failed to set up local authentication: {str(e)}")
|
||||
|
||||
|
||||
def getInterface(currentUser: Dict[str, Any]) -> GatewayInterface:
|
||||
"""
|
||||
Returns a GatewayInterface instance.
|
||||
If currentUser is provided, initializes with user context.
|
||||
Otherwise, returns an instance with only database access.
|
||||
Returns a GatewayInterface instance for the current user.
|
||||
Handles initialization of database and records.
|
||||
"""
|
||||
mandateId = currentUser.get("mandateId")
|
||||
userId = currentUser.get("id")
|
||||
if not mandateId or not userId:
|
||||
raise ValueError("Invalid user context: mandateId and id are required")
|
||||
|
||||
# Create context key
|
||||
contextKey = f"{mandateId}_{userId}"
|
||||
|
||||
# Create new instance if not exists
|
||||
if "default" not in _gatewayInterfaces:
|
||||
_gatewayInterfaces["default"] = GatewayInterface()
|
||||
if contextKey not in _gatewayInterfaces:
|
||||
_gatewayInterfaces[contextKey] = GatewayInterface(currentUser)
|
||||
|
||||
interface = _gatewayInterfaces["default"]
|
||||
return _gatewayInterfaces[contextKey]
|
||||
|
||||
if currentUser:
|
||||
interface.setUserContext(currentUser)
|
||||
else:
|
||||
logger.info("Returning interface without user context")
|
||||
def getRootUser() -> Dict[str, Any]:
|
||||
"""
|
||||
Returns the root user from the database.
|
||||
This is the user with the initial ID in the users table.
|
||||
"""
|
||||
try:
|
||||
readInterface = getInterface()
|
||||
# Get the initial user ID
|
||||
initialUserId = readInterface.db.getInitialId("users")
|
||||
if not initialUserId:
|
||||
raise ValueError("No initial user ID found in database")
|
||||
|
||||
return interface
|
||||
# Get the user record
|
||||
users = readInterface.db.getRecordset("users", recordFilter={"id": initialUserId})
|
||||
if not users:
|
||||
raise ValueError(f"Root user with ID {initialUserId} not found in database")
|
||||
|
||||
return users[0]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting root user: {str(e)}")
|
||||
raise ValueError(f"Failed to get root user: {str(e)}")
|
||||
|
||||
def getRootInterface() -> GatewayInterface:
|
||||
"""
|
||||
Returns a GatewayInterface instance with root privileges.
|
||||
This is used for initial setup and user creation.
|
||||
"""
|
||||
global _rootGatewayInterface
|
||||
|
||||
if _rootGatewayInterface is None:
|
||||
rootUser = getRootUser()
|
||||
_rootGatewayInterface = GatewayInterface(rootUser)
|
||||
|
||||
return _rootGatewayInterface
|
||||
|
|
|
|||
|
|
@ -56,6 +56,30 @@ class Mandate(BaseModel):
|
|||
"language": Label(default="Language", translations={"en": "Language", "fr": "Langue"})
|
||||
}
|
||||
|
||||
class UserConnection(BaseModel):
|
||||
"""Data model for a user's connection to an external service"""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the connection")
|
||||
authority: str = Field(description="Authentication authority (microsoft, google, etc.)")
|
||||
externalId: str = Field(description="User ID in the external system")
|
||||
externalUsername: str = Field(description="Username in the external system")
|
||||
externalEmail: Optional[str] = Field(None, description="Email in the external system")
|
||||
connectedAt: datetime = Field(default_factory=datetime.now, description="When the connection was established")
|
||||
|
||||
label: Label = Field(
|
||||
default=Label(default="User Connection", translations={"en": "User Connection", "fr": "Connexion utilisateur"}),
|
||||
description="Label for the class"
|
||||
)
|
||||
|
||||
# Labels for attributes
|
||||
fieldLabels: Dict[str, Label] = {
|
||||
"id": Label(default="ID", translations={}),
|
||||
"authority": Label(default="Authority", translations={"en": "Authority", "fr": "Autorité"}),
|
||||
"externalId": Label(default="External ID", translations={"en": "External ID", "fr": "ID externe"}),
|
||||
"externalUsername": Label(default="External Username", translations={"en": "External Username", "fr": "Nom d'utilisateur externe"}),
|
||||
"externalEmail": Label(default="External Email", translations={"en": "External Email", "fr": "Email externe"}),
|
||||
"connectedAt": Label(default="Connected At", translations={"en": "Connected At", "fr": "Connecté le"})
|
||||
}
|
||||
|
||||
class User(BaseModel):
|
||||
"""Data model for a user"""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the user")
|
||||
|
|
@ -65,8 +89,9 @@ class User(BaseModel):
|
|||
language: str = Field(description="Preferred language of the user")
|
||||
disabled: Optional[bool] = Field(False, description="Indicates whether the user is disabled")
|
||||
privilege: str = Field(description="Permission level") #sysadmin,admin,user
|
||||
authenticationAuthority: str = Field(default="local", description="Authentication authority (local, microsoft)")
|
||||
authenticationAuthority: str = Field(default="local", description="Primary authentication authority (local, microsoft)")
|
||||
mandateId: str = Field(description="ID of the mandate this user belongs to")
|
||||
connections: List[UserConnection] = Field(default_factory=list, description="List of external service connections")
|
||||
|
||||
label: Label = Field(
|
||||
default=Label(default="User", translations={"en": "User", "fr": "Utilisateur"}),
|
||||
|
|
@ -83,7 +108,8 @@ class User(BaseModel):
|
|||
"language": Label(default="Language", translations={"en": "Language", "fr": "Langue"}),
|
||||
"disabled": Label(default="Disabled", translations={"en": "Disabled", "fr": "Désactivé"}),
|
||||
"privilege": Label(default="Permission level", translations={"en": "Access level", "fr": "Niveau d'accès"}),
|
||||
"authenticationAuthority": Label(default="Authentication Authority", translations={"en": "Authentication Authority", "fr": "Autorité d'authentification"})
|
||||
"authenticationAuthority": Label(default="Authentication Authority", translations={"en": "Authentication Authority", "fr": "Autorité d'authentification"}),
|
||||
"connections": Label(default="External Connections", translations={"en": "External Connections", "fr": "Connexions externes"})
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
113
modules/interfaces/googleAccess.py
Normal file
113
modules/interfaces/googleAccess.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
"""
|
||||
Access control module for Google interface.
|
||||
Handles user access management and permission checks for Google tokens.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
class GoogleAccess:
|
||||
"""
|
||||
Access control class for Google interface.
|
||||
Handles user access management and permission checks for Google tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, currentUser: Dict[str, Any], db):
|
||||
"""Initialize with user context."""
|
||||
self.currentUser = currentUser
|
||||
self._mandateId = currentUser.get("_mandateId")
|
||||
self._userId = currentUser.get("id")
|
||||
|
||||
if not self._mandateId or not self._userId:
|
||||
raise ValueError("Invalid user context: _mandateId and id are required")
|
||||
|
||||
self.db = db
|
||||
|
||||
def uam(self, table: str, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Unified user access management function that filters data based on user privileges
|
||||
and adds access control attributes.
|
||||
|
||||
Args:
|
||||
table: Name of the table
|
||||
recordset: Recordset to filter based on access rules
|
||||
|
||||
Returns:
|
||||
Filtered recordset with access control attributes
|
||||
"""
|
||||
userPrivilege = self.currentUser.get("privilege", "user")
|
||||
filtered_records = []
|
||||
|
||||
# Apply filtering based on privilege
|
||||
if userPrivilege == "sysadmin":
|
||||
filtered_records = recordset # System admins see all records
|
||||
elif userPrivilege == "admin":
|
||||
# Admins see records in their mandate
|
||||
filtered_records = [r for r in recordset if r.get("_mandateId") == self._mandateId]
|
||||
else: # Regular users
|
||||
# Users only see their own Google tokens
|
||||
filtered_records = [r for r in recordset
|
||||
if r.get("_mandateId") == self._mandateId and r.get("_userId") == self._userId]
|
||||
|
||||
# Add access control attributes to each record
|
||||
for record in filtered_records:
|
||||
record_id = record.get("id")
|
||||
|
||||
# Set access control flags based on user permissions
|
||||
if table == "googleTokens":
|
||||
record["_hideView"] = False # Everyone can view their own tokens
|
||||
record["_hideEdit"] = not self.canModify("googleTokens", record_id)
|
||||
record["_hideDelete"] = not self.canModify("googleTokens", record_id)
|
||||
else:
|
||||
# Default access control for other tables
|
||||
record["_hideView"] = False
|
||||
record["_hideEdit"] = not self.canModify(table, record_id)
|
||||
record["_hideDelete"] = not self.canModify(table, record_id)
|
||||
|
||||
return filtered_records
|
||||
|
||||
def canModify(self, table: str, recordId: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Checks if the current user can modify (create/update/delete) records in a table.
|
||||
|
||||
Args:
|
||||
table: Name of the table
|
||||
recordId: Optional record ID for specific record check
|
||||
|
||||
Returns:
|
||||
Boolean indicating permission
|
||||
"""
|
||||
userPrivilege = self.currentUser.get("privilege", "user")
|
||||
|
||||
# System admins can modify anything
|
||||
if userPrivilege == "sysadmin":
|
||||
return True
|
||||
|
||||
# Check specific record permissions
|
||||
if recordId is not None:
|
||||
# Get the record to check ownership
|
||||
records = self.db.getRecordset(table, recordFilter={"id": recordId})
|
||||
if not records:
|
||||
return False
|
||||
|
||||
record = records[0]
|
||||
|
||||
# Admins can modify anything in their mandate
|
||||
if userPrivilege == "admin" and record.get("_mandateId") == self._mandateId:
|
||||
return True
|
||||
|
||||
# Users can only modify their own Google tokens
|
||||
if (record.get("_mandateId") == self._mandateId and
|
||||
record.get("_userId") == self._userId):
|
||||
return True
|
||||
|
||||
return False
|
||||
else:
|
||||
# For general table modify permission (e.g., create)
|
||||
# Admins can create anything in their mandate
|
||||
if userPrivilege == "admin":
|
||||
return True
|
||||
|
||||
# Regular users can create their own Google tokens
|
||||
if table == "googleTokens":
|
||||
return True
|
||||
return False
|
||||
287
modules/interfaces/googleInterface.py
Normal file
287
modules/interfaces/googleInterface.py
Normal file
|
|
@ -0,0 +1,287 @@
|
|||
"""
|
||||
Google interface for handling Google authentication and API operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import requests
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
import secrets
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
from google.auth.transport.requests import Request
|
||||
import os
|
||||
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.interfaces.googleModel import GoogleToken, GoogleUserInfo, GoogleConfig
|
||||
from modules.connectors.connectorDbJson import DatabaseConnector
|
||||
from modules.interfaces.googleAccess import GoogleAccess
|
||||
from modules.interfaces.gatewayInterface import getRootUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Singleton factory for GoogleInterface instances per context
|
||||
_googleInterfaces = {}
|
||||
|
||||
# Root interface instance
|
||||
_rootGoogleInterface = None
|
||||
|
||||
class GoogleInterface:
|
||||
"""Interface for Google authentication and API operations"""
|
||||
|
||||
def __init__(self, currentUser: Dict[str, Any] = None):
|
||||
"""Initialize the Google interface"""
|
||||
# Initialize variables
|
||||
self.currentUser = currentUser
|
||||
self.mandateId = currentUser.get("mandateId") if currentUser else None
|
||||
self.userId = currentUser.get("id") if currentUser else None
|
||||
self.access = None # Will be set when user context is provided
|
||||
|
||||
# Initialize configuration
|
||||
self.clientId = APP_CONFIG.get("Service_GOOGLE_CLIENT_ID")
|
||||
self.clientSecret = APP_CONFIG.get("Service_GOOGLE_CLIENT_SECRET")
|
||||
self.redirectUri = APP_CONFIG.get("Service_GOOGLE_REDIRECT_URI")
|
||||
self.authorityUrl = "https://accounts.google.com"
|
||||
self.tokenUrl = "https://oauth2.googleapis.com/token"
|
||||
self.userInfoUrl = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||
self.scopes = ["openid", "profile", "email"]
|
||||
|
||||
# Initialize database
|
||||
self._initializeDatabase()
|
||||
|
||||
# Initialize OAuth2 flow
|
||||
self.flow = Flow.from_client_config(
|
||||
{
|
||||
"web": {
|
||||
"client_id": self.clientId,
|
||||
"client_secret": self.clientSecret,
|
||||
"auth_uri": f"{self.authorityUrl}/o/oauth2/auth",
|
||||
"token_uri": self.tokenUrl,
|
||||
"redirect_uris": [self.redirectUri]
|
||||
}
|
||||
},
|
||||
scopes=self.scopes
|
||||
)
|
||||
|
||||
# Set user context if provided
|
||||
if currentUser:
|
||||
self.setUserContext(currentUser)
|
||||
|
||||
def _initializeDatabase(self):
|
||||
"""Initializes the database connection."""
|
||||
try:
|
||||
# Get configuration values with defaults
|
||||
dbHost = APP_CONFIG.get("DB_GOOGLE_HOST", "data")
|
||||
dbDatabase = APP_CONFIG.get("DB_GOOGLE_DATABASE", "google")
|
||||
dbUser = APP_CONFIG.get("DB_GOOGLE_USER")
|
||||
dbPassword = APP_CONFIG.get("DB_GOOGLE_PASSWORD_SECRET")
|
||||
|
||||
# Ensure the database directory exists
|
||||
os.makedirs(dbHost, exist_ok=True)
|
||||
|
||||
self.db = DatabaseConnector(
|
||||
dbHost=dbHost,
|
||||
dbDatabase=dbDatabase,
|
||||
dbUser=dbUser,
|
||||
dbPassword=dbPassword,
|
||||
mandateId=self.mandateId,
|
||||
userId=self.userId
|
||||
)
|
||||
|
||||
# Set context
|
||||
self.db.updateContext(self.mandateId, self.userId)
|
||||
|
||||
logger.info("Database initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database: {str(e)}")
|
||||
raise
|
||||
|
||||
def initiateLogin(self) -> str:
|
||||
"""Initiate Google login flow"""
|
||||
try:
|
||||
# Generate auth URL
|
||||
auth_url, _ = self.flow.authorization_url(
|
||||
access_type="offline",
|
||||
include_granted_scopes="true",
|
||||
state=self._generateState()
|
||||
)
|
||||
return auth_url
|
||||
except Exception as e:
|
||||
logger.error(f"Error initiating Google login: {str(e)}")
|
||||
return None
|
||||
|
||||
def handleAuthCallback(self, code: str) -> Optional[GoogleToken]:
|
||||
"""Handle Google OAuth callback"""
|
||||
try:
|
||||
# Exchange code for token
|
||||
self.flow.fetch_token(code=code)
|
||||
credentials = self.flow.credentials
|
||||
|
||||
# Get user info
|
||||
user_info = self.getUserInfoFromToken(credentials.token)
|
||||
if not user_info:
|
||||
return None
|
||||
|
||||
# Create token model
|
||||
token = GoogleToken(
|
||||
access_token=credentials.token,
|
||||
refresh_token=credentials.refresh_token,
|
||||
expires_in=credentials.expiry.timestamp() - datetime.now().timestamp(),
|
||||
token_type=credentials.token_type,
|
||||
expires_at=credentials.expiry.timestamp(),
|
||||
user_info=user_info.model_dump(),
|
||||
mandateId=self.mandateId,
|
||||
userId=self.userId
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling auth callback: {str(e)}")
|
||||
return None
|
||||
|
||||
def verifyToken(self, token: str) -> bool:
|
||||
"""Verify Google token"""
|
||||
try:
|
||||
# Get user info from token
|
||||
user_info = self.getUserInfoFromToken(token)
|
||||
if not user_info:
|
||||
return False
|
||||
|
||||
# Get current user's Google connection
|
||||
user = self.db.getRecordset("users", recordFilter={"id": self.userId})[0]
|
||||
google_connection = next((conn for conn in user.get("connections", [])
|
||||
if conn.get("authority") == "google"), None)
|
||||
|
||||
if not google_connection:
|
||||
return False
|
||||
|
||||
# Verify the token belongs to this user
|
||||
return user_info.id == google_connection.get("externalId")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying Google token: {str(e)}")
|
||||
return False
|
||||
|
||||
def getUserInfoFromToken(self, token: str) -> Optional[GoogleUserInfo]:
|
||||
"""Get user info from Google API"""
|
||||
try:
|
||||
# Call Google API
|
||||
response = requests.get(
|
||||
self.userInfoUrl,
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to get user info: {response.text}")
|
||||
return None
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Create user info model
|
||||
return GoogleUserInfo(
|
||||
id=data["sub"], # Google uses 'sub' as the unique identifier
|
||||
email=data["email"],
|
||||
name=data.get("name", ""),
|
||||
picture=data.get("picture") # Google provides profile picture URL
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user info: {str(e)}")
|
||||
return None
|
||||
|
||||
def refreshToken(self, refresh_token: str) -> Optional[GoogleToken]:
|
||||
"""Refresh Google token"""
|
||||
try:
|
||||
# Create credentials object
|
||||
credentials = Credentials(
|
||||
None, # No access token
|
||||
refresh_token=refresh_token,
|
||||
token_uri=self.tokenUrl,
|
||||
client_id=self.clientId,
|
||||
client_secret=self.clientSecret
|
||||
)
|
||||
|
||||
# Refresh token
|
||||
credentials.refresh(Request())
|
||||
|
||||
# Get user info
|
||||
user_info = self.getUserInfoFromToken(credentials.token)
|
||||
if not user_info:
|
||||
return None
|
||||
|
||||
# Create token model
|
||||
token = GoogleToken(
|
||||
access_token=credentials.token,
|
||||
refresh_token=credentials.refresh_token or refresh_token,
|
||||
expires_in=credentials.expiry.timestamp() - datetime.now().timestamp(),
|
||||
token_type=credentials.token_type,
|
||||
expires_at=credentials.expiry.timestamp(),
|
||||
user_info=user_info.model_dump(),
|
||||
mandateId=self.mandateId,
|
||||
userId=self.userId
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing token: {str(e)}")
|
||||
return None
|
||||
|
||||
def _generateState(self) -> str:
|
||||
"""Generate secure state token"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
def setUserContext(self, currentUser: Dict[str, Any]):
|
||||
"""Set user context for the interface"""
|
||||
if not currentUser:
|
||||
logger.info("Initializing interface without user context")
|
||||
return
|
||||
|
||||
self.currentUser = currentUser
|
||||
self.mandateId = currentUser.get("mandateId")
|
||||
self.userId = currentUser.get("id")
|
||||
|
||||
if not self.mandateId or not self.userId:
|
||||
raise ValueError("Invalid user context: mandateId and id are required")
|
||||
|
||||
# Initialize access control with user context
|
||||
self.access = GoogleAccess(self.currentUser, self.db)
|
||||
|
||||
# Update database context
|
||||
self.db.updateContext(self.mandateId, self.userId)
|
||||
|
||||
logger.debug(f"User context set: userId={self.userId}")
|
||||
|
||||
def getRootInterface() -> GoogleInterface:
|
||||
"""
|
||||
Returns a GoogleInterface instance with root privileges.
|
||||
This is used for initial setup and user creation.
|
||||
"""
|
||||
global _rootGoogleInterface
|
||||
|
||||
if _rootGoogleInterface is None:
|
||||
# Get root user from gateway
|
||||
rootUser = getRootUser()
|
||||
_rootGoogleInterface = GoogleInterface(rootUser)
|
||||
|
||||
return _rootGoogleInterface
|
||||
|
||||
def getInterface(currentUser: Dict[str, Any] = None) -> GoogleInterface:
|
||||
"""
|
||||
Returns a GoogleInterface instance.
|
||||
If currentUser is provided, initializes with user context.
|
||||
Otherwise, returns an instance with only database access.
|
||||
"""
|
||||
# Create new instance if not exists
|
||||
if "default" not in _googleInterfaces:
|
||||
_googleInterfaces["default"] = GoogleInterface(currentUser or {})
|
||||
|
||||
interface = _googleInterfaces["default"]
|
||||
|
||||
if currentUser:
|
||||
interface.setUserContext(currentUser)
|
||||
else:
|
||||
logger.info("Returning interface without user context")
|
||||
|
||||
return interface
|
||||
35
modules/interfaces/googleModel.py
Normal file
35
modules/interfaces/googleModel.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
"""
|
||||
Models for Google authentication and API operations.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
class GoogleToken(BaseModel):
|
||||
"""Model for Google OAuth tokens"""
|
||||
access_token: str
|
||||
refresh_token: Optional[str] = None
|
||||
expires_in: int
|
||||
token_type: str = "bearer"
|
||||
expires_at: float
|
||||
user_info: Dict[str, Any]
|
||||
mandateId: str
|
||||
userId: str
|
||||
|
||||
class GoogleUserInfo(BaseModel):
|
||||
"""Model for Google user information"""
|
||||
id: str # Google uses 'sub' as the unique identifier
|
||||
email: str
|
||||
name: str
|
||||
picture: Optional[str] = None # Google provides profile picture URL
|
||||
|
||||
class GoogleConfig(BaseModel):
|
||||
"""Configuration for Google authentication service"""
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
scopes: list[str]
|
||||
authority_url: str = "https://accounts.google.com"
|
||||
token_url: str = "https://oauth2.googleapis.com/token"
|
||||
user_info_url: str = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||
|
|
@ -13,48 +13,52 @@ import secrets
|
|||
import os
|
||||
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.interfaces.msftModel import MsftToken, MsftUserInfo
|
||||
from .msftModel import MsftToken, MsftUserInfo, MsftConfig
|
||||
from modules.connectors.connectorDbJson import DatabaseConnector
|
||||
from modules.interfaces.msftAccess import MsftAccess
|
||||
from .msftAccess import MsftAccess
|
||||
from modules.interfaces.gatewayInterface import getRootUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Singleton factory for MsftInterface instances per context
|
||||
_msftInterfaces = {}
|
||||
|
||||
# Root interface instance
|
||||
_rootMsftInterface = None
|
||||
|
||||
class MsftInterface:
|
||||
"""Interface for Microsoft authentication and Graph API operations"""
|
||||
|
||||
def __init__(self, currentUser: Dict[str, Any]):
|
||||
def __init__(self, currentUser: Dict[str, Any] = None):
|
||||
"""Initialize the Microsoft interface"""
|
||||
# Initialize variables
|
||||
self.currentUser = currentUser
|
||||
self._mandateId = currentUser.get("_mandateId")
|
||||
self._userId = currentUser.get("id")
|
||||
|
||||
if not self._mandateId or not self._userId:
|
||||
raise ValueError("Invalid user context: _mandateId and id are required")
|
||||
self.mandateId = currentUser.get("mandateId") if currentUser else None
|
||||
self.userId = currentUser.get("id") if currentUser else None
|
||||
self.access = None # Will be set when user context is provided
|
||||
|
||||
# Initialize configuration
|
||||
self.client_id = APP_CONFIG.get("Service_MSFT_CLIENT_ID")
|
||||
self.client_secret = APP_CONFIG.get("Service_MSFT_CLIENT_SECRET")
|
||||
self.tenant_id = APP_CONFIG.get("Service_MSFT_TENANT_ID", "common")
|
||||
self.redirect_uri = APP_CONFIG.get("Service_MSFT_REDIRECT_URI")
|
||||
self.authority = f"https://login.microsoftonline.com/{self.tenant_id}"
|
||||
self.clientId = APP_CONFIG.get("Service_MSFT_CLIENT_ID")
|
||||
self.clientSecret = APP_CONFIG.get("Service_MSFT_CLIENT_SECRET")
|
||||
self.tenantId = APP_CONFIG.get("Service_MSFT_TENANT_ID", "common")
|
||||
self.redirectUri = APP_CONFIG.get("Service_MSFT_REDIRECT_URI")
|
||||
self.authority = f"https://login.microsoftonline.com/{self.tenantId}"
|
||||
self.scopes = ["Mail.ReadWrite", "User.Read"]
|
||||
|
||||
# Initialize database
|
||||
self._initializeDatabase()
|
||||
|
||||
# Initialize access control
|
||||
self.access = MsftAccess(self.currentUser, self.db)
|
||||
|
||||
# Initialize MSAL application
|
||||
self.msal_app = msal.ConfidentialClientApplication(
|
||||
self.client_id,
|
||||
self.clientId,
|
||||
authority=self.authority,
|
||||
client_credential=self.client_secret
|
||||
client_credential=self.clientSecret
|
||||
)
|
||||
|
||||
# Set user context if provided
|
||||
if currentUser:
|
||||
self.setUserContext(currentUser)
|
||||
|
||||
def _initializeDatabase(self):
|
||||
"""Initializes the database connection."""
|
||||
try:
|
||||
|
|
@ -72,12 +76,12 @@ class MsftInterface:
|
|||
dbDatabase=dbDatabase,
|
||||
dbUser=dbUser,
|
||||
dbPassword=dbPassword,
|
||||
_mandateId=self._mandateId,
|
||||
_userId=self._userId
|
||||
mandateId=self.mandateId,
|
||||
userId=self.userId
|
||||
)
|
||||
|
||||
# Set context
|
||||
self.db.updateContext(self._mandateId, self._userId)
|
||||
self.db.updateContext(self.mandateId, self.userId)
|
||||
|
||||
logger.info("Database initialized successfully")
|
||||
except Exception as e:
|
||||
|
|
@ -111,156 +115,147 @@ class MsftInterface:
|
|||
"""
|
||||
return self.access.canModify(table, recordId)
|
||||
|
||||
def getMsftToken(self) -> Optional[MsftToken]:
|
||||
"""Get Microsoft token for current user"""
|
||||
def initiateLogin(self) -> str:
|
||||
"""Initiate Microsoft login flow"""
|
||||
try:
|
||||
tokens = self.db.getRecordset("msftTokens", recordFilter={
|
||||
"_mandateId": self._mandateId,
|
||||
"_userId": self._userId
|
||||
})
|
||||
if not tokens:
|
||||
# Generate auth URL
|
||||
auth_url = self.msal_app.get_authorization_request_url(
|
||||
scopes=self.scopes,
|
||||
redirect_uri=self.redirectUri,
|
||||
state=self._generateState()
|
||||
)
|
||||
return auth_url
|
||||
except Exception as e:
|
||||
logger.error(f"Error initiating Microsoft login: {str(e)}")
|
||||
return None
|
||||
|
||||
# Apply access control
|
||||
filtered_tokens = self._uam("msftTokens", tokens)
|
||||
if not filtered_tokens:
|
||||
return None
|
||||
|
||||
return MsftToken(**filtered_tokens[0])
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Microsoft token: {str(e)}")
|
||||
return None
|
||||
|
||||
def saveMsftToken(self, token_data: Dict[str, Any]) -> bool:
|
||||
"""Save Microsoft token data"""
|
||||
def handleAuthCallback(self, code: str) -> Optional[MsftToken]:
|
||||
"""Handle Microsoft OAuth callback"""
|
||||
try:
|
||||
# Check if user can modify tokens
|
||||
if not self._canModify("msftTokens"):
|
||||
raise PermissionError("No permission to save Microsoft token")
|
||||
|
||||
# Add user and mandate IDs to token data
|
||||
token_data["_mandateId"] = self._mandateId
|
||||
token_data["_userId"] = self._userId
|
||||
|
||||
# Validate token data using Pydantic model
|
||||
try:
|
||||
token = MsftToken(**token_data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid token data: {str(e)}")
|
||||
|
||||
# Check if token already exists
|
||||
existing_token = self.getMsftToken()
|
||||
|
||||
if existing_token:
|
||||
# Update existing token
|
||||
return self.db.recordModify("msftTokens", existing_token.id, token.model_dump())
|
||||
else:
|
||||
# Create new token record
|
||||
return self.db.recordCreate("msftTokens", token.model_dump())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving Microsoft token: {str(e)}")
|
||||
return False
|
||||
|
||||
def deleteMsftToken(self) -> bool:
|
||||
"""Delete Microsoft token for current user"""
|
||||
try:
|
||||
if not self._canModify("msftTokens"):
|
||||
raise PermissionError("No permission to delete Microsoft token")
|
||||
|
||||
existing_token = self.getMsftToken()
|
||||
if existing_token:
|
||||
return self.db.recordDelete("msftTokens", existing_token.id)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting Microsoft token: {str(e)}")
|
||||
return False
|
||||
|
||||
def getCurrentUserToken(self) -> Tuple[Optional[MsftUserInfo], Optional[str]]:
|
||||
"""Get current user's Microsoft token and info"""
|
||||
try:
|
||||
token_data = self.getMsftToken()
|
||||
if not token_data:
|
||||
return None, None
|
||||
|
||||
# Verify token is still valid
|
||||
if not self.verifyToken(token_data.access_token):
|
||||
if not self.refreshToken(token_data):
|
||||
return None, None
|
||||
token_data = self.getMsftToken()
|
||||
|
||||
user_info = token_data.user_info
|
||||
if not user_info:
|
||||
user_info = self.getUserInfoFromToken(token_data.access_token)
|
||||
if user_info:
|
||||
token_data.user_info = user_info
|
||||
self.saveMsftToken(token_data.model_dump())
|
||||
|
||||
return MsftUserInfo(**user_info) if user_info else None, token_data.access_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current user token: {str(e)}")
|
||||
return None, None
|
||||
|
||||
def verifyToken(self, token: str) -> bool:
|
||||
"""Verify the access token is valid"""
|
||||
try:
|
||||
headers = {
|
||||
'Authorization': f'Bearer {token}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.get('https://graph.microsoft.com/v1.0/me', headers=headers)
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying token: {str(e)}")
|
||||
return False
|
||||
|
||||
def refreshToken(self, token_data: MsftToken) -> bool:
|
||||
"""Refresh the access token using the stored refresh token"""
|
||||
try:
|
||||
if not token_data or not token_data.refresh_token:
|
||||
return False
|
||||
|
||||
result = self.msal_app.acquire_token_by_refresh_token(
|
||||
token_data.refresh_token,
|
||||
scopes=self.scopes
|
||||
# Get token from code
|
||||
token_response = self.msal_app.acquire_token_by_authorization_code(
|
||||
code,
|
||||
scopes=self.scopes,
|
||||
redirect_uri=self.redirectUri
|
||||
)
|
||||
|
||||
if "error" in result:
|
||||
logger.error(f"Error refreshing token: {result.get('error')}")
|
||||
return False
|
||||
if "error" in token_response:
|
||||
logger.error(f"Token acquisition failed: {token_response['error']}")
|
||||
return None
|
||||
|
||||
# Update token data
|
||||
token_data.access_token = result["access_token"]
|
||||
if "refresh_token" in result:
|
||||
token_data.refresh_token = result["refresh_token"]
|
||||
# Get user info
|
||||
user_info = self.getUserInfoFromToken(token_response["access_token"])
|
||||
if not user_info:
|
||||
return None
|
||||
|
||||
return self.saveMsftToken(token_data.model_dump())
|
||||
# Create token model
|
||||
token = MsftToken(
|
||||
access_token=token_response["access_token"],
|
||||
refresh_token=token_response.get("refresh_token", ""),
|
||||
expires_in=token_response.get("expires_in", 0),
|
||||
token_type=token_response.get("token_type", "bearer"),
|
||||
expires_at=datetime.now().timestamp() + token_response.get("expires_in", 0),
|
||||
user_info=user_info.model_dump(),
|
||||
mandateId=self.mandateId,
|
||||
userId=self.userId
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing token: {str(e)}")
|
||||
logger.error(f"Error handling auth callback: {str(e)}")
|
||||
return None
|
||||
|
||||
def verifyToken(self, token: str) -> bool:
|
||||
"""Verify Microsoft token"""
|
||||
try:
|
||||
# Get user info from token
|
||||
user_info = self.getUserInfoFromToken(token)
|
||||
if not user_info:
|
||||
return False
|
||||
|
||||
def getUserInfoFromToken(self, access_token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get user information using the access token"""
|
||||
# Get current user's Microsoft connection
|
||||
user = self.db.getRecordset("users", recordFilter={"id": self.userId})[0]
|
||||
msft_connection = next((conn for conn in user.get("connections", [])
|
||||
if conn.get("authority") == "microsoft"), None)
|
||||
|
||||
if not msft_connection:
|
||||
return False
|
||||
|
||||
# Verify the token belongs to this user
|
||||
return user_info.id == msft_connection.get("externalId")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying Microsoft token: {str(e)}")
|
||||
return False
|
||||
|
||||
def getUserInfoFromToken(self, token: str) -> Optional[MsftUserInfo]:
|
||||
"""Get user info from Microsoft Graph"""
|
||||
try:
|
||||
headers = {
|
||||
'Authorization': f'Bearer {access_token}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.get('https://graph.microsoft.com/v1.0/me', headers=headers)
|
||||
if response.status_code == 200:
|
||||
user_data = response.json()
|
||||
return {
|
||||
"name": user_data.get("displayName", ""),
|
||||
"email": user_data.get("userPrincipalName", ""),
|
||||
"id": user_data.get("id", "")
|
||||
}
|
||||
# Call Microsoft Graph API
|
||||
response = requests.get(
|
||||
"https://graph.microsoft.com/v1.0/me",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to get user info: {response.text}")
|
||||
return None
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Create user info model
|
||||
return MsftUserInfo(
|
||||
id=data["id"],
|
||||
email=data.get("mail") or data.get("userPrincipalName"),
|
||||
name=data.get("displayName", ""),
|
||||
picture=None # Microsoft Graph doesn't provide profile picture by default
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user info: {str(e)}")
|
||||
return None
|
||||
|
||||
def refreshToken(self, refresh_token: str) -> Optional[MsftToken]:
|
||||
"""Refresh Microsoft token"""
|
||||
try:
|
||||
# Refresh token
|
||||
token_response = self.msal_app.acquire_token_by_refresh_token(
|
||||
refresh_token,
|
||||
scopes=self.scopes
|
||||
)
|
||||
|
||||
if "error" in token_response:
|
||||
logger.error(f"Token refresh failed: {token_response['error']}")
|
||||
return None
|
||||
|
||||
# Get user info
|
||||
user_info = self.getUserInfoFromToken(token_response["access_token"])
|
||||
if not user_info:
|
||||
return None
|
||||
|
||||
# Create token model
|
||||
token = MsftToken(
|
||||
access_token=token_response["access_token"],
|
||||
refresh_token=token_response.get("refresh_token", refresh_token),
|
||||
expires_in=token_response.get("expires_in", 0),
|
||||
token_type=token_response.get("token_type", "bearer"),
|
||||
expires_at=datetime.now().timestamp() + token_response.get("expires_in", 0),
|
||||
user_info=user_info.model_dump(),
|
||||
mandateId=self.mandateId,
|
||||
userId=self.userId
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing token: {str(e)}")
|
||||
return None
|
||||
|
||||
def _generateState(self) -> str:
|
||||
"""Generate secure state token"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
def createDraftEmail(self, recipient: str, subject: str, body: str, attachments: List[Dict[str, Any]] = None) -> bool:
|
||||
"""Create a draft email using Microsoft Graph API"""
|
||||
try:
|
||||
|
|
@ -340,70 +335,186 @@ class MsftInterface:
|
|||
logger.error(f"Error creating draft email: {str(e)}")
|
||||
return False
|
||||
|
||||
def initiateLogin(self) -> str:
|
||||
"""Initiate Microsoft login flow"""
|
||||
def saveMsftToken(self, token_data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Save Microsoft token data to the database.
|
||||
|
||||
Args:
|
||||
token_data: Token data to save
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
state = secrets.token_urlsafe(32)
|
||||
auth_url = self.msal_app.get_authorization_request_url(
|
||||
self.scopes,
|
||||
state=state,
|
||||
redirect_uri=self.redirect_uri
|
||||
)
|
||||
return auth_url
|
||||
except Exception as e:
|
||||
logger.error(f"Error initiating login: {str(e)}")
|
||||
return None
|
||||
|
||||
def handleAuthCallback(self, code: str) -> Optional[Dict[str, Any]]:
|
||||
"""Handle Microsoft OAuth callback"""
|
||||
try:
|
||||
token_response = self.msal_app.acquire_token_by_authorization_code(
|
||||
code,
|
||||
self.scopes,
|
||||
redirect_uri=self.redirect_uri
|
||||
# Get existing token if any
|
||||
existing_tokens = self.db.getRecordset(
|
||||
"msftTokens",
|
||||
recordFilter={
|
||||
"mandateId": self.mandateId,
|
||||
"userId": self.userId
|
||||
}
|
||||
)
|
||||
|
||||
if "error" in token_response:
|
||||
logger.error(f"Token acquisition failed: {token_response['error']}")
|
||||
return None
|
||||
|
||||
user_info = self.getUserInfoFromToken(token_response["access_token"])
|
||||
if not user_info:
|
||||
return None
|
||||
|
||||
# Create MsftToken instance
|
||||
token_data = MsftToken(
|
||||
access_token=token_response["access_token"],
|
||||
refresh_token=token_response.get("refresh_token", ""),
|
||||
expires_in=token_response.get("expires_in", 0),
|
||||
token_type=token_response.get("token_type", "bearer"),
|
||||
expires_at=datetime.now().timestamp() + token_response.get("expires_in", 0),
|
||||
user_info=user_info,
|
||||
_mandateId=self._mandateId,
|
||||
_userId=self._userId
|
||||
if existing_tokens:
|
||||
# Update existing token
|
||||
token_id = existing_tokens[0]["id"]
|
||||
success = self.db.updateRecord(
|
||||
"msftTokens",
|
||||
token_id,
|
||||
token_data
|
||||
)
|
||||
else:
|
||||
# Create new token record
|
||||
success = self.db.createRecord(
|
||||
"msftTokens",
|
||||
token_data
|
||||
)
|
||||
|
||||
return token_data.model_dump()
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling auth callback: {str(e)}")
|
||||
logger.error(f"Error saving Microsoft token: {str(e)}")
|
||||
return False
|
||||
|
||||
def getMsftToken(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get Microsoft token data for current user.
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: Token data if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
tokens = self.db.getRecordset(
|
||||
"msftTokens",
|
||||
recordFilter={
|
||||
"mandateId": self.mandateId,
|
||||
"userId": self.userId
|
||||
}
|
||||
)
|
||||
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
def getInterface(currentUser: Dict[str, Any]) -> MsftInterface:
|
||||
"""
|
||||
Returns a MsftInterface instance for the current user.
|
||||
Handles initialization of database and records.
|
||||
"""
|
||||
mandateId = currentUser.get("_mandateId")
|
||||
userId = currentUser.get("id")
|
||||
if not mandateId or not userId:
|
||||
raise ValueError("Invalid user context: _mandateId and id are required")
|
||||
return tokens[0]
|
||||
|
||||
# Create context key
|
||||
contextKey = f"{mandateId}_{userId}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Microsoft token: {str(e)}")
|
||||
return None
|
||||
|
||||
def getCurrentUserToken(self) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""
|
||||
Get current user's Microsoft token and user info.
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[Dict[str, Any]], Optional[str]]: User info and access token
|
||||
"""
|
||||
try:
|
||||
token_data = self.getMsftToken()
|
||||
if not token_data:
|
||||
return None, None
|
||||
|
||||
# Check if token needs refresh
|
||||
if datetime.now().timestamp() >= token_data["expires_at"]:
|
||||
if not token_data.get("refresh_token"):
|
||||
return None, None
|
||||
|
||||
# Refresh token
|
||||
new_token = self.refreshToken(token_data["refresh_token"])
|
||||
if not new_token:
|
||||
return None, None
|
||||
|
||||
# Save new token
|
||||
self.saveMsftToken(new_token.model_dump())
|
||||
token_data = new_token.model_dump()
|
||||
|
||||
return token_data["user_info"], token_data["access_token"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current user token: {str(e)}")
|
||||
return None, None
|
||||
|
||||
def deleteMsftToken(self) -> bool:
|
||||
"""
|
||||
Delete Microsoft token for current user.
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Get existing token
|
||||
existing_tokens = self.db.getRecordset(
|
||||
"msftTokens",
|
||||
recordFilter={
|
||||
"mandateId": self.mandateId,
|
||||
"userId": self.userId
|
||||
}
|
||||
)
|
||||
|
||||
if not existing_tokens:
|
||||
return True # No token to delete
|
||||
|
||||
# Delete token
|
||||
success = self.db.deleteRecord(
|
||||
"msftTokens",
|
||||
existing_tokens[0]["id"]
|
||||
)
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting Microsoft token: {str(e)}")
|
||||
return False
|
||||
|
||||
def setUserContext(self, currentUser: Dict[str, Any]):
|
||||
"""Set user context for the interface"""
|
||||
if not currentUser:
|
||||
logger.info("Initializing interface without user context")
|
||||
return
|
||||
|
||||
self.currentUser = currentUser
|
||||
self.mandateId = currentUser.get("mandateId")
|
||||
self.userId = currentUser.get("id")
|
||||
|
||||
if not self.mandateId or not self.userId:
|
||||
raise ValueError("Invalid user context: mandateId and id are required")
|
||||
|
||||
# Initialize access control with user context
|
||||
self.access = MsftAccess(self.currentUser, self.db)
|
||||
|
||||
# Update database context
|
||||
self.db.updateContext(self.mandateId, self.userId)
|
||||
|
||||
logger.debug(f"User context set: userId={self.userId}")
|
||||
|
||||
def getRootInterface() -> MsftInterface:
|
||||
"""
|
||||
Returns a MsftInterface instance with root privileges.
|
||||
This is used for initial setup and user creation.
|
||||
"""
|
||||
global _rootMsftInterface
|
||||
|
||||
if _rootMsftInterface is None:
|
||||
# Get root user from gateway
|
||||
rootUser = getRootUser()
|
||||
_rootMsftInterface = MsftInterface(rootUser)
|
||||
|
||||
return _rootMsftInterface
|
||||
|
||||
def getInterface(currentUser: Dict[str, Any] = None) -> MsftInterface:
|
||||
"""
|
||||
Returns a MsftInterface instance.
|
||||
If currentUser is provided, initializes with user context.
|
||||
Otherwise, returns an instance with only database access.
|
||||
"""
|
||||
# Create new instance if not exists
|
||||
if contextKey not in _msftInterfaces:
|
||||
_msftInterfaces[contextKey] = MsftInterface(currentUser)
|
||||
if "default" not in _msftInterfaces:
|
||||
_msftInterfaces["default"] = MsftInterface(currentUser or {})
|
||||
|
||||
return _msftInterfaces[contextKey]
|
||||
interface = _msftInterfaces["default"]
|
||||
|
||||
if currentUser:
|
||||
interface.setUserContext(currentUser)
|
||||
else:
|
||||
logger.info("Returning interface without user context")
|
||||
|
||||
return interface
|
||||
|
|
@ -1,10 +1,38 @@
|
|||
"""
|
||||
Data models for Microsoft integration.
|
||||
Models for Microsoft authentication and Graph API operations.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
class MsftToken(BaseModel):
|
||||
"""Model for Microsoft OAuth tokens"""
|
||||
access_token: str
|
||||
refresh_token: Optional[str] = None
|
||||
expires_in: int
|
||||
token_type: str = "bearer"
|
||||
expires_at: float
|
||||
user_info: Dict[str, Any]
|
||||
mandateId: str
|
||||
userId: str
|
||||
|
||||
class MsftUserInfo(BaseModel):
|
||||
"""Model for Microsoft user information"""
|
||||
id: str
|
||||
email: str
|
||||
name: str
|
||||
picture: Optional[str] = None # Microsoft Graph doesn't provide profile picture by default
|
||||
|
||||
class MsftConfig(BaseModel):
|
||||
"""Configuration for Microsoft authentication service"""
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
scopes: list[str]
|
||||
authority_url: str = "https://login.microsoftonline.com/common"
|
||||
token_url: str = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||||
user_info_url: str = "https://graph.microsoft.com/v1.0/me"
|
||||
|
||||
# Get all attributes of the model
|
||||
def getModelAttributes(modelClass):
|
||||
|
|
@ -24,54 +52,6 @@ class Label(BaseModel):
|
|||
return self.translations[language]
|
||||
return self.default
|
||||
|
||||
class MsftToken(BaseModel):
|
||||
"""Data model for Microsoft authentication token"""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the token")
|
||||
access_token: str = Field(description="Microsoft access token")
|
||||
refresh_token: str = Field(description="Microsoft refresh token")
|
||||
expires_in: int = Field(description="Token expiration time in seconds")
|
||||
token_type: str = Field(description="Type of token (usually 'bearer')")
|
||||
expires_at: float = Field(description="Timestamp when token expires")
|
||||
user_info: Optional[Dict[str, Any]] = Field(None, description="User information from Microsoft")
|
||||
_mandateId: str = Field(description="Mandate ID associated with the token")
|
||||
_userId: str = Field(description="User ID associated with the token")
|
||||
|
||||
label: Label = Field(
|
||||
default=Label(default="Microsoft Token", translations={"en": "Microsoft Token", "fr": "Jeton Microsoft"}),
|
||||
description="Label for the class"
|
||||
)
|
||||
|
||||
# Labels for attributes
|
||||
fieldLabels: Dict[str, Label] = {
|
||||
"id": Label(default="ID", translations={}),
|
||||
"access_token": Label(default="Access Token", translations={"en": "Access Token", "fr": "Jeton d'accès"}),
|
||||
"refresh_token": Label(default="Refresh Token", translations={"en": "Refresh Token", "fr": "Jeton de rafraîchissement"}),
|
||||
"expires_in": Label(default="Expires In", translations={"en": "Expires In", "fr": "Expire dans"}),
|
||||
"token_type": Label(default="Token Type", translations={"en": "Token Type", "fr": "Type de jeton"}),
|
||||
"expires_at": Label(default="Expires At", translations={"en": "Expires At", "fr": "Expire à"}),
|
||||
"user_info": Label(default="User Info", translations={"en": "User Info", "fr": "Info utilisateur"}),
|
||||
"_mandateId": Label(default="Mandate ID", translations={"en": "Mandate ID", "fr": "ID de mandat"}),
|
||||
"_userId": Label(default="User ID", translations={"en": "User ID", "fr": "ID utilisateur"})
|
||||
}
|
||||
|
||||
class MsftUserInfo(BaseModel):
|
||||
"""Data model for Microsoft user information"""
|
||||
name: str = Field(description="User's display name")
|
||||
email: str = Field(description="User's email address")
|
||||
id: str = Field(description="User's Microsoft ID")
|
||||
|
||||
label: Label = Field(
|
||||
default=Label(default="Microsoft User Info", translations={"en": "Microsoft User Info", "fr": "Info utilisateur Microsoft"}),
|
||||
description="Label for the class"
|
||||
)
|
||||
|
||||
# Labels for attributes
|
||||
fieldLabels: Dict[str, Label] = {
|
||||
"name": Label(default="Name", translations={"en": "Name", "fr": "Nom"}),
|
||||
"email": Label(default="Email", translations={"en": "Email", "fr": "E-mail"}),
|
||||
"id": Label(default="ID", translations={})
|
||||
}
|
||||
|
||||
# Response models for Microsoft routes
|
||||
class MsftAuthStatus(BaseModel):
|
||||
"""Response model for Microsoft authentication status"""
|
||||
|
|
|
|||
|
|
@ -145,13 +145,13 @@ async def get_file(
|
|||
async def update_file(
|
||||
file_id: str,
|
||||
file_data: FileItem,
|
||||
current_user: Dict[str, Any] = Depends(auth.getCurrentActiveUser)
|
||||
currentUser: Dict[str, Any] = Depends(auth.getCurrentActiveUser)
|
||||
):
|
||||
"""
|
||||
Update file metadata
|
||||
"""
|
||||
try:
|
||||
interfaceLucydom = lucydomInterface.getInterface(current_user)
|
||||
interfaceLucydom = lucydomInterface.getInterface(currentUser)
|
||||
|
||||
# Get the file from the database
|
||||
file = interfaceLucydom.getFile(file_id)
|
||||
|
|
@ -159,7 +159,7 @@ async def update_file(
|
|||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# Check if user has access to the file
|
||||
if file.get("userId", 0) != current_user.get("id", 0):
|
||||
if file.get("userId", 0) != currentUser.get("id", 0):
|
||||
raise HTTPException(status_code=403, detail="Not authorized to update this file")
|
||||
|
||||
# Convert FileItem to dict for interface
|
||||
|
|
@ -171,8 +171,8 @@ async def update_file(
|
|||
raise HTTPException(status_code=500, detail="Failed to update file")
|
||||
|
||||
# Get updated file and convert to FileItem
|
||||
updated_file = interfaceLucydom.getFile(file_id)
|
||||
return FileItem(**updated_file)
|
||||
updatedFile = interfaceLucydom.getFile(file_id)
|
||||
return FileItem(**updatedFile)
|
||||
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from fastapi import APIRouter, HTTPException, Depends, Body, status, Response
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse, JSONResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import timedelta
|
||||
import pathlib
|
||||
import os
|
||||
|
|
@ -11,7 +11,6 @@ from pathlib import Path as FilePath
|
|||
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
import modules.security.auth as auth
|
||||
import modules.interfaces.gatewayModel as gatewayModel
|
||||
import modules.interfaces.gatewayInterface as gatewayInterface
|
||||
|
||||
router = APIRouter(
|
||||
|
|
@ -54,16 +53,23 @@ async def options_route(fullPath: str):
|
|||
return Response(status_code=200)
|
||||
|
||||
@router.post("/api/token", response_model=gatewayModel.Token, tags=["General"])
|
||||
async def login_for_access_token(formData: OAuth2PasswordRequestForm = Depends()):
|
||||
async def login_for_access_token(
|
||||
formData: OAuth2PasswordRequestForm = Depends(),
|
||||
authority: str = "local",
|
||||
external_token: Optional[str] = None
|
||||
):
|
||||
"""Get access token for user authentication"""
|
||||
# Create a new gateway interface instance with admin context
|
||||
interfaceRoot = auth.getRootInterface()
|
||||
interfaceRoot = gatewayInterface.getRootInterface()
|
||||
|
||||
try:
|
||||
# Authenticate user
|
||||
user = interfaceRoot.authenticateUser(formData.username, formData.password)
|
||||
|
||||
# Authenticate user and get token
|
||||
token = interfaceRoot.authenticateAndGetToken(formData.username, formData.password)
|
||||
# Get token directly
|
||||
token = interfaceRoot.authenticateAndGetToken(
|
||||
username=formData.username,
|
||||
password=formData.password,
|
||||
authority=authority,
|
||||
external_token=external_token
|
||||
)
|
||||
return token
|
||||
except ValueError as e:
|
||||
# Handle authentication errors
|
||||
|
|
@ -91,7 +97,7 @@ async def read_user_me(currentUser: Dict[str, Any] = Depends(auth.getCurrentActi
|
|||
async def register_user(userData: gatewayModel.User):
|
||||
"""Register a new user."""
|
||||
try:
|
||||
interfaceRoot = auth.getRootInterface()
|
||||
interfaceRoot = gatewayInterface.getRootInterface()
|
||||
return interfaceRoot.registerUser(userData.model_dump())
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -112,7 +118,7 @@ async def check_username_availability(
|
|||
):
|
||||
"""Check if a username is available for registration"""
|
||||
try:
|
||||
interfaceRoot = auth.getRootInterface()
|
||||
interfaceRoot = gatewayInterface.getRootInterface()
|
||||
return interfaceRoot.checkUsernameAvailability(username, authenticationAuthority)
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking username availability: {str(e)}")
|
||||
|
|
|
|||
322
modules/routes/routeGoogle.py
Normal file
322
modules/routes/routeGoogle.py
Normal file
|
|
@ -0,0 +1,322 @@
|
|||
"""
|
||||
Routes for Google authentication.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request, Response, status, Cookie, Body
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
||||
import logging
|
||||
import json
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Import auth module
|
||||
import modules.security.auth as auth
|
||||
|
||||
# Import interfaces
|
||||
import modules.interfaces.googleInterface as googleInterface
|
||||
import modules.interfaces.gatewayInterface as gatewayInterface
|
||||
from modules.interfaces.googleModel import (
|
||||
GoogleToken,
|
||||
GoogleUserInfo,
|
||||
GoogleAuthStatus,
|
||||
GoogleTokenResponse,
|
||||
GoogleSaveTokenResponse
|
||||
)
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create router for Google Auth endpoints
|
||||
router = APIRouter(
|
||||
prefix="/api/google",
|
||||
tags=["Google"],
|
||||
responses={
|
||||
404: {"description": "Not found"},
|
||||
400: {"description": "Bad request"},
|
||||
401: {"description": "Unauthorized"},
|
||||
403: {"description": "Forbidden"},
|
||||
500: {"description": "Internal server error"}
|
||||
}
|
||||
)
|
||||
|
||||
@router.get("/login")
|
||||
async def login():
|
||||
"""Initiate Google login for the current user"""
|
||||
try:
|
||||
# Get Google interface with root context for initial setup
|
||||
google = googleInterface.getRootInterface()
|
||||
|
||||
# Get login URL
|
||||
auth_url = google.initiateLogin()
|
||||
if not auth_url:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to initiate Google login"
|
||||
)
|
||||
|
||||
logger.info("Redirecting to Google login")
|
||||
return RedirectResponse(auth_url)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initiating Google login: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to initiate Google login: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/auth/callback")
|
||||
async def auth_callback(code: str, state: str, request: Request):
|
||||
"""Handle Google OAuth callback"""
|
||||
try:
|
||||
# Get Google interface with root context for initial setup
|
||||
google = googleInterface.getRootInterface()
|
||||
|
||||
# Handle auth callback
|
||||
token_response = google.handleAuthCallback(code)
|
||||
if not token_response:
|
||||
return HTMLResponse(
|
||||
content="""
|
||||
<html>
|
||||
<head>
|
||||
<title>Authentication Failed</title>
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; text-align: center; margin-top: 50px; }
|
||||
.error { color: red; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1 class="error">Authentication Failed</h1>
|
||||
<p>Could not acquire access token.</p>
|
||||
<script>
|
||||
setTimeout(() => window.close(), 3000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
""",
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Get gateway interface for user operations
|
||||
gateway = gatewayInterface.getRootInterface()
|
||||
|
||||
# Check if user exists
|
||||
user = gateway.getUserByUsername(token_response.user_info["email"])
|
||||
|
||||
# If user doesn't exist, create a new user in the default mandate
|
||||
if not user:
|
||||
try:
|
||||
# Get the root mandate ID
|
||||
rootMandateId = gateway.getInitialId("mandates")
|
||||
if not rootMandateId:
|
||||
raise ValueError("Root mandate not found")
|
||||
|
||||
# Create new user with Google authentication
|
||||
user = gateway.createUser(
|
||||
username=token_response.user_info["email"],
|
||||
email=token_response.user_info["email"],
|
||||
fullName=token_response.user_info.get("name", token_response.user_info["email"]),
|
||||
mandateId=rootMandateId,
|
||||
authenticationAuthority="google"
|
||||
)
|
||||
logger.info(f"Created new user for Google account: {token_response.user_info['email']}")
|
||||
|
||||
# Verify user was created by retrieving it
|
||||
user = gateway.getUserByUsername(token_response.user_info["email"])
|
||||
if not user:
|
||||
raise ValueError("Failed to retrieve created user")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create user for Google account: {str(e)}")
|
||||
return HTMLResponse(
|
||||
content="""
|
||||
<html>
|
||||
<head>
|
||||
<title>Registration Failed</title>
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; text-align: center; margin-top: 50px; }
|
||||
.error { color: red; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1 class="error">Registration Failed</h1>
|
||||
<p>Could not create user account.</p>
|
||||
<script>
|
||||
setTimeout(() => window.close(), 3000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
""",
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Create backend token
|
||||
access_token_expires = timedelta(minutes=auth.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = auth.createAccessToken(
|
||||
data={
|
||||
"sub": user["username"],
|
||||
"mandateId": str(user["mandateId"]),
|
||||
"userId": str(user["id"]),
|
||||
"authenticationAuthority": "google"
|
||||
},
|
||||
expiresDelta=access_token_expires
|
||||
)
|
||||
|
||||
# Store tokens in session storage for the frontend to pick up
|
||||
response = HTMLResponse(
|
||||
content=f"""
|
||||
<html>
|
||||
<head>
|
||||
<title>Authentication Successful</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; text-align: center; margin-top: 50px; }}
|
||||
.success {{ color: green; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1 class="success">Authentication Successful</h1>
|
||||
<p>Welcome, {token_response.user_info.get('name', 'User')}!</p>
|
||||
<p>This window will close automatically.</p>
|
||||
<script>
|
||||
// Store token data in session storage
|
||||
sessionStorage.setItem('google_token_data', JSON.stringify({json.dumps(token_response.model_dump())}));
|
||||
|
||||
// Notify parent window of success
|
||||
if (window.opener) {{
|
||||
window.opener.postMessage({{
|
||||
type: 'google_auth_success',
|
||||
user: {json.dumps(token_response.user_info)},
|
||||
token_data: {json.dumps(token_response.model_dump())},
|
||||
access_token: "{access_token}"
|
||||
}}, '*');
|
||||
}}
|
||||
// Close window after 3 seconds
|
||||
setTimeout(() => window.close(), 3000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in auth callback: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Authentication failed: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/status", response_model=GoogleAuthStatus)
|
||||
async def auth_status(currentUser: Dict[str, Any] = Depends(auth.getCurrentActiveUser)):
|
||||
"""Check Google authentication status"""
|
||||
try:
|
||||
# For authenticated endpoints, use the current user's context
|
||||
google = googleInterface.getInterface(currentUser)
|
||||
|
||||
# Get current user token and info
|
||||
user_info, access_token = google.getCurrentUserToken()
|
||||
|
||||
if not user_info or not access_token:
|
||||
return GoogleAuthStatus(
|
||||
authenticated=False,
|
||||
message="Not authenticated with Google"
|
||||
)
|
||||
|
||||
# Convert user_info to GoogleUserInfo model
|
||||
user_info_model = GoogleUserInfo(**user_info)
|
||||
|
||||
return GoogleAuthStatus(
|
||||
authenticated=True,
|
||||
user=user_info_model
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking authentication status: {str(e)}")
|
||||
return GoogleAuthStatus(
|
||||
authenticated=False,
|
||||
message=f"Error checking authentication status: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/token", response_model=GoogleTokenResponse)
|
||||
async def get_token(currentUser: Dict[str, Any] = Depends(auth.getCurrentActiveUser)):
|
||||
"""Get Google token for current user."""
|
||||
try:
|
||||
# For authenticated endpoints, use the current user's context
|
||||
google = googleInterface.getInterface(currentUser)
|
||||
|
||||
# Get token
|
||||
token_data = google.getGoogleToken()
|
||||
if not token_data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No token found"
|
||||
)
|
||||
|
||||
# Convert to GoogleToken model
|
||||
token = GoogleToken(**token_data)
|
||||
return GoogleTokenResponse(token=token)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting token: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
@router.post("/save-token", response_model=GoogleSaveTokenResponse)
|
||||
async def save_token(
|
||||
token_data: GoogleToken,
|
||||
currentUser: Dict[str, Any] = Depends(auth.getCurrentActiveUser)
|
||||
):
|
||||
"""Save Google token data from frontend"""
|
||||
try:
|
||||
# For authenticated endpoints, use the current user's context
|
||||
google = googleInterface.getInterface(currentUser)
|
||||
|
||||
# Save token
|
||||
success = google.saveGoogleToken(token_data.model_dump())
|
||||
|
||||
if success:
|
||||
return GoogleSaveTokenResponse(
|
||||
success=True,
|
||||
message="Token saved successfully",
|
||||
token=token_data
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to save token"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving token: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error saving token: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(currentUser: Dict[str, Any] = Depends(auth.getCurrentActiveUser)):
|
||||
"""Logout from Google"""
|
||||
try:
|
||||
# For authenticated endpoints, use the current user's context
|
||||
google = googleInterface.getInterface(currentUser)
|
||||
|
||||
# Delete token
|
||||
success = google.deleteGoogleToken()
|
||||
|
||||
if success:
|
||||
return JSONResponse({
|
||||
"message": "Successfully logged out from Google"
|
||||
})
|
||||
else:
|
||||
return JSONResponse({
|
||||
"message": "Failed to logout from Google"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Logout failed: {str(e)}"
|
||||
)
|
||||
|
|
@ -1,3 +1,7 @@
|
|||
"""
|
||||
Routes for Microsoft authentication.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request, Response, status, Cookie, Body
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
||||
import logging
|
||||
|
|
@ -10,6 +14,7 @@ import modules.security.auth as auth
|
|||
|
||||
# Import interfaces
|
||||
import modules.interfaces.msftInterface as msftInterface
|
||||
import modules.interfaces.gatewayInterface as gatewayInterface
|
||||
from modules.interfaces.msftModel import (
|
||||
MsftToken,
|
||||
MsftUserInfo,
|
||||
|
|
@ -38,8 +43,8 @@ router = APIRouter(
|
|||
async def login():
|
||||
"""Initiate Microsoft login for the current user"""
|
||||
try:
|
||||
# Get Microsoft interface
|
||||
msft = msftInterface.getInterface({"_mandateId": "root", "id": "root"})
|
||||
# Get Microsoft interface with root context for initial setup
|
||||
msft = msftInterface.getRootInterface()
|
||||
|
||||
# Get login URL
|
||||
auth_url = msft.initiateLogin()
|
||||
|
|
@ -63,8 +68,8 @@ async def login():
|
|||
async def auth_callback(code: str, state: str, request: Request):
|
||||
"""Handle Microsoft OAuth callback"""
|
||||
try:
|
||||
# Get Microsoft interface
|
||||
msft = msftInterface.getInterface({"_mandateId": "root", "id": "root"})
|
||||
# Get Microsoft interface with root context for initial setup
|
||||
msft = msftInterface.getRootInterface()
|
||||
|
||||
# Handle auth callback
|
||||
token_response = msft.handleAuthCallback(code)
|
||||
|
|
@ -92,10 +97,10 @@ async def auth_callback(code: str, state: str, request: Request):
|
|||
)
|
||||
|
||||
# Get gateway interface for user operations
|
||||
gateway = auth.getRootInterface()
|
||||
gateway = gatewayInterface.getRootInterface()
|
||||
|
||||
# Check if user exists
|
||||
user = gateway.getUserByUsername(token_response["user_info"]["email"])
|
||||
user = gateway.getUserByUsername(token_response.user_info["email"])
|
||||
|
||||
# If user doesn't exist, create a new user in the default mandate
|
||||
if not user:
|
||||
|
|
@ -107,16 +112,16 @@ async def auth_callback(code: str, state: str, request: Request):
|
|||
|
||||
# Create new user with Microsoft authentication
|
||||
user = gateway.createUser(
|
||||
username=token_response["user_info"]["email"],
|
||||
email=token_response["user_info"]["email"],
|
||||
fullName=token_response["user_info"].get("name", token_response["user_info"]["email"]),
|
||||
_mandateId=rootMandateId,
|
||||
username=token_response.user_info["email"],
|
||||
email=token_response.user_info["email"],
|
||||
fullName=token_response.user_info.get("name", token_response.user_info["email"]),
|
||||
mandateId=rootMandateId,
|
||||
authenticationAuthority="microsoft"
|
||||
)
|
||||
logger.info(f"Created new user for Microsoft account: {token_response['user_info']['email']}")
|
||||
logger.info(f"Created new user for Microsoft account: {token_response.user_info['email']}")
|
||||
|
||||
# Verify user was created by retrieving it
|
||||
user = gateway.getUserByUsername(token_response["user_info"]["email"])
|
||||
user = gateway.getUserByUsername(token_response.user_info["email"])
|
||||
if not user:
|
||||
raise ValueError("Failed to retrieve created user")
|
||||
|
||||
|
|
@ -149,8 +154,8 @@ async def auth_callback(code: str, state: str, request: Request):
|
|||
access_token = auth.createAccessToken(
|
||||
data={
|
||||
"sub": user["username"],
|
||||
"_mandateId": str(user["_mandateId"]),
|
||||
"_userId": str(user["id"]),
|
||||
"mandateId": str(user["mandateId"]),
|
||||
"userId": str(user["id"]),
|
||||
"authenticationAuthority": "microsoft"
|
||||
},
|
||||
expiresDelta=access_token_expires
|
||||
|
|
@ -169,18 +174,18 @@ async def auth_callback(code: str, state: str, request: Request):
|
|||
</head>
|
||||
<body>
|
||||
<h1 class="success">Authentication Successful</h1>
|
||||
<p>Welcome, {token_response['user_info'].get('name', 'User')}!</p>
|
||||
<p>Welcome, {token_response.user_info.get('name', 'User')}!</p>
|
||||
<p>This window will close automatically.</p>
|
||||
<script>
|
||||
// Store token data in session storage
|
||||
sessionStorage.setItem('msft_token_data', JSON.stringify({json.dumps(token_response)}));
|
||||
sessionStorage.setItem('msft_token_data', JSON.stringify({json.dumps(token_response.model_dump())}));
|
||||
|
||||
// Notify parent window of success
|
||||
if (window.opener) {{
|
||||
window.opener.postMessage({{
|
||||
type: 'msft_auth_success',
|
||||
user: {json.dumps(token_response['user_info'])},
|
||||
token_data: {json.dumps(token_response)},
|
||||
user: {json.dumps(token_response.user_info)},
|
||||
token_data: {json.dumps(token_response.model_dump())},
|
||||
access_token: "{access_token}"
|
||||
}}, '*');
|
||||
}}
|
||||
|
|
@ -205,7 +210,7 @@ async def auth_callback(code: str, state: str, request: Request):
|
|||
async def auth_status(currentUser: Dict[str, Any] = Depends(auth.getCurrentActiveUser)):
|
||||
"""Check Microsoft authentication status"""
|
||||
try:
|
||||
# Get Microsoft interface
|
||||
# For authenticated endpoints, use the current user's context
|
||||
msft = msftInterface.getInterface(currentUser)
|
||||
|
||||
# Get current user token and info
|
||||
|
|
@ -236,17 +241,21 @@ async def auth_status(currentUser: Dict[str, Any] = Depends(auth.getCurrentActiv
|
|||
async def get_token(currentUser: Dict[str, Any] = Depends(auth.getCurrentActiveUser)):
|
||||
"""Get Microsoft token for current user."""
|
||||
try:
|
||||
# Get Microsoft interface
|
||||
# For authenticated endpoints, use the current user's context
|
||||
msft = msftInterface.getInterface(currentUser)
|
||||
|
||||
# Get token
|
||||
token_data = msft.getMsftToken()
|
||||
if token_data:
|
||||
return MsftTokenResponse(token=token_data)
|
||||
if not token_data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No token found"
|
||||
)
|
||||
|
||||
# Convert to MsftToken model
|
||||
token = MsftToken(**token_data)
|
||||
return MsftTokenResponse(token=token)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting token: {str(e)}")
|
||||
raise HTTPException(
|
||||
|
|
@ -261,7 +270,7 @@ async def save_token(
|
|||
):
|
||||
"""Save Microsoft token data from frontend"""
|
||||
try:
|
||||
# Get Microsoft interface
|
||||
# For authenticated endpoints, use the current user's context
|
||||
msft = msftInterface.getInterface(currentUser)
|
||||
|
||||
# Save token
|
||||
|
|
@ -290,11 +299,11 @@ async def save_token(
|
|||
async def logout(currentUser: Dict[str, Any] = Depends(auth.getCurrentActiveUser)):
|
||||
"""Logout from Microsoft"""
|
||||
try:
|
||||
# Get Microsoft interface
|
||||
# For authenticated endpoints, use the current user's context
|
||||
msft = msftInterface.getInterface(currentUser)
|
||||
|
||||
# Delete token
|
||||
success = msft.db.deleteToken(currentUser["id"])
|
||||
success = msft.deleteMsftToken()
|
||||
|
||||
if success:
|
||||
return JSONResponse({
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ async def create_user(
|
|||
"""Create a new user"""
|
||||
try:
|
||||
# Get admin user for user creation
|
||||
interfaceRoot = auth.getRootInterface()
|
||||
interfaceRoot = gatewayInterface.getRootInterface()
|
||||
|
||||
try:
|
||||
# Convert User model to dict and pass to createUser
|
||||
|
|
@ -114,10 +114,10 @@ async def update_user(
|
|||
"""Update an existing user"""
|
||||
try:
|
||||
# Get admin user for user updates
|
||||
interfaceRoot = auth.getRootInterface()
|
||||
interfaceGateway = gatewayInterface.getInterface(currentUser)
|
||||
|
||||
# Check if user exists
|
||||
existingUser = interfaceRoot.getUserById(userId)
|
||||
existingUser = interfaceGateway.getUserById(userId)
|
||||
if not existingUser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -126,7 +126,7 @@ async def update_user(
|
|||
|
||||
# Update user data
|
||||
try:
|
||||
updatedUser = interfaceRoot.updateUser(userId, userData)
|
||||
updatedUser = interfaceGateway.updateUser(userId, userData)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import modules.security.auth as auth
|
|||
# Import interfaces
|
||||
import modules.interfaces.lucydomInterface as lucydomInterface
|
||||
import modules.interfaces.msftInterface as msftInterface
|
||||
import modules.interfaces.googleInterface as googleInterface
|
||||
|
||||
# Import workflow manager
|
||||
from modules.workflow.workflowManager import getWorkflowManager
|
||||
|
|
@ -44,6 +45,21 @@ router = APIRouter(
|
|||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
|
||||
def createServiceContainer(currentUser: Dict[str, Any]):
|
||||
"""Create a service container with all required interfaces."""
|
||||
# Get all interfaces
|
||||
interfaceBase = lucydomInterface.getInterface(currentUser)
|
||||
interfaceMsft = msftInterface.getInterface(currentUser)
|
||||
interfaceGoogle = googleInterface.getInterface(currentUser)
|
||||
|
||||
# Create service container
|
||||
service = type('ServiceContainer', (), {
|
||||
'base': interfaceBase,
|
||||
'msft': interfaceMsft,
|
||||
'google': interfaceGoogle
|
||||
})
|
||||
|
||||
return service
|
||||
|
||||
# API Endpoint for getting all workflows
|
||||
@router.get("", response_model=List[ChatWorkflow])
|
||||
|
|
@ -52,11 +68,11 @@ async def list_workflows(
|
|||
):
|
||||
"""List all workflows for the current user."""
|
||||
try:
|
||||
# Get interface with current user context
|
||||
interfaceLucydom = lucydomInterface.getInterface(currentUser)
|
||||
# Get service container
|
||||
service = createServiceContainer(currentUser)
|
||||
|
||||
# Retrieve workflows for the user
|
||||
workflows = interfaceLucydom.getWorkflowsByUser(currentUser["id"])
|
||||
workflows = service.base.getWorkflowsByUser(currentUser["id"])
|
||||
return [ChatWorkflow(**workflow) for workflow in workflows]
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing workflows: {str(e)}", exc_info=True)
|
||||
|
|
@ -77,24 +93,20 @@ async def start_workflow(
|
|||
Corresponds to State 1 in the state machine documentation.
|
||||
"""
|
||||
try:
|
||||
# Get interface with current user context
|
||||
interfaceLucydom = lucydomInterface.getInterface(currentUser)
|
||||
interfaceMsft = msftInterface.getInterface(currentUser)
|
||||
# Get service container
|
||||
service = createServiceContainer(currentUser)
|
||||
|
||||
# Get workflow manager with interface
|
||||
workflowManager = await getWorkflowManager(interfaceLucydom, interfaceMsft)
|
||||
# Get workflow manager
|
||||
workflowManager = await getWorkflowManager(service)
|
||||
|
||||
# Start or continue workflow using the workflow manager
|
||||
workflow = await workflowManager.workflowStart(userInput.dict(), workflowId)
|
||||
logger.info("User Input received. Answer:", workflow)
|
||||
# Start or continue workflow
|
||||
workflow = await workflowManager.workflowStart(userInput, workflowId)
|
||||
|
||||
return workflow
|
||||
|
||||
return ChatWorkflow(**workflow)
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting workflow: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error starting workflow: {str(e)}"
|
||||
)
|
||||
logger.error(f"Error in start_workflow: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# State 8: Workflow Stopped endpoint
|
||||
@router.post("/{workflowId}/stop", response_model=ChatWorkflow)
|
||||
|
|
@ -104,37 +116,20 @@ async def stop_workflow(
|
|||
):
|
||||
"""Stops a running workflow."""
|
||||
try:
|
||||
# Get interface with current user context
|
||||
interfaceLucydom = lucydomInterface.getInterface(currentUser)
|
||||
interfaceMsft = msftInterface.getInterface(currentUser)
|
||||
# Get service container
|
||||
service = createServiceContainer(currentUser)
|
||||
|
||||
# Verify workflow exists and belongs to user
|
||||
workflow = interfaceLucydom.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Workflow with ID {workflowId} not found"
|
||||
)
|
||||
# Get workflow manager
|
||||
workflowManager = await getWorkflowManager(service)
|
||||
|
||||
if workflow.get("_userId") != currentUser["id"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have permission to stop this workflow"
|
||||
)
|
||||
# Stop workflow
|
||||
workflow = await workflowManager.workflowStop(workflowId)
|
||||
|
||||
# Stop the workflow
|
||||
workflowManager = await getWorkflowManager(interfaceLucydom, interfaceMsft)
|
||||
stoppedWorkflow = await workflowManager.workflowStop(workflowId)
|
||||
return workflow
|
||||
|
||||
return ChatWorkflow(**stoppedWorkflow)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping workflow: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error stopping workflow: {str(e)}"
|
||||
)
|
||||
logger.error(f"Error in stop_workflow: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# State 11: Workflow Reset/Deletion endpoint
|
||||
@router.delete("/{workflowId}", response_model=Dict[str, Any])
|
||||
|
|
@ -144,11 +139,11 @@ async def delete_workflow(
|
|||
):
|
||||
"""Deletes a workflow and its associated data."""
|
||||
try:
|
||||
# Get interface with current user context
|
||||
interfaceLucydom = lucydomInterface.getInterface(currentUser)
|
||||
# Get service container
|
||||
service = createServiceContainer(currentUser)
|
||||
|
||||
# Verify workflow exists
|
||||
workflow = interfaceLucydom.getWorkflow(workflowId)
|
||||
workflow = service.base.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -163,7 +158,7 @@ async def delete_workflow(
|
|||
)
|
||||
|
||||
# Delete workflow
|
||||
success = interfaceLucydom.deleteWorkflow(workflowId)
|
||||
success = service.base.deleteWorkflow(workflowId)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
|
|
@ -192,11 +187,11 @@ async def get_workflow_status(
|
|||
):
|
||||
"""Get the current status of a workflow."""
|
||||
try:
|
||||
# Get interface with current user context
|
||||
interfaceLucydom = lucydomInterface.getInterface(currentUser)
|
||||
# Get service container
|
||||
service = createServiceContainer(currentUser)
|
||||
|
||||
# Retrieve workflow
|
||||
workflow = interfaceLucydom.getWorkflow(workflowId)
|
||||
workflow = service.base.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -222,11 +217,11 @@ async def get_workflow_logs(
|
|||
):
|
||||
"""Get logs for a workflow with support for selective data transfer."""
|
||||
try:
|
||||
# Get interface with current user context
|
||||
interfaceLucydom = lucydomInterface.getInterface(currentUser)
|
||||
# Get service container
|
||||
service = createServiceContainer(currentUser)
|
||||
|
||||
# Verify workflow exists
|
||||
workflow = interfaceLucydom.getWorkflow(workflowId)
|
||||
workflow = service.base.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -234,7 +229,7 @@ async def get_workflow_logs(
|
|||
)
|
||||
|
||||
# Get all logs
|
||||
allLogs = interfaceLucydom.getWorkflowLogs(workflowId)
|
||||
allLogs = service.base.getWorkflowLogs(workflowId)
|
||||
|
||||
# Apply selective data transfer if logId is provided
|
||||
if logId:
|
||||
|
|
@ -263,11 +258,11 @@ async def get_workflow_messages(
|
|||
):
|
||||
"""Get messages for a workflow with support for selective data transfer."""
|
||||
try:
|
||||
# Get admin user for workflow operations
|
||||
interfaceLucydom = lucydomInterface.getInterface(currentUser)
|
||||
# Get service container
|
||||
service = createServiceContainer(currentUser)
|
||||
|
||||
# Verify workflow exists
|
||||
workflow = interfaceLucydom.getWorkflow(workflowId)
|
||||
workflow = service.base.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -275,7 +270,7 @@ async def get_workflow_messages(
|
|||
)
|
||||
|
||||
# Get all messages
|
||||
allMessages = interfaceLucydom.getWorkflowMessages(workflowId)
|
||||
allMessages = service.base.getWorkflowMessages(workflowId)
|
||||
|
||||
# Apply selective data transfer if messageId is provided
|
||||
if messageId:
|
||||
|
|
@ -313,11 +308,11 @@ async def delete_workflow_message(
|
|||
):
|
||||
"""Delete a message from a workflow."""
|
||||
try:
|
||||
# Get admin user for workflow operations
|
||||
interfaceLucydom = lucydomInterface.getInterface(currentUser)
|
||||
# Get service container
|
||||
service = createServiceContainer(currentUser)
|
||||
|
||||
# Verify workflow exists and belongs to user
|
||||
workflow = interfaceLucydom.getWorkflow(workflowId)
|
||||
workflow = service.base.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -325,7 +320,7 @@ async def delete_workflow_message(
|
|||
)
|
||||
|
||||
# Delete the message
|
||||
success = interfaceLucydom.deleteWorkflowMessage(workflowId, messageId)
|
||||
success = service.base.deleteWorkflowMessage(workflowId, messageId)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
|
|
@ -337,7 +332,7 @@ async def delete_workflow_message(
|
|||
messageIds = workflow.get("messageIds", [])
|
||||
if messageId in messageIds:
|
||||
messageIds.remove(messageId)
|
||||
interfaceLucydom.updateWorkflow(workflowId, {"messageIds": messageIds})
|
||||
service.base.updateWorkflow(workflowId, {"messageIds": messageIds})
|
||||
|
||||
return {
|
||||
"workflowId": workflowId,
|
||||
|
|
@ -362,11 +357,11 @@ async def delete_file_from_message(
|
|||
):
|
||||
"""Delete a file reference from a message in a workflow."""
|
||||
try:
|
||||
# Get admin user for workflow operations
|
||||
interfaceLucydom = lucydomInterface.getInterface(currentUser)
|
||||
# Get service container
|
||||
service = createServiceContainer(currentUser)
|
||||
|
||||
# Verify workflow exists and belongs to user
|
||||
workflow = interfaceLucydom.getWorkflow(workflowId)
|
||||
workflow = service.base.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -374,7 +369,7 @@ async def delete_file_from_message(
|
|||
)
|
||||
|
||||
# Delete file reference from message
|
||||
success = interfaceLucydom.deleteFileFromMessage(workflowId, messageId, fileId)
|
||||
success = service.base.deleteFileFromMessage(workflowId, messageId, fileId)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
|
|
@ -406,11 +401,11 @@ async def preview_file(
|
|||
):
|
||||
"""Get file metadata and a preview of the file content."""
|
||||
try:
|
||||
# Get admin user for workflow operations
|
||||
interfaceLucydom = lucydomInterface.getInterface(currentUser)
|
||||
# Get service container
|
||||
service = createServiceContainer(currentUser)
|
||||
|
||||
# Get file metadata
|
||||
file = interfaceLucydom.getFile(fileId)
|
||||
file = service.base.getFile(fileId)
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -418,7 +413,7 @@ async def preview_file(
|
|||
)
|
||||
|
||||
# Get file data (limited for preview)
|
||||
fileData = interfaceLucydom.getFileData(fileId)
|
||||
fileData = service.base.getFileData(fileId)
|
||||
if fileData is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -436,7 +431,7 @@ async def preview_file(
|
|||
previewData = None
|
||||
|
||||
# Get base64Encoded flag from database
|
||||
fileDataEntries = interfaceLucydom.db.getRecordset("fileData", recordFilter={"id": fileId})
|
||||
fileDataEntries = service.base.db.getRecordset("fileData", recordFilter={"id": fileId})
|
||||
if fileDataEntries and "base64Encoded" in fileDataEntries[0]:
|
||||
# Use the flag from the database
|
||||
base64Encoded = fileDataEntries[0]["base64Encoded"]
|
||||
|
|
@ -500,11 +495,11 @@ async def download_file(
|
|||
):
|
||||
"""Download a file."""
|
||||
try:
|
||||
# Get admin user for workflow operations
|
||||
interfaceLucydom = lucydomInterface.getInterface(currentUser)
|
||||
# Get service container
|
||||
service = createServiceContainer(currentUser)
|
||||
|
||||
# Get file data
|
||||
fileInfo = interfaceLucydom.downloadFile(fileId)
|
||||
fileInfo = service.base.downloadFile(fileId)
|
||||
if not fileInfo:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from fastapi.security import OAuth2PasswordBearer
|
|||
from jose import JWTError, jwt
|
||||
import logging
|
||||
|
||||
from modules.interfaces.gatewayInterface import getInterface
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
# Get Config Data
|
||||
|
|
@ -126,27 +125,3 @@ def getCurrentActiveUser(currentUser: Dict[str, Any] = Depends(_getCurrentUser))
|
|||
)
|
||||
|
||||
return currentUser
|
||||
|
||||
def getRootInterface() -> Dict[str, Any]:
|
||||
try:
|
||||
# Get the initial user ID from the database
|
||||
gateway = getInterface() # Initialize without user context
|
||||
initialUserId = gateway.getInitialId("users")
|
||||
|
||||
if not initialUserId:
|
||||
raise ValueError("No initial user ID found in database")
|
||||
|
||||
# Get the actual user record
|
||||
gateway.setUserContext(initialUserId)
|
||||
rootUser = gateway.getUser(initialUserId)
|
||||
if not rootUser:
|
||||
raise ValueError(f"Root user with ID {initialUserId} not found in database")
|
||||
|
||||
return getInterface(currentUser=rootUser)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting root access: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get root access: {str(e)}"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import uuid
|
|||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
from modules.shared.mimeUtils import isTextMimeType, determineContentEncoding
|
||||
from modules.interfaces.lucydomModel import ChatContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -29,15 +30,47 @@ class AgentBase:
|
|||
self.service = None
|
||||
|
||||
def setWorkflowManager(self, workflowManager):
|
||||
"""Set the workflow manager reference."""
|
||||
"""
|
||||
Set the workflow manager reference and validate service container.
|
||||
|
||||
Args:
|
||||
workflowManager: The workflow manager instance
|
||||
"""
|
||||
if not workflowManager:
|
||||
logger.warning("Attempted to set null workflow manager")
|
||||
return False
|
||||
|
||||
self.workflowManager = workflowManager
|
||||
# Also set service reference from workflow manager
|
||||
if workflowManager and hasattr(workflowManager, 'service'):
|
||||
self.service = workflowManager.service
|
||||
|
||||
# Set service reference from workflow manager if available
|
||||
if hasattr(workflowManager, 'service'):
|
||||
return self.setService(workflowManager.service)
|
||||
return False
|
||||
|
||||
def setService(self, service):
|
||||
"""Set the service container reference."""
|
||||
"""
|
||||
Set the service container reference and validate required interfaces.
|
||||
|
||||
Args:
|
||||
service: The service container with required interfaces
|
||||
"""
|
||||
if not service:
|
||||
logger.warning("Attempted to set null service container")
|
||||
return False
|
||||
|
||||
# Validate required interfaces
|
||||
required_interfaces = ['base', 'msft', 'google']
|
||||
missing_interfaces = []
|
||||
for interface in required_interfaces:
|
||||
if not hasattr(service, interface):
|
||||
missing_interfaces.append(interface)
|
||||
|
||||
if missing_interfaces:
|
||||
logger.warning(f"Service container missing required interfaces: {', '.join(missing_interfaces)}")
|
||||
return False
|
||||
|
||||
self.service = service
|
||||
return True
|
||||
|
||||
def getAgentInfo(self) -> Dict[str, Any]:
|
||||
"""
|
||||
|
|
@ -48,6 +81,7 @@ class AgentBase:
|
|||
"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"label": self.label,
|
||||
"description": self.description,
|
||||
"capabilities": self.capabilities
|
||||
}
|
||||
|
|
@ -77,6 +111,21 @@ class AgentBase:
|
|||
- documents: List of document objects created by the agent,
|
||||
each containing a "base64Encoded" flag in addition to "label" and "content"
|
||||
"""
|
||||
# Validate service and workflow manager
|
||||
if not self.service:
|
||||
logger.error("Service container not initialized")
|
||||
return {
|
||||
"feedback": "Error: Service container not initialized",
|
||||
"documents": []
|
||||
}
|
||||
|
||||
if not self.workflowManager:
|
||||
logger.error("Workflow manager not initialized")
|
||||
return {
|
||||
"feedback": "Error: Workflow manager not initialized",
|
||||
"documents": []
|
||||
}
|
||||
|
||||
# Base implementation - should be overridden by specialized agents
|
||||
logger.warning(f"Agent {self.name} is using the default implementation of processTask")
|
||||
return {
|
||||
|
|
@ -85,51 +134,53 @@ class AgentBase:
|
|||
}
|
||||
|
||||
def determineBase64EncodingFlag(self, filename: str, content: Any, mimeType: str = None) -> bool:
|
||||
"""Wrapper for the utility function"""
|
||||
"""
|
||||
Determine if content should be base64 encoded.
|
||||
|
||||
Args:
|
||||
filename: Name of the file
|
||||
content: Content to check
|
||||
mimeType: Optional MIME type
|
||||
|
||||
Returns:
|
||||
Boolean indicating if content should be base64 encoded
|
||||
"""
|
||||
return determineContentEncoding(filename, content, mimeType)
|
||||
|
||||
def isTextMimeType(self, mimeType: str) -> bool:
|
||||
"""Wrapper for the utility function"""
|
||||
return isTextMimeType(mimeType)
|
||||
|
||||
def formatAgentDocumentOutput(self, label: str, content: Any, mimeType: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Format agent output as a document.
|
||||
Check if MIME type is text-based.
|
||||
|
||||
Args:
|
||||
label: Label for the document
|
||||
content: Content of the document
|
||||
mimeType: Optional MIME type for the document
|
||||
mimeType: MIME type to check
|
||||
|
||||
Returns:
|
||||
Boolean indicating if MIME type is text-based
|
||||
"""
|
||||
# Create document structure
|
||||
doc = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": label,
|
||||
"ext": "txt", # Default extension
|
||||
"data": content,
|
||||
"base64Encoded": False,
|
||||
"metadata": {
|
||||
"isText": True
|
||||
}
|
||||
}
|
||||
return isTextMimeType(mimeType)
|
||||
|
||||
# Set MIME type if provided
|
||||
if mimeType:
|
||||
doc["mimeType"] = mimeType
|
||||
# Update extension based on MIME type
|
||||
if mimeType == "text/markdown":
|
||||
doc["ext"] = "md"
|
||||
elif mimeType == "text/html":
|
||||
doc["ext"] = "html"
|
||||
elif mimeType == "text/csv":
|
||||
doc["ext"] = "csv"
|
||||
elif mimeType == "application/json":
|
||||
doc["ext"] = "json"
|
||||
elif mimeType.startswith("image/"):
|
||||
doc["ext"] = mimeType.split("/")[1]
|
||||
doc["metadata"]["isText"] = False
|
||||
elif mimeType == "application/pdf":
|
||||
doc["ext"] = "pdf"
|
||||
doc["metadata"]["isText"] = False
|
||||
def formatAgentDocumentOutput(self, label: str, content: str, contentType: str, base64Encoded: bool = False) -> ChatContent:
|
||||
"""
|
||||
Format agent document output using ChatContent model.
|
||||
|
||||
return doc
|
||||
Args:
|
||||
label: Document label/filename
|
||||
content: Document content
|
||||
contentType: MIME type of content
|
||||
base64Encoded: Whether content is base64 encoded
|
||||
|
||||
Returns:
|
||||
ChatContent object with the following attributes:
|
||||
- sequenceNr: Sequence number (defaults to 1)
|
||||
- name: Document label/filename
|
||||
- mimeType: MIME type of content
|
||||
- data: Actual content
|
||||
- metadata: Additional metadata including base64Encoded flag
|
||||
"""
|
||||
return ChatContent(
|
||||
sequenceNr=1,
|
||||
name=label,
|
||||
mimeType=contentType,
|
||||
data=content,
|
||||
metadata={"base64Encoded": base64Encoded}
|
||||
)
|
||||
|
|
@ -1,7 +1,5 @@
|
|||
"""
|
||||
Agent Registry Module.
|
||||
Provides a central registry system for all available agents.
|
||||
Optimized for the standardized task processing pattern.
|
||||
Agent Registry Module for managing and initializing agents.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -34,12 +32,27 @@ class AgentRegistry:
|
|||
|
||||
def initialize(self, service=None, workflowManager=None):
|
||||
"""Initialize or update the registry with workflow manager and service references."""
|
||||
if service:
|
||||
# Validate required interfaces
|
||||
required_interfaces = ['base', 'msft', 'google']
|
||||
missing_interfaces = []
|
||||
for interface in required_interfaces:
|
||||
if not hasattr(service, interface):
|
||||
missing_interfaces.append(interface)
|
||||
|
||||
if missing_interfaces:
|
||||
logger.warning(f"Service container missing required interfaces: {', '.join(missing_interfaces)}")
|
||||
return False
|
||||
|
||||
# Initialize agents with service and workflow manager
|
||||
for agent in self.agents.values():
|
||||
if workflowManager and hasattr(agent, 'setWorkflowManager'):
|
||||
agent.setWorkflowManager(workflowManager)
|
||||
if service and hasattr(agent, 'setService'):
|
||||
agent.setService(service)
|
||||
|
||||
return True
|
||||
|
||||
def _loadAgents(self):
|
||||
"""Load all available agents from modules."""
|
||||
logger.info("Loading agent modules...")
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import os
|
|||
import io
|
||||
from typing import Dict, Any, List, Optional, Union, Tuple
|
||||
import base64
|
||||
from modules.interfaces.lucydomModel import ChatContent
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -21,7 +22,7 @@ class FileProcessingError(Exception):
|
|||
"""Custom exception for file processing errors."""
|
||||
pass
|
||||
|
||||
def getDocumentContents(fileMetadata: Dict[str, Any], fileContent: bytes) -> List[Dict[str, Any]]:
|
||||
def getDocumentContents(fileMetadata: Dict[str, Any], fileContent: bytes) -> List[ChatContent]:
|
||||
"""
|
||||
Main function for extracting content from a file based on its MIME type.
|
||||
Delegates to specialized extraction functions.
|
||||
|
|
@ -31,7 +32,7 @@ def getDocumentContents(fileMetadata: Dict[str, Any], fileContent: bytes) -> Lis
|
|||
fileContent: Binary data of the file
|
||||
|
||||
Returns:
|
||||
List of Document-Content objects with metadata and base64Encoded flag
|
||||
List of ChatContent objects with metadata and base64Encoded flag
|
||||
"""
|
||||
try:
|
||||
mimeType = fileMetadata.get("mimeType", "application/octet-stream")
|
||||
|
|
@ -142,36 +143,35 @@ def getDocumentContents(fileMetadata: Dict[str, Any], fileContent: bytes) -> Lis
|
|||
# Convert binary content to base64
|
||||
encoded_data = base64.b64encode(fileContent).decode('utf-8')
|
||||
|
||||
contents.append({
|
||||
"sequenceNr": 1,
|
||||
"name": '1_undefined',
|
||||
"ext": os.path.splitext(fileName)[1][1:] if os.path.splitext(fileName)[1] else "bin",
|
||||
"mimeType": mimeType,
|
||||
"data": encoded_data,
|
||||
"base64Encoded": True,
|
||||
"metadata": {
|
||||
"isText": False
|
||||
contents.append(ChatContent(
|
||||
sequenceNr=1,
|
||||
name='1_undefined',
|
||||
mimeType=mimeType,
|
||||
data=encoded_data,
|
||||
metadata={
|
||||
"isText": False,
|
||||
"base64Encoded": True
|
||||
}
|
||||
})
|
||||
))
|
||||
|
||||
# Add generic attributes for all documents
|
||||
for content in contents:
|
||||
# Make sure all content items have the base64Encoded flag
|
||||
if "base64Encoded" not in content:
|
||||
if isinstance(content.get("data"), bytes):
|
||||
if not hasattr(content, "base64Encoded"):
|
||||
if isinstance(content.data, bytes):
|
||||
# Convert bytes to base64
|
||||
content["data"] = base64.b64encode(content["data"]).decode('utf-8')
|
||||
content["base64Encoded"] = True
|
||||
content.data = base64.b64encode(content.data).decode('utf-8')
|
||||
content.base64Encoded = True
|
||||
else:
|
||||
# Assume text content if not explicitly marked
|
||||
content["base64Encoded"] = False
|
||||
content.base64Encoded = False
|
||||
|
||||
# Maintain backward compatibility with old "base64Encoded" flag in metadata
|
||||
if "metadata" not in content:
|
||||
content["metadata"] = {}
|
||||
if not content.metadata:
|
||||
content.metadata = {}
|
||||
|
||||
# Set base64Encoded in metadata for backward compatibility
|
||||
content["metadata"]["base64Encoded"] = content["base64Encoded"]
|
||||
content.metadata["base64Encoded"] = content.base64Encoded
|
||||
|
||||
logger.info(f"Successfully extracted {len(contents)} content items from file '{fileName}'")
|
||||
return contents
|
||||
|
|
@ -179,18 +179,16 @@ def getDocumentContents(fileMetadata: Dict[str, Any], fileContent: bytes) -> Lis
|
|||
except Exception as e:
|
||||
logger.error(f"Error during content extraction for file {fileMetadata.get('name', 'unknown')}: {str(e)}", exc_info=True)
|
||||
# Fallback on error - return original data
|
||||
return [{
|
||||
"sequenceNr": 1,
|
||||
"name": fileMetadata.get("name", "unknown"),
|
||||
"ext": os.path.splitext(fileMetadata.get("name", ""))[1][1:] if os.path.splitext(fileMetadata.get("name", ""))[1] else "bin",
|
||||
"mimeType": fileMetadata.get("mimeType", "application/octet-stream"),
|
||||
"data": base64.b64encode(fileContent).decode('utf-8'),
|
||||
"base64Encoded": True,
|
||||
"metadata": {
|
||||
return [ChatContent(
|
||||
sequenceNr=1,
|
||||
name=fileMetadata.get("name", "unknown"),
|
||||
mimeType=fileMetadata.get("mimeType", "application/octet-stream"),
|
||||
data=base64.b64encode(fileContent).decode('utf-8'),
|
||||
metadata={
|
||||
"isText": False,
|
||||
"base64Encoded": True # For backward compatibility
|
||||
"base64Encoded": True
|
||||
}
|
||||
}]
|
||||
)]
|
||||
|
||||
|
||||
def _loadPdfExtractor():
|
||||
|
|
@ -979,131 +977,7 @@ def extractBinaryContent(fileName: str, fileContent: bytes, mimeType: str) -> Li
|
|||
}
|
||||
}]
|
||||
|
||||
def processFile(self, fileContent: bytes, fileName: str, fileMetadata: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Process a file and return its contents as a list of documents.
|
||||
|
||||
Args:
|
||||
fileContent: Binary content of the file
|
||||
fileName: Name of the file
|
||||
fileMetadata: Optional metadata about the file
|
||||
|
||||
Returns:
|
||||
List of document dictionaries
|
||||
"""
|
||||
try:
|
||||
# Get file extension and MIME type
|
||||
fileExtension = os.path.splitext(fileName)[1].lower()[1:]
|
||||
mimeType = fileMetadata.get("mimeType", self.serviceBase.getMimeType(fileName)) if fileMetadata else self.serviceBase.getMimeType(fileName)
|
||||
|
||||
# Process based on file type
|
||||
if mimeType.startswith("image/"):
|
||||
return self._processImageFile(fileContent, fileName, fileExtension, mimeType, fileMetadata)
|
||||
elif mimeType == "application/pdf":
|
||||
return self._processPdfFile(fileContent, fileName, fileMetadata)
|
||||
elif mimeType == "text/csv":
|
||||
return self._processCsvFile(fileContent, fileName, fileMetadata)
|
||||
elif mimeType == "text/plain":
|
||||
return self._processTextFile(fileContent, fileName, fileMetadata)
|
||||
else:
|
||||
# Default binary file handling
|
||||
return [{
|
||||
"name": fileName,
|
||||
"ext": fileExtension,
|
||||
"mimeType": mimeType,
|
||||
"data": base64.b64encode(fileContent).decode('utf-8'),
|
||||
"base64Encoded": True,
|
||||
"metadata": {
|
||||
"isText": False
|
||||
}
|
||||
}]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file {fileName}: {str(e)}")
|
||||
raise FileProcessingError(f"Error processing file: {str(e)}")
|
||||
|
||||
def _processImageFile(self, fileContent: bytes, fileName: str, fileExtension: str, mimeType: str, fileMetadata: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
||||
"""Process an image file."""
|
||||
try:
|
||||
# Create image document
|
||||
imageDoc = {
|
||||
"name": fileName,
|
||||
"ext": fileExtension,
|
||||
"mimeType": mimeType,
|
||||
"data": base64.b64encode(fileContent).decode('utf-8'),
|
||||
"base64Encoded": True,
|
||||
"metadata": {
|
||||
"isText": False,
|
||||
"isImage": True,
|
||||
"format": fileExtension
|
||||
}
|
||||
}
|
||||
|
||||
# Add image description if available
|
||||
if fileMetadata and "description" in fileMetadata:
|
||||
imageDoc["metadata"]["description"] = fileMetadata["description"]
|
||||
|
||||
return [imageDoc]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image file {fileName}: {str(e)}")
|
||||
raise FileProcessingError(f"Error processing image file: {str(e)}")
|
||||
|
||||
def _processPdfFile(self, fileContent: bytes, fileName: str, fileMetadata: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
||||
"""Process a PDF file."""
|
||||
try:
|
||||
# Create PDF document
|
||||
pdfDoc = {
|
||||
"name": fileName,
|
||||
"ext": "pdf",
|
||||
"mimeType": "application/pdf",
|
||||
"data": base64.b64encode(fileContent).decode('utf-8'),
|
||||
"base64Encoded": True,
|
||||
"metadata": {
|
||||
"isText": False,
|
||||
"isPdf": True
|
||||
}
|
||||
}
|
||||
|
||||
return [pdfDoc]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing PDF file {fileName}: {str(e)}")
|
||||
raise FileProcessingError(f"Error processing PDF file: {str(e)}")
|
||||
|
||||
def _processCsvFile(self, fileContent: bytes, fileName: str, fileMetadata: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
||||
"""Process a CSV file."""
|
||||
try:
|
||||
# Try to decode as text first
|
||||
try:
|
||||
csvContent = fileContent.decode('utf-8')
|
||||
base64Encoded = False
|
||||
except UnicodeDecodeError:
|
||||
# If not valid UTF-8, encode as base64
|
||||
csvContent = base64.b64encode(fileContent).decode('utf-8')
|
||||
base64Encoded = True
|
||||
|
||||
# Create CSV document
|
||||
csvDoc = {
|
||||
"name": fileName,
|
||||
"ext": "csv",
|
||||
"mimeType": "text/csv",
|
||||
"data": csvContent,
|
||||
"base64Encoded": base64Encoded,
|
||||
"metadata": {
|
||||
"isText": True,
|
||||
"isCsv": True,
|
||||
"format": "csv"
|
||||
}
|
||||
}
|
||||
|
||||
return [csvDoc]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing CSV file {fileName}: {str(e)}")
|
||||
raise FileProcessingError(f"Error processing CSV file: {str(e)}")
|
||||
|
||||
def _processTextFile(self, fileContent: bytes, fileName: str, fileMetadata: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
||||
"""Process a text file."""
|
||||
try:
|
||||
# Try to decode as text
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from modules.shared.mimeUtils import isTextMimeType
|
|||
# Required imports
|
||||
from modules.workflow.agentRegistry import getAgentRegistry
|
||||
from modules.workflow.documentProcessor import getDocumentContents
|
||||
from modules.interfaces.lucydomInterface import UserInputRequest
|
||||
from modules.interfaces.lucydomModel import UserInputRequest, ChatWorkflow, ChatMessage, ChatLog
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -41,16 +41,13 @@ class WorkflowStoppedException(Exception):
|
|||
class WorkflowManager:
|
||||
"""Manages the execution of workflows and their associated agents."""
|
||||
|
||||
def __init__(self, interfaceBase, interfaceMsft):
|
||||
"""Initialize the workflow manager with interface."""
|
||||
# Create service container
|
||||
self.service = type('ServiceContainer', (), {
|
||||
'base': interfaceBase,
|
||||
'msft': interfaceMsft
|
||||
})
|
||||
def __init__(self, service):
|
||||
"""Initialize the workflow manager with service container."""
|
||||
# Store service container
|
||||
self.service = service
|
||||
|
||||
self._mandateId = interfaceBase._mandateId
|
||||
self._userId = interfaceBase._userId
|
||||
self._mandateId = service.base._mandateId
|
||||
self._userId = service.base._userId
|
||||
self.agentRegistry = getAgentRegistry()
|
||||
self.agentRegistry.initialize(service=self.service, workflowManager=self)
|
||||
|
||||
|
|
@ -90,7 +87,7 @@ class WorkflowManager:
|
|||
|
||||
### Workflow State Machine Implementation
|
||||
|
||||
async def workflowStart(self, userInput: UserInputRequest, workflowId: Optional[str] = None) -> Dict[str, Any]:
|
||||
async def workflowStart(self, userInput: UserInputRequest, workflowId: Optional[str] = None) -> ChatWorkflow:
|
||||
"""Starts a new workflow or continues an existing one."""
|
||||
try:
|
||||
# Convert UserInputRequest to dict for processing
|
||||
|
|
@ -113,59 +110,38 @@ class WorkflowManager:
|
|||
|
||||
### Forces exit
|
||||
|
||||
def checkExitCriteria(self, workflow: Dict[str, Any]):
|
||||
current_workflow = self.service.base.loadWorkflowState(workflow["id"])
|
||||
if current_workflow["status"] in ["stopped", "failed"]:
|
||||
self.logAdd(workflow, f"Workflow processing terminated due to status: {current_workflow['status']}", level="info")
|
||||
# Raise an exception to stop execution
|
||||
raise WorkflowStoppedException(f"Workflow execution stopped due to status: {current_workflow['status']}")
|
||||
|
||||
async def workflowStop(self, workflowId: str) -> Dict[str, Any]:
|
||||
def checkExitCriteria(self, workflow: ChatWorkflow) -> None:
|
||||
"""
|
||||
Stops a running workflow (State 8: Workflow Stopped).
|
||||
Sets status to "stopped" and adds a log entry.
|
||||
Check if the workflow should exit based on the current state.
|
||||
Raises WorkflowStoppedException if workflow should stop.
|
||||
|
||||
Args:
|
||||
workflowId: ID of the workflow to stop
|
||||
|
||||
Returns:
|
||||
Updated workflow with status="stopped"
|
||||
workflow: ChatWorkflow object to check
|
||||
"""
|
||||
workflow = self.service.base.loadWorkflowState(workflowId)
|
||||
if not workflow:
|
||||
return {"error": "Workflow not found", "status": "failed"}
|
||||
current_workflow = self.service.base.loadWorkflowState(workflow.id)
|
||||
if current_workflow["state"] in ["stopped", "failed"]:
|
||||
self.logAdd(workflow, f"Workflow processing terminated due to state: {current_workflow['state']}", level="info")
|
||||
# Raise an exception to stop execution
|
||||
raise WorkflowStoppedException(f"Workflow execution stopped due to state: {current_workflow['state']}")
|
||||
|
||||
# Update status to stopped
|
||||
workflow["status"] = "stopped"
|
||||
workflow["lastActivity"] = datetime.now().isoformat()
|
||||
|
||||
# Update in database
|
||||
self.service.base.updateWorkflow(workflowId, {
|
||||
"status": workflow["status"],
|
||||
"lastActivity": workflow["lastActivity"]
|
||||
})
|
||||
|
||||
self.logAdd(workflow, GLOBAL_WORKFLOW_LABELS["workflowStatusMessages"]["stopped"], level="info", progress=100)
|
||||
return workflow
|
||||
|
||||
async def workflowProcess(self, userInput: Dict[str, Any], workflow: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def workflowProcess(self, userInput: Dict[str, Any], workflow: ChatWorkflow) -> ChatWorkflow:
|
||||
"""
|
||||
Main processing function that implements the workflow state machine.
|
||||
Handles the complete workflow process from user input to final response.
|
||||
|
||||
Args:
|
||||
userInput: User input with prompt and optional file list
|
||||
workflow: Current workflow object
|
||||
workflow: Current ChatWorkflow object
|
||||
|
||||
Returns:
|
||||
Updated workflow with processing results
|
||||
Updated ChatWorkflow object with processing results
|
||||
"""
|
||||
startTime = time.time()
|
||||
try:
|
||||
# State 3: User Message Processing
|
||||
self.checkExitCriteria(workflow)
|
||||
messageUser = await self.chatMessageToWorkflow("user", None, userInput, workflow)
|
||||
messageUser["status"] = "first" # For first message
|
||||
messageUser.status = "first" # For first message
|
||||
|
||||
# State 4: Project Manager Analysis
|
||||
self.checkExitCriteria(workflow)
|
||||
|
|
@ -182,13 +158,13 @@ class WorkflowManager:
|
|||
|
||||
# Save the response as a message in the workflow and add log entries
|
||||
self.checkExitCriteria(workflow)
|
||||
responseMessage = {
|
||||
"role": "assistant",
|
||||
"agentName": "Project Manager",
|
||||
"content": objUserResponse,
|
||||
"status": "step" # As per state machine specification
|
||||
}
|
||||
self.messageAdd(workflow, responseMessage)
|
||||
responseMessage = ChatMessage(
|
||||
role="assistant",
|
||||
agentName="Project Manager",
|
||||
content=objUserResponse,
|
||||
status="step" # As per state machine specification
|
||||
)
|
||||
self.messageAdd(workflow, responseMessage.model_dump())
|
||||
|
||||
# Add detailed log entry about the task plan
|
||||
taskPlanLog = "Input: "
|
||||
|
|
@ -255,8 +231,8 @@ class WorkflowManager:
|
|||
self.checkExitCriteria(workflow)
|
||||
self.logAdd(workflow, "Creating final response", level="info", progress=90)
|
||||
finalMessage = await self.generateFinalMessage(objUserResponse, objFinalDocuments, objResults)
|
||||
finalMessage["status"] = "last" # As per state machine specification
|
||||
self.messageAdd(workflow, finalMessage)
|
||||
finalMessage.status = "last" # As per state machine specification
|
||||
self.messageAdd(workflow, finalMessage.model_dump())
|
||||
|
||||
# State 7: Workflow Completion
|
||||
self.checkExitCriteria(workflow)
|
||||
|
|
@ -264,31 +240,31 @@ class WorkflowManager:
|
|||
|
||||
# Update processing time
|
||||
endTime = time.time()
|
||||
workflow["dataStats"]["processingTime"] = endTime - startTime
|
||||
workflow.dataStats["processingTime"] = endTime - startTime
|
||||
|
||||
return workflow
|
||||
|
||||
except Exception as e:
|
||||
# State 2: Workflow Exception
|
||||
logger.error(f"Workflow processing error: {str(e)}", exc_info=True)
|
||||
workflow["status"] = "failed"
|
||||
workflow["lastActivity"] = datetime.now().isoformat()
|
||||
workflow.state = "failed"
|
||||
workflow.lastActivity = datetime.now().isoformat()
|
||||
|
||||
# Update processing time even on error
|
||||
endTime = time.time()
|
||||
workflow["dataStats"]["processingTime"] = endTime - startTime
|
||||
workflow.dataStats["processingTime"] = endTime - startTime
|
||||
|
||||
# Update in database
|
||||
self.service.base.updateWorkflow(workflow["id"], {
|
||||
"status": "failed",
|
||||
"lastActivity": workflow["lastActivity"],
|
||||
"dataStats": workflow["dataStats"]
|
||||
self.service.base.updateWorkflow(workflow.id, {
|
||||
"state": "failed",
|
||||
"lastActivity": workflow.lastActivity,
|
||||
"dataStats": workflow.dataStats
|
||||
})
|
||||
|
||||
self.logAdd(workflow, f"Workflow failed: {str(e)}", level="error", progress=100)
|
||||
return workflow
|
||||
|
||||
def workflowInit(self, workflowId: Optional[str] = None) -> Dict[str, Any]:
|
||||
def workflowInit(self, workflowId: Optional[str] = None) -> ChatWorkflow:
|
||||
"""
|
||||
Initializes a workflow or loads an existing one with round counting (State 1: Workflow Initialization).
|
||||
|
||||
|
|
@ -296,7 +272,7 @@ class WorkflowManager:
|
|||
workflowId: Optional - ID of the workflow to load
|
||||
|
||||
Returns:
|
||||
Initialized workflow object
|
||||
Initialized ChatWorkflow object
|
||||
"""
|
||||
currentTime = datetime.now().isoformat()
|
||||
|
||||
|
|
@ -304,39 +280,28 @@ class WorkflowManager:
|
|||
if workflowId is None or not workflowExist:
|
||||
# Create new workflow
|
||||
newWorkflowId = str(uuid.uuid4()) if workflowId is None else workflowId
|
||||
workflow = {
|
||||
"id": newWorkflowId,
|
||||
"_mandateId": self._mandateId,
|
||||
"_userId": self._userId,
|
||||
"name": f"Workflow {newWorkflowId[:8]}",
|
||||
"startedAt": currentTime,
|
||||
"messages": [], # Empty list - will be filled with references
|
||||
"messageIds": [], # Initialize empty messageIds list
|
||||
"logs": [],
|
||||
"dataStats": {
|
||||
workflow = ChatWorkflow(
|
||||
id=newWorkflowId,
|
||||
_mandateId=self._mandateId,
|
||||
_userId=self._userId,
|
||||
name=f"Workflow {newWorkflowId[:8]}",
|
||||
startedAt=currentTime,
|
||||
messages=[], # Empty list - will be filled with references
|
||||
messageIds=[], # Initialize empty messageIds list
|
||||
logs=[],
|
||||
dataStats={
|
||||
"bytesSent": 0,
|
||||
"bytesReceived": 0,
|
||||
"tokensUsed": 0,
|
||||
"processingTime": 0.0
|
||||
},
|
||||
"currentRound": 1,
|
||||
"status": "running",
|
||||
"lastActivity": currentTime,
|
||||
}
|
||||
currentRound=1,
|
||||
state="running",
|
||||
lastActivity=currentTime,
|
||||
)
|
||||
|
||||
# Save to database - only the workflow metadata
|
||||
workflowDb = {
|
||||
"id": workflow["id"],
|
||||
"_mandateId": workflow["_mandateId"],
|
||||
"_userId": workflow["_userId"],
|
||||
"name": workflow["name"],
|
||||
"startedAt": workflow["startedAt"],
|
||||
"status": workflow["status"],
|
||||
"dataStats": workflow["dataStats"],
|
||||
"currentRound": workflow["currentRound"],
|
||||
"lastActivity": workflow["lastActivity"],
|
||||
"messageIds": workflow["messageIds"] # Include messageIds
|
||||
}
|
||||
workflowDb = workflow.model_dump()
|
||||
self.service.base.createWorkflow(workflowDb)
|
||||
|
||||
self.logAdd(workflow, GLOBAL_WORKFLOW_LABELS["workflowStatusMessages"]["init"], level="info", progress=0)
|
||||
|
|
@ -345,43 +310,41 @@ class WorkflowManager:
|
|||
else:
|
||||
# State 10: Workflow Resumption - Load existing workflow
|
||||
workflow = self.service.base.loadWorkflowState(workflowId)
|
||||
workflow = ChatWorkflow(**workflow)
|
||||
|
||||
# Ensure messageIds exists
|
||||
if "messageIds" not in workflow:
|
||||
if not workflow.messageIds:
|
||||
# Initialize from existing messages
|
||||
workflow["messageIds"] = [msg["id"] for msg in workflow.get("messages", [])]
|
||||
workflow.messageIds = [msg["id"] for msg in workflow.messages]
|
||||
|
||||
# Update in database
|
||||
self.service.base.updateWorkflow(workflowId, {"messageIds": workflow["messageIds"]})
|
||||
self.service.base.updateWorkflow(workflowId, {"messageIds": workflow.messageIds})
|
||||
|
||||
# Update status and increment round counter
|
||||
workflow["status"] = "running"
|
||||
workflow["lastActivity"] = currentTime
|
||||
workflow.state = "running"
|
||||
workflow.lastActivity = currentTime
|
||||
|
||||
# Increment currentRound if it exists, otherwise set it to 1
|
||||
if "currentRound" in workflow:
|
||||
workflow["currentRound"] += 1
|
||||
else:
|
||||
workflow["currentRound"] = 1
|
||||
workflow.currentRound = (workflow.currentRound or 0) + 1
|
||||
|
||||
# Ensure dataStats exists with correct field names
|
||||
if "dataStats" not in workflow:
|
||||
workflow["dataStats"] = {
|
||||
if not workflow.dataStats:
|
||||
workflow.dataStats = {
|
||||
"bytesSent": 0,
|
||||
"bytesReceived": 0,
|
||||
"tokensUsed": 0,
|
||||
"processingTime": 0.0
|
||||
}
|
||||
elif "tokenCount" in workflow["dataStats"]:
|
||||
elif "tokenCount" in workflow.dataStats:
|
||||
# Convert old tokenCount to tokensUsed if needed
|
||||
workflow["dataStats"]["tokensUsed"] = workflow["dataStats"].pop("tokenCount", 0)
|
||||
workflow.dataStats["tokensUsed"] = workflow.dataStats.pop("tokenCount", 0)
|
||||
|
||||
# Update in database - only the relevant workflow fields
|
||||
workflowUpdate = {
|
||||
"status": workflow["status"],
|
||||
"lastActivity": workflow["lastActivity"],
|
||||
"currentRound": workflow["currentRound"],
|
||||
"dataStats": workflow["dataStats"] # Include updated dataStats
|
||||
"state": workflow.state,
|
||||
"lastActivity": workflow.lastActivity,
|
||||
"currentRound": workflow.currentRound,
|
||||
"dataStats": workflow.dataStats # Include updated dataStats
|
||||
}
|
||||
self.service.base.updateWorkflow(workflowId, workflowUpdate)
|
||||
|
||||
|
|
@ -1128,36 +1091,32 @@ filesDelivered = {self.parseJson2text(matchingDocuments)}
|
|||
logger.error(f"Error processing content: {str(e)}")
|
||||
return f"Error processing content: {str(e)}"
|
||||
|
||||
def messageAdd(self, workflow: Dict[str, Any], message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def messageAdd(self, workflow: ChatWorkflow, message: Dict[str, Any]) -> ChatMessage:
|
||||
"""
|
||||
Adds a message to the workflow and updates lastActivity.
|
||||
Saves the message in the database and updates the workflow with references.
|
||||
Also updates statistics for the message.
|
||||
|
||||
Args:
|
||||
workflow: Workflow object
|
||||
message: Message to be saved
|
||||
workflow: ChatWorkflow object
|
||||
message: Message data to be saved
|
||||
|
||||
Returns:
|
||||
Added message
|
||||
Added ChatMessage object
|
||||
"""
|
||||
currentTime = datetime.now().isoformat()
|
||||
|
||||
# Ensure messages list exists
|
||||
if "messages" not in workflow:
|
||||
workflow["messages"] = []
|
||||
|
||||
# Generate new message ID if not present
|
||||
if "id" not in message:
|
||||
message["id"] = f"msg_{str(uuid.uuid4())}"
|
||||
|
||||
# Add workflow ID and timestamps
|
||||
message["workflowId"] = workflow["id"]
|
||||
message["workflowId"] = workflow.id
|
||||
message["startedAt"] = currentTime
|
||||
message["finishedAt"] = currentTime
|
||||
|
||||
# Set sequence number
|
||||
message["sequenceNo"] = len(workflow["messages"]) + 1
|
||||
message["sequenceNo"] = len(workflow.messages) + 1
|
||||
|
||||
# Ensure required fields are present
|
||||
if "role" not in message:
|
||||
|
|
@ -1184,8 +1143,8 @@ filesDelivered = {self.parseJson2text(matchingDocuments)}
|
|||
tokensUsed = bytesSent
|
||||
|
||||
# Update workflow statistics
|
||||
if "dataStats" not in workflow:
|
||||
workflow["dataStats"] = {
|
||||
if not workflow.dataStats:
|
||||
workflow.dataStats = {
|
||||
"bytesSent": 0,
|
||||
"bytesReceived": 0,
|
||||
"tokensUsed": 0,
|
||||
|
|
@ -1194,37 +1153,40 @@ filesDelivered = {self.parseJson2text(matchingDocuments)}
|
|||
|
||||
# Update statistics based on message role
|
||||
if message["role"] == "user":
|
||||
workflow["dataStats"]["bytesSent"] += bytesSent
|
||||
workflow["dataStats"]["tokensUsed"] += tokensUsed
|
||||
workflow.dataStats["bytesSent"] += bytesSent
|
||||
workflow.dataStats["tokensUsed"] += tokensUsed
|
||||
else: # assistant messages
|
||||
workflow["dataStats"]["bytesReceived"] += bytesSent
|
||||
workflow["dataStats"]["tokensUsed"] += tokensUsed
|
||||
workflow.dataStats["bytesReceived"] += bytesSent
|
||||
workflow.dataStats["tokensUsed"] += tokensUsed
|
||||
|
||||
# Create ChatMessage object
|
||||
chatMessage = ChatMessage(**message)
|
||||
|
||||
# Add message to workflow
|
||||
workflow["messages"].append(message)
|
||||
workflow.messages.append(chatMessage)
|
||||
|
||||
# Ensure messageIds list exists
|
||||
if "messageIds" not in workflow:
|
||||
workflow["messageIds"] = []
|
||||
if not workflow.messageIds:
|
||||
workflow.messageIds = []
|
||||
|
||||
# Add message ID to the messageIds list
|
||||
workflow["messageIds"].append(message["id"])
|
||||
workflow.messageIds.append(chatMessage.id)
|
||||
|
||||
# Update workflow status
|
||||
workflow["lastActivity"] = currentTime
|
||||
workflow.lastActivity = currentTime
|
||||
|
||||
# Save to database - first the message itself
|
||||
self.service.base.createWorkflowMessage(message)
|
||||
self.service.base.createWorkflowMessage(chatMessage.model_dump())
|
||||
|
||||
# Then save the workflow with updated references and statistics
|
||||
workflowUpdate = {
|
||||
"lastActivity": currentTime,
|
||||
"messageIds": workflow["messageIds"],
|
||||
"dataStats": workflow["dataStats"] # Include updated statistics
|
||||
"messageIds": workflow.messageIds,
|
||||
"dataStats": workflow.dataStats # Include updated statistics
|
||||
}
|
||||
self.service.base.updateWorkflow(workflow["id"], workflowUpdate)
|
||||
self.service.base.updateWorkflow(workflow.id, workflowUpdate)
|
||||
|
||||
return message
|
||||
return chatMessage
|
||||
|
||||
def _trimDataInJson(self, jsonObj: Any) -> Any:
|
||||
"""
|
||||
|
|
@ -1249,14 +1211,14 @@ filesDelivered = {self.parseJson2text(matchingDocuments)}
|
|||
return result
|
||||
return jsonObj
|
||||
|
||||
def logAdd(self, workflow: Dict[str, Any], message: str, level: str = "info",
|
||||
def logAdd(self, workflow: ChatWorkflow, message: str, level: str = "info",
|
||||
progress: Optional[int] = None) -> str:
|
||||
"""
|
||||
Adds a log entry to the workflow and also logs it in the logger.
|
||||
Enhanced with standardized formatting and workflow status tracking.
|
||||
|
||||
Args:
|
||||
workflow: Workflow object
|
||||
workflow: ChatWorkflow object
|
||||
message: Log message
|
||||
level: Log level (info, warning, error)
|
||||
progress: Optional - Progress value (0-100)
|
||||
|
|
@ -1264,15 +1226,11 @@ filesDelivered = {self.parseJson2text(matchingDocuments)}
|
|||
Returns:
|
||||
ID of the created log entry
|
||||
"""
|
||||
# Ensure logs list exists
|
||||
if "logs" not in workflow:
|
||||
workflow["logs"] = []
|
||||
|
||||
# Generate log ID
|
||||
logId = f"log_{str(uuid.uuid4())}"
|
||||
|
||||
# Get workflow status
|
||||
workflowStatus = workflow.get("status", "running")
|
||||
workflowStatus = workflow.state
|
||||
|
||||
# Set agentName from global settings
|
||||
agentName = GLOBAL_WORKFLOW_LABELS.get("systemName", "unknown")
|
||||
|
|
@ -1291,33 +1249,33 @@ filesDelivered = {self.parseJson2text(matchingDocuments)}
|
|||
pass
|
||||
|
||||
# Create log entry
|
||||
logEntry = {
|
||||
"id": logId,
|
||||
"workflowId": workflow["id"],
|
||||
"message": processedMessage,
|
||||
"type": level,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"agentName": agentName,
|
||||
"status": workflowStatus
|
||||
}
|
||||
logEntry = ChatLog(
|
||||
id=logId,
|
||||
workflowId=workflow.id,
|
||||
message=processedMessage,
|
||||
type=level,
|
||||
timestamp=datetime.now().isoformat(),
|
||||
agentName=agentName,
|
||||
status=workflowStatus
|
||||
)
|
||||
|
||||
# Add progress if provided
|
||||
if progress is not None:
|
||||
logEntry["progress"] = progress
|
||||
logEntry.progress = progress
|
||||
|
||||
# Add log to workflow
|
||||
workflow["logs"].append(logEntry)
|
||||
workflow.logs.append(logEntry)
|
||||
|
||||
# Save in database
|
||||
self.service.base.createWorkflowLog(logEntry)
|
||||
self.service.base.createWorkflowLog(logEntry.model_dump())
|
||||
|
||||
# Also log in logger
|
||||
if level == "info":
|
||||
logger.info(f"Workflow {workflow['id']}: {processedMessage}")
|
||||
logger.info(f"Workflow {workflow.id}: {processedMessage}")
|
||||
elif level == "warning":
|
||||
logger.warning(f"Workflow {workflow['id']}: {processedMessage}")
|
||||
logger.warning(f"Workflow {workflow.id}: {processedMessage}")
|
||||
elif level == "error":
|
||||
logger.error(f"Workflow {workflow['id']}: {processedMessage}")
|
||||
logger.error(f"Workflow {workflow.id}: {processedMessage}")
|
||||
|
||||
return logId
|
||||
|
||||
|
|
@ -1571,9 +1529,9 @@ filesDelivered = {self.parseJson2text(matchingDocuments)}
|
|||
_workflowManagers = {}
|
||||
_workflowManagerLastAccess = {} # Track last access time for cleanup
|
||||
|
||||
async def getWorkflowManager(interfaceBase, interfaceMsft) -> WorkflowManager:
|
||||
async def getWorkflowManager(service) -> WorkflowManager:
|
||||
"""Get or create a workflow manager instance."""
|
||||
contextKey = f"{interfaceBase._mandateId}_{interfaceBase._userId}"
|
||||
contextKey = f"{service.base._mandateId}_{service.base._userId}"
|
||||
|
||||
# Check if we have a cached instance
|
||||
if contextKey in _workflowManagers:
|
||||
|
|
@ -1581,7 +1539,7 @@ async def getWorkflowManager(interfaceBase, interfaceMsft) -> WorkflowManager:
|
|||
return _workflowManagers[contextKey]
|
||||
|
||||
# Create new instance
|
||||
manager = WorkflowManager(interfaceBase, interfaceMsft)
|
||||
manager = WorkflowManager(service)
|
||||
|
||||
# Cache the instance
|
||||
_workflowManagers[contextKey] = manager
|
||||
|
|
|
|||
|
|
@ -1,40 +1,9 @@
|
|||
....................... TASKS
|
||||
|
||||
|
||||
TODO: Frontend to adapt
|
||||
|
||||
#####################
|
||||
|
||||
CROSS-CHECK Wrkflow set
|
||||
|
||||
|
||||
ERROR --- > when user logs in with "local" managed account and then logs in to msft account with "msft" authority, the userid is switched to the microsoft instance.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
TODO: routeGeneral: To add User Model for user creation - or to pass to interface. to check!
|
||||
|
||||
TODO: All routes not to use "*interface.py" modules for checks or data handling. Full data handling, access control, uam to be in the "*Interface.py" modules. This to adapt.
|
||||
|
||||
TODO: Assign model classes for "create" and "update" functions. not to pass specific attributes to functions or routes
|
||||
|
||||
TODO: Interface assignment overall to adapt
|
||||
|
||||
TODO: Implement userid,mandateid change overall
|
||||
|
||||
|
||||
TODO: Workflow-sub modules and agents to include Models and adaptions
|
||||
|
||||
###################
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
! function callAI() to ask with userPrompt,systemPrompt optional), not with json
|
||||
! in the taskplan to refer files always in context of user/mandate
|
||||
! userinput to handle with object AgentQuery --> when received in frontend to enhance for full object
|
||||
|
|
@ -94,6 +63,36 @@ Tools to transfer incl funds:
|
|||
|
||||
----------------------- DONE
|
||||
|
||||
|
||||
We have to correct the following wrong user access management.
|
||||
|
||||
Issue is: when user logs in with "local" managed account and then logs in to msft account with "msft" authority, the userid is switched to the microsoft instance in the workflow. this must not happen.
|
||||
Objective: The correct logic is, that a user logs in with an account (managed by "local" or other authority). Once logged in, his login does not change, also if he connects to microsoft account afterwards.
|
||||
|
||||
Problem: We have a mix between user-login (creating currentUser profile) and user-connections (attaching user to a service, like "msft" - and future other services in parallel).
|
||||
|
||||
Concept: We need to separate user-login and user-connections:
|
||||
1. the ui login and register modules produce a user-login, resulting in a currentUser profile in the backend to be used for workflow and other activities. the user gets a token (from "local" or "msft" or furthers). this token has to be checked when user logs in. ALWAYS a check is required by the according registration authority.
|
||||
those use cases:
|
||||
- if user registers with a "local" profile, a new user is created, a local token is produced
|
||||
- if user logs in with a "local" profile for an existing user, a local token is produced
|
||||
- if user logs in with a "local" profile for a non-existing user, login is denied (no user)
|
||||
- if user logs in with a "msft" profile (or other foreign profile) for an existing user, a local token AND a token in "msft" database (or other foreign system) is produced
|
||||
- if user logs in with a "msft" profile (or other foreign profile) for a non-existing user, a local profile is generated based on information from foreign account, then a local token AND a token in "msft" database (or other foreign system) is produced
|
||||
|
||||
2. the ui navigation buttons for "Login MSFT" or future other buttons to connect to services (like e.g. google account or github account or microsoft "msft" account, etc.) does NOT generate a user-login, only a user-connection to a service.
|
||||
|
||||
soloution:
|
||||
So there must be a mechanism, which manages user-login and user-connection. Following proposition: User has a user profile to login and a list of profiles for user-connections.
|
||||
Examples:
|
||||
- user registers with "local" profile --> he gets local profile with 0 user-connections
|
||||
- user registers with "msft" profile --> he gets local profile with 1 user-connections to "msft". Then he connects to another "msft" profile. Now he gets local profile 2 user-connections "msft"
|
||||
- user registers with "google" profile (future) --> he gets local profile and 1 user-connections to "google". Then he connects to another "msft" profile. Now he gets local profile and 1 user-connections "msft" and 1 user-connection "google".
|
||||
|
||||
can you tell me, how you would implement this adapted model into the pydantic model and into the code modules in a structured and maintainable way?
|
||||
|
||||
|
||||
|
||||
i want to refactor the user management in the backend through the user journey. currrently we have two problems: we always pass _userid and _mandate or id with _mandate from function to function, which blocks scaling. this is too complicated and non-logic.
|
||||
|
||||
to adapt the following:
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ pydantic==1.10.13 # Ältere Version ohne Rust-Abhängigkeit
|
|||
python-jose==3.3.0
|
||||
passlib==1.7.4
|
||||
argon2-cffi>=21.3.0 # Für Passwort-Hashing in gateway_interface.py
|
||||
google-auth-oauthlib==1.2.0 # Für Google OAuth
|
||||
google-auth==2.27.0 # Für Google Authentication
|
||||
|
||||
## Database
|
||||
mysql-connector-python==8.1.0
|
||||
|
|
|
|||
Loading…
Reference in a new issue