logical fixes
This commit is contained in:
parent
7532841d9d
commit
d98c31a4d1
29 changed files with 588 additions and 449 deletions
4
app.py
4
app.py
|
|
@ -315,7 +315,7 @@ async def lifespan(app: FastAPI):
|
|||
logger.warning(f"Could not initialize feature containers: {e}")
|
||||
|
||||
# --- Init Managers ---
|
||||
await subAutomationSchedule.start(eventUser) # Automation scheduler
|
||||
subAutomationSchedule.start(eventUser) # Automation scheduler
|
||||
eventManager.start()
|
||||
|
||||
# Register audit log cleanup scheduler
|
||||
|
|
@ -345,7 +345,7 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
# --- Stop Managers ---
|
||||
eventManager.stop()
|
||||
await subAutomationSchedule.stop(eventUser) # Automation scheduler
|
||||
subAutomationSchedule.stop(eventUser) # Automation scheduler
|
||||
|
||||
# --- Stop Feature Containers (Plug&Play) ---
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
"""File-related datamodels: FileItem, FilePreview, FileData."""
|
||||
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
import uuid
|
||||
|
|
@ -11,6 +11,7 @@ import base64
|
|||
|
||||
|
||||
class FileItem(BaseModel):
|
||||
model_config = ConfigDict(extra='allow') # Preserve system fields (_createdBy, _createdAt, etc.)
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
mandateId: Optional[str] = Field(default="", description="ID of the mandate this file belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
featureInstanceId: Optional[str] = Field(default="", description="ID of the feature instance this file belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
|
|
|
|||
|
|
@ -3,22 +3,33 @@
|
|||
"""Utility datamodels: Prompt, TextMultilingual."""
|
||||
|
||||
from typing import Dict, Optional
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
import uuid
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
model_config = ConfigDict(extra='allow') # Preserve system fields (_createdBy, _createdAt, etc.)
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
mandateId: str = Field(default="", description="ID of the mandate this prompt belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
isSystem: bool = Field(default=False, description="System prompt visible to all users (read-only for non-SysAdmin)", json_schema_extra={"frontend_type": "boolean", "frontend_readonly": True, "frontend_required": False})
|
||||
content: str = Field(description="Content of the prompt", json_schema_extra={"frontend_type": "textarea", "frontend_readonly": False, "frontend_required": True})
|
||||
name: str = Field(description="Name of the prompt", json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": True})
|
||||
|
||||
@field_validator('isSystem', mode='before')
|
||||
@classmethod
|
||||
def _coerceIsSystem(cls, v):
|
||||
"""Existing records may have isSystem=None (field didn't exist). Treat None as False."""
|
||||
if v is None:
|
||||
return False
|
||||
return v
|
||||
registerModelLabels(
|
||||
"Prompt",
|
||||
{"en": "Prompt", "fr": "Invite"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
|
||||
"isSystem": {"en": "System", "fr": "Système"},
|
||||
"content": {"en": "Content", "fr": "Contenu"},
|
||||
"name": {"en": "Name", "fr": "Nom"},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ Uses the PostgreSQL connector for data access with user/mandate filtering.
|
|||
import logging
|
||||
import uuid
|
||||
import math
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
|
||||
from modules.security.rbac import RbacClass
|
||||
|
|
@ -99,7 +98,7 @@ class AutomationObjects:
|
|||
return True
|
||||
elif accessLevel == AccessLevel.MY:
|
||||
if recordId:
|
||||
record = self.db.getRecordset(model, {"id": recordId})
|
||||
record = self.db.getRecordset(model, recordFilter={"id": recordId})
|
||||
if record:
|
||||
return record[0].get("_createdBy") == self.userId
|
||||
else:
|
||||
|
|
@ -118,16 +117,17 @@ class AutomationObjects:
|
|||
|
||||
def _enrichAutomationsWithUserAndMandate(self, automations: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Batch enrich automations with user names and mandate names for display.
|
||||
Batch enrich automations with user names, mandate names and feature instance labels.
|
||||
Uses direct DB lookup (no RBAC) because this is purely cosmetic enrichment —
|
||||
the user already has RBAC-verified access to the automations themselves.
|
||||
"""
|
||||
if not automations:
|
||||
return automations
|
||||
|
||||
# Collect all unique user IDs and mandate IDs
|
||||
# Collect all unique IDs
|
||||
userIds = set()
|
||||
mandateIds = set()
|
||||
featureInstanceIds = set()
|
||||
|
||||
for automation in automations:
|
||||
createdBy = automation.get("_createdBy")
|
||||
|
|
@ -138,47 +138,62 @@ class AutomationObjects:
|
|||
if mandateId:
|
||||
mandateIds.add(mandateId)
|
||||
|
||||
featureInstanceId = automation.get("featureInstanceId")
|
||||
if featureInstanceId:
|
||||
featureInstanceIds.add(featureInstanceId)
|
||||
|
||||
# Use root DB connector for display-only lookups (no RBAC needed)
|
||||
usersMap = {}
|
||||
mandatesMap = {}
|
||||
featureInstancesMap = {}
|
||||
try:
|
||||
from modules.datamodels.datamodelUam import UserInDB, Mandate
|
||||
from modules.datamodels.datamodelFeatures import FeatureInstance
|
||||
from modules.security.rootAccess import getRootDbAppConnector
|
||||
dbAppConn = getRootDbAppConnector()
|
||||
|
||||
# Batch fetch user display names
|
||||
usersMap = {}
|
||||
if userIds:
|
||||
for userId in userIds:
|
||||
users = dbAppConn.getRecordset(UserInDB, {"id": userId})
|
||||
users = dbAppConn.getRecordset(UserInDB, recordFilter={"id": userId})
|
||||
if users:
|
||||
user = users[0]
|
||||
fullName = f"{user.get('firstName', '')} {user.get('lastName', '')}".strip()
|
||||
usersMap[userId] = fullName or user.get("email") or user.get("username") or userId
|
||||
displayName = user.get("fullName") or user.get("username") or user.get("email") or None
|
||||
if displayName:
|
||||
usersMap[userId] = displayName
|
||||
|
||||
# Batch fetch mandate display names
|
||||
mandatesMap = {}
|
||||
if mandateIds:
|
||||
for mandateId in mandateIds:
|
||||
mandates = dbAppConn.getRecordset(Mandate, {"id": mandateId})
|
||||
mandates = dbAppConn.getRecordset(Mandate, recordFilter={"id": mandateId})
|
||||
if mandates:
|
||||
mandatesMap[mandateId] = mandates[0].get("name") or mandateId
|
||||
label = mandates[0].get("label") or mandates[0].get("name") or None
|
||||
if label:
|
||||
mandatesMap[mandateId] = label
|
||||
|
||||
# Batch fetch feature instance labels
|
||||
if featureInstanceIds:
|
||||
for fiId in featureInstanceIds:
|
||||
instances = dbAppConn.getRecordset(FeatureInstance, recordFilter={"id": fiId})
|
||||
if instances:
|
||||
fi = instances[0]
|
||||
label = fi.get("label") or fi.get("featureCode") or None
|
||||
if label:
|
||||
featureInstancesMap[fiId] = label
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not enrich automations with user/mandate names: {e}")
|
||||
usersMap = {}
|
||||
mandatesMap = {}
|
||||
logger.warning(f"Could not enrich automations with display names: {e}")
|
||||
|
||||
# Enrich each automation with the fetched data
|
||||
# SECURITY: Never show a fallback name — if lookup fails, show empty string
|
||||
for automation in automations:
|
||||
createdBy = automation.get("_createdBy")
|
||||
if createdBy:
|
||||
automation["_createdByUserName"] = usersMap.get(createdBy, createdBy)
|
||||
else:
|
||||
automation["_createdByUserName"] = "-"
|
||||
automation["_createdByUserName"] = usersMap.get(createdBy, "") if createdBy else ""
|
||||
|
||||
mandateId = automation.get("mandateId")
|
||||
if mandateId:
|
||||
automation["mandateName"] = mandatesMap.get(mandateId, mandateId)
|
||||
else:
|
||||
automation["mandateName"] = "-"
|
||||
automation["mandateName"] = mandatesMap.get(mandateId, "") if mandateId else ""
|
||||
|
||||
featureInstanceId = automation.get("featureInstanceId")
|
||||
automation["featureInstanceName"] = featureInstancesMap.get(featureInstanceId, "") if featureInstanceId else ""
|
||||
|
||||
return automations
|
||||
|
||||
|
|
@ -195,11 +210,13 @@ class AutomationObjects:
|
|||
Supports optional pagination, sorting, and filtering.
|
||||
Computes status field for each automation.
|
||||
"""
|
||||
# Use RBAC filtering
|
||||
# AutomationDefinitions can belong to any feature instance within a mandate.
|
||||
# Filter by mandateId only — not by featureInstanceId — to show all definitions across features.
|
||||
filteredAutomations = getRecordsetWithRBAC(
|
||||
self.db,
|
||||
AutomationDefinition,
|
||||
self.currentUser
|
||||
self.currentUser,
|
||||
mandateId=self.mandateId
|
||||
)
|
||||
|
||||
# Compute status for each automation and normalize executionLogs
|
||||
|
|
@ -282,12 +299,14 @@ class AutomationObjects:
|
|||
If False (default), returns Pydantic model without system fields.
|
||||
"""
|
||||
try:
|
||||
# Use RBAC filtering
|
||||
# AutomationDefinitions can belong to any feature instance within a mandate.
|
||||
# Filter by mandateId only — not by featureInstanceId.
|
||||
filtered = getRecordsetWithRBAC(
|
||||
self.db,
|
||||
AutomationDefinition,
|
||||
self.currentUser,
|
||||
recordFilter={"id": automationId}
|
||||
recordFilter={"id": automationId},
|
||||
mandateId=self.mandateId
|
||||
)
|
||||
|
||||
if not filtered:
|
||||
|
|
@ -363,8 +382,8 @@ class AutomationObjects:
|
|||
if createdAutomation.get("executionLogs") is None:
|
||||
createdAutomation["executionLogs"] = []
|
||||
|
||||
# Trigger automation change callback (async, don't wait)
|
||||
asyncio.create_task(self._notifyAutomationChanged())
|
||||
# Trigger automation change callback
|
||||
self._notifyAutomationChanged()
|
||||
|
||||
# Clean metadata fields and return Pydantic model
|
||||
cleanedRecord = {k: v for k, v in createdAutomation.items() if not k.startswith("_")}
|
||||
|
|
@ -408,8 +427,8 @@ class AutomationObjects:
|
|||
if updatedAutomation.get("executionLogs") is None:
|
||||
updatedAutomation["executionLogs"] = []
|
||||
|
||||
# Trigger automation change callback (async, don't wait)
|
||||
asyncio.create_task(self._notifyAutomationChanged())
|
||||
# Trigger automation change callback
|
||||
self._notifyAutomationChanged()
|
||||
|
||||
# Clean metadata fields and return Pydantic model
|
||||
cleanedRecord = {k: v for k, v in updatedAutomation.items() if not k.startswith("_")}
|
||||
|
|
@ -432,8 +451,8 @@ class AutomationObjects:
|
|||
# Delete automation from database
|
||||
self.db.recordDelete(AutomationDefinition, automationId)
|
||||
|
||||
# Trigger automation change callback (async, don't wait)
|
||||
asyncio.create_task(self._notifyAutomationChanged())
|
||||
# Trigger automation change callback
|
||||
self._notifyAutomationChanged()
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
|
|
@ -454,7 +473,9 @@ class AutomationObjects:
|
|||
return getRecordsetWithRBAC(
|
||||
self.db,
|
||||
AutomationDefinition,
|
||||
user
|
||||
user,
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
|
|
@ -466,7 +487,7 @@ class AutomationObjects:
|
|||
Returns automation templates filtered by RBAC (MY = own templates).
|
||||
Supports optional pagination, sorting, and filtering.
|
||||
"""
|
||||
# Use RBAC filtering
|
||||
# Templates are global (not mandate/feature-instance scoped) — no mandateId/featureInstanceId filter
|
||||
filteredTemplates = getRecordsetWithRBAC(
|
||||
self.db,
|
||||
AutomationTemplate,
|
||||
|
|
@ -526,23 +547,24 @@ class AutomationObjects:
|
|||
|
||||
userNameMap = {}
|
||||
for userId in userIds:
|
||||
users = dbAppConn.getRecordset(UserInDB, {"id": userId})
|
||||
users = dbAppConn.getRecordset(UserInDB, recordFilter={"id": userId})
|
||||
if users:
|
||||
user = users[0]
|
||||
fullName = f"{user.get('firstName', '')} {user.get('lastName', '')}".strip()
|
||||
userNameMap[userId] = fullName or user.get("email", "Unknown")
|
||||
displayName = user.get("fullName") or user.get("username") or user.get("email") or None
|
||||
if displayName:
|
||||
userNameMap[userId] = displayName
|
||||
|
||||
# Apply to templates
|
||||
# Apply to templates — SECURITY: no fallback, empty if not found
|
||||
for template in templates:
|
||||
createdBy = template.get("_createdBy")
|
||||
if createdBy and createdBy in userNameMap:
|
||||
template["_createdByUserName"] = userNameMap[createdBy]
|
||||
template["_createdByUserName"] = userNameMap.get(createdBy, "") if createdBy else ""
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not enrich templates with user names: {e}")
|
||||
|
||||
def getAutomationTemplate(self, templateId: str) -> Optional[Dict[str, Any]]:
|
||||
"""Returns an automation template by ID if user has access."""
|
||||
try:
|
||||
# Templates are global — no mandateId/featureInstanceId filter
|
||||
filtered = getRecordsetWithRBAC(
|
||||
self.db,
|
||||
AutomationTemplate,
|
||||
|
|
@ -645,12 +667,13 @@ class AutomationObjects:
|
|||
logger.error(f"Error deleting automation template: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _notifyAutomationChanged(self):
|
||||
"""Notify registered callbacks about automation changes (decoupled from features)."""
|
||||
def _notifyAutomationChanged(self):
|
||||
"""Notify registered callbacks about automation changes (decoupled from features).
|
||||
Sync-safe: works from both sync and async contexts."""
|
||||
try:
|
||||
from modules.shared.callbackRegistry import callbackRegistry
|
||||
# Trigger callbacks without knowing which features are listening
|
||||
await callbackRegistry.trigger('automation.changed', self)
|
||||
callbackRegistry.trigger('automation.changed', self)
|
||||
except Exception as e:
|
||||
logger.error(f"Error notifying automation change: {str(e)}")
|
||||
|
||||
|
|
|
|||
|
|
@ -66,7 +66,9 @@ def get_automations(
|
|||
detail=f"Invalid pagination parameter: {str(e)}"
|
||||
)
|
||||
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None)
|
||||
# AutomationDefinitions can belong to ANY feature instance within a mandate.
|
||||
# The list endpoint must show all definitions for the user's mandate, not filter by a specific featureInstanceId.
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
|
||||
result = chatInterface.getAllAutomationDefinitions(pagination=paginationParams)
|
||||
|
||||
# If pagination was requested, result is PaginatedResult
|
||||
|
|
@ -150,7 +152,7 @@ def get_available_actions(
|
|||
# Ensure methods are discovered (need a service center for discovery)
|
||||
if not methods:
|
||||
# Create a lightweight service center for method discovery
|
||||
services = getServices(context.user, context.mandateId)
|
||||
services = getServices(context.user, mandateId=context.mandateId)
|
||||
discoverMethods(services)
|
||||
|
||||
actionsList = []
|
||||
|
|
@ -235,7 +237,7 @@ def get_automation(
|
|||
) -> AutomationDefinition:
|
||||
"""Get a single automation definition by ID"""
|
||||
try:
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None)
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
|
||||
automation = chatInterface.getAutomationDefinition(automationId)
|
||||
if not automation:
|
||||
raise HTTPException(
|
||||
|
|
@ -263,7 +265,7 @@ def update_automation(
|
|||
) -> AutomationDefinition:
|
||||
"""Update an automation definition"""
|
||||
try:
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None)
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
|
||||
automationData = automation.model_dump()
|
||||
updated = chatInterface.updateAutomationDefinition(automationId, automationData)
|
||||
return updated
|
||||
|
|
@ -291,7 +293,7 @@ def update_automation_status(
|
|||
) -> AutomationDefinition:
|
||||
"""Update only the active status of an automation definition"""
|
||||
try:
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None)
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
|
||||
|
||||
# Get existing automation
|
||||
automation = chatInterface.getAutomationDefinition(automationId)
|
||||
|
|
@ -331,7 +333,7 @@ def delete_automation(
|
|||
) -> Response:
|
||||
"""Delete an automation definition"""
|
||||
try:
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None)
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
|
||||
success = chatInterface.deleteAutomationDefinition(automationId)
|
||||
if success:
|
||||
return Response(status_code=204)
|
||||
|
|
@ -364,13 +366,15 @@ async def execute_automation_route(
|
|||
"""Execute an automation immediately (test mode)"""
|
||||
try:
|
||||
from modules.services import getInterface as getServices
|
||||
services = getServices(context.user, context.mandateId)
|
||||
# Propagate feature context for billing
|
||||
if context.featureInstanceId:
|
||||
services.featureInstanceId = str(context.featureInstanceId)
|
||||
services.featureCode = 'automation'
|
||||
services = getServices(context.user, mandateId=context.mandateId, featureInstanceId=context.featureInstanceId)
|
||||
|
||||
# Load automation with current user's context (user has RBAC permissions via UI)
|
||||
automation = services.interfaceDbAutomation.getAutomationDefinition(automationId, includeSystemFields=True)
|
||||
if not automation:
|
||||
raise ValueError(f"Automation {automationId} not found")
|
||||
|
||||
from modules.workflows.automation import executeAutomation
|
||||
workflow = await executeAutomation(automationId, services)
|
||||
workflow = await executeAutomation(automationId, automation, context.user, services)
|
||||
return workflow
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -360,10 +360,12 @@ class ChatObjects:
|
|||
return False
|
||||
|
||||
tableName = modelClass.__name__
|
||||
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
||||
objectKey = buildDataObjectKey(tableName, featureCode=self.featureCode if hasattr(self, 'featureCode') else None)
|
||||
permissions = self.rbac.getUserPermissions(
|
||||
self.currentUser,
|
||||
AccessRuleContext.DATA,
|
||||
tableName,
|
||||
objectKey,
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId
|
||||
)
|
||||
|
|
|
|||
|
|
@ -91,12 +91,8 @@ async def chatProcess(
|
|||
ChatWorkflow instance
|
||||
"""
|
||||
try:
|
||||
# Get services with mandate context
|
||||
services = getServices(currentUser, mandateId)
|
||||
|
||||
# Set feature context for billing
|
||||
if featureInstanceId:
|
||||
services.featureInstanceId = featureInstanceId
|
||||
# Get services with mandate and feature instance context
|
||||
services = getServices(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId)
|
||||
services.featureCode = 'chatbot'
|
||||
|
||||
interfaceDbChat = services.interfaceDbChat
|
||||
|
|
|
|||
|
|
@ -49,10 +49,10 @@ RESOURCE_OBJECTS = [
|
|||
]
|
||||
|
||||
# Template roles for this feature
|
||||
# IMPORTANT: "viewer" role is required for automatic user assignment!
|
||||
# Role names MUST follow convention: {featureCode}-{roleName}
|
||||
TEMPLATE_ROLES = [
|
||||
{
|
||||
"roleLabel": "viewer",
|
||||
"roleLabel": "chatplayground-viewer",
|
||||
"description": {
|
||||
"en": "Chat Playground Viewer - View chat playground (read-only)",
|
||||
"de": "Chat Playground Betrachter - Chat Playground ansehen (nur lesen)",
|
||||
|
|
@ -67,7 +67,7 @@ TEMPLATE_ROLES = [
|
|||
]
|
||||
},
|
||||
{
|
||||
"roleLabel": "user",
|
||||
"roleLabel": "chatplayground-user",
|
||||
"description": {
|
||||
"en": "Chat Playground User - Use chat playground and workflows",
|
||||
"de": "Chat Playground Benutzer - Chat Playground und Workflows nutzen",
|
||||
|
|
@ -86,7 +86,7 @@ TEMPLATE_ROLES = [
|
|||
]
|
||||
},
|
||||
{
|
||||
"roleLabel": "admin",
|
||||
"roleLabel": "chatplayground-admin",
|
||||
"description": {
|
||||
"en": "Chat Playground Admin - Full access to chat playground",
|
||||
"de": "Chat Playground Admin - Vollzugriff auf Chat Playground",
|
||||
|
|
|
|||
|
|
@ -749,10 +749,12 @@ class RealEstateObjects:
|
|||
return False
|
||||
|
||||
tableName = modelClass.__name__
|
||||
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
||||
objectKey = buildDataObjectKey(tableName, featureCode=self.featureCode if hasattr(self, 'featureCode') else None)
|
||||
permissions = self.rbac.getUserPermissions(
|
||||
self.currentUser,
|
||||
AccessRuleContext.DATA,
|
||||
tableName,
|
||||
objectKey,
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId
|
||||
)
|
||||
|
|
|
|||
|
|
@ -171,10 +171,12 @@ class TrusteeObjects:
|
|||
return False
|
||||
|
||||
tableName = modelClass.__name__
|
||||
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
||||
objectKey = buildDataObjectKey(tableName, featureCode=self.featureCode if hasattr(self, 'featureCode') else None)
|
||||
permissions = self.rbac.getUserPermissions(
|
||||
self.currentUser,
|
||||
AccessRuleContext.DATA,
|
||||
tableName,
|
||||
objectKey,
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId
|
||||
)
|
||||
|
|
@ -198,10 +200,12 @@ class TrusteeObjects:
|
|||
return AccessLevel.NONE
|
||||
|
||||
tableName = modelClass.__name__
|
||||
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
||||
objectKey = buildDataObjectKey(tableName, featureCode=self.featureCode if hasattr(self, 'featureCode') else None)
|
||||
permissions = self.rbac.getUserPermissions(
|
||||
self.currentUser,
|
||||
AccessRuleContext.DATA,
|
||||
tableName,
|
||||
objectKey,
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId
|
||||
)
|
||||
|
|
@ -1470,7 +1474,7 @@ class TrusteeObjects:
|
|||
|
||||
def getAllUserAccess(self, userId: str) -> List[Dict[str, Any]]:
|
||||
"""Get all access records for a user across all organisations."""
|
||||
return self.db.getRecordset(TrusteeAccess, {"userId": userId})
|
||||
return self.db.getRecordset(TrusteeAccess, recordFilter={"userId": userId})
|
||||
|
||||
def getUserTrusteeRoles(self, userId: str, organisationId: str, contractId: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ def initAutomationTemplates(dbApp: DatabaseConnector, adminUserId: Optional[str]
|
|||
|
||||
# Get admin user ID if not provided (from poweron_app)
|
||||
if not adminUserId:
|
||||
adminUsers = dbApp.getRecordset(UserInDB, {"email": APP_CONFIG.ADMIN_EMAIL})
|
||||
adminUsers = dbApp.getRecordset(UserInDB, recordFilter={"email": APP_CONFIG.ADMIN_EMAIL})
|
||||
adminUserId = adminUsers[0]["id"] if adminUsers else None
|
||||
# Update context with admin user
|
||||
if adminUserId:
|
||||
|
|
|
|||
|
|
@ -245,10 +245,13 @@ class AppObjects:
|
|||
return False
|
||||
|
||||
tableName = modelClass.__name__
|
||||
# Use buildDataObjectKey for semantic namespace lookup
|
||||
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
||||
objectKey = buildDataObjectKey(tableName)
|
||||
permissions = self.rbac.getUserPermissions(
|
||||
self.currentUser,
|
||||
AccessRuleContext.DATA,
|
||||
tableName,
|
||||
objectKey,
|
||||
mandateId=self.mandateId
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -339,6 +339,18 @@ class ChatObjects:
|
|||
pass
|
||||
|
||||
|
||||
def _getRecordset(self, modelClass, recordFilter=None, **kwargs):
|
||||
"""Wrapper for getRecordsetWithRBAC that automatically includes mandateId/featureInstanceId."""
|
||||
return getRecordsetWithRBAC(
|
||||
self.db,
|
||||
modelClass,
|
||||
self.currentUser,
|
||||
recordFilter=recordFilter,
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def checkRbacPermission(
|
||||
self,
|
||||
modelClass: type,
|
||||
|
|
@ -610,12 +622,7 @@ class ChatObjects:
|
|||
If pagination is provided: PaginatedResult with items and metadata
|
||||
"""
|
||||
# Use RBAC filtering with featureInstanceId for instance-level isolation
|
||||
filteredWorkflows = getRecordsetWithRBAC(self.db,
|
||||
ChatWorkflow,
|
||||
self.currentUser,
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId
|
||||
)
|
||||
filteredWorkflows = self._getRecordset(ChatWorkflow)
|
||||
|
||||
# If no pagination requested, return all items (no sorting - frontend handles it)
|
||||
if pagination is None:
|
||||
|
|
@ -647,13 +654,7 @@ class ChatObjects:
|
|||
def getWorkflow(self, workflowId: str) -> Optional[ChatWorkflow]:
|
||||
"""Returns a workflow by ID if user has access."""
|
||||
# Use RBAC filtering with featureInstanceId for instance-level isolation
|
||||
workflows = getRecordsetWithRBAC(self.db,
|
||||
ChatWorkflow,
|
||||
self.currentUser,
|
||||
recordFilter={"id": workflowId},
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId
|
||||
)
|
||||
workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
|
||||
|
||||
if not workflows:
|
||||
return None
|
||||
|
|
@ -809,7 +810,7 @@ class ChatObjects:
|
|||
# Delete message documents (but NOT the files!)
|
||||
# Note: ChatStat does NOT have messageId - stats are only at workflow level
|
||||
try:
|
||||
existing_docs = getRecordsetWithRBAC(self.db, ChatDocument, self.currentUser, recordFilter={"messageId": messageId})
|
||||
existing_docs = self._getRecordset(ChatDocument, recordFilter={"messageId": messageId})
|
||||
for doc in existing_docs:
|
||||
self.db.recordDelete(ChatDocument, doc["id"])
|
||||
except Exception as e:
|
||||
|
|
@ -819,12 +820,12 @@ class ChatObjects:
|
|||
self.db.recordDelete(ChatMessage, messageId)
|
||||
|
||||
# 2. Delete workflow stats
|
||||
existing_stats = getRecordsetWithRBAC(self.db, ChatStat, self.currentUser, recordFilter={"workflowId": workflowId})
|
||||
existing_stats = self._getRecordset(ChatStat, recordFilter={"workflowId": workflowId})
|
||||
for stat in existing_stats:
|
||||
self.db.recordDelete(ChatStat, stat["id"])
|
||||
|
||||
# 3. Delete workflow logs
|
||||
existing_logs = getRecordsetWithRBAC(self.db, ChatLog, self.currentUser, recordFilter={"workflowId": workflowId})
|
||||
existing_logs = self._getRecordset(ChatLog, recordFilter={"workflowId": workflowId})
|
||||
for log in existing_logs:
|
||||
self.db.recordDelete(ChatLog, log["id"])
|
||||
|
||||
|
|
@ -855,11 +856,7 @@ class ChatObjects:
|
|||
"""
|
||||
# Check workflow access first (without calling getWorkflow to avoid circular reference)
|
||||
# Use RBAC filtering
|
||||
workflows = getRecordsetWithRBAC(self.db,
|
||||
ChatWorkflow,
|
||||
self.currentUser,
|
||||
recordFilter={"id": workflowId}
|
||||
)
|
||||
workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
|
||||
|
||||
if not workflows:
|
||||
if pagination is None:
|
||||
|
|
@ -867,7 +864,7 @@ class ChatObjects:
|
|||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||||
|
||||
# Get messages for this workflow from normalized table
|
||||
messages = getRecordsetWithRBAC(self.db, ChatMessage, self.currentUser, recordFilter={"workflowId": workflowId})
|
||||
messages = self._getRecordset(ChatMessage, recordFilter={"workflowId": workflowId})
|
||||
|
||||
# Convert raw messages to dict format for sorting/filtering
|
||||
messageDicts = []
|
||||
|
|
@ -1143,7 +1140,7 @@ class ChatObjects:
|
|||
raise ValueError("messageId cannot be empty")
|
||||
|
||||
# Check if message exists in database
|
||||
messages = getRecordsetWithRBAC(self.db, ChatMessage, self.currentUser, recordFilter={"id": messageId})
|
||||
messages = self._getRecordset(ChatMessage, recordFilter={"id": messageId})
|
||||
if not messages:
|
||||
logger.warning(f"Message with ID {messageId} does not exist in database")
|
||||
|
||||
|
|
@ -1250,12 +1247,12 @@ class ChatObjects:
|
|||
# CASCADE DELETE: Delete all related data first
|
||||
|
||||
# 1. Delete message stats
|
||||
existing_stats = getRecordsetWithRBAC(self.db, ChatStat, self.currentUser, recordFilter={"messageId": messageId})
|
||||
existing_stats = self._getRecordset(ChatStat, recordFilter={"messageId": messageId})
|
||||
for stat in existing_stats:
|
||||
self.db.recordDelete(ChatStat, stat["id"])
|
||||
|
||||
# 2. Delete message documents (but NOT the files!)
|
||||
existing_docs = getRecordsetWithRBAC(self.db, ChatDocument, self.currentUser, recordFilter={"messageId": messageId})
|
||||
existing_docs = self._getRecordset(ChatDocument, recordFilter={"messageId": messageId})
|
||||
for doc in existing_docs:
|
||||
self.db.recordDelete(ChatDocument, doc["id"])
|
||||
|
||||
|
|
@ -1282,7 +1279,7 @@ class ChatObjects:
|
|||
|
||||
|
||||
# Get documents for this message from normalized table
|
||||
documents = getRecordsetWithRBAC(self.db, ChatDocument, self.currentUser, recordFilter={"messageId": messageId})
|
||||
documents = self._getRecordset(ChatDocument, recordFilter={"messageId": messageId})
|
||||
|
||||
if not documents:
|
||||
logger.warning(f"No documents found for message {messageId}")
|
||||
|
|
@ -1323,7 +1320,7 @@ class ChatObjects:
|
|||
def getDocuments(self, messageId: str) -> List[ChatDocument]:
|
||||
"""Returns documents for a message from normalized table."""
|
||||
try:
|
||||
documents = getRecordsetWithRBAC(self.db, ChatDocument, self.currentUser, recordFilter={"messageId": messageId})
|
||||
documents = self._getRecordset(ChatDocument, recordFilter={"messageId": messageId})
|
||||
return [ChatDocument(**doc) for doc in documents]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting message documents: {str(e)}")
|
||||
|
|
@ -1369,11 +1366,7 @@ class ChatObjects:
|
|||
"""
|
||||
# Check workflow access first (without calling getWorkflow to avoid circular reference)
|
||||
# Use RBAC filtering
|
||||
workflows = getRecordsetWithRBAC(self.db,
|
||||
ChatWorkflow,
|
||||
self.currentUser,
|
||||
recordFilter={"id": workflowId}
|
||||
)
|
||||
workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
|
||||
|
||||
if not workflows:
|
||||
if pagination is None:
|
||||
|
|
@ -1381,7 +1374,7 @@ class ChatObjects:
|
|||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||||
|
||||
# Get logs for this workflow from normalized table
|
||||
logs = getRecordsetWithRBAC(self.db, ChatLog, self.currentUser, recordFilter={"workflowId": workflowId})
|
||||
logs = self._getRecordset(ChatLog, recordFilter={"workflowId": workflowId})
|
||||
|
||||
# Convert raw logs to dict format for sorting/filtering
|
||||
logDicts = []
|
||||
|
|
@ -1513,17 +1506,13 @@ class ChatObjects:
|
|||
"""Returns list of statistics for a workflow if user has access."""
|
||||
# Check workflow access first (without calling getWorkflow to avoid circular reference)
|
||||
# Use RBAC filtering
|
||||
workflows = getRecordsetWithRBAC(self.db,
|
||||
ChatWorkflow,
|
||||
self.currentUser,
|
||||
recordFilter={"id": workflowId}
|
||||
)
|
||||
workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
|
||||
|
||||
if not workflows:
|
||||
return []
|
||||
|
||||
# Get stats for this workflow from normalized table
|
||||
stats = getRecordsetWithRBAC(self.db, ChatStat, self.currentUser, recordFilter={"workflowId": workflowId})
|
||||
stats = self._getRecordset(ChatStat, recordFilter={"workflowId": workflowId})
|
||||
|
||||
if not stats:
|
||||
return []
|
||||
|
|
@ -1581,11 +1570,7 @@ class ChatObjects:
|
|||
"""
|
||||
# Check workflow access first
|
||||
# Use RBAC filtering
|
||||
workflows = getRecordsetWithRBAC(self.db,
|
||||
ChatWorkflow,
|
||||
self.currentUser,
|
||||
recordFilter={"id": workflowId}
|
||||
)
|
||||
workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
|
||||
|
||||
if not workflows:
|
||||
return {"items": []}
|
||||
|
|
@ -1594,7 +1579,7 @@ class ChatObjects:
|
|||
items = []
|
||||
|
||||
# Get messages
|
||||
messages = getRecordsetWithRBAC(self.db, ChatMessage, self.currentUser, recordFilter={"workflowId": workflowId})
|
||||
messages = self._getRecordset(ChatMessage, recordFilter={"workflowId": workflowId})
|
||||
for msg in messages:
|
||||
# Apply timestamp filtering in Python
|
||||
msgTimestamp = parseTimestamp(msg.get("publishedAt"), default=getUtcTimestamp())
|
||||
|
|
@ -1635,7 +1620,7 @@ class ChatObjects:
|
|||
})
|
||||
|
||||
# Get logs - return all logs with roundNumber if available
|
||||
logs = getRecordsetWithRBAC(self.db, ChatLog, self.currentUser, recordFilter={"workflowId": workflowId})
|
||||
logs = self._getRecordset(ChatLog, recordFilter={"workflowId": workflowId})
|
||||
for log in logs:
|
||||
# Apply timestamp filtering in Python
|
||||
logTimestamp = parseTimestamp(log.get("timestamp"), default=getUtcTimestamp())
|
||||
|
|
|
|||
|
|
@ -313,10 +313,12 @@ class ComponentObjects:
|
|||
return False
|
||||
|
||||
tableName = modelClass.__name__
|
||||
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
||||
objectKey = buildDataObjectKey(tableName)
|
||||
permissions = self.rbac.getUserPermissions(
|
||||
self.currentUser,
|
||||
AccessRuleContext.DATA,
|
||||
tableName,
|
||||
objectKey,
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId
|
||||
)
|
||||
|
|
@ -590,10 +592,58 @@ class ComponentObjects:
|
|||
|
||||
# Prompt methods
|
||||
|
||||
def _isSysAdmin(self) -> bool:
|
||||
"""Check if the current user is a SysAdmin."""
|
||||
return hasattr(self.currentUser, 'isSysAdmin') and self.currentUser.isSysAdmin
|
||||
|
||||
def _enrichPromptsWithPermissions(self, prompts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Enrich prompts with row-level _permissions based on ownership and isSystem flag.
|
||||
|
||||
- SysAdmin: canUpdate=True, canDelete=True on all prompts
|
||||
- Regular user on own prompts: canUpdate=True, canDelete=True
|
||||
- Regular user on system prompts: canUpdate=False, canDelete=False (read-only)
|
||||
"""
|
||||
isSysAdmin = self._isSysAdmin()
|
||||
for prompt in prompts:
|
||||
isOwner = prompt.get("_createdBy") == self.userId
|
||||
prompt["_permissions"] = {
|
||||
"canUpdate": isOwner or isSysAdmin,
|
||||
"canDelete": isOwner or isSysAdmin
|
||||
}
|
||||
return prompts
|
||||
|
||||
def _getPromptsForUser(self) -> List[Dict[str, Any]]:
|
||||
"""Returns prompts visible to the current user.
|
||||
|
||||
Visibility rules:
|
||||
- SysAdmin: ALL prompts
|
||||
- Regular user: own prompts (_createdBy) + system prompts (isSystem=True)
|
||||
"""
|
||||
if self._isSysAdmin():
|
||||
return self.db.getRecordset(Prompt)
|
||||
|
||||
# Get own prompts
|
||||
ownPrompts = self.db.getRecordset(Prompt, recordFilter={"_createdBy": self.userId})
|
||||
|
||||
# Get system prompts
|
||||
systemPrompts = self.db.getRecordset(Prompt, recordFilter={"isSystem": True})
|
||||
|
||||
# Merge and deduplicate (a user's own prompt could also be isSystem)
|
||||
seen = {}
|
||||
for p in ownPrompts:
|
||||
seen[p["id"]] = p
|
||||
for p in systemPrompts:
|
||||
if p["id"] not in seen:
|
||||
seen[p["id"]] = p
|
||||
|
||||
return list(seen.values())
|
||||
|
||||
def getAllPrompts(self, pagination: Optional[PaginationParams] = None) -> Union[List[Prompt], PaginatedResult]:
|
||||
"""
|
||||
Returns prompts based on user access level.
|
||||
Supports optional pagination, sorting, and filtering.
|
||||
Returns prompts with visibility rules:
|
||||
- SysAdmin: sees ALL prompts, can CRUD all
|
||||
- Regular user: sees own prompts + system prompts (isSystem=True), can only CRUD own
|
||||
- Row-level _permissions control edit/delete buttons in the UI
|
||||
|
||||
Args:
|
||||
pagination: Optional pagination parameters. If None, returns all items.
|
||||
|
|
@ -603,11 +653,11 @@ class ComponentObjects:
|
|||
If pagination is provided: PaginatedResult with items and metadata
|
||||
"""
|
||||
try:
|
||||
# Use RBAC filtering
|
||||
filteredPrompts = getRecordsetWithRBAC(self.db,
|
||||
Prompt,
|
||||
self.currentUser
|
||||
)
|
||||
# Get prompts based on user role (own + system for regular, all for SysAdmin)
|
||||
filteredPrompts = self._getPromptsForUser()
|
||||
|
||||
# Enrich with row-level permissions (_permissions: canUpdate, canDelete)
|
||||
filteredPrompts = self._enrichPromptsWithPermissions(filteredPrompts)
|
||||
|
||||
# If no pagination requested, return all items
|
||||
if pagination is None:
|
||||
|
|
@ -630,7 +680,7 @@ class ComponentObjects:
|
|||
endIdx = startIdx + pagination.pageSize
|
||||
pagedPrompts = filteredPrompts[startIdx:endIdx]
|
||||
|
||||
# Convert to model objects
|
||||
# Convert to model objects (extra='allow' on Prompt preserves system fields)
|
||||
items = [Prompt(**prompt) for prompt in pagedPrompts]
|
||||
|
||||
return PaginatedResult(
|
||||
|
|
@ -646,15 +696,24 @@ class ComponentObjects:
|
|||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||||
|
||||
def getPrompt(self, promptId: str) -> Optional[Prompt]:
|
||||
"""Returns a prompt by ID if user has access."""
|
||||
# Use RBAC filtering
|
||||
filteredPrompts = getRecordsetWithRBAC(self.db,
|
||||
Prompt,
|
||||
self.currentUser,
|
||||
recordFilter={"id": promptId}
|
||||
)
|
||||
"""Returns a prompt by ID if the user has visibility.
|
||||
|
||||
return Prompt(**filteredPrompts[0]) if filteredPrompts else None
|
||||
Visibility: SysAdmin sees all, regular user sees own + system prompts.
|
||||
"""
|
||||
filteredPrompts = self.db.getRecordset(Prompt, recordFilter={"id": promptId})
|
||||
if not filteredPrompts:
|
||||
return None
|
||||
|
||||
prompt = filteredPrompts[0]
|
||||
|
||||
# Visibility check for non-SysAdmin: must be owner or system prompt
|
||||
if not self._isSysAdmin():
|
||||
isOwner = prompt.get("_createdBy") == self.userId
|
||||
isSystem = prompt.get("isSystem", False)
|
||||
if not isOwner and not isSystem:
|
||||
return None
|
||||
|
||||
return Prompt(**prompt)
|
||||
|
||||
def createPrompt(self, promptData: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Creates a new prompt if user has permission."""
|
||||
|
|
@ -669,13 +728,25 @@ class ComponentObjects:
|
|||
return createdRecord
|
||||
|
||||
def updatePrompt(self, promptId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Updates a prompt if user has access."""
|
||||
"""Updates a prompt. Rules:
|
||||
- SysAdmin: can update any prompt (including system prompts)
|
||||
- Regular user: can only update own prompts (not system prompts)
|
||||
"""
|
||||
try:
|
||||
# Get prompt
|
||||
# Get prompt (visibility-checked)
|
||||
prompt = self.getPrompt(promptId)
|
||||
if not prompt:
|
||||
raise ValueError(f"Prompt {promptId} not found")
|
||||
|
||||
# Permission check: owner or SysAdmin
|
||||
isOwner = (getattr(prompt, '_createdBy', None) == self.userId)
|
||||
if not self._isSysAdmin() and not isOwner:
|
||||
raise PermissionError(f"No permission to update prompt {promptId}")
|
||||
|
||||
# Regular users cannot set isSystem flag
|
||||
if not self._isSysAdmin() and 'isSystem' in updateData:
|
||||
del updateData['isSystem']
|
||||
|
||||
# Update prompt record directly with the update data
|
||||
self.db.recordModify(Prompt, promptId, updateData)
|
||||
|
||||
|
|
@ -688,77 +759,69 @@ class ComponentObjects:
|
|||
|
||||
return updatedPrompt.model_dump()
|
||||
|
||||
except PermissionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating prompt: {str(e)}")
|
||||
raise ValueError(f"Failed to update prompt: {str(e)}")
|
||||
|
||||
def deletePrompt(self, promptId: str) -> bool:
|
||||
"""Deletes a prompt if user has access."""
|
||||
# Check if the prompt exists and user has access
|
||||
"""Deletes a prompt. Rules:
|
||||
- SysAdmin: can delete any prompt (including system prompts)
|
||||
- Regular user: can only delete own prompts (not system prompts)
|
||||
"""
|
||||
# Get prompt (visibility-checked)
|
||||
prompt = self.getPrompt(promptId)
|
||||
if not prompt:
|
||||
return False
|
||||
|
||||
if not self.checkRbacPermission(Prompt, "update", promptId):
|
||||
# Permission check: owner or SysAdmin
|
||||
isOwner = (getattr(prompt, '_createdBy', None) == self.userId)
|
||||
if not self._isSysAdmin() and not isOwner:
|
||||
raise PermissionError(f"No permission to delete prompt {promptId}")
|
||||
|
||||
# Delete prompt
|
||||
success = self.db.recordDelete(Prompt, promptId)
|
||||
|
||||
|
||||
return success
|
||||
|
||||
# File Utilities
|
||||
|
||||
def checkForDuplicateFile(self, fileHash: str, fileName: str = None) -> Optional[FileItem]:
|
||||
"""Checks if a file with the same hash already exists for the current user and mandate.
|
||||
If fileName is provided, also checks for exact name+hash match.
|
||||
Only returns files the current user has access to."""
|
||||
# Get files with the hash, filtered by RBAC
|
||||
accessibleFiles = getRecordsetWithRBAC(self.db,
|
||||
FileItem,
|
||||
self.currentUser,
|
||||
recordFilter={"fileHash": fileHash}
|
||||
)
|
||||
def checkForDuplicateFile(self, fileHash: str, fileName: str) -> Optional[FileItem]:
|
||||
"""Checks if a file with the same hash AND fileName already exists for the current user.
|
||||
|
||||
if not accessibleFiles:
|
||||
Duplicate = same user (_createdBy) + same fileHash + same fileName.
|
||||
Same hash with different name is allowed (intentional copy by user).
|
||||
Uses direct DB query (not RBAC) because files are isolated per user.
|
||||
"""
|
||||
if not self.userId:
|
||||
return None
|
||||
|
||||
# If fileName is provided, check for exact name+hash match first
|
||||
if fileName:
|
||||
for file in accessibleFiles:
|
||||
# Skip files without fileName key or with None/empty fileName
|
||||
if "fileName" not in file or not file["fileName"]:
|
||||
continue
|
||||
if file["fileName"] == fileName:
|
||||
return FileItem(
|
||||
id=file["id"],
|
||||
mandateId=file["mandateId"],
|
||||
fileName=file["fileName"],
|
||||
mimeType=file["mimeType"],
|
||||
fileHash=file["fileHash"],
|
||||
fileSize=file["fileSize"],
|
||||
creationDate=file["creationDate"]
|
||||
)
|
||||
# Direct DB query: find files with matching hash + name + user
|
||||
matchingFiles = self.db.getRecordset(
|
||||
FileItem,
|
||||
recordFilter={
|
||||
"_createdBy": self.userId,
|
||||
"fileHash": fileHash,
|
||||
"fileName": fileName
|
||||
}
|
||||
)
|
||||
|
||||
# Return first valid file with matching hash (for general duplicate detection)
|
||||
for file in accessibleFiles:
|
||||
# Skip files without fileName key or with None/empty fileName
|
||||
if "fileName" not in file or not file["fileName"]:
|
||||
continue
|
||||
# Use first valid file
|
||||
return FileItem(
|
||||
id=file["id"],
|
||||
mandateId=file["mandateId"],
|
||||
fileName=file["fileName"],
|
||||
mimeType=file["mimeType"],
|
||||
fileHash=file["fileHash"],
|
||||
fileSize=file["fileSize"],
|
||||
creationDate=file["creationDate"]
|
||||
)
|
||||
if not matchingFiles:
|
||||
return None
|
||||
|
||||
# If no valid files found, return None
|
||||
return None
|
||||
# Return first match
|
||||
file = matchingFiles[0]
|
||||
return FileItem(
|
||||
id=file["id"],
|
||||
mandateId=file.get("mandateId", ""),
|
||||
featureInstanceId=file.get("featureInstanceId", ""),
|
||||
fileName=file["fileName"],
|
||||
mimeType=file["mimeType"],
|
||||
fileHash=file["fileHash"],
|
||||
fileSize=file["fileSize"],
|
||||
creationDate=file["creationDate"]
|
||||
)
|
||||
|
||||
def getMimeType(self, fileName: str) -> str:
|
||||
"""Determines the MIME type based on the file extension."""
|
||||
|
|
@ -832,9 +895,18 @@ class ComponentObjects:
|
|||
|
||||
# File methods - metadata-based operations
|
||||
|
||||
def _getFilesByCurrentUser(self, recordFilter: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
||||
"""Files are always user-scoped. Returns only files owned by the current user,
|
||||
regardless of role (including SysAdmin). This bypasses RBAC intentionally."""
|
||||
filterDict = {"_createdBy": self.userId}
|
||||
if recordFilter:
|
||||
filterDict.update(recordFilter)
|
||||
return self.db.getRecordset(FileItem, recordFilter=filterDict)
|
||||
|
||||
def getAllFiles(self, pagination: Optional[PaginationParams] = None) -> Union[List[FileItem], PaginatedResult]:
|
||||
"""
|
||||
Returns files based on user access level.
|
||||
Returns files owned by the current user (user-scoped, not RBAC-based).
|
||||
Every user (including SysAdmin) only sees their own files.
|
||||
Supports optional pagination, sorting, and filtering.
|
||||
|
||||
Args:
|
||||
|
|
@ -844,13 +916,10 @@ class ComponentObjects:
|
|||
If pagination is None: List[FileItem]
|
||||
If pagination is provided: PaginatedResult with items and metadata
|
||||
"""
|
||||
# Use RBAC filtering
|
||||
filteredFiles = getRecordsetWithRBAC(self.db,
|
||||
FileItem,
|
||||
self.currentUser
|
||||
)
|
||||
# Files are always user-scoped: filter by _createdBy (bypasses RBAC SysAdmin override)
|
||||
filteredFiles = self._getFilesByCurrentUser()
|
||||
|
||||
# Convert database records to FileItem instances (for both paginated and non-paginated)
|
||||
# Convert database records to FileItem instances (extra='allow' preserves system fields like _createdBy)
|
||||
def convertFileItems(files):
|
||||
fileItems = []
|
||||
for file in files:
|
||||
|
|
@ -858,21 +927,14 @@ class ComponentObjects:
|
|||
# Ensure proper values, use defaults for invalid data
|
||||
creationDate = file.get("creationDate")
|
||||
if creationDate is None or not isinstance(creationDate, (int, float)) or creationDate <= 0:
|
||||
creationDate = getUtcTimestamp()
|
||||
file["creationDate"] = getUtcTimestamp()
|
||||
|
||||
fileName = file.get("fileName")
|
||||
if not fileName or fileName == "None":
|
||||
continue # Skip records with invalid fileName
|
||||
|
||||
fileItem = FileItem(
|
||||
id=file.get("id"),
|
||||
mandateId=file.get("mandateId"),
|
||||
fileName=fileName,
|
||||
mimeType=file.get("mimeType"),
|
||||
fileHash=file.get("fileHash"),
|
||||
fileSize=file.get("fileSize"),
|
||||
creationDate=creationDate
|
||||
)
|
||||
# Use **file to pass all fields including system fields (_createdBy, etc.)
|
||||
fileItem = FileItem(**file)
|
||||
fileItems.append(fileItem)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping invalid file record: {str(e)}")
|
||||
|
|
@ -900,7 +962,7 @@ class ComponentObjects:
|
|||
endIdx = startIdx + pagination.pageSize
|
||||
pagedFiles = filteredFiles[startIdx:endIdx]
|
||||
|
||||
# Convert to model objects
|
||||
# Convert to model objects (extra='allow' on FileItem preserves system fields)
|
||||
items = convertFileItems(pagedFiles)
|
||||
|
||||
return PaginatedResult(
|
||||
|
|
@ -910,13 +972,9 @@ class ComponentObjects:
|
|||
)
|
||||
|
||||
def getFile(self, fileId: str) -> Optional[FileItem]:
|
||||
"""Returns a file by ID if user has access."""
|
||||
# Use RBAC filtering
|
||||
filteredFiles = getRecordsetWithRBAC(self.db,
|
||||
FileItem,
|
||||
self.currentUser,
|
||||
recordFilter={"id": fileId}
|
||||
)
|
||||
"""Returns a file by ID if it belongs to the current user (user-scoped)."""
|
||||
# Files are always user-scoped: filter by _createdBy (bypasses RBAC SysAdmin override)
|
||||
filteredFiles = self._getFilesByCurrentUser(recordFilter={"id": fileId})
|
||||
|
||||
if not filteredFiles:
|
||||
return None
|
||||
|
|
@ -976,17 +1034,28 @@ class ComponentObjects:
|
|||
counter += 1
|
||||
|
||||
def createFile(self, name: str, mimeType: str, content: bytes) -> FileItem:
|
||||
"""Creates a new file entry if user has permission. Computes fileHash and fileSize from content."""
|
||||
"""Creates a new file entry if user has permission. Computes fileHash and fileSize from content.
|
||||
|
||||
Duplicate check: if a file with the same user + fileHash + fileName already exists,
|
||||
the existing file is returned instead of creating a new one.
|
||||
Same hash with different name is allowed (intentional copy by user).
|
||||
"""
|
||||
if not self.checkRbacPermission(FileItem, "create"):
|
||||
raise PermissionError("No permission to create files")
|
||||
|
||||
# Ensure fileName is unique
|
||||
uniqueName = self._generateUniquefileName(name)
|
||||
|
||||
# Compute file size and hash
|
||||
fileSize = len(content)
|
||||
fileHash = hashlib.sha256(content).hexdigest()
|
||||
|
||||
# Duplicate check: same user + same hash + same fileName → return existing
|
||||
existingFile = self.checkForDuplicateFile(fileHash, name)
|
||||
if existingFile:
|
||||
logger.info(f"Duplicate file detected in createFile: '{name}' (hash={fileHash[:12]}...) for user {self.userId} — returning existing file {existingFile.id}")
|
||||
return existingFile
|
||||
|
||||
# Ensure fileName is unique
|
||||
uniqueName = self._generateUniquefileName(name)
|
||||
|
||||
# Use mandateId and featureInstanceId from context for proper data isolation
|
||||
# Convert None to empty string to satisfy Pydantic validation
|
||||
mandateId = self.mandateId or ""
|
||||
|
|
@ -1005,7 +1074,6 @@ class ComponentObjects:
|
|||
# Store in database
|
||||
self.db.recordCreate(FileItem, fileItem)
|
||||
|
||||
|
||||
return fileItem
|
||||
|
||||
def updateFile(self, fileId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
|
@ -1040,20 +1108,16 @@ class ComponentObjects:
|
|||
if not self.checkRbacPermission(FileItem, "update", fileId):
|
||||
raise PermissionError(f"No permission to delete file {fileId}")
|
||||
|
||||
# Check for other references to this file (by hash) - use RBAC to only check files user has access to
|
||||
# Check for other references to this file (by hash) - user-scoped check
|
||||
fileHash = file.fileHash
|
||||
if fileHash:
|
||||
allReferences = getRecordsetWithRBAC(self.db,
|
||||
FileItem,
|
||||
self.currentUser,
|
||||
recordFilter={"fileHash": fileHash}
|
||||
)
|
||||
allReferences = self._getFilesByCurrentUser(recordFilter={"fileHash": fileHash})
|
||||
otherReferences = [f for f in allReferences if f["id"] != fileId]
|
||||
|
||||
# Only delete associated fileData if no other references exist
|
||||
if not otherReferences:
|
||||
try:
|
||||
fileDataEntries = getRecordsetWithRBAC(self.db, FileData, self.currentUser, recordFilter={"id": fileId})
|
||||
fileDataEntries = self.db.getRecordset(FileData, recordFilter={"id": fileId})
|
||||
if fileDataEntries:
|
||||
self.db.recordDelete(FileData, fileId)
|
||||
logger.debug(f"FileData for file {fileId} deleted")
|
||||
|
|
@ -1113,6 +1177,12 @@ class ComponentObjects:
|
|||
base64Encoded = True
|
||||
logger.debug(f"Stored file {fileId} as base64")
|
||||
|
||||
# Check if file data already exists (e.g., when createFile returned a duplicate)
|
||||
existingData = self.db.getRecordset(FileData, recordFilter={"id": fileId})
|
||||
if existingData:
|
||||
logger.debug(f"File data already exists for {fileId} — skipping duplicate storage")
|
||||
return True
|
||||
|
||||
# Create the fileData record with data and encoding flag
|
||||
fileDataObj = {
|
||||
"id": fileId,
|
||||
|
|
@ -1245,25 +1315,21 @@ class ComponentObjects:
|
|||
logger.error(f"Invalid fileContent type: {type(fileContent)}")
|
||||
raise ValueError(f"fileContent must be bytes, got {type(fileContent)}")
|
||||
|
||||
# Compute file hash first to check for duplicates
|
||||
# Compute file hash to check for duplicates before any DB writes
|
||||
fileHash = hashlib.sha256(fileContent).hexdigest()
|
||||
|
||||
# Check for exact name+hash match first (same name + same content)
|
||||
# Duplicate check: same user + same fileHash + same fileName → return existing file
|
||||
# Same hash with different name is allowed (intentional copy by user)
|
||||
existingFile = self.checkForDuplicateFile(fileHash, fileName)
|
||||
if existingFile:
|
||||
logger.info(f"Exact duplicate detected: {fileName} with same hash. Returning existing file reference.")
|
||||
logger.info(f"Duplicate detected for user {self.userId}: '{fileName}' with hash {fileHash[:12]}... — returning existing file {existingFile.id}")
|
||||
return existingFile, "exact_duplicate"
|
||||
|
||||
# Check for hash-only match (same content, different name)
|
||||
existingFileWithSameHash = self.checkForDuplicateFile(fileHash)
|
||||
if existingFileWithSameHash:
|
||||
logger.info(f"Content duplicate detected: {fileName} has same content as {existingFileWithSameHash.fileName}")
|
||||
# Continue with upload - filename will be made unique if needed
|
||||
|
||||
# Determine MIME type
|
||||
mimeType = self.getMimeType(fileName)
|
||||
|
||||
# Save metadata and file (hash/size computed inside createFile)
|
||||
# createFile handles its own duplicate check (for calls from other code paths)
|
||||
# Here we already checked, so this will create a new file
|
||||
logger.debug(f"Saving file metadata to database for file: {fileName}")
|
||||
fileItem = self.createFile(
|
||||
name=fileName,
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ def getRecordsetWithRBAC(
|
|||
|
||||
# Check view permission first
|
||||
if not permissions.view:
|
||||
logger.debug(f"User {currentUser.id} has no view permission for {objectKey}")
|
||||
logger.debug(f"User {currentUser.id} has no view permission for {objectKey} (mandateId={effectiveMandateId}, featureInstanceId={featureInstanceId})")
|
||||
return []
|
||||
|
||||
# Build WHERE clause with RBAC filtering
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ async def sync_all_automation_events(
|
|||
|
||||
from modules.services import getInterface as getServices
|
||||
services = getServices(currentUser, None)
|
||||
result = await syncAutomationEvents(services, eventUser)
|
||||
result = syncAutomationEvents(services, eventUser)
|
||||
return {
|
||||
"success": True,
|
||||
"synced": result.get("synced", 0),
|
||||
|
|
|
|||
|
|
@ -78,11 +78,18 @@ def get_permissions(
|
|||
)
|
||||
|
||||
# MULTI-TENANT: Get permissions using context (mandateId/featureInstanceId)
|
||||
# For DATA context, resolve short model names to full objectKeys
|
||||
# e.g., "ChatWorkflow" → "data.chat.ChatWorkflow"
|
||||
resolvedItem = item or ""
|
||||
if accessContext == AccessRuleContext.DATA and resolvedItem and "." not in resolvedItem:
|
||||
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
||||
resolvedItem = buildDataObjectKey(resolvedItem)
|
||||
|
||||
# Pass mandateId and featureInstanceId to load Feature-Instance roles
|
||||
permissions = interface.rbac.getUserPermissions(
|
||||
reqContext.user,
|
||||
accessContext,
|
||||
item or "",
|
||||
resolvedItem,
|
||||
mandateId=reqContext.mandateId,
|
||||
featureInstanceId=reqContext.featureInstanceId
|
||||
)
|
||||
|
|
|
|||
|
|
@ -121,10 +121,10 @@ def get_prompt(
|
|||
def update_prompt(
|
||||
request: Request,
|
||||
promptId: str = Path(..., description="ID of the prompt to update"),
|
||||
promptData: Prompt = Body(...),
|
||||
promptData: Dict[str, Any] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Prompt:
|
||||
"""Update an existing prompt"""
|
||||
"""Update an existing prompt (supports partial updates for inline editing)"""
|
||||
managementInterface = interfaceDbManagement.getInterface(currentUser)
|
||||
|
||||
# Check if the prompt exists
|
||||
|
|
@ -135,14 +135,17 @@ def update_prompt(
|
|||
detail=f"Prompt with ID {promptId} not found"
|
||||
)
|
||||
|
||||
# Convert Prompt to dict for interface, excluding the id field
|
||||
if hasattr(promptData, "model_dump"):
|
||||
update_data = promptData.model_dump(exclude={"id"})
|
||||
else:
|
||||
update_data = promptData.model_dump(exclude={"id"})
|
||||
# Remove id from update data if present
|
||||
update_data = {k: v for k, v in promptData.items() if k != "id"}
|
||||
|
||||
# Update prompt
|
||||
updatedPrompt = managementInterface.updatePrompt(promptId, update_data)
|
||||
# Update prompt (ownership check happens in interface)
|
||||
try:
|
||||
updatedPrompt = managementInterface.updatePrompt(promptId, update_data)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
if not updatedPrompt:
|
||||
raise HTTPException(
|
||||
|
|
@ -170,7 +173,14 @@ def delete_prompt(
|
|||
detail=f"Prompt with ID {promptId} not found"
|
||||
)
|
||||
|
||||
success = managementInterface.deletePrompt(promptId)
|
||||
try:
|
||||
success = managementInterface.deletePrompt(promptId)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class RbacClass:
|
|||
|
||||
Multi-Tenant Design:
|
||||
- Lädt Rollen aus UserMandate + UserMandateRole wenn mandateId gegeben
|
||||
- isSysAdmin gibt vollen Zugriff auf System-Level (kein mandateId)
|
||||
- isSysAdmin gibt vollen Zugriff, unabhängig vom Kontext
|
||||
|
||||
Args:
|
||||
user: User object
|
||||
|
|
@ -82,8 +82,8 @@ class RbacClass:
|
|||
delete=AccessLevel.NONE
|
||||
)
|
||||
|
||||
# SysAdmin auf System-Level (kein Mandant) hat vollen Zugriff
|
||||
if hasattr(user, 'isSysAdmin') and user.isSysAdmin and not mandateId:
|
||||
# SysAdmin hat vollen Zugriff - unabhängig vom Kontext (Mandant/Feature)
|
||||
if hasattr(user, 'isSysAdmin') and user.isSysAdmin:
|
||||
return UserPermissions(
|
||||
view=True,
|
||||
read=AccessLevel.ALL,
|
||||
|
|
@ -96,6 +96,7 @@ class RbacClass:
|
|||
roleIds = self._getRoleIdsForUser(user, mandateId, featureInstanceId)
|
||||
|
||||
if not roleIds:
|
||||
logger.debug(f"getUserPermissions: NO roles found for user={user.id}, mandateId={mandateId}, featureInstanceId={featureInstanceId}, item={item}")
|
||||
return permissions
|
||||
|
||||
# Lade alle relevanten Regeln für alle Rollen
|
||||
|
|
|
|||
|
|
@ -63,10 +63,11 @@ class Services:
|
|||
- Feature-specific Services are loaded dynamically via filename discovery
|
||||
"""
|
||||
|
||||
def __init__(self, user: User, workflow: "ChatWorkflow" = None, mandateId: Optional[str] = None):
|
||||
def __init__(self, user: User, workflow: "ChatWorkflow" = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None):
|
||||
self.user: User = user
|
||||
self.workflow = workflow
|
||||
self.mandateId: Optional[str] = mandateId
|
||||
self.featureInstanceId: Optional[str] = featureInstanceId
|
||||
self.currentUserPrompt: str = ""
|
||||
self.rawUserPrompt: str = ""
|
||||
|
||||
|
|
@ -83,7 +84,7 @@ class Services:
|
|||
# CENTRAL INTERFACE (Chat/Workflow)
|
||||
# ============================================================
|
||||
from modules.interfaces.interfaceDbChat import getInterface as getChatInterface
|
||||
self.interfaceDbChat = getChatInterface(user, mandateId=mandateId)
|
||||
self.interfaceDbChat = getChatInterface(user, mandateId=mandateId, featureInstanceId=featureInstanceId)
|
||||
|
||||
# ============================================================
|
||||
# SHARED SERVICES (from modules/services/)
|
||||
|
|
@ -143,7 +144,7 @@ class Services:
|
|||
|
||||
# Get interface via getInterface()
|
||||
if hasattr(module, "getInterface"):
|
||||
interface = module.getInterface(self.user, mandateId=self.mandateId)
|
||||
interface = module.getInterface(self.user, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId)
|
||||
# Derive attribute name: interfaceFeatureAiChat -> interfaceDbChat
|
||||
attrName = filename.replace("interfaceFeature", "interfaceDb")
|
||||
setattr(self, attrName, interface)
|
||||
|
|
@ -191,6 +192,6 @@ class Services:
|
|||
logger.debug(f"Could not load service from {filepath}: {e}")
|
||||
|
||||
|
||||
def getInterface(user: User, workflow: "ChatWorkflow" = None, mandateId: Optional[str] = None) -> Services:
|
||||
"""Get Services instance for the given user and mandate context."""
|
||||
return Services(user, workflow, mandateId=mandateId)
|
||||
def getInterface(user: User, workflow: "ChatWorkflow" = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> Services:
|
||||
"""Get Services instance for the given user, mandate, and feature instance context."""
|
||||
return Services(user, workflow, mandateId=mandateId, featureInstanceId=featureInstanceId)
|
||||
|
|
|
|||
|
|
@ -2574,8 +2574,8 @@ CRITICAL:
|
|||
"""
|
||||
from modules.services.serviceGeneration.renderers.registry import getRenderer
|
||||
|
||||
# Get renderer for this format - NO FALLBACK
|
||||
renderer = getRenderer(outputFormat, self.services)
|
||||
# Get document renderer for this format (structure filling is document generation path)
|
||||
renderer = getRenderer(outputFormat, self.services, outputStyle='document')
|
||||
|
||||
if not renderer:
|
||||
raise ValueError(f"No renderer found for output format '{outputFormat}'. Check renderer registry.")
|
||||
|
|
|
|||
|
|
@ -556,10 +556,10 @@ class GenerationService:
|
|||
|
||||
|
||||
def _getFormatRenderer(self, output_format: str):
|
||||
"""Get the appropriate renderer for the specified format using auto-discovery."""
|
||||
"""Get the appropriate document renderer for the specified format."""
|
||||
try:
|
||||
from .renderers.registry import getRenderer, getSupportedFormats
|
||||
renderer = getRenderer(output_format, services=self.services)
|
||||
renderer = getRenderer(output_format, services=self.services, outputStyle='document')
|
||||
|
||||
if renderer:
|
||||
return renderer
|
||||
|
|
@ -573,7 +573,7 @@ class GenerationService:
|
|||
|
||||
# Fallback to text renderer if no specific renderer found
|
||||
logger.warning(f"Falling back to text renderer for format {output_format}")
|
||||
fallbackRenderer = getRenderer('text', services=self.services)
|
||||
fallbackRenderer = getRenderer('text', services=self.services, outputStyle='document')
|
||||
if fallbackRenderer:
|
||||
return fallbackRenderer
|
||||
|
||||
|
|
|
|||
|
|
@ -922,7 +922,7 @@ CRITICAL:
|
|||
"""Get code renderer for file type."""
|
||||
from modules.services.serviceGeneration.renderers.registry import getRenderer
|
||||
|
||||
# Map file types to renderer formats
|
||||
# Map file types to renderer formats (code path)
|
||||
formatMap = {
|
||||
'json': 'json',
|
||||
'csv': 'csv',
|
||||
|
|
@ -931,7 +931,7 @@ CRITICAL:
|
|||
|
||||
rendererFormat = formatMap.get(fileType.lower())
|
||||
if rendererFormat:
|
||||
renderer = getRenderer(rendererFormat, self.services)
|
||||
renderer = getRenderer(rendererFormat, self.services, outputStyle='code')
|
||||
# Check if renderer supports code rendering
|
||||
if renderer and hasattr(renderer, 'renderCodeFiles'):
|
||||
return renderer
|
||||
|
|
|
|||
|
|
@ -2,20 +2,30 @@
|
|||
# All rights reserved.
|
||||
"""
|
||||
Renderer registry for automatic discovery and registration of renderers.
|
||||
|
||||
Renderers are indexed by (format, outputStyle) so that document generation
|
||||
and code generation each get the correct renderer for the same format.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import importlib
|
||||
from typing import Dict, Type, List, Optional
|
||||
from typing import Dict, Type, List, Optional, Tuple
|
||||
from .documentRendererBaseTemplate import BaseRenderer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RendererRegistry:
|
||||
"""Registry for automatic renderer discovery and management."""
|
||||
"""Registry for automatic renderer discovery and management.
|
||||
|
||||
Maintains separate renderer mappings per outputStyle ('document', 'code', etc.)
|
||||
so that document-generation and code-generation paths each resolve to the
|
||||
correct renderer, even when both support the same format (e.g. 'csv').
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._renderers: Dict[str, Type[BaseRenderer]] = {}
|
||||
# Key: (formatName, outputStyle) -> rendererClass
|
||||
self._renderers: Dict[Tuple[str, str], Type[BaseRenderer]] = {}
|
||||
self._format_mappings: Dict[str, str] = {}
|
||||
self._discovered = False
|
||||
|
||||
|
|
@ -25,39 +35,27 @@ class RendererRegistry:
|
|||
return
|
||||
|
||||
try:
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Get the directory containing this registry file
|
||||
currentDir = Path(__file__).parent
|
||||
renderersDir = currentDir
|
||||
|
||||
# Get the package name dynamically
|
||||
packageName = __name__.rsplit('.', 1)[0]
|
||||
|
||||
# Scan all Python files in the renderers directory
|
||||
for filePath in renderersDir.glob("*.py"):
|
||||
if filePath.name in ['registry.py', 'documentRendererBaseTemplate.py', '__init__.py']:
|
||||
for filePath in currentDir.glob("*.py"):
|
||||
if filePath.name in ['registry.py', 'documentRendererBaseTemplate.py', 'codeRendererBaseTemplate.py', '__init__.py']:
|
||||
continue
|
||||
|
||||
# Extract module name from filename
|
||||
moduleName = filePath.stem
|
||||
|
||||
try:
|
||||
# Import the module dynamically
|
||||
fullModuleName = f"{packageName}.{moduleName}"
|
||||
module = importlib.import_module(fullModuleName)
|
||||
|
||||
# Look for renderer classes in the module
|
||||
for attrName in dir(module):
|
||||
attr = getattr(module, attrName)
|
||||
if (isinstance(attr, type) and
|
||||
issubclass(attr, BaseRenderer) and
|
||||
attr != BaseRenderer and
|
||||
hasattr(attr, 'getSupportedFormats')):
|
||||
|
||||
# Register the renderer
|
||||
self._registerRendererClass(attr)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -68,60 +66,75 @@ class RendererRegistry:
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during renderer discovery: {str(e)}")
|
||||
self._discovered = True # Mark as discovered to avoid repeated attempts
|
||||
self._discovered = True
|
||||
|
||||
def _registerRendererClass(self, rendererClass: Type[BaseRenderer]) -> None:
|
||||
"""Register a renderer class with its supported formats."""
|
||||
"""Register a renderer class keyed by (format, outputStyle)."""
|
||||
try:
|
||||
# Get supported formats from the renderer class
|
||||
supportedFormats = rendererClass.getSupportedFormats()
|
||||
|
||||
# Get priority (default to 0 if not specified)
|
||||
outputStyle = rendererClass.getOutputStyle() if hasattr(rendererClass, 'getOutputStyle') else 'document'
|
||||
priority = rendererClass.getPriority() if hasattr(rendererClass, 'getPriority') else 0
|
||||
|
||||
for formatName in supportedFormats:
|
||||
formatKey = formatName.lower()
|
||||
registryKey = (formatKey, outputStyle)
|
||||
|
||||
# Check if format already registered - use priority to decide
|
||||
if formatKey in self._renderers:
|
||||
existingRenderer = self._renderers[formatKey]
|
||||
if registryKey in self._renderers:
|
||||
existingRenderer = self._renderers[registryKey]
|
||||
existingPriority = existingRenderer.getPriority() if hasattr(existingRenderer, 'getPriority') else 0
|
||||
|
||||
# Only replace if new renderer has higher priority
|
||||
if priority > existingPriority:
|
||||
logger.debug(f"Replacing {existingRenderer.__name__} with {rendererClass.__name__} for format '{formatName}' (priority {priority} > {existingPriority})")
|
||||
self._renderers[formatKey] = rendererClass
|
||||
logger.debug(f"Replacing {existingRenderer.__name__} with {rendererClass.__name__} for ({formatKey}, {outputStyle}) (priority {priority} > {existingPriority})")
|
||||
self._renderers[registryKey] = rendererClass
|
||||
else:
|
||||
logger.debug(f"Keeping {existingRenderer.__name__} for format '{formatName}' (priority {existingPriority} >= {priority})")
|
||||
logger.debug(f"Keeping {existingRenderer.__name__} for ({formatKey}, {outputStyle}) (priority {existingPriority} >= {priority})")
|
||||
else:
|
||||
# Register primary format
|
||||
self._renderers[formatKey] = rendererClass
|
||||
self._renderers[registryKey] = rendererClass
|
||||
|
||||
# Register aliases if any
|
||||
# Register aliases
|
||||
if hasattr(rendererClass, 'getFormatAliases'):
|
||||
aliases = rendererClass.getFormatAliases()
|
||||
for alias in aliases:
|
||||
self._format_mappings[alias.lower()] = formatName.lower()
|
||||
self._format_mappings[alias.lower()] = formatKey
|
||||
|
||||
logger.debug(f"Registered {rendererClass.__name__} for formats: {supportedFormats} (priority: {priority})")
|
||||
logger.debug(f"Registered {rendererClass.__name__} for formats={supportedFormats}, style={outputStyle}, priority={priority}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering renderer {rendererClass.__name__}: {str(e)}")
|
||||
|
||||
def getRenderer(self, outputFormat: str, services=None) -> Optional[BaseRenderer]:
|
||||
"""Get a renderer instance for the specified format."""
|
||||
def getRenderer(self, outputFormat: str, services=None, outputStyle: str = None) -> Optional[BaseRenderer]:
|
||||
"""Get a renderer instance for the specified format and style.
|
||||
|
||||
Args:
|
||||
outputFormat: Format name (e.g. 'csv', 'json', 'pdf')
|
||||
services: Services instance passed to renderer constructor
|
||||
outputStyle: 'document' or 'code'. If None, returns the first match
|
||||
with preference: document > code (most callers are document path).
|
||||
"""
|
||||
if not self._discovered:
|
||||
self.discoverRenderers()
|
||||
|
||||
# Normalize format name
|
||||
formatName = outputFormat.lower().strip()
|
||||
|
||||
# Check for aliases first
|
||||
if formatName in self._format_mappings:
|
||||
formatName = self._format_mappings[formatName]
|
||||
|
||||
# Get renderer class
|
||||
rendererClass = self._renderers.get(formatName)
|
||||
rendererClass = None
|
||||
|
||||
if outputStyle:
|
||||
# Exact match by style
|
||||
rendererClass = self._renderers.get((formatName, outputStyle))
|
||||
else:
|
||||
# No style specified — prefer 'document', then 'code', then any
|
||||
for style in ['document', 'code']:
|
||||
rendererClass = self._renderers.get((formatName, style))
|
||||
if rendererClass:
|
||||
break
|
||||
# Fallback: check any registered style
|
||||
if not rendererClass:
|
||||
for key, cls in self._renderers.items():
|
||||
if key[0] == formatName:
|
||||
rendererClass = cls
|
||||
break
|
||||
|
||||
if rendererClass:
|
||||
try:
|
||||
|
|
@ -130,7 +143,7 @@ class RendererRegistry:
|
|||
logger.error(f"Error creating renderer instance for {formatName}: {str(e)}")
|
||||
return None
|
||||
|
||||
logger.warning(f"No renderer found for format: {outputFormat}")
|
||||
logger.warning(f"No renderer found for format={outputFormat}, style={outputStyle}")
|
||||
return None
|
||||
|
||||
def getSupportedFormats(self) -> List[str]:
|
||||
|
|
@ -138,9 +151,11 @@ class RendererRegistry:
|
|||
if not self._discovered:
|
||||
self.discoverRenderers()
|
||||
|
||||
formats = list(self._renderers.keys())
|
||||
formats.extend(self._format_mappings.keys())
|
||||
return sorted(set(formats))
|
||||
formats = set()
|
||||
for (fmt, _style) in self._renderers.keys():
|
||||
formats.add(fmt)
|
||||
formats.update(self._format_mappings.keys())
|
||||
return sorted(formats)
|
||||
|
||||
def getRendererInfo(self) -> Dict[str, Dict[str, str]]:
|
||||
"""Get information about all registered renderers."""
|
||||
|
|
@ -148,10 +163,12 @@ class RendererRegistry:
|
|||
self.discoverRenderers()
|
||||
|
||||
info = {}
|
||||
for formatName, rendererClass in self._renderers.items():
|
||||
info[formatName] = {
|
||||
for (formatName, style), rendererClass in self._renderers.items():
|
||||
key = f"{formatName}:{style}"
|
||||
info[key] = {
|
||||
'class_name': rendererClass.__name__,
|
||||
'module': rendererClass.__module__,
|
||||
'outputStyle': style,
|
||||
'description': getattr(rendererClass, '__doc__', 'No description').strip().split('\n')[0] if rendererClass.__doc__ else 'No description'
|
||||
}
|
||||
|
||||
|
|
@ -160,44 +177,62 @@ class RendererRegistry:
|
|||
def getOutputStyle(self, outputFormat: str) -> Optional[str]:
|
||||
"""
|
||||
Get the output style classification for a given format.
|
||||
Returns: 'code', 'document', 'image', or other (e.g., 'video' for future use)
|
||||
When both 'document' and 'code' renderers exist for a format,
|
||||
returns the default ('document') since this is called during document generation.
|
||||
"""
|
||||
if not self._discovered:
|
||||
self.discoverRenderers()
|
||||
|
||||
# Normalize format name
|
||||
formatName = outputFormat.lower().strip()
|
||||
|
||||
# Check for aliases first
|
||||
if formatName in self._format_mappings:
|
||||
formatName = self._format_mappings[formatName]
|
||||
|
||||
# Get renderer class and call getOutputStyle (all renderers have same signature)
|
||||
rendererClass = self._renderers.get(formatName)
|
||||
try:
|
||||
return rendererClass.getOutputStyle(formatName)
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.warning(f"No renderer found for format: {outputFormat}, cannot determine output style")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting output style for {outputFormat}: {str(e)}")
|
||||
return None
|
||||
# Check document first, then code
|
||||
for style in ['document', 'code']:
|
||||
rendererClass = self._renderers.get((formatName, style))
|
||||
if rendererClass:
|
||||
try:
|
||||
return rendererClass.getOutputStyle(formatName)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: any style
|
||||
for key, rendererClass in self._renderers.items():
|
||||
if key[0] == formatName:
|
||||
try:
|
||||
return rendererClass.getOutputStyle(formatName)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.warning(f"No renderer found for format: {outputFormat}, cannot determine output style")
|
||||
return None
|
||||
|
||||
|
||||
# Global registry instance
|
||||
_registry = RendererRegistry()
|
||||
|
||||
def getRenderer(outputFormat: str, services=None) -> Optional[BaseRenderer]:
|
||||
"""Get a renderer instance for the specified format."""
|
||||
return _registry.getRenderer(outputFormat, services)
|
||||
|
||||
def getRenderer(outputFormat: str, services=None, outputStyle: str = None) -> Optional[BaseRenderer]:
|
||||
"""Get a renderer instance for the specified format and style.
|
||||
|
||||
Args:
|
||||
outputFormat: Format name (e.g. 'csv', 'json', 'pdf')
|
||||
services: Services instance
|
||||
outputStyle: 'document' or 'code'. If None, prefers document renderer.
|
||||
"""
|
||||
return _registry.getRenderer(outputFormat, services, outputStyle=outputStyle)
|
||||
|
||||
|
||||
def getSupportedFormats() -> List[str]:
|
||||
"""Get list of all supported formats."""
|
||||
return _registry.getSupportedFormats()
|
||||
|
||||
|
||||
def getRendererInfo() -> Dict[str, Dict[str, str]]:
|
||||
"""Get information about all registered renderers."""
|
||||
return _registry.getRendererInfo()
|
||||
|
||||
|
||||
def getOutputStyle(outputFormat: str) -> Optional[str]:
|
||||
"""Get the output style classification for a given format."""
|
||||
return _registry.getOutputStyle(outputFormat)
|
||||
|
|
|
|||
|
|
@ -35,9 +35,9 @@ class RendererCsv(BaseRenderer):
|
|||
def getAcceptedSectionTypes(cls, formatName: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Return list of section content types that CSV renderer accepts.
|
||||
CSV renderer only accepts table sections.
|
||||
CSV renderer accepts table sections and code_block sections (for raw CSV content).
|
||||
"""
|
||||
return ["table"]
|
||||
return ["table", "code_block"]
|
||||
|
||||
async def render(self, extractedContent: Dict[str, Any], title: str, userPrompt: str = None, aiService=None) -> List[RenderedDocument]:
|
||||
"""Render extracted JSON content to CSV format. Produces one CSV file per table section."""
|
||||
|
|
@ -62,16 +62,24 @@ class RendererCsv(BaseRenderer):
|
|||
if baseFilename.endswith('.csv'):
|
||||
baseFilename = baseFilename[:-4]
|
||||
|
||||
# Find all table sections
|
||||
# Collect CSV-producing sections: table sections AND code_block sections with CSV language
|
||||
tableSections = []
|
||||
codeBlockCsvSections = []
|
||||
for section in sections:
|
||||
sectionType = section.get("content_type", "paragraph")
|
||||
if sectionType == "table":
|
||||
tableSections.append(section)
|
||||
elif sectionType == "code_block":
|
||||
# Check if any element is a code_block with language "csv"
|
||||
for element in section.get("elements", []):
|
||||
content = element.get("content", {})
|
||||
if isinstance(content, dict) and content.get("language", "").lower() == "csv":
|
||||
codeBlockCsvSections.append(section)
|
||||
break
|
||||
|
||||
# If no table sections found, return empty CSV
|
||||
if not tableSections:
|
||||
self.logger.warning("No table sections found in CSV document - returning empty CSV")
|
||||
# If no usable sections found, return empty CSV
|
||||
if not tableSections and not codeBlockCsvSections:
|
||||
self.logger.warning("No table or CSV code_block sections found in CSV document - returning empty CSV")
|
||||
emptyCsv = self._convertRowsToCsv([["No table data available"]])
|
||||
return [
|
||||
RenderedDocument(
|
||||
|
|
@ -83,45 +91,52 @@ class RendererCsv(BaseRenderer):
|
|||
)
|
||||
]
|
||||
|
||||
# Generate one CSV file per table section
|
||||
allCsvSections = tableSections + codeBlockCsvSections
|
||||
|
||||
# Generate one CSV file per section
|
||||
renderedDocuments = []
|
||||
for i, tableSection in enumerate(tableSections):
|
||||
# Generate CSV content for this table section
|
||||
csvRows = []
|
||||
for i, csvSection in enumerate(allCsvSections):
|
||||
sectionType = csvSection.get("content_type", "paragraph")
|
||||
sectionTitle = csvSection.get("title")
|
||||
csvContent = ""
|
||||
|
||||
# Add section title if available
|
||||
sectionTitle = tableSection.get("title")
|
||||
if sectionTitle:
|
||||
csvRows.append([sectionTitle])
|
||||
csvRows.append([]) # Empty row after title
|
||||
if sectionType == "code_block":
|
||||
# Extract raw CSV content directly from code_block elements
|
||||
rawCsvParts = []
|
||||
for element in csvSection.get("elements", []):
|
||||
content = element.get("content", {})
|
||||
if isinstance(content, dict) and content.get("language", "").lower() == "csv":
|
||||
code = content.get("code", "")
|
||||
if code:
|
||||
rawCsvParts.append(code)
|
||||
csvContent = "\n".join(rawCsvParts)
|
||||
else:
|
||||
# Table section — render via table logic
|
||||
csvRows = []
|
||||
if sectionTitle:
|
||||
csvRows.append([sectionTitle])
|
||||
csvRows.append([]) # Empty row after title
|
||||
|
||||
# Render table from section elements
|
||||
elements = tableSection.get("elements", [])
|
||||
for element in elements:
|
||||
tableRows = self._renderJsonTableToCsv(element)
|
||||
if tableRows:
|
||||
csvRows.extend(tableRows)
|
||||
elements = csvSection.get("elements", [])
|
||||
for element in elements:
|
||||
tableRows = self._renderJsonTableToCsv(element)
|
||||
if tableRows:
|
||||
csvRows.extend(tableRows)
|
||||
|
||||
# Convert to CSV string
|
||||
csvContent = self._convertRowsToCsv(csvRows)
|
||||
csvContent = self._convertRowsToCsv(csvRows)
|
||||
|
||||
# Determine filename for this table
|
||||
if len(tableSections) == 1:
|
||||
# Single table - use base filename
|
||||
# Determine filename
|
||||
if len(allCsvSections) == 1:
|
||||
filename = f"{baseFilename}.csv"
|
||||
else:
|
||||
# Multiple tables - add index or section title to filename
|
||||
sectionId = tableSection.get("id", f"table_{i+1}")
|
||||
# Use section title if available, otherwise use section ID
|
||||
sectionId = csvSection.get("id", f"csv_{i+1}")
|
||||
if sectionTitle:
|
||||
# Sanitize section title for filename
|
||||
safeTitle = "".join(c for c in sectionTitle if c.isalnum() or c in (' ', '-', '_')).strip()
|
||||
safeTitle = safeTitle.replace(' ', '_')[:30] # Limit length
|
||||
safeTitle = safeTitle.replace(' ', '_')[:30]
|
||||
filename = f"{baseFilename}_{safeTitle}.csv"
|
||||
else:
|
||||
filename = f"{baseFilename}_{sectionId}.csv"
|
||||
|
||||
# Extract document type from metadata
|
||||
documentType = metadata.get("documentType") if isinstance(metadata, dict) else None
|
||||
|
||||
renderedDocuments.append(
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ Features can register callbacks to be notified when automations change.
|
|||
|
||||
import logging
|
||||
from typing import Callable, List, Dict, Any
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -25,7 +24,7 @@ class CallbackRegistry:
|
|||
|
||||
Args:
|
||||
event_type: Type of event (e.g., 'automation.changed')
|
||||
callback: Async or sync callback function
|
||||
callback: Sync callback function
|
||||
"""
|
||||
if event_type not in self._callbacks:
|
||||
self._callbacks[event_type] = []
|
||||
|
|
@ -41,8 +40,8 @@ class CallbackRegistry:
|
|||
except ValueError:
|
||||
logger.warning(f"Callback not found for event type: {event_type}")
|
||||
|
||||
async def trigger(self, event_type: str, *args, **kwargs):
|
||||
"""Trigger all callbacks registered for an event type.
|
||||
def trigger(self, event_type: str, *args, **kwargs):
|
||||
"""Trigger all registered callbacks for an event type.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to trigger
|
||||
|
|
@ -55,18 +54,14 @@ class CallbackRegistry:
|
|||
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(*args, **kwargs)
|
||||
else:
|
||||
callback(*args, **kwargs)
|
||||
callback(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing callback for {event_type}: {str(e)}", exc_info=True)
|
||||
|
||||
def has_callbacks(self, event_type: str) -> bool:
|
||||
def hasCallbacks(self, event_type: str) -> bool:
|
||||
"""Check if there are any callbacks registered for an event type."""
|
||||
return event_type in self._callbacks and len(self._callbacks[event_type]) > 0
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
callbackRegistry = CallbackRegistry()
|
||||
|
||||
|
|
|
|||
|
|
@ -38,16 +38,14 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode
|
|||
featureCode: Feature code (e.g., 'chatplayground', 'automation')
|
||||
"""
|
||||
try:
|
||||
services = getServices(currentUser, mandateId=mandateId)
|
||||
services = getServices(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId)
|
||||
|
||||
# Store allowedProviders in services context for model selection
|
||||
if hasattr(userInput, 'allowedProviders') and userInput.allowedProviders:
|
||||
services.allowedProviders = userInput.allowedProviders
|
||||
logger.info(f"AI provider filter active: {userInput.allowedProviders}")
|
||||
|
||||
# Store feature context in services (for billing and RBAC)
|
||||
if featureInstanceId:
|
||||
services.featureInstanceId = featureInstanceId
|
||||
# Store feature code in services (for billing)
|
||||
if featureCode:
|
||||
services.featureCode = featureCode
|
||||
|
||||
|
|
@ -61,10 +59,8 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode
|
|||
async def chatStop(currentUser: User, workflowId: str, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> ChatWorkflow:
|
||||
"""Stops a running chat."""
|
||||
try:
|
||||
services = getServices(currentUser, mandateId=mandateId)
|
||||
# Store feature instance ID in services context for proper RBAC filtering
|
||||
services = getServices(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId)
|
||||
if featureInstanceId:
|
||||
services.featureInstanceId = featureInstanceId
|
||||
services.featureCode = 'chatplayground'
|
||||
workflowManager = WorkflowManager(services)
|
||||
return await workflowManager.workflowStop(workflowId)
|
||||
|
|
@ -73,12 +69,17 @@ async def chatStop(currentUser: User, workflowId: str, mandateId: Optional[str]
|
|||
raise
|
||||
|
||||
|
||||
async def executeAutomation(automationId: str, services) -> ChatWorkflow:
|
||||
"""Execute automation workflow immediately (test mode) with placeholder replacement.
|
||||
async def executeAutomation(automationId: str, automation, creatorUser: User, services) -> ChatWorkflow:
|
||||
"""Execute automation workflow with the creator user's context.
|
||||
|
||||
The automation object and creatorUser are resolved by the caller (handler)
|
||||
using the SysAdmin eventUser. This function does NOT re-load them.
|
||||
|
||||
Args:
|
||||
automationId: ID of automation to execute
|
||||
services: Services instance for data access
|
||||
automation: Pre-loaded automation object (with system fields like _createdBy)
|
||||
creatorUser: The user who created the automation (workflow runs in this context)
|
||||
services: Services instance (used for interfaceDbApp etc.)
|
||||
|
||||
Returns:
|
||||
ChatWorkflow instance created by automation execution
|
||||
|
|
@ -92,11 +93,6 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
|
|||
}
|
||||
|
||||
try:
|
||||
# 1. Load automation definition (with system fields for _createdBy access)
|
||||
automation = services.interfaceDbAutomation.getAutomationDefinition(automationId, includeSystemFields=True)
|
||||
if not automation:
|
||||
raise ValueError(f"Automation {automationId} not found")
|
||||
|
||||
executionLog["messages"].append(f"Started execution at {executionStartTime}")
|
||||
|
||||
# Store allowed providers from automation in services context
|
||||
|
|
@ -105,12 +101,12 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
|
|||
logger.debug(f"Automation {automationId} restricted to providers: {automation.allowedProviders}")
|
||||
|
||||
# Context comes EXCLUSIVELY from the automation definition
|
||||
services.mandateId = str(automation.mandateId)
|
||||
services.featureInstanceId = str(automation.featureInstanceId)
|
||||
services.featureCode = 'automation'
|
||||
featureInstanceId = services.featureInstanceId
|
||||
automationMandateId = str(automation.mandateId)
|
||||
automationFeatureInstanceId = str(automation.featureInstanceId)
|
||||
|
||||
# 2. Replace placeholders in template to generate plan
|
||||
logger.info(f"Executing automation {automationId} as user {creatorUser.id} with mandateId={automationMandateId}, featureInstanceId={automationFeatureInstanceId}")
|
||||
|
||||
# 1. Replace placeholders in template to generate plan
|
||||
template = automation.template or ""
|
||||
placeholders = automation.placeholders or {}
|
||||
planJson = replacePlaceholders(template, placeholders)
|
||||
|
|
@ -128,24 +124,9 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
|
|||
logger.error(f"Context around error: ...{planJson[start:end]}...")
|
||||
raise ValueError(f"Invalid JSON after placeholder replacement: {str(e)}")
|
||||
executionLog["messages"].append("Template placeholders replaced successfully")
|
||||
executionLog["messages"].append(f"Using creator user: {creatorUser.id}")
|
||||
|
||||
# 3. Get user who created automation
|
||||
creatorUserId = getattr(automation, "_createdBy", None)
|
||||
|
||||
# _createdBy is a system attribute - must be present
|
||||
if not creatorUserId:
|
||||
errorMsg = f"Automation {automationId} has no creator user (_createdBy field missing). Cannot execute automation."
|
||||
logger.error(errorMsg)
|
||||
executionLog["messages"].append(errorMsg)
|
||||
raise ValueError(errorMsg)
|
||||
|
||||
# Get creator user from database
|
||||
creatorUser = services.interfaceDbApp.getUser(creatorUserId)
|
||||
if not creatorUser:
|
||||
raise ValueError(f"Creator user {creatorUserId} not found")
|
||||
executionLog["messages"].append(f"Using creator user: {creatorUserId}")
|
||||
|
||||
# 4. Create UserInputRequest from plan
|
||||
# 2. Create UserInputRequest from plan
|
||||
# Embed plan JSON in prompt for TemplateMode to extract
|
||||
promptText = planToPrompt(plan)
|
||||
planJsonStr = json.dumps(plan)
|
||||
|
|
@ -160,16 +141,15 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
|
|||
|
||||
executionLog["messages"].append("Starting workflow execution")
|
||||
|
||||
# 5. Start workflow using chatStart
|
||||
# Pass mandateId, featureInstanceId, and featureCode from original services context
|
||||
# so billing is recorded correctly with full feature context
|
||||
# 3. Start workflow using chatStart with creator's context
|
||||
# mandateId and featureInstanceId come from the automation definition
|
||||
workflow = await chatStart(
|
||||
currentUser=creatorUser,
|
||||
userInput=userInput,
|
||||
workflowMode=WorkflowModeEnum.WORKFLOW_AUTOMATION,
|
||||
workflowId=None,
|
||||
mandateId=services.mandateId,
|
||||
featureInstanceId=featureInstanceId,
|
||||
mandateId=automationMandateId,
|
||||
featureInstanceId=automationFeatureInstanceId,
|
||||
featureCode='automation'
|
||||
)
|
||||
|
||||
|
|
@ -200,22 +180,22 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
|
|||
executionLog["messages"].append(f"Error: {str(e)}")
|
||||
|
||||
# Save execution log even on error (bypasses RBAC — system operation)
|
||||
# Use the automation object already passed in (no re-load needed)
|
||||
try:
|
||||
automation = services.interfaceDbAutomation.getAutomationDefinition(automationId)
|
||||
if automation:
|
||||
executionLogs = list(automation.executionLogs or [])
|
||||
executionLogs.append(executionLog)
|
||||
if len(executionLogs) > 50:
|
||||
executionLogs = executionLogs[-50:]
|
||||
services.interfaceDbAutomation._saveExecutionLog(automationId, executionLogs)
|
||||
executionLogs = list(getattr(automation, 'executionLogs', None) or [])
|
||||
executionLogs.append(executionLog)
|
||||
if len(executionLogs) > 50:
|
||||
executionLogs = executionLogs[-50:]
|
||||
services.interfaceDbAutomation._saveExecutionLog(automationId, executionLogs)
|
||||
except Exception as logError:
|
||||
logger.error(f"Error saving execution log: {str(logError)}")
|
||||
|
||||
raise
|
||||
|
||||
|
||||
async def syncAutomationEvents(services, eventUser) -> Dict[str, Any]:
|
||||
"""Automation event handler - syncs scheduler with all active automations.
|
||||
def syncAutomationEvents(services, eventUser) -> Dict[str, Any]:
|
||||
"""Sync scheduler with all active automations.
|
||||
All operations (DB reads, scheduler registration) are synchronous.
|
||||
|
||||
Args:
|
||||
services: Services instance for data access
|
||||
|
|
@ -316,37 +296,28 @@ def createAutomationEventHandler(automationId: str, eventUser):
|
|||
logger.error("Event user not available for automation execution")
|
||||
return
|
||||
|
||||
# Get services for event user (provides access to interfaces)
|
||||
# Load automation using SysAdmin eventUser (has unrestricted access)
|
||||
eventServices = getServices(eventUser, None)
|
||||
|
||||
# Load automation using event user context (with system fields for _createdBy access)
|
||||
automation = eventServices.interfaceDbAutomation.getAutomationDefinition(automationId, includeSystemFields=True)
|
||||
if not automation or not getattr(automation, "active", False):
|
||||
logger.warning(f"Automation {automationId} not found or not active, skipping execution")
|
||||
return
|
||||
|
||||
# Get creator user
|
||||
# Get creator user ID from automation's _createdBy system field
|
||||
creatorUserId = getattr(automation, "_createdBy", None)
|
||||
if not creatorUserId:
|
||||
logger.error(f"Automation {automationId} has no creator user")
|
||||
logger.error(f"Automation {automationId} has no creator user (_createdBy missing)")
|
||||
return
|
||||
|
||||
# Get mandate context from automation definition
|
||||
automationMandateId = getattr(automation, "mandateId", None)
|
||||
|
||||
# Get creator user from database using services
|
||||
eventServices = getServices(eventUser, None)
|
||||
# Get creator user from database (using SysAdmin access)
|
||||
creatorUser = eventServices.interfaceDbApp.getUser(creatorUserId)
|
||||
if not creatorUser:
|
||||
logger.error(f"Creator user {creatorUserId} not found for automation {automationId}")
|
||||
return
|
||||
|
||||
# Get services for creator user WITH mandate context from automation
|
||||
creatorServices = getServices(creatorUser, automationMandateId)
|
||||
|
||||
# Execute automation with creator user's context and mandate
|
||||
# executeAutomation is in same module, so we can call it directly
|
||||
await executeAutomation(automationId, creatorServices)
|
||||
# Execute automation — pass automation object and creatorUser directly
|
||||
# No re-load needed in executeAutomation
|
||||
await executeAutomation(automationId, automation, creatorUser, eventServices)
|
||||
logger.info(f"Successfully executed automation {automationId} as user {creatorUserId}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing automation {automationId}: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -14,9 +14,10 @@ from modules.services import getInterface as getServices
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def start(eventUser) -> None:
|
||||
def start(eventUser) -> bool:
|
||||
"""
|
||||
Start automation scheduler and sync scheduled events.
|
||||
All operations are synchronous (DB access, scheduler registration).
|
||||
|
||||
Args:
|
||||
eventUser: System-level event user for background operations (provided by app.py)
|
||||
|
|
@ -33,16 +34,16 @@ async def start(eventUser) -> None:
|
|||
services = getServices(eventUser, None)
|
||||
|
||||
# Register callback for automation changes
|
||||
async def onAutomationChanged(chatInterface):
|
||||
def onAutomationChanged(chatInterface):
|
||||
"""Callback triggered when automations are created/updated/deleted."""
|
||||
eventServices = getServices(eventUser, None)
|
||||
await syncAutomationEvents(eventServices, eventUser)
|
||||
syncAutomationEvents(eventServices, eventUser)
|
||||
|
||||
callbackRegistry.register('automation.changed', onAutomationChanged)
|
||||
logger.info("Automation: Registered change callback")
|
||||
|
||||
# Initial sync on startup
|
||||
await syncAutomationEvents(services, eventUser)
|
||||
syncAutomationEvents(services, eventUser)
|
||||
logger.info("Automation: Scheduled events synced on startup")
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -52,7 +53,7 @@ async def start(eventUser) -> None:
|
|||
return True
|
||||
|
||||
|
||||
async def stop(eventUser) -> None:
|
||||
def stop(eventUser) -> bool:
|
||||
"""
|
||||
Stop automation scheduler.
|
||||
|
||||
|
|
|
|||
|
|
@ -139,11 +139,16 @@ class MethodBase:
|
|||
return False
|
||||
|
||||
# RBAC-Check: RESOURCE context, item = actionId
|
||||
# mandateId/featureInstanceId from services context needed to resolve user roles
|
||||
try:
|
||||
mandateId = getattr(self.services, 'mandateId', None)
|
||||
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
|
||||
permissions = self.services.rbac.getUserPermissions(
|
||||
user=currentUser,
|
||||
context=AccessRuleContext.RESOURCE,
|
||||
item=actionId
|
||||
item=actionId,
|
||||
mandateId=str(mandateId) if mandateId else None,
|
||||
featureInstanceId=str(featureInstanceId) if featureInstanceId else None
|
||||
)
|
||||
hasPermission = permissions.view
|
||||
if not hasPermission:
|
||||
|
|
@ -151,8 +156,9 @@ class MethodBase:
|
|||
userRoles = getattr(currentUser, 'roleLabels', []) or []
|
||||
self.logger.warning(
|
||||
f"RBAC denied action {actionId} for user {currentUser.id}. "
|
||||
f"User roles: {userRoles}, "
|
||||
f"Permissions: view={permissions.view}, edit={permissions.edit}, delete={permissions.delete}. "
|
||||
f"User roles: {userRoles}, mandateId={mandateId}, "
|
||||
f"Permissions: view={permissions.view}, read={permissions.read}, "
|
||||
f"create={permissions.create}, update={permissions.update}, delete={permissions.delete}. "
|
||||
f"No matching RBAC rule found for context=RESOURCE, item={actionId}"
|
||||
)
|
||||
return hasPermission
|
||||
|
|
|
|||
Loading…
Reference in a new issue