logical fixes

This commit is contained in:
patrick-motsch 2026-02-09 23:44:52 +01:00
parent 7532841d9d
commit d98c31a4d1
29 changed files with 588 additions and 449 deletions

4
app.py
View file

@ -315,7 +315,7 @@ async def lifespan(app: FastAPI):
logger.warning(f"Could not initialize feature containers: {e}") logger.warning(f"Could not initialize feature containers: {e}")
# --- Init Managers --- # --- Init Managers ---
await subAutomationSchedule.start(eventUser) # Automation scheduler subAutomationSchedule.start(eventUser) # Automation scheduler
eventManager.start() eventManager.start()
# Register audit log cleanup scheduler # Register audit log cleanup scheduler
@ -345,7 +345,7 @@ async def lifespan(app: FastAPI):
# --- Stop Managers --- # --- Stop Managers ---
eventManager.stop() eventManager.stop()
await subAutomationSchedule.stop(eventUser) # Automation scheduler subAutomationSchedule.stop(eventUser) # Automation scheduler
# --- Stop Feature Containers (Plug&Play) --- # --- Stop Feature Containers (Plug&Play) ---
try: try:

View file

@ -3,7 +3,7 @@
"""File-related datamodels: FileItem, FilePreview, FileData.""" """File-related datamodels: FileItem, FilePreview, FileData."""
from typing import Dict, Any, Optional, Union from typing import Dict, Any, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
from modules.shared.attributeUtils import registerModelLabels from modules.shared.attributeUtils import registerModelLabels
from modules.shared.timeUtils import getUtcTimestamp from modules.shared.timeUtils import getUtcTimestamp
import uuid import uuid
@ -11,6 +11,7 @@ import base64
class FileItem(BaseModel): class FileItem(BaseModel):
model_config = ConfigDict(extra='allow') # Preserve system fields (_createdBy, _createdAt, etc.)
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
mandateId: Optional[str] = Field(default="", description="ID of the mandate this file belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) mandateId: Optional[str] = Field(default="", description="ID of the mandate this file belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
featureInstanceId: Optional[str] = Field(default="", description="ID of the feature instance this file belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) featureInstanceId: Optional[str] = Field(default="", description="ID of the feature instance this file belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})

View file

@ -3,22 +3,33 @@
"""Utility datamodels: Prompt, TextMultilingual.""" """Utility datamodels: Prompt, TextMultilingual."""
from typing import Dict, Optional from typing import Dict, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, ConfigDict, Field, field_validator
from modules.shared.attributeUtils import registerModelLabels from modules.shared.attributeUtils import registerModelLabels
import uuid import uuid
class Prompt(BaseModel): class Prompt(BaseModel):
model_config = ConfigDict(extra='allow') # Preserve system fields (_createdBy, _createdAt, etc.)
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
mandateId: str = Field(default="", description="ID of the mandate this prompt belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) mandateId: str = Field(default="", description="ID of the mandate this prompt belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
isSystem: bool = Field(default=False, description="System prompt visible to all users (read-only for non-SysAdmin)", json_schema_extra={"frontend_type": "boolean", "frontend_readonly": True, "frontend_required": False})
content: str = Field(description="Content of the prompt", json_schema_extra={"frontend_type": "textarea", "frontend_readonly": False, "frontend_required": True}) content: str = Field(description="Content of the prompt", json_schema_extra={"frontend_type": "textarea", "frontend_readonly": False, "frontend_required": True})
name: str = Field(description="Name of the prompt", json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": True}) name: str = Field(description="Name of the prompt", json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": True})
@field_validator('isSystem', mode='before')
@classmethod
def _coerceIsSystem(cls, v):
"""Existing records may have isSystem=None (field didn't exist). Treat None as False."""
if v is None:
return False
return v
registerModelLabels( registerModelLabels(
"Prompt", "Prompt",
{"en": "Prompt", "fr": "Invite"}, {"en": "Prompt", "fr": "Invite"},
{ {
"id": {"en": "ID", "fr": "ID"}, "id": {"en": "ID", "fr": "ID"},
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"}, "mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
"isSystem": {"en": "System", "fr": "Système"},
"content": {"en": "Content", "fr": "Contenu"}, "content": {"en": "Content", "fr": "Contenu"},
"name": {"en": "Name", "fr": "Nom"}, "name": {"en": "Name", "fr": "Nom"},
}, },

View file

@ -8,7 +8,6 @@ Uses the PostgreSQL connector for data access with user/mandate filtering.
import logging import logging
import uuid import uuid
import math import math
import asyncio
from typing import Dict, Any, List, Optional, Union from typing import Dict, Any, List, Optional, Union
from modules.security.rbac import RbacClass from modules.security.rbac import RbacClass
@ -99,7 +98,7 @@ class AutomationObjects:
return True return True
elif accessLevel == AccessLevel.MY: elif accessLevel == AccessLevel.MY:
if recordId: if recordId:
record = self.db.getRecordset(model, {"id": recordId}) record = self.db.getRecordset(model, recordFilter={"id": recordId})
if record: if record:
return record[0].get("_createdBy") == self.userId return record[0].get("_createdBy") == self.userId
else: else:
@ -118,16 +117,17 @@ class AutomationObjects:
def _enrichAutomationsWithUserAndMandate(self, automations: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def _enrichAutomationsWithUserAndMandate(self, automations: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
""" """
Batch enrich automations with user names and mandate names for display. Batch enrich automations with user names, mandate names and feature instance labels.
Uses direct DB lookup (no RBAC) because this is purely cosmetic enrichment Uses direct DB lookup (no RBAC) because this is purely cosmetic enrichment
the user already has RBAC-verified access to the automations themselves. the user already has RBAC-verified access to the automations themselves.
""" """
if not automations: if not automations:
return automations return automations
# Collect all unique user IDs and mandate IDs # Collect all unique IDs
userIds = set() userIds = set()
mandateIds = set() mandateIds = set()
featureInstanceIds = set()
for automation in automations: for automation in automations:
createdBy = automation.get("_createdBy") createdBy = automation.get("_createdBy")
@ -137,48 +137,63 @@ class AutomationObjects:
mandateId = automation.get("mandateId") mandateId = automation.get("mandateId")
if mandateId: if mandateId:
mandateIds.add(mandateId) mandateIds.add(mandateId)
featureInstanceId = automation.get("featureInstanceId")
if featureInstanceId:
featureInstanceIds.add(featureInstanceId)
# Use root DB connector for display-only lookups (no RBAC needed) # Use root DB connector for display-only lookups (no RBAC needed)
usersMap = {}
mandatesMap = {}
featureInstancesMap = {}
try: try:
from modules.datamodels.datamodelUam import UserInDB, Mandate from modules.datamodels.datamodelUam import UserInDB, Mandate
from modules.datamodels.datamodelFeatures import FeatureInstance
from modules.security.rootAccess import getRootDbAppConnector from modules.security.rootAccess import getRootDbAppConnector
dbAppConn = getRootDbAppConnector() dbAppConn = getRootDbAppConnector()
# Batch fetch user display names # Batch fetch user display names
usersMap = {}
if userIds: if userIds:
for userId in userIds: for userId in userIds:
users = dbAppConn.getRecordset(UserInDB, {"id": userId}) users = dbAppConn.getRecordset(UserInDB, recordFilter={"id": userId})
if users: if users:
user = users[0] user = users[0]
fullName = f"{user.get('firstName', '')} {user.get('lastName', '')}".strip() displayName = user.get("fullName") or user.get("username") or user.get("email") or None
usersMap[userId] = fullName or user.get("email") or user.get("username") or userId if displayName:
usersMap[userId] = displayName
# Batch fetch mandate display names # Batch fetch mandate display names
mandatesMap = {}
if mandateIds: if mandateIds:
for mandateId in mandateIds: for mandateId in mandateIds:
mandates = dbAppConn.getRecordset(Mandate, {"id": mandateId}) mandates = dbAppConn.getRecordset(Mandate, recordFilter={"id": mandateId})
if mandates: if mandates:
mandatesMap[mandateId] = mandates[0].get("name") or mandateId label = mandates[0].get("label") or mandates[0].get("name") or None
if label:
mandatesMap[mandateId] = label
# Batch fetch feature instance labels
if featureInstanceIds:
for fiId in featureInstanceIds:
instances = dbAppConn.getRecordset(FeatureInstance, recordFilter={"id": fiId})
if instances:
fi = instances[0]
label = fi.get("label") or fi.get("featureCode") or None
if label:
featureInstancesMap[fiId] = label
except Exception as e: except Exception as e:
logger.warning(f"Could not enrich automations with user/mandate names: {e}") logger.warning(f"Could not enrich automations with display names: {e}")
usersMap = {}
mandatesMap = {}
# Enrich each automation with the fetched data # Enrich each automation with the fetched data
# SECURITY: Never show a fallback name — if lookup fails, show empty string
for automation in automations: for automation in automations:
createdBy = automation.get("_createdBy") createdBy = automation.get("_createdBy")
if createdBy: automation["_createdByUserName"] = usersMap.get(createdBy, "") if createdBy else ""
automation["_createdByUserName"] = usersMap.get(createdBy, createdBy)
else:
automation["_createdByUserName"] = "-"
mandateId = automation.get("mandateId") mandateId = automation.get("mandateId")
if mandateId: automation["mandateName"] = mandatesMap.get(mandateId, "") if mandateId else ""
automation["mandateName"] = mandatesMap.get(mandateId, mandateId)
else: featureInstanceId = automation.get("featureInstanceId")
automation["mandateName"] = "-" automation["featureInstanceName"] = featureInstancesMap.get(featureInstanceId, "") if featureInstanceId else ""
return automations return automations
@ -195,11 +210,13 @@ class AutomationObjects:
Supports optional pagination, sorting, and filtering. Supports optional pagination, sorting, and filtering.
Computes status field for each automation. Computes status field for each automation.
""" """
# Use RBAC filtering # AutomationDefinitions can belong to any feature instance within a mandate.
# Filter by mandateId only — not by featureInstanceId — to show all definitions across features.
filteredAutomations = getRecordsetWithRBAC( filteredAutomations = getRecordsetWithRBAC(
self.db, self.db,
AutomationDefinition, AutomationDefinition,
self.currentUser self.currentUser,
mandateId=self.mandateId
) )
# Compute status for each automation and normalize executionLogs # Compute status for each automation and normalize executionLogs
@ -282,12 +299,14 @@ class AutomationObjects:
If False (default), returns Pydantic model without system fields. If False (default), returns Pydantic model without system fields.
""" """
try: try:
# Use RBAC filtering # AutomationDefinitions can belong to any feature instance within a mandate.
# Filter by mandateId only — not by featureInstanceId.
filtered = getRecordsetWithRBAC( filtered = getRecordsetWithRBAC(
self.db, self.db,
AutomationDefinition, AutomationDefinition,
self.currentUser, self.currentUser,
recordFilter={"id": automationId} recordFilter={"id": automationId},
mandateId=self.mandateId
) )
if not filtered: if not filtered:
@ -363,8 +382,8 @@ class AutomationObjects:
if createdAutomation.get("executionLogs") is None: if createdAutomation.get("executionLogs") is None:
createdAutomation["executionLogs"] = [] createdAutomation["executionLogs"] = []
# Trigger automation change callback (async, don't wait) # Trigger automation change callback
asyncio.create_task(self._notifyAutomationChanged()) self._notifyAutomationChanged()
# Clean metadata fields and return Pydantic model # Clean metadata fields and return Pydantic model
cleanedRecord = {k: v for k, v in createdAutomation.items() if not k.startswith("_")} cleanedRecord = {k: v for k, v in createdAutomation.items() if not k.startswith("_")}
@ -408,8 +427,8 @@ class AutomationObjects:
if updatedAutomation.get("executionLogs") is None: if updatedAutomation.get("executionLogs") is None:
updatedAutomation["executionLogs"] = [] updatedAutomation["executionLogs"] = []
# Trigger automation change callback (async, don't wait) # Trigger automation change callback
asyncio.create_task(self._notifyAutomationChanged()) self._notifyAutomationChanged()
# Clean metadata fields and return Pydantic model # Clean metadata fields and return Pydantic model
cleanedRecord = {k: v for k, v in updatedAutomation.items() if not k.startswith("_")} cleanedRecord = {k: v for k, v in updatedAutomation.items() if not k.startswith("_")}
@ -432,8 +451,8 @@ class AutomationObjects:
# Delete automation from database # Delete automation from database
self.db.recordDelete(AutomationDefinition, automationId) self.db.recordDelete(AutomationDefinition, automationId)
# Trigger automation change callback (async, don't wait) # Trigger automation change callback
asyncio.create_task(self._notifyAutomationChanged()) self._notifyAutomationChanged()
return True return True
except Exception as e: except Exception as e:
@ -454,7 +473,9 @@ class AutomationObjects:
return getRecordsetWithRBAC( return getRecordsetWithRBAC(
self.db, self.db,
AutomationDefinition, AutomationDefinition,
user user,
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId
) )
# ========================================================================= # =========================================================================
@ -466,7 +487,7 @@ class AutomationObjects:
Returns automation templates filtered by RBAC (MY = own templates). Returns automation templates filtered by RBAC (MY = own templates).
Supports optional pagination, sorting, and filtering. Supports optional pagination, sorting, and filtering.
""" """
# Use RBAC filtering # Templates are global (not mandate/feature-instance scoped) — no mandateId/featureInstanceId filter
filteredTemplates = getRecordsetWithRBAC( filteredTemplates = getRecordsetWithRBAC(
self.db, self.db,
AutomationTemplate, AutomationTemplate,
@ -526,23 +547,24 @@ class AutomationObjects:
userNameMap = {} userNameMap = {}
for userId in userIds: for userId in userIds:
users = dbAppConn.getRecordset(UserInDB, {"id": userId}) users = dbAppConn.getRecordset(UserInDB, recordFilter={"id": userId})
if users: if users:
user = users[0] user = users[0]
fullName = f"{user.get('firstName', '')} {user.get('lastName', '')}".strip() displayName = user.get("fullName") or user.get("username") or user.get("email") or None
userNameMap[userId] = fullName or user.get("email", "Unknown") if displayName:
userNameMap[userId] = displayName
# Apply to templates # Apply to templates — SECURITY: no fallback, empty if not found
for template in templates: for template in templates:
createdBy = template.get("_createdBy") createdBy = template.get("_createdBy")
if createdBy and createdBy in userNameMap: template["_createdByUserName"] = userNameMap.get(createdBy, "") if createdBy else ""
template["_createdByUserName"] = userNameMap[createdBy]
except Exception as e: except Exception as e:
logger.warning(f"Could not enrich templates with user names: {e}") logger.warning(f"Could not enrich templates with user names: {e}")
def getAutomationTemplate(self, templateId: str) -> Optional[Dict[str, Any]]: def getAutomationTemplate(self, templateId: str) -> Optional[Dict[str, Any]]:
"""Returns an automation template by ID if user has access.""" """Returns an automation template by ID if user has access."""
try: try:
# Templates are global — no mandateId/featureInstanceId filter
filtered = getRecordsetWithRBAC( filtered = getRecordsetWithRBAC(
self.db, self.db,
AutomationTemplate, AutomationTemplate,
@ -645,12 +667,13 @@ class AutomationObjects:
logger.error(f"Error deleting automation template: {str(e)}") logger.error(f"Error deleting automation template: {str(e)}")
raise raise
async def _notifyAutomationChanged(self): def _notifyAutomationChanged(self):
"""Notify registered callbacks about automation changes (decoupled from features).""" """Notify registered callbacks about automation changes (decoupled from features).
Sync-safe: works from both sync and async contexts."""
try: try:
from modules.shared.callbackRegistry import callbackRegistry from modules.shared.callbackRegistry import callbackRegistry
# Trigger callbacks without knowing which features are listening # Trigger callbacks without knowing which features are listening
await callbackRegistry.trigger('automation.changed', self) callbackRegistry.trigger('automation.changed', self)
except Exception as e: except Exception as e:
logger.error(f"Error notifying automation change: {str(e)}") logger.error(f"Error notifying automation change: {str(e)}")

View file

@ -66,7 +66,9 @@ def get_automations(
detail=f"Invalid pagination parameter: {str(e)}" detail=f"Invalid pagination parameter: {str(e)}"
) )
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None) # AutomationDefinitions can belong to ANY feature instance within a mandate.
# The list endpoint must show all definitions for the user's mandate, not filter by a specific featureInstanceId.
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
result = chatInterface.getAllAutomationDefinitions(pagination=paginationParams) result = chatInterface.getAllAutomationDefinitions(pagination=paginationParams)
# If pagination was requested, result is PaginatedResult # If pagination was requested, result is PaginatedResult
@ -150,7 +152,7 @@ def get_available_actions(
# Ensure methods are discovered (need a service center for discovery) # Ensure methods are discovered (need a service center for discovery)
if not methods: if not methods:
# Create a lightweight service center for method discovery # Create a lightweight service center for method discovery
services = getServices(context.user, context.mandateId) services = getServices(context.user, mandateId=context.mandateId)
discoverMethods(services) discoverMethods(services)
actionsList = [] actionsList = []
@ -235,7 +237,7 @@ def get_automation(
) -> AutomationDefinition: ) -> AutomationDefinition:
"""Get a single automation definition by ID""" """Get a single automation definition by ID"""
try: try:
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None) chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
automation = chatInterface.getAutomationDefinition(automationId) automation = chatInterface.getAutomationDefinition(automationId)
if not automation: if not automation:
raise HTTPException( raise HTTPException(
@ -263,7 +265,7 @@ def update_automation(
) -> AutomationDefinition: ) -> AutomationDefinition:
"""Update an automation definition""" """Update an automation definition"""
try: try:
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None) chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
automationData = automation.model_dump() automationData = automation.model_dump()
updated = chatInterface.updateAutomationDefinition(automationId, automationData) updated = chatInterface.updateAutomationDefinition(automationId, automationData)
return updated return updated
@ -291,7 +293,7 @@ def update_automation_status(
) -> AutomationDefinition: ) -> AutomationDefinition:
"""Update only the active status of an automation definition""" """Update only the active status of an automation definition"""
try: try:
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None) chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
# Get existing automation # Get existing automation
automation = chatInterface.getAutomationDefinition(automationId) automation = chatInterface.getAutomationDefinition(automationId)
@ -331,7 +333,7 @@ def delete_automation(
) -> Response: ) -> Response:
"""Delete an automation definition""" """Delete an automation definition"""
try: try:
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None) chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
success = chatInterface.deleteAutomationDefinition(automationId) success = chatInterface.deleteAutomationDefinition(automationId)
if success: if success:
return Response(status_code=204) return Response(status_code=204)
@ -364,13 +366,15 @@ async def execute_automation_route(
"""Execute an automation immediately (test mode)""" """Execute an automation immediately (test mode)"""
try: try:
from modules.services import getInterface as getServices from modules.services import getInterface as getServices
services = getServices(context.user, context.mandateId) services = getServices(context.user, mandateId=context.mandateId, featureInstanceId=context.featureInstanceId)
# Propagate feature context for billing
if context.featureInstanceId: # Load automation with current user's context (user has RBAC permissions via UI)
services.featureInstanceId = str(context.featureInstanceId) automation = services.interfaceDbAutomation.getAutomationDefinition(automationId, includeSystemFields=True)
services.featureCode = 'automation' if not automation:
raise ValueError(f"Automation {automationId} not found")
from modules.workflows.automation import executeAutomation from modules.workflows.automation import executeAutomation
workflow = await executeAutomation(automationId, services) workflow = await executeAutomation(automationId, automation, context.user, services)
return workflow return workflow
except HTTPException: except HTTPException:
raise raise

View file

@ -360,10 +360,12 @@ class ChatObjects:
return False return False
tableName = modelClass.__name__ tableName = modelClass.__name__
from modules.interfaces.interfaceRbac import buildDataObjectKey
objectKey = buildDataObjectKey(tableName, featureCode=self.featureCode if hasattr(self, 'featureCode') else None)
permissions = self.rbac.getUserPermissions( permissions = self.rbac.getUserPermissions(
self.currentUser, self.currentUser,
AccessRuleContext.DATA, AccessRuleContext.DATA,
tableName, objectKey,
mandateId=self.mandateId, mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId featureInstanceId=self.featureInstanceId
) )

View file

@ -91,12 +91,8 @@ async def chatProcess(
ChatWorkflow instance ChatWorkflow instance
""" """
try: try:
# Get services with mandate context # Get services with mandate and feature instance context
services = getServices(currentUser, mandateId) services = getServices(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId)
# Set feature context for billing
if featureInstanceId:
services.featureInstanceId = featureInstanceId
services.featureCode = 'chatbot' services.featureCode = 'chatbot'
interfaceDbChat = services.interfaceDbChat interfaceDbChat = services.interfaceDbChat

View file

@ -49,10 +49,10 @@ RESOURCE_OBJECTS = [
] ]
# Template roles for this feature # Template roles for this feature
# IMPORTANT: "viewer" role is required for automatic user assignment! # Role names MUST follow convention: {featureCode}-{roleName}
TEMPLATE_ROLES = [ TEMPLATE_ROLES = [
{ {
"roleLabel": "viewer", "roleLabel": "chatplayground-viewer",
"description": { "description": {
"en": "Chat Playground Viewer - View chat playground (read-only)", "en": "Chat Playground Viewer - View chat playground (read-only)",
"de": "Chat Playground Betrachter - Chat Playground ansehen (nur lesen)", "de": "Chat Playground Betrachter - Chat Playground ansehen (nur lesen)",
@ -67,7 +67,7 @@ TEMPLATE_ROLES = [
] ]
}, },
{ {
"roleLabel": "user", "roleLabel": "chatplayground-user",
"description": { "description": {
"en": "Chat Playground User - Use chat playground and workflows", "en": "Chat Playground User - Use chat playground and workflows",
"de": "Chat Playground Benutzer - Chat Playground und Workflows nutzen", "de": "Chat Playground Benutzer - Chat Playground und Workflows nutzen",
@ -86,7 +86,7 @@ TEMPLATE_ROLES = [
] ]
}, },
{ {
"roleLabel": "admin", "roleLabel": "chatplayground-admin",
"description": { "description": {
"en": "Chat Playground Admin - Full access to chat playground", "en": "Chat Playground Admin - Full access to chat playground",
"de": "Chat Playground Admin - Vollzugriff auf Chat Playground", "de": "Chat Playground Admin - Vollzugriff auf Chat Playground",

View file

@ -749,10 +749,12 @@ class RealEstateObjects:
return False return False
tableName = modelClass.__name__ tableName = modelClass.__name__
from modules.interfaces.interfaceRbac import buildDataObjectKey
objectKey = buildDataObjectKey(tableName, featureCode=self.featureCode if hasattr(self, 'featureCode') else None)
permissions = self.rbac.getUserPermissions( permissions = self.rbac.getUserPermissions(
self.currentUser, self.currentUser,
AccessRuleContext.DATA, AccessRuleContext.DATA,
tableName, objectKey,
mandateId=self.mandateId, mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId featureInstanceId=self.featureInstanceId
) )

View file

@ -171,10 +171,12 @@ class TrusteeObjects:
return False return False
tableName = modelClass.__name__ tableName = modelClass.__name__
from modules.interfaces.interfaceRbac import buildDataObjectKey
objectKey = buildDataObjectKey(tableName, featureCode=self.featureCode if hasattr(self, 'featureCode') else None)
permissions = self.rbac.getUserPermissions( permissions = self.rbac.getUserPermissions(
self.currentUser, self.currentUser,
AccessRuleContext.DATA, AccessRuleContext.DATA,
tableName, objectKey,
mandateId=self.mandateId, mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId featureInstanceId=self.featureInstanceId
) )
@ -198,10 +200,12 @@ class TrusteeObjects:
return AccessLevel.NONE return AccessLevel.NONE
tableName = modelClass.__name__ tableName = modelClass.__name__
from modules.interfaces.interfaceRbac import buildDataObjectKey
objectKey = buildDataObjectKey(tableName, featureCode=self.featureCode if hasattr(self, 'featureCode') else None)
permissions = self.rbac.getUserPermissions( permissions = self.rbac.getUserPermissions(
self.currentUser, self.currentUser,
AccessRuleContext.DATA, AccessRuleContext.DATA,
tableName, objectKey,
mandateId=self.mandateId, mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId featureInstanceId=self.featureInstanceId
) )
@ -1470,7 +1474,7 @@ class TrusteeObjects:
def getAllUserAccess(self, userId: str) -> List[Dict[str, Any]]: def getAllUserAccess(self, userId: str) -> List[Dict[str, Any]]:
"""Get all access records for a user across all organisations.""" """Get all access records for a user across all organisations."""
return self.db.getRecordset(TrusteeAccess, {"userId": userId}) return self.db.getRecordset(TrusteeAccess, recordFilter={"userId": userId})
def getUserTrusteeRoles(self, userId: str, organisationId: str, contractId: Optional[str] = None) -> List[str]: def getUserTrusteeRoles(self, userId: str, organisationId: str, contractId: Optional[str] = None) -> List[str]:
""" """

View file

@ -129,7 +129,7 @@ def initAutomationTemplates(dbApp: DatabaseConnector, adminUserId: Optional[str]
# Get admin user ID if not provided (from poweron_app) # Get admin user ID if not provided (from poweron_app)
if not adminUserId: if not adminUserId:
adminUsers = dbApp.getRecordset(UserInDB, {"email": APP_CONFIG.ADMIN_EMAIL}) adminUsers = dbApp.getRecordset(UserInDB, recordFilter={"email": APP_CONFIG.ADMIN_EMAIL})
adminUserId = adminUsers[0]["id"] if adminUsers else None adminUserId = adminUsers[0]["id"] if adminUsers else None
# Update context with admin user # Update context with admin user
if adminUserId: if adminUserId:

View file

@ -245,10 +245,13 @@ class AppObjects:
return False return False
tableName = modelClass.__name__ tableName = modelClass.__name__
# Use buildDataObjectKey for semantic namespace lookup
from modules.interfaces.interfaceRbac import buildDataObjectKey
objectKey = buildDataObjectKey(tableName)
permissions = self.rbac.getUserPermissions( permissions = self.rbac.getUserPermissions(
self.currentUser, self.currentUser,
AccessRuleContext.DATA, AccessRuleContext.DATA,
tableName, objectKey,
mandateId=self.mandateId mandateId=self.mandateId
) )

View file

@ -339,6 +339,18 @@ class ChatObjects:
pass pass
def _getRecordset(self, modelClass, recordFilter=None, **kwargs):
"""Wrapper for getRecordsetWithRBAC that automatically includes mandateId/featureInstanceId."""
return getRecordsetWithRBAC(
self.db,
modelClass,
self.currentUser,
recordFilter=recordFilter,
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
**kwargs
)
def checkRbacPermission( def checkRbacPermission(
self, self,
modelClass: type, modelClass: type,
@ -610,12 +622,7 @@ class ChatObjects:
If pagination is provided: PaginatedResult with items and metadata If pagination is provided: PaginatedResult with items and metadata
""" """
# Use RBAC filtering with featureInstanceId for instance-level isolation # Use RBAC filtering with featureInstanceId for instance-level isolation
filteredWorkflows = getRecordsetWithRBAC(self.db, filteredWorkflows = self._getRecordset(ChatWorkflow)
ChatWorkflow,
self.currentUser,
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId
)
# If no pagination requested, return all items (no sorting - frontend handles it) # If no pagination requested, return all items (no sorting - frontend handles it)
if pagination is None: if pagination is None:
@ -647,13 +654,7 @@ class ChatObjects:
def getWorkflow(self, workflowId: str) -> Optional[ChatWorkflow]: def getWorkflow(self, workflowId: str) -> Optional[ChatWorkflow]:
"""Returns a workflow by ID if user has access.""" """Returns a workflow by ID if user has access."""
# Use RBAC filtering with featureInstanceId for instance-level isolation # Use RBAC filtering with featureInstanceId for instance-level isolation
workflows = getRecordsetWithRBAC(self.db, workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
ChatWorkflow,
self.currentUser,
recordFilter={"id": workflowId},
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId
)
if not workflows: if not workflows:
return None return None
@ -809,7 +810,7 @@ class ChatObjects:
# Delete message documents (but NOT the files!) # Delete message documents (but NOT the files!)
# Note: ChatStat does NOT have messageId - stats are only at workflow level # Note: ChatStat does NOT have messageId - stats are only at workflow level
try: try:
existing_docs = getRecordsetWithRBAC(self.db, ChatDocument, self.currentUser, recordFilter={"messageId": messageId}) existing_docs = self._getRecordset(ChatDocument, recordFilter={"messageId": messageId})
for doc in existing_docs: for doc in existing_docs:
self.db.recordDelete(ChatDocument, doc["id"]) self.db.recordDelete(ChatDocument, doc["id"])
except Exception as e: except Exception as e:
@ -819,12 +820,12 @@ class ChatObjects:
self.db.recordDelete(ChatMessage, messageId) self.db.recordDelete(ChatMessage, messageId)
# 2. Delete workflow stats # 2. Delete workflow stats
existing_stats = getRecordsetWithRBAC(self.db, ChatStat, self.currentUser, recordFilter={"workflowId": workflowId}) existing_stats = self._getRecordset(ChatStat, recordFilter={"workflowId": workflowId})
for stat in existing_stats: for stat in existing_stats:
self.db.recordDelete(ChatStat, stat["id"]) self.db.recordDelete(ChatStat, stat["id"])
# 3. Delete workflow logs # 3. Delete workflow logs
existing_logs = getRecordsetWithRBAC(self.db, ChatLog, self.currentUser, recordFilter={"workflowId": workflowId}) existing_logs = self._getRecordset(ChatLog, recordFilter={"workflowId": workflowId})
for log in existing_logs: for log in existing_logs:
self.db.recordDelete(ChatLog, log["id"]) self.db.recordDelete(ChatLog, log["id"])
@ -855,11 +856,7 @@ class ChatObjects:
""" """
# Check workflow access first (without calling getWorkflow to avoid circular reference) # Check workflow access first (without calling getWorkflow to avoid circular reference)
# Use RBAC filtering # Use RBAC filtering
workflows = getRecordsetWithRBAC(self.db, workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
ChatWorkflow,
self.currentUser,
recordFilter={"id": workflowId}
)
if not workflows: if not workflows:
if pagination is None: if pagination is None:
@ -867,7 +864,7 @@ class ChatObjects:
return PaginatedResult(items=[], totalItems=0, totalPages=0) return PaginatedResult(items=[], totalItems=0, totalPages=0)
# Get messages for this workflow from normalized table # Get messages for this workflow from normalized table
messages = getRecordsetWithRBAC(self.db, ChatMessage, self.currentUser, recordFilter={"workflowId": workflowId}) messages = self._getRecordset(ChatMessage, recordFilter={"workflowId": workflowId})
# Convert raw messages to dict format for sorting/filtering # Convert raw messages to dict format for sorting/filtering
messageDicts = [] messageDicts = []
@ -1143,7 +1140,7 @@ class ChatObjects:
raise ValueError("messageId cannot be empty") raise ValueError("messageId cannot be empty")
# Check if message exists in database # Check if message exists in database
messages = getRecordsetWithRBAC(self.db, ChatMessage, self.currentUser, recordFilter={"id": messageId}) messages = self._getRecordset(ChatMessage, recordFilter={"id": messageId})
if not messages: if not messages:
logger.warning(f"Message with ID {messageId} does not exist in database") logger.warning(f"Message with ID {messageId} does not exist in database")
@ -1250,12 +1247,12 @@ class ChatObjects:
# CASCADE DELETE: Delete all related data first # CASCADE DELETE: Delete all related data first
# 1. Delete message stats # 1. Delete message stats
existing_stats = getRecordsetWithRBAC(self.db, ChatStat, self.currentUser, recordFilter={"messageId": messageId}) existing_stats = self._getRecordset(ChatStat, recordFilter={"messageId": messageId})
for stat in existing_stats: for stat in existing_stats:
self.db.recordDelete(ChatStat, stat["id"]) self.db.recordDelete(ChatStat, stat["id"])
# 2. Delete message documents (but NOT the files!) # 2. Delete message documents (but NOT the files!)
existing_docs = getRecordsetWithRBAC(self.db, ChatDocument, self.currentUser, recordFilter={"messageId": messageId}) existing_docs = self._getRecordset(ChatDocument, recordFilter={"messageId": messageId})
for doc in existing_docs: for doc in existing_docs:
self.db.recordDelete(ChatDocument, doc["id"]) self.db.recordDelete(ChatDocument, doc["id"])
@ -1282,7 +1279,7 @@ class ChatObjects:
# Get documents for this message from normalized table # Get documents for this message from normalized table
documents = getRecordsetWithRBAC(self.db, ChatDocument, self.currentUser, recordFilter={"messageId": messageId}) documents = self._getRecordset(ChatDocument, recordFilter={"messageId": messageId})
if not documents: if not documents:
logger.warning(f"No documents found for message {messageId}") logger.warning(f"No documents found for message {messageId}")
@ -1323,7 +1320,7 @@ class ChatObjects:
def getDocuments(self, messageId: str) -> List[ChatDocument]: def getDocuments(self, messageId: str) -> List[ChatDocument]:
"""Returns documents for a message from normalized table.""" """Returns documents for a message from normalized table."""
try: try:
documents = getRecordsetWithRBAC(self.db, ChatDocument, self.currentUser, recordFilter={"messageId": messageId}) documents = self._getRecordset(ChatDocument, recordFilter={"messageId": messageId})
return [ChatDocument(**doc) for doc in documents] return [ChatDocument(**doc) for doc in documents]
except Exception as e: except Exception as e:
logger.error(f"Error getting message documents: {str(e)}") logger.error(f"Error getting message documents: {str(e)}")
@ -1369,11 +1366,7 @@ class ChatObjects:
""" """
# Check workflow access first (without calling getWorkflow to avoid circular reference) # Check workflow access first (without calling getWorkflow to avoid circular reference)
# Use RBAC filtering # Use RBAC filtering
workflows = getRecordsetWithRBAC(self.db, workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
ChatWorkflow,
self.currentUser,
recordFilter={"id": workflowId}
)
if not workflows: if not workflows:
if pagination is None: if pagination is None:
@ -1381,7 +1374,7 @@ class ChatObjects:
return PaginatedResult(items=[], totalItems=0, totalPages=0) return PaginatedResult(items=[], totalItems=0, totalPages=0)
# Get logs for this workflow from normalized table # Get logs for this workflow from normalized table
logs = getRecordsetWithRBAC(self.db, ChatLog, self.currentUser, recordFilter={"workflowId": workflowId}) logs = self._getRecordset(ChatLog, recordFilter={"workflowId": workflowId})
# Convert raw logs to dict format for sorting/filtering # Convert raw logs to dict format for sorting/filtering
logDicts = [] logDicts = []
@ -1513,17 +1506,13 @@ class ChatObjects:
"""Returns list of statistics for a workflow if user has access.""" """Returns list of statistics for a workflow if user has access."""
# Check workflow access first (without calling getWorkflow to avoid circular reference) # Check workflow access first (without calling getWorkflow to avoid circular reference)
# Use RBAC filtering # Use RBAC filtering
workflows = getRecordsetWithRBAC(self.db, workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
ChatWorkflow,
self.currentUser,
recordFilter={"id": workflowId}
)
if not workflows: if not workflows:
return [] return []
# Get stats for this workflow from normalized table # Get stats for this workflow from normalized table
stats = getRecordsetWithRBAC(self.db, ChatStat, self.currentUser, recordFilter={"workflowId": workflowId}) stats = self._getRecordset(ChatStat, recordFilter={"workflowId": workflowId})
if not stats: if not stats:
return [] return []
@ -1581,11 +1570,7 @@ class ChatObjects:
""" """
# Check workflow access first # Check workflow access first
# Use RBAC filtering # Use RBAC filtering
workflows = getRecordsetWithRBAC(self.db, workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
ChatWorkflow,
self.currentUser,
recordFilter={"id": workflowId}
)
if not workflows: if not workflows:
return {"items": []} return {"items": []}
@ -1594,7 +1579,7 @@ class ChatObjects:
items = [] items = []
# Get messages # Get messages
messages = getRecordsetWithRBAC(self.db, ChatMessage, self.currentUser, recordFilter={"workflowId": workflowId}) messages = self._getRecordset(ChatMessage, recordFilter={"workflowId": workflowId})
for msg in messages: for msg in messages:
# Apply timestamp filtering in Python # Apply timestamp filtering in Python
msgTimestamp = parseTimestamp(msg.get("publishedAt"), default=getUtcTimestamp()) msgTimestamp = parseTimestamp(msg.get("publishedAt"), default=getUtcTimestamp())
@ -1635,7 +1620,7 @@ class ChatObjects:
}) })
# Get logs - return all logs with roundNumber if available # Get logs - return all logs with roundNumber if available
logs = getRecordsetWithRBAC(self.db, ChatLog, self.currentUser, recordFilter={"workflowId": workflowId}) logs = self._getRecordset(ChatLog, recordFilter={"workflowId": workflowId})
for log in logs: for log in logs:
# Apply timestamp filtering in Python # Apply timestamp filtering in Python
logTimestamp = parseTimestamp(log.get("timestamp"), default=getUtcTimestamp()) logTimestamp = parseTimestamp(log.get("timestamp"), default=getUtcTimestamp())

View file

@ -313,10 +313,12 @@ class ComponentObjects:
return False return False
tableName = modelClass.__name__ tableName = modelClass.__name__
from modules.interfaces.interfaceRbac import buildDataObjectKey
objectKey = buildDataObjectKey(tableName)
permissions = self.rbac.getUserPermissions( permissions = self.rbac.getUserPermissions(
self.currentUser, self.currentUser,
AccessRuleContext.DATA, AccessRuleContext.DATA,
tableName, objectKey,
mandateId=self.mandateId, mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId featureInstanceId=self.featureInstanceId
) )
@ -590,10 +592,58 @@ class ComponentObjects:
# Prompt methods # Prompt methods
def _isSysAdmin(self) -> bool:
"""Check if the current user is a SysAdmin."""
return hasattr(self.currentUser, 'isSysAdmin') and self.currentUser.isSysAdmin
def _enrichPromptsWithPermissions(self, prompts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Enrich prompts with row-level _permissions based on ownership and isSystem flag.
- SysAdmin: canUpdate=True, canDelete=True on all prompts
- Regular user on own prompts: canUpdate=True, canDelete=True
- Regular user on system prompts: canUpdate=False, canDelete=False (read-only)
"""
isSysAdmin = self._isSysAdmin()
for prompt in prompts:
isOwner = prompt.get("_createdBy") == self.userId
prompt["_permissions"] = {
"canUpdate": isOwner or isSysAdmin,
"canDelete": isOwner or isSysAdmin
}
return prompts
def _getPromptsForUser(self) -> List[Dict[str, Any]]:
"""Returns prompts visible to the current user.
Visibility rules:
- SysAdmin: ALL prompts
- Regular user: own prompts (_createdBy) + system prompts (isSystem=True)
"""
if self._isSysAdmin():
return self.db.getRecordset(Prompt)
# Get own prompts
ownPrompts = self.db.getRecordset(Prompt, recordFilter={"_createdBy": self.userId})
# Get system prompts
systemPrompts = self.db.getRecordset(Prompt, recordFilter={"isSystem": True})
# Merge and deduplicate (a user's own prompt could also be isSystem)
seen = {}
for p in ownPrompts:
seen[p["id"]] = p
for p in systemPrompts:
if p["id"] not in seen:
seen[p["id"]] = p
return list(seen.values())
def getAllPrompts(self, pagination: Optional[PaginationParams] = None) -> Union[List[Prompt], PaginatedResult]: def getAllPrompts(self, pagination: Optional[PaginationParams] = None) -> Union[List[Prompt], PaginatedResult]:
""" """
Returns prompts based on user access level. Returns prompts with visibility rules:
Supports optional pagination, sorting, and filtering. - SysAdmin: sees ALL prompts, can CRUD all
- Regular user: sees own prompts + system prompts (isSystem=True), can only CRUD own
- Row-level _permissions control edit/delete buttons in the UI
Args: Args:
pagination: Optional pagination parameters. If None, returns all items. pagination: Optional pagination parameters. If None, returns all items.
@ -603,11 +653,11 @@ class ComponentObjects:
If pagination is provided: PaginatedResult with items and metadata If pagination is provided: PaginatedResult with items and metadata
""" """
try: try:
# Use RBAC filtering # Get prompts based on user role (own + system for regular, all for SysAdmin)
filteredPrompts = getRecordsetWithRBAC(self.db, filteredPrompts = self._getPromptsForUser()
Prompt,
self.currentUser # Enrich with row-level permissions (_permissions: canUpdate, canDelete)
) filteredPrompts = self._enrichPromptsWithPermissions(filteredPrompts)
# If no pagination requested, return all items # If no pagination requested, return all items
if pagination is None: if pagination is None:
@ -630,7 +680,7 @@ class ComponentObjects:
endIdx = startIdx + pagination.pageSize endIdx = startIdx + pagination.pageSize
pagedPrompts = filteredPrompts[startIdx:endIdx] pagedPrompts = filteredPrompts[startIdx:endIdx]
# Convert to model objects # Convert to model objects (extra='allow' on Prompt preserves system fields)
items = [Prompt(**prompt) for prompt in pagedPrompts] items = [Prompt(**prompt) for prompt in pagedPrompts]
return PaginatedResult( return PaginatedResult(
@ -646,15 +696,24 @@ class ComponentObjects:
return PaginatedResult(items=[], totalItems=0, totalPages=0) return PaginatedResult(items=[], totalItems=0, totalPages=0)
def getPrompt(self, promptId: str) -> Optional[Prompt]: def getPrompt(self, promptId: str) -> Optional[Prompt]:
"""Returns a prompt by ID if user has access.""" """Returns a prompt by ID if the user has visibility.
# Use RBAC filtering
filteredPrompts = getRecordsetWithRBAC(self.db,
Prompt,
self.currentUser,
recordFilter={"id": promptId}
)
return Prompt(**filteredPrompts[0]) if filteredPrompts else None Visibility: SysAdmin sees all, regular user sees own + system prompts.
"""
filteredPrompts = self.db.getRecordset(Prompt, recordFilter={"id": promptId})
if not filteredPrompts:
return None
prompt = filteredPrompts[0]
# Visibility check for non-SysAdmin: must be owner or system prompt
if not self._isSysAdmin():
isOwner = prompt.get("_createdBy") == self.userId
isSystem = prompt.get("isSystem", False)
if not isOwner and not isSystem:
return None
return Prompt(**prompt)
def createPrompt(self, promptData: Dict[str, Any]) -> Dict[str, Any]: def createPrompt(self, promptData: Dict[str, Any]) -> Dict[str, Any]:
"""Creates a new prompt if user has permission.""" """Creates a new prompt if user has permission."""
@ -669,13 +728,25 @@ class ComponentObjects:
return createdRecord return createdRecord
def updatePrompt(self, promptId: str, updateData: Dict[str, Any]) -> Dict[str, Any]: def updatePrompt(self, promptId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
"""Updates a prompt if user has access.""" """Updates a prompt. Rules:
- SysAdmin: can update any prompt (including system prompts)
- Regular user: can only update own prompts (not system prompts)
"""
try: try:
# Get prompt # Get prompt (visibility-checked)
prompt = self.getPrompt(promptId) prompt = self.getPrompt(promptId)
if not prompt: if not prompt:
raise ValueError(f"Prompt {promptId} not found") raise ValueError(f"Prompt {promptId} not found")
# Permission check: owner or SysAdmin
isOwner = (getattr(prompt, '_createdBy', None) == self.userId)
if not self._isSysAdmin() and not isOwner:
raise PermissionError(f"No permission to update prompt {promptId}")
# Regular users cannot set isSystem flag
if not self._isSysAdmin() and 'isSystem' in updateData:
del updateData['isSystem']
# Update prompt record directly with the update data # Update prompt record directly with the update data
self.db.recordModify(Prompt, promptId, updateData) self.db.recordModify(Prompt, promptId, updateData)
@ -688,77 +759,69 @@ class ComponentObjects:
return updatedPrompt.model_dump() return updatedPrompt.model_dump()
except PermissionError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error updating prompt: {str(e)}") logger.error(f"Error updating prompt: {str(e)}")
raise ValueError(f"Failed to update prompt: {str(e)}") raise ValueError(f"Failed to update prompt: {str(e)}")
def deletePrompt(self, promptId: str) -> bool: def deletePrompt(self, promptId: str) -> bool:
"""Deletes a prompt if user has access.""" """Deletes a prompt. Rules:
# Check if the prompt exists and user has access - SysAdmin: can delete any prompt (including system prompts)
- Regular user: can only delete own prompts (not system prompts)
"""
# Get prompt (visibility-checked)
prompt = self.getPrompt(promptId) prompt = self.getPrompt(promptId)
if not prompt: if not prompt:
return False return False
if not self.checkRbacPermission(Prompt, "update", promptId): # Permission check: owner or SysAdmin
isOwner = (getattr(prompt, '_createdBy', None) == self.userId)
if not self._isSysAdmin() and not isOwner:
raise PermissionError(f"No permission to delete prompt {promptId}") raise PermissionError(f"No permission to delete prompt {promptId}")
# Delete prompt # Delete prompt
success = self.db.recordDelete(Prompt, promptId) success = self.db.recordDelete(Prompt, promptId)
return success return success
# File Utilities # File Utilities
def checkForDuplicateFile(self, fileHash: str, fileName: str = None) -> Optional[FileItem]: def checkForDuplicateFile(self, fileHash: str, fileName: str) -> Optional[FileItem]:
"""Checks if a file with the same hash already exists for the current user and mandate. """Checks if a file with the same hash AND fileName already exists for the current user.
If fileName is provided, also checks for exact name+hash match.
Only returns files the current user has access to.""" Duplicate = same user (_createdBy) + same fileHash + same fileName.
# Get files with the hash, filtered by RBAC Same hash with different name is allowed (intentional copy by user).
accessibleFiles = getRecordsetWithRBAC(self.db, Uses direct DB query (not RBAC) because files are isolated per user.
"""
if not self.userId:
return None
# Direct DB query: find files with matching hash + name + user
matchingFiles = self.db.getRecordset(
FileItem, FileItem,
self.currentUser, recordFilter={
recordFilter={"fileHash": fileHash} "_createdBy": self.userId,
"fileHash": fileHash,
"fileName": fileName
}
) )
if not accessibleFiles: if not matchingFiles:
return None return None
# If fileName is provided, check for exact name+hash match first
if fileName:
for file in accessibleFiles:
# Skip files without fileName key or with None/empty fileName
if "fileName" not in file or not file["fileName"]:
continue
if file["fileName"] == fileName:
return FileItem(
id=file["id"],
mandateId=file["mandateId"],
fileName=file["fileName"],
mimeType=file["mimeType"],
fileHash=file["fileHash"],
fileSize=file["fileSize"],
creationDate=file["creationDate"]
)
# Return first valid file with matching hash (for general duplicate detection) # Return first match
for file in accessibleFiles: file = matchingFiles[0]
# Skip files without fileName key or with None/empty fileName return FileItem(
if "fileName" not in file or not file["fileName"]: id=file["id"],
continue mandateId=file.get("mandateId", ""),
# Use first valid file featureInstanceId=file.get("featureInstanceId", ""),
return FileItem( fileName=file["fileName"],
id=file["id"], mimeType=file["mimeType"],
mandateId=file["mandateId"], fileHash=file["fileHash"],
fileName=file["fileName"], fileSize=file["fileSize"],
mimeType=file["mimeType"], creationDate=file["creationDate"]
fileHash=file["fileHash"], )
fileSize=file["fileSize"],
creationDate=file["creationDate"]
)
# If no valid files found, return None
return None
def getMimeType(self, fileName: str) -> str: def getMimeType(self, fileName: str) -> str:
"""Determines the MIME type based on the file extension.""" """Determines the MIME type based on the file extension."""
@ -832,9 +895,18 @@ class ComponentObjects:
# File methods - metadata-based operations # File methods - metadata-based operations
def _getFilesByCurrentUser(self, recordFilter: Dict[str, Any] = None) -> List[Dict[str, Any]]:
"""Files are always user-scoped. Returns only files owned by the current user,
regardless of role (including SysAdmin). This bypasses RBAC intentionally."""
filterDict = {"_createdBy": self.userId}
if recordFilter:
filterDict.update(recordFilter)
return self.db.getRecordset(FileItem, recordFilter=filterDict)
def getAllFiles(self, pagination: Optional[PaginationParams] = None) -> Union[List[FileItem], PaginatedResult]: def getAllFiles(self, pagination: Optional[PaginationParams] = None) -> Union[List[FileItem], PaginatedResult]:
""" """
Returns files based on user access level. Returns files owned by the current user (user-scoped, not RBAC-based).
Every user (including SysAdmin) only sees their own files.
Supports optional pagination, sorting, and filtering. Supports optional pagination, sorting, and filtering.
Args: Args:
@ -844,13 +916,10 @@ class ComponentObjects:
If pagination is None: List[FileItem] If pagination is None: List[FileItem]
If pagination is provided: PaginatedResult with items and metadata If pagination is provided: PaginatedResult with items and metadata
""" """
# Use RBAC filtering # Files are always user-scoped: filter by _createdBy (bypasses RBAC SysAdmin override)
filteredFiles = getRecordsetWithRBAC(self.db, filteredFiles = self._getFilesByCurrentUser()
FileItem,
self.currentUser
)
# Convert database records to FileItem instances (for both paginated and non-paginated) # Convert database records to FileItem instances (extra='allow' preserves system fields like _createdBy)
def convertFileItems(files): def convertFileItems(files):
fileItems = [] fileItems = []
for file in files: for file in files:
@ -858,21 +927,14 @@ class ComponentObjects:
# Ensure proper values, use defaults for invalid data # Ensure proper values, use defaults for invalid data
creationDate = file.get("creationDate") creationDate = file.get("creationDate")
if creationDate is None or not isinstance(creationDate, (int, float)) or creationDate <= 0: if creationDate is None or not isinstance(creationDate, (int, float)) or creationDate <= 0:
creationDate = getUtcTimestamp() file["creationDate"] = getUtcTimestamp()
fileName = file.get("fileName") fileName = file.get("fileName")
if not fileName or fileName == "None": if not fileName or fileName == "None":
continue # Skip records with invalid fileName continue # Skip records with invalid fileName
fileItem = FileItem( # Use **file to pass all fields including system fields (_createdBy, etc.)
id=file.get("id"), fileItem = FileItem(**file)
mandateId=file.get("mandateId"),
fileName=fileName,
mimeType=file.get("mimeType"),
fileHash=file.get("fileHash"),
fileSize=file.get("fileSize"),
creationDate=creationDate
)
fileItems.append(fileItem) fileItems.append(fileItem)
except Exception as e: except Exception as e:
logger.warning(f"Skipping invalid file record: {str(e)}") logger.warning(f"Skipping invalid file record: {str(e)}")
@ -900,7 +962,7 @@ class ComponentObjects:
endIdx = startIdx + pagination.pageSize endIdx = startIdx + pagination.pageSize
pagedFiles = filteredFiles[startIdx:endIdx] pagedFiles = filteredFiles[startIdx:endIdx]
# Convert to model objects # Convert to model objects (extra='allow' on FileItem preserves system fields)
items = convertFileItems(pagedFiles) items = convertFileItems(pagedFiles)
return PaginatedResult( return PaginatedResult(
@ -910,13 +972,9 @@ class ComponentObjects:
) )
def getFile(self, fileId: str) -> Optional[FileItem]: def getFile(self, fileId: str) -> Optional[FileItem]:
"""Returns a file by ID if user has access.""" """Returns a file by ID if it belongs to the current user (user-scoped)."""
# Use RBAC filtering # Files are always user-scoped: filter by _createdBy (bypasses RBAC SysAdmin override)
filteredFiles = getRecordsetWithRBAC(self.db, filteredFiles = self._getFilesByCurrentUser(recordFilter={"id": fileId})
FileItem,
self.currentUser,
recordFilter={"id": fileId}
)
if not filteredFiles: if not filteredFiles:
return None return None
@ -976,17 +1034,28 @@ class ComponentObjects:
counter += 1 counter += 1
def createFile(self, name: str, mimeType: str, content: bytes) -> FileItem: def createFile(self, name: str, mimeType: str, content: bytes) -> FileItem:
"""Creates a new file entry if user has permission. Computes fileHash and fileSize from content.""" """Creates a new file entry if user has permission. Computes fileHash and fileSize from content.
Duplicate check: if a file with the same user + fileHash + fileName already exists,
the existing file is returned instead of creating a new one.
Same hash with different name is allowed (intentional copy by user).
"""
if not self.checkRbacPermission(FileItem, "create"): if not self.checkRbacPermission(FileItem, "create"):
raise PermissionError("No permission to create files") raise PermissionError("No permission to create files")
# Ensure fileName is unique
uniqueName = self._generateUniquefileName(name)
# Compute file size and hash # Compute file size and hash
fileSize = len(content) fileSize = len(content)
fileHash = hashlib.sha256(content).hexdigest() fileHash = hashlib.sha256(content).hexdigest()
# Duplicate check: same user + same hash + same fileName → return existing
existingFile = self.checkForDuplicateFile(fileHash, name)
if existingFile:
logger.info(f"Duplicate file detected in createFile: '{name}' (hash={fileHash[:12]}...) for user {self.userId} — returning existing file {existingFile.id}")
return existingFile
# Ensure fileName is unique
uniqueName = self._generateUniquefileName(name)
# Use mandateId and featureInstanceId from context for proper data isolation # Use mandateId and featureInstanceId from context for proper data isolation
# Convert None to empty string to satisfy Pydantic validation # Convert None to empty string to satisfy Pydantic validation
mandateId = self.mandateId or "" mandateId = self.mandateId or ""
@ -1005,7 +1074,6 @@ class ComponentObjects:
# Store in database # Store in database
self.db.recordCreate(FileItem, fileItem) self.db.recordCreate(FileItem, fileItem)
return fileItem return fileItem
def updateFile(self, fileId: str, updateData: Dict[str, Any]) -> Dict[str, Any]: def updateFile(self, fileId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
@ -1040,20 +1108,16 @@ class ComponentObjects:
if not self.checkRbacPermission(FileItem, "update", fileId): if not self.checkRbacPermission(FileItem, "update", fileId):
raise PermissionError(f"No permission to delete file {fileId}") raise PermissionError(f"No permission to delete file {fileId}")
# Check for other references to this file (by hash) - use RBAC to only check files user has access to # Check for other references to this file (by hash) - user-scoped check
fileHash = file.fileHash fileHash = file.fileHash
if fileHash: if fileHash:
allReferences = getRecordsetWithRBAC(self.db, allReferences = self._getFilesByCurrentUser(recordFilter={"fileHash": fileHash})
FileItem,
self.currentUser,
recordFilter={"fileHash": fileHash}
)
otherReferences = [f for f in allReferences if f["id"] != fileId] otherReferences = [f for f in allReferences if f["id"] != fileId]
# Only delete associated fileData if no other references exist # Only delete associated fileData if no other references exist
if not otherReferences: if not otherReferences:
try: try:
fileDataEntries = getRecordsetWithRBAC(self.db, FileData, self.currentUser, recordFilter={"id": fileId}) fileDataEntries = self.db.getRecordset(FileData, recordFilter={"id": fileId})
if fileDataEntries: if fileDataEntries:
self.db.recordDelete(FileData, fileId) self.db.recordDelete(FileData, fileId)
logger.debug(f"FileData for file {fileId} deleted") logger.debug(f"FileData for file {fileId} deleted")
@ -1113,6 +1177,12 @@ class ComponentObjects:
base64Encoded = True base64Encoded = True
logger.debug(f"Stored file {fileId} as base64") logger.debug(f"Stored file {fileId} as base64")
# Check if file data already exists (e.g., when createFile returned a duplicate)
existingData = self.db.getRecordset(FileData, recordFilter={"id": fileId})
if existingData:
logger.debug(f"File data already exists for {fileId} — skipping duplicate storage")
return True
# Create the fileData record with data and encoding flag # Create the fileData record with data and encoding flag
fileDataObj = { fileDataObj = {
"id": fileId, "id": fileId,
@ -1245,25 +1315,21 @@ class ComponentObjects:
logger.error(f"Invalid fileContent type: {type(fileContent)}") logger.error(f"Invalid fileContent type: {type(fileContent)}")
raise ValueError(f"fileContent must be bytes, got {type(fileContent)}") raise ValueError(f"fileContent must be bytes, got {type(fileContent)}")
# Compute file hash first to check for duplicates # Compute file hash to check for duplicates before any DB writes
fileHash = hashlib.sha256(fileContent).hexdigest() fileHash = hashlib.sha256(fileContent).hexdigest()
# Check for exact name+hash match first (same name + same content) # Duplicate check: same user + same fileHash + same fileName → return existing file
# Same hash with different name is allowed (intentional copy by user)
existingFile = self.checkForDuplicateFile(fileHash, fileName) existingFile = self.checkForDuplicateFile(fileHash, fileName)
if existingFile: if existingFile:
logger.info(f"Exact duplicate detected: {fileName} with same hash. Returning existing file reference.") logger.info(f"Duplicate detected for user {self.userId}: '{fileName}' with hash {fileHash[:12]}... — returning existing file {existingFile.id}")
return existingFile, "exact_duplicate" return existingFile, "exact_duplicate"
# Check for hash-only match (same content, different name)
existingFileWithSameHash = self.checkForDuplicateFile(fileHash)
if existingFileWithSameHash:
logger.info(f"Content duplicate detected: {fileName} has same content as {existingFileWithSameHash.fileName}")
# Continue with upload - filename will be made unique if needed
# Determine MIME type # Determine MIME type
mimeType = self.getMimeType(fileName) mimeType = self.getMimeType(fileName)
# Save metadata and file (hash/size computed inside createFile) # createFile handles its own duplicate check (for calls from other code paths)
# Here we already checked, so this will create a new file
logger.debug(f"Saving file metadata to database for file: {fileName}") logger.debug(f"Saving file metadata to database for file: {fileName}")
fileItem = self.createFile( fileItem = self.createFile(
name=fileName, name=fileName,

View file

@ -163,7 +163,7 @@ def getRecordsetWithRBAC(
# Check view permission first # Check view permission first
if not permissions.view: if not permissions.view:
logger.debug(f"User {currentUser.id} has no view permission for {objectKey}") logger.debug(f"User {currentUser.id} has no view permission for {objectKey} (mandateId={effectiveMandateId}, featureInstanceId={featureInstanceId})")
return [] return []
# Build WHERE clause with RBAC filtering # Build WHERE clause with RBAC filtering

View file

@ -90,7 +90,7 @@ async def sync_all_automation_events(
from modules.services import getInterface as getServices from modules.services import getInterface as getServices
services = getServices(currentUser, None) services = getServices(currentUser, None)
result = await syncAutomationEvents(services, eventUser) result = syncAutomationEvents(services, eventUser)
return { return {
"success": True, "success": True,
"synced": result.get("synced", 0), "synced": result.get("synced", 0),

View file

@ -78,11 +78,18 @@ def get_permissions(
) )
# MULTI-TENANT: Get permissions using context (mandateId/featureInstanceId) # MULTI-TENANT: Get permissions using context (mandateId/featureInstanceId)
# For DATA context, resolve short model names to full objectKeys
# e.g., "ChatWorkflow" → "data.chat.ChatWorkflow"
resolvedItem = item or ""
if accessContext == AccessRuleContext.DATA and resolvedItem and "." not in resolvedItem:
from modules.interfaces.interfaceRbac import buildDataObjectKey
resolvedItem = buildDataObjectKey(resolvedItem)
# Pass mandateId and featureInstanceId to load Feature-Instance roles # Pass mandateId and featureInstanceId to load Feature-Instance roles
permissions = interface.rbac.getUserPermissions( permissions = interface.rbac.getUserPermissions(
reqContext.user, reqContext.user,
accessContext, accessContext,
item or "", resolvedItem,
mandateId=reqContext.mandateId, mandateId=reqContext.mandateId,
featureInstanceId=reqContext.featureInstanceId featureInstanceId=reqContext.featureInstanceId
) )

View file

@ -121,10 +121,10 @@ def get_prompt(
def update_prompt( def update_prompt(
request: Request, request: Request,
promptId: str = Path(..., description="ID of the prompt to update"), promptId: str = Path(..., description="ID of the prompt to update"),
promptData: Prompt = Body(...), promptData: Dict[str, Any] = Body(...),
currentUser: User = Depends(getCurrentUser) currentUser: User = Depends(getCurrentUser)
) -> Prompt: ) -> Prompt:
"""Update an existing prompt""" """Update an existing prompt (supports partial updates for inline editing)"""
managementInterface = interfaceDbManagement.getInterface(currentUser) managementInterface = interfaceDbManagement.getInterface(currentUser)
# Check if the prompt exists # Check if the prompt exists
@ -135,14 +135,17 @@ def update_prompt(
detail=f"Prompt with ID {promptId} not found" detail=f"Prompt with ID {promptId} not found"
) )
# Convert Prompt to dict for interface, excluding the id field # Remove id from update data if present
if hasattr(promptData, "model_dump"): update_data = {k: v for k, v in promptData.items() if k != "id"}
update_data = promptData.model_dump(exclude={"id"})
else:
update_data = promptData.model_dump(exclude={"id"})
# Update prompt # Update prompt (ownership check happens in interface)
updatedPrompt = managementInterface.updatePrompt(promptId, update_data) try:
updatedPrompt = managementInterface.updatePrompt(promptId, update_data)
except PermissionError as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e)
)
if not updatedPrompt: if not updatedPrompt:
raise HTTPException( raise HTTPException(
@ -170,7 +173,14 @@ def delete_prompt(
detail=f"Prompt with ID {promptId} not found" detail=f"Prompt with ID {promptId} not found"
) )
success = managementInterface.deletePrompt(promptId) try:
success = managementInterface.deletePrompt(promptId)
except PermissionError as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e)
)
if not success: if not success:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,

View file

@ -62,7 +62,7 @@ class RbacClass:
Multi-Tenant Design: Multi-Tenant Design:
- Lädt Rollen aus UserMandate + UserMandateRole wenn mandateId gegeben - Lädt Rollen aus UserMandate + UserMandateRole wenn mandateId gegeben
- isSysAdmin gibt vollen Zugriff auf System-Level (kein mandateId) - isSysAdmin gibt vollen Zugriff, unabhängig vom Kontext
Args: Args:
user: User object user: User object
@ -82,8 +82,8 @@ class RbacClass:
delete=AccessLevel.NONE delete=AccessLevel.NONE
) )
# SysAdmin auf System-Level (kein Mandant) hat vollen Zugriff # SysAdmin hat vollen Zugriff - unabhängig vom Kontext (Mandant/Feature)
if hasattr(user, 'isSysAdmin') and user.isSysAdmin and not mandateId: if hasattr(user, 'isSysAdmin') and user.isSysAdmin:
return UserPermissions( return UserPermissions(
view=True, view=True,
read=AccessLevel.ALL, read=AccessLevel.ALL,
@ -96,6 +96,7 @@ class RbacClass:
roleIds = self._getRoleIdsForUser(user, mandateId, featureInstanceId) roleIds = self._getRoleIdsForUser(user, mandateId, featureInstanceId)
if not roleIds: if not roleIds:
logger.debug(f"getUserPermissions: NO roles found for user={user.id}, mandateId={mandateId}, featureInstanceId={featureInstanceId}, item={item}")
return permissions return permissions
# Lade alle relevanten Regeln für alle Rollen # Lade alle relevanten Regeln für alle Rollen

View file

@ -63,10 +63,11 @@ class Services:
- Feature-specific Services are loaded dynamically via filename discovery - Feature-specific Services are loaded dynamically via filename discovery
""" """
def __init__(self, user: User, workflow: "ChatWorkflow" = None, mandateId: Optional[str] = None): def __init__(self, user: User, workflow: "ChatWorkflow" = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None):
self.user: User = user self.user: User = user
self.workflow = workflow self.workflow = workflow
self.mandateId: Optional[str] = mandateId self.mandateId: Optional[str] = mandateId
self.featureInstanceId: Optional[str] = featureInstanceId
self.currentUserPrompt: str = "" self.currentUserPrompt: str = ""
self.rawUserPrompt: str = "" self.rawUserPrompt: str = ""
@ -83,7 +84,7 @@ class Services:
# CENTRAL INTERFACE (Chat/Workflow) # CENTRAL INTERFACE (Chat/Workflow)
# ============================================================ # ============================================================
from modules.interfaces.interfaceDbChat import getInterface as getChatInterface from modules.interfaces.interfaceDbChat import getInterface as getChatInterface
self.interfaceDbChat = getChatInterface(user, mandateId=mandateId) self.interfaceDbChat = getChatInterface(user, mandateId=mandateId, featureInstanceId=featureInstanceId)
# ============================================================ # ============================================================
# SHARED SERVICES (from modules/services/) # SHARED SERVICES (from modules/services/)
@ -143,7 +144,7 @@ class Services:
# Get interface via getInterface() # Get interface via getInterface()
if hasattr(module, "getInterface"): if hasattr(module, "getInterface"):
interface = module.getInterface(self.user, mandateId=self.mandateId) interface = module.getInterface(self.user, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId)
# Derive attribute name: interfaceFeatureAiChat -> interfaceDbChat # Derive attribute name: interfaceFeatureAiChat -> interfaceDbChat
attrName = filename.replace("interfaceFeature", "interfaceDb") attrName = filename.replace("interfaceFeature", "interfaceDb")
setattr(self, attrName, interface) setattr(self, attrName, interface)
@ -191,6 +192,6 @@ class Services:
logger.debug(f"Could not load service from {filepath}: {e}") logger.debug(f"Could not load service from {filepath}: {e}")
def getInterface(user: User, workflow: "ChatWorkflow" = None, mandateId: Optional[str] = None) -> Services: def getInterface(user: User, workflow: "ChatWorkflow" = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> Services:
"""Get Services instance for the given user and mandate context.""" """Get Services instance for the given user, mandate, and feature instance context."""
return Services(user, workflow, mandateId=mandateId) return Services(user, workflow, mandateId=mandateId, featureInstanceId=featureInstanceId)

View file

@ -2574,8 +2574,8 @@ CRITICAL:
""" """
from modules.services.serviceGeneration.renderers.registry import getRenderer from modules.services.serviceGeneration.renderers.registry import getRenderer
# Get renderer for this format - NO FALLBACK # Get document renderer for this format (structure filling is document generation path)
renderer = getRenderer(outputFormat, self.services) renderer = getRenderer(outputFormat, self.services, outputStyle='document')
if not renderer: if not renderer:
raise ValueError(f"No renderer found for output format '{outputFormat}'. Check renderer registry.") raise ValueError(f"No renderer found for output format '{outputFormat}'. Check renderer registry.")

View file

@ -556,10 +556,10 @@ class GenerationService:
def _getFormatRenderer(self, output_format: str): def _getFormatRenderer(self, output_format: str):
"""Get the appropriate renderer for the specified format using auto-discovery.""" """Get the appropriate document renderer for the specified format."""
try: try:
from .renderers.registry import getRenderer, getSupportedFormats from .renderers.registry import getRenderer, getSupportedFormats
renderer = getRenderer(output_format, services=self.services) renderer = getRenderer(output_format, services=self.services, outputStyle='document')
if renderer: if renderer:
return renderer return renderer
@ -573,7 +573,7 @@ class GenerationService:
# Fallback to text renderer if no specific renderer found # Fallback to text renderer if no specific renderer found
logger.warning(f"Falling back to text renderer for format {output_format}") logger.warning(f"Falling back to text renderer for format {output_format}")
fallbackRenderer = getRenderer('text', services=self.services) fallbackRenderer = getRenderer('text', services=self.services, outputStyle='document')
if fallbackRenderer: if fallbackRenderer:
return fallbackRenderer return fallbackRenderer

View file

@ -922,7 +922,7 @@ CRITICAL:
"""Get code renderer for file type.""" """Get code renderer for file type."""
from modules.services.serviceGeneration.renderers.registry import getRenderer from modules.services.serviceGeneration.renderers.registry import getRenderer
# Map file types to renderer formats # Map file types to renderer formats (code path)
formatMap = { formatMap = {
'json': 'json', 'json': 'json',
'csv': 'csv', 'csv': 'csv',
@ -931,7 +931,7 @@ CRITICAL:
rendererFormat = formatMap.get(fileType.lower()) rendererFormat = formatMap.get(fileType.lower())
if rendererFormat: if rendererFormat:
renderer = getRenderer(rendererFormat, self.services) renderer = getRenderer(rendererFormat, self.services, outputStyle='code')
# Check if renderer supports code rendering # Check if renderer supports code rendering
if renderer and hasattr(renderer, 'renderCodeFiles'): if renderer and hasattr(renderer, 'renderCodeFiles'):
return renderer return renderer

View file

@ -2,20 +2,30 @@
# All rights reserved. # All rights reserved.
""" """
Renderer registry for automatic discovery and registration of renderers. Renderer registry for automatic discovery and registration of renderers.
Renderers are indexed by (format, outputStyle) so that document generation
and code generation each get the correct renderer for the same format.
""" """
import logging import logging
import importlib import importlib
from typing import Dict, Type, List, Optional from typing import Dict, Type, List, Optional, Tuple
from .documentRendererBaseTemplate import BaseRenderer from .documentRendererBaseTemplate import BaseRenderer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RendererRegistry: class RendererRegistry:
"""Registry for automatic renderer discovery and management.""" """Registry for automatic renderer discovery and management.
Maintains separate renderer mappings per outputStyle ('document', 'code', etc.)
so that document-generation and code-generation paths each resolve to the
correct renderer, even when both support the same format (e.g. 'csv').
"""
def __init__(self): def __init__(self):
self._renderers: Dict[str, Type[BaseRenderer]] = {} # Key: (formatName, outputStyle) -> rendererClass
self._renderers: Dict[Tuple[str, str], Type[BaseRenderer]] = {}
self._format_mappings: Dict[str, str] = {} self._format_mappings: Dict[str, str] = {}
self._discovered = False self._discovered = False
@ -25,39 +35,27 @@ class RendererRegistry:
return return
try: try:
import os
import sys
from pathlib import Path from pathlib import Path
# Get the directory containing this registry file
currentDir = Path(__file__).parent currentDir = Path(__file__).parent
renderersDir = currentDir
# Get the package name dynamically
packageName = __name__.rsplit('.', 1)[0] packageName = __name__.rsplit('.', 1)[0]
# Scan all Python files in the renderers directory for filePath in currentDir.glob("*.py"):
for filePath in renderersDir.glob("*.py"): if filePath.name in ['registry.py', 'documentRendererBaseTemplate.py', 'codeRendererBaseTemplate.py', '__init__.py']:
if filePath.name in ['registry.py', 'documentRendererBaseTemplate.py', '__init__.py']:
continue continue
# Extract module name from filename
moduleName = filePath.stem moduleName = filePath.stem
try: try:
# Import the module dynamically
fullModuleName = f"{packageName}.{moduleName}" fullModuleName = f"{packageName}.{moduleName}"
module = importlib.import_module(fullModuleName) module = importlib.import_module(fullModuleName)
# Look for renderer classes in the module
for attrName in dir(module): for attrName in dir(module):
attr = getattr(module, attrName) attr = getattr(module, attrName)
if (isinstance(attr, type) and if (isinstance(attr, type) and
issubclass(attr, BaseRenderer) and issubclass(attr, BaseRenderer) and
attr != BaseRenderer and attr != BaseRenderer and
hasattr(attr, 'getSupportedFormats')): hasattr(attr, 'getSupportedFormats')):
# Register the renderer
self._registerRendererClass(attr) self._registerRendererClass(attr)
except Exception as e: except Exception as e:
@ -68,60 +66,75 @@ class RendererRegistry:
except Exception as e: except Exception as e:
logger.error(f"Error during renderer discovery: {str(e)}") logger.error(f"Error during renderer discovery: {str(e)}")
self._discovered = True # Mark as discovered to avoid repeated attempts self._discovered = True
def _registerRendererClass(self, rendererClass: Type[BaseRenderer]) -> None: def _registerRendererClass(self, rendererClass: Type[BaseRenderer]) -> None:
"""Register a renderer class with its supported formats.""" """Register a renderer class keyed by (format, outputStyle)."""
try: try:
# Get supported formats from the renderer class
supportedFormats = rendererClass.getSupportedFormats() supportedFormats = rendererClass.getSupportedFormats()
outputStyle = rendererClass.getOutputStyle() if hasattr(rendererClass, 'getOutputStyle') else 'document'
# Get priority (default to 0 if not specified)
priority = rendererClass.getPriority() if hasattr(rendererClass, 'getPriority') else 0 priority = rendererClass.getPriority() if hasattr(rendererClass, 'getPriority') else 0
for formatName in supportedFormats: for formatName in supportedFormats:
formatKey = formatName.lower() formatKey = formatName.lower()
registryKey = (formatKey, outputStyle)
# Check if format already registered - use priority to decide if registryKey in self._renderers:
if formatKey in self._renderers: existingRenderer = self._renderers[registryKey]
existingRenderer = self._renderers[formatKey]
existingPriority = existingRenderer.getPriority() if hasattr(existingRenderer, 'getPriority') else 0 existingPriority = existingRenderer.getPriority() if hasattr(existingRenderer, 'getPriority') else 0
# Only replace if new renderer has higher priority
if priority > existingPriority: if priority > existingPriority:
logger.debug(f"Replacing {existingRenderer.__name__} with {rendererClass.__name__} for format '{formatName}' (priority {priority} > {existingPriority})") logger.debug(f"Replacing {existingRenderer.__name__} with {rendererClass.__name__} for ({formatKey}, {outputStyle}) (priority {priority} > {existingPriority})")
self._renderers[formatKey] = rendererClass self._renderers[registryKey] = rendererClass
else: else:
logger.debug(f"Keeping {existingRenderer.__name__} for format '{formatName}' (priority {existingPriority} >= {priority})") logger.debug(f"Keeping {existingRenderer.__name__} for ({formatKey}, {outputStyle}) (priority {existingPriority} >= {priority})")
else: else:
# Register primary format self._renderers[registryKey] = rendererClass
self._renderers[formatKey] = rendererClass
# Register aliases if any # Register aliases
if hasattr(rendererClass, 'getFormatAliases'): if hasattr(rendererClass, 'getFormatAliases'):
aliases = rendererClass.getFormatAliases() aliases = rendererClass.getFormatAliases()
for alias in aliases: for alias in aliases:
self._format_mappings[alias.lower()] = formatName.lower() self._format_mappings[alias.lower()] = formatKey
logger.debug(f"Registered {rendererClass.__name__} for formats: {supportedFormats} (priority: {priority})") logger.debug(f"Registered {rendererClass.__name__} for formats={supportedFormats}, style={outputStyle}, priority={priority}")
except Exception as e: except Exception as e:
logger.error(f"Error registering renderer {rendererClass.__name__}: {str(e)}") logger.error(f"Error registering renderer {rendererClass.__name__}: {str(e)}")
def getRenderer(self, outputFormat: str, services=None) -> Optional[BaseRenderer]: def getRenderer(self, outputFormat: str, services=None, outputStyle: str = None) -> Optional[BaseRenderer]:
"""Get a renderer instance for the specified format.""" """Get a renderer instance for the specified format and style.
Args:
outputFormat: Format name (e.g. 'csv', 'json', 'pdf')
services: Services instance passed to renderer constructor
outputStyle: 'document' or 'code'. If None, returns the first match
with preference: document > code (most callers are document path).
"""
if not self._discovered: if not self._discovered:
self.discoverRenderers() self.discoverRenderers()
# Normalize format name
formatName = outputFormat.lower().strip() formatName = outputFormat.lower().strip()
# Check for aliases first
if formatName in self._format_mappings: if formatName in self._format_mappings:
formatName = self._format_mappings[formatName] formatName = self._format_mappings[formatName]
# Get renderer class rendererClass = None
rendererClass = self._renderers.get(formatName)
if outputStyle:
# Exact match by style
rendererClass = self._renderers.get((formatName, outputStyle))
else:
# No style specified — prefer 'document', then 'code', then any
for style in ['document', 'code']:
rendererClass = self._renderers.get((formatName, style))
if rendererClass:
break
# Fallback: check any registered style
if not rendererClass:
for key, cls in self._renderers.items():
if key[0] == formatName:
rendererClass = cls
break
if rendererClass: if rendererClass:
try: try:
@ -130,7 +143,7 @@ class RendererRegistry:
logger.error(f"Error creating renderer instance for {formatName}: {str(e)}") logger.error(f"Error creating renderer instance for {formatName}: {str(e)}")
return None return None
logger.warning(f"No renderer found for format: {outputFormat}") logger.warning(f"No renderer found for format={outputFormat}, style={outputStyle}")
return None return None
def getSupportedFormats(self) -> List[str]: def getSupportedFormats(self) -> List[str]:
@ -138,9 +151,11 @@ class RendererRegistry:
if not self._discovered: if not self._discovered:
self.discoverRenderers() self.discoverRenderers()
formats = list(self._renderers.keys()) formats = set()
formats.extend(self._format_mappings.keys()) for (fmt, _style) in self._renderers.keys():
return sorted(set(formats)) formats.add(fmt)
formats.update(self._format_mappings.keys())
return sorted(formats)
def getRendererInfo(self) -> Dict[str, Dict[str, str]]: def getRendererInfo(self) -> Dict[str, Dict[str, str]]:
"""Get information about all registered renderers.""" """Get information about all registered renderers."""
@ -148,10 +163,12 @@ class RendererRegistry:
self.discoverRenderers() self.discoverRenderers()
info = {} info = {}
for formatName, rendererClass in self._renderers.items(): for (formatName, style), rendererClass in self._renderers.items():
info[formatName] = { key = f"{formatName}:{style}"
info[key] = {
'class_name': rendererClass.__name__, 'class_name': rendererClass.__name__,
'module': rendererClass.__module__, 'module': rendererClass.__module__,
'outputStyle': style,
'description': getattr(rendererClass, '__doc__', 'No description').strip().split('\n')[0] if rendererClass.__doc__ else 'No description' 'description': getattr(rendererClass, '__doc__', 'No description').strip().split('\n')[0] if rendererClass.__doc__ else 'No description'
} }
@ -160,44 +177,62 @@ class RendererRegistry:
def getOutputStyle(self, outputFormat: str) -> Optional[str]: def getOutputStyle(self, outputFormat: str) -> Optional[str]:
""" """
Get the output style classification for a given format. Get the output style classification for a given format.
Returns: 'code', 'document', 'image', or other (e.g., 'video' for future use) When both 'document' and 'code' renderers exist for a format,
returns the default ('document') since this is called during document generation.
""" """
if not self._discovered: if not self._discovered:
self.discoverRenderers() self.discoverRenderers()
# Normalize format name
formatName = outputFormat.lower().strip() formatName = outputFormat.lower().strip()
# Check for aliases first
if formatName in self._format_mappings: if formatName in self._format_mappings:
formatName = self._format_mappings[formatName] formatName = self._format_mappings[formatName]
# Get renderer class and call getOutputStyle (all renderers have same signature) # Check document first, then code
rendererClass = self._renderers.get(formatName) for style in ['document', 'code']:
try: rendererClass = self._renderers.get((formatName, style))
return rendererClass.getOutputStyle(formatName) if rendererClass:
except (AttributeError, TypeError) as e: try:
logger.warning(f"No renderer found for format: {outputFormat}, cannot determine output style") return rendererClass.getOutputStyle(formatName)
return None except Exception:
except Exception as e: pass
logger.warning(f"Error getting output style for {outputFormat}: {str(e)}")
return None # Fallback: any style
for key, rendererClass in self._renderers.items():
if key[0] == formatName:
try:
return rendererClass.getOutputStyle(formatName)
except Exception:
pass
logger.warning(f"No renderer found for format: {outputFormat}, cannot determine output style")
return None
# Global registry instance # Global registry instance
_registry = RendererRegistry() _registry = RendererRegistry()
def getRenderer(outputFormat: str, services=None) -> Optional[BaseRenderer]:
"""Get a renderer instance for the specified format.""" def getRenderer(outputFormat: str, services=None, outputStyle: str = None) -> Optional[BaseRenderer]:
return _registry.getRenderer(outputFormat, services) """Get a renderer instance for the specified format and style.
Args:
outputFormat: Format name (e.g. 'csv', 'json', 'pdf')
services: Services instance
outputStyle: 'document' or 'code'. If None, prefers document renderer.
"""
return _registry.getRenderer(outputFormat, services, outputStyle=outputStyle)
def getSupportedFormats() -> List[str]: def getSupportedFormats() -> List[str]:
"""Get list of all supported formats.""" """Get list of all supported formats."""
return _registry.getSupportedFormats() return _registry.getSupportedFormats()
def getRendererInfo() -> Dict[str, Dict[str, str]]: def getRendererInfo() -> Dict[str, Dict[str, str]]:
"""Get information about all registered renderers.""" """Get information about all registered renderers."""
return _registry.getRendererInfo() return _registry.getRendererInfo()
def getOutputStyle(outputFormat: str) -> Optional[str]: def getOutputStyle(outputFormat: str) -> Optional[str]:
"""Get the output style classification for a given format.""" """Get the output style classification for a given format."""
return _registry.getOutputStyle(outputFormat) return _registry.getOutputStyle(outputFormat)

View file

@ -35,9 +35,9 @@ class RendererCsv(BaseRenderer):
def getAcceptedSectionTypes(cls, formatName: Optional[str] = None) -> List[str]: def getAcceptedSectionTypes(cls, formatName: Optional[str] = None) -> List[str]:
""" """
Return list of section content types that CSV renderer accepts. Return list of section content types that CSV renderer accepts.
CSV renderer only accepts table sections. CSV renderer accepts table sections and code_block sections (for raw CSV content).
""" """
return ["table"] return ["table", "code_block"]
async def render(self, extractedContent: Dict[str, Any], title: str, userPrompt: str = None, aiService=None) -> List[RenderedDocument]: async def render(self, extractedContent: Dict[str, Any], title: str, userPrompt: str = None, aiService=None) -> List[RenderedDocument]:
"""Render extracted JSON content to CSV format. Produces one CSV file per table section.""" """Render extracted JSON content to CSV format. Produces one CSV file per table section."""
@ -62,16 +62,24 @@ class RendererCsv(BaseRenderer):
if baseFilename.endswith('.csv'): if baseFilename.endswith('.csv'):
baseFilename = baseFilename[:-4] baseFilename = baseFilename[:-4]
# Find all table sections # Collect CSV-producing sections: table sections AND code_block sections with CSV language
tableSections = [] tableSections = []
codeBlockCsvSections = []
for section in sections: for section in sections:
sectionType = section.get("content_type", "paragraph") sectionType = section.get("content_type", "paragraph")
if sectionType == "table": if sectionType == "table":
tableSections.append(section) tableSections.append(section)
elif sectionType == "code_block":
# Check if any element is a code_block with language "csv"
for element in section.get("elements", []):
content = element.get("content", {})
if isinstance(content, dict) and content.get("language", "").lower() == "csv":
codeBlockCsvSections.append(section)
break
# If no table sections found, return empty CSV # If no usable sections found, return empty CSV
if not tableSections: if not tableSections and not codeBlockCsvSections:
self.logger.warning("No table sections found in CSV document - returning empty CSV") self.logger.warning("No table or CSV code_block sections found in CSV document - returning empty CSV")
emptyCsv = self._convertRowsToCsv([["No table data available"]]) emptyCsv = self._convertRowsToCsv([["No table data available"]])
return [ return [
RenderedDocument( RenderedDocument(
@ -83,45 +91,52 @@ class RendererCsv(BaseRenderer):
) )
] ]
# Generate one CSV file per table section allCsvSections = tableSections + codeBlockCsvSections
# Generate one CSV file per section
renderedDocuments = [] renderedDocuments = []
for i, tableSection in enumerate(tableSections): for i, csvSection in enumerate(allCsvSections):
# Generate CSV content for this table section sectionType = csvSection.get("content_type", "paragraph")
csvRows = [] sectionTitle = csvSection.get("title")
csvContent = ""
# Add section title if available if sectionType == "code_block":
sectionTitle = tableSection.get("title") # Extract raw CSV content directly from code_block elements
if sectionTitle: rawCsvParts = []
csvRows.append([sectionTitle]) for element in csvSection.get("elements", []):
csvRows.append([]) # Empty row after title content = element.get("content", {})
if isinstance(content, dict) and content.get("language", "").lower() == "csv":
code = content.get("code", "")
if code:
rawCsvParts.append(code)
csvContent = "\n".join(rawCsvParts)
else:
# Table section — render via table logic
csvRows = []
if sectionTitle:
csvRows.append([sectionTitle])
csvRows.append([]) # Empty row after title
elements = csvSection.get("elements", [])
for element in elements:
tableRows = self._renderJsonTableToCsv(element)
if tableRows:
csvRows.extend(tableRows)
csvContent = self._convertRowsToCsv(csvRows)
# Render table from section elements # Determine filename
elements = tableSection.get("elements", []) if len(allCsvSections) == 1:
for element in elements:
tableRows = self._renderJsonTableToCsv(element)
if tableRows:
csvRows.extend(tableRows)
# Convert to CSV string
csvContent = self._convertRowsToCsv(csvRows)
# Determine filename for this table
if len(tableSections) == 1:
# Single table - use base filename
filename = f"{baseFilename}.csv" filename = f"{baseFilename}.csv"
else: else:
# Multiple tables - add index or section title to filename sectionId = csvSection.get("id", f"csv_{i+1}")
sectionId = tableSection.get("id", f"table_{i+1}")
# Use section title if available, otherwise use section ID
if sectionTitle: if sectionTitle:
# Sanitize section title for filename
safeTitle = "".join(c for c in sectionTitle if c.isalnum() or c in (' ', '-', '_')).strip() safeTitle = "".join(c for c in sectionTitle if c.isalnum() or c in (' ', '-', '_')).strip()
safeTitle = safeTitle.replace(' ', '_')[:30] # Limit length safeTitle = safeTitle.replace(' ', '_')[:30]
filename = f"{baseFilename}_{safeTitle}.csv" filename = f"{baseFilename}_{safeTitle}.csv"
else: else:
filename = f"{baseFilename}_{sectionId}.csv" filename = f"{baseFilename}_{sectionId}.csv"
# Extract document type from metadata
documentType = metadata.get("documentType") if isinstance(metadata, dict) else None documentType = metadata.get("documentType") if isinstance(metadata, dict) else None
renderedDocuments.append( renderedDocuments.append(

View file

@ -9,7 +9,6 @@ Features can register callbacks to be notified when automations change.
import logging import logging
from typing import Callable, List, Dict, Any from typing import Callable, List, Dict, Any
import asyncio
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -25,7 +24,7 @@ class CallbackRegistry:
Args: Args:
event_type: Type of event (e.g., 'automation.changed') event_type: Type of event (e.g., 'automation.changed')
callback: Async or sync callback function callback: Sync callback function
""" """
if event_type not in self._callbacks: if event_type not in self._callbacks:
self._callbacks[event_type] = [] self._callbacks[event_type] = []
@ -41,8 +40,8 @@ class CallbackRegistry:
except ValueError: except ValueError:
logger.warning(f"Callback not found for event type: {event_type}") logger.warning(f"Callback not found for event type: {event_type}")
async def trigger(self, event_type: str, *args, **kwargs): def trigger(self, event_type: str, *args, **kwargs):
"""Trigger all callbacks registered for an event type. """Trigger all registered callbacks for an event type.
Args: Args:
event_type: Type of event to trigger event_type: Type of event to trigger
@ -55,18 +54,14 @@ class CallbackRegistry:
for callback in callbacks: for callback in callbacks:
try: try:
if asyncio.iscoroutinefunction(callback): callback(*args, **kwargs)
await callback(*args, **kwargs)
else:
callback(*args, **kwargs)
except Exception as e: except Exception as e:
logger.error(f"Error executing callback for {event_type}: {str(e)}", exc_info=True) logger.error(f"Error executing callback for {event_type}: {str(e)}", exc_info=True)
def has_callbacks(self, event_type: str) -> bool: def hasCallbacks(self, event_type: str) -> bool:
"""Check if there are any callbacks registered for an event type.""" """Check if there are any callbacks registered for an event type."""
return event_type in self._callbacks and len(self._callbacks[event_type]) > 0 return event_type in self._callbacks and len(self._callbacks[event_type]) > 0
# Global singleton instance # Global singleton instance
callbackRegistry = CallbackRegistry() callbackRegistry = CallbackRegistry()

View file

@ -38,16 +38,14 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode
featureCode: Feature code (e.g., 'chatplayground', 'automation') featureCode: Feature code (e.g., 'chatplayground', 'automation')
""" """
try: try:
services = getServices(currentUser, mandateId=mandateId) services = getServices(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId)
# Store allowedProviders in services context for model selection # Store allowedProviders in services context for model selection
if hasattr(userInput, 'allowedProviders') and userInput.allowedProviders: if hasattr(userInput, 'allowedProviders') and userInput.allowedProviders:
services.allowedProviders = userInput.allowedProviders services.allowedProviders = userInput.allowedProviders
logger.info(f"AI provider filter active: {userInput.allowedProviders}") logger.info(f"AI provider filter active: {userInput.allowedProviders}")
# Store feature context in services (for billing and RBAC) # Store feature code in services (for billing)
if featureInstanceId:
services.featureInstanceId = featureInstanceId
if featureCode: if featureCode:
services.featureCode = featureCode services.featureCode = featureCode
@ -61,10 +59,8 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode
async def chatStop(currentUser: User, workflowId: str, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> ChatWorkflow: async def chatStop(currentUser: User, workflowId: str, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> ChatWorkflow:
"""Stops a running chat.""" """Stops a running chat."""
try: try:
services = getServices(currentUser, mandateId=mandateId) services = getServices(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId)
# Store feature instance ID in services context for proper RBAC filtering
if featureInstanceId: if featureInstanceId:
services.featureInstanceId = featureInstanceId
services.featureCode = 'chatplayground' services.featureCode = 'chatplayground'
workflowManager = WorkflowManager(services) workflowManager = WorkflowManager(services)
return await workflowManager.workflowStop(workflowId) return await workflowManager.workflowStop(workflowId)
@ -73,12 +69,17 @@ async def chatStop(currentUser: User, workflowId: str, mandateId: Optional[str]
raise raise
async def executeAutomation(automationId: str, services) -> ChatWorkflow: async def executeAutomation(automationId: str, automation, creatorUser: User, services) -> ChatWorkflow:
"""Execute automation workflow immediately (test mode) with placeholder replacement. """Execute automation workflow with the creator user's context.
The automation object and creatorUser are resolved by the caller (handler)
using the SysAdmin eventUser. This function does NOT re-load them.
Args: Args:
automationId: ID of automation to execute automationId: ID of automation to execute
services: Services instance for data access automation: Pre-loaded automation object (with system fields like _createdBy)
creatorUser: The user who created the automation (workflow runs in this context)
services: Services instance (used for interfaceDbApp etc.)
Returns: Returns:
ChatWorkflow instance created by automation execution ChatWorkflow instance created by automation execution
@ -92,11 +93,6 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
} }
try: try:
# 1. Load automation definition (with system fields for _createdBy access)
automation = services.interfaceDbAutomation.getAutomationDefinition(automationId, includeSystemFields=True)
if not automation:
raise ValueError(f"Automation {automationId} not found")
executionLog["messages"].append(f"Started execution at {executionStartTime}") executionLog["messages"].append(f"Started execution at {executionStartTime}")
# Store allowed providers from automation in services context # Store allowed providers from automation in services context
@ -105,12 +101,12 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
logger.debug(f"Automation {automationId} restricted to providers: {automation.allowedProviders}") logger.debug(f"Automation {automationId} restricted to providers: {automation.allowedProviders}")
# Context comes EXCLUSIVELY from the automation definition # Context comes EXCLUSIVELY from the automation definition
services.mandateId = str(automation.mandateId) automationMandateId = str(automation.mandateId)
services.featureInstanceId = str(automation.featureInstanceId) automationFeatureInstanceId = str(automation.featureInstanceId)
services.featureCode = 'automation'
featureInstanceId = services.featureInstanceId
# 2. Replace placeholders in template to generate plan logger.info(f"Executing automation {automationId} as user {creatorUser.id} with mandateId={automationMandateId}, featureInstanceId={automationFeatureInstanceId}")
# 1. Replace placeholders in template to generate plan
template = automation.template or "" template = automation.template or ""
placeholders = automation.placeholders or {} placeholders = automation.placeholders or {}
planJson = replacePlaceholders(template, placeholders) planJson = replacePlaceholders(template, placeholders)
@ -128,24 +124,9 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
logger.error(f"Context around error: ...{planJson[start:end]}...") logger.error(f"Context around error: ...{planJson[start:end]}...")
raise ValueError(f"Invalid JSON after placeholder replacement: {str(e)}") raise ValueError(f"Invalid JSON after placeholder replacement: {str(e)}")
executionLog["messages"].append("Template placeholders replaced successfully") executionLog["messages"].append("Template placeholders replaced successfully")
executionLog["messages"].append(f"Using creator user: {creatorUser.id}")
# 3. Get user who created automation # 2. Create UserInputRequest from plan
creatorUserId = getattr(automation, "_createdBy", None)
# _createdBy is a system attribute - must be present
if not creatorUserId:
errorMsg = f"Automation {automationId} has no creator user (_createdBy field missing). Cannot execute automation."
logger.error(errorMsg)
executionLog["messages"].append(errorMsg)
raise ValueError(errorMsg)
# Get creator user from database
creatorUser = services.interfaceDbApp.getUser(creatorUserId)
if not creatorUser:
raise ValueError(f"Creator user {creatorUserId} not found")
executionLog["messages"].append(f"Using creator user: {creatorUserId}")
# 4. Create UserInputRequest from plan
# Embed plan JSON in prompt for TemplateMode to extract # Embed plan JSON in prompt for TemplateMode to extract
promptText = planToPrompt(plan) promptText = planToPrompt(plan)
planJsonStr = json.dumps(plan) planJsonStr = json.dumps(plan)
@ -160,16 +141,15 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
executionLog["messages"].append("Starting workflow execution") executionLog["messages"].append("Starting workflow execution")
# 5. Start workflow using chatStart # 3. Start workflow using chatStart with creator's context
# Pass mandateId, featureInstanceId, and featureCode from original services context # mandateId and featureInstanceId come from the automation definition
# so billing is recorded correctly with full feature context
workflow = await chatStart( workflow = await chatStart(
currentUser=creatorUser, currentUser=creatorUser,
userInput=userInput, userInput=userInput,
workflowMode=WorkflowModeEnum.WORKFLOW_AUTOMATION, workflowMode=WorkflowModeEnum.WORKFLOW_AUTOMATION,
workflowId=None, workflowId=None,
mandateId=services.mandateId, mandateId=automationMandateId,
featureInstanceId=featureInstanceId, featureInstanceId=automationFeatureInstanceId,
featureCode='automation' featureCode='automation'
) )
@ -200,22 +180,22 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
executionLog["messages"].append(f"Error: {str(e)}") executionLog["messages"].append(f"Error: {str(e)}")
# Save execution log even on error (bypasses RBAC — system operation) # Save execution log even on error (bypasses RBAC — system operation)
# Use the automation object already passed in (no re-load needed)
try: try:
automation = services.interfaceDbAutomation.getAutomationDefinition(automationId) executionLogs = list(getattr(automation, 'executionLogs', None) or [])
if automation: executionLogs.append(executionLog)
executionLogs = list(automation.executionLogs or []) if len(executionLogs) > 50:
executionLogs.append(executionLog) executionLogs = executionLogs[-50:]
if len(executionLogs) > 50: services.interfaceDbAutomation._saveExecutionLog(automationId, executionLogs)
executionLogs = executionLogs[-50:]
services.interfaceDbAutomation._saveExecutionLog(automationId, executionLogs)
except Exception as logError: except Exception as logError:
logger.error(f"Error saving execution log: {str(logError)}") logger.error(f"Error saving execution log: {str(logError)}")
raise raise
async def syncAutomationEvents(services, eventUser) -> Dict[str, Any]: def syncAutomationEvents(services, eventUser) -> Dict[str, Any]:
"""Automation event handler - syncs scheduler with all active automations. """Sync scheduler with all active automations.
All operations (DB reads, scheduler registration) are synchronous.
Args: Args:
services: Services instance for data access services: Services instance for data access
@ -316,37 +296,28 @@ def createAutomationEventHandler(automationId: str, eventUser):
logger.error("Event user not available for automation execution") logger.error("Event user not available for automation execution")
return return
# Get services for event user (provides access to interfaces) # Load automation using SysAdmin eventUser (has unrestricted access)
eventServices = getServices(eventUser, None) eventServices = getServices(eventUser, None)
# Load automation using event user context (with system fields for _createdBy access)
automation = eventServices.interfaceDbAutomation.getAutomationDefinition(automationId, includeSystemFields=True) automation = eventServices.interfaceDbAutomation.getAutomationDefinition(automationId, includeSystemFields=True)
if not automation or not getattr(automation, "active", False): if not automation or not getattr(automation, "active", False):
logger.warning(f"Automation {automationId} not found or not active, skipping execution") logger.warning(f"Automation {automationId} not found or not active, skipping execution")
return return
# Get creator user # Get creator user ID from automation's _createdBy system field
creatorUserId = getattr(automation, "_createdBy", None) creatorUserId = getattr(automation, "_createdBy", None)
if not creatorUserId: if not creatorUserId:
logger.error(f"Automation {automationId} has no creator user") logger.error(f"Automation {automationId} has no creator user (_createdBy missing)")
return return
# Get mandate context from automation definition # Get creator user from database (using SysAdmin access)
automationMandateId = getattr(automation, "mandateId", None)
# Get creator user from database using services
eventServices = getServices(eventUser, None)
creatorUser = eventServices.interfaceDbApp.getUser(creatorUserId) creatorUser = eventServices.interfaceDbApp.getUser(creatorUserId)
if not creatorUser: if not creatorUser:
logger.error(f"Creator user {creatorUserId} not found for automation {automationId}") logger.error(f"Creator user {creatorUserId} not found for automation {automationId}")
return return
# Get services for creator user WITH mandate context from automation # Execute automation — pass automation object and creatorUser directly
creatorServices = getServices(creatorUser, automationMandateId) # No re-load needed in executeAutomation
await executeAutomation(automationId, automation, creatorUser, eventServices)
# Execute automation with creator user's context and mandate
# executeAutomation is in same module, so we can call it directly
await executeAutomation(automationId, creatorServices)
logger.info(f"Successfully executed automation {automationId} as user {creatorUserId}") logger.info(f"Successfully executed automation {automationId} as user {creatorUserId}")
except Exception as e: except Exception as e:
logger.error(f"Error executing automation {automationId}: {str(e)}") logger.error(f"Error executing automation {automationId}: {str(e)}")

View file

@ -14,9 +14,10 @@ from modules.services import getInterface as getServices
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def start(eventUser) -> None: def start(eventUser) -> bool:
""" """
Start automation scheduler and sync scheduled events. Start automation scheduler and sync scheduled events.
All operations are synchronous (DB access, scheduler registration).
Args: Args:
eventUser: System-level event user for background operations (provided by app.py) eventUser: System-level event user for background operations (provided by app.py)
@ -33,16 +34,16 @@ async def start(eventUser) -> None:
services = getServices(eventUser, None) services = getServices(eventUser, None)
# Register callback for automation changes # Register callback for automation changes
async def onAutomationChanged(chatInterface): def onAutomationChanged(chatInterface):
"""Callback triggered when automations are created/updated/deleted.""" """Callback triggered when automations are created/updated/deleted."""
eventServices = getServices(eventUser, None) eventServices = getServices(eventUser, None)
await syncAutomationEvents(eventServices, eventUser) syncAutomationEvents(eventServices, eventUser)
callbackRegistry.register('automation.changed', onAutomationChanged) callbackRegistry.register('automation.changed', onAutomationChanged)
logger.info("Automation: Registered change callback") logger.info("Automation: Registered change callback")
# Initial sync on startup # Initial sync on startup
await syncAutomationEvents(services, eventUser) syncAutomationEvents(services, eventUser)
logger.info("Automation: Scheduled events synced on startup") logger.info("Automation: Scheduled events synced on startup")
except Exception as e: except Exception as e:
@ -52,7 +53,7 @@ async def start(eventUser) -> None:
return True return True
async def stop(eventUser) -> None: def stop(eventUser) -> bool:
""" """
Stop automation scheduler. Stop automation scheduler.

View file

@ -139,11 +139,16 @@ class MethodBase:
return False return False
# RBAC-Check: RESOURCE context, item = actionId # RBAC-Check: RESOURCE context, item = actionId
# mandateId/featureInstanceId from services context needed to resolve user roles
try: try:
mandateId = getattr(self.services, 'mandateId', None)
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
permissions = self.services.rbac.getUserPermissions( permissions = self.services.rbac.getUserPermissions(
user=currentUser, user=currentUser,
context=AccessRuleContext.RESOURCE, context=AccessRuleContext.RESOURCE,
item=actionId item=actionId,
mandateId=str(mandateId) if mandateId else None,
featureInstanceId=str(featureInstanceId) if featureInstanceId else None
) )
hasPermission = permissions.view hasPermission = permissions.view
if not hasPermission: if not hasPermission:
@ -151,8 +156,9 @@ class MethodBase:
userRoles = getattr(currentUser, 'roleLabels', []) or [] userRoles = getattr(currentUser, 'roleLabels', []) or []
self.logger.warning( self.logger.warning(
f"RBAC denied action {actionId} for user {currentUser.id}. " f"RBAC denied action {actionId} for user {currentUser.id}. "
f"User roles: {userRoles}, " f"User roles: {userRoles}, mandateId={mandateId}, "
f"Permissions: view={permissions.view}, edit={permissions.edit}, delete={permissions.delete}. " f"Permissions: view={permissions.view}, read={permissions.read}, "
f"create={permissions.create}, update={permissions.update}, delete={permissions.delete}. "
f"No matching RBAC rule found for context=RESOURCE, item={actionId}" f"No matching RBAC rule found for context=RESOURCE, item={actionId}"
) )
return hasPermission return hasPermission