Merge pull request #92 from valueonag/feat/cost-control
Feat/cost control
This commit is contained in:
commit
18cdbe4c5c
105 changed files with 9311 additions and 2560 deletions
46
app.py
46
app.py
|
|
@ -286,6 +286,15 @@ instanceLabel = APP_CONFIG.get("APP_ENV_LABEL")
|
|||
async def lifespan(app: FastAPI):
|
||||
logger.info("Application is starting up")
|
||||
|
||||
# --- Register RBAC catalog for features (moved here from loadFeatureRouters for single-pass loading) ---
|
||||
try:
|
||||
from modules.security.rbacCatalog import getCatalogService
|
||||
from modules.system.registry import registerAllFeaturesInCatalog
|
||||
catalogService = getCatalogService()
|
||||
registerAllFeaturesInCatalog(catalogService)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not register feature RBAC catalog: {e}")
|
||||
|
||||
# Get event user for feature lifecycle (system-level user for background operations)
|
||||
rootInterface = getRootInterface()
|
||||
eventUser = rootInterface.getUserByUsername("event")
|
||||
|
|
@ -306,18 +315,37 @@ async def lifespan(app: FastAPI):
|
|||
logger.warning(f"Could not initialize feature containers: {e}")
|
||||
|
||||
# --- Init Managers ---
|
||||
await subAutomationSchedule.start(eventUser) # Automation scheduler
|
||||
subAutomationSchedule.start(eventUser) # Automation scheduler
|
||||
eventManager.start()
|
||||
|
||||
# Register audit log cleanup scheduler
|
||||
from modules.shared.auditLogger import registerAuditLogCleanupScheduler
|
||||
registerAuditLogCleanupScheduler()
|
||||
|
||||
# Ensure billing settings and accounts exist for all mandates
|
||||
try:
|
||||
from modules.interfaces.interfaceDbBilling import _getRootInterface as getBillingRootInterface
|
||||
|
||||
billingInterface = getBillingRootInterface()
|
||||
|
||||
# Step 1: Ensure all mandates have billing settings (creates defaults if missing)
|
||||
settingsCreated = billingInterface.ensureAllMandateSettingsExist()
|
||||
if settingsCreated > 0:
|
||||
logger.info(f"Billing startup: Created {settingsCreated} missing mandate billing settings")
|
||||
|
||||
# Step 2: Ensure all users have billing accounts (for PREPAY_USER mandates)
|
||||
accountsCreated = billingInterface.ensureAllUserAccountsExist()
|
||||
if accountsCreated > 0:
|
||||
logger.info(f"Billing startup: Created {accountsCreated} missing user accounts")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to ensure billing settings/accounts (non-critical): {e}")
|
||||
|
||||
yield
|
||||
|
||||
# --- Stop Managers ---
|
||||
eventManager.stop()
|
||||
await subAutomationSchedule.stop(eventUser) # Automation scheduler
|
||||
subAutomationSchedule.stop(eventUser) # Automation scheduler
|
||||
|
||||
# --- Stop Feature Containers (Plug&Play) ---
|
||||
try:
|
||||
|
|
@ -404,10 +432,16 @@ def getAllowedOrigins():
|
|||
return origins
|
||||
|
||||
|
||||
# CORS origin regex pattern for wildcard subdomain support
|
||||
# Matches all subdomains of poweron.swiss and poweron-center.net
|
||||
CORS_ORIGIN_REGEX = r"https://.*\.(poweron\.swiss|poweron-center\.net)"
|
||||
|
||||
|
||||
# CORS configuration using environment variables
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=getAllowedOrigins(),
|
||||
allow_origin_regex=CORS_ORIGIN_REGEX,
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
|
|
@ -485,7 +519,6 @@ app.include_router(rbacAdminRulesRouter)
|
|||
from modules.routes.routeMessaging import router as messagingRouter
|
||||
app.include_router(messagingRouter)
|
||||
|
||||
# Phase 8: New Feature Routes
|
||||
from modules.routes.routeAdminFeatures import router as featuresAdminRouter
|
||||
app.include_router(featuresAdminRouter)
|
||||
|
||||
|
|
@ -504,11 +537,8 @@ app.include_router(userAccessOverviewRouter)
|
|||
from modules.routes.routeGdpr import router as gdprRouter
|
||||
app.include_router(gdprRouter)
|
||||
|
||||
from modules.routes.routeChat import router as chatRouter
|
||||
app.include_router(chatRouter)
|
||||
|
||||
from modules.features.chatbot.routeFeatureChatbot import router as chatbotFeatureRouter
|
||||
app.include_router(chatbotFeatureRouter)
|
||||
from modules.routes.routeBilling import router as billingRouter
|
||||
app.include_router(billingRouter)
|
||||
|
||||
# ============================================================================
|
||||
# SYSTEM ROUTES (Navigation, etc.)
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ Connector_AiOpenai_API_SECRET = DEV_ENC:Z0FBQUFBQnBaSnM4TWFRRmxVQmNQblVIYmc1Y0Q3
|
|||
Connector_AiAnthropic_API_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpENmFBWG16STFQUVZxNzZZRzRLYTA4X3lRanF1VkF4cU45OExNMzlsQmdISGFxTUxud1dXODBKcFhMVG9KNjdWVnlTTFFROVc3NDlsdlNHLUJXeG41NDBHaXhHR0VHVWl5UW9RNkVWbmlhakRKVW5pM0R4VHk0LUw0TV9LdkljNHdBLXJua21NQkl2b3l4UkVkMGN1YjBrMmJEeWtMay1jbmxrYWJNbUV0aktCXzU1djR2d2RSQXZORTNwcG92ZUVvVGMtQzQzTTVncEZTRGRtZUFIZWQ0dz09
|
||||
Connector_AiPerplexity_API_SECRET = DEV_ENC:Z0FBQUFBQm82Mzk2Q1MwZ0dNcUVBcUtuRDJIcTZkMXVvYnpjM3JEMzJiT1NKSHljX282ZDIyZTJYc09VSTdVNXAtOWU2UXp5S193NTk5dHJsWlFjRjhWektFOG1DVGY4ZUhHTXMzS0RPN1lNcF9nSlVWbW5BZ1hkZDVTejl6bVZNRFVvX29xamJidWRFMmtjQmkyRUQ2RUh6UTN1aWNPSUJBPT0=
|
||||
Connector_AiTavily_API_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpEQTdnUHMwd2pIaXNtMmtCTFREd0pyQXRKb1F5eGtHSnkyOGZiUnlBOFc0b3Vzcndrc3ViRm1nMDJIOEZKYWxqdWNkZGh5N0Z4R0JlQmxXSG5pVnJUR2VYckZhMWNMZ1FNeXJ3enJLVlpiblhOZTNleUg3ZzZyUzRZanFSeDlVMkI=
|
||||
Connector_AiPrivateLlm_API_SECRET = jL4vyNfh_tv4rxoRaHKW88sVWNHbj32GsxuKE2A8bf0
|
||||
|
||||
# Microsoft Service Configuration
|
||||
Service_MSFT_CLIENT_ID = c7e7112d-61dc-4f3a-8cd3-08cc4cd7504c
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ Connector_AiOpenai_API_SECRET = INT_ENC:Z0FBQUFBQnBaSnM4MENkQ2xJVmE5WFZKUkh2SHJF
|
|||
Connector_AiAnthropic_API_SECRET = INT_ENC:Z0FBQUFBQm8xSVRjT1ZlRWVJdVZMT3ljSFJDcFdxRFBRVkZhS204NnN5RDBlQ0tpenhTM0FFVktuWW9mWHNwRWx2dHB0eDBSZ0JFQnZKWlp6c01pVGREWHd1eGpERnU0Q2xhaks1clQ1ZXVsdnd2ZzhpNXNQS1BhY3FjSkdkVEhHalNaRGR4emhpakZncnpDQUVxOHVXQzVUWmtQc0FsYmFwTF9TSG5FOUFtWk5Ick1NcHFvY2s1T1c2WXlRUFFJZnh6TWhuaVpMYmppcDR0QUx0a0R6RXlwbGRYb1R4dzJkUT09
|
||||
Connector_AiPerplexity_API_SECRET = INT_ENC:Z0FBQUFBQm82Mzk2UWZJdUFhSW8yc3RKc0tKRXphd0xWMkZOVlFpSGZ4SGhFWnk0cTF5VjlKQVZjdS1QSWdkS0pUSWw4OFU5MjUxdTVQel9aeWVIZTZ5TXRuVmFkZG0zWEdTOGdHMHpsTzI0TGlWYURKU1Q0VVpKTlhxUk5FTmN6SUJScDZ3ZldIaUJZcWpaQVRiSEpyQm9tRTNDWk9KTnZBPT0=
|
||||
Connector_AiTavily_API_SECRET = INT_ENC:Z0FBQUFBQm8xSVRkdkJMTDY0akhXNzZDWHVYSEt1cDZoOWEzSktneHZEV2JndTNmWlNSMV9KbFNIZmQzeVlrNE5qUEIwcUlBSGM1a0hOZ3J6djIyOVhnZzI3M1dIUkdicl9FVXF3RGktMmlEYmhnaHJfWTdGUkktSXVUSGdQMC1vSEV6VE8zR2F1SVk=
|
||||
Connector_AiPrivateLlm_API_SECRET = jL4vyNfh_tv4rxoRaHKW88sVWNHbj32GsxuKE2A8bf0
|
||||
|
||||
# Microsoft Service Configuration
|
||||
Service_MSFT_CLIENT_ID = c7e7112d-61dc-4f3a-8cd3-08cc4cd7504c
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ Connector_AiOpenai_API_SECRET = PROD_ENC:Z0FBQUFBQnBaSnM4TWJOVm4xVkx6azRlNDdxN3U
|
|||
Connector_AiAnthropic_API_SECRET = PROD_ENC:Z0FBQUFBQnBDM1Z3TnhYdlhSLW5RbXJyMHFXX0V0bHhuTDlTaFJsRDl2dTdIUTFtVFAwTE8tY3hLbzNSMnVTLXd3RUZualN3MGNzc1kwOTIxVUN2WW1rYi1TendFRVVBSVNqRFVjckEzNExyTGNaUkJLMmozazUwemI1cnhrcEtZVXJrWkdaVFFramp3MWZ6RmY2aGlRMXVEYjM2M3ZlbmxMdnNCRDM1QWR0Wmd6MWVnS1I1c01nV3hRLXg3d2NTZXVfTi1Wdm16UnRyNGsyRTZ0bG9TQ1g1OFB5Z002bmQ3QT09
|
||||
Connector_AiPerplexity_API_SECRET = PROD_ENC:Z0FBQUFBQm82Mzk2Q1FGRkJEUkI4LXlQbHYzT2RkdVJEcmM4WGdZTWpJTEhoeUF1NW5LUVpJdDBYN3k1WFN4a2FQSWJSQmd0U0xJbzZDTmFFN05FcXl0Z3V1OEpsZjYydV94TXVjVjVXRTRYSWdLMkd5XzZIbFV6emRCZHpuOUpQeThadE5xcDNDVGV1RHJrUEN0c1BBYXctZFNWcFRuVXhRPT0=
|
||||
Connector_AiTavily_API_SECRET = PROD_ENC:Z0FBQUFBQnBDM1Z3NmItcDh6V0JpcE5Jc0NlUWZqcmllRHB5eDlNZmVnUlNVenhNTm5xWExzbjJqdE1GZ0hTSUYtb2dvdWNhTnlQNmVWQ2NGVDgwZ0MwMWZBMlNKWEhzdlF3TlZzTXhCZWM4Z1Uwb18tSTRoU1JBVTVkSkJHOTJwX291b3dPaVphVFg=
|
||||
Connector_AiPrivateLlm_API_SECRET = jL4vyNfh_tv4rxoRaHKW88sVWNHbj32GsxuKE2A8bf0
|
||||
|
||||
# Microsoft Service Configuration
|
||||
Service_MSFT_CLIENT_ID = c7e7112d-61dc-4f3a-8cd3-08cc4cd7504c
|
||||
|
|
|
|||
|
|
@ -73,12 +73,14 @@ class ModelSelector:
|
|||
contextSize = len(context.encode("utf-8"))
|
||||
totalSize = promptSize + contextSize
|
||||
# Convert bytes to approximate tokens
|
||||
# Conservative estimate: 1 token ≈ 2 bytes (for safety margin)
|
||||
# Balanced estimate: 1 token ≈ 3 bytes
|
||||
# Note: Actual tokenization varies by content type and model
|
||||
# - English text: ~4 bytes/token
|
||||
# - Structured data/JSON: ~2-3 bytes/token
|
||||
# - German/European text: ~3.5 bytes/token
|
||||
# - Structured data/JSON: ~2.5-3 bytes/token
|
||||
# - Base64/encoded data: ~1.5-2 bytes/token
|
||||
bytesPerToken = 2 # Conservative estimate for mixed content
|
||||
# Using 3 as balanced estimate (previously 2 which overestimated by ~2x)
|
||||
bytesPerToken = 3 # Balanced estimate for mixed content
|
||||
promptTokens = promptSize / bytesPerToken
|
||||
contextTokens = contextSize / bytesPerToken
|
||||
totalTokens = totalSize / bytesPerToken
|
||||
|
|
@ -98,9 +100,16 @@ class ModelSelector:
|
|||
logger.debug(f"Models with {options.operationType.value}: {[m.name for m in operationFiltered]}")
|
||||
|
||||
# Step 2: Filter by prompt size (MUST be <= 80% of context size)
|
||||
# AND by maxInputTokensPerRequest (provider rate limit / TPM)
|
||||
# Note: contextLength is in tokens, so we need to compare tokens with tokens
|
||||
promptFiltered = []
|
||||
for model in operationFiltered:
|
||||
# Check provider rate limit first (maxInputTokensPerRequest)
|
||||
maxRequestTokens = getattr(model, 'maxInputTokensPerRequest', None)
|
||||
if maxRequestTokens and maxRequestTokens > 0 and totalTokens > maxRequestTokens:
|
||||
logger.debug(f"Model {model.name} filtered out: totalTokens={totalTokens:.0f} > maxInputTokensPerRequest={maxRequestTokens} (provider rate limit)")
|
||||
continue
|
||||
|
||||
if model.contextLength == 0:
|
||||
# No context length limit - always pass
|
||||
promptFiltered.append(model)
|
||||
|
|
|
|||
|
|
@ -46,7 +46,6 @@ class AiAnthropic(BaseConnectorAi):
|
|||
return "anthropic"
|
||||
|
||||
def getModels(self) -> List[AiModel]:
|
||||
# return [] # TODO: DEBUG TO TURN ON AFTER TESTING
|
||||
# Get all available Anthropic models.
|
||||
return [
|
||||
AiModel(
|
||||
|
|
@ -57,11 +56,10 @@ class AiAnthropic(BaseConnectorAi):
|
|||
temperature=0.2,
|
||||
maxTokens=8192,
|
||||
contextLength=200000,
|
||||
costPer1kTokensInput=0.015,
|
||||
costPer1kTokensOutput=0.075,
|
||||
costPer1kTokensInput=0.003, # $3/M tokens (updated 2026-02)
|
||||
costPer1kTokensOutput=0.015, # $15/M tokens (updated 2026-02)
|
||||
speedRating=6, # Slower due to high-quality processing
|
||||
qualityRating=10, # Best quality available
|
||||
# capabilities removed (not used in business logic)
|
||||
functionCall=self.callAiBasic,
|
||||
priority=PriorityEnum.QUALITY,
|
||||
processingMode=ProcessingModeEnum.DETAILED,
|
||||
|
|
@ -72,7 +70,55 @@ class AiAnthropic(BaseConnectorAi):
|
|||
(OperationTypeEnum.DATA_EXTRACT, 8)
|
||||
),
|
||||
version="claude-sonnet-4-5-20250929",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.015 + (bytesReceived / 4 / 1000) * 0.075
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.003 + (bytesReceived / 4 / 1000) * 0.015
|
||||
),
|
||||
AiModel(
|
||||
name="claude-haiku-4-5-20251001",
|
||||
displayName="Anthropic Claude Haiku 4.5",
|
||||
connectorType="anthropic",
|
||||
apiUrl="https://api.anthropic.com/v1/messages",
|
||||
temperature=0.2,
|
||||
maxTokens=8192,
|
||||
contextLength=200000,
|
||||
costPer1kTokensInput=0.001, # $1/M tokens (updated 2026-02)
|
||||
costPer1kTokensOutput=0.005, # $5/M tokens (updated 2026-02)
|
||||
speedRating=9, # Very fast, lightweight model
|
||||
qualityRating=8, # Good quality, cost-efficient
|
||||
functionCall=self.callAiBasic,
|
||||
priority=PriorityEnum.SPEED,
|
||||
processingMode=ProcessingModeEnum.BASIC,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.PLAN, 8),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 8),
|
||||
(OperationTypeEnum.DATA_GENERATE, 8),
|
||||
(OperationTypeEnum.DATA_EXTRACT, 7)
|
||||
),
|
||||
version="claude-haiku-4-5-20251001",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.001 + (bytesReceived / 4 / 1000) * 0.005
|
||||
),
|
||||
AiModel(
|
||||
name="claude-opus-4-6",
|
||||
displayName="Anthropic Claude Opus 4.6",
|
||||
connectorType="anthropic",
|
||||
apiUrl="https://api.anthropic.com/v1/messages",
|
||||
temperature=0.2,
|
||||
maxTokens=8192,
|
||||
contextLength=200000,
|
||||
costPer1kTokensInput=0.005, # $5/M tokens (updated 2026-02)
|
||||
costPer1kTokensOutput=0.025, # $25/M tokens (updated 2026-02)
|
||||
speedRating=5, # Moderate latency, most capable
|
||||
qualityRating=10, # Top-tier intelligence
|
||||
functionCall=self.callAiBasic,
|
||||
priority=PriorityEnum.QUALITY,
|
||||
processingMode=ProcessingModeEnum.DETAILED,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.PLAN, 10),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 10),
|
||||
(OperationTypeEnum.DATA_GENERATE, 10),
|
||||
(OperationTypeEnum.DATA_EXTRACT, 9)
|
||||
),
|
||||
version="claude-opus-4-6",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.005 + (bytesReceived / 4 / 1000) * 0.025
|
||||
),
|
||||
AiModel(
|
||||
name="claude-sonnet-4-5-20250929",
|
||||
|
|
@ -82,8 +128,8 @@ class AiAnthropic(BaseConnectorAi):
|
|||
temperature=0.2,
|
||||
maxTokens=8192,
|
||||
contextLength=200000,
|
||||
costPer1kTokensInput=0.015,
|
||||
costPer1kTokensOutput=0.075,
|
||||
costPer1kTokensInput=0.003, # $3/M tokens (updated 2026-02)
|
||||
costPer1kTokensOutput=0.015, # $15/M tokens (updated 2026-02)
|
||||
speedRating=6,
|
||||
qualityRating=10,
|
||||
functionCall=self.callAiImage,
|
||||
|
|
@ -93,7 +139,7 @@ class AiAnthropic(BaseConnectorAi):
|
|||
(OperationTypeEnum.IMAGE_ANALYSE, 10)
|
||||
),
|
||||
version="claude-sonnet-4-5-20250929",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.015 + (bytesReceived / 4 / 1000) * 0.075
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.003 + (bytesReceived / 4 / 1000) * 0.015
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class AiInternal(BaseConnectorAi):
|
|||
processingMode=ProcessingModeEnum.BASIC,
|
||||
operationTypes=createOperationTypeRatings(),
|
||||
version="internal-extractor-v1",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: 0.001 + (bytesSent + bytesReceived) / (1024 * 1024) * 0.01
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: 0.001 + (bytesSent + bytesReceived) / (1024 * 1024) * 0.01
|
||||
),
|
||||
AiModel(
|
||||
name="internal-generator",
|
||||
|
|
@ -60,7 +60,7 @@ class AiInternal(BaseConnectorAi):
|
|||
processingMode=ProcessingModeEnum.BASIC,
|
||||
operationTypes=createOperationTypeRatings(),
|
||||
version="internal-generator-v1",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: 0.002 + (bytesReceived / (1024 * 1024)) * 0.005
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: 0.002 + (bytesReceived / (1024 * 1024)) * 0.005
|
||||
),
|
||||
AiModel(
|
||||
name="internal-renderer",
|
||||
|
|
@ -80,7 +80,7 @@ class AiInternal(BaseConnectorAi):
|
|||
processingMode=ProcessingModeEnum.DETAILED,
|
||||
operationTypes=createOperationTypeRatings(),
|
||||
version="internal-renderer-v1",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: 0.003 + (bytesReceived / (1024 * 1024)) * 0.008
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: 0.003 + (bytesReceived / (1024 * 1024)) * 0.008
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import List
|
|||
from fastapi import HTTPException
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from .aicoreBase import BaseConnectorAi
|
||||
from modules.datamodels.datamodelAi import AiModel, PriorityEnum, ProcessingModeEnum, OperationTypeEnum, AiModelCall, AiModelResponse, createOperationTypeRatings
|
||||
from modules.datamodels.datamodelAi import AiModel, PriorityEnum, ProcessingModeEnum, OperationTypeEnum, AiModelCall, AiModelResponse, createOperationTypeRatings, AiCallPromptImage
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -15,6 +15,10 @@ class ContextLengthExceededException(Exception):
|
|||
"""Exception raised when the context length exceeds the model's limit"""
|
||||
pass
|
||||
|
||||
class RateLimitExceededException(Exception):
|
||||
"""Exception raised when the provider's rate limit (TPM) is exceeded"""
|
||||
pass
|
||||
|
||||
def loadConfigData():
|
||||
"""Load configuration data for OpenAI connector"""
|
||||
return {
|
||||
|
|
@ -57,11 +61,11 @@ class AiOpenai(BaseConnectorAi):
|
|||
temperature=0.2,
|
||||
maxTokens=16384,
|
||||
contextLength=128000,
|
||||
costPer1kTokensInput=0.03,
|
||||
costPer1kTokensOutput=0.06,
|
||||
maxInputTokensPerRequest=25000, # OpenAI org TPM limit is 30K, keep 5K buffer
|
||||
costPer1kTokensInput=0.0025, # $2.50/M tokens (updated 2026-02)
|
||||
costPer1kTokensOutput=0.01, # $10.00/M tokens (updated 2026-02)
|
||||
speedRating=8, # Good speed for complex tasks
|
||||
qualityRating=10, # High quality
|
||||
# capabilities removed (not used in business logic)
|
||||
functionCall=self.callAiBasic,
|
||||
priority=PriorityEnum.BALANCED,
|
||||
processingMode=ProcessingModeEnum.ADVANCED,
|
||||
|
|
@ -72,43 +76,44 @@ class AiOpenai(BaseConnectorAi):
|
|||
(OperationTypeEnum.DATA_EXTRACT, 7)
|
||||
),
|
||||
version="gpt-4o",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.03 + (bytesReceived / 4 / 1000) * 0.06
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0025 + (bytesReceived / 4 / 1000) * 0.01
|
||||
),
|
||||
AiModel(
|
||||
name="gpt-3.5-turbo",
|
||||
displayName="OpenAI GPT-3.5 Turbo",
|
||||
connectorType="openai",
|
||||
apiUrl="https://api.openai.com/v1/chat/completions",
|
||||
temperature=0.2,
|
||||
maxTokens=4096,
|
||||
contextLength=16000,
|
||||
costPer1kTokensInput=0.0015,
|
||||
costPer1kTokensOutput=0.002,
|
||||
speedRating=9, # Very fast
|
||||
qualityRating=7, # Good but not premium
|
||||
# capabilities removed (not used in business logic)
|
||||
functionCall=self.callAiBasic,
|
||||
priority=PriorityEnum.SPEED,
|
||||
processingMode=ProcessingModeEnum.BASIC,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.PLAN, 7),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 8),
|
||||
(OperationTypeEnum.DATA_GENERATE, 8)
|
||||
# Note: GPT-3.5-turbo does NOT support vision/image operations
|
||||
),
|
||||
version="gpt-3.5-turbo",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0015 + (bytesReceived / 4 / 1000) * 0.002
|
||||
),
|
||||
AiModel(
|
||||
name="gpt-4o",
|
||||
displayName="OpenAI GPT-4o Instance Vision",
|
||||
name="gpt-4o-mini",
|
||||
displayName="OpenAI GPT-4o Mini",
|
||||
connectorType="openai",
|
||||
apiUrl="https://api.openai.com/v1/chat/completions",
|
||||
temperature=0.2,
|
||||
maxTokens=16384,
|
||||
contextLength=128000,
|
||||
costPer1kTokensInput=0.03,
|
||||
costPer1kTokensOutput=0.06,
|
||||
maxInputTokensPerRequest=25000, # OpenAI org TPM limit, keep buffer
|
||||
costPer1kTokensInput=0.00015, # $0.15/M tokens (updated 2026-02)
|
||||
costPer1kTokensOutput=0.0006, # $0.60/M tokens (updated 2026-02)
|
||||
speedRating=9, # Very fast
|
||||
qualityRating=8, # Good quality, replaces gpt-3.5-turbo
|
||||
functionCall=self.callAiBasic,
|
||||
priority=PriorityEnum.SPEED,
|
||||
processingMode=ProcessingModeEnum.BASIC,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.PLAN, 8),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 8),
|
||||
(OperationTypeEnum.DATA_GENERATE, 9),
|
||||
(OperationTypeEnum.DATA_EXTRACT, 7)
|
||||
),
|
||||
version="gpt-4o-mini",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00015 + (bytesReceived / 4 / 1000) * 0.0006
|
||||
),
|
||||
AiModel(
|
||||
name="gpt-4o",
|
||||
displayName="OpenAI GPT-4o Vision",
|
||||
connectorType="openai",
|
||||
apiUrl="https://api.openai.com/v1/chat/completions",
|
||||
temperature=0.2,
|
||||
maxTokens=16384,
|
||||
contextLength=128000,
|
||||
maxInputTokensPerRequest=25000, # OpenAI org TPM limit is 30K, keep 5K buffer
|
||||
costPer1kTokensInput=0.0025, # $2.50/M tokens (updated 2026-02)
|
||||
costPer1kTokensOutput=0.01, # $10.00/M tokens (updated 2026-02)
|
||||
speedRating=6, # Slower for vision tasks
|
||||
qualityRating=9, # High quality vision
|
||||
functionCall=self.callAiImage,
|
||||
|
|
@ -118,7 +123,7 @@ class AiOpenai(BaseConnectorAi):
|
|||
(OperationTypeEnum.IMAGE_ANALYSE, 9)
|
||||
),
|
||||
version="gpt-4o",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.03 + (bytesReceived / 4 / 1000) * 0.06
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0025 + (bytesReceived / 4 / 1000) * 0.01
|
||||
),
|
||||
AiModel(
|
||||
name="dall-e-3",
|
||||
|
|
@ -140,7 +145,7 @@ class AiOpenai(BaseConnectorAi):
|
|||
(OperationTypeEnum.IMAGE_GENERATE, 10)
|
||||
),
|
||||
version="dall-e-3",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.04
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.04
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -183,6 +188,19 @@ class AiOpenai(BaseConnectorAi):
|
|||
error_message = f"OpenAI API error: {response.status_code} - {response.text}"
|
||||
logger.error(error_message)
|
||||
|
||||
# Check for rate limit exceeded (429 TPM)
|
||||
if response.status_code == 429:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = error_data.get("error", {}).get("message", "Rate limit exceeded")
|
||||
raise RateLimitExceededException(
|
||||
f"Rate limit exceeded for {model.name}: {error_msg}"
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
raise RateLimitExceededException(
|
||||
f"Rate limit exceeded for {model.name}"
|
||||
)
|
||||
|
||||
# Check for context length exceeded error
|
||||
if response.status_code == 400:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -59,13 +59,12 @@ class AiPerplexity(BaseConnectorAi):
|
|||
connectorType="perplexity",
|
||||
apiUrl="https://api.perplexity.ai/chat/completions",
|
||||
temperature=0.2,
|
||||
maxTokens=24000, # Increased for detailed web crawl responses (Perplexity supports up to 25k)
|
||||
contextLength=32000,
|
||||
costPer1kTokensInput=0.005,
|
||||
costPer1kTokensOutput=0.005,
|
||||
maxTokens=24000,
|
||||
contextLength=127000, # 127K context window (updated 2026-02)
|
||||
costPer1kTokensInput=0.001, # $1/M tokens (updated 2026-02)
|
||||
costPer1kTokensOutput=0.001, # $1/M tokens (updated 2026-02)
|
||||
speedRating=8,
|
||||
qualityRating=8,
|
||||
# capabilities removed (not used in business logic)
|
||||
functionCall=self._routeWebOperation,
|
||||
priority=PriorityEnum.BALANCED,
|
||||
processingMode=ProcessingModeEnum.ADVANCED,
|
||||
|
|
@ -74,7 +73,7 @@ class AiPerplexity(BaseConnectorAi):
|
|||
(OperationTypeEnum.WEB_CRAWL, 7)
|
||||
),
|
||||
version="sonar",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.005 + (bytesReceived / 4 / 1000) * 0.005
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.001 + (bytesReceived / 4 / 1000) * 0.001
|
||||
),
|
||||
AiModel(
|
||||
name="sonar-pro",
|
||||
|
|
@ -82,13 +81,12 @@ class AiPerplexity(BaseConnectorAi):
|
|||
connectorType="perplexity",
|
||||
apiUrl="https://api.perplexity.ai/chat/completions",
|
||||
temperature=0.2,
|
||||
maxTokens=24000, # Increased for detailed web crawl responses (Perplexity supports up to 25k)
|
||||
contextLength=32000,
|
||||
costPer1kTokensInput=0.01,
|
||||
costPer1kTokensOutput=0.01,
|
||||
maxTokens=24000,
|
||||
contextLength=200000, # 200K context window (updated 2026-02)
|
||||
costPer1kTokensInput=0.003, # $3/M tokens (updated 2026-02)
|
||||
costPer1kTokensOutput=0.015, # $15/M tokens (updated 2026-02)
|
||||
speedRating=6, # Slower due to AI analysis
|
||||
qualityRating=9, # Best AI analysis quality
|
||||
# capabilities removed (not used in business logic)
|
||||
functionCall=self._routeWebOperation,
|
||||
priority=PriorityEnum.QUALITY,
|
||||
processingMode=ProcessingModeEnum.DETAILED,
|
||||
|
|
@ -97,7 +95,7 @@ class AiPerplexity(BaseConnectorAi):
|
|||
(OperationTypeEnum.WEB_CRAWL, 8)
|
||||
),
|
||||
version="sonar-pro",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.01 + (bytesReceived / 4 / 1000) * 0.01
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.003 + (bytesReceived / 4 / 1000) * 0.015
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
|||
497
modules/aicore/aicorePluginPrivateLlm.py
Normal file
497
modules/aicore/aicorePluginPrivateLlm.py
Normal file
|
|
@ -0,0 +1,497 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
AI Connector for PowerOn Private-LLM Service.
|
||||
|
||||
Connects to the private-llm service running on-premise with Ollama backend.
|
||||
Provides OCR and Vision capabilities via local AI models.
|
||||
|
||||
Models:
|
||||
- poweron-ocr-general: Text extraction and OCR (deepseek backend)
|
||||
- poweron-vision-general: General vision tasks (qwen2.5vl backend)
|
||||
- poweron-vision-deep: Deep vision analysis (granite3.2 backend)
|
||||
|
||||
Pricing (CHF per call):
|
||||
- Text models: CHF 0.010
|
||||
- Vision models: CHF 0.100
|
||||
"""
|
||||
|
||||
import logging
|
||||
import httpx
|
||||
import time
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from .aicoreBase import BaseConnectorAi
|
||||
from modules.datamodels.datamodelAi import (
|
||||
AiModel,
|
||||
PriorityEnum,
|
||||
ProcessingModeEnum,
|
||||
OperationTypeEnum,
|
||||
AiModelCall,
|
||||
AiModelResponse,
|
||||
createOperationTypeRatings
|
||||
)
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pricing constants (CHF)
|
||||
PRICE_TEXT_PER_CALL = 0.01 # CHF 0.010 per text model call
|
||||
PRICE_VISION_PER_CALL = 0.10 # CHF 0.100 per vision model call
|
||||
|
||||
|
||||
# Private-LLM Service URL (fix, nicht via env konfigurierbar)
|
||||
PRIVATE_LLM_BASE_URL = "https://llm.poweron.swiss:8000"
|
||||
|
||||
|
||||
def _loadConfigData():
|
||||
"""Load configuration data for Private-LLM connector."""
|
||||
return {
|
||||
"apiKey": APP_CONFIG.get("Connector_AiPrivateLlm_API_SECRET"),
|
||||
"baseUrl": PRIVATE_LLM_BASE_URL,
|
||||
}
|
||||
|
||||
|
||||
class AiPrivateLlm(BaseConnectorAi):
|
||||
"""Connector for communication with the PowerOn Private-LLM Service."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Load configuration
|
||||
self.config = _loadConfigData()
|
||||
self.apiKey = self.config["apiKey"]
|
||||
self.baseUrl = self.config["baseUrl"]
|
||||
|
||||
# HTTP client for API calls
|
||||
# Timeout set to 3600 seconds (60 minutes) for large model processing
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.apiKey:
|
||||
headers["X-API-Key"] = self.apiKey
|
||||
|
||||
self.httpClient = httpx.AsyncClient(
|
||||
timeout=3600.0,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Cache for service availability check
|
||||
self._serviceAvailable: Optional[bool] = None
|
||||
self._availableOllamaModels: Optional[List[str]] = None
|
||||
self._lastAvailabilityCheck: float = 0
|
||||
self._availabilityCacheTtl: float = 60.0 # 60 seconds cache
|
||||
|
||||
logger.info(f"Private-LLM Connector initialized (URL: {self.baseUrl})")
|
||||
|
||||
def getConnectorType(self) -> str:
|
||||
"""Get the connector type identifier."""
|
||||
return "privatellm"
|
||||
|
||||
def _checkServiceAvailability(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Check if the Private-LLM service is available and which Ollama models are installed.
|
||||
Uses caching to avoid excessive health checks.
|
||||
|
||||
Returns:
|
||||
Dict with 'serviceAvailable', 'ollamaConnected', 'availableModels'
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
currentTime = time.time()
|
||||
|
||||
# Return cached result if still valid
|
||||
if (self._serviceAvailable is not None and
|
||||
currentTime - self._lastAvailabilityCheck < self._availabilityCacheTtl):
|
||||
return {
|
||||
"serviceAvailable": self._serviceAvailable,
|
||||
"ollamaConnected": self._serviceAvailable,
|
||||
"availableModels": self._availableOllamaModels or []
|
||||
}
|
||||
|
||||
# Perform availability check
|
||||
try:
|
||||
# Use synchronous client for blocking check during initialization
|
||||
with httpx.Client(timeout=5.0) as client:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.apiKey:
|
||||
headers["X-API-Key"] = self.apiKey
|
||||
|
||||
# Check health endpoint
|
||||
healthResponse = client.get(
|
||||
f"{self.baseUrl}/api/health",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
if healthResponse.status_code != 200:
|
||||
logger.warning(f"Private-LLM service not available: HTTP {healthResponse.status_code}")
|
||||
self._serviceAvailable = False
|
||||
self._availableOllamaModels = []
|
||||
self._lastAvailabilityCheck = currentTime
|
||||
return {"serviceAvailable": False, "ollamaConnected": False, "availableModels": []}
|
||||
|
||||
healthData = healthResponse.json()
|
||||
ollamaConnected = healthData.get("ollamaConnected", False)
|
||||
|
||||
if not ollamaConnected:
|
||||
logger.warning("Private-LLM service available but Ollama not connected")
|
||||
self._serviceAvailable = True
|
||||
self._availableOllamaModels = []
|
||||
self._lastAvailabilityCheck = currentTime
|
||||
return {"serviceAvailable": True, "ollamaConnected": False, "availableModels": []}
|
||||
|
||||
# Check Ollama status for available models
|
||||
statusResponse = client.get(
|
||||
f"{self.baseUrl}/api/ollama/status",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
if statusResponse.status_code == 200:
|
||||
statusData = statusResponse.json()
|
||||
self._availableOllamaModels = statusData.get("models", [])
|
||||
else:
|
||||
self._availableOllamaModels = []
|
||||
|
||||
self._serviceAvailable = True
|
||||
self._lastAvailabilityCheck = currentTime
|
||||
|
||||
logger.info(f"Private-LLM availability check: service=OK, ollama=OK, models={len(self._availableOllamaModels)}")
|
||||
|
||||
return {
|
||||
"serviceAvailable": True,
|
||||
"ollamaConnected": True,
|
||||
"availableModels": self._availableOllamaModels
|
||||
}
|
||||
|
||||
except httpx.ConnectError:
|
||||
logger.warning(f"Private-LLM service not reachable at {self.baseUrl}")
|
||||
self._serviceAvailable = False
|
||||
self._availableOllamaModels = []
|
||||
self._lastAvailabilityCheck = currentTime
|
||||
return {"serviceAvailable": False, "ollamaConnected": False, "availableModels": []}
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking Private-LLM availability: {e}")
|
||||
self._serviceAvailable = False
|
||||
self._availableOllamaModels = []
|
||||
self._lastAvailabilityCheck = currentTime
|
||||
return {"serviceAvailable": False, "ollamaConnected": False, "availableModels": []}
|
||||
|
||||
def _isModelAvailableInOllama(self, ollamaModelName: str, availableModels: List[str]) -> bool:
|
||||
"""
|
||||
Check if a model is available in Ollama.
|
||||
Handles model name variations (with/without tags).
|
||||
"""
|
||||
if not availableModels:
|
||||
return False
|
||||
|
||||
# Direct match
|
||||
if ollamaModelName in availableModels:
|
||||
return True
|
||||
|
||||
# Check without tag (e.g., "qwen2.5vl:72b" -> "qwen2.5vl")
|
||||
baseModelName = ollamaModelName.split(":")[0]
|
||||
for availModel in availableModels:
|
||||
availBase = availModel.split(":")[0]
|
||||
if baseModelName == availBase:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def getModels(self) -> List[AiModel]:
|
||||
"""
|
||||
Get all available Private-LLM models.
|
||||
|
||||
Checks service availability and returns only models that are actually available
|
||||
in the connected Ollama instance. Returns empty list if service is not reachable.
|
||||
"""
|
||||
# Check service availability
|
||||
availability = self._checkServiceAvailability()
|
||||
|
||||
if not availability["serviceAvailable"]:
|
||||
logger.warning("Private-LLM service not available - no models returned")
|
||||
return []
|
||||
|
||||
if not availability["ollamaConnected"]:
|
||||
logger.warning("Private-LLM service available but Ollama not connected - no models returned")
|
||||
return []
|
||||
|
||||
availableOllamaModels = availability.get("availableModels", [])
|
||||
|
||||
# Define all models with their Ollama backend names
|
||||
# Actual model specs (for 31GB RAM + 22GB GPU server):
|
||||
# Context sizes reduced to fit in available RAM
|
||||
# - qwen2.5:7b: 7.6B params, ~4.7GB RAM (Text) - 8K context
|
||||
# - qwen2.5vl:7b: 8.29B params, ~6GB RAM (Vision) - 4K context
|
||||
# - granite3.2-vision: 2B params, ~2.4GB RAM (Vision) - 4K context
|
||||
# - deepseek-ocr: ~6.7GB RAM (OCR) - 4K context
|
||||
modelDefinitions = [
|
||||
# Text Model (qwen2.5:7b: 7.6B)
|
||||
{
|
||||
"model": AiModel(
|
||||
name="poweron-text-general",
|
||||
displayName="PowerOn Text General",
|
||||
connectorType="privatellm",
|
||||
apiUrl=f"{self.baseUrl}/api/analyze",
|
||||
temperature=0.1,
|
||||
maxTokens=4096,
|
||||
contextLength=8192, # Reduced for RAM constraints
|
||||
costPer1kTokensInput=0.0, # Flat rate pricing
|
||||
costPer1kTokensOutput=0.0, # Flat rate pricing
|
||||
speedRating=8, # Fast and efficient
|
||||
qualityRating=9, # High quality text model
|
||||
functionCall=self.callAiText,
|
||||
priority=PriorityEnum.COST,
|
||||
processingMode=ProcessingModeEnum.BASIC,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.PLAN, 7),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 8),
|
||||
(OperationTypeEnum.DATA_GENERATE, 8),
|
||||
(OperationTypeEnum.DATA_EXTRACT, 8),
|
||||
),
|
||||
version="qwen2.5:7b",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: PRICE_TEXT_PER_CALL
|
||||
),
|
||||
"ollamaModel": "qwen2.5:7b"
|
||||
},
|
||||
# Vision General Model (qwen2.5vl:7b: 8.29B)
|
||||
{
|
||||
"model": AiModel(
|
||||
name="poweron-vision-general",
|
||||
displayName="PowerOn Vision General",
|
||||
connectorType="privatellm",
|
||||
apiUrl=f"{self.baseUrl}/api/analyze",
|
||||
temperature=0.2,
|
||||
maxTokens=2048,
|
||||
contextLength=4096, # Reduced for RAM constraints (vision needs more)
|
||||
costPer1kTokensInput=0.0, # Flat rate pricing
|
||||
costPer1kTokensOutput=0.0, # Flat rate pricing
|
||||
speedRating=7,
|
||||
qualityRating=9,
|
||||
functionCall=self.callAiVision,
|
||||
priority=PriorityEnum.BALANCED,
|
||||
processingMode=ProcessingModeEnum.ADVANCED,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.IMAGE_ANALYSE, 9),
|
||||
),
|
||||
version="qwen2.5vl:7b",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: PRICE_VISION_PER_CALL
|
||||
),
|
||||
"ollamaModel": "qwen2.5vl:7b"
|
||||
},
|
||||
# Vision Deep Model (granite3.2-vision: 2B)
|
||||
{
|
||||
"model": AiModel(
|
||||
name="poweron-vision-deep",
|
||||
displayName="PowerOn Vision Deep",
|
||||
connectorType="privatellm",
|
||||
apiUrl=f"{self.baseUrl}/api/analyze",
|
||||
temperature=0.1,
|
||||
maxTokens=2048,
|
||||
contextLength=4096, # Reduced for RAM constraints
|
||||
costPer1kTokensInput=0.0, # Flat rate pricing
|
||||
costPer1kTokensOutput=0.0, # Flat rate pricing
|
||||
speedRating=9, # Fast due to small 2B model
|
||||
qualityRating=8, # Good for document understanding
|
||||
functionCall=self.callAiVision,
|
||||
priority=PriorityEnum.QUALITY,
|
||||
processingMode=ProcessingModeEnum.DETAILED,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.IMAGE_ANALYSE, 9),
|
||||
),
|
||||
version="granite3.2-vision",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: PRICE_VISION_PER_CALL
|
||||
),
|
||||
"ollamaModel": "granite3.2-vision"
|
||||
},
|
||||
]
|
||||
|
||||
# Filter models by Ollama availability
|
||||
availableModels = []
|
||||
unavailableModels = []
|
||||
|
||||
for modelDef in modelDefinitions:
|
||||
ollamaModelName = modelDef["ollamaModel"]
|
||||
if self._isModelAvailableInOllama(ollamaModelName, availableOllamaModels):
|
||||
availableModels.append(modelDef["model"])
|
||||
else:
|
||||
unavailableModels.append(modelDef["model"].name)
|
||||
|
||||
if unavailableModels:
|
||||
logger.warning(
|
||||
f"Private-LLM: {len(unavailableModels)} models not available in Ollama: {', '.join(unavailableModels)}. "
|
||||
f"Install with: ollama pull <model-name>"
|
||||
)
|
||||
|
||||
if availableModels:
|
||||
logger.info(f"Private-LLM: {len(availableModels)} models available")
|
||||
else:
|
||||
logger.warning("Private-LLM: No models available. Check Ollama installation.")
|
||||
|
||||
return availableModels
|
||||
|
||||
async def callAiText(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""
|
||||
Call the Private-LLM API for text-based analysis.
|
||||
|
||||
Args:
|
||||
modelCall: AiModelCall with messages
|
||||
|
||||
Returns:
|
||||
AiModelResponse with content and metadata
|
||||
"""
|
||||
try:
|
||||
messages = modelCall.messages
|
||||
model = modelCall.model
|
||||
|
||||
# Extract prompt from messages
|
||||
prompt = ""
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
prompt += content + "\n"
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
prompt += part.get("text", "") + "\n"
|
||||
|
||||
payload = {
|
||||
"modelName": model.name,
|
||||
"prompt": prompt.strip(),
|
||||
"imageBase64": None
|
||||
}
|
||||
|
||||
logger.debug(f"Calling Private-LLM text API with model {model.name}")
|
||||
|
||||
response = await self.httpClient.post(
|
||||
model.apiUrl,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
errorMessage = f"Private-LLM API error: {response.status_code} - {response.text}"
|
||||
logger.error(errorMessage)
|
||||
raise HTTPException(status_code=500, detail=errorMessage)
|
||||
|
||||
responseJson = response.json()
|
||||
|
||||
if not responseJson.get("success", False):
|
||||
errorMsg = responseJson.get("error", "Unknown error")
|
||||
logger.error(f"Private-LLM returned error: {errorMsg}")
|
||||
return AiModelResponse(
|
||||
content="",
|
||||
success=False,
|
||||
error=errorMsg
|
||||
)
|
||||
|
||||
# Extract content from response
|
||||
data = responseJson.get("data", {})
|
||||
rawResponse = responseJson.get("rawResponse", "")
|
||||
|
||||
# Prefer rawResponse for full content, fall back to data
|
||||
content = rawResponse if rawResponse else str(data.get("response", data))
|
||||
|
||||
return AiModelResponse(
|
||||
content=content,
|
||||
success=True,
|
||||
modelId=model.name,
|
||||
metadata={"data": data}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Private-LLM text API: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error calling Private-LLM API: {str(e)}")
|
||||
|
||||
async def callAiVision(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""
|
||||
Call the Private-LLM API for vision-based analysis.
|
||||
|
||||
Args:
|
||||
modelCall: AiModelCall with messages containing image data
|
||||
|
||||
Returns:
|
||||
AiModelResponse with analysis content
|
||||
"""
|
||||
try:
|
||||
messages = modelCall.messages
|
||||
model = modelCall.model
|
||||
|
||||
# Extract prompt and image from messages
|
||||
prompt = ""
|
||||
imageBase64 = None
|
||||
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
|
||||
if isinstance(content, str):
|
||||
prompt += content + "\n"
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
if part.get("type") == "text":
|
||||
prompt += part.get("text", "") + "\n"
|
||||
elif part.get("type") == "image_url":
|
||||
imageUrl = part.get("image_url", {}).get("url", "")
|
||||
# Extract base64 from data URL
|
||||
if imageUrl.startswith("data:"):
|
||||
# Format: data:image/png;base64,<base64data>
|
||||
parts = imageUrl.split(",", 1)
|
||||
if len(parts) == 2:
|
||||
imageBase64 = parts[1]
|
||||
else:
|
||||
imageBase64 = imageUrl
|
||||
|
||||
if not imageBase64:
|
||||
logger.warning("No image provided for vision model call")
|
||||
|
||||
payload = {
|
||||
"modelName": model.name,
|
||||
"prompt": prompt.strip(),
|
||||
"imageBase64": imageBase64
|
||||
}
|
||||
|
||||
logger.debug(f"Calling Private-LLM vision API with model {model.name}")
|
||||
|
||||
response = await self.httpClient.post(
|
||||
model.apiUrl,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
errorMessage = f"Private-LLM API error: {response.status_code} - {response.text}"
|
||||
logger.error(errorMessage)
|
||||
raise HTTPException(status_code=500, detail=errorMessage)
|
||||
|
||||
responseJson = response.json()
|
||||
|
||||
if not responseJson.get("success", False):
|
||||
errorMsg = responseJson.get("error", "Unknown error")
|
||||
logger.error(f"Private-LLM returned error: {errorMsg}")
|
||||
return AiModelResponse(
|
||||
content="",
|
||||
success=False,
|
||||
error=errorMsg
|
||||
)
|
||||
|
||||
# Extract content from response
|
||||
data = responseJson.get("data", {})
|
||||
rawResponse = responseJson.get("rawResponse", "")
|
||||
|
||||
# Prefer rawResponse for full content
|
||||
content = rawResponse if rawResponse else str(data.get("response", data))
|
||||
|
||||
return AiModelResponse(
|
||||
content=content,
|
||||
success=True,
|
||||
modelId=model.name,
|
||||
metadata={"data": data}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Private-LLM vision API: {str(e)}", exc_info=True)
|
||||
return AiModelResponse(
|
||||
content="",
|
||||
success=False,
|
||||
error=f"Error during vision analysis: {str(e)}"
|
||||
)
|
||||
|
|
@ -71,7 +71,7 @@ class AiTavily(BaseConnectorAi):
|
|||
(OperationTypeEnum.WEB_CRAWL, 10)
|
||||
),
|
||||
version="tavily-search",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: 0.008 # Simple flat rate
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: 0.008 # Simple flat rate
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ class AiModel(BaseModel):
|
|||
# Token and context limits
|
||||
maxTokens: int = Field(description="Maximum tokens this model can generate")
|
||||
contextLength: int = Field(description="Maximum context length this model can handle")
|
||||
maxInputTokensPerRequest: Optional[int] = Field(default=None, description="Max input tokens per single request (provider rate limit / TPM). If set, model selector filters requests exceeding this limit.")
|
||||
|
||||
# Cost information
|
||||
costPer1kTokensInput: float = Field(default=0.0, description="Cost per 1000 input tokens")
|
||||
|
|
@ -98,7 +99,7 @@ class AiModel(BaseModel):
|
|||
|
||||
# Function reference (not serialized)
|
||||
functionCall: Optional[Callable] = Field(default=None, exclude=True, description="Function to call for this model")
|
||||
calculatePriceUsd: Optional[Callable] = Field(default=None, exclude=True, description="Function to calculate price in USD")
|
||||
calculatepriceCHF: Optional[Callable] = Field(default=None, exclude=True, description="Function to calculate price in USD")
|
||||
|
||||
# Selection criteria - capabilities with ratings
|
||||
priority: PriorityEnum = Field(default=PriorityEnum.BALANCED, description="Default priority for this model. See PriorityEnum for available values.")
|
||||
|
|
@ -144,6 +145,9 @@ class AiCallOptions(BaseModel):
|
|||
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0, description="Temperature for response generation (0.0-2.0, lower = more consistent)")
|
||||
maxParts: Optional[int] = Field(default=1000, ge=1, le=1000, description="Maximum number of continuation parts to fetch")
|
||||
|
||||
# Provider filtering (from UI multiselect or automation config)
|
||||
allowedProviders: Optional[List[str]] = Field(default=None, description="List of allowed AI providers to use (empty = all RBAC-permitted)")
|
||||
|
||||
|
||||
class AiCallRequest(BaseModel):
|
||||
"""Centralized AI call request payload for interface use."""
|
||||
|
|
@ -159,7 +163,8 @@ class AiCallResponse(BaseModel):
|
|||
|
||||
content: str = Field(description="AI response content")
|
||||
modelName: str = Field(description="Selected model name")
|
||||
priceUsd: float = Field(default=0.0, description="Calculated price in USD")
|
||||
provider: str = Field(default="unknown", description="AI provider / connectorType (anthropic, openai, perplexity, etc.)")
|
||||
priceCHF: float = Field(default=0.0, description="Calculated price in USD")
|
||||
processingTime: float = Field(default=0.0, description="Duration in seconds")
|
||||
bytesSent: int = Field(default=0, description="Input data size in bytes")
|
||||
bytesReceived: int = Field(default=0, description="Output data size in bytes")
|
||||
|
|
|
|||
269
modules/datamodels/datamodelBilling.py
Normal file
269
modules/datamodels/datamodelBilling.py
Normal file
|
|
@ -0,0 +1,269 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Billing models: BillingAccount, BillingTransaction, BillingSettings, UsageStatistics."""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from enum import Enum
|
||||
from datetime import date, datetime
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
import uuid
|
||||
|
||||
|
||||
class BillingModelEnum(str, Enum):
|
||||
"""Billing model types."""
|
||||
PREPAY_MANDATE = "PREPAY_MANDATE" # Prepaid budget shared by all users in mandate
|
||||
PREPAY_USER = "PREPAY_USER" # Prepaid budget per user within mandate
|
||||
CREDIT_POSTPAY = "CREDIT_POSTPAY" # Credit with monthly invoice (requires billing address)
|
||||
UNLIMITED = "UNLIMITED" # No cost limitation (internal mandates only)
|
||||
|
||||
|
||||
class AccountTypeEnum(str, Enum):
|
||||
"""Account type for billing accounts."""
|
||||
MANDATE = "MANDATE" # Account for entire mandate
|
||||
USER = "USER" # Account for specific user within mandate
|
||||
|
||||
|
||||
class TransactionTypeEnum(str, Enum):
|
||||
"""Transaction types for billing."""
|
||||
CREDIT = "CREDIT" # Credit/top-up (positive)
|
||||
DEBIT = "DEBIT" # Debit/usage (positive amount, reduces balance)
|
||||
ADJUSTMENT = "ADJUSTMENT" # Manual adjustment by admin
|
||||
|
||||
|
||||
class ReferenceTypeEnum(str, Enum):
|
||||
"""Reference types for transactions."""
|
||||
WORKFLOW = "WORKFLOW" # AI workflow usage
|
||||
PAYMENT = "PAYMENT" # Payment/top-up
|
||||
ADMIN = "ADMIN" # Admin adjustment
|
||||
SYSTEM = "SYSTEM" # System credit (e.g., initial credit)
|
||||
|
||||
|
||||
class PeriodTypeEnum(str, Enum):
|
||||
"""Period types for usage statistics."""
|
||||
DAY = "DAY"
|
||||
MONTH = "MONTH"
|
||||
YEAR = "YEAR"
|
||||
|
||||
|
||||
class BillingAddress(BaseModel):
|
||||
"""Billing address for CREDIT_POSTPAY mandates."""
|
||||
company: str = Field(..., description="Company name")
|
||||
street: str = Field(..., description="Street and number")
|
||||
zip: str = Field(..., description="Postal code")
|
||||
city: str = Field(..., description="City")
|
||||
country: str = Field(default="CH", description="Country code")
|
||||
vatNumber: Optional[str] = Field(None, description="VAT number (optional)")
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"BillingAddress",
|
||||
{"en": "Billing Address", "de": "Rechnungsadresse"},
|
||||
{
|
||||
"company": {"en": "Company", "de": "Firma"},
|
||||
"street": {"en": "Street", "de": "Strasse"},
|
||||
"zip": {"en": "ZIP", "de": "PLZ"},
|
||||
"city": {"en": "City", "de": "Ort"},
|
||||
"country": {"en": "Country", "de": "Land"},
|
||||
"vatNumber": {"en": "VAT Number", "de": "MwSt-Nummer"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class BillingAccount(BaseModel):
|
||||
"""Billing account for mandate or user-mandate combination."""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||
)
|
||||
mandateId: str = Field(..., description="Foreign key to Mandate")
|
||||
userId: Optional[str] = Field(None, description="Foreign key to User (only for PREPAY_USER)")
|
||||
accountType: AccountTypeEnum = Field(..., description="Account type: MANDATE or USER")
|
||||
balance: float = Field(default=0.0, description="Current balance in CHF")
|
||||
creditLimit: Optional[float] = Field(None, description="Credit limit in CHF (only for CREDIT_POSTPAY)")
|
||||
warningThreshold: float = Field(default=0.0, description="Warning threshold in CHF")
|
||||
lastWarningAt: Optional[datetime] = Field(None, description="Last warning sent timestamp")
|
||||
enabled: bool = Field(default=True, description="Account is active")
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"BillingAccount",
|
||||
{"en": "Billing Account", "de": "Abrechnungskonto"},
|
||||
{
|
||||
"id": {"en": "ID", "de": "ID"},
|
||||
"mandateId": {"en": "Mandate ID", "de": "Mandanten-ID"},
|
||||
"userId": {"en": "User ID", "de": "Benutzer-ID"},
|
||||
"accountType": {"en": "Account Type", "de": "Kontotyp"},
|
||||
"balance": {"en": "Balance (CHF)", "de": "Guthaben (CHF)"},
|
||||
"creditLimit": {"en": "Credit Limit (CHF)", "de": "Kreditlimit (CHF)"},
|
||||
"warningThreshold": {"en": "Warning Threshold (CHF)", "de": "Warnschwelle (CHF)"},
|
||||
"lastWarningAt": {"en": "Last Warning", "de": "Letzte Warnung"},
|
||||
"enabled": {"en": "Enabled", "de": "Aktiv"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class BillingTransaction(BaseModel):
|
||||
"""Single billing transaction (credit, debit, adjustment)."""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||
)
|
||||
accountId: str = Field(..., description="Foreign key to BillingAccount")
|
||||
transactionType: TransactionTypeEnum = Field(..., description="Transaction type")
|
||||
amount: float = Field(..., description="Amount in CHF (always positive)")
|
||||
description: str = Field(..., description="Transaction description")
|
||||
|
||||
# Reference to source
|
||||
referenceType: Optional[ReferenceTypeEnum] = Field(None, description="Reference type")
|
||||
referenceId: Optional[str] = Field(None, description="Reference ID")
|
||||
|
||||
# Context for workflow transactions
|
||||
workflowId: Optional[str] = Field(None, description="Workflow ID (for WORKFLOW transactions)")
|
||||
featureInstanceId: Optional[str] = Field(None, description="Feature instance ID")
|
||||
featureCode: Optional[str] = Field(None, description="Feature code (e.g., chatplayground, automation)")
|
||||
aicoreProvider: Optional[str] = Field(None, description="AICore provider (anthropic, openai, etc.)")
|
||||
aicoreModel: Optional[str] = Field(None, description="AICore model name (e.g., claude-4-sonnet, gpt-4o)")
|
||||
createdByUserId: Optional[str] = Field(None, description="User who created/caused this transaction")
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"BillingTransaction",
|
||||
{"en": "Billing Transaction", "de": "Transaktion"},
|
||||
{
|
||||
"id": {"en": "ID", "de": "ID"},
|
||||
"accountId": {"en": "Account ID", "de": "Konto-ID"},
|
||||
"transactionType": {"en": "Type", "de": "Typ"},
|
||||
"amount": {"en": "Amount (CHF)", "de": "Betrag (CHF)"},
|
||||
"description": {"en": "Description", "de": "Beschreibung"},
|
||||
"referenceType": {"en": "Reference Type", "de": "Referenztyp"},
|
||||
"referenceId": {"en": "Reference ID", "de": "Referenz-ID"},
|
||||
"workflowId": {"en": "Workflow ID", "de": "Workflow-ID"},
|
||||
"featureInstanceId": {"en": "Feature Instance ID", "de": "Feature-Instanz-ID"},
|
||||
"featureCode": {"en": "Feature Code", "de": "Feature-Code"},
|
||||
"aicoreProvider": {"en": "AI Provider", "de": "AI-Anbieter"},
|
||||
"aicoreModel": {"en": "AI Model", "de": "AI-Modell"},
|
||||
"createdByUserId": {"en": "Created By User", "de": "Erstellt von Benutzer"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class BillingSettings(BaseModel):
|
||||
"""Billing settings per mandate."""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||
)
|
||||
mandateId: str = Field(..., description="Foreign key to Mandate (UNIQUE)")
|
||||
billingModel: BillingModelEnum = Field(..., description="Billing model")
|
||||
|
||||
# Configuration
|
||||
defaultUserCredit: float = Field(default=10.0, description="Initial credit in CHF for new users (PREPAY_USER)")
|
||||
warningThresholdPercent: float = Field(default=10.0, description="Warning threshold as percentage")
|
||||
blockOnZeroBalance: bool = Field(default=True, description="Block AI features when balance is zero")
|
||||
|
||||
# Billing address (required for CREDIT_POSTPAY)
|
||||
billingAddress: Optional[BillingAddress] = Field(None, description="Billing address")
|
||||
|
||||
# Notifications
|
||||
notifyEmails: List[str] = Field(default_factory=list, description="Email addresses for billing notifications")
|
||||
notifyOnWarning: bool = Field(default=True, description="Send email when warning threshold is reached")
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"BillingSettings",
|
||||
{"en": "Billing Settings", "de": "Abrechnungseinstellungen"},
|
||||
{
|
||||
"id": {"en": "ID", "de": "ID"},
|
||||
"mandateId": {"en": "Mandate ID", "de": "Mandanten-ID"},
|
||||
"billingModel": {"en": "Billing Model", "de": "Abrechnungsmodell"},
|
||||
"defaultUserCredit": {"en": "Default User Credit (CHF)", "de": "Standard-Startguthaben (CHF)"},
|
||||
"warningThresholdPercent": {"en": "Warning Threshold (%)", "de": "Warnschwelle (%)"},
|
||||
"blockOnZeroBalance": {"en": "Block on Zero Balance", "de": "Bei 0 blockieren"},
|
||||
"billingAddress": {"en": "Billing Address", "de": "Rechnungsadresse"},
|
||||
"notifyEmails": {"en": "Notification Emails", "de": "Benachrichtigungs-Emails"},
|
||||
"notifyOnWarning": {"en": "Notify on Warning", "de": "Bei Warnung benachrichtigen"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class UsageStatistics(BaseModel):
|
||||
"""Aggregated usage statistics for quick retrieval."""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||
)
|
||||
accountId: str = Field(..., description="Foreign key to BillingAccount")
|
||||
periodType: PeriodTypeEnum = Field(..., description="Period type")
|
||||
periodStart: date = Field(..., description="Period start date")
|
||||
|
||||
# Aggregated values
|
||||
totalCostCHF: float = Field(default=0.0, description="Total cost in CHF")
|
||||
transactionCount: int = Field(default=0, description="Number of transactions")
|
||||
|
||||
# Breakdown by provider
|
||||
costByProvider: Dict[str, float] = Field(
|
||||
default_factory=dict,
|
||||
description="Cost breakdown by provider (e.g., {'anthropic': 12.50, 'openai': 8.30})"
|
||||
)
|
||||
|
||||
# Breakdown by feature
|
||||
costByFeature: Dict[str, float] = Field(
|
||||
default_factory=dict,
|
||||
description="Cost breakdown by feature (e.g., {'chatplayground': 15.00, 'automation': 5.80})"
|
||||
)
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"UsageStatistics",
|
||||
{"en": "Usage Statistics", "de": "Nutzungsstatistik"},
|
||||
{
|
||||
"id": {"en": "ID", "de": "ID"},
|
||||
"accountId": {"en": "Account ID", "de": "Konto-ID"},
|
||||
"periodType": {"en": "Period Type", "de": "Periodentyp"},
|
||||
"periodStart": {"en": "Period Start", "de": "Periodenbeginn"},
|
||||
"totalCostCHF": {"en": "Total Cost (CHF)", "de": "Gesamtkosten (CHF)"},
|
||||
"transactionCount": {"en": "Transaction Count", "de": "Anzahl Transaktionen"},
|
||||
"costByProvider": {"en": "Cost by Provider", "de": "Kosten nach Anbieter"},
|
||||
"costByFeature": {"en": "Cost by Feature", "de": "Kosten nach Feature"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Response Models for API
|
||||
# ============================================================================
|
||||
|
||||
class BillingBalanceResponse(BaseModel):
|
||||
"""Response model for balance endpoint."""
|
||||
mandateId: str
|
||||
mandateName: str
|
||||
billingModel: BillingModelEnum
|
||||
balance: float
|
||||
currency: str = "CHF"
|
||||
warningThreshold: float
|
||||
isWarning: bool
|
||||
creditLimit: Optional[float] = None
|
||||
|
||||
|
||||
class BillingStatisticsChartData(BaseModel):
|
||||
"""Chart data point for statistics."""
|
||||
label: str
|
||||
totalCost: float
|
||||
byProvider: Dict[str, float]
|
||||
|
||||
|
||||
class BillingStatisticsResponse(BaseModel):
|
||||
"""Response model for statistics endpoint."""
|
||||
mandateId: str
|
||||
period: PeriodTypeEnum
|
||||
year: int
|
||||
month: Optional[int] = None
|
||||
currency: str = "CHF"
|
||||
data: List[BillingStatisticsChartData]
|
||||
totals: Dict[str, Any]
|
||||
|
||||
|
||||
class BillingCheckResult(BaseModel):
|
||||
"""Result of a billing balance check."""
|
||||
allowed: bool
|
||||
reason: Optional[str] = None
|
||||
currentBalance: Optional[float] = None
|
||||
requiredAmount: Optional[float] = None
|
||||
billingModel: Optional[BillingModelEnum] = None
|
||||
|
|
@ -12,6 +12,8 @@ import uuid
|
|||
|
||||
class ChatStat(BaseModel):
|
||||
"""Statistics for chat operations. User-owned, no mandate context."""
|
||||
model_config = {"populate_by_name": True, "extra": "allow"} # Allow DB system fields
|
||||
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||
)
|
||||
|
|
@ -26,7 +28,7 @@ class ChatStat(BaseModel):
|
|||
errorCount: Optional[int] = Field(None, description="Number of errors encountered")
|
||||
process: Optional[str] = Field(None, description="The process that delivers the stats data (e.g. 'action.outlook.readMails', 'ai.process.document.name')")
|
||||
engine: Optional[str] = Field(None, description="The engine used (e.g. 'ai.anthropic.35', 'ai.tavily.basic', 'renderer.docx')")
|
||||
priceUsd: Optional[float] = Field(None, description="Calculated price in USD for the operation")
|
||||
priceCHF: Optional[float] = Field(None, description="Calculated price in USD for the operation")
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
|
|
@ -41,7 +43,7 @@ registerModelLabels(
|
|||
"errorCount": {"en": "Error Count", "fr": "Nombre d'erreurs"},
|
||||
"process": {"en": "Process", "fr": "Processus"},
|
||||
"engine": {"en": "Engine", "fr": "Moteur"},
|
||||
"priceUsd": {"en": "Price USD", "fr": "Prix USD"},
|
||||
"priceCHF": {"en": "Price CHF", "fr": "Prix CHF"},
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -301,6 +303,7 @@ registerModelLabels(
|
|||
class ChatWorkflow(BaseModel):
|
||||
"""Chat workflow container. User-owned, no mandate context."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
featureInstanceId: Optional[str] = Field(None, description="Feature instance ID for multi-tenancy isolation", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
status: str = Field(default="running", description="Current status of the workflow", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": [
|
||||
{"value": "running", "label": {"en": "Running", "fr": "En cours"}},
|
||||
{"value": "completed", "label": {"en": "Completed", "fr": "Terminé"}},
|
||||
|
|
@ -374,6 +377,7 @@ registerModelLabels(
|
|||
{"en": "Chat Workflow", "fr": "Flux de travail de chat"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance de fonctionnalité"},
|
||||
"status": {"en": "Status", "fr": "Statut"},
|
||||
"name": {"en": "Name", "fr": "Nom"},
|
||||
"currentRound": {"en": "Current Round", "fr": "Tour actuel"},
|
||||
|
|
@ -399,6 +403,7 @@ class UserInputRequest(BaseModel):
|
|||
listFileId: List[str] = Field(default_factory=list, description="List of file IDs")
|
||||
userLanguage: str = Field(default="en", description="User's preferred language")
|
||||
workflowId: Optional[str] = Field(None, description="Optional ID of the workflow to continue")
|
||||
allowedProviders: Optional[List[str]] = Field(None, description="List of allowed AI providers (multiselect)")
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
|
|
@ -408,6 +413,7 @@ registerModelLabels(
|
|||
"prompt": {"en": "Prompt", "fr": "Invite"},
|
||||
"listFileId": {"en": "File IDs", "fr": "IDs des fichiers"},
|
||||
"userLanguage": {"en": "User Language", "fr": "Langue de l'utilisateur"},
|
||||
"preferredProvider": {"en": "Preferred Provider", "fr": "Fournisseur préféré"},
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
"""File-related datamodels: FileItem, FilePreview, FileData."""
|
||||
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
import uuid
|
||||
|
|
@ -11,6 +11,7 @@ import base64
|
|||
|
||||
|
||||
class FileItem(BaseModel):
|
||||
model_config = ConfigDict(extra='allow') # Preserve system fields (_createdBy, _createdAt, etc.)
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
mandateId: Optional[str] = Field(default="", description="ID of the mandate this file belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
featureInstanceId: Optional[str] = Field(default="", description="ID of the feature instance this file belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class UserMandate(BaseModel):
|
|||
)
|
||||
mandateId: str = Field(
|
||||
description="FK → Mandate.id (CASCADE DELETE)",
|
||||
json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True, "frontend_fk_source": "/api/mandates/", "frontend_fk_display_field": "name"}
|
||||
json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True, "frontend_fk_source": "/api/mandates/", "frontend_fk_display_field": "label"}
|
||||
)
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class Role(BaseModel):
|
|||
mandateId: Optional[str] = Field(
|
||||
default=None,
|
||||
description="FK → Mandate.id (CASCADE DELETE). Null = Global/Template role.",
|
||||
json_schema_extra={"frontend_type": "select", "frontend_readonly": True, "frontend_visible": True, "frontend_required": False, "frontend_fk_source": "/api/mandates/", "frontend_fk_display_field": "name"}
|
||||
json_schema_extra={"frontend_type": "select", "frontend_readonly": True, "frontend_visible": True, "frontend_required": False, "frontend_fk_source": "/api/mandates/", "frontend_fk_display_field": "label"}
|
||||
)
|
||||
featureInstanceId: Optional[str] = Field(
|
||||
default=None,
|
||||
|
|
|
|||
|
|
@ -73,16 +73,29 @@ class Mandate(BaseModel):
|
|||
description="Name of the mandate",
|
||||
json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": True}
|
||||
)
|
||||
description: Optional[str] = Field(
|
||||
label: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Description of the mandate",
|
||||
json_schema_extra={"frontend_type": "textarea", "frontend_readonly": False, "frontend_required": False}
|
||||
description="Display label of the mandate",
|
||||
json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": False}
|
||||
)
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Indicates whether the mandate is enabled",
|
||||
json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False}
|
||||
)
|
||||
isSystem: bool = Field(
|
||||
default=False,
|
||||
description="Whether this is a system mandate (e.g. root mandate). Cannot be deleted.",
|
||||
json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": True, "frontend_required": False}
|
||||
)
|
||||
|
||||
@field_validator('isSystem', mode='before')
|
||||
@classmethod
|
||||
def _coerceIsSystem(cls, v):
|
||||
"""Coerce None to False (for existing DB records without isSystem field)."""
|
||||
if v is None:
|
||||
return False
|
||||
return v
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
|
|
@ -91,8 +104,9 @@ registerModelLabels(
|
|||
{
|
||||
"id": {"en": "ID", "de": "ID", "fr": "ID"},
|
||||
"name": {"en": "Name", "de": "Name", "fr": "Nom"},
|
||||
"description": {"en": "Description", "de": "Beschreibung", "fr": "Description"},
|
||||
"label": {"en": "Label", "de": "Label", "fr": "Libellé"},
|
||||
"enabled": {"en": "Enabled", "de": "Aktiviert", "fr": "Activé"},
|
||||
"isSystem": {"en": "System Mandate", "de": "System-Mandant", "fr": "Mandat système"},
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -114,6 +128,7 @@ class UserConnection(BaseModel):
|
|||
{"value": "none", "label": {"en": "None", "fr": "Aucun"}},
|
||||
]})
|
||||
tokenExpiresAt: Optional[float] = Field(None, description="When the current token expires (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
|
||||
grantedScopes: Optional[List[str]] = Field(None, description="OAuth scopes granted for this connection", json_schema_extra={"frontend_type": "list", "frontend_readonly": True, "frontend_required": False})
|
||||
|
||||
@computed_field
|
||||
@computed_field
|
||||
|
|
@ -146,6 +161,7 @@ registerModelLabels(
|
|||
"expiresAt": {"en": "Expires At", "de": "Läuft ab am", "fr": "Expire le"},
|
||||
"tokenStatus": {"en": "Connection Status", "de": "Verbindungsstatus", "fr": "Statut de connexion"},
|
||||
"tokenExpiresAt": {"en": "Expires At", "de": "Läuft ab am", "fr": "Expire le"},
|
||||
"grantedScopes": {"en": "Granted Scopes", "de": "Gewährte Berechtigungen", "fr": "Autorisations accordées"},
|
||||
"connectionReference": {"en": "Connection Reference", "de": "Verbindungsreferenz", "fr": "Référence de connexion"},
|
||||
"displayLabel": {"en": "Display Label", "de": "Anzeigebezeichnung", "fr": "Libellé d'affichage"},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -3,22 +3,33 @@
|
|||
"""Utility datamodels: Prompt, TextMultilingual."""
|
||||
|
||||
from typing import Dict, Optional
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
import uuid
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
model_config = ConfigDict(extra='allow') # Preserve system fields (_createdBy, _createdAt, etc.)
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
mandateId: str = Field(default="", description="ID of the mandate this prompt belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
isSystem: bool = Field(default=False, description="System prompt visible to all users (read-only for non-SysAdmin)", json_schema_extra={"frontend_type": "boolean", "frontend_readonly": True, "frontend_required": False})
|
||||
content: str = Field(description="Content of the prompt", json_schema_extra={"frontend_type": "textarea", "frontend_readonly": False, "frontend_required": True})
|
||||
name: str = Field(description="Name of the prompt", json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": True})
|
||||
|
||||
@field_validator('isSystem', mode='before')
|
||||
@classmethod
|
||||
def _coerceIsSystem(cls, v):
|
||||
"""Existing records may have isSystem=None (field didn't exist). Treat None as False."""
|
||||
if v is None:
|
||||
return False
|
||||
return v
|
||||
registerModelLabels(
|
||||
"Prompt",
|
||||
{"en": "Prompt", "fr": "Invite"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
|
||||
"isSystem": {"en": "System", "fr": "Système"},
|
||||
"content": {"en": "Content", "fr": "Contenu"},
|
||||
"name": {"en": "Name", "fr": "Nom"},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ class AutomationDefinition(BaseModel):
|
|||
eventId: Optional[str] = Field(None, description="Event ID from event management (None if not registered)", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
status: Optional[str] = Field(None, description="Status: 'active' if event is registered, 'inactive' if not (computed, readonly)", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
executionLogs: List[Dict[str, Any]] = Field(default_factory=list, description="List of execution logs, each containing timestamp, workflowId, status, and messages", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
allowedProviders: List[str] = Field(default_factory=list, description="List of allowed AICore providers (e.g., 'anthropic', 'openai'). Empty means all RBAC-permitted providers are allowed.", json_schema_extra={"frontend_type": "multiselect", "frontend_readonly": False, "frontend_required": False})
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
|
|
@ -42,6 +43,7 @@ registerModelLabels(
|
|||
"eventId": {"en": "Event ID", "ge": "Event-ID", "fr": "ID de l'événement"},
|
||||
"status": {"en": "Status", "ge": "Status", "fr": "Statut"},
|
||||
"executionLogs": {"en": "Execution Logs", "ge": "Ausführungsprotokolle", "fr": "Journaux d'exécution"},
|
||||
"allowedProviders": {"en": "Allowed Providers", "ge": "Erlaubte Provider", "fr": "Fournisseurs autorisés"},
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ Uses the PostgreSQL connector for data access with user/mandate filtering.
|
|||
import logging
|
||||
import uuid
|
||||
import math
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
|
||||
from modules.security.rbac import RbacClass
|
||||
|
|
@ -69,8 +68,6 @@ class AutomationObjects:
|
|||
userId=self.userId,
|
||||
)
|
||||
|
||||
# Initialize database system
|
||||
self.db.initDbSystem()
|
||||
logger.debug(f"Automation database initialized for user {self.userId}")
|
||||
|
||||
def setUserContext(self, currentUser: User, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None):
|
||||
|
|
@ -88,7 +85,9 @@ class AutomationObjects:
|
|||
permissions = self.rbac.getUserPermissions(
|
||||
user=self.currentUser,
|
||||
context=AccessRuleContext.DATA,
|
||||
item=objectKey
|
||||
item=objectKey,
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId
|
||||
)
|
||||
|
||||
accessLevel = getattr(permissions, action, AccessLevel.NONE)
|
||||
|
|
@ -99,7 +98,7 @@ class AutomationObjects:
|
|||
return True
|
||||
elif accessLevel == AccessLevel.MY:
|
||||
if recordId:
|
||||
record = self.db.getRecordset(model, {"id": recordId})
|
||||
record = self.db.getRecordset(model, recordFilter={"id": recordId})
|
||||
if record:
|
||||
return record[0].get("_createdBy") == self.userId
|
||||
else:
|
||||
|
|
@ -118,17 +117,17 @@ class AutomationObjects:
|
|||
|
||||
def _enrichAutomationsWithUserAndMandate(self, automations: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Batch enrich automations with user names and mandate names for display.
|
||||
Uses AppObjects interface to fetch users and mandates with proper access control.
|
||||
Batch enrich automations with user names, mandate names and feature instance labels.
|
||||
Uses direct DB lookup (no RBAC) because this is purely cosmetic enrichment —
|
||||
the user already has RBAC-verified access to the automations themselves.
|
||||
"""
|
||||
if not automations:
|
||||
return automations
|
||||
|
||||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||||
|
||||
# Collect all unique user IDs and mandate IDs
|
||||
# Collect all unique IDs
|
||||
userIds = set()
|
||||
mandateIds = set()
|
||||
featureInstanceIds = set()
|
||||
|
||||
for automation in automations:
|
||||
createdBy = automation.get("_createdBy")
|
||||
|
|
@ -138,37 +137,63 @@ class AutomationObjects:
|
|||
mandateId = automation.get("mandateId")
|
||||
if mandateId:
|
||||
mandateIds.add(mandateId)
|
||||
|
||||
featureInstanceId = automation.get("featureInstanceId")
|
||||
if featureInstanceId:
|
||||
featureInstanceIds.add(featureInstanceId)
|
||||
|
||||
# Use AppObjects interface to fetch users (respects access control)
|
||||
appInterface = getAppInterface(self.currentUser)
|
||||
# Use root DB connector for display-only lookups (no RBAC needed)
|
||||
usersMap = {}
|
||||
if userIds:
|
||||
for userId in userIds:
|
||||
user = appInterface.getUser(userId)
|
||||
if user:
|
||||
usersMap[userId] = user.username or user.email or userId
|
||||
|
||||
# Use AppObjects interface to fetch mandates (respects access control)
|
||||
mandatesMap = {}
|
||||
if mandateIds:
|
||||
for mandateId in mandateIds:
|
||||
mandate = appInterface.getMandate(mandateId)
|
||||
if mandate:
|
||||
mandatesMap[mandateId] = mandate.name or mandateId
|
||||
featureInstancesMap = {}
|
||||
try:
|
||||
from modules.datamodels.datamodelUam import UserInDB, Mandate
|
||||
from modules.datamodels.datamodelFeatures import FeatureInstance
|
||||
from modules.security.rootAccess import getRootDbAppConnector
|
||||
dbAppConn = getRootDbAppConnector()
|
||||
|
||||
# Batch fetch user display names
|
||||
if userIds:
|
||||
for userId in userIds:
|
||||
users = dbAppConn.getRecordset(UserInDB, recordFilter={"id": userId})
|
||||
if users:
|
||||
user = users[0]
|
||||
displayName = user.get("fullName") or user.get("username") or user.get("email") or None
|
||||
if displayName:
|
||||
usersMap[userId] = displayName
|
||||
|
||||
# Batch fetch mandate display names
|
||||
if mandateIds:
|
||||
for mandateId in mandateIds:
|
||||
mandates = dbAppConn.getRecordset(Mandate, recordFilter={"id": mandateId})
|
||||
if mandates:
|
||||
label = mandates[0].get("label") or mandates[0].get("name") or None
|
||||
if label:
|
||||
mandatesMap[mandateId] = label
|
||||
|
||||
# Batch fetch feature instance labels
|
||||
if featureInstanceIds:
|
||||
for fiId in featureInstanceIds:
|
||||
instances = dbAppConn.getRecordset(FeatureInstance, recordFilter={"id": fiId})
|
||||
if instances:
|
||||
fi = instances[0]
|
||||
label = fi.get("label") or fi.get("featureCode") or None
|
||||
if label:
|
||||
featureInstancesMap[fiId] = label
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not enrich automations with display names: {e}")
|
||||
|
||||
# Enrich each automation with the fetched data
|
||||
# SECURITY: Never show a fallback name — if lookup fails, show empty string
|
||||
for automation in automations:
|
||||
createdBy = automation.get("_createdBy")
|
||||
if createdBy:
|
||||
automation["_createdByUserName"] = usersMap.get(createdBy, createdBy)
|
||||
else:
|
||||
automation["_createdByUserName"] = "-"
|
||||
automation["_createdByUserName"] = usersMap.get(createdBy, "") if createdBy else ""
|
||||
|
||||
mandateId = automation.get("mandateId")
|
||||
if mandateId:
|
||||
automation["mandateName"] = mandatesMap.get(mandateId, mandateId)
|
||||
else:
|
||||
automation["mandateName"] = "-"
|
||||
automation["mandateName"] = mandatesMap.get(mandateId, "") if mandateId else ""
|
||||
|
||||
featureInstanceId = automation.get("featureInstanceId")
|
||||
automation["featureInstanceName"] = featureInstancesMap.get(featureInstanceId, "") if featureInstanceId else ""
|
||||
|
||||
return automations
|
||||
|
||||
|
|
@ -185,11 +210,13 @@ class AutomationObjects:
|
|||
Supports optional pagination, sorting, and filtering.
|
||||
Computes status field for each automation.
|
||||
"""
|
||||
# Use RBAC filtering
|
||||
# AutomationDefinitions can belong to any feature instance within a mandate.
|
||||
# Filter by mandateId only — not by featureInstanceId — to show all definitions across features.
|
||||
filteredAutomations = getRecordsetWithRBAC(
|
||||
self.db,
|
||||
AutomationDefinition,
|
||||
self.currentUser
|
||||
self.currentUser,
|
||||
mandateId=self.mandateId
|
||||
)
|
||||
|
||||
# Compute status for each automation and normalize executionLogs
|
||||
|
|
@ -272,12 +299,14 @@ class AutomationObjects:
|
|||
If False (default), returns Pydantic model without system fields.
|
||||
"""
|
||||
try:
|
||||
# Use RBAC filtering
|
||||
# AutomationDefinitions can belong to any feature instance within a mandate.
|
||||
# Filter by mandateId only — not by featureInstanceId.
|
||||
filtered = getRecordsetWithRBAC(
|
||||
self.db,
|
||||
AutomationDefinition,
|
||||
self.currentUser,
|
||||
recordFilter={"id": automationId}
|
||||
recordFilter={"id": automationId},
|
||||
mandateId=self.mandateId
|
||||
)
|
||||
|
||||
if not filtered:
|
||||
|
|
@ -353,8 +382,8 @@ class AutomationObjects:
|
|||
if createdAutomation.get("executionLogs") is None:
|
||||
createdAutomation["executionLogs"] = []
|
||||
|
||||
# Trigger automation change callback (async, don't wait)
|
||||
asyncio.create_task(self._notifyAutomationChanged())
|
||||
# Trigger automation change callback
|
||||
self._notifyAutomationChanged()
|
||||
|
||||
# Clean metadata fields and return Pydantic model
|
||||
cleanedRecord = {k: v for k, v in createdAutomation.items() if not k.startswith("_")}
|
||||
|
|
@ -363,6 +392,21 @@ class AutomationObjects:
|
|||
logger.error(f"Error creating automation definition: {str(e)}")
|
||||
raise
|
||||
|
||||
def _saveExecutionLog(self, automationId: str, executionLogs: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Save execution logs to an automation definition WITHOUT RBAC check.
|
||||
|
||||
This is a system-level operation: when a user executes an automation,
|
||||
the execution log must be saved regardless of whether the user has
|
||||
'update' permission on the AutomationDefinition. The user already
|
||||
proved they have execute/read access by loading the automation.
|
||||
"""
|
||||
try:
|
||||
self.db.recordModify(AutomationDefinition, automationId, {"executionLogs": executionLogs})
|
||||
logger.debug(f"Saved execution log for automation {automationId}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not save execution log for automation {automationId}: {e}")
|
||||
|
||||
def updateAutomationDefinition(self, automationId: str, automationData: Dict[str, Any]) -> AutomationDefinition:
|
||||
"""Updates an automation definition, then triggers sync."""
|
||||
try:
|
||||
|
|
@ -383,8 +427,8 @@ class AutomationObjects:
|
|||
if updatedAutomation.get("executionLogs") is None:
|
||||
updatedAutomation["executionLogs"] = []
|
||||
|
||||
# Trigger automation change callback (async, don't wait)
|
||||
asyncio.create_task(self._notifyAutomationChanged())
|
||||
# Trigger automation change callback
|
||||
self._notifyAutomationChanged()
|
||||
|
||||
# Clean metadata fields and return Pydantic model
|
||||
cleanedRecord = {k: v for k, v in updatedAutomation.items() if not k.startswith("_")}
|
||||
|
|
@ -407,8 +451,8 @@ class AutomationObjects:
|
|||
# Delete automation from database
|
||||
self.db.recordDelete(AutomationDefinition, automationId)
|
||||
|
||||
# Trigger automation change callback (async, don't wait)
|
||||
asyncio.create_task(self._notifyAutomationChanged())
|
||||
# Trigger automation change callback
|
||||
self._notifyAutomationChanged()
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
|
|
@ -429,7 +473,9 @@ class AutomationObjects:
|
|||
return getRecordsetWithRBAC(
|
||||
self.db,
|
||||
AutomationDefinition,
|
||||
user
|
||||
user,
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
|
|
@ -441,7 +487,7 @@ class AutomationObjects:
|
|||
Returns automation templates filtered by RBAC (MY = own templates).
|
||||
Supports optional pagination, sorting, and filtering.
|
||||
"""
|
||||
# Use RBAC filtering
|
||||
# Templates are global (not mandate/feature-instance scoped) — no mandateId/featureInstanceId filter
|
||||
filteredTemplates = getRecordsetWithRBAC(
|
||||
self.db,
|
||||
AutomationTemplate,
|
||||
|
|
@ -501,23 +547,24 @@ class AutomationObjects:
|
|||
|
||||
userNameMap = {}
|
||||
for userId in userIds:
|
||||
users = dbAppConn.getRecordset(UserInDB, {"id": userId})
|
||||
users = dbAppConn.getRecordset(UserInDB, recordFilter={"id": userId})
|
||||
if users:
|
||||
user = users[0]
|
||||
fullName = f"{user.get('firstName', '')} {user.get('lastName', '')}".strip()
|
||||
userNameMap[userId] = fullName or user.get("email", "Unknown")
|
||||
displayName = user.get("fullName") or user.get("username") or user.get("email") or None
|
||||
if displayName:
|
||||
userNameMap[userId] = displayName
|
||||
|
||||
# Apply to templates
|
||||
# Apply to templates — SECURITY: no fallback, empty if not found
|
||||
for template in templates:
|
||||
createdBy = template.get("_createdBy")
|
||||
if createdBy and createdBy in userNameMap:
|
||||
template["_createdByUserName"] = userNameMap[createdBy]
|
||||
template["_createdByUserName"] = userNameMap.get(createdBy, "") if createdBy else ""
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not enrich templates with user names: {e}")
|
||||
|
||||
def getAutomationTemplate(self, templateId: str) -> Optional[Dict[str, Any]]:
|
||||
"""Returns an automation template by ID if user has access."""
|
||||
try:
|
||||
# Templates are global — no mandateId/featureInstanceId filter
|
||||
filtered = getRecordsetWithRBAC(
|
||||
self.db,
|
||||
AutomationTemplate,
|
||||
|
|
@ -620,12 +667,13 @@ class AutomationObjects:
|
|||
logger.error(f"Error deleting automation template: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _notifyAutomationChanged(self):
|
||||
"""Notify registered callbacks about automation changes (decoupled from features)."""
|
||||
def _notifyAutomationChanged(self):
|
||||
"""Notify registered callbacks about automation changes (decoupled from features).
|
||||
Sync-safe: works from both sync and async contexts."""
|
||||
try:
|
||||
from modules.shared.callbackRegistry import callbackRegistry
|
||||
# Trigger callbacks without knowing which features are listening
|
||||
await callbackRegistry.trigger('automation.changed', self)
|
||||
callbackRegistry.trigger('automation.changed', self)
|
||||
except Exception as e:
|
||||
logger.error(f"Error notifying automation change: {str(e)}")
|
||||
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ TEMPLATE_ROLES = [
|
|||
"fr": "Visualiseur automatisation - Consulter les automatisations et résultats"
|
||||
},
|
||||
"accessRules": [
|
||||
# UI access to view only - vollqualifizierte ObjectKeys
|
||||
# UI access to view only
|
||||
{"context": "UI", "item": "ui.feature.automation.definitions", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.automation.logs", "view": True},
|
||||
# Read-only DATA access (my level)
|
||||
|
|
@ -113,7 +113,8 @@ def getFeatureDefinition() -> Dict[str, Any]:
|
|||
return {
|
||||
"code": FEATURE_CODE,
|
||||
"label": FEATURE_LABEL,
|
||||
"icon": FEATURE_ICON
|
||||
"icon": FEATURE_ICON,
|
||||
"autoCreateInstance": True, # Automatically create instance in root mandate during bootstrap
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -161,9 +162,131 @@ def registerFeature(catalogService) -> bool:
|
|||
meta=resObj.get("meta")
|
||||
)
|
||||
|
||||
# Sync template roles to database
|
||||
_syncTemplateRolesToDb()
|
||||
|
||||
logger.info(f"Feature '{FEATURE_CODE}' registered {len(UI_OBJECTS)} UI objects and {len(RESOURCE_OBJECTS)} resource objects")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register feature '{FEATURE_CODE}': {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _syncTemplateRolesToDb() -> int:
|
||||
"""
|
||||
Sync template roles and their AccessRules to the database.
|
||||
Creates global template roles (mandateId=None) if they don't exist.
|
||||
|
||||
Returns:
|
||||
Number of roles created/updated
|
||||
"""
|
||||
try:
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
from modules.datamodels.datamodelRbac import Role, AccessRule, AccessRuleContext
|
||||
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Get existing template roles for this feature (Pydantic models)
|
||||
existingRoles = rootInterface.getRolesByFeatureCode(FEATURE_CODE)
|
||||
# Filter to template roles (mandateId is None)
|
||||
templateRoles = [r for r in existingRoles if r.mandateId is None]
|
||||
existingRoleLabels = {r.roleLabel: str(r.id) for r in templateRoles}
|
||||
|
||||
createdCount = 0
|
||||
for roleTemplate in TEMPLATE_ROLES:
|
||||
roleLabel = roleTemplate["roleLabel"]
|
||||
|
||||
if roleLabel in existingRoleLabels:
|
||||
roleId = existingRoleLabels[roleLabel]
|
||||
# Ensure AccessRules exist for this role
|
||||
_ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", []))
|
||||
else:
|
||||
# Create new template role
|
||||
newRole = Role(
|
||||
roleLabel=roleLabel,
|
||||
description=roleTemplate.get("description", {}),
|
||||
featureCode=FEATURE_CODE,
|
||||
mandateId=None, # Global template
|
||||
featureInstanceId=None,
|
||||
isSystemRole=False
|
||||
)
|
||||
createdRole = rootInterface.db.recordCreate(Role, newRole.model_dump())
|
||||
roleId = createdRole.get("id")
|
||||
|
||||
# Create AccessRules for this role
|
||||
_ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", []))
|
||||
|
||||
logger.info(f"Created template role '{roleLabel}' with ID {roleId}")
|
||||
createdCount += 1
|
||||
|
||||
if createdCount > 0:
|
||||
logger.info(f"Feature '{FEATURE_CODE}': Created {createdCount} template roles")
|
||||
|
||||
return createdCount
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing template roles for feature '{FEATURE_CODE}': {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: List[Dict[str, Any]]) -> int:
|
||||
"""
|
||||
Ensure AccessRules exist for a role based on templates.
|
||||
|
||||
Args:
|
||||
rootInterface: Root interface instance
|
||||
roleId: Role ID
|
||||
ruleTemplates: List of rule templates
|
||||
|
||||
Returns:
|
||||
Number of rules created
|
||||
"""
|
||||
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
|
||||
|
||||
# Get existing rules for this role (Pydantic models)
|
||||
existingRules = rootInterface.getAccessRulesByRole(roleId)
|
||||
|
||||
# Create a set of existing rule signatures to avoid duplicates
|
||||
# IMPORTANT: Use .value for enum comparison, not str() which gives "AccessRuleContext.DATA" in Python 3.11+
|
||||
existingSignatures = set()
|
||||
for rule in existingRules:
|
||||
sig = (rule.context.value if rule.context else None, rule.item)
|
||||
existingSignatures.add(sig)
|
||||
|
||||
createdCount = 0
|
||||
for template in ruleTemplates:
|
||||
context = template.get("context", "UI")
|
||||
item = template.get("item")
|
||||
sig = (context, item)
|
||||
|
||||
if sig in existingSignatures:
|
||||
continue
|
||||
|
||||
# Map context string to enum
|
||||
if context == "UI":
|
||||
contextEnum = AccessRuleContext.UI
|
||||
elif context == "DATA":
|
||||
contextEnum = AccessRuleContext.DATA
|
||||
elif context == "RESOURCE":
|
||||
contextEnum = AccessRuleContext.RESOURCE
|
||||
else:
|
||||
contextEnum = context
|
||||
|
||||
newRule = AccessRule(
|
||||
roleId=roleId,
|
||||
context=contextEnum,
|
||||
item=item,
|
||||
view=template.get("view", False),
|
||||
read=template.get("read"),
|
||||
create=template.get("create"),
|
||||
update=template.get("update"),
|
||||
delete=template.get("delete"),
|
||||
)
|
||||
rootInterface.db.recordCreate(AccessRule, newRule.model_dump())
|
||||
createdCount += 1
|
||||
|
||||
if createdCount > 0:
|
||||
logger.debug(f"Created {createdCount} AccessRules for role {roleId}")
|
||||
|
||||
return createdCount
|
||||
|
|
|
|||
|
|
@ -19,8 +19,6 @@ from modules.features.automation.datamodelFeatureAutomation import AutomationDef
|
|||
from modules.datamodels.datamodelChat import ChatWorkflow
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||||
from modules.shared.attributeUtils import getModelAttributeDefinitions
|
||||
from modules.workflows.automation import executeAutomation
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -42,7 +40,7 @@ router = APIRouter(
|
|||
|
||||
@router.get("", response_model=PaginatedResponse[AutomationDefinition])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_automations(
|
||||
def get_automations(
|
||||
request: Request,
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -68,7 +66,9 @@ async def get_automations(
|
|||
detail=f"Invalid pagination parameter: {str(e)}"
|
||||
)
|
||||
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None)
|
||||
# AutomationDefinitions can belong to ANY feature instance within a mandate.
|
||||
# The list endpoint must show all definitions for the user's mandate, not filter by a specific featureInstanceId.
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
|
||||
result = chatInterface.getAllAutomationDefinitions(pagination=paginationParams)
|
||||
|
||||
# If pagination was requested, result is PaginatedResult
|
||||
|
|
@ -107,7 +107,7 @@ async def get_automations(
|
|||
|
||||
@router.post("", response_model=AutomationDefinition)
|
||||
@limiter.limit("10/minute")
|
||||
async def create_automation(
|
||||
def create_automation(
|
||||
request: Request,
|
||||
automation: AutomationDefinition,
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -128,7 +128,7 @@ async def create_automation(
|
|||
)
|
||||
|
||||
@router.get("/attributes", response_model=Dict[str, Any])
|
||||
async def get_automation_attributes(
|
||||
def get_automation_attributes(
|
||||
request: Request
|
||||
) -> Dict[str, Any]:
|
||||
"""Get attribute definitions for AutomationDefinition model"""
|
||||
|
|
@ -137,7 +137,7 @@ async def get_automation_attributes(
|
|||
|
||||
@router.get("/actions")
|
||||
@limiter.limit("30/minute")
|
||||
async def get_available_actions(
|
||||
def get_available_actions(
|
||||
request: Request,
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> JSONResponse:
|
||||
|
|
@ -152,7 +152,7 @@ async def get_available_actions(
|
|||
# Ensure methods are discovered (need a service center for discovery)
|
||||
if not methods:
|
||||
# Create a lightweight service center for method discovery
|
||||
services = getServices(context.user, context.mandateId)
|
||||
services = getServices(context.user, mandateId=context.mandateId)
|
||||
discoverMethods(services)
|
||||
|
||||
actionsList = []
|
||||
|
|
@ -230,14 +230,14 @@ async def get_available_actions(
|
|||
|
||||
@router.get("/{automationId}", response_model=AutomationDefinition)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_automation(
|
||||
def get_automation(
|
||||
request: Request,
|
||||
automationId: str = Path(..., description="Automation ID"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> AutomationDefinition:
|
||||
"""Get a single automation definition by ID"""
|
||||
try:
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None)
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
|
||||
automation = chatInterface.getAutomationDefinition(automationId)
|
||||
if not automation:
|
||||
raise HTTPException(
|
||||
|
|
@ -257,7 +257,7 @@ async def get_automation(
|
|||
|
||||
@router.put("/{automationId}", response_model=AutomationDefinition)
|
||||
@limiter.limit("10/minute")
|
||||
async def update_automation(
|
||||
def update_automation(
|
||||
request: Request,
|
||||
automationId: str = Path(..., description="Automation ID"),
|
||||
automation: AutomationDefinition = Body(...),
|
||||
|
|
@ -265,7 +265,7 @@ async def update_automation(
|
|||
) -> AutomationDefinition:
|
||||
"""Update an automation definition"""
|
||||
try:
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None)
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
|
||||
automationData = automation.model_dump()
|
||||
updated = chatInterface.updateAutomationDefinition(automationId, automationData)
|
||||
return updated
|
||||
|
|
@ -285,7 +285,7 @@ async def update_automation(
|
|||
|
||||
@router.patch("/{automationId}/status")
|
||||
@limiter.limit("30/minute")
|
||||
async def update_automation_status(
|
||||
def update_automation_status(
|
||||
request: Request,
|
||||
automationId: str = Path(..., description="Automation ID"),
|
||||
active: bool = Body(..., embed=True),
|
||||
|
|
@ -293,7 +293,7 @@ async def update_automation_status(
|
|||
) -> AutomationDefinition:
|
||||
"""Update only the active status of an automation definition"""
|
||||
try:
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None)
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
|
||||
|
||||
# Get existing automation
|
||||
automation = chatInterface.getAutomationDefinition(automationId)
|
||||
|
|
@ -326,14 +326,14 @@ async def update_automation_status(
|
|||
|
||||
@router.delete("/{automationId}")
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_automation(
|
||||
def delete_automation(
|
||||
request: Request,
|
||||
automationId: str = Path(..., description="Automation ID"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Response:
|
||||
"""Delete an automation definition"""
|
||||
try:
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None)
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
|
||||
success = chatInterface.deleteAutomationDefinition(automationId)
|
||||
if success:
|
||||
return Response(status_code=204)
|
||||
|
|
@ -366,8 +366,15 @@ async def execute_automation_route(
|
|||
"""Execute an automation immediately (test mode)"""
|
||||
try:
|
||||
from modules.services import getInterface as getServices
|
||||
services = getServices(context.user, context.mandateId)
|
||||
workflow = await executeAutomation(automationId, services)
|
||||
services = getServices(context.user, mandateId=context.mandateId, featureInstanceId=context.featureInstanceId)
|
||||
|
||||
# Load automation with current user's context (user has RBAC permissions via UI)
|
||||
automation = services.interfaceDbAutomation.getAutomationDefinition(automationId, includeSystemFields=True)
|
||||
if not automation:
|
||||
raise ValueError(f"Automation {automationId} not found")
|
||||
|
||||
from modules.workflows.automation import executeAutomation
|
||||
workflow = await executeAutomation(automationId, automation, context.user, services)
|
||||
return workflow
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -407,7 +414,7 @@ templateAttributes = getModelAttributeDefinitions(AutomationTemplate)
|
|||
|
||||
@templateRouter.get("", response_model=PaginatedResponse[AutomationTemplate])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_db_templates(
|
||||
def get_db_templates(
|
||||
request: Request,
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -470,7 +477,7 @@ async def get_db_templates(
|
|||
|
||||
|
||||
@templateRouter.get("/attributes", response_model=Dict[str, Any])
|
||||
async def get_template_attributes(
|
||||
def get_template_attributes(
|
||||
request: Request
|
||||
) -> Dict[str, Any]:
|
||||
"""Get attribute definitions for AutomationTemplate model"""
|
||||
|
|
@ -479,7 +486,7 @@ async def get_template_attributes(
|
|||
|
||||
@templateRouter.get("/{templateId}")
|
||||
@limiter.limit("30/minute")
|
||||
async def get_db_template(
|
||||
def get_db_template(
|
||||
request: Request,
|
||||
templateId: str = Path(..., description="Template ID"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -511,7 +518,7 @@ async def get_db_template(
|
|||
|
||||
@templateRouter.post("")
|
||||
@limiter.limit("10/minute")
|
||||
async def create_db_template(
|
||||
def create_db_template(
|
||||
request: Request,
|
||||
templateData: Dict[str, Any] = Body(...),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -542,7 +549,7 @@ async def create_db_template(
|
|||
|
||||
@templateRouter.put("/{templateId}")
|
||||
@limiter.limit("10/minute")
|
||||
async def update_db_template(
|
||||
def update_db_template(
|
||||
request: Request,
|
||||
templateId: str = Path(..., description="Template ID"),
|
||||
templateData: Dict[str, Any] = Body(...),
|
||||
|
|
@ -574,7 +581,7 @@ async def update_db_template(
|
|||
|
||||
@templateRouter.delete("/{templateId}")
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_db_template(
|
||||
def delete_db_template(
|
||||
request: Request,
|
||||
templateId: str = Path(..., description="Template ID"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
|
|||
|
|
@ -2,8 +2,14 @@
|
|||
# All rights reserved.
|
||||
"""
|
||||
Chatbot feature - LangGraph-based chatbot implementation.
|
||||
Lazy-loaded to avoid importing langgraph/langchain at boot time.
|
||||
"""
|
||||
|
||||
from .service import chatProcess
|
||||
|
||||
async def chatProcess(*args, **kwargs):
|
||||
"""Lazy wrapper - imports the real chatProcess on first call to defer langgraph loading."""
|
||||
from .service import chatProcess as _chatProcess
|
||||
return await _chatProcess(*args, **kwargs)
|
||||
|
||||
|
||||
__all__ = ['chatProcess']
|
||||
|
|
|
|||
|
|
@ -329,9 +329,6 @@ class ChatObjects:
|
|||
userId=self.userId
|
||||
)
|
||||
|
||||
# Initialize database system
|
||||
self.db.initDbSystem()
|
||||
|
||||
logger.info("Database initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database: {str(e)}")
|
||||
|
|
@ -363,10 +360,12 @@ class ChatObjects:
|
|||
return False
|
||||
|
||||
tableName = modelClass.__name__
|
||||
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
||||
objectKey = buildDataObjectKey(tableName, featureCode=self.featureCode if hasattr(self, 'featureCode') else None)
|
||||
permissions = self.rbac.getUserPermissions(
|
||||
self.currentUser,
|
||||
AccessRuleContext.DATA,
|
||||
tableName,
|
||||
objectKey,
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId
|
||||
)
|
||||
|
|
@ -1116,7 +1115,7 @@ class ChatObjects:
|
|||
|
||||
# Emit message event for streaming (if event manager is available)
|
||||
try:
|
||||
from modules.features.chatbot.eventManager import get_event_manager
|
||||
from modules.features.chatbot.eventManager import get_event_manager # type: ignore
|
||||
event_manager = get_event_manager()
|
||||
message_timestamp = parseTimestamp(chat_message.publishedAt, default=getUtcTimestamp())
|
||||
# Emit message event in exact chatData format: {type, createdAt, item}
|
||||
|
|
@ -1514,7 +1513,7 @@ class ChatObjects:
|
|||
# Only emit events for chatbot workflows, not for automation or dynamic workflows
|
||||
if workflow.workflowMode == WorkflowModeEnum.WORKFLOW_CHATBOT:
|
||||
try:
|
||||
from modules.features.chatbot.eventManager import get_event_manager
|
||||
from modules.features.chatbot.eventManager import get_event_manager # type: ignore
|
||||
event_manager = get_event_manager()
|
||||
log_timestamp = parseTimestamp(createdLog.get("timestamp"), default=getUtcTimestamp())
|
||||
# Emit log event in exact chatData format: {type, createdAt, item}
|
||||
|
|
@ -1563,8 +1562,8 @@ class ChatObjects:
|
|||
if not stats:
|
||||
return []
|
||||
|
||||
# Return all stats records sorted by creation time
|
||||
stats.sort(key=lambda x: x.get("created_at", ""))
|
||||
# Return all stats records sorted by _createdAt (system field from DB)
|
||||
stats.sort(key=lambda x: x.get("_createdAt", 0))
|
||||
# Ensure mandateId and featureInstanceId are set for each stat
|
||||
return [ChatStat(**{**stat, "mandateId": stat.get("mandateId") or self.mandateId or "", "featureInstanceId": stat.get("featureInstanceId") or self.featureInstanceId or ""}) for stat in stats]
|
||||
|
||||
|
|
@ -1680,11 +1679,12 @@ class ChatObjects:
|
|||
"item": chatLog
|
||||
})
|
||||
|
||||
# Get stats list
|
||||
# Get stats - ChatStat model now supports _createdAt via extra="allow"
|
||||
stats = self.getStats(workflowId)
|
||||
for stat in stats:
|
||||
# Apply timestamp filtering in Python
|
||||
stat_timestamp = stat.createdAt if hasattr(stat, 'createdAt') else getUtcTimestamp()
|
||||
# Use _createdAt (system field from DB, preserved via model_config extra="allow")
|
||||
stat_timestamp = getattr(stat, '_createdAt', None) or getUtcTimestamp()
|
||||
if afterTimestamp is not None and stat_timestamp <= afterTimestamp:
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -32,9 +32,6 @@ from modules.datamodels.datamodelPagination import PaginationParams, PaginatedRe
|
|||
from modules.features.chatbot import chatProcess
|
||||
from modules.features.chatbot.streaming.events import get_event_manager
|
||||
|
||||
# Import workflow control functions
|
||||
from modules.workflows.automation import chatStop
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -55,7 +52,7 @@ def _getServiceChat(context: RequestContext, instanceId: Optional[str] = None):
|
|||
)
|
||||
|
||||
|
||||
async def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str:
|
||||
def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str:
|
||||
"""
|
||||
Validate that the user has access to the feature instance.
|
||||
Returns the mandateId for the instance.
|
||||
|
|
@ -124,7 +121,7 @@ async def stream_chatbot_start(
|
|||
- Query parameter takes precedence if both are provided
|
||||
"""
|
||||
# Validate instance access
|
||||
mandateId = await _validateInstanceAccess(instanceId, context)
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
|
||||
event_manager = get_event_manager()
|
||||
|
||||
|
|
@ -323,7 +320,7 @@ async def stop_chatbot(
|
|||
) -> ChatWorkflow:
|
||||
"""Stops a running chatbot workflow."""
|
||||
# Validate instance access
|
||||
await _validateInstanceAccess(instanceId, context)
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
|
||||
try:
|
||||
# Get chatbot interface with instance context
|
||||
|
|
@ -392,7 +389,7 @@ async def stop_chatbot(
|
|||
# to prevent "threads" from being matched as a workflowId
|
||||
@router.get("/{instanceId}/threads")
|
||||
@limiter.limit("120/minute")
|
||||
async def get_chatbot_threads(
|
||||
def get_chatbot_threads(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
workflowId: Optional[str] = Query(None, description="Optional workflow ID to get details and chat data for a specific thread"),
|
||||
|
|
@ -406,7 +403,7 @@ async def get_chatbot_threads(
|
|||
- If workflowId is not provided: Returns a paginated list of all workflows
|
||||
"""
|
||||
# Validate instance access
|
||||
mandateId = await _validateInstanceAccess(instanceId, context)
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
|
||||
try:
|
||||
interfaceDbChat = _getServiceChat(context, instanceId)
|
||||
|
|
@ -523,7 +520,7 @@ async def get_chatbot_threads(
|
|||
# NOTE: This catch-all route MUST be defined AFTER more specific routes like /threads
|
||||
@router.delete("/{instanceId}/{workflowId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("120/minute")
|
||||
async def delete_chatbot(
|
||||
def delete_chatbot(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
workflowId: str = Path(..., description="ID of the workflow to delete"),
|
||||
|
|
@ -531,7 +528,7 @@ async def delete_chatbot(
|
|||
) -> Dict[str, Any]:
|
||||
"""Deletes a chatbot workflow and its associated data."""
|
||||
# Validate instance access - if user has access to instance, they can delete their workflows
|
||||
mandateId = await _validateInstanceAccess(instanceId, context)
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
|
||||
try:
|
||||
# Get service center
|
||||
|
|
|
|||
|
|
@ -91,8 +91,10 @@ async def chatProcess(
|
|||
ChatWorkflow instance
|
||||
"""
|
||||
try:
|
||||
# Get services with mandate context
|
||||
services = getServices(currentUser, mandateId)
|
||||
# Get services with mandate and feature instance context
|
||||
services = getServices(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId)
|
||||
services.featureCode = 'chatbot'
|
||||
|
||||
interfaceDbChat = services.interfaceDbChat
|
||||
|
||||
# Get event manager and create queue if needed
|
||||
|
|
@ -698,7 +700,7 @@ async def _convert_file_ids_to_document_references(
|
|||
# Search database if not found in messages
|
||||
if not document_id:
|
||||
try:
|
||||
from modules.shared.databaseUtils import getRecordsetWithRBAC
|
||||
from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
|
||||
documents = getRecordsetWithRBAC(
|
||||
services.interfaceDbChat.db,
|
||||
ChatDocument,
|
||||
|
|
|
|||
6
modules/features/chatplayground/__init__.py
Normal file
6
modules/features/chatplayground/__init__.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Chat Playground Feature Container.
|
||||
Provides workflow-based chat playground functionality.
|
||||
"""
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Chat Playground Feature Interface.
|
||||
Wrapper around interfaceDbChat with feature instance context.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.interfaces import interfaceDbChat
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Feature code constant
|
||||
FEATURE_CODE = "chatplayground"
|
||||
|
||||
# Singleton instances cache
|
||||
_instances: Dict[str, "ChatPlaygroundObjects"] = {}
|
||||
|
||||
|
||||
def getInterface(currentUser: User, mandateId: str = None, featureInstanceId: str = None) -> "ChatPlaygroundObjects":
|
||||
"""
|
||||
Factory function to get or create a ChatPlaygroundObjects instance.
|
||||
Uses singleton pattern per user context.
|
||||
|
||||
Args:
|
||||
currentUser: Current user object
|
||||
mandateId: Mandate ID
|
||||
featureInstanceId: Feature instance ID
|
||||
|
||||
Returns:
|
||||
ChatPlaygroundObjects instance
|
||||
"""
|
||||
cacheKey = f"{currentUser.id}_{mandateId}_{featureInstanceId}"
|
||||
|
||||
if cacheKey not in _instances:
|
||||
_instances[cacheKey] = ChatPlaygroundObjects(currentUser, mandateId, featureInstanceId)
|
||||
else:
|
||||
# Update context if needed
|
||||
_instances[cacheKey].setUserContext(currentUser, mandateId, featureInstanceId)
|
||||
|
||||
return _instances[cacheKey]
|
||||
|
||||
|
||||
class ChatPlaygroundObjects:
|
||||
"""
|
||||
Chat Playground feature interface.
|
||||
Wraps the shared interfaceDbChat with feature instance context.
|
||||
"""
|
||||
|
||||
FEATURE_CODE = FEATURE_CODE
|
||||
|
||||
def __init__(self, currentUser: User, mandateId: str = None, featureInstanceId: str = None):
|
||||
"""
|
||||
Initialize the Chat Playground interface.
|
||||
|
||||
Args:
|
||||
currentUser: Current user object
|
||||
mandateId: Mandate ID
|
||||
featureInstanceId: Feature instance ID
|
||||
"""
|
||||
self.currentUser = currentUser
|
||||
self.mandateId = mandateId
|
||||
self.featureInstanceId = featureInstanceId
|
||||
|
||||
# Get the underlying chat interface
|
||||
self._chatInterface = interfaceDbChat.getInterface(
|
||||
currentUser,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=featureInstanceId
|
||||
)
|
||||
|
||||
def setUserContext(self, currentUser: User, mandateId: str = None, featureInstanceId: str = None):
|
||||
"""
|
||||
Update the user context.
|
||||
|
||||
Args:
|
||||
currentUser: Current user object
|
||||
mandateId: Mandate ID
|
||||
featureInstanceId: Feature instance ID
|
||||
"""
|
||||
self.currentUser = currentUser
|
||||
self.mandateId = mandateId
|
||||
self.featureInstanceId = featureInstanceId
|
||||
|
||||
# Update underlying interface
|
||||
self._chatInterface = interfaceDbChat.getInterface(
|
||||
currentUser,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=featureInstanceId
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Delegated methods from interfaceDbChat
|
||||
# =========================================================================
|
||||
|
||||
def getWorkflow(self, workflowId: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a workflow by ID."""
|
||||
return self._chatInterface.getWorkflow(workflowId)
|
||||
|
||||
def getWorkflows(self, pagination=None) -> Dict[str, Any]:
|
||||
"""Get all workflows with pagination."""
|
||||
return self._chatInterface.getWorkflows(pagination=pagination)
|
||||
|
||||
def getUnifiedChatData(self, workflowId: str, afterTimestamp: float = None) -> Dict[str, Any]:
|
||||
"""Get unified chat data for a workflow."""
|
||||
return self._chatInterface.getUnifiedChatData(workflowId, afterTimestamp)
|
||||
|
||||
def createWorkflow(self, workflow) -> Dict[str, Any]:
|
||||
"""Create a new workflow."""
|
||||
return self._chatInterface.createWorkflow(workflow)
|
||||
|
||||
def updateWorkflow(self, workflowId: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Update a workflow."""
|
||||
return self._chatInterface.updateWorkflow(workflowId, updates)
|
||||
|
||||
def deleteWorkflow(self, workflowId: str) -> bool:
|
||||
"""Delete a workflow."""
|
||||
return self._chatInterface.deleteWorkflow(workflowId)
|
||||
|
||||
def getMessages(self, workflowId: str) -> List[Dict[str, Any]]:
|
||||
"""Get messages for a workflow."""
|
||||
return self._chatInterface.getMessages(workflowId)
|
||||
|
||||
def createMessage(self, message) -> Dict[str, Any]:
|
||||
"""Create a new message."""
|
||||
return self._chatInterface.createMessage(message)
|
||||
|
||||
def getLogs(self, workflowId: str) -> List[Dict[str, Any]]:
|
||||
"""Get logs for a workflow."""
|
||||
return self._chatInterface.getLogs(workflowId)
|
||||
|
||||
def createLog(self, log) -> Dict[str, Any]:
|
||||
"""Create a new log entry."""
|
||||
return self._chatInterface.createLog(log)
|
||||
|
||||
def getStats(self, workflowId: str) -> List[Dict[str, Any]]:
|
||||
"""Get stats for a workflow."""
|
||||
return self._chatInterface.getStats(workflowId)
|
||||
|
||||
def createStat(self, stat) -> Dict[str, Any]:
|
||||
"""Create a new stat entry."""
|
||||
return self._chatInterface.createStat(stat)
|
||||
288
modules/features/chatplayground/mainChatplayground.py
Normal file
288
modules/features/chatplayground/mainChatplayground.py
Normal file
|
|
@ -0,0 +1,288 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Chat Playground Feature Container - Main Module.
|
||||
Handles feature initialization and RBAC catalog registration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Feature metadata
|
||||
FEATURE_CODE = "chatplayground"
|
||||
FEATURE_LABEL = {"en": "Chat Playground", "de": "Chat Playground", "fr": "Chat Playground"}
|
||||
FEATURE_ICON = "mdi-message-text"
|
||||
|
||||
# UI Objects for RBAC catalog
|
||||
UI_OBJECTS = [
|
||||
{
|
||||
"objectKey": "ui.feature.chatplayground.playground",
|
||||
"label": {"en": "Playground", "de": "Playground", "fr": "Playground"},
|
||||
"meta": {"area": "playground"}
|
||||
},
|
||||
{
|
||||
"objectKey": "ui.feature.chatplayground.workflows",
|
||||
"label": {"en": "Workflows", "de": "Workflows", "fr": "Workflows"},
|
||||
"meta": {"area": "workflows"}
|
||||
},
|
||||
]
|
||||
|
||||
# Resource Objects for RBAC catalog
|
||||
RESOURCE_OBJECTS = [
|
||||
{
|
||||
"objectKey": "resource.feature.chatplayground.start",
|
||||
"label": {"en": "Start Workflow", "de": "Workflow starten", "fr": "Démarrer workflow"},
|
||||
"meta": {"endpoint": "/api/chatplayground/{instanceId}/start", "method": "POST"}
|
||||
},
|
||||
{
|
||||
"objectKey": "resource.feature.chatplayground.stop",
|
||||
"label": {"en": "Stop Workflow", "de": "Workflow stoppen", "fr": "Arrêter workflow"},
|
||||
"meta": {"endpoint": "/api/chatplayground/{instanceId}/{workflowId}/stop", "method": "POST"}
|
||||
},
|
||||
{
|
||||
"objectKey": "resource.feature.chatplayground.chatData",
|
||||
"label": {"en": "Get Chat Data", "de": "Chat-Daten abrufen", "fr": "Récupérer données chat"},
|
||||
"meta": {"endpoint": "/api/chatplayground/{instanceId}/{workflowId}/chatData", "method": "GET"}
|
||||
},
|
||||
]
|
||||
|
||||
# Template roles for this feature
|
||||
# Role names MUST follow convention: {featureCode}-{roleName}
|
||||
TEMPLATE_ROLES = [
|
||||
{
|
||||
"roleLabel": "chatplayground-viewer",
|
||||
"description": {
|
||||
"en": "Chat Playground Viewer - View chat playground (read-only)",
|
||||
"de": "Chat Playground Betrachter - Chat Playground ansehen (nur lesen)",
|
||||
"fr": "Visualiseur Chat Playground - Consulter le chat playground (lecture seule)"
|
||||
},
|
||||
"accessRules": [
|
||||
# UI: only playground view, NO workflows
|
||||
{"context": "UI", "item": "ui.feature.chatplayground.playground", "view": True},
|
||||
# RESOURCE: NO access (viewer cannot start/stop/access chat data)
|
||||
# DATA access (own records, read-only)
|
||||
{"context": "DATA", "item": None, "view": True, "read": "m", "create": "n", "update": "n", "delete": "n"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"roleLabel": "chatplayground-user",
|
||||
"description": {
|
||||
"en": "Chat Playground User - Use chat playground and workflows",
|
||||
"de": "Chat Playground Benutzer - Chat Playground und Workflows nutzen",
|
||||
"fr": "Utilisateur Chat Playground - Utiliser le chat playground et les workflows"
|
||||
},
|
||||
"accessRules": [
|
||||
# UI: full access to all views
|
||||
{"context": "UI", "item": "ui.feature.chatplayground.playground", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.chatplayground.workflows", "view": True},
|
||||
# Resource access: can start/stop workflows and access chat data
|
||||
{"context": "RESOURCE", "item": "resource.feature.chatplayground.start", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.chatplayground.stop", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.chatplayground.chatData", "view": True},
|
||||
# DATA access (own records)
|
||||
{"context": "DATA", "item": None, "view": True, "read": "m", "create": "m", "update": "m", "delete": "m"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"roleLabel": "chatplayground-admin",
|
||||
"description": {
|
||||
"en": "Chat Playground Admin - Full access to chat playground",
|
||||
"de": "Chat Playground Admin - Vollzugriff auf Chat Playground",
|
||||
"fr": "Administrateur Chat Playground - Accès complet au chat playground"
|
||||
},
|
||||
"accessRules": [
|
||||
# Full UI access
|
||||
{"context": "UI", "item": None, "view": True},
|
||||
# Full resource access
|
||||
{"context": "RESOURCE", "item": None, "view": True},
|
||||
# Full DATA access
|
||||
{"context": "DATA", "item": None, "view": True, "read": "a", "create": "a", "update": "a", "delete": "a"},
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def getFeatureDefinition() -> Dict[str, Any]:
|
||||
"""Return the feature definition for registration."""
|
||||
return {
|
||||
"code": FEATURE_CODE,
|
||||
"label": FEATURE_LABEL,
|
||||
"icon": FEATURE_ICON,
|
||||
"autoCreateInstance": True, # Automatically create instance in root mandate during bootstrap
|
||||
}
|
||||
|
||||
|
||||
def getUiObjects() -> List[Dict[str, Any]]:
|
||||
"""Return UI objects for RBAC catalog registration."""
|
||||
return UI_OBJECTS
|
||||
|
||||
|
||||
def getResourceObjects() -> List[Dict[str, Any]]:
|
||||
"""Return resource objects for RBAC catalog registration."""
|
||||
return RESOURCE_OBJECTS
|
||||
|
||||
|
||||
def getTemplateRoles() -> List[Dict[str, Any]]:
|
||||
"""Return template roles for this feature."""
|
||||
return TEMPLATE_ROLES
|
||||
|
||||
|
||||
def registerFeature(catalogService) -> bool:
|
||||
"""
|
||||
Register this feature's RBAC objects in the catalog.
|
||||
|
||||
Args:
|
||||
catalogService: The RBAC catalog service instance
|
||||
|
||||
Returns:
|
||||
True if registration was successful
|
||||
"""
|
||||
try:
|
||||
# Register UI objects
|
||||
for uiObj in UI_OBJECTS:
|
||||
catalogService.registerUiObject(
|
||||
featureCode=FEATURE_CODE,
|
||||
objectKey=uiObj["objectKey"],
|
||||
label=uiObj["label"],
|
||||
meta=uiObj.get("meta")
|
||||
)
|
||||
|
||||
# Register Resource objects
|
||||
for resObj in RESOURCE_OBJECTS:
|
||||
catalogService.registerResourceObject(
|
||||
featureCode=FEATURE_CODE,
|
||||
objectKey=resObj["objectKey"],
|
||||
label=resObj["label"],
|
||||
meta=resObj.get("meta")
|
||||
)
|
||||
|
||||
# Sync template roles to database
|
||||
_syncTemplateRolesToDb()
|
||||
|
||||
logger.info(f"Feature '{FEATURE_CODE}' registered {len(UI_OBJECTS)} UI objects and {len(RESOURCE_OBJECTS)} resource objects")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register feature '{FEATURE_CODE}': {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _syncTemplateRolesToDb() -> int:
|
||||
"""
|
||||
Sync template roles and their AccessRules to the database.
|
||||
Creates global template roles (mandateId=None) if they don't exist.
|
||||
|
||||
Returns:
|
||||
Number of roles created/updated
|
||||
"""
|
||||
try:
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
from modules.datamodels.datamodelRbac import Role, AccessRule, AccessRuleContext
|
||||
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Get existing template roles for this feature (Pydantic models)
|
||||
existingRoles = rootInterface.getRolesByFeatureCode(FEATURE_CODE)
|
||||
# Filter to template roles (mandateId is None)
|
||||
templateRoles = [r for r in existingRoles if r.mandateId is None]
|
||||
existingRoleLabels = {r.roleLabel: str(r.id) for r in templateRoles}
|
||||
|
||||
createdCount = 0
|
||||
for roleTemplate in TEMPLATE_ROLES:
|
||||
roleLabel = roleTemplate["roleLabel"]
|
||||
|
||||
if roleLabel in existingRoleLabels:
|
||||
roleId = existingRoleLabels[roleLabel]
|
||||
# Ensure AccessRules exist for this role
|
||||
_ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", []))
|
||||
else:
|
||||
# Create new template role
|
||||
newRole = Role(
|
||||
roleLabel=roleLabel,
|
||||
description=roleTemplate.get("description", {}),
|
||||
featureCode=FEATURE_CODE,
|
||||
mandateId=None, # Global template
|
||||
featureInstanceId=None,
|
||||
isSystemRole=False
|
||||
)
|
||||
createdRole = rootInterface.db.recordCreate(Role, newRole.model_dump())
|
||||
roleId = createdRole.get("id")
|
||||
|
||||
# Create AccessRules for this role
|
||||
_ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", []))
|
||||
|
||||
logger.info(f"Created template role '{roleLabel}' with ID {roleId}")
|
||||
createdCount += 1
|
||||
|
||||
if createdCount > 0:
|
||||
logger.info(f"Feature '{FEATURE_CODE}': Created {createdCount} template roles")
|
||||
|
||||
return createdCount
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing template roles for feature '{FEATURE_CODE}': {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: List[Dict[str, Any]]) -> int:
|
||||
"""
|
||||
Ensure AccessRules exist for a role based on templates.
|
||||
|
||||
Args:
|
||||
rootInterface: Root interface instance
|
||||
roleId: Role ID
|
||||
ruleTemplates: List of rule templates
|
||||
|
||||
Returns:
|
||||
Number of rules created
|
||||
"""
|
||||
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
|
||||
|
||||
# Get existing rules for this role (Pydantic models)
|
||||
existingRules = rootInterface.getAccessRulesByRole(roleId)
|
||||
|
||||
# Create a set of existing rule signatures to avoid duplicates
|
||||
# IMPORTANT: Use .value for enum comparison, not str() which gives "AccessRuleContext.DATA" in Python 3.11+
|
||||
existingSignatures = set()
|
||||
for rule in existingRules:
|
||||
sig = (rule.context.value if rule.context else None, rule.item)
|
||||
existingSignatures.add(sig)
|
||||
|
||||
createdCount = 0
|
||||
for template in ruleTemplates:
|
||||
context = template.get("context", "UI")
|
||||
item = template.get("item")
|
||||
sig = (context, item)
|
||||
|
||||
if sig in existingSignatures:
|
||||
continue
|
||||
|
||||
# Map context string to enum
|
||||
if context == "UI":
|
||||
contextEnum = AccessRuleContext.UI
|
||||
elif context == "DATA":
|
||||
contextEnum = AccessRuleContext.DATA
|
||||
elif context == "RESOURCE":
|
||||
contextEnum = AccessRuleContext.RESOURCE
|
||||
else:
|
||||
contextEnum = context
|
||||
|
||||
newRule = AccessRule(
|
||||
roleId=roleId,
|
||||
context=contextEnum,
|
||||
item=item,
|
||||
view=template.get("view", False),
|
||||
read=template.get("read"),
|
||||
create=template.get("create"),
|
||||
update=template.get("update"),
|
||||
delete=template.get("delete"),
|
||||
)
|
||||
rootInterface.db.recordCreate(AccessRule, newRule.model_dump())
|
||||
createdCount += 1
|
||||
|
||||
if createdCount > 0:
|
||||
logger.debug(f"Created {createdCount} AccessRules for role {roleId}")
|
||||
|
||||
return createdCount
|
||||
234
modules/features/chatplayground/routeFeatureChatplayground.py
Normal file
234
modules/features/chatplayground/routeFeatureChatplayground.py
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Chat Playground Feature Routes.
|
||||
Implements the endpoints for chat playground workflow management as a feature.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request
|
||||
|
||||
# Import auth modules
|
||||
from modules.auth import limiter, getRequestContext, RequestContext
|
||||
|
||||
# Import interfaces
|
||||
from modules.interfaces import interfaceDbChat
|
||||
|
||||
# Import models
|
||||
from modules.datamodels.datamodelChat import ChatWorkflow, UserInputRequest, WorkflowModeEnum
|
||||
|
||||
# Import workflow control functions
|
||||
from modules.workflows.automation import chatStart, chatStop
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create router for chat playground feature endpoints
|
||||
router = APIRouter(
|
||||
prefix="/api/chatplayground",
|
||||
tags=["Chat Playground Feature"],
|
||||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
|
||||
|
||||
def _getServiceChat(context: RequestContext, featureInstanceId: str = None):
|
||||
"""Get chat interface with feature instance context."""
|
||||
return interfaceDbChat.getInterface(
|
||||
context.user,
|
||||
mandateId=str(context.mandateId) if context.mandateId else None,
|
||||
featureInstanceId=featureInstanceId
|
||||
)
|
||||
|
||||
|
||||
def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str:
|
||||
"""
|
||||
Validate that user has access to the feature instance.
|
||||
|
||||
Args:
|
||||
instanceId: Feature instance ID
|
||||
context: Request context
|
||||
|
||||
Returns:
|
||||
mandateId for the instance
|
||||
|
||||
Raises:
|
||||
HTTPException if access is denied
|
||||
"""
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Get feature instance (Pydantic model)
|
||||
instance = rootInterface.getFeatureInstance(instanceId)
|
||||
if not instance:
|
||||
raise HTTPException(status_code=404, detail=f"Feature instance {instanceId} not found")
|
||||
|
||||
# Check user has access to this instance using interface method
|
||||
featureAccess = rootInterface.getFeatureAccess(str(context.user.id), instanceId)
|
||||
|
||||
if not featureAccess or not featureAccess.enabled:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this feature instance")
|
||||
|
||||
return str(instance.mandateId) if instance.mandateId else None
|
||||
|
||||
|
||||
# Workflow start endpoint
|
||||
@router.post("/{instanceId}/start", response_model=ChatWorkflow)
|
||||
@limiter.limit("120/minute")
|
||||
async def start_workflow(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: Optional[str] = Query(None, description="Optional ID of the workflow to continue"),
|
||||
workflowMode: WorkflowModeEnum = Query(..., description="Workflow mode: 'Dynamic' or 'Automation' (mandatory)"),
|
||||
userInput: UserInputRequest = Body(...),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> ChatWorkflow:
|
||||
"""
|
||||
Starts a new workflow or continues an existing one.
|
||||
|
||||
Args:
|
||||
instanceId: Feature instance ID
|
||||
workflowMode: "Dynamic" for iterative dynamic-style processing, "Automation" for automated workflow execution
|
||||
"""
|
||||
try:
|
||||
# Validate access and get mandate ID
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
|
||||
# Start or continue workflow
|
||||
workflow = await chatStart(
|
||||
context.user,
|
||||
userInput,
|
||||
workflowMode,
|
||||
workflowId,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=instanceId,
|
||||
featureCode='chatplayground'
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in start_workflow: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
|
||||
# Stop workflow endpoint
|
||||
@router.post("/{instanceId}/{workflowId}/stop", response_model=ChatWorkflow)
|
||||
@limiter.limit("120/minute")
|
||||
async def stop_workflow(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="ID of the workflow to stop"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> ChatWorkflow:
|
||||
"""Stops a running workflow."""
|
||||
try:
|
||||
# Validate access and get mandate ID
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
|
||||
# Stop workflow (pass featureInstanceId for proper RBAC filtering)
|
||||
workflow = await chatStop(
|
||||
context.user,
|
||||
workflowId,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=instanceId
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stop_workflow: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
|
||||
# Unified Chat Data Endpoint for Polling
|
||||
@router.get("/{instanceId}/{workflowId}/chatData")
|
||||
@limiter.limit("120/minute")
|
||||
def get_workflow_chat_data(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
afterTimestamp: Optional[float] = Query(None, description="Unix timestamp to get data after"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get unified chat data (messages, logs, stats) for a workflow with timestamp-based selective data transfer.
|
||||
Returns all data types in chronological order based on _createdAt timestamp.
|
||||
"""
|
||||
try:
|
||||
# Validate access
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
|
||||
# Get service with feature instance context
|
||||
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
|
||||
|
||||
# Verify workflow exists
|
||||
workflow = chatInterface.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Workflow with ID {workflowId} not found"
|
||||
)
|
||||
|
||||
# Get unified chat data
|
||||
chatData = chatInterface.getUnifiedChatData(workflowId, afterTimestamp)
|
||||
|
||||
return chatData
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting unified chat data: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error getting unified chat data: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# Get workflows for this instance
|
||||
@router.get("/{instanceId}/workflows")
|
||||
@limiter.limit("120/minute")
|
||||
def get_workflows(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
page: int = Query(1, ge=1, description="Page number"),
|
||||
pageSize: int = Query(20, ge=1, le=100, description="Items per page"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get all workflows for this feature instance.
|
||||
"""
|
||||
try:
|
||||
# Validate access
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
|
||||
# Get service with feature instance context
|
||||
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
|
||||
|
||||
# Get workflows with pagination
|
||||
from modules.datamodels.datamodelPagination import PaginationParams
|
||||
pagination = PaginationParams(page=page, pageSize=pageSize)
|
||||
|
||||
result = chatInterface.getWorkflows(pagination=pagination)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflows: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error getting workflows: {str(e)}"
|
||||
)
|
||||
|
|
@ -66,7 +66,6 @@ class InterfaceFeatureNeutralizer:
|
|||
dbPort=dbPort,
|
||||
userId=self.userId,
|
||||
)
|
||||
self.db.initDbSystem()
|
||||
logger.debug("Neutralizer database initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing Neutralizer database: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -182,20 +182,18 @@ class SharepointProcessor:
|
|||
|
||||
async def _getSharepointConnection(self, sharepointPath: str = None):
|
||||
try:
|
||||
connections = self.services.interfaceDbApp.db.getRecordset(
|
||||
UserConnection,
|
||||
recordFilter={"userId": self.services.interfaceDbApp.userId}
|
||||
)
|
||||
msftConnections = [c for c in connections if c.get('authority') == 'msft']
|
||||
# Use interface method to get user connections
|
||||
connections = self.services.interfaceDbApp.getUserConnections(self.services.interfaceDbApp.userId)
|
||||
msftConnections = [c for c in connections if c.authority == 'msft']
|
||||
if not msftConnections:
|
||||
logger.warning('No Microsoft connections found for user')
|
||||
return None
|
||||
if len(msftConnections) == 1:
|
||||
logger.info(f"Found single Microsoft connection: {msftConnections[0].get('id')}")
|
||||
logger.info(f"Found single Microsoft connection: {msftConnections[0].id}")
|
||||
return msftConnections[0]
|
||||
if sharepointPath:
|
||||
return await self._matchConnectionToPath(msftConnections, sharepointPath)
|
||||
logger.info(f"Multiple Microsoft connections found, using first one: {msftConnections[0].get('id')}")
|
||||
logger.info(f"Multiple Microsoft connections found, using first one: {msftConnections[0].id}")
|
||||
return msftConnections[0]
|
||||
except Exception:
|
||||
logger.error('Error getting SharePoint connection')
|
||||
|
|
@ -9,7 +9,7 @@ from modules.auth import limiter, getRequestContext, RequestContext
|
|||
|
||||
# Import interfaces
|
||||
from .datamodelFeatureNeutralizer import DataNeutraliserConfig, DataNeutralizerAttributes
|
||||
from .mainNeutralizePlayground import NeutralizationPlayground
|
||||
from .neutralizePlayground import NeutralizationPlayground
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -29,7 +29,7 @@ router = APIRouter(
|
|||
|
||||
@router.get("/config", response_model=DataNeutraliserConfig)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_neutralization_config(
|
||||
def get_neutralization_config(
|
||||
request: Request,
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> DataNeutraliserConfig:
|
||||
|
|
@ -62,7 +62,7 @@ async def get_neutralization_config(
|
|||
|
||||
@router.post("/config", response_model=DataNeutraliserConfig)
|
||||
@limiter.limit("10/minute")
|
||||
async def save_neutralization_config(
|
||||
def save_neutralization_config(
|
||||
request: Request,
|
||||
config_data: Dict[str, Any] = Body(...),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -83,7 +83,7 @@ async def save_neutralization_config(
|
|||
|
||||
@router.post("/neutralize-text", response_model=Dict[str, Any])
|
||||
@limiter.limit("20/minute")
|
||||
async def neutralize_text(
|
||||
def neutralize_text(
|
||||
request: Request,
|
||||
text_data: Dict[str, Any] = Body(...),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -115,7 +115,7 @@ async def neutralize_text(
|
|||
|
||||
@router.post("/resolve-text", response_model=Dict[str, str])
|
||||
@limiter.limit("20/minute")
|
||||
async def resolve_text(
|
||||
def resolve_text(
|
||||
request: Request,
|
||||
text_data: Dict[str, str] = Body(...),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -146,7 +146,7 @@ async def resolve_text(
|
|||
|
||||
@router.get("/attributes", response_model=List[DataNeutralizerAttributes])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_neutralization_attributes(
|
||||
def get_neutralization_attributes(
|
||||
request: Request,
|
||||
fileId: Optional[str] = Query(None, description="Filter by file ID"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -199,7 +199,7 @@ async def process_sharepoint_files(
|
|||
|
||||
@router.post("/batch-process", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def batch_process_files(
|
||||
def batch_process_files(
|
||||
request: Request,
|
||||
files_data: List[Dict[str, Any]] = Body(...),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -228,7 +228,7 @@ async def batch_process_files(
|
|||
|
||||
@router.get("/stats", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_neutralization_stats(
|
||||
def get_neutralization_stats(
|
||||
request: Request,
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
|
|
@ -248,7 +248,7 @@ async def get_neutralization_stats(
|
|||
|
||||
@router.delete("/attributes/{fileId}", response_model=Dict[str, str])
|
||||
@limiter.limit("10/minute")
|
||||
async def cleanup_file_attributes(
|
||||
def cleanup_file_attributes(
|
||||
request: Request,
|
||||
fileId: str = Path(..., description="File ID to cleanup attributes for"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
|
|||
|
|
@ -85,11 +85,6 @@ class RealEstateObjects:
|
|||
userId=self.userId if self.userId else None,
|
||||
)
|
||||
|
||||
# Initialize database system (creates database and system table if needed)
|
||||
# Note: This is also called in DatabaseConnector.__init__, but we call it explicitly
|
||||
# for consistency with other interfaces and to ensure proper initialization
|
||||
self.db.initDbSystem()
|
||||
|
||||
# Ensure all supporting tables are created (Land, Kanton, Gemeinde, Dokument)
|
||||
# These tables are needed for foreign key relationships
|
||||
self._ensureSupportingTablesExist()
|
||||
|
|
@ -754,10 +749,12 @@ class RealEstateObjects:
|
|||
return False
|
||||
|
||||
tableName = modelClass.__name__
|
||||
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
||||
objectKey = buildDataObjectKey(tableName, featureCode=self.featureCode if hasattr(self, 'featureCode') else None)
|
||||
permissions = self.rbac.getUserPermissions(
|
||||
self.currentUser,
|
||||
AccessRuleContext.DATA,
|
||||
tableName,
|
||||
objectKey,
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId
|
||||
)
|
||||
|
|
|
|||
|
|
@ -165,13 +165,11 @@ def _syncTemplateRolesToDb() -> int:
|
|||
from modules.datamodels.datamodelRbac import Role, AccessRule, AccessRuleContext
|
||||
|
||||
rootInterface = getRootInterface()
|
||||
db = rootInterface.db
|
||||
|
||||
existingRoles = db.getRecordset(
|
||||
Role,
|
||||
recordFilter={"featureCode": FEATURE_CODE, "mandateId": None}
|
||||
)
|
||||
existingRoleLabels = {r.get("roleLabel"): r.get("id") for r in existingRoles}
|
||||
# Get existing template roles (Pydantic models)
|
||||
existingRoles = rootInterface.getRolesByFeatureCode(FEATURE_CODE)
|
||||
templateRoles = [r for r in existingRoles if r.mandateId is None]
|
||||
existingRoleLabels = {r.roleLabel: str(r.id) for r in templateRoles}
|
||||
|
||||
createdCount = 0
|
||||
for roleTemplate in TEMPLATE_ROLES:
|
||||
|
|
@ -179,7 +177,7 @@ def _syncTemplateRolesToDb() -> int:
|
|||
|
||||
if roleLabel in existingRoleLabels:
|
||||
roleId = existingRoleLabels[roleLabel]
|
||||
_ensureAccessRulesForRole(db, roleId, roleTemplate.get("accessRules", []))
|
||||
_ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", []))
|
||||
else:
|
||||
newRole = Role(
|
||||
roleLabel=roleLabel,
|
||||
|
|
@ -189,65 +187,66 @@ def _syncTemplateRolesToDb() -> int:
|
|||
featureInstanceId=None,
|
||||
isSystemRole=False
|
||||
)
|
||||
createdRole = db.recordCreate(Role, newRole.model_dump())
|
||||
createdRole = rootInterface.db.recordCreate(Role, newRole.model_dump())
|
||||
roleId = createdRole.get("id")
|
||||
existingRoleLabels[roleLabel] = roleId
|
||||
_ensureAccessRulesForRole(db, roleId, roleTemplate.get("accessRules", []))
|
||||
_ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", []))
|
||||
logging.getLogger(__name__).info(f"Created template role '{roleLabel}' with ID {roleId}")
|
||||
createdCount += 1
|
||||
|
||||
if createdCount > 0:
|
||||
logging.getLogger(__name__).info(f"Feature '{FEATURE_CODE}': Created {createdCount} template roles")
|
||||
|
||||
_repairInstanceRolesAccessRules(db, existingRoleLabels)
|
||||
_repairInstanceRolesAccessRules(rootInterface, existingRoleLabels)
|
||||
return createdCount
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).error(f"Error syncing template roles for feature '{FEATURE_CODE}': {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def _repairInstanceRolesAccessRules(db, templateRoleLabels: dict) -> int:
|
||||
def _repairInstanceRolesAccessRules(rootInterface, templateRoleLabels: dict) -> int:
|
||||
"""Repair instance-specific roles by copying AccessRules from their template roles."""
|
||||
from modules.datamodels.datamodelRbac import Role, AccessRule
|
||||
|
||||
repairedCount = 0
|
||||
allRoles = db.getRecordset(Role, recordFilter={"featureCode": FEATURE_CODE})
|
||||
instanceRoles = [r for r in allRoles if r.get("mandateId") is not None]
|
||||
allRoles = rootInterface.getRolesByFeatureCode(FEATURE_CODE)
|
||||
instanceRoles = [r for r in allRoles if r.mandateId is not None]
|
||||
|
||||
for instanceRole in instanceRoles:
|
||||
roleLabel = instanceRole.get("roleLabel")
|
||||
instanceRoleId = instanceRole.get("id")
|
||||
roleLabel = instanceRole.roleLabel
|
||||
instanceRoleId = str(instanceRole.id)
|
||||
templateRoleId = templateRoleLabels.get(roleLabel)
|
||||
if not templateRoleId:
|
||||
continue
|
||||
existingRules = db.getRecordset(AccessRule, recordFilter={"roleId": instanceRoleId})
|
||||
existingRules = rootInterface.getAccessRulesByRole(instanceRoleId)
|
||||
if existingRules:
|
||||
continue
|
||||
templateRules = db.getRecordset(AccessRule, recordFilter={"roleId": templateRoleId})
|
||||
templateRules = rootInterface.getAccessRulesByRole(templateRoleId)
|
||||
if not templateRules:
|
||||
continue
|
||||
for rule in templateRules:
|
||||
newRule = AccessRule(
|
||||
roleId=instanceRoleId,
|
||||
context=rule.get("context"),
|
||||
item=rule.get("item"),
|
||||
view=rule.get("view", False),
|
||||
read=rule.get("read"),
|
||||
create=rule.get("create"),
|
||||
update=rule.get("update"),
|
||||
delete=rule.get("delete"),
|
||||
context=rule.context,
|
||||
item=rule.item,
|
||||
view=rule.view if rule.view else False,
|
||||
read=rule.read,
|
||||
create=rule.create,
|
||||
update=rule.update,
|
||||
delete=rule.delete,
|
||||
)
|
||||
db.recordCreate(AccessRule, newRule.model_dump())
|
||||
rootInterface.db.recordCreate(AccessRule, newRule.model_dump())
|
||||
repairedCount += 1
|
||||
return repairedCount
|
||||
|
||||
|
||||
def _ensureAccessRulesForRole(db, roleId: str, ruleTemplates: list) -> int:
|
||||
def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: list) -> int:
|
||||
"""Ensure AccessRules exist for a role based on templates."""
|
||||
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
|
||||
|
||||
existingRules = db.getRecordset(AccessRule, recordFilter={"roleId": roleId})
|
||||
existingSignatures = {(r.get("context"), r.get("item")) for r in existingRules}
|
||||
existingRules = rootInterface.getAccessRulesByRole(roleId)
|
||||
# IMPORTANT: Use .value for enum comparison, not str() which gives "AccessRuleContext.DATA" in Python 3.11+
|
||||
existingSignatures = {(r.context.value if r.context else None, r.item) for r in existingRules}
|
||||
createdCount = 0
|
||||
|
||||
for template in ruleTemplates or []:
|
||||
|
|
@ -273,7 +272,7 @@ def _ensureAccessRulesForRole(db, roleId: str, ruleTemplates: list) -> int:
|
|||
update=template.get("update"),
|
||||
delete=template.get("delete"),
|
||||
)
|
||||
db.recordCreate(AccessRule, newRule.model_dump())
|
||||
rootInterface.db.recordCreate(AccessRule, newRule.model_dump())
|
||||
createdCount += 1
|
||||
existingSignatures.add((context, item))
|
||||
return createdCount
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ def _parsePagination(pagination: Optional[str]) -> Optional[PaginationParams]:
|
|||
return None
|
||||
|
||||
|
||||
async def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str:
|
||||
def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str:
|
||||
"""
|
||||
Validate that the user has access to the feature instance.
|
||||
Returns the mandateId for the instance.
|
||||
|
|
@ -132,14 +132,14 @@ _REALESTATE_ENTITY_MODELS = {
|
|||
|
||||
@router.get("/{instanceId}/attributes/{entityType}", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_entity_attributes(
|
||||
def get_entity_attributes(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
entityType: str = Path(..., description="Entity type (e.g., Projekt, Parzelle)"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get attribute definitions for a Real Estate entity. Used by FormGeneratorTable."""
|
||||
await _validateInstanceAccess(instanceId, context)
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
if entityType not in _REALESTATE_ENTITY_MODELS:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
|
|
@ -163,13 +163,13 @@ async def get_entity_attributes(
|
|||
|
||||
@router.get("/{instanceId}/projects/options", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("60/minute")
|
||||
async def get_project_options(
|
||||
def get_project_options(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get project options for select dropdowns. Returns: [{ value, label }]"""
|
||||
mandateId = await _validateInstanceAccess(instanceId, context)
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
interface = getRealEstateInterface(
|
||||
context.user, mandateId=mandateId, featureInstanceId=instanceId
|
||||
)
|
||||
|
|
@ -179,13 +179,13 @@ async def get_project_options(
|
|||
|
||||
@router.get("/{instanceId}/parcels/options", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("60/minute")
|
||||
async def get_parcel_options(
|
||||
def get_parcel_options(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get parcel options for select dropdowns. Returns: [{ value, label }]"""
|
||||
mandateId = await _validateInstanceAccess(instanceId, context)
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
interface = getRealEstateInterface(
|
||||
context.user, mandateId=mandateId, featureInstanceId=instanceId
|
||||
)
|
||||
|
|
@ -197,14 +197,14 @@ async def get_parcel_options(
|
|||
|
||||
@router.get("/{instanceId}/projects", response_model=PaginatedResponse[Projekt])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_projects(
|
||||
def get_projects(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> PaginatedResponse[Projekt]:
|
||||
"""Get all projects for a feature instance with optional pagination."""
|
||||
mandateId = await _validateInstanceAccess(instanceId, context)
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
interface = getRealEstateInterface(
|
||||
context.user, mandateId=mandateId, featureInstanceId=instanceId
|
||||
)
|
||||
|
|
@ -241,14 +241,14 @@ async def get_projects(
|
|||
|
||||
@router.get("/{instanceId}/projects/{projectId}", response_model=Projekt)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_project_by_id(
|
||||
def get_project_by_id(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
projectId: str = Path(..., description="Project ID"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Projekt:
|
||||
"""Get a single project by ID."""
|
||||
mandateId = await _validateInstanceAccess(instanceId, context)
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
interface = getRealEstateInterface(
|
||||
context.user, mandateId=mandateId, featureInstanceId=instanceId
|
||||
)
|
||||
|
|
@ -260,14 +260,14 @@ async def get_project_by_id(
|
|||
|
||||
@router.post("/{instanceId}/projects", response_model=Projekt)
|
||||
@limiter.limit("30/minute")
|
||||
async def create_project(
|
||||
def create_project(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
data: Dict[str, Any] = Body(...),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Projekt:
|
||||
"""Create a new project."""
|
||||
mandateId = await _validateInstanceAccess(instanceId, context)
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
interface = getRealEstateInterface(
|
||||
context.user, mandateId=mandateId, featureInstanceId=instanceId
|
||||
)
|
||||
|
|
@ -284,7 +284,7 @@ async def create_project(
|
|||
|
||||
@router.put("/{instanceId}/projects/{projectId}", response_model=Projekt)
|
||||
@limiter.limit("30/minute")
|
||||
async def update_project(
|
||||
def update_project(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
projectId: str = Path(..., description="Project ID"),
|
||||
|
|
@ -292,7 +292,7 @@ async def update_project(
|
|||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Projekt:
|
||||
"""Update a project."""
|
||||
mandateId = await _validateInstanceAccess(instanceId, context)
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
interface = getRealEstateInterface(
|
||||
context.user, mandateId=mandateId, featureInstanceId=instanceId
|
||||
)
|
||||
|
|
@ -307,14 +307,14 @@ async def update_project(
|
|||
|
||||
@router.delete("/{instanceId}/projects/{projectId}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@limiter.limit("30/minute")
|
||||
async def delete_project(
|
||||
def delete_project(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
projectId: str = Path(..., description="Project ID"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> None:
|
||||
"""Delete a project."""
|
||||
mandateId = await _validateInstanceAccess(instanceId, context)
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
interface = getRealEstateInterface(
|
||||
context.user, mandateId=mandateId, featureInstanceId=instanceId
|
||||
)
|
||||
|
|
@ -329,14 +329,14 @@ async def delete_project(
|
|||
|
||||
@router.get("/{instanceId}/parcels", response_model=PaginatedResponse[Parzelle])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_parcels(
|
||||
def get_parcels(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> PaginatedResponse[Parzelle]:
|
||||
"""Get all parcels for a feature instance with optional pagination."""
|
||||
mandateId = await _validateInstanceAccess(instanceId, context)
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
interface = getRealEstateInterface(
|
||||
context.user, mandateId=mandateId, featureInstanceId=instanceId
|
||||
)
|
||||
|
|
@ -373,14 +373,14 @@ async def get_parcels(
|
|||
|
||||
@router.get("/{instanceId}/parcels/{parcelId}", response_model=Parzelle)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_parcel_by_id(
|
||||
def get_parcel_by_id(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
parcelId: str = Path(..., description="Parcel ID"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Parzelle:
|
||||
"""Get a single parcel by ID."""
|
||||
mandateId = await _validateInstanceAccess(instanceId, context)
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
interface = getRealEstateInterface(
|
||||
context.user, mandateId=mandateId, featureInstanceId=instanceId
|
||||
)
|
||||
|
|
@ -392,14 +392,14 @@ async def get_parcel_by_id(
|
|||
|
||||
@router.post("/{instanceId}/parcels", response_model=Parzelle)
|
||||
@limiter.limit("30/minute")
|
||||
async def create_parcel(
|
||||
def create_parcel(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
data: Dict[str, Any] = Body(...),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Parzelle:
|
||||
"""Create a new parcel."""
|
||||
mandateId = await _validateInstanceAccess(instanceId, context)
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
interface = getRealEstateInterface(
|
||||
context.user, mandateId=mandateId, featureInstanceId=instanceId
|
||||
)
|
||||
|
|
@ -416,7 +416,7 @@ async def create_parcel(
|
|||
|
||||
@router.put("/{instanceId}/parcels/{parcelId}", response_model=Parzelle)
|
||||
@limiter.limit("30/minute")
|
||||
async def update_parcel(
|
||||
def update_parcel(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
parcelId: str = Path(..., description="Parcel ID"),
|
||||
|
|
@ -424,7 +424,7 @@ async def update_parcel(
|
|||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Parzelle:
|
||||
"""Update a parcel."""
|
||||
mandateId = await _validateInstanceAccess(instanceId, context)
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
interface = getRealEstateInterface(
|
||||
context.user, mandateId=mandateId, featureInstanceId=instanceId
|
||||
)
|
||||
|
|
@ -439,14 +439,14 @@ async def update_parcel(
|
|||
|
||||
@router.delete("/{instanceId}/parcels/{parcelId}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@limiter.limit("30/minute")
|
||||
async def delete_parcel(
|
||||
def delete_parcel(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
parcelId: str = Path(..., description="Parcel ID"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> None:
|
||||
"""Delete a parcel."""
|
||||
mandateId = await _validateInstanceAccess(instanceId, context)
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
interface = getRealEstateInterface(
|
||||
context.user, mandateId=mandateId, featureInstanceId=instanceId
|
||||
)
|
||||
|
|
@ -549,7 +549,7 @@ async def process_command(
|
|||
|
||||
@router.get("/tables", response_model=Dict[str, Any])
|
||||
@limiter.limit("120/minute")
|
||||
async def get_available_tables(
|
||||
def get_available_tables(
|
||||
request: Request,
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
|
|
@ -645,7 +645,7 @@ async def get_available_tables(
|
|||
|
||||
@router.get("/table/{table}", response_model=PaginatedResponse[Any])
|
||||
@limiter.limit("120/minute")
|
||||
async def get_table_data(
|
||||
def get_table_data(
|
||||
request: Request,
|
||||
table: str = Path(..., description="Table name (Projekt, Parzelle, Dokument, Gemeinde, Kanton, Land)"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
|
|
|
|||
|
|
@ -155,7 +155,6 @@ class TrusteeObjects:
|
|||
userId=self.userId,
|
||||
)
|
||||
|
||||
self.db.initDbSystem()
|
||||
logger.info(f"Trustee database initialized successfully for user {self.userId}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Trustee database: {str(e)}")
|
||||
|
|
@ -172,10 +171,12 @@ class TrusteeObjects:
|
|||
return False
|
||||
|
||||
tableName = modelClass.__name__
|
||||
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
||||
objectKey = buildDataObjectKey(tableName, featureCode=self.featureCode if hasattr(self, 'featureCode') else None)
|
||||
permissions = self.rbac.getUserPermissions(
|
||||
self.currentUser,
|
||||
AccessRuleContext.DATA,
|
||||
tableName,
|
||||
objectKey,
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId
|
||||
)
|
||||
|
|
@ -199,10 +200,12 @@ class TrusteeObjects:
|
|||
return AccessLevel.NONE
|
||||
|
||||
tableName = modelClass.__name__
|
||||
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
||||
objectKey = buildDataObjectKey(tableName, featureCode=self.featureCode if hasattr(self, 'featureCode') else None)
|
||||
permissions = self.rbac.getUserPermissions(
|
||||
self.currentUser,
|
||||
AccessRuleContext.DATA,
|
||||
tableName,
|
||||
objectKey,
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId
|
||||
)
|
||||
|
|
@ -1471,7 +1474,7 @@ class TrusteeObjects:
|
|||
|
||||
def getAllUserAccess(self, userId: str) -> List[Dict[str, Any]]:
|
||||
"""Get all access records for a user across all organisations."""
|
||||
return self.db.getRecordset(TrusteeAccess, {"userId": userId})
|
||||
return self.db.getRecordset(TrusteeAccess, recordFilter={"userId": userId})
|
||||
|
||||
def getUserTrusteeRoles(self, userId: str, organisationId: str, contractId: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -144,12 +144,11 @@ TEMPLATE_ROLES = [
|
|||
"fr": "Comptable fiduciaire - Gérer les données comptables et financières"
|
||||
},
|
||||
"accessRules": [
|
||||
# UI access to main views (not admin views) - vollqualifizierte ObjectKeys
|
||||
# UI access to main views (not admin views, not expense-import) - vollqualifizierte ObjectKeys
|
||||
{"context": "UI", "item": "ui.feature.trustee.dashboard", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.trustee.positions", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.trustee.documents", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.trustee.position-documents", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.trustee.expense-import", "view": True},
|
||||
# Group-level DATA access
|
||||
{"context": "DATA", "item": None, "view": True, "read": "g", "create": "g", "update": "g", "delete": "g"},
|
||||
]
|
||||
|
|
@ -162,11 +161,12 @@ TEMPLATE_ROLES = [
|
|||
"fr": "Client fiduciaire - Consulter ses propres données comptables et documents"
|
||||
},
|
||||
"accessRules": [
|
||||
# UI access to main views only (read-only focus) - vollqualifizierte ObjectKeys
|
||||
# UI access to main views + expense-import - vollqualifizierte ObjectKeys
|
||||
{"context": "UI", "item": "ui.feature.trustee.dashboard", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.trustee.positions", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.trustee.documents", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.trustee.position-documents", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.trustee.expense-import", "view": True},
|
||||
# Own records only (MY level) - explizite Regeln pro Tabelle
|
||||
{"context": "DATA", "item": "data.feature.trustee.TrusteePosition", "view": True, "read": "m", "create": "m", "update": "m", "delete": "n"},
|
||||
{"context": "DATA", "item": "data.feature.trustee.TrusteeDocument", "view": True, "read": "m", "create": "m", "update": "m", "delete": "n"},
|
||||
|
|
@ -267,14 +267,11 @@ def _syncTemplateRolesToDb() -> int:
|
|||
from modules.datamodels.datamodelRbac import Role, AccessRule, AccessRuleContext
|
||||
|
||||
rootInterface = getRootInterface()
|
||||
db = rootInterface.db
|
||||
|
||||
# Get existing template roles for this feature
|
||||
existingRoles = db.getRecordset(
|
||||
Role,
|
||||
recordFilter={"featureCode": FEATURE_CODE, "mandateId": None}
|
||||
)
|
||||
existingRoleLabels = {r.get("roleLabel"): r.get("id") for r in existingRoles}
|
||||
# Get existing template roles for this feature (Pydantic models)
|
||||
existingRoles = rootInterface.getRolesByFeatureCode(FEATURE_CODE)
|
||||
templateRoles = [r for r in existingRoles if r.mandateId is None]
|
||||
existingRoleLabels = {r.roleLabel: str(r.id) for r in templateRoles}
|
||||
|
||||
createdCount = 0
|
||||
for roleTemplate in TEMPLATE_ROLES:
|
||||
|
|
@ -282,10 +279,8 @@ def _syncTemplateRolesToDb() -> int:
|
|||
|
||||
if roleLabel in existingRoleLabels:
|
||||
roleId = existingRoleLabels[roleLabel]
|
||||
logger.debug(f"Template role '{roleLabel}' already exists with ID {roleId}")
|
||||
|
||||
# Ensure AccessRules exist for this role
|
||||
_ensureAccessRulesForRole(db, roleId, roleTemplate.get("accessRules", []))
|
||||
_ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", []))
|
||||
else:
|
||||
# Create new template role
|
||||
newRole = Role(
|
||||
|
|
@ -296,11 +291,11 @@ def _syncTemplateRolesToDb() -> int:
|
|||
featureInstanceId=None,
|
||||
isSystemRole=False
|
||||
)
|
||||
createdRole = db.recordCreate(Role, newRole.model_dump())
|
||||
createdRole = rootInterface.db.recordCreate(Role, newRole.model_dump())
|
||||
roleId = createdRole.get("id")
|
||||
|
||||
# Create AccessRules for this role
|
||||
_ensureAccessRulesForRole(db, roleId, roleTemplate.get("accessRules", []))
|
||||
_ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", []))
|
||||
|
||||
logger.info(f"Created template role '{roleLabel}' with ID {roleId}")
|
||||
createdCount += 1
|
||||
|
|
@ -309,7 +304,7 @@ def _syncTemplateRolesToDb() -> int:
|
|||
logger.info(f"Feature '{FEATURE_CODE}': Created {createdCount} template roles")
|
||||
|
||||
# Repair instance-specific roles that are missing AccessRules
|
||||
_repairInstanceRolesAccessRules(db, existingRoleLabels)
|
||||
_repairInstanceRolesAccessRules(rootInterface, existingRoleLabels)
|
||||
|
||||
return createdCount
|
||||
|
||||
|
|
@ -318,13 +313,13 @@ def _syncTemplateRolesToDb() -> int:
|
|||
return 0
|
||||
|
||||
|
||||
def _repairInstanceRolesAccessRules(db, templateRoleLabels: Dict[str, str]) -> int:
|
||||
def _repairInstanceRolesAccessRules(rootInterface, templateRoleLabels: Dict[str, str]) -> int:
|
||||
"""
|
||||
Repair instance-specific roles by copying AccessRules from their template roles.
|
||||
This ensures instance roles created before AccessRules were defined get updated.
|
||||
|
||||
Args:
|
||||
db: Database connector
|
||||
rootInterface: Root interface instance
|
||||
templateRoleLabels: Dict mapping roleLabel to template role ID
|
||||
|
||||
Returns:
|
||||
|
|
@ -334,41 +329,41 @@ def _repairInstanceRolesAccessRules(db, templateRoleLabels: Dict[str, str]) -> i
|
|||
|
||||
repairedCount = 0
|
||||
|
||||
# Get all instance-specific roles for this feature (mandateId is NOT None)
|
||||
allRoles = db.getRecordset(Role, recordFilter={"featureCode": FEATURE_CODE})
|
||||
instanceRoles = [r for r in allRoles if r.get("mandateId") is not None]
|
||||
# Get all instance-specific roles for this feature (Pydantic models)
|
||||
allRoles = rootInterface.getRolesByFeatureCode(FEATURE_CODE)
|
||||
instanceRoles = [r for r in allRoles if r.mandateId is not None]
|
||||
|
||||
for instanceRole in instanceRoles:
|
||||
roleLabel = instanceRole.get("roleLabel")
|
||||
instanceRoleId = instanceRole.get("id")
|
||||
roleLabel = instanceRole.roleLabel
|
||||
instanceRoleId = str(instanceRole.id)
|
||||
|
||||
# Find matching template role
|
||||
templateRoleId = templateRoleLabels.get(roleLabel)
|
||||
if not templateRoleId:
|
||||
continue
|
||||
|
||||
# Check if instance role has AccessRules
|
||||
existingRules = db.getRecordset(AccessRule, recordFilter={"roleId": instanceRoleId})
|
||||
# Check if instance role has AccessRules (Pydantic models)
|
||||
existingRules = rootInterface.getAccessRulesByRole(instanceRoleId)
|
||||
if existingRules:
|
||||
continue # Already has rules, skip
|
||||
|
||||
# Copy AccessRules from template role
|
||||
templateRules = db.getRecordset(AccessRule, recordFilter={"roleId": templateRoleId})
|
||||
# Copy AccessRules from template role (Pydantic models)
|
||||
templateRules = rootInterface.getAccessRulesByRole(templateRoleId)
|
||||
if not templateRules:
|
||||
continue # Template has no rules
|
||||
|
||||
for rule in templateRules:
|
||||
newRule = AccessRule(
|
||||
roleId=instanceRoleId,
|
||||
context=rule.get("context"),
|
||||
item=rule.get("item"),
|
||||
view=rule.get("view", False),
|
||||
read=rule.get("read"),
|
||||
create=rule.get("create"),
|
||||
update=rule.get("update"),
|
||||
delete=rule.get("delete"),
|
||||
context=rule.context,
|
||||
item=rule.item,
|
||||
view=rule.view if rule.view else False,
|
||||
read=rule.read,
|
||||
create=rule.create,
|
||||
update=rule.update,
|
||||
delete=rule.delete,
|
||||
)
|
||||
db.recordCreate(AccessRule, newRule.model_dump())
|
||||
rootInterface.db.recordCreate(AccessRule, newRule.model_dump())
|
||||
|
||||
logger.info(f"Repaired instance role '{roleLabel}' (ID: {instanceRoleId}): copied {len(templateRules)} AccessRules from template")
|
||||
repairedCount += 1
|
||||
|
|
@ -379,12 +374,12 @@ def _repairInstanceRolesAccessRules(db, templateRoleLabels: Dict[str, str]) -> i
|
|||
return repairedCount
|
||||
|
||||
|
||||
def _ensureAccessRulesForRole(db, roleId: str, ruleTemplates: List[Dict[str, Any]]) -> int:
|
||||
def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: List[Dict[str, Any]]) -> int:
|
||||
"""
|
||||
Ensure AccessRules exist for a role based on templates.
|
||||
|
||||
Args:
|
||||
db: Database connector
|
||||
rootInterface: Root interface instance
|
||||
roleId: Role ID
|
||||
ruleTemplates: List of rule templates
|
||||
|
||||
|
|
@ -393,13 +388,14 @@ def _ensureAccessRulesForRole(db, roleId: str, ruleTemplates: List[Dict[str, Any
|
|||
"""
|
||||
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
|
||||
|
||||
# Get existing rules for this role
|
||||
existingRules = db.getRecordset(AccessRule, recordFilter={"roleId": roleId})
|
||||
# Get existing rules for this role (Pydantic models)
|
||||
existingRules = rootInterface.getAccessRulesByRole(roleId)
|
||||
|
||||
# Create a set of existing rule signatures to avoid duplicates
|
||||
# IMPORTANT: Use .value for enum comparison, not str() which gives "AccessRuleContext.DATA" in Python 3.11+
|
||||
existingSignatures = set()
|
||||
for rule in existingRules:
|
||||
sig = (rule.get("context"), rule.get("item"))
|
||||
sig = (rule.context.value if rule.context else None, rule.item)
|
||||
existingSignatures.add(sig)
|
||||
|
||||
createdCount = 0
|
||||
|
|
@ -431,7 +427,7 @@ def _ensureAccessRulesForRole(db, roleId: str, ruleTemplates: List[Dict[str, Any
|
|||
update=template.get("update"),
|
||||
delete=template.get("delete"),
|
||||
)
|
||||
db.recordCreate(AccessRule, newRule.model_dump())
|
||||
rootInterface.db.recordCreate(AccessRule, newRule.model_dump())
|
||||
createdCount += 1
|
||||
|
||||
if createdCount > 0:
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -4,8 +4,8 @@ import logging
|
|||
import asyncio
|
||||
import uuid
|
||||
import base64
|
||||
from typing import Dict, Any, List, Union, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, List, Union, Tuple, Optional, Callable
|
||||
from dataclasses import dataclass, field
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -29,7 +29,13 @@ from modules.datamodels.datamodelExtraction import ContentPart, MergeStrategy
|
|||
|
||||
@dataclass(slots=True)
|
||||
class AiObjects:
|
||||
"""Centralized AI interface: dynamically discovers and uses AI models. Includes web functionality."""
|
||||
"""Centralized AI interface: dynamically discovers and uses AI models.
|
||||
|
||||
billingCallback: Set by serviceAi before AI calls. Called after EVERY individual
|
||||
model call with the AiCallResponse. This ensures per-model-call billing with
|
||||
exact provider + model name. The callback handles billing recording.
|
||||
"""
|
||||
billingCallback: Optional[Callable] = field(default=None, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Auto-discover and register all available connectors
|
||||
|
|
@ -89,6 +95,17 @@ class AiObjects:
|
|||
|
||||
# Get failover models for this operation type
|
||||
availableModels = modelRegistry.getAvailableModels()
|
||||
|
||||
# Filter by allowedProviders if specified (from workflow config)
|
||||
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
|
||||
if allowedProviders:
|
||||
filteredModels = [m for m in availableModels if m.connectorType in allowedProviders]
|
||||
if filteredModels:
|
||||
logger.info(f"Filtered models by allowedProviders {allowedProviders}: {len(filteredModels)} models (from {len(availableModels)})")
|
||||
availableModels = filteredModels
|
||||
else:
|
||||
logger.warning(f"No models match allowedProviders {allowedProviders}, using all {len(availableModels)} available models")
|
||||
|
||||
failoverModelList = modelSelector.getFailoverModelList(prompt, context, options, availableModels)
|
||||
|
||||
if not failoverModelList:
|
||||
|
|
@ -97,7 +114,7 @@ class AiObjects:
|
|||
return AiCallResponse(
|
||||
content=errorMsg,
|
||||
modelName="error",
|
||||
priceUsd=0.0,
|
||||
priceCHF=0.0,
|
||||
processingTime=0.0,
|
||||
bytesSent=0,
|
||||
bytesReceived=0,
|
||||
|
|
@ -135,7 +152,7 @@ class AiObjects:
|
|||
return AiCallResponse(
|
||||
content=errorMsg,
|
||||
modelName="error",
|
||||
priceUsd=0.0,
|
||||
priceCHF=0.0,
|
||||
processingTime=0.0,
|
||||
bytesSent=0,
|
||||
bytesReceived=0,
|
||||
|
|
@ -147,7 +164,7 @@ class AiObjects:
|
|||
return AiCallResponse(
|
||||
content=errorMsg,
|
||||
modelName="error",
|
||||
priceUsd=0.0,
|
||||
priceCHF=0.0,
|
||||
processingTime=0.0,
|
||||
bytesSent=inputBytes,
|
||||
bytesReceived=outputBytes,
|
||||
|
|
@ -213,17 +230,29 @@ class AiObjects:
|
|||
outputBytes = len(content.encode("utf-8"))
|
||||
|
||||
# Calculate price using model's own price calculation method
|
||||
priceUsd = model.calculatePriceUsd(processingTime, inputBytes, outputBytes)
|
||||
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes)
|
||||
|
||||
return AiCallResponse(
|
||||
response = AiCallResponse(
|
||||
content=content,
|
||||
modelName=model.name,
|
||||
priceUsd=priceUsd,
|
||||
provider=model.connectorType,
|
||||
priceCHF=priceCHF,
|
||||
processingTime=processingTime,
|
||||
bytesSent=inputBytes,
|
||||
bytesReceived=outputBytes,
|
||||
errorCount=0
|
||||
)
|
||||
|
||||
# BILLING: Record billing for THIS specific model call
|
||||
# billingCallback is set by serviceAi and records one billing transaction
|
||||
# per model call with exact provider + model name
|
||||
if self.billingCallback:
|
||||
try:
|
||||
self.billingCallback(response)
|
||||
except Exception as e:
|
||||
logger.error(f"BILLING: Failed to record billing for model {model.name}: {e}")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Utility methods
|
||||
|
|
|
|||
|
|
@ -51,12 +51,19 @@ def initBootstrap(db: DatabaseConnector) -> None:
|
|||
# Initialize root mandate
|
||||
mandateId = initRootMandate(db)
|
||||
|
||||
# Initialize roles FIRST (needed for AccessRules)
|
||||
# Migrate existing mandate records: description -> label
|
||||
_migrateMandateDescriptionToLabel(db)
|
||||
|
||||
# Initialize system role TEMPLATES (mandateId=None, isSystemRole=True)
|
||||
initRoles(db)
|
||||
|
||||
# Initialize RBAC rules (uses roleIds from roles)
|
||||
# Initialize RBAC rules for template roles
|
||||
initRbacRules(db)
|
||||
|
||||
# Copy system template roles to ALL mandates as mandate-instance roles
|
||||
# This also serves as migration for existing mandates that don't have instance roles yet
|
||||
_ensureAllMandatesHaveSystemRoles(db)
|
||||
|
||||
# Initialize admin user
|
||||
adminUserId = initAdminUser(db, mandateId)
|
||||
|
||||
|
|
@ -64,6 +71,7 @@ def initBootstrap(db: DatabaseConnector) -> None:
|
|||
eventUserId = initEventUser(db, mandateId)
|
||||
|
||||
# Assign initial user memberships (via UserMandate + UserMandateRole)
|
||||
# Uses mandate-instance roles (not template roles)
|
||||
if adminUserId and eventUserId and mandateId:
|
||||
assignInitialUserMemberships(db, mandateId, adminUserId, eventUserId)
|
||||
|
||||
|
|
@ -72,6 +80,14 @@ def initBootstrap(db: DatabaseConnector) -> None:
|
|||
|
||||
# Seed automation templates (after admin user exists)
|
||||
initAutomationTemplates(db, adminUserId)
|
||||
|
||||
# Initialize feature instances for root mandate
|
||||
if mandateId:
|
||||
initRootMandateFeatures(db, mandateId)
|
||||
|
||||
# Initialize billing settings for root mandate
|
||||
if mandateId:
|
||||
initRootMandateBilling(mandateId)
|
||||
|
||||
|
||||
def initAutomationTemplates(dbApp: DatabaseConnector, adminUserId: Optional[str] = None) -> None:
|
||||
|
|
@ -116,7 +132,7 @@ def initAutomationTemplates(dbApp: DatabaseConnector, adminUserId: Optional[str]
|
|||
|
||||
# Get admin user ID if not provided (from poweron_app)
|
||||
if not adminUserId:
|
||||
adminUsers = dbApp.getRecordset(UserInDB, {"email": APP_CONFIG.ADMIN_EMAIL})
|
||||
adminUsers = dbApp.getRecordset(UserInDB, recordFilter={"email": APP_CONFIG.ADMIN_EMAIL})
|
||||
adminUserId = adminUsers[0]["id"] if adminUsers else None
|
||||
# Update context with admin user
|
||||
if adminUserId:
|
||||
|
|
@ -153,9 +169,86 @@ def initAutomationTemplates(dbApp: DatabaseConnector, adminUserId: Optional[str]
|
|||
logger.info("System bootstrap completed")
|
||||
|
||||
|
||||
def initRootMandateFeatures(db: DatabaseConnector, mandateId: str) -> None:
|
||||
"""
|
||||
Create feature instances for root mandate.
|
||||
Dynamically discovers all feature modules with autoCreateInstance=True.
|
||||
|
||||
Args:
|
||||
db: Database connector instance
|
||||
mandateId: Root mandate ID
|
||||
"""
|
||||
from modules.datamodels.datamodelFeatures import FeatureInstance
|
||||
from modules.interfaces.interfaceFeatures import getFeatureInterface
|
||||
from modules.system.registry import loadFeatureMainModules
|
||||
|
||||
logger.info("Initializing root mandate features")
|
||||
|
||||
# Dynamically discover features with autoCreateInstance=True
|
||||
featuresToCreate = []
|
||||
mainModules = loadFeatureMainModules()
|
||||
|
||||
for featureName, module in mainModules.items():
|
||||
if hasattr(module, "getFeatureDefinition"):
|
||||
try:
|
||||
featureDef = module.getFeatureDefinition()
|
||||
if featureDef.get("autoCreateInstance", False):
|
||||
featureCode = featureDef.get("code", featureName)
|
||||
featureLabel = featureDef.get("label", {}).get("en", featureName)
|
||||
featuresToCreate.append({"code": featureCode, "label": featureLabel})
|
||||
logger.debug(f"Feature '{featureCode}' marked for auto-creation in root mandate")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not read feature definition for '{featureName}': {e}")
|
||||
|
||||
if not featuresToCreate:
|
||||
logger.info("No features marked for auto-creation in root mandate")
|
||||
return
|
||||
|
||||
featureInterface = getFeatureInterface(db)
|
||||
|
||||
for featureConfig in featuresToCreate:
|
||||
featureCode = featureConfig["code"]
|
||||
featureLabel = featureConfig["label"]
|
||||
|
||||
try:
|
||||
# Check if instance already exists
|
||||
existingInstances = db.getRecordset(
|
||||
FeatureInstance,
|
||||
recordFilter={
|
||||
"mandateId": mandateId,
|
||||
"featureCode": featureCode
|
||||
}
|
||||
)
|
||||
|
||||
if existingInstances:
|
||||
logger.info(f"Feature instance for '{featureCode}' already exists in root mandate")
|
||||
continue
|
||||
|
||||
# Create feature instance with template roles copied
|
||||
instance = featureInterface.createFeatureInstance(
|
||||
featureCode=featureCode,
|
||||
mandateId=mandateId,
|
||||
label=featureLabel,
|
||||
enabled=True,
|
||||
copyTemplateRoles=True
|
||||
)
|
||||
|
||||
if instance:
|
||||
instanceId = instance.get("id") if isinstance(instance, dict) else instance.id
|
||||
logger.info(f"Created feature instance '{instanceId}' for '{featureCode}' in root mandate")
|
||||
else:
|
||||
logger.warning(f"Failed to create feature instance for '{featureCode}'")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating feature instance for '{featureCode}': {e}")
|
||||
|
||||
logger.info("Root mandate features initialization completed")
|
||||
|
||||
|
||||
def initRootMandate(db: DatabaseConnector) -> Optional[str]:
|
||||
"""
|
||||
Creates the Root mandate if it doesn't exist.
|
||||
Root mandate is identified by name='root' AND isSystem=True.
|
||||
|
||||
Args:
|
||||
db: Database connector instance
|
||||
|
|
@ -163,20 +256,55 @@ def initRootMandate(db: DatabaseConnector) -> Optional[str]:
|
|||
Returns:
|
||||
Mandate ID if created or found, None otherwise
|
||||
"""
|
||||
existingMandates = db.getRecordset(Mandate)
|
||||
# Find existing root mandate by name AND isSystem flag
|
||||
existingMandates = db.getRecordset(Mandate, recordFilter={"name": "root", "isSystem": True})
|
||||
if existingMandates:
|
||||
mandateId = existingMandates[0].get("id")
|
||||
logger.info(f"Root mandate already exists with ID {mandateId}")
|
||||
return mandateId
|
||||
|
||||
# Check for legacy root mandates (name="Root" without isSystem flag) and migrate
|
||||
legacyMandates = db.getRecordset(Mandate, recordFilter={"name": "Root"})
|
||||
if legacyMandates:
|
||||
mandateId = legacyMandates[0].get("id")
|
||||
logger.info(f"Migrating legacy Root mandate {mandateId}: setting name='root', isSystem=True")
|
||||
db.recordModify(Mandate, mandateId, {"name": "root", "isSystem": True})
|
||||
return mandateId
|
||||
|
||||
logger.info("Creating Root mandate")
|
||||
rootMandate = Mandate(name="Root", enabled=True)
|
||||
rootMandate = Mandate(name="root", isSystem=True, enabled=True)
|
||||
createdMandate = db.recordCreate(Mandate, rootMandate)
|
||||
mandateId = createdMandate.get("id")
|
||||
logger.info(f"Root mandate created with ID {mandateId}")
|
||||
return mandateId
|
||||
|
||||
|
||||
def _migrateMandateDescriptionToLabel(db: DatabaseConnector) -> None:
|
||||
"""
|
||||
Migration: Rename 'description' field to 'label' in all Mandate records.
|
||||
Copies existing 'description' values to 'label' and removes the old field.
|
||||
Safe to run multiple times (idempotent).
|
||||
"""
|
||||
allMandates = db.getRecordset(Mandate)
|
||||
migratedCount = 0
|
||||
for mandateRecord in allMandates:
|
||||
mandateId = mandateRecord.get("id")
|
||||
hasDescription = "description" in mandateRecord and mandateRecord.get("description") is not None
|
||||
hasLabel = "label" in mandateRecord and mandateRecord.get("label") is not None
|
||||
|
||||
if hasDescription and not hasLabel:
|
||||
# Copy description to label
|
||||
updateData = {"label": mandateRecord["description"]}
|
||||
db.recordModify(Mandate, mandateId, updateData)
|
||||
migratedCount += 1
|
||||
logger.info(f"Migrated mandate {mandateId}: description -> label")
|
||||
|
||||
if migratedCount > 0:
|
||||
logger.info(f"Migrated {migratedCount} mandate(s) from description to label")
|
||||
else:
|
||||
logger.debug("No mandate description->label migration needed")
|
||||
|
||||
|
||||
def initAdminUser(db: DatabaseConnector, mandateId: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
Creates the Admin user if it doesn't exist.
|
||||
|
|
@ -314,11 +442,113 @@ def initRoles(db: DatabaseConnector) -> None:
|
|||
logger.warning(f"Error creating role {role.roleLabel}: {e}")
|
||||
else:
|
||||
_roleIdCache[role.roleLabel] = existingRoleLabels[role.roleLabel]
|
||||
logger.debug(f"Role {role.roleLabel} already exists with ID {existingRoleLabels[role.roleLabel]}")
|
||||
|
||||
logger.info("Roles initialization completed")
|
||||
|
||||
|
||||
def _ensureAllMandatesHaveSystemRoles(db: DatabaseConnector) -> None:
|
||||
"""
|
||||
Ensure all existing mandates have system-instance roles.
|
||||
Serves as both initial setup and migration for existing mandates.
|
||||
"""
|
||||
allMandates = db.getRecordset(Mandate)
|
||||
if not allMandates:
|
||||
return
|
||||
|
||||
for mandate in allMandates:
|
||||
mandateId = mandate.get("id")
|
||||
copySystemRolesToMandate(db, mandateId)
|
||||
|
||||
|
||||
def copySystemRolesToMandate(db: DatabaseConnector, mandateId: str) -> int:
|
||||
"""
|
||||
Copy system template roles (mandateId=None, isSystemRole=True) to a mandate
|
||||
as mandate-instance roles. Also copies all AccessRules for each role.
|
||||
|
||||
This is analogous to how feature template roles are copied to feature instances.
|
||||
Each mandate gets its own instances of admin/user/viewer with their AccessRules.
|
||||
|
||||
Args:
|
||||
db: Database connector instance
|
||||
mandateId: Target mandate ID
|
||||
|
||||
Returns:
|
||||
Number of roles copied
|
||||
"""
|
||||
import uuid as _uuid
|
||||
|
||||
# Find system template roles (global, no mandateId)
|
||||
templateRoles = db.getRecordset(
|
||||
Role,
|
||||
recordFilter={"isSystemRole": True, "mandateId": None}
|
||||
)
|
||||
|
||||
if not templateRoles:
|
||||
logger.debug("No system template roles found to copy")
|
||||
return 0
|
||||
|
||||
# Check which roles already exist for this mandate
|
||||
existingMandateRoles = db.getRecordset(
|
||||
Role,
|
||||
recordFilter={"mandateId": mandateId, "featureInstanceId": None}
|
||||
)
|
||||
existingLabels = {r.get("roleLabel") for r in existingMandateRoles}
|
||||
|
||||
# Load all AccessRules for template roles
|
||||
templateRoleIds = [r.get("id") for r in templateRoles]
|
||||
rulesByRoleId = {}
|
||||
for roleId in templateRoleIds:
|
||||
rules = db.getRecordset(AccessRule, recordFilter={"roleId": roleId})
|
||||
rulesByRoleId[roleId] = rules
|
||||
|
||||
copiedCount = 0
|
||||
for templateRole in templateRoles:
|
||||
roleLabel = templateRole.get("roleLabel")
|
||||
|
||||
# Skip if mandate already has this role
|
||||
if roleLabel in existingLabels:
|
||||
logger.debug(f"Mandate {mandateId} already has role '{roleLabel}', skipping")
|
||||
continue
|
||||
|
||||
newRoleId = str(_uuid.uuid4())
|
||||
|
||||
# Create mandate-instance role
|
||||
newRole = Role(
|
||||
id=newRoleId,
|
||||
roleLabel=roleLabel,
|
||||
description=templateRole.get("description", {}),
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=None,
|
||||
featureCode=None,
|
||||
isSystemRole=True # Still a system role, but bound to this mandate
|
||||
)
|
||||
db.recordCreate(Role, newRole.model_dump())
|
||||
|
||||
# Copy AccessRules
|
||||
templateRules = rulesByRoleId.get(templateRole.get("id"), [])
|
||||
for rule in templateRules:
|
||||
newRule = AccessRule(
|
||||
id=str(_uuid.uuid4()),
|
||||
roleId=newRoleId,
|
||||
context=rule.get("context"),
|
||||
item=rule.get("item"),
|
||||
view=rule.get("view", False),
|
||||
read=rule.get("read"),
|
||||
create=rule.get("create"),
|
||||
update=rule.get("update"),
|
||||
delete=rule.get("delete")
|
||||
)
|
||||
db.recordCreate(AccessRule, newRule.model_dump())
|
||||
|
||||
copiedCount += 1
|
||||
logger.info(f"Copied system role '{roleLabel}' to mandate {mandateId} with {len(templateRules)} AccessRules")
|
||||
|
||||
if copiedCount > 0:
|
||||
logger.info(f"Copied {copiedCount} system roles to mandate {mandateId}")
|
||||
|
||||
return copiedCount
|
||||
|
||||
|
||||
def _getRoleId(db: DatabaseConnector, roleLabel: str) -> Optional[str]:
|
||||
"""
|
||||
Get role ID by label, using cache or database lookup.
|
||||
|
|
@ -792,6 +1022,117 @@ def _createTableSpecificRules(db: DatabaseConnector) -> None:
|
|||
delete=AccessLevel.NONE,
|
||||
))
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Billing Namespace - Billing accounts and transactions
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# BillingAccount: User sees own accounts (MY), Admin sees all in mandate (GROUP)
|
||||
# Each user must see all billing accounts assigned to them
|
||||
if adminId:
|
||||
tableRules.append(AccessRule(
|
||||
roleId=adminId,
|
||||
context=AccessRuleContext.DATA,
|
||||
item="data.billing.BillingAccount",
|
||||
view=True,
|
||||
read=AccessLevel.GROUP,
|
||||
create=AccessLevel.NONE,
|
||||
update=AccessLevel.NONE,
|
||||
delete=AccessLevel.NONE,
|
||||
))
|
||||
if userId:
|
||||
tableRules.append(AccessRule(
|
||||
roleId=userId,
|
||||
context=AccessRuleContext.DATA,
|
||||
item="data.billing.BillingAccount",
|
||||
view=True,
|
||||
read=AccessLevel.MY,
|
||||
create=AccessLevel.NONE,
|
||||
update=AccessLevel.NONE,
|
||||
delete=AccessLevel.NONE,
|
||||
))
|
||||
if viewerId:
|
||||
tableRules.append(AccessRule(
|
||||
roleId=viewerId,
|
||||
context=AccessRuleContext.DATA,
|
||||
item="data.billing.BillingAccount",
|
||||
view=True,
|
||||
read=AccessLevel.MY,
|
||||
create=AccessLevel.NONE,
|
||||
update=AccessLevel.NONE,
|
||||
delete=AccessLevel.NONE,
|
||||
))
|
||||
|
||||
# BillingTransaction: User sees own transactions (MY), Admin sees all in mandate (GROUP)
|
||||
if adminId:
|
||||
tableRules.append(AccessRule(
|
||||
roleId=adminId,
|
||||
context=AccessRuleContext.DATA,
|
||||
item="data.billing.BillingTransaction",
|
||||
view=True,
|
||||
read=AccessLevel.GROUP,
|
||||
create=AccessLevel.NONE,
|
||||
update=AccessLevel.NONE,
|
||||
delete=AccessLevel.NONE,
|
||||
))
|
||||
if userId:
|
||||
tableRules.append(AccessRule(
|
||||
roleId=userId,
|
||||
context=AccessRuleContext.DATA,
|
||||
item="data.billing.BillingTransaction",
|
||||
view=True,
|
||||
read=AccessLevel.MY,
|
||||
create=AccessLevel.NONE,
|
||||
update=AccessLevel.NONE,
|
||||
delete=AccessLevel.NONE,
|
||||
))
|
||||
if viewerId:
|
||||
tableRules.append(AccessRule(
|
||||
roleId=viewerId,
|
||||
context=AccessRuleContext.DATA,
|
||||
item="data.billing.BillingTransaction",
|
||||
view=True,
|
||||
read=AccessLevel.MY,
|
||||
create=AccessLevel.NONE,
|
||||
update=AccessLevel.NONE,
|
||||
delete=AccessLevel.NONE,
|
||||
))
|
||||
|
||||
# BillingSettings: Only admin can view mandate settings (read-only)
|
||||
# SysAdmin (flag) manages settings, roles only read
|
||||
if adminId:
|
||||
tableRules.append(AccessRule(
|
||||
roleId=adminId,
|
||||
context=AccessRuleContext.DATA,
|
||||
item="data.billing.BillingSettings",
|
||||
view=True,
|
||||
read=AccessLevel.GROUP,
|
||||
create=AccessLevel.NONE,
|
||||
update=AccessLevel.NONE,
|
||||
delete=AccessLevel.NONE,
|
||||
))
|
||||
if userId:
|
||||
tableRules.append(AccessRule(
|
||||
roleId=userId,
|
||||
context=AccessRuleContext.DATA,
|
||||
item="data.billing.BillingSettings",
|
||||
view=False,
|
||||
read=AccessLevel.NONE,
|
||||
create=AccessLevel.NONE,
|
||||
update=AccessLevel.NONE,
|
||||
delete=AccessLevel.NONE,
|
||||
))
|
||||
if viewerId:
|
||||
tableRules.append(AccessRule(
|
||||
roleId=viewerId,
|
||||
context=AccessRuleContext.DATA,
|
||||
item="data.billing.BillingSettings",
|
||||
view=False,
|
||||
read=AccessLevel.NONE,
|
||||
create=AccessLevel.NONE,
|
||||
update=AccessLevel.NONE,
|
||||
delete=AccessLevel.NONE,
|
||||
))
|
||||
|
||||
# Create all table-specific rules
|
||||
for rule in tableRules:
|
||||
db.recordCreate(AccessRule, rule)
|
||||
|
|
@ -923,8 +1264,7 @@ def _ensureUiContextRules(db: DatabaseConnector) -> None:
|
|||
for rule in missingRules:
|
||||
db.recordCreate(AccessRule, rule)
|
||||
logger.info(f"Created {len(missingRules)} missing UI context rules")
|
||||
else:
|
||||
logger.debug("All UI context rules already exist")
|
||||
# All UI context rules already exist (nothing to create)
|
||||
|
||||
|
||||
def _ensureDataContextRules(db: DatabaseConnector) -> None:
|
||||
|
|
@ -965,6 +1305,13 @@ def _ensureDataContextRules(db: DatabaseConnector) -> None:
|
|||
"data.automation.AutomationTemplate",
|
||||
]
|
||||
|
||||
# Billing tables: read-only for all roles, scoped by role level
|
||||
# Users see their own accounts/transactions (MY), Admins see mandate-wide (GROUP)
|
||||
billingReadOnlyTables = [
|
||||
"data.billing.BillingAccount",
|
||||
"data.billing.BillingTransaction",
|
||||
]
|
||||
|
||||
missingRules = []
|
||||
|
||||
# MY-level rules for user-owned tables
|
||||
|
|
@ -1008,9 +1355,9 @@ def _ensureDataContextRules(db: DatabaseConnector) -> None:
|
|||
delete=AccessLevel.NONE,
|
||||
))
|
||||
|
||||
# ALL-level rules for admin on system templates
|
||||
# Admin rules for system templates (read ALL, write GROUP-scoped)
|
||||
for objectKey in tablesNeedingAllRulesForAdmin:
|
||||
# Admin: ALL-level access (sees all templates)
|
||||
# Admin: read ALL templates, create/update/delete within GROUP (mandate-scoped)
|
||||
if adminId and (adminId, objectKey) not in existingCombinations:
|
||||
missingRules.append(AccessRule(
|
||||
roleId=adminId,
|
||||
|
|
@ -1018,9 +1365,9 @@ def _ensureDataContextRules(db: DatabaseConnector) -> None:
|
|||
item=objectKey,
|
||||
view=True,
|
||||
read=AccessLevel.ALL,
|
||||
create=AccessLevel.ALL,
|
||||
update=AccessLevel.ALL,
|
||||
delete=AccessLevel.ALL,
|
||||
create=AccessLevel.GROUP,
|
||||
update=AccessLevel.GROUP,
|
||||
delete=AccessLevel.GROUP,
|
||||
))
|
||||
|
||||
# User: MY-level access
|
||||
|
|
@ -1049,13 +1396,89 @@ def _ensureDataContextRules(db: DatabaseConnector) -> None:
|
|||
delete=AccessLevel.NONE,
|
||||
))
|
||||
|
||||
# Billing read-only rules: Admin=GROUP, User/Viewer=MY (own accounts/transactions)
|
||||
for objectKey in billingReadOnlyTables:
|
||||
# Admin: GROUP-level read (sees all accounts in their mandates)
|
||||
if adminId and (adminId, objectKey) not in existingCombinations:
|
||||
missingRules.append(AccessRule(
|
||||
roleId=adminId,
|
||||
context=AccessRuleContext.DATA,
|
||||
item=objectKey,
|
||||
view=True,
|
||||
read=AccessLevel.GROUP,
|
||||
create=AccessLevel.NONE,
|
||||
update=AccessLevel.NONE,
|
||||
delete=AccessLevel.NONE,
|
||||
))
|
||||
|
||||
# User: MY-level read (sees only own billing accounts/transactions)
|
||||
if userId and (userId, objectKey) not in existingCombinations:
|
||||
missingRules.append(AccessRule(
|
||||
roleId=userId,
|
||||
context=AccessRuleContext.DATA,
|
||||
item=objectKey,
|
||||
view=True,
|
||||
read=AccessLevel.MY,
|
||||
create=AccessLevel.NONE,
|
||||
update=AccessLevel.NONE,
|
||||
delete=AccessLevel.NONE,
|
||||
))
|
||||
|
||||
# Viewer: MY-level read-only (sees only own billing accounts/transactions)
|
||||
if viewerId and (viewerId, objectKey) not in existingCombinations:
|
||||
missingRules.append(AccessRule(
|
||||
roleId=viewerId,
|
||||
context=AccessRuleContext.DATA,
|
||||
item=objectKey,
|
||||
view=True,
|
||||
read=AccessLevel.MY,
|
||||
create=AccessLevel.NONE,
|
||||
update=AccessLevel.NONE,
|
||||
delete=AccessLevel.NONE,
|
||||
))
|
||||
|
||||
# BillingSettings: Admin can view (GROUP), User/Viewer have no access
|
||||
billingSettingsKey = "data.billing.BillingSettings"
|
||||
if adminId and (adminId, billingSettingsKey) not in existingCombinations:
|
||||
missingRules.append(AccessRule(
|
||||
roleId=adminId,
|
||||
context=AccessRuleContext.DATA,
|
||||
item=billingSettingsKey,
|
||||
view=True,
|
||||
read=AccessLevel.GROUP,
|
||||
create=AccessLevel.NONE,
|
||||
update=AccessLevel.NONE,
|
||||
delete=AccessLevel.NONE,
|
||||
))
|
||||
if userId and (userId, billingSettingsKey) not in existingCombinations:
|
||||
missingRules.append(AccessRule(
|
||||
roleId=userId,
|
||||
context=AccessRuleContext.DATA,
|
||||
item=billingSettingsKey,
|
||||
view=False,
|
||||
read=AccessLevel.NONE,
|
||||
create=AccessLevel.NONE,
|
||||
update=AccessLevel.NONE,
|
||||
delete=AccessLevel.NONE,
|
||||
))
|
||||
if viewerId and (viewerId, billingSettingsKey) not in existingCombinations:
|
||||
missingRules.append(AccessRule(
|
||||
roleId=viewerId,
|
||||
context=AccessRuleContext.DATA,
|
||||
item=billingSettingsKey,
|
||||
view=False,
|
||||
read=AccessLevel.NONE,
|
||||
create=AccessLevel.NONE,
|
||||
update=AccessLevel.NONE,
|
||||
delete=AccessLevel.NONE,
|
||||
))
|
||||
|
||||
# Create missing rules
|
||||
if missingRules:
|
||||
for rule in missingRules:
|
||||
db.recordCreate(AccessRule, rule)
|
||||
logger.info(f"Created {len(missingRules)} missing DATA context rules")
|
||||
else:
|
||||
logger.debug("All DATA context rules already exist")
|
||||
# All DATA context rules already exist (nothing to create)
|
||||
|
||||
# Update existing AutomationTemplate rules for admin/viewer to ALL access
|
||||
_updateAutomationTemplateRulesToAll(db, adminId, viewerId)
|
||||
|
|
@ -1063,8 +1486,9 @@ def _ensureDataContextRules(db: DatabaseConnector) -> None:
|
|||
|
||||
def _updateAutomationTemplateRulesToAll(db: DatabaseConnector, adminId: Optional[str], viewerId: Optional[str]) -> None:
|
||||
"""
|
||||
Update existing AutomationTemplate RBAC rules from MY to ALL for admin and viewer.
|
||||
This ensures sysadmins can see all templates (including system-seeded ones).
|
||||
Update existing AutomationTemplate RBAC rules to correct levels.
|
||||
- Admin: read=ALL, create/update/delete=GROUP (mandate-scoped writes)
|
||||
- Viewer: read=ALL (read-only)
|
||||
"""
|
||||
if not adminId and not viewerId:
|
||||
return
|
||||
|
|
@ -1086,14 +1510,29 @@ def _updateAutomationTemplateRulesToAll(db: DatabaseConnector, adminId: Optional
|
|||
roleId = rule.get("roleId")
|
||||
currentReadLevel = rule.get("read")
|
||||
|
||||
# Update admin and viewer rules from MY to ALL
|
||||
if roleId in [adminId, viewerId] and currentReadLevel == AccessLevel.MY.value:
|
||||
if roleId == adminId:
|
||||
# Admin: read ALL, write GROUP
|
||||
updates = {}
|
||||
if currentReadLevel != AccessLevel.ALL.value:
|
||||
updates["read"] = AccessLevel.ALL.value
|
||||
currentCreate = rule.get("create")
|
||||
if currentCreate == AccessLevel.ALL.value:
|
||||
updates["create"] = AccessLevel.GROUP.value
|
||||
updates["update"] = AccessLevel.GROUP.value
|
||||
updates["delete"] = AccessLevel.GROUP.value
|
||||
if updates:
|
||||
db.recordModify(AccessRule, ruleId, updates)
|
||||
updatedCount += 1
|
||||
logger.debug(f"Updated AutomationTemplate rule {ruleId} for admin to read=ALL, write=GROUP")
|
||||
|
||||
elif roleId == viewerId and currentReadLevel == AccessLevel.MY.value:
|
||||
# Viewer: read ALL (read-only)
|
||||
db.recordModify(AccessRule, ruleId, {"read": AccessLevel.ALL.value})
|
||||
updatedCount += 1
|
||||
logger.debug(f"Updated AutomationTemplate rule {ruleId} for role {roleId} to ALL access")
|
||||
logger.debug(f"Updated AutomationTemplate rule {ruleId} for viewer to read=ALL")
|
||||
|
||||
if updatedCount > 0:
|
||||
logger.info(f"Updated {updatedCount} AutomationTemplate RBAC rules to ALL access")
|
||||
logger.info(f"Updated {updatedCount} AutomationTemplate RBAC rules")
|
||||
|
||||
|
||||
def _createResourceContextRules(db: DatabaseConnector) -> None:
|
||||
|
|
@ -1108,8 +1547,8 @@ def _createResourceContextRules(db: DatabaseConnector) -> None:
|
|||
"""
|
||||
resourceRules = []
|
||||
|
||||
# All roles get full resource access by default (no sysadmin - that's a flag)
|
||||
for roleLabel in ["admin", "user", "viewer"]:
|
||||
# Admin and User get default resource access; Viewer gets NO resource access
|
||||
for roleLabel in ["admin", "user"]:
|
||||
roleId = _getRoleId(db, roleLabel)
|
||||
if roleId:
|
||||
resourceRules.append(AccessRule(
|
||||
|
|
@ -1123,10 +1562,170 @@ def _createResourceContextRules(db: DatabaseConnector) -> None:
|
|||
delete=None,
|
||||
))
|
||||
|
||||
# Viewer: no default RESOURCE access (viewer cannot use system resources)
|
||||
|
||||
for rule in resourceRules:
|
||||
db.recordCreate(AccessRule, rule)
|
||||
|
||||
logger.info(f"Created {len(resourceRules)} RESOURCE context rules")
|
||||
|
||||
# Create AICore provider RBAC rules
|
||||
_createAicoreProviderRules(db)
|
||||
|
||||
|
||||
def _createAicoreProviderRules(db: DatabaseConnector) -> None:
|
||||
"""
|
||||
Create RBAC rules for AICore providers (resource.aicore.{provider}).
|
||||
|
||||
Provider access per role:
|
||||
- admin: all providers allowed
|
||||
- user: all providers EXCEPT anthropic (view=False)
|
||||
- viewer: NO provider access (viewer has no RESOURCE permissions)
|
||||
|
||||
NOTE: Provider list is dynamically discovered from AICore model registry.
|
||||
|
||||
Args:
|
||||
db: Database connector instance
|
||||
"""
|
||||
try:
|
||||
from modules.aicore.aicoreModelRegistry import modelRegistry
|
||||
|
||||
# Discover available connectors dynamically
|
||||
connectors = modelRegistry.discoverConnectors()
|
||||
providers = [c.getConnectorType() for c in connectors]
|
||||
|
||||
if not providers:
|
||||
logger.warning("No AICore providers discovered, skipping provider RBAC rules")
|
||||
return
|
||||
|
||||
logger.info(f"Creating RBAC rules for AICore providers: {providers}")
|
||||
|
||||
providerRules = []
|
||||
|
||||
# Admin: access to ALL providers
|
||||
adminId = _getRoleId(db, "admin")
|
||||
if adminId:
|
||||
for provider in providers:
|
||||
resourceKey = f"resource.aicore.{provider}"
|
||||
existingRules = db.getRecordset(
|
||||
AccessRule,
|
||||
recordFilter={
|
||||
"roleId": adminId,
|
||||
"context": AccessRuleContext.RESOURCE.value,
|
||||
"item": resourceKey
|
||||
}
|
||||
)
|
||||
if not existingRules:
|
||||
providerRules.append(AccessRule(
|
||||
roleId=adminId,
|
||||
context=AccessRuleContext.RESOURCE,
|
||||
item=resourceKey,
|
||||
view=True,
|
||||
read=None, create=None, update=None, delete=None,
|
||||
))
|
||||
|
||||
# User: access to all providers EXCEPT anthropic
|
||||
userId = _getRoleId(db, "user")
|
||||
if userId:
|
||||
for provider in providers:
|
||||
resourceKey = f"resource.aicore.{provider}"
|
||||
existingRules = db.getRecordset(
|
||||
AccessRule,
|
||||
recordFilter={
|
||||
"roleId": userId,
|
||||
"context": AccessRuleContext.RESOURCE.value,
|
||||
"item": resourceKey
|
||||
}
|
||||
)
|
||||
if not existingRules:
|
||||
# Anthropic is not allowed for user role
|
||||
isAllowed = provider != "anthropic"
|
||||
providerRules.append(AccessRule(
|
||||
roleId=userId,
|
||||
context=AccessRuleContext.RESOURCE,
|
||||
item=resourceKey,
|
||||
view=isAllowed,
|
||||
read=None, create=None, update=None, delete=None,
|
||||
))
|
||||
|
||||
# Viewer: NO provider access (viewer has no RESOURCE permissions at all)
|
||||
|
||||
for rule in providerRules:
|
||||
db.recordCreate(AccessRule, rule)
|
||||
|
||||
if providerRules:
|
||||
logger.info(f"Created {len(providerRules)} AICore provider RBAC rules")
|
||||
else:
|
||||
logger.debug("All AICore provider RBAC rules already exist")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create AICore provider RBAC rules: {e}")
|
||||
|
||||
|
||||
def initRootMandateBilling(mandateId: str) -> None:
|
||||
"""
|
||||
Initialize billing settings for root mandate.
|
||||
Root mandate uses PREPAY_USER model with 10 CHF initial credit per user.
|
||||
Creates billing accounts for ALL users regardless of billing model (for audit trail).
|
||||
|
||||
Args:
|
||||
mandateId: Root mandate ID
|
||||
"""
|
||||
try:
|
||||
from modules.interfaces.interfaceDbBilling import _getRootInterface
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface as getAppRootInterface
|
||||
from modules.datamodels.datamodelBilling import BillingSettings, BillingModelEnum
|
||||
|
||||
billingInterface = _getRootInterface()
|
||||
appInterface = getAppRootInterface()
|
||||
|
||||
# Check if settings already exist
|
||||
existingSettings = billingInterface.getSettings(mandateId)
|
||||
if existingSettings:
|
||||
logger.info("Billing settings for root mandate already exist")
|
||||
else:
|
||||
settings = BillingSettings(
|
||||
mandateId=mandateId,
|
||||
billingModel=BillingModelEnum.PREPAY_USER,
|
||||
defaultUserCredit=10.0,
|
||||
warningThresholdPercent=10.0,
|
||||
blockOnZeroBalance=True,
|
||||
notifyOnWarning=True
|
||||
)
|
||||
|
||||
billingInterface.createSettings(settings)
|
||||
logger.info(f"Created billing settings for root mandate: PREPAY_USER with 10 CHF default credit")
|
||||
existingSettings = billingInterface.getSettings(mandateId)
|
||||
|
||||
# Always create user accounts for all users (audit trail)
|
||||
if existingSettings:
|
||||
billingModel = existingSettings.get("billingModel", "UNLIMITED")
|
||||
if billingModel == BillingModelEnum.UNLIMITED.value:
|
||||
return # No accounts needed for UNLIMITED
|
||||
|
||||
# Initial balance depends on billing model
|
||||
if billingModel == BillingModelEnum.PREPAY_USER.value:
|
||||
initialBalance = existingSettings.get("defaultUserCredit", 10.0)
|
||||
else:
|
||||
initialBalance = 0.0 # PREPAY_MANDATE / CREDIT_POSTPAY: budget on pool
|
||||
|
||||
userMandates = appInterface.getUserMandatesByMandate(mandateId)
|
||||
accountsCreated = 0
|
||||
|
||||
for um in userMandates:
|
||||
userId = um.get("userId") if isinstance(um, dict) else getattr(um, "userId", None)
|
||||
if userId:
|
||||
existingAccount = billingInterface.getUserAccount(mandateId, userId)
|
||||
if not existingAccount:
|
||||
billingInterface.getOrCreateUserAccount(mandateId, userId, initialBalance=initialBalance)
|
||||
accountsCreated += 1
|
||||
logger.debug(f"Created billing account for user {userId}")
|
||||
|
||||
if accountsCreated > 0:
|
||||
logger.info(f"Created {accountsCreated} billing accounts for root mandate users with {initialBalance} CHF each")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize root mandate billing (non-critical): {e}")
|
||||
|
||||
|
||||
def assignInitialUserMemberships(
|
||||
|
|
@ -1148,10 +1747,14 @@ def assignInitialUserMemberships(
|
|||
adminUserId: Admin user ID
|
||||
eventUserId: Event user ID
|
||||
"""
|
||||
# Use "admin" role for mandate membership (SysAdmin is a flag, not a role!)
|
||||
adminRoleId = _getRoleId(db, "admin")
|
||||
# Use mandate-instance "admin" role (not the global template)
|
||||
mandateAdminRoles = db.getRecordset(
|
||||
Role,
|
||||
recordFilter={"roleLabel": "admin", "mandateId": mandateId, "featureInstanceId": None}
|
||||
)
|
||||
adminRoleId = mandateAdminRoles[0].get("id") if mandateAdminRoles else None
|
||||
if not adminRoleId:
|
||||
logger.warning("Admin role not found, skipping membership assignment")
|
||||
logger.warning(f"Admin role not found for mandate {mandateId}, skipping membership assignment")
|
||||
return
|
||||
|
||||
for userId, userName in [(adminUserId, "admin"), (eventUserId, "event")]:
|
||||
|
|
@ -1163,7 +1766,6 @@ def assignInitialUserMemberships(
|
|||
|
||||
if existingMemberships:
|
||||
userMandateId = existingMemberships[0].get("id")
|
||||
logger.debug(f"UserMandate already exists for {userName} user")
|
||||
else:
|
||||
# Create UserMandate
|
||||
userMandate = UserMandate(
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ from modules.datamodels.datamodelMembership import (
|
|||
)
|
||||
from modules.datamodels.datamodelFeatures import Feature, FeatureInstance
|
||||
from modules.datamodels.datamodelInvitation import Invitation
|
||||
from modules.datamodels.datamodelNotification import UserNotification
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -54,6 +55,9 @@ _gatewayInterfaces = {}
|
|||
# Root interface instance
|
||||
_rootAppObjects = None
|
||||
|
||||
# Bootstrap completion flag - ensures bootstrap runs only ONCE per application lifecycle
|
||||
_bootstrapCompleted = False
|
||||
|
||||
# Password-Hashing
|
||||
pwdContext = CryptContext(schemes=["argon2"], deprecated="auto")
|
||||
|
||||
|
|
@ -149,9 +153,6 @@ class AppObjects:
|
|||
userId=self.userId,
|
||||
)
|
||||
|
||||
# Initialize database system
|
||||
self.db.initDbSystem()
|
||||
|
||||
logger.info(f"Database initialized successfully for user {self.userId}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database: {str(e)}")
|
||||
|
|
@ -199,8 +200,28 @@ class AppObjects:
|
|||
return simpleFields, objectFields
|
||||
|
||||
def _initRecords(self):
|
||||
"""Initialize standard records if they don't exist."""
|
||||
initBootstrap(self.db)
|
||||
"""Initialize standard records if they don't exist.
|
||||
|
||||
Uses a global flag to ensure bootstrap only runs ONCE per application lifecycle.
|
||||
The flag is set BEFORE calling bootstrap to prevent recursive calls during bootstrap.
|
||||
"""
|
||||
global _bootstrapCompleted
|
||||
|
||||
if _bootstrapCompleted:
|
||||
return
|
||||
|
||||
# Set flag BEFORE bootstrap to prevent recursive calls during bootstrap
|
||||
_bootstrapCompleted = True
|
||||
logger.info("Starting bootstrap (will only run once)")
|
||||
|
||||
try:
|
||||
initBootstrap(self.db)
|
||||
logger.info("Bootstrap completed successfully")
|
||||
except Exception as e:
|
||||
# Reset flag on failure so bootstrap can be retried
|
||||
_bootstrapCompleted = False
|
||||
logger.error(f"Bootstrap failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def checkRbacPermission(
|
||||
|
|
@ -224,10 +245,13 @@ class AppObjects:
|
|||
return False
|
||||
|
||||
tableName = modelClass.__name__
|
||||
# Use buildDataObjectKey for semantic namespace lookup
|
||||
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
||||
objectKey = buildDataObjectKey(tableName)
|
||||
permissions = self.rbac.getUserPermissions(
|
||||
self.currentUser,
|
||||
AccessRuleContext.DATA,
|
||||
tableName,
|
||||
objectKey,
|
||||
mandateId=self.mandateId
|
||||
)
|
||||
|
||||
|
|
@ -458,17 +482,12 @@ class AppObjects:
|
|||
"""Returns the initial ID for a table."""
|
||||
return self.db.getInitialId(model_class)
|
||||
|
||||
def _getDefaultMandateId(self) -> str:
|
||||
"""Get the default mandate ID, creating it if necessary."""
|
||||
defaultMandateId = self.getInitialId(Mandate)
|
||||
if not defaultMandateId:
|
||||
# If no default mandate exists, create one
|
||||
logger.warning("No default mandate found, creating Root mandate")
|
||||
self._initRootMandate()
|
||||
defaultMandateId = self.getInitialId(Mandate)
|
||||
if not defaultMandateId:
|
||||
raise ValueError("Failed to get or create default mandate")
|
||||
return defaultMandateId
|
||||
def _getRootMandateId(self) -> Optional[str]:
|
||||
"""Get the root mandate ID (name='root', isSystem=True)."""
|
||||
rootMandates = self.db.getRecordset(Mandate, recordFilter={"name": "root", "isSystem": True})
|
||||
if rootMandates:
|
||||
return rootMandates[0].get("id")
|
||||
return None
|
||||
|
||||
def _getPasswordHash(self, password: str) -> str:
|
||||
"""Creates a hash for a password."""
|
||||
|
|
@ -733,6 +752,10 @@ class AppObjects:
|
|||
|
||||
# Clear cache to ensure fresh data (already done above)
|
||||
|
||||
# Assign new user to the root mandate with system 'viewer' role
|
||||
userId = createdUser[0]["id"]
|
||||
self._assignUserToRootMandate(userId)
|
||||
|
||||
return User(**createdUser[0])
|
||||
|
||||
except ValueError as e:
|
||||
|
|
@ -796,6 +819,48 @@ class AppObjects:
|
|||
logger.error(f"Error updating user: {str(e)}")
|
||||
raise ValueError(f"Failed to update user: {str(e)}")
|
||||
|
||||
def _assignUserToRootMandate(self, userId: str) -> None:
|
||||
"""
|
||||
Assign a new user to the root mandate with the mandate-instance 'viewer' role.
|
||||
This ensures every user has a base membership in the system mandate.
|
||||
|
||||
Uses the mandate-instance role (mandateId=rootMandateId), not the global template.
|
||||
Feature instance access is NOT granted here - it is managed separately
|
||||
via invitations or admin assignment.
|
||||
|
||||
Args:
|
||||
userId: User ID to assign
|
||||
"""
|
||||
try:
|
||||
from modules.datamodels.datamodelRbac import Role
|
||||
|
||||
rootMandateId = self._getRootMandateId()
|
||||
if not rootMandateId:
|
||||
logger.warning("No root mandate found, skipping root mandate assignment")
|
||||
return
|
||||
|
||||
# Check if user already has a mandate membership
|
||||
existing = self.getUserMandate(userId, rootMandateId)
|
||||
if existing:
|
||||
logger.debug(f"User {userId} already assigned to root mandate")
|
||||
return
|
||||
|
||||
# Find the mandate-instance 'viewer' role (bound to this mandate, not a global template)
|
||||
mandateViewerRoles = self.db.getRecordset(
|
||||
Role,
|
||||
recordFilter={"roleLabel": "viewer", "mandateId": rootMandateId, "featureInstanceId": None}
|
||||
)
|
||||
viewerRoleId = mandateViewerRoles[0].get("id") if mandateViewerRoles else None
|
||||
|
||||
roleIds = [viewerRoleId] if viewerRoleId else []
|
||||
|
||||
self.createUserMandate(userId, rootMandateId, roleIds)
|
||||
logger.info(f"Assigned user {userId} to root mandate with viewer role")
|
||||
|
||||
except Exception as e:
|
||||
# Log but don't fail user creation
|
||||
logger.error(f"Error assigning user {userId} to root mandate: {e}")
|
||||
|
||||
def disableUser(self, userId: str) -> User:
|
||||
"""Disables a user if current user has permission."""
|
||||
return self.updateUser(userId, {"enabled": False})
|
||||
|
|
@ -1209,6 +1274,31 @@ class AppObjects:
|
|||
logger.error(f"Error getting user connections: {str(e)}")
|
||||
return []
|
||||
|
||||
def getUserConnectionById(self, connectionId: str) -> Optional[UserConnection]:
|
||||
"""Get a single UserConnection by ID."""
|
||||
try:
|
||||
connections = self.db.getRecordset(
|
||||
UserConnection, recordFilter={"id": connectionId}
|
||||
)
|
||||
if connections:
|
||||
conn_dict = connections[0]
|
||||
return UserConnection(
|
||||
id=conn_dict["id"],
|
||||
userId=conn_dict["userId"],
|
||||
authority=conn_dict.get("authority"),
|
||||
externalId=conn_dict.get("externalId", ""),
|
||||
externalUsername=conn_dict.get("externalUsername", ""),
|
||||
externalEmail=conn_dict.get("externalEmail"),
|
||||
status=conn_dict.get("status", "pending"),
|
||||
connectedAt=conn_dict.get("connectedAt"),
|
||||
lastChecked=conn_dict.get("lastChecked"),
|
||||
expiresAt=conn_dict.get("expiresAt"),
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user connection by ID: {str(e)}")
|
||||
return None
|
||||
|
||||
def addUserConnection(
|
||||
self,
|
||||
userId: str,
|
||||
|
|
@ -1354,19 +1444,32 @@ class AppObjects:
|
|||
|
||||
return Mandate(**filteredMandates[0])
|
||||
|
||||
def createMandate(self, name: str, description: str = None, enabled: bool = True) -> Mandate:
|
||||
"""Creates a new mandate if user has permission."""
|
||||
def createMandate(self, name: str, label: str = None, enabled: bool = True) -> Mandate:
|
||||
"""
|
||||
Creates a new mandate if user has permission.
|
||||
Automatically copies system template roles (admin, user, viewer) to the new mandate.
|
||||
"""
|
||||
if not self.checkRbacPermission(Mandate, "create"):
|
||||
raise PermissionError("No permission to create mandates")
|
||||
|
||||
# Create mandate data using model
|
||||
mandateData = Mandate(name=name, description=description, enabled=enabled)
|
||||
mandateData = Mandate(name=name, label=label, enabled=enabled)
|
||||
|
||||
# Create mandate record
|
||||
createdRecord = self.db.recordCreate(Mandate, mandateData)
|
||||
if not createdRecord or not createdRecord.get("id"):
|
||||
raise ValueError("Failed to create mandate record")
|
||||
|
||||
mandateId = createdRecord.get("id")
|
||||
|
||||
# Copy system template roles to new mandate (admin, user, viewer + AccessRules)
|
||||
try:
|
||||
from modules.interfaces.interfaceBootstrap import copySystemRolesToMandate
|
||||
copiedCount = copySystemRolesToMandate(self.db, mandateId)
|
||||
logger.info(f"Copied {copiedCount} system roles to new mandate {mandateId}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error copying system roles to mandate {mandateId}: {e}")
|
||||
|
||||
return Mandate(**createdRecord)
|
||||
|
||||
def updateMandate(self, mandateId: str, updateData: Dict[str, Any]) -> Mandate:
|
||||
|
|
@ -1381,9 +1484,13 @@ class AppObjects:
|
|||
if not mandate:
|
||||
raise ValueError(f"Mandate {mandateId} not found")
|
||||
|
||||
# Strip immutable/protected fields from update data
|
||||
_protectedFields = {"id", "isSystem"}
|
||||
_sanitizedData = {k: v for k, v in updateData.items() if k not in _protectedFields}
|
||||
|
||||
# Update mandate data using model
|
||||
updatedData = mandate.model_dump()
|
||||
updatedData.update(updateData)
|
||||
updatedData.update(_sanitizedData)
|
||||
updatedMandate = Mandate(**updatedData)
|
||||
|
||||
# Update mandate record
|
||||
|
|
@ -1403,13 +1510,17 @@ class AppObjects:
|
|||
raise ValueError(f"Failed to update mandate: {str(e)}")
|
||||
|
||||
def deleteMandate(self, mandateId: str) -> bool:
|
||||
"""Deletes a mandate if user has access."""
|
||||
"""Deletes a mandate if user has access. System mandates cannot be deleted."""
|
||||
try:
|
||||
# Check if mandate exists and user has access
|
||||
mandate = self.getMandate(mandateId)
|
||||
if not mandate:
|
||||
return False
|
||||
|
||||
# System mandates (isSystem=True) cannot be deleted
|
||||
if getattr(mandate, "isSystem", False):
|
||||
raise ValueError(f"System mandate '{mandate.name}' cannot be deleted")
|
||||
|
||||
if not self.checkRbacPermission(Mandate, "delete", mandateId):
|
||||
raise PermissionError(f"No permission to delete mandate {mandateId}")
|
||||
|
||||
|
|
@ -1486,6 +1597,7 @@ class AppObjects:
|
|||
def createUserMandate(self, userId: str, mandateId: str, roleIds: List[str] = None) -> UserMandate:
|
||||
"""
|
||||
Create a UserMandate record (add user to mandate).
|
||||
Also creates a billing account for the user if billing is configured for PREPAY_USER.
|
||||
|
||||
Args:
|
||||
userId: User ID
|
||||
|
|
@ -1519,11 +1631,52 @@ class AppObjects:
|
|||
)
|
||||
self.db.recordCreate(UserMandateRole, userMandateRole.model_dump())
|
||||
|
||||
# Create billing account for user if billing is configured
|
||||
self._ensureUserBillingAccount(userId, mandateId)
|
||||
|
||||
cleanedRecord = {k: v for k, v in createdRecord.items() if not k.startswith("_")}
|
||||
return UserMandate(**cleanedRecord)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating UserMandate: {e}")
|
||||
raise ValueError(f"Failed to create UserMandate: {e}")
|
||||
|
||||
def _ensureUserBillingAccount(self, userId: str, mandateId: str) -> None:
|
||||
"""
|
||||
Ensure a user has a billing account for the mandate if billing is configured.
|
||||
User accounts are always created for all billing models (for audit trail).
|
||||
Initial balance depends on billing model:
|
||||
- PREPAY_USER: defaultUserCredit from settings
|
||||
- PREPAY_MANDATE / CREDIT_POSTPAY: 0.0 (budget is on mandate pool)
|
||||
|
||||
Args:
|
||||
userId: User ID
|
||||
mandateId: Mandate ID
|
||||
"""
|
||||
try:
|
||||
from modules.interfaces.interfaceDbBilling import _getRootInterface as getBillingRootInterface
|
||||
from modules.datamodels.datamodelBilling import BillingModelEnum
|
||||
|
||||
billingInterface = getBillingRootInterface()
|
||||
settings = billingInterface.getSettings(mandateId)
|
||||
|
||||
if not settings:
|
||||
return # No billing configured for this mandate
|
||||
|
||||
billingModel = settings.get("billingModel", "UNLIMITED")
|
||||
if billingModel == BillingModelEnum.UNLIMITED.value:
|
||||
return # No accounts needed for UNLIMITED
|
||||
|
||||
# Initial balance depends on billing model
|
||||
if billingModel == BillingModelEnum.PREPAY_USER.value:
|
||||
initialBalance = settings.get("defaultUserCredit", 10.0)
|
||||
else:
|
||||
initialBalance = 0.0 # PREPAY_MANDATE / CREDIT_POSTPAY: budget is on pool
|
||||
|
||||
billingInterface.getOrCreateUserAccount(mandateId, userId, initialBalance=initialBalance)
|
||||
logger.info(f"Ensured billing account for user {userId} in mandate {mandateId} (model={billingModel}, initial={initialBalance} CHF)")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create billing account for user {userId} (non-critical): {e}")
|
||||
|
||||
def deleteUserMandate(self, userId: str, mandateId: str) -> bool:
|
||||
"""
|
||||
|
|
@ -1547,6 +1700,106 @@ class AppObjects:
|
|||
logger.error(f"Error deleting UserMandate: {e}")
|
||||
raise ValueError(f"Failed to delete UserMandate: {e}")
|
||||
|
||||
def getUserMandatesByMandate(self, mandateId: str) -> List[UserMandate]:
|
||||
"""
|
||||
Get all UserMandate records for a specific mandate.
|
||||
|
||||
Args:
|
||||
mandateId: Mandate ID
|
||||
|
||||
Returns:
|
||||
List of UserMandate objects
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(
|
||||
UserMandate,
|
||||
recordFilter={"mandateId": mandateId}
|
||||
)
|
||||
result = []
|
||||
for record in records:
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
||||
result.append(UserMandate(**cleanedRecord))
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting UserMandates for mandate {mandateId}: {e}")
|
||||
return []
|
||||
|
||||
def getUserMandateRoles(self, userMandateId: str) -> List[UserMandateRole]:
|
||||
"""
|
||||
Get all UserMandateRole records for a UserMandate.
|
||||
|
||||
Args:
|
||||
userMandateId: UserMandate ID
|
||||
|
||||
Returns:
|
||||
List of UserMandateRole objects
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(
|
||||
UserMandateRole,
|
||||
recordFilter={"userMandateId": userMandateId}
|
||||
)
|
||||
result = []
|
||||
for record in records:
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
||||
result.append(UserMandateRole(**cleanedRecord))
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting UserMandateRoles: {e}")
|
||||
return []
|
||||
|
||||
def deleteUserMandateRoles(self, userMandateId: str) -> int:
|
||||
"""
|
||||
Delete all role assignments for a UserMandate.
|
||||
|
||||
Args:
|
||||
userMandateId: UserMandate ID
|
||||
|
||||
Returns:
|
||||
Number of deleted role assignments
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(
|
||||
UserMandateRole,
|
||||
recordFilter={"userMandateId": userMandateId}
|
||||
)
|
||||
deletedCount = 0
|
||||
for record in records:
|
||||
if self.db.recordDelete(UserMandateRole, record.get("id")):
|
||||
deletedCount += 1
|
||||
return deletedCount
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting UserMandateRoles: {e}")
|
||||
return 0
|
||||
|
||||
def validateRoleForMandate(self, roleId: str, mandateId: str) -> Role:
|
||||
"""
|
||||
Validate a role exists and belongs to the specified mandate (or is global).
|
||||
|
||||
Args:
|
||||
roleId: Role ID to validate
|
||||
mandateId: Mandate ID for context validation
|
||||
|
||||
Returns:
|
||||
Role object if valid
|
||||
|
||||
Raises:
|
||||
ValueError: If role not found or belongs to different mandate
|
||||
"""
|
||||
role = self.getRole(roleId)
|
||||
if not role:
|
||||
raise ValueError(f"Role {roleId} not found")
|
||||
|
||||
# Check mandate scope
|
||||
if role.mandateId and str(role.mandateId) != str(mandateId):
|
||||
raise ValueError(f"Role {roleId} belongs to a different mandate")
|
||||
|
||||
# Check feature-instance scope (not allowed at mandate level)
|
||||
if role.featureInstanceId:
|
||||
raise ValueError(f"Role {roleId} is a feature-instance role and cannot be assigned at mandate level")
|
||||
|
||||
return role
|
||||
|
||||
def getRoleIdsForUserMandate(self, userMandateId: str) -> List[str]:
|
||||
"""
|
||||
Get all role IDs assigned to a UserMandate.
|
||||
|
|
@ -1688,6 +1941,30 @@ class AppObjects:
|
|||
logger.error(f"Error getting FeatureAccesses: {e}")
|
||||
return []
|
||||
|
||||
def getFeatureAccessesByInstance(self, featureInstanceId: str) -> List[FeatureAccess]:
|
||||
"""
|
||||
Get all FeatureAccess records for a specific feature instance.
|
||||
|
||||
Args:
|
||||
featureInstanceId: FeatureInstance ID
|
||||
|
||||
Returns:
|
||||
List of FeatureAccess objects
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(
|
||||
FeatureAccess,
|
||||
recordFilter={"featureInstanceId": featureInstanceId}
|
||||
)
|
||||
result = []
|
||||
for record in records:
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
||||
result.append(FeatureAccess(**cleanedRecord))
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting FeatureAccesses for instance {featureInstanceId}: {e}")
|
||||
return []
|
||||
|
||||
def createFeatureAccess(self, userId: str, featureInstanceId: str, roleIds: List[str] = None) -> FeatureAccess:
|
||||
"""
|
||||
Create a FeatureAccess record (grant user access to feature instance).
|
||||
|
|
@ -1750,6 +2027,445 @@ class AppObjects:
|
|||
logger.error(f"Error getting role IDs for FeatureAccess: {e}")
|
||||
return []
|
||||
|
||||
def deleteFeatureAccessRoles(self, featureAccessId: str) -> int:
|
||||
"""
|
||||
Delete all FeatureAccessRole records for a FeatureAccess.
|
||||
|
||||
Args:
|
||||
featureAccessId: FeatureAccess ID
|
||||
|
||||
Returns:
|
||||
Number of records deleted
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(
|
||||
FeatureAccessRole,
|
||||
recordFilter={"featureAccessId": featureAccessId}
|
||||
)
|
||||
count = 0
|
||||
for record in records:
|
||||
recordId = record.get("id")
|
||||
if recordId:
|
||||
self.db.recordDelete(FeatureAccessRole, recordId)
|
||||
count += 1
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting FeatureAccessRoles for {featureAccessId}: {e}")
|
||||
return 0
|
||||
|
||||
# ============================================
|
||||
# Invitation Methods
|
||||
# ============================================
|
||||
|
||||
def getInvitation(self, invitationId: str) -> Optional[Invitation]:
|
||||
"""
|
||||
Get an invitation by ID.
|
||||
|
||||
Args:
|
||||
invitationId: Invitation ID
|
||||
|
||||
Returns:
|
||||
Invitation object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(Invitation, recordFilter={"id": invitationId})
|
||||
if records:
|
||||
cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")}
|
||||
return Invitation(**cleanedRecord)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting invitation {invitationId}: {e}")
|
||||
return None
|
||||
|
||||
def getInvitationByToken(self, token: str) -> Optional[Invitation]:
|
||||
"""
|
||||
Get an invitation by token.
|
||||
|
||||
Args:
|
||||
token: Invitation token
|
||||
|
||||
Returns:
|
||||
Invitation object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(Invitation, recordFilter={"token": token})
|
||||
if records:
|
||||
cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")}
|
||||
return Invitation(**cleanedRecord)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting invitation by token: {e}")
|
||||
return None
|
||||
|
||||
def getInvitationsByMandate(self, mandateId: str) -> List[Invitation]:
|
||||
"""
|
||||
Get all invitations for a mandate.
|
||||
|
||||
Args:
|
||||
mandateId: Mandate ID
|
||||
|
||||
Returns:
|
||||
List of Invitation objects
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(Invitation, recordFilter={"mandateId": mandateId})
|
||||
result = []
|
||||
for record in records:
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
||||
result.append(Invitation(**cleanedRecord))
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting invitations for mandate {mandateId}: {e}")
|
||||
return []
|
||||
|
||||
def getInvitationsByCreator(self, creatorId: str) -> List[Invitation]:
|
||||
"""
|
||||
Get all invitations created by a user.
|
||||
|
||||
Args:
|
||||
creatorId: User ID who created the invitations
|
||||
|
||||
Returns:
|
||||
List of Invitation objects
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(Invitation, recordFilter={"createdBy": creatorId})
|
||||
result = []
|
||||
for record in records:
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
||||
result.append(Invitation(**cleanedRecord))
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting invitations by creator {creatorId}: {e}")
|
||||
return []
|
||||
|
||||
def getInvitationsByUsedBy(self, usedById: str) -> List[Invitation]:
|
||||
"""
|
||||
Get all invitations used by a user.
|
||||
|
||||
Args:
|
||||
usedById: User ID who used the invitations
|
||||
|
||||
Returns:
|
||||
List of Invitation objects
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(Invitation, recordFilter={"usedBy": usedById})
|
||||
result = []
|
||||
for record in records:
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
||||
result.append(Invitation(**cleanedRecord))
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting invitations used by {usedById}: {e}")
|
||||
return []
|
||||
|
||||
def getInvitationsByTargetUsername(self, targetUsername: str) -> List[Invitation]:
|
||||
"""
|
||||
Get all invitations for a target username.
|
||||
|
||||
Args:
|
||||
targetUsername: Target username for the invitations
|
||||
|
||||
Returns:
|
||||
List of Invitation objects
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(Invitation, recordFilter={"targetUsername": targetUsername})
|
||||
result = []
|
||||
for record in records:
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
||||
result.append(Invitation(**cleanedRecord))
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting invitations for target username {targetUsername}: {e}")
|
||||
return []
|
||||
|
||||
# ============================================
|
||||
# Additional Helper Methods
|
||||
# ============================================
|
||||
|
||||
def getAllUsers(self) -> List[User]:
|
||||
"""
|
||||
Get all users (for SysAdmin only).
|
||||
|
||||
Returns:
|
||||
List of User objects (without sensitive fields)
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(UserInDB)
|
||||
result = []
|
||||
for record in records:
|
||||
# Filter out sensitive and internal fields
|
||||
cleanedRecord = {
|
||||
k: v for k, v in record.items()
|
||||
if not k.startswith("_") and k not in ["hashedPassword", "resetToken", "resetTokenExpires"]
|
||||
}
|
||||
# Ensure roleLabels is a list
|
||||
if cleanedRecord.get("roleLabels") is None:
|
||||
cleanedRecord["roleLabels"] = []
|
||||
result.append(User(**cleanedRecord))
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all users: {e}")
|
||||
return []
|
||||
|
||||
def getUserMandateById(self, userMandateId: str) -> Optional[UserMandate]:
|
||||
"""
|
||||
Get a UserMandate by its ID.
|
||||
|
||||
Args:
|
||||
userMandateId: UserMandate ID
|
||||
|
||||
Returns:
|
||||
UserMandate object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(UserMandate, recordFilter={"id": userMandateId})
|
||||
if records:
|
||||
cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")}
|
||||
return UserMandate(**cleanedRecord)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting UserMandate {userMandateId}: {e}")
|
||||
return None
|
||||
|
||||
def getUserMandateRolesByRole(self, roleId: str) -> List[UserMandateRole]:
|
||||
"""
|
||||
Get all UserMandateRole records for a specific role.
|
||||
|
||||
Args:
|
||||
roleId: Role ID
|
||||
|
||||
Returns:
|
||||
List of UserMandateRole objects
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(UserMandateRole, recordFilter={"roleId": roleId})
|
||||
result = []
|
||||
for record in records:
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
||||
result.append(UserMandateRole(**cleanedRecord))
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting UserMandateRoles for role {roleId}: {e}")
|
||||
return []
|
||||
|
||||
def getFeatureInstance(self, instanceId: str):
|
||||
"""
|
||||
Get a FeatureInstance by ID.
|
||||
|
||||
Args:
|
||||
instanceId: FeatureInstance ID
|
||||
|
||||
Returns:
|
||||
FeatureInstance object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(FeatureInstance, recordFilter={"id": instanceId})
|
||||
if records:
|
||||
cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")}
|
||||
return FeatureInstance(**cleanedRecord)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting FeatureInstance {instanceId}: {e}")
|
||||
return None
|
||||
|
||||
def getFeatureByCode(self, featureCode: str) -> Optional[Feature]:
|
||||
"""
|
||||
Get a Feature by its code.
|
||||
|
||||
Args:
|
||||
featureCode: Feature code
|
||||
|
||||
Returns:
|
||||
Feature object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(Feature, recordFilter={"code": featureCode})
|
||||
if records:
|
||||
cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")}
|
||||
return Feature(**cleanedRecord)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Feature by code {featureCode}: {e}")
|
||||
return None
|
||||
|
||||
def getFeatureInstancesByMandate(self, mandateId: str, enabledOnly: bool = False) -> List[FeatureInstance]:
|
||||
"""
|
||||
Get all FeatureInstances for a mandate.
|
||||
|
||||
Args:
|
||||
mandateId: Mandate ID
|
||||
enabledOnly: If True, only return enabled instances
|
||||
|
||||
Returns:
|
||||
List of FeatureInstance objects
|
||||
"""
|
||||
try:
|
||||
recordFilter = {"mandateId": mandateId}
|
||||
if enabledOnly:
|
||||
recordFilter["enabled"] = True
|
||||
records = self.db.getRecordset(FeatureInstance, recordFilter=recordFilter)
|
||||
result = []
|
||||
for record in records:
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
||||
result.append(FeatureInstance(**cleanedRecord))
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting FeatureInstances for mandate {mandateId}: {e}")
|
||||
return []
|
||||
|
||||
# ============================================
|
||||
# Notification Methods
|
||||
# ============================================
|
||||
|
||||
def getNotification(self, notificationId: str) -> Optional[UserNotification]:
|
||||
"""
|
||||
Get a notification by ID.
|
||||
|
||||
Args:
|
||||
notificationId: Notification ID
|
||||
|
||||
Returns:
|
||||
UserNotification object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(UserNotification, recordFilter={"id": notificationId})
|
||||
if records:
|
||||
cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")}
|
||||
return UserNotification(**cleanedRecord)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting notification {notificationId}: {e}")
|
||||
return None
|
||||
|
||||
def getNotificationsByUser(
|
||||
self,
|
||||
userId: str,
|
||||
status: Optional[str] = None,
|
||||
limit: Optional[int] = None
|
||||
) -> List[UserNotification]:
|
||||
"""
|
||||
Get notifications for a user.
|
||||
|
||||
Args:
|
||||
userId: User ID
|
||||
status: Optional status filter (e.g., 'unread')
|
||||
limit: Optional limit on number of results
|
||||
|
||||
Returns:
|
||||
List of UserNotification objects
|
||||
"""
|
||||
try:
|
||||
recordFilter = {"userId": userId}
|
||||
if status:
|
||||
recordFilter["status"] = status
|
||||
records = self.db.getRecordset(UserNotification, recordFilter=recordFilter)
|
||||
result = []
|
||||
for record in records:
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
||||
result.append(UserNotification(**cleanedRecord))
|
||||
# Sort by createdAt descending
|
||||
result.sort(key=lambda x: x.createdAt or 0, reverse=True)
|
||||
if limit:
|
||||
result = result[:limit]
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting notifications for user {userId}: {e}")
|
||||
return []
|
||||
|
||||
# ============================================
|
||||
# AccessRule Methods
|
||||
# ============================================
|
||||
|
||||
def getAccessRule(self, ruleId: str) -> Optional[AccessRule]:
|
||||
"""
|
||||
Get an AccessRule by ID.
|
||||
|
||||
Args:
|
||||
ruleId: AccessRule ID
|
||||
|
||||
Returns:
|
||||
AccessRule object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(AccessRule, recordFilter={"id": ruleId})
|
||||
if records:
|
||||
cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")}
|
||||
return AccessRule(**cleanedRecord)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting AccessRule {ruleId}: {e}")
|
||||
return None
|
||||
|
||||
def getAccessRulesByRole(self, roleId: str) -> List[AccessRule]:
|
||||
"""
|
||||
Get all AccessRules for a role.
|
||||
|
||||
Args:
|
||||
roleId: Role ID
|
||||
|
||||
Returns:
|
||||
List of AccessRule objects
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(AccessRule, recordFilter={"roleId": roleId})
|
||||
result = []
|
||||
for record in records:
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
||||
result.append(AccessRule(**cleanedRecord))
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting AccessRules for role {roleId}: {e}")
|
||||
return []
|
||||
|
||||
def getRolesByFeatureInstance(self, featureInstanceId: str) -> List[Role]:
|
||||
"""
|
||||
Get all roles for a feature instance.
|
||||
|
||||
Args:
|
||||
featureInstanceId: FeatureInstance ID
|
||||
|
||||
Returns:
|
||||
List of Role objects
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(Role, recordFilter={"featureInstanceId": featureInstanceId})
|
||||
result = []
|
||||
for record in records:
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
||||
result.append(Role(**cleanedRecord))
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting roles for feature instance {featureInstanceId}: {e}")
|
||||
return []
|
||||
|
||||
def getRolesByFeatureCode(self, featureCode: str, featureInstanceId: Optional[str] = None) -> List[Role]:
|
||||
"""
|
||||
Get all roles for a feature code, optionally filtered by instance.
|
||||
|
||||
Args:
|
||||
featureCode: Feature code
|
||||
featureInstanceId: Optional FeatureInstance ID filter
|
||||
|
||||
Returns:
|
||||
List of Role objects
|
||||
"""
|
||||
try:
|
||||
recordFilter = {"featureCode": featureCode}
|
||||
if featureInstanceId:
|
||||
recordFilter["featureInstanceId"] = featureInstanceId
|
||||
records = self.db.getRecordset(Role, recordFilter=recordFilter)
|
||||
result = []
|
||||
for record in records:
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
||||
result.append(Role(**cleanedRecord))
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting roles for feature code {featureCode}: {e}")
|
||||
return []
|
||||
|
||||
# Token methods
|
||||
|
||||
def saveAccessToken(self, token: Token, replace_existing: bool = True) -> None:
|
||||
|
|
@ -1908,6 +2624,56 @@ class AppObjects:
|
|||
)
|
||||
return None
|
||||
|
||||
def getTokensByConnectionIdAndAuthority(
|
||||
self, connectionId: str, authority: AuthAuthority
|
||||
) -> List[Token]:
|
||||
"""Get tokens for a connection with specific authority."""
|
||||
try:
|
||||
tokens = self.db.getRecordset(
|
||||
Token, recordFilter={
|
||||
"connectionId": connectionId,
|
||||
"authority": authority.value if hasattr(authority, 'value') else str(authority)
|
||||
}
|
||||
)
|
||||
result = []
|
||||
for token_dict in tokens:
|
||||
cleanedRecord = {k: v for k, v in token_dict.items() if not k.startswith("_")}
|
||||
result.append(Token(**cleanedRecord))
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting tokens by connection and authority: {str(e)}")
|
||||
return []
|
||||
|
||||
def getTokensByUserIdNoConnection(
|
||||
self, userId: str, authority: AuthAuthority
|
||||
) -> List[Token]:
|
||||
"""Get tokens for a user without a connection (access tokens)."""
|
||||
try:
|
||||
tokens = self.db.getRecordset(
|
||||
Token, recordFilter={
|
||||
"userId": userId,
|
||||
"connectionId": None,
|
||||
"authority": authority.value if hasattr(authority, 'value') else str(authority)
|
||||
}
|
||||
)
|
||||
result = []
|
||||
for token_dict in tokens:
|
||||
cleanedRecord = {k: v for k, v in token_dict.items() if not k.startswith("_")}
|
||||
result.append(Token(**cleanedRecord))
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting tokens by user and authority: {str(e)}")
|
||||
return []
|
||||
|
||||
def getAllTokens(self, recordFilter: dict = None) -> List[dict]:
|
||||
"""Get all tokens with optional filtering (returns raw dicts)."""
|
||||
try:
|
||||
tokens = self.db.getRecordset(Token, recordFilter=recordFilter or {})
|
||||
return tokens
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all tokens: {str(e)}")
|
||||
return []
|
||||
|
||||
def findActiveTokenById(
|
||||
self,
|
||||
tokenId: str,
|
||||
|
|
@ -2340,6 +3106,42 @@ class AppObjects:
|
|||
logger.error(f"Error getting role by label {roleLabel}: {str(e)}")
|
||||
return None
|
||||
|
||||
def getRoleByLabelAndScope(
|
||||
self,
|
||||
roleLabel: str,
|
||||
mandateId: Optional[str] = None,
|
||||
featureInstanceId: Optional[str] = None,
|
||||
featureCode: Optional[str] = None
|
||||
) -> Optional[Role]:
|
||||
"""
|
||||
Get a role by label with scope filtering.
|
||||
|
||||
Args:
|
||||
roleLabel: Role label
|
||||
mandateId: Mandate ID (use None for global roles)
|
||||
featureInstanceId: Feature instance ID
|
||||
featureCode: Feature code
|
||||
|
||||
Returns:
|
||||
Role object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
recordFilter = {"roleLabel": roleLabel}
|
||||
if mandateId is not None:
|
||||
recordFilter["mandateId"] = mandateId
|
||||
if featureInstanceId is not None:
|
||||
recordFilter["featureInstanceId"] = featureInstanceId
|
||||
if featureCode is not None:
|
||||
recordFilter["featureCode"] = featureCode
|
||||
|
||||
roles = self.db.getRecordset(Role, recordFilter=recordFilter)
|
||||
if roles:
|
||||
return Role(**roles[0])
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting role by label and scope {roleLabel}: {str(e)}")
|
||||
return None
|
||||
|
||||
def getAllRoles(self, pagination: Optional[PaginationParams] = None) -> Union[List[Role], PaginatedResult]:
|
||||
"""
|
||||
Get all roles with optional pagination, sorting, and filtering.
|
||||
|
|
|
|||
1389
modules/interfaces/interfaceDbBilling.py
Normal file
1389
modules/interfaces/interfaceDbBilling.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -329,9 +329,6 @@ class ChatObjects:
|
|||
userId=self.userId
|
||||
)
|
||||
|
||||
# Initialize database system
|
||||
self.db.initDbSystem()
|
||||
|
||||
logger.info("Database initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database: {str(e)}")
|
||||
|
|
@ -342,6 +339,18 @@ class ChatObjects:
|
|||
pass
|
||||
|
||||
|
||||
def _getRecordset(self, modelClass, recordFilter=None, **kwargs):
|
||||
"""Wrapper for getRecordsetWithRBAC that automatically includes mandateId/featureInstanceId."""
|
||||
return getRecordsetWithRBAC(
|
||||
self.db,
|
||||
modelClass,
|
||||
self.currentUser,
|
||||
recordFilter=recordFilter,
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def checkRbacPermission(
|
||||
self,
|
||||
modelClass: type,
|
||||
|
|
@ -613,12 +622,7 @@ class ChatObjects:
|
|||
If pagination is provided: PaginatedResult with items and metadata
|
||||
"""
|
||||
# Use RBAC filtering with featureInstanceId for instance-level isolation
|
||||
filteredWorkflows = getRecordsetWithRBAC(self.db,
|
||||
ChatWorkflow,
|
||||
self.currentUser,
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId
|
||||
)
|
||||
filteredWorkflows = self._getRecordset(ChatWorkflow)
|
||||
|
||||
# If no pagination requested, return all items (no sorting - frontend handles it)
|
||||
if pagination is None:
|
||||
|
|
@ -650,13 +654,7 @@ class ChatObjects:
|
|||
def getWorkflow(self, workflowId: str) -> Optional[ChatWorkflow]:
|
||||
"""Returns a workflow by ID if user has access."""
|
||||
# Use RBAC filtering with featureInstanceId for instance-level isolation
|
||||
workflows = getRecordsetWithRBAC(self.db,
|
||||
ChatWorkflow,
|
||||
self.currentUser,
|
||||
recordFilter={"id": workflowId},
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId
|
||||
)
|
||||
workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
|
||||
|
||||
if not workflows:
|
||||
return None
|
||||
|
|
@ -812,7 +810,7 @@ class ChatObjects:
|
|||
# Delete message documents (but NOT the files!)
|
||||
# Note: ChatStat does NOT have messageId - stats are only at workflow level
|
||||
try:
|
||||
existing_docs = getRecordsetWithRBAC(self.db, ChatDocument, self.currentUser, recordFilter={"messageId": messageId})
|
||||
existing_docs = self._getRecordset(ChatDocument, recordFilter={"messageId": messageId})
|
||||
for doc in existing_docs:
|
||||
self.db.recordDelete(ChatDocument, doc["id"])
|
||||
except Exception as e:
|
||||
|
|
@ -822,12 +820,12 @@ class ChatObjects:
|
|||
self.db.recordDelete(ChatMessage, messageId)
|
||||
|
||||
# 2. Delete workflow stats
|
||||
existing_stats = getRecordsetWithRBAC(self.db, ChatStat, self.currentUser, recordFilter={"workflowId": workflowId})
|
||||
existing_stats = self._getRecordset(ChatStat, recordFilter={"workflowId": workflowId})
|
||||
for stat in existing_stats:
|
||||
self.db.recordDelete(ChatStat, stat["id"])
|
||||
|
||||
# 3. Delete workflow logs
|
||||
existing_logs = getRecordsetWithRBAC(self.db, ChatLog, self.currentUser, recordFilter={"workflowId": workflowId})
|
||||
existing_logs = self._getRecordset(ChatLog, recordFilter={"workflowId": workflowId})
|
||||
for log in existing_logs:
|
||||
self.db.recordDelete(ChatLog, log["id"])
|
||||
|
||||
|
|
@ -858,11 +856,7 @@ class ChatObjects:
|
|||
"""
|
||||
# Check workflow access first (without calling getWorkflow to avoid circular reference)
|
||||
# Use RBAC filtering
|
||||
workflows = getRecordsetWithRBAC(self.db,
|
||||
ChatWorkflow,
|
||||
self.currentUser,
|
||||
recordFilter={"id": workflowId}
|
||||
)
|
||||
workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
|
||||
|
||||
if not workflows:
|
||||
if pagination is None:
|
||||
|
|
@ -870,7 +864,7 @@ class ChatObjects:
|
|||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||||
|
||||
# Get messages for this workflow from normalized table
|
||||
messages = getRecordsetWithRBAC(self.db, ChatMessage, self.currentUser, recordFilter={"workflowId": workflowId})
|
||||
messages = self._getRecordset(ChatMessage, recordFilter={"workflowId": workflowId})
|
||||
|
||||
# Convert raw messages to dict format for sorting/filtering
|
||||
messageDicts = []
|
||||
|
|
@ -1146,7 +1140,7 @@ class ChatObjects:
|
|||
raise ValueError("messageId cannot be empty")
|
||||
|
||||
# Check if message exists in database
|
||||
messages = getRecordsetWithRBAC(self.db, ChatMessage, self.currentUser, recordFilter={"id": messageId})
|
||||
messages = self._getRecordset(ChatMessage, recordFilter={"id": messageId})
|
||||
if not messages:
|
||||
logger.warning(f"Message with ID {messageId} does not exist in database")
|
||||
|
||||
|
|
@ -1253,12 +1247,12 @@ class ChatObjects:
|
|||
# CASCADE DELETE: Delete all related data first
|
||||
|
||||
# 1. Delete message stats
|
||||
existing_stats = getRecordsetWithRBAC(self.db, ChatStat, self.currentUser, recordFilter={"messageId": messageId})
|
||||
existing_stats = self._getRecordset(ChatStat, recordFilter={"messageId": messageId})
|
||||
for stat in existing_stats:
|
||||
self.db.recordDelete(ChatStat, stat["id"])
|
||||
|
||||
# 2. Delete message documents (but NOT the files!)
|
||||
existing_docs = getRecordsetWithRBAC(self.db, ChatDocument, self.currentUser, recordFilter={"messageId": messageId})
|
||||
existing_docs = self._getRecordset(ChatDocument, recordFilter={"messageId": messageId})
|
||||
for doc in existing_docs:
|
||||
self.db.recordDelete(ChatDocument, doc["id"])
|
||||
|
||||
|
|
@ -1285,7 +1279,7 @@ class ChatObjects:
|
|||
|
||||
|
||||
# Get documents for this message from normalized table
|
||||
documents = getRecordsetWithRBAC(self.db, ChatDocument, self.currentUser, recordFilter={"messageId": messageId})
|
||||
documents = self._getRecordset(ChatDocument, recordFilter={"messageId": messageId})
|
||||
|
||||
if not documents:
|
||||
logger.warning(f"No documents found for message {messageId}")
|
||||
|
|
@ -1326,7 +1320,7 @@ class ChatObjects:
|
|||
def getDocuments(self, messageId: str) -> List[ChatDocument]:
|
||||
"""Returns documents for a message from normalized table."""
|
||||
try:
|
||||
documents = getRecordsetWithRBAC(self.db, ChatDocument, self.currentUser, recordFilter={"messageId": messageId})
|
||||
documents = self._getRecordset(ChatDocument, recordFilter={"messageId": messageId})
|
||||
return [ChatDocument(**doc) for doc in documents]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting message documents: {str(e)}")
|
||||
|
|
@ -1372,11 +1366,7 @@ class ChatObjects:
|
|||
"""
|
||||
# Check workflow access first (without calling getWorkflow to avoid circular reference)
|
||||
# Use RBAC filtering
|
||||
workflows = getRecordsetWithRBAC(self.db,
|
||||
ChatWorkflow,
|
||||
self.currentUser,
|
||||
recordFilter={"id": workflowId}
|
||||
)
|
||||
workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
|
||||
|
||||
if not workflows:
|
||||
if pagination is None:
|
||||
|
|
@ -1384,7 +1374,7 @@ class ChatObjects:
|
|||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||||
|
||||
# Get logs for this workflow from normalized table
|
||||
logs = getRecordsetWithRBAC(self.db, ChatLog, self.currentUser, recordFilter={"workflowId": workflowId})
|
||||
logs = self._getRecordset(ChatLog, recordFilter={"workflowId": workflowId})
|
||||
|
||||
# Convert raw logs to dict format for sorting/filtering
|
||||
logDicts = []
|
||||
|
|
@ -1516,24 +1506,31 @@ class ChatObjects:
|
|||
"""Returns list of statistics for a workflow if user has access."""
|
||||
# Check workflow access first (without calling getWorkflow to avoid circular reference)
|
||||
# Use RBAC filtering
|
||||
workflows = getRecordsetWithRBAC(self.db,
|
||||
ChatWorkflow,
|
||||
self.currentUser,
|
||||
recordFilter={"id": workflowId}
|
||||
)
|
||||
workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
|
||||
|
||||
if not workflows:
|
||||
return []
|
||||
|
||||
# Get stats for this workflow from normalized table
|
||||
stats = getRecordsetWithRBAC(self.db, ChatStat, self.currentUser, recordFilter={"workflowId": workflowId})
|
||||
stats = self._getRecordset(ChatStat, recordFilter={"workflowId": workflowId})
|
||||
|
||||
if not stats:
|
||||
return []
|
||||
|
||||
# Return all stats records sorted by creation time
|
||||
stats.sort(key=lambda x: x.get("created_at", ""))
|
||||
return [ChatStat(**stat) for stat in stats]
|
||||
# DB uses _createdAt (camelCase system field)
|
||||
stats.sort(key=lambda x: x.get("_createdAt", 0))
|
||||
|
||||
# Convert to ChatStat objects, preserving _createdAt via extra="allow"
|
||||
result = []
|
||||
for stat in stats:
|
||||
chat_stat = ChatStat(**stat)
|
||||
# Explicitly preserve _createdAt from raw DB record
|
||||
if "_createdAt" in stat:
|
||||
setattr(chat_stat, '_createdAt', stat["_createdAt"])
|
||||
result.append(chat_stat)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def createStat(self, statData: Dict[str, Any]) -> ChatStat:
|
||||
|
|
@ -1549,9 +1546,16 @@ class ChatObjects:
|
|||
# Validate the stat data against ChatStat model
|
||||
stat = ChatStat(**statData)
|
||||
|
||||
logger.debug(f"Creating stat for workflow {statData.get('workflowId')}: "
|
||||
f"process={statData.get('process')}, "
|
||||
f"priceCHF={statData.get('priceCHF', 0):.4f}, "
|
||||
f"processingTime={statData.get('processingTime', 0):.2f}s")
|
||||
|
||||
# Create the stat record in the database
|
||||
created = self.db.recordCreate(ChatStat, stat)
|
||||
|
||||
logger.info(f"Created stat {created.get('id')} for workflow {statData.get('workflowId')}")
|
||||
|
||||
# Return the created ChatStat
|
||||
return ChatStat(**created)
|
||||
except Exception as e:
|
||||
|
|
@ -1566,11 +1570,7 @@ class ChatObjects:
|
|||
"""
|
||||
# Check workflow access first
|
||||
# Use RBAC filtering
|
||||
workflows = getRecordsetWithRBAC(self.db,
|
||||
ChatWorkflow,
|
||||
self.currentUser,
|
||||
recordFilter={"id": workflowId}
|
||||
)
|
||||
workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
|
||||
|
||||
if not workflows:
|
||||
return {"items": []}
|
||||
|
|
@ -1579,7 +1579,7 @@ class ChatObjects:
|
|||
items = []
|
||||
|
||||
# Get messages
|
||||
messages = getRecordsetWithRBAC(self.db, ChatMessage, self.currentUser, recordFilter={"workflowId": workflowId})
|
||||
messages = self._getRecordset(ChatMessage, recordFilter={"workflowId": workflowId})
|
||||
for msg in messages:
|
||||
# Apply timestamp filtering in Python
|
||||
msgTimestamp = parseTimestamp(msg.get("publishedAt"), default=getUtcTimestamp())
|
||||
|
|
@ -1620,7 +1620,7 @@ class ChatObjects:
|
|||
})
|
||||
|
||||
# Get logs - return all logs with roundNumber if available
|
||||
logs = getRecordsetWithRBAC(self.db, ChatLog, self.currentUser, recordFilter={"workflowId": workflowId})
|
||||
logs = self._getRecordset(ChatLog, recordFilter={"workflowId": workflowId})
|
||||
for log in logs:
|
||||
# Apply timestamp filtering in Python
|
||||
logTimestamp = parseTimestamp(log.get("timestamp"), default=getUtcTimestamp())
|
||||
|
|
@ -1634,18 +1634,23 @@ class ChatObjects:
|
|||
"item": chatLog
|
||||
})
|
||||
|
||||
# Get stats list
|
||||
# Get stats - ChatStat model supports _createdAt via model_config extra="allow"
|
||||
stats = self.getStats(workflowId)
|
||||
for stat in stats:
|
||||
# Apply timestamp filtering in Python
|
||||
stat_timestamp = stat.createdAt if hasattr(stat, 'createdAt') else getUtcTimestamp()
|
||||
# Use _createdAt (system field from DB, preserved via model_config extra="allow")
|
||||
stat_timestamp = getattr(stat, '_createdAt', None) or getUtcTimestamp()
|
||||
if afterTimestamp is not None and stat_timestamp <= afterTimestamp:
|
||||
continue
|
||||
|
||||
# Convert to dict and include _createdAt for frontend
|
||||
stat_dict = stat.model_dump() if hasattr(stat, 'model_dump') else stat.dict()
|
||||
stat_dict['_createdAt'] = stat_timestamp
|
||||
|
||||
items.append({
|
||||
"type": "stat",
|
||||
"createdAt": stat_timestamp,
|
||||
"item": stat
|
||||
"item": stat_dict
|
||||
})
|
||||
|
||||
# Sort all items by createdAt timestamp for chronological order
|
||||
|
|
|
|||
|
|
@ -141,9 +141,6 @@ class ComponentObjects:
|
|||
userId=self.userId if hasattr(self, 'userId') else None
|
||||
)
|
||||
|
||||
# Initialize database system
|
||||
self.db.initDbSystem()
|
||||
|
||||
logger.info("Database initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database: {str(e)}")
|
||||
|
|
@ -316,10 +313,12 @@ class ComponentObjects:
|
|||
return False
|
||||
|
||||
tableName = modelClass.__name__
|
||||
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
||||
objectKey = buildDataObjectKey(tableName)
|
||||
permissions = self.rbac.getUserPermissions(
|
||||
self.currentUser,
|
||||
AccessRuleContext.DATA,
|
||||
tableName,
|
||||
objectKey,
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId
|
||||
)
|
||||
|
|
@ -593,10 +592,58 @@ class ComponentObjects:
|
|||
|
||||
# Prompt methods
|
||||
|
||||
def _isSysAdmin(self) -> bool:
|
||||
"""Check if the current user is a SysAdmin."""
|
||||
return hasattr(self.currentUser, 'isSysAdmin') and self.currentUser.isSysAdmin
|
||||
|
||||
def _enrichPromptsWithPermissions(self, prompts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Enrich prompts with row-level _permissions based on ownership and isSystem flag.
|
||||
|
||||
- SysAdmin: canUpdate=True, canDelete=True on all prompts
|
||||
- Regular user on own prompts: canUpdate=True, canDelete=True
|
||||
- Regular user on system prompts: canUpdate=False, canDelete=False (read-only)
|
||||
"""
|
||||
isSysAdmin = self._isSysAdmin()
|
||||
for prompt in prompts:
|
||||
isOwner = prompt.get("_createdBy") == self.userId
|
||||
prompt["_permissions"] = {
|
||||
"canUpdate": isOwner or isSysAdmin,
|
||||
"canDelete": isOwner or isSysAdmin
|
||||
}
|
||||
return prompts
|
||||
|
||||
def _getPromptsForUser(self) -> List[Dict[str, Any]]:
|
||||
"""Returns prompts visible to the current user.
|
||||
|
||||
Visibility rules:
|
||||
- SysAdmin: ALL prompts
|
||||
- Regular user: own prompts (_createdBy) + system prompts (isSystem=True)
|
||||
"""
|
||||
if self._isSysAdmin():
|
||||
return self.db.getRecordset(Prompt)
|
||||
|
||||
# Get own prompts
|
||||
ownPrompts = self.db.getRecordset(Prompt, recordFilter={"_createdBy": self.userId})
|
||||
|
||||
# Get system prompts
|
||||
systemPrompts = self.db.getRecordset(Prompt, recordFilter={"isSystem": True})
|
||||
|
||||
# Merge and deduplicate (a user's own prompt could also be isSystem)
|
||||
seen = {}
|
||||
for p in ownPrompts:
|
||||
seen[p["id"]] = p
|
||||
for p in systemPrompts:
|
||||
if p["id"] not in seen:
|
||||
seen[p["id"]] = p
|
||||
|
||||
return list(seen.values())
|
||||
|
||||
def getAllPrompts(self, pagination: Optional[PaginationParams] = None) -> Union[List[Prompt], PaginatedResult]:
|
||||
"""
|
||||
Returns prompts based on user access level.
|
||||
Supports optional pagination, sorting, and filtering.
|
||||
Returns prompts with visibility rules:
|
||||
- SysAdmin: sees ALL prompts, can CRUD all
|
||||
- Regular user: sees own prompts + system prompts (isSystem=True), can only CRUD own
|
||||
- Row-level _permissions control edit/delete buttons in the UI
|
||||
|
||||
Args:
|
||||
pagination: Optional pagination parameters. If None, returns all items.
|
||||
|
|
@ -606,11 +653,11 @@ class ComponentObjects:
|
|||
If pagination is provided: PaginatedResult with items and metadata
|
||||
"""
|
||||
try:
|
||||
# Use RBAC filtering
|
||||
filteredPrompts = getRecordsetWithRBAC(self.db,
|
||||
Prompt,
|
||||
self.currentUser
|
||||
)
|
||||
# Get prompts based on user role (own + system for regular, all for SysAdmin)
|
||||
filteredPrompts = self._getPromptsForUser()
|
||||
|
||||
# Enrich with row-level permissions (_permissions: canUpdate, canDelete)
|
||||
filteredPrompts = self._enrichPromptsWithPermissions(filteredPrompts)
|
||||
|
||||
# If no pagination requested, return all items
|
||||
if pagination is None:
|
||||
|
|
@ -633,7 +680,7 @@ class ComponentObjects:
|
|||
endIdx = startIdx + pagination.pageSize
|
||||
pagedPrompts = filteredPrompts[startIdx:endIdx]
|
||||
|
||||
# Convert to model objects
|
||||
# Convert to model objects (extra='allow' on Prompt preserves system fields)
|
||||
items = [Prompt(**prompt) for prompt in pagedPrompts]
|
||||
|
||||
return PaginatedResult(
|
||||
|
|
@ -649,15 +696,24 @@ class ComponentObjects:
|
|||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||||
|
||||
def getPrompt(self, promptId: str) -> Optional[Prompt]:
|
||||
"""Returns a prompt by ID if user has access."""
|
||||
# Use RBAC filtering
|
||||
filteredPrompts = getRecordsetWithRBAC(self.db,
|
||||
Prompt,
|
||||
self.currentUser,
|
||||
recordFilter={"id": promptId}
|
||||
)
|
||||
"""Returns a prompt by ID if the user has visibility.
|
||||
|
||||
return Prompt(**filteredPrompts[0]) if filteredPrompts else None
|
||||
Visibility: SysAdmin sees all, regular user sees own + system prompts.
|
||||
"""
|
||||
filteredPrompts = self.db.getRecordset(Prompt, recordFilter={"id": promptId})
|
||||
if not filteredPrompts:
|
||||
return None
|
||||
|
||||
prompt = filteredPrompts[0]
|
||||
|
||||
# Visibility check for non-SysAdmin: must be owner or system prompt
|
||||
if not self._isSysAdmin():
|
||||
isOwner = prompt.get("_createdBy") == self.userId
|
||||
isSystem = prompt.get("isSystem", False)
|
||||
if not isOwner and not isSystem:
|
||||
return None
|
||||
|
||||
return Prompt(**prompt)
|
||||
|
||||
def createPrompt(self, promptData: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Creates a new prompt if user has permission."""
|
||||
|
|
@ -672,13 +728,25 @@ class ComponentObjects:
|
|||
return createdRecord
|
||||
|
||||
def updatePrompt(self, promptId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Updates a prompt if user has access."""
|
||||
"""Updates a prompt. Rules:
|
||||
- SysAdmin: can update any prompt (including system prompts)
|
||||
- Regular user: can only update own prompts (not system prompts)
|
||||
"""
|
||||
try:
|
||||
# Get prompt
|
||||
# Get prompt (visibility-checked)
|
||||
prompt = self.getPrompt(promptId)
|
||||
if not prompt:
|
||||
raise ValueError(f"Prompt {promptId} not found")
|
||||
|
||||
# Permission check: owner or SysAdmin
|
||||
isOwner = (getattr(prompt, '_createdBy', None) == self.userId)
|
||||
if not self._isSysAdmin() and not isOwner:
|
||||
raise PermissionError(f"No permission to update prompt {promptId}")
|
||||
|
||||
# Regular users cannot set isSystem flag
|
||||
if not self._isSysAdmin() and 'isSystem' in updateData:
|
||||
del updateData['isSystem']
|
||||
|
||||
# Update prompt record directly with the update data
|
||||
self.db.recordModify(Prompt, promptId, updateData)
|
||||
|
||||
|
|
@ -691,77 +759,69 @@ class ComponentObjects:
|
|||
|
||||
return updatedPrompt.model_dump()
|
||||
|
||||
except PermissionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating prompt: {str(e)}")
|
||||
raise ValueError(f"Failed to update prompt: {str(e)}")
|
||||
|
||||
def deletePrompt(self, promptId: str) -> bool:
|
||||
"""Deletes a prompt if user has access."""
|
||||
# Check if the prompt exists and user has access
|
||||
"""Deletes a prompt. Rules:
|
||||
- SysAdmin: can delete any prompt (including system prompts)
|
||||
- Regular user: can only delete own prompts (not system prompts)
|
||||
"""
|
||||
# Get prompt (visibility-checked)
|
||||
prompt = self.getPrompt(promptId)
|
||||
if not prompt:
|
||||
return False
|
||||
|
||||
if not self.checkRbacPermission(Prompt, "update", promptId):
|
||||
|
||||
# Permission check: owner or SysAdmin
|
||||
isOwner = (getattr(prompt, '_createdBy', None) == self.userId)
|
||||
if not self._isSysAdmin() and not isOwner:
|
||||
raise PermissionError(f"No permission to delete prompt {promptId}")
|
||||
|
||||
# Delete prompt
|
||||
success = self.db.recordDelete(Prompt, promptId)
|
||||
|
||||
|
||||
return success
|
||||
|
||||
# File Utilities
|
||||
|
||||
def checkForDuplicateFile(self, fileHash: str, fileName: str = None) -> Optional[FileItem]:
|
||||
"""Checks if a file with the same hash already exists for the current user and mandate.
|
||||
If fileName is provided, also checks for exact name+hash match.
|
||||
Only returns files the current user has access to."""
|
||||
# Get files with the hash, filtered by RBAC
|
||||
accessibleFiles = getRecordsetWithRBAC(self.db,
|
||||
def checkForDuplicateFile(self, fileHash: str, fileName: str) -> Optional[FileItem]:
|
||||
"""Checks if a file with the same hash AND fileName already exists for the current user.
|
||||
|
||||
Duplicate = same user (_createdBy) + same fileHash + same fileName.
|
||||
Same hash with different name is allowed (intentional copy by user).
|
||||
Uses direct DB query (not RBAC) because files are isolated per user.
|
||||
"""
|
||||
if not self.userId:
|
||||
return None
|
||||
|
||||
# Direct DB query: find files with matching hash + name + user
|
||||
matchingFiles = self.db.getRecordset(
|
||||
FileItem,
|
||||
self.currentUser,
|
||||
recordFilter={"fileHash": fileHash}
|
||||
recordFilter={
|
||||
"_createdBy": self.userId,
|
||||
"fileHash": fileHash,
|
||||
"fileName": fileName
|
||||
}
|
||||
)
|
||||
|
||||
if not accessibleFiles:
|
||||
if not matchingFiles:
|
||||
return None
|
||||
|
||||
# If fileName is provided, check for exact name+hash match first
|
||||
if fileName:
|
||||
for file in accessibleFiles:
|
||||
# Skip files without fileName key or with None/empty fileName
|
||||
if "fileName" not in file or not file["fileName"]:
|
||||
continue
|
||||
if file["fileName"] == fileName:
|
||||
return FileItem(
|
||||
id=file["id"],
|
||||
mandateId=file["mandateId"],
|
||||
fileName=file["fileName"],
|
||||
mimeType=file["mimeType"],
|
||||
fileHash=file["fileHash"],
|
||||
fileSize=file["fileSize"],
|
||||
creationDate=file["creationDate"]
|
||||
)
|
||||
|
||||
# Return first valid file with matching hash (for general duplicate detection)
|
||||
for file in accessibleFiles:
|
||||
# Skip files without fileName key or with None/empty fileName
|
||||
if "fileName" not in file or not file["fileName"]:
|
||||
continue
|
||||
# Use first valid file
|
||||
return FileItem(
|
||||
id=file["id"],
|
||||
mandateId=file["mandateId"],
|
||||
fileName=file["fileName"],
|
||||
mimeType=file["mimeType"],
|
||||
fileHash=file["fileHash"],
|
||||
fileSize=file["fileSize"],
|
||||
creationDate=file["creationDate"]
|
||||
)
|
||||
|
||||
# If no valid files found, return None
|
||||
return None
|
||||
# Return first match
|
||||
file = matchingFiles[0]
|
||||
return FileItem(
|
||||
id=file["id"],
|
||||
mandateId=file.get("mandateId", ""),
|
||||
featureInstanceId=file.get("featureInstanceId", ""),
|
||||
fileName=file["fileName"],
|
||||
mimeType=file["mimeType"],
|
||||
fileHash=file["fileHash"],
|
||||
fileSize=file["fileSize"],
|
||||
creationDate=file["creationDate"]
|
||||
)
|
||||
|
||||
def getMimeType(self, fileName: str) -> str:
|
||||
"""Determines the MIME type based on the file extension."""
|
||||
|
|
@ -835,9 +895,18 @@ class ComponentObjects:
|
|||
|
||||
# File methods - metadata-based operations
|
||||
|
||||
def _getFilesByCurrentUser(self, recordFilter: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
||||
"""Files are always user-scoped. Returns only files owned by the current user,
|
||||
regardless of role (including SysAdmin). This bypasses RBAC intentionally."""
|
||||
filterDict = {"_createdBy": self.userId}
|
||||
if recordFilter:
|
||||
filterDict.update(recordFilter)
|
||||
return self.db.getRecordset(FileItem, recordFilter=filterDict)
|
||||
|
||||
def getAllFiles(self, pagination: Optional[PaginationParams] = None) -> Union[List[FileItem], PaginatedResult]:
|
||||
"""
|
||||
Returns files based on user access level.
|
||||
Returns files owned by the current user (user-scoped, not RBAC-based).
|
||||
Every user (including SysAdmin) only sees their own files.
|
||||
Supports optional pagination, sorting, and filtering.
|
||||
|
||||
Args:
|
||||
|
|
@ -847,13 +916,10 @@ class ComponentObjects:
|
|||
If pagination is None: List[FileItem]
|
||||
If pagination is provided: PaginatedResult with items and metadata
|
||||
"""
|
||||
# Use RBAC filtering
|
||||
filteredFiles = getRecordsetWithRBAC(self.db,
|
||||
FileItem,
|
||||
self.currentUser
|
||||
)
|
||||
# Files are always user-scoped: filter by _createdBy (bypasses RBAC SysAdmin override)
|
||||
filteredFiles = self._getFilesByCurrentUser()
|
||||
|
||||
# Convert database records to FileItem instances (for both paginated and non-paginated)
|
||||
# Convert database records to FileItem instances (extra='allow' preserves system fields like _createdBy)
|
||||
def convertFileItems(files):
|
||||
fileItems = []
|
||||
for file in files:
|
||||
|
|
@ -861,21 +927,14 @@ class ComponentObjects:
|
|||
# Ensure proper values, use defaults for invalid data
|
||||
creationDate = file.get("creationDate")
|
||||
if creationDate is None or not isinstance(creationDate, (int, float)) or creationDate <= 0:
|
||||
creationDate = getUtcTimestamp()
|
||||
file["creationDate"] = getUtcTimestamp()
|
||||
|
||||
fileName = file.get("fileName")
|
||||
if not fileName or fileName == "None":
|
||||
continue # Skip records with invalid fileName
|
||||
|
||||
fileItem = FileItem(
|
||||
id=file.get("id"),
|
||||
mandateId=file.get("mandateId"),
|
||||
fileName=fileName,
|
||||
mimeType=file.get("mimeType"),
|
||||
fileHash=file.get("fileHash"),
|
||||
fileSize=file.get("fileSize"),
|
||||
creationDate=creationDate
|
||||
)
|
||||
# Use **file to pass all fields including system fields (_createdBy, etc.)
|
||||
fileItem = FileItem(**file)
|
||||
fileItems.append(fileItem)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping invalid file record: {str(e)}")
|
||||
|
|
@ -903,7 +962,7 @@ class ComponentObjects:
|
|||
endIdx = startIdx + pagination.pageSize
|
||||
pagedFiles = filteredFiles[startIdx:endIdx]
|
||||
|
||||
# Convert to model objects
|
||||
# Convert to model objects (extra='allow' on FileItem preserves system fields)
|
||||
items = convertFileItems(pagedFiles)
|
||||
|
||||
return PaginatedResult(
|
||||
|
|
@ -913,13 +972,9 @@ class ComponentObjects:
|
|||
)
|
||||
|
||||
def getFile(self, fileId: str) -> Optional[FileItem]:
|
||||
"""Returns a file by ID if user has access."""
|
||||
# Use RBAC filtering
|
||||
filteredFiles = getRecordsetWithRBAC(self.db,
|
||||
FileItem,
|
||||
self.currentUser,
|
||||
recordFilter={"id": fileId}
|
||||
)
|
||||
"""Returns a file by ID if it belongs to the current user (user-scoped)."""
|
||||
# Files are always user-scoped: filter by _createdBy (bypasses RBAC SysAdmin override)
|
||||
filteredFiles = self._getFilesByCurrentUser(recordFilter={"id": fileId})
|
||||
|
||||
if not filteredFiles:
|
||||
return None
|
||||
|
|
@ -979,17 +1034,28 @@ class ComponentObjects:
|
|||
counter += 1
|
||||
|
||||
def createFile(self, name: str, mimeType: str, content: bytes) -> FileItem:
|
||||
"""Creates a new file entry if user has permission. Computes fileHash and fileSize from content."""
|
||||
"""Creates a new file entry if user has permission. Computes fileHash and fileSize from content.
|
||||
|
||||
Duplicate check: if a file with the same user + fileHash + fileName already exists,
|
||||
the existing file is returned instead of creating a new one.
|
||||
Same hash with different name is allowed (intentional copy by user).
|
||||
"""
|
||||
if not self.checkRbacPermission(FileItem, "create"):
|
||||
raise PermissionError("No permission to create files")
|
||||
|
||||
# Ensure fileName is unique
|
||||
uniqueName = self._generateUniquefileName(name)
|
||||
|
||||
# Compute file size and hash
|
||||
fileSize = len(content)
|
||||
fileHash = hashlib.sha256(content).hexdigest()
|
||||
|
||||
# Duplicate check: same user + same hash + same fileName → return existing
|
||||
existingFile = self.checkForDuplicateFile(fileHash, name)
|
||||
if existingFile:
|
||||
logger.info(f"Duplicate file detected in createFile: '{name}' (hash={fileHash[:12]}...) for user {self.userId} — returning existing file {existingFile.id}")
|
||||
return existingFile
|
||||
|
||||
# Ensure fileName is unique
|
||||
uniqueName = self._generateUniquefileName(name)
|
||||
|
||||
# Use mandateId and featureInstanceId from context for proper data isolation
|
||||
# Convert None to empty string to satisfy Pydantic validation
|
||||
mandateId = self.mandateId or ""
|
||||
|
|
@ -1008,7 +1074,6 @@ class ComponentObjects:
|
|||
# Store in database
|
||||
self.db.recordCreate(FileItem, fileItem)
|
||||
|
||||
|
||||
return fileItem
|
||||
|
||||
def updateFile(self, fileId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
|
@ -1043,20 +1108,16 @@ class ComponentObjects:
|
|||
if not self.checkRbacPermission(FileItem, "update", fileId):
|
||||
raise PermissionError(f"No permission to delete file {fileId}")
|
||||
|
||||
# Check for other references to this file (by hash) - use RBAC to only check files user has access to
|
||||
# Check for other references to this file (by hash) - user-scoped check
|
||||
fileHash = file.fileHash
|
||||
if fileHash:
|
||||
allReferences = getRecordsetWithRBAC(self.db,
|
||||
FileItem,
|
||||
self.currentUser,
|
||||
recordFilter={"fileHash": fileHash}
|
||||
)
|
||||
allReferences = self._getFilesByCurrentUser(recordFilter={"fileHash": fileHash})
|
||||
otherReferences = [f for f in allReferences if f["id"] != fileId]
|
||||
|
||||
# Only delete associated fileData if no other references exist
|
||||
if not otherReferences:
|
||||
try:
|
||||
fileDataEntries = getRecordsetWithRBAC(self.db, FileData, self.currentUser, recordFilter={"id": fileId})
|
||||
fileDataEntries = self.db.getRecordset(FileData, recordFilter={"id": fileId})
|
||||
if fileDataEntries:
|
||||
self.db.recordDelete(FileData, fileId)
|
||||
logger.debug(f"FileData for file {fileId} deleted")
|
||||
|
|
@ -1116,6 +1177,12 @@ class ComponentObjects:
|
|||
base64Encoded = True
|
||||
logger.debug(f"Stored file {fileId} as base64")
|
||||
|
||||
# Check if file data already exists (e.g., when createFile returned a duplicate)
|
||||
existingData = self.db.getRecordset(FileData, recordFilter={"id": fileId})
|
||||
if existingData:
|
||||
logger.debug(f"File data already exists for {fileId} — skipping duplicate storage")
|
||||
return True
|
||||
|
||||
# Create the fileData record with data and encoding flag
|
||||
fileDataObj = {
|
||||
"id": fileId,
|
||||
|
|
@ -1248,25 +1315,21 @@ class ComponentObjects:
|
|||
logger.error(f"Invalid fileContent type: {type(fileContent)}")
|
||||
raise ValueError(f"fileContent must be bytes, got {type(fileContent)}")
|
||||
|
||||
# Compute file hash first to check for duplicates
|
||||
# Compute file hash to check for duplicates before any DB writes
|
||||
fileHash = hashlib.sha256(fileContent).hexdigest()
|
||||
|
||||
# Check for exact name+hash match first (same name + same content)
|
||||
# Duplicate check: same user + same fileHash + same fileName → return existing file
|
||||
# Same hash with different name is allowed (intentional copy by user)
|
||||
existingFile = self.checkForDuplicateFile(fileHash, fileName)
|
||||
if existingFile:
|
||||
logger.info(f"Exact duplicate detected: {fileName} with same hash. Returning existing file reference.")
|
||||
logger.info(f"Duplicate detected for user {self.userId}: '{fileName}' with hash {fileHash[:12]}... — returning existing file {existingFile.id}")
|
||||
return existingFile, "exact_duplicate"
|
||||
|
||||
# Check for hash-only match (same content, different name)
|
||||
existingFileWithSameHash = self.checkForDuplicateFile(fileHash)
|
||||
if existingFileWithSameHash:
|
||||
logger.info(f"Content duplicate detected: {fileName} has same content as {existingFileWithSameHash.fileName}")
|
||||
# Continue with upload - filename will be made unique if needed
|
||||
|
||||
# Determine MIME type
|
||||
mimeType = self.getMimeType(fileName)
|
||||
|
||||
# Save metadata and file (hash/size computed inside createFile)
|
||||
# createFile handles its own duplicate check (for calls from other code paths)
|
||||
# Here we already checked, so this will create a new file
|
||||
logger.debug(f"Saving file metadata to database for file: {fileName}")
|
||||
fileItem = self.createFile(
|
||||
name=fileName,
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ def getRecordsetWithRBAC(
|
|||
|
||||
# Check view permission first
|
||||
if not permissions.view:
|
||||
logger.debug(f"User {currentUser.id} has no view permission for {objectKey}")
|
||||
logger.debug(f"User {currentUser.id} has no view permission for {objectKey} (mandateId={effectiveMandateId}, featureInstanceId={featureInstanceId})")
|
||||
return []
|
||||
|
||||
# Build WHERE clause with RBAC filtering
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ router.mount(
|
|||
|
||||
@router.get("/")
|
||||
@limiter.limit("30/minute")
|
||||
async def root(request: Request) -> Dict[str, str]:
|
||||
def root(request: Request) -> Dict[str, str]:
|
||||
"""API status endpoint"""
|
||||
# Validate required configuration values
|
||||
allowedOrigins = APP_CONFIG.get("APP_ALLOWED_ORIGINS")
|
||||
|
|
@ -51,7 +51,7 @@ async def root(request: Request) -> Dict[str, str]:
|
|||
|
||||
@router.get("/api/environment")
|
||||
@limiter.limit("30/minute")
|
||||
async def get_environment(
|
||||
def get_environment(
|
||||
request: Request, currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
) -> Dict[str, str]:
|
||||
"""Get environment configuration for frontend"""
|
||||
|
|
@ -82,13 +82,13 @@ async def get_environment(
|
|||
|
||||
@router.options("/{fullPath:path}")
|
||||
@limiter.limit("60/minute")
|
||||
async def options_route(request: Request, fullPath: str) -> Response:
|
||||
def options_route(request: Request, fullPath: str) -> Response:
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.get("/favicon.ico")
|
||||
@limiter.limit("30/minute")
|
||||
async def favicon(request: Request) -> FileResponse:
|
||||
def favicon(request: Request) -> FileResponse:
|
||||
favicon_path = staticFolder / "favicon.ico"
|
||||
if not favicon_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Favicon not found")
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ router = APIRouter(
|
|||
|
||||
@router.get("")
|
||||
@limiter.limit("30/minute")
|
||||
async def get_all_automation_events(
|
||||
def get_all_automation_events(
|
||||
request: Request,
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
) -> List[Dict[str, Any]]:
|
||||
|
|
@ -90,7 +90,7 @@ async def sync_all_automation_events(
|
|||
|
||||
from modules.services import getInterface as getServices
|
||||
services = getServices(currentUser, None)
|
||||
result = await syncAutomationEvents(services, eventUser)
|
||||
result = syncAutomationEvents(services, eventUser)
|
||||
return {
|
||||
"success": True,
|
||||
"synced": result.get("synced", 0),
|
||||
|
|
@ -107,7 +107,7 @@ async def sync_all_automation_events(
|
|||
|
||||
@router.post("/{eventId}/remove")
|
||||
@limiter.limit("10/minute")
|
||||
async def remove_event(
|
||||
def remove_event(
|
||||
request: Request,
|
||||
eventId: str = Path(..., description="Event ID to remove"),
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class SyncRolesResult(BaseModel):
|
|||
|
||||
@router.get("/", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("60/minute")
|
||||
async def list_features(
|
||||
def list_features(
|
||||
request: Request,
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> List[Dict[str, Any]]:
|
||||
|
|
@ -105,7 +105,7 @@ class FeaturesMyResponse(BaseModel):
|
|||
|
||||
@router.get("/my", response_model=FeaturesMyResponse)
|
||||
@limiter.limit("60/minute")
|
||||
async def get_my_feature_instances(
|
||||
def get_my_feature_instances(
|
||||
request: Request,
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> FeaturesMyResponse:
|
||||
|
|
@ -204,38 +204,26 @@ async def get_my_feature_instances(
|
|||
def _getUserRolesInInstance(rootInterface, userId: str, instanceId: str) -> List[str]:
|
||||
"""Get all role labels for a user in a feature instance."""
|
||||
try:
|
||||
from modules.datamodels.datamodelRbac import Role
|
||||
from modules.datamodels.datamodelMembership import FeatureAccess, FeatureAccessRole
|
||||
# Get FeatureAccess for this user and instance (Pydantic model)
|
||||
featureAccess = rootInterface.getFeatureAccess(userId, instanceId)
|
||||
|
||||
# Get FeatureAccess for this user and instance
|
||||
featureAccesses = rootInterface.db.getRecordset(
|
||||
FeatureAccess,
|
||||
recordFilter={"userId": userId, "featureInstanceId": instanceId}
|
||||
)
|
||||
|
||||
if featureAccesses:
|
||||
featureAccessId = featureAccesses[0].get("id")
|
||||
if featureAccess:
|
||||
# Get role IDs via interface method
|
||||
roleIds = rootInterface.getRoleIdsForFeatureAccess(str(featureAccess.id))
|
||||
|
||||
# Get role IDs via FeatureAccessRole junction table
|
||||
featureAccessRoles = rootInterface.db.getRecordset(
|
||||
FeatureAccessRole,
|
||||
recordFilter={"featureAccessId": featureAccessId}
|
||||
)
|
||||
|
||||
if featureAccessRoles:
|
||||
# Get ALL roles, not just the first one
|
||||
if roleIds:
|
||||
# Get ALL roles and extract labels
|
||||
roleLabels = []
|
||||
for far in featureAccessRoles:
|
||||
roleId = far.get("roleId")
|
||||
roles = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId})
|
||||
if roles:
|
||||
roleLabels.append(roles[0].get("roleLabel", "user"))
|
||||
for roleId in roleIds:
|
||||
role = rootInterface.getRole(roleId)
|
||||
if role:
|
||||
roleLabels.append(role.roleLabel)
|
||||
return roleLabels if roleLabels else ["user"]
|
||||
|
||||
return ["user"] # Default
|
||||
return ["user"] # Default - no access means basic user level
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting user roles: {e}")
|
||||
return ["user"]
|
||||
return ["user"] # Fail-safe: default to basic user
|
||||
|
||||
|
||||
def _getInstancePermissions(rootInterface, userId: str, instanceId: str) -> Dict[str, Any]:
|
||||
|
|
@ -249,66 +237,53 @@ def _getInstancePermissions(rootInterface, userId: str, instanceId: str) -> Dict
|
|||
}
|
||||
|
||||
try:
|
||||
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext, Role
|
||||
from modules.datamodels.datamodelMembership import FeatureAccess, FeatureAccessRole
|
||||
from modules.datamodels.datamodelRbac import AccessRuleContext
|
||||
|
||||
# Get FeatureAccess for this user and instance
|
||||
featureAccesses = rootInterface.db.getRecordset(
|
||||
FeatureAccess,
|
||||
recordFilter={"userId": userId, "featureInstanceId": instanceId}
|
||||
)
|
||||
# Get FeatureAccess for this user and instance (Pydantic model)
|
||||
featureAccess = rootInterface.getFeatureAccess(userId, instanceId)
|
||||
|
||||
logger.debug(f"_getInstancePermissions: userId={userId}, instanceId={instanceId}, featureAccesses={len(featureAccesses) if featureAccesses else 0}")
|
||||
logger.debug(f"_getInstancePermissions: userId={userId}, instanceId={instanceId}, featureAccess={featureAccess is not None}")
|
||||
|
||||
if not featureAccesses:
|
||||
if not featureAccess:
|
||||
logger.debug(f"_getInstancePermissions: No FeatureAccess found for user {userId} and instance {instanceId}")
|
||||
return permissions
|
||||
|
||||
# Get role IDs via FeatureAccessRole junction table
|
||||
featureAccessId = featureAccesses[0].get("id")
|
||||
featureAccessRoles = rootInterface.db.getRecordset(
|
||||
FeatureAccessRole,
|
||||
recordFilter={"featureAccessId": featureAccessId}
|
||||
)
|
||||
roleIds = [far.get("roleId") for far in featureAccessRoles]
|
||||
# Get role IDs via interface method
|
||||
roleIds = rootInterface.getRoleIdsForFeatureAccess(str(featureAccess.id))
|
||||
|
||||
logger.debug(f"_getInstancePermissions: featureAccessId={featureAccessId}, roleIds={roleIds}")
|
||||
logger.debug(f"_getInstancePermissions: featureAccessId={featureAccess.id}, roleIds={roleIds}")
|
||||
|
||||
if not roleIds:
|
||||
logger.debug(f"_getInstancePermissions: No roles found for FeatureAccess {featureAccessId}")
|
||||
logger.debug(f"_getInstancePermissions: No roles found for FeatureAccess {featureAccess.id}")
|
||||
return permissions
|
||||
|
||||
# Check if user has admin role
|
||||
for roleId in roleIds:
|
||||
roles = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId})
|
||||
if roles:
|
||||
roleLabel = roles[0].get("roleLabel", "").lower()
|
||||
if "admin" in roleLabel:
|
||||
permissions["isAdmin"] = True
|
||||
break
|
||||
role = rootInterface.getRole(roleId)
|
||||
if role and "admin" in role.roleLabel.lower():
|
||||
permissions["isAdmin"] = True
|
||||
break
|
||||
|
||||
# Get permissions (AccessRules) for all roles
|
||||
for roleId in roleIds:
|
||||
accessRules = rootInterface.db.getRecordset(
|
||||
AccessRule,
|
||||
recordFilter={"roleId": roleId}
|
||||
)
|
||||
# Get all rules for this role (returns Pydantic models)
|
||||
accessRules = rootInterface.getAccessRules(roleId=roleId)
|
||||
|
||||
logger.debug(f"_getInstancePermissions: roleId={roleId}, accessRules={len(accessRules) if accessRules else 0}")
|
||||
|
||||
for rule in accessRules:
|
||||
context = rule.get("context", "")
|
||||
item = rule.get("item", "")
|
||||
context = rule.context
|
||||
item = rule.item or ""
|
||||
|
||||
# Handle DATA context (tables/fields)
|
||||
if context == "DATA" or context == AccessRuleContext.DATA:
|
||||
if context == AccessRuleContext.DATA or context == "DATA":
|
||||
if item:
|
||||
# Check if it's a field (table.field) or table
|
||||
if "." in item:
|
||||
tableName, fieldName = item.split(".", 1)
|
||||
if fieldName not in permissions["fields"]:
|
||||
permissions["fields"][fieldName] = {"view": False}
|
||||
permissions["fields"][fieldName]["view"] = permissions["fields"][fieldName]["view"] or rule.get("view", False)
|
||||
permissions["fields"][fieldName]["view"] = permissions["fields"][fieldName]["view"] or rule.view
|
||||
else:
|
||||
tableName = item
|
||||
if tableName not in permissions["tables"]:
|
||||
|
|
@ -322,20 +297,18 @@ def _getInstancePermissions(rootInterface, userId: str, instanceId: str) -> Dict
|
|||
|
||||
# Merge permissions (highest wins)
|
||||
current = permissions["tables"][tableName]
|
||||
current["view"] = current["view"] or rule.get("view", False)
|
||||
current["read"] = _mergeAccessLevel(current["read"], rule.get("read") or "n")
|
||||
current["create"] = _mergeAccessLevel(current["create"], rule.get("create") or "n")
|
||||
current["update"] = _mergeAccessLevel(current["update"], rule.get("update") or "n")
|
||||
current["delete"] = _mergeAccessLevel(current["delete"], rule.get("delete") or "n")
|
||||
current["view"] = current["view"] or rule.view
|
||||
current["read"] = _mergeAccessLevel(current["read"], rule.read or "n")
|
||||
current["create"] = _mergeAccessLevel(current["create"], rule.create or "n")
|
||||
current["update"] = _mergeAccessLevel(current["update"], rule.update or "n")
|
||||
current["delete"] = _mergeAccessLevel(current["delete"], rule.delete or "n")
|
||||
|
||||
# Handle UI context (views)
|
||||
# Views are stored with full objectKey (e.g., ui.feature.trustee.dashboard)
|
||||
elif context == "UI" or context == AccessRuleContext.UI:
|
||||
ruleView = rule.get("view", False)
|
||||
elif context == AccessRuleContext.UI or context == "UI":
|
||||
if item:
|
||||
# Store with full objectKey as per Navigation-API-Konzept
|
||||
permissions["views"][item] = permissions["views"].get(item, False) or ruleView
|
||||
elif ruleView:
|
||||
permissions["views"][item] = permissions["views"].get(item, False) or rule.view
|
||||
elif rule.view:
|
||||
# item=None means all views - set a wildcard flag
|
||||
permissions["views"]["_all"] = True
|
||||
|
||||
|
|
@ -343,7 +316,7 @@ def _getInstancePermissions(rootInterface, userId: str, instanceId: str) -> Dict
|
|||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting instance permissions: {e}")
|
||||
return permissions
|
||||
return permissions # Fail-safe: no permissions on error
|
||||
|
||||
|
||||
def _mergeAccessLevel(current: str, new: str) -> str:
|
||||
|
|
@ -359,7 +332,7 @@ def _mergeAccessLevel(current: str, new: str) -> str:
|
|||
|
||||
@router.post("/", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def create_feature(
|
||||
def create_feature(
|
||||
request: Request,
|
||||
code: str = Query(..., description="Unique feature code"),
|
||||
label: Dict[str, str] = None,
|
||||
|
|
@ -414,7 +387,7 @@ async def create_feature(
|
|||
|
||||
@router.get("/instances", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("60/minute")
|
||||
async def list_feature_instances(
|
||||
def list_feature_instances(
|
||||
request: Request,
|
||||
featureCode: Optional[str] = Query(None, description="Filter by feature code"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -456,7 +429,7 @@ async def list_feature_instances(
|
|||
|
||||
@router.get("/instances/{instanceId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("60/minute")
|
||||
async def get_feature_instance(
|
||||
def get_feature_instance(
|
||||
request: Request,
|
||||
instanceId: str,
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -500,7 +473,7 @@ async def get_feature_instance(
|
|||
|
||||
@router.post("/instances", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def create_feature_instance(
|
||||
def create_feature_instance(
|
||||
request: Request,
|
||||
data: FeatureInstanceCreate,
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -567,7 +540,7 @@ async def create_feature_instance(
|
|||
|
||||
@router.delete("/instances/{instanceId}", response_model=Dict[str, str])
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_feature_instance(
|
||||
def delete_feature_instance(
|
||||
request: Request,
|
||||
instanceId: str,
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -632,7 +605,7 @@ class FeatureInstanceUpdate(BaseModel):
|
|||
|
||||
@router.put("/instances/{instanceId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def updateFeatureInstance(
|
||||
def updateFeatureInstance(
|
||||
request: Request,
|
||||
instanceId: str,
|
||||
data: FeatureInstanceUpdate,
|
||||
|
|
@ -709,7 +682,7 @@ async def updateFeatureInstance(
|
|||
|
||||
@router.post("/instances/{instanceId}/sync-roles", response_model=SyncRolesResult)
|
||||
@limiter.limit("10/minute")
|
||||
async def sync_instance_roles(
|
||||
def sync_instance_roles(
|
||||
request: Request,
|
||||
instanceId: str,
|
||||
addOnly: bool = Query(True, description="Only add missing roles, don't remove extras"),
|
||||
|
|
@ -776,7 +749,7 @@ async def sync_instance_roles(
|
|||
|
||||
@router.get("/templates/roles", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("60/minute")
|
||||
async def list_template_roles(
|
||||
def list_template_roles(
|
||||
request: Request,
|
||||
featureCode: Optional[str] = Query(None, description="Filter by feature code"),
|
||||
sysAdmin: User = Depends(requireSysAdmin)
|
||||
|
|
@ -806,7 +779,7 @@ async def list_template_roles(
|
|||
|
||||
@router.post("/templates/roles", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def create_template_role(
|
||||
def create_template_role(
|
||||
request: Request,
|
||||
roleLabel: str = Query(..., description="Role label (e.g., 'admin', 'viewer')"),
|
||||
featureCode: str = Query(..., description="Feature code this role belongs to"),
|
||||
|
|
@ -891,7 +864,7 @@ class FeatureInstanceUserUpdate(BaseModel):
|
|||
|
||||
@router.get("/instances/{instanceId}/users", response_model=List[FeatureInstanceUserResponse])
|
||||
@limiter.limit("60/minute")
|
||||
async def list_feature_instance_users(
|
||||
def list_feature_instance_users(
|
||||
request: Request,
|
||||
instanceId: str,
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -924,49 +897,35 @@ async def list_feature_instance_users(
|
|||
detail="Access denied to this feature instance"
|
||||
)
|
||||
|
||||
# Get all FeatureAccess records for this instance
|
||||
from modules.datamodels.datamodelMembership import FeatureAccess, FeatureAccessRole
|
||||
from modules.datamodels.datamodelRbac import Role
|
||||
|
||||
featureAccesses = rootInterface.db.getRecordset(
|
||||
FeatureAccess,
|
||||
recordFilter={"featureInstanceId": instanceId}
|
||||
)
|
||||
# Get all FeatureAccess records for this instance (Pydantic models)
|
||||
featureAccesses = rootInterface.getFeatureAccessesByInstance(instanceId)
|
||||
|
||||
result = []
|
||||
for fa in featureAccesses:
|
||||
userId = fa.get("userId")
|
||||
featureAccessId = fa.get("id")
|
||||
|
||||
# Get user info
|
||||
users = rootInterface.db.getRecordset(UserInDB, recordFilter={"id": userId})
|
||||
if not users:
|
||||
# Get user info (Pydantic model)
|
||||
user = rootInterface.getUser(str(fa.userId))
|
||||
if not user:
|
||||
continue
|
||||
user = users[0]
|
||||
|
||||
# Get role IDs via FeatureAccessRole junction table
|
||||
featureAccessRoles = rootInterface.db.getRecordset(
|
||||
FeatureAccessRole,
|
||||
recordFilter={"featureAccessId": featureAccessId}
|
||||
)
|
||||
roleIds = [far.get("roleId") for far in featureAccessRoles]
|
||||
# Get role IDs via interface method
|
||||
roleIds = rootInterface.getRoleIdsForFeatureAccess(str(fa.id))
|
||||
|
||||
# Get role labels
|
||||
roleLabels = []
|
||||
for roleId in roleIds:
|
||||
roles = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId})
|
||||
if roles:
|
||||
roleLabels.append(roles[0].get("roleLabel", ""))
|
||||
role = rootInterface.getRole(roleId)
|
||||
if role:
|
||||
roleLabels.append(role.roleLabel)
|
||||
|
||||
result.append(FeatureInstanceUserResponse(
|
||||
id=featureAccessId, # FeatureAccess ID as primary key
|
||||
userId=userId,
|
||||
username=user.get("username", ""),
|
||||
email=user.get("email"),
|
||||
fullName=user.get("fullName"),
|
||||
id=str(fa.id), # FeatureAccess ID as primary key
|
||||
userId=str(fa.userId),
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
fullName=user.fullName,
|
||||
roleIds=roleIds,
|
||||
roleLabels=roleLabels,
|
||||
enabled=fa.get("enabled", True)
|
||||
enabled=fa.enabled
|
||||
))
|
||||
|
||||
return result
|
||||
|
|
@ -983,7 +942,7 @@ async def list_feature_instance_users(
|
|||
|
||||
@router.post("/instances/{instanceId}/users", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def add_user_to_feature_instance(
|
||||
def add_user_to_feature_instance(
|
||||
request: Request,
|
||||
instanceId: str,
|
||||
data: FeatureInstanceUserCreate,
|
||||
|
|
@ -1026,8 +985,8 @@ async def add_user_to_feature_instance(
|
|||
)
|
||||
|
||||
# Verify user exists
|
||||
users = rootInterface.db.getRecordset(UserInDB, recordFilter={"id": data.userId})
|
||||
if not users:
|
||||
user = rootInterface.getUser(data.userId)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"User '{data.userId}' not found"
|
||||
|
|
@ -1035,10 +994,7 @@ async def add_user_to_feature_instance(
|
|||
|
||||
# Check if user already has access
|
||||
from modules.datamodels.datamodelMembership import FeatureAccess, FeatureAccessRole
|
||||
existingAccess = rootInterface.db.getRecordset(
|
||||
FeatureAccess,
|
||||
recordFilter={"userId": data.userId, "featureInstanceId": instanceId}
|
||||
)
|
||||
existingAccess = rootInterface.getFeatureAccess(data.userId, instanceId)
|
||||
if existingAccess:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
|
|
@ -1087,7 +1043,7 @@ async def add_user_to_feature_instance(
|
|||
|
||||
@router.delete("/instances/{instanceId}/users/{userId}", response_model=Dict[str, str])
|
||||
@limiter.limit("30/minute")
|
||||
async def remove_user_from_feature_instance(
|
||||
def remove_user_from_feature_instance(
|
||||
request: Request,
|
||||
instanceId: str,
|
||||
userId: str,
|
||||
|
|
@ -1131,17 +1087,14 @@ async def remove_user_from_feature_instance(
|
|||
|
||||
# Find FeatureAccess record
|
||||
from modules.datamodels.datamodelMembership import FeatureAccess
|
||||
existingAccess = rootInterface.db.getRecordset(
|
||||
FeatureAccess,
|
||||
recordFilter={"userId": userId, "featureInstanceId": instanceId}
|
||||
)
|
||||
existingAccess = rootInterface.getFeatureAccess(userId, instanceId)
|
||||
if not existingAccess:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User does not have access to this feature instance"
|
||||
)
|
||||
|
||||
featureAccessId = existingAccess[0].get("id")
|
||||
featureAccessId = str(existingAccess.id)
|
||||
|
||||
# Delete FeatureAccess (CASCADE will delete FeatureAccessRole records)
|
||||
rootInterface.db.recordDelete(FeatureAccess, featureAccessId)
|
||||
|
|
@ -1168,7 +1121,7 @@ async def remove_user_from_feature_instance(
|
|||
|
||||
@router.put("/instances/{instanceId}/users/{userId}/roles", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def update_feature_instance_user_roles(
|
||||
def update_feature_instance_user_roles(
|
||||
request: Request,
|
||||
instanceId: str,
|
||||
userId: str,
|
||||
|
|
@ -1215,29 +1168,21 @@ async def update_feature_instance_user_roles(
|
|||
|
||||
# Find FeatureAccess record
|
||||
from modules.datamodels.datamodelMembership import FeatureAccess, FeatureAccessRole
|
||||
existingAccess = rootInterface.db.getRecordset(
|
||||
FeatureAccess,
|
||||
recordFilter={"userId": userId, "featureInstanceId": instanceId}
|
||||
)
|
||||
existingAccess = rootInterface.getFeatureAccess(userId, instanceId)
|
||||
if not existingAccess:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User does not have access to this feature instance"
|
||||
)
|
||||
|
||||
featureAccessId = existingAccess[0].get("id")
|
||||
featureAccessId = str(existingAccess.id)
|
||||
|
||||
# Update enabled flag if provided
|
||||
if data.enabled is not None:
|
||||
rootInterface.db.recordModify(FeatureAccess, featureAccessId, {"enabled": data.enabled})
|
||||
|
||||
# Delete existing FeatureAccessRole records
|
||||
existingRoles = rootInterface.db.getRecordset(
|
||||
FeatureAccessRole,
|
||||
recordFilter={"featureAccessId": featureAccessId}
|
||||
)
|
||||
for role in existingRoles:
|
||||
rootInterface.db.recordDelete(FeatureAccessRole, role.get("id"))
|
||||
# Delete existing FeatureAccessRole records via interface method
|
||||
rootInterface.deleteFeatureAccessRoles(featureAccessId)
|
||||
|
||||
# Create new FeatureAccessRole records
|
||||
for roleId in data.roleIds:
|
||||
|
|
@ -1271,7 +1216,7 @@ async def update_feature_instance_user_roles(
|
|||
|
||||
@router.get("/instances/{instanceId}/available-roles", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("60/minute")
|
||||
async def get_feature_instance_available_roles(
|
||||
def get_feature_instance_available_roles(
|
||||
request: Request,
|
||||
instanceId: str,
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -1304,21 +1249,17 @@ async def get_feature_instance_available_roles(
|
|||
detail="Access denied to this feature instance"
|
||||
)
|
||||
|
||||
# Get roles for this instance
|
||||
from modules.datamodels.datamodelRbac import Role
|
||||
instanceRoles = rootInterface.db.getRecordset(
|
||||
Role,
|
||||
recordFilter={"featureInstanceId": instanceId}
|
||||
)
|
||||
# Get roles for this instance using interface method
|
||||
instanceRoles = rootInterface.getRolesByFeatureInstance(instanceId)
|
||||
|
||||
result = []
|
||||
for role in instanceRoles:
|
||||
result.append({
|
||||
"id": role.get("id"),
|
||||
"roleLabel": role.get("roleLabel"),
|
||||
"description": role.get("description", {}),
|
||||
"featureCode": role.get("featureCode"),
|
||||
"isSystemRole": role.get("isSystemRole", False)
|
||||
"id": role.id,
|
||||
"roleLabel": role.roleLabel,
|
||||
"description": role.description or {},
|
||||
"featureCode": role.featureCode,
|
||||
"isSystemRole": role.isSystemRole
|
||||
})
|
||||
|
||||
return result
|
||||
|
|
@ -1339,7 +1280,7 @@ async def get_feature_instance_available_roles(
|
|||
|
||||
@router.get("/{featureCode}", response_model=Dict[str, Any])
|
||||
@limiter.limit("60/minute")
|
||||
async def get_feature(
|
||||
def get_feature(
|
||||
request: Request,
|
||||
featureCode: str,
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -1394,15 +1335,13 @@ def _hasMandateAdminRole(context: RequestContext) -> bool:
|
|||
# Check if any of the user's roles is an admin role
|
||||
try:
|
||||
rootInterface = getRootInterface()
|
||||
from modules.datamodels.datamodelRbac import Role
|
||||
|
||||
for roleId in context.roleIds:
|
||||
roleRecords = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId})
|
||||
if roleRecords:
|
||||
role = roleRecords[0]
|
||||
roleLabel = role.get("roleLabel", "")
|
||||
role = rootInterface.getRole(roleId)
|
||||
if role:
|
||||
roleLabel = role.roleLabel
|
||||
# Admin role at mandate level (not feature-instance level)
|
||||
if roleLabel == "admin" and role.get("mandateId") and not role.get("featureInstanceId"):
|
||||
if roleLabel == "admin" and role.mandateId and not role.featureInstanceId:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class RbacImportResult(BaseModel):
|
|||
|
||||
@router.get("/export/global", response_model=RbacExportData)
|
||||
@limiter.limit("10/minute")
|
||||
async def export_global_rbac(
|
||||
def export_global_rbac(
|
||||
request: Request,
|
||||
sysAdmin: User = Depends(requireSysAdmin)
|
||||
) -> RbacExportData:
|
||||
|
|
@ -85,34 +85,31 @@ async def export_global_rbac(
|
|||
try:
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Get all global template roles (mandateId is NULL)
|
||||
allRoles = rootInterface.db.getRecordset(Role)
|
||||
globalRoles = [r for r in allRoles if r.get("mandateId") is None]
|
||||
# Get all global template roles (mandateId is NULL) using interface method
|
||||
allRoles = rootInterface.getAllRoles()
|
||||
globalRoles = [r for r in allRoles if r.mandateId is None]
|
||||
|
||||
exportRoles = []
|
||||
for role in globalRoles:
|
||||
roleId = role.get("id")
|
||||
roleId = role.id
|
||||
|
||||
# Get access rules for this role
|
||||
accessRules = rootInterface.db.getRecordset(
|
||||
AccessRule,
|
||||
recordFilter={"roleId": roleId}
|
||||
)
|
||||
# Get access rules for this role using interface method
|
||||
accessRules = rootInterface.getAccessRulesByRole(roleId)
|
||||
|
||||
exportRoles.append(RoleExport(
|
||||
roleLabel=role.get("roleLabel"),
|
||||
description=role.get("description", {}),
|
||||
featureCode=role.get("featureCode"),
|
||||
isSystemRole=role.get("isSystemRole", False),
|
||||
roleLabel=role.roleLabel,
|
||||
description=role.description or {},
|
||||
featureCode=role.featureCode,
|
||||
isSystemRole=role.isSystemRole,
|
||||
accessRules=[
|
||||
{
|
||||
"context": r.get("context"),
|
||||
"item": r.get("item"),
|
||||
"view": r.get("view", False),
|
||||
"read": r.get("read"),
|
||||
"create": r.get("create"),
|
||||
"update": r.get("update"),
|
||||
"delete": r.get("delete")
|
||||
"context": r.context,
|
||||
"item": r.item,
|
||||
"view": r.view if r.view is not None else False,
|
||||
"read": r.read,
|
||||
"create": r.create,
|
||||
"update": r.update,
|
||||
"delete": r.delete
|
||||
}
|
||||
for r in accessRules
|
||||
]
|
||||
|
|
@ -191,21 +188,20 @@ async def import_global_rbac(
|
|||
result.rolesSkipped += 1
|
||||
continue
|
||||
|
||||
# Check if role exists (global role with same label and featureCode)
|
||||
existingRoles = rootInterface.db.getRecordset(
|
||||
Role,
|
||||
recordFilter={
|
||||
"roleLabel": roleLabel,
|
||||
"mandateId": None,
|
||||
"featureCode": featureCode
|
||||
}
|
||||
)
|
||||
# Check if role exists (global role with same label and featureCode) using interface method
|
||||
allRoles = rootInterface.getAllRoles()
|
||||
existingRoles = [
|
||||
r for r in allRoles
|
||||
if r.roleLabel == roleLabel
|
||||
and r.mandateId is None
|
||||
and r.featureCode == featureCode
|
||||
]
|
||||
|
||||
if existingRoles:
|
||||
if updateExisting:
|
||||
# Update existing role
|
||||
existingRole = existingRoles[0]
|
||||
roleId = existingRole.get("id")
|
||||
roleId = existingRole.id
|
||||
|
||||
rootInterface.db.recordModify(
|
||||
Role,
|
||||
|
|
@ -285,7 +281,7 @@ async def import_global_rbac(
|
|||
|
||||
@router.get("/export/mandate", response_model=RbacExportData)
|
||||
@limiter.limit("10/minute")
|
||||
async def export_mandate_rbac(
|
||||
def export_mandate_rbac(
|
||||
request: Request,
|
||||
includeFeatureInstances: bool = True,
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -315,41 +311,38 @@ async def export_mandate_rbac(
|
|||
try:
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Get mandate-level roles
|
||||
allRoles = rootInterface.db.getRecordset(Role)
|
||||
# Get mandate-level roles using interface method
|
||||
allRoles = rootInterface.getAllRoles()
|
||||
mandateRoles = [
|
||||
r for r in allRoles
|
||||
if str(r.get("mandateId")) == str(context.mandateId)
|
||||
if str(r.mandateId) == str(context.mandateId)
|
||||
]
|
||||
|
||||
# Filter by feature instance if not including them
|
||||
if not includeFeatureInstances:
|
||||
mandateRoles = [r for r in mandateRoles if not r.get("featureInstanceId")]
|
||||
mandateRoles = [r for r in mandateRoles if not r.featureInstanceId]
|
||||
|
||||
exportRoles = []
|
||||
for role in mandateRoles:
|
||||
roleId = role.get("id")
|
||||
roleId = role.id
|
||||
|
||||
# Get access rules for this role
|
||||
accessRules = rootInterface.db.getRecordset(
|
||||
AccessRule,
|
||||
recordFilter={"roleId": roleId}
|
||||
)
|
||||
# Get access rules for this role using interface method
|
||||
accessRules = rootInterface.getAccessRulesByRole(roleId)
|
||||
|
||||
exportRoles.append(RoleExport(
|
||||
roleLabel=role.get("roleLabel"),
|
||||
description=role.get("description", {}),
|
||||
featureCode=role.get("featureCode"),
|
||||
isSystemRole=role.get("isSystemRole", False),
|
||||
roleLabel=role.roleLabel,
|
||||
description=role.description or {},
|
||||
featureCode=role.featureCode,
|
||||
isSystemRole=role.isSystemRole,
|
||||
accessRules=[
|
||||
{
|
||||
"context": r.get("context"),
|
||||
"item": r.get("item"),
|
||||
"view": r.get("view", False),
|
||||
"read": r.get("read"),
|
||||
"create": r.get("create"),
|
||||
"update": r.get("update"),
|
||||
"delete": r.get("delete")
|
||||
"context": r.context,
|
||||
"item": r.item,
|
||||
"view": r.view if r.view is not None else False,
|
||||
"read": r.read,
|
||||
"create": r.create,
|
||||
"update": r.update,
|
||||
"delete": r.delete
|
||||
}
|
||||
for r in accessRules
|
||||
]
|
||||
|
|
@ -453,21 +446,20 @@ async def import_mandate_rbac(
|
|||
result.rolesSkipped += 1
|
||||
continue
|
||||
|
||||
# Check if role exists (mandate role with same label)
|
||||
existingRoles = rootInterface.db.getRecordset(
|
||||
Role,
|
||||
recordFilter={
|
||||
"roleLabel": roleLabel,
|
||||
"mandateId": str(context.mandateId),
|
||||
"featureInstanceId": None # Only mandate-level roles
|
||||
}
|
||||
)
|
||||
# Check if role exists (mandate role with same label) using interface method
|
||||
allRoles = rootInterface.getAllRoles()
|
||||
existingRoles = [
|
||||
r for r in allRoles
|
||||
if r.roleLabel == roleLabel
|
||||
and str(r.mandateId) == str(context.mandateId)
|
||||
and r.featureInstanceId is None # Only mandate-level roles
|
||||
]
|
||||
|
||||
if existingRoles:
|
||||
if updateExisting:
|
||||
# Update existing role
|
||||
existingRole = existingRoles[0]
|
||||
roleId = existingRole.get("id")
|
||||
roleId = existingRole.id
|
||||
|
||||
rootInterface.db.recordModify(
|
||||
Role,
|
||||
|
|
@ -556,12 +548,11 @@ def _hasMandateAdminRole(context: RequestContext) -> bool:
|
|||
rootInterface = getRootInterface()
|
||||
|
||||
for roleId in context.roleIds:
|
||||
roleRecords = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId})
|
||||
if roleRecords:
|
||||
role = roleRecords[0]
|
||||
roleLabel = role.get("roleLabel", "")
|
||||
role = rootInterface.getRole(roleId)
|
||||
if role:
|
||||
roleLabel = role.roleLabel
|
||||
# Admin role at mandate level
|
||||
if roleLabel == "admin" and role.get("mandateId") and not role.get("featureInstanceId"):
|
||||
if roleLabel == "admin" and role.mandateId and not role.featureInstanceId:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
|
@ -580,10 +571,10 @@ def _updateAccessRules(interface, roleId: str, newRules: List[Dict[str, Any]]) -
|
|||
Number of rules created/updated
|
||||
"""
|
||||
try:
|
||||
# Delete existing rules for this role
|
||||
existingRules = interface.db.getRecordset(AccessRule, recordFilter={"roleId": roleId})
|
||||
# Delete existing rules for this role using interface method
|
||||
existingRules = interface.getAccessRulesByRole(roleId)
|
||||
for rule in existingRules:
|
||||
interface.db.recordDelete(AccessRule, rule.get("id"))
|
||||
interface.db.recordDelete(AccessRule, rule.id)
|
||||
|
||||
# Create new rules
|
||||
count = 0
|
||||
|
|
|
|||
|
|
@ -36,25 +36,17 @@ def _getUserRoleLabels(interface, userId: str) -> List[str]:
|
|||
"""
|
||||
roleLabels: Set[str] = set()
|
||||
|
||||
# Get all UserMandate records for this user
|
||||
userMandates = interface.db.getRecordset(UserMandate, recordFilter={"userId": userId})
|
||||
# Get all UserMandate records for this user (Pydantic models)
|
||||
userMandates = interface.getUserMandates(userId)
|
||||
|
||||
for um in userMandates:
|
||||
userMandateId = um.get("id")
|
||||
if not userMandateId:
|
||||
continue
|
||||
|
||||
# Get all UserMandateRole records for this membership
|
||||
userMandateRoles = interface.db.getRecordset(
|
||||
UserMandateRole,
|
||||
recordFilter={"userMandateId": str(userMandateId)}
|
||||
)
|
||||
# Get all UserMandateRole records for this membership (Pydantic models)
|
||||
userMandateRoles = interface.getUserMandateRoles(str(um.id))
|
||||
|
||||
for umr in userMandateRoles:
|
||||
roleId = umr.get("roleId")
|
||||
if roleId:
|
||||
if umr.roleId:
|
||||
# Get role by ID to get roleLabel
|
||||
role = interface.getRole(str(roleId))
|
||||
role = interface.getRole(str(umr.roleId))
|
||||
if role:
|
||||
roleLabels.add(role.roleLabel)
|
||||
|
||||
|
|
@ -76,7 +68,7 @@ router = APIRouter(
|
|||
|
||||
@router.get("/", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("60/minute")
|
||||
async def list_roles(
|
||||
def list_roles(
|
||||
request: Request,
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
) -> List[Dict[str, Any]]:
|
||||
|
|
@ -121,7 +113,7 @@ async def list_roles(
|
|||
|
||||
@router.get("/options", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("60/minute")
|
||||
async def get_role_options(
|
||||
def get_role_options(
|
||||
request: Request,
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
) -> List[Dict[str, Any]]:
|
||||
|
|
@ -162,7 +154,7 @@ async def get_role_options(
|
|||
|
||||
@router.post("/", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def create_role(
|
||||
def create_role(
|
||||
request: Request,
|
||||
role: Role = Body(...),
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
|
|
@ -206,7 +198,7 @@ async def create_role(
|
|||
|
||||
@router.get("/{roleId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("60/minute")
|
||||
async def get_role(
|
||||
def get_role(
|
||||
request: Request,
|
||||
roleId: str = Path(..., description="Role ID"),
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
|
|
@ -250,7 +242,7 @@ async def get_role(
|
|||
|
||||
@router.put("/{roleId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def update_role(
|
||||
def update_role(
|
||||
request: Request,
|
||||
roleId: str = Path(..., description="Role ID"),
|
||||
role: Role = Body(...),
|
||||
|
|
@ -298,7 +290,7 @@ async def update_role(
|
|||
|
||||
@router.delete("/{roleId}", response_model=Dict[str, str])
|
||||
@limiter.limit("30/minute")
|
||||
async def delete_role(
|
||||
def delete_role(
|
||||
request: Request,
|
||||
roleId: str = Path(..., description="Role ID"),
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
|
|
@ -342,7 +334,7 @@ async def delete_role(
|
|||
|
||||
@router.get("/users", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("60/minute")
|
||||
async def list_users_with_roles(
|
||||
def list_users_with_roles(
|
||||
request: Request,
|
||||
roleLabel: Optional[str] = Query(None, description="Filter by role label"),
|
||||
mandateId: Optional[str] = Query(None, description="Filter by mandate ID (via UserMandate)"),
|
||||
|
|
@ -362,21 +354,13 @@ async def list_users_with_roles(
|
|||
try:
|
||||
interface = getRootInterface()
|
||||
|
||||
# Get all users (SysAdmin sees all)
|
||||
# Use db.getRecordset with UserInDB (the actual database model)
|
||||
allUsersData = interface.db.getRecordset(UserInDB)
|
||||
# Convert to User objects, filtering out sensitive fields
|
||||
users = []
|
||||
for u in allUsersData:
|
||||
cleanedUser = {k: v for k, v in u.items() if not k.startswith("_") and k != "hashedPassword" and k != "resetToken" and k != "resetTokenExpires"}
|
||||
if cleanedUser.get("roleLabels") is None:
|
||||
cleanedUser["roleLabels"] = []
|
||||
users.append(User(**cleanedUser))
|
||||
# Get all users via interface method (Pydantic models)
|
||||
users = interface.getAllUsers()
|
||||
|
||||
# Filter by mandate if specified (via UserMandate table)
|
||||
if mandateId:
|
||||
userMandates = interface.db.getRecordset(UserMandate, recordFilter={"mandateId": mandateId})
|
||||
mandateUserIds = {str(um["userId"]) for um in userMandates}
|
||||
userMandates = interface.getUserMandatesByMandate(mandateId)
|
||||
mandateUserIds = {str(um.userId) for um in userMandates}
|
||||
users = [u for u in users if str(u.id) in mandateUserIds]
|
||||
|
||||
# Filter by role if specified (via UserMandateRole)
|
||||
|
|
@ -412,7 +396,7 @@ async def list_users_with_roles(
|
|||
|
||||
@router.get("/users/{userId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("60/minute")
|
||||
async def get_user_roles(
|
||||
def get_user_roles(
|
||||
request: Request,
|
||||
userId: str = Path(..., description="User ID"),
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
|
|
@ -462,7 +446,7 @@ async def get_user_roles(
|
|||
|
||||
@router.put("/users/{userId}/roles", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def update_user_roles(
|
||||
def update_user_roles(
|
||||
request: Request,
|
||||
userId: str = Path(..., description="User ID"),
|
||||
newRoleLabels: List[str] = Body(..., description="List of role labels to assign"),
|
||||
|
|
@ -499,21 +483,18 @@ async def update_user_roles(
|
|||
logger.warning(f"Non-standard role label assigned: {roleLabel}")
|
||||
|
||||
# Get user's first mandate (for role assignment)
|
||||
userMandates = interface.db.getRecordset(UserMandate, recordFilter={"userId": userId})
|
||||
userMandates = interface.getUserMandates(userId)
|
||||
if not userMandates:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"User {userId} has no mandate memberships. Add to mandate first."
|
||||
)
|
||||
|
||||
userMandateId = str(userMandates[0].get("id"))
|
||||
userMandateId = str(userMandates[0].id)
|
||||
|
||||
# Get current roles for this mandate
|
||||
existingRoles = interface.db.getRecordset(
|
||||
UserMandateRole,
|
||||
recordFilter={"userMandateId": userMandateId}
|
||||
)
|
||||
existingRoleIds = {str(r.get("roleId")) for r in existingRoles}
|
||||
# Get current roles for this mandate (Pydantic models)
|
||||
existingRoles = interface.getUserMandateRoles(userMandateId)
|
||||
existingRoleIds = {str(r.roleId) for r in existingRoles}
|
||||
|
||||
# Convert roleLabels to roleIds
|
||||
newRoleIds = set()
|
||||
|
|
@ -524,8 +505,8 @@ async def update_user_roles(
|
|||
|
||||
# Remove roles that are no longer needed
|
||||
for existingRole in existingRoles:
|
||||
if str(existingRole.get("roleId")) not in newRoleIds:
|
||||
interface.db.recordDelete(UserMandateRole, str(existingRole.get("id")))
|
||||
if str(existingRole.roleId) not in newRoleIds:
|
||||
interface.removeRoleFromUserMandate(userMandateId, str(existingRole.roleId))
|
||||
|
||||
# Add new roles
|
||||
for roleId in newRoleIds:
|
||||
|
|
@ -559,7 +540,7 @@ async def update_user_roles(
|
|||
|
||||
@router.post("/users/{userId}/roles/{roleLabel}", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def add_user_role(
|
||||
def add_user_role(
|
||||
request: Request,
|
||||
userId: str = Path(..., description="User ID"),
|
||||
roleLabel: str = Path(..., description="Role label to add"),
|
||||
|
|
@ -596,25 +577,22 @@ async def add_user_role(
|
|||
)
|
||||
|
||||
# Get user's first mandate
|
||||
userMandates = interface.db.getRecordset(UserMandate, recordFilter={"userId": userId})
|
||||
userMandates = interface.getUserMandates(userId)
|
||||
if not userMandates:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"User {userId} has no mandate memberships. Add to mandate first."
|
||||
)
|
||||
|
||||
userMandateId = str(userMandates[0].get("id"))
|
||||
userMandateId = str(userMandates[0].id)
|
||||
|
||||
# Check if role is already assigned
|
||||
existingAssignment = interface.db.getRecordset(
|
||||
UserMandateRole,
|
||||
recordFilter={"userMandateId": userMandateId, "roleId": str(role.id)}
|
||||
)
|
||||
# Check if role is already assigned - use interface method
|
||||
existingRoles = interface.getUserMandateRoles(userMandateId)
|
||||
roleAlreadyAssigned = any(str(r.roleId) == str(role.id) for r in existingRoles)
|
||||
|
||||
if not existingAssignment:
|
||||
# Add the role
|
||||
newRole = UserMandateRole(userMandateId=userMandateId, roleId=str(role.id))
|
||||
interface.db.recordCreate(UserMandateRole, newRole.model_dump())
|
||||
if not roleAlreadyAssigned:
|
||||
# Add the role via interface method
|
||||
interface.addRoleToUserMandate(userMandateId, str(role.id))
|
||||
logger.info(f"Added role {roleLabel} to user {userId} by SysAdmin {currentUser.id}")
|
||||
|
||||
userRoleLabels = _getUserRoleLabels(interface, userId)
|
||||
|
|
@ -641,7 +619,7 @@ async def add_user_role(
|
|||
|
||||
@router.delete("/users/{userId}/roles/{roleLabel}", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def remove_user_role(
|
||||
def remove_user_role(
|
||||
request: Request,
|
||||
userId: str = Path(..., description="User ID"),
|
||||
roleLabel: str = Path(..., description="Role label to remove"),
|
||||
|
|
@ -678,20 +656,14 @@ async def remove_user_role(
|
|||
)
|
||||
|
||||
# Remove role from all user's mandates
|
||||
userMandates = interface.db.getRecordset(UserMandate, recordFilter={"userId": userId})
|
||||
userMandates = interface.getUserMandates(userId)
|
||||
roleRemoved = False
|
||||
|
||||
for um in userMandates:
|
||||
userMandateId = str(um.get("id"))
|
||||
userMandateId = str(um.id)
|
||||
|
||||
# Find and delete the role assignment
|
||||
assignments = interface.db.getRecordset(
|
||||
UserMandateRole,
|
||||
recordFilter={"userMandateId": userMandateId, "roleId": str(role.id)}
|
||||
)
|
||||
|
||||
for assignment in assignments:
|
||||
interface.db.recordDelete(UserMandateRole, str(assignment.get("id")))
|
||||
# Remove role via interface method
|
||||
if interface.removeRoleFromUserMandate(userMandateId, str(role.id)):
|
||||
roleRemoved = True
|
||||
|
||||
if roleRemoved:
|
||||
|
|
@ -721,7 +693,7 @@ async def remove_user_role(
|
|||
|
||||
@router.get("/roles/{roleLabel}/users", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("60/minute")
|
||||
async def get_users_with_role(
|
||||
def get_users_with_role(
|
||||
request: Request,
|
||||
roleLabel: str = Path(..., description="Role label"),
|
||||
mandateId: Optional[str] = Query(None, description="Filter by mandate ID (via UserMandate)"),
|
||||
|
|
@ -751,25 +723,21 @@ async def get_users_with_role(
|
|||
detail=f"Role '{roleLabel}' not found"
|
||||
)
|
||||
|
||||
# Get all UserMandateRole assignments for this role
|
||||
roleAssignments = interface.db.getRecordset(
|
||||
UserMandateRole,
|
||||
recordFilter={"roleId": str(role.id)}
|
||||
)
|
||||
# Get all UserMandateRole assignments for this role (Pydantic models)
|
||||
roleAssignments = interface.getUserMandateRolesByRole(str(role.id))
|
||||
|
||||
# Get unique userMandateIds
|
||||
userMandateIds = {str(ra.get("userMandateId")) for ra in roleAssignments}
|
||||
userMandateIds = {str(ra.userMandateId) for ra in roleAssignments}
|
||||
|
||||
# Get userIds from UserMandate records
|
||||
userIds: Set[str] = set()
|
||||
for userMandateId in userMandateIds:
|
||||
umRecords = interface.db.getRecordset(UserMandate, recordFilter={"id": userMandateId})
|
||||
if umRecords:
|
||||
um = umRecords[0]
|
||||
um = interface.getUserMandateById(userMandateId)
|
||||
if um:
|
||||
# Filter by mandate if specified
|
||||
if mandateId and str(um.get("mandateId")) != mandateId:
|
||||
if mandateId and str(um.mandateId) != mandateId:
|
||||
continue
|
||||
userIds.add(str(um.get("userId")))
|
||||
userIds.add(str(um.userId))
|
||||
|
||||
# Get users and format response
|
||||
result = []
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ router = APIRouter(
|
|||
|
||||
@router.get("/permissions", response_model=UserPermissions)
|
||||
@limiter.limit("300/minute") # Raised from 60 - sidebar checks many pages individually
|
||||
async def get_permissions(
|
||||
def get_permissions(
|
||||
request: Request,
|
||||
context: str = Query(..., description="Context type: DATA, UI, or RESOURCE"),
|
||||
item: Optional[str] = Query(None, description="Item identifier (table name, UI path, or resource path)"),
|
||||
|
|
@ -78,11 +78,18 @@ async def get_permissions(
|
|||
)
|
||||
|
||||
# MULTI-TENANT: Get permissions using context (mandateId/featureInstanceId)
|
||||
# For DATA context, resolve short model names to full objectKeys
|
||||
# e.g., "ChatWorkflow" → "data.chat.ChatWorkflow"
|
||||
resolvedItem = item or ""
|
||||
if accessContext == AccessRuleContext.DATA and resolvedItem and "." not in resolvedItem:
|
||||
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
||||
resolvedItem = buildDataObjectKey(resolvedItem)
|
||||
|
||||
# Pass mandateId and featureInstanceId to load Feature-Instance roles
|
||||
permissions = interface.rbac.getUserPermissions(
|
||||
reqContext.user,
|
||||
accessContext,
|
||||
item or "",
|
||||
resolvedItem,
|
||||
mandateId=reqContext.mandateId,
|
||||
featureInstanceId=reqContext.featureInstanceId
|
||||
)
|
||||
|
|
@ -101,7 +108,7 @@ async def get_permissions(
|
|||
|
||||
@router.get("/permissions/all", response_model=Dict[str, Any])
|
||||
@limiter.limit("120/minute") # Raised from 30 - optimized endpoint for bulk permission fetch
|
||||
async def get_all_permissions(
|
||||
def get_all_permissions(
|
||||
request: Request,
|
||||
context: Optional[str] = Query(None, description="Context type: UI or RESOURCE (if not provided, returns both)"),
|
||||
reqContext: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -179,17 +186,15 @@ async def get_all_permissions(
|
|||
|
||||
# For UI/RESOURCE: Load system roles the user has across ALL their mandates
|
||||
# This allows users to access system UI elements without needing a specific mandate header
|
||||
userMandates = rootInterface.db.getRecordset(
|
||||
UserMandate,
|
||||
recordFilter={"userId": str(reqContext.user.id), "enabled": True}
|
||||
)
|
||||
allUserMandates = rootInterface.getUserMandates(str(reqContext.user.id))
|
||||
userMandates = [um for um in allUserMandates if um.enabled]
|
||||
|
||||
logger.debug(f"UI/RESOURCE permissions: Found {len(userMandates)} UserMandates for user {reqContext.user.id}")
|
||||
|
||||
# Collect all role IDs the user has across all mandates
|
||||
for userMandate in userMandates:
|
||||
mandateRoleIds = rootInterface.getRoleIdsForUserMandate(userMandate.get("id"))
|
||||
logger.debug(f"UI/RESOURCE permissions: UserMandate {userMandate.get('id')} (mandate {userMandate.get('mandateId')}) has {len(mandateRoleIds)} roles: {mandateRoleIds}")
|
||||
mandateRoleIds = rootInterface.getRoleIdsForUserMandate(userMandate.id)
|
||||
logger.debug(f"UI/RESOURCE permissions: UserMandate {userMandate.id} (mandate {userMandate.mandateId}) has {len(mandateRoleIds)} roles: {mandateRoleIds}")
|
||||
for rid in mandateRoleIds:
|
||||
if rid not in roleIds:
|
||||
roleIds.append(rid)
|
||||
|
|
@ -210,14 +215,11 @@ async def get_all_permissions(
|
|||
allRules[ctx] = []
|
||||
# Get all rules for user's roles - bypass RBAC filtering
|
||||
for roleId in roleIds:
|
||||
ruleRecords = rootInterface.db.getRecordset(
|
||||
AccessRule,
|
||||
recordFilter={"roleId": str(roleId), "context": ctx.value}
|
||||
)
|
||||
for ruleRecord in ruleRecords:
|
||||
# Convert dict to AccessRule object
|
||||
cleanedRule = {k: v for k, v in ruleRecord.items() if not k.startswith("_")}
|
||||
allRules[ctx].append(AccessRule(**cleanedRule))
|
||||
# Use interface method and filter by context
|
||||
rules = rootInterface.getAccessRulesByRole(str(roleId))
|
||||
for rule in rules:
|
||||
if rule.context == ctx.value:
|
||||
allRules[ctx].append(rule)
|
||||
|
||||
# Build result: for each context, collect all unique items and calculate permissions
|
||||
for ctx in contextsToFetch:
|
||||
|
|
@ -298,7 +300,7 @@ async def get_all_permissions(
|
|||
|
||||
@router.get("/rules", response_model=PaginatedResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_access_rules(
|
||||
def get_access_rules(
|
||||
request: Request,
|
||||
roleLabel: Optional[str] = Query(None, description="Filter by role label"),
|
||||
context: Optional[str] = Query(None, description="Filter by context (DATA, UI, RESOURCE)"),
|
||||
|
|
@ -387,7 +389,7 @@ async def get_access_rules(
|
|||
|
||||
@router.get("/rules/by-role/{roleId}", response_model=PaginatedResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_access_rules_by_role(
|
||||
def get_access_rules_by_role(
|
||||
request: Request,
|
||||
roleId: str = Path(..., description="Role ID to get rules for"),
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
|
|
@ -405,14 +407,8 @@ async def get_access_rules_by_role(
|
|||
try:
|
||||
interface = getRootInterface()
|
||||
|
||||
# Build filter for roleId
|
||||
recordFilter = {"roleId": roleId}
|
||||
|
||||
# Get rules from database
|
||||
rules = interface.db.getRecordset(AccessRule, recordFilter=recordFilter)
|
||||
|
||||
# Convert to AccessRule objects
|
||||
ruleObjects = [AccessRule(**rule) for rule in rules]
|
||||
# Get rules from database using interface method
|
||||
ruleObjects = interface.getAccessRulesByRole(roleId)
|
||||
|
||||
return PaginatedResponse(
|
||||
items=[rule.model_dump() for rule in ruleObjects],
|
||||
|
|
@ -431,7 +427,7 @@ async def get_access_rules_by_role(
|
|||
|
||||
@router.get("/rules/{ruleId}", response_model=dict)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_access_rule(
|
||||
def get_access_rule(
|
||||
request: Request,
|
||||
ruleId: str = Path(..., description="Access rule ID"),
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
|
|
@ -473,7 +469,7 @@ async def get_access_rule(
|
|||
|
||||
@router.post("/rules", response_model=dict)
|
||||
@limiter.limit("30/minute")
|
||||
async def create_access_rule(
|
||||
def create_access_rule(
|
||||
request: Request,
|
||||
accessRuleData: dict = Body(..., description="Access rule data"),
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
|
|
@ -539,7 +535,7 @@ async def create_access_rule(
|
|||
|
||||
@router.put("/rules/{ruleId}", response_model=dict)
|
||||
@limiter.limit("30/minute")
|
||||
async def update_access_rule(
|
||||
def update_access_rule(
|
||||
request: Request,
|
||||
ruleId: str = Path(..., description="Access rule ID"),
|
||||
accessRuleData: dict = Body(..., description="Updated access rule data"),
|
||||
|
|
@ -622,7 +618,7 @@ async def update_access_rule(
|
|||
|
||||
@router.delete("/rules/{ruleId}")
|
||||
@limiter.limit("30/minute")
|
||||
async def delete_access_rule(
|
||||
def delete_access_rule(
|
||||
request: Request,
|
||||
ruleId: str = Path(..., description="Access rule ID"),
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
|
|
@ -680,7 +676,7 @@ async def delete_access_rule(
|
|||
|
||||
@router.get("/roles", response_model=PaginatedResponse)
|
||||
@limiter.limit("60/minute")
|
||||
async def list_roles(
|
||||
def list_roles(
|
||||
request: Request,
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
includeTemplates: bool = Query(False, description="Include feature template roles"),
|
||||
|
|
@ -849,7 +845,7 @@ async def list_roles(
|
|||
|
||||
@router.get("/roles/options", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("60/minute")
|
||||
async def get_role_options(
|
||||
def get_role_options(
|
||||
request: Request,
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
) -> List[Dict[str, Any]]:
|
||||
|
|
@ -890,7 +886,7 @@ async def get_role_options(
|
|||
|
||||
@router.post("/roles", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def create_role(
|
||||
def create_role(
|
||||
request: Request,
|
||||
role: Role = Body(...),
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
|
|
@ -939,7 +935,7 @@ async def create_role(
|
|||
|
||||
@router.get("/roles/{roleId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("60/minute")
|
||||
async def get_role(
|
||||
def get_role(
|
||||
request: Request,
|
||||
roleId: str = Path(..., description="Role ID"),
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
|
|
@ -986,7 +982,7 @@ async def get_role(
|
|||
|
||||
@router.put("/roles/{roleId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def update_role(
|
||||
def update_role(
|
||||
request: Request,
|
||||
roleId: str = Path(..., description="Role ID"),
|
||||
role: Role = Body(...),
|
||||
|
|
@ -1039,7 +1035,7 @@ async def update_role(
|
|||
|
||||
@router.delete("/roles/{roleId}", response_model=Dict[str, str])
|
||||
@limiter.limit("30/minute")
|
||||
async def delete_role(
|
||||
def delete_role(
|
||||
request: Request,
|
||||
roleId: str = Path(..., description="Role ID"),
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
|
|
@ -1089,7 +1085,7 @@ async def delete_role(
|
|||
|
||||
@router.get("/catalog/objects", response_model=Dict[str, Any])
|
||||
@limiter.limit("60/minute")
|
||||
async def getCatalogObjects(
|
||||
def getCatalogObjects(
|
||||
request: Request,
|
||||
context: Optional[str] = Query(None, description="Filter by context (DATA, UI, RESOURCE)"),
|
||||
featureCode: Optional[str] = Query(None, description="Filter by feature code"),
|
||||
|
|
@ -1128,13 +1124,9 @@ async def getCatalogObjects(
|
|||
if mandateId:
|
||||
try:
|
||||
interface = getRootInterface()
|
||||
# Get all feature instances for this mandate
|
||||
from modules.datamodels.datamodelFeatures import FeatureInstance
|
||||
instances = interface.db.getRecordset(
|
||||
FeatureInstance,
|
||||
recordFilter={"mandateId": mandateId, "enabled": True}
|
||||
)
|
||||
activeFeatures = set(inst.get("featureCode") for inst in instances)
|
||||
# Get all feature instances for this mandate using interface method
|
||||
instances = interface.getFeatureInstancesByMandate(mandateId, enabledOnly=True)
|
||||
activeFeatures = set(inst.featureCode for inst in instances)
|
||||
# Always include "system" feature
|
||||
activeFeatures.add("system")
|
||||
except Exception as e:
|
||||
|
|
@ -1185,7 +1177,7 @@ async def getCatalogObjects(
|
|||
|
||||
@router.get("/catalog/stats", response_model=Dict[str, Any])
|
||||
@limiter.limit("60/minute")
|
||||
async def getCatalogStats(
|
||||
def getCatalogStats(
|
||||
request: Request,
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
) -> Dict[str, Any]:
|
||||
|
|
@ -1207,3 +1199,101 @@ async def getCatalogStats(
|
|||
status_code=500,
|
||||
detail=f"Failed to get catalog stats: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CLEANUP: Remove duplicate AccessRules
|
||||
# =============================================================================
|
||||
|
||||
@router.post("/cleanup/duplicate-rules", response_model=dict)
|
||||
@limiter.limit("5/minute")
|
||||
def cleanup_duplicate_access_rules(
|
||||
request: Request,
|
||||
dryRun: bool = Query(True, description="If true, only report duplicates without deleting"),
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
) -> dict:
|
||||
"""
|
||||
Find and remove duplicate AccessRules.
|
||||
|
||||
Duplicates are rules with the same (roleId, context, item) signature.
|
||||
Only the first rule (oldest) is kept, all others are deleted.
|
||||
|
||||
Query Parameters:
|
||||
- dryRun: If true (default), only report what would be deleted. Set to false to actually delete.
|
||||
|
||||
Returns:
|
||||
- Summary with counts and details of duplicates found/removed
|
||||
"""
|
||||
try:
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Get ALL AccessRules from DB
|
||||
allRules = rootInterface.db.getRecordset(AccessRule)
|
||||
|
||||
# Group by signature (roleId, context, item)
|
||||
rulesBySignature: Dict[tuple, list] = {}
|
||||
for rule in allRules:
|
||||
context = rule.get("context", "")
|
||||
# Normalize context enum value
|
||||
if hasattr(context, 'value'):
|
||||
context = context.value
|
||||
sig = (rule.get("roleId"), str(context), rule.get("item"))
|
||||
if sig not in rulesBySignature:
|
||||
rulesBySignature[sig] = []
|
||||
rulesBySignature[sig].append(rule)
|
||||
|
||||
# Find duplicates and collect IDs to delete
|
||||
duplicateGroups = []
|
||||
idsToDelete = []
|
||||
|
||||
for sig, rules in rulesBySignature.items():
|
||||
if len(rules) > 1:
|
||||
# Sort by creation time (keep oldest)
|
||||
rules.sort(key=lambda r: r.get("_createdAt", 0))
|
||||
keepRule = rules[0]
|
||||
deleteRules = rules[1:]
|
||||
|
||||
duplicateGroups.append({
|
||||
"roleId": sig[0],
|
||||
"context": sig[1],
|
||||
"item": sig[2] or "(global)",
|
||||
"totalCount": len(rules),
|
||||
"keepId": keepRule.get("id"),
|
||||
"deleteCount": len(deleteRules),
|
||||
"deleteIds": [r.get("id") for r in deleteRules]
|
||||
})
|
||||
|
||||
idsToDelete.extend([r.get("id") for r in deleteRules])
|
||||
|
||||
# Perform deletion if not dry run
|
||||
deletedCount = 0
|
||||
if not dryRun and idsToDelete:
|
||||
for ruleId in idsToDelete:
|
||||
try:
|
||||
rootInterface.db.recordDelete(AccessRule, ruleId)
|
||||
deletedCount += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete rule {ruleId}: {e}")
|
||||
|
||||
result = {
|
||||
"dryRun": dryRun,
|
||||
"totalRules": len(allRules),
|
||||
"uniqueSignatures": len(rulesBySignature),
|
||||
"duplicateGroups": len(duplicateGroups),
|
||||
"duplicateRulesToDelete": len(idsToDelete),
|
||||
"deletedCount": deletedCount,
|
||||
"details": duplicateGroups[:50] # Limit details to 50 groups
|
||||
}
|
||||
|
||||
logger.info(f"AccessRule cleanup: dryRun={dryRun}, total={len(allRules)}, "
|
||||
f"duplicateGroups={len(duplicateGroups)}, toDelete={len(idsToDelete)}, "
|
||||
f"deleted={deletedCount}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during AccessRule cleanup: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to cleanup duplicate rules: {str(e)}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -47,11 +47,15 @@ def _getAccessLevelLabel(level: Optional[str]) -> str:
|
|||
return labels.get(level, "-")
|
||||
|
||||
|
||||
def _getRoleScope(role: Dict[str, Any]) -> str:
|
||||
"""Determine the scope of a role."""
|
||||
if role.get("featureInstanceId"):
|
||||
def _getRoleScope(role) -> str:
|
||||
"""Determine the scope of a role. Accepts Role object or dict."""
|
||||
# Support both Pydantic models and dicts
|
||||
featureInstanceId = getattr(role, 'featureInstanceId', None) or (role.get("featureInstanceId") if isinstance(role, dict) else None)
|
||||
mandateId = getattr(role, 'mandateId', None) or (role.get("mandateId") if isinstance(role, dict) else None)
|
||||
|
||||
if featureInstanceId:
|
||||
return "instance"
|
||||
elif role.get("mandateId"):
|
||||
elif mandateId:
|
||||
return "mandate"
|
||||
else:
|
||||
return "global"
|
||||
|
|
@ -65,7 +69,7 @@ def _getRoleScopePriority(scope: str) -> int:
|
|||
|
||||
@router.get("/users", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("60/minute")
|
||||
async def listUsersForOverview(
|
||||
def listUsersForOverview(
|
||||
request: Request,
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
) -> List[Dict[str, Any]]:
|
||||
|
|
@ -79,18 +83,18 @@ async def listUsersForOverview(
|
|||
try:
|
||||
interface = getRootInterface()
|
||||
|
||||
# Get all users
|
||||
allUsersData = interface.db.getRecordset(UserInDB)
|
||||
# Get all users using interface method
|
||||
allUsers = interface.getAllUsers()
|
||||
|
||||
result = []
|
||||
for u in allUsersData:
|
||||
for u in allUsers:
|
||||
result.append({
|
||||
"id": u.get("id"),
|
||||
"username": u.get("username"),
|
||||
"email": u.get("email"),
|
||||
"fullName": u.get("fullName"),
|
||||
"isSysAdmin": u.get("isSysAdmin", False),
|
||||
"enabled": u.get("enabled", True),
|
||||
"id": u.id,
|
||||
"username": u.username,
|
||||
"email": u.email,
|
||||
"fullName": u.fullName,
|
||||
"isSysAdmin": u.isSysAdmin,
|
||||
"enabled": u.enabled,
|
||||
})
|
||||
|
||||
# Sort by username
|
||||
|
|
@ -108,7 +112,7 @@ async def listUsersForOverview(
|
|||
|
||||
@router.get("/{userId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("60/minute")
|
||||
async def getUserAccessOverview(
|
||||
def getUserAccessOverview(
|
||||
request: Request,
|
||||
userId: str = Path(..., description="User ID to get access overview for"),
|
||||
mandateId: Optional[str] = Query(None, description="Filter by mandate ID"),
|
||||
|
|
@ -172,47 +176,43 @@ async def getUserAccessOverview(
|
|||
allRoles = []
|
||||
roleIdToInfo = {} # Map roleId to role info for later reference
|
||||
|
||||
# Get mandates for this user
|
||||
mandateFilter = {"userId": userId, "enabled": True}
|
||||
# Get mandates for this user using interface method
|
||||
allUserMandates = interface.getUserMandates(userId)
|
||||
# Filter by enabled and optionally mandateId
|
||||
userMandates = [um for um in allUserMandates if um.enabled]
|
||||
if mandateId:
|
||||
mandateFilter["mandateId"] = mandateId
|
||||
|
||||
userMandates = interface.db.getRecordset(UserMandate, recordFilter=mandateFilter)
|
||||
userMandates = [um for um in userMandates if um.mandateId == mandateId]
|
||||
|
||||
mandatesInfo = []
|
||||
for um in userMandates:
|
||||
umId = um.get("id")
|
||||
umMandateId = um.get("mandateId")
|
||||
umId = um.id
|
||||
umMandateId = um.mandateId
|
||||
|
||||
# Get mandate name
|
||||
mandate = interface.getMandate(umMandateId)
|
||||
mandateName = mandate.name if mandate else umMandateId
|
||||
|
||||
# Get roles for this UserMandate
|
||||
umRoles = interface.db.getRecordset(
|
||||
UserMandateRole,
|
||||
recordFilter={"userMandateId": umId}
|
||||
)
|
||||
# Get roles for this UserMandate using interface method
|
||||
umRoles = interface.getUserMandateRoles(umId)
|
||||
|
||||
mandateRoleIds = []
|
||||
for umr in umRoles:
|
||||
roleId = umr.get("roleId")
|
||||
roleId = umr.roleId
|
||||
if roleId:
|
||||
mandateRoleIds.append(roleId)
|
||||
|
||||
# Get role details
|
||||
roleRecords = interface.db.getRecordset(Role, recordFilter={"id": roleId})
|
||||
if roleRecords:
|
||||
role = roleRecords[0]
|
||||
# Get role details using interface method
|
||||
role = interface.getRole(roleId)
|
||||
if role:
|
||||
scope = _getRoleScope(role)
|
||||
roleInfo = {
|
||||
"id": roleId,
|
||||
"roleLabel": role.get("roleLabel"),
|
||||
"description": role.get("description", {}),
|
||||
"roleLabel": role.roleLabel,
|
||||
"description": role.description or {},
|
||||
"scope": scope,
|
||||
"scopePriority": _getRoleScopePriority(scope),
|
||||
"mandateId": role.get("mandateId"),
|
||||
"featureInstanceId": role.get("featureInstanceId"),
|
||||
"mandateId": role.mandateId,
|
||||
"featureInstanceId": role.featureInstanceId,
|
||||
"source": "mandate",
|
||||
"sourceMandateId": umMandateId,
|
||||
"sourceMandateName": mandateName,
|
||||
|
|
@ -220,69 +220,59 @@ async def getUserAccessOverview(
|
|||
allRoles.append(roleInfo)
|
||||
roleIdToInfo[roleId] = roleInfo
|
||||
|
||||
# Get feature instances for this mandate
|
||||
featureInstanceFilter = {"userId": userId, "enabled": True}
|
||||
featureAccesses = interface.db.getRecordset(FeatureAccess, recordFilter=featureInstanceFilter)
|
||||
# Get feature instances for this mandate using interface method
|
||||
allFeatureAccesses = interface.getFeatureAccessesForUser(userId)
|
||||
featureAccesses = [fa for fa in allFeatureAccesses if fa.enabled]
|
||||
|
||||
featureInstancesInfo = []
|
||||
for fa in featureAccesses:
|
||||
faId = fa.get("id")
|
||||
faInstanceId = fa.get("featureInstanceId")
|
||||
faId = fa.id
|
||||
faInstanceId = fa.featureInstanceId
|
||||
|
||||
# Check if instance belongs to this mandate
|
||||
instance = interface.db.getRecordset(FeatureInstance, recordFilter={"id": faInstanceId})
|
||||
# Check if instance belongs to this mandate using interface method
|
||||
instance = interface.getFeatureInstance(faInstanceId)
|
||||
if not instance:
|
||||
continue
|
||||
instance = instance[0]
|
||||
|
||||
if instance.get("mandateId") != umMandateId:
|
||||
if instance.mandateId != umMandateId:
|
||||
continue
|
||||
|
||||
# Filter by featureInstanceId if specified
|
||||
if featureInstanceId and faInstanceId != featureInstanceId:
|
||||
continue
|
||||
|
||||
# Get feature info
|
||||
featureCode = instance.get("featureCode")
|
||||
featureRecords = interface.db.getRecordset(Feature, recordFilter={"code": featureCode})
|
||||
featureLabel = featureRecords[0].get("label", {}) if featureRecords else {}
|
||||
# Get feature info using interface method
|
||||
featureCode = instance.featureCode
|
||||
feature = interface.getFeatureByCode(featureCode)
|
||||
featureLabel = feature.label if feature else {}
|
||||
|
||||
# Get roles for this FeatureAccess
|
||||
faRoles = interface.db.getRecordset(
|
||||
FeatureAccessRole,
|
||||
recordFilter={"featureAccessId": faId}
|
||||
)
|
||||
# Get roles for this FeatureAccess using interface method
|
||||
instanceRoleIds = interface.getRoleIdsForFeatureAccess(faId)
|
||||
|
||||
instanceRoleIds = []
|
||||
for far in faRoles:
|
||||
roleId = far.get("roleId")
|
||||
if roleId:
|
||||
instanceRoleIds.append(roleId)
|
||||
|
||||
# Get role details (if not already added)
|
||||
if roleId not in roleIdToInfo:
|
||||
roleRecords = interface.db.getRecordset(Role, recordFilter={"id": roleId})
|
||||
if roleRecords:
|
||||
role = roleRecords[0]
|
||||
scope = _getRoleScope(role)
|
||||
roleInfo = {
|
||||
"id": roleId,
|
||||
"roleLabel": role.get("roleLabel"),
|
||||
"description": role.get("description", {}),
|
||||
"scope": scope,
|
||||
"scopePriority": _getRoleScopePriority(scope),
|
||||
"mandateId": role.get("mandateId"),
|
||||
"featureInstanceId": role.get("featureInstanceId"),
|
||||
"source": "featureInstance",
|
||||
"sourceInstanceId": faInstanceId,
|
||||
"sourceInstanceLabel": instance.get("label"),
|
||||
}
|
||||
allRoles.append(roleInfo)
|
||||
roleIdToInfo[roleId] = roleInfo
|
||||
for roleId in instanceRoleIds:
|
||||
# Get role details (if not already added)
|
||||
if roleId not in roleIdToInfo:
|
||||
role = interface.getRole(roleId)
|
||||
if role:
|
||||
scope = _getRoleScope(role)
|
||||
roleInfo = {
|
||||
"id": roleId,
|
||||
"roleLabel": role.roleLabel,
|
||||
"description": role.description or {},
|
||||
"scope": scope,
|
||||
"scopePriority": _getRoleScopePriority(scope),
|
||||
"mandateId": role.mandateId,
|
||||
"featureInstanceId": role.featureInstanceId,
|
||||
"source": "featureInstance",
|
||||
"sourceInstanceId": faInstanceId,
|
||||
"sourceInstanceLabel": instance.label,
|
||||
}
|
||||
allRoles.append(roleInfo)
|
||||
roleIdToInfo[roleId] = roleInfo
|
||||
|
||||
featureInstancesInfo.append({
|
||||
"id": faInstanceId,
|
||||
"label": instance.get("label"),
|
||||
"label": instance.label,
|
||||
"featureCode": featureCode,
|
||||
"featureLabel": featureLabel,
|
||||
"roleIds": instanceRoleIds,
|
||||
|
|
@ -317,12 +307,12 @@ async def getUserAccessOverview(
|
|||
roleLabel = roleInfo.get("roleLabel", "unknown")
|
||||
roleScope = roleInfo.get("scope", "unknown")
|
||||
|
||||
# Get all rules for this role
|
||||
rules = interface.db.getRecordset(AccessRule, recordFilter={"roleId": roleId})
|
||||
# Get all rules for this role using interface method
|
||||
rules = interface.getAccessRulesByRole(roleId)
|
||||
|
||||
for rule in rules:
|
||||
context = rule.get("context")
|
||||
item = rule.get("item")
|
||||
context = rule.context
|
||||
item = rule.item
|
||||
|
||||
accessEntry = {
|
||||
"item": item or "(all)",
|
||||
|
|
@ -333,20 +323,20 @@ async def getUserAccessOverview(
|
|||
}
|
||||
|
||||
if context == "UI":
|
||||
accessEntry["view"] = rule.get("view", False)
|
||||
accessEntry["view"] = rule.view if rule.view is not None else False
|
||||
if accessEntry["view"]:
|
||||
uiAccess.append(accessEntry)
|
||||
|
||||
elif context == "DATA":
|
||||
accessEntry["view"] = rule.get("view", False)
|
||||
accessEntry["read"] = _getAccessLevelLabel(rule.get("read"))
|
||||
accessEntry["create"] = _getAccessLevelLabel(rule.get("create"))
|
||||
accessEntry["update"] = _getAccessLevelLabel(rule.get("update"))
|
||||
accessEntry["delete"] = _getAccessLevelLabel(rule.get("delete"))
|
||||
accessEntry["view"] = rule.view if rule.view is not None else False
|
||||
accessEntry["read"] = _getAccessLevelLabel(rule.read)
|
||||
accessEntry["create"] = _getAccessLevelLabel(rule.create)
|
||||
accessEntry["update"] = _getAccessLevelLabel(rule.update)
|
||||
accessEntry["delete"] = _getAccessLevelLabel(rule.delete)
|
||||
dataAccess.append(accessEntry)
|
||||
|
||||
elif context == "RESOURCE":
|
||||
accessEntry["view"] = rule.get("view", False)
|
||||
accessEntry["view"] = rule.view if rule.view is not None else False
|
||||
if accessEntry["view"]:
|
||||
resourceAccess.append(accessEntry)
|
||||
|
||||
|
|
@ -420,7 +410,7 @@ async def getUserAccessOverview(
|
|||
|
||||
@router.get("/{userId}/effective-permissions", response_model=Dict[str, Any])
|
||||
@limiter.limit("60/minute")
|
||||
async def getEffectivePermissions(
|
||||
def getEffectivePermissions(
|
||||
request: Request,
|
||||
userId: str = Path(..., description="User ID"),
|
||||
mandateId: str = Query(..., description="Mandate ID context"),
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ router = APIRouter(
|
|||
|
||||
@router.get("/{entityType}", response_model=AttributeResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_entity_attributes(
|
||||
def get_entity_attributes(
|
||||
request: Request,
|
||||
entityType: str = Path(..., description="Type of entity (e.g. prompt)")
|
||||
) -> AttributeResponse:
|
||||
|
|
@ -76,7 +76,7 @@ async def get_entity_attributes(
|
|||
|
||||
@router.options("/{entityType}")
|
||||
@limiter.limit("60/minute")
|
||||
async def options_entity_attributes(
|
||||
def options_entity_attributes(
|
||||
request: Request,
|
||||
entityType: str = Path(..., description="Type of entity (e.g. prompt)")
|
||||
) -> Response:
|
||||
|
|
|
|||
1257
modules/routes/routeBilling.py
Normal file
1257
modules/routes/routeBilling.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -1,128 +0,0 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Chat Playground routes for the backend API.
|
||||
Implements the endpoints for chat playground workflow management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request
|
||||
|
||||
# Import auth modules
|
||||
from modules.auth import limiter, getRequestContext, RequestContext
|
||||
|
||||
# Import interfaces
|
||||
from modules.interfaces import interfaceDbChat
|
||||
|
||||
# Import models
|
||||
from modules.datamodels.datamodelChat import ChatWorkflow, UserInputRequest, WorkflowModeEnum
|
||||
|
||||
# Import workflow control functions
|
||||
from modules.workflows.automation import chatStart, chatStop
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create router for chat playground endpoints
|
||||
router = APIRouter(
|
||||
prefix="/api/chat/playground",
|
||||
tags=["Chat Playground"],
|
||||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
|
||||
def _getServiceChat(context: RequestContext):
|
||||
return interfaceDbChat.getInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
|
||||
|
||||
# Workflow start endpoint
|
||||
@router.post("/start", response_model=ChatWorkflow)
|
||||
@limiter.limit("120/minute")
|
||||
async def start_workflow(
|
||||
request: Request,
|
||||
workflowId: Optional[str] = Query(None, description="Optional ID of the workflow to continue"),
|
||||
workflowMode: WorkflowModeEnum = Query(..., description="Workflow mode: 'Dynamic' or 'Automation' (mandatory)"),
|
||||
userInput: UserInputRequest = Body(...),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> ChatWorkflow:
|
||||
"""
|
||||
Starts a new workflow or continues an existing one.
|
||||
Corresponds to State 1 in the state machine documentation.
|
||||
|
||||
Args:
|
||||
workflowMode: "Dynamic" for iterative dynamic-style processing, "Automation" for automated workflow execution
|
||||
"""
|
||||
try:
|
||||
# Start or continue workflow using playground controller
|
||||
mandateId = str(context.mandateId) if context.mandateId else None
|
||||
workflow = await chatStart(context.user, userInput, workflowMode, workflowId, mandateId=mandateId)
|
||||
|
||||
return workflow
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in start_workflow: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
# State 8: Workflow Stopped endpoint
|
||||
@router.post("/{workflowId}/stop", response_model=ChatWorkflow)
|
||||
@limiter.limit("120/minute")
|
||||
async def stop_workflow(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow to stop"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> ChatWorkflow:
|
||||
"""Stops a running workflow."""
|
||||
try:
|
||||
# Stop workflow using playground controller
|
||||
mandateId = str(context.mandateId) if context.mandateId else None
|
||||
workflow = await chatStop(context.user, workflowId, mandateId=mandateId)
|
||||
|
||||
return workflow
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stop_workflow: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
# Unified Chat Data Endpoint for Polling
|
||||
@router.get("/{workflowId}/chatData")
|
||||
@limiter.limit("120/minute")
|
||||
async def get_workflow_chat_data(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
afterTimestamp: Optional[float] = Query(None, description="Unix timestamp to get data after"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get unified chat data (messages, logs, stats) for a workflow with timestamp-based selective data transfer.
|
||||
Returns all data types in chronological order based on _createdAt timestamp.
|
||||
"""
|
||||
try:
|
||||
# Get service center
|
||||
interfaceDbChat = _getServiceChat(context)
|
||||
|
||||
# Verify workflow exists
|
||||
workflow = interfaceDbChat.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Workflow with ID {workflowId} not found"
|
||||
)
|
||||
|
||||
# Get unified chat data using the new method
|
||||
chatData = interfaceDbChat.getUnifiedChatData(workflowId, afterTimestamp)
|
||||
|
||||
return chatData
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting unified chat data: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error getting unified chat data: {str(e)}"
|
||||
)
|
||||
|
|
@ -43,30 +43,14 @@ def getTokenStatusForConnection(interface, connectionId: str) -> tuple[str, Opti
|
|||
- tokenExpiresAt: UTC timestamp or None
|
||||
"""
|
||||
try:
|
||||
# Query tokens table for the latest token for this connection
|
||||
tokens = interface.db.getRecordset(
|
||||
Token,
|
||||
recordFilter={"connectionId": connectionId}
|
||||
)
|
||||
|
||||
if not tokens:
|
||||
return "none", None
|
||||
|
||||
# Find the most recent token (highest createdAt timestamp)
|
||||
latestToken = None
|
||||
latestCreatedAt = 0
|
||||
|
||||
for tokenData in tokens:
|
||||
createdAt = parseTimestamp(tokenData.get("createdAt"), default=0)
|
||||
if createdAt > latestCreatedAt:
|
||||
latestCreatedAt = createdAt
|
||||
latestToken = tokenData
|
||||
# Query tokens table for the latest token for this connection using interface method
|
||||
latestToken = interface.getConnectionToken(connectionId)
|
||||
|
||||
if not latestToken:
|
||||
return "none", None
|
||||
|
||||
# Check if token is expired
|
||||
expiresAt = parseTimestamp(latestToken.get("expiresAt"))
|
||||
expiresAt = parseTimestamp(latestToken.expiresAt)
|
||||
if not expiresAt:
|
||||
return "none", None
|
||||
|
||||
|
|
@ -100,7 +84,7 @@ router = APIRouter(
|
|||
|
||||
@router.get("/statuses/options", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("60/minute")
|
||||
async def get_connection_status_options(
|
||||
def get_connection_status_options(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[Dict[str, Any]]:
|
||||
|
|
@ -116,7 +100,7 @@ async def get_connection_status_options(
|
|||
|
||||
@router.get("/authorities/options", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("60/minute")
|
||||
async def get_auth_authority_options(
|
||||
def get_auth_authority_options(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[Dict[str, Any]]:
|
||||
|
|
@ -304,7 +288,7 @@ async def get_connections(
|
|||
|
||||
@router.post("/", response_model=UserConnection)
|
||||
@limiter.limit("10/minute")
|
||||
async def create_connection(
|
||||
def create_connection(
|
||||
request: Request,
|
||||
connection_data: Dict[str, Any] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -360,7 +344,7 @@ async def create_connection(
|
|||
|
||||
@router.put("/{connectionId}", response_model=UserConnection)
|
||||
@limiter.limit("10/minute")
|
||||
async def update_connection(
|
||||
def update_connection(
|
||||
request: Request,
|
||||
connectionId: str = Path(..., description="The ID of the connection to update"),
|
||||
connection_data: Dict[str, Any] = Body(...),
|
||||
|
|
@ -432,7 +416,7 @@ async def update_connection(
|
|||
|
||||
@router.post("/{connectionId}/connect")
|
||||
@limiter.limit("10/minute")
|
||||
async def connect_service(
|
||||
def connect_service(
|
||||
request: Request,
|
||||
connectionId: str = Path(..., description="The ID of the connection to connect"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -498,7 +482,7 @@ async def connect_service(
|
|||
|
||||
@router.post("/{connectionId}/disconnect")
|
||||
@limiter.limit("10/minute")
|
||||
async def disconnect_service(
|
||||
def disconnect_service(
|
||||
request: Request,
|
||||
connectionId: str = Path(..., description="The ID of the connection to disconnect"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -548,7 +532,7 @@ async def disconnect_service(
|
|||
|
||||
@router.delete("/{connectionId}")
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_connection(
|
||||
def delete_connection(
|
||||
request: Request,
|
||||
connectionId: str = Path(..., description="The ID of the connection to delete"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ router = APIRouter(
|
|||
|
||||
@router.get("/list", response_model=PaginatedResponse[FileItem])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_files(
|
||||
def get_files(
|
||||
request: Request,
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -168,7 +168,7 @@ async def upload_file(
|
|||
|
||||
@router.get("/{fileId}", response_model=FileItem)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_file(
|
||||
def get_file(
|
||||
request: Request,
|
||||
fileId: str = Path(..., description="ID of the file"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -214,7 +214,7 @@ async def get_file(
|
|||
|
||||
@router.put("/{fileId}", response_model=FileItem)
|
||||
@limiter.limit("10/minute")
|
||||
async def update_file(
|
||||
def update_file(
|
||||
request: Request,
|
||||
fileId: str = Path(..., description="ID of the file to update"),
|
||||
file_info: Dict[str, Any] = Body(...),
|
||||
|
|
@ -262,7 +262,7 @@ async def update_file(
|
|||
|
||||
@router.delete("/{fileId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_file(
|
||||
def delete_file(
|
||||
request: Request,
|
||||
fileId: str = Path(..., description="ID of the file to delete"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -289,7 +289,7 @@ async def delete_file(
|
|||
|
||||
@router.get("/stats", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_file_stats(
|
||||
def get_file_stats(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
|
|
@ -327,7 +327,7 @@ async def get_file_stats(
|
|||
|
||||
@router.get("/{fileId}/download")
|
||||
@limiter.limit("30/minute")
|
||||
async def download_file(
|
||||
def download_file(
|
||||
request: Request,
|
||||
fileId: str = Path(..., description="ID of the file to download"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -375,7 +375,7 @@ async def download_file(
|
|||
|
||||
@router.get("/{fileId}/preview", response_model=FilePreview)
|
||||
@limiter.limit("30/minute")
|
||||
async def preview_file(
|
||||
def preview_file(
|
||||
request: Request,
|
||||
fileId: str = Path(..., description="ID of the file to preview"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ router = APIRouter(
|
|||
|
||||
@router.get("/", response_model=PaginatedResponse[Mandate])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_mandates(
|
||||
def get_mandates(
|
||||
request: Request,
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
|
|
@ -140,7 +140,7 @@ async def get_mandates(
|
|||
|
||||
@router.get("/{mandateId}", response_model=Mandate)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_mandate(
|
||||
def get_mandate(
|
||||
request: Request,
|
||||
mandateId: str = Path(..., description="ID of the mandate"),
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
|
|
@ -171,7 +171,7 @@ async def get_mandate(
|
|||
|
||||
@router.post("/", response_model=Mandate)
|
||||
@limiter.limit("10/minute")
|
||||
async def create_mandate(
|
||||
def create_mandate(
|
||||
request: Request,
|
||||
mandateData: dict = Body(..., description="Mandate data with at least 'name' field"),
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
|
|
@ -192,7 +192,7 @@ async def create_mandate(
|
|||
)
|
||||
|
||||
# Get optional fields with defaults
|
||||
description = mandateData.get('description')
|
||||
label = mandateData.get('label')
|
||||
enabled = mandateData.get('enabled', True)
|
||||
|
||||
appInterface = interfaceDbApp.getRootInterface()
|
||||
|
|
@ -200,7 +200,7 @@ async def create_mandate(
|
|||
# Create mandate
|
||||
newMandate = appInterface.createMandate(
|
||||
name=name,
|
||||
description=description,
|
||||
label=label,
|
||||
enabled=enabled
|
||||
)
|
||||
|
||||
|
|
@ -224,7 +224,7 @@ async def create_mandate(
|
|||
|
||||
@router.put("/{mandateId}", response_model=Mandate)
|
||||
@limiter.limit("10/minute")
|
||||
async def update_mandate(
|
||||
def update_mandate(
|
||||
request: Request,
|
||||
mandateId: str = Path(..., description="ID of the mandate to update"),
|
||||
mandateData: dict = Body(..., description="Mandate update data"),
|
||||
|
|
@ -270,7 +270,7 @@ async def update_mandate(
|
|||
|
||||
@router.delete("/{mandateId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_mandate(
|
||||
def delete_mandate(
|
||||
request: Request,
|
||||
mandateId: str = Path(..., description="ID of the mandate to delete"),
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
|
|
@ -291,9 +291,9 @@ async def delete_mandate(
|
|||
)
|
||||
|
||||
# MULTI-TENANT: Delete all UserMandate entries for this mandate first
|
||||
userMandates = appInterface.db.getRecordset(UserMandate, recordFilter={"mandateId": mandateId})
|
||||
userMandates = appInterface.getUserMandatesByMandate(mandateId)
|
||||
for um in userMandates:
|
||||
appInterface.db.deleteRecord(UserMandate, um["id"])
|
||||
appInterface.deleteUserMandate(str(um.userId), mandateId)
|
||||
logger.info(f"Deleted {len(userMandates)} UserMandate entries for mandate {mandateId}")
|
||||
|
||||
# Delete mandate
|
||||
|
|
@ -324,7 +324,7 @@ async def delete_mandate(
|
|||
|
||||
@router.get("/{targetMandateId}/users")
|
||||
@limiter.limit("60/minute")
|
||||
async def list_mandate_users(
|
||||
def list_mandate_users(
|
||||
request: Request,
|
||||
targetMandateId: str = Path(..., description="ID of the mandate"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
|
|
@ -377,39 +377,46 @@ async def list_mandate_users(
|
|||
)
|
||||
|
||||
# Get all UserMandate entries for this mandate
|
||||
userMandates = rootInterface.db.getRecordset(
|
||||
UserMandate,
|
||||
recordFilter={"mandateId": targetMandateId}
|
||||
)
|
||||
userMandates = rootInterface.getUserMandatesByMandate(targetMandateId)
|
||||
|
||||
result = []
|
||||
for um in userMandates:
|
||||
# Get user info
|
||||
user = rootInterface.getUser(um.get("userId"))
|
||||
user = rootInterface.getUser(str(um.userId))
|
||||
if not user:
|
||||
continue
|
||||
|
||||
# Get roles for this membership
|
||||
roleIds = rootInterface.getRoleIdsForUserMandate(um.get("id"))
|
||||
roleIds = rootInterface.getRoleIdsForUserMandate(str(um.id))
|
||||
|
||||
# Resolve role labels for display
|
||||
# Resolve role labels for display (only mandate-level roles, deduplicated)
|
||||
roleLabels = []
|
||||
filteredRoleIds = []
|
||||
seenLabels = set()
|
||||
for roleId in roleIds:
|
||||
role = rootInterface.getRole(roleId)
|
||||
if role:
|
||||
roleLabels.append(role.roleLabel)
|
||||
# Skip feature-instance roles - they don't belong in mandate membership
|
||||
if role.featureInstanceId:
|
||||
continue
|
||||
filteredRoleIds.append(roleId)
|
||||
if role.roleLabel not in seenLabels:
|
||||
roleLabels.append(role.roleLabel)
|
||||
seenLabels.add(role.roleLabel)
|
||||
else:
|
||||
roleLabels.append(roleId) # Fallback to ID if not found
|
||||
# Role not found - fail-safe: skip (no access)
|
||||
logger.warning(f"Role {roleId} not found, skipping")
|
||||
continue
|
||||
|
||||
result.append({
|
||||
"id": um.get("id"), # UserMandate ID as primary key
|
||||
"id": str(um.id), # UserMandate ID as primary key
|
||||
"userId": str(user.id),
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"fullName": user.fullName,
|
||||
"roleIds": roleIds,
|
||||
"roleIds": filteredRoleIds,
|
||||
"roleLabels": roleLabels,
|
||||
"enabled": um.get("enabled", True)
|
||||
"enabled": um.enabled
|
||||
})
|
||||
|
||||
# Apply search, filtering, and sorting if pagination requested
|
||||
|
|
@ -486,7 +493,7 @@ async def list_mandate_users(
|
|||
|
||||
@router.post("/{targetMandateId}/users", response_model=UserMandateResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def add_user_to_mandate(
|
||||
def add_user_to_mandate(
|
||||
request: Request,
|
||||
targetMandateId: str = Path(..., description="ID of the mandate"),
|
||||
data: UserMandateCreate = Body(...),
|
||||
|
|
@ -545,18 +552,12 @@ async def add_user_to_mandate(
|
|||
|
||||
# 6. Validate roles (must exist and belong to this mandate or be global)
|
||||
for roleId in data.roleIds:
|
||||
roleRecords = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId})
|
||||
if not roleRecords:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Role {roleId} not found"
|
||||
)
|
||||
role = roleRecords[0]
|
||||
roleMandateId = role.get("mandateId")
|
||||
if roleMandateId and str(roleMandateId) != str(targetMandateId):
|
||||
try:
|
||||
rootInterface.validateRoleForMandate(roleId, targetMandateId)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Role {roleId} belongs to a different mandate"
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
# 7. Create UserMandate
|
||||
|
|
@ -602,7 +603,7 @@ async def add_user_to_mandate(
|
|||
|
||||
@router.delete("/{targetMandateId}/users/{targetUserId}", response_model=Dict[str, str])
|
||||
@limiter.limit("30/minute")
|
||||
async def remove_user_from_mandate(
|
||||
def remove_user_from_mandate(
|
||||
request: Request,
|
||||
targetMandateId: str = Path(..., description="ID of the mandate"),
|
||||
targetUserId: str = Path(..., description="ID of the user to remove"),
|
||||
|
|
@ -680,7 +681,7 @@ async def remove_user_from_mandate(
|
|||
|
||||
@router.put("/{targetMandateId}/users/{targetUserId}/roles", response_model=UserMandateResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def update_user_roles_in_mandate(
|
||||
def update_user_roles_in_mandate(
|
||||
request: Request,
|
||||
targetMandateId: str = Path(..., description="ID of the mandate"),
|
||||
targetUserId: str = Path(..., description="ID of the user"),
|
||||
|
|
@ -718,18 +719,12 @@ async def update_user_roles_in_mandate(
|
|||
|
||||
# Validate new roles
|
||||
for roleId in roleIds:
|
||||
roleRecords = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId})
|
||||
if not roleRecords:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Role {roleId} not found"
|
||||
)
|
||||
role = roleRecords[0]
|
||||
roleMandateId = role.get("mandateId")
|
||||
if roleMandateId and str(roleMandateId) != str(targetMandateId):
|
||||
try:
|
||||
rootInterface.validateRoleForMandate(roleId, targetMandateId)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Role {roleId} belongs to a different mandate"
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
# Check if removing admin role would leave mandate without admins
|
||||
|
|
@ -745,12 +740,7 @@ async def update_user_roles_in_mandate(
|
|||
)
|
||||
|
||||
# Remove existing role assignments
|
||||
existingRoles = rootInterface.db.getRecordset(
|
||||
UserMandateRole,
|
||||
recordFilter={"userMandateId": str(membership.id)}
|
||||
)
|
||||
for er in existingRoles:
|
||||
rootInterface.db.recordDelete(UserMandateRole, er.get("id"))
|
||||
rootInterface.deleteUserMandateRoles(str(membership.id))
|
||||
|
||||
# Add new role assignments
|
||||
for roleId in roleIds:
|
||||
|
|
@ -812,19 +802,17 @@ def _hasMandateAdminRole(context: RequestContext, mandateId: str) -> bool:
|
|||
rootInterface = interfaceDbApp.getRootInterface()
|
||||
|
||||
for roleId in context.roleIds:
|
||||
roleRecords = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId})
|
||||
if roleRecords:
|
||||
role = roleRecords[0]
|
||||
roleLabel = role.get("roleLabel", "")
|
||||
role = rootInterface.getRole(roleId)
|
||||
if role:
|
||||
# Admin role at mandate level (not feature-instance level)
|
||||
if roleLabel == "admin" and role.get("mandateId") and not role.get("featureInstanceId"):
|
||||
if role.roleLabel == "admin" and role.mandateId and not role.featureInstanceId:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking mandate admin role: {e}")
|
||||
return False
|
||||
return False # Fail-safe: no access on error
|
||||
|
||||
|
||||
def _isLastMandateAdmin(interface, mandateId: str, excludeUserId: str) -> bool:
|
||||
|
|
@ -832,19 +820,17 @@ def _isLastMandateAdmin(interface, mandateId: str, excludeUserId: str) -> bool:
|
|||
Check if excluding this user would leave the mandate without any admins.
|
||||
"""
|
||||
try:
|
||||
# Get all UserMandates for this mandate
|
||||
userMandates = interface.db.getRecordset(
|
||||
UserMandate,
|
||||
recordFilter={"mandateId": mandateId, "enabled": True}
|
||||
)
|
||||
# Get all UserMandates for this mandate (Pydantic models)
|
||||
allMandates = interface.getUserMandatesByMandate(mandateId)
|
||||
userMandates = [um for um in allMandates if um.enabled]
|
||||
|
||||
adminCount = 0
|
||||
for um in userMandates:
|
||||
if str(um.get("userId")) == str(excludeUserId):
|
||||
if str(um.userId) == str(excludeUserId):
|
||||
continue
|
||||
|
||||
# Check if this user has admin role
|
||||
roleIds = interface.getRoleIdsForUserMandate(um.get("id"))
|
||||
roleIds = interface.getRoleIdsForUserMandate(str(um.id))
|
||||
if _hasAdminRoleInList(interface, roleIds, mandateId):
|
||||
adminCount += 1
|
||||
|
||||
|
|
@ -852,7 +838,7 @@ def _isLastMandateAdmin(interface, mandateId: str, excludeUserId: str) -> bool:
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking last admin: {e}")
|
||||
return True # Fail-safe: assume they're the last admin
|
||||
return True # Fail-safe: assume they're the last admin (prevents deletion)
|
||||
|
||||
|
||||
def _hasAdminRoleInList(interface, roleIds: List[str], mandateId: str) -> bool:
|
||||
|
|
@ -860,13 +846,10 @@ def _hasAdminRoleInList(interface, roleIds: List[str], mandateId: str) -> bool:
|
|||
Check if any of the role IDs is an admin role for the mandate.
|
||||
"""
|
||||
for roleId in roleIds:
|
||||
roleRecords = interface.db.getRecordset(Role, recordFilter={"id": roleId})
|
||||
if roleRecords:
|
||||
role = roleRecords[0]
|
||||
roleLabel = role.get("roleLabel", "")
|
||||
roleMandateId = role.get("mandateId")
|
||||
# Admin role at mandate level
|
||||
if roleLabel == "admin" and (not roleMandateId or str(roleMandateId) == str(mandateId)):
|
||||
if not role.get("featureInstanceId"):
|
||||
role = interface.getRole(roleId)
|
||||
if role:
|
||||
# Admin role at mandate level (global or mandate-specific, not feature-instance)
|
||||
if role.roleLabel == "admin" and not role.featureInstanceId:
|
||||
if not role.mandateId or str(role.mandateId) == str(mandateId):
|
||||
return True
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ router = APIRouter(
|
|||
|
||||
@router.get("", response_model=PaginatedResponse[Prompt])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_prompts(
|
||||
def get_prompts(
|
||||
request: Request,
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -83,7 +83,7 @@ async def get_prompts(
|
|||
|
||||
@router.post("", response_model=Prompt)
|
||||
@limiter.limit("10/minute")
|
||||
async def create_prompt(
|
||||
def create_prompt(
|
||||
request: Request,
|
||||
prompt: Prompt,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -98,7 +98,7 @@ async def create_prompt(
|
|||
|
||||
@router.get("/{promptId}", response_model=Prompt)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_prompt(
|
||||
def get_prompt(
|
||||
request: Request,
|
||||
promptId: str = Path(..., description="ID of the prompt"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -118,13 +118,13 @@ async def get_prompt(
|
|||
|
||||
@router.put("/{promptId}", response_model=Prompt)
|
||||
@limiter.limit("10/minute")
|
||||
async def update_prompt(
|
||||
def update_prompt(
|
||||
request: Request,
|
||||
promptId: str = Path(..., description="ID of the prompt to update"),
|
||||
promptData: Prompt = Body(...),
|
||||
promptData: Dict[str, Any] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Prompt:
|
||||
"""Update an existing prompt"""
|
||||
"""Update an existing prompt (supports partial updates for inline editing)"""
|
||||
managementInterface = interfaceDbManagement.getInterface(currentUser)
|
||||
|
||||
# Check if the prompt exists
|
||||
|
|
@ -135,14 +135,17 @@ async def update_prompt(
|
|||
detail=f"Prompt with ID {promptId} not found"
|
||||
)
|
||||
|
||||
# Convert Prompt to dict for interface, excluding the id field
|
||||
if hasattr(promptData, "model_dump"):
|
||||
update_data = promptData.model_dump(exclude={"id"})
|
||||
else:
|
||||
update_data = promptData.model_dump(exclude={"id"})
|
||||
# Remove id from update data if present
|
||||
update_data = {k: v for k, v in promptData.items() if k != "id"}
|
||||
|
||||
# Update prompt
|
||||
updatedPrompt = managementInterface.updatePrompt(promptId, update_data)
|
||||
# Update prompt (ownership check happens in interface)
|
||||
try:
|
||||
updatedPrompt = managementInterface.updatePrompt(promptId, update_data)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
if not updatedPrompt:
|
||||
raise HTTPException(
|
||||
|
|
@ -154,7 +157,7 @@ async def update_prompt(
|
|||
|
||||
@router.delete("/{promptId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_prompt(
|
||||
def delete_prompt(
|
||||
request: Request,
|
||||
promptId: str = Path(..., description="ID of the prompt to delete"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -170,7 +173,14 @@ async def delete_prompt(
|
|||
detail=f"Prompt with ID {promptId} not found"
|
||||
)
|
||||
|
||||
success = managementInterface.deletePrompt(promptId)
|
||||
try:
|
||||
success = managementInterface.deletePrompt(promptId)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
|
|
|||
|
|
@ -21,7 +21,8 @@ import modules.interfaces.interfaceDbApp as interfaceDbApp
|
|||
from modules.auth import limiter, getRequestContext, RequestContext
|
||||
|
||||
# Import the attribute definition and helper functions
|
||||
from modules.datamodels.datamodelUam import User, UserInDB
|
||||
from modules.datamodels.datamodelUam import User, UserInDB, AuthAuthority
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||||
|
||||
# Configure logger
|
||||
|
|
@ -152,7 +153,7 @@ router = APIRouter(
|
|||
|
||||
@router.get("/options", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("60/minute")
|
||||
async def get_user_options(
|
||||
def get_user_options(
|
||||
request: Request,
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> List[Dict[str, Any]]:
|
||||
|
|
@ -189,7 +190,7 @@ async def get_user_options(
|
|||
|
||||
@router.get("/", response_model=PaginatedResponse[User])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_users(
|
||||
def get_users(
|
||||
request: Request,
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -251,16 +252,10 @@ async def get_users(
|
|||
)
|
||||
elif context.isSysAdmin:
|
||||
# SysAdmin without mandateId sees all users
|
||||
# Get all users directly from database using UserInDB (the actual database model)
|
||||
allUsers = appInterface.db.getRecordset(UserInDB)
|
||||
# Convert to cleaned dictionaries first for filtering
|
||||
cleanedUsers = []
|
||||
for u in allUsers:
|
||||
cleanedUser = {k: v for k, v in u.items() if not k.startswith("_") and k != "hashedPassword" and k != "resetToken" and k != "resetTokenExpires"}
|
||||
# Ensure roleLabels is always a list
|
||||
if cleanedUser.get("roleLabels") is None:
|
||||
cleanedUser["roleLabels"] = []
|
||||
cleanedUsers.append(cleanedUser)
|
||||
# Get all users via interface method (returns Pydantic User models)
|
||||
allUserModels = appInterface.getAllUsers()
|
||||
# Convert to dictionaries for filtering/sorting
|
||||
cleanedUsers = [u.model_dump() for u in allUserModels]
|
||||
|
||||
# Apply server-side filtering and sorting
|
||||
filteredUsers = _applyFiltersAndSort(cleanedUsers, paginationParams)
|
||||
|
|
@ -309,7 +304,7 @@ async def get_users(
|
|||
|
||||
@router.get("/{userId}", response_model=User)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_user(
|
||||
def get_user(
|
||||
request: Request,
|
||||
userId: str = Path(..., description="ID of the user"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -331,11 +326,7 @@ async def get_user(
|
|||
|
||||
# MULTI-TENANT: Verify user is in the same mandate (unless SysAdmin)
|
||||
if context.mandateId and not context.isSysAdmin:
|
||||
from modules.datamodels.datamodelMembership import UserMandate
|
||||
userMandate = appInterface.db.getRecordset(UserMandate, recordFilter={
|
||||
"userId": userId,
|
||||
"mandateId": str(context.mandateId)
|
||||
})
|
||||
userMandate = appInterface.getUserMandate(userId, str(context.mandateId))
|
||||
if not userMandate:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
|
@ -365,7 +356,7 @@ class CreateUserRequest(BaseModel):
|
|||
|
||||
@router.post("", response_model=User)
|
||||
@limiter.limit("10/minute")
|
||||
async def create_user(
|
||||
def create_user(
|
||||
request: Request,
|
||||
userData: CreateUserRequest = Body(...),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -405,7 +396,7 @@ async def create_user(
|
|||
|
||||
@router.put("/{userId}", response_model=User)
|
||||
@limiter.limit("10/minute")
|
||||
async def update_user(
|
||||
def update_user(
|
||||
request: Request,
|
||||
userId: str = Path(..., description="ID of the user to update"),
|
||||
userData: User = Body(...),
|
||||
|
|
@ -427,11 +418,7 @@ async def update_user(
|
|||
|
||||
# MULTI-TENANT: Verify user is in the same mandate (unless SysAdmin)
|
||||
if context.mandateId and not context.isSysAdmin:
|
||||
from modules.datamodels.datamodelMembership import UserMandate
|
||||
userMandate = appInterface.db.getRecordset(UserMandate, recordFilter={
|
||||
"userId": userId,
|
||||
"mandateId": str(context.mandateId)
|
||||
})
|
||||
userMandate = appInterface.getUserMandate(userId, str(context.mandateId))
|
||||
if not userMandate:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
|
@ -451,7 +438,7 @@ async def update_user(
|
|||
|
||||
@router.post("/{userId}/reset-password")
|
||||
@limiter.limit("5/minute")
|
||||
async def reset_user_password(
|
||||
def reset_user_password(
|
||||
request: Request,
|
||||
userId: str = Path(..., description="ID of the user to reset password for"),
|
||||
newPassword: str = Body(..., embed=True),
|
||||
|
|
@ -482,11 +469,7 @@ async def reset_user_password(
|
|||
|
||||
# MULTI-TENANT: Verify user is in the same mandate (unless SysAdmin)
|
||||
if context.mandateId and not context.isSysAdmin:
|
||||
from modules.datamodels.datamodelMembership import UserMandate
|
||||
userMandate = appInterface.db.getRecordset(UserMandate, recordFilter={
|
||||
"userId": userId,
|
||||
"mandateId": str(context.mandateId)
|
||||
})
|
||||
userMandate = appInterface.getUserMandate(userId, str(context.mandateId))
|
||||
if not userMandate:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
|
@ -552,7 +535,7 @@ async def reset_user_password(
|
|||
|
||||
@router.post("/change-password")
|
||||
@limiter.limit("5/minute")
|
||||
async def change_password(
|
||||
def change_password(
|
||||
request: Request,
|
||||
currentPassword: str = Body(..., embed=True),
|
||||
newPassword: str = Body(..., embed=True),
|
||||
|
|
@ -631,7 +614,7 @@ async def change_password(
|
|||
|
||||
@router.post("/{userId}/send-password-link")
|
||||
@limiter.limit("10/minute")
|
||||
async def send_password_link(
|
||||
def send_password_link(
|
||||
request: Request,
|
||||
userId: str = Path(..., description="ID of the user to send password setup link"),
|
||||
frontendUrl: str = Body(..., embed=True),
|
||||
|
|
@ -664,11 +647,7 @@ async def send_password_link(
|
|||
|
||||
# MULTI-TENANT: Verify user is in the same mandate (unless SysAdmin)
|
||||
if context.mandateId and not context.isSysAdmin:
|
||||
from modules.datamodels.datamodelMembership import UserMandate
|
||||
userMandate = appInterface.db.getRecordset(UserMandate, recordFilter={
|
||||
"userId": userId,
|
||||
"mandateId": str(context.mandateId)
|
||||
})
|
||||
userMandate = appInterface.getUserMandate(userId, str(context.mandateId))
|
||||
if not userMandate:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
|
@ -770,7 +749,7 @@ Falls Sie diese Anforderung nicht erwartet haben, kontaktieren Sie bitte Ihren A
|
|||
|
||||
@router.delete("/{userId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_user(
|
||||
def delete_user(
|
||||
request: Request,
|
||||
userId: str = Path(..., description="ID of the user to delete"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -791,11 +770,7 @@ async def delete_user(
|
|||
|
||||
# MULTI-TENANT: Verify user is in the same mandate (unless SysAdmin)
|
||||
if context.mandateId and not context.isSysAdmin:
|
||||
from modules.datamodels.datamodelMembership import UserMandate
|
||||
userMandate = appInterface.db.getRecordset(UserMandate, recordFilter={
|
||||
"userId": userId,
|
||||
"mandateId": str(context.mandateId)
|
||||
})
|
||||
userMandate = appInterface.getUserMandate(userId, str(context.mandateId))
|
||||
if not userMandate:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
|
@ -803,10 +778,9 @@ async def delete_user(
|
|||
)
|
||||
|
||||
# Delete UserMandate entries for this user first
|
||||
from modules.datamodels.datamodelMembership import UserMandate
|
||||
userMandates = appInterface.db.getRecordset(UserMandate, recordFilter={"userId": userId})
|
||||
userMandates = appInterface.getUserMandates(userId)
|
||||
for um in userMandates:
|
||||
appInterface.db.deleteRecord(UserMandate, um["id"])
|
||||
appInterface.deleteUserMandate(userId, str(um.mandateId))
|
||||
|
||||
success = appInterface.deleteUser(userId)
|
||||
if not success:
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ def getServiceChat(currentUser: User):
|
|||
# Consolidated endpoint for getting all workflows
|
||||
@router.get("/", response_model=PaginatedResponse[ChatWorkflow])
|
||||
@limiter.limit("120/minute")
|
||||
async def get_workflows(
|
||||
def get_workflows(
|
||||
request: Request,
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -123,7 +123,7 @@ async def get_workflows(
|
|||
|
||||
@router.get("/{workflowId}", response_model=ChatWorkflow)
|
||||
@limiter.limit("120/minute")
|
||||
async def get_workflow(
|
||||
def get_workflow(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -152,7 +152,7 @@ async def get_workflow(
|
|||
|
||||
@router.put("/{workflowId}", response_model=ChatWorkflow)
|
||||
@limiter.limit("120/minute")
|
||||
async def update_workflow(
|
||||
def update_workflow(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow to update"),
|
||||
workflowData: Dict[str, Any] = Body(...),
|
||||
|
|
@ -163,16 +163,14 @@ async def update_workflow(
|
|||
# Get workflow interface with current user context
|
||||
workflowInterface = getInterface(currentUser)
|
||||
|
||||
# Get raw workflow data from database to check permissions
|
||||
workflows = workflowInterface.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
|
||||
if not workflows:
|
||||
# Get workflow using interface method to check permissions
|
||||
workflow = workflowInterface.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Workflow not found"
|
||||
)
|
||||
|
||||
workflow_data = workflows[0]
|
||||
|
||||
# Check if user has permission to update using RBAC
|
||||
if not workflowInterface.checkRbacPermission(ChatWorkflow, "update", workflowId):
|
||||
raise HTTPException(
|
||||
|
|
@ -202,7 +200,7 @@ async def update_workflow(
|
|||
# API Endpoint for workflow status
|
||||
@router.get("/{workflowId}/status", response_model=ChatWorkflow)
|
||||
@limiter.limit("120/minute")
|
||||
async def get_workflow_status(
|
||||
def get_workflow_status(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -230,10 +228,53 @@ async def get_workflow_status(
|
|||
detail=f"Error getting workflow status: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# API Endpoint for stopping a workflow
|
||||
@router.post("/{workflowId}/stop", response_model=ChatWorkflow)
|
||||
@limiter.limit("120/minute")
|
||||
async def stop_workflow(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow to stop"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> ChatWorkflow:
|
||||
"""
|
||||
Stop a running workflow.
|
||||
This is a general endpoint that can be used by any feature to stop a workflow.
|
||||
"""
|
||||
try:
|
||||
from modules.workflows.automation import chatStop
|
||||
|
||||
# Get the workflow first to get mandateId
|
||||
interfaceChatDb = getServiceChat(currentUser)
|
||||
workflow = interfaceChatDb.getWorkflow(workflowId)
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Workflow with ID {workflowId} not found"
|
||||
)
|
||||
|
||||
mandateId = workflow.get("mandateId") if isinstance(workflow, dict) else getattr(workflow, "mandateId", None)
|
||||
|
||||
# Stop the workflow
|
||||
stoppedWorkflow = await chatStop(currentUser, workflowId, mandateId=mandateId)
|
||||
|
||||
return stoppedWorkflow
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping workflow: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error stopping workflow: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# API Endpoint for workflow logs with selective data transfer
|
||||
@router.get("/{workflowId}/logs", response_model=PaginatedResponse[ChatLog])
|
||||
@limiter.limit("120/minute")
|
||||
async def get_workflow_logs(
|
||||
def get_workflow_logs(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
logId: Optional[str] = Query(None, description="Optional log ID to get only newer logs (legacy selective data transfer)"),
|
||||
|
|
@ -324,7 +365,7 @@ async def get_workflow_logs(
|
|||
# API Endpoint for workflow messages with selective data transfer
|
||||
@router.get("/{workflowId}/messages", response_model=PaginatedResponse[ChatMessage])
|
||||
@limiter.limit("120/minute")
|
||||
async def get_workflow_messages(
|
||||
def get_workflow_messages(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
messageId: Optional[str] = Query(None, description="Optional message ID to get only newer messages (legacy selective data transfer)"),
|
||||
|
|
@ -416,7 +457,7 @@ async def get_workflow_messages(
|
|||
# State 11: Workflow Reset/Deletion endpoint
|
||||
@router.delete("/{workflowId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("120/minute")
|
||||
async def delete_workflow(
|
||||
def delete_workflow(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow to delete"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -475,7 +516,7 @@ async def delete_workflow(
|
|||
|
||||
@router.delete("/{workflowId}/messages/{messageId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("120/minute")
|
||||
async def delete_workflow_message(
|
||||
def delete_workflow_message(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
messageId: str = Path(..., description="ID of the message to delete"),
|
||||
|
|
@ -525,7 +566,7 @@ async def delete_workflow_message(
|
|||
|
||||
@router.delete("/{workflowId}/messages/{messageId}/files/{fileId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("120/minute")
|
||||
async def delete_file_from_message(
|
||||
def delete_file_from_message(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
messageId: str = Path(..., description="ID of the message"),
|
||||
|
|
@ -574,7 +615,7 @@ async def delete_file_from_message(
|
|||
|
||||
@router.get("/actions", response_model=Dict[str, Any])
|
||||
@limiter.limit("120/minute")
|
||||
async def get_all_actions(
|
||||
def get_all_actions(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
|
|
@ -644,7 +685,7 @@ async def get_all_actions(
|
|||
|
||||
@router.get("/actions/{method}", response_model=Dict[str, Any])
|
||||
@limiter.limit("120/minute")
|
||||
async def get_method_actions(
|
||||
def get_method_actions(
|
||||
request: Request,
|
||||
method: str = Path(..., description="Method name (e.g., 'outlook', 'sharepoint')"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -727,7 +768,7 @@ async def get_method_actions(
|
|||
|
||||
@router.get("/actions/{method}/{action}", response_model=Dict[str, Any])
|
||||
@limiter.limit("120/minute")
|
||||
async def get_action_schema(
|
||||
def get_action_schema(
|
||||
request: Request,
|
||||
method: str = Path(..., description="Method name (e.g., 'outlook', 'sharepoint')"),
|
||||
action: str = Path(..., description="Action name (e.g., 'readEmails', 'uploadDocument')"),
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class DeletionResult(BaseModel):
|
|||
|
||||
@router.get("/data-export", response_model=DataExportResponse)
|
||||
@limiter.limit("5/minute")
|
||||
async def export_user_data(
|
||||
def export_user_data(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> DataExportResponse:
|
||||
|
|
@ -109,96 +109,73 @@ async def export_user_data(
|
|||
"authenticationAuthority": str(getattr(currentUser, "authenticationAuthority", ""))
|
||||
}
|
||||
|
||||
# Mandate memberships
|
||||
from modules.datamodels.datamodelMembership import UserMandate
|
||||
userMandates = rootInterface.db.getRecordset(
|
||||
UserMandate,
|
||||
recordFilter={"userId": str(currentUser.id)}
|
||||
)
|
||||
# Mandate memberships using interface method
|
||||
userMandates = rootInterface.getUserMandates(str(currentUser.id))
|
||||
|
||||
mandates = []
|
||||
for um in userMandates:
|
||||
mandateId = um.get("mandateId")
|
||||
mandateId = um.mandateId
|
||||
|
||||
# Get mandate details
|
||||
mandateRecords = rootInterface.db.getRecordset(
|
||||
Mandate,
|
||||
recordFilter={"id": mandateId}
|
||||
)
|
||||
mandateName = mandateRecords[0].get("name") if mandateRecords else "Unknown"
|
||||
# Get mandate details using interface method
|
||||
mandate = rootInterface.getMandate(mandateId)
|
||||
mandateName = mandate.name if mandate else "Unknown"
|
||||
|
||||
# Get roles for this membership
|
||||
roleIds = rootInterface.getRoleIdsForUserMandate(um.get("id"))
|
||||
roleIds = rootInterface.getRoleIdsForUserMandate(um.id)
|
||||
|
||||
mandates.append({
|
||||
"userMandateId": um.get("id"),
|
||||
"userMandateId": um.id,
|
||||
"mandateId": mandateId,
|
||||
"mandateName": mandateName,
|
||||
"enabled": um.get("enabled", True),
|
||||
"enabled": um.enabled,
|
||||
"roleIds": roleIds,
|
||||
"joinedAt": um.get("createdAt")
|
||||
"joinedAt": um.createdAt
|
||||
})
|
||||
|
||||
# Feature access records
|
||||
from modules.datamodels.datamodelMembership import FeatureAccess
|
||||
featureAccesses = rootInterface.db.getRecordset(
|
||||
FeatureAccess,
|
||||
recordFilter={"userId": str(currentUser.id)}
|
||||
)
|
||||
# Feature access records using interface method
|
||||
featureAccesses = rootInterface.getFeatureAccessesForUser(str(currentUser.id))
|
||||
|
||||
featureAccessList = []
|
||||
for fa in featureAccesses:
|
||||
instanceId = fa.get("featureInstanceId")
|
||||
instanceId = fa.featureInstanceId
|
||||
|
||||
# Get instance details
|
||||
from modules.datamodels.datamodelFeatures import FeatureInstance
|
||||
instanceRecords = rootInterface.db.getRecordset(
|
||||
FeatureInstance,
|
||||
recordFilter={"id": instanceId}
|
||||
)
|
||||
# Get instance details using interface method
|
||||
instance = rootInterface.getFeatureInstance(instanceId)
|
||||
|
||||
instanceInfo = instanceRecords[0] if instanceRecords else {}
|
||||
roleIds = rootInterface.getRoleIdsForFeatureAccess(fa.get("id"))
|
||||
roleIds = rootInterface.getRoleIdsForFeatureAccess(fa.id)
|
||||
|
||||
featureAccessList.append({
|
||||
"featureAccessId": fa.get("id"),
|
||||
"featureAccessId": fa.id,
|
||||
"featureInstanceId": instanceId,
|
||||
"featureCode": instanceInfo.get("featureCode"),
|
||||
"instanceLabel": instanceInfo.get("label"),
|
||||
"enabled": fa.get("enabled", True),
|
||||
"featureCode": instance.featureCode if instance else None,
|
||||
"instanceLabel": instance.label if instance else None,
|
||||
"enabled": fa.enabled,
|
||||
"roleIds": roleIds
|
||||
})
|
||||
|
||||
# Invitations created by user
|
||||
from modules.datamodels.datamodelInvitation import Invitation
|
||||
invitationsCreated = rootInterface.db.getRecordset(
|
||||
Invitation,
|
||||
recordFilter={"createdBy": str(currentUser.id)}
|
||||
)
|
||||
# Invitations created by user using interface method
|
||||
invitationsCreated = rootInterface.getInvitationsByCreator(str(currentUser.id))
|
||||
|
||||
invitationsCreatedList = [
|
||||
{
|
||||
"id": inv.get("id"),
|
||||
"mandateId": inv.get("mandateId"),
|
||||
"createdAt": inv.get("createdAt"),
|
||||
"expiresAt": inv.get("expiresAt"),
|
||||
"maxUses": inv.get("maxUses"),
|
||||
"currentUses": inv.get("currentUses")
|
||||
"id": inv.id,
|
||||
"mandateId": inv.mandateId,
|
||||
"createdAt": inv.createdAt,
|
||||
"expiresAt": inv.expiresAt,
|
||||
"maxUses": inv.maxUses,
|
||||
"currentUses": inv.currentUses
|
||||
}
|
||||
for inv in invitationsCreated
|
||||
]
|
||||
|
||||
# Invitations used by user
|
||||
invitationsUsed = rootInterface.db.getRecordset(
|
||||
Invitation,
|
||||
recordFilter={"usedBy": str(currentUser.id)}
|
||||
)
|
||||
# Invitations used by user using interface method
|
||||
invitationsUsed = rootInterface.getInvitationsByUsedBy(str(currentUser.id))
|
||||
|
||||
invitationsUsedList = [
|
||||
{
|
||||
"id": inv.get("id"),
|
||||
"mandateId": inv.get("mandateId"),
|
||||
"usedAt": inv.get("usedAt")
|
||||
"id": inv.id,
|
||||
"mandateId": inv.mandateId,
|
||||
"usedAt": inv.usedAt
|
||||
}
|
||||
for inv in invitationsUsed
|
||||
]
|
||||
|
|
@ -238,7 +215,7 @@ async def export_user_data(
|
|||
|
||||
@router.get("/data-portability")
|
||||
@limiter.limit("5/minute")
|
||||
async def export_portable_data(
|
||||
def export_portable_data(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> JSONResponse:
|
||||
|
|
@ -262,26 +239,18 @@ async def export_portable_data(
|
|||
"additionalProperty": []
|
||||
}
|
||||
|
||||
# Add mandate memberships as organization affiliations
|
||||
from modules.datamodels.datamodelMembership import UserMandate
|
||||
userMandates = rootInterface.db.getRecordset(
|
||||
UserMandate,
|
||||
recordFilter={"userId": str(currentUser.id)}
|
||||
)
|
||||
# Add mandate memberships as organization affiliations using interface method
|
||||
userMandates = rootInterface.getUserMandates(str(currentUser.id))
|
||||
|
||||
affiliations = []
|
||||
for um in userMandates:
|
||||
mandateRecords = rootInterface.db.getRecordset(
|
||||
Mandate,
|
||||
recordFilter={"id": um.get("mandateId")}
|
||||
)
|
||||
if mandateRecords:
|
||||
mandate = mandateRecords[0]
|
||||
mandate = rootInterface.getMandate(um.mandateId)
|
||||
if mandate:
|
||||
affiliations.append({
|
||||
"@type": "Organization",
|
||||
"identifier": um.get("mandateId"),
|
||||
"name": mandate.get("name"),
|
||||
"membershipActive": um.get("enabled", True)
|
||||
"identifier": um.mandateId,
|
||||
"name": mandate.name,
|
||||
"membershipActive": um.enabled
|
||||
})
|
||||
|
||||
if affiliations:
|
||||
|
|
@ -327,7 +296,7 @@ async def export_portable_data(
|
|||
|
||||
@router.delete("/", response_model=DeletionResult)
|
||||
@limiter.limit("1/hour")
|
||||
async def delete_account(
|
||||
def delete_account(
|
||||
request: Request,
|
||||
confirmDeletion: bool = False,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -370,15 +339,12 @@ async def delete_account(
|
|||
# Step 2: Revoke invitations BEFORE generic deletion (business logic)
|
||||
rootInterface = getRootInterface()
|
||||
from modules.datamodels.datamodelInvitation import Invitation
|
||||
userInvitations = rootInterface.db.getRecordset(
|
||||
Invitation,
|
||||
recordFilter={"createdBy": str(currentUser.id)}
|
||||
)
|
||||
userInvitations = rootInterface.getInvitationsByCreator(str(currentUser.id))
|
||||
|
||||
for inv in userInvitations:
|
||||
rootInterface.db.recordModify(
|
||||
Invitation,
|
||||
inv.get("id"),
|
||||
inv.id,
|
||||
{"revokedAt": getUtcTimestamp()}
|
||||
)
|
||||
|
||||
|
|
@ -425,7 +391,7 @@ async def delete_account(
|
|||
|
||||
@router.get("/consent-info", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_consent_info(
|
||||
def get_consent_info(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ class InvitationCreate(BaseModel):
|
|||
email: Optional[str] = Field(None, description="Email address to send invitation link (optional)")
|
||||
roleIds: List[str] = Field(..., description="Role IDs to assign to the invited user")
|
||||
featureInstanceId: Optional[str] = Field(None, description="Optional feature instance access")
|
||||
frontendUrl: str = Field(..., description="Frontend URL for building the invite link (provided by frontend)")
|
||||
expiresInHours: int = Field(
|
||||
72,
|
||||
ge=1,
|
||||
|
|
@ -94,7 +95,7 @@ class InvitationValidation(BaseModel):
|
|||
|
||||
@router.post("/", response_model=InvitationResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def create_invitation(
|
||||
def create_invitation(
|
||||
request: Request,
|
||||
data: InvitationCreate,
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -131,17 +132,14 @@ async def create_invitation(
|
|||
|
||||
# Validate role IDs exist and belong to this mandate or are global
|
||||
for roleId in data.roleIds:
|
||||
from modules.datamodels.datamodelRbac import Role
|
||||
roleRecords = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId})
|
||||
if not roleRecords:
|
||||
role = rootInterface.getRole(roleId)
|
||||
if not role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Role '{roleId}' not found"
|
||||
)
|
||||
role = roleRecords[0]
|
||||
# Role must be global or belong to this mandate
|
||||
roleMandateId = role.get("mandateId")
|
||||
if roleMandateId and str(roleMandateId) != str(context.mandateId):
|
||||
if role.mandateId and str(role.mandateId) != str(context.mandateId):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Role '{roleId}' belongs to a different mandate"
|
||||
|
|
@ -149,18 +147,13 @@ async def create_invitation(
|
|||
|
||||
# Validate feature instance if provided
|
||||
if data.featureInstanceId:
|
||||
from modules.datamodels.datamodelFeatures import FeatureInstance
|
||||
instanceRecords = rootInterface.db.getRecordset(
|
||||
FeatureInstance,
|
||||
recordFilter={"id": data.featureInstanceId}
|
||||
)
|
||||
if not instanceRecords:
|
||||
instance = rootInterface.getFeatureInstance(data.featureInstanceId)
|
||||
if not instance:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Feature instance '{data.featureInstanceId}' not found"
|
||||
)
|
||||
instance = instanceRecords[0]
|
||||
if str(instance.get("mandateId")) != str(context.mandateId):
|
||||
if str(instance.mandateId) != str(context.mandateId):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Feature instance belongs to a different mandate"
|
||||
|
|
@ -186,24 +179,18 @@ async def create_invitation(
|
|||
if not createdRecord:
|
||||
raise ValueError("Failed to create invitation record")
|
||||
|
||||
# Build invite URL
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
frontendUrl = APP_CONFIG.get("APP_FRONTEND_URL", "http://localhost:8080")
|
||||
inviteUrl = f"{frontendUrl}/invite/{invitation.token}"
|
||||
# Build invite URL using frontend URL provided by the caller
|
||||
baseUrl = data.frontendUrl.rstrip("/")
|
||||
inviteUrl = f"{baseUrl}/invite/{invitation.token}"
|
||||
|
||||
# Send email if email address is provided
|
||||
emailSent = False
|
||||
if data.email:
|
||||
try:
|
||||
from modules.connectors.connectorMessagingEmail import ConnectorMessagingEmail
|
||||
from modules.datamodels.datamodelUam import Mandate
|
||||
|
||||
# Get mandate name for the email
|
||||
mandateRecords = rootInterface.db.getRecordset(
|
||||
Mandate,
|
||||
recordFilter={"id": str(context.mandateId)}
|
||||
)
|
||||
mandateName = mandateRecords[0].get("name", "PowerOn") if mandateRecords else "PowerOn"
|
||||
mandate = rootInterface.getMandate(str(context.mandateId))
|
||||
mandateName = (mandate.label or mandate.name) if mandate else "PowerOn"
|
||||
|
||||
emailConnector = ConnectorMessagingEmail()
|
||||
emailSubject = f"Einladung zu {mandateName}"
|
||||
|
|
@ -259,14 +246,10 @@ async def create_invitation(
|
|||
existingUser = rootInterface.getUserByUsername(data.targetUsername)
|
||||
if existingUser:
|
||||
from modules.routes.routeNotifications import createInvitationNotification
|
||||
from modules.datamodels.datamodelUam import Mandate
|
||||
|
||||
# Get mandate name for notification
|
||||
mandateRecords = rootInterface.db.getRecordset(
|
||||
Mandate,
|
||||
recordFilter={"id": str(context.mandateId)}
|
||||
)
|
||||
mandateName = mandateRecords[0].get("mandateLabel", "PowerOn") if mandateRecords else "PowerOn"
|
||||
mandate = rootInterface.getMandate(str(context.mandateId))
|
||||
mandateName = (mandate.label or mandate.name) if mandate else "PowerOn"
|
||||
inviterName = context.user.fullName or context.user.username
|
||||
|
||||
createInvitationNotification(
|
||||
|
|
@ -317,8 +300,9 @@ async def create_invitation(
|
|||
|
||||
@router.get("/", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("60/minute")
|
||||
async def list_invitations(
|
||||
def list_invitations(
|
||||
request: Request,
|
||||
frontendUrl: str = Query(..., description="Frontend URL for building invite links (provided by frontend)"),
|
||||
includeUsed: bool = Query(False, description="Include already used invitations"),
|
||||
includeExpired: bool = Query(False, description="Include expired invitations"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -348,38 +332,37 @@ async def list_invitations(
|
|||
try:
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Get all invitations for this mandate
|
||||
allInvitations = rootInterface.db.getRecordset(
|
||||
Invitation,
|
||||
recordFilter={"mandateId": str(context.mandateId)}
|
||||
)
|
||||
# Get all invitations for this mandate (Pydantic models)
|
||||
allInvitations = rootInterface.getInvitationsByMandate(str(context.mandateId))
|
||||
|
||||
currentTime = getUtcTimestamp()
|
||||
result = []
|
||||
|
||||
for inv in allInvitations:
|
||||
# Skip revoked invitations
|
||||
if inv.get("revokedAt"):
|
||||
if inv.revokedAt:
|
||||
continue
|
||||
|
||||
# Filter by usage
|
||||
if not includeUsed and inv.get("currentUses", 0) >= inv.get("maxUses", 1):
|
||||
currentUses = inv.currentUses or 0
|
||||
maxUses = inv.maxUses or 1
|
||||
if not includeUsed and currentUses >= maxUses:
|
||||
continue
|
||||
|
||||
# Filter by expiration
|
||||
if not includeExpired and inv.get("expiresAt", 0) < currentTime:
|
||||
expiresAt = inv.expiresAt or 0
|
||||
if not includeExpired and expiresAt < currentTime:
|
||||
continue
|
||||
|
||||
# Build invite URL
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
frontendUrl = APP_CONFIG.get("APP_FRONTEND_URL", "http://localhost:8080")
|
||||
inviteUrl = f"{frontendUrl}/invite/{inv.get('token')}"
|
||||
# Build invite URL using frontend URL provided by the caller
|
||||
baseUrl = frontendUrl.rstrip("/")
|
||||
inviteUrl = f"{baseUrl}/invite/{inv.token}"
|
||||
|
||||
result.append({
|
||||
**{k: v for k, v in inv.items() if not k.startswith("_")},
|
||||
**inv.model_dump(),
|
||||
"inviteUrl": inviteUrl,
|
||||
"isExpired": inv.get("expiresAt", 0) < currentTime,
|
||||
"isUsedUp": inv.get("currentUses", 0) >= inv.get("maxUses", 1)
|
||||
"isExpired": expiresAt < currentTime,
|
||||
"isUsedUp": currentUses >= maxUses
|
||||
})
|
||||
|
||||
return result
|
||||
|
|
@ -396,7 +379,7 @@ async def list_invitations(
|
|||
|
||||
@router.delete("/{invitationId}", response_model=Dict[str, str])
|
||||
@limiter.limit("30/minute")
|
||||
async def revoke_invitation(
|
||||
def revoke_invitation(
|
||||
request: Request,
|
||||
invitationId: str,
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
|
|
@ -425,29 +408,24 @@ async def revoke_invitation(
|
|||
try:
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Get invitation
|
||||
invitationRecords = rootInterface.db.getRecordset(
|
||||
Invitation,
|
||||
recordFilter={"id": invitationId}
|
||||
)
|
||||
# Get invitation (Pydantic model)
|
||||
invitation = rootInterface.getInvitation(invitationId)
|
||||
|
||||
if not invitationRecords:
|
||||
if not invitation:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Invitation '{invitationId}' not found"
|
||||
)
|
||||
|
||||
invitation = invitationRecords[0]
|
||||
|
||||
# Verify mandate access
|
||||
if str(invitation.get("mandateId")) != str(context.mandateId):
|
||||
if str(invitation.mandateId) != str(context.mandateId):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this invitation"
|
||||
)
|
||||
|
||||
# Already revoked?
|
||||
if invitation.get("revokedAt"):
|
||||
if invitation.revokedAt:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invitation is already revoked"
|
||||
|
|
@ -480,7 +458,7 @@ async def revoke_invitation(
|
|||
|
||||
@router.get("/validate/{token}", response_model=InvitationValidation)
|
||||
@limiter.limit("30/minute")
|
||||
async def validate_invitation(
|
||||
def validate_invitation(
|
||||
request: Request,
|
||||
token: str
|
||||
) -> InvitationValidation:
|
||||
|
|
@ -496,13 +474,10 @@ async def validate_invitation(
|
|||
try:
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Find invitation by token
|
||||
invitationRecords = rootInterface.db.getRecordset(
|
||||
Invitation,
|
||||
recordFilter={"token": token}
|
||||
)
|
||||
# Find invitation by token (Pydantic model)
|
||||
invitation = rootInterface.getInvitationByToken(token)
|
||||
|
||||
if not invitationRecords:
|
||||
if not invitation:
|
||||
return InvitationValidation(
|
||||
valid=False,
|
||||
reason="Invitation not found",
|
||||
|
|
@ -511,10 +486,8 @@ async def validate_invitation(
|
|||
roleIds=[]
|
||||
)
|
||||
|
||||
invitation = invitationRecords[0]
|
||||
|
||||
# Check if revoked
|
||||
if invitation.get("revokedAt"):
|
||||
if invitation.revokedAt:
|
||||
return InvitationValidation(
|
||||
valid=False,
|
||||
reason="Invitation has been revoked",
|
||||
|
|
@ -525,7 +498,8 @@ async def validate_invitation(
|
|||
|
||||
# Check if expired
|
||||
currentTime = getUtcTimestamp()
|
||||
if invitation.get("expiresAt", 0) < currentTime:
|
||||
expiresAt = invitation.expiresAt or 0
|
||||
if expiresAt < currentTime:
|
||||
return InvitationValidation(
|
||||
valid=False,
|
||||
reason="Invitation has expired",
|
||||
|
|
@ -535,7 +509,9 @@ async def validate_invitation(
|
|||
)
|
||||
|
||||
# Check if used up
|
||||
if invitation.get("currentUses", 0) >= invitation.get("maxUses", 1):
|
||||
currentUses = invitation.currentUses or 0
|
||||
maxUses = invitation.maxUses or 1
|
||||
if currentUses >= maxUses:
|
||||
return InvitationValidation(
|
||||
valid=False,
|
||||
reason="Invitation has reached maximum uses",
|
||||
|
|
@ -545,34 +521,29 @@ async def validate_invitation(
|
|||
)
|
||||
|
||||
# Get additional info for display
|
||||
mandateId = invitation.get("mandateId")
|
||||
mandateId = invitation.mandateId
|
||||
mandateName = None
|
||||
roleLabels = []
|
||||
targetUsername = invitation.get("targetUsername")
|
||||
targetUsername = invitation.targetUsername
|
||||
|
||||
# Get mandate name
|
||||
from modules.datamodels.datamodelUam import Mandate
|
||||
mandateRecords = rootInterface.db.getRecordset(
|
||||
Mandate,
|
||||
recordFilter={"id": mandateId}
|
||||
)
|
||||
if mandateRecords:
|
||||
mandateName = mandateRecords[0].get("name")
|
||||
mandate = rootInterface.getMandate(str(mandateId)) if mandateId else None
|
||||
if mandate:
|
||||
mandateName = mandate.label or mandate.name
|
||||
|
||||
# Get role names
|
||||
roleIds = invitation.get("roleIds", [])
|
||||
from modules.datamodels.datamodelRbac import Role
|
||||
roleIds = invitation.roleIds or []
|
||||
for roleId in roleIds:
|
||||
roleRecords = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId})
|
||||
if roleRecords:
|
||||
roleLabels.append(roleRecords[0].get("roleLabel", roleId))
|
||||
role = rootInterface.getRole(roleId)
|
||||
if role:
|
||||
roleLabels.append(role.roleLabel)
|
||||
|
||||
return InvitationValidation(
|
||||
valid=True,
|
||||
reason=None,
|
||||
mandateId=mandateId,
|
||||
mandateId=str(mandateId) if mandateId else None,
|
||||
mandateName=mandateName,
|
||||
featureInstanceId=invitation.get("featureInstanceId"),
|
||||
featureInstanceId=str(invitation.featureInstanceId) if invitation.featureInstanceId else None,
|
||||
roleIds=roleIds,
|
||||
roleLabels=roleLabels,
|
||||
targetUsername=targetUsername
|
||||
|
|
@ -591,7 +562,7 @@ async def validate_invitation(
|
|||
|
||||
@router.post("/accept/{token}", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def accept_invitation(
|
||||
def accept_invitation(
|
||||
request: Request,
|
||||
token: str,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -608,42 +579,40 @@ async def accept_invitation(
|
|||
try:
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Find invitation by token
|
||||
invitationRecords = rootInterface.db.getRecordset(
|
||||
Invitation,
|
||||
recordFilter={"token": token}
|
||||
)
|
||||
# Find invitation by token (Pydantic model)
|
||||
invitation = rootInterface.getInvitationByToken(token)
|
||||
|
||||
if not invitationRecords:
|
||||
if not invitation:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Invitation not found"
|
||||
)
|
||||
|
||||
invitation = invitationRecords[0]
|
||||
|
||||
# Validate invitation
|
||||
if invitation.get("revokedAt"):
|
||||
if invitation.revokedAt:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invitation has been revoked"
|
||||
)
|
||||
|
||||
currentTime = getUtcTimestamp()
|
||||
if invitation.get("expiresAt", 0) < currentTime:
|
||||
expiresAt = invitation.expiresAt or 0
|
||||
if expiresAt < currentTime:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invitation has expired"
|
||||
)
|
||||
|
||||
if invitation.get("currentUses", 0) >= invitation.get("maxUses", 1):
|
||||
currentUses = invitation.currentUses or 0
|
||||
maxUses = invitation.maxUses or 1
|
||||
if currentUses >= maxUses:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invitation has reached maximum uses"
|
||||
)
|
||||
|
||||
# Validate username matches - the invitation is bound to a specific user
|
||||
targetUsername = invitation.get("targetUsername")
|
||||
targetUsername = invitation.targetUsername
|
||||
if targetUsername and currentUser.username != targetUsername:
|
||||
logger.warning(
|
||||
f"User {currentUser.username} tried to accept invitation meant for {targetUsername}"
|
||||
|
|
@ -653,9 +622,9 @@ async def accept_invitation(
|
|||
detail=f"Diese Einladung ist für Benutzer '{targetUsername}' bestimmt"
|
||||
)
|
||||
|
||||
mandateId = invitation.get("mandateId")
|
||||
roleIds = invitation.get("roleIds", [])
|
||||
featureInstanceId = invitation.get("featureInstanceId")
|
||||
mandateId = str(invitation.mandateId) if invitation.mandateId else None
|
||||
roleIds = invitation.roleIds or []
|
||||
featureInstanceId = str(invitation.featureInstanceId) if invitation.featureInstanceId else None
|
||||
|
||||
# Check if user is already a member
|
||||
existingMembership = rootInterface.getUserMandate(str(currentUser.id), mandateId)
|
||||
|
|
@ -744,22 +713,19 @@ def _hasMandateAdminRole(context: RequestContext) -> bool:
|
|||
|
||||
try:
|
||||
rootInterface = getRootInterface()
|
||||
from modules.datamodels.datamodelRbac import Role
|
||||
|
||||
for roleId in context.roleIds:
|
||||
roleRecords = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId})
|
||||
if roleRecords:
|
||||
role = roleRecords[0]
|
||||
roleLabel = role.get("roleLabel", "")
|
||||
# Admin role at mandate level
|
||||
if roleLabel == "admin" and role.get("mandateId") and not role.get("featureInstanceId"):
|
||||
role = rootInterface.getRole(roleId)
|
||||
if role:
|
||||
# Admin role at mandate level (not feature-instance level)
|
||||
if role.roleLabel == "admin" and role.mandateId and not role.featureInstanceId:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking mandate admin role: {e}")
|
||||
return False
|
||||
return False # Fail-safe: no access on error
|
||||
|
||||
|
||||
def _isInstanceRole(interface, roleId: str, featureInstanceId: str) -> bool:
|
||||
|
|
@ -767,11 +733,9 @@ def _isInstanceRole(interface, roleId: str, featureInstanceId: str) -> bool:
|
|||
Check if a role belongs to a specific feature instance.
|
||||
"""
|
||||
try:
|
||||
from modules.datamodels.datamodelRbac import Role
|
||||
roleRecords = interface.db.getRecordset(Role, recordFilter={"id": roleId})
|
||||
if roleRecords:
|
||||
role = roleRecords[0]
|
||||
return str(role.get("featureInstanceId", "")) == str(featureInstanceId)
|
||||
role = interface.getRole(roleId)
|
||||
if role:
|
||||
return str(role.featureInstanceId or "") == str(featureInstanceId)
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
return False # Fail-safe: assume not instance role on error
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ router = APIRouter(
|
|||
|
||||
@router.get("/subscriptions", response_model=PaginatedResponse[MessagingSubscription])
|
||||
@limiter.limit("60/minute")
|
||||
async def get_subscriptions(
|
||||
def get_subscriptions(
|
||||
request: Request,
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -79,7 +79,7 @@ async def get_subscriptions(
|
|||
|
||||
@router.post("/subscriptions", response_model=MessagingSubscription)
|
||||
@limiter.limit("60/minute")
|
||||
async def create_subscription(
|
||||
def create_subscription(
|
||||
request: Request,
|
||||
subscription: MessagingSubscription,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -95,7 +95,7 @@ async def create_subscription(
|
|||
|
||||
@router.get("/subscriptions/{subscriptionId}", response_model=MessagingSubscription)
|
||||
@limiter.limit("60/minute")
|
||||
async def get_subscription(
|
||||
def get_subscription(
|
||||
request: Request,
|
||||
subscriptionId: str = Path(..., description="ID of the subscription"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -115,7 +115,7 @@ async def get_subscription(
|
|||
|
||||
@router.put("/subscriptions/{subscriptionId}", response_model=MessagingSubscription)
|
||||
@limiter.limit("60/minute")
|
||||
async def update_subscription(
|
||||
def update_subscription(
|
||||
request: Request,
|
||||
subscriptionId: str = Path(..., description="ID of the subscription to update"),
|
||||
subscriptionData: MessagingSubscription = Body(...),
|
||||
|
|
@ -145,7 +145,7 @@ async def update_subscription(
|
|||
|
||||
@router.delete("/subscriptions/{subscriptionId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("60/minute")
|
||||
async def delete_subscription(
|
||||
def delete_subscription(
|
||||
request: Request,
|
||||
subscriptionId: str = Path(..., description="ID of the subscription to delete"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -174,7 +174,7 @@ async def delete_subscription(
|
|||
|
||||
@router.get("/subscriptions/{subscriptionId}/registrations", response_model=PaginatedResponse[MessagingSubscriptionRegistration])
|
||||
@limiter.limit("60/minute")
|
||||
async def get_subscription_registrations(
|
||||
def get_subscription_registrations(
|
||||
request: Request,
|
||||
subscriptionId: str = Path(..., description="ID of the subscription"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
|
|
@ -219,7 +219,7 @@ async def get_subscription_registrations(
|
|||
|
||||
@router.post("/subscriptions/{subscriptionId}/subscribe", response_model=MessagingSubscriptionRegistration)
|
||||
@limiter.limit("60/minute")
|
||||
async def subscribe_user(
|
||||
def subscribe_user(
|
||||
request: Request,
|
||||
subscriptionId: str = Path(..., description="ID of the subscription"),
|
||||
channel: MessagingChannel = Body(..., embed=True),
|
||||
|
|
@ -241,7 +241,7 @@ async def subscribe_user(
|
|||
|
||||
@router.delete("/subscriptions/{subscriptionId}/unsubscribe", response_model=Dict[str, Any])
|
||||
@limiter.limit("60/minute")
|
||||
async def unsubscribe_user(
|
||||
def unsubscribe_user(
|
||||
request: Request,
|
||||
subscriptionId: str = Path(..., description="ID of the subscription"),
|
||||
channel: MessagingChannel = Body(..., embed=True),
|
||||
|
|
@ -267,7 +267,7 @@ async def unsubscribe_user(
|
|||
|
||||
@router.get("/registrations", response_model=PaginatedResponse[MessagingSubscriptionRegistration])
|
||||
@limiter.limit("60/minute")
|
||||
async def get_my_registrations(
|
||||
def get_my_registrations(
|
||||
request: Request,
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -311,7 +311,7 @@ async def get_my_registrations(
|
|||
|
||||
@router.put("/registrations/{registrationId}", response_model=MessagingSubscriptionRegistration)
|
||||
@limiter.limit("60/minute")
|
||||
async def update_registration(
|
||||
def update_registration(
|
||||
request: Request,
|
||||
registrationId: str = Path(..., description="ID of the registration to update"),
|
||||
registrationData: MessagingSubscriptionRegistration = Body(...),
|
||||
|
|
@ -341,7 +341,7 @@ async def update_registration(
|
|||
|
||||
@router.delete("/registrations/{registrationId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("60/minute")
|
||||
async def delete_registration(
|
||||
def delete_registration(
|
||||
request: Request,
|
||||
registrationId: str = Path(..., description="ID of the registration to delete"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -376,7 +376,7 @@ def _getTriggerKey(request: Request) -> str:
|
|||
|
||||
@router.post("/trigger/{subscriptionId}", response_model=MessagingSubscriptionExecutionResult)
|
||||
@limiter.limit("60/minute", key_func=_getTriggerKey)
|
||||
async def trigger_subscription(
|
||||
def trigger_subscription(
|
||||
request: Request,
|
||||
subscriptionId: str = Path(..., description="ID of the subscription to trigger"),
|
||||
eventParameters: Dict[str, Any] = Body(...),
|
||||
|
|
@ -421,10 +421,9 @@ def _hasTriggerPermission(context: RequestContext) -> bool:
|
|||
rootInterface = getRootInterface()
|
||||
|
||||
for roleId in context.roleIds:
|
||||
roleRecords = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId})
|
||||
if roleRecords:
|
||||
role = roleRecords[0]
|
||||
roleLabel = role.get("roleLabel", "")
|
||||
role = rootInterface.getRole(roleId)
|
||||
if role:
|
||||
roleLabel = role.roleLabel
|
||||
# Admin role at mandate level or system admin
|
||||
if roleLabel in ("admin", "sysadmin"):
|
||||
return True
|
||||
|
|
@ -440,7 +439,7 @@ def _hasTriggerPermission(context: RequestContext) -> bool:
|
|||
|
||||
@router.get("/deliveries", response_model=PaginatedResponse[MessagingDelivery])
|
||||
@limiter.limit("60/minute")
|
||||
async def get_deliveries(
|
||||
def get_deliveries(
|
||||
request: Request,
|
||||
subscriptionId: Optional[str] = Query(None, description="Filter by subscription ID"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
|
|
@ -486,7 +485,7 @@ async def get_deliveries(
|
|||
|
||||
@router.get("/deliveries/{deliveryId}", response_model=MessagingDelivery)
|
||||
@limiter.limit("60/minute")
|
||||
async def get_delivery(
|
||||
def get_delivery(
|
||||
request: Request,
|
||||
deliveryId: str = Path(..., description="ID of the delivery"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ def createInvitationNotification(
|
|||
|
||||
@router.get("", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("60/minute")
|
||||
async def getNotifications(
|
||||
def getNotifications(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
status: Optional[str] = None,
|
||||
|
|
@ -137,23 +137,19 @@ async def getNotifications(
|
|||
|
||||
# Build filter
|
||||
recordFilter = {"userId": str(currentUser.id)}
|
||||
if status:
|
||||
recordFilter["status"] = status
|
||||
if type:
|
||||
recordFilter["type"] = type
|
||||
|
||||
# Get notifications
|
||||
notifications = rootInterface.db.getRecordset(
|
||||
model_class=UserNotification,
|
||||
recordFilter=recordFilter
|
||||
# Get notifications (Pydantic models, sorted and limited)
|
||||
notifications = rootInterface.getNotificationsByUser(
|
||||
userId=str(currentUser.id),
|
||||
status=status,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# Sort by creation date (newest first) and limit
|
||||
notifications = sorted(notifications, key=lambda x: x.get("createdAt", 0), reverse=True)
|
||||
if limit:
|
||||
notifications = notifications[:limit]
|
||||
# Apply type filter if needed (not common, so filter post-fetch)
|
||||
if type:
|
||||
notifications = [n for n in notifications if n.type == type]
|
||||
|
||||
return notifications
|
||||
# Convert to dicts for response
|
||||
return [n.model_dump() for n in notifications]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting notifications: {e}")
|
||||
|
|
@ -165,7 +161,7 @@ async def getNotifications(
|
|||
|
||||
@router.get("/unread-count", response_model=UnreadCountResponse)
|
||||
@limiter.limit("120/minute")
|
||||
async def getUnreadCount(
|
||||
def getUnreadCount(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> UnreadCountResponse:
|
||||
|
|
@ -176,12 +172,10 @@ async def getUnreadCount(
|
|||
try:
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
notifications = rootInterface.db.getRecordset(
|
||||
model_class=UserNotification,
|
||||
recordFilter={
|
||||
"userId": str(currentUser.id),
|
||||
"status": NotificationStatus.UNREAD.value
|
||||
}
|
||||
# Get unread notifications (Pydantic models)
|
||||
notifications = rootInterface.getNotificationsByUser(
|
||||
userId=str(currentUser.id),
|
||||
status=NotificationStatus.UNREAD.value
|
||||
)
|
||||
|
||||
return UnreadCountResponse(count=len(notifications))
|
||||
|
|
@ -196,7 +190,7 @@ async def getUnreadCount(
|
|||
|
||||
@router.put("/{notificationId}/read", response_model=Dict[str, Any])
|
||||
@limiter.limit("60/minute")
|
||||
async def markAsRead(
|
||||
def markAsRead(
|
||||
request: Request,
|
||||
notificationId: str,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -207,22 +201,17 @@ async def markAsRead(
|
|||
try:
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Get the notification
|
||||
notifications = rootInterface.db.getRecordset(
|
||||
model_class=UserNotification,
|
||||
recordFilter={"id": notificationId}
|
||||
)
|
||||
# Get the notification (Pydantic model)
|
||||
notification = rootInterface.getNotification(notificationId)
|
||||
|
||||
if not notifications:
|
||||
if not notification:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Notification not found"
|
||||
)
|
||||
|
||||
notification = notifications[0]
|
||||
|
||||
# Verify ownership
|
||||
if notification.get("userId") != currentUser.id:
|
||||
if str(notification.userId) != str(currentUser.id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not authorized to access this notification"
|
||||
|
|
@ -252,7 +241,7 @@ async def markAsRead(
|
|||
|
||||
@router.put("/mark-all-read", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def markAllAsRead(
|
||||
def markAllAsRead(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
|
|
@ -262,13 +251,10 @@ async def markAllAsRead(
|
|||
try:
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Get all unread notifications
|
||||
notifications = rootInterface.db.getRecordset(
|
||||
model_class=UserNotification,
|
||||
recordFilter={
|
||||
"userId": currentUser.id,
|
||||
"status": NotificationStatus.UNREAD.value
|
||||
}
|
||||
# Get all unread notifications (Pydantic models)
|
||||
notifications = rootInterface.getNotificationsByUser(
|
||||
userId=str(currentUser.id),
|
||||
status=NotificationStatus.UNREAD.value
|
||||
)
|
||||
|
||||
currentTime = getUtcTimestamp()
|
||||
|
|
@ -277,7 +263,7 @@ async def markAllAsRead(
|
|||
for notification in notifications:
|
||||
rootInterface.db.recordModify(
|
||||
model_class=UserNotification,
|
||||
recordId=notification.get("id"),
|
||||
recordId=str(notification.id),
|
||||
record={
|
||||
"status": NotificationStatus.READ.value,
|
||||
"readAt": currentTime
|
||||
|
|
@ -297,7 +283,7 @@ async def markAllAsRead(
|
|||
|
||||
@router.post("/{notificationId}/action", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def executeAction(
|
||||
def executeAction(
|
||||
request: Request,
|
||||
notificationId: str,
|
||||
actionRequest: NotificationActionRequest,
|
||||
|
|
@ -309,37 +295,32 @@ async def executeAction(
|
|||
try:
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Get the notification
|
||||
notifications = rootInterface.db.getRecordset(
|
||||
model_class=UserNotification,
|
||||
recordFilter={"id": notificationId}
|
||||
)
|
||||
# Get the notification (Pydantic model)
|
||||
notification = rootInterface.getNotification(notificationId)
|
||||
|
||||
if not notifications:
|
||||
if not notification:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Notification not found"
|
||||
)
|
||||
|
||||
notification = notifications[0]
|
||||
|
||||
# Verify ownership
|
||||
if notification.get("userId") != currentUser.id:
|
||||
if str(notification.userId) != str(currentUser.id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not authorized to access this notification"
|
||||
)
|
||||
|
||||
# Check if already actioned
|
||||
if notification.get("status") == NotificationStatus.ACTIONED.value:
|
||||
if notification.status == NotificationStatus.ACTIONED.value:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Notification has already been actioned"
|
||||
)
|
||||
|
||||
# Validate action exists
|
||||
actions = notification.get("actions", [])
|
||||
validActionIds = [a.get("actionId") if isinstance(a, dict) else a.actionId for a in (actions or [])]
|
||||
actions = notification.actions or []
|
||||
validActionIds = [a.get("actionId") if isinstance(a, dict) else a.actionId for a in actions]
|
||||
|
||||
if actionRequest.actionId not in validActionIds:
|
||||
raise HTTPException(
|
||||
|
|
@ -351,7 +332,7 @@ async def executeAction(
|
|||
actionResult = None
|
||||
|
||||
if notification.get("type") == NotificationType.INVITATION.value:
|
||||
actionResult = await _handleInvitationAction(
|
||||
actionResult = _handleInvitationAction(
|
||||
notification=notification,
|
||||
actionId=actionRequest.actionId,
|
||||
currentUser=currentUser,
|
||||
|
|
@ -389,7 +370,7 @@ async def executeAction(
|
|||
)
|
||||
|
||||
|
||||
async def _handleInvitationAction(
|
||||
def _handleInvitationAction(
|
||||
notification: Dict[str, Any],
|
||||
actionId: str,
|
||||
currentUser: User,
|
||||
|
|
@ -407,22 +388,17 @@ async def _handleInvitationAction(
|
|||
detail="No invitation reference found"
|
||||
)
|
||||
|
||||
# Get the invitation
|
||||
invitations = rootInterface.db.getRecordset(
|
||||
model_class=Invitation,
|
||||
recordFilter={"id": invitationId}
|
||||
)
|
||||
# Get the invitation (Pydantic model)
|
||||
invitation = rootInterface.getInvitation(invitationId)
|
||||
|
||||
if not invitations:
|
||||
if not invitation:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Invitation not found"
|
||||
)
|
||||
|
||||
invitation = invitations[0]
|
||||
|
||||
# Verify username matches
|
||||
if invitation.get("targetUsername") != currentUser.username:
|
||||
if invitation.targetUsername != currentUser.username:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="This invitation is for a different user"
|
||||
|
|
@ -430,19 +406,22 @@ async def _handleInvitationAction(
|
|||
|
||||
# Check if invitation is still valid
|
||||
currentTime = getUtcTimestamp()
|
||||
if invitation.get("expiresAt", 0) < currentTime:
|
||||
expiresAt = invitation.expiresAt or 0
|
||||
if expiresAt < currentTime:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invitation has expired"
|
||||
)
|
||||
|
||||
if invitation.get("revokedAt"):
|
||||
if invitation.revokedAt:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invitation has been revoked"
|
||||
)
|
||||
|
||||
if invitation.get("currentUses", 0) >= invitation.get("maxUses", 1):
|
||||
currentUses = invitation.currentUses or 0
|
||||
maxUses = invitation.maxUses or 1
|
||||
if currentUses >= maxUses:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invitation has reached maximum uses"
|
||||
|
|
@ -450,59 +429,34 @@ async def _handleInvitationAction(
|
|||
|
||||
if actionId == "accept":
|
||||
# Accept the invitation - assign roles and mandate access
|
||||
mandateId = invitation.get("mandateId")
|
||||
roleIds = invitation.get("roleIds", [])
|
||||
mandateId = str(invitation.mandateId) if invitation.mandateId else None
|
||||
roleIds = list(invitation.roleIds or [])
|
||||
|
||||
# Ensure user gets the system "user" role for access to public UI elements (e.g. playground)
|
||||
userRoles = rootInterface.db.getRecordset(
|
||||
model_class=Role,
|
||||
recordFilter={"roleLabel": "user"}
|
||||
)
|
||||
if userRoles:
|
||||
userRoleId = userRoles[0].get("id")
|
||||
userRole = rootInterface.getRoleByLabel("user")
|
||||
if userRole:
|
||||
userRoleId = str(userRole.id)
|
||||
if userRoleId and userRoleId not in roleIds:
|
||||
roleIds = roleIds + [userRoleId]
|
||||
logger.debug(f"Added system 'user' role {userRoleId} to invitation roles")
|
||||
|
||||
# Get mandate name for result message
|
||||
mandates = rootInterface.db.getRecordset(
|
||||
model_class=Mandate,
|
||||
recordFilter={"id": mandateId}
|
||||
)
|
||||
mandateName = mandates[0].get("mandateLabel", mandateId) if mandates else mandateId
|
||||
mandate = rootInterface.getMandate(mandateId) if mandateId else None
|
||||
mandateName = mandate.mandateLabel if mandate and mandate.mandateLabel else mandateId
|
||||
|
||||
# Check if user already has this mandate
|
||||
existingMemberships = rootInterface.db.getRecordset(
|
||||
model_class=UserMandate,
|
||||
recordFilter={
|
||||
"userId": currentUser.id,
|
||||
"mandateId": mandateId
|
||||
}
|
||||
)
|
||||
existingMembership = rootInterface.getUserMandate(str(currentUser.id), mandateId) if mandateId else None
|
||||
|
||||
if existingMemberships:
|
||||
# Update existing membership with new roles
|
||||
existingMembership = existingMemberships[0]
|
||||
existingRoles = existingMembership.get("roleIds", [])
|
||||
mergedRoles = list(set(existingRoles + roleIds))
|
||||
|
||||
rootInterface.db.recordModify(
|
||||
model_class=UserMandate,
|
||||
recordId=existingMembership.get("id"),
|
||||
record={"roleIds": mergedRoles}
|
||||
)
|
||||
logger.info(f"Updated UserMandate for user {currentUser.id} in mandate {mandateId}")
|
||||
if existingMembership:
|
||||
# Update existing membership with new roles via interface
|
||||
# Note: roleIds on UserMandate is deprecated - roles should be assigned via UserMandateRole
|
||||
logger.info(f"User {currentUser.id} already has membership in mandate {mandateId}, adding roles via UserMandateRole")
|
||||
# Add roles via junction table
|
||||
for roleId in roleIds:
|
||||
rootInterface.addRoleToUserMandate(str(existingMembership.id), roleId)
|
||||
else:
|
||||
# Create new user-mandate relationship
|
||||
userMandate = UserMandate(
|
||||
userId=currentUser.id,
|
||||
mandateId=mandateId,
|
||||
roleIds=roleIds
|
||||
)
|
||||
rootInterface.db.recordCreate(
|
||||
model_class=UserMandate,
|
||||
record=userMandate.model_dump()
|
||||
)
|
||||
# Create new user-mandate relationship via interface
|
||||
rootInterface.createUserMandate(str(currentUser.id), mandateId, roleIds)
|
||||
logger.info(f"Created UserMandate for user {currentUser.id} in mandate {mandateId}")
|
||||
|
||||
# Mark invitation as used
|
||||
|
|
@ -510,9 +464,9 @@ async def _handleInvitationAction(
|
|||
model_class=Invitation,
|
||||
recordId=invitationId,
|
||||
record={
|
||||
"usedBy": currentUser.id,
|
||||
"usedBy": str(currentUser.id),
|
||||
"usedAt": currentTime,
|
||||
"currentUses": invitation.get("currentUses", 0) + 1
|
||||
"currentUses": currentUses + 1
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -534,7 +488,7 @@ async def _handleInvitationAction(
|
|||
|
||||
@router.delete("/{notificationId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def deleteNotification(
|
||||
def deleteNotification(
|
||||
request: Request,
|
||||
notificationId: str,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
|
|
@ -545,22 +499,17 @@ async def deleteNotification(
|
|||
try:
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Get the notification
|
||||
notifications = rootInterface.db.getRecordset(
|
||||
model_class=UserNotification,
|
||||
recordFilter={"id": notificationId}
|
||||
)
|
||||
# Get the notification (Pydantic model)
|
||||
notification = rootInterface.getNotification(notificationId)
|
||||
|
||||
if not notifications:
|
||||
if not notification:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Notification not found"
|
||||
)
|
||||
|
||||
notification = notifications[0]
|
||||
|
||||
# Verify ownership
|
||||
if notification.get("userId") != currentUser.id:
|
||||
if str(notification.userId) != str(currentUser.id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not authorized to delete this notification"
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ def _getDatabaseConnector(databaseName: str, userId: str = None) -> DatabaseConn
|
|||
|
||||
@router.get("/tokens")
|
||||
@limiter.limit("30/minute")
|
||||
async def list_tokens(
|
||||
def list_tokens(
|
||||
request: Request,
|
||||
currentUser: User = Depends(requireSysAdmin),
|
||||
userId: Optional[str] = None,
|
||||
|
|
@ -125,8 +125,8 @@ async def list_tokens(
|
|||
if statusFilter:
|
||||
recordFilter["status"] = statusFilter
|
||||
# MULTI-TENANT: SysAdmin sees ALL tokens (no mandate filter)
|
||||
|
||||
tokens = appInterface.db.getRecordset(Token, recordFilter=recordFilter)
|
||||
# Use interface method to get tokens with flexible filtering
|
||||
tokens = appInterface.getAllTokens(recordFilter=recordFilter)
|
||||
return tokens
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -137,7 +137,7 @@ async def list_tokens(
|
|||
|
||||
@router.post("/tokens/revoke/user")
|
||||
@limiter.limit("30/minute")
|
||||
async def revoke_tokens_by_user(
|
||||
def revoke_tokens_by_user(
|
||||
request: Request,
|
||||
currentUser: User = Depends(requireSysAdmin),
|
||||
payload: Dict[str, Any] = Body(...)
|
||||
|
|
@ -172,7 +172,7 @@ async def revoke_tokens_by_user(
|
|||
|
||||
@router.post("/tokens/revoke/session")
|
||||
@limiter.limit("30/minute")
|
||||
async def revoke_tokens_by_session(
|
||||
def revoke_tokens_by_session(
|
||||
request: Request,
|
||||
currentUser: User = Depends(requireSysAdmin),
|
||||
payload: Dict[str, Any] = Body(...)
|
||||
|
|
@ -208,7 +208,7 @@ async def revoke_tokens_by_session(
|
|||
|
||||
@router.post("/tokens/revoke/id")
|
||||
@limiter.limit("30/minute")
|
||||
async def revoke_token_by_id(
|
||||
def revoke_token_by_id(
|
||||
request: Request,
|
||||
currentUser: User = Depends(requireSysAdmin),
|
||||
payload: Dict[str, Any] = Body(...)
|
||||
|
|
@ -235,7 +235,7 @@ async def revoke_token_by_id(
|
|||
|
||||
@router.post("/tokens/revoke/mandate")
|
||||
@limiter.limit("10/minute")
|
||||
async def revoke_tokens_by_mandate(
|
||||
def revoke_tokens_by_mandate(
|
||||
request: Request,
|
||||
currentUser: User = Depends(requireSysAdmin),
|
||||
payload: Dict[str, Any] = Body(...)
|
||||
|
|
@ -254,15 +254,13 @@ async def revoke_tokens_by_mandate(
|
|||
# MULTI-TENANT: SysAdmin can revoke tokens for any mandate
|
||||
appInterface = getRootInterface()
|
||||
|
||||
# Get all UserMandate entries for this mandate to find users
|
||||
# Note: In new model, users are linked via UserMandate, not User.mandateId
|
||||
from modules.datamodels.datamodelMembership import UserMandate
|
||||
userMandates = appInterface.db.getRecordset(UserMandate, recordFilter={"mandateId": mandateId})
|
||||
# Get all UserMandate entries for this mandate to find users using interface method
|
||||
userMandates = appInterface.getUserMandatesByMandate(mandateId)
|
||||
|
||||
total = 0
|
||||
for um in userMandates:
|
||||
total += appInterface.revokeTokensByUser(
|
||||
userId=um["userId"],
|
||||
userId=um.userId,
|
||||
authority=AuthAuthority(authority) if authority else None,
|
||||
mandateId=None, # Revoke all tokens for user
|
||||
revokedBy=currentUser.id,
|
||||
|
|
@ -282,7 +280,7 @@ async def revoke_tokens_by_mandate(
|
|||
|
||||
@router.get("/logs/{log_name}")
|
||||
@limiter.limit("60/minute")
|
||||
async def download_log(
|
||||
def download_log(
|
||||
request: Request,
|
||||
currentUser: User = Depends(requireSysAdmin),
|
||||
log_name: str = "poweron"
|
||||
|
|
@ -311,7 +309,7 @@ async def download_log(
|
|||
|
||||
@router.get("/databases")
|
||||
@limiter.limit("10/minute")
|
||||
async def list_databases(
|
||||
def list_databases(
|
||||
request: Request,
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
) -> Dict[str, Any]:
|
||||
|
|
@ -329,7 +327,7 @@ async def list_databases(
|
|||
|
||||
@router.get("/databases/{database_name}/tables")
|
||||
@limiter.limit("30/minute")
|
||||
async def get_database_tables(
|
||||
def get_database_tables(
|
||||
request: Request,
|
||||
database_name: str,
|
||||
currentUser: User = Depends(requireSysAdmin)
|
||||
|
|
@ -358,7 +356,7 @@ async def get_database_tables(
|
|||
|
||||
@router.post("/databases/{database_name}/tables/{table_name}/drop")
|
||||
@limiter.limit("10/minute")
|
||||
async def drop_table(
|
||||
def drop_table(
|
||||
request: Request,
|
||||
database_name: str,
|
||||
table_name: str,
|
||||
|
|
@ -406,7 +404,7 @@ async def drop_table(
|
|||
|
||||
@router.post("/databases/drop")
|
||||
@limiter.limit("5/minute")
|
||||
async def drop_database(
|
||||
def drop_database(
|
||||
request: Request,
|
||||
currentUser: User = Depends(requireSysAdmin),
|
||||
payload: Dict[str, Any] = Body(...)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import httpx
|
|||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.interfaces.interfaceDbApp import getInterface, getRootInterface
|
||||
from modules.datamodels.datamodelUam import AuthAuthority, User, ConnectionStatus, UserConnection
|
||||
from modules.auth import getCurrentUser, limiter
|
||||
from modules.auth import getCurrentUser, limiter, SECRET_KEY, ALGORITHM
|
||||
from modules.auth import createAccessToken, setAccessTokenCookie, createRefreshToken, setRefreshTokenCookie
|
||||
from modules.auth.tokenManager import TokenManager
|
||||
from modules.shared.timeUtils import createExpirationTimestamp, getUtcTimestamp, parseTimestamp
|
||||
|
|
@ -93,7 +93,7 @@ SCOPES = [
|
|||
]
|
||||
|
||||
@router.get("/config")
|
||||
async def get_config():
|
||||
def get_config():
|
||||
"""Debug endpoint to check Google OAuth configuration"""
|
||||
return {
|
||||
"client_id": CLIENT_ID,
|
||||
|
|
@ -109,7 +109,7 @@ async def get_config():
|
|||
|
||||
@router.get("/login")
|
||||
@limiter.limit("5/minute")
|
||||
async def login(
|
||||
def login(
|
||||
request: Request,
|
||||
state: str = Query("login", description="State parameter to distinguish between login and connection flows"),
|
||||
connectionId: Optional[str] = Query(None, description="Connection ID for connection flow")
|
||||
|
|
@ -171,10 +171,9 @@ async def login(
|
|||
try:
|
||||
if connectionId:
|
||||
rootInterface = getRootInterface()
|
||||
records = rootInterface.db.getRecordset(UserConnection, recordFilter={"id": connectionId})
|
||||
if records:
|
||||
record = records[0]
|
||||
login_hint = record.get("externalEmail") or record.get("externalUsername")
|
||||
connection = rootInterface.getUserConnectionById(connectionId)
|
||||
if connection:
|
||||
login_hint = connection.externalEmail or connection.externalUsername
|
||||
if login_hint:
|
||||
extra_params["login_hint"] = login_hint
|
||||
if "@" in login_hint:
|
||||
|
|
@ -260,23 +259,20 @@ async def auth_callback(code: str, state: str, request: Request, response: Respo
|
|||
rootInterface = getRootInterface()
|
||||
# Prefer connection flow reuse; fallback to user access token
|
||||
if connection_id:
|
||||
existing_tokens = rootInterface.db.getRecordset(Token, recordFilter={
|
||||
"connectionId": connection_id,
|
||||
"authority": AuthAuthority.GOOGLE
|
||||
})
|
||||
existing_tokens = rootInterface.getTokensByConnectionIdAndAuthority(
|
||||
connection_id, AuthAuthority.GOOGLE
|
||||
)
|
||||
if existing_tokens:
|
||||
# Use most recent by createdAt
|
||||
existing_tokens.sort(key=lambda x: parseTimestamp(x.get("createdAt"), default=0), reverse=True)
|
||||
token_response["refresh_token"] = existing_tokens[0].get("tokenRefresh", "")
|
||||
existing_tokens.sort(key=lambda x: parseTimestamp(x.createdAt, default=0), reverse=True)
|
||||
token_response["refresh_token"] = existing_tokens[0].tokenRefresh or ""
|
||||
if not token_response.get("refresh_token") and user_id:
|
||||
existing_access_tokens = rootInterface.db.getRecordset(Token, recordFilter={
|
||||
"userId": user_id,
|
||||
"connectionId": None,
|
||||
"authority": AuthAuthority.GOOGLE
|
||||
})
|
||||
existing_access_tokens = rootInterface.getTokensByUserIdNoConnection(
|
||||
user_id, AuthAuthority.GOOGLE
|
||||
)
|
||||
if existing_access_tokens:
|
||||
existing_access_tokens.sort(key=lambda x: parseTimestamp(x.get("createdAt"), default=0), reverse=True)
|
||||
token_response["refresh_token"] = existing_access_tokens[0].get("tokenRefresh", "")
|
||||
existing_access_tokens.sort(key=lambda x: parseTimestamp(x.createdAt, default=0), reverse=True)
|
||||
token_response["refresh_token"] = existing_access_tokens[0].tokenRefresh or ""
|
||||
except Exception:
|
||||
# Non-fatal; continue without refresh token
|
||||
pass
|
||||
|
|
@ -491,6 +487,10 @@ async def auth_callback(code: str, state: str, request: Request, response: Respo
|
|||
connection.externalId = user_info.get("id")
|
||||
connection.externalUsername = user_info.get("email")
|
||||
connection.externalEmail = user_info.get("email")
|
||||
# Store actually granted scopes for this connection
|
||||
granted_scopes_list = granted_scopes.split(" ") if granted_scopes else SCOPES
|
||||
connection.grantedScopes = granted_scopes_list
|
||||
logger.info(f"Storing granted scopes for connection {connection_id}: {granted_scopes_list}")
|
||||
|
||||
# Update connection record directly
|
||||
rootInterface.db.recordModify(UserConnection, connection_id, connection.model_dump())
|
||||
|
|
@ -589,7 +589,7 @@ async def auth_callback(code: str, state: str, request: Request, response: Respo
|
|||
|
||||
@router.get("/me", response_model=User)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_current_user(
|
||||
def get_current_user(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> User:
|
||||
|
|
@ -605,7 +605,7 @@ async def get_current_user(
|
|||
|
||||
@router.post("/logout")
|
||||
@limiter.limit("10/minute")
|
||||
async def logout(
|
||||
def logout(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ router = APIRouter(
|
|||
|
||||
@router.post("/login")
|
||||
@limiter.limit("30/minute")
|
||||
async def login(
|
||||
def login(
|
||||
request: Request,
|
||||
response: Response,
|
||||
formData: OAuth2PasswordRequestForm = Depends(),
|
||||
|
|
@ -242,7 +242,7 @@ async def login(
|
|||
|
||||
@router.post("/register")
|
||||
@limiter.limit("10/minute")
|
||||
async def register_user(
|
||||
def register_user(
|
||||
request: Request,
|
||||
userData: User = Body(...),
|
||||
frontendUrl: str = Body(..., embed=True)
|
||||
|
|
@ -330,40 +330,34 @@ Falls Sie sich nicht registriert haben, können Sie diese E-Mail ignorieren."""
|
|||
from modules.datamodels.datamodelUam import Mandate
|
||||
|
||||
currentTime = getUtcTimestamp()
|
||||
pendingInvitations = appInterface.db.getRecordset(
|
||||
model_class=Invitation,
|
||||
recordFilter={"targetUsername": userData.username}
|
||||
)
|
||||
pendingInvitations = appInterface.getInvitationsByTargetUsername(userData.username)
|
||||
|
||||
for invitation in pendingInvitations:
|
||||
# Skip expired, revoked, or fully used invitations
|
||||
if invitation.get("expiresAt", 0) < currentTime:
|
||||
if (invitation.expiresAt or 0) < currentTime:
|
||||
continue
|
||||
if invitation.get("revokedAt"):
|
||||
if invitation.revokedAt:
|
||||
continue
|
||||
if invitation.get("currentUses", 0) >= invitation.get("maxUses", 1):
|
||||
if (invitation.currentUses or 0) >= (invitation.maxUses or 1):
|
||||
continue
|
||||
|
||||
# Get mandate name for notification
|
||||
mandateId = invitation.get("mandateId")
|
||||
mandateRecords = appInterface.db.getRecordset(
|
||||
Mandate,
|
||||
recordFilter={"id": mandateId}
|
||||
)
|
||||
mandateName = mandateRecords[0].get("mandateLabel", "PowerOn") if mandateRecords else "PowerOn"
|
||||
# Get mandate name for notification using interface method
|
||||
mandateId = invitation.mandateId
|
||||
mandate = appInterface.getMandate(mandateId)
|
||||
mandateName = mandate.mandateLabel if mandate else "PowerOn"
|
||||
|
||||
# Get inviter name
|
||||
inviterId = invitation.get("createdBy")
|
||||
inviterId = invitation.createdBy
|
||||
inviter = appInterface.getUserById(inviterId) if inviterId else None
|
||||
inviterName = (inviter.fullName or inviter.username) if inviter else "PowerOn"
|
||||
|
||||
createInvitationNotification(
|
||||
userId=str(user.id),
|
||||
invitationId=str(invitation.get("id")),
|
||||
invitationId=str(invitation.id),
|
||||
mandateName=mandateName,
|
||||
inviterName=inviterName
|
||||
)
|
||||
logger.info(f"Created notification for new user {userData.username} for invitation {invitation.get('id')}")
|
||||
logger.info(f"Created notification for new user {userData.username} for invitation {invitation.id}")
|
||||
|
||||
except Exception as notifErr:
|
||||
logger.warning(f"Failed to create notifications for pending invitations: {notifErr}")
|
||||
|
|
@ -387,7 +381,7 @@ Falls Sie sich nicht registriert haben, können Sie diese E-Mail ignorieren."""
|
|||
|
||||
@router.get("/me", response_model=User)
|
||||
@limiter.limit("30/minute")
|
||||
async def read_user_me(
|
||||
def read_user_me(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> User:
|
||||
|
|
@ -403,7 +397,7 @@ async def read_user_me(
|
|||
|
||||
@router.post("/refresh")
|
||||
@limiter.limit("60/minute")
|
||||
async def refresh_token(
|
||||
def refresh_token(
|
||||
request: Request,
|
||||
response: Response
|
||||
) -> Dict[str, Any]:
|
||||
|
|
@ -478,7 +472,7 @@ async def refresh_token(
|
|||
|
||||
@router.post("/logout")
|
||||
@limiter.limit("30/minute")
|
||||
async def logout(request: Request, response: Response, currentUser: User = Depends(getCurrentUser)) -> JSONResponse:
|
||||
def logout(request: Request, response: Response, currentUser: User = Depends(getCurrentUser)) -> JSONResponse:
|
||||
"""Logout from local authentication"""
|
||||
try:
|
||||
# Get user interface with current user context
|
||||
|
|
@ -547,7 +541,7 @@ async def logout(request: Request, response: Response, currentUser: User = Depen
|
|||
|
||||
@router.get("/available")
|
||||
@limiter.limit("10/minute")
|
||||
async def check_username_availability(
|
||||
def check_username_availability(
|
||||
request: Request,
|
||||
username: str,
|
||||
authenticationAuthority: str = "local"
|
||||
|
|
@ -579,7 +573,7 @@ async def check_username_availability(
|
|||
|
||||
@router.post("/password-reset-request")
|
||||
@limiter.limit("5/minute")
|
||||
async def password_reset_request(
|
||||
def password_reset_request(
|
||||
request: Request,
|
||||
username: str = Body(..., embed=True),
|
||||
frontendUrl: str = Body(..., embed=True)
|
||||
|
|
@ -659,7 +653,7 @@ Falls Sie diese Anforderung nicht gestellt haben, können Sie diese E-Mail ignor
|
|||
|
||||
@router.post("/password-reset")
|
||||
@limiter.limit("10/minute")
|
||||
async def password_reset(
|
||||
def password_reset(
|
||||
request: Request,
|
||||
token: str = Body(..., embed=True),
|
||||
password: str = Body(..., embed=True)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from modules.shared.configuration import APP_CONFIG
|
|||
from modules.interfaces.interfaceDbApp import getInterface, getRootInterface
|
||||
from modules.datamodels.datamodelUam import AuthAuthority, User, ConnectionStatus, UserConnection
|
||||
from modules.datamodels.datamodelSecurity import Token
|
||||
from modules.auth import getCurrentUser, limiter
|
||||
from modules.auth import getCurrentUser, limiter, SECRET_KEY, ALGORITHM
|
||||
from modules.auth import createAccessToken, setAccessTokenCookie, createRefreshToken, setRefreshTokenCookie
|
||||
from modules.auth.tokenManager import TokenManager
|
||||
from modules.shared.timeUtils import createExpirationTimestamp, getUtcTimestamp, parseTimestamp
|
||||
|
|
@ -66,7 +66,7 @@ SCOPES = [
|
|||
|
||||
@router.get("/login")
|
||||
@limiter.limit("5/minute")
|
||||
async def login(
|
||||
def login(
|
||||
request: Request,
|
||||
state: str = Query("login", description="State parameter to distinguish between login and connection flows"),
|
||||
connectionId: Optional[str] = Query(None, description="Connection ID for connection flow")
|
||||
|
|
@ -97,11 +97,10 @@ async def login(
|
|||
if connectionId:
|
||||
try:
|
||||
rootInterface = getRootInterface()
|
||||
# Fetch the connection by ID directly
|
||||
records = rootInterface.db.getRecordset(UserConnection, recordFilter={"id": connectionId})
|
||||
if records:
|
||||
record = records[0]
|
||||
login_hint = record.get("externalEmail") or record.get("externalUsername")
|
||||
# Fetch the connection by ID directly using interface method
|
||||
connection = rootInterface.getUserConnectionById(connectionId)
|
||||
if connection:
|
||||
login_hint = connection.externalEmail or connection.externalUsername
|
||||
if login_hint:
|
||||
login_kwargs["login_hint"] = login_hint
|
||||
# Derive domain hint from email/UPN
|
||||
|
|
@ -139,7 +138,7 @@ async def login(
|
|||
|
||||
@router.get("/adminconsent")
|
||||
@limiter.limit("5/minute")
|
||||
async def adminconsent(request: Request) -> RedirectResponse:
|
||||
def adminconsent(request: Request) -> RedirectResponse:
|
||||
"""Initiate Microsoft Admin Consent flow.
|
||||
|
||||
An Azure AD admin must visit this URL once to grant consent for the entire tenant.
|
||||
|
|
@ -162,7 +161,7 @@ async def adminconsent(request: Request) -> RedirectResponse:
|
|||
)
|
||||
|
||||
@router.get("/adminconsent/callback")
|
||||
async def adminconsent_callback(
|
||||
def adminconsent_callback(
|
||||
admin_consent: Optional[str] = Query(None),
|
||||
tenant: Optional[str] = Query(None),
|
||||
error: Optional[str] = Query(None),
|
||||
|
|
@ -499,6 +498,9 @@ async def auth_callback(code: str, state: str, request: Request, response: Respo
|
|||
connection.externalId = user_info.get("id")
|
||||
connection.externalUsername = user_info.get("userPrincipalName")
|
||||
connection.externalEmail = user_info.get("mail")
|
||||
# Store granted scopes for this connection
|
||||
connection.grantedScopes = SCOPES
|
||||
logger.info(f"Storing granted scopes for connection {connection_id}: {SCOPES}")
|
||||
|
||||
# Update connection record directly
|
||||
rootInterface.db.recordModify(UserConnection, connection_id, connection.model_dump())
|
||||
|
|
@ -601,7 +603,7 @@ async def auth_callback(code: str, state: str, request: Request, response: Respo
|
|||
|
||||
@router.get("/me", response_model=User)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_current_user(
|
||||
def get_current_user(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> User:
|
||||
|
|
@ -617,7 +619,7 @@ async def get_current_user(
|
|||
|
||||
@router.post("/logout")
|
||||
@limiter.limit("10/minute")
|
||||
async def logout(
|
||||
def logout(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
|
|
@ -653,7 +655,7 @@ async def logout(
|
|||
|
||||
@router.post("/cleanup")
|
||||
@limiter.limit("5/minute")
|
||||
async def cleanup_expired_tokens(
|
||||
def cleanup_expired_tokens(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -38,13 +38,13 @@ def _getUserRoleIds(userId: str) -> List[str]:
|
|||
rootInterface = getRootInterface()
|
||||
roleIds = []
|
||||
|
||||
userMandates = rootInterface.db.getRecordset(
|
||||
UserMandate,
|
||||
recordFilter={"userId": userId, "enabled": True}
|
||||
)
|
||||
# Get UserMandates as Pydantic models
|
||||
userMandates = rootInterface.getUserMandates(userId)
|
||||
|
||||
for um in userMandates:
|
||||
mandateRoleIds = rootInterface.getRoleIdsForUserMandate(um.get("id"))
|
||||
if not um.enabled:
|
||||
continue
|
||||
mandateRoleIds = rootInterface.getRoleIdsForUserMandate(str(um.id))
|
||||
for rid in mandateRoleIds:
|
||||
if rid not in roleIds:
|
||||
roleIds.append(rid)
|
||||
|
|
@ -60,30 +60,24 @@ def _checkUiPermission(roleIds: List[str], objectKey: str) -> bool:
|
|||
rootInterface = getRootInterface()
|
||||
|
||||
for roleId in roleIds:
|
||||
# Get UI rules for this role
|
||||
rules = rootInterface.db.getRecordset(
|
||||
AccessRule,
|
||||
recordFilter={"roleId": roleId, "context": "UI"}
|
||||
)
|
||||
# Get UI rules for this role (returns Pydantic AccessRule models)
|
||||
rules = rootInterface.getAccessRules(roleId=roleId, context=AccessRuleContext.UI)
|
||||
|
||||
for rule in rules:
|
||||
ruleItem = rule.get("item")
|
||||
ruleView = rule.get("view", False)
|
||||
|
||||
if not ruleView:
|
||||
if not rule.view:
|
||||
continue
|
||||
|
||||
# Global rule (item=None) grants access to all UI
|
||||
if ruleItem is None:
|
||||
if rule.item is None:
|
||||
return True
|
||||
|
||||
# Exact match
|
||||
if ruleItem == objectKey:
|
||||
if rule.item == objectKey:
|
||||
return True
|
||||
|
||||
# Wildcard match (e.g., ui.system.* matches ui.system.playground)
|
||||
if ruleItem.endswith(".*"):
|
||||
prefix = ruleItem[:-2]
|
||||
if rule.item.endswith(".*"):
|
||||
prefix = rule.item[:-2]
|
||||
if objectKey.startswith(prefix):
|
||||
return True
|
||||
|
||||
|
|
@ -108,6 +102,12 @@ def _getFeatureUiObjects(featureCode: str) -> List[Dict[str, Any]]:
|
|||
elif featureCode == "realestate":
|
||||
from modules.features.realestate.mainRealEstate import UI_OBJECTS
|
||||
return UI_OBJECTS
|
||||
elif featureCode == "chatplayground":
|
||||
from modules.features.chatplayground.mainChatplayground import UI_OBJECTS
|
||||
return UI_OBJECTS
|
||||
elif featureCode == "automation":
|
||||
from modules.features.automation.mainAutomation import UI_OBJECTS
|
||||
return UI_OBJECTS
|
||||
else:
|
||||
logger.warning(f"Unknown feature code: {featureCode}")
|
||||
return []
|
||||
|
|
@ -153,7 +153,7 @@ def _buildDynamicBlock(
|
|||
mandateId = str(instance.mandateId)
|
||||
if mandateId not in mandatesMap:
|
||||
mandate = rootInterface.getMandate(mandateId)
|
||||
mandateName = mandate.name if mandate and hasattr(mandate, 'name') else mandateId
|
||||
mandateName = (mandate.label or mandate.name) if mandate else mandateId
|
||||
mandatesMap[mandateId] = {
|
||||
"id": mandateId,
|
||||
"uiLabel": mandateName,
|
||||
|
|
@ -287,67 +287,50 @@ def _getInstanceViewPermissions(
|
|||
permissions = {"_all": False, "isAdmin": False}
|
||||
|
||||
try:
|
||||
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext, Role
|
||||
# Get FeatureAccess for this user and instance (Pydantic model)
|
||||
featureAccess = rootInterface.getFeatureAccess(userId, instanceId)
|
||||
|
||||
# Get FeatureAccess for this user and instance
|
||||
featureAccesses = rootInterface.db.getRecordset(
|
||||
FeatureAccess,
|
||||
recordFilter={"userId": userId, "featureInstanceId": instanceId}
|
||||
)
|
||||
|
||||
if not featureAccesses:
|
||||
if not featureAccess:
|
||||
return permissions
|
||||
|
||||
# Get role IDs via FeatureAccessRole junction table
|
||||
featureAccessId = featureAccesses[0].get("id")
|
||||
featureAccessRoles = rootInterface.db.getRecordset(
|
||||
FeatureAccessRole,
|
||||
recordFilter={"featureAccessId": featureAccessId}
|
||||
)
|
||||
roleIds = [far.get("roleId") for far in featureAccessRoles]
|
||||
# Get role IDs via interface method
|
||||
roleIds = rootInterface.getRoleIdsForFeatureAccess(str(featureAccess.id))
|
||||
|
||||
if not roleIds:
|
||||
return permissions
|
||||
|
||||
# Check if user has admin role
|
||||
for roleId in roleIds:
|
||||
roles = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId})
|
||||
if roles:
|
||||
roleLabel = roles[0].get("roleLabel", "").lower()
|
||||
if "admin" in roleLabel:
|
||||
permissions["isAdmin"] = True
|
||||
break
|
||||
role = rootInterface.getRole(roleId)
|
||||
if role and "admin" in role.roleLabel.lower():
|
||||
permissions["isAdmin"] = True
|
||||
break
|
||||
|
||||
# Get UI permissions from AccessRules
|
||||
# Permissions are stored with full objectKey (e.g., ui.feature.trustee.dashboard)
|
||||
# Get UI permissions from AccessRules (Pydantic models)
|
||||
for roleId in roleIds:
|
||||
accessRules = rootInterface.db.getRecordset(
|
||||
AccessRule,
|
||||
recordFilter={"roleId": roleId, "context": "UI"}
|
||||
)
|
||||
accessRules = rootInterface.getAccessRules(roleId=roleId, context=AccessRuleContext.UI)
|
||||
|
||||
logger.debug(f"_getInstanceViewPermissions: roleId={roleId}, UI rules count={len(accessRules)}")
|
||||
|
||||
for rule in accessRules:
|
||||
if not rule.get("view", False):
|
||||
if not rule.view:
|
||||
continue
|
||||
|
||||
item = rule.get("item")
|
||||
logger.debug(f"_getInstanceViewPermissions: rule item={item}, view={rule.get('view')}")
|
||||
logger.debug(f"_getInstanceViewPermissions: rule item={rule.item}, view={rule.view}")
|
||||
|
||||
if item is None:
|
||||
if rule.item is None:
|
||||
# item=None means all views
|
||||
permissions["_all"] = True
|
||||
else:
|
||||
# Store full objectKey as per Navigation-API-Konzept
|
||||
permissions[item] = True
|
||||
permissions[rule.item] = True
|
||||
|
||||
logger.debug(f"_getInstanceViewPermissions: final permissions={permissions}")
|
||||
return permissions
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting instance view permissions: {e}")
|
||||
return permissions
|
||||
return permissions # Fail-safe: no permissions on error
|
||||
|
||||
|
||||
def _buildStaticBlocks(
|
||||
|
|
@ -426,7 +409,7 @@ def _formatBlockItem(item: Dict[str, Any], language: str) -> Dict[str, Any]:
|
|||
|
||||
@navigationRouter.get("/navigation")
|
||||
@limiter.limit("60/minute")
|
||||
async def get_navigation(
|
||||
def get_navigation(
|
||||
request: Request,
|
||||
language: str = Query("de", description="Language for labels (en, de, fr)"),
|
||||
reqContext: RequestContext = Depends(getRequestContext)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ Multi-Tenant Design:
|
|||
import logging
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext, Role
|
||||
from modules.datamodels.datamodelUam import User, UserPermissions, AccessLevel, Mandate
|
||||
from modules.datamodels.datamodelUam import User, UserPermissions, AccessLevel
|
||||
from modules.datamodels.datamodelMembership import (
|
||||
UserMandate,
|
||||
UserMandateRole,
|
||||
|
|
@ -62,7 +62,7 @@ class RbacClass:
|
|||
|
||||
Multi-Tenant Design:
|
||||
- Lädt Rollen aus UserMandate + UserMandateRole wenn mandateId gegeben
|
||||
- isSysAdmin gibt vollen Zugriff auf System-Level (kein mandateId)
|
||||
- isSysAdmin gibt vollen Zugriff, unabhängig vom Kontext
|
||||
|
||||
Args:
|
||||
user: User object
|
||||
|
|
@ -82,8 +82,8 @@ class RbacClass:
|
|||
delete=AccessLevel.NONE
|
||||
)
|
||||
|
||||
# SysAdmin auf System-Level (kein Mandant) hat vollen Zugriff
|
||||
if hasattr(user, 'isSysAdmin') and user.isSysAdmin and not mandateId:
|
||||
# SysAdmin hat vollen Zugriff - unabhängig vom Kontext (Mandant/Feature)
|
||||
if hasattr(user, 'isSysAdmin') and user.isSysAdmin:
|
||||
return UserPermissions(
|
||||
view=True,
|
||||
read=AccessLevel.ALL,
|
||||
|
|
@ -96,6 +96,7 @@ class RbacClass:
|
|||
roleIds = self._getRoleIdsForUser(user, mandateId, featureInstanceId)
|
||||
|
||||
if not roleIds:
|
||||
logger.debug(f"getUserPermissions: NO roles found for user={user.id}, mandateId={mandateId}, featureInstanceId={featureInstanceId}, item={item}")
|
||||
return permissions
|
||||
|
||||
# Lade alle relevanten Regeln für alle Rollen
|
||||
|
|
@ -155,10 +156,16 @@ class RbacClass:
|
|||
) -> List[str]:
|
||||
"""
|
||||
Get all role IDs for a user in the given context.
|
||||
Uses UserMandate + UserMandateRole for the new multi-tenant model.
|
||||
Uses UserMandate + UserMandateRole for the multi-tenant model.
|
||||
|
||||
Also includes roles from the Root mandate (first mandate) if different
|
||||
from the requested mandate, so system-level permissions are always available.
|
||||
Each mandate has its own instances of system roles (admin, user, viewer)
|
||||
which are copied from the global templates during mandate creation.
|
||||
Therefore, only the requested mandate's roles are loaded - no need to
|
||||
load root mandate roles separately.
|
||||
|
||||
Loads roles from:
|
||||
1. The requested mandate (if provided) - includes mandate-instance system roles
|
||||
2. Feature instance roles (if featureInstanceId provided)
|
||||
|
||||
Args:
|
||||
user: User object
|
||||
|
|
@ -171,36 +178,23 @@ class RbacClass:
|
|||
roleIds = set() # Use set to avoid duplicates
|
||||
|
||||
try:
|
||||
# Get Root mandate ID (first mandate in system)
|
||||
allMandates = self.dbApp.getRecordset(Mandate)
|
||||
rootMandateId = allMandates[0].get("id") if allMandates else None
|
||||
|
||||
# Collect mandates to check:
|
||||
# - If mandateId provided: current mandate + Root mandate (if different)
|
||||
# - If no mandateId: just Root mandate (for system-level access)
|
||||
mandatesToCheck = []
|
||||
# Load roles from the requested mandate
|
||||
if mandateId:
|
||||
mandatesToCheck.append(mandateId)
|
||||
if rootMandateId and rootMandateId not in mandatesToCheck:
|
||||
mandatesToCheck.append(rootMandateId)
|
||||
|
||||
# Load roles from each mandate
|
||||
for checkMandateId in mandatesToCheck:
|
||||
userMandates = self.dbApp.getRecordset(
|
||||
userMandateRecords = self.dbApp.getRecordset(
|
||||
UserMandate,
|
||||
recordFilter={"userId": user.id, "mandateId": checkMandateId, "enabled": True}
|
||||
recordFilter={"userId": user.id, "mandateId": mandateId, "enabled": True}
|
||||
)
|
||||
|
||||
if userMandates:
|
||||
userMandateId = userMandates[0].get("id")
|
||||
if userMandateRecords:
|
||||
userMandateId = userMandateRecords[0]["id"]
|
||||
|
||||
# Lade UserMandateRoles (Mandate-level roles)
|
||||
userMandateRoles = self.dbApp.getRecordset(
|
||||
userMandateRoleRecords = self.dbApp.getRecordset(
|
||||
UserMandateRole,
|
||||
recordFilter={"userMandateId": userMandateId}
|
||||
)
|
||||
|
||||
foundRoles = [r.get("roleId") for r in userMandateRoles if r.get("roleId")]
|
||||
foundRoles = [r["roleId"] for r in userMandateRoleRecords if r.get("roleId")]
|
||||
roleIds.update(foundRoles)
|
||||
|
||||
# Load FeatureAccess + FeatureAccessRole (Instance-level roles)
|
||||
|
|
@ -215,14 +209,14 @@ class RbacClass:
|
|||
)
|
||||
|
||||
if featureAccessRecords:
|
||||
featureAccessId = featureAccessRecords[0].get("id")
|
||||
featureAccessId = featureAccessRecords[0]["id"]
|
||||
|
||||
featureAccessRoles = self.dbApp.getRecordset(
|
||||
featureAccessRoleRecords = self.dbApp.getRecordset(
|
||||
FeatureAccessRole,
|
||||
recordFilter={"featureAccessId": featureAccessId}
|
||||
)
|
||||
|
||||
roleIds.update([r.get("roleId") for r in featureAccessRoles if r.get("roleId")])
|
||||
roleIds.update([r["roleId"] for r in featureAccessRoleRecords if r.get("roleId")])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading role IDs for user {user.id}: {e}")
|
||||
|
|
@ -377,12 +371,14 @@ class RbacClass:
|
|||
if not roleRecords:
|
||||
continue
|
||||
|
||||
role = roleRecords[0]
|
||||
# Convert to Pydantic model for type-safe access
|
||||
roleDict = {k: v for k, v in roleRecords[0].items() if not k.startswith("_")}
|
||||
role = Role(**roleDict)
|
||||
|
||||
# Bestimme Priorität basierend auf Role-Scope
|
||||
if role.get("featureInstanceId"):
|
||||
if role.featureInstanceId:
|
||||
priority = 3 # Instance-specific
|
||||
elif role.get("mandateId"):
|
||||
elif role.mandateId:
|
||||
priority = 2 # Mandate-specific
|
||||
else:
|
||||
priority = 1 # Global
|
||||
|
|
|
|||
|
|
@ -63,10 +63,11 @@ class Services:
|
|||
- Feature-specific Services are loaded dynamically via filename discovery
|
||||
"""
|
||||
|
||||
def __init__(self, user: User, workflow: "ChatWorkflow" = None, mandateId: Optional[str] = None):
|
||||
def __init__(self, user: User, workflow: "ChatWorkflow" = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None):
|
||||
self.user: User = user
|
||||
self.workflow = workflow
|
||||
self.mandateId: Optional[str] = mandateId
|
||||
self.featureInstanceId: Optional[str] = featureInstanceId
|
||||
self.currentUserPrompt: str = ""
|
||||
self.rawUserPrompt: str = ""
|
||||
|
||||
|
|
@ -83,7 +84,7 @@ class Services:
|
|||
# CENTRAL INTERFACE (Chat/Workflow)
|
||||
# ============================================================
|
||||
from modules.interfaces.interfaceDbChat import getInterface as getChatInterface
|
||||
self.interfaceDbChat = getChatInterface(user, mandateId=mandateId)
|
||||
self.interfaceDbChat = getChatInterface(user, mandateId=mandateId, featureInstanceId=featureInstanceId)
|
||||
|
||||
# ============================================================
|
||||
# SHARED SERVICES (from modules/services/)
|
||||
|
|
@ -143,7 +144,7 @@ class Services:
|
|||
|
||||
# Get interface via getInterface()
|
||||
if hasattr(module, "getInterface"):
|
||||
interface = module.getInterface(self.user, mandateId=self.mandateId)
|
||||
interface = module.getInterface(self.user, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId)
|
||||
# Derive attribute name: interfaceFeatureAiChat -> interfaceDbChat
|
||||
attrName = filename.replace("interfaceFeature", "interfaceDb")
|
||||
setattr(self, attrName, interface)
|
||||
|
|
@ -191,6 +192,6 @@ class Services:
|
|||
logger.debug(f"Could not load service from {filepath}: {e}")
|
||||
|
||||
|
||||
def getInterface(user: User, workflow: "ChatWorkflow" = None, mandateId: Optional[str] = None) -> Services:
|
||||
"""Get Services instance for the given user and mandate context."""
|
||||
return Services(user, workflow, mandateId=mandateId)
|
||||
def getInterface(user: User, workflow: "ChatWorkflow" = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> Services:
|
||||
"""Get Services instance for the given user, mandate, and feature instance context."""
|
||||
return Services(user, workflow, mandateId=mandateId, featureInstanceId=featureInstanceId)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,12 @@ from modules.shared.jsonUtils import (
|
|||
)
|
||||
from .subJsonResponseHandling import JsonResponseHandler
|
||||
from modules.datamodels.datamodelAi import JsonAccumulationState
|
||||
from modules.services.serviceBilling.mainServiceBilling import (
|
||||
getService as getBillingService,
|
||||
InsufficientBalanceException,
|
||||
ProviderNotAllowedException,
|
||||
BillingContextError
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -83,14 +89,325 @@ class AiService:
|
|||
async def callAi(self, request: AiCallRequest, progressCallback=None):
|
||||
"""Router: handles content parts via extractionService, text context via interface.
|
||||
|
||||
Replaces direct calls to self.aiObjects.call() to route content parts processing
|
||||
through serviceExtraction layer.
|
||||
FAIL-SAFE BILLING at the source:
|
||||
1. Pre-flight check: validates billing context is complete (RAISES if not)
|
||||
2. Balance & provider check before AI call
|
||||
3. billingCallback on aiObjects: records one billing transaction per model call
|
||||
with exact provider + model name (set before AI call, invoked by _callWithModel)
|
||||
"""
|
||||
if hasattr(request, 'contentParts') and request.contentParts:
|
||||
return await self.extractionService.processContentPartsWithAi(
|
||||
request, self.aiObjects, progressCallback
|
||||
# FAIL-SAFE: Pre-flight billing validation (like 0 CHF credit card check)
|
||||
self._preflightBillingCheck()
|
||||
|
||||
# Balance & provider permission checks
|
||||
await self._checkBillingBeforeAiCall()
|
||||
|
||||
# Calculate effective allowedProviders: RBAC ∩ Workflow
|
||||
effectiveProviders = self._calculateEffectiveProviders()
|
||||
if effectiveProviders and request.options:
|
||||
request.options = request.options.model_copy(update={'allowedProviders': effectiveProviders})
|
||||
logger.debug(f"Effective allowedProviders for AI request: {effectiveProviders}")
|
||||
|
||||
# Set billing callback on aiObjects BEFORE the AI call
|
||||
# This callback is invoked by _callWithModel() after EVERY individual model call
|
||||
# For parallel content parts (e.g., 200 MB doc), each model call creates its own transaction
|
||||
self.aiObjects.billingCallback = self._createBillingCallback()
|
||||
|
||||
try:
|
||||
if hasattr(request, 'contentParts') and request.contentParts:
|
||||
response = await self.extractionService.processContentPartsWithAi(
|
||||
request, self.aiObjects, progressCallback
|
||||
)
|
||||
else:
|
||||
response = await self.aiObjects.callWithTextContext(request)
|
||||
finally:
|
||||
# Clear callback after call completes
|
||||
self.aiObjects.billingCallback = None
|
||||
|
||||
# Store workflow stats for analytics
|
||||
self._storeAiCallStats(response, request)
|
||||
|
||||
return response
|
||||
|
||||
def _preflightBillingCheck(self) -> None:
|
||||
"""
|
||||
Pre-flight billing validation - like a 0 CHF credit card authorization check.
|
||||
|
||||
Validates that ALL required billing context is present and that a billing
|
||||
transaction CAN be recorded. This dry-run check catches missing context
|
||||
BEFORE an expensive AI call starts.
|
||||
|
||||
FAIL-SAFE: This method RAISES if billing context is incomplete.
|
||||
An AI call without billing context MUST NOT proceed.
|
||||
|
||||
Raises:
|
||||
BillingContextError: If billing context is incomplete or invalid
|
||||
"""
|
||||
if not self.services:
|
||||
raise BillingContextError("No service context available - cannot bill AI call")
|
||||
|
||||
user = getattr(self.services, 'user', None)
|
||||
if not user:
|
||||
raise BillingContextError("No user context - cannot bill AI call")
|
||||
|
||||
mandateId = getattr(self.services, 'mandateId', None)
|
||||
if not mandateId:
|
||||
raise BillingContextError(
|
||||
f"No mandateId in service context for user {user.id} - cannot bill AI call. "
|
||||
"Every AI call MUST have a mandate context for billing."
|
||||
)
|
||||
return await self.aiObjects.callWithTextContext(request)
|
||||
|
||||
# Validate billing service can be created
|
||||
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
|
||||
featureCode = getattr(self.services, 'featureCode', None)
|
||||
|
||||
try:
|
||||
billingService = getBillingService(user, mandateId, featureInstanceId, featureCode)
|
||||
except Exception as e:
|
||||
raise BillingContextError(
|
||||
f"Cannot create billing service for user {user.id}, mandate {mandateId}: {e}"
|
||||
)
|
||||
|
||||
# Dry-run: verify billing service can check balance (DB accessible)
|
||||
try:
|
||||
billingService.checkBalance(0.0)
|
||||
except Exception as e:
|
||||
raise BillingContextError(
|
||||
f"Billing system not accessible for mandate {mandateId}: {e}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Pre-flight billing check PASSED: user={user.id}, mandate={mandateId}, "
|
||||
f"feature={featureCode or 'none'}, instance={featureInstanceId or 'none'}"
|
||||
)
|
||||
|
||||
async def _checkBillingBeforeAiCall(self) -> None:
|
||||
"""
|
||||
Check billing status before making an AI call.
|
||||
|
||||
FAIL-SAFE: Context validation is done in _preflightBillingCheck() which is
|
||||
called first. This method handles balance and provider permission checks.
|
||||
|
||||
Verifies:
|
||||
1. User has sufficient balance (for prepay models)
|
||||
2. Provider is allowed for the user (via RBAC)
|
||||
|
||||
Raises:
|
||||
InsufficientBalanceException: If balance is insufficient
|
||||
ProviderNotAllowedException: If provider is not allowed
|
||||
BillingContextError: If billing check fails unexpectedly
|
||||
"""
|
||||
# Context is already validated by _preflightBillingCheck()
|
||||
user = self.services.user
|
||||
mandateId = self.services.mandateId
|
||||
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
|
||||
featureCode = getattr(self.services, 'featureCode', None)
|
||||
|
||||
try:
|
||||
# Get billing service
|
||||
billingService = getBillingService(user, mandateId, featureInstanceId, featureCode)
|
||||
|
||||
# Check balance (estimate typical AI call cost)
|
||||
estimatedCost = 0.01 # ~1 cent CHF minimum
|
||||
balanceCheck = billingService.checkBalance(estimatedCost)
|
||||
|
||||
if not balanceCheck.allowed:
|
||||
logger.warning(
|
||||
f"Billing check failed for user {user.id}: "
|
||||
f"Balance {balanceCheck.currentBalance:.2f} CHF, "
|
||||
f"Reason: {balanceCheck.reason}"
|
||||
)
|
||||
raise InsufficientBalanceException(
|
||||
currentBalance=balanceCheck.currentBalance or 0.0,
|
||||
requiredAmount=estimatedCost,
|
||||
message=f"Ungenugendes Guthaben. Aktuell: CHF {balanceCheck.currentBalance:.2f}"
|
||||
)
|
||||
|
||||
logger.debug(f"Billing check passed: Balance {balanceCheck.currentBalance:.2f} CHF")
|
||||
|
||||
# Check if at least one provider is allowed (RBAC check)
|
||||
rbacAllowedProviders = billingService.getallowedProviders()
|
||||
if not rbacAllowedProviders:
|
||||
logger.warning(f"No AI providers allowed for user {user.id} in mandate {mandateId}")
|
||||
raise ProviderNotAllowedException(
|
||||
provider="any",
|
||||
message="Keine AI-Provider fuer Ihre Rolle freigegeben. Kontaktieren Sie Ihren Administrator."
|
||||
)
|
||||
|
||||
# Check automation-level allowedProviders restriction
|
||||
automationAllowedProviders = getattr(self.services, 'allowedProviders', None)
|
||||
if automationAllowedProviders:
|
||||
effectiveProviders = [p for p in automationAllowedProviders if p in rbacAllowedProviders]
|
||||
if not effectiveProviders:
|
||||
logger.warning(f"No providers available after automation restriction. "
|
||||
f"Automation allows: {automationAllowedProviders}, "
|
||||
f"RBAC allows: {rbacAllowedProviders}")
|
||||
raise ProviderNotAllowedException(
|
||||
provider="any",
|
||||
message="Die konfigurierten AI-Provider dieser Automation sind fuer Ihre Rolle nicht freigegeben."
|
||||
)
|
||||
logger.debug(f"Automation provider check passed: {effectiveProviders}")
|
||||
|
||||
# Check if preferred providers (from UI multiselect) are allowed
|
||||
preferredProviders = getattr(self.services, 'preferredProviders', None)
|
||||
if preferredProviders:
|
||||
for provider in preferredProviders:
|
||||
if provider not in rbacAllowedProviders:
|
||||
logger.warning(f"Preferred provider {provider} not allowed for user {user.id}")
|
||||
raise ProviderNotAllowedException(
|
||||
provider=provider,
|
||||
message=f"Der gewaehlte Provider '{provider}' ist fuer Ihre Rolle nicht freigegeben."
|
||||
)
|
||||
logger.debug(f"All preferred providers are allowed: {preferredProviders}")
|
||||
|
||||
logger.debug(f"Provider check passed: {len(rbacAllowedProviders)} providers allowed")
|
||||
|
||||
except InsufficientBalanceException:
|
||||
raise
|
||||
except ProviderNotAllowedException:
|
||||
raise
|
||||
except BillingContextError:
|
||||
raise
|
||||
except Exception as e:
|
||||
# FAIL-SAFE: Don't silently swallow errors - log at ERROR level
|
||||
logger.error(f"BILLING FAIL-SAFE: Billing check failed with unexpected error: {e}")
|
||||
raise BillingContextError(f"Billing check failed: {e}")
|
||||
|
||||
def _createBillingCallback(self):
|
||||
"""
|
||||
Create a billing callback for interfaceAiObjects._callWithModel().
|
||||
|
||||
Returns a function that records one billing transaction per individual model call.
|
||||
Each transaction contains the exact provider name AND model name.
|
||||
|
||||
For a 200 MB document processed with N parallel AI calls (possibly different models),
|
||||
this creates N separate billing transactions - one per model call.
|
||||
"""
|
||||
user = self.services.user
|
||||
mandateId = self.services.mandateId
|
||||
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
|
||||
featureCode = getattr(self.services, 'featureCode', None)
|
||||
|
||||
# Get workflow ID if available
|
||||
workflowId = None
|
||||
workflow = getattr(self.services, 'workflow', None)
|
||||
if workflow and hasattr(workflow, 'id'):
|
||||
workflowId = workflow.id
|
||||
|
||||
billingService = getBillingService(user, mandateId, featureInstanceId, featureCode)
|
||||
|
||||
def _billingCallback(response) -> None:
|
||||
"""Record billing for a single AI model call."""
|
||||
if not response or getattr(response, 'errorCount', 0) > 0:
|
||||
return
|
||||
|
||||
priceCHF = getattr(response, 'priceCHF', 0.0)
|
||||
if not priceCHF or priceCHF <= 0:
|
||||
return
|
||||
|
||||
provider = getattr(response, 'provider', None) or 'unknown'
|
||||
modelName = getattr(response, 'modelName', None) or 'unknown'
|
||||
|
||||
try:
|
||||
billingService.recordUsage(
|
||||
priceCHF=priceCHF,
|
||||
workflowId=workflowId,
|
||||
aicoreProvider=provider,
|
||||
aicoreModel=modelName,
|
||||
description=f"AI: {modelName}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Billed model call: {priceCHF:.4f} CHF, "
|
||||
f"provider={provider}, model={modelName}, mandate={mandateId}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"BILLING: Failed to record transaction! "
|
||||
f"Cost={priceCHF:.4f} CHF, user={user.id}, mandate={mandateId}, "
|
||||
f"provider={provider}, model={modelName}, error={e}"
|
||||
)
|
||||
|
||||
return _billingCallback
|
||||
|
||||
def _calculateEffectiveProviders(self) -> Optional[List[str]]:
|
||||
"""
|
||||
Calculate effective allowed providers: RBAC ∩ Workflow.
|
||||
|
||||
RBAC is master - only RBAC-permitted providers can ever be used.
|
||||
If workflow specifies allowedProviders, intersect with RBAC.
|
||||
If no workflow providers, use all RBAC-permitted providers.
|
||||
|
||||
Returns:
|
||||
List of effective allowed providers, or None if no filtering needed
|
||||
"""
|
||||
try:
|
||||
user = getattr(self.services, 'user', None)
|
||||
mandateId = getattr(self.services, 'mandateId', None)
|
||||
|
||||
if not user or not mandateId:
|
||||
return None
|
||||
|
||||
# Get RBAC-permitted providers (master list)
|
||||
# Note: getBillingService is imported at module level from mainServiceBilling
|
||||
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
|
||||
featureCode = getattr(self.services, 'featureCode', None)
|
||||
billingService = getBillingService(user, mandateId, featureInstanceId, featureCode)
|
||||
rbacProviders = billingService.getallowedProviders()
|
||||
|
||||
if not rbacProviders:
|
||||
logger.warning("No RBAC-permitted providers found")
|
||||
return None
|
||||
|
||||
# Get workflow-specified providers (optional filter)
|
||||
workflowProviders = getattr(self.services, 'allowedProviders', None)
|
||||
|
||||
if workflowProviders:
|
||||
# Intersect: only providers that are both RBAC-permitted AND workflow-allowed
|
||||
effectiveProviders = [p for p in workflowProviders if p in rbacProviders]
|
||||
logger.debug(f"Provider filter: RBAC={rbacProviders}, Workflow={workflowProviders}, Effective={effectiveProviders}")
|
||||
else:
|
||||
# No workflow filter - use all RBAC-permitted providers
|
||||
effectiveProviders = rbacProviders
|
||||
logger.debug(f"Provider filter: RBAC={rbacProviders} (no workflow filter)")
|
||||
|
||||
return effectiveProviders if effectiveProviders else None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating effective providers: {e}")
|
||||
return None
|
||||
|
||||
def _storeAiCallStats(self, response, request: AiCallRequest) -> None:
|
||||
"""Store workflow stats after an AI call.
|
||||
|
||||
This method stores the AI call statistics (cost, processing time, bytes)
|
||||
to the workflow stats collection for tracking and billing purposes.
|
||||
|
||||
Args:
|
||||
response: AiCallResponse with cost/timing data
|
||||
request: Original AiCallRequest for context
|
||||
"""
|
||||
try:
|
||||
# Skip if no workflow context
|
||||
workflow = getattr(self.services, 'workflow', None)
|
||||
if not workflow or not hasattr(workflow, 'id') or not workflow.id:
|
||||
logger.debug("No workflow context - skipping stats storage")
|
||||
return
|
||||
|
||||
# Skip if response is an error
|
||||
if not response or getattr(response, 'errorCount', 0) > 0:
|
||||
logger.debug("Error response - skipping stats storage")
|
||||
return
|
||||
|
||||
# Determine process name from operation type
|
||||
opType = getattr(request.options, 'operationType', 'unknown') if request.options else 'unknown'
|
||||
process = f"ai.call.{opType}"
|
||||
|
||||
# Store the stat
|
||||
self.services.chat.storeWorkflowStat(workflow, response, process)
|
||||
logger.debug(f"Stored AI call stat: {process}, cost={getattr(response, 'priceCHF', 0):.4f} CHF")
|
||||
|
||||
except Exception as e:
|
||||
# Log but don't fail - stats storage is not critical
|
||||
logger.debug(f"Could not store AI call stat: {str(e)}")
|
||||
|
||||
async def ensureAiObjectsInitialized(self):
|
||||
"""Ensure aiObjects is initialized and submodules are ready."""
|
||||
|
|
@ -314,7 +631,7 @@ Respond with ONLY a JSON object in this exact format:
|
|||
# Debug: persist prompt/response for analysis with context-specific naming
|
||||
debugPrefix = debugType if debugType else "plan"
|
||||
self.services.utils.writeDebugFile(fullPrompt, f"{debugPrefix}_prompt")
|
||||
response = await self.aiObjects.callWithTextContext(request)
|
||||
response = await self.callAi(request) # Use callAi to ensure stats are stored
|
||||
result = response.content or ""
|
||||
self.services.utils.writeDebugFile(result, f"{debugPrefix}_response")
|
||||
return result
|
||||
|
|
@ -371,16 +688,7 @@ Respond with ONLY a JSON object in this exact format:
|
|||
operationType=opType.value
|
||||
)
|
||||
|
||||
# Try to store workflow stats, but don't fail if workflow is None (e.g., in chatbot context)
|
||||
try:
|
||||
self.services.chat.storeWorkflowStat(
|
||||
self.services.workflow,
|
||||
response,
|
||||
f"ai.{opType.name.lower()}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Log but don't fail - workflow might be None in some contexts (e.g., chatbot)
|
||||
logger.debug(f"Could not store workflow stat (workflow may be None): {str(e)}")
|
||||
# Note: Stats are now stored centrally in callAi() - no need to duplicate here
|
||||
|
||||
self.services.chat.progressLogUpdate(aiOperationId, 0.9, f"{opType.name} completed")
|
||||
self.services.chat.progressLogFinish(aiOperationId, True)
|
||||
|
|
|
|||
|
|
@ -269,17 +269,7 @@ class AiCallLooper:
|
|||
# Document generation - save all iteration responses
|
||||
self.services.utils.writeDebugFile(result, f"{debugPrefix}_response_iteration_{iteration}")
|
||||
|
||||
# Emit stats for this iteration (only if workflow exists and has id)
|
||||
if self.services.workflow and hasattr(self.services.workflow, 'id') and self.services.workflow.id:
|
||||
try:
|
||||
self.services.chat.storeWorkflowStat(
|
||||
self.services.workflow,
|
||||
response,
|
||||
f"ai.call.{debugPrefix}.iteration_{iteration}"
|
||||
)
|
||||
except Exception as statError:
|
||||
# Don't break the main loop if stat storage fails
|
||||
logger.warning(f"Failed to store workflow stat: {str(statError)}")
|
||||
# Note: Stats are now stored centrally in callAi() - no need to duplicate here
|
||||
|
||||
# Check for error response using generic error detection (errorCount > 0 or modelName == "error")
|
||||
if hasattr(response, 'errorCount') and response.errorCount > 0:
|
||||
|
|
|
|||
|
|
@ -2574,8 +2574,8 @@ CRITICAL:
|
|||
"""
|
||||
from modules.services.serviceGeneration.renderers.registry import getRenderer
|
||||
|
||||
# Get renderer for this format - NO FALLBACK
|
||||
renderer = getRenderer(outputFormat, self.services)
|
||||
# Get document renderer for this format (structure filling is document generation path)
|
||||
renderer = getRenderer(outputFormat, self.services, outputStyle='document')
|
||||
|
||||
if not renderer:
|
||||
raise ValueError(f"No renderer found for output format '{outputFormat}'. Check renderer registry.")
|
||||
|
|
|
|||
7
modules/services/serviceBilling/__init__.py
Normal file
7
modules/services/serviceBilling/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Billing service module."""
|
||||
|
||||
from .mainServiceBilling import BillingService, getService
|
||||
|
||||
__all__ = ["BillingService", "getService"]
|
||||
417
modules/services/serviceBilling/mainServiceBilling.py
Normal file
417
modules/services/serviceBilling/mainServiceBilling.py
Normal file
|
|
@ -0,0 +1,417 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Billing Service - Central service for billing operations.
|
||||
|
||||
Handles:
|
||||
- Balance checks before AI operations
|
||||
- Cost recording after AI operations
|
||||
- Provider permission checks via RBAC
|
||||
- Price calculation with markup
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelBilling import (
|
||||
BillingModelEnum,
|
||||
BillingCheckResult,
|
||||
TransactionTypeEnum,
|
||||
ReferenceTypeEnum,
|
||||
BillingTransaction,
|
||||
BillingBalanceResponse,
|
||||
)
|
||||
from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Markup percentage for internal pricing (+50% für Infrastruktur und Platform Service + 50% für Währungsrisiko ==> Faktor 2.0)
|
||||
BILLING_MARKUP_PERCENT = 100
|
||||
|
||||
# Singleton cache
|
||||
_billingServices: Dict[str, "BillingService"] = {}
|
||||
|
||||
|
||||
def getService(currentUser: User, mandateId: str, featureInstanceId: str = None, featureCode: str = None) -> "BillingService":
|
||||
"""
|
||||
Factory function to get or create a BillingService instance.
|
||||
|
||||
Args:
|
||||
currentUser: Current user object
|
||||
mandateId: Mandate ID for context
|
||||
featureInstanceId: Optional feature instance ID
|
||||
featureCode: Optional feature code (e.g., 'chatplayground', 'automation')
|
||||
|
||||
Returns:
|
||||
BillingService instance
|
||||
"""
|
||||
cacheKey = f"{currentUser.id}_{mandateId}_{featureInstanceId}"
|
||||
|
||||
if cacheKey not in _billingServices:
|
||||
_billingServices[cacheKey] = BillingService(currentUser, mandateId, featureInstanceId, featureCode)
|
||||
else:
|
||||
_billingServices[cacheKey].setContext(currentUser, mandateId, featureInstanceId, featureCode)
|
||||
|
||||
return _billingServices[cacheKey]
|
||||
|
||||
|
||||
class BillingService:
|
||||
"""
|
||||
Central billing service for AI operations.
|
||||
|
||||
Responsibilities:
|
||||
- Check balance before operations
|
||||
- Record usage costs
|
||||
- Apply pricing markup
|
||||
- Check provider permissions via RBAC
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
currentUser: User,
|
||||
mandateId: str,
|
||||
featureInstanceId: str = None,
|
||||
featureCode: str = None
|
||||
):
|
||||
"""
|
||||
Initialize the billing service.
|
||||
|
||||
Args:
|
||||
currentUser: Current user object
|
||||
mandateId: Mandate ID
|
||||
featureInstanceId: Optional feature instance ID
|
||||
featureCode: Optional feature code
|
||||
"""
|
||||
self.currentUser = currentUser
|
||||
self.mandateId = mandateId
|
||||
self.featureInstanceId = featureInstanceId
|
||||
self.featureCode = featureCode
|
||||
|
||||
# Get billing interface
|
||||
self._billingInterface = getBillingInterface(currentUser, mandateId)
|
||||
|
||||
# Cache settings
|
||||
self._settingsCache = None
|
||||
|
||||
def setContext(
|
||||
self,
|
||||
currentUser: User,
|
||||
mandateId: str,
|
||||
featureInstanceId: str = None,
|
||||
featureCode: str = None
|
||||
):
|
||||
"""Update service context."""
|
||||
self.currentUser = currentUser
|
||||
self.mandateId = mandateId
|
||||
self.featureInstanceId = featureInstanceId
|
||||
self.featureCode = featureCode
|
||||
self._billingInterface = getBillingInterface(currentUser, mandateId)
|
||||
self._settingsCache = None
|
||||
|
||||
def _getSettings(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get billing settings with caching."""
|
||||
if self._settingsCache is None:
|
||||
self._settingsCache = self._billingInterface.getSettings(self.mandateId)
|
||||
return self._settingsCache
|
||||
|
||||
# =========================================================================
|
||||
# Price Calculation
|
||||
# =========================================================================
|
||||
|
||||
def calculatePriceWithMarkup(self, basePriceCHF: float) -> float:
|
||||
"""
|
||||
Calculate final price with markup.
|
||||
|
||||
The AICore plugins return prices in their original currency (USD).
|
||||
This method applies the configured markup percentage.
|
||||
|
||||
Args:
|
||||
basePriceCHF: Base price from AI model (actually USD from provider)
|
||||
|
||||
Returns:
|
||||
Final price in CHF with markup applied
|
||||
"""
|
||||
if basePriceCHF <= 0:
|
||||
return 0.0
|
||||
|
||||
# Apply markup (50% = multiply by 1.5)
|
||||
markup_multiplier = 1 + (BILLING_MARKUP_PERCENT / 100)
|
||||
return round(basePriceCHF * markup_multiplier, 6)
|
||||
|
||||
# =========================================================================
|
||||
# Balance Operations
|
||||
# =========================================================================
|
||||
|
||||
def checkBalance(self, estimatedCost: float = 0.0) -> BillingCheckResult:
|
||||
"""
|
||||
Check if the current user/mandate has sufficient balance.
|
||||
|
||||
Args:
|
||||
estimatedCost: Estimated cost of the operation (with markup applied)
|
||||
|
||||
Returns:
|
||||
BillingCheckResult indicating if operation is allowed
|
||||
"""
|
||||
return self._billingInterface.checkBalance(
|
||||
self.mandateId,
|
||||
self.currentUser.id,
|
||||
estimatedCost
|
||||
)
|
||||
|
||||
def hasBalance(self, estimatedCost: float = 0.0) -> bool:
|
||||
"""
|
||||
Quick check if balance is sufficient.
|
||||
|
||||
Args:
|
||||
estimatedCost: Estimated cost with markup
|
||||
|
||||
Returns:
|
||||
True if operation is allowed
|
||||
"""
|
||||
result = self.checkBalance(estimatedCost)
|
||||
return result.allowed
|
||||
|
||||
def getCurrentBalance(self) -> float:
|
||||
"""
|
||||
Get current balance for the user/mandate.
|
||||
|
||||
Returns:
|
||||
Current balance in CHF
|
||||
"""
|
||||
result = self.checkBalance(0.0)
|
||||
return result.currentBalance or 0.0
|
||||
|
||||
# =========================================================================
|
||||
# Usage Recording
|
||||
# =========================================================================
|
||||
|
||||
def recordUsage(
|
||||
self,
|
||||
priceCHF: float,
|
||||
workflowId: str = None,
|
||||
aicoreProvider: str = None,
|
||||
aicoreModel: str = None,
|
||||
description: str = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Record AI usage cost as a billing transaction.
|
||||
|
||||
This method:
|
||||
1. Applies the pricing markup
|
||||
2. Creates a DEBIT transaction
|
||||
3. Updates the account balance
|
||||
|
||||
Args:
|
||||
priceCHF: Base price from AI model (before markup)
|
||||
workflowId: Optional workflow ID
|
||||
aicoreProvider: AICore provider name (e.g., 'anthropic', 'openai')
|
||||
aicoreModel: AICore model name (e.g., 'claude-4-sonnet', 'gpt-4o')
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
Created transaction dict or None if not recorded
|
||||
"""
|
||||
if priceCHF <= 0:
|
||||
return None
|
||||
|
||||
# Apply markup
|
||||
finalPrice = self.calculatePriceWithMarkup(priceCHF)
|
||||
|
||||
if finalPrice <= 0:
|
||||
return None
|
||||
|
||||
# Build description
|
||||
if not description:
|
||||
description = f"AI Usage: {aicoreModel or aicoreProvider or 'unknown'}"
|
||||
|
||||
return self._billingInterface.recordUsage(
|
||||
mandateId=self.mandateId,
|
||||
userId=self.currentUser.id,
|
||||
priceCHF=finalPrice,
|
||||
workflowId=workflowId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode=self.featureCode,
|
||||
aicoreProvider=aicoreProvider,
|
||||
aicoreModel=aicoreModel,
|
||||
description=description
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Provider Permission Check (via RBAC)
|
||||
# =========================================================================
|
||||
|
||||
def isProviderAllowed(self, provider: str) -> bool:
|
||||
"""
|
||||
Check if the user has permission to use an AICore provider.
|
||||
|
||||
Uses RBAC to check for resource permission:
|
||||
resource.aicore.{provider}
|
||||
|
||||
Args:
|
||||
provider: Provider name (e.g., 'anthropic', 'openai')
|
||||
|
||||
Returns:
|
||||
True if provider is allowed
|
||||
"""
|
||||
try:
|
||||
from modules.security.rbac import RbacClass
|
||||
from modules.datamodels.datamodelRbac import AccessRuleContext
|
||||
from modules.security.rootAccess import getRootDbAppConnector
|
||||
|
||||
# Get database connector via established pattern
|
||||
dbApp = getRootDbAppConnector()
|
||||
|
||||
rbac = RbacClass(dbApp, dbApp)
|
||||
resourceKey = f"resource.aicore.{provider}"
|
||||
|
||||
# Check if user has view permission for this resource (view = use for RESOURCE context)
|
||||
permissions = rbac.getUserPermissions(
|
||||
self.currentUser,
|
||||
AccessRuleContext.RESOURCE,
|
||||
resourceKey,
|
||||
mandateId=self.mandateId
|
||||
)
|
||||
|
||||
return permissions.view
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking provider permission: {e}")
|
||||
# Default to allowed if RBAC check fails
|
||||
return True
|
||||
|
||||
def getallowedProviders(self) -> List[str]:
|
||||
"""
|
||||
Get list of AICore providers the user is allowed to use.
|
||||
|
||||
Returns:
|
||||
List of allowed provider names
|
||||
"""
|
||||
try:
|
||||
from modules.aicore.aicoreModelRegistry import modelRegistry
|
||||
|
||||
# Get all available providers
|
||||
connectors = modelRegistry.discoverConnectors()
|
||||
allProviders = [c.getConnectorType() for c in connectors]
|
||||
|
||||
# Filter by RBAC permissions
|
||||
return [p for p in allProviders if self.isProviderAllowed(p)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting allowed providers: {e}")
|
||||
return []
|
||||
|
||||
# =========================================================================
|
||||
# Admin Operations
|
||||
# =========================================================================
|
||||
|
||||
def addCredit(
|
||||
self,
|
||||
amount: float,
|
||||
description: str = "Manual credit",
|
||||
referenceType: ReferenceTypeEnum = ReferenceTypeEnum.ADMIN
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Add credit to the account (admin operation).
|
||||
|
||||
Args:
|
||||
amount: Amount to credit (positive)
|
||||
description: Transaction description
|
||||
referenceType: Reference type (ADMIN, PAYMENT, SYSTEM)
|
||||
|
||||
Returns:
|
||||
Created transaction dict or None
|
||||
"""
|
||||
if amount <= 0:
|
||||
return None
|
||||
|
||||
settings = self._getSettings()
|
||||
if not settings:
|
||||
logger.warning(f"No billing settings for mandate {self.mandateId}")
|
||||
return None
|
||||
|
||||
billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value))
|
||||
|
||||
# Get or create account
|
||||
if billingModel == BillingModelEnum.PREPAY_USER:
|
||||
account = self._billingInterface.getOrCreateUserAccount(
|
||||
self.mandateId,
|
||||
self.currentUser.id,
|
||||
initialBalance=0.0
|
||||
)
|
||||
else:
|
||||
account = self._billingInterface.getOrCreateMandateAccount(
|
||||
self.mandateId,
|
||||
initialBalance=0.0
|
||||
)
|
||||
|
||||
# Create credit transaction
|
||||
transaction = BillingTransaction(
|
||||
accountId=account["id"],
|
||||
transactionType=TransactionTypeEnum.CREDIT,
|
||||
amount=amount,
|
||||
description=description,
|
||||
referenceType=referenceType
|
||||
)
|
||||
|
||||
return self._billingInterface.createTransaction(transaction)
|
||||
|
||||
# =========================================================================
|
||||
# Statistics & Reporting
|
||||
# =========================================================================
|
||||
|
||||
def getBalancesForUser(self) -> List[BillingBalanceResponse]:
|
||||
"""
|
||||
Get all billing balances for the current user.
|
||||
|
||||
Returns:
|
||||
List of balance responses for each mandate
|
||||
"""
|
||||
return self._billingInterface.getBalancesForUser(self.currentUser.id)
|
||||
|
||||
def getTransactionHistory(self, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get transaction history for the user across all mandates.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of transactions
|
||||
|
||||
Returns:
|
||||
List of transactions
|
||||
"""
|
||||
return self._billingInterface.getTransactionsForUser(self.currentUser.id, limit=limit)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Exception Classes
|
||||
# ============================================================================
|
||||
|
||||
class InsufficientBalanceException(Exception):
|
||||
"""Raised when there's insufficient balance for an operation."""
|
||||
|
||||
def __init__(self, currentBalance: float, requiredAmount: float, message: str = None):
|
||||
self.currentBalance = currentBalance
|
||||
self.requiredAmount = requiredAmount
|
||||
self.message = message or f"Insufficient balance. Current: {currentBalance:.2f} CHF, Required: {requiredAmount:.2f} CHF"
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class ProviderNotAllowedException(Exception):
|
||||
"""Raised when a user doesn't have permission to use an AI provider."""
|
||||
|
||||
def __init__(self, provider: str, message: str = None):
|
||||
self.provider = provider
|
||||
self.message = message or f"Provider '{provider}' is not allowed for your role"
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class BillingContextError(Exception):
|
||||
"""Raised when billing context is incomplete (missing mandateId, user, etc.).
|
||||
|
||||
This is a FAIL-SAFE error: AI calls MUST NOT proceed without valid billing context.
|
||||
Acts like a 0 CHF credit card pre-authorization check - validates that billing
|
||||
CAN be recorded before any expensive AI operation starts.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = None):
|
||||
self.message = message or "Billing context incomplete - AI call blocked"
|
||||
super().__init__(self.message)
|
||||
|
|
@ -674,24 +674,25 @@ class ChatService:
|
|||
return chatLog
|
||||
|
||||
def storeWorkflowStat(self, workflow: Any, aiResponse: Any, process: str) -> ChatStat:
|
||||
"""Persist workflow-level ChatStat from AiCallResponse and append to workflow stats list."""
|
||||
"""Persist workflow-level ChatStat from AiCallResponse and append to workflow stats list.
|
||||
|
||||
Billing is handled at the AI call source (interfaceAiObjects._callWithModel)
|
||||
via billingCallback - not here. This method only handles workflow stats.
|
||||
"""
|
||||
try:
|
||||
# Create ChatStat from AiCallResponse data
|
||||
statData = {
|
||||
"workflowId": workflow.id,
|
||||
"process": process,
|
||||
"engine": aiResponse.modelName,
|
||||
"priceUsd": aiResponse.priceUsd,
|
||||
"priceCHF": aiResponse.priceCHF,
|
||||
"processingTime": aiResponse.processingTime,
|
||||
"bytesSent": aiResponse.bytesSent,
|
||||
"bytesReceived": aiResponse.bytesReceived,
|
||||
"errorCount": aiResponse.errorCount
|
||||
}
|
||||
|
||||
# Create the stat record in the database
|
||||
stat = self.interfaceDbChat.createStat(statData)
|
||||
|
||||
# Append to workflow stats list in memory
|
||||
if not hasattr(workflow, 'stats') or workflow.stats is None:
|
||||
workflow.stats = []
|
||||
workflow.stats.append(stat)
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class ExtractionService:
|
|||
# Verify required internal model is available (used for pricing in extractContent)
|
||||
modelDisplayName = "Internal Document Extractor"
|
||||
model = modelRegistry.getModel(modelDisplayName)
|
||||
if model is None or model.calculatePriceUsd is None:
|
||||
if model is None or model.calculatepriceCHF is None:
|
||||
raise RuntimeError(f"FATAL: Required internal model '{modelDisplayName}' is not available. Check connector registration.")
|
||||
|
||||
def extractContent(
|
||||
|
|
@ -218,18 +218,19 @@ class ExtractionService:
|
|||
modelDisplayName = "Internal Document Extractor"
|
||||
model = modelRegistry.getModel(modelDisplayName)
|
||||
# Hard fail if model is missing; caller must ensure connectors are registered
|
||||
if model is None or model.calculatePriceUsd is None:
|
||||
if model is None or model.calculatepriceCHF is None:
|
||||
if docOperationId:
|
||||
self.services.chat.progressLogFinish(docOperationId, False)
|
||||
raise RuntimeError(f"Pricing model not available: {modelDisplayName}")
|
||||
priceUsd = model.calculatePriceUsd(processingTime, bytesSent, bytesReceived)
|
||||
priceCHF = model.calculatepriceCHF(processingTime, bytesSent, bytesReceived)
|
||||
|
||||
# Create AiCallResponse with real calculation
|
||||
# Use model.name for the response (API identifier), not displayName
|
||||
aiResponse = AiCallResponse(
|
||||
content="", # No content for extraction stats needed
|
||||
modelName=model.name,
|
||||
priceUsd=priceUsd,
|
||||
provider=model.connectorType,
|
||||
priceCHF=priceCHF,
|
||||
processingTime=processingTime,
|
||||
bytesSent=bytesSent,
|
||||
bytesReceived=bytesReceived,
|
||||
|
|
@ -478,7 +479,7 @@ class ExtractionService:
|
|||
"resultSize": len(response.content),
|
||||
"typeGroup": part.typeGroup,
|
||||
"modelName": response.modelName,
|
||||
"priceUsd": response.priceUsd
|
||||
"priceCHF": response.priceCHF
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -606,7 +607,7 @@ class ExtractionService:
|
|||
"originalIndex": i, # Phase 7: Explicit order index
|
||||
"processingOrder": i, # Phase 7: Processing order
|
||||
"modelName": result.modelName,
|
||||
"priceUsd": result.priceUsd,
|
||||
"priceCHF": result.priceCHF,
|
||||
"processingTime": result.processingTime,
|
||||
"bytesSent": result.bytesSent,
|
||||
"bytesReceived": result.bytesReceived
|
||||
|
|
@ -1311,7 +1312,8 @@ class ExtractionService:
|
|||
return AiCallResponse(
|
||||
content=modelResponse.content,
|
||||
modelName=model.name,
|
||||
priceUsd=0.0,
|
||||
provider=model.connectorType,
|
||||
priceCHF=0.0,
|
||||
processingTime=processingTime,
|
||||
bytesSent=0,
|
||||
bytesReceived=0,
|
||||
|
|
@ -1416,7 +1418,8 @@ class ExtractionService:
|
|||
return AiCallResponse(
|
||||
content=mergedContent,
|
||||
modelName=model.name,
|
||||
priceUsd=sum(r.priceUsd for r in chunkResults),
|
||||
provider=model.connectorType,
|
||||
priceCHF=sum(r.priceCHF for r in chunkResults),
|
||||
processingTime=sum(r.processingTime for r in chunkResults),
|
||||
bytesSent=sum(r.bytesSent for r in chunkResults),
|
||||
bytesReceived=sum(r.bytesReceived for r in chunkResults),
|
||||
|
|
@ -1428,49 +1431,6 @@ class ExtractionService:
|
|||
response = await aiObjects._callWithModel(model, prompt, contentPart.data, options)
|
||||
logger.info(f"✅ Content part processed successfully with model: {model.name}")
|
||||
return response
|
||||
chunks = await self.chunkContentPartForAi(contentPart, model, options, prompt)
|
||||
if not chunks:
|
||||
raise ValueError(f"Failed to chunk content part for model {model.name}")
|
||||
|
||||
logger.info(f"Starting to process {len(chunks)} chunks with model {model.name}")
|
||||
|
||||
if progressCallback:
|
||||
progressCallback(0.0, f"Starting to process {len(chunks)} chunks")
|
||||
|
||||
chunkResults = []
|
||||
for idx, chunk in enumerate(chunks):
|
||||
chunkNum = idx + 1
|
||||
chunkData = chunk.get('data', '')
|
||||
logger.info(f"Processing chunk {chunkNum}/{len(chunks)} with model {model.name}")
|
||||
|
||||
if progressCallback:
|
||||
progressCallback(chunkNum / len(chunks), f"Processing chunk {chunkNum}/{len(chunks)}")
|
||||
|
||||
try:
|
||||
chunkResponse = await aiObjects._callWithModel(model, prompt, chunkData, options)
|
||||
chunkResults.append(chunkResponse)
|
||||
logger.info(f"✅ Chunk {chunkNum}/{len(chunks)} processed successfully")
|
||||
|
||||
if progressCallback:
|
||||
progressCallback(chunkNum / len(chunks), f"Chunk {chunkNum}/{len(chunks)} processed")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error processing chunk {chunkNum}/{len(chunks)}: {str(e)}")
|
||||
raise
|
||||
|
||||
# Merge chunk results using unified mergePartResults
|
||||
# Pass original contentPart to preserve typeGroup for all chunks (one-to-many: 1 part -> N chunks)
|
||||
mergedContent = self.mergePartResults(chunkResults, options, [contentPart])
|
||||
|
||||
logger.info(f"✅ Content part chunked and processed with model: {model.name} ({len(chunks)} chunks)")
|
||||
return AiCallResponse(
|
||||
content=mergedContent,
|
||||
modelName=model.name,
|
||||
priceUsd=sum(r.priceUsd for r in chunkResults),
|
||||
processingTime=sum(r.processingTime for r in chunkResults),
|
||||
bytesSent=sum(r.bytesSent for r in chunkResults),
|
||||
bytesReceived=sum(r.bytesReceived for r in chunkResults),
|
||||
errorCount=sum(r.errorCount for r in chunkResults)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
lastError = e
|
||||
|
|
@ -1492,7 +1452,7 @@ class ExtractionService:
|
|||
return AiCallResponse(
|
||||
content=errorMsg,
|
||||
modelName="error",
|
||||
priceUsd=0.0,
|
||||
priceCHF=0.0,
|
||||
processingTime=0.0,
|
||||
bytesSent=inputBytes,
|
||||
bytesReceived=outputBytes,
|
||||
|
|
@ -1622,7 +1582,7 @@ class ExtractionService:
|
|||
return AiCallResponse(
|
||||
content=mergedContent,
|
||||
modelName="multiple",
|
||||
priceUsd=sum(r.priceUsd for r in allResults),
|
||||
priceCHF=sum(r.priceCHF for r in allResults),
|
||||
processingTime=sum(r.processingTime for r in allResults),
|
||||
bytesSent=sum(r.bytesSent for r in allResults),
|
||||
bytesReceived=sum(r.bytesReceived for r in allResults),
|
||||
|
|
|
|||
|
|
@ -74,6 +74,14 @@ class GenerationService:
|
|||
document_data_dict = document_data.dict()
|
||||
elif isinstance(document_data, dict):
|
||||
document_data_dict = document_data
|
||||
elif isinstance(document_data, str):
|
||||
# JSON-String: parsen und als dict speichern (z.B. von outlook.composeAndDraftEmailWithContext)
|
||||
import json
|
||||
try:
|
||||
document_data_dict = json.loads(document_data)
|
||||
except json.JSONDecodeError:
|
||||
# Kein valides JSON - als plain text speichern
|
||||
document_data_dict = {"data": document_data}
|
||||
else:
|
||||
document_data_dict = {"data": str(document_data)}
|
||||
|
||||
|
|
@ -548,10 +556,10 @@ class GenerationService:
|
|||
|
||||
|
||||
def _getFormatRenderer(self, output_format: str):
|
||||
"""Get the appropriate renderer for the specified format using auto-discovery."""
|
||||
"""Get the appropriate document renderer for the specified format."""
|
||||
try:
|
||||
from .renderers.registry import getRenderer, getSupportedFormats
|
||||
renderer = getRenderer(output_format, services=self.services)
|
||||
renderer = getRenderer(output_format, services=self.services, outputStyle='document')
|
||||
|
||||
if renderer:
|
||||
return renderer
|
||||
|
|
@ -565,7 +573,7 @@ class GenerationService:
|
|||
|
||||
# Fallback to text renderer if no specific renderer found
|
||||
logger.warning(f"Falling back to text renderer for format {output_format}")
|
||||
fallbackRenderer = getRenderer('text', services=self.services)
|
||||
fallbackRenderer = getRenderer('text', services=self.services, outputStyle='document')
|
||||
if fallbackRenderer:
|
||||
return fallbackRenderer
|
||||
|
||||
|
|
|
|||
|
|
@ -922,7 +922,7 @@ CRITICAL:
|
|||
"""Get code renderer for file type."""
|
||||
from modules.services.serviceGeneration.renderers.registry import getRenderer
|
||||
|
||||
# Map file types to renderer formats
|
||||
# Map file types to renderer formats (code path)
|
||||
formatMap = {
|
||||
'json': 'json',
|
||||
'csv': 'csv',
|
||||
|
|
@ -931,7 +931,7 @@ CRITICAL:
|
|||
|
||||
rendererFormat = formatMap.get(fileType.lower())
|
||||
if rendererFormat:
|
||||
renderer = getRenderer(rendererFormat, self.services)
|
||||
renderer = getRenderer(rendererFormat, self.services, outputStyle='code')
|
||||
# Check if renderer supports code rendering
|
||||
if renderer and hasattr(renderer, 'renderCodeFiles'):
|
||||
return renderer
|
||||
|
|
|
|||
|
|
@ -101,11 +101,7 @@ class ImageGenerationPath:
|
|||
operationType=OperationTypeEnum.IMAGE_GENERATE.value
|
||||
)
|
||||
|
||||
self.services.chat.storeWorkflowStat(
|
||||
self.services.workflow,
|
||||
response,
|
||||
"ai.generate.image"
|
||||
)
|
||||
# Note: Stats are now stored centrally in callAi() - no need to duplicate here
|
||||
|
||||
self.services.chat.progressLogUpdate(imageOperationId, 0.9, "Image generated")
|
||||
self.services.chat.progressLogFinish(imageOperationId, True)
|
||||
|
|
|
|||
|
|
@ -2,20 +2,30 @@
|
|||
# All rights reserved.
|
||||
"""
|
||||
Renderer registry for automatic discovery and registration of renderers.
|
||||
|
||||
Renderers are indexed by (format, outputStyle) so that document generation
|
||||
and code generation each get the correct renderer for the same format.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import importlib
|
||||
from typing import Dict, Type, List, Optional
|
||||
from typing import Dict, Type, List, Optional, Tuple
|
||||
from .documentRendererBaseTemplate import BaseRenderer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RendererRegistry:
|
||||
"""Registry for automatic renderer discovery and management."""
|
||||
"""Registry for automatic renderer discovery and management.
|
||||
|
||||
Maintains separate renderer mappings per outputStyle ('document', 'code', etc.)
|
||||
so that document-generation and code-generation paths each resolve to the
|
||||
correct renderer, even when both support the same format (e.g. 'csv').
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._renderers: Dict[str, Type[BaseRenderer]] = {}
|
||||
# Key: (formatName, outputStyle) -> rendererClass
|
||||
self._renderers: Dict[Tuple[str, str], Type[BaseRenderer]] = {}
|
||||
self._format_mappings: Dict[str, str] = {}
|
||||
self._discovered = False
|
||||
|
||||
|
|
@ -25,39 +35,27 @@ class RendererRegistry:
|
|||
return
|
||||
|
||||
try:
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Get the directory containing this registry file
|
||||
currentDir = Path(__file__).parent
|
||||
renderersDir = currentDir
|
||||
|
||||
# Get the package name dynamically
|
||||
packageName = __name__.rsplit('.', 1)[0]
|
||||
|
||||
# Scan all Python files in the renderers directory
|
||||
for filePath in renderersDir.glob("*.py"):
|
||||
if filePath.name in ['registry.py', 'documentRendererBaseTemplate.py', '__init__.py']:
|
||||
for filePath in currentDir.glob("*.py"):
|
||||
if filePath.name in ['registry.py', 'documentRendererBaseTemplate.py', 'codeRendererBaseTemplate.py', '__init__.py']:
|
||||
continue
|
||||
|
||||
# Extract module name from filename
|
||||
moduleName = filePath.stem
|
||||
|
||||
try:
|
||||
# Import the module dynamically
|
||||
fullModuleName = f"{packageName}.{moduleName}"
|
||||
module = importlib.import_module(fullModuleName)
|
||||
|
||||
# Look for renderer classes in the module
|
||||
for attrName in dir(module):
|
||||
attr = getattr(module, attrName)
|
||||
if (isinstance(attr, type) and
|
||||
issubclass(attr, BaseRenderer) and
|
||||
attr != BaseRenderer and
|
||||
hasattr(attr, 'getSupportedFormats')):
|
||||
|
||||
# Register the renderer
|
||||
self._registerRendererClass(attr)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -68,60 +66,75 @@ class RendererRegistry:
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during renderer discovery: {str(e)}")
|
||||
self._discovered = True # Mark as discovered to avoid repeated attempts
|
||||
self._discovered = True
|
||||
|
||||
def _registerRendererClass(self, rendererClass: Type[BaseRenderer]) -> None:
|
||||
"""Register a renderer class with its supported formats."""
|
||||
"""Register a renderer class keyed by (format, outputStyle)."""
|
||||
try:
|
||||
# Get supported formats from the renderer class
|
||||
supportedFormats = rendererClass.getSupportedFormats()
|
||||
|
||||
# Get priority (default to 0 if not specified)
|
||||
outputStyle = rendererClass.getOutputStyle() if hasattr(rendererClass, 'getOutputStyle') else 'document'
|
||||
priority = rendererClass.getPriority() if hasattr(rendererClass, 'getPriority') else 0
|
||||
|
||||
for formatName in supportedFormats:
|
||||
formatKey = formatName.lower()
|
||||
registryKey = (formatKey, outputStyle)
|
||||
|
||||
# Check if format already registered - use priority to decide
|
||||
if formatKey in self._renderers:
|
||||
existingRenderer = self._renderers[formatKey]
|
||||
if registryKey in self._renderers:
|
||||
existingRenderer = self._renderers[registryKey]
|
||||
existingPriority = existingRenderer.getPriority() if hasattr(existingRenderer, 'getPriority') else 0
|
||||
|
||||
# Only replace if new renderer has higher priority
|
||||
if priority > existingPriority:
|
||||
logger.debug(f"Replacing {existingRenderer.__name__} with {rendererClass.__name__} for format '{formatName}' (priority {priority} > {existingPriority})")
|
||||
self._renderers[formatKey] = rendererClass
|
||||
logger.debug(f"Replacing {existingRenderer.__name__} with {rendererClass.__name__} for ({formatKey}, {outputStyle}) (priority {priority} > {existingPriority})")
|
||||
self._renderers[registryKey] = rendererClass
|
||||
else:
|
||||
logger.debug(f"Keeping {existingRenderer.__name__} for format '{formatName}' (priority {existingPriority} >= {priority})")
|
||||
logger.debug(f"Keeping {existingRenderer.__name__} for ({formatKey}, {outputStyle}) (priority {existingPriority} >= {priority})")
|
||||
else:
|
||||
# Register primary format
|
||||
self._renderers[formatKey] = rendererClass
|
||||
self._renderers[registryKey] = rendererClass
|
||||
|
||||
# Register aliases if any
|
||||
# Register aliases
|
||||
if hasattr(rendererClass, 'getFormatAliases'):
|
||||
aliases = rendererClass.getFormatAliases()
|
||||
for alias in aliases:
|
||||
self._format_mappings[alias.lower()] = formatName.lower()
|
||||
self._format_mappings[alias.lower()] = formatKey
|
||||
|
||||
logger.debug(f"Registered {rendererClass.__name__} for formats: {supportedFormats} (priority: {priority})")
|
||||
logger.debug(f"Registered {rendererClass.__name__} for formats={supportedFormats}, style={outputStyle}, priority={priority}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering renderer {rendererClass.__name__}: {str(e)}")
|
||||
|
||||
def getRenderer(self, outputFormat: str, services=None) -> Optional[BaseRenderer]:
|
||||
"""Get a renderer instance for the specified format."""
|
||||
def getRenderer(self, outputFormat: str, services=None, outputStyle: str = None) -> Optional[BaseRenderer]:
|
||||
"""Get a renderer instance for the specified format and style.
|
||||
|
||||
Args:
|
||||
outputFormat: Format name (e.g. 'csv', 'json', 'pdf')
|
||||
services: Services instance passed to renderer constructor
|
||||
outputStyle: 'document' or 'code'. If None, returns the first match
|
||||
with preference: document > code (most callers are document path).
|
||||
"""
|
||||
if not self._discovered:
|
||||
self.discoverRenderers()
|
||||
|
||||
# Normalize format name
|
||||
formatName = outputFormat.lower().strip()
|
||||
|
||||
# Check for aliases first
|
||||
if formatName in self._format_mappings:
|
||||
formatName = self._format_mappings[formatName]
|
||||
|
||||
# Get renderer class
|
||||
rendererClass = self._renderers.get(formatName)
|
||||
rendererClass = None
|
||||
|
||||
if outputStyle:
|
||||
# Exact match by style
|
||||
rendererClass = self._renderers.get((formatName, outputStyle))
|
||||
else:
|
||||
# No style specified — prefer 'document', then 'code', then any
|
||||
for style in ['document', 'code']:
|
||||
rendererClass = self._renderers.get((formatName, style))
|
||||
if rendererClass:
|
||||
break
|
||||
# Fallback: check any registered style
|
||||
if not rendererClass:
|
||||
for key, cls in self._renderers.items():
|
||||
if key[0] == formatName:
|
||||
rendererClass = cls
|
||||
break
|
||||
|
||||
if rendererClass:
|
||||
try:
|
||||
|
|
@ -130,7 +143,7 @@ class RendererRegistry:
|
|||
logger.error(f"Error creating renderer instance for {formatName}: {str(e)}")
|
||||
return None
|
||||
|
||||
logger.warning(f"No renderer found for format: {outputFormat}")
|
||||
logger.warning(f"No renderer found for format={outputFormat}, style={outputStyle}")
|
||||
return None
|
||||
|
||||
def getSupportedFormats(self) -> List[str]:
|
||||
|
|
@ -138,9 +151,11 @@ class RendererRegistry:
|
|||
if not self._discovered:
|
||||
self.discoverRenderers()
|
||||
|
||||
formats = list(self._renderers.keys())
|
||||
formats.extend(self._format_mappings.keys())
|
||||
return sorted(set(formats))
|
||||
formats = set()
|
||||
for (fmt, _style) in self._renderers.keys():
|
||||
formats.add(fmt)
|
||||
formats.update(self._format_mappings.keys())
|
||||
return sorted(formats)
|
||||
|
||||
def getRendererInfo(self) -> Dict[str, Dict[str, str]]:
|
||||
"""Get information about all registered renderers."""
|
||||
|
|
@ -148,10 +163,12 @@ class RendererRegistry:
|
|||
self.discoverRenderers()
|
||||
|
||||
info = {}
|
||||
for formatName, rendererClass in self._renderers.items():
|
||||
info[formatName] = {
|
||||
for (formatName, style), rendererClass in self._renderers.items():
|
||||
key = f"{formatName}:{style}"
|
||||
info[key] = {
|
||||
'class_name': rendererClass.__name__,
|
||||
'module': rendererClass.__module__,
|
||||
'outputStyle': style,
|
||||
'description': getattr(rendererClass, '__doc__', 'No description').strip().split('\n')[0] if rendererClass.__doc__ else 'No description'
|
||||
}
|
||||
|
||||
|
|
@ -160,44 +177,62 @@ class RendererRegistry:
|
|||
def getOutputStyle(self, outputFormat: str) -> Optional[str]:
|
||||
"""
|
||||
Get the output style classification for a given format.
|
||||
Returns: 'code', 'document', 'image', or other (e.g., 'video' for future use)
|
||||
When both 'document' and 'code' renderers exist for a format,
|
||||
returns the default ('document') since this is called during document generation.
|
||||
"""
|
||||
if not self._discovered:
|
||||
self.discoverRenderers()
|
||||
|
||||
# Normalize format name
|
||||
formatName = outputFormat.lower().strip()
|
||||
|
||||
# Check for aliases first
|
||||
if formatName in self._format_mappings:
|
||||
formatName = self._format_mappings[formatName]
|
||||
|
||||
# Get renderer class and call getOutputStyle (all renderers have same signature)
|
||||
rendererClass = self._renderers.get(formatName)
|
||||
try:
|
||||
return rendererClass.getOutputStyle(formatName)
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.warning(f"No renderer found for format: {outputFormat}, cannot determine output style")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting output style for {outputFormat}: {str(e)}")
|
||||
return None
|
||||
# Check document first, then code
|
||||
for style in ['document', 'code']:
|
||||
rendererClass = self._renderers.get((formatName, style))
|
||||
if rendererClass:
|
||||
try:
|
||||
return rendererClass.getOutputStyle(formatName)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: any style
|
||||
for key, rendererClass in self._renderers.items():
|
||||
if key[0] == formatName:
|
||||
try:
|
||||
return rendererClass.getOutputStyle(formatName)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.warning(f"No renderer found for format: {outputFormat}, cannot determine output style")
|
||||
return None
|
||||
|
||||
|
||||
# Global registry instance
|
||||
_registry = RendererRegistry()
|
||||
|
||||
def getRenderer(outputFormat: str, services=None) -> Optional[BaseRenderer]:
|
||||
"""Get a renderer instance for the specified format."""
|
||||
return _registry.getRenderer(outputFormat, services)
|
||||
|
||||
def getRenderer(outputFormat: str, services=None, outputStyle: str = None) -> Optional[BaseRenderer]:
|
||||
"""Get a renderer instance for the specified format and style.
|
||||
|
||||
Args:
|
||||
outputFormat: Format name (e.g. 'csv', 'json', 'pdf')
|
||||
services: Services instance
|
||||
outputStyle: 'document' or 'code'. If None, prefers document renderer.
|
||||
"""
|
||||
return _registry.getRenderer(outputFormat, services, outputStyle=outputStyle)
|
||||
|
||||
|
||||
def getSupportedFormats() -> List[str]:
|
||||
"""Get list of all supported formats."""
|
||||
return _registry.getSupportedFormats()
|
||||
|
||||
|
||||
def getRendererInfo() -> Dict[str, Dict[str, str]]:
|
||||
"""Get information about all registered renderers."""
|
||||
return _registry.getRendererInfo()
|
||||
|
||||
|
||||
def getOutputStyle(outputFormat: str) -> Optional[str]:
|
||||
"""Get the output style classification for a given format."""
|
||||
return _registry.getOutputStyle(outputFormat)
|
||||
|
|
|
|||
|
|
@ -35,9 +35,9 @@ class RendererCsv(BaseRenderer):
|
|||
def getAcceptedSectionTypes(cls, formatName: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Return list of section content types that CSV renderer accepts.
|
||||
CSV renderer only accepts table sections.
|
||||
CSV renderer accepts table sections and code_block sections (for raw CSV content).
|
||||
"""
|
||||
return ["table"]
|
||||
return ["table", "code_block"]
|
||||
|
||||
async def render(self, extractedContent: Dict[str, Any], title: str, userPrompt: str = None, aiService=None) -> List[RenderedDocument]:
|
||||
"""Render extracted JSON content to CSV format. Produces one CSV file per table section."""
|
||||
|
|
@ -62,16 +62,24 @@ class RendererCsv(BaseRenderer):
|
|||
if baseFilename.endswith('.csv'):
|
||||
baseFilename = baseFilename[:-4]
|
||||
|
||||
# Find all table sections
|
||||
# Collect CSV-producing sections: table sections AND code_block sections with CSV language
|
||||
tableSections = []
|
||||
codeBlockCsvSections = []
|
||||
for section in sections:
|
||||
sectionType = section.get("content_type", "paragraph")
|
||||
if sectionType == "table":
|
||||
tableSections.append(section)
|
||||
elif sectionType == "code_block":
|
||||
# Check if any element is a code_block with language "csv"
|
||||
for element in section.get("elements", []):
|
||||
content = element.get("content", {})
|
||||
if isinstance(content, dict) and content.get("language", "").lower() == "csv":
|
||||
codeBlockCsvSections.append(section)
|
||||
break
|
||||
|
||||
# If no table sections found, return empty CSV
|
||||
if not tableSections:
|
||||
self.logger.warning("No table sections found in CSV document - returning empty CSV")
|
||||
# If no usable sections found, return empty CSV
|
||||
if not tableSections and not codeBlockCsvSections:
|
||||
self.logger.warning("No table or CSV code_block sections found in CSV document - returning empty CSV")
|
||||
emptyCsv = self._convertRowsToCsv([["No table data available"]])
|
||||
return [
|
||||
RenderedDocument(
|
||||
|
|
@ -83,45 +91,52 @@ class RendererCsv(BaseRenderer):
|
|||
)
|
||||
]
|
||||
|
||||
# Generate one CSV file per table section
|
||||
allCsvSections = tableSections + codeBlockCsvSections
|
||||
|
||||
# Generate one CSV file per section
|
||||
renderedDocuments = []
|
||||
for i, tableSection in enumerate(tableSections):
|
||||
# Generate CSV content for this table section
|
||||
csvRows = []
|
||||
for i, csvSection in enumerate(allCsvSections):
|
||||
sectionType = csvSection.get("content_type", "paragraph")
|
||||
sectionTitle = csvSection.get("title")
|
||||
csvContent = ""
|
||||
|
||||
# Add section title if available
|
||||
sectionTitle = tableSection.get("title")
|
||||
if sectionTitle:
|
||||
csvRows.append([sectionTitle])
|
||||
csvRows.append([]) # Empty row after title
|
||||
if sectionType == "code_block":
|
||||
# Extract raw CSV content directly from code_block elements
|
||||
rawCsvParts = []
|
||||
for element in csvSection.get("elements", []):
|
||||
content = element.get("content", {})
|
||||
if isinstance(content, dict) and content.get("language", "").lower() == "csv":
|
||||
code = content.get("code", "")
|
||||
if code:
|
||||
rawCsvParts.append(code)
|
||||
csvContent = "\n".join(rawCsvParts)
|
||||
else:
|
||||
# Table section — render via table logic
|
||||
csvRows = []
|
||||
if sectionTitle:
|
||||
csvRows.append([sectionTitle])
|
||||
csvRows.append([]) # Empty row after title
|
||||
|
||||
elements = csvSection.get("elements", [])
|
||||
for element in elements:
|
||||
tableRows = self._renderJsonTableToCsv(element)
|
||||
if tableRows:
|
||||
csvRows.extend(tableRows)
|
||||
|
||||
csvContent = self._convertRowsToCsv(csvRows)
|
||||
|
||||
# Render table from section elements
|
||||
elements = tableSection.get("elements", [])
|
||||
for element in elements:
|
||||
tableRows = self._renderJsonTableToCsv(element)
|
||||
if tableRows:
|
||||
csvRows.extend(tableRows)
|
||||
|
||||
# Convert to CSV string
|
||||
csvContent = self._convertRowsToCsv(csvRows)
|
||||
|
||||
# Determine filename for this table
|
||||
if len(tableSections) == 1:
|
||||
# Single table - use base filename
|
||||
# Determine filename
|
||||
if len(allCsvSections) == 1:
|
||||
filename = f"{baseFilename}.csv"
|
||||
else:
|
||||
# Multiple tables - add index or section title to filename
|
||||
sectionId = tableSection.get("id", f"table_{i+1}")
|
||||
# Use section title if available, otherwise use section ID
|
||||
sectionId = csvSection.get("id", f"csv_{i+1}")
|
||||
if sectionTitle:
|
||||
# Sanitize section title for filename
|
||||
safeTitle = "".join(c for c in sectionTitle if c.isalnum() or c in (' ', '-', '_')).strip()
|
||||
safeTitle = safeTitle.replace(' ', '_')[:30] # Limit length
|
||||
safeTitle = safeTitle.replace(' ', '_')[:30]
|
||||
filename = f"{baseFilename}_{safeTitle}.csv"
|
||||
else:
|
||||
filename = f"{baseFilename}_{sectionId}.csv"
|
||||
|
||||
# Extract document type from metadata
|
||||
documentType = metadata.get("documentType") if isinstance(metadata, dict) else None
|
||||
|
||||
renderedDocuments.append(
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ Features can register callbacks to be notified when automations change.
|
|||
|
||||
import logging
|
||||
from typing import Callable, List, Dict, Any
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -25,7 +24,7 @@ class CallbackRegistry:
|
|||
|
||||
Args:
|
||||
event_type: Type of event (e.g., 'automation.changed')
|
||||
callback: Async or sync callback function
|
||||
callback: Sync callback function
|
||||
"""
|
||||
if event_type not in self._callbacks:
|
||||
self._callbacks[event_type] = []
|
||||
|
|
@ -41,8 +40,8 @@ class CallbackRegistry:
|
|||
except ValueError:
|
||||
logger.warning(f"Callback not found for event type: {event_type}")
|
||||
|
||||
async def trigger(self, event_type: str, *args, **kwargs):
|
||||
"""Trigger all callbacks registered for an event type.
|
||||
def trigger(self, event_type: str, *args, **kwargs):
|
||||
"""Trigger all registered callbacks for an event type.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to trigger
|
||||
|
|
@ -55,18 +54,14 @@ class CallbackRegistry:
|
|||
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(*args, **kwargs)
|
||||
else:
|
||||
callback(*args, **kwargs)
|
||||
callback(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing callback for {event_type}: {str(e)}", exc_info=True)
|
||||
|
||||
def has_callbacks(self, event_type: str) -> bool:
|
||||
def hasCallbacks(self, event_type: str) -> bool:
|
||||
"""Check if there are any callbacks registered for an event type."""
|
||||
return event_type in self._callbacks and len(self._callbacks[event_type]) > 0
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
callbackRegistry = CallbackRegistry()
|
||||
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ class EventManagement:
|
|||
self.scheduler.remove_job(jobId)
|
||||
logger.info(f"Removed job '{jobId}'")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Could not remove job '{jobId}': {exc}")
|
||||
logger.debug(f"Could not remove job '{jobId}': {exc}")
|
||||
|
||||
|
||||
# Singleton instance for easy import and reuse
|
||||
|
|
|
|||
|
|
@ -576,22 +576,16 @@ def _deleteUserDataFromFeatureDatabases(userId: str, currentUser) -> Dict[str, A
|
|||
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Get all feature accesses for this user
|
||||
featureAccesses = rootInterface.db.getRecordset(
|
||||
FeatureAccess,
|
||||
recordFilter={"userId": str(userId)}
|
||||
)
|
||||
# Get all feature accesses for this user using interface method
|
||||
featureAccesses = rootInterface.getFeatureAccessesForUser(str(userId))
|
||||
|
||||
# Collect unique feature codes
|
||||
featureCodes: Set[str] = set()
|
||||
for fa in featureAccesses:
|
||||
instanceId = fa.get("featureInstanceId")
|
||||
instanceRecords = rootInterface.db.getRecordset(
|
||||
FeatureInstance,
|
||||
recordFilter={"id": instanceId}
|
||||
)
|
||||
if instanceRecords:
|
||||
featureCode = instanceRecords[0].get("featureCode")
|
||||
instanceId = fa.featureInstanceId
|
||||
instance = rootInterface.getFeatureInstance(instanceId)
|
||||
if instance:
|
||||
featureCode = instance.featureCode
|
||||
if featureCode:
|
||||
featureCodes.add(featureCode)
|
||||
|
||||
|
|
|
|||
|
|
@ -25,11 +25,11 @@ FEATURE_ICON = "mdi-cog"
|
|||
# Block Order (gemäss Navigation-API-Konzept):
|
||||
# - System: 10
|
||||
# - <dynamic/features>: 15 (wird in routeSystem.py eingefügt)
|
||||
# - Workflows: 20
|
||||
# - Basisdaten: 30
|
||||
# - Migrate: 40
|
||||
# - Administration: 200
|
||||
#
|
||||
# NOTE: Workflows and Migrate sections removed - now handled as features
|
||||
#
|
||||
# Item Order: Default-Abstand 10 pro Item
|
||||
# uiComponent: Abgeleitet von objectKey (ui.system.home -> page.system.home)
|
||||
# icon: Wird intern gehalten aber NICHT in der API Response zurückgegeben
|
||||
|
|
@ -60,49 +60,6 @@ NAVIGATION_SECTIONS = [
|
|||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "workflows",
|
||||
"title": {"en": "WORKFLOWS", "de": "WORKFLOWS", "fr": "WORKFLOWS"},
|
||||
"order": 20,
|
||||
"items": [
|
||||
{
|
||||
"id": "playground",
|
||||
"objectKey": "ui.system.playground",
|
||||
"label": {"en": "Chat Playground", "de": "Chat Playground", "fr": "Chat Playground"},
|
||||
"icon": "FaPlay",
|
||||
"path": "/workflows/playground",
|
||||
"order": 10,
|
||||
"public": True,
|
||||
},
|
||||
{
|
||||
"id": "chats",
|
||||
"objectKey": "ui.system.chats",
|
||||
"label": {"en": "Chats", "de": "Chats", "fr": "Chats"},
|
||||
"icon": "FaListAlt",
|
||||
"path": "/workflows/list",
|
||||
"order": 20,
|
||||
"public": True,
|
||||
},
|
||||
{
|
||||
"id": "automations",
|
||||
"objectKey": "ui.system.automations",
|
||||
"label": {"en": "Automations", "de": "Automatisierungen", "fr": "Automatisations"},
|
||||
"icon": "FaCogs",
|
||||
"path": "/workflows/automations",
|
||||
"order": 30,
|
||||
"public": True,
|
||||
},
|
||||
{
|
||||
"id": "automation-templates",
|
||||
"objectKey": "ui.system.automation-templates",
|
||||
"label": {"en": "Templates", "de": "Vorlagen", "fr": "Modèles"},
|
||||
"icon": "FaFileAlt",
|
||||
"path": "/workflows/automation-templates",
|
||||
"order": 35,
|
||||
"public": True,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "basedata",
|
||||
"title": {"en": "BASE DATA", "de": "BASISDATEN", "fr": "DONNÉES DE BASE"},
|
||||
|
|
@ -135,37 +92,17 @@ NAVIGATION_SECTIONS = [
|
|||
],
|
||||
},
|
||||
{
|
||||
"id": "migrate",
|
||||
"title": {"en": "MIGRATE TO FEATURES", "de": "MIGRATE TO FEATURES", "fr": "MIGRER VERS FEATURES"},
|
||||
"order": 40,
|
||||
"deprecated": True,
|
||||
"id": "billing",
|
||||
"title": {"en": "BILLING", "de": "BILLING", "fr": "FACTURATION"},
|
||||
"order": 35,
|
||||
"items": [
|
||||
{
|
||||
"id": "chatbot",
|
||||
"objectKey": "ui.system.chatbot",
|
||||
"label": {"en": "Chatbot", "de": "Chatbot", "fr": "Chatbot"},
|
||||
"icon": "FaComments",
|
||||
"path": "/chatbot",
|
||||
"id": "billing-transactions",
|
||||
"objectKey": "ui.billing.transactions",
|
||||
"label": {"en": "Billing", "de": "Billing", "fr": "Facturation"},
|
||||
"icon": "FaWallet",
|
||||
"path": "/billing/transactions",
|
||||
"order": 10,
|
||||
"deprecated": True,
|
||||
},
|
||||
{
|
||||
"id": "pek",
|
||||
"objectKey": "ui.system.pek",
|
||||
"label": {"en": "PEK", "de": "PEK", "fr": "PEK"},
|
||||
"icon": "FaChartBar",
|
||||
"path": "/pek",
|
||||
"order": 20,
|
||||
"deprecated": True,
|
||||
},
|
||||
{
|
||||
"id": "speech",
|
||||
"objectKey": "ui.system.speech",
|
||||
"label": {"en": "Speech", "de": "Sprache", "fr": "Parole"},
|
||||
"icon": "FaMicrophone",
|
||||
"path": "/speech",
|
||||
"order": 30,
|
||||
"deprecated": True,
|
||||
},
|
||||
],
|
||||
},
|
||||
|
|
@ -175,13 +112,49 @@ NAVIGATION_SECTIONS = [
|
|||
"order": 200,
|
||||
"adminOnly": True,
|
||||
"items": [
|
||||
{
|
||||
"id": "admin-users",
|
||||
"objectKey": "ui.admin.users",
|
||||
"label": {"en": "Users", "de": "Benutzer", "fr": "Utilisateurs"},
|
||||
"icon": "FaUsers",
|
||||
"path": "/admin/users",
|
||||
"order": 10,
|
||||
"adminOnly": True,
|
||||
},
|
||||
{
|
||||
"id": "admin-invitations",
|
||||
"objectKey": "ui.admin.invitations",
|
||||
"label": {"en": "User Invitations", "de": "Benutzer-Einladungen", "fr": "Invitations utilisateurs"},
|
||||
"icon": "FaEnvelopeOpenText",
|
||||
"path": "/admin/invitations",
|
||||
"order": 12,
|
||||
"adminOnly": True,
|
||||
},
|
||||
{
|
||||
"id": "admin-user-access-overview",
|
||||
"objectKey": "ui.admin.userAccessOverview",
|
||||
"label": {"en": "User Access Overview", "de": "Benutzer-Zugriffsübersicht", "fr": "Aperçu des accès utilisateur"},
|
||||
"icon": "FaClipboardList",
|
||||
"path": "/admin/user-access-overview",
|
||||
"order": 14,
|
||||
"adminOnly": True,
|
||||
},
|
||||
{
|
||||
"id": "admin-mandates",
|
||||
"objectKey": "ui.admin.mandates",
|
||||
"label": {"en": "Mandates", "de": "Mandanten", "fr": "Mandats"},
|
||||
"icon": "FaBuilding",
|
||||
"path": "/admin/mandates",
|
||||
"order": 3,
|
||||
"order": 20,
|
||||
"adminOnly": True,
|
||||
},
|
||||
{
|
||||
"id": "admin-user-mandates",
|
||||
"objectKey": "ui.admin.userMandates",
|
||||
"label": {"en": "Mandate Members", "de": "Mandanten-Mitglieder", "fr": "Membres du mandat"},
|
||||
"icon": "FaUserFriends",
|
||||
"path": "/admin/user-mandates",
|
||||
"order": 25,
|
||||
"adminOnly": True,
|
||||
},
|
||||
{
|
||||
|
|
@ -190,27 +163,54 @@ NAVIGATION_SECTIONS = [
|
|||
"label": {"en": "Access Management", "de": "Zugriffsverwaltung", "fr": "Gestion des accès"},
|
||||
"icon": "FaBuilding",
|
||||
"path": "/admin/access",
|
||||
"order": 5,
|
||||
"adminOnly": True,
|
||||
},
|
||||
{
|
||||
"id": "admin-users",
|
||||
"objectKey": "ui.admin.users",
|
||||
"label": {"en": "Users & Invitations", "de": "Benutzer & Einladungen", "fr": "Utilisateurs et invitations"},
|
||||
"icon": "FaUsers",
|
||||
"path": "/admin/users",
|
||||
"order": 10,
|
||||
"order": 30,
|
||||
"adminOnly": True,
|
||||
},
|
||||
{
|
||||
"id": "admin-roles",
|
||||
"objectKey": "ui.admin.roles",
|
||||
"label": {"en": "Roles & Permissions", "de": "Rollen & Berechtigungen", "fr": "Rôles et permissions"},
|
||||
"icon": "FaKey",
|
||||
"label": {"en": "Roles", "de": "Rollen", "fr": "Rôles"},
|
||||
"icon": "FaUserTag",
|
||||
"path": "/admin/mandate-roles",
|
||||
"order": 40,
|
||||
"adminOnly": True,
|
||||
},
|
||||
{
|
||||
"id": "admin-mandate-role-permissions",
|
||||
"objectKey": "ui.admin.mandateRolePermissions",
|
||||
"label": {"en": "Role Permissions", "de": "Rollen-Berechtigungen", "fr": "Permissions des rôles"},
|
||||
"icon": "FaKey",
|
||||
"path": "/admin/mandate-role-permissions",
|
||||
"order": 45,
|
||||
"adminOnly": True,
|
||||
},
|
||||
{
|
||||
"id": "admin-feature-instances",
|
||||
"objectKey": "ui.admin.featureInstances",
|
||||
"label": {"en": "Feature Instances", "de": "Feature-Instanzen", "fr": "Instances de features"},
|
||||
"icon": "FaCubes",
|
||||
"path": "/admin/feature-instances",
|
||||
"order": 48,
|
||||
"adminOnly": True,
|
||||
},
|
||||
{
|
||||
"id": "admin-feature-roles",
|
||||
"objectKey": "ui.admin.featureRoles",
|
||||
"label": {"en": "Feature Roles & Permissions", "de": "Features Rollen & Rechte", "fr": "Rôles et droits des features"},
|
||||
"icon": "FaShieldAlt",
|
||||
"path": "/admin/feature-roles",
|
||||
"order": 50,
|
||||
"adminOnly": True,
|
||||
},
|
||||
{
|
||||
"id": "admin-billing",
|
||||
"objectKey": "ui.admin.billing",
|
||||
"label": {"en": "Billing Administration", "de": "Billing-Verwaltung", "fr": "Administration de facturation"},
|
||||
"icon": "FaMoneyBillAlt",
|
||||
"path": "/admin/billing",
|
||||
"order": 60,
|
||||
"adminOnly": True,
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
|
@ -363,6 +363,43 @@ RESOURCE_OBJECTS = [
|
|||
]
|
||||
|
||||
|
||||
def _discoverAicoreProviderObjects() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Dynamically discover AICore provider resources for the RBAC catalog.
|
||||
Providers are discovered from the model registry at startup.
|
||||
"""
|
||||
providerLabels = {
|
||||
"anthropic": {"en": "Anthropic (Claude)", "de": "Anthropic (Claude)", "fr": "Anthropic (Claude)"},
|
||||
"openai": {"en": "OpenAI (GPT)", "de": "OpenAI (GPT)", "fr": "OpenAI (GPT)"},
|
||||
"perplexity": {"en": "Perplexity", "de": "Perplexity", "fr": "Perplexity"},
|
||||
"tavily": {"en": "Tavily (Web Search)", "de": "Tavily (Websuche)", "fr": "Tavily (Recherche Web)"},
|
||||
"privatellm": {"en": "Private LLM", "de": "Private LLM", "fr": "LLM Privé"},
|
||||
"internal": {"en": "Internal", "de": "Intern", "fr": "Interne"},
|
||||
}
|
||||
|
||||
try:
|
||||
from modules.aicore.aicoreModelRegistry import modelRegistry
|
||||
connectors = modelRegistry.discoverConnectors()
|
||||
providers = [c.getConnectorType() for c in connectors]
|
||||
|
||||
objects = []
|
||||
for provider in providers:
|
||||
label = providerLabels.get(provider, {"en": provider, "de": provider, "fr": provider})
|
||||
objects.append({
|
||||
"objectKey": f"resource.aicore.{provider}",
|
||||
"label": label,
|
||||
"meta": {"provider": provider, "category": "aicore"}
|
||||
})
|
||||
|
||||
if objects:
|
||||
logger.info(f"Discovered {len(objects)} AICore provider catalog objects: {providers}")
|
||||
return objects
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to discover AICore providers for catalog: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def registerFeature(catalogService) -> bool:
|
||||
"""
|
||||
Register system RBAC objects in the catalog.
|
||||
|
|
@ -401,6 +438,16 @@ def registerFeature(catalogService) -> bool:
|
|||
meta=resObj.get("meta")
|
||||
)
|
||||
|
||||
# Register dynamically discovered AICore provider resources
|
||||
aicoreObjects = _discoverAicoreProviderObjects()
|
||||
for aicoreObj in aicoreObjects:
|
||||
catalogService.registerResourceObject(
|
||||
featureCode=FEATURE_CODE,
|
||||
objectKey=aicoreObj["objectKey"],
|
||||
label=aicoreObj["label"],
|
||||
meta=aicoreObj.get("meta")
|
||||
)
|
||||
|
||||
# Register feature definition
|
||||
catalogService.registerFeatureDefinition(
|
||||
featureCode=FEATURE_CODE,
|
||||
|
|
|
|||
|
|
@ -86,22 +86,20 @@ def loadFeatureRouters(app: FastAPI) -> Dict[str, Any]:
|
|||
logger.error(f"Failed to load router from {featureDir}: {e}")
|
||||
results[featureDir] = {"status": "error", "error": str(e)}
|
||||
|
||||
# Register features in RBAC catalog and sync template roles to database
|
||||
from modules.security.rbacCatalog import getCatalogService
|
||||
catalogService = getCatalogService()
|
||||
registrationResults = registerAllFeaturesInCatalog(catalogService)
|
||||
|
||||
for featureName, success in registrationResults.items():
|
||||
if featureName in results:
|
||||
results[featureName]["rbac_registered"] = success
|
||||
|
||||
return results
|
||||
|
||||
|
||||
_cachedMainModules = None
|
||||
|
||||
def loadFeatureMainModules() -> Dict[str, Any]:
|
||||
"""
|
||||
Dynamically load main modules from all discovered feature containers.
|
||||
Results are cached after the first call.
|
||||
"""
|
||||
global _cachedMainModules
|
||||
if _cachedMainModules is not None:
|
||||
return _cachedMainModules
|
||||
|
||||
mainModules = {}
|
||||
pattern = os.path.join(FEATURES_DIR, "*", "main*.py")
|
||||
|
||||
|
|
@ -113,6 +111,10 @@ def loadFeatureMainModules() -> Dict[str, Any]:
|
|||
featureDir = os.path.basename(os.path.dirname(filepath))
|
||||
if featureDir.startswith("_"):
|
||||
continue
|
||||
|
||||
# Skip if this feature already has a main module loaded (avoid duplicates)
|
||||
if featureDir in mainModules:
|
||||
continue
|
||||
|
||||
mainFile = filename[:-3] # Remove .py
|
||||
|
||||
|
|
@ -124,6 +126,7 @@ def loadFeatureMainModules() -> Dict[str, Any]:
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to load main module from {featureDir}: {e}")
|
||||
|
||||
_cachedMainModules = mainModules
|
||||
return mainModules
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from .subAutomationUtils import parseScheduleToCron, planToPrompt, replacePlaceh
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode: WorkflowModeEnum, workflowId: Optional[str] = None, mandateId: Optional[str] = None) -> ChatWorkflow:
|
||||
async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode: WorkflowModeEnum, workflowId: Optional[str] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None, featureCode: Optional[str] = None) -> ChatWorkflow:
|
||||
"""
|
||||
Starts a new chat or continues an existing one, then launches processing asynchronously.
|
||||
|
||||
|
|
@ -32,14 +32,23 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode
|
|||
currentUser: Current user
|
||||
userInput: User input request
|
||||
workflowId: Optional workflow ID to continue existing workflow
|
||||
workflowMode: "Dynamic" for iterative dynamic-style processing, "Automation" for automated workflow execution
|
||||
mandateId: Mandate ID from request context (required for proper data isolation)
|
||||
|
||||
Example usage for Dynamic mode:
|
||||
workflow = await chatStart(currentUser, userInput, workflowMode=WorkflowModeEnum.WORKFLOW_DYNAMIC, mandateId=mandateId)
|
||||
workflowMode: Workflow mode (Dynamic, Automation, etc.)
|
||||
mandateId: Mandate ID (required for billing)
|
||||
featureInstanceId: Feature instance ID (required for billing)
|
||||
featureCode: Feature code (e.g., 'chatplayground', 'automation')
|
||||
"""
|
||||
try:
|
||||
services = getServices(currentUser, mandateId=mandateId)
|
||||
services = getServices(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId)
|
||||
|
||||
# Store allowedProviders in services context for model selection
|
||||
if hasattr(userInput, 'allowedProviders') and userInput.allowedProviders:
|
||||
services.allowedProviders = userInput.allowedProviders
|
||||
logger.info(f"AI provider filter active: {userInput.allowedProviders}")
|
||||
|
||||
# Store feature code in services (for billing)
|
||||
if featureCode:
|
||||
services.featureCode = featureCode
|
||||
|
||||
workflowManager = WorkflowManager(services)
|
||||
workflow = await workflowManager.workflowStart(userInput, workflowMode, workflowId)
|
||||
return workflow
|
||||
|
|
@ -47,10 +56,12 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode
|
|||
logger.error(f"Error starting chat: {str(e)}")
|
||||
raise
|
||||
|
||||
async def chatStop(currentUser: User, workflowId: str, mandateId: Optional[str] = None) -> ChatWorkflow:
|
||||
async def chatStop(currentUser: User, workflowId: str, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> ChatWorkflow:
|
||||
"""Stops a running chat."""
|
||||
try:
|
||||
services = getServices(currentUser, mandateId=mandateId)
|
||||
services = getServices(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId)
|
||||
if featureInstanceId:
|
||||
services.featureCode = 'chatplayground'
|
||||
workflowManager = WorkflowManager(services)
|
||||
return await workflowManager.workflowStop(workflowId)
|
||||
except Exception as e:
|
||||
|
|
@ -58,12 +69,17 @@ async def chatStop(currentUser: User, workflowId: str, mandateId: Optional[str]
|
|||
raise
|
||||
|
||||
|
||||
async def executeAutomation(automationId: str, services) -> ChatWorkflow:
|
||||
"""Execute automation workflow immediately (test mode) with placeholder replacement.
|
||||
async def executeAutomation(automationId: str, automation, creatorUser: User, services) -> ChatWorkflow:
|
||||
"""Execute automation workflow with the creator user's context.
|
||||
|
||||
The automation object and creatorUser are resolved by the caller (handler)
|
||||
using the SysAdmin eventUser. This function does NOT re-load them.
|
||||
|
||||
Args:
|
||||
automationId: ID of automation to execute
|
||||
services: Services instance for data access
|
||||
automation: Pre-loaded automation object (with system fields like _createdBy)
|
||||
creatorUser: The user who created the automation (workflow runs in this context)
|
||||
services: Services instance (used for interfaceDbApp etc.)
|
||||
|
||||
Returns:
|
||||
ChatWorkflow instance created by automation execution
|
||||
|
|
@ -77,14 +93,20 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
|
|||
}
|
||||
|
||||
try:
|
||||
# 1. Load automation definition (with system fields for _createdBy access)
|
||||
automation = services.interfaceDbAutomation.getAutomationDefinition(automationId, includeSystemFields=True)
|
||||
if not automation:
|
||||
raise ValueError(f"Automation {automationId} not found")
|
||||
|
||||
executionLog["messages"].append(f"Started execution at {executionStartTime}")
|
||||
|
||||
# 2. Replace placeholders in template to generate plan
|
||||
# Store allowed providers from automation in services context
|
||||
if hasattr(automation, 'allowedProviders') and automation.allowedProviders:
|
||||
services.allowedProviders = automation.allowedProviders
|
||||
logger.debug(f"Automation {automationId} restricted to providers: {automation.allowedProviders}")
|
||||
|
||||
# Context comes EXCLUSIVELY from the automation definition
|
||||
automationMandateId = str(automation.mandateId)
|
||||
automationFeatureInstanceId = str(automation.featureInstanceId)
|
||||
|
||||
logger.info(f"Executing automation {automationId} as user {creatorUser.id} with mandateId={automationMandateId}, featureInstanceId={automationFeatureInstanceId}")
|
||||
|
||||
# 1. Replace placeholders in template to generate plan
|
||||
template = automation.template or ""
|
||||
placeholders = automation.placeholders or {}
|
||||
planJson = replacePlaceholders(template, placeholders)
|
||||
|
|
@ -102,24 +124,9 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
|
|||
logger.error(f"Context around error: ...{planJson[start:end]}...")
|
||||
raise ValueError(f"Invalid JSON after placeholder replacement: {str(e)}")
|
||||
executionLog["messages"].append("Template placeholders replaced successfully")
|
||||
executionLog["messages"].append(f"Using creator user: {creatorUser.id}")
|
||||
|
||||
# 3. Get user who created automation
|
||||
creatorUserId = getattr(automation, "_createdBy", None)
|
||||
|
||||
# _createdBy is a system attribute - must be present
|
||||
if not creatorUserId:
|
||||
errorMsg = f"Automation {automationId} has no creator user (_createdBy field missing). Cannot execute automation."
|
||||
logger.error(errorMsg)
|
||||
executionLog["messages"].append(errorMsg)
|
||||
raise ValueError(errorMsg)
|
||||
|
||||
# Get creator user from database
|
||||
creatorUser = services.interfaceDbApp.getUser(creatorUserId)
|
||||
if not creatorUser:
|
||||
raise ValueError(f"Creator user {creatorUserId} not found")
|
||||
executionLog["messages"].append(f"Using creator user: {creatorUserId}")
|
||||
|
||||
# 4. Create UserInputRequest from plan
|
||||
# 2. Create UserInputRequest from plan
|
||||
# Embed plan JSON in prompt for TemplateMode to extract
|
||||
promptText = planToPrompt(plan)
|
||||
planJsonStr = json.dumps(plan)
|
||||
|
|
@ -134,12 +141,16 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
|
|||
|
||||
executionLog["messages"].append("Starting workflow execution")
|
||||
|
||||
# 5. Start workflow using chatStart
|
||||
# 3. Start workflow using chatStart with creator's context
|
||||
# mandateId and featureInstanceId come from the automation definition
|
||||
workflow = await chatStart(
|
||||
currentUser=creatorUser,
|
||||
userInput=userInput,
|
||||
workflowMode=WorkflowModeEnum.WORKFLOW_AUTOMATION,
|
||||
workflowId=None
|
||||
workflowId=None,
|
||||
mandateId=automationMandateId,
|
||||
featureInstanceId=automationFeatureInstanceId,
|
||||
featureCode='automation'
|
||||
)
|
||||
|
||||
executionLog["workflowId"] = workflow.id
|
||||
|
|
@ -153,17 +164,14 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
|
|||
workflow = services.interfaceDbChat.updateWorkflow(workflow.id, {"name": workflowName})
|
||||
logger.info(f"Set workflow {workflow.id} name to: {workflowName}")
|
||||
|
||||
# Update automation with execution log
|
||||
# Save execution log (bypasses RBAC — system operation, not a user edit)
|
||||
executionLogs = list(automation.executionLogs or [])
|
||||
executionLogs.append(executionLog)
|
||||
# Keep only last 50 executions
|
||||
if len(executionLogs) > 50:
|
||||
executionLogs = executionLogs[-50:]
|
||||
|
||||
services.interfaceDbAutomation.updateAutomationDefinition(
|
||||
automationId,
|
||||
{"executionLogs": executionLogs}
|
||||
)
|
||||
services.interfaceDbAutomation._saveExecutionLog(automationId, executionLogs)
|
||||
|
||||
return workflow
|
||||
except Exception as e:
|
||||
|
|
@ -171,26 +179,23 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
|
|||
executionLog["status"] = "error"
|
||||
executionLog["messages"].append(f"Error: {str(e)}")
|
||||
|
||||
# Update automation with execution log even on error
|
||||
# Save execution log even on error (bypasses RBAC — system operation)
|
||||
# Use the automation object already passed in (no re-load needed)
|
||||
try:
|
||||
automation = services.interfaceDbAutomation.getAutomationDefinition(automationId)
|
||||
if automation:
|
||||
executionLogs = list(automation.executionLogs or [])
|
||||
executionLogs.append(executionLog)
|
||||
if len(executionLogs) > 50:
|
||||
executionLogs = executionLogs[-50:]
|
||||
services.interfaceDbAutomation.updateAutomationDefinition(
|
||||
automationId,
|
||||
{"executionLogs": executionLogs}
|
||||
)
|
||||
executionLogs = list(getattr(automation, 'executionLogs', None) or [])
|
||||
executionLogs.append(executionLog)
|
||||
if len(executionLogs) > 50:
|
||||
executionLogs = executionLogs[-50:]
|
||||
services.interfaceDbAutomation._saveExecutionLog(automationId, executionLogs)
|
||||
except Exception as logError:
|
||||
logger.error(f"Error saving execution log: {str(logError)}")
|
||||
|
||||
raise
|
||||
|
||||
|
||||
async def syncAutomationEvents(services, eventUser) -> Dict[str, Any]:
|
||||
"""Automation event handler - syncs scheduler with all active automations.
|
||||
def syncAutomationEvents(services, eventUser) -> Dict[str, Any]:
|
||||
"""Sync scheduler with all active automations.
|
||||
All operations (DB reads, scheduler registration) are synchronous.
|
||||
|
||||
Args:
|
||||
services: Services instance for data access
|
||||
|
|
@ -291,37 +296,28 @@ def createAutomationEventHandler(automationId: str, eventUser):
|
|||
logger.error("Event user not available for automation execution")
|
||||
return
|
||||
|
||||
# Get services for event user (provides access to interfaces)
|
||||
# Load automation using SysAdmin eventUser (has unrestricted access)
|
||||
eventServices = getServices(eventUser, None)
|
||||
|
||||
# Load automation using event user context (with system fields for _createdBy access)
|
||||
automation = eventServices.interfaceDbAutomation.getAutomationDefinition(automationId, includeSystemFields=True)
|
||||
if not automation or not getattr(automation, "active", False):
|
||||
logger.warning(f"Automation {automationId} not found or not active, skipping execution")
|
||||
return
|
||||
|
||||
# Get creator user
|
||||
# Get creator user ID from automation's _createdBy system field
|
||||
creatorUserId = getattr(automation, "_createdBy", None)
|
||||
if not creatorUserId:
|
||||
logger.error(f"Automation {automationId} has no creator user")
|
||||
logger.error(f"Automation {automationId} has no creator user (_createdBy missing)")
|
||||
return
|
||||
|
||||
# Get mandate context from automation definition
|
||||
automationMandateId = getattr(automation, "mandateId", None)
|
||||
|
||||
# Get creator user from database using services
|
||||
eventServices = getServices(eventUser, None)
|
||||
# Get creator user from database (using SysAdmin access)
|
||||
creatorUser = eventServices.interfaceDbApp.getUser(creatorUserId)
|
||||
if not creatorUser:
|
||||
logger.error(f"Creator user {creatorUserId} not found for automation {automationId}")
|
||||
return
|
||||
|
||||
# Get services for creator user WITH mandate context from automation
|
||||
creatorServices = getServices(creatorUser, automationMandateId)
|
||||
|
||||
# Execute automation with creator user's context and mandate
|
||||
# executeAutomation is in same module, so we can call it directly
|
||||
await executeAutomation(automationId, creatorServices)
|
||||
# Execute automation — pass automation object and creatorUser directly
|
||||
# No re-load needed in executeAutomation
|
||||
await executeAutomation(automationId, automation, creatorUser, eventServices)
|
||||
logger.info(f"Successfully executed automation {automationId} as user {creatorUserId}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing automation {automationId}: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -14,9 +14,10 @@ from modules.services import getInterface as getServices
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def start(eventUser) -> None:
|
||||
def start(eventUser) -> bool:
|
||||
"""
|
||||
Start automation scheduler and sync scheduled events.
|
||||
All operations are synchronous (DB access, scheduler registration).
|
||||
|
||||
Args:
|
||||
eventUser: System-level event user for background operations (provided by app.py)
|
||||
|
|
@ -33,16 +34,16 @@ async def start(eventUser) -> None:
|
|||
services = getServices(eventUser, None)
|
||||
|
||||
# Register callback for automation changes
|
||||
async def onAutomationChanged(chatInterface):
|
||||
def onAutomationChanged(chatInterface):
|
||||
"""Callback triggered when automations are created/updated/deleted."""
|
||||
eventServices = getServices(eventUser, None)
|
||||
await syncAutomationEvents(eventServices, eventUser)
|
||||
syncAutomationEvents(eventServices, eventUser)
|
||||
|
||||
callbackRegistry.register('automation.changed', onAutomationChanged)
|
||||
logger.info("Automation: Registered change callback")
|
||||
|
||||
# Initial sync on startup
|
||||
await syncAutomationEvents(services, eventUser)
|
||||
syncAutomationEvents(services, eventUser)
|
||||
logger.info("Automation: Scheduled events synced on startup")
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -52,7 +53,7 @@ async def start(eventUser) -> None:
|
|||
return True
|
||||
|
||||
|
||||
async def stop(eventUser) -> None:
|
||||
def stop(eventUser) -> bool:
|
||||
"""
|
||||
Stop automation scheduler.
|
||||
|
||||
|
|
|
|||
|
|
@ -153,7 +153,7 @@ async def process(self, parameters: Dict[str, Any]) -> ActionResult:
|
|||
metadata=AiResponseMetadata(
|
||||
additionalData={
|
||||
"modelName": aiResponse_obj.modelName,
|
||||
"priceUsd": aiResponse_obj.priceUsd,
|
||||
"priceCHF": aiResponse_obj.priceCHF,
|
||||
"processingTime": aiResponse_obj.processingTime,
|
||||
"bytesSent": aiResponse_obj.bytesSent,
|
||||
"bytesReceived": aiResponse_obj.bytesReceived,
|
||||
|
|
|
|||
|
|
@ -139,11 +139,16 @@ class MethodBase:
|
|||
return False
|
||||
|
||||
# RBAC-Check: RESOURCE context, item = actionId
|
||||
# mandateId/featureInstanceId from services context needed to resolve user roles
|
||||
try:
|
||||
mandateId = getattr(self.services, 'mandateId', None)
|
||||
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
|
||||
permissions = self.services.rbac.getUserPermissions(
|
||||
user=currentUser,
|
||||
context=AccessRuleContext.RESOURCE,
|
||||
item=actionId
|
||||
item=actionId,
|
||||
mandateId=str(mandateId) if mandateId else None,
|
||||
featureInstanceId=str(featureInstanceId) if featureInstanceId else None
|
||||
)
|
||||
hasPermission = permissions.view
|
||||
if not hasPermission:
|
||||
|
|
@ -151,8 +156,9 @@ class MethodBase:
|
|||
userRoles = getattr(currentUser, 'roleLabels', []) or []
|
||||
self.logger.warning(
|
||||
f"RBAC denied action {actionId} for user {currentUser.id}. "
|
||||
f"User roles: {userRoles}, "
|
||||
f"Permissions: view={permissions.view}, edit={permissions.edit}, delete={permissions.delete}. "
|
||||
f"User roles: {userRoles}, mandateId={mandateId}, "
|
||||
f"Permissions: view={permissions.view}, read={permissions.read}, "
|
||||
f"create={permissions.create}, update={permissions.update}, delete={permissions.delete}. "
|
||||
f"No matching RBAC rule found for context=RESOURCE, item={actionId}"
|
||||
)
|
||||
return hasPermission
|
||||
|
|
|
|||
|
|
@ -13,16 +13,17 @@ logger = logging.getLogger(__name__)
|
|||
async def composeAndDraftEmailWithContext(self, parameters: Dict[str, Any]) -> ActionResult:
|
||||
try:
|
||||
connectionReference = parameters.get("connectionReference")
|
||||
to = parameters.get("to")
|
||||
to = parameters.get("to") or [] # Optional for drafts - can save draft without recipients
|
||||
context = parameters.get("context")
|
||||
documentList = parameters.get("documentList", [])
|
||||
cc = parameters.get("cc", [])
|
||||
bcc = parameters.get("bcc", [])
|
||||
emailStyle = parameters.get("emailStyle", "business")
|
||||
maxLength = parameters.get("maxLength", 1000)
|
||||
documentList = parameters.get("documentList") or []
|
||||
cc = parameters.get("cc") or []
|
||||
bcc = parameters.get("bcc") or []
|
||||
emailStyle = parameters.get("emailStyle") or "business"
|
||||
maxLength = parameters.get("maxLength") or 1000
|
||||
|
||||
if not connectionReference or not to or not context:
|
||||
return ActionResult.isFailure(error="connectionReference, to, and context are required")
|
||||
# Only connectionReference and context are required - to is optional for drafts
|
||||
if not connectionReference or not context:
|
||||
return ActionResult.isFailure(error="connectionReference and context are required")
|
||||
|
||||
# Convert single values to lists for all recipient parameters
|
||||
if isinstance(to, str):
|
||||
|
|
@ -82,12 +83,15 @@ async def composeAndDraftEmailWithContext(self, parameters: Dict[str, Any]) -> A
|
|||
# Escape only the user-controlled context to prevent prompt injection
|
||||
escaped_context = context.replace('"', '\\"').replace('\n', '\\n').replace('\r', '\\r')
|
||||
|
||||
# Build recipients text for prompt
|
||||
recipients_text = f"Recipients: {to}" if to else "Recipients: (not specified - this is a draft)"
|
||||
|
||||
ai_prompt = f"""Compose an email based on this context:
|
||||
-------
|
||||
{escaped_context}
|
||||
-------
|
||||
|
||||
Recipients: {to}
|
||||
{recipients_text}
|
||||
Style: {emailStyle}
|
||||
Max length: {maxLength} characters
|
||||
{doc_list_text}
|
||||
|
|
|
|||
|
|
@ -90,15 +90,20 @@ async def sendDraftEmail(self, parameters: Dict[str, Any]) -> ActionResult:
|
|||
else:
|
||||
jsonContent = str(fileData)
|
||||
|
||||
# Parse JSON - handle both direct JSON and JSON wrapped in documentData
|
||||
# Parse JSON - handle ActionDocument format with validationMetadata wrapper
|
||||
try:
|
||||
draftEmailData = json.loads(jsonContent)
|
||||
|
||||
# If the JSON contains a 'documentData' field, extract it
|
||||
# ActionDocument format: { "validationMetadata": {...}, "documentData": {...} }
|
||||
# Extract documentData which contains the actual draft email data
|
||||
if isinstance(draftEmailData, dict) and 'documentData' in draftEmailData:
|
||||
documentDataStr = draftEmailData['documentData']
|
||||
if isinstance(documentDataStr, str):
|
||||
draftEmailData = json.loads(documentDataStr)
|
||||
documentDataContent = draftEmailData['documentData']
|
||||
# documentData should be a dict (parsed from JSON by processSingleDocument)
|
||||
if isinstance(documentDataContent, dict):
|
||||
draftEmailData = documentDataContent
|
||||
elif isinstance(documentDataContent, str):
|
||||
# Legacy/fallback: parse if still a string
|
||||
draftEmailData = json.loads(documentDataContent)
|
||||
|
||||
# Validate draft email structure
|
||||
if not isinstance(draftEmailData, dict):
|
||||
|
|
|
|||
|
|
@ -84,6 +84,14 @@ class ConnectionHelper:
|
|||
elif response.status_code == 403:
|
||||
logger.error("Permission denied - connection lacks necessary mail permissions")
|
||||
logger.error("Required scopes: Mail.ReadWrite, Mail.Send, Mail.ReadWrite.Shared")
|
||||
logger.error("Solution: User must reconnect and grant mail permissions")
|
||||
return False
|
||||
elif response.status_code == 404:
|
||||
# 404 on /me/mailFolders typically means the token lacks mail scopes
|
||||
# This happens when the connection was created without mail permissions
|
||||
logger.error("Mail API not accessible (404) - token likely lacks mail scopes")
|
||||
logger.error("This usually means the connection was created without Mail.ReadWrite permission")
|
||||
logger.error("Solution: User must delete the connection and reconnect, granting mail permissions")
|
||||
return False
|
||||
else:
|
||||
logger.warning(f"Permission check returned status {response.status_code}")
|
||||
|
|
|
|||
|
|
@ -150,8 +150,8 @@ class MethodOutlook(MethodBase):
|
|||
name="to",
|
||||
type="List[str]",
|
||||
frontendType=FrontendType.MULTISELECT,
|
||||
required=True,
|
||||
description="Recipient email addresses"
|
||||
required=False,
|
||||
description="Recipient email addresses (optional for drafts)"
|
||||
),
|
||||
"context": WorkflowActionParameter(
|
||||
name="context",
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue