Merge pull request #110 from valueonag/int

Int
This commit is contained in:
Patrick Motsch 2026-03-19 13:43:36 +01:00 committed by GitHub
commit bb0941ffa4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
233 changed files with 15586 additions and 41731 deletions

View file

@ -39,12 +39,11 @@ Service_GOOGLE_REDIRECT_URI = http://localhost:8000/api/google/auth/callback
STRIPE_SECRET_KEY_SECRET = DEV_ENC:Z0FBQUFBQnBudkpGWDkxSldfM0NCZ3dmbHY5cS1nQlI3UWZ4ZWRrNVdUdEFKa25RckRiQWY0c1E5MjVsZzlfRkZEU0VFU2tNQ01qZnRNQ0pZVU9hVFN6OEU0RXhwdTl3algzLWJlSXRhYmZlMHltSC1XejlGWEU5TDF1LUlYNEh1aG9tRFI4YmlCYzUyei02U1dabWoyb0N2dVFSb1RhWTNnQjBCZkFjV0FfOWdYdDVpX1k5R2pYM1R6SHRiaE10V1l1dnQybjVHWDRiQUJLM0UxRDZnczhJZGFsc3JhOU82QT09
STRIPE_WEBHOOK_SECRET = DEV_ENC:Z0FBQUFBQnBudkpGcHNWTWpBWkFHRExtdU01N3RyZzNsMjhUS3NiVTNCZmMwN2NEcFZ6UkQ1a2I0aUkyNU4wR2dUdHJXYmtkaEFRUnFpcThObHBEQmJkdEFnT1FXeUxOTlU3UDFNRzl6LWdpRFpYdExvY3FTTG9MTkswdEhrVkNKQVFucnBjSnhLNm4=
STRIPE_API_VERSION = 2026-01-28.clover
APP_FRONTEND_URL = http://localhost:5176
# AI configuration
Connector_AiOpenai_API_SECRET = DEV_ENC:Z0FBQUFBQnBaSnM4TWFRRmxVQmNQblVIYmc1Y0Q3aW9zZUtDWlNWdGZjbFpncGp2NHN2QjkxMWxibUJnZDBId252MWk5TXN3Yk14ajFIdi1CTkx2ZWx2QzF5OFR6LUx5azQ3dnNLaXJBOHNxc0tlWmtZcTFVelF4eXBSM2JkbHd2eTM0VHNXdHNtVUprZWtPVzctNlJsZHNmM20tU1N6Q1Q2cHFYSi1tNlhZNDNabTVuaEVGWmIydEhadTcyMlBURmw2aUJxOF9GTzR0dTZiNGZfOFlHaVpPZ1A1LXhhOEFtN1J5TEVNNWtMcGpyNkMzSl8xRnZsaTF1WTZrOUZmb0cxVURjSGFLS2dIYTQyZEJtTm90bEYxVWxNNXVPdTVjaVhYbXhxT3JsVDM5VjZMVFZKSE1tZnM9
Connector_AiAnthropic_API_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpENmFBWG16STFQUVZxNzZZRzRLYTA4X3lRanF1VkF4cU45OExNMzlsQmdISGFxTUxud1dXODBKcFhMVG9KNjdWVnlTTFFROVc3NDlsdlNHLUJXeG41NDBHaXhHR0VHVWl5UW9RNkVWbmlhakRKVW5pM0R4VHk0LUw0TV9LdkljNHdBLXJua21NQkl2b3l4UkVkMGN1YjBrMmJEeWtMay1jbmxrYWJNbUV0aktCXzU1djR2d2RSQXZORTNwcG92ZUVvVGMtQzQzTTVncEZTRGRtZUFIZWQ0dz09
Connector_AiPerplexity_API_SECRET = DEV_ENC:Z0FBQUFBQm82Mzk2Q1MwZ0dNcUVBcUtuRDJIcTZkMXVvYnpjM3JEMzJiT1NKSHljX282ZDIyZTJYc09VSTdVNXAtOWU2UXp5S193NTk5dHJsWlFjRjhWektFOG1DVGY4ZUhHTXMzS0RPN1lNcF9nSlVWbW5BZ1hkZDVTejl6bVZNRFVvX29xamJidWRFMmtjQmkyRUQ2RUh6UTN1aWNPSUJBPT0=
Connector_AiPerplexity_API_SECRET = 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
Connector_AiTavily_API_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpEQTdnUHMwd2pIaXNtMmtCTFREd0pyQXRKb1F5eGtHSnkyOGZiUnlBOFc0b3Vzcndrc3ViRm1nMDJIOEZKYWxqdWNkZGh5N0Z4R0JlQmxXSG5pVnJUR2VYckZhMWNMZ1FNeXJ3enJLVlpiblhOZTNleUg3ZzZyUzRZanFSeDlVMkI=
Connector_AiPrivateLlm_API_SECRET = DEV_ENC:Z0FBQUFBQnBudkpGRHM5eFdUVmVZU1R1cHBwN1RlMUx4T0NlLTJLUFFVX3J2OElDWFpuZmJHVmp4Z3BNNWMwZUVVZUd2TFhRSjVmVkVlcFlVRWtybXh0ZHloZ01ZcnVvX195YjdlWVdEcjZSWFFTTlNBWUlaTlNoLWhqVFBIb0thVlBiaWhjYjFQOFY=
Connector_AiMistral_API_SECRET = DEV_ENC:Z0FBQUFBQnBudkpGeEQxYUIxOHhia0JlQWpWQ2dWQWZzY3l6SWwyUnJoR1hRQWloX2lxb2lGNkc4UnA4U2tWNjJaYzB1d1hvNG9fWUp1N3V4OW9FMGhaWVhjSlVwWEc1X2loVDBSZDEtdHdfcTA5QkcxQTR4OHc4RkRzclJrU2d1RFZpNDJkRDRURlE=

View file

@ -36,15 +36,14 @@ Service_MSFT_REDIRECT_URI = https://gateway-int.poweron-center.net/api/msft/auth
Service_GOOGLE_REDIRECT_URI = https://gateway-int.poweron-center.net/api/google/auth/callback
# Stripe Billing (both end with _SECRET for encryption script)
STRIPE_SECRET_KEY_SECRET = INT_ENC:Z0FBQUFBQnBudkpGOTVvaGhuSTVRTW5Fa0x3akFwZktQX21DZnFGQWgteDJwUWpnbTV5enRmbmZyWXlpY2lKVVNINkFManhNMnFQZ1VnOUxSQ0FTZFBVNmdhSGhwWFBHaDNnSzZXVnIxRmNUcnBQN0c3R19Xb2g1QnBxVXpiSXRRTk5NOGtzcU5HcUNiWDNvdmhYbGFkWkRCR25iVEJKTmwzcGRBZjNjaVNiWDJDaWlhLWpfdkdXYlQyUWk2NndKYW5lYXBaTkRzMWZsZjlFb3JOX1NzbkM4NWFyQU9MajZlZz09
STRIPE_SECRET_KEY_SECRET = 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
STRIPE_WEBHOOK_SECRET = INT_ENC:Z0FBQUFBQnBudkpGamJBNW91VUdEaThWRTFiTWpyb3NqSDJJcGtjNkhUVVZqVElxUWExY05KcllSYVk1SkRuS1NjYWpZUk1uU29nb2pzdXUxRzBsOEgyRWtmUEw3dUF4ejFIXzNwTVZRM1R1bVVhTUs4ZHJMT0V4Xy1pcHVfWlBaQV9wVXo5MGlQYXA=
STRIPE_API_VERSION = 2026-01-28.clover
APP_FRONTEND_URL = https://nyla-int.poweron-center.net
# AI configuration
Connector_AiOpenai_API_SECRET = INT_ENC:Z0FBQUFBQnBaSnM4MENkQ2xJVmE5WFZKUkh2SHJFby1YVXN3ZmVxRkptS3ZWRmlwdU93ZEJjSjlMV2NGbU5mS3NCdmFfcmFYTEJNZXFIQ3ozTWE4ZC1pemlQNk9wbjU1d3BPS0ZCTTZfOF8yWmVXMWx0TU1DamlJLVFhSTJXclZsY3hMVWlPcXVqQWtMdER4T252NHZUWEhUOTdIN1VGR3ltazEweXFqQ0lvb0hYWmxQQnpxb0JwcFNhRDNGWXdoRTVJWm9FalZpTUF5b1RqZlRaYnVKYkp0NWR5Vko1WWJ0Wmg2VWJzYXZ0Z3Q4UkpsTldDX2dsekhKMmM4YjRoa2RwemMwYVQwM2cyMFlvaU5mOTVTWGlROU8xY2ZVRXlxZzJqWkxURWlGZGI2STZNb0NpdEtWUnM9
Connector_AiAnthropic_API_SECRET = INT_ENC:Z0FBQUFBQm8xSVRjT1ZlRWVJdVZMT3ljSFJDcFdxRFBRVkZhS204NnN5RDBlQ0tpenhTM0FFVktuWW9mWHNwRWx2dHB0eDBSZ0JFQnZKWlp6c01pVGREWHd1eGpERnU0Q2xhaks1clQ1ZXVsdnd2ZzhpNXNQS1BhY3FjSkdkVEhHalNaRGR4emhpakZncnpDQUVxOHVXQzVUWmtQc0FsYmFwTF9TSG5FOUFtWk5Ick1NcHFvY2s1T1c2WXlRUFFJZnh6TWhuaVpMYmppcDR0QUx0a0R6RXlwbGRYb1R4dzJkUT09
Connector_AiPerplexity_API_SECRET = INT_ENC:Z0FBQUFBQm82Mzk2UWZJdUFhSW8yc3RKc0tKRXphd0xWMkZOVlFpSGZ4SGhFWnk0cTF5VjlKQVZjdS1QSWdkS0pUSWw4OFU5MjUxdTVQel9aeWVIZTZ5TXRuVmFkZG0zWEdTOGdHMHpsTzI0TGlWYURKU1Q0VVpKTlhxUk5FTmN6SUJScDZ3ZldIaUJZcWpaQVRiSEpyQm9tRTNDWk9KTnZBPT0=
Connector_AiPerplexity_API_SECRET = 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
Connector_AiTavily_API_SECRET = INT_ENC:Z0FBQUFBQm8xSVRkdkJMTDY0akhXNzZDWHVYSEt1cDZoOWEzSktneHZEV2JndTNmWlNSMV9KbFNIZmQzeVlrNE5qUEIwcUlBSGM1a0hOZ3J6djIyOVhnZzI3M1dIUkdicl9FVXF3RGktMmlEYmhnaHJfWTdGUkktSXVUSGdQMC1vSEV6VE8zR2F1SVk=
Connector_AiPrivateLlm_API_SECRET = INT_ENC:Z0FBQUFBQnBudkpGSjZ1NWh0aWc1R3Z4MHNaeS1HamtUbndhcUZFZDlqUDhjSmg5eHFfdlVkU0RsVkJ2UVRaMWs3aWhraG5jSlc0YkxNWHVmR2JoSW5ENFFCdkJBM0VienlKSnhzNnBKbTJOUTFKczRfWlQ3bWpmUkRTT1I1OGNUSTlQdExacGRpeXg=
Connector_AiMistral_API_SECRET = INT_ENC:Z0FBQUFBQnBudkpGZTNtZ1E4TWIxSEU1OUlreUpxZkJIR0Vxcm9xRHRUbnBxbTQ1cXlkbnltWkJVdTdMYWZ4c3Fsam42TERWUTVhNzZFMU9xVjdyRGFCYml6bmZsZFd2YmJzemlrSWN6Q3o3X0NXX2xXNUQteTNONHdKYzJ5YVpLLWdhU2JhSTJQZnI=

View file

@ -36,15 +36,14 @@ Service_MSFT_REDIRECT_URI = https://gateway-prod.poweron-center.net/api/msft/aut
Service_GOOGLE_REDIRECT_URI = https://gateway-prod.poweron-center.net/api/google/auth/callback
# Stripe Billing (both end with _SECRET for encryption script)
STRIPE_SECRET_KEY_SECRET = PROD_ENC:Z0FBQUFBQnBudkpGNmx3N3Q4QWdlcXVBSlJuYzVJX2hZRW8wbklJYUI0Rzh2YWRPcWY5dC1rMjhfUDRTOE91TlZyLTBEZkY1N015dmg5akEta1d0M0NpNk9oNDZpQTlMUGlLalV6aVowbl9Jc2hKMVlxbE9aaTZNRUxDQ3VGSnJxN040VERUMDFiekhITXdTR0N4aUxwWGxtcHdlU2NtOVNsSlVpOE0xTkRSdGhnN09UWGxuLURUaFdfQWJ4ZEw3R0c0bVRQaTA1NURhVEZudHY4d2gtTzItOF9TcmMwajFmZz09
STRIPE_SECRET_KEY_SECRET = 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
STRIPE_WEBHOOK_SECRET = PROD_ENC:Z0FBQUFBQnBudkpGNUpTWldsakYydFhFelBrR1lSaWxYT3kyMENOMUljZTJUZHBWcEhhdWVCMzYxZXQ5b3VlTFVRalFiTVdsbGxrdUx0RDFwSEpsOC1sTDJRTEJNQlA3S3ZaQzBtV1h6bWp5VnlMZUgwUlF3cXYxcnljZVE5SWdzLVg3V0syOWRYS08=
STRIPE_API_VERSION = 2026-01-28.clover
APP_FRONTEND_URL = https://nyla.poweron-center.net
# AI configuration
Connector_AiOpenai_API_SECRET = PROD_ENC:Z0FBQUFBQnBaSnM4TWJOVm4xVkx6azRlNDdxN3UxLUdwY2hhdGYxRGp4VFJqYXZIcmkxM1ZyOWV2M0Z4MHdFNkVYQ0ROb1d6LUZFUEdvMHhLMEtXYVBCRzM5TlYyY3ROYWtJRk41cDZxd0tYYi00MjVqMTh4QVcyTXl0bmVocEFHbXQwREpwNi1vODdBNmwzazE5bkpNelE2WXpvblIzWlQwbGdEelI2WXFqT1RibXVHcjNWbVhwYzBOM25XTzNmTDAwUjRvYk4yNjIyZHc5c2RSZzREQUFCdUwyb0ZuOXN1dzI2c2FKdXI4NGxEbk92czZWamJXU3ZSbUlLejZjRklRRk4tLV9aVUFZekI2bTU4OHYxNTUybDg3RVo0ZTh6dXNKRW5GNXVackZvcm9laGI0X3R6V3M9
Connector_AiAnthropic_API_SECRET = PROD_ENC:Z0FBQUFBQnBDM1Z3TnhYdlhSLW5RbXJyMHFXX0V0bHhuTDlTaFJsRDl2dTdIUTFtVFAwTE8tY3hLbzNSMnVTLXd3RUZualN3MGNzc1kwOTIxVUN2WW1rYi1TendFRVVBSVNqRFVjckEzNExyTGNaUkJLMmozazUwemI1cnhrcEtZVXJrWkdaVFFramp3MWZ6RmY2aGlRMXVEYjM2M3ZlbmxMdnNCRDM1QWR0Wmd6MWVnS1I1c01nV3hRLXg3d2NTZXVfTi1Wdm16UnRyNGsyRTZ0bG9TQ1g1OFB5Z002bmQ3QT09
Connector_AiPerplexity_API_SECRET = PROD_ENC:Z0FBQUFBQm82Mzk2Q1FGRkJEUkI4LXlQbHYzT2RkdVJEcmM4WGdZTWpJTEhoeUF1NW5LUVpJdDBYN3k1WFN4a2FQSWJSQmd0U0xJbzZDTmFFN05FcXl0Z3V1OEpsZjYydV94TXVjVjVXRTRYSWdLMkd5XzZIbFV6emRCZHpuOUpQeThadE5xcDNDVGV1RHJrUEN0c1BBYXctZFNWcFRuVXhRPT0=
Connector_AiPerplexity_API_SECRET = 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
Connector_AiTavily_API_SECRET = PROD_ENC:Z0FBQUFBQnBDM1Z3NmItcDh6V0JpcE5Jc0NlUWZqcmllRHB5eDlNZmVnUlNVenhNTm5xWExzbjJqdE1GZ0hTSUYtb2dvdWNhTnlQNmVWQ2NGVDgwZ0MwMWZBMlNKWEhzdlF3TlZzTXhCZWM4Z1Uwb18tSTRoU1JBVTVkSkJHOTJwX291b3dPaVphVFg=
Connector_AiPrivateLlm_API_SECRET = PROD_ENC:Z0FBQUFBQnBudkpGanZ6U3pzZWkwXzVPWGtIQ040XzFrTXc5QWRnazdEeEktaUJ0akJmNnEzbWUzNHczLTJfc2dIdzBDY0FTaXZYcDhxNFdNbTNtbEJTb2VRZ0ZYd05hdlNLR1h6SUFzVml2Z1FLY1BjTl90UWozUGxtak1URnhhZmNDRWFTb0dKVUo=
Connector_AiMistral_API_SECRET = PROD_ENC:Z0FBQUFBQnBudkpGc2tQc2lvMk1YZk01Q1dob1U5cnR0dG03WWE3WkpoOWo0SEpvLU9Rc2lCNDExdy1wZExaN3lpT2FEQkxnaHRmWmZUUUZUUUJmblZreGlpaFpOdnFhbzlEd1RsVVJtX216cmhxTm5BcTN2eUZ2T054cDE5bmlEamJ3NGR6MVpFQnA=

View file

@ -11,9 +11,34 @@ IMPORTANT: Model Registration Requirements
- If duplicate displayNames are detected during registration, an error will be raised
"""
import re as _re
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from modules.datamodels.datamodelAi import AiModel
from typing import List, Dict, Any, Optional, AsyncGenerator, Union
from modules.datamodels.datamodelAi import AiModel, AiModelCall, AiModelResponse
_RETRY_AFTER_PATTERN = _re.compile(r"try again in (\d+(?:\.\d+)?)\s*s", _re.IGNORECASE)
def _parseRetryAfterSeconds(message: str) -> float:
"""Extract retry-after seconds from provider error messages like 'Please try again in 6.558s'."""
match = _RETRY_AFTER_PATTERN.search(message)
return float(match.group(1)) if match else 0.0
class RateLimitExceededException(Exception):
"""Raised when a provider's rate limit (TPM / RPM) is exceeded."""
def __init__(self, message: str = "Rate limit exceeded", retryAfterSeconds: float = 0.0):
super().__init__(message)
if retryAfterSeconds <= 0:
retryAfterSeconds = _parseRetryAfterSeconds(message)
self.retryAfterSeconds = retryAfterSeconds
class ContextLengthExceededException(Exception):
"""Raised when the input exceeds a model's context window."""
pass
class BaseConnectorAi(ABC):
@ -102,3 +127,24 @@ class BaseConnectorAi(ABC):
"""Get only available models."""
models = self.getCachedModels()
return [model for model in models if model.isAvailable]
async def callAiBasicStream(self, modelCall: AiModelCall) -> AsyncGenerator[Union[str, AiModelResponse], None]:
"""Stream AI response. Yields str deltas during generation, then final AiModelResponse.
Default implementation: falls back to non-streaming callAiBasic.
Override in connectors that support streaming.
"""
response = await self.callAiBasic(modelCall)
if response.content:
yield response.content
yield response
async def callEmbedding(self, modelCall: AiModelCall) -> AiModelResponse:
"""Generate embeddings for input texts. Override in connectors that support embeddings.
Reads texts from modelCall.embeddingInput.
Returns AiModelResponse with metadata["embeddings"] containing the vectors.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support embeddings"
)

View file

@ -19,26 +19,29 @@ class ModelSelector:
"""Model selector with priority scoring and recent-failure cooldown."""
def __init__(self):
self._failureLog: Dict[str, float] = {}
self._failureLog: Dict[str, Tuple[float, float]] = {}
logger.info("ModelSelector initialized with failure cooldown support")
def reportFailure(self, modelName: str):
def reportFailure(self, modelName: str, cooldownSeconds: float = 0.0):
"""Record that a model just failed (rate limit, error, etc.).
The model will be deprioritized for COOLDOWN_DURATION seconds."""
self._failureLog[modelName] = time.time()
logger.info(f"ModelSelector: Recorded failure for {modelName}, cooldown {_COOLDOWN_DURATION}s")
The model will be deprioritized for *cooldownSeconds* (default: _COOLDOWN_DURATION)."""
if cooldownSeconds <= 0:
cooldownSeconds = _COOLDOWN_DURATION
self._failureLog[modelName] = (time.time(), cooldownSeconds)
logger.info(f"ModelSelector: Recorded failure for {modelName}, cooldown {cooldownSeconds:.1f}s")
def _getCooldownPenalty(self, modelName: str) -> float:
"""Return a score penalty (0.0 = no penalty, large negative = recently failed)."""
failedAt = self._failureLog.get(modelName)
if failedAt is None:
entry = self._failureLog.get(modelName)
if entry is None:
return 0.0
failedAt, cooldown = entry
elapsed = time.time() - failedAt
if elapsed > _COOLDOWN_DURATION:
if elapsed > cooldown:
del self._failureLog[modelName]
return 0.0
remaining = _COOLDOWN_DURATION - elapsed
return -(remaining / _COOLDOWN_DURATION) * 5000.0
remaining = cooldown - elapsed
return -(remaining / cooldown) * 5000.0
def selectModel(self,
prompt: str,

View file

@ -1,12 +1,13 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
import json
import logging
import httpx
import os
from typing import Dict, Any, List
from typing import Dict, Any, List, AsyncGenerator, Optional, Union
from fastapi import HTTPException
from modules.shared.configuration import APP_CONFIG
from .aicoreBase import BaseConnectorAi
from .aicoreBase import BaseConnectorAi, RateLimitExceededException
from modules.datamodels.datamodelAi import AiModel, PriorityEnum, ProcessingModeEnum, OperationTypeEnum, AiModelCall, AiModelResponse, createOperationTypeRatings
# Configure logger
@ -61,13 +62,15 @@ class AiAnthropic(BaseConnectorAi):
speedRating=6, # Slower due to high-quality processing
qualityRating=10, # Best quality available
functionCall=self.callAiBasic,
functionCallStream=self.callAiBasicStream,
priority=PriorityEnum.QUALITY,
processingMode=ProcessingModeEnum.DETAILED,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.PLAN, 9),
(OperationTypeEnum.DATA_ANALYSE, 9),
(OperationTypeEnum.DATA_GENERATE, 9),
(OperationTypeEnum.DATA_EXTRACT, 8)
(OperationTypeEnum.DATA_EXTRACT, 8),
(OperationTypeEnum.AGENT, 9),
),
version="claude-sonnet-4-5-20250929",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.003 + (bytesReceived / 4 / 1000) * 0.015
@ -85,13 +88,15 @@ class AiAnthropic(BaseConnectorAi):
speedRating=9, # Very fast, lightweight model
qualityRating=8, # Good quality, cost-efficient
functionCall=self.callAiBasic,
functionCallStream=self.callAiBasicStream,
priority=PriorityEnum.SPEED,
processingMode=ProcessingModeEnum.BASIC,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.PLAN, 8),
(OperationTypeEnum.DATA_ANALYSE, 8),
(OperationTypeEnum.DATA_GENERATE, 8),
(OperationTypeEnum.DATA_EXTRACT, 7)
(OperationTypeEnum.DATA_EXTRACT, 7),
(OperationTypeEnum.AGENT, 7),
),
version="claude-haiku-4-5-20251001",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.001 + (bytesReceived / 4 / 1000) * 0.005
@ -109,13 +114,15 @@ class AiAnthropic(BaseConnectorAi):
speedRating=5, # Moderate latency, most capable
qualityRating=10, # Top-tier intelligence
functionCall=self.callAiBasic,
functionCallStream=self.callAiBasicStream,
priority=PriorityEnum.QUALITY,
processingMode=ProcessingModeEnum.DETAILED,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.PLAN, 10),
(OperationTypeEnum.DATA_ANALYSE, 10),
(OperationTypeEnum.DATA_ANALYSE, 8),
(OperationTypeEnum.DATA_GENERATE, 10),
(OperationTypeEnum.DATA_EXTRACT, 9)
(OperationTypeEnum.DATA_EXTRACT, 9),
(OperationTypeEnum.AGENT, 10),
),
version="claude-opus-4-6",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.005 + (bytesReceived / 4 / 1000) * 0.025
@ -158,8 +165,6 @@ class AiAnthropic(BaseConnectorAi):
HTTPException: For errors in API communication
"""
try:
# Extract parameters from modelCall
messages = modelCall.messages
model = modelCall.model
options = modelCall.options
temperature = getattr(options, "temperature", None)
@ -167,44 +172,8 @@ class AiAnthropic(BaseConnectorAi):
temperature = model.temperature
maxTokens = model.maxTokens
# Transform OpenAI-style messages to Anthropic format:
# - Move any 'system' role content to top-level 'system'
# - Keep only 'user'/'assistant' messages in the list
system_contents: List[str] = []
converted_messages: List[Dict[str, Any]] = []
for m in messages:
role = m.get("role")
content = m.get("content", "")
if role == "system":
# Collect system content; Anthropic expects top-level 'system'
if isinstance(content, list):
# Join text parts if provided as blocks
joined = "\n\n".join(
[
(part.get("text") if isinstance(part, dict) else str(part))
for part in content
]
)
system_contents.append(joined)
else:
system_contents.append(str(content))
continue
# For Anthropic, content can be a string; pass through strings, collapse blocks
if isinstance(content, list):
# Collapse to text if blocks are provided
collapsed = "\n\n".join(
[
(part.get("text") if isinstance(part, dict) else str(part))
for part in content
]
)
converted_messages.append({"role": role, "content": collapsed})
else:
converted_messages.append({"role": role, "content": content})
converted_messages, system_prompt = _convertMessagesForAnthropic(modelCall.messages)
system_prompt = "\n\n".join([s for s in system_contents if s]) if system_contents else None
# Create Anthropic API payload
payload: Dict[str, Any] = {
"model": model.name,
"messages": converted_messages,
@ -218,6 +187,13 @@ class AiAnthropic(BaseConnectorAi):
if system_prompt:
payload["system"] = system_prompt
if modelCall.tools:
payload["tools"] = _convertToolsToAnthropicFormat(modelCall.tools)
if modelCall.toolChoice:
payload["tool_choice"] = modelCall.toolChoice
else:
payload["tool_choice"] = {"type": "auto"}
response = await self.httpClient.post(
model.apiUrl,
json=payload
@ -227,11 +203,12 @@ class AiAnthropic(BaseConnectorAi):
error_detail = f"Anthropic API error: {response.status_code} - {response.text}"
logger.error(error_detail)
# Provide more specific error messages based on status code
if response.status_code == 429:
raise RateLimitExceededException(
f"Rate limit exceeded for {model.name}: {response.text}"
)
if response.status_code == 529:
error_message = "Anthropic API is currently overloaded. Please try again in a few minutes."
elif response.status_code == 429:
error_message = "Rate limit exceeded. Please wait before making another request."
elif response.status_code == 401:
error_message = "Invalid API key. Please check your Anthropic API configuration."
elif response.status_code == 400:
@ -244,31 +221,43 @@ class AiAnthropic(BaseConnectorAi):
# Parse response
anthropicResponse = response.json()
# Extract content from response
# Extract content and tool_use blocks from response
content = ""
toolCalls = []
if "content" in anthropicResponse:
if isinstance(anthropicResponse["content"], list):
# Content is a list of parts (in newer API versions)
for part in anthropicResponse["content"]:
if part.get("type") == "text":
content += part.get("text", "")
elif part.get("type") == "tool_use":
toolCalls.append({
"id": part.get("id", ""),
"type": "function",
"function": {
"name": part.get("name", ""),
"arguments": json.dumps(part.get("input", {})) if isinstance(part.get("input"), dict) else str(part.get("input", "{}"))
}
})
else:
# Direct content as string (in older API versions)
content = anthropicResponse["content"]
# Debug logging for empty responses
if not content or content.strip() == "":
if not content and not toolCalls:
logger.warning(f"Anthropic API returned empty content. Full response: {anthropicResponse}")
content = "[Anthropic API returned empty response]"
# Return standardized response
metadata = {"response_id": anthropicResponse.get("id", "")}
if toolCalls:
metadata["toolCalls"] = toolCalls
return AiModelResponse(
content=content,
success=True,
modelId=model.name,
metadata={"response_id": anthropicResponse.get("id", "")}
metadata=metadata
)
except (RateLimitExceededException, HTTPException):
raise
except Exception as e:
error_msg = str(e) if str(e) else f"{type(e).__name__}"
error_detail = f"Error calling Anthropic API: {error_msg}"
@ -279,6 +268,128 @@ class AiAnthropic(BaseConnectorAi):
logger.error(error_detail, exc_info=True)
raise HTTPException(status_code=500, detail=error_detail)
async def callAiBasicStream(self, modelCall: AiModelCall) -> AsyncGenerator[Union[str, AiModelResponse], None]:
"""Stream Anthropic response. Yields str deltas, then final AiModelResponse."""
try:
model = modelCall.model
options = modelCall.options
temperature = getattr(options, "temperature", None)
if temperature is None:
temperature = model.temperature
converted, system_prompt = _convertMessagesForAnthropic(modelCall.messages)
payload: Dict[str, Any] = {
"model": model.name,
"messages": converted,
"temperature": temperature,
"max_tokens": model.maxTokens,
"stream": True,
}
if system_prompt:
payload["system"] = system_prompt
if modelCall.tools:
payload["tools"] = _convertToolsToAnthropicFormat(modelCall.tools)
payload["tool_choice"] = modelCall.toolChoice or {"type": "auto"}
fullContent = ""
toolUseBlocks: Dict[int, Dict[str, Any]] = {}
currentToolIdx = -1
stopReason: Optional[str] = None
async with self.httpClient.stream("POST", model.apiUrl, json=payload) as response:
if response.status_code != 200:
body = await response.aread()
bodyStr = body.decode()
if response.status_code == 429:
raise RateLimitExceededException(
f"Rate limit exceeded for {model.name}: {bodyStr}"
)
raise HTTPException(status_code=500, detail=f"Anthropic stream error: {response.status_code} - {bodyStr}")
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
try:
event = json.loads(line[6:])
except json.JSONDecodeError:
continue
eventType = event.get("type", "")
if eventType == "error":
errDetail = event.get("error", {})
errMsg = errDetail.get("message", str(errDetail))
errType = errDetail.get("type", "unknown")
logger.error(f"Anthropic stream error event: type={errType}, message={errMsg}")
if "overloaded" in errMsg.lower() or "overloaded" in errType.lower():
raise HTTPException(status_code=500, detail=f"Anthropic API is currently overloaded. Please try again in a few minutes.")
raise HTTPException(status_code=500, detail=f"Anthropic stream error: [{errType}] {errMsg}")
elif eventType == "content_block_start":
block = event.get("content_block", {})
idx = event.get("index", 0)
if block.get("type") == "tool_use":
currentToolIdx = idx
toolUseBlocks[idx] = {
"id": block.get("id", ""),
"name": block.get("name", ""),
"arguments": "",
}
elif eventType == "content_block_delta":
delta = event.get("delta", {})
if delta.get("type") == "text_delta":
text = delta.get("text", "")
fullContent += text
yield text
elif delta.get("type") == "input_json_delta":
idx = event.get("index", currentToolIdx)
if idx in toolUseBlocks:
toolUseBlocks[idx]["arguments"] += delta.get("partial_json", "")
elif eventType == "message_delta":
delta = event.get("delta", {})
stopReason = delta.get("stop_reason", stopReason)
elif eventType == "message_stop":
break
if not fullContent and not toolUseBlocks:
logger.warning(
f"Anthropic stream returned empty response: model={model.name}, "
f"stopReason={stopReason}"
)
metadata: Dict[str, Any] = {}
if stopReason:
metadata["stopReason"] = stopReason
if toolUseBlocks:
metadata["toolCalls"] = [
{
"id": tb["id"],
"type": "function",
"function": {
"name": tb["name"],
"arguments": tb["arguments"],
},
}
for tb in toolUseBlocks.values()
]
yield AiModelResponse(
content=fullContent,
success=True,
modelId=model.name,
metadata=metadata,
)
except (RateLimitExceededException, HTTPException):
raise
except Exception as e:
logger.error(f"Error streaming Anthropic API: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error streaming Anthropic API: {e}")
async def callAiImage(self, modelCall: AiModelCall) -> AiModelResponse:
"""
Analyzes an image using Anthropic's vision capabilities using standardized pattern.
@ -331,6 +442,20 @@ class AiAnthropic(BaseConnectorAi):
mimeType = parts[0].replace("data:", "")
base64Data = parts[1]
import base64 as _b64
try:
rawHead = _b64.b64decode(base64Data[:32])
if rawHead[:3] == b"\xff\xd8\xff":
mimeType = "image/jpeg"
elif rawHead[:8] == b"\x89PNG\r\n\x1a\n":
mimeType = "image/png"
elif rawHead[:4] == b"GIF8":
mimeType = "image/gif"
elif rawHead[:4] == b"RIFF" and rawHead[8:12] == b"WEBP":
mimeType = "image/webp"
except Exception:
pass
# Convert to Anthropic's vision format
anthropicMessages = [{
"role": "user",
@ -425,3 +550,100 @@ class AiAnthropic(BaseConnectorAi):
success=False,
error=f"Error during image analysis: {str(e)}"
)
def _convertMessagesForAnthropic(messages: List[Dict[str, Any]]):
"""Convert OpenAI-style messages to Anthropic format. Returns (messages, system_prompt)."""
system_contents: List[str] = []
converted_messages: List[Dict[str, Any]] = []
pendingToolResults: List[Dict[str, Any]] = []
def _flush():
if not pendingToolResults:
return
converted_messages.append({"role": "user", "content": list(pendingToolResults)})
pendingToolResults.clear()
def _collapse(content):
if isinstance(content, list):
return "\n\n".join(
(part.get("text") if isinstance(part, dict) else str(part))
for part in content
)
return str(content) if content else ""
for m in messages:
role = m.get("role")
content = m.get("content", "")
if role == "system":
system_contents.append(_collapse(content))
continue
if role == "tool":
pendingToolResults.append({
"type": "tool_result",
"tool_use_id": m.get("tool_call_id", ""),
"content": str(content) if content else "",
})
continue
_flush()
if role == "assistant" and m.get("tool_calls"):
contentBlocks = []
textPart = _collapse(content)
if textPart:
contentBlocks.append({"type": "text", "text": textPart})
for tc in m["tool_calls"]:
fn = tc.get("function", {})
inputData = fn.get("arguments", "{}")
if isinstance(inputData, str):
try:
inputData = json.loads(inputData)
except (json.JSONDecodeError, ValueError):
inputData = {}
contentBlocks.append({
"type": "tool_use",
"id": tc.get("id", ""),
"name": fn.get("name", ""),
"input": inputData,
})
converted_messages.append({"role": "assistant", "content": contentBlocks})
continue
converted_messages.append({"role": role, "content": _collapse(content)})
_flush()
merged: List[Dict[str, Any]] = []
for msg in converted_messages:
if merged and merged[-1]["role"] == msg["role"]:
prev = merged[-1]
pc, nc = prev["content"], msg["content"]
if isinstance(pc, str) and isinstance(nc, str):
prev["content"] = pc + "\n\n" + nc
elif isinstance(pc, list) and isinstance(nc, list):
prev["content"] = pc + nc
elif isinstance(pc, str) and isinstance(nc, list):
prev["content"] = [{"type": "text", "text": pc}] + nc
elif isinstance(pc, list) and isinstance(nc, str):
prev["content"] = pc + [{"type": "text", "text": nc}]
else:
merged.append(msg)
system_prompt = "\n\n".join([s for s in system_contents if s]) if system_contents else None
return merged, system_prompt
def _convertToolsToAnthropicFormat(openaiTools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Convert OpenAI-style tool definitions to Anthropic format."""
anthropicTools = []
for tool in openaiTools:
if tool.get("type") == "function":
fn = tool["function"]
anthropicTools.append({
"name": fn["name"],
"description": fn.get("description", ""),
"input_schema": fn.get("parameters", {"type": "object", "properties": {}})
})
return anthropicTools

View file

@ -1,24 +1,16 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
import logging
import json as _json
import httpx
from typing import List
from typing import List, Dict, Any, AsyncGenerator, Union
from fastapi import HTTPException
from modules.shared.configuration import APP_CONFIG
from .aicoreBase import BaseConnectorAi
from .aicoreBase import BaseConnectorAi, RateLimitExceededException, ContextLengthExceededException
from modules.datamodels.datamodelAi import AiModel, PriorityEnum, ProcessingModeEnum, OperationTypeEnum, AiModelCall, AiModelResponse, createOperationTypeRatings
# Configure logger
logger = logging.getLogger(__name__)
class ContextLengthExceededException(Exception):
"""Exception raised when the context length exceeds the model's limit"""
pass
class RateLimitExceededException(Exception):
"""Exception raised when the provider's rate limit (TPM) is exceeded"""
pass
def loadConfigData():
"""Load configuration data for Mistral connector"""
return {
@ -66,13 +58,15 @@ class AiMistral(BaseConnectorAi):
speedRating=8, # Good speed for complex tasks
qualityRating=9, # High quality
functionCall=self.callAiBasic,
functionCallStream=self.callAiBasicStream,
priority=PriorityEnum.BALANCED,
processingMode=ProcessingModeEnum.ADVANCED,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.PLAN, 9),
(OperationTypeEnum.DATA_ANALYSE, 9),
(OperationTypeEnum.DATA_GENERATE, 9),
(OperationTypeEnum.DATA_EXTRACT, 8)
(OperationTypeEnum.DATA_EXTRACT, 8),
(OperationTypeEnum.AGENT, 8),
),
version="mistral-large-latest",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0005 + (bytesReceived / 4 / 1000) * 0.0015
@ -90,17 +84,40 @@ class AiMistral(BaseConnectorAi):
speedRating=9, # Very fast, lightweight model
qualityRating=7, # Good quality, cost-efficient
functionCall=self.callAiBasic,
functionCallStream=self.callAiBasicStream,
priority=PriorityEnum.SPEED,
processingMode=ProcessingModeEnum.BASIC,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.PLAN, 7),
(OperationTypeEnum.DATA_ANALYSE, 7),
(OperationTypeEnum.DATA_GENERATE, 8),
(OperationTypeEnum.DATA_EXTRACT, 7)
(OperationTypeEnum.DATA_EXTRACT, 7),
(OperationTypeEnum.AGENT, 6),
),
version="mistral-small-latest",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00006 + (bytesReceived / 4 / 1000) * 0.00018
),
AiModel(
name="mistral-embed",
displayName="Mistral Embed",
connectorType="mistral",
apiUrl="https://api.mistral.ai/v1/embeddings",
temperature=0.0,
maxTokens=0,
contextLength=8192,
costPer1kTokensInput=0.0001, # $0.10/M tokens
costPer1kTokensOutput=0.0,
speedRating=10,
qualityRating=7,
functionCall=self.callEmbedding,
priority=PriorityEnum.COST,
processingMode=ProcessingModeEnum.BASIC,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.EMBEDDING, 8)
),
version="mistral-embed",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0001
),
AiModel(
name="mistral-large-latest",
displayName="Mistral Large 3 Vision",
@ -158,6 +175,10 @@ class AiMistral(BaseConnectorAi):
"max_tokens": maxTokens
}
if modelCall.tools:
payload["tools"] = modelCall.tools
payload["tool_choice"] = modelCall.toolChoice or "auto"
response = await self.httpClient.post(
model.apiUrl,
json=payload
@ -197,13 +218,18 @@ class AiMistral(BaseConnectorAi):
raise HTTPException(status_code=500, detail=error_message)
responseJson = response.json()
content = responseJson["choices"][0]["message"]["content"]
choiceMessage = responseJson["choices"][0]["message"]
content = choiceMessage.get("content") or ""
metadata = {"response_id": responseJson.get("id", "")}
if choiceMessage.get("tool_calls"):
metadata["toolCalls"] = choiceMessage["tool_calls"]
return AiModelResponse(
content=content,
success=True,
modelId=model.name,
metadata={"response_id": responseJson.get("id", "")}
metadata=metadata,
)
except ContextLengthExceededException:
@ -216,6 +242,147 @@ class AiMistral(BaseConnectorAi):
logger.error(f"Error calling Mistral API: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error calling Mistral API: {str(e)}")
async def callAiBasicStream(self, modelCall: AiModelCall) -> AsyncGenerator[Union[str, AiModelResponse], None]:
"""Stream Mistral response. Yields str deltas, then final AiModelResponse."""
try:
model = modelCall.model
options = modelCall.options
temperature = getattr(options, "temperature", None)
if temperature is None:
temperature = model.temperature
payload: Dict[str, Any] = {
"model": model.name,
"messages": modelCall.messages,
"temperature": temperature,
"max_tokens": model.maxTokens,
"stream": True,
}
if modelCall.tools:
payload["tools"] = modelCall.tools
payload["tool_choice"] = modelCall.toolChoice or "auto"
fullContent = ""
toolCallsAccum: Dict[int, Dict[str, Any]] = {}
async with self.httpClient.stream("POST", model.apiUrl, json=payload) as response:
if response.status_code != 200:
body = await response.aread()
bodyStr = body.decode()
if response.status_code == 429:
try:
errorMsg = _json.loads(bodyStr).get("error", {}).get("message", "Rate limit exceeded")
except (ValueError, KeyError):
errorMsg = f"Rate limit exceeded for {model.name}"
raise RateLimitExceededException(f"Rate limit exceeded for {model.name}: {errorMsg}")
raise HTTPException(status_code=500, detail=f"Mistral stream error: {response.status_code} - {bodyStr}")
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
data = line[6:]
if data.strip() == "[DONE]":
break
try:
chunk = _json.loads(data)
except _json.JSONDecodeError:
continue
delta = chunk.get("choices", [{}])[0].get("delta", {})
if "content" in delta and delta["content"]:
fullContent += delta["content"]
yield delta["content"]
for tcDelta in delta.get("tool_calls", []):
idx = tcDelta.get("index", 0)
if idx not in toolCallsAccum:
toolCallsAccum[idx] = {
"id": tcDelta.get("id", ""),
"type": "function",
"function": {"name": "", "arguments": ""},
}
if tcDelta.get("id"):
toolCallsAccum[idx]["id"] = tcDelta["id"]
fn = tcDelta.get("function", {})
if fn.get("name"):
toolCallsAccum[idx]["function"]["name"] = fn["name"]
if fn.get("arguments"):
toolCallsAccum[idx]["function"]["arguments"] += fn["arguments"]
metadata: Dict[str, Any] = {}
if toolCallsAccum:
metadata["toolCalls"] = [toolCallsAccum[i] for i in sorted(toolCallsAccum)]
yield AiModelResponse(
content=fullContent,
success=True,
modelId=model.name,
metadata=metadata,
)
except (RateLimitExceededException, ContextLengthExceededException, HTTPException):
raise
except Exception as e:
logger.error(f"Error streaming Mistral API: {e}")
raise HTTPException(status_code=500, detail=f"Error streaming Mistral API: {e}")
async def callEmbedding(self, modelCall: AiModelCall) -> AiModelResponse:
"""Generate embeddings via the Mistral Embeddings API.
Reads texts from modelCall.embeddingInput.
Returns vectors in metadata["embeddings"].
"""
try:
model = modelCall.model
texts = modelCall.embeddingInput or []
if not texts:
return AiModelResponse(
content="", success=False, error="No embeddingInput provided"
)
payload = {"model": model.name, "input": texts}
response = await self.httpClient.post(model.apiUrl, json=payload)
if response.status_code != 200:
errorMessage = f"Mistral Embedding API error: {response.status_code} - {response.text}"
logger.error(errorMessage)
if response.status_code == 429:
raise RateLimitExceededException(f"Rate limit exceeded for {model.name}")
if response.status_code == 400:
try:
errorData = response.json()
errMsg = errorData.get("error", {}).get("message", "").lower()
errCode = errorData.get("error", {}).get("code", "")
if errCode == "context_length_exceeded" or "too many tokens" in errMsg or "maximum context length" in errMsg:
raise ContextLengthExceededException(
f"Embedding context length exceeded for {model.name}: {errorData.get('error', {}).get('message', '')}"
)
except (ValueError, KeyError):
pass
raise HTTPException(status_code=500, detail=errorMessage)
responseJson = response.json()
embeddings = [item["embedding"] for item in responseJson["data"]]
usage = responseJson.get("usage", {})
return AiModelResponse(
content="",
success=True,
modelId=model.name,
tokensUsed={
"input": usage.get("prompt_tokens", 0),
"output": 0,
"total": usage.get("total_tokens", 0),
},
metadata={"embeddings": embeddings},
)
except (RateLimitExceededException, ContextLengthExceededException):
raise
except Exception as e:
logger.error(f"Error calling Mistral Embedding API: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error calling Mistral Embedding API: {str(e)}")
async def callAiImage(self, modelCall: AiModelCall) -> AiModelResponse:
"""
Analyzes an image with the Mistral Vision API using standardized pattern.

View file

@ -1,24 +1,16 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
import logging
import json as _json
import httpx
from typing import List
from typing import List, Dict, Any, AsyncGenerator, Union
from fastapi import HTTPException
from modules.shared.configuration import APP_CONFIG
from .aicoreBase import BaseConnectorAi
from .aicoreBase import BaseConnectorAi, RateLimitExceededException, ContextLengthExceededException
from modules.datamodels.datamodelAi import AiModel, PriorityEnum, ProcessingModeEnum, OperationTypeEnum, AiModelCall, AiModelResponse, createOperationTypeRatings, AiCallPromptImage
# Configure logger
logger = logging.getLogger(__name__)
class ContextLengthExceededException(Exception):
"""Exception raised when the context length exceeds the model's limit"""
pass
class RateLimitExceededException(Exception):
"""Exception raised when the provider's rate limit (TPM) is exceeded"""
pass
def loadConfigData():
"""Load configuration data for OpenAI connector"""
return {
@ -67,13 +59,15 @@ class AiOpenai(BaseConnectorAi):
speedRating=8, # Good speed for complex tasks
qualityRating=10, # High quality
functionCall=self.callAiBasic,
functionCallStream=self.callAiBasicStream,
priority=PriorityEnum.BALANCED,
processingMode=ProcessingModeEnum.ADVANCED,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.PLAN, 9),
(OperationTypeEnum.DATA_ANALYSE, 10),
(OperationTypeEnum.DATA_GENERATE, 10),
(OperationTypeEnum.DATA_EXTRACT, 7)
(OperationTypeEnum.DATA_EXTRACT, 7),
(OperationTypeEnum.AGENT, 9),
),
version="gpt-4o",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0025 + (bytesReceived / 4 / 1000) * 0.01
@ -92,13 +86,15 @@ class AiOpenai(BaseConnectorAi):
speedRating=9, # Very fast
qualityRating=8, # Good quality, replaces gpt-3.5-turbo
functionCall=self.callAiBasic,
functionCallStream=self.callAiBasicStream,
priority=PriorityEnum.SPEED,
processingMode=ProcessingModeEnum.BASIC,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.PLAN, 8),
(OperationTypeEnum.DATA_ANALYSE, 8),
(OperationTypeEnum.DATA_GENERATE, 9),
(OperationTypeEnum.DATA_EXTRACT, 7)
(OperationTypeEnum.DATA_EXTRACT, 7),
(OperationTypeEnum.AGENT, 8),
),
version="gpt-4o-mini",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00015 + (bytesReceived / 4 / 1000) * 0.0006
@ -125,6 +121,48 @@ class AiOpenai(BaseConnectorAi):
version="gpt-4o",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0025 + (bytesReceived / 4 / 1000) * 0.01
),
AiModel(
name="text-embedding-3-small",
displayName="OpenAI Embedding Small",
connectorType="openai",
apiUrl="https://api.openai.com/v1/embeddings",
temperature=0.0,
maxTokens=0,
contextLength=8191,
costPer1kTokensInput=0.00002, # $0.02/M tokens
costPer1kTokensOutput=0.0,
speedRating=10,
qualityRating=8,
functionCall=self.callEmbedding,
priority=PriorityEnum.COST,
processingMode=ProcessingModeEnum.BASIC,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.EMBEDDING, 10)
),
version="text-embedding-3-small",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00002
),
AiModel(
name="text-embedding-3-large",
displayName="OpenAI Embedding Large",
connectorType="openai",
apiUrl="https://api.openai.com/v1/embeddings",
temperature=0.0,
maxTokens=0,
contextLength=8191,
costPer1kTokensInput=0.00013, # $0.13/M tokens
costPer1kTokensOutput=0.0,
speedRating=9,
qualityRating=10,
functionCall=self.callEmbedding,
priority=PriorityEnum.QUALITY,
processingMode=ProcessingModeEnum.ADVANCED,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.EMBEDDING, 10)
),
version="text-embedding-3-large",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00013
),
AiModel(
name="dall-e-3",
displayName="OpenAI DALL-E 3",
@ -179,6 +217,10 @@ class AiOpenai(BaseConnectorAi):
"max_tokens": maxTokens
}
if modelCall.tools:
payload["tools"] = modelCall.tools
payload["tool_choice"] = modelCall.toolChoice or "auto"
response = await self.httpClient.post(
model.apiUrl,
json=payload
@ -218,22 +260,168 @@ class AiOpenai(BaseConnectorAi):
raise HTTPException(status_code=500, detail=error_message)
responseJson = response.json()
content = responseJson["choices"][0]["message"]["content"]
choiceMessage = responseJson["choices"][0]["message"]
content = choiceMessage.get("content") or ""
metadata = {"response_id": responseJson.get("id", "")}
if choiceMessage.get("tool_calls"):
metadata["toolCalls"] = choiceMessage["tool_calls"]
return AiModelResponse(
content=content,
success=True,
modelId=model.name,
metadata={"response_id": responseJson.get("id", "")}
metadata=metadata
)
except ContextLengthExceededException:
# Re-raise context length exceptions without wrapping
raise
except Exception as e:
logger.error(f"Error calling OpenAI API: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error calling OpenAI API: {str(e)}")
async def callAiBasicStream(self, modelCall: AiModelCall) -> AsyncGenerator[Union[str, AiModelResponse], None]:
"""Stream OpenAI response. Yields str deltas, then final AiModelResponse."""
try:
messages = modelCall.messages
model = modelCall.model
options = modelCall.options
temperature = getattr(options, "temperature", None)
if temperature is None:
temperature = model.temperature
payload: Dict[str, Any] = {
"model": model.name,
"messages": messages,
"temperature": temperature,
"max_tokens": model.maxTokens,
"stream": True,
}
if modelCall.tools:
payload["tools"] = modelCall.tools
payload["tool_choice"] = modelCall.toolChoice or "auto"
fullContent = ""
toolCallsAccum: Dict[int, Dict[str, Any]] = {}
async with self.httpClient.stream("POST", model.apiUrl, json=payload) as response:
if response.status_code != 200:
body = await response.aread()
bodyStr = body.decode()
if response.status_code == 429:
try:
errorMsg = _json.loads(bodyStr).get("error", {}).get("message", "Rate limit exceeded")
except (ValueError, KeyError):
errorMsg = f"Rate limit exceeded for {model.name}"
raise RateLimitExceededException(f"Rate limit exceeded for {model.name}: {errorMsg}")
raise HTTPException(status_code=500, detail=f"OpenAI stream error: {response.status_code} - {bodyStr}")
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
data = line[6:]
if data.strip() == "[DONE]":
break
try:
chunk = _json.loads(data)
except _json.JSONDecodeError:
continue
delta = chunk.get("choices", [{}])[0].get("delta", {})
if "content" in delta and delta["content"]:
fullContent += delta["content"]
yield delta["content"]
for tcDelta in delta.get("tool_calls", []):
idx = tcDelta.get("index", 0)
if idx not in toolCallsAccum:
toolCallsAccum[idx] = {
"id": tcDelta.get("id", ""),
"type": "function",
"function": {"name": "", "arguments": ""},
}
if tcDelta.get("id"):
toolCallsAccum[idx]["id"] = tcDelta["id"]
fn = tcDelta.get("function", {})
if fn.get("name"):
toolCallsAccum[idx]["function"]["name"] = fn["name"]
if fn.get("arguments"):
toolCallsAccum[idx]["function"]["arguments"] += fn["arguments"]
metadata: Dict[str, Any] = {}
if toolCallsAccum:
metadata["toolCalls"] = [toolCallsAccum[i] for i in sorted(toolCallsAccum)]
yield AiModelResponse(
content=fullContent,
success=True,
modelId=model.name,
metadata=metadata,
)
except (RateLimitExceededException, ContextLengthExceededException, HTTPException):
raise
except Exception as e:
logger.error(f"Error streaming OpenAI API: {e}")
raise HTTPException(status_code=500, detail=f"Error streaming OpenAI API: {e}")
async def callEmbedding(self, modelCall: AiModelCall) -> AiModelResponse:
"""Generate embeddings via the OpenAI Embeddings API.
Reads texts from modelCall.embeddingInput.
Returns vectors in metadata["embeddings"].
"""
try:
model = modelCall.model
texts = modelCall.embeddingInput or []
if not texts:
return AiModelResponse(
content="", success=False, error="No embeddingInput provided"
)
payload = {"model": model.name, "input": texts}
response = await self.httpClient.post(model.apiUrl, json=payload)
if response.status_code != 200:
errorMessage = f"OpenAI Embedding API error: {response.status_code} - {response.text}"
logger.error(errorMessage)
if response.status_code == 429:
raise RateLimitExceededException(f"Rate limit exceeded for {model.name}")
if response.status_code == 400:
try:
errorData = response.json()
errMsg = errorData.get("error", {}).get("message", "").lower()
errCode = errorData.get("error", {}).get("code", "")
if errCode == "context_length_exceeded" or "too many tokens" in errMsg or "maximum context length" in errMsg:
raise ContextLengthExceededException(
f"Embedding context length exceeded for {model.name}: {errorData.get('error', {}).get('message', '')}"
)
except (ValueError, KeyError):
pass
raise HTTPException(status_code=500, detail=errorMessage)
responseJson = response.json()
embeddings = [item["embedding"] for item in responseJson["data"]]
usage = responseJson.get("usage", {})
return AiModelResponse(
content="",
success=True,
modelId=model.name,
tokensUsed={
"input": usage.get("prompt_tokens", 0),
"output": 0,
"total": usage.get("total_tokens", 0),
},
metadata={"embeddings": embeddings},
)
except (RateLimitExceededException, ContextLengthExceededException):
raise
except Exception as e:
logger.error(f"Error calling OpenAI Embedding API: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error calling OpenAI Embedding API: {str(e)}")
async def callAiImage(self, modelCall: AiModelCall) -> AiModelResponse:
"""
Analyzes an image with the OpenAI Vision API using standardized pattern.

View file

@ -288,7 +288,16 @@ class AiTavily(BaseConnectorAi):
if maxResults < minResults or maxResults > maxAllowedResults:
raise ValueError(f"maxResults must be between {minResults} and {maxAllowedResults}")
# Perform actual API call
# Tavily enforces a 400-character query limit
TAVILY_MAX_QUERY_LENGTH = 400
if len(query) > TAVILY_MAX_QUERY_LENGTH:
truncated = query[:TAVILY_MAX_QUERY_LENGTH]
lastSpace = truncated.rfind(' ')
if lastSpace > TAVILY_MAX_QUERY_LENGTH // 2:
truncated = truncated[:lastSpace]
logger.warning(f"Tavily query truncated from {len(query)} to {len(truncated)} chars")
query = truncated
# Build kwargs only for provided options to avoid API rejections
kwargs: dict = {"query": query, "max_results": maxResults}
if searchDepth is not None:

View file

@ -41,6 +41,11 @@ class SystemTable(BaseModel):
)
def _isVectorType(sqlType: str) -> bool:
"""Check if a SQL type string represents a pgvector column."""
return sqlType.upper().startswith("VECTOR")
def _isJsonbType(fieldType) -> bool:
"""Check if a type should be stored as JSONB in PostgreSQL."""
# Direct dict or list
@ -70,20 +75,26 @@ def _isJsonbType(fieldType) -> bool:
def _get_model_fields(model_class) -> Dict[str, str]:
"""Get all fields from Pydantic model and map to SQL types."""
# Pydantic v2
"""Get all fields from Pydantic model and map to SQL types.
Supports explicit db_type override via json_schema_extra={"db_type": "vector(1536)"}.
This enables pgvector columns without special-casing field names.
"""
model_fields = model_class.model_fields
fields = {}
for field_name, field_info in model_fields.items():
# Pydantic v2
field_type = field_info.annotation
# Explicit db_type override (e.g. vector columns)
extra = field_info.json_schema_extra
if extra and isinstance(extra, dict) and "db_type" in extra:
fields[field_name] = extra["db_type"]
continue
# Check for JSONB fields (Dict, List, or complex types)
# Purely type-based detection - no hardcoded field names
if _isJsonbType(field_type):
fields[field_name] = "JSONB"
# Simple type mapping
elif field_type in (str, type(None)) or (
get_origin(field_type) is Union and type(None) in get_args(field_type)
):
@ -95,11 +106,45 @@ def _get_model_fields(model_class) -> Dict[str, str]:
elif field_type == bool:
fields[field_name] = "BOOLEAN"
else:
fields[field_name] = "TEXT" # Default to TEXT
fields[field_name] = "TEXT"
return fields
def _parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context: str = "") -> None:
"""Parse record fields in-place: numeric typing, vector parsing, JSONB deserialization."""
import json as _json
for fieldName, fieldType in fields.items():
if fieldName not in record:
continue
value = record[fieldName]
if fieldType in ("DOUBLE PRECISION", "INTEGER") and value is not None:
try:
record[fieldName] = float(value) if fieldType == "DOUBLE PRECISION" else int(value)
except (ValueError, TypeError):
logger.warning(f"Could not convert {fieldName} to {fieldType} ({context}): {value}")
elif _isVectorType(fieldType) and value is not None:
if isinstance(value, str):
try:
record[fieldName] = [float(v) for v in value.strip("[]").split(",")]
except (ValueError, TypeError):
logger.warning(f"Could not parse vector field {fieldName} ({context})")
elif isinstance(value, list):
pass # already a list
elif fieldType == "JSONB" and value is not None:
try:
if isinstance(value, str):
record[fieldName] = _json.loads(value)
elif not isinstance(value, (dict, list)):
record[fieldName] = _json.loads(str(value))
except (_json.JSONDecodeError, TypeError, ValueError):
logger.warning(f"Could not parse JSONB field {fieldName}, keeping as string ({context})")
# Cache connectors by (host, database, port) to avoid duplicate inits for same database.
# Thread safety: _connector_cache_lock protects cache access. userId is request-scoped via
# contextvars to avoid races when concurrent requests share the same connector.
@ -132,7 +177,7 @@ def _get_cached_connector(
oldest_key = _connector_cache_order.pop(0)
if oldest_key in _connector_cache:
try:
_connector_cache[oldest_key].close()
_connector_cache[oldest_key].close(forceClose=True)
except Exception as e:
logger.warning(f"Error closing evicted connector: {e}")
del _connector_cache[oldest_key]
@ -144,6 +189,7 @@ def _get_cached_connector(
dbPort=dbPort,
userId=userId,
)
_connector_cache[key]._isCachedShared = True
_connector_cache_order.append(key)
conn = _connector_cache[key]
# Set request-scoped userId via contextvar (avoids mutating shared connector)
@ -180,6 +226,7 @@ class DatabaseConnector:
# Initialize database system first (creates database if needed)
self.connection = None
self._isCachedShared = False
self.initDbSystem()
# No caching needed with proper database - PostgreSQL handles performance
@ -187,6 +234,9 @@ class DatabaseConnector:
# Thread safety
self._lock = threading.Lock()
# pgvector extension state
self._vectorExtensionEnabled = False
# Initialize system table
self._systemTableName = "_system"
self._initializeSystemTable()
@ -500,10 +550,32 @@ class DatabaseConnector:
self.connection.rollback()
return False
def _ensureVectorExtension(self) -> bool:
"""Enable pgvector extension if not already enabled. Called lazily on first vector table."""
if self._vectorExtensionEnabled:
return True
try:
self._ensure_connection()
with self.connection.cursor() as cursor:
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector")
self.connection.commit()
self._vectorExtensionEnabled = True
logger.info("pgvector extension enabled")
return True
except Exception as e:
logger.error(f"Failed to enable pgvector extension: {e}")
if hasattr(self, "connection") and self.connection:
self.connection.rollback()
return False
def _create_table_from_model(self, cursor, table: str, model_class: type) -> None:
"""Create table with columns matching Pydantic model fields."""
fields = _get_model_fields(model_class)
# Enable pgvector if any field uses vector type
if any(_isVectorType(sqlType) for sqlType in fields.values()):
self._ensureVectorExtension()
# Build column definitions with quoted identifiers to preserve exact case
columns = ['"id" VARCHAR(255) PRIMARY KEY']
for field_name, sql_type in fields.items():
@ -576,28 +648,25 @@ class DatabaseConnector:
elif hasattr(value, "value"):
value = value.value
# Handle vector fields (pgvector) - convert List[float] to string
elif col in fields and _isVectorType(fields[col]) and value is not None:
if isinstance(value, list):
value = f"[{','.join(str(v) for v in value)}]"
# Handle JSONB fields - ensure proper JSON format for PostgreSQL
elif col in fields and fields[col] == "JSONB" and value is not None:
import json
if isinstance(value, (dict, list)):
# Convert Python objects to JSON string for PostgreSQL JSONB
value = json.dumps(value)
elif isinstance(value, str):
# Validate that it's valid JSON, if not, try to parse and re-serialize
try:
# Test if it's already valid JSON
json.loads(value)
# If successful, keep as is
pass
except (json.JSONDecodeError, TypeError):
# If not valid JSON, convert to JSON string
value = json.dumps(value)
elif hasattr(value, 'model_dump'):
# Handle Pydantic models
value = json.dumps(value.model_dump())
else:
# Convert other types to JSON
value = json.dumps(value)
values.append(value)
@ -635,46 +704,7 @@ class DatabaseConnector:
record = dict(row)
fields = _get_model_fields(model_class)
# Ensure numeric fields are properly typed and parse JSONB fields
for field_name, field_type in fields.items():
# Ensure numeric fields (float/int) are properly typed
# psycopg2 may return them as strings in some environments (e.g., Azure PostgreSQL)
if field_type in ("DOUBLE PRECISION", "INTEGER") and field_name in record:
value = record[field_name]
if value is not None:
try:
if field_type == "DOUBLE PRECISION":
record[field_name] = float(value)
elif field_type == "INTEGER":
record[field_name] = int(value)
except (ValueError, TypeError):
# If conversion fails, log warning but keep original value
logger.warning(
f"Could not convert {field_name} to {field_type} for record {recordId}: {value}"
)
elif (
field_type == "JSONB"
and field_name in record
and record[field_name] is not None
):
import json
try:
if isinstance(record[field_name], str):
# Parse JSON string back to Python object
record[field_name] = json.loads(record[field_name])
elif isinstance(record[field_name], (dict, list)):
# Already a Python object, keep as is
pass
else:
# Try to parse as JSON
record[field_name] = json.loads(str(record[field_name]))
except (json.JSONDecodeError, TypeError, ValueError):
# If parsing fails, keep as string
logger.warning(
f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}"
)
pass
_parseRecordFields(record, fields, f"record {recordId}")
return record
except Exception as e:
@ -737,55 +767,24 @@ class DatabaseConnector:
cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"')
records = [dict(row) for row in cursor.fetchall()]
# Handle JSONB fields for all records
fields = _get_model_fields(model_class)
model_fields = model_class.model_fields # Get Pydantic model fields
modelFields = model_class.model_fields
for record in records:
for field_name, field_type in fields.items():
if field_type == "JSONB" and field_name in record:
if record[field_name] is None:
# Generic type-based default: List types -> [], Dict types -> {}
# Interfaces handle domain-specific defaults
field_info = model_fields.get(field_name)
if field_info:
field_annotation = field_info.annotation
# Check if it's a List type
if (field_annotation == list or
(hasattr(field_annotation, "__origin__") and
field_annotation.__origin__ is list)):
record[field_name] = []
# Check if it's a Dict type
elif (field_annotation == dict or
(hasattr(field_annotation, "__origin__") and
field_annotation.__origin__ is dict)):
record[field_name] = {}
else:
record[field_name] = None
else:
record[field_name] = None
else:
import json
try:
if isinstance(record[field_name], str):
# Parse JSON string back to Python object
record[field_name] = json.loads(
record[field_name]
)
elif isinstance(record[field_name], (dict, list)):
# Already a Python object, keep as is
pass
else:
# Try to parse as JSON
record[field_name] = json.loads(
str(record[field_name])
)
except (json.JSONDecodeError, TypeError, ValueError):
# If parsing fails, keep as string
logger.warning(
f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}"
)
pass
_parseRecordFields(record, fields, f"table {table}")
# Set type-aware defaults for NULL JSONB fields
for fieldName, fieldType in fields.items():
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
fieldInfo = modelFields.get(fieldName)
if fieldInfo:
fieldAnnotation = fieldInfo.annotation
if (fieldAnnotation == list or
(hasattr(fieldAnnotation, "__origin__") and
fieldAnnotation.__origin__ is list)):
record[fieldName] = []
elif (fieldAnnotation == dict or
(hasattr(fieldAnnotation, "__origin__") and
fieldAnnotation.__origin__ is dict)):
record[fieldName] = {}
return records
except Exception as e:
@ -936,70 +935,23 @@ class DatabaseConnector:
cursor.execute(query, where_values)
records = [dict(row) for row in cursor.fetchall()]
# Handle JSONB fields and ensure numeric types are correct
fields = _get_model_fields(model_class)
model_fields = model_class.model_fields # Get Pydantic model fields
modelFields = model_class.model_fields
for record in records:
for field_name, field_type in fields.items():
# Ensure numeric fields (float/int) are properly typed
# psycopg2 may return them as strings in some environments (e.g., Azure PostgreSQL)
if field_type in ("DOUBLE PRECISION", "INTEGER") and field_name in record:
value = record[field_name]
if value is not None:
try:
if field_type == "DOUBLE PRECISION":
record[field_name] = float(value)
elif field_type == "INTEGER":
record[field_name] = int(value)
except (ValueError, TypeError):
# If conversion fails, log warning but keep original value
logger.warning(
f"Could not convert {field_name} to {field_type} for record {record.get('id', 'unknown')}: {value}"
)
elif field_type == "JSONB" and field_name in record:
if record[field_name] is None:
# Generic type-based default: List types -> [], Dict types -> {}
# Interfaces handle domain-specific defaults
field_info = model_fields.get(field_name)
if field_info:
field_annotation = field_info.annotation
# Check if it's a List type
if (field_annotation == list or
(hasattr(field_annotation, "__origin__") and
field_annotation.__origin__ is list)):
record[field_name] = []
# Check if it's a Dict type
elif (field_annotation == dict or
(hasattr(field_annotation, "__origin__") and
field_annotation.__origin__ is dict)):
record[field_name] = {}
else:
record[field_name] = None
else:
record[field_name] = None
else:
import json
try:
if isinstance(record[field_name], str):
# Parse JSON string back to Python object
record[field_name] = json.loads(
record[field_name]
)
elif isinstance(record[field_name], (dict, list)):
# Already a Python object, keep as is
pass
else:
# Try to parse as JSON
record[field_name] = json.loads(
str(record[field_name])
)
except (json.JSONDecodeError, TypeError, ValueError):
# If parsing fails, keep as string
logger.warning(
f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}"
)
pass
_parseRecordFields(record, fields, f"table {table}")
for fieldName, fieldType in fields.items():
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
fieldInfo = modelFields.get(fieldName)
if fieldInfo:
fieldAnnotation = fieldInfo.annotation
if (fieldAnnotation == list or
(hasattr(fieldAnnotation, "__origin__") and
fieldAnnotation.__origin__ is list)):
record[fieldName] = []
elif (fieldAnnotation == dict or
(hasattr(fieldAnnotation, "__origin__") and
fieldAnnotation.__origin__ is dict)):
record[fieldName] = {}
# If fieldFilter is available, reduce the fields
if fieldFilter and isinstance(fieldFilter, list):
@ -1080,7 +1032,10 @@ class DatabaseConnector:
existingRecord.update(record)
# Save updated record
self._saveRecord(model_class, recordId, existingRecord)
saved = self._saveRecord(model_class, recordId, existingRecord)
if not saved:
table = model_class.__name__
raise ValueError(f"Failed to save record {recordId} to table {table}")
return existingRecord
def recordDelete(self, model_class: type, recordId: str) -> bool:
@ -1127,8 +1082,94 @@ class DatabaseConnector:
initialId = systemData.get(table)
return initialId
def close(self):
"""Close the database connection."""
def semanticSearch(
self,
modelClass: type,
vectorColumn: str,
queryVector: List[float],
limit: int = 10,
recordFilter: Dict[str, Any] = None,
minScore: float = None,
) -> List[Dict[str, Any]]:
"""Semantic search using pgvector cosine distance.
Args:
modelClass: Pydantic model class for the table.
vectorColumn: Name of the vector column to search.
queryVector: Query vector as List[float].
limit: Maximum number of results.
recordFilter: Additional WHERE filters (field: value).
minScore: Minimum cosine similarity (0.0 - 1.0).
Returns:
List of records with an added '_score' field (cosine similarity),
sorted by similarity descending.
"""
table = modelClass.__name__
try:
if not self._ensureTableExists(modelClass):
return []
vectorStr = f"[{','.join(str(v) for v in queryVector)}]"
whereConditions = []
whereValues = []
if recordFilter:
for field, value in recordFilter.items():
if value is None:
whereConditions.append(f'"{field}" IS NULL')
elif isinstance(value, (list, tuple)):
if not value:
whereConditions.append("1 = 0")
else:
whereConditions.append(f'"{field}" = ANY(%s)')
whereValues.append(list(value))
else:
whereConditions.append(f'"{field}" = %s')
whereValues.append(value)
if minScore is not None:
whereConditions.append(
f'1 - ("{vectorColumn}" <=> %s::vector) >= %s'
)
whereValues.extend([vectorStr, minScore])
whereClause = ""
if whereConditions:
whereClause = " WHERE " + " AND ".join(whereConditions)
query = (
f'SELECT *, 1 - ("{vectorColumn}" <=> %s::vector) AS "_score" '
f'FROM "{table}"{whereClause} '
f'ORDER BY "{vectorColumn}" <=> %s::vector '
f'LIMIT %s'
)
params = [vectorStr] + whereValues + [vectorStr, limit]
with self.connection.cursor() as cursor:
cursor.execute(query, params)
records = [dict(row) for row in cursor.fetchall()]
fields = _get_model_fields(modelClass)
for record in records:
_parseRecordFields(record, fields, f"semanticSearch {table}")
return records
except Exception as e:
logger.error(f"Error in semantic search on {table}: {e}")
return []
def close(self, forceClose: bool = False):
"""Close the database connection.
Shared cached connectors are intentionally kept open unless forceClose=True.
This prevents accidental shutdown from interface __del__ methods while
other requests are still using the same cached connector instance.
"""
if self._isCachedShared and not forceClose:
return
if (
hasattr(self, "connection")
and self.connection
@ -1141,5 +1182,4 @@ class DatabaseConnector:
try:
self.close()
except Exception:
# Ignore errors during cleanup
pass

View file

@ -0,0 +1,63 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Abstract base classes for the Provider-Connector architecture (1:n).
One ProviderConnector per vendor (e.g. MsftConnector, GoogleConnector).
Each ProviderConnector exposes n ServiceAdapters (e.g. SharepointAdapter, OutlookAdapter).
All ServiceAdapters share the same access token from the UserConnection.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Optional, Union
@dataclass
class DownloadResult:
"""Rich return type for ServiceAdapter.download() when metadata is available."""
data: bytes = field(default=b"", repr=False)
fileName: str = ""
mimeType: str = ""
class ServiceAdapter(ABC):
"""Standardized operations for a single service of a provider."""
@abstractmethod
async def browse(self, path: str, filter: Optional[str] = None) -> list:
"""List items (files/folders) at the given path."""
...
@abstractmethod
async def download(self, path: str) -> Union[bytes, DownloadResult]:
"""Download a file. Return bytes or DownloadResult with metadata."""
...
@abstractmethod
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
"""Upload a file to the given path. Returns metadata of the created entry."""
...
@abstractmethod
async def search(self, query: str, path: Optional[str] = None) -> list:
"""Search for items matching the query."""
...
class ProviderConnector(ABC):
"""One connector per provider. Manages a UserConnection + token.
Provides access to n services of the provider."""
def __init__(self, connection, accessToken: str):
self.connection = connection
self.accessToken = accessToken
@abstractmethod
def getAvailableServices(self) -> List[str]:
"""Which services does this provider offer?"""
...
@abstractmethod
def getServiceAdapter(self, service: str) -> ServiceAdapter:
"""Return the ServiceAdapter for a specific service."""
...

View file

@ -0,0 +1,94 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""ConnectorResolver -- resolves a connectionId to the correct ProviderConnector and ServiceAdapter.
Registry maps authority values to ProviderConnector classes.
The resolver loads the UserConnection, obtains a fresh token via SecurityService,
and instantiates the appropriate connector.
"""
import logging
from typing import Dict, Any, Type, Optional
from modules.connectors.connectorProviderBase import ProviderConnector, ServiceAdapter
logger = logging.getLogger(__name__)
class ConnectorResolver:
"""Resolves connectionId → ProviderConnector (with fresh token) → ServiceAdapter."""
_providerRegistry: Dict[str, Type[ProviderConnector]] = {}
def __init__(self, securityService, dbInterface):
"""
Args:
securityService: SecurityService instance (for getFreshToken)
dbInterface: DB interface with getUserConnection(connectionId)
"""
self._security = securityService
self._db = dbInterface
self._ensureRegistered()
def _ensureRegistered(self):
"""Lazy-register known providers on first instantiation."""
if ConnectorResolver._providerRegistry:
return
try:
from modules.connectors.providerMsft.connectorMsft import MsftConnector
ConnectorResolver._providerRegistry["msft"] = MsftConnector
except ImportError:
logger.warning("MsftConnector not available")
try:
from modules.connectors.providerGoogle.connectorGoogle import GoogleConnector
ConnectorResolver._providerRegistry["google"] = GoogleConnector
except ImportError:
logger.debug("GoogleConnector not available (stub)")
try:
from modules.connectors.providerFtp.connectorFtp import FtpConnector
ConnectorResolver._providerRegistry["local:ftp"] = FtpConnector
except ImportError:
logger.debug("FtpConnector not available (stub)")
async def resolve(self, connectionId: str) -> ProviderConnector:
"""Resolve connectionId to a ProviderConnector with a fresh access token."""
connection = await self._loadConnection(connectionId)
if not connection:
raise ValueError(f"UserConnection not found: {connectionId}")
authority = getattr(connection, "authority", None)
if not authority:
raise ValueError(f"Connection {connectionId} has no authority")
authorityStr = authority.value if hasattr(authority, "value") else str(authority)
providerClass = self._providerRegistry.get(authorityStr)
if not providerClass:
raise ValueError(f"No ProviderConnector registered for authority: {authorityStr}")
token = self._security.getFreshToken(connectionId)
if not token or not token.tokenAccess:
raise ValueError(f"No valid token for connection {connectionId}")
return providerClass(connection, token.tokenAccess)
async def resolveService(self, connectionId: str, service: str) -> ServiceAdapter:
"""Resolve connectionId + service name to a concrete ServiceAdapter."""
provider = await self.resolve(connectionId)
available = provider.getAvailableServices()
if service not in available:
raise ValueError(f"Service '{service}' not available. Options: {available}")
return provider.getServiceAdapter(service)
async def _loadConnection(self, connectionId: str) -> Optional[Any]:
"""Load UserConnection from DB."""
try:
if hasattr(self._db, "getUserConnection"):
return self._db.getUserConnection(connectionId)
if hasattr(self._db, "loadRecord"):
from modules.datamodels.datamodelUam import UserConnection
return self._db.loadRecord(UserConnection, connectionId)
except Exception as e:
logger.error(f"Failed to load connection {connectionId}: {e}")
return None

View file

@ -9,7 +9,8 @@ import json
import html
import asyncio
import logging
from typing import Dict, Optional, Any, List
import time
from typing import AsyncGenerator, Dict, Optional, Any, List, Tuple
from google.cloud import speech
from google.cloud import translate_v2 as translate
from google.cloud import texttospeech
@ -403,6 +404,155 @@ class ConnectorGoogleSpeech:
"error": str(e)
}
async def streamingRecognize(
self,
audioQueue: asyncio.Queue,
language: str = "de-DE",
phraseHints: Optional[list] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Stream audio chunks to Google Cloud Speech-to-Text Streaming API.
Google handles silence/endpoint detection natively.
Args:
audioQueue: Queue of (bytes, bool) tuples. bytes=audio data, bool=isLast.
Send (b"", True) to signal end of stream.
language: Language code
phraseHints: Optional boost phrases
Yields:
Dicts with keys: isFinal, transcript, confidence, stabilityScore, audioDurationSec
"""
STREAM_LIMIT_SEC = 290
streamStartTs = time.time()
totalAudioBytes = 0
configParams = {
"encoding": speech.RecognitionConfig.AudioEncoding.WEBM_OPUS,
"sample_rate_hertz": 48000,
"audio_channel_count": 1,
"language_code": language,
"enable_automatic_punctuation": True,
"model": "latest_long",
"use_enhanced": True,
}
if phraseHints:
configParams["speech_contexts"] = [speech.SpeechContext(phrases=phraseHints, boost=15.0)]
recognitionConfig = speech.RecognitionConfig(**configParams)
streamingConfig = speech.StreamingRecognitionConfig(
config=recognitionConfig,
interim_results=True,
single_utterance=False,
)
import queue as threadQueue
audioInQ: threadQueue.Queue = threadQueue.Queue()
resultOutQ: asyncio.Queue = asyncio.Queue()
async def _pumpAudioToThread():
try:
while True:
item = await audioQueue.get()
audioInQ.put(item)
if item[1]:
return
except asyncio.CancelledError:
audioInQ.put((b"", True))
def _requestGenerator():
nonlocal totalAudioBytes
while True:
try:
chunk, isLast = audioInQ.get(timeout=30.0)
except threadQueue.Empty:
return
if isLast or not chunk:
return
totalAudioBytes += len(chunk)
yield speech.StreamingRecognizeRequest(audio_content=chunk)
def _runStreamingInThread():
try:
responseStream = self.speech_client.streaming_recognize(
config=streamingConfig,
requests=_requestGenerator(),
)
for response in responseStream:
elapsed = time.time() - streamStartTs
estimatedDurationSec = totalAudioBytes / (48000 * 1 * 2) if totalAudioBytes else 0
finalTexts = []
interimTexts = []
lastFinalConfidence = 0.0
for result in response.results:
alt = result.alternatives[0] if result.alternatives else None
if not alt or not alt.transcript.strip():
continue
if result.is_final:
finalTexts.append(alt.transcript.strip())
lastFinalConfidence = alt.confidence
else:
interimTexts.append(alt.transcript.strip())
for ft in finalTexts:
asyncio.run_coroutine_threadsafe(resultOutQ.put({
"isFinal": True,
"transcript": ft,
"confidence": lastFinalConfidence,
"stabilityScore": 0.0,
"audioDurationSec": estimatedDurationSec,
}), loop)
if interimTexts:
combined = " ".join(interimTexts)
asyncio.run_coroutine_threadsafe(resultOutQ.put({
"isFinal": False,
"transcript": combined,
"confidence": 0.0,
"stabilityScore": 0.0,
"audioDurationSec": estimatedDurationSec,
}), loop)
if elapsed >= STREAM_LIMIT_SEC:
logger.info("Streaming STT approaching 5-min limit, client should reconnect")
asyncio.run_coroutine_threadsafe(resultOutQ.put({
"isFinal": False, "transcript": "", "confidence": 0.0,
"reconnectRequired": True, "audioDurationSec": 0,
}), loop)
return
except Exception as e:
logger.error(f"Google Streaming STT error: {e}")
asyncio.run_coroutine_threadsafe(resultOutQ.put({
"error": str(e),
}), loop)
finally:
asyncio.run_coroutine_threadsafe(resultOutQ.put(None), loop)
loop = asyncio.get_running_loop()
pumpTask = asyncio.ensure_future(_pumpAudioToThread())
streamFuture = loop.run_in_executor(None, _runStreamingInThread)
try:
while True:
item = await resultOutQ.get()
if item is None:
break
if "error" in item:
raise RuntimeError(item["error"])
yield item
finally:
pumpTask.cancel()
await asyncio.shield(streamFuture)
def calculateSttCostCHF(self, audioDurationSec: float) -> float:
"""Google STT cost: ~$0.016/min (standard model)."""
return round((audioDurationSec / 60.0) * 0.016, 8)
def calculateTtsCostCHF(self, characterCount: int) -> float:
"""Google TTS WaveNet cost: ~$0.000004/char."""
return round(characterCount * 0.000004, 8)
async def translateText(self, text: str, targetLanguage: str = "en",
sourceLanguage: str = "de") -> Dict:
"""

View file

@ -1,4 +1,3 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""FTP/SFTP Provider Connector stub."""

View file

@ -0,0 +1,48 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""FTP/SFTP ProviderConnector stub.
Implements the ProviderConnector interface for FTP/SFTP file access.
Full implementation follows when FTP integration is prioritized.
"""
import logging
from typing import List, Optional
from modules.connectors.connectorProviderBase import ProviderConnector, ServiceAdapter
from modules.datamodels.datamodelDataSource import ExternalEntry
logger = logging.getLogger(__name__)
class FtpFilesAdapter(ServiceAdapter):
"""FTP files ServiceAdapter (stub)."""
def __init__(self, accessToken: str):
self._accessToken = accessToken
async def browse(self, path: str, filter: Optional[str] = None) -> List[ExternalEntry]:
logger.info(f"FTP browse stub: {path}")
return []
async def download(self, path: str) -> bytes:
logger.info(f"FTP download stub: {path}")
return b""
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
return {"error": "FTP upload not yet implemented"}
async def search(self, query: str, path: Optional[str] = None) -> List[ExternalEntry]:
return []
class FtpConnector(ProviderConnector):
"""FTP ProviderConnector -- 1 connection -> files."""
def getAvailableServices(self) -> List[str]:
return ["files"]
def getServiceAdapter(self, service: str) -> ServiceAdapter:
if service != "files":
raise ValueError(f"FTP only supports 'files' service, got '{service}'")
return FtpFilesAdapter(self.accessToken)

View file

@ -0,0 +1,3 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Google Provider Connector -- 1 Connection : n Services (Drive, Gmail)."""

View file

@ -0,0 +1,265 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Google ProviderConnector -- Drive and Gmail via Google OAuth."""
import logging
from typing import Any, Dict, List, Optional
import aiohttp
from modules.connectors.connectorProviderBase import ProviderConnector, ServiceAdapter, DownloadResult
from modules.datamodels.datamodelDataSource import ExternalEntry
logger = logging.getLogger(__name__)
_DRIVE_BASE = "https://www.googleapis.com/drive/v3"
_GMAIL_BASE = "https://gmail.googleapis.com/gmail/v1"
async def _googleGet(token: str, url: str) -> Dict[str, Any]:
headers = {"Authorization": f"Bearer {token}"}
timeout = aiohttp.ClientTimeout(total=20)
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as resp:
if resp.status in (200, 201):
return await resp.json()
errorText = await resp.text()
logger.warning(f"Google API {resp.status}: {errorText[:300]}")
return {"error": f"{resp.status}: {errorText[:200]}"}
except Exception as e:
return {"error": str(e)}
class DriveAdapter(ServiceAdapter):
"""Google Drive ServiceAdapter -- browse files and folders."""
def __init__(self, accessToken: str):
self._token = accessToken
async def browse(self, path: str, filter: Optional[str] = None) -> List[ExternalEntry]:
folderId = (path or "").strip("/") or "root"
query = f"'{folderId}' in parents and trashed=false"
fields = "files(id,name,mimeType,size,modifiedTime,parents)"
url = f"{_DRIVE_BASE}/files?q={query}&fields={fields}&pageSize=100&orderBy=folder,name"
result = await _googleGet(self._token, url)
if "error" in result:
logger.warning(f"Google Drive browse failed: {result['error']}")
return []
entries = []
for f in result.get("files", []):
isFolder = f.get("mimeType") == "application/vnd.google-apps.folder"
entries.append(ExternalEntry(
name=f.get("name", ""),
path=f"/{f.get('id', '')}",
isFolder=isFolder,
size=int(f.get("size", 0)) if f.get("size") else None,
mimeType=f.get("mimeType") if not isFolder else None,
metadata={"id": f.get("id"), "modifiedTime": f.get("modifiedTime")},
))
return entries
_EXPORT_MIME_MAP = {
"application/vnd.google-apps.document": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.google-apps.spreadsheet": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.google-apps.presentation": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
"application/vnd.google-apps.drawing": "application/pdf",
}
async def download(self, path: str) -> bytes:
fileId = (path or "").strip("/")
if not fileId:
return b""
headers = {"Authorization": f"Bearer {self._token}"}
timeout = aiohttp.ClientTimeout(total=60)
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
# Try direct download first
url = f"{_DRIVE_BASE}/files/{fileId}?alt=media"
async with session.get(url, headers=headers) as resp:
if resp.status == 200:
return await resp.read()
logger.debug(f"Google Drive direct download returned {resp.status} for {fileId}")
# If 403/404, check if it's a native Google file that needs export
metaUrl = f"{_DRIVE_BASE}/files/{fileId}?fields=mimeType,name"
async with session.get(metaUrl, headers=headers) as metaResp:
if metaResp.status != 200:
logger.warning(f"Google Drive metadata fetch failed ({metaResp.status}) for {fileId}")
return b""
meta = await metaResp.json()
fileMime = meta.get("mimeType", "")
fileName = meta.get("name", fileId)
exportMime = self._EXPORT_MIME_MAP.get(fileMime)
if not exportMime:
logger.warning(f"Google Drive: unsupported mimeType '{fileMime}' for file '{fileName}' ({fileId})")
return b""
exportUrl = f"{_DRIVE_BASE}/files/{fileId}/export?mimeType={exportMime}"
logger.info(f"Google Drive: exporting '{fileName}' as {exportMime}")
async with session.get(exportUrl, headers=headers) as exportResp:
if exportResp.status == 200:
return await exportResp.read()
logger.warning(f"Google Drive export failed ({exportResp.status}) for '{fileName}'")
except Exception as e:
logger.error(f"Google Drive download failed for {fileId}: {e}")
return b""
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
return {"error": "Google Drive upload not yet implemented"}
async def search(self, query: str, path: Optional[str] = None) -> List[ExternalEntry]:
safeQuery = query.replace("'", "\\'")
folderId = (path or "").strip("/")
qParts = [f"name contains '{safeQuery}'", "trashed=false"]
if folderId:
qParts.append(f"'{folderId}' in parents")
qStr = " and ".join(qParts)
url = f"{_DRIVE_BASE}/files?q={qStr}&fields=files(id,name,mimeType,size)&pageSize=25"
logger.debug(f"Google Drive search: q={qStr}")
result = await _googleGet(self._token, url)
if "error" in result:
return []
return [
ExternalEntry(
name=f.get("name", ""),
path=f"/{f.get('id', '')}",
isFolder=f.get("mimeType") == "application/vnd.google-apps.folder",
size=int(f.get("size", 0)) if f.get("size") else None,
)
for f in result.get("files", [])
]
class GmailAdapter(ServiceAdapter):
"""Gmail ServiceAdapter -- browse labels and messages."""
def __init__(self, accessToken: str):
self._token = accessToken
async def browse(self, path: str, filter: Optional[str] = None) -> list:
cleanPath = (path or "").strip("/")
if not cleanPath:
url = f"{_GMAIL_BASE}/users/me/labels"
result = await _googleGet(self._token, url)
if "error" in result:
logger.warning(f"Gmail labels failed: {result['error']}")
return []
_SYSTEM_LABELS = {"INBOX", "SENT", "DRAFT", "TRASH", "SPAM", "STARRED", "IMPORTANT"}
labels = []
for lbl in result.get("labels", []):
labelId = lbl.get("id", "")
labelName = lbl.get("name", labelId)
if lbl.get("type") == "system" and labelId not in _SYSTEM_LABELS:
continue
labels.append(ExternalEntry(
name=labelName,
path=f"/{labelId}",
isFolder=True,
metadata={"id": labelId, "type": lbl.get("type", "")},
))
labels.sort(key=lambda e: (0 if e.metadata.get("type") == "system" else 1, e.name))
return labels
url = f"{_GMAIL_BASE}/users/me/messages?labelIds={cleanPath}&maxResults=25"
result = await _googleGet(self._token, url)
if "error" in result:
return []
entries = []
for msg in result.get("messages", [])[:25]:
msgId = msg.get("id", "")
detailUrl = f"{_GMAIL_BASE}/users/me/messages/{msgId}?format=metadata&metadataHeaders=Subject&metadataHeaders=From&metadataHeaders=Date"
detail = await _googleGet(self._token, detailUrl)
if "error" in detail:
entries.append(ExternalEntry(name=f"Message {msgId}", path=f"/{cleanPath}/{msgId}", isFolder=False))
continue
headers = {h.get("name", ""): h.get("value", "") for h in detail.get("payload", {}).get("headers", [])}
entries.append(ExternalEntry(
name=headers.get("Subject", "(no subject)"),
path=f"/{cleanPath}/{msgId}",
isFolder=False,
metadata={
"id": msgId,
"from": headers.get("From", ""),
"date": headers.get("Date", ""),
"snippet": detail.get("snippet", ""),
},
))
return entries
async def download(self, path: str) -> DownloadResult:
"""Download a Gmail message as RFC 822 EML via format=raw."""
import base64
import re
cleanPath = (path or "").strip("/")
msgId = cleanPath.split("/")[-1] if cleanPath else ""
if not msgId:
return DownloadResult()
url = f"{_GMAIL_BASE}/users/me/messages/{msgId}?format=raw"
result = await _googleGet(self._token, url)
if "error" in result:
return DownloadResult()
rawB64 = result.get("raw", "")
if not rawB64:
return DownloadResult()
emlBytes = base64.urlsafe_b64decode(rawB64)
metaUrl = f"{_GMAIL_BASE}/users/me/messages/{msgId}?format=metadata&metadataHeaders=Subject"
meta = await _googleGet(self._token, metaUrl)
subject = msgId
if "error" not in meta:
for h in meta.get("payload", {}).get("headers", []):
if h.get("name", "").lower() == "subject":
subject = h.get("value", msgId)
break
safeName = re.sub(r'[<>:"/\\|?*\x00-\x1f]', "_", subject)[:80].strip(". ") or "email"
return DownloadResult(
data=emlBytes,
fileName=f"{safeName}.eml",
mimeType="message/rfc822",
)
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
return {"error": "Gmail upload not applicable"}
async def search(self, query: str, path: Optional[str] = None) -> list:
url = f"{_GMAIL_BASE}/users/me/messages?q={query}&maxResults=10"
result = await _googleGet(self._token, url)
if "error" in result:
return []
return [
ExternalEntry(
name=f"Message {m.get('id', '')}",
path=f"/{m.get('id', '')}",
isFolder=False,
metadata={"id": m.get("id")},
)
for m in result.get("messages", [])
]
class GoogleConnector(ProviderConnector):
"""Google ProviderConnector -- 1 connection -> Drive + Gmail."""
_SERVICE_MAP = {
"drive": DriveAdapter,
"gmail": GmailAdapter,
}
def getAvailableServices(self) -> List[str]:
return list(self._SERVICE_MAP.keys())
def getServiceAdapter(self, service: str) -> ServiceAdapter:
adapterClass = self._SERVICE_MAP.get(service)
if not adapterClass:
raise ValueError(f"Unknown Google service: {service}. Available: {list(self._SERVICE_MAP.keys())}")
return adapterClass(self.accessToken)

View file

@ -0,0 +1,3 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Microsoft Provider Connector -- 1 Connection : n Services (SharePoint, Outlook, Teams, OneDrive)."""

View file

@ -0,0 +1,469 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Microsoft ProviderConnector -- one MSFT connection serves SharePoint, Outlook, Teams, OneDrive.
All ServiceAdapters share the same OAuth access token obtained from the
UserConnection (authority=msft).
"""
import logging
import aiohttp
import asyncio
from typing import Dict, Any, List, Optional
from modules.connectors.connectorProviderBase import ProviderConnector, ServiceAdapter, DownloadResult
from modules.datamodels.datamodelDataSource import ExternalEntry
logger = logging.getLogger(__name__)
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
class _GraphApiMixin:
"""Shared Graph API call logic for all MSFT service adapters."""
def __init__(self, accessToken: str):
self._accessToken = accessToken
async def _graphGet(self, endpoint: str) -> Dict[str, Any]:
return await _makeGraphCall(self._accessToken, endpoint, "GET")
async def _graphPost(self, endpoint: str, data: Any = None) -> Dict[str, Any]:
return await _makeGraphCall(self._accessToken, endpoint, "POST", data)
async def _graphPut(self, endpoint: str, data: bytes = None) -> Dict[str, Any]:
return await _makeGraphCall(self._accessToken, endpoint, "PUT", data)
async def _graphDelete(self, endpoint: str) -> Dict[str, Any]:
return await _makeGraphCall(self._accessToken, endpoint, "DELETE")
async def _graphDownload(self, endpoint: str) -> Optional[bytes]:
"""Download binary content from Graph API."""
headers = {"Authorization": f"Bearer {self._accessToken}"}
timeout = aiohttp.ClientTimeout(total=60)
url = f"{_GRAPH_BASE}/{endpoint.lstrip('/')}"
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as resp:
if resp.status == 200:
return await resp.read()
logger.error(f"Download failed {resp.status}: {await resp.text()}")
return None
except Exception as e:
logger.error(f"Graph download error: {e}")
return None
async def _makeGraphCall(
token: str, endpoint: str, method: str = "GET", data: Any = None
) -> Dict[str, Any]:
"""Execute a single Microsoft Graph API call."""
url = f"{_GRAPH_BASE}/{endpoint.lstrip('/')}"
contentType = "application/json"
if method == "PUT" and isinstance(data, bytes):
contentType = "application/octet-stream"
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": contentType,
}
timeout = aiohttp.ClientTimeout(total=30)
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
kwargs: Dict[str, Any] = {"headers": headers}
if data is not None:
kwargs["data"] = data
if method == "GET":
async with session.get(url, **kwargs) as resp:
return await _handleResponse(resp)
elif method == "POST":
async with session.post(url, **kwargs) as resp:
return await _handleResponse(resp)
elif method == "PUT":
async with session.put(url, **kwargs) as resp:
return await _handleResponse(resp)
elif method == "DELETE":
async with session.delete(url, **kwargs) as resp:
if resp.status in (200, 204):
return {}
return await _handleResponse(resp)
except asyncio.TimeoutError:
return {"error": f"Graph API timeout: {endpoint}"}
except Exception as e:
return {"error": f"Graph API error: {e}"}
return {"error": f"Unsupported method: {method}"}
async def _handleResponse(resp: aiohttp.ClientResponse) -> Dict[str, Any]:
if resp.status in (200, 201):
return await resp.json()
errorText = await resp.text()
logger.error(f"Graph API {resp.status}: {errorText}")
return {"error": f"{resp.status}: {errorText}"}
def _graphItemToExternalEntry(item: Dict[str, Any], basePath: str = "") -> ExternalEntry:
isFolder = "folder" in item
return ExternalEntry(
name=item.get("name", ""),
path=f"{basePath}/{item.get('name', '')}" if basePath else item.get("name", ""),
isFolder=isFolder,
size=item.get("size"),
mimeType=item.get("file", {}).get("mimeType") if not isFolder else None,
lastModified=None,
metadata={
"id": item.get("id"),
"webUrl": item.get("webUrl"),
"childCount": item.get("folder", {}).get("childCount") if isFolder else None,
},
)
# ---------------------------------------------------------------------------
# SharePoint Adapter
# ---------------------------------------------------------------------------
class SharepointAdapter(_GraphApiMixin, ServiceAdapter):
"""ServiceAdapter for SharePoint (files, sites) via Microsoft Graph."""
async def browse(self, path: str, filter: Optional[str] = None) -> List[ExternalEntry]:
"""List items in a SharePoint folder.
Path format: /sites/<SiteName>/<FolderPath>
Root "/" lists available sites via discovery.
"""
if not path or path == "/":
return await self._discoverSites()
siteId, folderPath = _parseSharepointPath(path)
if not siteId:
return await self._discoverSites()
if not folderPath or folderPath == "/":
endpoint = f"sites/{siteId}/drive/root/children"
else:
cleanPath = folderPath.lstrip("/")
endpoint = f"sites/{siteId}/drive/root:/{cleanPath}:/children"
result = await self._graphGet(endpoint)
if "error" in result:
logger.warning(f"SharePoint browse failed: {result['error']}")
return []
entries = [_graphItemToExternalEntry(item, path) for item in result.get("value", [])]
if filter:
entries = [e for e in entries if _matchFilter(e, filter)]
return entries
async def _discoverSites(self) -> List[ExternalEntry]:
"""Discover accessible SharePoint sites."""
result = await self._graphGet("sites?search=*&$top=50")
if "error" in result:
logger.warning(f"SharePoint site discovery failed: {result['error']}")
return []
return [
ExternalEntry(
name=s.get("displayName") or s.get("name", ""),
path=f"/sites/{s.get('id', '')}",
isFolder=True,
metadata={
"id": s.get("id"),
"webUrl": s.get("webUrl"),
"description": s.get("description", ""),
},
)
for s in result.get("value", [])
if s.get("displayName")
]
async def download(self, path: str) -> bytes:
siteId, filePath = _parseSharepointPath(path)
if not siteId or not filePath:
return b""
cleanPath = filePath.strip("/")
endpoint = f"sites/{siteId}/drive/root:/{cleanPath}:/content"
data = await self._graphDownload(endpoint)
return data or b""
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
siteId, folderPath = _parseSharepointPath(path)
if not siteId:
return {"error": "Invalid SharePoint path"}
cleanFolder = (folderPath or "").strip("/")
uploadPath = f"{cleanFolder}/{fileName}" if cleanFolder else fileName
endpoint = f"sites/{siteId}/drive/root:/{uploadPath}:/content"
result = await self._graphPut(endpoint, data)
return result
async def search(self, query: str, path: Optional[str] = None) -> List[ExternalEntry]:
siteId, _ = _parseSharepointPath(path or "")
if not siteId:
return []
safeQuery = query.replace("'", "''")
endpoint = f"sites/{siteId}/drive/root/search(q='{safeQuery}')"
result = await self._graphGet(endpoint)
if "error" in result:
return []
return [_graphItemToExternalEntry(item) for item in result.get("value", [])]
# ---------------------------------------------------------------------------
# Outlook Adapter
# ---------------------------------------------------------------------------
class OutlookAdapter(_GraphApiMixin, ServiceAdapter):
"""ServiceAdapter for Outlook (mail, calendar) via Microsoft Graph."""
async def browse(self, path: str, filter: Optional[str] = None) -> List[ExternalEntry]:
"""List mail folders or messages.
path = "" or "/" list mail folders
path = "/Inbox" list messages in Inbox
"""
if not path or path == "/":
result = await self._graphGet("me/mailFolders")
if "error" in result:
return []
return [
ExternalEntry(
name=f.get("displayName", ""),
path=f"/{f.get('id', '')}",
isFolder=True,
metadata={"id": f.get("id"), "totalItemCount": f.get("totalItemCount")},
)
for f in result.get("value", [])
]
folderId = path.strip("/")
endpoint = f"me/mailFolders/{folderId}/messages?$top=25&$orderby=receivedDateTime desc"
result = await self._graphGet(endpoint)
if "error" in result:
return []
return [
ExternalEntry(
name=m.get("subject", "(no subject)"),
path=f"{path}/{m.get('id', '')}",
isFolder=False,
metadata={
"id": m.get("id"),
"from": m.get("from", {}).get("emailAddress", {}).get("address"),
"receivedDateTime": m.get("receivedDateTime"),
"hasAttachments": m.get("hasAttachments", False),
},
)
for m in result.get("value", [])
]
async def download(self, path: str) -> DownloadResult:
"""Download a mail message as RFC 822 EML via Graph API $value endpoint."""
import re
messageId = path.strip("/").split("/")[-1]
meta = await self._graphGet(f"me/messages/{messageId}?$select=subject")
subject = meta.get("subject", messageId) if "error" not in meta else messageId
safeName = re.sub(r'[<>:"/\\|?*\x00-\x1f]', "_", subject)[:80].strip(". ") or "email"
emlBytes = await self._graphDownload(f"me/messages/{messageId}/$value")
if not emlBytes:
return DownloadResult()
return DownloadResult(
data=emlBytes,
fileName=f"{safeName}.eml",
mimeType="message/rfc822",
)
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
"""Not applicable for Outlook in the file sense."""
return {"error": "Upload not supported for Outlook"}
async def search(self, query: str, path: Optional[str] = None) -> List[ExternalEntry]:
safeQuery = query.replace("'", "''")
endpoint = f"me/messages?$search=\"{safeQuery}\"&$top=25"
result = await self._graphGet(endpoint)
if "error" in result:
return []
return [
ExternalEntry(
name=m.get("subject", "(no subject)"),
path=f"/search/{m.get('id', '')}",
isFolder=False,
metadata={
"id": m.get("id"),
"from": m.get("from", {}).get("emailAddress", {}).get("address"),
"receivedDateTime": m.get("receivedDateTime"),
},
)
for m in result.get("value", [])
]
async def sendMail(
self, to: List[str], subject: str, body: str,
cc: Optional[List[str]] = None, attachments: Optional[List[Dict]] = None
) -> Dict[str, Any]:
"""Send an email via Microsoft Graph."""
import json
message: Dict[str, Any] = {
"subject": subject,
"body": {"contentType": "Text", "content": body},
"toRecipients": [{"emailAddress": {"address": addr}} for addr in to],
}
if cc:
message["ccRecipients"] = [{"emailAddress": {"address": addr}} for addr in cc]
payload = json.dumps({"message": message, "saveToSentItems": True}).encode("utf-8")
result = await self._graphPost("me/sendMail", payload)
if "error" in result:
return result
return {"success": True}
# ---------------------------------------------------------------------------
# Teams Adapter (Stub)
# ---------------------------------------------------------------------------
class TeamsAdapter(_GraphApiMixin, ServiceAdapter):
"""ServiceAdapter for Microsoft Teams -- browse joined teams and channels."""
async def browse(self, path: str, filter: Optional[str] = None) -> list:
cleanPath = (path or "").strip("/")
if not cleanPath:
result = await self._graphGet("me/joinedTeams")
if "error" in result:
logger.warning(f"Teams browse failed: {result['error']}")
return []
return [
ExternalEntry(
name=t.get("displayName", ""),
path=f"/{t.get('id', '')}",
isFolder=True,
metadata={"id": t.get("id"), "description": t.get("description", "")},
)
for t in result.get("value", [])
]
parts = cleanPath.split("/", 1)
teamId = parts[0]
if len(parts) == 1:
result = await self._graphGet(f"teams/{teamId}/channels")
if "error" in result:
return []
return [
ExternalEntry(
name=ch.get("displayName", ""),
path=f"/{teamId}/{ch.get('id', '')}",
isFolder=True,
metadata={"id": ch.get("id"), "membershipType": ch.get("membershipType", "")},
)
for ch in result.get("value", [])
]
return []
async def download(self, path: str) -> bytes:
return b""
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
return {"error": "Teams upload not implemented"}
async def search(self, query: str, path: Optional[str] = None) -> list:
return []
# ---------------------------------------------------------------------------
# OneDrive Adapter (Stub -- similar to SharePoint but personal drive)
# ---------------------------------------------------------------------------
class OneDriveAdapter(_GraphApiMixin, ServiceAdapter):
"""ServiceAdapter stub for OneDrive (personal drive)."""
async def browse(self, path: str, filter: Optional[str] = None) -> List[ExternalEntry]:
cleanPath = (path or "").strip("/")
if not cleanPath:
endpoint = "me/drive/root/children"
else:
endpoint = f"me/drive/root:/{cleanPath}:/children"
result = await self._graphGet(endpoint)
if "error" in result:
return []
entries = [_graphItemToExternalEntry(item, path) for item in result.get("value", [])]
if filter:
entries = [e for e in entries if _matchFilter(e, filter)]
return entries
async def download(self, path: str) -> bytes:
cleanPath = (path or "").strip("/")
if not cleanPath:
return b""
data = await self._graphDownload(f"me/drive/root:/{cleanPath}:/content")
return data or b""
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
cleanPath = (path or "").strip("/")
uploadPath = f"{cleanPath}/{fileName}" if cleanPath else fileName
endpoint = f"me/drive/root:/{uploadPath}:/content"
return await self._graphPut(endpoint, data)
async def search(self, query: str, path: Optional[str] = None) -> List[ExternalEntry]:
safeQuery = query.replace("'", "''")
endpoint = f"me/drive/root/search(q='{safeQuery}')"
result = await self._graphGet(endpoint)
if "error" in result:
return []
return [_graphItemToExternalEntry(item) for item in result.get("value", [])]
# ---------------------------------------------------------------------------
# MsftConnector (1:n)
# ---------------------------------------------------------------------------
class MsftConnector(ProviderConnector):
"""Microsoft ProviderConnector -- 1 connection → n services."""
_SERVICE_MAP = {
"sharepoint": SharepointAdapter,
"outlook": OutlookAdapter,
"teams": TeamsAdapter,
"onedrive": OneDriveAdapter,
}
def getAvailableServices(self) -> List[str]:
return list(self._SERVICE_MAP.keys())
def getServiceAdapter(self, service: str) -> ServiceAdapter:
adapterClass = self._SERVICE_MAP.get(service)
if not adapterClass:
raise ValueError(f"Unknown MSFT service: {service}. Available: {list(self._SERVICE_MAP.keys())}")
return adapterClass(self.accessToken)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _parseSharepointPath(path: str) -> tuple:
"""Parse a SharePoint path into (siteId, innerPath).
Expected format: /sites/<siteId>/<innerPath>
Also accepts bare siteId if no /sites/ prefix.
"""
if not path:
return ("", "")
clean = path.strip("/")
if clean.startswith("sites/"):
parts = clean.split("/", 2)
siteId = parts[1] if len(parts) > 1 else ""
innerPath = parts[2] if len(parts) > 2 else ""
return (siteId, innerPath)
parts = clean.split("/", 1)
return (parts[0], parts[1] if len(parts) > 1 else "")
def _matchFilter(entry: ExternalEntry, pattern: str) -> bool:
"""Simple glob-like filter (supports * wildcard)."""
import fnmatch
return fnmatch.fnmatch(entry.name.lower(), pattern.lower())

View file

@ -26,6 +26,12 @@ class OperationTypeEnum(str, Enum):
WEB_SEARCH_DATA = "webSearch" # Returns list of URLs only
WEB_CRAWL = "webCrawl" # Web crawl for a given URL
# Agent Operations
AGENT = "agent" # Agent loop: reasoning + tool use
# Embedding Operations
EMBEDDING = "embedding" # Text → vector conversion for semantic search
# Speech Operations (dedicated pipeline, bypasses standard model selection)
SPEECH_TEAMS = "speechTeams" # Teams Meeting AI analysis: decide if/how to respond
@ -102,6 +108,7 @@ class AiModel(BaseModel):
# Function reference (not serialized)
functionCall: Optional[Callable] = Field(default=None, exclude=True, description="Function to call for this model")
functionCallStream: Optional[Callable] = Field(default=None, exclude=True, description="Streaming function: yields str deltas, then final AiModelResponse")
calculatepriceCHF: Optional[Callable] = Field(default=None, exclude=True, description="Function to calculate price in USD")
# Selection criteria - capabilities with ratings
@ -155,10 +162,12 @@ class AiCallOptions(BaseModel):
class AiCallRequest(BaseModel):
"""Centralized AI call request payload for interface use."""
prompt: str = Field(description="The user prompt")
prompt: str = Field(default="", description="The user prompt")
context: Optional[str] = Field(default=None, description="Optional external context (e.g., extracted docs)")
options: AiCallOptions = Field(default_factory=AiCallOptions)
contentParts: Optional[List['ContentPart']] = None # NEW: Content parts for model-aware chunking
contentParts: Optional[List['ContentPart']] = None # Content parts for model-aware chunking
messages: Optional[List[Dict[str, Any]]] = Field(default=None, description="OpenAI-style messages for multi-turn agent conversations")
tools: Optional[List[Dict[str, Any]]] = Field(default=None, description="Tool definitions for native function calling")
class AiCallResponse(BaseModel):
@ -172,14 +181,19 @@ class AiCallResponse(BaseModel):
bytesSent: int = Field(default=0, description="Input data size in bytes")
bytesReceived: int = Field(default=0, description="Output data size in bytes")
errorCount: int = Field(default=0, description="0 for success, 1+ for errors")
toolCalls: Optional[List[Dict[str, Any]]] = Field(default=None, description="Tool calls from native function calling")
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional response metadata (e.g. embeddings vectors)")
class AiModelCall(BaseModel):
"""Standardized input for AI model calls."""
messages: List[Dict[str, Any]] = Field(description="Messages in OpenAI format (role, content)")
messages: List[Dict[str, Any]] = Field(default_factory=list, description="Messages in OpenAI format (role, content)")
model: Optional[AiModel] = Field(default=None, description="The AI model being called")
options: AiCallOptions = Field(default_factory=AiCallOptions, description="Additional model-specific options")
tools: Optional[List[Dict[str, Any]]] = Field(default=None, description="Tool definitions for native function calling")
toolChoice: Optional[Any] = Field(default=None, description="Tool choice: 'auto', 'none', or specific tool")
embeddingInput: Optional[List[str]] = Field(default=None, description="Input texts for embedding models (used instead of messages)")
model_config = ConfigDict(arbitrary_types_allowed=True)

View file

@ -119,11 +119,17 @@ class BillingTransaction(BaseModel):
# Context for workflow transactions
workflowId: Optional[str] = Field(None, description="Workflow ID (for WORKFLOW transactions)")
featureInstanceId: Optional[str] = Field(None, description="Feature instance ID")
featureCode: Optional[str] = Field(None, description="Feature code (e.g., chatplayground, automation)")
featureCode: Optional[str] = Field(None, description="Feature code (e.g., automation)")
aicoreProvider: Optional[str] = Field(None, description="AICore provider (anthropic, openai, etc.)")
aicoreModel: Optional[str] = Field(None, description="AICore model name (e.g., claude-4-sonnet, gpt-4o)")
createdByUserId: Optional[str] = Field(None, description="User who created/caused this transaction")
# AI call metadata (for per-call analytics)
processingTime: Optional[float] = Field(None, description="Processing time in seconds")
bytesSent: Optional[int] = Field(None, description="Bytes sent to AI model")
bytesReceived: Optional[int] = Field(None, description="Bytes received from AI model")
errorCount: Optional[int] = Field(None, description="Number of errors in this call")
registerModelLabels(
"BillingTransaction",
@ -218,7 +224,7 @@ class UsageStatistics(BaseModel):
# Breakdown by feature
costByFeature: Dict[str, float] = Field(
default_factory=dict,
description="Cost breakdown by feature (e.g., {'chatplayground': 15.00, 'automation': 5.80})"
description="Cost breakdown by feature (e.g., {'automation': 5.80, 'workspace': 3.20})"
)

View file

@ -1,6 +1,6 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Chat models: ChatWorkflow, ChatMessage, ChatLog, ChatStat, ChatDocument."""
"""Chat models: ChatWorkflow, ChatMessage, ChatLog, ChatDocument."""
from typing import List, Dict, Any, Optional
from enum import Enum
@ -10,44 +10,6 @@ from modules.shared.timeUtils import getUtcTimestamp
import uuid
class ChatStat(BaseModel):
"""Statistics for chat operations. User-owned, no mandate context."""
model_config = {"populate_by_name": True, "extra": "allow"} # Allow DB system fields
id: str = Field(
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
)
workflowId: Optional[str] = Field(
None, description="Foreign key to workflow (for workflow stats)"
)
processingTime: Optional[float] = Field(
None, description="Processing time in seconds"
)
bytesSent: Optional[int] = Field(None, description="Number of bytes sent")
bytesReceived: Optional[int] = Field(None, description="Number of bytes received")
errorCount: Optional[int] = Field(None, description="Number of errors encountered")
process: Optional[str] = Field(None, description="The process that delivers the stats data (e.g. 'action.outlook.readMails', 'ai.process.document.name')")
engine: Optional[str] = Field(None, description="The engine used (e.g. 'ai.anthropic.35', 'ai.tavily.basic', 'renderer.docx')")
priceCHF: Optional[float] = Field(None, description="Calculated price in USD for the operation")
registerModelLabels(
"ChatStat",
{"en": "Chat Statistics", "fr": "Statistiques de chat"},
{
"id": {"en": "ID", "fr": "ID"},
"workflowId": {"en": "Workflow ID", "fr": "ID du workflow"},
"processingTime": {"en": "Processing Time", "fr": "Temps de traitement"},
"bytesSent": {"en": "Bytes Sent", "fr": "Octets envoyés"},
"bytesReceived": {"en": "Bytes Received", "fr": "Octets reçus"},
"errorCount": {"en": "Error Count", "fr": "Nombre d'erreurs"},
"process": {"en": "Process", "fr": "Processus"},
"engine": {"en": "Engine", "fr": "Moteur"},
"priceCHF": {"en": "Price CHF", "fr": "Prix CHF"},
},
)
class ChatLog(BaseModel):
"""Log entries for chat workflows. User-owned, no mandate context."""
id: str = Field(
@ -285,8 +247,6 @@ class WorkflowModeEnum(str, Enum):
WORKFLOW_DYNAMIC = "Dynamic"
WORKFLOW_AUTOMATION = "Automation"
WORKFLOW_CHATBOT = "Chatbot"
WORKFLOW_CODEEDITOR = "CodeEditor"
WORKFLOW_REACT = "React" # Legacy mode - kept for backward compatibility
registerModelLabels(
@ -296,8 +256,6 @@ registerModelLabels(
"WORKFLOW_DYNAMIC": {"en": "Dynamic", "fr": "Dynamique"},
"WORKFLOW_AUTOMATION": {"en": "Automation", "fr": "Automatisation"},
"WORKFLOW_CHATBOT": {"en": "Chatbot", "fr": "Chatbot"},
"WORKFLOW_CODEEDITOR": {"en": "Code Editor", "fr": "Éditeur de code"},
"WORKFLOW_REACT": {"en": "React (Legacy)", "fr": "React (Hérité)"},
},
)
@ -322,7 +280,6 @@ class ChatWorkflow(BaseModel):
startedAt: float = Field(default_factory=getUtcTimestamp, description="When the workflow started (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
logs: List[ChatLog] = Field(default_factory=list, description="Workflow logs", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
messages: List[ChatMessage] = Field(default_factory=list, description="Messages in the workflow", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
stats: List[ChatStat] = Field(default_factory=list, description="Workflow statistics list", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
tasks: list = Field(default_factory=list, description="List of tasks in the workflow", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
workflowMode: WorkflowModeEnum = Field(default=WorkflowModeEnum.WORKFLOW_DYNAMIC, description="Workflow mode selector", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": [
{
@ -337,10 +294,6 @@ class ChatWorkflow(BaseModel):
"value": WorkflowModeEnum.WORKFLOW_CHATBOT.value,
"label": {"en": "Chatbot", "fr": "Chatbot"},
},
{
"value": WorkflowModeEnum.WORKFLOW_REACT.value,
"label": {"en": "React (Legacy)", "fr": "React (Hérité)"},
},
]})
maxSteps: int = Field(default=10, description="Maximum number of iterations in dynamic mode", json_schema_extra={"frontend_type": "integer", "frontend_readonly": False, "frontend_required": False})
expectedFormats: Optional[List[str]] = Field(None, description="List of expected file format extensions from user request (e.g., ['xlsx', 'pdf']). Extracted during intent analysis.", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})

View file

@ -0,0 +1,58 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Content Object data models for the container and content extraction pipeline.
Physical layer: Container hierarchy (ZIP, Folder, File)
Logical layer: Scalar content objects (text, image, videostream, audiostream, other)
The entire extraction pipeline up to ContentObjects runs without AI.
"""
from typing import Dict, Any, List, Optional
from pydantic import BaseModel, Field
import uuid
class ContainerLimitError(Exception):
"""Raised when container extraction exceeds safety limits (size, depth, file count)."""
pass
class ContentContextRef(BaseModel):
"""Reference to the origin context within a container/file."""
containerPath: str = Field(description="e.g. 'archiv.zip/folder-a/report.pdf'")
location: str = Field(default="", description="e.g. 'page:5/region:bottomLeft'")
label: Optional[str] = Field(default=None, description="e.g. 'Abbildung 3: Uebersicht'")
pageIndex: Optional[int] = Field(default=None, description="Page number (PDF, DOCX)")
sectionId: Optional[str] = Field(default=None, description="Section/Heading ID")
sheetName: Optional[str] = Field(default=None, description="Sheet name (XLSX)")
slideIndex: Optional[int] = Field(default=None, description="Slide number (PPTX)")
class ContentObject(BaseModel):
"""Scalar content object extracted from a file. No AI involved."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
fileId: str = Field(description="FK to the physical file")
contentType: str = Field(description="text, image, videostream, audiostream, other")
data: str = Field(default="", description="Content data (text, base64, URL)")
contextRef: ContentContextRef = Field(default_factory=ContentContextRef)
metadata: Dict[str, Any] = Field(default_factory=dict)
sequence: int = Field(default=0, description="Order within the context")
class ContentObjectSummary(BaseModel):
"""Compact description of a content object for the FileContentIndex."""
id: str = Field(description="Content object ID")
contentType: str = Field(description="text, image, videostream, audiostream, other")
contextRef: ContentContextRef = Field(default_factory=ContentContextRef)
charCount: Optional[int] = Field(default=None, description="Only for text")
dimensions: Optional[str] = Field(default=None, description="Only for image/video (e.g. '1920x1080')")
duration: Optional[float] = Field(default=None, description="Only for audio/video (seconds)")
class FileEntry(BaseModel):
"""A file extracted from a container (ZIP, TAR, Folder)."""
path: str = Field(description="Relative path within the container")
data: bytes = Field(description="File content bytes")
mimeType: str = Field(description="Detected MIME type")
size: int = Field(description="File size in bytes")

View file

@ -0,0 +1,58 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""DataSource and ExternalEntry models for external data integration.
DataSource links a UserConnection to an external path (SharePoint folder,
Google Drive folder, FTP directory, etc.) for agent-accessible data containers.
"""
from typing import Dict, Any, Optional
from pydantic import BaseModel, Field
from modules.shared.attributeUtils import registerModelLabels
from modules.shared.timeUtils import getUtcTimestamp
import uuid
class DataSource(BaseModel):
"""Configured external data source linked to a UserConnection."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
connectionId: str = Field(description="FK to UserConnection")
sourceType: str = Field(description="sharepointFolder, googleDriveFolder, outlookFolder, ftpFolder")
path: str = Field(description="External path (e.g. '/sites/MySite/Documents/Reports')")
label: str = Field(description="User-visible label")
featureInstanceId: Optional[str] = Field(default=None, description="Scoped to feature instance")
mandateId: Optional[str] = Field(default=None, description="Mandate scope")
userId: str = Field(default="", description="Owner user ID")
autoSync: bool = Field(default=False, description="Automatically sync on schedule")
lastSynced: Optional[float] = Field(default=None, description="Last sync timestamp")
createdAt: float = Field(default_factory=getUtcTimestamp, description="Creation timestamp")
registerModelLabels(
"DataSource",
{"en": "Data Source", "de": "Datenquelle", "fr": "Source de données"},
{
"id": {"en": "ID", "de": "ID", "fr": "ID"},
"connectionId": {"en": "Connection ID", "de": "Verbindungs-ID", "fr": "ID de connexion"},
"sourceType": {"en": "Source Type", "de": "Quellentyp", "fr": "Type de source"},
"path": {"en": "Path", "de": "Pfad", "fr": "Chemin"},
"label": {"en": "Label", "de": "Bezeichnung", "fr": "Libellé"},
"featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance de fonctionnalité"},
"mandateId": {"en": "Mandate ID", "de": "Mandanten-ID", "fr": "ID du mandat"},
"userId": {"en": "User ID", "de": "Benutzer-ID", "fr": "ID utilisateur"},
"autoSync": {"en": "Auto Sync", "de": "Auto-Sync", "fr": "Synchro auto"},
"lastSynced": {"en": "Last Synced", "de": "Letzter Sync", "fr": "Dernier sync"},
"createdAt": {"en": "Created At", "de": "Erstellt am", "fr": "Créé le"},
},
)
class ExternalEntry(BaseModel):
"""An item (file or folder) from an external data source."""
name: str = Field(description="Item name")
path: str = Field(description="Full path within the source")
isFolder: bool = Field(default=False, description="True if directory/folder")
size: Optional[int] = Field(default=None, description="File size in bytes")
mimeType: Optional[str] = Field(default=None, description="MIME type (files only)")
lastModified: Optional[float] = Field(default=None, description="Last modification timestamp")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Provider-specific metadata")

View file

@ -73,7 +73,7 @@ class ExtractionOptions(BaseModel):
"""Options for document extraction and processing with clear data structures."""
# Core extraction parameters
prompt: str = Field(description="Extraction prompt for AI processing")
prompt: str = Field(default="", description="Extraction prompt for AI processing")
processDocumentsIndividually: bool = Field(default=True, description="Process each document separately")
# Image processing parameters
@ -81,7 +81,7 @@ class ExtractionOptions(BaseModel):
imageQuality: int = Field(default=85, ge=1, le=100, description="Image quality (1-100)")
# Merging strategy
mergeStrategy: MergeStrategy = Field(description="Strategy for merging extraction results")
mergeStrategy: MergeStrategy = Field(default_factory=MergeStrategy, description="Strategy for merging extraction results")
# Optional chunking parameters (for backward compatibility)
chunkAllowed: Optional[bool] = Field(default=None, description="Whether chunking is allowed")

View file

@ -0,0 +1,45 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""FeatureDataSource model for exposing feature instance data to the AI workspace.
A FeatureDataSource links a FeatureInstance table (DATA_OBJECT) to a workspace
so the agent can query structured feature data (e.g. TrusteePosition rows).
"""
from typing import Optional
from pydantic import BaseModel, Field
from modules.shared.attributeUtils import registerModelLabels
from modules.shared.timeUtils import getUtcTimestamp
import uuid
class FeatureDataSource(BaseModel):
"""A feature-instance table attached as data source in the AI workspace."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
featureInstanceId: str = Field(description="FK to FeatureInstance")
featureCode: str = Field(description="Feature code (e.g. trustee, commcoach)")
tableName: str = Field(description="Table name from DATA_OBJECTS meta (e.g. TrusteePosition)")
objectKey: str = Field(description="RBAC object key (e.g. data.feature.trustee.TrusteePosition)")
label: str = Field(description="User-visible label")
mandateId: str = Field(default="", description="Mandate scope")
userId: str = Field(default="", description="Owner user ID")
workspaceInstanceId: str = Field(description="Workspace instance where this source is used")
createdAt: float = Field(default_factory=getUtcTimestamp, description="Creation timestamp")
registerModelLabels(
"FeatureDataSource",
{"en": "Feature Data Source", "de": "Feature-Datenquelle", "fr": "Source de données fonctionnalité"},
{
"id": {"en": "ID", "de": "ID", "fr": "ID"},
"featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance"},
"featureCode": {"en": "Feature", "de": "Feature", "fr": "Fonctionnalité"},
"tableName": {"en": "Table", "de": "Tabelle", "fr": "Table"},
"objectKey": {"en": "Object Key", "de": "Objekt-Schlüssel", "fr": "Clé objet"},
"label": {"en": "Label", "de": "Bezeichnung", "fr": "Libellé"},
"mandateId": {"en": "Mandate", "de": "Mandant", "fr": "Mandat"},
"userId": {"en": "User", "de": "Benutzer", "fr": "Utilisateur"},
"workspaceInstanceId": {"en": "Workspace", "de": "Workspace", "fr": "Espace de travail"},
"createdAt": {"en": "Created At", "de": "Erstellt am", "fr": "Créé le"},
},
)

View file

@ -0,0 +1,32 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""FileFolder: hierarchical folder structure for file organization."""
from typing import Optional
from pydantic import BaseModel, Field
from modules.shared.attributeUtils import registerModelLabels
from modules.shared.timeUtils import getUtcTimestamp
import uuid
class FileFolder(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
name: str = Field(description="Folder name", json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": True})
parentId: Optional[str] = Field(default=None, description="Parent folder ID (null = root)", json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": False})
mandateId: Optional[str] = Field(default=None, description="Mandate context", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
featureInstanceId: Optional[str] = Field(default=None, description="Feature instance context", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
createdAt: float = Field(default_factory=getUtcTimestamp, description="Creation timestamp", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
registerModelLabels(
"FileFolder",
{"en": "File Folder", "fr": "Dossier de fichiers"},
{
"id": {"en": "ID", "fr": "ID"},
"name": {"en": "Name", "fr": "Nom"},
"parentId": {"en": "Parent Folder", "fr": "Dossier parent"},
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
"featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance"},
"createdAt": {"en": "Created At", "fr": "Créé le"},
},
)

View file

@ -2,7 +2,7 @@
# All rights reserved.
"""File-related datamodels: FileItem, FilePreview, FileData."""
from typing import Dict, Any, Optional, Union
from typing import Dict, Any, List, Optional, Union
from pydantic import BaseModel, ConfigDict, Field
from modules.shared.attributeUtils import registerModelLabels
from modules.shared.timeUtils import getUtcTimestamp
@ -14,12 +14,16 @@ class FileItem(BaseModel):
model_config = ConfigDict(extra='allow') # Preserve system fields (_createdBy, _createdAt, etc.)
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
mandateId: Optional[str] = Field(default="", description="ID of the mandate this file belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
featureInstanceId: Optional[str] = Field(default="", description="ID of the feature instance this file belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
featureInstanceId: Optional[str] = Field(default="", description="ID of the feature instance this file belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False, "frontend_fk_source": "/api/features/instances", "frontend_fk_display_field": "label"})
fileName: str = Field(description="Name of the file", json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": True})
mimeType: str = Field(description="MIME type of the file", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
fileHash: str = Field(description="Hash of the file", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
fileSize: int = Field(description="Size of the file in bytes", json_schema_extra={"frontend_type": "integer", "frontend_readonly": True, "frontend_required": False})
creationDate: float = Field(default_factory=getUtcTimestamp, description="Date when the file was created (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
tags: Optional[List[str]] = Field(default=None, description="Tags for categorization and search", json_schema_extra={"frontend_type": "tags", "frontend_readonly": False, "frontend_required": False})
folderId: Optional[str] = Field(default=None, description="ID of the parent folder", json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": False})
description: Optional[str] = Field(default=None, description="User-provided description of the file", json_schema_extra={"frontend_type": "textarea", "frontend_readonly": False, "frontend_required": False})
status: Optional[str] = Field(default=None, description="Processing status: pending, extracted, embedding, indexed, failed", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
registerModelLabels(
"FileItem",
@ -27,12 +31,16 @@ registerModelLabels(
{
"id": {"en": "ID", "fr": "ID"},
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
"featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance de fonctionnalité"},
"featureInstanceId": {"en": "Feature Instance", "fr": "Instance de fonctionnalité"},
"fileName": {"en": "fileName", "fr": "Nom de fichier"},
"mimeType": {"en": "MIME Type", "fr": "Type MIME"},
"fileHash": {"en": "File Hash", "fr": "Hash du fichier"},
"fileSize": {"en": "File Size", "fr": "Taille du fichier"},
"creationDate": {"en": "Creation Date", "fr": "Date de création"},
"tags": {"en": "Tags", "fr": "Tags"},
"folderId": {"en": "Folder ID", "fr": "ID du dossier"},
"description": {"en": "Description", "fr": "Description"},
"status": {"en": "Status", "fr": "Statut"},
},
)

View file

@ -0,0 +1,130 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Knowledge Store data models: FileContentIndex, ContentChunk, WorkflowMemory.
These models support the 3-tier RAG architecture:
- Shared Layer: mandateId-scoped, isShared=True
- Instance Layer: userId + featureInstanceId-scoped
- Workflow Layer: workflowId-scoped (WorkflowMemory)
Vector fields use json_schema_extra={"db_type": "vector(1536)"} for pgvector.
"""
from typing import Dict, Any, List, Optional
from pydantic import BaseModel, Field
from modules.shared.attributeUtils import registerModelLabels
from modules.shared.timeUtils import getUtcTimestamp
import uuid
class FileContentIndex(BaseModel):
"""Structural index of a file's content objects. Created without AI.
Lives in the Instance Layer; optionally promoted to Shared Layer via isShared."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key (typically = fileId)")
userId: str = Field(description="Owner user ID")
featureInstanceId: str = Field(default="", description="Feature instance scope")
mandateId: str = Field(default="", description="Mandate scope")
isShared: bool = Field(default=False, description="Visible in Shared Layer for all mandate users")
fileName: str = Field(description="Original file name")
mimeType: str = Field(description="MIME type of the file")
containerPath: Optional[str] = Field(default=None, description="Path within a container (e.g. 'archive.zip/folder/report.pdf')")
totalObjects: int = Field(default=0, description="Total number of content objects extracted")
totalSize: int = Field(default=0, description="Total size of all content objects in bytes")
structure: Dict[str, Any] = Field(default_factory=dict, description="Structural overview (pages, sections, hierarchy)")
objectSummary: List[Dict[str, Any]] = Field(default_factory=list, description="Compact summary per content object")
extractedAt: float = Field(default_factory=getUtcTimestamp, description="Extraction timestamp")
status: str = Field(default="pending", description="Processing status: pending, extracted, embedding, indexed, failed")
registerModelLabels(
"FileContentIndex",
{"en": "File Content Index", "fr": "Index du contenu de fichier"},
{
"id": {"en": "ID", "fr": "ID"},
"userId": {"en": "User ID", "fr": "ID utilisateur"},
"featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance"},
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
"isShared": {"en": "Shared", "fr": "Partagé"},
"fileName": {"en": "File Name", "fr": "Nom de fichier"},
"mimeType": {"en": "MIME Type", "fr": "Type MIME"},
"containerPath": {"en": "Container Path", "fr": "Chemin du conteneur"},
"totalObjects": {"en": "Total Objects", "fr": "Nombre total d'objets"},
"totalSize": {"en": "Total Size", "fr": "Taille totale"},
"structure": {"en": "Structure", "fr": "Structure"},
"objectSummary": {"en": "Object Summary", "fr": "Résumé des objets"},
"extractedAt": {"en": "Extracted At", "fr": "Extrait le"},
"status": {"en": "Status", "fr": "Statut"},
},
)
class ContentChunk(BaseModel):
"""Persisted content chunk with embedding vector. Reusable across workflows.
Scalar content object (or chunk thereof) with pgvector embedding."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
contentObjectId: str = Field(description="Reference to the content object within FileContentIndex")
fileId: str = Field(description="FK to the source file")
userId: str = Field(description="Owner user ID")
featureInstanceId: str = Field(default="", description="Feature instance scope")
contentType: str = Field(description="Content type: text, image, videostream, audiostream, other")
data: str = Field(description="Content data (text, base64, URL)")
contextRef: Dict[str, Any] = Field(default_factory=dict, description="Context reference (page, position, label)")
summary: Optional[str] = Field(default=None, description="AI-generated summary (on demand)")
chunkMetadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
embedding: Optional[List[float]] = Field(
default=None, description="pgvector embedding (NOT NULL for text chunks)",
json_schema_extra={"db_type": "vector(1536)"}
)
registerModelLabels(
"ContentChunk",
{"en": "Content Chunk", "fr": "Fragment de contenu"},
{
"id": {"en": "ID", "fr": "ID"},
"contentObjectId": {"en": "Content Object ID", "fr": "ID de l'objet de contenu"},
"fileId": {"en": "File ID", "fr": "ID du fichier"},
"userId": {"en": "User ID", "fr": "ID utilisateur"},
"featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance"},
"contentType": {"en": "Content Type", "fr": "Type de contenu"},
"data": {"en": "Data", "fr": "Données"},
"contextRef": {"en": "Context Reference", "fr": "Référence contextuelle"},
"summary": {"en": "Summary", "fr": "Résumé"},
"chunkMetadata": {"en": "Metadata", "fr": "Métadonnées"},
"embedding": {"en": "Embedding", "fr": "Vecteur d'embedding"},
},
)
class WorkflowMemory(BaseModel):
"""Workflow-scoped key-value cache for entities and facts.
Extracted during agent rounds, persisted for cross-round and cross-workflow reuse."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
workflowId: str = Field(description="FK to the workflow")
userId: str = Field(description="Owner user ID")
featureInstanceId: str = Field(default="", description="Feature instance scope")
key: str = Field(description="Key identifier (e.g. 'entity:companyName')")
value: str = Field(description="Extracted value")
source: str = Field(default="extraction", description="Origin: extraction, tool, conversation, summary")
createdAt: float = Field(default_factory=getUtcTimestamp, description="Creation timestamp")
embedding: Optional[List[float]] = Field(
default=None, description="Optional embedding for semantic lookup",
json_schema_extra={"db_type": "vector(1536)"}
)
registerModelLabels(
"WorkflowMemory",
{"en": "Workflow Memory", "fr": "Mémoire de workflow"},
{
"id": {"en": "ID", "fr": "ID"},
"workflowId": {"en": "Workflow ID", "fr": "ID du workflow"},
"userId": {"en": "User ID", "fr": "ID utilisateur"},
"featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance"},
"key": {"en": "Key", "fr": "Clé"},
"value": {"en": "Value", "fr": "Valeur"},
"source": {"en": "Source", "fr": "Source"},
"createdAt": {"en": "Created At", "fr": "Créé le"},
"embedding": {"en": "Embedding", "fr": "Vecteur d'embedding"},
},
)

View file

@ -2,6 +2,7 @@
# All rights reserved.
"""Voice settings datamodel."""
from typing import Dict, Any, Optional
from pydantic import BaseModel, Field
from modules.shared.attributeUtils import registerModelLabels
from modules.shared.timeUtils import getUtcTimestamp
@ -16,6 +17,7 @@ class VoiceSettings(BaseModel):
sttLanguage: str = Field(default="de-DE", description="Speech-to-Text language", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True})
ttsLanguage: str = Field(default="de-DE", description="Text-to-Speech language", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True})
ttsVoice: str = Field(default="de-DE-KatjaNeural", description="Text-to-Speech voice", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True})
ttsVoiceMap: Dict[str, Any] = Field(default_factory=dict, description="Per-language voice mapping, e.g. {'de-DE': {'voiceName': 'de-DE-Wavenet-A'}, 'en-US': {'voiceName': 'en-US-Wavenet-C'}}", json_schema_extra={"frontend_type": "json", "frontend_readonly": False, "frontend_required": False})
translationEnabled: bool = Field(default=True, description="Whether translation is enabled", json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False})
targetLanguage: str = Field(default="en-US", description="Target language for translation", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False})
creationDate: float = Field(default_factory=getUtcTimestamp, description="Date when the settings were created (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
@ -33,6 +35,7 @@ registerModelLabels(
"sttLanguage": {"en": "STT Language", "fr": "Langue STT"},
"ttsLanguage": {"en": "TTS Language", "fr": "Langue TTS"},
"ttsVoice": {"en": "TTS Voice", "fr": "Voix TTS"},
"ttsVoiceMap": {"en": "TTS Voice Map", "fr": "Carte des voix TTS"},
"translationEnabled": {"en": "Translation Enabled", "fr": "Traduction activée"},
"targetLanguage": {"en": "Target Language", "fr": "Langue cible"},
"creationDate": {"en": "Creation Date", "fr": "Date de création"},

View file

@ -180,7 +180,7 @@ def getAutomationServices(
for spec in REQUIRED_SERVICES:
key = spec["serviceKey"]
try:
svc = getService(key, ctx, legacy_hub=None)
svc = getService(key, ctx)
setattr(hub, key, svc)
except Exception as e:
logger.warning(f"Could not resolve service '{key}' for automation: {e}")

View file

@ -17,10 +17,11 @@ from modules.features.automation.interfaceFeatureAutomation import getInterface
from modules.features.automation.mainAutomation import getAutomationServices
from modules.auth import limiter, getRequestContext, RequestContext
from modules.features.automation.datamodelFeatureAutomation import AutomationDefinition, AutomationTemplate
from modules.datamodels.datamodelChat import ChatWorkflow, ChatMessage, ChatLog
from modules.datamodels.datamodelChat import ChatWorkflow, ChatMessage, ChatLog, UserInputRequest, WorkflowModeEnum
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
from modules.shared.attributeUtils import getModelAttributeDefinitions
from modules.interfaces import interfaceDbChat
from modules.interfaces.interfaceDbBilling import getInterface as _getBillingInterface
# Configure logger
logger = logging.getLogger(__name__)
@ -234,7 +235,7 @@ def get_available_actions(
# -----------------------------------------------------------------------------
# Workflow routes under /{instanceId}/workflows/ (instance-scoped, same as chatplayground)
# Workflow routes under /{instanceId}/workflows/ (instance-scoped)
# -----------------------------------------------------------------------------
def _validateAutomationInstanceAccess(instanceId: str, context: RequestContext) -> Optional[str]:
@ -682,7 +683,9 @@ def get_automation_workflow_chat_data(
workflow = chatInterface.getWorkflow(workflowId)
if not workflow:
raise HTTPException(status_code=404, detail=f"Workflow {workflowId} not found")
return chatInterface.getUnifiedChatData(workflowId, afterTimestamp)
billingInterface = _getBillingInterface(context.user, context.mandateId)
workflowCost = billingInterface.getWorkflowCost(workflowId)
return chatInterface.getUnifiedChatData(workflowId, afterTimestamp, workflowCost=workflowCost)
except HTTPException:
raise
except Exception as e:
@ -851,6 +854,46 @@ def delete_automation(
detail=f"Error deleting automation: {str(e)}"
)
@router.post("/{instanceId}/start", response_model=ChatWorkflow)
@limiter.limit("120/minute")
async def start_automation_workflow(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
workflowId: Optional[str] = Query(None, description="Optional ID of the workflow to continue"),
workflowMode: WorkflowModeEnum = Query(..., description="Workflow mode: 'Dynamic' or 'Automation' (mandatory)"),
userInput: UserInputRequest = Body(...),
context: RequestContext = Depends(getRequestContext)
) -> ChatWorkflow:
"""Start a new workflow or continue an existing one."""
try:
from modules.workflows.automation import chatStart
mandateId = _validateAutomationInstanceAccess(instanceId, context)
services = getAutomationServices(
context.user,
mandateId=mandateId,
featureInstanceId=instanceId,
)
services.featureCode = "automation"
if hasattr(userInput, 'allowedProviders') and userInput.allowedProviders:
services.allowedProviders = userInput.allowedProviders
workflow = await chatStart(
context.user,
userInput,
workflowMode,
workflowId,
mandateId=mandateId,
featureInstanceId=instanceId,
featureCode="automation",
services=services,
)
return workflow
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in start_automation_workflow: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/{automationId}/execute", response_model=ChatWorkflow)
@limiter.limit("5/minute")
async def execute_automation_route(

View file

@ -1291,17 +1291,6 @@ class ChatObjects:
logger.error(f"Error updating message {messageId}: {str(e)}", exc_info=True)
raise ValueError(f"Error updating message {messageId}: {str(e)}")
def createStat(self, statData: Dict[str, Any]):
"""Create stat record. Compatibility with ChatService; stats may not be persisted in chatbot schema."""
from modules.datamodels.datamodelChat import ChatStat
stat = ChatStat(**statData)
try:
created = self.db.recordCreate(ChatStat, statData)
return ChatStat(**created)
except Exception as e:
logger.debug(f"createStat: not persisting (chatbot schema): {e}")
return stat
def deleteMessage(self, conversationId: str, messageId: str) -> bool:
"""Deletes a conversation message and related data if user has access."""
try:

View file

@ -179,7 +179,7 @@ def getChatbotServices(
for spec in REQUIRED_SERVICES:
key = spec["serviceKey"]
try:
svc = getService(key, ctx, legacy_hub=None)
svc = getService(key, ctx)
setattr(hub, key, svc)
except Exception as e:
logger.warning(f"Could not resolve service '{key}' for chatbot: {e}")
@ -197,7 +197,7 @@ def getChatStreamingHelper():
from modules.serviceCenter.context import ServiceCenterContext
# Minimal context - streaming service only needs it for resolver
ctx = ServiceCenterContext(user=__get_placeholder_user(), mandate_id=None, feature_instance_id=None)
streaming = getService("streaming", ctx, legacy_hub=None)
streaming = getService("streaming", ctx)
return streaming.getChatStreamingHelper() if streaming else None
@ -219,7 +219,7 @@ def getEventManager(user, mandateId: Optional[str] = None, featureInstanceId: Op
mandate_id=mandateId,
feature_instance_id=featureInstanceId,
)
streaming = getService("streaming", ctx, legacy_hub=None)
streaming = getService("streaming", ctx)
return streaming.getEventManager()
@ -306,12 +306,12 @@ def getChatbotServices(
Uses interfaceFeatureChatbot (ChatObjects) for interfaceDbChat to avoid
duplicate DB init - chatProcess reuses hub.interfaceDbChat.
"""
from modules.services import PublicService
from modules.serviceHub import PublicService
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
from modules.features.chatbot.interfaceFeatureChatbot import getInterface as getChatbotInterface
from modules.services.serviceChat.mainServiceChat import ChatService
from modules.services.serviceAi.mainServiceAi import AiService
from modules.services.serviceStreaming.mainServiceStreaming import StreamingService
from modules.serviceCenter.services.serviceChat.mainServiceChat import ChatService
from modules.serviceCenter.services.serviceAi.mainServiceAi import AiService
from modules.serviceCenter.core.serviceStreaming.mainServiceStreaming import StreamingService
hub = _ChatbotServiceHub()
hub.user = user
@ -344,7 +344,7 @@ def getChatbotServices(
feature_instance_id=featureInstanceId,
workflow=_workflow,
)
hub.billing = getService("billing", ctx, legacy_hub=None)
hub.billing = getService("billing", ctx)
except Exception as e:
logger.warning(f"Could not resolve billing service for chatbot: {e}")
hub.billing = None

View file

@ -1,6 +0,0 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Chat Playground Feature Container.
Provides workflow-based chat playground functionality.
"""

View file

@ -1,145 +0,0 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Chat Playground Feature Interface.
Wrapper around interfaceDbChat with feature instance context.
"""
import logging
from typing import Dict, Any, List, Optional
from modules.datamodels.datamodelUam import User
from modules.interfaces import interfaceDbChat
logger = logging.getLogger(__name__)
# Feature code constant
FEATURE_CODE = "chatplayground"
# Singleton instances cache
_instances: Dict[str, "ChatPlaygroundObjects"] = {}
def getInterface(currentUser: User, mandateId: str = None, featureInstanceId: str = None) -> "ChatPlaygroundObjects":
"""
Factory function to get or create a ChatPlaygroundObjects instance.
Uses singleton pattern per user context.
Args:
currentUser: Current user object
mandateId: Mandate ID
featureInstanceId: Feature instance ID
Returns:
ChatPlaygroundObjects instance
"""
cacheKey = f"{currentUser.id}_{mandateId}_{featureInstanceId}"
if cacheKey not in _instances:
_instances[cacheKey] = ChatPlaygroundObjects(currentUser, mandateId, featureInstanceId)
else:
# Update context if needed
_instances[cacheKey].setUserContext(currentUser, mandateId, featureInstanceId)
return _instances[cacheKey]
class ChatPlaygroundObjects:
"""
Chat Playground feature interface.
Wraps the shared interfaceDbChat with feature instance context.
"""
FEATURE_CODE = FEATURE_CODE
def __init__(self, currentUser: User, mandateId: str = None, featureInstanceId: str = None):
"""
Initialize the Chat Playground interface.
Args:
currentUser: Current user object
mandateId: Mandate ID
featureInstanceId: Feature instance ID
"""
self.currentUser = currentUser
self.mandateId = mandateId
self.featureInstanceId = featureInstanceId
# Get the underlying chat interface
self._chatInterface = interfaceDbChat.getInterface(
currentUser,
mandateId=mandateId,
featureInstanceId=featureInstanceId
)
def setUserContext(self, currentUser: User, mandateId: str = None, featureInstanceId: str = None):
"""
Update the user context.
Args:
currentUser: Current user object
mandateId: Mandate ID
featureInstanceId: Feature instance ID
"""
self.currentUser = currentUser
self.mandateId = mandateId
self.featureInstanceId = featureInstanceId
# Update underlying interface
self._chatInterface = interfaceDbChat.getInterface(
currentUser,
mandateId=mandateId,
featureInstanceId=featureInstanceId
)
# =========================================================================
# Delegated methods from interfaceDbChat
# =========================================================================
def getWorkflow(self, workflowId: str) -> Optional[Dict[str, Any]]:
"""Get a workflow by ID."""
return self._chatInterface.getWorkflow(workflowId)
def getWorkflows(self, pagination=None) -> Dict[str, Any]:
"""Get all workflows with pagination."""
return self._chatInterface.getWorkflows(pagination=pagination)
def getUnifiedChatData(self, workflowId: str, afterTimestamp: float = None) -> Dict[str, Any]:
"""Get unified chat data for a workflow."""
return self._chatInterface.getUnifiedChatData(workflowId, afterTimestamp)
def createWorkflow(self, workflow) -> Dict[str, Any]:
"""Create a new workflow."""
return self._chatInterface.createWorkflow(workflow)
def updateWorkflow(self, workflowId: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Update a workflow."""
return self._chatInterface.updateWorkflow(workflowId, updates)
def deleteWorkflow(self, workflowId: str) -> bool:
"""Delete a workflow."""
return self._chatInterface.deleteWorkflow(workflowId)
def getMessages(self, workflowId: str) -> List[Dict[str, Any]]:
"""Get messages for a workflow."""
return self._chatInterface.getMessages(workflowId)
def createMessage(self, message) -> Dict[str, Any]:
"""Create a new message."""
return self._chatInterface.createMessage(message)
def getLogs(self, workflowId: str) -> List[Dict[str, Any]]:
"""Get logs for a workflow."""
return self._chatInterface.getLogs(workflowId)
def createLog(self, log) -> Dict[str, Any]:
"""Create a new log entry."""
return self._chatInterface.createLog(log)
def getStats(self, workflowId: str) -> List[Dict[str, Any]]:
"""Get stats for a workflow."""
return self._chatInterface.getStats(workflowId)
def createStat(self, stat) -> Dict[str, Any]:
"""Create a new stat entry."""
return self._chatInterface.createStat(stat)

View file

@ -1,381 +0,0 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Chat Playground Feature Container - Main Module.
Handles feature initialization and RBAC catalog registration.
"""
import logging
from typing import Dict, List, Any, Optional
logger = logging.getLogger(__name__)
# Feature metadata
FEATURE_CODE = "chatplayground"
FEATURE_LABEL = {"en": "Chat Playground", "de": "Chat Playground", "fr": "Chat Playground"}
FEATURE_ICON = "mdi-message-text"
# UI Objects for RBAC catalog
UI_OBJECTS = [
{
"objectKey": "ui.feature.chatplayground.playground",
"label": {"en": "Playground", "de": "Playground", "fr": "Playground"},
"meta": {"area": "playground"}
},
{
"objectKey": "ui.feature.chatplayground.workflows",
"label": {"en": "Workflows", "de": "Workflows", "fr": "Workflows"},
"meta": {"area": "workflows"}
},
]
# Resource Objects for RBAC catalog
RESOURCE_OBJECTS = [
{
"objectKey": "resource.feature.chatplayground.start",
"label": {"en": "Start Workflow", "de": "Workflow starten", "fr": "Démarrer workflow"},
"meta": {"endpoint": "/api/chatplayground/{instanceId}/start", "method": "POST"}
},
{
"objectKey": "resource.feature.chatplayground.stop",
"label": {"en": "Stop Workflow", "de": "Workflow stoppen", "fr": "Arrêter workflow"},
"meta": {"endpoint": "/api/chatplayground/{instanceId}/workflows/{workflowId}/stop", "method": "POST"}
},
{
"objectKey": "resource.feature.chatplayground.chatData",
"label": {"en": "Get Chat Data", "de": "Chat-Daten abrufen", "fr": "Récupérer données chat"},
"meta": {"endpoint": "/api/chatplayground/{instanceId}/workflows/{workflowId}/chatData", "method": "GET"}
},
]
# Service requirements - services this feature needs from the service center
# Same as automation: chatplayground runs the same WorkflowManager and workflow methods
REQUIRED_SERVICES = [
{"serviceKey": "chat", "meta": {"usage": "Workflow CRUD, messages, logs"}},
{"serviceKey": "ai", "meta": {"usage": "AI planning for workflow execution"}},
{"serviceKey": "utils", "meta": {"usage": "Timestamps, utilities"}},
{"serviceKey": "billing", "meta": {"usage": "AI call billing"}},
{"serviceKey": "extraction", "meta": {"usage": "Workflow method actions"}},
{"serviceKey": "sharepoint", "meta": {"usage": "SharePoint actions (listDocuments, uploadDocument, etc.)"}},
{"serviceKey": "generation", "meta": {"usage": "Action completion messages, document creation from results"}},
]
# Template roles for this feature
# Role names MUST follow convention: {featureCode}-{roleName}
TEMPLATE_ROLES = [
{
"roleLabel": "chatplayground-viewer",
"description": {
"en": "Chat Playground Viewer - View chat playground (read-only)",
"de": "Chat Playground Betrachter - Chat Playground ansehen (nur lesen)",
"fr": "Visualiseur Chat Playground - Consulter le chat playground (lecture seule)"
},
"accessRules": [
# UI: only playground view, NO workflows
{"context": "UI", "item": "ui.feature.chatplayground.playground", "view": True},
# RESOURCE: NO access (viewer cannot start/stop/access chat data)
# DATA access (own records, read-only)
{"context": "DATA", "item": None, "view": True, "read": "m", "create": "n", "update": "n", "delete": "n"},
]
},
{
"roleLabel": "chatplayground-user",
"description": {
"en": "Chat Playground User - Use chat playground and workflows",
"de": "Chat Playground Benutzer - Chat Playground und Workflows nutzen",
"fr": "Utilisateur Chat Playground - Utiliser le chat playground et les workflows"
},
"accessRules": [
# UI: full access to all views
{"context": "UI", "item": "ui.feature.chatplayground.playground", "view": True},
{"context": "UI", "item": "ui.feature.chatplayground.workflows", "view": True},
# Resource access: can start/stop workflows and access chat data
{"context": "RESOURCE", "item": "resource.feature.chatplayground.start", "view": True},
{"context": "RESOURCE", "item": "resource.feature.chatplayground.stop", "view": True},
{"context": "RESOURCE", "item": "resource.feature.chatplayground.chatData", "view": True},
# DATA access (own records)
{"context": "DATA", "item": None, "view": True, "read": "m", "create": "m", "update": "m", "delete": "m"},
]
},
{
"roleLabel": "chatplayground-admin",
"description": {
"en": "Chat Playground Admin - Full access to chat playground",
"de": "Chat Playground Admin - Vollzugriff auf Chat Playground",
"fr": "Administrateur Chat Playground - Accès complet au chat playground"
},
"accessRules": [
# Full UI access
{"context": "UI", "item": None, "view": True},
# Full resource access
{"context": "RESOURCE", "item": None, "view": True},
# Full DATA access
{"context": "DATA", "item": None, "view": True, "read": "a", "create": "a", "update": "a", "delete": "a"},
]
},
]
def getRequiredServiceKeys() -> List[str]:
"""Return list of service keys this feature requires."""
return [s["serviceKey"] for s in REQUIRED_SERVICES]
def getChatplaygroundServices(
user,
mandateId: Optional[str] = None,
featureInstanceId: Optional[str] = None,
workflow=None,
) -> "_ChatplaygroundServiceHub":
"""
Get a service hub for the chatplayground feature using the service center.
Resolves only the services declared in REQUIRED_SERVICES.
No legacy fallback - service center only.
Returns a hub-like object with: chat, ai, utils, billing, extraction,
sharepoint, rbac, interfaceDbApp, interfaceDbComponent, interfaceDbChat.
"""
from modules.serviceCenter import getService
from modules.serviceCenter.context import ServiceCenterContext
_workflow = workflow
if _workflow is None:
_workflow = type("_Placeholder", (), {"featureCode": FEATURE_CODE})()
ctx = ServiceCenterContext(
user=user,
mandate_id=mandateId,
feature_instance_id=featureInstanceId,
workflow=_workflow,
)
hub = _ChatplaygroundServiceHub()
hub.user = user
hub.mandateId = mandateId
hub.featureInstanceId = featureInstanceId
hub.workflow = workflow
hub.featureCode = FEATURE_CODE
hub.allowedProviders = None
for spec in REQUIRED_SERVICES:
key = spec["serviceKey"]
try:
svc = getService(key, ctx, legacy_hub=None)
setattr(hub, key, svc)
except Exception as e:
logger.warning(f"Could not resolve service '{key}' for chatplayground: {e}")
setattr(hub, key, None)
# Copy interfaces from chat service for WorkflowManager compatibility
if hub.chat:
hub.interfaceDbApp = getattr(hub.chat, "interfaceDbApp", None)
hub.interfaceDbComponent = getattr(hub.chat, "interfaceDbComponent", None)
hub.interfaceDbChat = getattr(hub.chat, "interfaceDbChat", None)
# RBAC for MethodBase action permission checks (workflow methods)
hub.rbac = getattr(hub.interfaceDbApp, "rbac", None) if hub.interfaceDbApp else None
return hub
class _ChatplaygroundServiceHub:
"""Lightweight hub exposing only services required by the chatplayground feature."""
user = None
mandateId = None
featureInstanceId = None
workflow = None
featureCode = "chatplayground"
allowedProviders = None
interfaceDbApp = None
interfaceDbComponent = None
interfaceDbChat = None
rbac = None
chat = None
ai = None
utils = None
billing = None
extraction = None
sharepoint = None
def getFeatureDefinition() -> Dict[str, Any]:
"""Return the feature definition for registration."""
return {
"code": FEATURE_CODE,
"label": FEATURE_LABEL,
"icon": FEATURE_ICON,
"autoCreateInstance": True, # Automatically create instance in root mandate during bootstrap
}
def getUiObjects() -> List[Dict[str, Any]]:
"""Return UI objects for RBAC catalog registration."""
return UI_OBJECTS
def getResourceObjects() -> List[Dict[str, Any]]:
"""Return resource objects for RBAC catalog registration."""
return RESOURCE_OBJECTS
def getTemplateRoles() -> List[Dict[str, Any]]:
"""Return template roles for this feature."""
return TEMPLATE_ROLES
def registerFeature(catalogService) -> bool:
"""
Register this feature's RBAC objects in the catalog.
Args:
catalogService: The RBAC catalog service instance
Returns:
True if registration was successful
"""
try:
# Register UI objects
for uiObj in UI_OBJECTS:
catalogService.registerUiObject(
featureCode=FEATURE_CODE,
objectKey=uiObj["objectKey"],
label=uiObj["label"],
meta=uiObj.get("meta")
)
# Register Resource objects
for resObj in RESOURCE_OBJECTS:
catalogService.registerResourceObject(
featureCode=FEATURE_CODE,
objectKey=resObj["objectKey"],
label=resObj["label"],
meta=resObj.get("meta")
)
# Sync template roles to database
_syncTemplateRolesToDb()
logger.info(f"Feature '{FEATURE_CODE}' registered {len(UI_OBJECTS)} UI objects and {len(RESOURCE_OBJECTS)} resource objects")
return True
except Exception as e:
logger.error(f"Failed to register feature '{FEATURE_CODE}': {e}")
return False
def _syncTemplateRolesToDb() -> int:
"""
Sync template roles and their AccessRules to the database.
Creates global template roles (mandateId=None) if they don't exist.
Returns:
Number of roles created/updated
"""
try:
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.datamodels.datamodelRbac import Role, AccessRule, AccessRuleContext
rootInterface = getRootInterface()
# Get existing template roles for this feature (Pydantic models)
existingRoles = rootInterface.getRolesByFeatureCode(FEATURE_CODE)
# Filter to template roles (mandateId is None)
templateRoles = [r for r in existingRoles if r.mandateId is None]
existingRoleLabels = {r.roleLabel: str(r.id) for r in templateRoles}
createdCount = 0
for roleTemplate in TEMPLATE_ROLES:
roleLabel = roleTemplate["roleLabel"]
if roleLabel in existingRoleLabels:
roleId = existingRoleLabels[roleLabel]
# Ensure AccessRules exist for this role
_ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", []))
else:
# Create new template role
newRole = Role(
roleLabel=roleLabel,
description=roleTemplate.get("description", {}),
featureCode=FEATURE_CODE,
mandateId=None, # Global template
featureInstanceId=None,
isSystemRole=False
)
createdRole = rootInterface.db.recordCreate(Role, newRole.model_dump())
roleId = createdRole.get("id")
# Create AccessRules for this role
_ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", []))
logger.info(f"Created template role '{roleLabel}' with ID {roleId}")
createdCount += 1
if createdCount > 0:
logger.info(f"Feature '{FEATURE_CODE}': Created {createdCount} template roles")
return createdCount
except Exception as e:
logger.error(f"Error syncing template roles for feature '{FEATURE_CODE}': {e}")
return 0
def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: List[Dict[str, Any]]) -> int:
"""
Ensure AccessRules exist for a role based on templates.
Args:
rootInterface: Root interface instance
roleId: Role ID
ruleTemplates: List of rule templates
Returns:
Number of rules created
"""
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
# Get existing rules for this role (Pydantic models)
existingRules = rootInterface.getAccessRulesByRole(roleId)
# Create a set of existing rule signatures to avoid duplicates
# IMPORTANT: Use .value for enum comparison, not str() which gives "AccessRuleContext.DATA" in Python 3.11+
existingSignatures = set()
for rule in existingRules:
sig = (rule.context.value if rule.context else None, rule.item)
existingSignatures.add(sig)
createdCount = 0
for template in ruleTemplates:
context = template.get("context", "UI")
item = template.get("item")
sig = (context, item)
if sig in existingSignatures:
continue
# Map context string to enum
if context == "UI":
contextEnum = AccessRuleContext.UI
elif context == "DATA":
contextEnum = AccessRuleContext.DATA
elif context == "RESOURCE":
contextEnum = AccessRuleContext.RESOURCE
else:
contextEnum = context
newRule = AccessRule(
roleId=roleId,
context=contextEnum,
item=item,
view=template.get("view", False),
read=template.get("read"),
create=template.get("create"),
update=template.get("update"),
delete=template.get("delete"),
)
rootInterface.db.recordCreate(AccessRule, newRule.model_dump())
createdCount += 1
if createdCount > 0:
logger.debug(f"Created {createdCount} AccessRules for role {roleId}")
return createdCount

View file

@ -1,719 +0,0 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Chat Playground Feature Routes.
Implements the endpoints for chat playground workflow management as a feature.
"""
import json
import logging
from typing import Optional, Dict, Any
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request, status
# Import auth modules
from modules.auth import limiter, getRequestContext, RequestContext
# Import interfaces
from modules.interfaces import interfaceDbChat
# Import models
from modules.datamodels.datamodelChat import (
ChatWorkflow,
ChatMessage,
ChatLog,
UserInputRequest,
WorkflowModeEnum,
)
from modules.datamodels.datamodelPagination import (
PaginationParams,
PaginatedResponse,
PaginationMetadata,
normalize_pagination_dict,
)
# Import workflow control functions
from modules.workflows.automation import chatStart, chatStop
from modules.features.chatplayground.mainChatplayground import getChatplaygroundServices
from modules.shared.attributeUtils import getModelAttributeDefinitions
# Configure logger
logger = logging.getLogger(__name__)
# Model attributes for ChatWorkflow (workflow attributes endpoint)
workflowAttributes = getModelAttributeDefinitions(ChatWorkflow)
# Create router for chat playground feature endpoints
router = APIRouter(
prefix="/api/chatplayground",
tags=["Chat Playground Feature"],
responses={404: {"description": "Not found"}}
)
def _getServiceChat(context: RequestContext, featureInstanceId: str = None):
"""Get chat interface with feature instance context."""
return interfaceDbChat.getInterface(
context.user,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=featureInstanceId
)
def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str:
"""
Validate that user has access to the feature instance.
Args:
instanceId: Feature instance ID
context: Request context
Returns:
mandateId for the instance
Raises:
HTTPException if access is denied
"""
from modules.interfaces.interfaceDbApp import getRootInterface
rootInterface = getRootInterface()
# Get feature instance (Pydantic model)
instance = rootInterface.getFeatureInstance(instanceId)
if not instance:
raise HTTPException(status_code=404, detail=f"Feature instance {instanceId} not found")
# Check user has access to this instance using interface method
featureAccess = rootInterface.getFeatureAccess(str(context.user.id), instanceId)
if not featureAccess or not featureAccess.enabled:
raise HTTPException(status_code=403, detail="Access denied to this feature instance")
return str(instance.mandateId) if instance.mandateId else None
# Workflow start endpoint
@router.post("/{instanceId}/start", response_model=ChatWorkflow)
@limiter.limit("120/minute")
async def start_workflow(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
workflowId: Optional[str] = Query(None, description="Optional ID of the workflow to continue"),
workflowMode: WorkflowModeEnum = Query(..., description="Workflow mode: 'Dynamic' or 'Automation' (mandatory)"),
userInput: UserInputRequest = Body(...),
context: RequestContext = Depends(getRequestContext)
) -> ChatWorkflow:
"""
Starts a new workflow or continues an existing one.
Args:
instanceId: Feature instance ID
workflowMode: "Dynamic" for iterative dynamic-style processing, "Automation" for automated workflow execution
"""
try:
# Validate access and get mandate ID
mandateId = _validateInstanceAccess(instanceId, context)
# Get chatplayground services from service center (not automation)
services = getChatplaygroundServices(
context.user,
mandateId=mandateId,
featureInstanceId=instanceId,
)
services.featureCode = 'chatplayground'
if hasattr(userInput, 'allowedProviders') and userInput.allowedProviders:
services.allowedProviders = userInput.allowedProviders
# Start or continue workflow
workflow = await chatStart(
context.user,
userInput,
workflowMode,
workflowId,
mandateId=mandateId,
featureInstanceId=instanceId,
featureCode='chatplayground',
services=services,
)
return workflow
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in start_workflow: {str(e)}")
raise HTTPException(
status_code=500,
detail=str(e)
)
# Stop workflow endpoint (under /workflows/{workflowId}/ for consistency)
@router.post("/{instanceId}/workflows/{workflowId}/stop", response_model=ChatWorkflow)
@limiter.limit("120/minute")
async def stop_workflow(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
workflowId: str = Path(..., description="ID of the workflow to stop"),
context: RequestContext = Depends(getRequestContext)
) -> ChatWorkflow:
"""Stops a running workflow."""
try:
# Validate access and get mandate ID
mandateId = _validateInstanceAccess(instanceId, context)
# Get chatplayground services from service center (not automation)
services = getChatplaygroundServices(
context.user,
mandateId=mandateId,
featureInstanceId=instanceId,
)
services.featureCode = 'chatplayground'
# Stop workflow (pass featureInstanceId for proper RBAC filtering)
workflow = await chatStop(
context.user,
workflowId,
mandateId=mandateId,
featureInstanceId=instanceId,
featureCode='chatplayground',
services=services,
)
return workflow
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in stop_workflow: {str(e)}")
raise HTTPException(
status_code=500,
detail=str(e)
)
# Unified Chat Data Endpoint for Polling (under /workflows/{workflowId}/ for consistency)
@router.get("/{instanceId}/workflows/{workflowId}/chatData")
@limiter.limit("120/minute")
def get_workflow_chat_data(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
workflowId: str = Path(..., description="ID of the workflow"),
afterTimestamp: Optional[float] = Query(None, description="Unix timestamp to get data after"),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""
Get unified chat data (messages, logs, stats) for a workflow with timestamp-based selective data transfer.
Returns all data types in chronological order based on _createdAt timestamp.
"""
try:
# Validate access
_validateInstanceAccess(instanceId, context)
# Get service with feature instance context
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
# Verify workflow exists
workflow = chatInterface.getWorkflow(workflowId)
if not workflow:
raise HTTPException(
status_code=404,
detail=f"Workflow with ID {workflowId} not found"
)
# Get unified chat data
chatData = chatInterface.getUnifiedChatData(workflowId, afterTimestamp)
return chatData
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting unified chat data: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Error getting unified chat data: {str(e)}"
)
# Get workflow attributes (ChatWorkflow model)
@router.get("/{instanceId}/workflows/attributes", response_model=Dict[str, Any])
@limiter.limit("120/minute")
def get_workflow_attributes(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Get attribute definitions for ChatWorkflow model."""
_validateInstanceAccess(instanceId, context)
return {"attributes": workflowAttributes}
# Get workflows for this instance
@router.get("/{instanceId}/workflows", response_model=PaginatedResponse[ChatWorkflow])
@limiter.limit("120/minute")
def get_workflows(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
page: int = Query(1, ge=1, description="Page number (legacy)"),
pageSize: int = Query(20, ge=1, le=100, description="Items per page (legacy)"),
context: RequestContext = Depends(getRequestContext)
) -> PaginatedResponse[ChatWorkflow]:
"""
Get all workflows for this feature instance with optional pagination.
"""
try:
# Validate access
_validateInstanceAccess(instanceId, context)
# Get service with feature instance context
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
# Parse pagination parameter
paginationParams = None
if pagination:
try:
paginationDict = json.loads(pagination)
if paginationDict:
paginationDict = normalize_pagination_dict(paginationDict)
paginationParams = PaginationParams(**paginationDict)
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(
status_code=400,
detail=f"Invalid pagination parameter: {str(e)}"
)
else:
paginationParams = PaginationParams(page=page, pageSize=pageSize)
result = chatInterface.getWorkflows(pagination=paginationParams)
if paginationParams:
return PaginatedResponse(
items=result.items,
pagination=PaginationMetadata(
currentPage=paginationParams.page,
pageSize=paginationParams.pageSize,
totalItems=result.totalItems,
totalPages=result.totalPages,
sort=paginationParams.sort,
filters=paginationParams.filters
)
)
else:
return PaginatedResponse(items=result, pagination=None)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting workflows: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Error getting workflows: {str(e)}"
)
# Action Discovery Endpoints (must be before /{workflowId} to avoid path conflict)
@router.get("/{instanceId}/workflows/actions", response_model=Dict[str, Any])
@limiter.limit("120/minute")
def get_all_workflow_actions(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Get all available workflow actions for the current user (filtered by RBAC)."""
try:
mandateId = _validateInstanceAccess(instanceId, context)
services = getChatplaygroundServices(
context.user,
mandateId=mandateId,
featureInstanceId=instanceId,
)
from modules.workflows.processing.shared.methodDiscovery import discoverMethods, methods
discoverMethods(services)
allActions = []
for methodName, methodInfo in methods.items():
if methodName.startswith('Method'):
continue
methodInstance = methodInfo['instance']
methodActions = methodInstance.actions
for actionName, actionInfo in methodActions.items():
actionResponse = {
"module": methodInstance.name,
"actionId": f"{methodInstance.name}.{actionName}",
"name": actionName,
"description": actionInfo.get('description', ''),
"parameters": actionInfo.get('parameters', {})
}
allActions.append(actionResponse)
return {"actions": allActions}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting all actions: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to get actions: {str(e)}")
@router.get("/{instanceId}/workflows/actions/{method}", response_model=Dict[str, Any])
@limiter.limit("120/minute")
def get_method_workflow_actions(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
method: str = Path(..., description="Method name (e.g., 'outlook', 'sharepoint')"),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Get all available actions for a specific method."""
try:
mandateId = _validateInstanceAccess(instanceId, context)
services = getChatplaygroundServices(
context.user,
mandateId=mandateId,
featureInstanceId=instanceId,
)
from modules.workflows.processing.shared.methodDiscovery import discoverMethods, methods
discoverMethods(services)
methodInstance = None
for methodName, methodInfo in methods.items():
if methodInfo['instance'].name == method:
methodInstance = methodInfo['instance']
break
if not methodInstance:
raise HTTPException(status_code=404, detail=f"Method '{method}' not found")
actions = []
for actionName, actionInfo in methodInstance.actions.items():
actionResponse = {
"actionId": f"{methodInstance.name}.{actionName}",
"name": actionName,
"description": actionInfo.get('description', ''),
"parameters": actionInfo.get('parameters', {})
}
actions.append(actionResponse)
return {"module": methodInstance.name, "description": methodInstance.description, "actions": actions}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting actions for method {method}: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to get actions for method {method}: {str(e)}")
@router.get("/{instanceId}/workflows/actions/{method}/{action}", response_model=Dict[str, Any])
@limiter.limit("120/minute")
def get_action_schema(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
method: str = Path(..., description="Method name (e.g., 'outlook', 'sharepoint')"),
action: str = Path(..., description="Action name (e.g., 'readEmails', 'uploadDocument')"),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Get action schema with parameter definitions for a specific action."""
try:
mandateId = _validateInstanceAccess(instanceId, context)
services = getChatplaygroundServices(
context.user,
mandateId=mandateId,
featureInstanceId=instanceId,
)
from modules.workflows.processing.shared.methodDiscovery import discoverMethods, methods
discoverMethods(services)
methodInstance = None
for methodName, methodInfo in methods.items():
if methodInfo['instance'].name == method:
methodInstance = methodInfo['instance']
break
if not methodInstance:
raise HTTPException(status_code=404, detail=f"Method '{method}' not found")
methodActions = methodInstance.actions
if action not in methodActions:
raise HTTPException(status_code=404, detail=f"Action '{action}' not found in method '{method}'")
actionInfo = methodActions[action]
return {
"method": methodInstance.name,
"action": action,
"actionId": f"{methodInstance.name}.{action}",
"description": actionInfo.get('description', ''),
"parameters": actionInfo.get('parameters', {})
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting action schema for {method}.{action}: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to get action schema: {str(e)}")
# Get single workflow by ID
@router.get("/{instanceId}/workflows/{workflowId}", response_model=ChatWorkflow)
@limiter.limit("120/minute")
def get_workflow(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
workflowId: str = Path(..., description="ID of the workflow"),
context: RequestContext = Depends(getRequestContext)
) -> ChatWorkflow:
"""Get workflow by ID."""
try:
_validateInstanceAccess(instanceId, context)
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
workflow = chatInterface.getWorkflow(workflowId)
if not workflow:
raise HTTPException(status_code=404, detail="Workflow not found")
return workflow
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting workflow: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get workflow: {str(e)}")
# Update workflow
@router.put("/{instanceId}/workflows/{workflowId}", response_model=ChatWorkflow)
@limiter.limit("120/minute")
def update_workflow(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
workflowId: str = Path(..., description="ID of the workflow to update"),
workflowData: Dict[str, Any] = Body(...),
context: RequestContext = Depends(getRequestContext)
) -> ChatWorkflow:
"""Update workflow by ID."""
try:
_validateInstanceAccess(instanceId, context)
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
workflow = chatInterface.getWorkflow(workflowId)
if not workflow:
raise HTTPException(status_code=404, detail="Workflow not found")
if not chatInterface.checkRbacPermission(ChatWorkflow, "update", workflowId):
raise HTTPException(status_code=403, detail="You don't have permission to update this workflow")
updatedWorkflow = chatInterface.updateWorkflow(workflowId, workflowData)
if not updatedWorkflow:
raise HTTPException(status_code=500, detail="Failed to update workflow")
return updatedWorkflow
except HTTPException:
raise
except Exception as e:
logger.error(f"Error updating workflow: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to update workflow: {str(e)}")
# Delete workflow
@router.delete("/{instanceId}/workflows/{workflowId}")
@limiter.limit("120/minute")
def delete_workflow(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
workflowId: str = Path(..., description="ID of the workflow to delete"),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Deletes a workflow and its associated data."""
try:
_validateInstanceAccess(instanceId, context)
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
workflow = chatInterface.getWorkflow(workflowId)
if not workflow:
raise HTTPException(status_code=404, detail=f"Workflow with ID {workflowId} not found")
if not chatInterface.checkRbacPermission(ChatWorkflow, "delete", workflowId):
raise HTTPException(status_code=403, detail="You don't have permission to delete this workflow")
success = chatInterface.deleteWorkflow(workflowId)
if not success:
raise HTTPException(status_code=500, detail="Failed to delete workflow")
return {"id": workflowId, "message": "Workflow and associated data deleted successfully"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error deleting workflow: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error deleting workflow: {str(e)}")
# Get workflow status
@router.get("/{instanceId}/workflows/{workflowId}/status", response_model=ChatWorkflow)
@limiter.limit("120/minute")
def get_workflow_status(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
workflowId: str = Path(..., description="ID of the workflow"),
context: RequestContext = Depends(getRequestContext)
) -> ChatWorkflow:
"""Get the current status of a workflow."""
try:
_validateInstanceAccess(instanceId, context)
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
workflow = chatInterface.getWorkflow(workflowId)
if not workflow:
raise HTTPException(status_code=404, detail=f"Workflow with ID {workflowId} not found")
return workflow
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting workflow status: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error getting workflow status: {str(e)}")
# Get workflow logs
@router.get("/{instanceId}/workflows/{workflowId}/logs", response_model=PaginatedResponse[ChatLog])
@limiter.limit("120/minute")
def get_workflow_logs(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
workflowId: str = Path(..., description="ID of the workflow"),
logId: Optional[str] = Query(None, description="Optional log ID for selective data transfer"),
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
context: RequestContext = Depends(getRequestContext)
) -> PaginatedResponse[ChatLog]:
"""Get logs for a workflow with optional pagination."""
try:
_validateInstanceAccess(instanceId, context)
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
workflow = chatInterface.getWorkflow(workflowId)
if not workflow:
raise HTTPException(status_code=404, detail=f"Workflow with ID {workflowId} not found")
paginationParams = None
if pagination:
try:
paginationDict = json.loads(pagination)
if paginationDict:
paginationDict = normalize_pagination_dict(paginationDict)
paginationParams = PaginationParams(**paginationDict)
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
result = chatInterface.getLogs(workflowId, pagination=paginationParams)
if logId:
allLogs = result.items if paginationParams else result
logIndex = next((i for i, log in enumerate(allLogs) if log.id == logId), -1)
if logIndex >= 0:
filteredLogs = allLogs[logIndex + 1:]
return PaginatedResponse(items=filteredLogs, pagination=None)
if paginationParams:
return PaginatedResponse(
items=result.items,
pagination=PaginationMetadata(
currentPage=paginationParams.page,
pageSize=paginationParams.pageSize,
totalItems=result.totalItems,
totalPages=result.totalPages,
sort=paginationParams.sort,
filters=paginationParams.filters
)
)
return PaginatedResponse(items=result, pagination=None)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting workflow logs: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error getting workflow logs: {str(e)}")
# Get workflow messages
@router.get("/{instanceId}/workflows/{workflowId}/messages", response_model=PaginatedResponse[ChatMessage])
@limiter.limit("120/minute")
def get_workflow_messages(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
workflowId: str = Path(..., description="ID of the workflow"),
messageId: Optional[str] = Query(None, description="Optional message ID for selective data transfer"),
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
context: RequestContext = Depends(getRequestContext)
) -> PaginatedResponse[ChatMessage]:
"""Get messages for a workflow with optional pagination."""
try:
_validateInstanceAccess(instanceId, context)
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
workflow = chatInterface.getWorkflow(workflowId)
if not workflow:
raise HTTPException(status_code=404, detail=f"Workflow with ID {workflowId} not found")
paginationParams = None
if pagination:
try:
paginationDict = json.loads(pagination)
if paginationDict:
paginationDict = normalize_pagination_dict(paginationDict)
paginationParams = PaginationParams(**paginationDict)
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
result = chatInterface.getMessages(workflowId, pagination=paginationParams)
if messageId:
allMessages = result.items if paginationParams else result
messageIndex = next((i for i, msg in enumerate(allMessages) if msg.id == messageId), -1)
if messageIndex >= 0:
filteredMessages = allMessages[messageIndex + 1:]
return PaginatedResponse(items=filteredMessages, pagination=None)
if paginationParams:
return PaginatedResponse(
items=result.items,
pagination=PaginationMetadata(
currentPage=paginationParams.page,
pageSize=paginationParams.pageSize,
totalItems=result.totalItems,
totalPages=result.totalPages,
sort=paginationParams.sort,
filters=paginationParams.filters
)
)
return PaginatedResponse(items=result, pagination=None)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting workflow messages: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error getting workflow messages: {str(e)}")
# Delete message from workflow
@router.delete("/{instanceId}/workflows/{workflowId}/messages/{messageId}")
@limiter.limit("120/minute")
def delete_workflow_message(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
workflowId: str = Path(..., description="ID of the workflow"),
messageId: str = Path(..., description="ID of the message to delete"),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Delete a message from a workflow."""
try:
_validateInstanceAccess(instanceId, context)
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
workflow = chatInterface.getWorkflow(workflowId)
if not workflow:
raise HTTPException(status_code=404, detail=f"Workflow with ID {workflowId} not found")
success = chatInterface.deleteMessage(workflowId, messageId)
if not success:
raise HTTPException(status_code=404, detail=f"Message with ID {messageId} not found in workflow {workflowId}")
return {"workflowId": workflowId, "messageId": messageId, "message": "Message deleted successfully"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error deleting message: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error deleting message: {str(e)}")
# Delete file from message
@router.delete("/{instanceId}/workflows/{workflowId}/messages/{messageId}/files/{fileId}")
@limiter.limit("120/minute")
def delete_file_from_message(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
workflowId: str = Path(..., description="ID of the workflow"),
messageId: str = Path(..., description="ID of the message"),
fileId: str = Path(..., description="ID of the file to delete"),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Delete a file reference from a message in a workflow."""
try:
_validateInstanceAccess(instanceId, context)
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
workflow = chatInterface.getWorkflow(workflowId)
if not workflow:
raise HTTPException(status_code=404, detail=f"Workflow with ID {workflowId} not found")
success = chatInterface.deleteFileFromMessage(workflowId, messageId, fileId)
if not success:
raise HTTPException(status_code=404, detail=f"File with ID {fileId} not found in message {messageId}")
return {"workflowId": workflowId, "messageId": messageId, "fileId": fileId, "message": "File reference deleted successfully"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error deleting file reference: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error deleting file reference: {str(e)}")

View file

@ -1 +0,0 @@
"""CodeEditor Feature - Cursor-style AI file editing via chat interface."""

View file

@ -1,280 +0,0 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""CodeEditor processor -- single-shot (Phase 1) and agent loop (Phase 2).
Orchestrates file loading, prompt building, AI calls, response parsing, and SSE emission."""
import logging
from typing import List, Dict, Any
from modules.features.codeeditor import fileContextManager, promptAssembly, responseParser
from modules.features.codeeditor.datamodelCodeeditor import (
FileEditProposal, SegmentTypeEnum, AgentState
)
from modules.features.codeeditor import toolRegistry
from modules.shared.timeUtils import getUtcTimestamp
logger = logging.getLogger(__name__)
async def processMessage(
workflowId: str,
userPrompt: str,
selectedFileIds: List[str],
dbManagement,
interfaceAi,
chatInterface,
eventManager,
agentMode: bool = False
):
"""Process a user message. Dispatches to single-shot or agent loop based on mode."""
if agentMode:
await _processAgentMessage(
workflowId, userPrompt, dbManagement, interfaceAi, chatInterface, eventManager
)
else:
await _processSingleShot(
workflowId, userPrompt, selectedFileIds, dbManagement, interfaceAi, chatInterface, eventManager
)
async def _processSingleShot(
workflowId, userPrompt, selectedFileIds, dbManagement, interfaceAi, chatInterface, eventManager
):
"""Phase 1: Single AI call with pre-loaded file context."""
try:
await _emitStatus(eventManager, workflowId, "Loading files...")
fileContexts = await fileContextManager.loadFileContexts(dbManagement, selectedFileIds)
await _emitStatus(eventManager, workflowId, "Building prompt...")
chatHistory = _loadChatHistory(chatInterface, workflowId)
aiRequest = promptAssembly.buildRequest(userPrompt, fileContexts, chatHistory)
await _emitStatus(eventManager, workflowId, "AI is processing...")
aiResponse = await interfaceAi.callWithTextContext(aiRequest)
if aiResponse.errorCount > 0:
await _emitError(eventManager, workflowId, aiResponse.content)
return
segments = responseParser.parseResponse(aiResponse.content)
await _emitSegments(eventManager, workflowId, segments, fileContexts)
_logAiStats(aiResponse, workflowId)
await eventManager.emit_event(workflowId, "complete", {
"workflowId": workflowId,
"modelName": aiResponse.modelName,
"priceCHF": aiResponse.priceCHF,
"processingTime": aiResponse.processingTime
})
except Exception as e:
logger.error(f"CodeEditor single-shot failed for {workflowId}: {e}", exc_info=True)
await eventManager.emit_event(workflowId, "error", {
"workflowId": workflowId, "error": str(e)
})
async def _processAgentMessage(
workflowId, userPrompt, dbManagement, interfaceAi, chatInterface, eventManager
):
"""Phase 2: Agent loop -- multiple AI calls with tool execution until done."""
state = AgentState(workflowId=workflowId)
try:
await _emitStatus(eventManager, workflowId, "Agent: Scanning available files...")
fileListContext = fileContextManager.buildFileListContext(dbManagement)
state.conversationHistory.append({"role": "user", "content": userPrompt})
aiRequest = promptAssembly.buildAgentRequest(
userPrompt=userPrompt,
fileListContext=fileListContext,
conversationHistory=[]
)
while state.status == "running" and state.currentRound < state.maxRounds:
state.currentRound += 1
state.totalAiCalls += 1
await _emitStatus(eventManager, workflowId,
f"Agent round {state.currentRound}: AI is thinking...")
await eventManager.emit_event(workflowId, "chatdata", {
"type": "agent_progress",
"item": {
"round": state.currentRound,
"totalAiCalls": state.totalAiCalls,
"totalToolCalls": state.totalToolCalls,
"costCHF": round(state.totalCostCHF, 4),
}
})
aiResponse = await interfaceAi.callWithTextContext(aiRequest)
state.totalCostCHF += aiResponse.priceCHF
state.totalProcessingTime += aiResponse.processingTime
if aiResponse.errorCount > 0:
logger.error(f"Agent AI call failed in round {state.currentRound}: {aiResponse.content}")
await _emitError(eventManager, workflowId, aiResponse.content)
state.status = "error"
break
_logAiStats(aiResponse, workflowId)
state.conversationHistory.append({"role": "assistant", "content": aiResponse.content})
segments = responseParser.parseResponse(aiResponse.content)
textAndEditSegments = [s for s in segments if s.type != SegmentTypeEnum.TOOL_CALL]
if textAndEditSegments:
await _emitSegments(eventManager, workflowId, textAndEditSegments, [])
toolCallSegments = [s for s in segments if s.type == SegmentTypeEnum.TOOL_CALL]
if not toolCallSegments:
state.status = "completed"
break
toolResultTexts = []
for tc in toolCallSegments:
state.totalToolCalls += 1
await _emitStatus(eventManager, workflowId,
f"Agent: Running {tc.toolName}...")
result = await toolRegistry.dispatch(tc.toolName, tc.toolArgs or {}, dbManagement)
toolResultTexts.append(f"[{tc.toolName}] (success={result.success}):\n{result.result}")
logger.info(f"Agent tool {tc.toolName}: success={result.success}, time={result.executionTime:.2f}s")
combinedResults = "\n\n".join(toolResultTexts)
state.conversationHistory.append({
"role": "tool_result",
"content": combinedResults,
"toolName": "batch"
})
aiRequest = promptAssembly.buildAgentRequest(
userPrompt=None,
fileListContext=fileListContext,
conversationHistory=state.conversationHistory
)
if state.currentRound >= state.maxRounds and state.status == "running":
state.status = "max_rounds"
await eventManager.emit_event(workflowId, "chatdata", {
"type": "message",
"item": {
"role": "system",
"content": f"Agent stopped: maximum rounds ({state.maxRounds}) reached.",
"createdAt": getUtcTimestamp()
}
})
await eventManager.emit_event(workflowId, "chatdata", {
"type": "agent_summary",
"item": {
"rounds": state.currentRound,
"totalAiCalls": state.totalAiCalls,
"totalToolCalls": state.totalToolCalls,
"costCHF": round(state.totalCostCHF, 4),
"processingTime": round(state.totalProcessingTime, 1),
"status": state.status,
}
})
await eventManager.emit_event(workflowId, "complete", {
"workflowId": workflowId,
"agentRounds": state.currentRound,
"totalCostCHF": round(state.totalCostCHF, 4),
"processingTime": round(state.totalProcessingTime, 1)
})
except Exception as e:
logger.error(f"CodeEditor agent loop failed for {workflowId}: {e}", exc_info=True)
await eventManager.emit_event(workflowId, "error", {
"workflowId": workflowId, "error": str(e)
})
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
async def _emitStatus(eventManager, workflowId: str, label: str):
await eventManager.emit_event(workflowId, "chatdata", {
"type": "status", "label": label
})
async def _emitError(eventManager, workflowId: str, errorMsg: str):
await eventManager.emit_event(workflowId, "chatdata", {
"type": "message",
"item": {"role": "assistant", "content": f"Error: {errorMsg}"}
})
await eventManager.emit_event(workflowId, "error", {
"workflowId": workflowId, "error": errorMsg
})
async def _emitSegments(eventManager, workflowId: str, segments, fileContexts):
"""Emit parsed segments as SSE events."""
for segment in segments:
messageData = {
"role": "assistant",
"content": segment.content,
"type": segment.type.value,
"createdAt": getUtcTimestamp()
}
await eventManager.emit_event(workflowId, "chatdata", {
"type": "message", "item": messageData
})
if segment.type == SegmentTypeEnum.FILE_EDIT:
proposal = FileEditProposal(
workflowId=workflowId,
fileId=_resolveFileId(segment.fileName, fileContexts),
fileName=segment.fileName,
operation="edit",
oldContent=segment.oldContent,
newContent=segment.newContent
)
await eventManager.emit_event(workflowId, "chatdata", {
"type": "file_edit_proposal", "item": proposal.model_dump()
})
def _loadChatHistory(chatInterface, workflowId: str) -> List[Dict[str, Any]]:
"""Load recent chat messages for multi-turn context."""
try:
messages = chatInterface.getMessages(workflowId)
if not messages:
return []
history = []
for msg in messages:
role = msg.get("role", "unknown") if isinstance(msg, dict) else getattr(msg, "role", "unknown")
content = msg.get("content", "") if isinstance(msg, dict) else getattr(msg, "content", "")
history.append({"role": role, "content": content})
return history
except Exception as e:
logger.warning(f"Could not load chat history for {workflowId}: {e}")
return []
def _resolveFileId(fileName: str, fileContexts) -> str:
"""Resolve a fileName to its fileId from the loaded contexts."""
for fc in fileContexts:
if fc.fileName == fileName:
return fc.fileId
return f"unknown-{fileName}"
def _logAiStats(aiResponse, workflowId: str):
"""Log AI call statistics."""
logger.info(
f"CodeEditor AI call for {workflowId}: "
f"model={aiResponse.modelName}, "
f"provider={aiResponse.provider}, "
f"cost={aiResponse.priceCHF:.4f} CHF, "
f"time={aiResponse.processingTime:.1f}s, "
f"sent={aiResponse.bytesSent}B, received={aiResponse.bytesReceived}B"
)

View file

@ -1,122 +0,0 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Data models for the CodeEditor feature."""
from typing import List, Dict, Any, Optional
from enum import Enum
from pydantic import BaseModel, Field
from modules.shared.timeUtils import getUtcTimestamp
import uuid
class SegmentTypeEnum(str, Enum):
TEXT = "text"
CODE_BLOCK = "code_block"
FILE_EDIT = "file_edit"
TOOL_CALL = "tool_call"
class EditStatusEnum(str, Enum):
PENDING = "pending"
ACCEPTED = "accepted"
REJECTED = "rejected"
class FileContext(BaseModel):
"""A text file loaded as context for the AI."""
fileId: str
fileName: str
content: Optional[str] = None
mimeType: str
sizeBytes: int = 0
modifiedAt: Optional[float] = None
tags: List[str] = Field(default_factory=list)
class ResponseSegment(BaseModel):
"""A parsed segment from the AI response."""
type: SegmentTypeEnum
content: str
language: Optional[str] = None
fileId: Optional[str] = None
fileName: Optional[str] = None
oldContent: Optional[str] = None
newContent: Optional[str] = None
toolName: Optional[str] = None
toolArgs: Optional[Dict[str, Any]] = None
class FileEditProposal(BaseModel):
"""A proposed file edit from the AI, awaiting user accept/reject."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
workflowId: str
fileId: str
fileName: str
operation: str = "edit"
oldContent: Optional[str] = None
newContent: str
diffSummary: Optional[str] = None
status: EditStatusEnum = EditStatusEnum.PENDING
createdAt: float = Field(default_factory=getUtcTimestamp)
class FileVersion(BaseModel):
"""A new version of a file created after accepting an edit proposal."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
sourceFileId: str
editProposalId: str
newFileId: str
createdAt: float = Field(default_factory=getUtcTimestamp)
class AgentState(BaseModel):
"""Tracks state across an agent loop execution."""
workflowId: str
currentRound: int = 0
maxRounds: int = 50
totalAiCalls: int = 0
totalToolCalls: int = 0
totalCostCHF: float = 0.0
totalProcessingTime: float = 0.0
conversationHistory: List[Dict[str, Any]] = Field(default_factory=list)
status: str = "running"
class ToolResult(BaseModel):
"""Result from executing a tool."""
toolName: str
result: str
success: bool = True
executionTime: float = 0.0
TEXT_MIME_TYPES = {
"text/plain", "text/markdown", "text/html", "text/css", "text/csv",
"text/xml", "text/yaml", "text/x-python", "text/x-java",
"text/javascript", "text/x-typescript", "text/x-sql",
"application/json", "application/xml", "application/yaml",
"application/x-yaml", "application/javascript",
}
TEXT_EXTENSIONS = {
".md", ".txt", ".json", ".yaml", ".yml", ".xml", ".csv",
".py", ".js", ".ts", ".tsx", ".jsx", ".html", ".htm", ".css", ".scss",
".sql", ".sh", ".bash", ".zsh", ".ps1", ".bat",
".toml", ".ini", ".cfg", ".conf", ".env", ".gitignore",
".dockerfile", ".docker-compose", ".makefile",
".java", ".kt", ".go", ".rs", ".rb", ".php", ".swift", ".c", ".cpp", ".h",
".r", ".lua", ".dart", ".vue", ".svelte",
}
def isTextFile(mimeType: Optional[str], fileName: Optional[str] = None) -> bool:
"""Check if a file is a text-based file suitable for the editor."""
if mimeType and mimeType.lower() in TEXT_MIME_TYPES:
return True
if mimeType and mimeType.lower().startswith("text/"):
return True
if fileName:
ext = "." + fileName.rsplit(".", 1)[-1].lower() if "." in fileName else ""
if ext in TEXT_EXTENSIONS:
return True
return False

View file

@ -1,84 +0,0 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""File context manager for CodeEditor feature.
Loads text files from the database and provides them as context for AI calls."""
import logging
from typing import List, Optional
from modules.features.codeeditor.datamodelCodeeditor import FileContext, isTextFile
logger = logging.getLogger(__name__)
async def loadFileContexts(dbManagement, fileIds: List[str]) -> List[FileContext]:
"""Load text files from DB and return as FileContext list.
Args:
dbManagement: interfaceDbManagement instance with user context set
fileIds: list of file IDs to load
"""
contexts = []
for fileId in fileIds:
fileItem = dbManagement.getFile(fileId)
if not fileItem:
logger.warning(f"File {fileId} not found or no access")
continue
if not isTextFile(fileItem.mimeType, fileItem.fileName):
logger.warning(f"File {fileItem.fileName} ({fileItem.mimeType}) is not a text file, skipping")
continue
fileData = dbManagement.getFileData(fileId)
if not fileData:
logger.warning(f"No data for file {fileId}")
continue
try:
content = fileData.decode("utf-8")
except UnicodeDecodeError:
logger.warning(f"File {fileItem.fileName} is not valid UTF-8, skipping")
continue
contexts.append(FileContext(
fileId=fileId,
fileName=fileItem.fileName,
content=content,
mimeType=fileItem.mimeType,
sizeBytes=fileItem.fileSize
))
logger.info(f"Loaded {len(contexts)} file contexts from {len(fileIds)} requested")
return contexts
def listTextFiles(dbManagement) -> List[FileContext]:
"""List all text files accessible to the user (metadata only, no content)."""
allFiles = dbManagement.getAllFiles()
textFiles = []
if not allFiles:
return textFiles
for fileItem in allFiles:
if isTextFile(fileItem.mimeType, fileItem.fileName):
modifiedAt = getattr(fileItem, "_modifiedAt", None) or getattr(fileItem, "creationDate", None)
textFiles.append(FileContext(
fileId=fileItem.id,
fileName=fileItem.fileName,
content=None,
mimeType=fileItem.mimeType,
sizeBytes=fileItem.fileSize,
modifiedAt=modifiedAt
))
return textFiles
def buildFileListContext(dbManagement) -> str:
"""Build a compact file list string for the agent prompt (no content, just metadata)."""
textFiles = listTextFiles(dbManagement)
if not textFiles:
return "No text files available."
lines = [f"- {f.fileName} (id: {f.fileId}, size: {f.sizeBytes}B)" for f in textFiles]
return f"Total: {len(lines)} text files\n" + "\n".join(lines)

View file

@ -1,183 +0,0 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Prompt assembly for the CodeEditor feature.
Builds Cursor-style system prompts with file context and format instructions."""
import logging
from typing import List, Optional, Dict, Any
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum
from modules.features.codeeditor.datamodelCodeeditor import FileContext
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = """You are an AI assistant for text and code file editing. You receive files as context and can suggest changes.
## Rules for file edits
- Use ```file_edit``` blocks for file changes
- Each file_edit block must contain: fileName, oldContent (exact text to replace), newContent (replacement text)
- Explain changes in normal text before or after the block
- oldContent must EXACTLY match existing content (including whitespace and indentation)
- You may propose edits to multiple files in one response
## Response format
Normal text is displayed as explanation.
File changes must use this format:
```file_edit
fileName: <filename>
oldContent: |
<exact existing content to replace>
newContent: |
<new replacement content>
```
Code examples (without edits) use standard markdown code blocks:
```language
code here
```
## Important
- Only edit files that are provided in context
- Make minimal, targeted changes
- Preserve existing formatting and style
- If a task is unclear, ask for clarification instead of guessing"""
def buildRequest(
userPrompt: str,
fileContexts: List[FileContext],
chatHistory: Optional[List[Dict[str, Any]]] = None
) -> AiCallRequest:
"""Build an AiCallRequest with system prompt, file context, and user prompt."""
systemPart = SYSTEM_PROMPT
fileContextPart = _buildFileContext(fileContexts)
historyPart = _buildChatHistory(chatHistory) if chatHistory else ""
fullPrompt = systemPart
if historyPart:
fullPrompt += f"\n\n## Previous conversation\n{historyPart}"
fullPrompt += f"\n\n## User request\n{userPrompt}"
return AiCallRequest(
prompt=fullPrompt,
context=fileContextPart if fileContextPart else None,
options=AiCallOptions(
operationType=OperationTypeEnum.DATA_ANALYSE,
temperature=0.0,
compressPrompt=False,
compressContext=False,
resultFormat="txt"
)
)
def _buildFileContext(fileContexts: List[FileContext]) -> str:
"""Build the file context string with line numbers."""
if not fileContexts:
return ""
parts = []
for fc in fileContexts:
if not fc.content:
continue
lines = fc.content.split("\n")
numberedLines = [f"{i + 1}|{line}" for i, line in enumerate(lines)]
numbered = "\n".join(numberedLines)
parts.append(f"--- FILE: {fc.fileName} ---\n{numbered}\n--- END FILE ---")
return "\n\n".join(parts)
def buildAgentRequest(
userPrompt: Optional[str],
fileListContext: str,
conversationHistory: List[Dict[str, Any]]
) -> AiCallRequest:
"""Build an AiCallRequest for agent mode with tool definitions and conversation history."""
from modules.features.codeeditor.toolRegistry import formatToolDefinitions
systemPrompt = _AGENT_SYSTEM_PROMPT.replace("{{TOOL_DEFINITIONS}}", formatToolDefinitions())
if not conversationHistory:
fullPrompt = systemPrompt
context = f"## Available files\n{fileListContext}\n\n## Task\n{userPrompt}"
else:
fullPrompt = systemPrompt
historyText = _buildConversationHistory(conversationHistory)
context = f"## Available files\n{fileListContext}\n\n## Conversation\n{historyText}"
return AiCallRequest(
prompt=fullPrompt,
context=context,
options=AiCallOptions(
operationType=OperationTypeEnum.DATA_ANALYSE,
temperature=0.0,
compressPrompt=False,
compressContext=False,
resultFormat="txt"
)
)
_AGENT_SYSTEM_PROMPT = """You are an AI agent for file analysis and editing. You work autonomously by using tools to read files, search content, and propose edits.
## Available tools
{{TOOL_DEFINITIONS}}
## How to call tools
Use this exact format for each tool call:
```tool_call
tool: <tool_name>
args: {"param": "value"}
```
## Rules
- Read files ONE AT A TIME with read_file, never assume file contents
- First create a plan, then execute it step by step
- Use search_files to find relevant files before reading them
- Use list_files to discover what files are available
- For file changes, use ```file_edit``` blocks (same format as before)
- You may combine text explanations, tool calls, and file edits in one response
- When you are DONE and need no more tool calls, simply respond with text only (no tool_call blocks)
- Keep responses focused and efficient
## file_edit format (for changes)
```file_edit
fileName: <filename>
oldContent: |
<exact existing content>
newContent: |
<replacement content>
```"""
def _buildConversationHistory(history: List[Dict[str, Any]]) -> str:
"""Build the full conversation history for agent multi-turn context."""
parts = []
for msg in history:
role = msg.get("role", "unknown")
content = msg.get("content", "")
if role == "tool_result":
toolName = msg.get("toolName", "")
parts.append(f"[Tool Result - {toolName}]:\n{content}")
else:
parts.append(f"[{role}]:\n{content}")
return "\n\n".join(parts)
def _buildChatHistory(chatHistory: List[Dict[str, Any]]) -> str:
"""Build a condensed chat history string for multi-turn context."""
if not chatHistory:
return ""
parts = []
for msg in chatHistory[-10:]:
role = msg.get("role", "unknown")
content = msg.get("content", "")
if len(content) > 500:
content = content[:500] + "..."
parts.append(f"[{role}]: {content}")
return "\n".join(parts)

View file

@ -1,184 +0,0 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Response parser for the CodeEditor feature.
Parses AI responses into typed segments (text, code_block, file_edit, tool_call)."""
import logging
import json
import re
from typing import List, Optional
from modules.features.codeeditor.datamodelCodeeditor import ResponseSegment, SegmentTypeEnum
logger = logging.getLogger(__name__)
_FENCE_PATTERN = re.compile(r"^```(\w*)\s*$", re.MULTILINE)
def parseResponse(rawContent: str) -> List[ResponseSegment]:
"""Parse an AI response into typed segments."""
if not rawContent or not rawContent.strip():
return []
segments = []
lines = rawContent.split("\n")
i = 0
textBuffer = []
while i < len(lines):
line = lines[i]
match = _FENCE_PATTERN.match(line)
if match:
if textBuffer:
_flushTextBuffer(textBuffer, segments)
textBuffer = []
lang = match.group(1).strip()
blockLines, endIdx = _collectBlock(lines, i + 1)
blockContent = "\n".join(blockLines)
if lang == "file_edit":
segment = _parseFileEditBlock(blockContent)
if segment:
segments.append(segment)
else:
segments.append(ResponseSegment(
type=SegmentTypeEnum.CODE_BLOCK,
content=blockContent,
language="text"
))
elif lang == "tool_call":
segment = _parseToolCallBlock(blockContent)
if segment:
segments.append(segment)
else:
segments.append(ResponseSegment(
type=SegmentTypeEnum.CODE_BLOCK,
content=blockContent,
language="text"
))
else:
segments.append(ResponseSegment(
type=SegmentTypeEnum.CODE_BLOCK,
content=blockContent,
language=lang or "text"
))
i = endIdx + 1
else:
textBuffer.append(line)
i += 1
if textBuffer:
_flushTextBuffer(textBuffer, segments)
return segments
def hasToolCalls(segments: List[ResponseSegment]) -> bool:
"""Check if any segments contain tool calls."""
return any(s.type == SegmentTypeEnum.TOOL_CALL for s in segments)
def _collectBlock(lines: List[str], startIdx: int) -> tuple:
"""Collect lines inside a fenced code block until closing ```."""
blockLines = []
idx = startIdx
while idx < len(lines):
if lines[idx].strip() == "```":
return blockLines, idx
blockLines.append(lines[idx])
idx += 1
return blockLines, idx
def _flushTextBuffer(buffer: List[str], segments: List[ResponseSegment]):
"""Flush accumulated text lines into a text segment."""
text = "\n".join(buffer).strip()
buffer.clear()
if text:
segments.append(ResponseSegment(
type=SegmentTypeEnum.TEXT,
content=text
))
def _parseFileEditBlock(blockContent: str) -> Optional[ResponseSegment]:
"""Parse a file_edit block into a ResponseSegment with fileName, oldContent, newContent."""
fields = {"fileName": None, "oldContent": None, "newContent": None}
currentField = None
currentLines = []
for line in blockContent.split("\n"):
stripped = line.strip()
newField = None
for key in ("fileName", "oldContent", "newContent"):
if stripped.startswith(f"{key}:"):
newField = key
break
if newField:
if currentField and currentLines:
fields[currentField] = "\n".join(currentLines)
currentField = newField
value = stripped[len(f"{newField}:"):].strip()
if newField == "fileName":
fields["fileName"] = value if value else None
currentField = None
currentLines = []
else:
currentLines = [value] if value and value != "|" else []
else:
if currentField in ("oldContent", "newContent"):
dedented = line[2:] if line.startswith(" ") else line
currentLines.append(dedented)
if currentField and currentLines:
fields[currentField] = "\n".join(currentLines)
if not fields["fileName"]:
logger.warning("file_edit block missing fileName")
return None
if fields["newContent"] is None:
logger.warning(f"file_edit block for {fields['fileName']} missing newContent")
return None
return ResponseSegment(
type=SegmentTypeEnum.FILE_EDIT,
content=f"Edit: {fields['fileName']}",
fileName=fields["fileName"],
oldContent=fields["oldContent"],
newContent=fields["newContent"]
)
def _parseToolCallBlock(blockContent: str) -> Optional[ResponseSegment]:
"""Parse a tool_call block into a ResponseSegment with toolName and toolArgs."""
toolName = None
toolArgs = {}
for line in blockContent.split("\n"):
stripped = line.strip()
if stripped.startswith("tool:"):
toolName = stripped[len("tool:"):].strip()
elif stripped.startswith("args:"):
argsStr = stripped[len("args:"):].strip()
try:
toolArgs = json.loads(argsStr)
except json.JSONDecodeError:
logger.warning(f"Could not parse tool args as JSON: {argsStr}")
toolArgs = {"raw": argsStr}
if not toolName:
logger.warning("tool_call block missing tool name")
return None
return ResponseSegment(
type=SegmentTypeEnum.TOOL_CALL,
content=f"Tool: {toolName}",
toolName=toolName,
toolArgs=toolArgs
)

View file

@ -1,395 +0,0 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
CodeEditor Feature Routes.
SSE-based endpoints for Cursor-style AI file editing.
"""
import logging
import json
import asyncio
from typing import Optional, Dict, Any, List
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request
from fastapi.responses import StreamingResponse
from modules.auth import limiter, getRequestContext, RequestContext
from modules.interfaces import interfaceDbChat, interfaceDbManagement
from modules.interfaces.interfaceAiObjects import AiObjects
from modules.datamodels.datamodelChat import UserInputRequest
from modules.services.serviceStreaming import get_event_manager
from modules.features.codeeditor import codeEditorProcessor, fileContextManager
from modules.features.codeeditor.datamodelCodeeditor import FileEditProposal, EditStatusEnum
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/api/codeeditor",
tags=["Code Editor Feature"],
responses={404: {"description": "Not found"}}
)
_aiObjects: Optional[AiObjects] = None
async def _getAiObjects() -> AiObjects:
"""Lazy-init singleton for AiObjects."""
global _aiObjects
if _aiObjects is None:
_aiObjects = await AiObjects.create()
return _aiObjects
def _getServiceChat(context: RequestContext, featureInstanceId: str = None):
"""Get chat interface with feature instance context."""
return interfaceDbChat.getInterface(
context.user,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=featureInstanceId
)
def _getDbManagement(context: RequestContext, featureInstanceId: str = None):
"""Get management interface with user context for file access."""
return interfaceDbManagement.getInterface(
context.user,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=featureInstanceId
)
def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str:
"""Validate user has access to the feature instance. Returns mandateId."""
from modules.interfaces.interfaceDbApp import getRootInterface
rootInterface = getRootInterface()
instance = rootInterface.getFeatureInstance(instanceId)
if not instance:
raise HTTPException(status_code=404, detail=f"Feature instance {instanceId} not found")
featureAccess = rootInterface.getFeatureAccess(str(context.user.id), instanceId)
if not featureAccess or not featureAccess.enabled:
raise HTTPException(status_code=403, detail="Access denied to this feature instance")
return str(instance.mandateId) if instance.mandateId else None
@router.post("/{instanceId}/start/stream")
@limiter.limit("60/minute")
async def streamCodeeditorStart(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
workflowId: Optional[str] = Query(None, description="Optional workflow ID to continue"),
mode: str = Query("simple", description="Processing mode: 'simple' (single AI call) or 'agent' (multi-step with tools)"),
userInput: UserInputRequest = Body(...),
context: RequestContext = Depends(getRequestContext)
):
"""Start or continue a CodeEditor workflow with SSE streaming. Supports simple and agent mode."""
try:
mandateId = _validateInstanceAccess(instanceId, context)
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
dbManagement = _getDbManagement(context, featureInstanceId=instanceId)
aiObjects = await _getAiObjects()
eventManager = get_event_manager()
if workflowId:
workflow = chatInterface.getWorkflow(workflowId)
if not workflow:
raise HTTPException(status_code=404, detail=f"Workflow {workflowId} not found")
else:
workflow = chatInterface.createWorkflow({
"workflowMode": "CodeEditor",
"status": "running",
"label": userInput.prompt[:80] if userInput.prompt else "CodeEditor Session",
})
workflowId = workflow.get("id") if isinstance(workflow, dict) else workflow.id
queue = eventManager.create_queue(workflowId)
userMessage = {
"role": "user",
"content": userInput.prompt,
"selectedFiles": userInput.listFileId or []
}
await eventManager.emit_event(workflowId, "chatdata", {
"type": "message", "item": userMessage
})
selectedFileIds = userInput.listFileId or []
agentMode = mode.lower() == "agent"
asyncio.create_task(
codeEditorProcessor.processMessage(
workflowId=workflowId,
userPrompt=userInput.prompt,
selectedFileIds=selectedFileIds,
dbManagement=dbManagement,
interfaceAi=aiObjects,
chatInterface=chatInterface,
eventManager=eventManager,
agentMode=agentMode
)
)
async def _eventStream():
streamTimeout = 300
lastActivity = asyncio.get_event_loop().time()
while True:
now = asyncio.get_event_loop().time()
if now - lastActivity > streamTimeout:
yield f"data: {json.dumps({'type': 'error', 'error': 'Stream timeout'})}\n\n"
break
if await request.is_disconnected():
logger.info(f"Client disconnected for workflow {workflowId}")
break
try:
event = await asyncio.wait_for(queue.get(), timeout=1.0)
lastActivity = asyncio.get_event_loop().time()
eventType = event.get("type", "")
if eventType == "chatdata":
yield f"data: {json.dumps(event.get('data', {}))}\n\n"
elif eventType in ("complete", "stopped", "error"):
yield f"data: {json.dumps({'type': eventType, **event.get('data', {})})}\n\n"
break
else:
yield f"data: {json.dumps(event)}\n\n"
except asyncio.TimeoutError:
yield ": keepalive\n\n"
await eventManager.cleanup(workflowId)
return StreamingResponse(
_eventStream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in streamCodeeditorStart: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/{instanceId}/{workflowId}/stop")
@limiter.limit("120/minute")
async def stopWorkflow(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
workflowId: str = Path(..., description="Workflow ID"),
context: RequestContext = Depends(getRequestContext)
):
"""Stop a running CodeEditor workflow."""
try:
_validateInstanceAccess(instanceId, context)
eventManager = get_event_manager()
await eventManager.emit_event(workflowId, "stopped", {
"workflowId": workflowId
}, event_category="workflow", message="Workflow stopped by user")
return {"status": "stopped", "workflowId": workflowId}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error stopping workflow: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{instanceId}/{workflowId}/chatData")
@limiter.limit("120/minute")
def getWorkflowChatData(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
workflowId: str = Path(..., description="Workflow ID"),
afterTimestamp: Optional[float] = Query(None, description="Unix timestamp for incremental fetch"),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Get chat data for a workflow (polling fallback)."""
try:
_validateInstanceAccess(instanceId, context)
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
workflow = chatInterface.getWorkflow(workflowId)
if not workflow:
raise HTTPException(status_code=404, detail=f"Workflow {workflowId} not found")
return chatInterface.getUnifiedChatData(workflowId, afterTimestamp)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting chat data: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{instanceId}/workflows")
@limiter.limit("120/minute")
def getWorkflows(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
page: int = Query(1, ge=1),
pageSize: int = Query(20, ge=1, le=100),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""List workflows for this feature instance."""
try:
_validateInstanceAccess(instanceId, context)
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
from modules.datamodels.datamodelPagination import PaginationParams
pagination = PaginationParams(page=page, pageSize=pageSize)
return chatInterface.getWorkflows(pagination=pagination)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting workflows: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{instanceId}/files")
@limiter.limit("120/minute")
def getFiles(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""List all text files accessible to the user."""
try:
_validateInstanceAccess(instanceId, context)
dbManagement = _getDbManagement(context, featureInstanceId=instanceId)
textFiles = fileContextManager.listTextFiles(dbManagement)
return {
"files": [f.model_dump(exclude={"content"}) for f in textFiles],
"count": len(textFiles)
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error listing files: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{instanceId}/files/{fileId}/content")
@limiter.limit("120/minute")
def getFileContent(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
fileId: str = Path(..., description="File ID"),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Get the text content of a file."""
try:
_validateInstanceAccess(instanceId, context)
dbManagement = _getDbManagement(context, featureInstanceId=instanceId)
fileItem = dbManagement.getFile(fileId)
if not fileItem:
raise HTTPException(status_code=404, detail=f"File {fileId} not found")
fileData = dbManagement.getFileData(fileId)
if not fileData:
raise HTTPException(status_code=404, detail=f"No data for file {fileId}")
try:
content = fileData.decode("utf-8")
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
return {
"fileId": fileId,
"fileName": fileItem.fileName,
"mimeType": fileItem.mimeType,
"content": content
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting file content: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/{instanceId}/{workflowId}/apply")
@limiter.limit("60/minute")
async def applyEdit(
request: Request,
instanceId: str = Path(..., description="Feature instance ID"),
workflowId: str = Path(..., description="Workflow ID"),
proposalData: Dict[str, Any] = Body(...),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Accept a file edit proposal. Updates existing file or creates new one."""
try:
_validateInstanceAccess(instanceId, context)
dbManagement = _getDbManagement(context, featureInstanceId=instanceId)
fileId = proposalData.get("fileId", "")
newContent = proposalData.get("newContent")
fileName = proposalData.get("fileName", "")
if newContent is None:
raise HTTPException(status_code=400, detail="newContent is required")
contentBytes = newContent.encode("utf-8")
isNewFile = not fileId or fileId.startswith("unknown-")
if isNewFile:
mimeType = _guessMimeType(fileName)
fileItem = dbManagement.createFile(fileName, mimeType, contentBytes)
resultFileId = fileItem.id
resultFileName = fileItem.fileName
else:
fileItem = dbManagement.getFile(fileId)
if not fileItem:
raise HTTPException(status_code=404, detail=f"File {fileId} not found")
success = dbManagement.createFileData(fileId, contentBytes)
if not success:
raise HTTPException(status_code=500, detail="Failed to store updated file content")
resultFileId = fileId
resultFileName = fileName or fileItem.fileName
eventManager = get_event_manager()
await eventManager.emit_event(workflowId, "chatdata", {
"type": "file_version",
"item": {
"fileId": resultFileId,
"fileName": resultFileName,
"status": "accepted",
"isNew": isNewFile
}
})
return {
"status": "accepted",
"fileId": resultFileId,
"fileName": resultFileName,
"isNew": isNewFile
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error applying edit: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
_MIME_MAP = {
".md": "text/markdown", ".txt": "text/plain", ".json": "application/json",
".yaml": "application/yaml", ".yml": "application/yaml", ".xml": "application/xml",
".csv": "text/csv", ".py": "text/x-python", ".js": "text/javascript",
".ts": "text/x-typescript", ".html": "text/html", ".css": "text/css",
".sql": "text/x-sql", ".sh": "text/x-shellscript",
}
def _guessMimeType(fileName: str) -> str:
"""Guess MIME type from file extension."""
if not fileName or "." not in fileName:
return "text/plain"
ext = "." + fileName.rsplit(".", 1)[-1].lower()
return _MIME_MAP.get(ext, "text/plain")

View file

@ -1,157 +0,0 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Tool registry and dispatcher for the CodeEditor agent loop.
Defines available tools and executes them against the file context manager."""
import logging
import time
import fnmatch
from typing import Dict, Any, List
from modules.features.codeeditor.datamodelCodeeditor import ToolResult
logger = logging.getLogger(__name__)
TOOL_DEFINITIONS = [
{
"name": "read_file",
"description": "Read the full content of a single file by its fileId.",
"parameters": {"fileId": "string (required)"}
},
{
"name": "list_files",
"description": "List all available text files with metadata (name, size, mimeType). Optionally filter by glob pattern.",
"parameters": {"filter": "string (optional, glob pattern e.g. '*.py')"}
},
{
"name": "search_files",
"description": "Search all file contents for a text query. Returns matching lines with file name and line number.",
"parameters": {"query": "string (required)", "fileType": "string (optional, extension e.g. 'py')"}
},
]
async def dispatch(toolName: str, toolArgs: Dict[str, Any], dbManagement) -> ToolResult:
"""Execute a tool and return the result."""
startTime = time.time()
try:
if toolName == "read_file":
result = await _toolReadFile(toolArgs, dbManagement)
elif toolName == "list_files":
result = _toolListFiles(toolArgs, dbManagement)
elif toolName == "search_files":
result = await _toolSearchFiles(toolArgs, dbManagement)
else:
result = f"Unknown tool: {toolName}"
return ToolResult(toolName=toolName, result=result, success=False,
executionTime=time.time() - startTime)
return ToolResult(toolName=toolName, result=result, success=True,
executionTime=time.time() - startTime)
except Exception as e:
logger.error(f"Tool {toolName} failed: {e}", exc_info=True)
return ToolResult(toolName=toolName, result=f"Error: {str(e)}", success=False,
executionTime=time.time() - startTime)
async def _toolReadFile(args: Dict[str, Any], dbManagement) -> str:
"""Read a single file's content."""
fileId = args.get("fileId", "")
if not fileId:
return "Error: fileId is required"
fileItem = dbManagement.getFile(fileId)
if not fileItem:
return f"Error: File {fileId} not found"
fileData = dbManagement.getFileData(fileId)
if not fileData:
return f"Error: No data for file {fileId}"
try:
content = fileData.decode("utf-8")
except UnicodeDecodeError:
return f"Error: File {fileItem.fileName} is not valid UTF-8"
lines = content.split("\n")
numbered = "\n".join([f"{i + 1}|{line}" for i, line in enumerate(lines)])
return f"--- FILE: {fileItem.fileName} (id: {fileId}) ---\n{numbered}\n--- END FILE ---"
def _toolListFiles(args: Dict[str, Any], dbManagement) -> str:
"""List all text files, optionally filtered by glob pattern."""
from modules.features.codeeditor.datamodelCodeeditor import isTextFile
filterPattern = args.get("filter", "")
allFiles = dbManagement.getAllFiles()
if not allFiles:
return "No files found."
lines = []
for f in allFiles:
if not isTextFile(f.mimeType, f.fileName):
continue
if filterPattern and not fnmatch.fnmatch(f.fileName, filterPattern):
continue
lines.append(f"- {f.fileName} (id: {f.id}, size: {f.fileSize}B, type: {f.mimeType})")
if not lines:
return "No matching text files found."
return f"Available files ({len(lines)}):\n" + "\n".join(lines)
async def _toolSearchFiles(args: Dict[str, Any], dbManagement) -> str:
"""Search file contents for a query string."""
from modules.features.codeeditor.datamodelCodeeditor import isTextFile
query = args.get("query", "")
if not query:
return "Error: query is required"
fileType = args.get("fileType", "")
allFiles = dbManagement.getAllFiles()
if not allFiles:
return "No files to search."
hits = []
maxHits = 50
queryLower = query.lower()
for f in allFiles:
if not isTextFile(f.mimeType, f.fileName):
continue
if fileType and not f.fileName.endswith(f".{fileType}"):
continue
fileData = dbManagement.getFileData(f.id)
if not fileData:
continue
try:
content = fileData.decode("utf-8")
except UnicodeDecodeError:
continue
for lineNum, line in enumerate(content.split("\n"), 1):
if queryLower in line.lower():
hits.append(f"{f.fileName}:{lineNum}: {line.strip()}")
if len(hits) >= maxHits:
break
if len(hits) >= maxHits:
break
if not hits:
return f"No matches found for '{query}'."
result = f"Search results for '{query}' ({len(hits)} matches):\n" + "\n".join(hits)
if len(hits) >= maxHits:
result += f"\n... (truncated at {maxHits} matches)"
return result
def formatToolDefinitions() -> str:
"""Format tool definitions for inclusion in the system prompt."""
parts = []
for tool in TOOL_DEFINITIONS:
params = ", ".join([f"{k}: {v}" for k, v in tool["parameters"].items()])
parts.append(f"- **{tool['name']}**: {tool['description']}\n Parameters: {{{params}}}")
return "\n".join(parts)

View file

@ -10,8 +10,10 @@ import json
import asyncio
import base64
import uuid
from typing import Optional
from fastapi import APIRouter, HTTPException, Depends, Request
from fastapi import APIRouter, HTTPException, Depends, Request, WebSocket, WebSocketDisconnect, Query
from fastapi.responses import StreamingResponse, Response
from modules.auth import limiter, getRequestContext, RequestContext
@ -31,7 +33,6 @@ from .datamodelCommcoach import (
StartSessionRequest, CreatePersonaRequest, UpdatePersonaRequest,
)
from .serviceCommcoach import CommcoachService, emitSessionEvent, getSessionEventQueue, cleanupSessionEvents
logger = logging.getLogger(__name__)
_activeProcessTasks: dict = {}

View file

@ -1011,15 +1011,15 @@ class CommcoachService:
async def _callAi(self, systemPrompt: str, userPrompt: str):
"""Call the AI service with the given prompts."""
from modules.services.serviceAi.mainServiceAi import AiService
from modules.serviceCenter import getService
from modules.serviceCenter.context import ServiceCenterContext
serviceContext = type('Ctx', (), {
'user': self.currentUser,
'mandateId': self.mandateId,
'featureInstanceId': self.instanceId,
'featureCode': 'commcoach',
})()
aiService = AiService(serviceCenter=serviceContext)
serviceContext = ServiceCenterContext(
user=self.currentUser,
mandate_id=self.mandateId,
feature_instance_id=self.instanceId,
)
aiService = getService("ai", serviceContext)
await aiService.ensureAiObjectsInitialized()
aiRequest = AiCallRequest(

View file

@ -7,7 +7,7 @@ from urllib.parse import urlparse, unquote
from modules.datamodels.datamodelUam import User
from .datamodelFeatureNeutralizer import DataNeutralizerAttributes, DataNeutraliserConfig
from modules.services import getInterface as getServices
from modules.serviceHub import getInterface as getServices
logger = logging.getLogger(__name__)
@ -205,7 +205,7 @@ class NeutralizationPlayground:
async def processSharepointFiles(self, sourcePath: str, targetPath: str) -> Dict[str, Any]:
"""Process files from SharePoint source path and store neutralized files in target path"""
from modules.services.serviceSharepoint.mainServiceSharepoint import SharepointService
from modules.serviceCenter.services.serviceSharepoint.mainServiceSharepoint import SharepointService
processor = SharepointProcessor(self.currentUser, self.services)
return await processor.processSharepointFiles(sourcePath, targetPath)

View file

@ -262,8 +262,8 @@ class NeutralizationService:
fileId: Optional[str]
) -> Dict[str, Any]:
"""Extract -> neutralize -> adapt -> generate for PDF/DOCX/XLSX/PPTX."""
from modules.services.serviceExtraction.mainServiceExtraction import ExtractionService
from modules.services.serviceExtraction.subPipeline import runExtraction
from modules.serviceCenter.services.serviceExtraction.mainServiceExtraction import ExtractionService
from modules.serviceCenter.services.serviceExtraction.subPipeline import runExtraction
from modules.datamodels.datamodelExtraction import ExtractionOptions, MergeStrategy
# Ensure registries exist
@ -405,10 +405,10 @@ class NeutralizationService:
def _getRendererForMime(self, mimeType: str):
"""Get renderer instance and output mime for the given input MIME type."""
from modules.services.serviceGeneration.renderers.rendererPdf import RendererPdf
from modules.services.serviceGeneration.renderers.rendererDocx import RendererDocx
from modules.services.serviceGeneration.renderers.rendererXlsx import RendererXlsx
from modules.services.serviceGeneration.renderers.rendererPptx import RendererPptx
from modules.serviceCenter.services.serviceGeneration.renderers.rendererPdf import RendererPdf
from modules.serviceCenter.services.serviceGeneration.renderers.rendererDocx import RendererDocx
from modules.serviceCenter.services.serviceGeneration.renderers.rendererXlsx import RendererXlsx
from modules.serviceCenter.services.serviceGeneration.renderers.rendererPptx import RendererPptx
mime_map = {
"application/pdf": (RendererPdf, "application/pdf"),

View file

@ -20,6 +20,8 @@ _NEUTRALIZATION_BLACKLIST = frozenset({
"Leistungen", "Basis", "Benefits", # Section labels
"Start", "Beginn", "Ende", "End", "trip", # Contract labels (Start of trip, End of trip, etc.)
"incomplete", "Application", "Complete", "Pending", # Form/status labels, not addresses
"Marketing", "Verkaufsstrategien", "Qualitätsmanagement", # Business terms, not addresses
"Ausbildungsstätte", "Realschule", # Institution types, not city names
# Ambiguous substrings match in Zurich, CHF, UID-Nr., websites, etc.
"CH", "DE", "FR", "IT", "Nr", "Nr.", "Nr:", "No", "No.", "No:",
"www", ".ch", ".com", ".org", ".net", "CHF",
@ -95,6 +97,19 @@ class StringParser:
if not (m[0] == "address" and any(overlaps(m[2], m[3], ds, de) for ds, de in date_ranges))
]
# For name matches: resolve overlaps keep only longest to avoid multiple placeholders for one name
# (e.g. "Ida", "Dittrich", "Ida Dittrich" → keep only "Ida Dittrich" with one UUID)
name_matches = [(m, m[3] - m[2]) for m in patternMatches if m[0] == "name"]
name_spans = [(m[2], m[3]) for m, _ in name_matches]
patternMatches = [
m for m in patternMatches
if m[0] != "name"
or not any(
overlaps(m[2], m[3], ns, ne) and (ne - ns) > (m[3] - m[2])
for (ns, ne) in name_spans
)
]
# Process from right to left to avoid position shifts
for patternName, matchedText, start, end in reversed(patternMatches):
# Skip if already a placeholder
@ -157,24 +172,72 @@ class StringParser:
expanded.add(f"{n1} {n2}")
expanded.add(f"{n2} {n1}")
# Process longest first so "Ida Dittrich" replaces before "Ida" or "Dittrich"
# One UUID per person: composites and their parts share same UUID
# Also align with DataPatterns mapping (step 1 may have already replaced "Ida Dittrich")
name_to_uuid: Dict[str, str] = {}
for composite in sorted(expanded, key=len, reverse=True):
if " " not in composite:
continue
parts = composite.split()
parts_set = frozenset(parts)
existing_uuid = next((name_to_uuid[p] for p in parts_set if p in name_to_uuid), None)
if existing_uuid is None:
existing_uuid = next(
(self.mapping[k] for k in (composite, *parts_set) if k in self.mapping),
None
)
if existing_uuid is None:
existing_uuid = f"[name.{uuid.uuid4()}]"
for p in parts_set:
name_to_uuid[p] = existing_uuid
name_to_uuid[composite] = existing_uuid
if len(parts) == 2:
name_to_uuid[f"{parts[1]} {parts[0]}"] = existing_uuid
for n in names:
if n not in name_to_uuid:
name_to_uuid[n] = self.mapping.get(n) or f"[name.{uuid.uuid4()}]"
self.mapping.update({k: v for k, v in name_to_uuid.items() if k not in self.mapping})
# Collect ALL matches from all name patterns, then keep only longest per span to avoid
# triple replacement ("Ida" + "Dittrich" + "Ida Dittrich" -> only "Ida Dittrich")
all_matches: List[Tuple[str, int, int]] = []
for name in sorted(expanded, key=len, reverse=True):
# Composite: flexible whitespace (space, newline); single: word boundaries
if " " in name:
parts = name.split()
pattern_str = r"\b" + r"\s+".join(re.escape(p) for p in parts) + r"\b"
else:
pattern_str = r"\b" + re.escape(name) + r"\b"
pattern = re.compile(pattern_str, re.IGNORECASE)
for m in pattern.finditer(text):
all_matches.append((m.group(), m.start(), m.end()))
matches = list(pattern.finditer(text))
for match in reversed(matches):
matchedText = match.group()
if matchedText not in self.mapping:
placeholderId = str(uuid.uuid4())
self.mapping[matchedText] = f"[name.{placeholderId}]"
replacement = self.mapping[matchedText]
start, end = match.span()
# Remove matches that overlap with a longer match (keep longest per span)
def _overlaps(s1, e1, s2, e2):
return s1 < e2 and s2 < e1
def _contained_in_longer(matched_text: str, start: int, end: int) -> bool:
for other_text, os, oe in all_matches:
if (os, oe) == (start, end):
continue
if _overlaps(start, end, os, oe) and (oe - os) > (end - start):
return True
return False
to_replace = [(t, s, e) for t, s, e in all_matches if not _contained_in_longer(t, s, e)]
to_replace = list({(s, e): (t, s, e) for t, s, e in to_replace}.values())
# Replace from right to left to avoid position shift
for matched_text, start, end in sorted(to_replace, key=lambda x: -x[1]):
normalized = " ".join(matched_text.split())
replacement = (
self.mapping.get(matched_text)
or self.mapping.get(normalized)
or next((v for k, v in self.mapping.items() if " ".join(k.split()) == normalized), None)
or next((v for k, v in self.mapping.items() if k.lower() == matched_text.lower()), None)
)
if not replacement:
replacement = f"[name.{uuid.uuid4()}]"
self.mapping[matched_text] = replacement
text = text[:start] + replacement + text[end:]
return text

View file

@ -307,14 +307,17 @@ class DataPatterns:
name="address",
patterns=[
# Full address block: company, street, postfach, postal+city (stop before domain like , AXA.ch)
r'\b[^,\n]+(?:,\s*[^,\n]+)*,\s*\d{4}\s+[A-Za-zäöüßÄÖÜ]+\s*(?=,\s*[a-zA-Z0-9.-]+\.(?:ch|com|org|net)\b|$)',
# Street + house number (standalone)
r'\b(?:[A-Za-zäöüßÄÖÜ]+(?:-[A-Za-zäöüßÄÖÜ]+)*(?:strasse|str\.|gasse|weg|platz|allee|boulevard|avenue|via|strada|rue|chemin|route))\s+\d{1,4}(?:/\d{1,4})?(?:[a-z])?\b',
# Supports Swiss PLZ (4 digits) and German PLZ (5 digits)
r'\b[^,\n]+(?:,\s*[^,\n]+)*,\s*\d{4,5}\s+[A-Za-zäöüßÄÖÜ]+\s*(?=,\s*[a-zA-Z0-9.-]+\.(?:ch|com|org|net)\b|$)',
# Street + house number (standalone); includes "straße" for German
r'\b(?:[A-Za-zäöüßÄÖÜ]+(?:-[A-Za-zäöüßÄÖÜ]+)*(?:straße|strasse|str\.|gasse|weg|platz|allee|boulevard|avenue|via|strada|rue|chemin|route))\s+\d{1,4}(?:/\d{1,4})?(?:[a-z])?\b',
# Postfach / PO Box (standalone)
r'\b(?:Postfach|Postbox|P\.?O\.?\s*Box|Case\s+postale|Casella\s+postale|Boîte\s+postale)\s+\d{1,6}\b',
# Postal code + city (standalone); exclude year+non-city and common non-city words
# (?<!\d{2}\.\d{2}\.) = not part of date DD.MM.YYYY (e.g. 27.01.2026)
r'(?<!\d{2}\.\d{2}\.)\b\d{4}\s+(?!den|der|die|das|dem|des|und|oder|für|bei|mit|Version|Versand|Vertrag|Verfügung|Verschickung|Versicherung|erhalten|Schreiben|Jahr|Jahres|incomplete|Application|Complete|Pending|Matrikel|Student|Studien|Kontakt|Telefon|Rechnung|Invoice)[A-Za-zäöüßÄÖÜ]+\b(?!\s*:)'
# Exclude business terms (Marketing, Qualitätsmanagement, etc.) often follow years
# Swiss PLZ (4 digits) and German PLZ (5 digits)
r'(?<!\d{2}\.\d{2}\.)\b\d{4,5}\s+(?!den|der|die|das|dem|des|und|oder|für|bei|mit|Version|Versand|Vertrag|Verfügung|Verschickung|Versicherung|erhalten|Schreiben|Jahr|Jahres|incomplete|Application|Complete|Pending|Matrikel|Student|Studien|Kontakt|Telefon|Rechnung|Invoice|Marketing|Verkaufsstrategien|Qualitätsmanagement|Management|Strategien|Projektmanagement|Vertrieb|Vertriebsstrategien|Ausbildungsstätte|Realschule)[A-Za-zäöüßÄÖÜ]+\b(?!\s*:)'
],
replacement_template="[ADDRESS_{}]"
),

View file

@ -56,7 +56,16 @@ def neutralize_pdf_in_place(
logger.error(f"Failed to open PDF: {e}")
return None
sorted_items = sorted(mapping.items(), key=lambda x: -len(x[0]))
# For same placeholder: only search longest original_text to avoid triple overlay
# (e.g. "Ida Dittrich", "Ida", "Dittrich" all map to [name.x] → only search "Ida Dittrich")
placeholder_to_longest: Dict[str, str] = {}
for orig, ph in mapping.items():
if not orig or not ph:
continue
if ph not in placeholder_to_longest or len(orig) > len(placeholder_to_longest[ph]):
placeholder_to_longest[ph] = orig
filtered = [(orig, ph) for ph, orig in placeholder_to_longest.items()]
sorted_items = sorted(filtered, key=lambda x: -len(x[0]))
fill_color = (1, 1, 1)
text_color = (0, 0, 0)
fontname = "helv"

View file

@ -284,7 +284,7 @@ from .datamodelFeatureRealEstate import (
Land,
DokumentTyp,
)
from modules.services import getInterface as getServices
from modules.serviceHub import getInterface as getServices
from .interfaceFeatureRealEstate import getInterface as getRealEstateInterface
from modules.interfaces.interfaceDbManagement import getInterface as getComponentInterface
from modules.connectors.connectorSwissTopoMapServer import SwissTopoMapServerConnector

View file

@ -843,7 +843,7 @@ async def testVoice(
):
"""Test TTS voice with AI-generated sample text in the correct language."""
from modules.interfaces.interfaceVoiceObjects import getVoiceInterface
from modules.services.serviceAi.mainServiceAi import AiService
from modules.serviceCenter.services.serviceAi.mainServiceAi import AiService
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, PriorityEnum
mandateId = _validateInstanceAccess(instanceId, context)

View file

@ -1062,7 +1062,7 @@ class TeamsbotService:
# Call SPEECH_TEAMS
try:
from modules.services.serviceAi.mainServiceAi import AiService
from modules.serviceCenter.services.serviceAi.mainServiceAi import AiService
# Create minimal service context for AI billing
serviceContext = _ServiceContext(self.currentUser, self.mandateId, self.instanceId)
@ -1684,7 +1684,7 @@ class TeamsbotService:
"""Summarize a long user-provided session context to its essential points.
This reduces token usage in every subsequent AI call."""
try:
from modules.services.serviceAi.mainServiceAi import AiService
from modules.serviceCenter.services.serviceAi.mainServiceAi import AiService
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, PriorityEnum
serviceContext = _ServiceContext(self.currentUser, self.mandateId, self.instanceId)
@ -1738,7 +1738,7 @@ class TeamsbotService:
lines.append(f"[{speaker}]: {text}")
textToSummarize = "\n".join(lines)
from modules.services.serviceAi.mainServiceAi import AiService
from modules.serviceCenter.services.serviceAi.mainServiceAi import AiService
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, PriorityEnum
serviceContext = _ServiceContext(self.currentUser, self.mandateId, self.instanceId)
@ -1783,7 +1783,7 @@ class TeamsbotService:
for t in transcripts
)
from modules.services.serviceAi.mainServiceAi import AiService
from modules.serviceCenter.services.serviceAi.mainServiceAi import AiService
serviceContext = _ServiceContext(self.currentUser, self.mandateId, self.instanceId)
aiService = AiService(serviceCenter=serviceContext)

View file

@ -118,6 +118,13 @@ class BaseAccountingConnector(ABC):
"""Load the vendor list. Override in connectors that support it."""
return []
async def getJournalEntries(self, config: Dict[str, Any], dateFrom: Optional[str] = None, dateTo: Optional[str] = None, accountNumbers: Optional[List[str]] = None) -> List[Dict[str, Any]]:
"""Read journal entries from the external system. Each entry should contain:
- externalId, bookingDate, reference, description, currency, totalAmount
- lines: list of {accountNumber, debitAmount, creditAmount, currency, taxCode, costCenter, description}
accountNumbers: pre-fetched account numbers (avoids redundant API call). Override in connectors that support it."""
return []
async def uploadDocument(
self,
config: Dict[str, Any],

View file

@ -0,0 +1,306 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Orchestrates importing accounting data from external systems into TrusteeData* tables.
Flow: load config resolve connector fetch data clear old records write new records compute balances.
"""
import logging
import time
from collections import defaultdict
from typing import Dict, Any, Optional
from .accountingConnectorBase import BaseAccountingConnector
from .accountingRegistry import _getAccountingRegistry
logger = logging.getLogger(__name__)
class AccountingDataSync:
"""Imports accounting data (read-only) from an external system into local TrusteeData* tables."""
def __init__(self, trusteeInterface):
self._if = trusteeInterface
self._registry = _getAccountingRegistry()
async def importData(
self,
featureInstanceId: str,
mandateId: str,
dateFrom: Optional[str] = None,
dateTo: Optional[str] = None,
) -> Dict[str, Any]:
"""Run a full data import for a feature instance.
Returns a summary dict with counts per entity and any errors.
"""
from modules.features.trustee.datamodelFeatureTrustee import (
TrusteeAccountingConfig,
TrusteeDataAccount,
TrusteeDataJournalEntry,
TrusteeDataJournalLine,
TrusteeDataContact,
TrusteeDataAccountBalance,
)
from modules.shared.configuration import decryptValue
summary: Dict[str, Any] = {
"accounts": 0,
"journalEntries": 0,
"journalLines": 0,
"contacts": 0,
"accountBalances": 0,
"errors": [],
"startedAt": time.time(),
}
cfgRecords = self._if.db.getRecordset(
TrusteeAccountingConfig,
recordFilter={"featureInstanceId": featureInstanceId, "isActive": True},
)
if not cfgRecords:
summary["errors"].append("No active accounting configuration found")
return summary
cfgRecord = cfgRecords[0]
connectorType = cfgRecord.get("connectorType", "")
encryptedConfig = cfgRecord.get("encryptedConfig", "")
try:
import json
plainJson = decryptValue(encryptedConfig)
connConfig = json.loads(plainJson) if plainJson else {}
except Exception as e:
summary["errors"].append(f"Failed to decrypt config: {e}")
return summary
connector = self._registry.getConnector(connectorType)
if not connector:
summary["errors"].append(f"Unknown connector type: {connectorType}")
return summary
scope = {"featureInstanceId": featureInstanceId, "mandateId": mandateId}
logger.info(f"AccountingDataSync starting for {featureInstanceId}, connector={connectorType}, dateFrom={dateFrom}, dateTo={dateTo}")
fetchedAccountNumbers: list = []
# 1) Chart of accounts
try:
charts = await connector.getChartOfAccounts(connConfig)
fetchedAccountNumbers = [acc.accountNumber for acc in charts if acc.accountNumber]
self._clearTable(TrusteeDataAccount, featureInstanceId)
for acc in charts:
self._if.db.recordCreate(TrusteeDataAccount, {
"accountNumber": acc.accountNumber,
"label": acc.label,
"accountType": acc.accountType or "",
"currency": "CHF",
"isActive": True,
**scope,
})
summary["accounts"] = len(charts)
except Exception as e:
logger.error(f"Import accounts failed: {e}", exc_info=True)
summary["errors"].append(f"Accounts: {e}")
# 2) Journal entries + lines (pass already-fetched chart to avoid redundant API call)
try:
rawEntries = await connector.getJournalEntries(connConfig, dateFrom=dateFrom, dateTo=dateTo, accountNumbers=fetchedAccountNumbers or None)
self._clearTable(TrusteeDataJournalEntry, featureInstanceId)
self._clearTable(TrusteeDataJournalLine, featureInstanceId)
lineCount = 0
for raw in rawEntries:
import uuid
entryId = str(uuid.uuid4())
self._if.db.recordCreate(TrusteeDataJournalEntry, {
"id": entryId,
"externalId": raw.get("externalId"),
"bookingDate": raw.get("bookingDate"),
"reference": raw.get("reference"),
"description": raw.get("description", ""),
"currency": raw.get("currency", "CHF"),
"totalAmount": float(raw.get("totalAmount", 0)),
**scope,
})
for line in (raw.get("lines") or []):
self._if.db.recordCreate(TrusteeDataJournalLine, {
"journalEntryId": entryId,
"accountNumber": line.get("accountNumber", ""),
"debitAmount": float(line.get("debitAmount", 0)),
"creditAmount": float(line.get("creditAmount", 0)),
"currency": line.get("currency", "CHF"),
"taxCode": line.get("taxCode"),
"costCenter": line.get("costCenter"),
"description": line.get("description", ""),
**scope,
})
lineCount += 1
summary["journalEntries"] = len(rawEntries)
summary["journalLines"] = lineCount
except Exception as e:
logger.error(f"Import journal entries failed: {e}")
summary["errors"].append(f"Journal entries: {e}")
# 3) Contacts (customers + vendors)
try:
self._clearTable(TrusteeDataContact, featureInstanceId)
contactCount = 0
customers = await connector.getCustomers(connConfig)
for c in customers:
self._if.db.recordCreate(TrusteeDataContact, self._mapContact(c, "customer", scope))
contactCount += 1
vendors = await connector.getVendors(connConfig)
for v in vendors:
self._if.db.recordCreate(TrusteeDataContact, self._mapContact(v, "vendor", scope))
contactCount += 1
summary["contacts"] = contactCount
except Exception as e:
logger.error(f"Import contacts failed: {e}", exc_info=True)
summary["errors"].append(f"Contacts: {e}")
# 4) Compute account balances from journal lines
try:
self._clearTable(TrusteeDataAccountBalance, featureInstanceId)
balanceCount = self._computeBalances(featureInstanceId, mandateId)
summary["accountBalances"] = balanceCount
except Exception as e:
logger.error(f"Compute balances failed: {e}")
summary["errors"].append(f"Balances: {e}")
# Update config with last import timestamp
try:
cfgId = cfgRecord.get("id")
if cfgId:
self._if.db.recordModify(TrusteeAccountingConfig, cfgId, {
"lastSyncAt": time.time(),
"lastSyncStatus": "success" if not summary["errors"] else "partial",
"lastSyncErrorMessage": "; ".join(summary["errors"])[:500] if summary["errors"] else None,
})
except Exception:
pass
summary["finishedAt"] = time.time()
summary["durationSeconds"] = round(summary["finishedAt"] - summary["startedAt"], 1)
logger.info(
f"AccountingDataSync completed for {featureInstanceId}: "
f"{summary['accounts']} accounts, {summary['journalEntries']} entries, "
f"{summary['journalLines']} lines, {summary['contacts']} contacts, "
f"{summary['accountBalances']} balances, {len(summary['errors'])} errors, "
f"{summary['durationSeconds']}s"
)
return summary
@staticmethod
def _safeStr(val: Any) -> str:
"""Convert a value to a safe string for DB storage, collapsing nested dicts/lists."""
if val is None:
return ""
if isinstance(val, (dict, list)):
return ""
return str(val)
def _mapContact(self, raw: Dict[str, Any], contactType: str, scope: Dict[str, Any]) -> Dict[str, Any]:
"""Extract contact fields from a raw API dict, handling varying field names across connectors."""
s = self._safeStr
return {
"externalId": s(raw.get("id") or raw.get("Id") or raw.get("customer_nr") or raw.get("vendor_nr") or ""),
"contactType": contactType,
"contactNumber": s(
raw.get("customernumber") or raw.get("customer_nr")
or raw.get("vendornumber") or raw.get("vendor_nr")
or raw.get("nr") or raw.get("ContactNumber")
or raw.get("id") or ""
),
"name": s(raw.get("name") or raw.get("Name") or raw.get("name_1") or ""),
"address": s(raw.get("addr1") or raw.get("address") or raw.get("Address") or ""),
"zip": s(raw.get("zipcode") or raw.get("postcode") or raw.get("Zip") or raw.get("zip") or ""),
"city": s(raw.get("city") or raw.get("City") or ""),
"country": s(raw.get("country") or raw.get("country_id") or raw.get("Country") or ""),
"email": s(raw.get("email") or raw.get("mail") or raw.get("Email") or ""),
"phone": s(raw.get("phone") or raw.get("phone_fixed") or raw.get("Phone") or ""),
"vatNumber": s(raw.get("vat_identifier") or raw.get("vatNumber") or ""),
**scope,
}
def _clearTable(self, model, featureInstanceId: str):
"""Delete all records for this feature instance from a TrusteeData* table."""
records = self._if.db.getRecordset(model, recordFilter={"featureInstanceId": featureInstanceId})
for r in (records or []):
rid = r.get("id") if isinstance(r, dict) else getattr(r, "id", None)
if rid:
try:
self._if.db.recordDelete(model, rid)
except Exception:
pass
def _computeBalances(self, featureInstanceId: str, mandateId: str) -> int:
"""Aggregate journal lines into monthly + annual account balances."""
from modules.features.trustee.datamodelFeatureTrustee import (
TrusteeDataJournalEntry,
TrusteeDataJournalLine,
TrusteeDataAccountBalance,
)
entries = self._if.db.getRecordset(
TrusteeDataJournalEntry,
recordFilter={"featureInstanceId": featureInstanceId},
) or []
entryDates = {}
for e in entries:
eid = e.get("id") if isinstance(e, dict) else getattr(e, "id", None)
bdate = e.get("bookingDate") if isinstance(e, dict) else getattr(e, "bookingDate", None)
if eid and bdate:
entryDates[eid] = bdate
lines = self._if.db.getRecordset(
TrusteeDataJournalLine,
recordFilter={"featureInstanceId": featureInstanceId},
) or []
# key: (accountNumber, year, month)
buckets: Dict[tuple, Dict[str, float]] = defaultdict(lambda: {"debit": 0.0, "credit": 0.0})
for ln in lines:
if isinstance(ln, dict):
jeid = ln.get("journalEntryId", "")
accNo = ln.get("accountNumber", "")
debit = float(ln.get("debitAmount", 0))
credit = float(ln.get("creditAmount", 0))
else:
jeid = getattr(ln, "journalEntryId", "")
accNo = getattr(ln, "accountNumber", "")
debit = float(getattr(ln, "debitAmount", 0))
credit = float(getattr(ln, "creditAmount", 0))
bdate = entryDates.get(jeid, "")
if not accNo or not bdate:
continue
parts = bdate.split("-")
if len(parts) < 2:
continue
year = int(parts[0])
month = int(parts[1])
buckets[(accNo, year, month)]["debit"] += debit
buckets[(accNo, year, month)]["credit"] += credit
buckets[(accNo, year, 0)]["debit"] += debit
buckets[(accNo, year, 0)]["credit"] += credit
count = 0
scope = {"featureInstanceId": featureInstanceId, "mandateId": mandateId}
for (accNo, year, month), totals in buckets.items():
closing = totals["debit"] - totals["credit"]
self._if.db.recordCreate(TrusteeDataAccountBalance, {
"accountNumber": accNo,
"periodYear": year,
"periodMonth": month,
"openingBalance": 0.0,
"debitTotal": round(totals["debit"], 2),
"creditTotal": round(totals["credit"], 2),
"closingBalance": round(closing, 2),
"currency": "CHF",
**scope,
})
count += 1
return count

View file

@ -40,15 +40,15 @@ class AccountingConnectorAbacus(BaseAccountingConnector):
def getRequiredConfigFields(self) -> List[ConnectorConfigField]:
return [
ConnectorConfigField(
key="abacusHost",
label={"en": "Abacus Host URL", "de": "Abacus Host-URL", "fr": "URL Hôte Abacus"},
key="apiBaseUrl",
label={"en": "API Base URL", "de": "API Base URL", "fr": "URL de base API"},
fieldType="text",
secret=False,
placeholder="e.g. abacus.meinefirma.ch",
placeholder="e.g. https://abacus.meinefirma.ch/api/entity/v1/",
),
ConnectorConfigField(
key="mandant",
label={"en": "Mandant Number", "de": "Mandantennummer", "fr": "Numéro de mandant"},
key="clientName",
label={"en": "Client Name", "de": "Mandantenname", "fr": "Nom du client"},
fieldType="text",
secret=False,
placeholder="e.g. 7777",
@ -68,22 +68,37 @@ class AccountingConnectorAbacus(BaseAccountingConnector):
]
def _buildBaseUrl(self, config: Dict[str, Any]) -> str:
host = config["abacusHost"].rstrip("/")
if not host.startswith("http"):
host = f"https://{host}"
return host
apiBaseUrl = str(config.get("apiBaseUrl") or "").strip()
if not apiBaseUrl:
raise ValueError("Missing required config: apiBaseUrl")
if not apiBaseUrl.startswith("http"):
apiBaseUrl = f"https://{apiBaseUrl}"
return apiBaseUrl.rstrip("/")
def _buildAuthBaseUrl(self, config: Dict[str, Any]) -> str:
apiBaseUrl = str(config.get("apiBaseUrl") or "").strip()
if not apiBaseUrl:
raise ValueError("Missing required config: apiBaseUrl")
if not apiBaseUrl.startswith("http"):
apiBaseUrl = f"https://{apiBaseUrl}"
apiBaseUrl = apiBaseUrl.rstrip("/")
if "/api/entity/v1" in apiBaseUrl:
return apiBaseUrl.split("/api/entity/v1", 1)[0]
if "/api/" in apiBaseUrl:
return apiBaseUrl.split("/api/", 1)[0]
return apiBaseUrl
async def _getAccessToken(self, config: Dict[str, Any]) -> Optional[str]:
"""Obtain an OAuth access token using client_credentials grant.
Tokens are cached and refreshed when expired (default 600s).
"""
cacheKey = f"{config.get('abacusHost')}_{config.get('clientId')}"
cacheKey = f"{config.get('apiBaseUrl')}_{config.get('clientName')}_{config.get('clientId')}"
cached = self._tokenCache.get(cacheKey)
if cached and cached.get("expiresAt", 0) > time.time() + 30:
return cached["accessToken"]
baseUrl = self._buildBaseUrl(config)
baseUrl = self._buildAuthBaseUrl(config)
try:
async with aiohttp.ClientSession() as session:
@ -120,8 +135,10 @@ class AccountingConnectorAbacus(BaseAccountingConnector):
def _buildEntityUrl(self, config: Dict[str, Any], entity: str) -> str:
baseUrl = self._buildBaseUrl(config)
mandant = config["mandant"]
return f"{baseUrl}/api/entity/v1/{mandant}/{entity}"
clientName = config.get("clientName")
if not clientName:
raise ValueError("Missing required config: clientName")
return f"{baseUrl}/{clientName}/{entity}"
async def _buildAuthHeaders(self, config: Dict[str, Any]) -> Optional[Dict[str, str]]:
token = await self._getAccessToken(config)
@ -130,6 +147,19 @@ class AccountingConnectorAbacus(BaseAccountingConnector):
return {"Authorization": f"Bearer {token}", "Accept": "application/json", "Content-Type": "application/json"}
async def testConnection(self, config: Dict[str, Any]) -> SyncResult:
apiBaseUrl = str(config.get("apiBaseUrl") or "")
clientName = str(config.get("clientName") or "")
clientId = str(config.get("clientId") or "")
clientSecret = str(config.get("clientSecret") or "")
if not apiBaseUrl or not clientName or not clientId or not clientSecret:
return SyncResult(
success=False,
errorMessage=(
f"Missing credentials: apiBaseUrl={bool(apiBaseUrl)}, "
f"clientName={bool(clientName)}, clientId={bool(clientId)}, "
f"clientSecret={bool(clientSecret)}"
),
)
headers = await self._buildAuthHeaders(config)
if not headers:
return SyncResult(success=False, errorMessage="Failed to obtain access token")
@ -225,6 +255,60 @@ class AccountingConnectorAbacus(BaseAccountingConnector):
except Exception as e:
return SyncResult(success=False, errorMessage=str(e))
async def getJournalEntries(self, config: Dict[str, Any], dateFrom: Optional[str] = None, dateTo: Optional[str] = None, accountNumbers: Optional[List[str]] = None) -> List[Dict[str, Any]]:
"""Read GeneralJournalEntries from Abacus (OData V4, paginated)."""
headers = await self._buildAuthHeaders(config)
if not headers:
return []
filterParts = []
if dateFrom:
filterParts.append(f"JournalDate ge {dateFrom}")
if dateTo:
filterParts.append(f"JournalDate le {dateTo}")
queryParams = ""
if filterParts:
queryParams = "?$filter=" + " and ".join(filterParts)
entries: List[Dict[str, Any]] = []
url: Optional[str] = self._buildEntityUrl(config, f"GeneralJournalEntries{queryParams}")
try:
async with aiohttp.ClientSession() as session:
while url:
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=60)) as resp:
if resp.status != 200:
break
data = await resp.json()
for item in data.get("value", []):
lines = []
totalAmt = 0.0
for line in (item.get("Lines") or []):
debit = float(line.get("DebitAmount", 0))
credit = float(line.get("CreditAmount", 0))
lines.append({
"accountNumber": str(line.get("AccountId", "")),
"debitAmount": debit,
"creditAmount": credit,
"description": line.get("Text", ""),
"taxCode": line.get("TaxCode"),
"costCenter": line.get("CostCenterId"),
})
totalAmt += max(debit, credit)
entries.append({
"externalId": str(item.get("Id", "")),
"bookingDate": str(item.get("JournalDate", "")).split("T")[0],
"reference": item.get("Reference", ""),
"description": item.get("Text", ""),
"currency": "CHF",
"totalAmount": totalAmt,
"lines": lines,
})
url = data.get("@odata.nextLink")
except Exception as e:
logger.error(f"Abacus getJournalEntries error: {e}")
return entries
async def getCustomers(self, config: Dict[str, Any]) -> List[Dict[str, Any]]:
headers = await self._buildAuthHeaders(config)
if not headers:

View file

@ -24,7 +24,7 @@ from ..accountingConnectorBase import (
logger = logging.getLogger(__name__)
_BASE_URL = "https://api.bexio.com"
_DEFAULT_API_BASE_URL = "https://api.bexio.com/"
class AccountingConnectorBexio(BaseAccountingConnector):
@ -40,6 +40,20 @@ class AccountingConnectorBexio(BaseAccountingConnector):
def getRequiredConfigFields(self) -> List[ConnectorConfigField]:
return [
ConnectorConfigField(
key="apiBaseUrl",
label={"en": "API Base URL", "de": "API Base URL", "fr": "URL de base API"},
fieldType="text",
secret=False,
placeholder="https://api.bexio.com/",
),
ConnectorConfigField(
key="clientName",
label={"en": "Client Name", "de": "Mandantenname", "fr": "Nom du client"},
fieldType="text",
secret=False,
placeholder="e.g. poweronag",
),
ConnectorConfigField(
key="accessToken",
label={"en": "Personal Access Token", "de": "Persönlicher Zugriffstoken", "fr": "Jeton d'accès personnel"},
@ -49,6 +63,14 @@ class AccountingConnectorBexio(BaseAccountingConnector):
),
]
def _buildUrl(self, config: Dict[str, Any], resource: str) -> str:
apiBaseUrl = str(config.get("apiBaseUrl") or "").strip()
if not apiBaseUrl:
raise ValueError("Missing required config: apiBaseUrl")
apiBaseUrl = apiBaseUrl.rstrip("/")
resourcePath = resource.lstrip("/")
return f"{apiBaseUrl}/{resourcePath}"
def _buildHeaders(self, config: Dict[str, Any]) -> Dict[str, str]:
return {
"Authorization": f"Bearer {config['accessToken']}",
@ -57,9 +79,20 @@ class AccountingConnectorBexio(BaseAccountingConnector):
}
async def testConnection(self, config: Dict[str, Any]) -> SyncResult:
apiBaseUrl = str(config.get("apiBaseUrl") or "")
clientName = str(config.get("clientName") or "")
accessToken = str(config.get("accessToken") or "")
if not apiBaseUrl or not clientName or not accessToken:
return SyncResult(
success=False,
errorMessage=(
f"Missing credentials: apiBaseUrl={bool(apiBaseUrl)}, "
f"clientName={bool(clientName)}, accessToken={bool(accessToken)}"
),
)
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"{_BASE_URL}/3.0/users/me", headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=15)) as resp:
async with session.get(self._buildUrl(config, "3.0/users/me"), headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=15)) as resp:
if resp.status == 200:
return SyncResult(success=True)
body = await resp.text()
@ -75,7 +108,7 @@ class AccountingConnectorBexio(BaseAccountingConnector):
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"{_BASE_URL}/2.0/accounts", headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=30)) as resp:
async with session.get(self._buildUrl(config, "2.0/accounts"), headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=30)) as resp:
if resp.status != 200:
return []
accounts = await resp.json()
@ -139,7 +172,7 @@ class AccountingConnectorBexio(BaseAccountingConnector):
}
async with aiohttp.ClientSession() as session:
url = f"{_BASE_URL}/3.0/accounting/manual-entries"
url = self._buildUrl(config, "3.0/accounting/manual-entries")
async with session.post(url, headers=self._buildHeaders(config), json=payload, timeout=aiohttp.ClientTimeout(total=30)) as resp:
body = await resp.json() if resp.content_type == "application/json" else {"raw": await resp.text()}
if resp.status in (200, 201):
@ -152,7 +185,7 @@ class AccountingConnectorBexio(BaseAccountingConnector):
async def getBookingStatus(self, config: Dict[str, Any], externalId: str) -> SyncResult:
try:
async with aiohttp.ClientSession() as session:
url = f"{_BASE_URL}/3.0/accounting/manual-entries/{externalId}"
url = self._buildUrl(config, f"3.0/accounting/manual-entries/{externalId}")
async with session.get(url, headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=15)) as resp:
if resp.status == 200:
return SyncResult(success=True, externalId=externalId)
@ -160,10 +193,66 @@ class AccountingConnectorBexio(BaseAccountingConnector):
except Exception as e:
return SyncResult(success=False, errorMessage=str(e))
async def getJournalEntries(self, config: Dict[str, Any], dateFrom: Optional[str] = None, dateTo: Optional[str] = None, accountNumbers: Optional[List[str]] = None) -> List[Dict[str, Any]]:
"""Read manual entries from Bexio. API: GET 3.0/accounting/manual-entries"""
try:
accounts = await self._loadRawAccounts(config)
accMap = {acc.get("id"): str(acc.get("account_no", "")) for acc in accounts}
async with aiohttp.ClientSession() as session:
url = self._buildUrl(config, "3.0/accounting/manual-entries")
params: Dict[str, str] = {}
if dateFrom:
params["date_from"] = dateFrom
if dateTo:
params["date_to"] = dateTo
async with session.get(url, headers=self._buildHeaders(config), params=params, timeout=aiohttp.ClientTimeout(total=60)) as resp:
if resp.status != 200:
logger.error(f"Bexio getJournalEntries failed: HTTP {resp.status}")
return []
items = await resp.json()
entries = []
for item in (items if isinstance(items, list) else []):
lines = []
totalAmt = 0.0
for e in (item.get("entries") or []):
amt = float(e.get("amount", 0))
debitAccId = e.get("debit_account_id")
creditAccId = e.get("credit_account_id")
lines.append({
"accountNumber": accMap.get(debitAccId, str(debitAccId or "")),
"debitAmount": amt,
"creditAmount": 0.0,
"description": e.get("description", ""),
"taxCode": str(e.get("tax_id", "")) if e.get("tax_id") else None,
})
if creditAccId and creditAccId != debitAccId:
lines.append({
"accountNumber": accMap.get(creditAccId, str(creditAccId or "")),
"debitAmount": 0.0,
"creditAmount": amt,
"description": e.get("description", ""),
})
totalAmt += amt
entries.append({
"externalId": str(item.get("id", "")),
"bookingDate": item.get("date", ""),
"reference": item.get("reference_nr", ""),
"description": item.get("text", ""),
"currency": "CHF",
"totalAmount": totalAmt,
"lines": lines,
})
return entries
except Exception as e:
logger.error(f"Bexio getJournalEntries error: {e}")
return []
async def getCustomers(self, config: Dict[str, Any]) -> List[Dict[str, Any]]:
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"{_BASE_URL}/2.0/contact", headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=30)) as resp:
async with session.get(self._buildUrl(config, "2.0/contact"), headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=30)) as resp:
if resp.status != 200:
return []
return await resp.json()

View file

@ -3,7 +3,8 @@
"""Run My Accounts (Infoniqa) accounting connector.
API docs: https://runmyaccountsag.github.io/runmyaccounts-rest-api/
Auth: API key (incl. ``pat_`` tokens since Sep 2025) via ``X-RMA-KEY`` request header.
Auth: PAT tokens (``pat_...``) via ``Authorization: Bearer``.
Fallback for legacy API keys via ``X-RMA-KEY``.
Base URL: https://service.runmyaccounts.com/api/latest/clients/{clientName}/
"""
@ -26,7 +27,7 @@ from ..accountingConnectorBase import (
logger = logging.getLogger(__name__)
_BASE_URL = "https://service.runmyaccounts.com/api/latest/clients"
_DEFAULT_API_BASE_URL = "https://service.runmyaccounts.com/api/latest/clients/"
class AccountingConnectorRma(BaseAccountingConnector):
@ -39,6 +40,13 @@ class AccountingConnectorRma(BaseAccountingConnector):
def getRequiredConfigFields(self) -> List[ConnectorConfigField]:
return [
ConnectorConfigField(
key="apiBaseUrl",
label={"en": "API Base URL", "de": "API Base URL", "fr": "URL de base API"},
fieldType="text",
secret=False,
placeholder="https://service.runmyaccounts.com/api/latest/clients/",
),
ConnectorConfigField(
key="clientName",
label={"en": "Client Name", "de": "Mandantenname", "fr": "Nom du client"},
@ -55,33 +63,55 @@ class AccountingConnectorRma(BaseAccountingConnector):
]
def _buildUrl(self, config: Dict[str, Any], resource: str) -> str:
clientName = config.get("clientName", "")
return f"{_BASE_URL}/{clientName}/{resource}"
apiBaseUrl = str(config.get("apiBaseUrl") or "").strip()
if not apiBaseUrl:
raise ValueError("Missing required config: apiBaseUrl")
apiBaseUrl = apiBaseUrl.rstrip("/") + "/"
clientName = str(config.get("clientName") or "").strip()
if not clientName:
raise ValueError("Missing required config: clientName")
return f"{apiBaseUrl}{clientName}/{resource}"
def _buildHeaders(self, config: Dict[str, Any]) -> Dict[str, str]:
apiKey = config.get("apiKey", "")
return {
"X-RMA-KEY": apiKey,
headers = {
"Accept": "application/json, application/xml, */*",
"Content-Type": "application/json",
}
if str(apiKey).startswith("pat_"):
headers["Authorization"] = f"Bearer {apiKey}"
else:
headers["X-RMA-KEY"] = apiKey
return headers
async def testConnection(self, config: Dict[str, Any]) -> SyncResult:
clientName = config.get("clientName", "")
apiKey = config.get("apiKey", "")
if not clientName or not apiKey:
return SyncResult(success=False, errorMessage=f"Missing credentials: clientName={bool(clientName)}, apiKey={bool(apiKey)}")
apiBaseUrl = str(config.get("apiBaseUrl") or "")
if not clientName or not apiKey or not apiBaseUrl:
return SyncResult(
success=False,
errorMessage=(
f"Missing credentials: apiBaseUrl={bool(apiBaseUrl)}, "
f"clientName={bool(clientName)}, apiKey={bool(apiKey)}"
),
)
url = self._buildUrl(config, "customers")
headers = self._buildHeaders(config)
logger.info("RMA testConnection: url=%s, clientName=%s, apiKey=%s...", url, clientName, apiKey[:6] if len(apiKey) > 6 else "***")
authMethod = "Bearer" if str(apiKey).startswith("pat_") else "X-RMA-KEY"
logger.info(
"RMA testConnection: url=%s, clientName=%s, apiKey=%s..., auth=%s",
url, clientName, apiKey[:6] if len(apiKey) > 6 else "***", authMethod,
)
try:
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=15)) as resp:
if resp.status == 200:
logger.info("RMA connection successful")
return SyncResult(success=True)
logger.info("RMA connection successful with auth method: %s", authMethod)
return SyncResult(success=True, rawResponse={"authMethod": authMethod})
body = await resp.text()
logger.warning("RMA testConnection failed: status=%s, url=%s, body=%s", resp.status, url, body[:500])
return SyncResult(success=False, errorMessage=f"HTTP {resp.status}: {body[:300]}")
@ -120,11 +150,11 @@ class AccountingConnectorRma(BaseAccountingConnector):
charts = []
items = data if isinstance(data, list) else data.get("chart", data.get("row", []))
if not isinstance(items, list):
items = []
items = [items] if isinstance(items, dict) else []
for item in items:
if isinstance(item, dict):
accNo = str(item.get("accno", item.get("account_number", "")))
label = str(item.get("description", item.get("label", "")))
accNo = str(item.get("accno") or item.get("account_number") or item.get("number") or item.get("@accno") or "")
label = str(item.get("description") or item.get("label") or item.get("@description") or "")
rmaLink = item.get("link") or ""
chartType = item.get("charttype") or item.get("category") or ""
if not chartType and rmaLink:
@ -308,6 +338,169 @@ class AccountingConnectorRma(BaseAccountingConnector):
logger.debug("RMA isBookingSynced error: %s trust local", e)
return SyncResult(success=True)
async def getJournalEntries(self, config: Dict[str, Any], dateFrom: Optional[str] = None, dateTo: Optional[str] = None, accountNumbers: Optional[List[str]] = None) -> List[Dict[str, Any]]:
"""Read GL entries from RMA.
Strategy: first try GET /gl (bulk), then fall back to iterating
account transactions. Uses pre-fetched accountNumbers if provided.
"""
try:
params: Dict[str, str] = {}
if dateFrom:
params["from_date"] = dateFrom
if dateTo:
params["to_date"] = dateTo
# Try bulk GL endpoint first
bulkEntries = await self._fetchGlBulk(config, params)
if bulkEntries:
return bulkEntries
# Fallback: iterate accounts and fetch transactions
if accountNumbers:
accNums = accountNumbers
else:
chart = await self.getChartOfAccounts(config)
accNums = [acc.accountNumber for acc in chart if acc.accountNumber]
if not accNums:
return []
entriesByRef: Dict[str, Dict[str, Any]] = {}
fetchedCount = 0
emptyCount = 0
errorCount = 0
async with aiohttp.ClientSession() as session:
for accNo in accNums:
url = self._buildUrl(config, f"charts/{accNo}/transactions")
try:
async with session.get(url, headers=self._buildHeaders(config), params=params, timeout=aiohttp.ClientTimeout(total=10)) as resp:
if resp.status != 200:
emptyCount += 1
continue
body = await resp.text()
if not body.strip():
emptyCount += 1
continue
try:
data = json.loads(body)
except Exception:
errorCount += 1
continue
except (asyncio.TimeoutError, Exception):
errorCount += 1
continue
fetchedCount += 1
if isinstance(data, dict):
transactions = data.get("transaction") or data.get("@transaction")
else:
transactions = data
if isinstance(transactions, dict):
transactions = [transactions]
if not isinstance(transactions, list):
continue
for t in transactions:
if not isinstance(t, dict):
continue
ref = t.get("reference") or t.get("@reference") or t.get("batch_number") or str(t.get("id") or "")
transDate = str(t.get("transdate") or t.get("@transdate") or "").split("T")[0]
desc = t.get("description") or t.get("memo") or t.get("@description") or ""
rawAmount = float(t.get("amount") or t.get("@amount") or 0)
debit = rawAmount if rawAmount > 0 else 0.0
credit = abs(rawAmount) if rawAmount < 0 else 0.0
if ref not in entriesByRef:
entriesByRef[ref] = {
"externalId": str(t.get("id") or t.get("@id") or ref),
"bookingDate": transDate,
"reference": ref,
"description": desc,
"currency": "CHF",
"totalAmount": 0.0,
"lines": [],
}
entry = entriesByRef[ref]
entry["lines"].append({
"accountNumber": accNo,
"debitAmount": debit,
"creditAmount": credit,
"description": desc,
})
entry["totalAmount"] += max(debit, credit)
return list(entriesByRef.values())
except Exception as e:
logger.error(f"RMA getJournalEntries error: {e}", exc_info=True)
return []
async def _fetchGlBulk(self, config: Dict[str, Any], params: Dict[str, str]) -> List[Dict[str, Any]]:
"""Try GET /gl to fetch journal entries in bulk (not all RMA versions support this)."""
try:
async with aiohttp.ClientSession() as session:
url = self._buildUrl(config, "gl")
async with session.get(url, headers=self._buildHeaders(config), params=params, timeout=aiohttp.ClientTimeout(total=60)) as resp:
if resp.status != 200:
return []
body = await resp.text()
if not body.strip():
return []
try:
data = json.loads(body)
except Exception:
return []
items = data if isinstance(data, list) else (data.get("gl_batch") or data.get("gl") or data.get("items") or [])
if isinstance(items, dict):
items = [items]
if not isinstance(items, list):
return []
entries = []
for batch in items:
if not isinstance(batch, dict):
continue
transDate = str(batch.get("transdate") or batch.get("date") or "").split("T")[0]
ref = batch.get("batch_number") or batch.get("reference") or str(batch.get("id", ""))
desc = batch.get("description") or batch.get("notes") or ""
rawTxns = batch.get("gl_transactions", {})
txnList = rawTxns.get("gl_transaction") if isinstance(rawTxns, dict) else rawTxns
if isinstance(txnList, dict):
txnList = [txnList]
if not isinstance(txnList, list):
txnList = []
lines = []
totalAmt = 0.0
for t in txnList:
if not isinstance(t, dict):
continue
debit = float(t.get("debit_amount") or 0)
credit = float(t.get("credit_amount") or 0)
lines.append({
"accountNumber": str(t.get("accno", "")),
"debitAmount": debit,
"creditAmount": credit,
"description": t.get("memo", ""),
})
totalAmt += max(debit, credit)
entries.append({
"externalId": str(batch.get("id", ref)),
"bookingDate": transDate,
"reference": ref,
"description": desc,
"currency": batch.get("currency", "CHF"),
"totalAmount": totalAmt,
"lines": lines,
})
return entries
except Exception as e:
logger.debug(f"RMA _fetchGlBulk not available: {e}")
return []
async def pushInvoice(self, config: Dict[str, Any], invoice: Dict[str, Any]) -> SyncResult:
try:
async with aiohttp.ClientSession() as session:
@ -327,8 +520,8 @@ class AccountingConnectorRma(BaseAccountingConnector):
async with session.get(url, headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=30)) as resp:
if resp.status != 200:
return []
data = await resp.json()
return data if isinstance(data, list) else data.get("customer", [])
data = await self._parseJsonOrXmlList(resp, "customer")
return data
except Exception as e:
logger.error(f"RMA getCustomers error: {e}")
return []
@ -340,12 +533,39 @@ class AccountingConnectorRma(BaseAccountingConnector):
async with session.get(url, headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=30)) as resp:
if resp.status != 200:
return []
data = await resp.json()
return data if isinstance(data, list) else data.get("vendor", [])
data = await self._parseJsonOrXmlList(resp, "vendor")
return data
except Exception as e:
logger.error(f"RMA getVendors error: {e}")
return []
async def _parseJsonOrXmlList(self, resp: aiohttp.ClientResponse, itemKey: str) -> List[Dict[str, Any]]:
"""Parse RMA response that may be JSON or XML. Returns list of dicts."""
body = await resp.text()
if not body or not body.strip():
return []
try:
data = json.loads(body)
if isinstance(data, list):
return data
if isinstance(data, dict):
items = data.get(itemKey) or data.get("items") or data.get("row") or []
if isinstance(items, dict):
return [items]
return items if isinstance(items, list) else []
return []
except (json.JSONDecodeError, ValueError):
pass
result: List[Dict[str, Any]] = []
ids = re.findall(r"<id>([^<]+)</id>", body)
names = re.findall(r"<name>([^<]+)</name>", body)
for i, rid in enumerate(ids):
entry: Dict[str, Any] = {"id": rid.strip()}
if i < len(names):
entry["name"] = names[i].strip()
result.append(entry)
return result
async def _findBelegByFilename(self, config: Dict[str, Any], session: aiohttp.ClientSession, fileName: str) -> Optional[str]:
"""Try GET /belege (undocumented) to find an existing beleg by filename."""
try:

View file

@ -736,6 +736,177 @@ registerModelLabels(
)
# ── TrusteeData* tables (synced from external accounting apps for analysis) ──
class TrusteeDataAccount(BaseModel):
"""Chart of accounts synced from external accounting system."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
accountNumber: str = Field(description="Account number (e.g. '1020')")
label: str = Field(default="", description="Account name")
accountType: Optional[str] = Field(default=None, description="asset / liability / equity / revenue / expense")
accountGroup: Optional[str] = Field(default=None, description="Account group/category")
currency: str = Field(default="CHF", description="Account currency")
isActive: bool = Field(default=True)
mandateId: Optional[str] = Field(default=None)
featureInstanceId: Optional[str] = Field(default=None)
registerModelLabels(
"TrusteeDataAccount",
{"en": "Account (Synced)", "de": "Konto (Sync)", "fr": "Compte (Sync)"},
{
"id": {"en": "ID", "de": "ID", "fr": "ID"},
"accountNumber": {"en": "Account Number", "de": "Kontonummer", "fr": "Numéro de compte"},
"label": {"en": "Name", "de": "Bezeichnung", "fr": "Libellé"},
"accountType": {"en": "Type", "de": "Typ", "fr": "Type"},
"accountGroup": {"en": "Group", "de": "Gruppe", "fr": "Groupe"},
"currency": {"en": "Currency", "de": "Währung", "fr": "Devise"},
"isActive": {"en": "Active", "de": "Aktiv", "fr": "Actif"},
"mandateId": {"en": "Mandate", "de": "Mandat", "fr": "Mandat"},
"featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance"},
},
)
class TrusteeDataJournalEntry(BaseModel):
"""Journal entry header synced from external accounting system."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
externalId: Optional[str] = Field(default=None, description="ID in the source system")
bookingDate: Optional[str] = Field(default=None, description="Booking date (YYYY-MM-DD)")
reference: Optional[str] = Field(default=None, description="Booking reference / voucher number")
description: str = Field(default="", description="Booking text")
currency: str = Field(default="CHF")
totalAmount: float = Field(default=0.0, description="Total amount of entry")
mandateId: Optional[str] = Field(default=None)
featureInstanceId: Optional[str] = Field(default=None)
registerModelLabels(
"TrusteeDataJournalEntry",
{"en": "Journal Entry (Synced)", "de": "Buchung (Sync)", "fr": "Écriture (Sync)"},
{
"id": {"en": "ID", "de": "ID", "fr": "ID"},
"externalId": {"en": "External ID", "de": "Externe ID", "fr": "ID externe"},
"bookingDate": {"en": "Date", "de": "Datum", "fr": "Date"},
"reference": {"en": "Reference", "de": "Referenz", "fr": "Référence"},
"description": {"en": "Description", "de": "Beschreibung", "fr": "Description"},
"currency": {"en": "Currency", "de": "Währung", "fr": "Devise"},
"totalAmount": {"en": "Amount", "de": "Betrag", "fr": "Montant"},
"mandateId": {"en": "Mandate", "de": "Mandat", "fr": "Mandat"},
"featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance"},
},
)
class TrusteeDataJournalLine(BaseModel):
"""Journal entry line (debit/credit) synced from external accounting system."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
journalEntryId: str = Field(description="FK → TrusteeDataJournalEntry.id")
accountNumber: str = Field(description="Account number")
debitAmount: float = Field(default=0.0)
creditAmount: float = Field(default=0.0)
currency: str = Field(default="CHF")
taxCode: Optional[str] = Field(default=None)
costCenter: Optional[str] = Field(default=None)
description: str = Field(default="")
mandateId: Optional[str] = Field(default=None)
featureInstanceId: Optional[str] = Field(default=None)
registerModelLabels(
"TrusteeDataJournalLine",
{"en": "Journal Line (Synced)", "de": "Buchungszeile (Sync)", "fr": "Ligne écriture (Sync)"},
{
"id": {"en": "ID", "de": "ID", "fr": "ID"},
"journalEntryId": {"en": "Journal Entry", "de": "Buchung", "fr": "Écriture"},
"accountNumber": {"en": "Account", "de": "Konto", "fr": "Compte"},
"debitAmount": {"en": "Debit", "de": "Soll", "fr": "Débit"},
"creditAmount": {"en": "Credit", "de": "Haben", "fr": "Crédit"},
"currency": {"en": "Currency", "de": "Währung", "fr": "Devise"},
"taxCode": {"en": "Tax Code", "de": "Steuercode", "fr": "Code TVA"},
"costCenter": {"en": "Cost Center", "de": "Kostenstelle", "fr": "Centre de coûts"},
"description": {"en": "Description", "de": "Beschreibung", "fr": "Description"},
"mandateId": {"en": "Mandate", "de": "Mandat", "fr": "Mandat"},
"featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance"},
},
)
class TrusteeDataContact(BaseModel):
"""Customer or vendor synced from external accounting system."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
externalId: Optional[str] = Field(default=None, description="ID in the source system")
contactType: str = Field(default="customer", description="customer / vendor / both")
contactNumber: Optional[str] = Field(default=None, description="Customer/vendor number")
name: str = Field(default="", description="Name / company")
address: Optional[str] = Field(default=None)
zip: Optional[str] = Field(default=None)
city: Optional[str] = Field(default=None)
country: Optional[str] = Field(default=None)
email: Optional[str] = Field(default=None)
phone: Optional[str] = Field(default=None)
vatNumber: Optional[str] = Field(default=None)
mandateId: Optional[str] = Field(default=None)
featureInstanceId: Optional[str] = Field(default=None)
registerModelLabels(
"TrusteeDataContact",
{"en": "Contact (Synced)", "de": "Kontakt (Sync)", "fr": "Contact (Sync)"},
{
"id": {"en": "ID", "de": "ID", "fr": "ID"},
"externalId": {"en": "External ID", "de": "Externe ID", "fr": "ID externe"},
"contactType": {"en": "Type", "de": "Typ", "fr": "Type"},
"contactNumber": {"en": "Number", "de": "Nummer", "fr": "Numéro"},
"name": {"en": "Name", "de": "Name", "fr": "Nom"},
"address": {"en": "Address", "de": "Adresse", "fr": "Adresse"},
"zip": {"en": "ZIP", "de": "PLZ", "fr": "NPA"},
"city": {"en": "City", "de": "Ort", "fr": "Ville"},
"country": {"en": "Country", "de": "Land", "fr": "Pays"},
"email": {"en": "Email", "de": "E-Mail", "fr": "E-mail"},
"phone": {"en": "Phone", "de": "Telefon", "fr": "Téléphone"},
"vatNumber": {"en": "VAT Number", "de": "MWST-Nr.", "fr": "N° TVA"},
"mandateId": {"en": "Mandate", "de": "Mandat", "fr": "Mandat"},
"featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance"},
},
)
class TrusteeDataAccountBalance(BaseModel):
"""Account balance per period, derived from journal lines or directly from accounting system."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
accountNumber: str = Field(description="Account number")
periodYear: int = Field(description="Fiscal year")
periodMonth: int = Field(default=0, description="Month (1-12); 0 = annual total")
openingBalance: float = Field(default=0.0)
debitTotal: float = Field(default=0.0)
creditTotal: float = Field(default=0.0)
closingBalance: float = Field(default=0.0)
currency: str = Field(default="CHF")
mandateId: Optional[str] = Field(default=None)
featureInstanceId: Optional[str] = Field(default=None)
registerModelLabels(
"TrusteeDataAccountBalance",
{"en": "Account Balance (Synced)", "de": "Kontosaldo (Sync)", "fr": "Solde compte (Sync)"},
{
"id": {"en": "ID", "de": "ID", "fr": "ID"},
"accountNumber": {"en": "Account", "de": "Konto", "fr": "Compte"},
"periodYear": {"en": "Year", "de": "Jahr", "fr": "Année"},
"periodMonth": {"en": "Month", "de": "Monat", "fr": "Mois"},
"openingBalance": {"en": "Opening Balance", "de": "Eröffnungssaldo", "fr": "Solde d'ouverture"},
"debitTotal": {"en": "Debit Total", "de": "Soll-Umsatz", "fr": "Total débit"},
"creditTotal": {"en": "Credit Total", "de": "Haben-Umsatz", "fr": "Total crédit"},
"closingBalance": {"en": "Closing Balance", "de": "Schlusssaldo", "fr": "Solde de clôture"},
"currency": {"en": "Currency", "de": "Währung", "fr": "Devise"},
"mandateId": {"en": "Mandate", "de": "Mandat", "fr": "Mandat"},
"featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance"},
},
)
class TrusteeAccountingConfig(BaseModel):
"""Per-instance accounting system configuration with encrypted credentials.

View file

@ -78,6 +78,31 @@ DATA_OBJECTS = [
"label": {"en": "Accounting Sync", "de": "Buchhaltungs-Synchronisation", "fr": "Sync. comptable"},
"meta": {"table": "TrusteeAccountingSync", "fields": ["id", "positionId", "syncStatus", "externalId"]}
},
{
"objectKey": "data.feature.trustee.TrusteeDataAccount",
"label": {"en": "Accounts (Synced)", "de": "Kontenplan (Sync)", "fr": "Plan comptable (Sync)"},
"meta": {"table": "TrusteeDataAccount", "fields": ["id", "accountNumber", "label", "accountType", "accountGroup", "currency", "isActive"]}
},
{
"objectKey": "data.feature.trustee.TrusteeDataJournalEntry",
"label": {"en": "Journal Entries (Synced)", "de": "Buchungen (Sync)", "fr": "Écritures (Sync)"},
"meta": {"table": "TrusteeDataJournalEntry", "fields": ["id", "externalId", "bookingDate", "reference", "description", "currency", "totalAmount"]}
},
{
"objectKey": "data.feature.trustee.TrusteeDataJournalLine",
"label": {"en": "Journal Lines (Synced)", "de": "Buchungszeilen (Sync)", "fr": "Lignes écriture (Sync)"},
"meta": {"table": "TrusteeDataJournalLine", "fields": ["id", "journalEntryId", "accountNumber", "debitAmount", "creditAmount", "currency", "taxCode", "costCenter", "description"]}
},
{
"objectKey": "data.feature.trustee.TrusteeDataContact",
"label": {"en": "Contacts (Synced)", "de": "Kontakte (Sync)", "fr": "Contacts (Sync)"},
"meta": {"table": "TrusteeDataContact", "fields": ["id", "externalId", "contactType", "contactNumber", "name", "address", "zip", "city", "country", "email", "phone", "vatNumber"]}
},
{
"objectKey": "data.feature.trustee.TrusteeDataAccountBalance",
"label": {"en": "Account Balances (Synced)", "de": "Kontosalden (Sync)", "fr": "Soldes comptes (Sync)"},
"meta": {"table": "TrusteeDataAccountBalance", "fields": ["id", "accountNumber", "periodYear", "periodMonth", "openingBalance", "debitTotal", "creditTotal", "closingBalance", "currency"]}
},
{
"objectKey": "data.feature.trustee.*",
"label": {"en": "All Trustee Data", "de": "Alle Treuhand-Daten", "fr": "Toutes les données fiduciaires"},

View file

@ -188,7 +188,7 @@ def get_mime_type_options(
"""Get supported MIME types from the document extraction service.
Returns: [{ value: "mime/type", label: "Description" }]
"""
from modules.services.serviceExtraction.subRegistry import ExtractorRegistry
from modules.serviceCenter.services.serviceExtraction.subRegistry import ExtractorRegistry
registry = ExtractorRegistry()
formats = registry.getSupportedFormats()
@ -1481,6 +1481,63 @@ def get_position_sync_status(
return {"items": items}
# ===== Accounting Data Import =====
@router.post("/{instanceId}/accounting/import-data")
@limiter.limit("3/minute")
async def import_accounting_data(
request: Request,
instanceId: str = Path(..., description="Feature Instance ID"),
data: Dict[str, Any] = Body(default={}),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Import accounting data (chart, journal entries, contacts) from the external system into TrusteeData* tables."""
mandateId = _validateInstanceAccess(instanceId, context)
interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
from .accounting.accountingDataSync import AccountingDataSync
sync = AccountingDataSync(interface)
dateFrom = data.get("dateFrom")
dateTo = data.get("dateTo")
result = await sync.importData(
featureInstanceId=instanceId,
mandateId=mandateId,
dateFrom=dateFrom,
dateTo=dateTo,
)
return result
@router.get("/{instanceId}/accounting/import-status")
@limiter.limit("30/minute")
def get_import_status(
request: Request,
instanceId: str = Path(..., description="Feature Instance ID"),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Get counts of imported TrusteeData* records for this instance."""
mandateId = _validateInstanceAccess(instanceId, context)
interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
from .datamodelFeatureTrustee import (
TrusteeDataAccount, TrusteeDataJournalEntry, TrusteeDataJournalLine,
TrusteeDataContact, TrusteeDataAccountBalance, TrusteeAccountingConfig,
)
filt = {"featureInstanceId": instanceId}
counts = {
"accounts": len(interface.db.getRecordset(TrusteeDataAccount, recordFilter=filt) or []),
"journalEntries": len(interface.db.getRecordset(TrusteeDataJournalEntry, recordFilter=filt) or []),
"journalLines": len(interface.db.getRecordset(TrusteeDataJournalLine, recordFilter=filt) or []),
"contacts": len(interface.db.getRecordset(TrusteeDataContact, recordFilter=filt) or []),
"accountBalances": len(interface.db.getRecordset(TrusteeDataAccountBalance, recordFilter=filt) or []),
}
cfgRecords = interface.db.getRecordset(TrusteeAccountingConfig, recordFilter={"featureInstanceId": instanceId, "isActive": True})
if cfgRecords:
cfg = cfgRecords[0]
counts["lastSyncAt"] = cfg.get("lastSyncAt")
counts["lastSyncStatus"] = cfg.get("lastSyncStatus")
counts["lastSyncErrorMessage"] = cfg.get("lastSyncErrorMessage")
return counts
# ===== Position-Document Query =====
@router.get("/{instanceId}/positions/document/{documentId}", response_model=List[TrusteePosition])

View file

@ -1,2 +1,3 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Unified AI Workspace feature."""

View file

@ -1,9 +1,9 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
CodeEditor Feature Container - Main Module.
Workspace Feature Container - Main Module.
Handles feature initialization and RBAC catalog registration.
Cursor-style AI file editing via chat interface.
Unified AI Workspace feature.
"""
import logging
@ -11,88 +11,108 @@ from typing import Dict, List, Any
logger = logging.getLogger(__name__)
FEATURE_CODE = "codeeditor"
FEATURE_LABEL = {"en": "Code Editor", "de": "Code Editor", "fr": "Code Editor"}
FEATURE_ICON = "mdi-file-document-edit"
FEATURE_CODE = "workspace"
FEATURE_LABEL = {"en": "AI Workspace", "de": "AI Workspace", "fr": "AI Workspace"}
FEATURE_ICON = "mdi-brain"
UI_OBJECTS = [
{
"objectKey": "ui.feature.codeeditor.editor",
"objectKey": "ui.feature.workspace.dashboard",
"label": {"en": "Dashboard", "de": "Dashboard", "fr": "Tableau de bord"},
"meta": {"area": "dashboard"}
},
{
"objectKey": "ui.feature.workspace.editor",
"label": {"en": "Editor", "de": "Editor", "fr": "Editeur"},
"meta": {"area": "editor"}
},
{
"objectKey": "ui.feature.codeeditor.workflows",
"label": {"en": "Workflows", "de": "Workflows", "fr": "Workflows"},
"meta": {"area": "workflows"}
"objectKey": "ui.feature.workspace.settings",
"label": {"en": "Settings", "de": "Einstellungen", "fr": "Parametres"},
"meta": {"area": "settings"}
},
]
RESOURCE_OBJECTS = [
{
"objectKey": "resource.feature.codeeditor.start",
"label": {"en": "Start Workflow", "de": "Workflow starten", "fr": "Demarrer workflow"},
"meta": {"endpoint": "/api/codeeditor/{instanceId}/start/stream", "method": "POST"}
"objectKey": "resource.feature.workspace.start",
"label": {"en": "Start Agent", "de": "Agent starten", "fr": "Demarrer agent"},
"meta": {"endpoint": "/api/workspace/{instanceId}/start/stream", "method": "POST"}
},
{
"objectKey": "resource.feature.codeeditor.stop",
"label": {"en": "Stop Workflow", "de": "Workflow stoppen", "fr": "Arreter workflow"},
"meta": {"endpoint": "/api/codeeditor/{instanceId}/{workflowId}/stop", "method": "POST"}
"objectKey": "resource.feature.workspace.stop",
"label": {"en": "Stop Agent", "de": "Agent stoppen", "fr": "Arreter agent"},
"meta": {"endpoint": "/api/workspace/{instanceId}/{workflowId}/stop", "method": "POST"}
},
{
"objectKey": "resource.feature.codeeditor.chatData",
"label": {"en": "Get Chat Data", "de": "Chat-Daten abrufen", "fr": "Recuperer donnees chat"},
"meta": {"endpoint": "/api/codeeditor/{instanceId}/{workflowId}/chatData", "method": "GET"}
},
{
"objectKey": "resource.feature.codeeditor.files",
"objectKey": "resource.feature.workspace.files",
"label": {"en": "Manage Files", "de": "Dateien verwalten", "fr": "Gerer fichiers"},
"meta": {"endpoint": "/api/codeeditor/{instanceId}/files", "method": "GET"}
"meta": {"endpoint": "/api/workspace/{instanceId}/files", "method": "GET"}
},
{
"objectKey": "resource.feature.codeeditor.apply",
"label": {"en": "Apply Edit", "de": "Aenderung anwenden", "fr": "Appliquer modification"},
"meta": {"endpoint": "/api/codeeditor/{instanceId}/{workflowId}/apply", "method": "POST"}
"objectKey": "resource.feature.workspace.folders",
"label": {"en": "Manage Folders", "de": "Ordner verwalten", "fr": "Gerer dossiers"},
"meta": {"endpoint": "/api/workspace/{instanceId}/folders", "method": "GET"}
},
{
"objectKey": "resource.feature.workspace.datasources",
"label": {"en": "Data Sources", "de": "Datenquellen", "fr": "Sources de donnees"},
"meta": {"endpoint": "/api/workspace/{instanceId}/datasources", "method": "GET"}
},
{
"objectKey": "resource.feature.workspace.voice",
"label": {"en": "Voice Input/Output", "de": "Spracheingabe/-ausgabe", "fr": "Entree/sortie vocale"},
"meta": {"endpoint": "/api/workspace/{instanceId}/voice/*", "method": "POST"}
},
{
"objectKey": "resource.feature.workspace.edits",
"label": {"en": "Review File Edits", "de": "Datei-Aenderungen pruefen", "fr": "Verifier les modifications de fichiers"},
"meta": {"endpoint": "/api/workspace/{instanceId}/edit/*", "method": "POST"}
},
]
TEMPLATE_ROLES = [
{
"roleLabel": "codeeditor-viewer",
"roleLabel": "workspace-viewer",
"description": {
"en": "Code Editor Viewer - View editor (read-only)",
"de": "Code Editor Betrachter - Editor ansehen (nur lesen)",
"fr": "Visualiseur Code Editor - Consulter l'editeur (lecture seule)"
"en": "Workspace Viewer - View workspace (read-only)",
"de": "Workspace Betrachter - Workspace ansehen (nur lesen)",
"fr": "Visualiseur Workspace - Consulter le workspace (lecture seule)"
},
"accessRules": [
{"context": "UI", "item": "ui.feature.codeeditor.editor", "view": True},
{"context": "UI", "item": "ui.feature.workspace.dashboard", "view": True},
{"context": "UI", "item": "ui.feature.workspace.editor", "view": True},
{"context": "UI", "item": "ui.feature.workspace.settings", "view": True},
{"context": "DATA", "item": None, "view": True, "read": "m", "create": "n", "update": "n", "delete": "n"},
]
},
{
"roleLabel": "codeeditor-user",
"roleLabel": "workspace-user",
"description": {
"en": "Code Editor User - Use editor and workflows",
"de": "Code Editor Benutzer - Editor und Workflows nutzen",
"fr": "Utilisateur Code Editor - Utiliser l'editeur et les workflows"
"en": "Workspace User - Use AI workspace and tools",
"de": "Workspace Benutzer - AI Workspace und Tools nutzen",
"fr": "Utilisateur Workspace - Utiliser l'espace de travail AI et les outils"
},
"accessRules": [
{"context": "UI", "item": "ui.feature.codeeditor.editor", "view": True},
{"context": "UI", "item": "ui.feature.codeeditor.workflows", "view": True},
{"context": "RESOURCE", "item": "resource.feature.codeeditor.start", "view": True},
{"context": "RESOURCE", "item": "resource.feature.codeeditor.stop", "view": True},
{"context": "RESOURCE", "item": "resource.feature.codeeditor.chatData", "view": True},
{"context": "RESOURCE", "item": "resource.feature.codeeditor.files", "view": True},
{"context": "RESOURCE", "item": "resource.feature.codeeditor.apply", "view": True},
{"context": "UI", "item": "ui.feature.workspace.dashboard", "view": True},
{"context": "UI", "item": "ui.feature.workspace.editor", "view": True},
{"context": "UI", "item": "ui.feature.workspace.settings", "view": True},
{"context": "RESOURCE", "item": "resource.feature.workspace.start", "view": True},
{"context": "RESOURCE", "item": "resource.feature.workspace.stop", "view": True},
{"context": "RESOURCE", "item": "resource.feature.workspace.files", "view": True},
{"context": "RESOURCE", "item": "resource.feature.workspace.folders", "view": True},
{"context": "RESOURCE", "item": "resource.feature.workspace.datasources", "view": True},
{"context": "RESOURCE", "item": "resource.feature.workspace.voice", "view": True},
{"context": "RESOURCE", "item": "resource.feature.workspace.edits", "view": True},
{"context": "DATA", "item": None, "view": True, "read": "m", "create": "m", "update": "m", "delete": "m"},
]
},
{
"roleLabel": "codeeditor-admin",
"roleLabel": "workspace-admin",
"description": {
"en": "Code Editor Admin - Full access to code editor",
"de": "Code Editor Admin - Vollzugriff auf Code Editor",
"fr": "Administrateur Code Editor - Acces complet au code editor"
"en": "Workspace Admin - Full access to AI workspace",
"de": "Workspace Admin - Vollzugriff auf AI Workspace",
"fr": "Administrateur Workspace - Acces complet au workspace AI"
},
"accessRules": [
{"context": "UI", "item": None, "view": True},

File diff suppressed because it is too large Load diff

View file

@ -4,7 +4,7 @@ import logging
import asyncio
import uuid
import base64
from typing import Dict, Any, List, Union, Tuple, Optional, Callable
from typing import Dict, Any, List, Union, Tuple, Optional, Callable, AsyncGenerator
from dataclasses import dataclass, field
import time
@ -12,6 +12,7 @@ logger = logging.getLogger(__name__)
from modules.aicore.aicoreModelRegistry import modelRegistry
from modules.aicore.aicoreModelSelector import modelSelector
from modules.aicore.aicoreBase import RateLimitExceededException
from modules.datamodels.datamodelAi import (
AiModel,
AiCallOptions,
@ -84,27 +85,31 @@ class AiObjects:
# AI for Extraction, Processing, Generation
async def callWithTextContext(self, request: AiCallRequest) -> AiCallResponse:
"""Call AI model for traditional text/context calls with fallback mechanism."""
"""Call AI model for traditional text/context calls with fallback mechanism.
Supports two modes:
- Legacy: prompt + context constructs messages internally
- Agent: request.messages provided passes through directly
"""
prompt = request.prompt
context = request.context or ""
options = request.options
# Input bytes will be calculated inside _callWithModel
# Generation parameters are handled inside _callWithModel
# Get failover models for this operation type
availableModels = modelRegistry.getAvailableModels()
# Filter by allowedProviders if specified (from workflow config)
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
if allowedProviders:
filteredModels = [m for m in availableModels if m.connectorType in allowedProviders]
if filteredModels:
logger.info(f"Filtered models by allowedProviders {allowedProviders}: {len(filteredModels)} models (from {len(availableModels)})")
availableModels = filteredModels
else:
logger.warning(f"No models match allowedProviders {allowedProviders}, using all {len(availableModels)} available models")
errorMsg = f"No models match allowedProviders {allowedProviders} for operation {options.operationType}"
logger.error(errorMsg)
return AiCallResponse(
content=errorMsg, modelName="error", priceCHF=0.0,
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1,
)
failoverModelList = modelSelector.getFailoverModelList(prompt, context, options, availableModels)
@ -121,18 +126,46 @@ class AiObjects:
errorCount=1
)
# Try each model in failover sequence
_MAX_SHORT_RETRY = 15.0
lastError = None
for attempt, model in enumerate(failoverModelList):
try:
logger.info(f"Attempting AI call with model: {model.name} (attempt {attempt + 1}/{len(failoverModelList)})")
# Call the model directly - no truncation or compression here
if request.messages:
response = await self._callWithMessages(model, request.messages, options, request.tools)
else:
response = await self._callWithModel(model, prompt, context, options)
logger.info(f"AI call successful with model: {model.name}")
logger.info(f"AI call successful with model: {model.name}")
return response
except RateLimitExceededException as rle:
retryAfter = rle.retryAfterSeconds
lastError = rle
if 0 < retryAfter <= _MAX_SHORT_RETRY:
logger.info(f"Rate limit on {model.name}, waiting {retryAfter:.1f}s before retry")
await asyncio.sleep(retryAfter + 0.5)
try:
if request.messages:
response = await self._callWithMessages(model, request.messages, options, request.tools)
else:
response = await self._callWithModel(model, prompt, context, options)
logger.info(f"AI call successful with {model.name} after rate-limit retry")
return response
except Exception as retryErr:
lastError = retryErr
logger.warning(f"Retry after rate-limit wait also failed for {model.name}: {retryErr}")
else:
logger.warning(f"Rate limit on {model.name} (retryAfter={retryAfter:.1f}s), failing over")
cooldown = max(retryAfter, 10.0) if retryAfter > 0 else 0.0
modelSelector.reportFailure(model.name, cooldownSeconds=cooldown)
if attempt < len(failoverModelList) - 1:
continue
logger.error(f"All {len(failoverModelList)} models failed for operation {options.operationType}")
break
except Exception as e:
lastError = e
logger.warning(f"AI call failed with model {model.name}: {str(e)}")
@ -142,8 +175,7 @@ class AiObjects:
logger.info(f"Trying next failover model...")
continue
else:
# All models failed
logger.error(f"💥 All {len(failoverModelList)} models failed for operation {options.operationType}")
logger.error(f"All {len(failoverModelList)} models failed for operation {options.operationType}")
break
# All failover attempts failed - return error response
@ -254,6 +286,321 @@ class AiObjects:
return response
async def _callWithMessages(self, model: AiModel, messages: List[Dict[str, Any]],
options: AiCallOptions = None,
tools: List[Dict[str, Any]] = None) -> AiCallResponse:
"""Call a model with pre-built messages (agent mode). Supports tools for native function calling."""
import json as _json
inputBytes = sum(len(str(m.get("content", "")).encode("utf-8")) for m in messages)
startTime = time.time()
if not model.functionCall:
raise ValueError(f"Model {model.name} has no function call defined")
modelCall = AiModelCall(
messages=messages,
model=model,
options=options or {},
tools=tools
)
modelResponse = await model.functionCall(modelCall)
if not modelResponse.success:
raise ValueError(f"Model call failed: {modelResponse.error}")
endTime = time.time()
processingTime = endTime - startTime
content = modelResponse.content
outputBytes = len(content.encode("utf-8"))
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes)
# Extract tool calls from metadata if present (native function calling)
responseToolCalls = None
if modelResponse.metadata:
responseToolCalls = modelResponse.metadata.get("toolCalls")
response = AiCallResponse(
content=content,
modelName=model.name,
provider=model.connectorType,
priceCHF=priceCHF,
processingTime=processingTime,
bytesSent=inputBytes,
bytesReceived=outputBytes,
errorCount=0,
toolCalls=responseToolCalls
)
if self.billingCallback:
try:
self.billingCallback(response)
except Exception as e:
logger.error(f"BILLING: Failed to record billing for model {model.name}: {e}")
return response
async def callWithTextContextStream(
self, request: AiCallRequest
) -> AsyncGenerator[Union[str, AiCallResponse], None]:
"""Streaming variant of callWithTextContext. Yields str deltas, then final AiCallResponse."""
options = request.options
availableModels = modelRegistry.getAvailableModels()
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
if allowedProviders:
filtered = [m for m in availableModels if m.connectorType in allowedProviders]
if filtered:
availableModels = filtered
else:
yield AiCallResponse(
content=f"No models match allowedProviders {allowedProviders} for operation {options.operationType}",
modelName="error", priceCHF=0.0, processingTime=0.0,
bytesSent=0, bytesReceived=0, errorCount=1,
)
return
failoverModelList = modelSelector.getFailoverModelList(
request.prompt, request.context or "", options, availableModels
)
if not failoverModelList:
yield AiCallResponse(
content=f"No suitable models found for operation {options.operationType}",
modelName="error", priceCHF=0.0, processingTime=0.0,
bytesSent=0, bytesReceived=0, errorCount=1,
)
return
_MAX_SHORT_RETRY = 15.0
lastError = None
for attempt, model in enumerate(failoverModelList):
try:
logger.info(f"Streaming AI call with model: {model.name} (attempt {attempt + 1})")
async for chunk in self._callWithMessagesStream(model, request.messages, options, request.tools):
yield chunk
return
except RateLimitExceededException as rle:
retryAfter = rle.retryAfterSeconds
lastError = rle
if 0 < retryAfter <= _MAX_SHORT_RETRY:
logger.info(f"Rate limit on {model.name}, waiting {retryAfter:.1f}s before retry")
await asyncio.sleep(retryAfter + 0.5)
try:
async for chunk in self._callWithMessagesStream(model, request.messages, options, request.tools):
yield chunk
return
except Exception as retryErr:
lastError = retryErr
logger.warning(f"Retry after rate-limit wait also failed for {model.name}: {retryErr}")
else:
logger.warning(f"Rate limit on {model.name} (retryAfter={retryAfter:.1f}s), failing over")
cooldown = max(retryAfter, 10.0) if retryAfter > 0 else 0.0
modelSelector.reportFailure(model.name, cooldownSeconds=cooldown)
if attempt < len(failoverModelList) - 1:
continue
break
except Exception as e:
lastError = e
logger.warning(f"Streaming AI call failed with {model.name}: {e}")
modelSelector.reportFailure(model.name)
if attempt < len(failoverModelList) - 1:
continue
break
yield AiCallResponse(
content=f"All models failed (stream). Last error: {lastError}",
modelName="error", priceCHF=0.0, processingTime=0.0,
bytesSent=0, bytesReceived=0, errorCount=1,
)
async def _callWithMessagesStream(
self, model: AiModel, messages: List[Dict[str, Any]],
options: AiCallOptions = None, tools: List[Dict[str, Any]] = None,
) -> AsyncGenerator[Union[str, AiCallResponse], None]:
"""Stream a model call. Yields str deltas, then final AiCallResponse with billing."""
from modules.datamodels.datamodelAi import AiModelCall, AiModelResponse
inputBytes = sum(len(str(m.get("content", "")).encode("utf-8")) for m in messages)
startTime = time.time()
if not model.functionCallStream:
response = await self._callWithMessages(model, messages, options, tools)
if response.content:
yield response.content
yield response
return
modelCall = AiModelCall(
messages=messages, model=model,
options=options or {}, tools=tools,
)
finalModelResponse = None
async for item in model.functionCallStream(modelCall):
if isinstance(item, AiModelResponse):
finalModelResponse = item
else:
yield item
if not finalModelResponse:
raise ValueError(f"Stream from {model.name} produced no final AiModelResponse")
endTime = time.time()
processingTime = endTime - startTime
content = finalModelResponse.content
outputBytes = len(content.encode("utf-8"))
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes)
responseToolCalls = None
if finalModelResponse.metadata:
responseToolCalls = finalModelResponse.metadata.get("toolCalls")
response = AiCallResponse(
content=content,
modelName=model.name,
provider=model.connectorType,
priceCHF=priceCHF,
processingTime=processingTime,
bytesSent=inputBytes,
bytesReceived=outputBytes,
errorCount=0,
toolCalls=responseToolCalls,
)
if self.billingCallback:
try:
self.billingCallback(response)
except Exception as e:
logger.error(f"BILLING: Failed to record stream billing for {model.name}: {e}")
yield response
async def callEmbedding(self, texts: List[str], options: AiCallOptions = None) -> AiCallResponse:
"""Generate embeddings for a list of texts using the best available embedding model.
Token-aware batching: splits the texts list into batches that respect the
model's contextLength (with 10% safety margin). Each batch is sent as a
separate API call; the resulting embeddings are merged in order.
Failover across providers (OpenAI -> Mistral) works identically to chat models,
but ContextLengthExceededException is NOT retried via failover (same limits).
Returns:
AiCallResponse with metadata["embeddings"] containing the vectors.
"""
from modules.aicore.aicoreBase import ContextLengthExceededException as _CtxExc
if options is None:
options = AiCallOptions(operationType=OperationTypeEnum.EMBEDDING)
else:
options.operationType = OperationTypeEnum.EMBEDDING
combinedText = " ".join(texts[:3])[:500]
availableModels = modelRegistry.getAvailableModels()
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
if allowedProviders:
filtered = [m for m in availableModels if m.connectorType in allowedProviders]
if filtered:
availableModels = filtered
else:
logger.warning(f"No embedding models match allowedProviders {allowedProviders}")
failoverModelList = modelSelector.getFailoverModelList(
combinedText, "", options, availableModels
)
if not failoverModelList:
return AiCallResponse(
content="", modelName="error", priceCHF=0.0,
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1
)
lastError = None
for attempt, model in enumerate(failoverModelList):
try:
logger.info(f"Embedding call with {model.name} (attempt {attempt + 1}/{len(failoverModelList)})")
inputBytes = sum(len(t.encode("utf-8")) for t in texts)
startTime = time.time()
batches = _buildEmbeddingBatches(texts, model.contextLength)
logger.info(
f"Embedding: {len(texts)} texts -> {len(batches)} batch(es), "
f"model contextLength={model.contextLength}"
)
allEmbeddings: List[List[float]] = []
totalPriceCHF = 0.0
for batchIdx, batch in enumerate(batches):
modelCall = AiModelCall(
model=model, options=options, embeddingInput=batch
)
modelResponse = await model.functionCall(modelCall)
if not modelResponse.success:
raise ValueError(f"Embedding batch {batchIdx + 1} failed: {modelResponse.error}")
batchEmbeddings = (modelResponse.metadata or {}).get("embeddings", [])
allEmbeddings.extend(batchEmbeddings)
batchBytes = sum(len(t.encode("utf-8")) for t in batch)
totalPriceCHF += model.calculatepriceCHF(0, batchBytes, 0)
processingTime = time.time() - startTime
if totalPriceCHF == 0.0:
totalPriceCHF = model.calculatepriceCHF(processingTime, inputBytes, 0)
response = AiCallResponse(
content="", modelName=model.name, provider=model.connectorType,
priceCHF=totalPriceCHF, processingTime=processingTime,
bytesSent=inputBytes, bytesReceived=0, errorCount=0,
metadata={"embeddings": allEmbeddings}
)
if self.billingCallback:
try:
self.billingCallback(response)
except Exception as e:
logger.error(f"BILLING: Failed to record billing for embedding {model.name}: {e}")
return response
except _CtxExc as e:
logger.error(f"ContextLengthExceeded for {model.name} despite batching aborting failover: {e}")
return AiCallResponse(
content=str(e), modelName=model.name, priceCHF=0.0,
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1
)
except RateLimitExceededException as rle:
retryAfter = rle.retryAfterSeconds
lastError = rle
cooldown = max(retryAfter, 10.0) if retryAfter > 0 else 0.0
logger.warning(f"Rate limit on {model.name} during embedding (retryAfter={retryAfter:.1f}s)")
modelSelector.reportFailure(model.name, cooldownSeconds=cooldown)
if attempt < len(failoverModelList) - 1:
continue
break
except Exception as e:
lastError = e
logger.warning(f"Embedding call failed with {model.name}: {str(e)}")
modelSelector.reportFailure(model.name)
if attempt < len(failoverModelList) - 1:
continue
break
errorMsg = f"All embedding models failed. Last error: {str(lastError)}"
logger.error(errorMsg)
return AiCallResponse(
content=errorMsg, modelName="error", priceCHF=0.0,
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1
)
# Utility methods
async def listAvailableModels(self, connectorType: str = None) -> List[Dict[str, Any]]:
@ -276,4 +623,50 @@ class AiObjects:
return [model.displayName for model in models]
# =============================================================================
# Internal helpers
# =============================================================================
_CHARS_PER_TOKEN = 4
_SAFETY_MARGIN = 0.90
def _estimateTokens(text: str) -> int:
"""Rough token estimate: 1 token ~ 4 characters."""
return max(1, len(text) // _CHARS_PER_TOKEN)
def _buildEmbeddingBatches(texts: List[str], contextLength: int) -> List[List[str]]:
"""Split a list of texts into batches whose total estimated token count
stays within the model's contextLength (with safety margin).
Each individual text is assumed to already be within limits (enforced by
the chunking layer). If a single text exceeds the budget, it is placed
in its own batch as a last resort.
"""
if not texts:
return []
if contextLength <= 0:
return [texts]
maxTokensPerBatch = int(contextLength * _SAFETY_MARGIN)
batches: List[List[str]] = []
currentBatch: List[str] = []
currentTokens = 0
for text in texts:
textTokens = _estimateTokens(text)
if currentBatch and (currentTokens + textTokens) > maxTokensPerBatch:
batches.append(currentBatch)
currentBatch = []
currentTokens = 0
currentBatch.append(text)
currentTokens += textTokens
if currentBatch:
batches.append(currentBatch)
return batches

View file

@ -96,6 +96,9 @@ def initBootstrap(db: DatabaseConnector) -> None:
if mandateId:
initRootMandateFeatures(db, mandateId)
# Remove feature instances for features that no longer exist in the codebase
_cleanupRemovedFeatureInstances(db)
# Initialize billing settings for root mandate
if mandateId:
initRootMandateBilling(mandateId)
@ -257,6 +260,33 @@ def initRootMandateFeatures(db: DatabaseConnector, mandateId: str) -> None:
logger.info("Root mandate features initialization completed")
def _cleanupRemovedFeatureInstances(db: DatabaseConnector) -> None:
"""Remove feature instances whose featureCode no longer exists in the codebase."""
from modules.datamodels.datamodelFeatures import FeatureInstance
from modules.system.registry import loadFeatureMainModules
mainModules = loadFeatureMainModules()
activeCodes = set()
for featureName, module in mainModules.items():
if hasattr(module, "getFeatureDefinition"):
try:
featureDef = module.getFeatureDefinition()
activeCodes.add(featureDef.get("code", featureName))
except Exception:
pass
allInstances = db.getRecordset(FeatureInstance)
for inst in allInstances:
code = inst.get("featureCode") if isinstance(inst, dict) else getattr(inst, "featureCode", None)
instId = inst.get("id") if isinstance(inst, dict) else getattr(inst, "id", None)
if code and code not in activeCodes:
try:
db.recordDelete(FeatureInstance, str(instId))
logger.info(f"Removed orphaned feature instance '{instId}' (featureCode='{code}')")
except Exception as e:
logger.warning(f"Could not remove orphaned feature instance '{instId}': {e}")
def initRootMandate(db: DatabaseConnector) -> Optional[str]:
"""
Creates the Root mandate if it doesn't exist.
@ -443,7 +473,7 @@ def initRoles(db: DatabaseConnector) -> None:
# Check specifically for system template roles:
# mandateId=NULL, isSystemRole=True, featureCode=NULL
# Feature templates (e.g. chatplayground admin) share the same labels but have featureCode set!
# Feature templates (e.g. automation admin) share the same labels but have featureCode set!
allTemplates = db.getRecordset(
Role,
recordFilter={"mandateId": None, "isSystemRole": True}
@ -475,7 +505,7 @@ def _deduplicateRoles(db: DatabaseConnector) -> None:
# Group by (roleLabel, mandateId, featureInstanceId, featureCode)
# featureCode is essential: system template ('admin', None, None, None)
# must NOT be grouped with feature template ('admin', None, None, 'chatplayground')
# must NOT be grouped with feature template ('admin', None, None, 'automation')
groups: dict = {}
for role in allRoles:
key = (role.get("roleLabel"), role.get("mandateId"), role.get("featureInstanceId"), role.get("featureCode"))
@ -1931,8 +1961,6 @@ def _createStoreResourceRules(db: DatabaseConnector) -> None:
"""
storeResources = [
"resource.store.automation",
"resource.store.chatplayground",
"resource.store.codeeditor",
"resource.store.teamsbot",
]

View file

@ -680,6 +680,29 @@ class BillingObjects:
record = StripeWebhookEvent(event_id=event_id)
return self.db.recordCreate(StripeWebhookEvent, record.model_dump())
def getPaymentTransactionByReferenceId(self, referenceId: str) -> Optional[Dict[str, Any]]:
"""
Find an existing Stripe payment credit transaction by Checkout Session ID.
Args:
referenceId: Stripe Checkout Session ID (cs_xxx)
Returns:
Transaction record if found, else None
"""
try:
results = self.db.getRecordset(
BillingTransaction,
recordFilter={
"referenceType": ReferenceTypeEnum.PAYMENT.value,
"referenceId": referenceId,
}
)
return results[0] if results else None
except Exception as e:
logger.error(f"Error checking Stripe payment transaction by referenceId: {e}")
return None
# =========================================================================
# Balance Check Operations
# =========================================================================
@ -764,7 +787,11 @@ class BillingObjects:
featureCode: str = None,
aicoreProvider: str = None,
aicoreModel: str = None,
description: str = "AI Usage"
description: str = "AI Usage",
processingTime: float = None,
bytesSent: int = None,
bytesReceived: int = None,
errorCount: int = None
) -> Optional[Dict[str, Any]]:
"""
Record usage cost as a billing transaction.
@ -774,20 +801,6 @@ class BillingObjects:
- PREPAY_USER: deduct from user's own balance
- PREPAY_MANDATE: deduct from mandate pool balance
- CREDIT_POSTPAY: deduct from mandate pool balance
Args:
mandateId: Mandate ID
userId: User ID
priceCHF: Cost in CHF
workflowId: Optional workflow ID
featureInstanceId: Optional feature instance ID
featureCode: Optional feature code
aicoreProvider: AICore provider name (e.g., 'anthropic', 'openai')
aicoreModel: AICore model name (e.g., 'claude-4-sonnet', 'gpt-4o')
description: Transaction description
Returns:
Created transaction dict or None
"""
if priceCHF <= 0:
return None
@ -816,7 +829,11 @@ class BillingObjects:
featureCode=featureCode,
aicoreProvider=aicoreProvider,
aicoreModel=aicoreModel,
createdByUserId=userId
createdByUserId=userId,
processingTime=processingTime,
bytesSent=bytesSent,
bytesReceived=bytesReceived,
errorCount=errorCount
)
# Determine where to deduct balance
@ -828,6 +845,20 @@ class BillingObjects:
poolAccount = self.getOrCreateMandateAccount(mandateId)
return self.createTransaction(transaction, balanceAccountId=poolAccount["id"])
# =========================================================================
# Workflow Cost Query
# =========================================================================
def getWorkflowCost(self, workflowId: str) -> float:
"""Sum of all transaction amounts for a workflow."""
if not workflowId:
return 0.0
transactions = self.db.getRecordset(
BillingTransaction,
recordFilter={"workflowId": workflowId}
)
return sum(t.get("amount", 0.0) for t in transactions)
# =========================================================================
# Billing Model Switch Operations
# =========================================================================

View file

@ -18,7 +18,6 @@ from modules.datamodels.datamodelUam import AccessLevel
from modules.datamodels.datamodelChat import (
ChatDocument,
ChatStat,
ChatLog,
ChatMessage,
ChatWorkflow,
@ -663,10 +662,8 @@ class ChatObjects:
workflow = workflows[0]
try:
# Load related data from normalized tables
logs = self.getLogs(workflowId)
messages = self.getMessages(workflowId)
stats = self.getStats(workflowId)
# Validate workflow data against ChatWorkflow model
# Explicit type coercion: DB may store numeric fields as TEXT on some platforms
@ -694,8 +691,7 @@ class ChatObjects:
lastActivity=_toFloat(workflow.get("lastActivity")),
startedAt=_toFloat(workflow.get("startedAt")),
logs=logs,
messages=messages,
stats=stats
messages=messages
)
except Exception as e:
logger.error(f"Error validating workflow data: {str(e)}")
@ -731,7 +727,7 @@ class ChatObjects:
except Exception as e:
logger.warning(f"Could not get Root mandate: {e}")
# Note: ChatWorkflow has featureInstanceId for multi-tenancy isolation.
# Child tables (ChatMessage, ChatLog, ChatStat, ChatDocument) are user-owned
# Child tables (ChatMessage, ChatLog, ChatDocument) are user-owned
# and do NOT store featureInstanceId - they inherit isolation from ChatWorkflow.
# Ensure featureInstanceId is set from context if not already in workflowData
if "featureInstanceId" not in workflowData or not workflowData.get("featureInstanceId"):
@ -760,7 +756,7 @@ class ChatObjects:
logs=[],
messages=[],
stats=[],
workflowMode=created["workflowMode"],
workflowMode=created.get("workflowMode", "Dynamic"),
maxSteps=created.get("maxSteps", 1)
)
@ -789,23 +785,20 @@ class ChatObjects:
# Load fresh data from normalized tables
logs = self.getLogs(workflowId)
messages = self.getMessages(workflowId)
stats = self.getStats(workflowId)
# Convert to ChatWorkflow model
return ChatWorkflow(
id=updated["id"],
status=updated.get("status", workflow.status),
name=updated.get("name", workflow.name),
currentRound=updated.get("currentRound", workflow.currentRound),
currentTask=updated.get("currentTask", workflow.currentTask),
currentAction=updated.get("currentAction", workflow.currentAction),
totalTasks=updated.get("totalTasks", workflow.totalTasks),
totalActions=updated.get("totalActions", workflow.totalActions),
currentRound=updated.get("currentRound") or getattr(workflow, "currentRound", 0) or 0,
currentTask=updated.get("currentTask") or getattr(workflow, "currentTask", 0) or 0,
currentAction=updated.get("currentAction") or getattr(workflow, "currentAction", 0) or 0,
totalTasks=updated.get("totalTasks") or getattr(workflow, "totalTasks", 0) or 0,
totalActions=updated.get("totalActions") or getattr(workflow, "totalActions", 0) or 0,
lastActivity=updated.get("lastActivity", workflow.lastActivity),
startedAt=updated.get("startedAt", workflow.startedAt),
logs=logs,
messages=messages,
stats=stats
messages=messages
)
def deleteWorkflow(self, workflowId: str) -> bool:
@ -827,7 +820,6 @@ class ChatObjects:
messageId = message.id
if messageId:
# Delete message documents (but NOT the files!)
# Note: ChatStat does NOT have messageId - stats are only at workflow level
try:
existing_docs = self._getRecordset(ChatDocument, recordFilter={"messageId": messageId})
for doc in existing_docs:
@ -839,11 +831,7 @@ class ChatObjects:
self.db.recordDelete(ChatMessage, messageId)
# 2. Delete workflow stats
existing_stats = self._getRecordset(ChatStat, recordFilter={"workflowId": workflowId})
for stat in existing_stats:
self.db.recordDelete(ChatStat, stat["id"])
# 3. Delete workflow logs
# 2. Delete workflow logs
existing_logs = self._getRecordset(ChatLog, recordFilter={"workflowId": workflowId})
for log in existing_logs:
self.db.recordDelete(ChatLog, log["id"])
@ -1270,7 +1258,6 @@ class ChatObjects:
self.db.recordDelete(ChatDocument, doc["id"])
# 2. Finally delete the message itself
# Note: ChatStat has no messageId field -- stats are workflow-level, not message-level
success = self.db.recordDelete(ChatMessage, messageId)
return success
@ -1517,74 +1504,10 @@ class ChatObjects:
# Return validated ChatLog instance
return ChatLog(**createdLog)
# Stats methods
def getStats(self, workflowId: str) -> List[ChatStat]:
"""Returns list of statistics for a workflow if user has access."""
# Check workflow access first (without calling getWorkflow to avoid circular reference)
# Use RBAC filtering
workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
if not workflows:
return []
# Get stats for this workflow from normalized table
stats = self._getRecordset(ChatStat, recordFilter={"workflowId": workflowId})
if not stats:
return []
# Return all stats records sorted by creation time.
# Use parseTimestamp to tolerate mixed DB types (float/string) on INT.
# DB uses _createdAt (camelCase system field).
stats.sort(key=lambda x: parseTimestamp(x.get("_createdAt"), default=0))
# Convert to ChatStat objects, preserving _createdAt via extra="allow"
result = []
for stat in stats:
chat_stat = ChatStat(**stat)
# Explicitly preserve _createdAt from raw DB record
if "_createdAt" in stat:
setattr(chat_stat, '_createdAt', stat["_createdAt"])
result.append(chat_stat)
return result
def createStat(self, statData: Dict[str, Any]) -> ChatStat:
"""Creates a new stats record and returns it."""
try:
# Ensure workflowId is present in statData
if "workflowId" not in statData:
raise ValueError("workflowId is required in statData")
# Note: Chat data is user-owned, no mandate/featureInstance context stored
# mandateId/featureInstanceId removed from ChatStat model
# Validate the stat data against ChatStat model
stat = ChatStat(**statData)
logger.debug(f"Creating stat for workflow {statData.get('workflowId')}: "
f"process={statData.get('process')}, "
f"priceCHF={statData.get('priceCHF', 0):.4f}, "
f"processingTime={statData.get('processingTime', 0):.2f}s")
# Create the stat record in the database
created = self.db.recordCreate(ChatStat, stat)
logger.info(f"Created stat {created.get('id')} for workflow {statData.get('workflowId')}")
# Return the created ChatStat
return ChatStat(**created)
except Exception as e:
logger.error(f"Error creating workflow stat: {str(e)}")
raise
def getUnifiedChatData(self, workflowId: str, afterTimestamp: Optional[float] = None) -> Dict[str, Any]:
def getUnifiedChatData(self, workflowId: str, afterTimestamp: Optional[float] = None, workflowCost: float = 0.0) -> Dict[str, Any]:
"""
Returns unified chat data (messages, logs, stats) for a workflow in chronological order.
Uses timestamp-based selective data transfer for efficient polling.
Returns unified chat data (messages, logs) for a workflow in chronological order,
plus workflowCost from billing transactions (single source of truth).
"""
# Check workflow access first
# Use RBAC filtering
@ -1652,29 +1575,10 @@ class ChatObjects:
"item": chatLog
})
# Get stats - ChatStat model supports _createdAt via model_config extra="allow"
stats = self.getStats(workflowId)
for stat in stats:
# Apply timestamp filtering in Python
# Use _createdAt (system field from DB, preserved via model_config extra="allow")
stat_timestamp = getattr(stat, '_createdAt', None) or getUtcTimestamp()
if afterTimestamp is not None and stat_timestamp <= afterTimestamp:
continue
# Convert to dict and include _createdAt for frontend
stat_dict = stat.model_dump() if hasattr(stat, 'model_dump') else stat.dict()
stat_dict['_createdAt'] = stat_timestamp
items.append({
"type": "stat",
"createdAt": stat_timestamp,
"item": stat_dict
})
# Sort all items by createdAt timestamp for chronological order
items.sort(key=lambda x: parseTimestamp(x.get("createdAt"), default=0))
return {"items": items}
return {"items": items, "workflowCost": workflowCost}
def getInterface(currentUser: Optional[User] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> 'ChatObjects':

View file

@ -0,0 +1,247 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Interface to the Knowledge Store database (poweron_knowledge).
Provides CRUD for FileContentIndex, ContentChunk, WorkflowMemory
and semantic search via pgvector.
"""
import logging
from typing import Dict, Any, List, Optional
from modules.connectors.connectorDbPostgre import _get_cached_connector
from modules.datamodels.datamodelKnowledge import FileContentIndex, ContentChunk, WorkflowMemory
from modules.datamodels.datamodelUam import User
from modules.shared.configuration import APP_CONFIG
from modules.shared.timeUtils import getUtcTimestamp
logger = logging.getLogger(__name__)
_instances: Dict[str, "KnowledgeObjects"] = {}
class KnowledgeObjects:
"""Interface to the Knowledge Store database.
Manages FileContentIndex, ContentChunk, and WorkflowMemory with semantic search."""
def __init__(self):
self.currentUser: Optional[User] = None
self.userId: Optional[str] = None
self._initializeDatabase()
def _initializeDatabase(self):
dbHost = APP_CONFIG.get("DB_HOST", "_no_config_default_data")
dbDatabase = "poweron_knowledge"
dbUser = APP_CONFIG.get("DB_USER")
dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET")
dbPort = int(APP_CONFIG.get("DB_PORT", 5432))
self.db = _get_cached_connector(
dbHost=dbHost,
dbDatabase=dbDatabase,
dbUser=dbUser,
dbPassword=dbPassword,
dbPort=dbPort,
userId=self.userId,
)
logger.info("Knowledge Store database initialized")
def setUserContext(self, user: User):
self.currentUser = user
self.userId = user.id if user else None
if self.userId:
self.db.updateContext(self.userId)
# =========================================================================
# FileContentIndex CRUD
# =========================================================================
def upsertFileContentIndex(self, index: FileContentIndex) -> Dict[str, Any]:
"""Create or update a FileContentIndex entry."""
data = index.model_dump()
existing = self.db._loadRecord(FileContentIndex, index.id)
if existing:
return self.db.recordModify(FileContentIndex, index.id, data)
return self.db.recordCreate(FileContentIndex, data)
def getFileContentIndex(self, fileId: str) -> Optional[Dict[str, Any]]:
"""Get a FileContentIndex by file ID."""
return self.db._loadRecord(FileContentIndex, fileId)
def getFileContentIndexByUser(
self, userId: str, featureInstanceId: str = None
) -> List[Dict[str, Any]]:
"""Get all FileContentIndex entries for a user."""
recordFilter = {"userId": userId}
if featureInstanceId:
recordFilter["featureInstanceId"] = featureInstanceId
return self.db.getRecordset(FileContentIndex, recordFilter=recordFilter)
def updateFileStatus(self, fileId: str, status: str) -> bool:
"""Update the processing status of a FileContentIndex."""
existing = self.db._loadRecord(FileContentIndex, fileId)
if not existing:
return False
self.db.recordModify(FileContentIndex, fileId, {"status": status})
return True
def deleteFileContentIndex(self, fileId: str) -> bool:
"""Delete a FileContentIndex and all associated ContentChunks."""
chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
for chunk in chunks:
self.db.recordDelete(ContentChunk, chunk["id"])
return self.db.recordDelete(FileContentIndex, fileId)
# =========================================================================
# ContentChunk CRUD
# =========================================================================
def upsertContentChunk(self, chunk: ContentChunk) -> Dict[str, Any]:
"""Create or update a ContentChunk."""
data = chunk.model_dump()
existing = self.db._loadRecord(ContentChunk, chunk.id)
if existing:
return self.db.recordModify(ContentChunk, chunk.id, data)
return self.db.recordCreate(ContentChunk, data)
def upsertContentChunks(self, chunks: List[ContentChunk]) -> int:
"""Batch upsert multiple ContentChunks. Returns count of upserted chunks."""
count = 0
for chunk in chunks:
self.upsertContentChunk(chunk)
count += 1
return count
def getContentChunks(self, fileId: str) -> List[Dict[str, Any]]:
"""Get all ContentChunks for a file."""
return self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
def deleteContentChunks(self, fileId: str) -> int:
"""Delete all ContentChunks for a file. Returns count of deleted chunks."""
chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
count = 0
for chunk in chunks:
if self.db.recordDelete(ContentChunk, chunk["id"]):
count += 1
return count
# =========================================================================
# WorkflowMemory CRUD
# =========================================================================
def upsertWorkflowMemory(self, memory: WorkflowMemory) -> Dict[str, Any]:
"""Create or update a WorkflowMemory entry."""
data = memory.model_dump()
existing = self.db._loadRecord(WorkflowMemory, memory.id)
if existing:
return self.db.recordModify(WorkflowMemory, memory.id, data)
return self.db.recordCreate(WorkflowMemory, data)
def getWorkflowEntities(self, workflowId: str) -> List[Dict[str, Any]]:
"""Get all WorkflowMemory entries for a workflow."""
return self.db.getRecordset(WorkflowMemory, recordFilter={"workflowId": workflowId})
def getWorkflowEntity(self, workflowId: str, key: str) -> Optional[Dict[str, Any]]:
"""Get a specific WorkflowMemory entry by workflow and key."""
results = self.db.getRecordset(
WorkflowMemory, recordFilter={"workflowId": workflowId, "key": key}
)
return results[0] if results else None
def deleteWorkflowMemory(self, workflowId: str) -> int:
"""Delete all WorkflowMemory entries for a workflow. Returns count."""
entries = self.db.getRecordset(WorkflowMemory, recordFilter={"workflowId": workflowId})
count = 0
for entry in entries:
if self.db.recordDelete(WorkflowMemory, entry["id"]):
count += 1
return count
# =========================================================================
# Semantic Search
# =========================================================================
def semanticSearch(
self,
queryVector: List[float],
userId: str = None,
featureInstanceId: str = None,
mandateId: str = None,
isShared: bool = None,
limit: int = 10,
minScore: float = None,
contentType: str = None,
) -> List[Dict[str, Any]]:
"""Semantic search across ContentChunks using pgvector cosine similarity.
Args:
queryVector: Query embedding vector.
userId: Filter by user (Instance Layer).
featureInstanceId: Filter by feature instance.
mandateId: Filter by mandate (for Shared Layer lookups).
isShared: If True, search Shared Layer via FileContentIndex join.
limit: Max results.
minScore: Minimum cosine similarity (0.0 - 1.0).
contentType: Filter by content type (text, image, etc.).
Returns:
List of ContentChunk records with _score field, sorted by relevance.
"""
recordFilter = {}
if userId:
recordFilter["userId"] = userId
if featureInstanceId:
recordFilter["featureInstanceId"] = featureInstanceId
if contentType:
recordFilter["contentType"] = contentType
if isShared and mandateId:
sharedIndexes = self.db.getRecordset(
FileContentIndex,
recordFilter={"mandateId": mandateId, "isShared": True},
)
sharedFileIds = [idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", None) for idx in sharedIndexes]
sharedFileIds = [fid for fid in sharedFileIds if fid]
if not sharedFileIds:
return []
recordFilter.pop("userId", None)
recordFilter.pop("featureInstanceId", None)
recordFilter["fileId"] = sharedFileIds
return self.db.semanticSearch(
modelClass=ContentChunk,
vectorColumn="embedding",
queryVector=queryVector,
limit=limit,
recordFilter=recordFilter if recordFilter else None,
minScore=minScore,
)
def semanticSearchWorkflowMemory(
self,
queryVector: List[float],
workflowId: str,
limit: int = 5,
minScore: float = None,
) -> List[Dict[str, Any]]:
"""Semantic search across WorkflowMemory entries."""
return self.db.semanticSearch(
modelClass=WorkflowMemory,
vectorColumn="embedding",
queryVector=queryVector,
limit=limit,
recordFilter={"workflowId": workflowId},
minScore=minScore,
)
def getInterface(currentUser: Optional[User] = None) -> KnowledgeObjects:
"""Get or create a KnowledgeObjects singleton."""
if "default" not in _instances:
_instances["default"] = KnowledgeObjects()
interface = _instances["default"]
if currentUser:
interface.setUserContext(currentUser)
return interface

View file

@ -10,6 +10,7 @@ import logging
import base64
import hashlib
import math
import mimetypes
from typing import Dict, Any, List, Optional, Union
from modules.connectors.connectorDbPostgre import DatabaseConnector, _get_cached_connector
@ -18,6 +19,7 @@ from modules.security.rbac import RbacClass
from modules.datamodels.datamodelRbac import AccessRuleContext
from modules.datamodels.datamodelUam import AccessLevel
from modules.datamodels.datamodelFiles import FilePreview, FileItem, FileData
from modules.datamodels.datamodelFileFolder import FileFolder
from modules.datamodels.datamodelUtils import Prompt
from modules.datamodels.datamodelVoice import VoiceSettings
from modules.datamodels.datamodelMessaging import (
@ -851,7 +853,9 @@ class ComponentObjects:
"svg": "image/svg+xml",
"py": "text/x-python",
"js": "application/javascript",
"css": "text/css"
"css": "text/css",
"eml": "message/rfc822",
"msg": "application/vnd.ms-outlook",
}
return extensionToMime.get(ext.lower(), "application/octet-stream")
@ -1143,6 +1147,350 @@ class ComponentObjects:
logger.error(f"Error deleting file {fileId}: {str(e)}")
raise FileDeletionError(f"Error deleting file: {str(e)}")
def deleteFilesBatch(self, fileIds: List[str]) -> Dict[str, Any]:
"""Delete multiple files in a single SQL batch call."""
uniqueIds = [str(fid) for fid in dict.fromkeys(fileIds or []) if fid]
if not uniqueIds:
return {"deletedFiles": 0}
try:
self.db._ensure_connection()
with self.db.connection.cursor() as cursor:
cursor.execute(
'SELECT "id" FROM "FileItem" WHERE "id" = ANY(%s) AND "_createdBy" = %s',
(uniqueIds, self.userId or ""),
)
accessibleIds = [row["id"] for row in cursor.fetchall()]
if len(accessibleIds) != len(uniqueIds):
missingIds = sorted(set(uniqueIds) - set(accessibleIds))
raise FileNotFoundError(f"Files not found or not accessible: {missingIds}")
cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (accessibleIds,))
cursor.execute(
'DELETE FROM "FileItem" WHERE "id" = ANY(%s) AND "_createdBy" = %s',
(accessibleIds, self.userId or ""),
)
deletedFiles = cursor.rowcount
self.db.connection.commit()
return {"deletedFiles": deletedFiles}
except Exception as e:
logger.error(f"Error deleting files in batch: {e}")
self.db.connection.rollback()
raise FileDeletionError(f"Error deleting files in batch: {str(e)}")
# ---- Folder methods ----
_RESERVED_FOLDER_NAMES = {"(Global)"}
def _validateFolderName(self, name: str, parentId: Optional[str], excludeFolderId: Optional[str] = None):
"""Ensures folder name is not reserved and is unique within parent."""
if name in self._RESERVED_FOLDER_NAMES:
raise ValueError(f"Folder name '{name}' is reserved")
if not name or not name.strip():
raise ValueError("Folder name cannot be empty")
existingFolders = self.db.getRecordset(FileFolder, recordFilter={"parentId": parentId or ""})
for f in existingFolders:
if f.get("name") == name and f.get("id") != excludeFolderId:
raise ValueError(f"Folder '{name}' already exists in this directory")
def _isDescendantOf(self, folderId: str, ancestorId: str) -> bool:
"""Checks if folderId is a descendant of ancestorId (circular reference check)."""
visited = set()
currentId = folderId
while currentId:
if currentId == ancestorId:
return True
if currentId in visited:
break
visited.add(currentId)
folders = self.db.getRecordset(FileFolder, recordFilter={"id": currentId})
if not folders:
break
currentId = folders[0].get("parentId")
return False
def getFolder(self, folderId: str) -> Optional[Dict[str, Any]]:
"""Returns a folder by ID if it belongs to the current user."""
folders = self.db.getRecordset(FileFolder, recordFilter={"id": folderId, "_createdBy": self.userId or ""})
return folders[0] if folders else None
def listFolders(self, parentId: Optional[str] = None) -> List[Dict[str, Any]]:
"""List folders for current user, optionally filtered by parentId."""
recordFilter = {"_createdBy": self.userId or ""}
if parentId is not None:
recordFilter["parentId"] = parentId
return self.db.getRecordset(FileFolder, recordFilter=recordFilter)
def createFolder(self, name: str, parentId: Optional[str] = None) -> Dict[str, Any]:
"""Create a new folder with unique name validation."""
self._validateFolderName(name, parentId)
folder = FileFolder(
name=name,
parentId=parentId,
mandateId=self.mandateId or "",
featureInstanceId=self.featureInstanceId or "",
)
return self.db.recordCreate(FileFolder, folder)
def renameFolder(self, folderId: str, newName: str) -> bool:
"""Rename a folder with unique name validation."""
folder = self.getFolder(folderId)
if not folder:
raise FileNotFoundError(f"Folder {folderId} not found")
self._validateFolderName(newName, folder.get("parentId"), excludeFolderId=folderId)
return self.db.recordModify(FileFolder, folderId, {"name": newName})
def moveFolder(self, folderId: str, targetParentId: Optional[str] = None) -> bool:
"""Move a folder to a new parent, with circular reference and unique name checks."""
folder = self.getFolder(folderId)
if not folder:
raise FileNotFoundError(f"Folder {folderId} not found")
if targetParentId and self._isDescendantOf(targetParentId, folderId):
raise ValueError("Cannot move folder into its own subtree")
self._validateFolderName(folder.get("name", ""), targetParentId, excludeFolderId=folderId)
return self.db.recordModify(FileFolder, folderId, {"parentId": targetParentId})
def moveFilesBatch(self, fileIds: List[str], targetFolderId: Optional[str] = None) -> Dict[str, Any]:
"""Move multiple files with one SQL update."""
uniqueIds = [str(fid) for fid in dict.fromkeys(fileIds or []) if fid]
if not uniqueIds:
return {"movedFiles": 0}
if targetFolderId:
targetFolder = self.getFolder(targetFolderId)
if not targetFolder:
raise FileNotFoundError(f"Target folder {targetFolderId} not found")
try:
self.db._ensure_connection()
with self.db.connection.cursor() as cursor:
cursor.execute(
'SELECT "id" FROM "FileItem" WHERE "id" = ANY(%s) AND "_createdBy" = %s',
(uniqueIds, self.userId or ""),
)
accessibleIds = [row["id"] for row in cursor.fetchall()]
if len(accessibleIds) != len(uniqueIds):
missingIds = sorted(set(uniqueIds) - set(accessibleIds))
raise FileNotFoundError(f"Files not found or not accessible: {missingIds}")
cursor.execute(
'UPDATE "FileItem" SET "folderId" = %s, "_modifiedAt" = %s, "_modifiedBy" = %s '
'WHERE "id" = ANY(%s) AND "_createdBy" = %s',
(targetFolderId, getUtcTimestamp(), self.userId or "", accessibleIds, self.userId or ""),
)
movedFiles = cursor.rowcount
self.db.connection.commit()
return {"movedFiles": movedFiles}
except Exception as e:
logger.error(f"Error moving files in batch: {e}")
self.db.connection.rollback()
raise FileError(f"Error moving files in batch: {str(e)}")
def moveFoldersBatch(self, folderIds: List[str], targetParentId: Optional[str] = None) -> Dict[str, Any]:
"""Move multiple folders with one SQL update after validation."""
uniqueIds = [str(fid) for fid in dict.fromkeys(folderIds or []) if fid]
if not uniqueIds:
return {"movedFolders": 0}
foldersToMove: List[Dict[str, Any]] = []
for folderId in uniqueIds:
folder = self.getFolder(folderId)
if not folder:
raise FileNotFoundError(f"Folder {folderId} not found")
if targetParentId and self._isDescendantOf(targetParentId, folderId):
raise ValueError("Cannot move folder into its own subtree")
foldersToMove.append(folder)
existingInTarget = self.db.getRecordset(
FileFolder,
recordFilter={"parentId": targetParentId or "", "_createdBy": self.userId or ""},
)
existingNames = {f.get("name"): f.get("id") for f in existingInTarget}
movingNames: Dict[str, str] = {}
movingIds = set(uniqueIds)
for folder in foldersToMove:
name = folder.get("name", "")
folderId = folder.get("id")
if name in movingNames and movingNames[name] != folderId:
raise ValueError(f"Folder '{name}' already exists in this move batch")
movingNames[name] = folderId
existingId = existingNames.get(name)
if existingId and existingId not in movingIds:
raise ValueError(f"Folder '{name}' already exists in target directory")
try:
self.db._ensure_connection()
with self.db.connection.cursor() as cursor:
cursor.execute(
'UPDATE "FileFolder" SET "parentId" = %s, "_modifiedAt" = %s, "_modifiedBy" = %s '
'WHERE "id" = ANY(%s) AND "_createdBy" = %s',
(targetParentId, getUtcTimestamp(), self.userId or "", uniqueIds, self.userId or ""),
)
movedFolders = cursor.rowcount
self.db.connection.commit()
return {"movedFolders": movedFolders}
except Exception as e:
logger.error(f"Error moving folders in batch: {e}")
self.db.connection.rollback()
raise FileError(f"Error moving folders in batch: {str(e)}")
def deleteFolder(self, folderId: str, recursive: bool = False) -> Dict[str, Any]:
"""Delete a folder. If recursive, deletes all contents. Returns summary of deletions."""
folder = self.getFolder(folderId)
if not folder:
raise FileNotFoundError(f"Folder {folderId} not found")
childFolders = self.db.getRecordset(FileFolder, recordFilter={"parentId": folderId, "_createdBy": self.userId or ""})
childFiles = self._getFilesByCurrentUser(recordFilter={"folderId": folderId})
if not recursive and (childFolders or childFiles):
raise ValueError(
f"Folder '{folder.get('name')}' is not empty "
f"({len(childFiles)} files, {len(childFolders)} subfolders). "
f"Use recursive=true to delete contents."
)
deletedFiles = 0
deletedFolders = 0
if recursive:
for subFolder in childFolders:
subResult = self.deleteFolder(subFolder["id"], recursive=True)
deletedFiles += subResult.get("deletedFiles", 0)
deletedFolders += subResult.get("deletedFolders", 0)
for childFile in childFiles:
try:
self.deleteFile(childFile["id"])
deletedFiles += 1
except Exception as e:
logger.warning(f"Failed to delete file {childFile['id']} during folder deletion: {e}")
self.db.recordDelete(FileFolder, folderId)
deletedFolders += 1
return {"deletedFiles": deletedFiles, "deletedFolders": deletedFolders}
def deleteFoldersBatch(self, folderIds: List[str], recursive: bool = True) -> Dict[str, Any]:
"""Delete multiple folders and their content in batched SQL calls."""
uniqueIds = [str(fid) for fid in dict.fromkeys(folderIds or []) if fid]
if not uniqueIds:
return {"deletedFiles": 0, "deletedFolders": 0}
if not recursive:
deletedFiles = 0
deletedFolders = 0
for folderId in uniqueIds:
result = self.deleteFolder(folderId, recursive=False)
deletedFiles += result.get("deletedFiles", 0)
deletedFolders += result.get("deletedFolders", 0)
return {"deletedFiles": deletedFiles, "deletedFolders": deletedFolders}
try:
self.db._ensure_connection()
with self.db.connection.cursor() as cursor:
cursor.execute(
'SELECT "id" FROM "FileFolder" WHERE "id" = ANY(%s) AND "_createdBy" = %s',
(uniqueIds, self.userId or ""),
)
rootAccessibleIds = [row["id"] for row in cursor.fetchall()]
if len(rootAccessibleIds) != len(uniqueIds):
missingIds = sorted(set(uniqueIds) - set(rootAccessibleIds))
raise FileNotFoundError(f"Folders not found or not accessible: {missingIds}")
cursor.execute(
"""
WITH RECURSIVE folder_tree AS (
SELECT "id"
FROM "FileFolder"
WHERE "id" = ANY(%s) AND "_createdBy" = %s
UNION ALL
SELECT child."id"
FROM "FileFolder" child
INNER JOIN folder_tree ft ON child."parentId" = ft."id"
WHERE child."_createdBy" = %s
)
SELECT DISTINCT "id" FROM folder_tree
""",
(rootAccessibleIds, self.userId or "", self.userId or ""),
)
allFolderIds = [row["id"] for row in cursor.fetchall()]
cursor.execute(
'SELECT "id" FROM "FileItem" WHERE "folderId" = ANY(%s) AND "_createdBy" = %s',
(allFolderIds, self.userId or ""),
)
allFileIds = [row["id"] for row in cursor.fetchall()]
if allFileIds:
cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (allFileIds,))
cursor.execute(
'DELETE FROM "FileItem" WHERE "id" = ANY(%s) AND "_createdBy" = %s',
(allFileIds, self.userId or ""),
)
deletedFiles = cursor.rowcount
else:
deletedFiles = 0
cursor.execute(
'DELETE FROM "FileFolder" WHERE "id" = ANY(%s) AND "_createdBy" = %s',
(allFolderIds, self.userId or ""),
)
deletedFolders = cursor.rowcount
self.db.connection.commit()
return {"deletedFiles": deletedFiles, "deletedFolders": deletedFolders}
except Exception as e:
logger.error(f"Error deleting folders in batch: {e}")
self.db.connection.rollback()
raise FileDeletionError(f"Error deleting folders in batch: {str(e)}")
def copyFile(self, sourceFileId: str, targetFolderId: Optional[str] = None, newFileName: Optional[str] = None) -> FileItem:
"""Create a full duplicate of a file (FileItem + FileData)."""
sourceFile = self.getFile(sourceFileId)
if not sourceFile:
raise FileNotFoundError(f"File {sourceFileId} not found")
sourceData = self.getFileData(sourceFileId)
if sourceData is None:
raise FileStorageError(f"No data found for file {sourceFileId}")
fileName = newFileName or sourceFile.fileName
copiedFile = self.createFile(fileName, sourceFile.mimeType, sourceData)
if targetFolderId:
self.updateFile(copiedFile.id, {"folderId": targetFolderId})
elif sourceFile.folderId:
self.updateFile(copiedFile.id, {"folderId": sourceFile.folderId})
self.createFileData(copiedFile.id, sourceData)
return copiedFile
def updateFileData(self, fileId: str, data: bytes) -> bool:
"""Replace existing file data (delete + create). Updates FileItem metadata."""
file = self.getFile(fileId)
if not file:
raise FileNotFoundError(f"File {fileId} not found")
try:
self.db.recordDelete(FileData, fileId)
logger.debug(f"Deleted existing FileData for {fileId}")
except Exception as e:
logger.debug(f"No existing FileData to delete for {fileId}: {e}")
success = self.createFileData(fileId, data)
if success:
newSize = len(data)
newHash = hashlib.sha256(data).hexdigest()
self.db.recordModify(FileItem, fileId, {"fileSize": newSize, "fileHash": newHash})
logger.info(f"Updated file data for {fileId} ({newSize} bytes)")
return success
# FileData methods - data operations
def createFileData(self, fileId: str, data: bytes) -> bool:
@ -1395,11 +1743,15 @@ class ComponentObjects:
logger.error("No user ID provided for voice settings")
return None
# Get voice settings for the user, filtered by RBAC
recordFilter: Dict[str, Any] = {"userId": targetUserId}
if self.featureInstanceId:
recordFilter["featureInstanceId"] = self.featureInstanceId
# Get voice settings for the user (scoped to current feature instance if available), filtered by RBAC
filteredSettings = getRecordsetWithRBAC(self.db,
VoiceSettings,
self.currentUser,
recordFilter={"userId": targetUserId},
recordFilter=recordFilter,
mandateId=self.mandateId
)

View file

@ -58,7 +58,6 @@ TABLE_NAMESPACE = {
"ChatWorkflow": "chat",
"ChatMessage": "chat",
"ChatLog": "chat",
"ChatStat": "chat",
"ChatDocument": "chat",
"Prompt": "chat",
# Chatbot (poweron_chatbot) - per feature-instance isolation
@ -69,13 +68,20 @@ TABLE_NAMESPACE = {
# Files - benutzer-eigen
"FileItem": "files",
"FileData": "files",
"FileFolder": "files",
# Automation - benutzer-eigen
"AutomationDefinition": "automation",
"AutomationTemplate": "automation",
# Knowledge Store - benutzer-eigen
"FileContentIndex": "knowledge",
"ContentChunk": "knowledge",
"WorkflowMemory": "knowledge",
# Data Sources - benutzer-eigen
"DataSource": "datasource",
}
# Namespaces ohne Mandantenkontext - GROUP wird auf MY gemappt
USER_OWNED_NAMESPACES = {"chat", "chatbot", "files", "automation"}
USER_OWNED_NAMESPACES = {"chat", "chatbot", "files", "automation", "knowledge", "datasource"}
def buildDataObjectKey(tableName: str, featureCode: Optional[str] = None) -> str:
@ -175,7 +181,7 @@ def getRecordsetWithRBAC(
whereValues = []
# CRITICAL: Only pass featureInstanceId to WHERE clause if the model actually has
# this column. Chat child tables (ChatMessage, ChatLog, ChatStat, ChatDocument)
# this column. Chat child tables (ChatMessage, ChatLog, ChatDocument)
# are user-owned and do NOT have featureInstanceId - only ChatWorkflow does.
# Without this check, the SQL query would reference a non-existent column,
# causing a silent error that returns empty results.

View file

@ -6,8 +6,9 @@ Provides a generic interface layer between routes and voice connectors.
Handles voice operations including speech-to-text, text-to-speech, and translation.
"""
import asyncio
import logging
from typing import Dict, Any, Optional, List
from typing import AsyncGenerator, Callable, Dict, Any, Optional, List
from modules.connectors.connectorVoiceGoogle import ConnectorGoogleSpeech
from modules.datamodels.datamodelVoice import VoiceSettings
@ -30,6 +31,7 @@ class VoiceObjects:
self.currentUser: Optional[User] = None
self.userId: Optional[str] = None
self._google_speech_connector: Optional[ConnectorGoogleSpeech] = None
self.billingCallback: Optional[Callable[[Dict[str, Any]], None]] = None
def setUserContext(self, currentUser: User, mandateId: Optional[str] = None):
"""Set the user context for the interface.
@ -115,6 +117,32 @@ class VoiceObjects:
"error": str(e)
}
async def streamingSpeechToText(
self,
audioQueue: asyncio.Queue,
language: str = "de-DE",
phraseHints: Optional[list] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Stream audio to Google Streaming STT and yield interim/final results.
Billing is recorded for each final result.
"""
connector = self._getGoogleSpeechConnector()
async for event in connector.streamingRecognize(audioQueue, language, phraseHints):
if event.get("isFinal") and self.billingCallback:
durationSec = event.get("audioDurationSec", 0)
priceCHF = connector.calculateSttCostCHF(durationSec)
if priceCHF > 0:
try:
self.billingCallback({
"operation": "stt-streaming",
"priceCHF": priceCHF,
"audioDurationSec": durationSec,
})
except Exception as e:
logger.warning(f"Voice STT billing callback failed: {e}")
yield event
# Translation Operations
async def detectLanguage(self, text: str) -> Dict[str, Any]:
@ -277,7 +305,18 @@ class VoiceObjects:
if result["success"]:
logger.info(f"✅ Text-to-Speech successful: {len(result['audio_content'])} bytes")
# Map connector snake_case keys to camelCase for consistent API
if self.billingCallback:
connector = self._getGoogleSpeechConnector()
priceCHF = connector.calculateTtsCostCHF(len(text))
if priceCHF > 0:
try:
self.billingCallback({
"operation": "tts-wavenet",
"priceCHF": priceCHF,
"characterCount": len(text),
})
except Exception as e:
logger.warning(f"Voice TTS billing callback failed: {e}")
return {
"success": True,
"audioContent": result["audio_content"],

View file

@ -247,19 +247,13 @@ def _getInstancePermissions(rootInterface, userId: str, instanceId: str) -> Dict
# Get FeatureAccess for this user and instance (Pydantic model)
featureAccess = rootInterface.getFeatureAccess(userId, instanceId)
logger.debug(f"_getInstancePermissions: userId={userId}, instanceId={instanceId}, featureAccess={featureAccess is not None}")
if not featureAccess:
logger.debug(f"_getInstancePermissions: No FeatureAccess found for user {userId} and instance {instanceId}")
return permissions
# Get role IDs via interface method
roleIds = rootInterface.getRoleIdsForFeatureAccess(str(featureAccess.id))
logger.debug(f"_getInstancePermissions: featureAccessId={featureAccess.id}, roleIds={roleIds}")
if not roleIds:
logger.debug(f"_getInstancePermissions: No roles found for FeatureAccess {featureAccess.id}")
return permissions
# Check if user has admin role
@ -274,8 +268,6 @@ def _getInstancePermissions(rootInterface, userId: str, instanceId: str) -> Dict
# Get all rules for this role (returns Pydantic models)
accessRules = rootInterface.getAccessRules(roleId=roleId)
logger.debug(f"_getInstancePermissions: roleId={roleId}, accessRules={len(accessRules) if accessRules else 0}")
for rule in accessRules:
context = rule.context
item = rule.item or ""

View file

@ -21,7 +21,7 @@ from modules.auth import limiter, requireSysAdminRole, getRequestContext, Reques
# Import billing components
from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface, _getRootInterface
from modules.services.serviceBilling.mainServiceBilling import getService as getBillingService
from modules.serviceCenter.services.serviceBilling.mainServiceBilling import getService as getBillingService
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
from modules.routes.routeDataUsers import _applyFiltersAndSort
from modules.datamodels.datamodelBilling import (
@ -162,6 +162,23 @@ def _isAdminOfMandate(ctx: RequestContext, targetMandateId: str) -> bool:
return False
def _isMemberOfMandate(ctx: RequestContext, targetMandateId: str) -> bool:
"""Check if user has any enabled membership in the specified mandate."""
try:
from modules.interfaces.interfaceDbApp import getRootInterface
rootInterface = getRootInterface()
userMandates = rootInterface.getUserMandates(str(ctx.user.id))
for um in userMandates:
if str(getattr(um, 'mandateId', None)) != str(targetMandateId):
continue
if not getattr(um, 'enabled', True):
continue
return True
return False
except Exception:
return False
def _filterTransactionsByScope(transactions: list, scope: BillingDataScope) -> list:
"""
Filter a list of transaction dicts based on the user's BillingDataScope.
@ -219,6 +236,7 @@ class CheckoutCreateRequest(BaseModel):
"""Request model for creating Stripe Checkout Session."""
userId: Optional[str] = Field(None, description="Target user ID (for PREPAY_USER model)")
amount: float = Field(..., gt=0, description="Amount to pay in CHF (must be in allowed presets)")
returnUrl: str = Field(..., min_length=1, description="Absolute frontend URL used for Stripe success/cancel redirects")
class CheckoutCreateResponse(BaseModel):
@ -226,6 +244,20 @@ class CheckoutCreateResponse(BaseModel):
redirectUrl: str = Field(..., description="Stripe Checkout URL for redirect")
class CheckoutConfirmRequest(BaseModel):
"""Request model for confirming Stripe Checkout after redirect."""
sessionId: str = Field(..., min_length=1, description="Stripe Checkout Session ID (cs_xxx)")
class CheckoutConfirmResponse(BaseModel):
"""Response model for Stripe Checkout confirmation."""
credited: bool = Field(..., description="True if a new billing credit was created")
alreadyCredited: bool = Field(..., description="True if session was already credited before")
sessionId: str = Field(..., description="Stripe Checkout Session ID")
mandateId: str = Field(..., description="Mandate ID from Stripe metadata")
amountChf: float = Field(..., description="Credited amount in CHF")
class BillingSettingsUpdate(BaseModel):
"""Request model for updating billing settings."""
billingModel: Optional[BillingModelEnum] = None
@ -328,6 +360,107 @@ class UserTransactionResponse(BaseModel):
userName: Optional[str] = None
def _getStripeClient():
"""Initialize and return configured Stripe SDK module."""
import stripe
from modules.shared.configuration import APP_CONFIG
api_version = APP_CONFIG.get("STRIPE_API_VERSION")
if api_version:
stripe.api_version = api_version
secret_key = APP_CONFIG.get("STRIPE_SECRET_KEY_SECRET") or APP_CONFIG.get("STRIPE_SECRET_KEY")
if not secret_key:
raise ValueError("STRIPE_SECRET_KEY_SECRET not configured")
stripe.api_key = secret_key
return stripe
def _creditStripeSessionIfNeeded(
billingInterface,
session: Dict[str, Any],
eventId: Optional[str] = None,
) -> CheckoutConfirmResponse:
"""
Credit balance from Stripe Checkout session if not already credited.
Uses Checkout session ID for idempotency across webhook + manual confirmation flows.
"""
from modules.serviceCenter.services.serviceBilling.stripeCheckout import ALLOWED_AMOUNTS_CHF
session_id = session.get("id")
metadata = session.get("metadata") or {}
mandate_id = metadata.get("mandateId")
user_id = metadata.get("userId") or None
amount_chf_str = metadata.get("amountChf", "0")
if not session_id:
raise HTTPException(status_code=400, detail="Stripe session id missing")
if not mandate_id:
raise HTTPException(status_code=400, detail="Invalid session metadata: mandateId missing")
existing_payment_tx = billingInterface.getPaymentTransactionByReferenceId(session_id)
if existing_payment_tx:
if eventId and not billingInterface.getStripeWebhookEventByEventId(eventId):
billingInterface.createStripeWebhookEvent(eventId)
return CheckoutConfirmResponse(
credited=False,
alreadyCredited=True,
sessionId=session_id,
mandateId=mandate_id,
amountChf=float(existing_payment_tx.get("amount", 0.0)),
)
try:
amount_chf = float(amount_chf_str)
except (TypeError, ValueError):
amount_chf = None
if amount_chf is None or amount_chf not in ALLOWED_AMOUNTS_CHF:
amount_total = session.get("amount_total")
if amount_total is not None:
amount_chf = amount_total / 100.0
else:
raise HTTPException(status_code=400, detail="Invalid amount in Stripe session")
settings = billingInterface.getSettings(mandate_id)
if not settings:
raise HTTPException(status_code=404, detail="Billing settings not found")
billing_model = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value))
if billing_model == BillingModelEnum.PREPAY_USER:
if not user_id:
raise HTTPException(status_code=400, detail="userId required for PREPAY_USER")
account = billingInterface.getOrCreateUserAccount(mandate_id, user_id, initialBalance=0.0)
elif billing_model in [BillingModelEnum.PREPAY_MANDATE, BillingModelEnum.CREDIT_POSTPAY]:
account = billingInterface.getOrCreateMandateAccount(mandate_id, initialBalance=0.0)
else:
raise HTTPException(status_code=400, detail=f"Cannot add credit to {billing_model.value}")
transaction = BillingTransaction(
accountId=account["id"],
transactionType=TransactionTypeEnum.CREDIT,
amount=amount_chf,
description="Stripe-Zahlung",
referenceType=ReferenceTypeEnum.PAYMENT,
referenceId=session_id,
createdByUserId=user_id,
)
billingInterface.createTransaction(transaction)
if eventId and not billingInterface.getStripeWebhookEventByEventId(eventId):
billingInterface.createStripeWebhookEvent(eventId)
logger.info(f"Stripe credit applied: {amount_chf} CHF for session {session_id} on mandate {mandate_id}")
return CheckoutConfirmResponse(
credited=True,
alreadyCredited=False,
sessionId=session_id,
mandateId=mandate_id,
amountChf=amount_chf,
)
# =============================================================================
# Router Setup
# =============================================================================
@ -720,11 +853,11 @@ def createCheckoutSession(
targetMandateId: str = Path(..., description="Mandate ID"),
checkoutRequest: CheckoutCreateRequest = Body(...),
ctx: RequestContext = Depends(getRequestContext),
_admin = Depends(requireSysAdminRole)
):
"""
Create Stripe Checkout Session for credit top-up. Returns redirect URL.
SysAdmin only. Amount is validated server-side against allowed presets.
RBAC: PREPAY_USER requires mandate membership (user loads own account),
PREPAY_MANDATE requires mandate admin role.
"""
try:
billingInterface = getBillingInterface(ctx.user, targetMandateId)
@ -738,14 +871,22 @@ def createCheckoutSession(
if billingModel == BillingModelEnum.PREPAY_USER:
if not checkoutRequest.userId:
raise HTTPException(status_code=400, detail="userId is required for PREPAY_USER model")
elif billingModel not in [BillingModelEnum.PREPAY_MANDATE, BillingModelEnum.CREDIT_POSTPAY]:
if str(checkoutRequest.userId) != str(ctx.user.id):
raise HTTPException(status_code=403, detail="Users can only load credit to their own account")
if not _isMemberOfMandate(ctx, targetMandateId):
raise HTTPException(status_code=403, detail="User is not a member of this mandate")
elif billingModel in [BillingModelEnum.PREPAY_MANDATE, BillingModelEnum.CREDIT_POSTPAY]:
if not _isAdminOfMandate(ctx, targetMandateId):
raise HTTPException(status_code=403, detail="Mandate admin role required to load mandate credit")
else:
raise HTTPException(status_code=400, detail=f"Cannot add credit to {billingModel.value} billing model")
from modules.services.serviceBilling.stripeCheckout import create_checkout_session
from modules.serviceCenter.services.serviceBilling.stripeCheckout import create_checkout_session
redirect_url = create_checkout_session(
mandate_id=targetMandateId,
user_id=checkoutRequest.userId,
amount_chf=checkoutRequest.amount
amount_chf=checkoutRequest.amount,
return_url=checkoutRequest.returnUrl
)
return CheckoutCreateResponse(redirectUrl=redirect_url)
@ -758,6 +899,65 @@ def createCheckoutSession(
raise HTTPException(status_code=500, detail=str(e))
@router.post("/checkout/confirm", response_model=CheckoutConfirmResponse)
@limiter.limit("20/minute")
def confirmCheckoutSession(
request: Request,
confirmRequest: CheckoutConfirmRequest = Body(...),
ctx: RequestContext = Depends(getRequestContext),
):
"""
Confirm Stripe Checkout success by session ID and apply credit idempotently.
This is a fallback/reconciliation path in addition to webhook processing.
"""
try:
stripe = _getStripeClient()
session = stripe.checkout.Session.retrieve(confirmRequest.sessionId)
if not session:
raise HTTPException(status_code=404, detail="Stripe Checkout Session not found")
session_dict = session.to_dict_recursive() if hasattr(session, "to_dict_recursive") else dict(session)
metadata = session_dict.get("metadata") or {}
mandate_id = metadata.get("mandateId")
user_id = metadata.get("userId") or None
if not mandate_id:
raise HTTPException(status_code=400, detail="Invalid session metadata: mandateId missing")
payment_status = session_dict.get("payment_status")
if payment_status != "paid":
raise HTTPException(status_code=409, detail=f"Payment not completed yet (payment_status={payment_status})")
billingInterface = getBillingInterface(ctx.user, mandate_id)
settings = billingInterface.getSettings(mandate_id)
if not settings:
raise HTTPException(status_code=404, detail="Billing settings not found")
billing_model = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value))
if billing_model == BillingModelEnum.PREPAY_USER:
if not user_id:
raise HTTPException(status_code=400, detail="userId required for PREPAY_USER")
if str(user_id) != str(ctx.user.id):
raise HTTPException(status_code=403, detail="Users can only confirm their own payment sessions")
if not _isMemberOfMandate(ctx, mandate_id):
raise HTTPException(status_code=403, detail="User is not a member of this mandate")
elif billing_model in [BillingModelEnum.PREPAY_MANDATE, BillingModelEnum.CREDIT_POSTPAY]:
if not _isAdminOfMandate(ctx, mandate_id):
raise HTTPException(status_code=403, detail="Mandate admin role required")
else:
raise HTTPException(status_code=400, detail=f"Cannot add credit to {billing_model.value}")
root_billing_interface = _getRootInterface()
return _creditStripeSessionIfNeeded(root_billing_interface, session_dict, eventId=None)
except HTTPException:
raise
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Error confirming checkout session {confirmRequest.sessionId}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/webhook/stripe")
async def stripeWebhook(
request: Request,
@ -768,7 +968,6 @@ async def stripeWebhook(
No JWT auth - Stripe authenticates via Stripe-Signature header.
"""
from modules.shared.configuration import APP_CONFIG
from modules.services.serviceBilling.stripeCheckout import ALLOWED_AMOUNTS_CHF
webhook_secret = APP_CONFIG.get("STRIPE_WEBHOOK_SECRET")
if not webhook_secret:
@ -792,71 +991,26 @@ async def stripeWebhook(
logger.warning(f"Stripe webhook signature verification failed: {e}")
raise HTTPException(status_code=400, detail="Invalid signature")
if event.type != "checkout.session.completed":
logger.info(f"Stripe webhook received: event={event.id}, type={event.type}")
accepted_event_types = {"checkout.session.completed", "checkout.session.async_payment_succeeded"}
if event.type not in accepted_event_types:
return {"received": True}
session = event.data.object
event_id = event.id
session_id = session.id
billingInterface = _getRootInterface()
if billingInterface.getStripeWebhookEventByEventId(event_id):
logger.info(f"Stripe event {event_id} already processed, skipping")
return {"received": True}
metadata = session.get("metadata") or {}
mandate_id = metadata.get("mandateId")
user_id = metadata.get("userId") or None
amount_chf_str = metadata.get("amountChf", "0")
if not mandate_id:
logger.error(f"Stripe webhook missing mandateId in session {session_id}")
raise HTTPException(status_code=400, detail="Invalid session metadata")
try:
amount_chf = float(amount_chf_str)
except (TypeError, ValueError):
amount_chf = None
if amount_chf is None or amount_chf not in ALLOWED_AMOUNTS_CHF:
amount_total = session.get("amount_total")
if amount_total is not None:
amount_chf = amount_total / 100.0
else:
logger.error(f"Stripe webhook invalid amount for session {session_id}")
raise HTTPException(status_code=400, detail="Invalid amount")
settings = billingInterface.getSettings(mandate_id)
if not settings:
logger.error(f"Stripe webhook: billing settings not found for mandate {mandate_id}")
raise HTTPException(status_code=404, detail="Billing settings not found")
billing_model = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value))
if billing_model == BillingModelEnum.PREPAY_USER:
if not user_id:
logger.error(f"Stripe webhook: userId required for PREPAY_USER mandate {mandate_id}")
raise HTTPException(status_code=400, detail="userId required")
account = billingInterface.getOrCreateUserAccount(mandate_id, user_id, initialBalance=0.0)
elif billing_model in [BillingModelEnum.PREPAY_MANDATE, BillingModelEnum.CREDIT_POSTPAY]:
account = billingInterface.getOrCreateMandateAccount(mandate_id, initialBalance=0.0)
else:
logger.error(f"Stripe webhook: cannot credit mandate {mandate_id} with model {billing_model}")
raise HTTPException(status_code=400, detail=f"Cannot add credit to {billing_model.value}")
transaction = BillingTransaction(
accountId=account["id"],
transactionType=TransactionTypeEnum.CREDIT,
amount=amount_chf,
description="Stripe-Zahlung",
referenceType=ReferenceTypeEnum.PAYMENT,
referenceId=session_id
session_dict = session.to_dict_recursive() if hasattr(session, "to_dict_recursive") else dict(session)
result = _creditStripeSessionIfNeeded(billingInterface, session_dict, eventId=event_id)
logger.info(
f"Stripe webhook processed session {result.sessionId}: "
f"credited={result.credited}, alreadyCredited={result.alreadyCredited}"
)
billingInterface.createTransaction(transaction)
billingInterface.createStripeWebhookEvent(event_id)
logger.info(f"Stripe webhook: credited {amount_chf} CHF to account {account['id']} (session {session_id})")
return {"received": True}

View file

@ -12,6 +12,7 @@ from modules.auth import limiter, getCurrentUser, getRequestContext, RequestCont
# Import interfaces
import modules.interfaces.interfaceDbManagement as interfaceDbManagement
from modules.datamodels.datamodelFiles import FileItem, FilePreview
from modules.datamodels.datamodelFileFolder import FileFolder
from modules.shared.attributeUtils import getModelAttributeDefinitions
from modules.datamodels.datamodelUam import User
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
@ -19,6 +20,114 @@ from modules.datamodels.datamodelPagination import PaginationParams, PaginatedRe
# Configure logger
logger = logging.getLogger(__name__)
async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user):
"""Background task: pre-scan + extraction + knowledge indexing.
Step 1: Structure Pre-Scan (AI-free) -> FileContentIndex (persisted)
Step 2: Content extraction via runExtraction -> ContentParts
Step 3: KnowledgeService.indexFile -> chunking + embedding -> Knowledge Store"""
userId = user.id if hasattr(user, "id") else str(user)
try:
mgmtInterface = interfaceDbManagement.getInterface(user)
mgmtInterface.updateFile(fileId, {"status": "processing"})
rawBytes = mgmtInterface.getFileData(fileId)
if not rawBytes:
logger.warning(f"Auto-index: no file data for {fileId}, skipping")
mgmtInterface.updateFile(fileId, {"status": "active"})
return
logger.info(f"Auto-index starting for {fileName} ({len(rawBytes)} bytes, {mimeType})")
# Step 1: Structure Pre-Scan (AI-free)
from modules.serviceCenter.services.serviceKnowledge.subPreScan import preScanDocument
contentIndex = await preScanDocument(
fileData=rawBytes,
mimeType=mimeType,
fileId=fileId,
fileName=fileName,
userId=userId,
)
logger.info(
f"Pre-scan complete for {fileName}: "
f"{contentIndex.totalObjects} objects"
)
# Persist FileContentIndex immediately
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
knowledgeDb = getKnowledgeInterface()
knowledgeDb.upsertFileContentIndex(contentIndex)
# Step 2: Content extraction (AI-free, produces ContentParts)
from modules.serviceCenter.services.serviceExtraction.subRegistry import ExtractorRegistry, ChunkerRegistry
from modules.serviceCenter.services.serviceExtraction.subPipeline import runExtraction
from modules.datamodels.datamodelExtraction import ExtractionOptions
extractorRegistry = ExtractorRegistry()
chunkerRegistry = ChunkerRegistry()
options = ExtractionOptions()
extracted = runExtraction(
extractorRegistry, chunkerRegistry,
rawBytes, fileName, mimeType, options,
)
contentObjects = []
for part in extracted.parts:
contentType = "text"
if part.typeGroup == "image":
contentType = "image"
elif part.typeGroup in ("binary", "container"):
contentType = "other"
if not part.data or not part.data.strip():
continue
contentObjects.append({
"contentObjectId": part.id,
"contentType": contentType,
"data": part.data,
"contextRef": {
"containerPath": fileName,
"location": part.label or "file",
**(part.metadata or {}),
},
})
logger.info(f"Extracted {len(contentObjects)} content objects from {fileName}")
if not contentObjects:
knowledgeDb.updateFileStatus(fileId, "indexed")
mgmtInterface.updateFile(fileId, {"status": "active"})
return
# Step 3: Knowledge indexing (chunking + embedding)
from modules.serviceCenter import getService
from modules.serviceCenter.context import ServiceCenterContext
ctx = ServiceCenterContext(user=user, mandate_id="", feature_instance_id="")
knowledgeService = getService("knowledge", ctx)
await knowledgeService.indexFile(
fileId=fileId,
fileName=fileName,
mimeType=mimeType,
userId=userId,
contentObjects=contentObjects,
structure=contentIndex.structure,
)
mgmtInterface.updateFile(fileId, {"status": "active"})
logger.info(f"Auto-index complete for file {fileId} ({fileName})")
except Exception as e:
logger.error(f"Auto-index failed for file {fileId}: {e}", exc_info=True)
try:
errMgmt = interfaceDbManagement.getInterface(user)
errMgmt.updateFile(fileId, {"status": "active"})
except Exception:
pass
# Model attributes for FileItem
fileAttributes = getModelAttributeDefinitions(FileItem)
@ -111,6 +220,7 @@ async def upload_file(
request: Request,
file: UploadFile = File(...),
workflowId: Optional[str] = Form(None),
featureInstanceId: Optional[str] = Form(None),
currentUser: User = Depends(getCurrentUser)
) -> JSONResponse:
# Add fileName property to UploadFile for consistency with backend model
@ -133,6 +243,10 @@ async def upload_file(
# Save file via LucyDOM interface in the database
fileItem, duplicateType = managementInterface.saveUploadedFile(fileContent, file.filename)
if featureInstanceId and not fileItem.featureInstanceId:
managementInterface.updateFile(fileItem.id, {"featureInstanceId": featureInstanceId})
fileItem.featureInstanceId = featureInstanceId
# Determine response message based on duplicate type
if duplicateType == "exact_duplicate":
message = f"File '{file.filename}' already exists with identical content. Reusing existing file."
@ -148,6 +262,32 @@ async def upload_file(
if workflowId:
fileMeta["workflowId"] = workflowId
# Trigger background auto-index pipeline (non-blocking)
# Also runs for duplicates in case the original was never successfully indexed
shouldIndex = duplicateType == "new_file"
if not shouldIndex:
try:
from modules.interfaces.interfaceDbKnowledge import getInterface as _getKnowledgeInterface
_kDb = _getKnowledgeInterface()
_existingIndex = _kDb.getFileContentIndex(fileItem.id)
if not _existingIndex:
shouldIndex = True
logger.info(f"Re-triggering auto-index for duplicate {fileItem.id} (not yet indexed)")
except Exception:
shouldIndex = True
if shouldIndex:
try:
import asyncio
asyncio.ensure_future(_autoIndexFile(
fileId=fileItem.id,
fileName=fileItem.fileName,
mimeType=fileItem.mimeType,
user=currentUser,
))
except Exception as indexErr:
logger.warning(f"Auto-index trigger failed (non-blocking): {indexErr}")
# Response with duplicate information
return JSONResponse({
"message": message,
@ -171,6 +311,288 @@ async def upload_file(
detail=f"Error during file upload: {str(e)}"
)
# ── Folder endpoints (MUST be before /{fileId} catch-all) ─────────────────────
@router.get("/folders", response_model=List[Dict[str, Any]])
@limiter.limit("30/minute")
def list_folders(
request: Request,
parentId: Optional[str] = Query(None, description="Parent folder ID (omit for all folders)"),
currentUser: User = Depends(getCurrentUser),
context: RequestContext = Depends(getRequestContext)
) -> List[Dict[str, Any]]:
"""List folders for the current user."""
try:
mgmt = interfaceDbManagement.getInterface(
currentUser,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
)
if parentId is not None:
return mgmt.listFolders(parentId=parentId)
return mgmt.listFolders()
except Exception as e:
logger.error(f"Error listing folders: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/folders", status_code=status.HTTP_201_CREATED)
@limiter.limit("10/minute")
def create_folder(
request: Request,
body: Dict[str, Any] = Body(...),
currentUser: User = Depends(getCurrentUser),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Create a new folder."""
name = body.get("name", "")
parentId = body.get("parentId")
if not name:
raise HTTPException(status_code=400, detail="name is required")
try:
mgmt = interfaceDbManagement.getInterface(
currentUser,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
)
return mgmt.createFolder(name=name, parentId=parentId)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Error creating folder: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.put("/folders/{folderId}")
@limiter.limit("10/minute")
def rename_folder(
request: Request,
folderId: str = Path(...),
body: Dict[str, Any] = Body(...),
currentUser: User = Depends(getCurrentUser),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Rename a folder."""
newName = body.get("name", "")
if not newName:
raise HTTPException(status_code=400, detail="name is required")
try:
mgmt = interfaceDbManagement.getInterface(
currentUser,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
)
mgmt.renameFolder(folderId, newName)
return {"success": True, "folderId": folderId, "name": newName}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Error renaming folder: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/folders/{folderId}")
@limiter.limit("10/minute")
def delete_folder(
request: Request,
folderId: str = Path(...),
recursive: bool = Query(False, description="Delete folder contents recursively"),
currentUser: User = Depends(getCurrentUser),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Delete a folder. Use recursive=true to delete non-empty folders."""
try:
mgmt = interfaceDbManagement.getInterface(
currentUser,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
)
return mgmt.deleteFolder(folderId, recursive=recursive)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Error deleting folder: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/folders/{folderId}/move")
@limiter.limit("10/minute")
def move_folder(
request: Request,
folderId: str = Path(...),
body: Dict[str, Any] = Body(...),
currentUser: User = Depends(getCurrentUser),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Move a folder to a new parent."""
targetParentId = body.get("targetParentId")
try:
mgmt = interfaceDbManagement.getInterface(
currentUser,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
)
mgmt.moveFolder(folderId, targetParentId)
return {"success": True, "folderId": folderId, "parentId": targetParentId}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Error moving folder: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/folders/{folderId}/download")
@limiter.limit("10/minute")
def download_folder(
request: Request,
folderId: str = Path(..., description="ID of the folder to download as ZIP"),
currentUser: User = Depends(getCurrentUser),
context: RequestContext = Depends(getRequestContext)
) -> Response:
"""Download a folder (including subfolders) as a ZIP archive."""
import io
import zipfile
import urllib.parse
try:
mgmt = interfaceDbManagement.getInterface(
currentUser,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
)
folder = mgmt.getFolder(folderId)
if not folder:
raise HTTPException(status_code=404, detail=f"Folder {folderId} not found")
folderName = folder.get("name", "download")
def _collectFiles(parentId: str, pathPrefix: str):
"""Recursively collect (zipPath, fileId) tuples."""
entries = []
for f in mgmt._getFilesByCurrentUser(recordFilter={"folderId": parentId}):
fname = f.get("fileName") or f.get("name") or f.get("id", "file")
entries.append((f"{pathPrefix}{fname}", f["id"]))
for sub in mgmt.listFolders(parentId=parentId):
subName = sub.get("name", sub["id"])
entries.extend(_collectFiles(sub["id"], f"{pathPrefix}{subName}/"))
return entries
fileEntries = _collectFiles(folderId, "")
if not fileEntries:
raise HTTPException(status_code=404, detail="Folder is empty")
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
for zipPath, fileId in fileEntries:
data = mgmt.getFileData(fileId)
if data:
zf.writestr(zipPath, data)
buf.seek(0)
zipBytes = buf.getvalue()
encodedName = urllib.parse.quote(f"{folderName}.zip")
return Response(
content=zipBytes,
media_type="application/zip",
headers={
"Content-Disposition": f"attachment; filename*=UTF-8''{encodedName}"
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error downloading folder as ZIP: {e}")
raise HTTPException(status_code=500, detail=f"Error downloading folder: {str(e)}")
@router.post("/batch-delete")
@limiter.limit("10/minute")
def batch_delete_items(
request: Request,
body: Dict[str, Any] = Body(...),
currentUser: User = Depends(getCurrentUser),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Batch delete files/folders with a single SQL-backed operation per type."""
fileIds = body.get("fileIds") or []
folderIds = body.get("folderIds") or []
recursiveFolders = bool(body.get("recursiveFolders", True))
if not isinstance(fileIds, list) or not isinstance(folderIds, list):
raise HTTPException(status_code=400, detail="fileIds and folderIds must be arrays")
try:
mgmt = interfaceDbManagement.getInterface(
currentUser,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
)
result = {"deletedFiles": 0, "deletedFolders": 0}
if fileIds:
fileResult = mgmt.deleteFilesBatch(fileIds)
result["deletedFiles"] += fileResult.get("deletedFiles", 0)
if folderIds:
folderResult = mgmt.deleteFoldersBatch(folderIds, recursive=recursiveFolders)
result["deletedFiles"] += folderResult.get("deletedFiles", 0)
result["deletedFolders"] += folderResult.get("deletedFolders", 0)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Error in batch delete: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/batch-move")
@limiter.limit("10/minute")
def batch_move_items(
request: Request,
body: Dict[str, Any] = Body(...),
currentUser: User = Depends(getCurrentUser),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Batch move files/folders with a single SQL-backed operation per type."""
fileIds = body.get("fileIds") or []
folderIds = body.get("folderIds") or []
targetFolderId = body.get("targetFolderId")
targetParentId = body.get("targetParentId")
if not isinstance(fileIds, list) or not isinstance(folderIds, list):
raise HTTPException(status_code=400, detail="fileIds and folderIds must be arrays")
try:
mgmt = interfaceDbManagement.getInterface(
currentUser,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
)
result = {"movedFiles": 0, "movedFolders": 0}
if fileIds:
fileResult = mgmt.moveFilesBatch(fileIds, targetFolderId=targetFolderId)
result["movedFiles"] += fileResult.get("movedFiles", 0)
if folderIds:
folderResult = mgmt.moveFoldersBatch(folderIds, targetParentId=targetParentId)
result["movedFolders"] += folderResult.get("movedFolders", 0)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Error in batch move: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ── File endpoints with path parameters (catch-all /{fileId}) ─────────────────
@router.get("/{fileId}", response_model=FileItem)
@limiter.limit("30/minute")
def get_file(
@ -418,3 +840,25 @@ def preview_file(
)
@router.post("/{fileId}/move")
@limiter.limit("10/minute")
def move_file(
request: Request,
fileId: str = Path(...),
body: Dict[str, Any] = Body(...),
currentUser: User = Depends(getCurrentUser),
context: RequestContext = Depends(getRequestContext)
) -> Dict[str, Any]:
"""Move a file to a different folder."""
targetFolderId = body.get("targetFolderId")
try:
mgmt = interfaceDbManagement.getInterface(
currentUser,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
)
mgmt.updateFile(fileId, {"folderId": targetFolderId})
return {"success": True, "fileId": fileId, "folderId": targetFolderId}
except Exception as e:
logger.error(f"Error moving file: {e}")
raise HTTPException(status_code=500, detail=str(e))

View file

@ -764,7 +764,7 @@ def send_password_link(
expiryHours = int(APP_CONFIG.get("Auth_RESET_TOKEN_EXPIRY_HOURS", "24"))
try:
from modules.services import Services
from modules.serviceHub import Services
services = Services(targetUser)
emailSubject = "PowerOn - Passwort setzen"

View file

@ -395,7 +395,7 @@ def trigger_subscription(
)
# Get messaging service from request app state
from modules.services import getInterface as getServicesInterface
from modules.serviceHub import getInterface as getServicesInterface
services = getServicesInterface(context.user, None, mandateId=str(context.mandateId))
# Konvertiere Dict zu Pydantic Model

View file

@ -87,9 +87,10 @@ CLIENT_SECRET = APP_CONFIG.get("Service_GOOGLE_CLIENT_SECRET")
REDIRECT_URI = APP_CONFIG.get("Service_GOOGLE_REDIRECT_URI")
SCOPES = [
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/userinfo.email",
"openid"
"openid",
]
@router.get("/config")
@ -488,7 +489,7 @@ async def auth_callback(code: str, state: str, request: Request, response: Respo
connection.externalUsername = user_info.get("email")
connection.externalEmail = user_info.get("email")
# Store actually granted scopes for this connection
granted_scopes_list = granted_scopes.split(" ") if granted_scopes else SCOPES
granted_scopes_list = granted_scopes if isinstance(granted_scopes, list) else (granted_scopes.split(" ") if granted_scopes else SCOPES)
connection.grantedScopes = granted_scopes_list
logger.info(f"Storing granted scopes for connection {connection_id}: {granted_scopes_list}")

View file

@ -59,6 +59,7 @@ SCOPES = [
"Mail.Send", # Send mail
"Files.ReadWrite.All", # Read and write files (SharePoint/OneDrive)
"Sites.ReadWrite.All", # Read and write SharePoint sites
"Team.ReadBasic.All", # List joined teams and channels
# Teams Bot: Meeting and chat access (requires admin consent)
"OnlineMeetings.Read", # Read user's Teams meeting details (delegated scope)
"Chat.ReadWrite", # Read and write Teams chat messages

View file

@ -12,7 +12,7 @@ from fastapi import APIRouter, HTTPException, Depends, Path, Query, Request, sta
from modules.auth import limiter, getCurrentUser
from modules.datamodels.datamodelUam import User, UserConnection
from modules.interfaces.interfaceDbApp import getInterface
from modules.services import getInterface as getServices
from modules.serviceHub import getInterface as getServices
logger = logging.getLogger(__name__)

View file

@ -102,12 +102,6 @@ def _getFeatureUiObjects(featureCode: str) -> List[Dict[str, Any]]:
elif featureCode == "realestate":
from modules.features.realEstate.mainRealEstate import UI_OBJECTS
return UI_OBJECTS
elif featureCode == "chatplayground":
from modules.features.chatplayground.mainChatplayground import UI_OBJECTS
return UI_OBJECTS
elif featureCode == "codeeditor":
from modules.features.codeeditor.mainCodeeditor import UI_OBJECTS
return UI_OBJECTS
elif featureCode == "automation":
from modules.features.automation.mainAutomation import UI_OBJECTS
return UI_OBJECTS
@ -123,8 +117,11 @@ def _getFeatureUiObjects(featureCode: str) -> List[Dict[str, Any]]:
elif featureCode == "commcoach":
from modules.features.commcoach.mainCommcoach import UI_OBJECTS
return UI_OBJECTS
elif featureCode == "workspace":
from modules.features.workspace.mainWorkspace import UI_OBJECTS
return UI_OBJECTS
else:
logger.warning(f"Unknown feature code: {featureCode}")
logger.debug(f"Skipping removed feature code: {featureCode}")
return []
except ImportError as e:
logger.error(f"Failed to import UI_OBJECTS for feature {featureCode}: {e}")

View file

@ -6,13 +6,16 @@ Replaces Azure voice services with Google Cloud Speech-to-Text and Translation
Includes WebSocket support for real-time voice streaming
"""
import asyncio
import logging
import json
import base64
from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException, Body, WebSocket, WebSocketDisconnect
import secrets
import time
from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException, Body, Query, Request, WebSocket, WebSocketDisconnect
from fastapi.responses import Response
from typing import Optional, Dict, Any, List
from modules.auth import getCurrentUser
from modules.auth import getCurrentUser, getRequestContext, RequestContext, limiter
from modules.datamodels.datamodelUam import User
from modules.interfaces.interfaceVoiceObjects import getVoiceInterface, VoiceObjects
@ -290,10 +293,11 @@ async def realtime_interpreter(
@router.post("/text-to-speech")
async def text_to_speech(
request: Request,
text: str = Form(...),
language: str = Form("de-DE"),
voice: str = Form(None),
currentUser: User = Depends(getCurrentUser)
context: RequestContext = Depends(getRequestContext),
):
"""Convert text to speech using Google Cloud Text-to-Speech."""
try:
@ -305,7 +309,20 @@ async def text_to_speech(
detail="Empty text provided for text-to-speech"
)
voiceInterface = _getVoiceInterface(currentUser)
mandateId = str(getattr(context, "mandateId", "") or "")
voiceInterface = getVoiceInterface(context.user, mandateId)
try:
from modules.serviceCenter.services.serviceBilling.mainServiceBilling import getService as getBillingService
billingService = getBillingService(context.user, mandateId)
def _billingCb(data):
priceCHF = data.get("priceCHF", 0.0)
operation = data.get("operation", "voice")
if priceCHF > 0:
billingService.recordUsage(priceCHF=priceCHF, aicoreProvider="google-voice", aicoreModel=operation, description=f"Voice {operation}")
voiceInterface.billingCallback = _billingCb
except Exception as e:
logger.warning(f"TTS billing setup skipped: {e}")
result = await voiceInterface.textToSpeech(
text=text,
languageCode=language,
@ -314,12 +331,12 @@ async def text_to_speech(
if result["success"]:
return Response(
content=result["audio_content"],
content=result["audioContent"],
media_type="audio/mpeg",
headers={
"Content-Disposition": "attachment; filename=speech.mp3",
"X-Voice-Name": result["voice_name"],
"X-Language-Code": result["language_code"]
"X-Voice-Name": result.get("voiceName", ""),
"X-Language-Code": result.get("languageCode", language),
}
)
else:
@ -533,189 +550,192 @@ async def save_voice_settings(
detail=f"Failed to save voice settings: {str(e)}"
)
# WebSocket endpoints for real-time voice streaming
# =========================================================================
# STT Streaming WebSocket — generic, used by all features
# =========================================================================
@router.websocket("/ws/realtime-interpreter")
async def websocket_realtime_interpreter(
websocket: WebSocket,
userId: str = "default",
fromLanguage: str = "de-DE",
toLanguage: str = "en-US"
_sttTokens: Dict[str, Dict[str, Any]] = {}
_STT_TOKEN_TTL = 45
def _cleanupSttTokens():
now = time.time()
expired = [t for t, p in _sttTokens.items() if p.get("expiresAt", 0) <= now]
for t in expired:
_sttTokens.pop(t, None)
@router.post("/stt/token")
@limiter.limit("60/minute")
async def createSttToken(
request: Request,
context: RequestContext = Depends(getRequestContext),
):
"""WebSocket endpoint for real-time voice interpretation"""
connectionId = f"realtime_{userId}_{fromLanguage}_{toLanguage}"
"""Issue a short-lived single-use token for the STT streaming WebSocket."""
_cleanupSttTokens()
token = secrets.token_urlsafe(32)
_sttTokens[token] = {
"userId": str(context.user.id),
"mandateId": str(getattr(context, "mandateId", "") or ""),
"expiresAt": time.time() + _STT_TOKEN_TTL,
}
return {"wsToken": token, "expiresInSeconds": _STT_TOKEN_TTL}
@router.websocket("/stt/stream")
async def sttStream(
websocket: WebSocket,
wsToken: Optional[str] = Query(None),
):
"""
Generic STT streaming WebSocket.
Protocol:
Client sends JSON:
{"type": "open", "language": "de-DE"}
{"type": "audio", "chunk": "<base64>"}
{"type": "close"}
Server sends JSON:
{"type": "interim", "text": "..."}
{"type": "final", "text": "...", "confidence": 0.95}
{"type": "error", "message": "..."}
{"type": "closed"}
"""
await websocket.accept()
# --- authenticate via wsToken ---
if not wsToken:
await websocket.send_json({"type": "error", "code": "ws_token_required", "message": "wsToken query param required"})
await websocket.close(code=1008)
return
_cleanupSttTokens()
tokenPayload = _sttTokens.pop(wsToken, None)
if not tokenPayload:
await websocket.send_json({"type": "error", "code": "ws_token_invalid", "message": "Invalid or expired wsToken"})
await websocket.close(code=1008)
return
tokenUserId = tokenPayload["userId"]
tokenMandateId = tokenPayload.get("mandateId", "")
# Resolve real user for billing
from modules.interfaces.interfaceDbApp import getRootInterface
rootInterface = getRootInterface()
currentUser = rootInterface.getUser(tokenUserId)
if not currentUser:
await websocket.send_json({"type": "error", "code": "user_not_found", "message": "User not found"})
await websocket.close(code=1008)
return
# --- billing pre-flight ---
billingService = None
try:
await manager.connect(websocket, connectionId)
# Send connection confirmation
await manager.sendPersonalMessage({
"type": "connected",
"connection_id": connectionId,
"message": "Connected to real-time interpreter"
}, websocket)
# Initialize voice interface
voiceInterface = _getVoiceInterface(User(id=userId))
while True:
# Receive message from client
data = await websocket.receive_text()
message = json.loads(data)
if message["type"] == "audio_chunk":
# Process audio chunk
try:
# Decode base64 audio data
audioData = base64.b64decode(message["data"])
# For now, just acknowledge receipt
# In a full implementation, this would:
# 1. Buffer audio chunks
# 2. Process with Google Cloud Speech-to-Text streaming
# 3. Send partial results back
# 4. Handle translation
await manager.sendPersonalMessage({
"type": "audio_received",
"chunk_size": len(audioData),
"timestamp": message.get("timestamp")
}, websocket)
from modules.serviceCenter.services.serviceBilling.mainServiceBilling import getService as getBillingService
billingService = getBillingService(currentUser, tokenMandateId)
billingCheck = billingService.checkBalance(0.0)
if not billingCheck.allowed:
await websocket.send_json({"type": "error", "code": "billing_insufficient", "message": "Insufficient balance for voice services"})
await websocket.close(code=1008)
return
except Exception as e:
logger.error(f"Error processing audio chunk: {e}")
await manager.send_personal_message({
"type": "error",
"error": f"Failed to process audio: {str(e)}"
}, websocket)
logger.warning(f"STT billing pre-flight skipped: {e}")
elif message["type"] == "ping":
# Respond to ping
await manager.sendPersonalMessage({
"type": "pong",
"timestamp": message.get("timestamp")
}, websocket)
audioQueue: asyncio.Queue = asyncio.Queue()
language = "de-DE"
streamingTask: Optional[asyncio.Task] = None
voiceInterface: Optional[VoiceObjects] = None
async def _sendJson(payload: Dict[str, Any]) -> bool:
try:
await websocket.send_json(payload)
return True
except Exception:
return False
async def _runStreaming():
nonlocal voiceInterface
voiceInterface = getVoiceInterface(currentUser, tokenMandateId)
if billingService:
def _billingCb(data):
priceCHF = data.get("priceCHF", 0.0)
operation = data.get("operation", "voice")
if priceCHF > 0:
billingService.recordUsage(
priceCHF=priceCHF,
aicoreProvider="google-voice",
aicoreModel=operation,
description=f"Voice {operation}",
)
voiceInterface.billingCallback = _billingCb
try:
async for event in voiceInterface.streamingSpeechToText(audioQueue, language):
if event.get("reconnectRequired"):
await _sendJson({"type": "reconnect_required"})
return
if event.get("isFinal"):
if event.get("transcript"):
await _sendJson({"type": "final", "text": event["transcript"], "confidence": event.get("confidence", 0.0)})
else:
logger.warning(f"Unknown message type: {message['type']}")
except WebSocketDisconnect:
manager.disconnect(websocket, connectionId)
logger.info(f"Client disconnected: {connectionId}")
if event.get("transcript"):
await _sendJson({"type": "interim", "text": event["transcript"]})
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"WebSocket error: {e}")
manager.disconnect(websocket, connectionId)
@router.websocket("/ws/speech-to-text")
async def websocket_speech_to_text(
websocket: WebSocket,
userId: str = "default",
language: str = "de-DE"
):
"""WebSocket endpoint for real-time speech-to-text"""
connectionId = f"stt_{userId}_{language}"
logger.error(f"STT streaming error: {e}")
await _sendJson({"type": "error", "message": str(e)})
try:
await manager.connect(websocket, connectionId)
await manager.sendPersonalMessage({
"type": "connected",
"connection_id": connectionId,
"message": "Connected to speech-to-text"
}, websocket)
# Initialize voice interface
voiceInterface = _getVoiceInterface(User(id=userId))
await _sendJson({"type": "status", "label": "STT stream connected"})
while True:
data = await websocket.receive_text()
message = json.loads(data)
raw = await websocket.receive_text()
msg = json.loads(raw)
msgType = (msg.get("type") or "").strip()
if message["type"] == "audio_chunk":
if msgType == "open":
language = msg.get("language") or "de-DE"
if streamingTask and not streamingTask.done():
await audioQueue.put((b"", True))
streamingTask.cancel()
audioQueue = asyncio.Queue()
streamingTask = asyncio.create_task(_runStreaming())
await _sendJson({"type": "status", "label": "Listening..."})
elif msgType == "audio":
chunkB64 = msg.get("chunk")
if not chunkB64:
continue
chunkBytes = base64.b64decode(chunkB64)
if len(chunkBytes) > 400_000:
await _sendJson({"type": "error", "code": "chunk_too_large", "message": "Audio chunk too large"})
continue
await audioQueue.put((chunkBytes, False))
elif msgType == "close":
await audioQueue.put((b"", True))
if streamingTask:
try:
audioData = base64.b64decode(message["data"])
await asyncio.wait_for(streamingTask, timeout=10.0)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
await _sendJson({"type": "closed"})
await websocket.close()
break
# Process audio chunk
# This would integrate with Google Cloud Speech-to-Text streaming API
await manager.sendPersonalMessage({
"type": "transcription_result",
"text": "Audio chunk received", # Placeholder
"confidence": 0.95,
"is_final": False
}, websocket)
except Exception as e:
logger.error(f"Error processing audio: {e}")
await manager.sendPersonalMessage({
"type": "error",
"error": f"Failed to process audio: {str(e)}"
}, websocket)
elif message["type"] == "ping":
await manager.sendPersonalMessage({
"type": "pong",
"timestamp": message.get("timestamp")
}, websocket)
elif msgType == "ping":
await _sendJson({"type": "pong"})
except WebSocketDisconnect:
manager.disconnect(websocket, connectionId)
logger.info(f"STT WebSocket disconnected: userId={tokenUserId}")
except Exception as e:
logger.error(f"WebSocket error: {e}")
manager.disconnect(websocket, connectionId)
@router.websocket("/ws/text-to-speech")
async def websocket_text_to_speech(
websocket: WebSocket,
userId: str = "default",
language: str = "de-DE",
voice: str = "de-DE-Wavenet-A"
):
"""WebSocket endpoint for real-time text-to-speech"""
connectionId = f"tts_{userId}_{language}_{voice}"
logger.error(f"STT WebSocket error: {e}", exc_info=True)
try:
await manager.connect(websocket, connectionId)
await manager.sendPersonalMessage({
"type": "connected",
"connection_id": connectionId,
"message": "Connected to text-to-speech"
}, websocket)
while True:
data = await websocket.receive_text()
message = json.loads(data)
if message["type"] == "text_to_speak":
try:
text = message["text"]
# Process text-to-speech
# This would integrate with Google Cloud Text-to-Speech API
# For now, send a placeholder response
await manager.sendPersonalMessage({
"type": "audio_data",
"audio": "base64_encoded_audio_here", # Placeholder
"format": "mp3"
}, websocket)
except Exception as e:
logger.error(f"Error processing text-to-speech: {e}")
await manager.sendPersonalMessage({
"type": "error",
"error": f"Failed to process text: {str(e)}"
}, websocket)
elif message["type"] == "ping":
await manager.sendPersonalMessage({
"type": "pong",
"timestamp": message.get("timestamp")
}, websocket)
except WebSocketDisconnect:
manager.disconnect(websocket, connectionId)
except Exception as e:
logger.error(f"WebSocket error: {e}")
manager.disconnect(websocket, connectionId)
await websocket.send_json({"type": "error", "message": str(e)})
except Exception:
pass
finally:
await audioQueue.put((b"", True))
if streamingTask and not streamingTask.done():
streamingTask.cancel()

View file

@ -120,6 +120,49 @@ class RbacCatalogService:
return [obj for obj in self._dataObjects.values() if obj["featureCode"] == featureCode]
return list(self._dataObjects.values())
def getAccessibleDataObjects(
self,
featureCode: str,
rbacInstance,
user,
mandateId: str,
featureInstanceId: str,
) -> List[Dict[str, Any]]:
"""Get DATA objects filtered by RBAC read permission for the user.
Args:
featureCode: Feature code to filter by
rbacInstance: RbacClass instance for permission checks
user: User object
mandateId: Mandate scope
featureInstanceId: Feature instance scope
"""
from modules.datamodels.datamodelRbac import AccessRuleContext
allObjects = self.getDataObjects(featureCode)
accessible = []
for obj in allObjects:
objectKey = obj.get("objectKey", "")
try:
perms = rbacInstance.getUserPermissions(
user=user,
context=AccessRuleContext.DATA,
item=objectKey,
mandateId=mandateId,
featureInstanceId=featureInstanceId,
)
if perms.view or perms.read.value != "n":
accessible.append(obj)
except Exception:
pass
return accessible
def getFeaturesWithDataObjects(self) -> List[str]:
"""Get feature codes that have at least one registered DATA object."""
codes = set()
for obj in self._dataObjects.values():
codes.add(obj["featureCode"])
return list(codes)
def getAllObjects(self, featureCode: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get all RBAC objects (UI + RESOURCE + DATA), optionally filtered by feature."""
return self.getUiObjects(featureCode) + self.getResourceObjects(featureCode) + self.getDataObjects(featureCode)

View file

@ -26,7 +26,6 @@ logger = logging.getLogger(__name__)
def getService(
key: str,
context: ServiceCenterContext,
legacy_hub: Optional[Any] = None,
) -> Any:
"""
Get a service instance by key for the given context.
@ -34,14 +33,13 @@ def getService(
Args:
key: Service key (e.g., "web", "extraction", "utils")
context: ServiceCenterContext with user, mandate_id, feature_instance_id, workflow
legacy_hub: Optional legacy Services instance for fallback when service not yet migrated
Returns:
Service instance
"""
cache = get_resolution_cache()
resolving = set()
return resolve(key, context, cache, resolving, legacy_hub=legacy_hub)
return resolve(key, context, cache, resolving)
def preWarm(service_keys: Optional[List[str]] = None) -> None:

View file

@ -22,6 +22,8 @@ class EventManager:
"""Initialize the event manager."""
self._queues: Dict[str, asyncio.Queue] = {}
self._cleanup_tasks: Dict[str, asyncio.Task] = {}
self._agent_tasks: Dict[str, asyncio.Task] = {}
self._cancelled: Dict[str, bool] = {}
def create_queue(self, workflow_id: str) -> asyncio.Queue:
"""
@ -33,9 +35,22 @@ class EventManager:
Returns:
Async queue for events
"""
if workflow_id in self._cleanup_tasks:
self._cleanup_tasks[workflow_id].cancel()
del self._cleanup_tasks[workflow_id]
logger.debug(f"Cancelled pending cleanup for workflow {workflow_id}")
if workflow_id not in self._queues:
self._queues[workflow_id] = asyncio.Queue()
logger.debug(f"Created event queue for workflow {workflow_id}")
else:
old = self._queues[workflow_id]
while not old.empty():
try:
old.get_nowait()
except asyncio.QueueEmpty:
break
logger.debug(f"Reusing event queue for workflow {workflow_id} (drained stale events)")
return self._queues[workflow_id]
def get_queue(self, workflow_id: str) -> Optional[asyncio.Queue]:
@ -62,6 +77,31 @@ class EventManager:
"""
return workflow_id in self._queues
def register_agent_task(self, workflow_id: str, task: asyncio.Task) -> None:
"""Register the asyncio Task running the agent for a workflow."""
self._agent_tasks[workflow_id] = task
self._cancelled.pop(workflow_id, None)
def is_cancelled(self, workflow_id: str) -> bool:
"""Check if a workflow has been cancelled."""
return self._cancelled.get(workflow_id, False)
async def cancel_agent(self, workflow_id: str) -> bool:
"""Cancel the running agent task for a workflow. Returns True if cancelled."""
self._cancelled[workflow_id] = True
task = self._agent_tasks.pop(workflow_id, None)
if task and not task.done():
task.cancel()
logger.info(f"Cancelled agent task for workflow {workflow_id}")
return True
logger.debug(f"No running agent task found for workflow {workflow_id}")
return False
def _unregister_agent_task(self, workflow_id: str) -> None:
"""Remove the agent task reference after completion."""
self._agent_tasks.pop(workflow_id, None)
self._cancelled.pop(workflow_id, None)
async def emit_event(
self,
context_id: str,
@ -97,6 +137,7 @@ class EventManager:
try:
await queue.put(event)
if event_type not in ("chunk",):
logger.debug(f"Emitted {event_type} event for workflow {context_id}")
except Exception as e:
logger.error(f"Error emitting event for workflow {context_id}: {e}", exc_info=True)

View file

@ -98,6 +98,20 @@ IMPORTABLE_SERVICES: Dict[str, Dict[str, Any]] = {
"objectKey": "service.neutralization",
"label": {"en": "Neutralization", "de": "Neutralisierung", "fr": "Neutralisation"},
},
"agent": {
"module": "modules.serviceCenter.services.serviceAgent.mainServiceAgent",
"class": "AgentService",
"dependencies": ["ai", "chat", "utils", "extraction", "billing", "streaming", "knowledge"],
"objectKey": "service.agent",
"label": {"en": "Agent", "de": "Agent", "fr": "Agent"},
},
"knowledge": {
"module": "modules.serviceCenter.services.serviceKnowledge.mainServiceKnowledge",
"class": "KnowledgeService",
"dependencies": ["ai"],
"objectKey": "service.knowledge",
"label": {"en": "Knowledge Store", "de": "Wissensspeicher", "fr": "Base de connaissances"},
},
}
# RBAC objects for service-level access control (for catalog registration)

View file

@ -2,7 +2,7 @@
# All rights reserved.
"""
Service Center Resolver.
Resolution logic, dependency injection, and optional legacy fallback.
Resolution logic and dependency injection for service instantiation.
"""
import importlib
@ -14,7 +14,6 @@ from modules.serviceCenter.registry import CORE_SERVICES, IMPORTABLE_SERVICES
logger = logging.getLogger(__name__)
# Type for get_service callable passed to services
GetServiceFunc = Callable[[str], Any]
@ -29,50 +28,15 @@ def _load_service_class(module_path: str, class_name: str):
return getattr(module, class_name)
def _create_legacy_hub(ctx: ServiceCenterContext) -> Any:
"""Create legacy Services instance for fallback when service not yet migrated."""
from modules.services import getInterface
return getInterface(
ctx.user,
workflow=ctx.workflow,
mandateId=ctx.mandate_id,
featureInstanceId=ctx.feature_instance_id,
)
def _get_from_legacy(legacy_hub: Any, key: str) -> Any:
"""Map service key to legacy hub attribute (for fallback when service center module fails)."""
key_to_attr = {
"utils": "utils",
"security": "security",
"streaming": "streaming",
"ticket": "ticket",
"messaging": "messaging",
"billing": "billing",
"sharepoint": "sharepoint",
"chat": "chat",
"extraction": "extraction",
"generation": "generation",
"ai": "ai",
"web": "web",
"neutralization": "neutralization",
}
attr = key_to_attr.get(key)
if attr and hasattr(legacy_hub, attr):
return getattr(legacy_hub, attr)
return None
def resolve(
key: str,
context: ServiceCenterContext,
cache: Dict[str, Any],
resolving: Set[str],
legacy_hub: Optional[Any] = None,
) -> Any:
"""
Resolve a service by key. Uses cache, resolves dependencies recursively.
Falls back to legacy_hub if service module cannot be loaded.
Raises KeyError if the service is not registered.
"""
cache_key = f"{_make_context_id(context)}_{key}"
if cache_key in cache:
@ -82,12 +46,10 @@ def resolve(
raise RuntimeError(f"Circular dependency detected for service: {key}")
def get_service(dep_key: str) -> Any:
return resolve(dep_key, context, cache, resolving, legacy_hub)
return resolve(dep_key, context, cache, resolving)
# Try core first
if key in CORE_SERVICES:
spec = CORE_SERVICES[key]
try:
spec = CORE_SERVICES.get(key) or IMPORTABLE_SERVICES.get(key)
if spec:
cls = _load_service_class(spec["module"], spec["class"])
resolving.add(key)
try:
@ -98,43 +60,6 @@ def resolve(
instance = cls(context, get_service)
cache[cache_key] = instance
return instance
except (ImportError, ModuleNotFoundError, AttributeError) as e:
logger.debug(f"Could not load core service '{key}' from service center: {e}")
if legacy_hub:
fallback = _get_from_legacy(legacy_hub, key)
if fallback is not None:
cache[cache_key] = fallback
return fallback
raise
# Try importable
if key in IMPORTABLE_SERVICES:
spec = IMPORTABLE_SERVICES[key]
try:
cls = _load_service_class(spec["module"], spec["class"])
resolving.add(key)
try:
for dep in spec.get("dependencies", []):
get_service(dep)
finally:
resolving.discard(key)
instance = cls(context, get_service)
cache[cache_key] = instance
return instance
except (ImportError, ModuleNotFoundError, AttributeError) as e:
logger.debug(f"Could not load importable service '{key}' from service center: {e}")
if legacy_hub:
fallback = _get_from_legacy(legacy_hub, key)
if fallback is not None:
cache[cache_key] = fallback
return fallback
raise
if legacy_hub:
fallback = _get_from_legacy(legacy_hub, key)
if fallback is not None:
cache[cache_key] = fallback
return fallback
raise KeyError(f"Unknown service: {key}")

View file

@ -0,0 +1,3 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""serviceAgent: AI Agent with ReAct loop and native function calling."""

View file

@ -0,0 +1,162 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""ActionToolAdapter: wraps existing workflow actions (dynamicMode=True) as agent tools."""
import logging
from typing import Dict, Any, List, Optional
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
ToolDefinition, ToolResult
)
from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry
logger = logging.getLogger(__name__)
class ActionToolAdapter:
"""Wraps existing Workflow-Actions as Agent-Tools.
Iterates over discovered methods, finds actions with dynamicMode=True,
and registers them in the ToolRegistry with a compound name (method.action).
"""
def __init__(self, actionExecutor):
self._actionExecutor = actionExecutor
self._registeredTools: List[str] = []
def registerAll(self, toolRegistry: ToolRegistry):
"""Discover and register all dynamicMode actions as agent tools."""
from modules.workflows.processing.shared.methodDiscovery import methods
registered = 0
for methodName, methodInfo in methods.items():
if not methodName[0].isupper():
continue
shortName = methodName.replace("Method", "").lower()
methodInstance = methodInfo["instance"]
for actionName, actionInfo in methodInfo["actions"].items():
actionDef = methodInstance._actions.get(actionName)
if not actionDef or not getattr(actionDef, "dynamicMode", False):
continue
compoundName = f"{shortName}_{actionName}"
toolDef = _buildToolDefinition(compoundName, actionDef, actionInfo)
handler = _createDispatchHandler(self._actionExecutor, shortName, actionName)
toolRegistry.registerFromDefinition(toolDef, handler)
self._registeredTools.append(compoundName)
registered += 1
logger.info(f"ActionToolAdapter: registered {registered} tools from workflow actions")
@property
def registeredTools(self) -> List[str]:
"""Names of all tools registered by this adapter."""
return list(self._registeredTools)
def _buildToolDefinition(compoundName: str, actionDef, actionInfo: Dict[str, Any]) -> ToolDefinition:
"""Build a ToolDefinition from a WorkflowActionDefinition."""
parameters = _convertParameterSchema(actionInfo.get("parameters", {}))
return ToolDefinition(
name=compoundName,
description=actionDef.description or actionInfo.get("description", ""),
parameters=parameters,
readOnly=False
)
def _convertParameterSchema(actionParams: Dict[str, Any]) -> Dict[str, Any]:
"""Convert workflow action parameter schema to JSON Schema for tool definitions."""
properties = {}
required = []
for paramName, paramInfo in actionParams.items():
paramType = paramInfo.get("type", "str") if isinstance(paramInfo, dict) else "str"
paramDesc = paramInfo.get("description", "") if isinstance(paramInfo, dict) else ""
paramRequired = paramInfo.get("required", False) if isinstance(paramInfo, dict) else False
jsonType = _pythonTypeToJsonType(paramType)
properties[paramName] = {
"type": jsonType,
"description": paramDesc
}
if paramRequired:
required.append(paramName)
return {
"type": "object",
"properties": properties,
"required": required
}
def _pythonTypeToJsonType(pythonType: str) -> str:
"""Map Python type strings to JSON Schema types."""
mapping = {
"str": "string",
"int": "integer",
"float": "number",
"bool": "boolean",
"list": "array",
"dict": "object",
"List[str]": "array",
"List[int]": "array",
"List[dict]": "array",
"Dict[str, Any]": "object",
}
return mapping.get(pythonType, "string")
def _createDispatchHandler(actionExecutor, methodName: str, actionName: str):
"""Create an async handler that dispatches to the ActionExecutor."""
async def _handler(args: Dict[str, Any], context: Dict[str, Any]) -> ToolResult:
try:
result = await actionExecutor.executeAction(methodName, actionName, args)
data = _formatActionResult(result)
return ToolResult(
toolCallId="",
toolName=f"{methodName}_{actionName}",
success=result.success,
data=data,
error=result.error
)
except Exception as e:
logger.error(f"ActionToolAdapter dispatch failed for {methodName}_{actionName}: {e}")
return ToolResult(
toolCallId="",
toolName=f"{methodName}_{actionName}",
success=False,
error=str(e)
)
return _handler
def _formatActionResult(result) -> str:
"""Format an ActionResult into a text representation for the agent."""
parts = []
if result.resultLabel:
parts.append(f"Result: {result.resultLabel}")
if result.error:
parts.append(f"Error: {result.error}")
if result.documents:
parts.append(f"Documents ({len(result.documents)}):")
for doc in result.documents:
docName = getattr(doc, "documentName", "unnamed")
docType = getattr(doc, "mimeType", "unknown")
parts.append(f" - {docName} ({docType})")
docData = getattr(doc, "documentData", None)
if docData and isinstance(docData, str) and len(docData) < 2000:
parts.append(f" Content: {docData[:2000]}")
if not parts:
parts.append("Action completed successfully." if result.success else "Action failed.")
return "\n".join(parts)

View file

@ -0,0 +1,507 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Agent loop: ReAct pattern with native function calling, budget control, and error handling."""
import asyncio
import logging
import time
import json
import re
from typing import List, Dict, Any, Optional, AsyncGenerator, Callable, Awaitable
from modules.datamodels.datamodelAi import (
AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum
)
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
AgentState, AgentStatusEnum, AgentConfig, AgentEvent, AgentEventTypeEnum,
ToolCallRequest, ToolResult, ToolCallLog, AgentRoundLog, AgentTrace
)
from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry
from modules.serviceCenter.services.serviceAgent.conversationManager import (
ConversationManager, buildSystemPrompt
)
from modules.shared.timeUtils import getUtcTimestamp
from modules.shared.jsonUtils import closeJsonStructures
logger = logging.getLogger(__name__)
async def runAgentLoop(
prompt: str,
toolRegistry: ToolRegistry,
config: AgentConfig,
aiCallFn: Callable[[AiCallRequest], Awaitable[AiCallResponse]],
getWorkflowCostFn: Callable[[], Awaitable[float]],
workflowId: str,
userId: str = "",
featureInstanceId: str = "",
buildRagContextFn: Callable[..., Awaitable[str]] = None,
mandateId: str = "",
aiCallStreamFn: Callable = None,
userLanguage: str = "",
conversationHistory: List[Dict[str, Any]] = None,
) -> AsyncGenerator[AgentEvent, None]:
"""Run the agent loop. Yields AgentEvent for each step (SSE-ready).
Args:
prompt: User prompt
toolRegistry: Registry with available tools
config: Agent configuration (maxRounds, maxCostCHF, etc.)
aiCallFn: Function to call the AI (wraps serviceAi.callAi with billing)
getWorkflowCostFn: Function to get current workflow cost
workflowId: Workflow ID for tracking
userId: User ID for tracing
featureInstanceId: Feature instance ID for tracing
buildRagContextFn: Optional async function to build RAG context before each round
mandateId: Mandate ID for RAG scoping
userLanguage: ISO 639-1 language code for agent responses
conversationHistory: Prior messages [{role, content/message}] for follow-up context
"""
state = AgentState(workflowId=workflowId, maxRounds=config.maxRounds)
trace = AgentTrace(
workflowId=workflowId, userId=userId,
featureInstanceId=featureInstanceId
)
tools = toolRegistry.getTools()
toolDefinitions = toolRegistry.formatToolsForFunctionCalling()
# Text-based tool descriptions are ONLY used as fallback when native function
# calling is unavailable. Including both creates conflicting instructions
# (text ```tool_call format vs native tool_use blocks) and can cause the model
# to respond with plain text instead of actual tool calls.
toolsText = "" if toolDefinitions else toolRegistry.formatToolsForPrompt()
systemPrompt = buildSystemPrompt(tools, toolsText, userLanguage=userLanguage)
conversation = ConversationManager(systemPrompt)
if conversationHistory:
conversation.loadHistory(conversationHistory)
conversation.addUserMessage(prompt)
while state.status == AgentStatusEnum.RUNNING and state.currentRound < state.maxRounds:
await asyncio.sleep(0)
state.currentRound += 1
roundStartTime = time.time()
roundLog = AgentRoundLog(roundNumber=state.currentRound)
# RAG context injection (before each round for fresh relevance)
if buildRagContextFn:
try:
latestUserMsg = ""
for msg in reversed(conversation.messages):
if msg.get("role") == "user":
latestUserMsg = msg.get("content", "")
break
ragContext = await buildRagContextFn(
currentPrompt=latestUserMsg or prompt,
workflowId=workflowId,
userId=userId,
featureInstanceId=featureInstanceId,
mandateId=mandateId,
)
if ragContext:
conversation.injectRagContext(ragContext)
except Exception as ragErr:
logger.warning(f"RAG context injection failed (non-blocking): {ragErr}")
# Budget check
budgetExceeded = await _checkBudget(config, getWorkflowCostFn)
if budgetExceeded:
state.status = AgentStatusEnum.BUDGET_EXCEEDED
state.abortReason = "Workflow cost budget exceeded"
yield AgentEvent(
type=AgentEventTypeEnum.FINAL,
content=_buildProgressSummary(state, "Budget exceeded. Here is the progress so far.")
)
break
logger.info(f"Agent round {state.currentRound}/{state.maxRounds} for workflow {workflowId} (tools={state.totalToolCalls}, cost={state.totalCostCHF:.4f})")
yield AgentEvent(
type=AgentEventTypeEnum.AGENT_PROGRESS,
data={
"round": state.currentRound,
"maxRounds": state.maxRounds,
"totalAiCalls": state.totalAiCalls,
"totalToolCalls": state.totalToolCalls,
"costCHF": state.totalCostCHF
}
)
# Progressive summarization
if conversation.needsSummarization(state.currentRound):
async def _summarizeCall(summaryPrompt: str) -> str:
req = AiCallRequest(
prompt=summaryPrompt,
options=AiCallOptions(operationType=OperationTypeEnum.DATA_ANALYSE)
)
resp = await aiCallFn(req)
state.totalCostCHF += resp.priceCHF
state.totalAiCalls += 1
return resp.content
await conversation.summarize(state.currentRound, _summarizeCall)
# AI call
aiRequest = AiCallRequest(
prompt="",
options=AiCallOptions(
operationType=OperationTypeEnum.AGENT,
temperature=config.temperature
),
messages=conversation.messages,
tools=toolDefinitions
)
try:
aiResponse = None
streamedText = ""
isFirstChunkOfRound = True
if aiCallStreamFn:
async for chunk in aiCallStreamFn(aiRequest):
if isinstance(chunk, str):
if isFirstChunkOfRound and state.currentRound > 1:
chunk = "\n\n" + chunk
isFirstChunkOfRound = False
elif isFirstChunkOfRound:
isFirstChunkOfRound = False
streamedText += chunk
yield AgentEvent(type=AgentEventTypeEnum.CHUNK, content=chunk)
else:
aiResponse = chunk
if aiResponse is None:
raise RuntimeError("Stream ended without final AiCallResponse")
else:
aiResponse = await aiCallFn(aiRequest)
except Exception as e:
logger.error(f"AI call failed in round {state.currentRound}: {e}", exc_info=True)
state.status = AgentStatusEnum.ERROR
state.abortReason = f"AI call error: {e}"
yield AgentEvent(type=AgentEventTypeEnum.ERROR, content=str(e))
break
state.totalAiCalls += 1
state.totalCostCHF += aiResponse.priceCHF
state.totalProcessingTime += aiResponse.processingTime
roundLog.aiModel = aiResponse.modelName
roundLog.costCHF = aiResponse.priceCHF
if aiResponse.errorCount > 0:
state.status = AgentStatusEnum.ERROR
state.abortReason = f"AI returned error: {aiResponse.content}"
yield AgentEvent(type=AgentEventTypeEnum.ERROR, content=aiResponse.content)
break
# Parse response for tool calls
toolCalls = _parseToolCalls(aiResponse)
textContent = _extractTextContent(aiResponse)
logger.debug(
f"Round {state.currentRound} AI response: model={aiResponse.modelName}, "
f"toolCalls={len(toolCalls)}, nativeToolCalls={'yes' if aiResponse.toolCalls else 'no'}, "
f"contentLen={len(aiResponse.content)}, streamedLen={len(streamedText)}"
)
# Empty response (no content, no tool calls) = model returned nothing useful.
# Burn the round but let the loop continue so the next iteration can retry
# (the failover mechanism in the AI layer will try alternative models).
if not toolCalls and not textContent and not streamedText:
logger.warning(
f"Round {state.currentRound}: AI returned empty response "
f"(model={aiResponse.modelName}). Retrying next round."
)
conversation.addUserMessage(
"Your previous response was empty. Please use the available tools "
"to accomplish the task. Start by planning the steps, then call the "
"appropriate tools."
)
roundLog.durationMs = int((time.time() - roundStartTime) * 1000)
trace.rounds.append(roundLog)
continue
if textContent and not streamedText:
yield AgentEvent(type=AgentEventTypeEnum.MESSAGE, content=textContent)
if not toolCalls:
state.status = AgentStatusEnum.COMPLETED
conversation.addAssistantMessage(aiResponse.content)
roundLog.durationMs = int((time.time() - roundStartTime) * 1000)
trace.rounds.append(roundLog)
yield AgentEvent(type=AgentEventTypeEnum.FINAL, content=textContent or aiResponse.content)
break
# Add assistant message with tool calls to conversation
assistantToolCalls = _formatAssistantToolCalls(toolCalls)
conversation.addAssistantMessage(textContent or "", assistantToolCalls)
# Execute tool calls
for tc in toolCalls:
yield AgentEvent(
type=AgentEventTypeEnum.TOOL_CALL,
data={"toolName": tc.name, "args": tc.args}
)
results = await _executeToolCalls(toolCalls, toolRegistry, {
"workflowId": workflowId,
"userId": userId,
"featureInstanceId": featureInstanceId,
"mandateId": mandateId,
})
state.totalToolCalls += len(results)
for result in results:
roundLog.toolCalls.append(ToolCallLog(
toolName=result.toolName,
args=next((tc.args for tc in toolCalls if tc.id == result.toolCallId), {}),
success=result.success,
durationMs=result.durationMs,
error=result.error,
resultData=result.data[:300] if result.data else "",
))
if not result.success:
logger.warning(f"Tool '{result.toolName}' failed: {result.error}")
yield AgentEvent(
type=AgentEventTypeEnum.TOOL_RESULT,
data={
"toolName": result.toolName,
"success": result.success,
"data": result.data[:500] if result.data else "",
"error": result.error
}
)
if result.sideEvents:
for sideEvt in result.sideEvents:
evtType = sideEvt.get("type", "")
try:
evtEnum = AgentEventTypeEnum(evtType)
except (ValueError, KeyError):
continue
yield AgentEvent(
type=evtEnum,
data=sideEvt.get("data"),
content=sideEvt.get("content"),
)
# Add tool results to conversation
toolResultMessages = [
{"toolCallId": r.toolCallId, "toolName": r.toolName,
"content": r.data if r.success else f"Error: {r.error}"}
for r in results
]
conversation.addToolResults(toolResultMessages)
roundLog.durationMs = int((time.time() - roundStartTime) * 1000)
trace.rounds.append(roundLog)
# maxRounds reached
if state.currentRound >= state.maxRounds and state.status == AgentStatusEnum.RUNNING:
state.status = AgentStatusEnum.MAX_ROUNDS_REACHED
state.abortReason = f"Maximum rounds ({state.maxRounds}) reached"
yield AgentEvent(
type=AgentEventTypeEnum.FINAL,
content=_buildProgressSummary(state, "Maximum rounds reached.")
)
# Agent summary
trace.completedAt = getUtcTimestamp()
trace.status = state.status
trace.totalRounds = state.currentRound
trace.totalToolCalls = state.totalToolCalls
trace.totalCostCHF = state.totalCostCHF
trace.abortReason = state.abortReason
artifactSummary = _buildArtifactSummary(trace.rounds)
yield AgentEvent(
type=AgentEventTypeEnum.AGENT_SUMMARY,
data={
"rounds": state.currentRound,
"totalAiCalls": state.totalAiCalls,
"totalToolCalls": state.totalToolCalls,
"costCHF": round(state.totalCostCHF, 4),
"processingTime": round(state.totalProcessingTime, 2),
"status": state.status.value,
"abortReason": state.abortReason,
"artifacts": artifactSummary,
}
)
async def _checkBudget(config: AgentConfig,
getWorkflowCostFn: Callable[[], Awaitable[float]]) -> bool:
"""Check if workflow budget is exceeded. Returns True if exceeded."""
if config.maxCostCHF is None:
return False
try:
currentCost = await getWorkflowCostFn()
return currentCost > config.maxCostCHF
except Exception as e:
logger.warning(f"Could not check workflow cost: {e}")
return False
async def _executeToolCalls(toolCalls: List[ToolCallRequest],
toolRegistry: ToolRegistry,
context: Dict[str, Any]) -> List[ToolResult]:
"""Execute tool calls: readOnly tools in parallel, others sequentially.
Tool calls with _parseError (truncated JSON from LLM) are short-circuited
with an error result so the agent can retry.
"""
readOnlyCalls = [tc for tc in toolCalls if toolRegistry.isReadOnly(tc.name)]
writeCalls = [tc for tc in toolCalls if not toolRegistry.isReadOnly(tc.name)]
results: Dict[str, ToolResult] = {}
for tc in toolCalls:
if "_parseError" in tc.args:
results[tc.id] = ToolResult(
toolCallId=tc.id,
toolName=tc.name,
success=False,
data="",
error=tc.args["_parseError"],
durationMs=0,
)
activeCalls = [tc for tc in toolCalls if tc.id not in results]
activeReadOnly = [tc for tc in activeCalls if toolRegistry.isReadOnly(tc.name)]
activeWrite = [tc for tc in activeCalls if not toolRegistry.isReadOnly(tc.name)]
if activeReadOnly:
readResults = await asyncio.gather(*[
toolRegistry.dispatch(tc, context) for tc in activeReadOnly
])
for tc, result in zip(activeReadOnly, readResults):
results[tc.id] = result
for tc in activeWrite:
results[tc.id] = await toolRegistry.dispatch(tc, context)
return [results[tc.id] for tc in toolCalls]
def _repairTruncatedJson(raw: str) -> Optional[Dict[str, Any]]:
"""Repair truncated JSON using the shared jsonUtils toolbox.
Uses closeJsonStructures which handles open strings, brackets, braces,
and trailing commas with stack-based structure tracking.
Returns parsed dict on success, None if unrecoverable.
"""
if not raw or not raw.strip().startswith("{"):
return None
try:
closed = closeJsonStructures(raw)
return json.loads(closed)
except (json.JSONDecodeError, Exception):
return None
def _parseToolCalls(aiResponse: AiCallResponse) -> List[ToolCallRequest]:
"""Parse tool calls from AI response. Supports native function calling and text-based fallback."""
toolCalls = []
# Native function calling: check response metadata
if hasattr(aiResponse, 'toolCalls') and aiResponse.toolCalls:
for tc in aiResponse.toolCalls:
rawArgs = tc["function"]["arguments"]
if isinstance(rawArgs, str):
rawArgs = rawArgs.strip()
try:
parsedArgs = json.loads(rawArgs) if rawArgs else {}
except json.JSONDecodeError:
parsedArgs = _repairTruncatedJson(rawArgs)
if parsedArgs is None:
logger.warning(f"Unrecoverable truncated JSON for '{tc['function']['name']}': {rawArgs[:200]}")
parsedArgs = {"_parseError": (
"Your tool call arguments were truncated (output cut off by token limit). "
"The content is too large for a single tool call. Strategies:\n"
"1. For new files: use writeFile(mode='create') with the first part, "
"then writeFile(fileId=..., mode='append') for subsequent parts (~8000 chars each).\n"
"2. For editing existing files: use replaceInFile to change only the specific parts.\n"
"3. For documentation: split into multiple smaller files."
)}
else:
logger.info(f"Repaired truncated JSON for '{tc['function']['name']}'")
else:
parsedArgs = rawArgs if rawArgs else {}
toolCalls.append(ToolCallRequest(
id=tc.get("id", str(len(toolCalls))),
name=tc["function"]["name"],
args=parsedArgs,
))
return toolCalls
# Text-based fallback: parse ```tool_call blocks
content = aiResponse.content or ""
pattern = r"```tool_call\s*\n\s*tool:\s*(\S+)\s*\n\s*args:\s*(\{.*?\})\s*\n\s*```"
matches = re.finditer(pattern, content, re.DOTALL)
for match in matches:
toolName = match.group(1).strip()
argsStr = match.group(2).strip()
try:
args = json.loads(argsStr)
except json.JSONDecodeError:
logger.warning(f"Failed to parse tool args for '{toolName}': {argsStr}")
args = {}
toolCalls.append(ToolCallRequest(name=toolName, args=args))
return toolCalls
def _extractTextContent(aiResponse: AiCallResponse) -> str:
"""Extract text content from AI response, removing tool_call blocks."""
content = aiResponse.content or ""
cleaned = re.sub(r"```tool_call\s*\n.*?\n\s*```", "", content, flags=re.DOTALL)
return cleaned.strip()
def _formatAssistantToolCalls(toolCalls: List[ToolCallRequest]) -> List[Dict[str, Any]]:
"""Format tool calls for the conversation history (OpenAI tool_calls format)."""
return [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.args)
}
}
for tc in toolCalls
]
def _buildProgressSummary(state: AgentState, reason: str) -> str:
"""Build a human-readable summary of agent progress for graceful termination."""
return (
f"{reason}\n\n"
f"Progress after {state.currentRound} rounds:\n"
f"- AI calls: {state.totalAiCalls}\n"
f"- Tool calls: {state.totalToolCalls}\n"
f"- Cost: {state.totalCostCHF:.4f} CHF\n"
f"- Processing time: {state.totalProcessingTime:.1f}s"
)
_ARTIFACT_TOOLS = {"writeFile", "replaceInFile", "deleteFile", "renameFile", "copyFile",
"createFolder", "deleteFolder", "renderDocument", "generateImage"}
def _buildArtifactSummary(roundLogs: List[AgentRoundLog]) -> str:
"""Extract file operations and key results from all agent rounds.
Produces a concise summary persisted as _workflowArtifacts so
follow-up rounds have immediate context (file IDs, names, actions).
"""
ops = []
for log in roundLogs:
for tc in log.toolCalls:
if tc.toolName not in _ARTIFACT_TOOLS or not tc.success:
continue
ops.append(f"- {tc.resultData}" if tc.resultData else f"- {tc.toolName}")
if not ops:
return ""
return "File operations in this run:\n" + "\n".join(ops)

View file

@ -0,0 +1,331 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Conversation manager for the Agent service.
Handles message history, context window management, and progressive summarization."""
import logging
from typing import List, Dict, Any, Optional
from modules.serviceCenter.services.serviceAgent.datamodelAgent import ToolDefinition
logger = logging.getLogger(__name__)
FIRST_SUMMARY_ROUND = 4
META_SUMMARY_ROUND = 7
KEEP_RECENT_MESSAGES = 4
MAX_ESTIMATED_TOKENS = 60000
_MAX_HISTORY_MESSAGES = 40
_MAX_HISTORY_MSG_CHARS = 12000
class ConversationManager:
"""Manages the conversation history and context window for agent runs.
Progressive summarization strategy:
- Rounds 1-3: full conversation retained
- Round 4+: older messages compressed into a running summary
- Round 7+: meta-summary replaces prior summaries
Supports RAG context injection before each round via injectRagContext."""
def __init__(self, systemPrompt: str):
self._messages: List[Dict[str, Any]] = [
{"role": "system", "content": systemPrompt}
]
self._summaries: List[Dict[str, Any]] = []
self._lastSummarizedRound: int = 0
self._ragContextInjected: bool = False
def loadHistory(self, messages: List[Dict[str, Any]]):
"""Load prior conversation messages for follow-up context.
Accepts messages with {role, content/message} format (as stored in DB).
Truncates long messages and limits total count to keep the context window
manageable. Must be called BEFORE addUserMessage with the current prompt.
"""
if not messages:
return
recent = messages[-_MAX_HISTORY_MESSAGES:]
loaded = 0
for msg in recent:
role = msg.get("role", "")
content = msg.get("content", "") or msg.get("message", "") or ""
if role not in ("user", "assistant"):
continue
if not content.strip():
continue
if len(content) > _MAX_HISTORY_MSG_CHARS:
content = content[:_MAX_HISTORY_MSG_CHARS] + ""
self._messages.append({"role": role, "content": content})
loaded += 1
if loaded:
logger.info(f"Loaded {loaded} history messages into conversation context")
@property
def messages(self) -> List[Dict[str, Any]]:
"""Current messages for the next AI call (internal markers stripped)."""
return [
{k: v for k, v in msg.items() if not k.startswith("_")}
for msg in self._messages
]
def addUserMessage(self, content: str):
"""Add a user message."""
self._messages.append({"role": "user", "content": content})
def addAssistantMessage(self, content: str, toolCalls: List[Dict[str, Any]] = None):
"""Add an assistant message, optionally with tool calls."""
msg: Dict[str, Any] = {"role": "assistant", "content": content}
if toolCalls:
msg["tool_calls"] = toolCalls
self._messages.append(msg)
def addToolResults(self, results: List[Dict[str, Any]]):
"""Add tool results to the conversation.
Each result: {toolCallId, toolName, content}."""
for result in results:
self._messages.append({
"role": "tool",
"tool_call_id": result["toolCallId"],
"content": result["content"]
})
def addToolResultsAsText(self, resultText: str):
"""Add combined tool results as a user message (text-based fallback)."""
self._messages.append({
"role": "user",
"content": f"Tool Results:\n{resultText}"
})
def injectRagContext(self, ragContext: str):
"""Inject RAG context as a system message right after the main system prompt.
Called before each agent round by the agent loop if KnowledgeService is available.
Replaces any previously injected RAG context to keep the context fresh."""
if not ragContext:
return
ragMessage = {
"role": "system",
"content": f"Relevant Knowledge (from indexed documents and workflow context):\n{ragContext}",
"_isRagContext": True,
}
# Replace existing RAG message if present, otherwise insert after system prompt
for i, msg in enumerate(self._messages):
if msg.get("_isRagContext"):
self._messages[i] = ragMessage
self._ragContextInjected = True
return
# Insert after the first system prompt
self._messages.insert(1, ragMessage)
self._ragContextInjected = True
def getMessageCount(self) -> int:
"""Get the number of messages (excluding system prompt)."""
return len(self._messages) - 1
def estimateTokenCount(self) -> int:
"""Rough estimate of total tokens in the conversation (4 chars ≈ 1 token)."""
totalChars = sum(len(str(m.get("content", ""))) for m in self._messages)
return totalChars // 4
def needsSummarization(self, currentRound: int) -> bool:
"""Check if progressive summarization should be triggered.
Triggers:
- At round FIRST_SUMMARY_ROUND (4) if not yet summarized
- At round META_SUMMARY_ROUND (7) for meta-summary
- Every 5 rounds after that
- When estimated token count exceeds MAX_ESTIMATED_TOKENS
"""
if currentRound >= FIRST_SUMMARY_ROUND and self._lastSummarizedRound < currentRound:
if currentRound == FIRST_SUMMARY_ROUND or currentRound == META_SUMMARY_ROUND:
return True
if (currentRound - META_SUMMARY_ROUND) % 5 == 0 and currentRound > META_SUMMARY_ROUND:
return True
if self.estimateTokenCount() > MAX_ESTIMATED_TOKENS:
return True
return False
async def summarize(self, currentRound: int, aiCallFn) -> Optional[str]:
"""Perform progressive summarization of older messages.
Rounds 1-3: full history retained, no summarization.
Round 4+: compress older messages into a running summary.
Round 7+: meta-summary that consolidates prior summaries.
"""
if currentRound < FIRST_SUMMARY_ROUND and self.estimateTokenCount() <= MAX_ESTIMATED_TOKENS:
return None
systemMsgs = [m for m in self._messages if m.get("role") == "system"]
nonSystemMessages = [m for m in self._messages if m.get("role") != "system"]
keepRecent = min(KEEP_RECENT_MESSAGES, len(nonSystemMessages))
if len(nonSystemMessages) <= keepRecent + 1:
return None
splitIdx = len(nonSystemMessages) - keepRecent
# Ensure the split doesn't orphan tool messages from their assistant.
# Walk backwards from splitIdx: if we're landing in the middle of a
# tool-call sequence (assistant+tool_calls → tool → tool …), include
# the entire sequence in recentMessages.
while splitIdx > 0 and nonSystemMessages[splitIdx].get("role") == "tool":
splitIdx -= 1
# Also include the assistant message that triggered the tool calls.
if splitIdx > 0 and splitIdx < len(nonSystemMessages) and \
nonSystemMessages[splitIdx].get("role") == "assistant" and \
nonSystemMessages[splitIdx].get("tool_calls"):
pass # splitIdx already points at the assistant; keep it in recent
elif splitIdx == 0:
return None # nothing to summarize
messagesToSummarize = nonSystemMessages[:splitIdx]
recentMessages = nonSystemMessages[splitIdx:]
summaryInput = _formatMessagesForSummary(messagesToSummarize)
previousSummary = self._summaries[-1]["content"] if self._summaries else ""
isMetaSummary = currentRound >= META_SUMMARY_ROUND and len(self._summaries) >= 2
summaryPrompt = _buildSummaryPrompt(summaryInput, previousSummary, isMetaSummary)
try:
summaryText = await aiCallFn(summaryPrompt)
except Exception as e:
logger.error(f"Progressive summarization failed: {e}")
return None
self._summaries.append({
"round": currentRound,
"content": summaryText,
"isMeta": isMetaSummary,
})
self._lastSummarizedRound = currentRound
mainSystem = systemMsgs[0] if systemMsgs else {"role": "system", "content": ""}
ragMessages = [m for m in systemMsgs if m.get("_isRagContext")]
self._messages = [
mainSystem,
*ragMessages,
{"role": "system", "content": f"Conversation Summary (rounds 1-{currentRound - keepRecent}):\n{summaryText}"},
*recentMessages,
]
logger.info(
f"Progressive summarization at round {currentRound}: "
f"compressed {len(messagesToSummarize)} messages into "
f"{'meta-' if isMetaSummary else ''}summary"
)
return summaryText
def _formatMessagesForSummary(messages: List[Dict[str, Any]]) -> str:
"""Format messages into a text block for summarization."""
parts = []
for msg in messages:
role = msg.get("role", "unknown")
content = msg.get("content", "")
if role == "tool":
toolName = msg.get("tool_call_id", "tool")
parts.append(f"[Tool Result ({toolName})]:\n{content}")
elif role == "assistant" and msg.get("tool_calls"):
calls = msg["tool_calls"]
callNames = [c.get("function", {}).get("name", "?") for c in calls]
parts.append(f"[Assistant → Tool Calls: {', '.join(callNames)}]")
if content:
parts.append(f"[Assistant]: {content}")
else:
parts.append(f"[{role.capitalize()}]: {content}")
return "\n\n".join(parts)
def _buildSummaryPrompt(messagesText: str, previousSummary: str, isMetaSummary: bool = False) -> str:
"""Build the prompt for progressive summarization."""
if isMetaSummary:
prompt = (
"Create a comprehensive meta-summary consolidating the previous summary "
"and the new messages. Preserve all key facts, decisions, entities (names, "
"numbers, dates), tool results, and action outcomes. Be concise but complete.\n\n"
)
else:
prompt = (
"Summarize the following conversation concisely. Preserve all key facts, "
"decisions, entities (names, numbers, dates), and tool results. "
"Do not lose any important information.\n\n"
)
if previousSummary:
prompt += f"Previous Summary:\n{previousSummary}\n\n"
prompt += f"New Messages to Summarize:\n{messagesText}\n\nProvide a concise, factual summary:"
return prompt
_LANGUAGE_NAMES = {
"de": "German", "en": "English", "fr": "French", "it": "Italian",
"es": "Spanish", "pt": "Portuguese", "nl": "Dutch", "ja": "Japanese",
"zh": "Chinese", "ko": "Korean", "ar": "Arabic", "ru": "Russian",
}
def buildSystemPrompt(
tools: List[ToolDefinition],
toolsFormatted: str = None,
userLanguage: str = "",
) -> str:
"""Build the system prompt for the agent.
Args:
tools: Available tool definitions.
toolsFormatted: Pre-formatted tool descriptions for text-based fallback.
userLanguage: ISO 639-1 language code (e.g. "de", "en"). The agent will
respond in this language.
"""
langName = _LANGUAGE_NAMES.get(userLanguage, "")
langInstruction = (
f"IMPORTANT: Always respond in {langName} ({userLanguage}). "
f"The user's language is {langName}. All your messages, explanations, "
f"and summaries MUST be in {langName}. "
f"Only use English for tool call arguments and technical identifiers.\n\n"
) if langName else ""
prompt = (
f"{langInstruction}"
"You are an AI agent with access to tools. "
"Use the provided tools to accomplish the user's task. "
"Think step by step. Call tools when you need information or need to perform actions. "
"When you have enough information to answer, respond directly without calling tools.\n\n"
)
prompt += (
"## Working Guidelines\n\n"
"### Workflow Context\n"
"When continuing a workflow (follow-up message), the Relevant Knowledge section contains "
"artifacts from previous rounds (file IDs, operations). Use this context instead of "
"re-searching or re-listing files.\n\n"
"### Efficient File Editing\n"
"- Use readFile with offset/limit to read specific line ranges of large files.\n"
"- Use searchInFileContent to find text before editing.\n"
"- Use replaceInFile for targeted edits (preferred over rewriting entire files).\n"
"- Use writeFile(mode='overwrite') only when the entire content must change.\n\n"
"### Large Content Strategy\n"
"- For content larger than ~8000 characters: use writeFile(mode='create') for the first "
"part, then writeFile(fileId=..., mode='append') for subsequent parts.\n"
"- Split large documentation into multiple focused files rather than one huge document.\n"
"- Structure outputs so files reference each other (e.g. index.md linking to sections).\n\n"
"### Code Generation\n"
"- Prefer modular file structures over monolithic files.\n"
"- When generating applications, create separate files for logical components.\n"
"- Always plan the structure before writing code.\n\n"
)
if toolsFormatted:
prompt += f"Available Tools:\n{toolsFormatted}\n\n"
prompt += (
"To call a tool, use this format:\n"
"```tool_call\n"
"tool: <tool_name>\n"
'args: {"param": "value"}\n'
"```\n\n"
)
return prompt

View file

@ -0,0 +1,153 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Data models for the Agent service."""
from typing import List, Dict, Any, Optional
from enum import Enum
from pydantic import BaseModel, Field
from modules.shared.timeUtils import getUtcTimestamp
import uuid
class AgentStatusEnum(str, Enum):
RUNNING = "running"
COMPLETED = "completed"
MAX_ROUNDS_REACHED = "maxRoundsReached"
BUDGET_EXCEEDED = "budgetExceeded"
ERROR = "error"
STOPPED = "stopped"
class AgentEventTypeEnum(str, Enum):
MESSAGE = "message"
CHUNK = "chunk"
TOOL_CALL = "toolCall"
TOOL_RESULT = "toolResult"
AGENT_PROGRESS = "agentProgress"
AGENT_SUMMARY = "agentSummary"
FILE_CREATED = "fileCreated"
FILE_UPDATED = "fileUpdated"
FILE_EDIT_PROPOSAL = "fileEditProposal"
FILE_VERSION = "fileVersion"
FILE_EDIT_REJECTED = "fileEditRejected"
DATA_SOURCE_ACCESS = "dataSourceAccess"
VOICE_RESPONSE = "voiceResponse"
FINAL = "final"
ERROR = "error"
class ToolDefinition(BaseModel):
"""Schema for a tool available to the agent."""
name: str = Field(description="Unique tool name")
description: str = Field(description="What this tool does")
parameters: Dict[str, Any] = Field(
default_factory=dict,
description="JSON Schema for tool parameters"
)
readOnly: bool = Field(
default=False,
description="If True, tool can run in parallel with other readOnly tools"
)
featureType: Optional[str] = Field(
default=None,
description="Feature scope for this tool (None = available to all)"
)
toolSet: Optional[str] = Field(
default=None,
description="Tool-set scope (None = available to all sets, e.g. 'core', 'workspace')"
)
class ToolCallRequest(BaseModel):
"""A tool call requested by the AI model."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: str
args: Dict[str, Any] = Field(default_factory=dict)
class ToolResult(BaseModel):
"""Result from executing a tool."""
toolCallId: str
toolName: str
success: bool = True
data: str = ""
error: Optional[str] = None
durationMs: int = 0
sideEvents: Optional[List[Dict[str, Any]]] = None
class AgentEvent(BaseModel):
"""Event emitted during agent execution for SSE streaming."""
type: AgentEventTypeEnum
content: Optional[str] = None
data: Optional[Dict[str, Any]] = None
class AgentConfig(BaseModel):
"""Configuration for an agent run."""
maxRounds: int = Field(default=25, ge=1, le=100)
maxCostCHF: Optional[float] = Field(default=None, ge=0.0)
toolSet: str = Field(default="core")
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
class AgentState(BaseModel):
"""Tracks state across an agent loop execution."""
workflowId: str
currentRound: int = 0
maxRounds: int = 25
totalAiCalls: int = 0
totalToolCalls: int = 0
totalCostCHF: float = 0.0
totalProcessingTime: float = 0.0
status: AgentStatusEnum = AgentStatusEnum.RUNNING
abortReason: Optional[str] = None
class ToolCallLog(BaseModel):
"""Log of a single tool call for observability."""
toolName: str
args: Dict[str, Any] = Field(default_factory=dict)
success: bool = True
durationMs: int = 0
error: Optional[str] = None
resultData: str = Field(default="", description="Short result summary for artifact tracking")
class AgentRoundLog(BaseModel):
"""Log of a single agent round for observability."""
roundNumber: int
aiModel: str = ""
inputTokens: int = 0
outputTokens: int = 0
costCHF: float = 0.0
toolCalls: List[ToolCallLog] = Field(default_factory=list)
durationMs: int = 0
class AgentTrace(BaseModel):
"""Full trace of an agent workflow for observability."""
workflowId: str
userId: str = ""
featureInstanceId: str = ""
startedAt: float = Field(default_factory=getUtcTimestamp)
completedAt: Optional[float] = None
status: AgentStatusEnum = AgentStatusEnum.RUNNING
totalRounds: int = 0
totalToolCalls: int = 0
totalCostCHF: float = 0.0
abortReason: Optional[str] = None
rounds: List[AgentRoundLog] = Field(default_factory=list)
class PendingFileEdit(BaseModel):
"""A proposed file edit awaiting user approval."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
fileId: str
fileName: str
mimeType: str = ""
oldContent: str = ""
newContent: str = ""
status: str = Field(default="pending", description="pending | accepted | rejected")
toolCallId: str = ""
workflowId: str = ""

View file

@ -0,0 +1,253 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Feature Data Sub-Agent.
Specialized mini-agent that queries feature-instance data tables. Receives
schema context (fields, descriptions) for the selected tables and has two
tools: browseTable and queryTable. Runs its own agent loop (max 5 rounds,
low budget) and returns structured results back to the main agent.
"""
import json
import logging
from typing import Any, Callable, Awaitable, Dict, List, Optional
from modules.datamodels.datamodelAi import (
AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum,
)
from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
AgentConfig, AgentEvent, AgentEventTypeEnum, ToolResult,
)
from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry
from modules.serviceCenter.services.serviceAgent.featureDataProvider import FeatureDataProvider
logger = logging.getLogger(__name__)
_MAX_ROUNDS = 5
_MAX_COST_CHF = 0.10
async def runFeatureDataAgent(
question: str,
featureInstanceId: str,
featureCode: str,
selectedTables: List[Dict[str, Any]],
mandateId: str,
userId: str,
aiCallFn: Callable[[AiCallRequest], Awaitable[AiCallResponse]],
dbConnector,
instanceLabel: str = "",
) -> str:
"""Run the feature data sub-agent and return the textual result.
Args:
question: The user/main-agent question to answer using feature data.
featureInstanceId: Feature instance to scope queries.
featureCode: Feature code (trustee, commcoach, ...).
selectedTables: List of DATA_OBJECT dicts the user selected.
mandateId: Mandate scope.
userId: Calling user ID.
aiCallFn: AI call function (with billing).
dbConnector: DatabaseConnector for queries.
instanceLabel: Human-readable instance name for context.
Returns:
Plain-text answer produced by the sub-agent.
"""
provider = FeatureDataProvider(dbConnector)
registry = _buildSubAgentTools(provider, featureInstanceId, mandateId)
for tbl in selectedTables:
meta = tbl.get("meta", {})
tableName = meta.get("table", "")
if tableName:
realCols = provider.getActualColumns(tableName)
if realCols:
meta["fields"] = realCols
schemaContext = _buildSchemaContext(featureCode, instanceLabel, selectedTables)
prompt = f"{schemaContext}\n\nUser question:\n{question}"
config = AgentConfig(maxRounds=_MAX_ROUNDS, maxCostCHF=_MAX_COST_CHF)
async def _getWorkflowCost() -> float:
return 0.0
result = ""
async for event in runAgentLoop(
prompt=prompt,
toolRegistry=registry,
config=config,
aiCallFn=aiCallFn,
getWorkflowCostFn=_getWorkflowCost,
workflowId=f"fda-{featureInstanceId[:8]}",
userId=userId,
featureInstanceId=featureInstanceId,
mandateId=mandateId,
):
if event.type == AgentEventTypeEnum.FINAL and event.content:
result = event.content
elif event.type == AgentEventTypeEnum.MESSAGE and event.content:
result += event.content
return result or "(no data returned by feature agent)"
# ------------------------------------------------------------------
# tool registration
# ------------------------------------------------------------------
def _buildSubAgentTools(
provider: FeatureDataProvider,
featureInstanceId: str,
mandateId: str,
) -> ToolRegistry:
"""Register browseTable and queryTable as sub-agent tools."""
registry = ToolRegistry()
async def _browseTable(args: Dict[str, Any], context: Dict[str, Any]):
tableName = args.get("tableName", "")
limit = args.get("limit", 50)
offset = args.get("offset", 0)
fields = args.get("fields")
if not tableName:
return ToolResult(toolCallId="", toolName="browseTable", success=False, error="tableName required")
result = provider.browseTable(
tableName=tableName,
featureInstanceId=featureInstanceId,
mandateId=mandateId,
fields=fields,
limit=min(limit, 200),
offset=offset,
)
return ToolResult(
toolCallId="", toolName="browseTable",
success="error" not in result,
data=json.dumps(result, default=str, ensure_ascii=False)[:30000],
error=result.get("error"),
)
async def _queryTable(args: Dict[str, Any], context: Dict[str, Any]):
tableName = args.get("tableName", "")
filters = args.get("filters", [])
fields = args.get("fields")
orderBy = args.get("orderBy")
limit = args.get("limit", 50)
offset = args.get("offset", 0)
if not tableName:
return ToolResult(toolCallId="", toolName="queryTable", success=False, error="tableName required")
result = provider.queryTable(
tableName=tableName,
featureInstanceId=featureInstanceId,
mandateId=mandateId,
filters=filters,
fields=fields,
orderBy=orderBy,
limit=min(limit, 200),
offset=offset,
)
return ToolResult(
toolCallId="", toolName="queryTable",
success="error" not in result,
data=json.dumps(result, default=str, ensure_ascii=False)[:30000],
error=result.get("error"),
)
registry.register(
"browseTable", _browseTable,
description="List rows from a feature data table with pagination.",
parameters={
"type": "object",
"properties": {
"tableName": {"type": "string", "description": "Name of the table to browse"},
"fields": {
"type": "array", "items": {"type": "string"},
"description": "Optional list of fields to return (default: all)",
},
"limit": {"type": "integer", "description": "Max rows to return (default 50, max 200)"},
"offset": {"type": "integer", "description": "Row offset for pagination"},
},
"required": ["tableName"],
},
readOnly=True,
)
registry.register(
"queryTable", _queryTable,
description=(
"Query a feature data table with filters, field selection, and ordering. "
"Filters: [{\"field\": \"status\", \"op\": \"=\", \"value\": \"active\"}]. "
"Operators: =, !=, >, <, >=, <=, LIKE, ILIKE, IS NULL, IS NOT NULL."
),
parameters={
"type": "object",
"properties": {
"tableName": {"type": "string", "description": "Name of the table to query"},
"filters": {
"type": "array",
"items": {
"type": "object",
"properties": {
"field": {"type": "string"},
"op": {"type": "string"},
"value": {},
},
},
"description": "Filter conditions",
},
"fields": {
"type": "array", "items": {"type": "string"},
"description": "Optional list of fields to return",
},
"orderBy": {"type": "string", "description": "Field name to order by"},
"limit": {"type": "integer", "description": "Max rows (default 50, max 200)"},
"offset": {"type": "integer", "description": "Row offset"},
},
"required": ["tableName"],
},
readOnly=True,
)
return registry
# ------------------------------------------------------------------
# context building
# ------------------------------------------------------------------
def _buildSchemaContext(
featureCode: str,
instanceLabel: str,
selectedTables: List[Dict[str, Any]],
) -> str:
"""Build a system-level context block describing available tables."""
parts = [
f"You are a data query assistant for the '{featureCode}' feature",
]
if instanceLabel:
parts[0] += f' (instance: "{instanceLabel}")'
parts[0] += "."
parts.append(
"You have access to the following data tables. "
"Use browseTable to list rows and queryTable to filter/search."
)
parts.append("")
for obj in selectedTables:
meta = obj.get("meta", {})
tbl = meta.get("table", "?")
fields = meta.get("fields", [])
label = obj.get("label", {})
labelStr = label.get("en") or label.get("de") or tbl
parts.append(f"Table: {tbl} ({labelStr})")
if fields:
parts.append(f" Fields: {', '.join(fields)}")
parts.append("")
parts.append(
"Answer the user's question using the data from these tables. "
"Be precise, cite row counts, and format data clearly."
)
return "\n".join(parts)

View file

@ -0,0 +1,215 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Generic data provider for querying feature-instance tables.
Uses the RBAC catalog's DATA_OBJECTS metadata (table name, fields) and the
DB connector to execute scoped, read-only queries against any registered
feature table. All queries are automatically filtered by featureInstanceId
and mandateId so data isolation is guaranteed.
"""
import logging
import json
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
_ALLOWED_OPERATORS = {"=", "!=", ">", "<", ">=", "<=", "LIKE", "ILIKE", "IS NULL", "IS NOT NULL"}
class FeatureDataProvider:
"""Reads feature-instance data from the DB using DATA_OBJECTS metadata."""
def __init__(self, dbConnector):
"""
Args:
dbConnector: A connectorDbPostgre.DatabaseConnector with an open connection.
"""
self._db = dbConnector
# ------------------------------------------------------------------
# public API (called by FeatureDataAgent tools)
# ------------------------------------------------------------------
def getAvailableTables(self, featureCode: str) -> List[Dict[str, Any]]:
"""Return DATA_OBJECTS registered for *featureCode*."""
from modules.security.rbacCatalog import getCatalogService
catalog = getCatalogService()
return catalog.getDataObjects(featureCode)
def getTableSchema(self, featureCode: str, tableName: str) -> Optional[Dict[str, Any]]:
"""Return the DATA_OBJECT entry for a specific table."""
for obj in self.getAvailableTables(featureCode):
if obj.get("meta", {}).get("table") == tableName:
return obj
return None
def getActualColumns(self, tableName: str) -> List[str]:
"""Read real column names from PostgreSQL information_schema."""
try:
conn = self._db.connection
with conn.cursor() as cur:
cur.execute(
"SELECT column_name FROM information_schema.columns "
"WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
"ORDER BY ordinal_position",
[tableName],
)
cols = [row["column_name"] for row in cur.fetchall()]
return [c for c in cols if not c.startswith("_")]
except Exception as e:
logger.warning(f"getActualColumns({tableName}) failed: {e}")
return []
def browseTable(
self,
tableName: str,
featureInstanceId: str,
mandateId: str,
fields: List[str] = None,
limit: int = 50,
offset: int = 0,
) -> Dict[str, Any]:
"""List rows from a feature table with pagination.
Returns ``{"rows": [...], "total": N, "limit": L, "offset": O}``.
"""
_validateTableName(tableName)
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId)
try:
conn = self._db.connection
with conn.cursor() as cur:
countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {scopeFilter["where"]}'
cur.execute(countSql, scopeFilter["params"])
total = cur.fetchone()["count"] if cur.rowcount else 0
selectCols = ", ".join(f'"{f}"' for f in fields) if fields else "*"
dataSql = (
f'SELECT {selectCols} FROM "{tableName}" '
f'WHERE {scopeFilter["where"]} '
f'ORDER BY "id" LIMIT %s OFFSET %s'
)
cur.execute(dataSql, scopeFilter["params"] + [limit, offset])
rows = [_serializeRow(dict(r)) for r in cur.fetchall()]
return {"rows": rows, "total": total, "limit": limit, "offset": offset}
except Exception as e:
logger.error(f"browseTable({tableName}) failed: {e}")
return {"rows": [], "total": 0, "limit": limit, "offset": offset, "error": str(e)}
def queryTable(
self,
tableName: str,
featureInstanceId: str,
mandateId: str,
filters: List[Dict[str, Any]] = None,
fields: List[str] = None,
orderBy: str = None,
limit: int = 50,
offset: int = 0,
) -> Dict[str, Any]:
"""Query a feature table with optional filters.
``filters`` is a list of ``{"field": "x", "op": "=", "value": "y"}``.
"""
_validateTableName(tableName)
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId)
extraWhere, extraParams = _buildFilterClauses(filters)
fullWhere = scopeFilter["where"]
allParams = list(scopeFilter["params"])
if extraWhere:
fullWhere += " AND " + extraWhere
allParams.extend(extraParams)
try:
conn = self._db.connection
with conn.cursor() as cur:
countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}'
cur.execute(countSql, allParams)
total = cur.fetchone()["count"] if cur.rowcount else 0
selectCols = ", ".join(f'"{f}"' for f in fields) if fields else "*"
orderClause = f'ORDER BY "{orderBy}"' if orderBy and _isValidIdentifier(orderBy) else 'ORDER BY "id"'
dataSql = (
f'SELECT {selectCols} FROM "{tableName}" '
f'WHERE {fullWhere} {orderClause} LIMIT %s OFFSET %s'
)
cur.execute(dataSql, allParams + [limit, offset])
rows = [_serializeRow(dict(r)) for r in cur.fetchall()]
return {"rows": rows, "total": total, "limit": limit, "offset": offset}
except Exception as e:
logger.error(f"queryTable({tableName}) failed: {e}")
return {"rows": [], "total": 0, "limit": limit, "offset": offset, "error": str(e)}
# ------------------------------------------------------------------
# helpers
# ------------------------------------------------------------------
def _validateTableName(tableName: str):
if not tableName or not _isValidIdentifier(tableName):
raise ValueError(f"Invalid table name: {tableName}")
def _isValidIdentifier(name: str) -> bool:
"""Only allow alphanumeric + underscore to prevent SQL injection."""
return name.isidentifier()
def _buildScopeFilter(tableName: str, featureInstanceId: str, mandateId: str) -> Dict[str, Any]:
"""Build the mandatory WHERE clause that scopes rows to the feature instance.
Feature tables usually have either ``featureInstanceId`` or a combination
of ``mandateId`` + an org/context FK. We try ``featureInstanceId`` first,
then fall back to ``mandateId``.
"""
conditions = []
params = []
conditions.append('"featureInstanceId" = %s')
params.append(featureInstanceId)
if mandateId:
conditions.append('"mandateId" = %s')
params.append(mandateId)
return {"where": " AND ".join(conditions), "params": params}
def _buildFilterClauses(filters: Optional[List[Dict[str, Any]]]) -> tuple:
"""Convert agent-provided filter dicts into safe SQL."""
if not filters:
return "", []
parts = []
params = []
for f in filters:
field = f.get("field", "")
op = (f.get("op") or "=").upper()
value = f.get("value")
if not field or not _isValidIdentifier(field):
continue
if op not in _ALLOWED_OPERATORS:
continue
if op in ("IS NULL", "IS NOT NULL"):
parts.append(f'"{field}" {op}')
else:
parts.append(f'"{field}" {op} %s')
params.append(value)
return " AND ".join(parts), params
def _serializeRow(row: Dict[str, Any]) -> Dict[str, Any]:
"""Ensure all values are JSON-serializable."""
for k, v in row.items():
if isinstance(v, (bytes, bytearray)):
row[k] = f"<binary {len(v)} bytes>"
elif hasattr(v, "isoformat"):
row[k] = v.isoformat()
return row

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,108 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Sandboxed code execution for the AI agent executeCode tool."""
import logging
import signal
import sys
import io
import traceback
from typing import Dict, Any
logger = logging.getLogger(__name__)
_PYTHON_ALLOWED_MODULES = {
"math", "statistics", "json", "csv", "re", "datetime",
"collections", "itertools", "functools", "decimal", "fractions",
"random", "string", "textwrap", "operator", "copy",
}
_PYTHON_BLOCKED_BUILTINS = {
"open", "exec", "eval", "compile", "__import__", "globals", "locals",
"getattr", "setattr", "delattr", "breakpoint", "exit", "quit",
"input", "memoryview", "type",
}
_MAX_EXECUTION_TIME_S = 30
_MAX_OUTPUT_CHARS = 50000
def _safeImport(name, *args, **kwargs):
"""Restricted import that only allows whitelisted modules."""
if name not in _PYTHON_ALLOWED_MODULES:
raise ImportError(f"Module '{name}' is not allowed. Permitted: {', '.join(sorted(_PYTHON_ALLOWED_MODULES))}")
return __builtins__["__import__"](name, *args, **kwargs) if isinstance(__builtins__, dict) else __import__(name, *args, **kwargs)
def _buildRestrictedGlobals() -> Dict[str, Any]:
"""Build a restricted globals dict for exec()."""
import builtins
safeBuiltins = {}
for name in dir(builtins):
if name.startswith("_"):
continue
if name in _PYTHON_BLOCKED_BUILTINS:
continue
safeBuiltins[name] = getattr(builtins, name)
safeBuiltins["__import__"] = _safeImport
safeBuiltins["__name__"] = "__sandbox__"
safeBuiltins["__builtins__"] = safeBuiltins
for modName in _PYTHON_ALLOWED_MODULES:
try:
safeBuiltins[modName] = __import__(modName)
except ImportError:
pass
return {"__builtins__": safeBuiltins}
async def executePython(code: str) -> Dict[str, Any]:
"""Execute Python code in a restricted sandbox. Returns {success, output, error}."""
import asyncio
def _run():
restrictedGlobals = _buildRestrictedGlobals()
capturedOutput = io.StringIO()
oldStdout = sys.stdout
oldStderr = sys.stderr
try:
sys.stdout = capturedOutput
sys.stderr = capturedOutput
if sys.platform != "win32":
signal.signal(signal.SIGALRM, lambda *_: (_ for _ in ()).throw(TimeoutError("Execution timed out")))
signal.alarm(_MAX_EXECUTION_TIME_S)
exec(compile(code, "<sandbox>", "exec"), restrictedGlobals)
if sys.platform != "win32":
signal.alarm(0)
output = capturedOutput.getvalue()
if len(output) > _MAX_OUTPUT_CHARS:
output = output[:_MAX_OUTPUT_CHARS] + f"\n... (truncated at {_MAX_OUTPUT_CHARS} chars)"
return {"success": True, "output": output}
except TimeoutError:
return {"success": False, "error": f"Execution timed out after {_MAX_EXECUTION_TIME_S}s"}
except Exception as e:
tb = traceback.format_exc()
return {"success": False, "error": f"{type(e).__name__}: {e}", "traceback": tb}
finally:
sys.stdout = oldStdout
sys.stderr = oldStderr
if sys.platform != "win32":
signal.alarm(0)
loop = asyncio.get_event_loop()
try:
result = await asyncio.wait_for(
loop.run_in_executor(None, _run),
timeout=_MAX_EXECUTION_TIME_S + 5,
)
return result
except asyncio.TimeoutError:
return {"success": False, "error": f"Execution timed out after {_MAX_EXECUTION_TIME_S}s"}

View file

@ -0,0 +1,154 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Tool registry for the Agent service. Manages tool definitions and dispatch."""
import logging
import time
from typing import Dict, List, Any, Optional, Callable, Awaitable
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
ToolDefinition, ToolCallRequest, ToolResult
)
logger = logging.getLogger(__name__)
class ToolRegistry:
"""Registry for agent tools. Handles registration, lookup, and dispatch."""
def __init__(self):
self._tools: Dict[str, ToolDefinition] = {}
self._handlers: Dict[str, Callable[..., Awaitable[ToolResult]]] = {}
def register(self, name: str, handler: Callable[..., Awaitable[ToolResult]],
description: str = "", parameters: Dict[str, Any] = None,
readOnly: bool = False, featureType: str = None,
toolSet: str = None):
"""Register a tool with its handler function."""
if name in self._tools:
logger.warning(f"Tool '{name}' already registered, overwriting")
self._tools[name] = ToolDefinition(
name=name,
description=description,
parameters=parameters or {},
readOnly=readOnly,
featureType=featureType,
toolSet=toolSet,
)
self._handlers[name] = handler
logger.debug(f"Registered tool: {name} (readOnly={readOnly}, toolSet={toolSet})")
def registerFromDefinition(self, definition: ToolDefinition,
handler: Callable[..., Awaitable[ToolResult]]):
"""Register a tool from a pre-built ToolDefinition."""
self._tools[definition.name] = definition
self._handlers[definition.name] = handler
logger.debug(f"Registered tool: {definition.name} (readOnly={definition.readOnly})")
def unregister(self, name: str):
"""Remove a tool from the registry."""
self._tools.pop(name, None)
self._handlers.pop(name, None)
def getTools(self, toolSet: str = None, featureType: str = None) -> List[ToolDefinition]:
"""Get available tools, optionally filtered by toolSet or featureType."""
tools = list(self._tools.values())
if featureType:
tools = [t for t in tools if t.featureType is None or t.featureType == featureType]
if toolSet:
tools = [t for t in tools if t.toolSet is None or t.toolSet == toolSet]
return tools
def getToolNames(self) -> List[str]:
"""Get names of all registered tools."""
return list(self._tools.keys())
def getTool(self, name: str) -> Optional[ToolDefinition]:
"""Get a single tool definition by name."""
return self._tools.get(name)
def isReadOnly(self, name: str) -> bool:
"""Check if a tool is marked as readOnly."""
tool = self._tools.get(name)
return tool.readOnly if tool else False
def isValidTool(self, name: str) -> bool:
"""Check if a tool name is valid (registered)."""
return name in self._tools
async def dispatch(self, toolCall: ToolCallRequest, context: Dict[str, Any] = None) -> ToolResult:
"""Execute a tool call and return the result."""
startTime = time.time()
if not self.isValidTool(toolCall.name):
return ToolResult(
toolCallId=toolCall.id,
toolName=toolCall.name,
success=False,
error=f"Unknown tool: '{toolCall.name}'. Available: {', '.join(self.getToolNames())}"
)
handler = self._handlers[toolCall.name]
argsSummary = ", ".join(f"{k}={str(v)[:80]}" for k, v in (toolCall.args or {}).items())
logger.info(f"Tool dispatch: {toolCall.name}({argsSummary})")
try:
result = await handler(toolCall.args, context or {})
durationMs = int((time.time() - startTime) * 1000)
if isinstance(result, ToolResult):
result.toolCallId = toolCall.id
result.durationMs = durationMs
dataSummary = (result.data[:200] + "...") if result.data and len(result.data) > 200 else (result.data or "")
if result.success:
logger.info(f"Tool result: {toolCall.name} OK ({durationMs}ms) → {dataSummary}")
else:
logger.warning(f"Tool result: {toolCall.name} FAILED ({durationMs}ms) → {result.error}")
return result
return ToolResult(
toolCallId=toolCall.id,
toolName=toolCall.name,
success=True,
data=str(result),
durationMs=durationMs
)
except Exception as e:
durationMs = int((time.time() - startTime) * 1000)
logger.error(f"Tool '{toolCall.name}' failed: {e}", exc_info=True)
return ToolResult(
toolCallId=toolCall.id,
toolName=toolCall.name,
success=False,
error=str(e),
durationMs=durationMs
)
def formatToolsForPrompt(self) -> str:
"""Format all tools as text for system prompt (text-based fallback)."""
parts = []
for tool in self._tools.values():
paramStr = ", ".join(
f"{k}: {v}" for k, v in tool.parameters.items()
) if tool.parameters else "none"
parts.append(f"- **{tool.name}**: {tool.description}\n Parameters: {{{paramStr}}}")
return "\n".join(parts)
def formatToolsForFunctionCalling(self) -> List[Dict[str, Any]]:
"""Format all tools as OpenAI-compatible function definitions for native function calling."""
functions = []
for tool in self._tools.values():
functions.append({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters if tool.parameters else {
"type": "object",
"properties": {},
"required": []
}
}
})
return functions

Some files were not shown because too many files have changed in this diff Show more