commit
bb0941ffa4
233 changed files with 15586 additions and 41731 deletions
|
|
@ -39,12 +39,11 @@ Service_GOOGLE_REDIRECT_URI = http://localhost:8000/api/google/auth/callback
|
|||
STRIPE_SECRET_KEY_SECRET = DEV_ENC:Z0FBQUFBQnBudkpGWDkxSldfM0NCZ3dmbHY5cS1nQlI3UWZ4ZWRrNVdUdEFKa25RckRiQWY0c1E5MjVsZzlfRkZEU0VFU2tNQ01qZnRNQ0pZVU9hVFN6OEU0RXhwdTl3algzLWJlSXRhYmZlMHltSC1XejlGWEU5TDF1LUlYNEh1aG9tRFI4YmlCYzUyei02U1dabWoyb0N2dVFSb1RhWTNnQjBCZkFjV0FfOWdYdDVpX1k5R2pYM1R6SHRiaE10V1l1dnQybjVHWDRiQUJLM0UxRDZnczhJZGFsc3JhOU82QT09
|
||||
STRIPE_WEBHOOK_SECRET = DEV_ENC:Z0FBQUFBQnBudkpGcHNWTWpBWkFHRExtdU01N3RyZzNsMjhUS3NiVTNCZmMwN2NEcFZ6UkQ1a2I0aUkyNU4wR2dUdHJXYmtkaEFRUnFpcThObHBEQmJkdEFnT1FXeUxOTlU3UDFNRzl6LWdpRFpYdExvY3FTTG9MTkswdEhrVkNKQVFucnBjSnhLNm4=
|
||||
STRIPE_API_VERSION = 2026-01-28.clover
|
||||
APP_FRONTEND_URL = http://localhost:5176
|
||||
|
||||
# AI configuration
|
||||
Connector_AiOpenai_API_SECRET = DEV_ENC:Z0FBQUFBQnBaSnM4TWFRRmxVQmNQblVIYmc1Y0Q3aW9zZUtDWlNWdGZjbFpncGp2NHN2QjkxMWxibUJnZDBId252MWk5TXN3Yk14ajFIdi1CTkx2ZWx2QzF5OFR6LUx5azQ3dnNLaXJBOHNxc0tlWmtZcTFVelF4eXBSM2JkbHd2eTM0VHNXdHNtVUprZWtPVzctNlJsZHNmM20tU1N6Q1Q2cHFYSi1tNlhZNDNabTVuaEVGWmIydEhadTcyMlBURmw2aUJxOF9GTzR0dTZiNGZfOFlHaVpPZ1A1LXhhOEFtN1J5TEVNNWtMcGpyNkMzSl8xRnZsaTF1WTZrOUZmb0cxVURjSGFLS2dIYTQyZEJtTm90bEYxVWxNNXVPdTVjaVhYbXhxT3JsVDM5VjZMVFZKSE1tZnM9
|
||||
Connector_AiAnthropic_API_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpENmFBWG16STFQUVZxNzZZRzRLYTA4X3lRanF1VkF4cU45OExNMzlsQmdISGFxTUxud1dXODBKcFhMVG9KNjdWVnlTTFFROVc3NDlsdlNHLUJXeG41NDBHaXhHR0VHVWl5UW9RNkVWbmlhakRKVW5pM0R4VHk0LUw0TV9LdkljNHdBLXJua21NQkl2b3l4UkVkMGN1YjBrMmJEeWtMay1jbmxrYWJNbUV0aktCXzU1djR2d2RSQXZORTNwcG92ZUVvVGMtQzQzTTVncEZTRGRtZUFIZWQ0dz09
|
||||
Connector_AiPerplexity_API_SECRET = DEV_ENC:Z0FBQUFBQm82Mzk2Q1MwZ0dNcUVBcUtuRDJIcTZkMXVvYnpjM3JEMzJiT1NKSHljX282ZDIyZTJYc09VSTdVNXAtOWU2UXp5S193NTk5dHJsWlFjRjhWektFOG1DVGY4ZUhHTXMzS0RPN1lNcF9nSlVWbW5BZ1hkZDVTejl6bVZNRFVvX29xamJidWRFMmtjQmkyRUQ2RUh6UTN1aWNPSUJBPT0=
|
||||
Connector_AiPerplexity_API_SECRET = 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
|
||||
Connector_AiTavily_API_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpEQTdnUHMwd2pIaXNtMmtCTFREd0pyQXRKb1F5eGtHSnkyOGZiUnlBOFc0b3Vzcndrc3ViRm1nMDJIOEZKYWxqdWNkZGh5N0Z4R0JlQmxXSG5pVnJUR2VYckZhMWNMZ1FNeXJ3enJLVlpiblhOZTNleUg3ZzZyUzRZanFSeDlVMkI=
|
||||
Connector_AiPrivateLlm_API_SECRET = DEV_ENC:Z0FBQUFBQnBudkpGRHM5eFdUVmVZU1R1cHBwN1RlMUx4T0NlLTJLUFFVX3J2OElDWFpuZmJHVmp4Z3BNNWMwZUVVZUd2TFhRSjVmVkVlcFlVRWtybXh0ZHloZ01ZcnVvX195YjdlWVdEcjZSWFFTTlNBWUlaTlNoLWhqVFBIb0thVlBiaWhjYjFQOFY=
|
||||
Connector_AiMistral_API_SECRET = DEV_ENC:Z0FBQUFBQnBudkpGeEQxYUIxOHhia0JlQWpWQ2dWQWZzY3l6SWwyUnJoR1hRQWloX2lxb2lGNkc4UnA4U2tWNjJaYzB1d1hvNG9fWUp1N3V4OW9FMGhaWVhjSlVwWEc1X2loVDBSZDEtdHdfcTA5QkcxQTR4OHc4RkRzclJrU2d1RFZpNDJkRDRURlE=
|
||||
|
|
|
|||
|
|
@ -36,15 +36,14 @@ Service_MSFT_REDIRECT_URI = https://gateway-int.poweron-center.net/api/msft/auth
|
|||
Service_GOOGLE_REDIRECT_URI = https://gateway-int.poweron-center.net/api/google/auth/callback
|
||||
|
||||
# Stripe Billing (both end with _SECRET for encryption script)
|
||||
STRIPE_SECRET_KEY_SECRET = INT_ENC:Z0FBQUFBQnBudkpGOTVvaGhuSTVRTW5Fa0x3akFwZktQX21DZnFGQWgteDJwUWpnbTV5enRmbmZyWXlpY2lKVVNINkFManhNMnFQZ1VnOUxSQ0FTZFBVNmdhSGhwWFBHaDNnSzZXVnIxRmNUcnBQN0c3R19Xb2g1QnBxVXpiSXRRTk5NOGtzcU5HcUNiWDNvdmhYbGFkWkRCR25iVEJKTmwzcGRBZjNjaVNiWDJDaWlhLWpfdkdXYlQyUWk2NndKYW5lYXBaTkRzMWZsZjlFb3JOX1NzbkM4NWFyQU9MajZlZz09
|
||||
STRIPE_SECRET_KEY_SECRET = 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
|
||||
STRIPE_WEBHOOK_SECRET = INT_ENC:Z0FBQUFBQnBudkpGamJBNW91VUdEaThWRTFiTWpyb3NqSDJJcGtjNkhUVVZqVElxUWExY05KcllSYVk1SkRuS1NjYWpZUk1uU29nb2pzdXUxRzBsOEgyRWtmUEw3dUF4ejFIXzNwTVZRM1R1bVVhTUs4ZHJMT0V4Xy1pcHVfWlBaQV9wVXo5MGlQYXA=
|
||||
STRIPE_API_VERSION = 2026-01-28.clover
|
||||
APP_FRONTEND_URL = https://nyla-int.poweron-center.net
|
||||
|
||||
# AI configuration
|
||||
Connector_AiOpenai_API_SECRET = INT_ENC:Z0FBQUFBQnBaSnM4MENkQ2xJVmE5WFZKUkh2SHJFby1YVXN3ZmVxRkptS3ZWRmlwdU93ZEJjSjlMV2NGbU5mS3NCdmFfcmFYTEJNZXFIQ3ozTWE4ZC1pemlQNk9wbjU1d3BPS0ZCTTZfOF8yWmVXMWx0TU1DamlJLVFhSTJXclZsY3hMVWlPcXVqQWtMdER4T252NHZUWEhUOTdIN1VGR3ltazEweXFqQ0lvb0hYWmxQQnpxb0JwcFNhRDNGWXdoRTVJWm9FalZpTUF5b1RqZlRaYnVKYkp0NWR5Vko1WWJ0Wmg2VWJzYXZ0Z3Q4UkpsTldDX2dsekhKMmM4YjRoa2RwemMwYVQwM2cyMFlvaU5mOTVTWGlROU8xY2ZVRXlxZzJqWkxURWlGZGI2STZNb0NpdEtWUnM9
|
||||
Connector_AiAnthropic_API_SECRET = INT_ENC:Z0FBQUFBQm8xSVRjT1ZlRWVJdVZMT3ljSFJDcFdxRFBRVkZhS204NnN5RDBlQ0tpenhTM0FFVktuWW9mWHNwRWx2dHB0eDBSZ0JFQnZKWlp6c01pVGREWHd1eGpERnU0Q2xhaks1clQ1ZXVsdnd2ZzhpNXNQS1BhY3FjSkdkVEhHalNaRGR4emhpakZncnpDQUVxOHVXQzVUWmtQc0FsYmFwTF9TSG5FOUFtWk5Ick1NcHFvY2s1T1c2WXlRUFFJZnh6TWhuaVpMYmppcDR0QUx0a0R6RXlwbGRYb1R4dzJkUT09
|
||||
Connector_AiPerplexity_API_SECRET = INT_ENC:Z0FBQUFBQm82Mzk2UWZJdUFhSW8yc3RKc0tKRXphd0xWMkZOVlFpSGZ4SGhFWnk0cTF5VjlKQVZjdS1QSWdkS0pUSWw4OFU5MjUxdTVQel9aeWVIZTZ5TXRuVmFkZG0zWEdTOGdHMHpsTzI0TGlWYURKU1Q0VVpKTlhxUk5FTmN6SUJScDZ3ZldIaUJZcWpaQVRiSEpyQm9tRTNDWk9KTnZBPT0=
|
||||
Connector_AiPerplexity_API_SECRET = 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
|
||||
Connector_AiTavily_API_SECRET = INT_ENC:Z0FBQUFBQm8xSVRkdkJMTDY0akhXNzZDWHVYSEt1cDZoOWEzSktneHZEV2JndTNmWlNSMV9KbFNIZmQzeVlrNE5qUEIwcUlBSGM1a0hOZ3J6djIyOVhnZzI3M1dIUkdicl9FVXF3RGktMmlEYmhnaHJfWTdGUkktSXVUSGdQMC1vSEV6VE8zR2F1SVk=
|
||||
Connector_AiPrivateLlm_API_SECRET = INT_ENC:Z0FBQUFBQnBudkpGSjZ1NWh0aWc1R3Z4MHNaeS1HamtUbndhcUZFZDlqUDhjSmg5eHFfdlVkU0RsVkJ2UVRaMWs3aWhraG5jSlc0YkxNWHVmR2JoSW5ENFFCdkJBM0VienlKSnhzNnBKbTJOUTFKczRfWlQ3bWpmUkRTT1I1OGNUSTlQdExacGRpeXg=
|
||||
Connector_AiMistral_API_SECRET = INT_ENC:Z0FBQUFBQnBudkpGZTNtZ1E4TWIxSEU1OUlreUpxZkJIR0Vxcm9xRHRUbnBxbTQ1cXlkbnltWkJVdTdMYWZ4c3Fsam42TERWUTVhNzZFMU9xVjdyRGFCYml6bmZsZFd2YmJzemlrSWN6Q3o3X0NXX2xXNUQteTNONHdKYzJ5YVpLLWdhU2JhSTJQZnI=
|
||||
|
|
|
|||
|
|
@ -36,15 +36,14 @@ Service_MSFT_REDIRECT_URI = https://gateway-prod.poweron-center.net/api/msft/aut
|
|||
Service_GOOGLE_REDIRECT_URI = https://gateway-prod.poweron-center.net/api/google/auth/callback
|
||||
|
||||
# Stripe Billing (both end with _SECRET for encryption script)
|
||||
STRIPE_SECRET_KEY_SECRET = PROD_ENC:Z0FBQUFBQnBudkpGNmx3N3Q4QWdlcXVBSlJuYzVJX2hZRW8wbklJYUI0Rzh2YWRPcWY5dC1rMjhfUDRTOE91TlZyLTBEZkY1N015dmg5akEta1d0M0NpNk9oNDZpQTlMUGlLalV6aVowbl9Jc2hKMVlxbE9aaTZNRUxDQ3VGSnJxN040VERUMDFiekhITXdTR0N4aUxwWGxtcHdlU2NtOVNsSlVpOE0xTkRSdGhnN09UWGxuLURUaFdfQWJ4ZEw3R0c0bVRQaTA1NURhVEZudHY4d2gtTzItOF9TcmMwajFmZz09
|
||||
STRIPE_SECRET_KEY_SECRET = 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
|
||||
STRIPE_WEBHOOK_SECRET = PROD_ENC:Z0FBQUFBQnBudkpGNUpTWldsakYydFhFelBrR1lSaWxYT3kyMENOMUljZTJUZHBWcEhhdWVCMzYxZXQ5b3VlTFVRalFiTVdsbGxrdUx0RDFwSEpsOC1sTDJRTEJNQlA3S3ZaQzBtV1h6bWp5VnlMZUgwUlF3cXYxcnljZVE5SWdzLVg3V0syOWRYS08=
|
||||
STRIPE_API_VERSION = 2026-01-28.clover
|
||||
APP_FRONTEND_URL = https://nyla.poweron-center.net
|
||||
|
||||
# AI configuration
|
||||
Connector_AiOpenai_API_SECRET = PROD_ENC:Z0FBQUFBQnBaSnM4TWJOVm4xVkx6azRlNDdxN3UxLUdwY2hhdGYxRGp4VFJqYXZIcmkxM1ZyOWV2M0Z4MHdFNkVYQ0ROb1d6LUZFUEdvMHhLMEtXYVBCRzM5TlYyY3ROYWtJRk41cDZxd0tYYi00MjVqMTh4QVcyTXl0bmVocEFHbXQwREpwNi1vODdBNmwzazE5bkpNelE2WXpvblIzWlQwbGdEelI2WXFqT1RibXVHcjNWbVhwYzBOM25XTzNmTDAwUjRvYk4yNjIyZHc5c2RSZzREQUFCdUwyb0ZuOXN1dzI2c2FKdXI4NGxEbk92czZWamJXU3ZSbUlLejZjRklRRk4tLV9aVUFZekI2bTU4OHYxNTUybDg3RVo0ZTh6dXNKRW5GNXVackZvcm9laGI0X3R6V3M9
|
||||
Connector_AiAnthropic_API_SECRET = PROD_ENC:Z0FBQUFBQnBDM1Z3TnhYdlhSLW5RbXJyMHFXX0V0bHhuTDlTaFJsRDl2dTdIUTFtVFAwTE8tY3hLbzNSMnVTLXd3RUZualN3MGNzc1kwOTIxVUN2WW1rYi1TendFRVVBSVNqRFVjckEzNExyTGNaUkJLMmozazUwemI1cnhrcEtZVXJrWkdaVFFramp3MWZ6RmY2aGlRMXVEYjM2M3ZlbmxMdnNCRDM1QWR0Wmd6MWVnS1I1c01nV3hRLXg3d2NTZXVfTi1Wdm16UnRyNGsyRTZ0bG9TQ1g1OFB5Z002bmQ3QT09
|
||||
Connector_AiPerplexity_API_SECRET = PROD_ENC:Z0FBQUFBQm82Mzk2Q1FGRkJEUkI4LXlQbHYzT2RkdVJEcmM4WGdZTWpJTEhoeUF1NW5LUVpJdDBYN3k1WFN4a2FQSWJSQmd0U0xJbzZDTmFFN05FcXl0Z3V1OEpsZjYydV94TXVjVjVXRTRYSWdLMkd5XzZIbFV6emRCZHpuOUpQeThadE5xcDNDVGV1RHJrUEN0c1BBYXctZFNWcFRuVXhRPT0=
|
||||
Connector_AiPerplexity_API_SECRET = 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
|
||||
Connector_AiTavily_API_SECRET = PROD_ENC:Z0FBQUFBQnBDM1Z3NmItcDh6V0JpcE5Jc0NlUWZqcmllRHB5eDlNZmVnUlNVenhNTm5xWExzbjJqdE1GZ0hTSUYtb2dvdWNhTnlQNmVWQ2NGVDgwZ0MwMWZBMlNKWEhzdlF3TlZzTXhCZWM4Z1Uwb18tSTRoU1JBVTVkSkJHOTJwX291b3dPaVphVFg=
|
||||
Connector_AiPrivateLlm_API_SECRET = PROD_ENC:Z0FBQUFBQnBudkpGanZ6U3pzZWkwXzVPWGtIQ040XzFrTXc5QWRnazdEeEktaUJ0akJmNnEzbWUzNHczLTJfc2dIdzBDY0FTaXZYcDhxNFdNbTNtbEJTb2VRZ0ZYd05hdlNLR1h6SUFzVml2Z1FLY1BjTl90UWozUGxtak1URnhhZmNDRWFTb0dKVUo=
|
||||
Connector_AiMistral_API_SECRET = PROD_ENC:Z0FBQUFBQnBudkpGc2tQc2lvMk1YZk01Q1dob1U5cnR0dG03WWE3WkpoOWo0SEpvLU9Rc2lCNDExdy1wZExaN3lpT2FEQkxnaHRmWmZUUUZUUUJmblZreGlpaFpOdnFhbzlEd1RsVVJtX216cmhxTm5BcTN2eUZ2T054cDE5bmlEamJ3NGR6MVpFQnA=
|
||||
|
|
|
|||
|
|
@ -11,9 +11,34 @@ IMPORTANT: Model Registration Requirements
|
|||
- If duplicate displayNames are detected during registration, an error will be raised
|
||||
"""
|
||||
|
||||
import re as _re
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
from modules.datamodels.datamodelAi import AiModel
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator, Union
|
||||
from modules.datamodels.datamodelAi import AiModel, AiModelCall, AiModelResponse
|
||||
|
||||
|
||||
_RETRY_AFTER_PATTERN = _re.compile(r"try again in (\d+(?:\.\d+)?)\s*s", _re.IGNORECASE)
|
||||
|
||||
|
||||
def _parseRetryAfterSeconds(message: str) -> float:
|
||||
"""Extract retry-after seconds from provider error messages like 'Please try again in 6.558s'."""
|
||||
match = _RETRY_AFTER_PATTERN.search(message)
|
||||
return float(match.group(1)) if match else 0.0
|
||||
|
||||
|
||||
class RateLimitExceededException(Exception):
|
||||
"""Raised when a provider's rate limit (TPM / RPM) is exceeded."""
|
||||
def __init__(self, message: str = "Rate limit exceeded", retryAfterSeconds: float = 0.0):
|
||||
super().__init__(message)
|
||||
if retryAfterSeconds <= 0:
|
||||
retryAfterSeconds = _parseRetryAfterSeconds(message)
|
||||
self.retryAfterSeconds = retryAfterSeconds
|
||||
|
||||
|
||||
class ContextLengthExceededException(Exception):
|
||||
"""Raised when the input exceeds a model's context window."""
|
||||
pass
|
||||
|
||||
|
||||
class BaseConnectorAi(ABC):
|
||||
|
|
@ -102,3 +127,24 @@ class BaseConnectorAi(ABC):
|
|||
"""Get only available models."""
|
||||
models = self.getCachedModels()
|
||||
return [model for model in models if model.isAvailable]
|
||||
|
||||
async def callAiBasicStream(self, modelCall: AiModelCall) -> AsyncGenerator[Union[str, AiModelResponse], None]:
|
||||
"""Stream AI response. Yields str deltas during generation, then final AiModelResponse.
|
||||
|
||||
Default implementation: falls back to non-streaming callAiBasic.
|
||||
Override in connectors that support streaming.
|
||||
"""
|
||||
response = await self.callAiBasic(modelCall)
|
||||
if response.content:
|
||||
yield response.content
|
||||
yield response
|
||||
|
||||
async def callEmbedding(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""Generate embeddings for input texts. Override in connectors that support embeddings.
|
||||
|
||||
Reads texts from modelCall.embeddingInput.
|
||||
Returns AiModelResponse with metadata["embeddings"] containing the vectors.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not support embeddings"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,26 +19,29 @@ class ModelSelector:
|
|||
"""Model selector with priority scoring and recent-failure cooldown."""
|
||||
|
||||
def __init__(self):
|
||||
self._failureLog: Dict[str, float] = {}
|
||||
self._failureLog: Dict[str, Tuple[float, float]] = {}
|
||||
logger.info("ModelSelector initialized with failure cooldown support")
|
||||
|
||||
def reportFailure(self, modelName: str):
|
||||
def reportFailure(self, modelName: str, cooldownSeconds: float = 0.0):
|
||||
"""Record that a model just failed (rate limit, error, etc.).
|
||||
The model will be deprioritized for COOLDOWN_DURATION seconds."""
|
||||
self._failureLog[modelName] = time.time()
|
||||
logger.info(f"ModelSelector: Recorded failure for {modelName}, cooldown {_COOLDOWN_DURATION}s")
|
||||
The model will be deprioritized for *cooldownSeconds* (default: _COOLDOWN_DURATION)."""
|
||||
if cooldownSeconds <= 0:
|
||||
cooldownSeconds = _COOLDOWN_DURATION
|
||||
self._failureLog[modelName] = (time.time(), cooldownSeconds)
|
||||
logger.info(f"ModelSelector: Recorded failure for {modelName}, cooldown {cooldownSeconds:.1f}s")
|
||||
|
||||
def _getCooldownPenalty(self, modelName: str) -> float:
|
||||
"""Return a score penalty (0.0 = no penalty, large negative = recently failed)."""
|
||||
failedAt = self._failureLog.get(modelName)
|
||||
if failedAt is None:
|
||||
entry = self._failureLog.get(modelName)
|
||||
if entry is None:
|
||||
return 0.0
|
||||
failedAt, cooldown = entry
|
||||
elapsed = time.time() - failedAt
|
||||
if elapsed > _COOLDOWN_DURATION:
|
||||
if elapsed > cooldown:
|
||||
del self._failureLog[modelName]
|
||||
return 0.0
|
||||
remaining = _COOLDOWN_DURATION - elapsed
|
||||
return -(remaining / _COOLDOWN_DURATION) * 5000.0
|
||||
remaining = cooldown - elapsed
|
||||
return -(remaining / cooldown) * 5000.0
|
||||
|
||||
def selectModel(self,
|
||||
prompt: str,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
import json
|
||||
import logging
|
||||
import httpx
|
||||
import os
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, AsyncGenerator, Optional, Union
|
||||
from fastapi import HTTPException
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from .aicoreBase import BaseConnectorAi
|
||||
from .aicoreBase import BaseConnectorAi, RateLimitExceededException
|
||||
from modules.datamodels.datamodelAi import AiModel, PriorityEnum, ProcessingModeEnum, OperationTypeEnum, AiModelCall, AiModelResponse, createOperationTypeRatings
|
||||
|
||||
# Configure logger
|
||||
|
|
@ -61,13 +62,15 @@ class AiAnthropic(BaseConnectorAi):
|
|||
speedRating=6, # Slower due to high-quality processing
|
||||
qualityRating=10, # Best quality available
|
||||
functionCall=self.callAiBasic,
|
||||
functionCallStream=self.callAiBasicStream,
|
||||
priority=PriorityEnum.QUALITY,
|
||||
processingMode=ProcessingModeEnum.DETAILED,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.PLAN, 9),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 9),
|
||||
(OperationTypeEnum.DATA_GENERATE, 9),
|
||||
(OperationTypeEnum.DATA_EXTRACT, 8)
|
||||
(OperationTypeEnum.DATA_EXTRACT, 8),
|
||||
(OperationTypeEnum.AGENT, 9),
|
||||
),
|
||||
version="claude-sonnet-4-5-20250929",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.003 + (bytesReceived / 4 / 1000) * 0.015
|
||||
|
|
@ -85,13 +88,15 @@ class AiAnthropic(BaseConnectorAi):
|
|||
speedRating=9, # Very fast, lightweight model
|
||||
qualityRating=8, # Good quality, cost-efficient
|
||||
functionCall=self.callAiBasic,
|
||||
functionCallStream=self.callAiBasicStream,
|
||||
priority=PriorityEnum.SPEED,
|
||||
processingMode=ProcessingModeEnum.BASIC,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.PLAN, 8),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 8),
|
||||
(OperationTypeEnum.DATA_GENERATE, 8),
|
||||
(OperationTypeEnum.DATA_EXTRACT, 7)
|
||||
(OperationTypeEnum.DATA_EXTRACT, 7),
|
||||
(OperationTypeEnum.AGENT, 7),
|
||||
),
|
||||
version="claude-haiku-4-5-20251001",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.001 + (bytesReceived / 4 / 1000) * 0.005
|
||||
|
|
@ -109,13 +114,15 @@ class AiAnthropic(BaseConnectorAi):
|
|||
speedRating=5, # Moderate latency, most capable
|
||||
qualityRating=10, # Top-tier intelligence
|
||||
functionCall=self.callAiBasic,
|
||||
functionCallStream=self.callAiBasicStream,
|
||||
priority=PriorityEnum.QUALITY,
|
||||
processingMode=ProcessingModeEnum.DETAILED,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.PLAN, 10),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 10),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 8),
|
||||
(OperationTypeEnum.DATA_GENERATE, 10),
|
||||
(OperationTypeEnum.DATA_EXTRACT, 9)
|
||||
(OperationTypeEnum.DATA_EXTRACT, 9),
|
||||
(OperationTypeEnum.AGENT, 10),
|
||||
),
|
||||
version="claude-opus-4-6",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.005 + (bytesReceived / 4 / 1000) * 0.025
|
||||
|
|
@ -158,8 +165,6 @@ class AiAnthropic(BaseConnectorAi):
|
|||
HTTPException: For errors in API communication
|
||||
"""
|
||||
try:
|
||||
# Extract parameters from modelCall
|
||||
messages = modelCall.messages
|
||||
model = modelCall.model
|
||||
options = modelCall.options
|
||||
temperature = getattr(options, "temperature", None)
|
||||
|
|
@ -167,44 +172,8 @@ class AiAnthropic(BaseConnectorAi):
|
|||
temperature = model.temperature
|
||||
maxTokens = model.maxTokens
|
||||
|
||||
# Transform OpenAI-style messages to Anthropic format:
|
||||
# - Move any 'system' role content to top-level 'system'
|
||||
# - Keep only 'user'/'assistant' messages in the list
|
||||
system_contents: List[str] = []
|
||||
converted_messages: List[Dict[str, Any]] = []
|
||||
for m in messages:
|
||||
role = m.get("role")
|
||||
content = m.get("content", "")
|
||||
if role == "system":
|
||||
# Collect system content; Anthropic expects top-level 'system'
|
||||
if isinstance(content, list):
|
||||
# Join text parts if provided as blocks
|
||||
joined = "\n\n".join(
|
||||
[
|
||||
(part.get("text") if isinstance(part, dict) else str(part))
|
||||
for part in content
|
||||
]
|
||||
)
|
||||
system_contents.append(joined)
|
||||
else:
|
||||
system_contents.append(str(content))
|
||||
continue
|
||||
# For Anthropic, content can be a string; pass through strings, collapse blocks
|
||||
if isinstance(content, list):
|
||||
# Collapse to text if blocks are provided
|
||||
collapsed = "\n\n".join(
|
||||
[
|
||||
(part.get("text") if isinstance(part, dict) else str(part))
|
||||
for part in content
|
||||
]
|
||||
)
|
||||
converted_messages.append({"role": role, "content": collapsed})
|
||||
else:
|
||||
converted_messages.append({"role": role, "content": content})
|
||||
converted_messages, system_prompt = _convertMessagesForAnthropic(modelCall.messages)
|
||||
|
||||
system_prompt = "\n\n".join([s for s in system_contents if s]) if system_contents else None
|
||||
|
||||
# Create Anthropic API payload
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model.name,
|
||||
"messages": converted_messages,
|
||||
|
|
@ -218,6 +187,13 @@ class AiAnthropic(BaseConnectorAi):
|
|||
if system_prompt:
|
||||
payload["system"] = system_prompt
|
||||
|
||||
if modelCall.tools:
|
||||
payload["tools"] = _convertToolsToAnthropicFormat(modelCall.tools)
|
||||
if modelCall.toolChoice:
|
||||
payload["tool_choice"] = modelCall.toolChoice
|
||||
else:
|
||||
payload["tool_choice"] = {"type": "auto"}
|
||||
|
||||
response = await self.httpClient.post(
|
||||
model.apiUrl,
|
||||
json=payload
|
||||
|
|
@ -227,11 +203,12 @@ class AiAnthropic(BaseConnectorAi):
|
|||
error_detail = f"Anthropic API error: {response.status_code} - {response.text}"
|
||||
logger.error(error_detail)
|
||||
|
||||
# Provide more specific error messages based on status code
|
||||
if response.status_code == 429:
|
||||
raise RateLimitExceededException(
|
||||
f"Rate limit exceeded for {model.name}: {response.text}"
|
||||
)
|
||||
if response.status_code == 529:
|
||||
error_message = "Anthropic API is currently overloaded. Please try again in a few minutes."
|
||||
elif response.status_code == 429:
|
||||
error_message = "Rate limit exceeded. Please wait before making another request."
|
||||
elif response.status_code == 401:
|
||||
error_message = "Invalid API key. Please check your Anthropic API configuration."
|
||||
elif response.status_code == 400:
|
||||
|
|
@ -244,31 +221,43 @@ class AiAnthropic(BaseConnectorAi):
|
|||
# Parse response
|
||||
anthropicResponse = response.json()
|
||||
|
||||
# Extract content from response
|
||||
# Extract content and tool_use blocks from response
|
||||
content = ""
|
||||
toolCalls = []
|
||||
if "content" in anthropicResponse:
|
||||
if isinstance(anthropicResponse["content"], list):
|
||||
# Content is a list of parts (in newer API versions)
|
||||
for part in anthropicResponse["content"]:
|
||||
if part.get("type") == "text":
|
||||
content += part.get("text", "")
|
||||
elif part.get("type") == "tool_use":
|
||||
toolCalls.append({
|
||||
"id": part.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": part.get("name", ""),
|
||||
"arguments": json.dumps(part.get("input", {})) if isinstance(part.get("input"), dict) else str(part.get("input", "{}"))
|
||||
}
|
||||
})
|
||||
else:
|
||||
# Direct content as string (in older API versions)
|
||||
content = anthropicResponse["content"]
|
||||
|
||||
# Debug logging for empty responses
|
||||
if not content or content.strip() == "":
|
||||
if not content and not toolCalls:
|
||||
logger.warning(f"Anthropic API returned empty content. Full response: {anthropicResponse}")
|
||||
content = "[Anthropic API returned empty response]"
|
||||
|
||||
# Return standardized response
|
||||
metadata = {"response_id": anthropicResponse.get("id", "")}
|
||||
if toolCalls:
|
||||
metadata["toolCalls"] = toolCalls
|
||||
|
||||
return AiModelResponse(
|
||||
content=content,
|
||||
success=True,
|
||||
modelId=model.name,
|
||||
metadata={"response_id": anthropicResponse.get("id", "")}
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
except (RateLimitExceededException, HTTPException):
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = str(e) if str(e) else f"{type(e).__name__}"
|
||||
error_detail = f"Error calling Anthropic API: {error_msg}"
|
||||
|
|
@ -279,6 +268,128 @@ class AiAnthropic(BaseConnectorAi):
|
|||
logger.error(error_detail, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=error_detail)
|
||||
|
||||
async def callAiBasicStream(self, modelCall: AiModelCall) -> AsyncGenerator[Union[str, AiModelResponse], None]:
|
||||
"""Stream Anthropic response. Yields str deltas, then final AiModelResponse."""
|
||||
try:
|
||||
model = modelCall.model
|
||||
options = modelCall.options
|
||||
temperature = getattr(options, "temperature", None)
|
||||
if temperature is None:
|
||||
temperature = model.temperature
|
||||
|
||||
converted, system_prompt = _convertMessagesForAnthropic(modelCall.messages)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model.name,
|
||||
"messages": converted,
|
||||
"temperature": temperature,
|
||||
"max_tokens": model.maxTokens,
|
||||
"stream": True,
|
||||
}
|
||||
if system_prompt:
|
||||
payload["system"] = system_prompt
|
||||
if modelCall.tools:
|
||||
payload["tools"] = _convertToolsToAnthropicFormat(modelCall.tools)
|
||||
payload["tool_choice"] = modelCall.toolChoice or {"type": "auto"}
|
||||
|
||||
fullContent = ""
|
||||
toolUseBlocks: Dict[int, Dict[str, Any]] = {}
|
||||
currentToolIdx = -1
|
||||
stopReason: Optional[str] = None
|
||||
|
||||
async with self.httpClient.stream("POST", model.apiUrl, json=payload) as response:
|
||||
if response.status_code != 200:
|
||||
body = await response.aread()
|
||||
bodyStr = body.decode()
|
||||
if response.status_code == 429:
|
||||
raise RateLimitExceededException(
|
||||
f"Rate limit exceeded for {model.name}: {bodyStr}"
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=f"Anthropic stream error: {response.status_code} - {bodyStr}")
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
try:
|
||||
event = json.loads(line[6:])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
eventType = event.get("type", "")
|
||||
|
||||
if eventType == "error":
|
||||
errDetail = event.get("error", {})
|
||||
errMsg = errDetail.get("message", str(errDetail))
|
||||
errType = errDetail.get("type", "unknown")
|
||||
logger.error(f"Anthropic stream error event: type={errType}, message={errMsg}")
|
||||
if "overloaded" in errMsg.lower() or "overloaded" in errType.lower():
|
||||
raise HTTPException(status_code=500, detail=f"Anthropic API is currently overloaded. Please try again in a few minutes.")
|
||||
raise HTTPException(status_code=500, detail=f"Anthropic stream error: [{errType}] {errMsg}")
|
||||
|
||||
elif eventType == "content_block_start":
|
||||
block = event.get("content_block", {})
|
||||
idx = event.get("index", 0)
|
||||
if block.get("type") == "tool_use":
|
||||
currentToolIdx = idx
|
||||
toolUseBlocks[idx] = {
|
||||
"id": block.get("id", ""),
|
||||
"name": block.get("name", ""),
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
elif eventType == "content_block_delta":
|
||||
delta = event.get("delta", {})
|
||||
if delta.get("type") == "text_delta":
|
||||
text = delta.get("text", "")
|
||||
fullContent += text
|
||||
yield text
|
||||
elif delta.get("type") == "input_json_delta":
|
||||
idx = event.get("index", currentToolIdx)
|
||||
if idx in toolUseBlocks:
|
||||
toolUseBlocks[idx]["arguments"] += delta.get("partial_json", "")
|
||||
|
||||
elif eventType == "message_delta":
|
||||
delta = event.get("delta", {})
|
||||
stopReason = delta.get("stop_reason", stopReason)
|
||||
|
||||
elif eventType == "message_stop":
|
||||
break
|
||||
|
||||
if not fullContent and not toolUseBlocks:
|
||||
logger.warning(
|
||||
f"Anthropic stream returned empty response: model={model.name}, "
|
||||
f"stopReason={stopReason}"
|
||||
)
|
||||
|
||||
metadata: Dict[str, Any] = {}
|
||||
if stopReason:
|
||||
metadata["stopReason"] = stopReason
|
||||
if toolUseBlocks:
|
||||
metadata["toolCalls"] = [
|
||||
{
|
||||
"id": tb["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tb["name"],
|
||||
"arguments": tb["arguments"],
|
||||
},
|
||||
}
|
||||
for tb in toolUseBlocks.values()
|
||||
]
|
||||
|
||||
yield AiModelResponse(
|
||||
content=fullContent,
|
||||
success=True,
|
||||
modelId=model.name,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
except (RateLimitExceededException, HTTPException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming Anthropic API: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Error streaming Anthropic API: {e}")
|
||||
|
||||
async def callAiImage(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""
|
||||
Analyzes an image using Anthropic's vision capabilities using standardized pattern.
|
||||
|
|
@ -331,6 +442,20 @@ class AiAnthropic(BaseConnectorAi):
|
|||
mimeType = parts[0].replace("data:", "")
|
||||
base64Data = parts[1]
|
||||
|
||||
import base64 as _b64
|
||||
try:
|
||||
rawHead = _b64.b64decode(base64Data[:32])
|
||||
if rawHead[:3] == b"\xff\xd8\xff":
|
||||
mimeType = "image/jpeg"
|
||||
elif rawHead[:8] == b"\x89PNG\r\n\x1a\n":
|
||||
mimeType = "image/png"
|
||||
elif rawHead[:4] == b"GIF8":
|
||||
mimeType = "image/gif"
|
||||
elif rawHead[:4] == b"RIFF" and rawHead[8:12] == b"WEBP":
|
||||
mimeType = "image/webp"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Convert to Anthropic's vision format
|
||||
anthropicMessages = [{
|
||||
"role": "user",
|
||||
|
|
@ -425,3 +550,100 @@ class AiAnthropic(BaseConnectorAi):
|
|||
success=False,
|
||||
error=f"Error during image analysis: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def _convertMessagesForAnthropic(messages: List[Dict[str, Any]]):
|
||||
"""Convert OpenAI-style messages to Anthropic format. Returns (messages, system_prompt)."""
|
||||
system_contents: List[str] = []
|
||||
converted_messages: List[Dict[str, Any]] = []
|
||||
pendingToolResults: List[Dict[str, Any]] = []
|
||||
|
||||
def _flush():
|
||||
if not pendingToolResults:
|
||||
return
|
||||
converted_messages.append({"role": "user", "content": list(pendingToolResults)})
|
||||
pendingToolResults.clear()
|
||||
|
||||
def _collapse(content):
|
||||
if isinstance(content, list):
|
||||
return "\n\n".join(
|
||||
(part.get("text") if isinstance(part, dict) else str(part))
|
||||
for part in content
|
||||
)
|
||||
return str(content) if content else ""
|
||||
|
||||
for m in messages:
|
||||
role = m.get("role")
|
||||
content = m.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
system_contents.append(_collapse(content))
|
||||
continue
|
||||
if role == "tool":
|
||||
pendingToolResults.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": m.get("tool_call_id", ""),
|
||||
"content": str(content) if content else "",
|
||||
})
|
||||
continue
|
||||
|
||||
_flush()
|
||||
|
||||
if role == "assistant" and m.get("tool_calls"):
|
||||
contentBlocks = []
|
||||
textPart = _collapse(content)
|
||||
if textPart:
|
||||
contentBlocks.append({"type": "text", "text": textPart})
|
||||
for tc in m["tool_calls"]:
|
||||
fn = tc.get("function", {})
|
||||
inputData = fn.get("arguments", "{}")
|
||||
if isinstance(inputData, str):
|
||||
try:
|
||||
inputData = json.loads(inputData)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
inputData = {}
|
||||
contentBlocks.append({
|
||||
"type": "tool_use",
|
||||
"id": tc.get("id", ""),
|
||||
"name": fn.get("name", ""),
|
||||
"input": inputData,
|
||||
})
|
||||
converted_messages.append({"role": "assistant", "content": contentBlocks})
|
||||
continue
|
||||
|
||||
converted_messages.append({"role": role, "content": _collapse(content)})
|
||||
|
||||
_flush()
|
||||
|
||||
merged: List[Dict[str, Any]] = []
|
||||
for msg in converted_messages:
|
||||
if merged and merged[-1]["role"] == msg["role"]:
|
||||
prev = merged[-1]
|
||||
pc, nc = prev["content"], msg["content"]
|
||||
if isinstance(pc, str) and isinstance(nc, str):
|
||||
prev["content"] = pc + "\n\n" + nc
|
||||
elif isinstance(pc, list) and isinstance(nc, list):
|
||||
prev["content"] = pc + nc
|
||||
elif isinstance(pc, str) and isinstance(nc, list):
|
||||
prev["content"] = [{"type": "text", "text": pc}] + nc
|
||||
elif isinstance(pc, list) and isinstance(nc, str):
|
||||
prev["content"] = pc + [{"type": "text", "text": nc}]
|
||||
else:
|
||||
merged.append(msg)
|
||||
|
||||
system_prompt = "\n\n".join([s for s in system_contents if s]) if system_contents else None
|
||||
return merged, system_prompt
|
||||
|
||||
|
||||
def _convertToolsToAnthropicFormat(openaiTools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Convert OpenAI-style tool definitions to Anthropic format."""
|
||||
anthropicTools = []
|
||||
for tool in openaiTools:
|
||||
if tool.get("type") == "function":
|
||||
fn = tool["function"]
|
||||
anthropicTools.append({
|
||||
"name": fn["name"],
|
||||
"description": fn.get("description", ""),
|
||||
"input_schema": fn.get("parameters", {"type": "object", "properties": {}})
|
||||
})
|
||||
return anthropicTools
|
||||
|
|
@ -1,24 +1,16 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
import logging
|
||||
import json as _json
|
||||
import httpx
|
||||
from typing import List
|
||||
from typing import List, Dict, Any, AsyncGenerator, Union
|
||||
from fastapi import HTTPException
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from .aicoreBase import BaseConnectorAi
|
||||
from .aicoreBase import BaseConnectorAi, RateLimitExceededException, ContextLengthExceededException
|
||||
from modules.datamodels.datamodelAi import AiModel, PriorityEnum, ProcessingModeEnum, OperationTypeEnum, AiModelCall, AiModelResponse, createOperationTypeRatings
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ContextLengthExceededException(Exception):
|
||||
"""Exception raised when the context length exceeds the model's limit"""
|
||||
pass
|
||||
|
||||
class RateLimitExceededException(Exception):
|
||||
"""Exception raised when the provider's rate limit (TPM) is exceeded"""
|
||||
pass
|
||||
|
||||
def loadConfigData():
|
||||
"""Load configuration data for Mistral connector"""
|
||||
return {
|
||||
|
|
@ -66,13 +58,15 @@ class AiMistral(BaseConnectorAi):
|
|||
speedRating=8, # Good speed for complex tasks
|
||||
qualityRating=9, # High quality
|
||||
functionCall=self.callAiBasic,
|
||||
functionCallStream=self.callAiBasicStream,
|
||||
priority=PriorityEnum.BALANCED,
|
||||
processingMode=ProcessingModeEnum.ADVANCED,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.PLAN, 9),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 9),
|
||||
(OperationTypeEnum.DATA_GENERATE, 9),
|
||||
(OperationTypeEnum.DATA_EXTRACT, 8)
|
||||
(OperationTypeEnum.DATA_EXTRACT, 8),
|
||||
(OperationTypeEnum.AGENT, 8),
|
||||
),
|
||||
version="mistral-large-latest",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0005 + (bytesReceived / 4 / 1000) * 0.0015
|
||||
|
|
@ -90,17 +84,40 @@ class AiMistral(BaseConnectorAi):
|
|||
speedRating=9, # Very fast, lightweight model
|
||||
qualityRating=7, # Good quality, cost-efficient
|
||||
functionCall=self.callAiBasic,
|
||||
functionCallStream=self.callAiBasicStream,
|
||||
priority=PriorityEnum.SPEED,
|
||||
processingMode=ProcessingModeEnum.BASIC,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.PLAN, 7),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 7),
|
||||
(OperationTypeEnum.DATA_GENERATE, 8),
|
||||
(OperationTypeEnum.DATA_EXTRACT, 7)
|
||||
(OperationTypeEnum.DATA_EXTRACT, 7),
|
||||
(OperationTypeEnum.AGENT, 6),
|
||||
),
|
||||
version="mistral-small-latest",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00006 + (bytesReceived / 4 / 1000) * 0.00018
|
||||
),
|
||||
AiModel(
|
||||
name="mistral-embed",
|
||||
displayName="Mistral Embed",
|
||||
connectorType="mistral",
|
||||
apiUrl="https://api.mistral.ai/v1/embeddings",
|
||||
temperature=0.0,
|
||||
maxTokens=0,
|
||||
contextLength=8192,
|
||||
costPer1kTokensInput=0.0001, # $0.10/M tokens
|
||||
costPer1kTokensOutput=0.0,
|
||||
speedRating=10,
|
||||
qualityRating=7,
|
||||
functionCall=self.callEmbedding,
|
||||
priority=PriorityEnum.COST,
|
||||
processingMode=ProcessingModeEnum.BASIC,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.EMBEDDING, 8)
|
||||
),
|
||||
version="mistral-embed",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0001
|
||||
),
|
||||
AiModel(
|
||||
name="mistral-large-latest",
|
||||
displayName="Mistral Large 3 Vision",
|
||||
|
|
@ -158,6 +175,10 @@ class AiMistral(BaseConnectorAi):
|
|||
"max_tokens": maxTokens
|
||||
}
|
||||
|
||||
if modelCall.tools:
|
||||
payload["tools"] = modelCall.tools
|
||||
payload["tool_choice"] = modelCall.toolChoice or "auto"
|
||||
|
||||
response = await self.httpClient.post(
|
||||
model.apiUrl,
|
||||
json=payload
|
||||
|
|
@ -197,13 +218,18 @@ class AiMistral(BaseConnectorAi):
|
|||
raise HTTPException(status_code=500, detail=error_message)
|
||||
|
||||
responseJson = response.json()
|
||||
content = responseJson["choices"][0]["message"]["content"]
|
||||
choiceMessage = responseJson["choices"][0]["message"]
|
||||
content = choiceMessage.get("content") or ""
|
||||
|
||||
metadata = {"response_id": responseJson.get("id", "")}
|
||||
if choiceMessage.get("tool_calls"):
|
||||
metadata["toolCalls"] = choiceMessage["tool_calls"]
|
||||
|
||||
return AiModelResponse(
|
||||
content=content,
|
||||
success=True,
|
||||
modelId=model.name,
|
||||
metadata={"response_id": responseJson.get("id", "")}
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
except ContextLengthExceededException:
|
||||
|
|
@ -216,6 +242,147 @@ class AiMistral(BaseConnectorAi):
|
|||
logger.error(f"Error calling Mistral API: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error calling Mistral API: {str(e)}")
|
||||
|
||||
async def callAiBasicStream(self, modelCall: AiModelCall) -> AsyncGenerator[Union[str, AiModelResponse], None]:
|
||||
"""Stream Mistral response. Yields str deltas, then final AiModelResponse."""
|
||||
try:
|
||||
model = modelCall.model
|
||||
options = modelCall.options
|
||||
temperature = getattr(options, "temperature", None)
|
||||
if temperature is None:
|
||||
temperature = model.temperature
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model.name,
|
||||
"messages": modelCall.messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": model.maxTokens,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if modelCall.tools:
|
||||
payload["tools"] = modelCall.tools
|
||||
payload["tool_choice"] = modelCall.toolChoice or "auto"
|
||||
|
||||
fullContent = ""
|
||||
toolCallsAccum: Dict[int, Dict[str, Any]] = {}
|
||||
|
||||
async with self.httpClient.stream("POST", model.apiUrl, json=payload) as response:
|
||||
if response.status_code != 200:
|
||||
body = await response.aread()
|
||||
bodyStr = body.decode()
|
||||
if response.status_code == 429:
|
||||
try:
|
||||
errorMsg = _json.loads(bodyStr).get("error", {}).get("message", "Rate limit exceeded")
|
||||
except (ValueError, KeyError):
|
||||
errorMsg = f"Rate limit exceeded for {model.name}"
|
||||
raise RateLimitExceededException(f"Rate limit exceeded for {model.name}: {errorMsg}")
|
||||
raise HTTPException(status_code=500, detail=f"Mistral stream error: {response.status_code} - {bodyStr}")
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data = line[6:]
|
||||
if data.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = _json.loads(data)
|
||||
except _json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||
if "content" in delta and delta["content"]:
|
||||
fullContent += delta["content"]
|
||||
yield delta["content"]
|
||||
|
||||
for tcDelta in delta.get("tool_calls", []):
|
||||
idx = tcDelta.get("index", 0)
|
||||
if idx not in toolCallsAccum:
|
||||
toolCallsAccum[idx] = {
|
||||
"id": tcDelta.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
if tcDelta.get("id"):
|
||||
toolCallsAccum[idx]["id"] = tcDelta["id"]
|
||||
fn = tcDelta.get("function", {})
|
||||
if fn.get("name"):
|
||||
toolCallsAccum[idx]["function"]["name"] = fn["name"]
|
||||
if fn.get("arguments"):
|
||||
toolCallsAccum[idx]["function"]["arguments"] += fn["arguments"]
|
||||
|
||||
metadata: Dict[str, Any] = {}
|
||||
if toolCallsAccum:
|
||||
metadata["toolCalls"] = [toolCallsAccum[i] for i in sorted(toolCallsAccum)]
|
||||
|
||||
yield AiModelResponse(
|
||||
content=fullContent,
|
||||
success=True,
|
||||
modelId=model.name,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
except (RateLimitExceededException, ContextLengthExceededException, HTTPException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming Mistral API: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Error streaming Mistral API: {e}")
|
||||
|
||||
async def callEmbedding(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""Generate embeddings via the Mistral Embeddings API.
|
||||
|
||||
Reads texts from modelCall.embeddingInput.
|
||||
Returns vectors in metadata["embeddings"].
|
||||
"""
|
||||
try:
|
||||
model = modelCall.model
|
||||
texts = modelCall.embeddingInput or []
|
||||
if not texts:
|
||||
return AiModelResponse(
|
||||
content="", success=False, error="No embeddingInput provided"
|
||||
)
|
||||
|
||||
payload = {"model": model.name, "input": texts}
|
||||
response = await self.httpClient.post(model.apiUrl, json=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
errorMessage = f"Mistral Embedding API error: {response.status_code} - {response.text}"
|
||||
logger.error(errorMessage)
|
||||
if response.status_code == 429:
|
||||
raise RateLimitExceededException(f"Rate limit exceeded for {model.name}")
|
||||
if response.status_code == 400:
|
||||
try:
|
||||
errorData = response.json()
|
||||
errMsg = errorData.get("error", {}).get("message", "").lower()
|
||||
errCode = errorData.get("error", {}).get("code", "")
|
||||
if errCode == "context_length_exceeded" or "too many tokens" in errMsg or "maximum context length" in errMsg:
|
||||
raise ContextLengthExceededException(
|
||||
f"Embedding context length exceeded for {model.name}: {errorData.get('error', {}).get('message', '')}"
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
pass
|
||||
raise HTTPException(status_code=500, detail=errorMessage)
|
||||
|
||||
responseJson = response.json()
|
||||
embeddings = [item["embedding"] for item in responseJson["data"]]
|
||||
usage = responseJson.get("usage", {})
|
||||
|
||||
return AiModelResponse(
|
||||
content="",
|
||||
success=True,
|
||||
modelId=model.name,
|
||||
tokensUsed={
|
||||
"input": usage.get("prompt_tokens", 0),
|
||||
"output": 0,
|
||||
"total": usage.get("total_tokens", 0),
|
||||
},
|
||||
metadata={"embeddings": embeddings},
|
||||
)
|
||||
except (RateLimitExceededException, ContextLengthExceededException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Mistral Embedding API: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error calling Mistral Embedding API: {str(e)}")
|
||||
|
||||
async def callAiImage(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""
|
||||
Analyzes an image with the Mistral Vision API using standardized pattern.
|
||||
|
|
|
|||
|
|
@ -1,24 +1,16 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
import logging
|
||||
import json as _json
|
||||
import httpx
|
||||
from typing import List
|
||||
from typing import List, Dict, Any, AsyncGenerator, Union
|
||||
from fastapi import HTTPException
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from .aicoreBase import BaseConnectorAi
|
||||
from .aicoreBase import BaseConnectorAi, RateLimitExceededException, ContextLengthExceededException
|
||||
from modules.datamodels.datamodelAi import AiModel, PriorityEnum, ProcessingModeEnum, OperationTypeEnum, AiModelCall, AiModelResponse, createOperationTypeRatings, AiCallPromptImage
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ContextLengthExceededException(Exception):
|
||||
"""Exception raised when the context length exceeds the model's limit"""
|
||||
pass
|
||||
|
||||
class RateLimitExceededException(Exception):
|
||||
"""Exception raised when the provider's rate limit (TPM) is exceeded"""
|
||||
pass
|
||||
|
||||
def loadConfigData():
|
||||
"""Load configuration data for OpenAI connector"""
|
||||
return {
|
||||
|
|
@ -67,13 +59,15 @@ class AiOpenai(BaseConnectorAi):
|
|||
speedRating=8, # Good speed for complex tasks
|
||||
qualityRating=10, # High quality
|
||||
functionCall=self.callAiBasic,
|
||||
functionCallStream=self.callAiBasicStream,
|
||||
priority=PriorityEnum.BALANCED,
|
||||
processingMode=ProcessingModeEnum.ADVANCED,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.PLAN, 9),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 10),
|
||||
(OperationTypeEnum.DATA_GENERATE, 10),
|
||||
(OperationTypeEnum.DATA_EXTRACT, 7)
|
||||
(OperationTypeEnum.DATA_EXTRACT, 7),
|
||||
(OperationTypeEnum.AGENT, 9),
|
||||
),
|
||||
version="gpt-4o",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0025 + (bytesReceived / 4 / 1000) * 0.01
|
||||
|
|
@ -92,13 +86,15 @@ class AiOpenai(BaseConnectorAi):
|
|||
speedRating=9, # Very fast
|
||||
qualityRating=8, # Good quality, replaces gpt-3.5-turbo
|
||||
functionCall=self.callAiBasic,
|
||||
functionCallStream=self.callAiBasicStream,
|
||||
priority=PriorityEnum.SPEED,
|
||||
processingMode=ProcessingModeEnum.BASIC,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.PLAN, 8),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 8),
|
||||
(OperationTypeEnum.DATA_GENERATE, 9),
|
||||
(OperationTypeEnum.DATA_EXTRACT, 7)
|
||||
(OperationTypeEnum.DATA_EXTRACT, 7),
|
||||
(OperationTypeEnum.AGENT, 8),
|
||||
),
|
||||
version="gpt-4o-mini",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00015 + (bytesReceived / 4 / 1000) * 0.0006
|
||||
|
|
@ -125,6 +121,48 @@ class AiOpenai(BaseConnectorAi):
|
|||
version="gpt-4o",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0025 + (bytesReceived / 4 / 1000) * 0.01
|
||||
),
|
||||
AiModel(
|
||||
name="text-embedding-3-small",
|
||||
displayName="OpenAI Embedding Small",
|
||||
connectorType="openai",
|
||||
apiUrl="https://api.openai.com/v1/embeddings",
|
||||
temperature=0.0,
|
||||
maxTokens=0,
|
||||
contextLength=8191,
|
||||
costPer1kTokensInput=0.00002, # $0.02/M tokens
|
||||
costPer1kTokensOutput=0.0,
|
||||
speedRating=10,
|
||||
qualityRating=8,
|
||||
functionCall=self.callEmbedding,
|
||||
priority=PriorityEnum.COST,
|
||||
processingMode=ProcessingModeEnum.BASIC,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.EMBEDDING, 10)
|
||||
),
|
||||
version="text-embedding-3-small",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00002
|
||||
),
|
||||
AiModel(
|
||||
name="text-embedding-3-large",
|
||||
displayName="OpenAI Embedding Large",
|
||||
connectorType="openai",
|
||||
apiUrl="https://api.openai.com/v1/embeddings",
|
||||
temperature=0.0,
|
||||
maxTokens=0,
|
||||
contextLength=8191,
|
||||
costPer1kTokensInput=0.00013, # $0.13/M tokens
|
||||
costPer1kTokensOutput=0.0,
|
||||
speedRating=9,
|
||||
qualityRating=10,
|
||||
functionCall=self.callEmbedding,
|
||||
priority=PriorityEnum.QUALITY,
|
||||
processingMode=ProcessingModeEnum.ADVANCED,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.EMBEDDING, 10)
|
||||
),
|
||||
version="text-embedding-3-large",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00013
|
||||
),
|
||||
AiModel(
|
||||
name="dall-e-3",
|
||||
displayName="OpenAI DALL-E 3",
|
||||
|
|
@ -179,6 +217,10 @@ class AiOpenai(BaseConnectorAi):
|
|||
"max_tokens": maxTokens
|
||||
}
|
||||
|
||||
if modelCall.tools:
|
||||
payload["tools"] = modelCall.tools
|
||||
payload["tool_choice"] = modelCall.toolChoice or "auto"
|
||||
|
||||
response = await self.httpClient.post(
|
||||
model.apiUrl,
|
||||
json=payload
|
||||
|
|
@ -218,22 +260,168 @@ class AiOpenai(BaseConnectorAi):
|
|||
raise HTTPException(status_code=500, detail=error_message)
|
||||
|
||||
responseJson = response.json()
|
||||
content = responseJson["choices"][0]["message"]["content"]
|
||||
choiceMessage = responseJson["choices"][0]["message"]
|
||||
content = choiceMessage.get("content") or ""
|
||||
|
||||
metadata = {"response_id": responseJson.get("id", "")}
|
||||
if choiceMessage.get("tool_calls"):
|
||||
metadata["toolCalls"] = choiceMessage["tool_calls"]
|
||||
|
||||
return AiModelResponse(
|
||||
content=content,
|
||||
success=True,
|
||||
modelId=model.name,
|
||||
metadata={"response_id": responseJson.get("id", "")}
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
except ContextLengthExceededException:
|
||||
# Re-raise context length exceptions without wrapping
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling OpenAI API: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error calling OpenAI API: {str(e)}")
|
||||
|
||||
async def callAiBasicStream(self, modelCall: AiModelCall) -> AsyncGenerator[Union[str, AiModelResponse], None]:
|
||||
"""Stream OpenAI response. Yields str deltas, then final AiModelResponse."""
|
||||
try:
|
||||
messages = modelCall.messages
|
||||
model = modelCall.model
|
||||
options = modelCall.options
|
||||
temperature = getattr(options, "temperature", None)
|
||||
if temperature is None:
|
||||
temperature = model.temperature
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model.name,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": model.maxTokens,
|
||||
"stream": True,
|
||||
}
|
||||
if modelCall.tools:
|
||||
payload["tools"] = modelCall.tools
|
||||
payload["tool_choice"] = modelCall.toolChoice or "auto"
|
||||
|
||||
fullContent = ""
|
||||
toolCallsAccum: Dict[int, Dict[str, Any]] = {}
|
||||
|
||||
async with self.httpClient.stream("POST", model.apiUrl, json=payload) as response:
|
||||
if response.status_code != 200:
|
||||
body = await response.aread()
|
||||
bodyStr = body.decode()
|
||||
if response.status_code == 429:
|
||||
try:
|
||||
errorMsg = _json.loads(bodyStr).get("error", {}).get("message", "Rate limit exceeded")
|
||||
except (ValueError, KeyError):
|
||||
errorMsg = f"Rate limit exceeded for {model.name}"
|
||||
raise RateLimitExceededException(f"Rate limit exceeded for {model.name}: {errorMsg}")
|
||||
raise HTTPException(status_code=500, detail=f"OpenAI stream error: {response.status_code} - {bodyStr}")
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data = line[6:]
|
||||
if data.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = _json.loads(data)
|
||||
except _json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||
|
||||
if "content" in delta and delta["content"]:
|
||||
fullContent += delta["content"]
|
||||
yield delta["content"]
|
||||
|
||||
for tcDelta in delta.get("tool_calls", []):
|
||||
idx = tcDelta.get("index", 0)
|
||||
if idx not in toolCallsAccum:
|
||||
toolCallsAccum[idx] = {
|
||||
"id": tcDelta.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
if tcDelta.get("id"):
|
||||
toolCallsAccum[idx]["id"] = tcDelta["id"]
|
||||
fn = tcDelta.get("function", {})
|
||||
if fn.get("name"):
|
||||
toolCallsAccum[idx]["function"]["name"] = fn["name"]
|
||||
if fn.get("arguments"):
|
||||
toolCallsAccum[idx]["function"]["arguments"] += fn["arguments"]
|
||||
|
||||
metadata: Dict[str, Any] = {}
|
||||
if toolCallsAccum:
|
||||
metadata["toolCalls"] = [toolCallsAccum[i] for i in sorted(toolCallsAccum)]
|
||||
|
||||
yield AiModelResponse(
|
||||
content=fullContent,
|
||||
success=True,
|
||||
modelId=model.name,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
except (RateLimitExceededException, ContextLengthExceededException, HTTPException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming OpenAI API: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Error streaming OpenAI API: {e}")
|
||||
|
||||
async def callEmbedding(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""Generate embeddings via the OpenAI Embeddings API.
|
||||
|
||||
Reads texts from modelCall.embeddingInput.
|
||||
Returns vectors in metadata["embeddings"].
|
||||
"""
|
||||
try:
|
||||
model = modelCall.model
|
||||
texts = modelCall.embeddingInput or []
|
||||
if not texts:
|
||||
return AiModelResponse(
|
||||
content="", success=False, error="No embeddingInput provided"
|
||||
)
|
||||
|
||||
payload = {"model": model.name, "input": texts}
|
||||
response = await self.httpClient.post(model.apiUrl, json=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
errorMessage = f"OpenAI Embedding API error: {response.status_code} - {response.text}"
|
||||
logger.error(errorMessage)
|
||||
if response.status_code == 429:
|
||||
raise RateLimitExceededException(f"Rate limit exceeded for {model.name}")
|
||||
if response.status_code == 400:
|
||||
try:
|
||||
errorData = response.json()
|
||||
errMsg = errorData.get("error", {}).get("message", "").lower()
|
||||
errCode = errorData.get("error", {}).get("code", "")
|
||||
if errCode == "context_length_exceeded" or "too many tokens" in errMsg or "maximum context length" in errMsg:
|
||||
raise ContextLengthExceededException(
|
||||
f"Embedding context length exceeded for {model.name}: {errorData.get('error', {}).get('message', '')}"
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
pass
|
||||
raise HTTPException(status_code=500, detail=errorMessage)
|
||||
|
||||
responseJson = response.json()
|
||||
embeddings = [item["embedding"] for item in responseJson["data"]]
|
||||
usage = responseJson.get("usage", {})
|
||||
|
||||
return AiModelResponse(
|
||||
content="",
|
||||
success=True,
|
||||
modelId=model.name,
|
||||
tokensUsed={
|
||||
"input": usage.get("prompt_tokens", 0),
|
||||
"output": 0,
|
||||
"total": usage.get("total_tokens", 0),
|
||||
},
|
||||
metadata={"embeddings": embeddings},
|
||||
)
|
||||
except (RateLimitExceededException, ContextLengthExceededException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling OpenAI Embedding API: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error calling OpenAI Embedding API: {str(e)}")
|
||||
|
||||
async def callAiImage(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""
|
||||
Analyzes an image with the OpenAI Vision API using standardized pattern.
|
||||
|
|
|
|||
|
|
@ -288,7 +288,16 @@ class AiTavily(BaseConnectorAi):
|
|||
if maxResults < minResults or maxResults > maxAllowedResults:
|
||||
raise ValueError(f"maxResults must be between {minResults} and {maxAllowedResults}")
|
||||
|
||||
# Perform actual API call
|
||||
# Tavily enforces a 400-character query limit
|
||||
TAVILY_MAX_QUERY_LENGTH = 400
|
||||
if len(query) > TAVILY_MAX_QUERY_LENGTH:
|
||||
truncated = query[:TAVILY_MAX_QUERY_LENGTH]
|
||||
lastSpace = truncated.rfind(' ')
|
||||
if lastSpace > TAVILY_MAX_QUERY_LENGTH // 2:
|
||||
truncated = truncated[:lastSpace]
|
||||
logger.warning(f"Tavily query truncated from {len(query)} to {len(truncated)} chars")
|
||||
query = truncated
|
||||
|
||||
# Build kwargs only for provided options to avoid API rejections
|
||||
kwargs: dict = {"query": query, "max_results": maxResults}
|
||||
if searchDepth is not None:
|
||||
|
|
|
|||
|
|
@ -41,6 +41,11 @@ class SystemTable(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
def _isVectorType(sqlType: str) -> bool:
|
||||
"""Check if a SQL type string represents a pgvector column."""
|
||||
return sqlType.upper().startswith("VECTOR")
|
||||
|
||||
|
||||
def _isJsonbType(fieldType) -> bool:
|
||||
"""Check if a type should be stored as JSONB in PostgreSQL."""
|
||||
# Direct dict or list
|
||||
|
|
@ -70,20 +75,26 @@ def _isJsonbType(fieldType) -> bool:
|
|||
|
||||
|
||||
def _get_model_fields(model_class) -> Dict[str, str]:
|
||||
"""Get all fields from Pydantic model and map to SQL types."""
|
||||
# Pydantic v2
|
||||
"""Get all fields from Pydantic model and map to SQL types.
|
||||
|
||||
Supports explicit db_type override via json_schema_extra={"db_type": "vector(1536)"}.
|
||||
This enables pgvector columns without special-casing field names.
|
||||
"""
|
||||
model_fields = model_class.model_fields
|
||||
|
||||
fields = {}
|
||||
for field_name, field_info in model_fields.items():
|
||||
# Pydantic v2
|
||||
field_type = field_info.annotation
|
||||
|
||||
# Explicit db_type override (e.g. vector columns)
|
||||
extra = field_info.json_schema_extra
|
||||
if extra and isinstance(extra, dict) and "db_type" in extra:
|
||||
fields[field_name] = extra["db_type"]
|
||||
continue
|
||||
|
||||
# Check for JSONB fields (Dict, List, or complex types)
|
||||
# Purely type-based detection - no hardcoded field names
|
||||
if _isJsonbType(field_type):
|
||||
fields[field_name] = "JSONB"
|
||||
# Simple type mapping
|
||||
elif field_type in (str, type(None)) or (
|
||||
get_origin(field_type) is Union and type(None) in get_args(field_type)
|
||||
):
|
||||
|
|
@ -95,11 +106,45 @@ def _get_model_fields(model_class) -> Dict[str, str]:
|
|||
elif field_type == bool:
|
||||
fields[field_name] = "BOOLEAN"
|
||||
else:
|
||||
fields[field_name] = "TEXT" # Default to TEXT
|
||||
fields[field_name] = "TEXT"
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def _parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context: str = "") -> None:
|
||||
"""Parse record fields in-place: numeric typing, vector parsing, JSONB deserialization."""
|
||||
import json as _json
|
||||
|
||||
for fieldName, fieldType in fields.items():
|
||||
if fieldName not in record:
|
||||
continue
|
||||
value = record[fieldName]
|
||||
|
||||
if fieldType in ("DOUBLE PRECISION", "INTEGER") and value is not None:
|
||||
try:
|
||||
record[fieldName] = float(value) if fieldType == "DOUBLE PRECISION" else int(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"Could not convert {fieldName} to {fieldType} ({context}): {value}")
|
||||
|
||||
elif _isVectorType(fieldType) and value is not None:
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
record[fieldName] = [float(v) for v in value.strip("[]").split(",")]
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"Could not parse vector field {fieldName} ({context})")
|
||||
elif isinstance(value, list):
|
||||
pass # already a list
|
||||
|
||||
elif fieldType == "JSONB" and value is not None:
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
record[fieldName] = _json.loads(value)
|
||||
elif not isinstance(value, (dict, list)):
|
||||
record[fieldName] = _json.loads(str(value))
|
||||
except (_json.JSONDecodeError, TypeError, ValueError):
|
||||
logger.warning(f"Could not parse JSONB field {fieldName}, keeping as string ({context})")
|
||||
|
||||
|
||||
# Cache connectors by (host, database, port) to avoid duplicate inits for same database.
|
||||
# Thread safety: _connector_cache_lock protects cache access. userId is request-scoped via
|
||||
# contextvars to avoid races when concurrent requests share the same connector.
|
||||
|
|
@ -132,7 +177,7 @@ def _get_cached_connector(
|
|||
oldest_key = _connector_cache_order.pop(0)
|
||||
if oldest_key in _connector_cache:
|
||||
try:
|
||||
_connector_cache[oldest_key].close()
|
||||
_connector_cache[oldest_key].close(forceClose=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing evicted connector: {e}")
|
||||
del _connector_cache[oldest_key]
|
||||
|
|
@ -144,6 +189,7 @@ def _get_cached_connector(
|
|||
dbPort=dbPort,
|
||||
userId=userId,
|
||||
)
|
||||
_connector_cache[key]._isCachedShared = True
|
||||
_connector_cache_order.append(key)
|
||||
conn = _connector_cache[key]
|
||||
# Set request-scoped userId via contextvar (avoids mutating shared connector)
|
||||
|
|
@ -180,6 +226,7 @@ class DatabaseConnector:
|
|||
|
||||
# Initialize database system first (creates database if needed)
|
||||
self.connection = None
|
||||
self._isCachedShared = False
|
||||
self.initDbSystem()
|
||||
|
||||
# No caching needed with proper database - PostgreSQL handles performance
|
||||
|
|
@ -187,6 +234,9 @@ class DatabaseConnector:
|
|||
# Thread safety
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# pgvector extension state
|
||||
self._vectorExtensionEnabled = False
|
||||
|
||||
# Initialize system table
|
||||
self._systemTableName = "_system"
|
||||
self._initializeSystemTable()
|
||||
|
|
@ -500,10 +550,32 @@ class DatabaseConnector:
|
|||
self.connection.rollback()
|
||||
return False
|
||||
|
||||
def _ensureVectorExtension(self) -> bool:
|
||||
"""Enable pgvector extension if not already enabled. Called lazily on first vector table."""
|
||||
if self._vectorExtensionEnabled:
|
||||
return True
|
||||
try:
|
||||
self._ensure_connection()
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
self.connection.commit()
|
||||
self._vectorExtensionEnabled = True
|
||||
logger.info("pgvector extension enabled")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to enable pgvector extension: {e}")
|
||||
if hasattr(self, "connection") and self.connection:
|
||||
self.connection.rollback()
|
||||
return False
|
||||
|
||||
def _create_table_from_model(self, cursor, table: str, model_class: type) -> None:
|
||||
"""Create table with columns matching Pydantic model fields."""
|
||||
fields = _get_model_fields(model_class)
|
||||
|
||||
# Enable pgvector if any field uses vector type
|
||||
if any(_isVectorType(sqlType) for sqlType in fields.values()):
|
||||
self._ensureVectorExtension()
|
||||
|
||||
# Build column definitions with quoted identifiers to preserve exact case
|
||||
columns = ['"id" VARCHAR(255) PRIMARY KEY']
|
||||
for field_name, sql_type in fields.items():
|
||||
|
|
@ -576,28 +648,25 @@ class DatabaseConnector:
|
|||
elif hasattr(value, "value"):
|
||||
value = value.value
|
||||
|
||||
# Handle vector fields (pgvector) - convert List[float] to string
|
||||
elif col in fields and _isVectorType(fields[col]) and value is not None:
|
||||
if isinstance(value, list):
|
||||
value = f"[{','.join(str(v) for v in value)}]"
|
||||
|
||||
# Handle JSONB fields - ensure proper JSON format for PostgreSQL
|
||||
elif col in fields and fields[col] == "JSONB" and value is not None:
|
||||
import json
|
||||
|
||||
if isinstance(value, (dict, list)):
|
||||
# Convert Python objects to JSON string for PostgreSQL JSONB
|
||||
value = json.dumps(value)
|
||||
elif isinstance(value, str):
|
||||
# Validate that it's valid JSON, if not, try to parse and re-serialize
|
||||
try:
|
||||
# Test if it's already valid JSON
|
||||
json.loads(value)
|
||||
# If successful, keep as is
|
||||
pass
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# If not valid JSON, convert to JSON string
|
||||
value = json.dumps(value)
|
||||
elif hasattr(value, 'model_dump'):
|
||||
# Handle Pydantic models
|
||||
value = json.dumps(value.model_dump())
|
||||
else:
|
||||
# Convert other types to JSON
|
||||
value = json.dumps(value)
|
||||
|
||||
values.append(value)
|
||||
|
|
@ -635,46 +704,7 @@ class DatabaseConnector:
|
|||
record = dict(row)
|
||||
fields = _get_model_fields(model_class)
|
||||
|
||||
# Ensure numeric fields are properly typed and parse JSONB fields
|
||||
for field_name, field_type in fields.items():
|
||||
# Ensure numeric fields (float/int) are properly typed
|
||||
# psycopg2 may return them as strings in some environments (e.g., Azure PostgreSQL)
|
||||
if field_type in ("DOUBLE PRECISION", "INTEGER") and field_name in record:
|
||||
value = record[field_name]
|
||||
if value is not None:
|
||||
try:
|
||||
if field_type == "DOUBLE PRECISION":
|
||||
record[field_name] = float(value)
|
||||
elif field_type == "INTEGER":
|
||||
record[field_name] = int(value)
|
||||
except (ValueError, TypeError):
|
||||
# If conversion fails, log warning but keep original value
|
||||
logger.warning(
|
||||
f"Could not convert {field_name} to {field_type} for record {recordId}: {value}"
|
||||
)
|
||||
elif (
|
||||
field_type == "JSONB"
|
||||
and field_name in record
|
||||
and record[field_name] is not None
|
||||
):
|
||||
import json
|
||||
|
||||
try:
|
||||
if isinstance(record[field_name], str):
|
||||
# Parse JSON string back to Python object
|
||||
record[field_name] = json.loads(record[field_name])
|
||||
elif isinstance(record[field_name], (dict, list)):
|
||||
# Already a Python object, keep as is
|
||||
pass
|
||||
else:
|
||||
# Try to parse as JSON
|
||||
record[field_name] = json.loads(str(record[field_name]))
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
# If parsing fails, keep as string
|
||||
logger.warning(
|
||||
f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}"
|
||||
)
|
||||
pass
|
||||
_parseRecordFields(record, fields, f"record {recordId}")
|
||||
|
||||
return record
|
||||
except Exception as e:
|
||||
|
|
@ -737,55 +767,24 @@ class DatabaseConnector:
|
|||
cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"')
|
||||
records = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
# Handle JSONB fields for all records
|
||||
fields = _get_model_fields(model_class)
|
||||
model_fields = model_class.model_fields # Get Pydantic model fields
|
||||
modelFields = model_class.model_fields
|
||||
for record in records:
|
||||
for field_name, field_type in fields.items():
|
||||
if field_type == "JSONB" and field_name in record:
|
||||
if record[field_name] is None:
|
||||
# Generic type-based default: List types -> [], Dict types -> {}
|
||||
# Interfaces handle domain-specific defaults
|
||||
field_info = model_fields.get(field_name)
|
||||
if field_info:
|
||||
field_annotation = field_info.annotation
|
||||
# Check if it's a List type
|
||||
if (field_annotation == list or
|
||||
(hasattr(field_annotation, "__origin__") and
|
||||
field_annotation.__origin__ is list)):
|
||||
record[field_name] = []
|
||||
# Check if it's a Dict type
|
||||
elif (field_annotation == dict or
|
||||
(hasattr(field_annotation, "__origin__") and
|
||||
field_annotation.__origin__ is dict)):
|
||||
record[field_name] = {}
|
||||
else:
|
||||
record[field_name] = None
|
||||
else:
|
||||
record[field_name] = None
|
||||
else:
|
||||
import json
|
||||
|
||||
try:
|
||||
if isinstance(record[field_name], str):
|
||||
# Parse JSON string back to Python object
|
||||
record[field_name] = json.loads(
|
||||
record[field_name]
|
||||
)
|
||||
elif isinstance(record[field_name], (dict, list)):
|
||||
# Already a Python object, keep as is
|
||||
pass
|
||||
else:
|
||||
# Try to parse as JSON
|
||||
record[field_name] = json.loads(
|
||||
str(record[field_name])
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
# If parsing fails, keep as string
|
||||
logger.warning(
|
||||
f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}"
|
||||
)
|
||||
pass
|
||||
_parseRecordFields(record, fields, f"table {table}")
|
||||
# Set type-aware defaults for NULL JSONB fields
|
||||
for fieldName, fieldType in fields.items():
|
||||
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
|
||||
fieldInfo = modelFields.get(fieldName)
|
||||
if fieldInfo:
|
||||
fieldAnnotation = fieldInfo.annotation
|
||||
if (fieldAnnotation == list or
|
||||
(hasattr(fieldAnnotation, "__origin__") and
|
||||
fieldAnnotation.__origin__ is list)):
|
||||
record[fieldName] = []
|
||||
elif (fieldAnnotation == dict or
|
||||
(hasattr(fieldAnnotation, "__origin__") and
|
||||
fieldAnnotation.__origin__ is dict)):
|
||||
record[fieldName] = {}
|
||||
|
||||
return records
|
||||
except Exception as e:
|
||||
|
|
@ -936,70 +935,23 @@ class DatabaseConnector:
|
|||
cursor.execute(query, where_values)
|
||||
records = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
# Handle JSONB fields and ensure numeric types are correct
|
||||
fields = _get_model_fields(model_class)
|
||||
model_fields = model_class.model_fields # Get Pydantic model fields
|
||||
modelFields = model_class.model_fields
|
||||
for record in records:
|
||||
for field_name, field_type in fields.items():
|
||||
# Ensure numeric fields (float/int) are properly typed
|
||||
# psycopg2 may return them as strings in some environments (e.g., Azure PostgreSQL)
|
||||
if field_type in ("DOUBLE PRECISION", "INTEGER") and field_name in record:
|
||||
value = record[field_name]
|
||||
if value is not None:
|
||||
try:
|
||||
if field_type == "DOUBLE PRECISION":
|
||||
record[field_name] = float(value)
|
||||
elif field_type == "INTEGER":
|
||||
record[field_name] = int(value)
|
||||
except (ValueError, TypeError):
|
||||
# If conversion fails, log warning but keep original value
|
||||
logger.warning(
|
||||
f"Could not convert {field_name} to {field_type} for record {record.get('id', 'unknown')}: {value}"
|
||||
)
|
||||
elif field_type == "JSONB" and field_name in record:
|
||||
if record[field_name] is None:
|
||||
# Generic type-based default: List types -> [], Dict types -> {}
|
||||
# Interfaces handle domain-specific defaults
|
||||
field_info = model_fields.get(field_name)
|
||||
if field_info:
|
||||
field_annotation = field_info.annotation
|
||||
# Check if it's a List type
|
||||
if (field_annotation == list or
|
||||
(hasattr(field_annotation, "__origin__") and
|
||||
field_annotation.__origin__ is list)):
|
||||
record[field_name] = []
|
||||
# Check if it's a Dict type
|
||||
elif (field_annotation == dict or
|
||||
(hasattr(field_annotation, "__origin__") and
|
||||
field_annotation.__origin__ is dict)):
|
||||
record[field_name] = {}
|
||||
else:
|
||||
record[field_name] = None
|
||||
else:
|
||||
record[field_name] = None
|
||||
else:
|
||||
import json
|
||||
|
||||
try:
|
||||
if isinstance(record[field_name], str):
|
||||
# Parse JSON string back to Python object
|
||||
record[field_name] = json.loads(
|
||||
record[field_name]
|
||||
)
|
||||
elif isinstance(record[field_name], (dict, list)):
|
||||
# Already a Python object, keep as is
|
||||
pass
|
||||
else:
|
||||
# Try to parse as JSON
|
||||
record[field_name] = json.loads(
|
||||
str(record[field_name])
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
# If parsing fails, keep as string
|
||||
logger.warning(
|
||||
f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}"
|
||||
)
|
||||
pass
|
||||
_parseRecordFields(record, fields, f"table {table}")
|
||||
for fieldName, fieldType in fields.items():
|
||||
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
|
||||
fieldInfo = modelFields.get(fieldName)
|
||||
if fieldInfo:
|
||||
fieldAnnotation = fieldInfo.annotation
|
||||
if (fieldAnnotation == list or
|
||||
(hasattr(fieldAnnotation, "__origin__") and
|
||||
fieldAnnotation.__origin__ is list)):
|
||||
record[fieldName] = []
|
||||
elif (fieldAnnotation == dict or
|
||||
(hasattr(fieldAnnotation, "__origin__") and
|
||||
fieldAnnotation.__origin__ is dict)):
|
||||
record[fieldName] = {}
|
||||
|
||||
# If fieldFilter is available, reduce the fields
|
||||
if fieldFilter and isinstance(fieldFilter, list):
|
||||
|
|
@ -1080,7 +1032,10 @@ class DatabaseConnector:
|
|||
existingRecord.update(record)
|
||||
|
||||
# Save updated record
|
||||
self._saveRecord(model_class, recordId, existingRecord)
|
||||
saved = self._saveRecord(model_class, recordId, existingRecord)
|
||||
if not saved:
|
||||
table = model_class.__name__
|
||||
raise ValueError(f"Failed to save record {recordId} to table {table}")
|
||||
return existingRecord
|
||||
|
||||
def recordDelete(self, model_class: type, recordId: str) -> bool:
|
||||
|
|
@ -1127,8 +1082,94 @@ class DatabaseConnector:
|
|||
initialId = systemData.get(table)
|
||||
return initialId
|
||||
|
||||
def close(self):
|
||||
"""Close the database connection."""
|
||||
def semanticSearch(
|
||||
self,
|
||||
modelClass: type,
|
||||
vectorColumn: str,
|
||||
queryVector: List[float],
|
||||
limit: int = 10,
|
||||
recordFilter: Dict[str, Any] = None,
|
||||
minScore: float = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Semantic search using pgvector cosine distance.
|
||||
|
||||
Args:
|
||||
modelClass: Pydantic model class for the table.
|
||||
vectorColumn: Name of the vector column to search.
|
||||
queryVector: Query vector as List[float].
|
||||
limit: Maximum number of results.
|
||||
recordFilter: Additional WHERE filters (field: value).
|
||||
minScore: Minimum cosine similarity (0.0 - 1.0).
|
||||
|
||||
Returns:
|
||||
List of records with an added '_score' field (cosine similarity),
|
||||
sorted by similarity descending.
|
||||
"""
|
||||
table = modelClass.__name__
|
||||
|
||||
try:
|
||||
if not self._ensureTableExists(modelClass):
|
||||
return []
|
||||
|
||||
vectorStr = f"[{','.join(str(v) for v in queryVector)}]"
|
||||
|
||||
whereConditions = []
|
||||
whereValues = []
|
||||
|
||||
if recordFilter:
|
||||
for field, value in recordFilter.items():
|
||||
if value is None:
|
||||
whereConditions.append(f'"{field}" IS NULL')
|
||||
elif isinstance(value, (list, tuple)):
|
||||
if not value:
|
||||
whereConditions.append("1 = 0")
|
||||
else:
|
||||
whereConditions.append(f'"{field}" = ANY(%s)')
|
||||
whereValues.append(list(value))
|
||||
else:
|
||||
whereConditions.append(f'"{field}" = %s')
|
||||
whereValues.append(value)
|
||||
|
||||
if minScore is not None:
|
||||
whereConditions.append(
|
||||
f'1 - ("{vectorColumn}" <=> %s::vector) >= %s'
|
||||
)
|
||||
whereValues.extend([vectorStr, minScore])
|
||||
|
||||
whereClause = ""
|
||||
if whereConditions:
|
||||
whereClause = " WHERE " + " AND ".join(whereConditions)
|
||||
|
||||
query = (
|
||||
f'SELECT *, 1 - ("{vectorColumn}" <=> %s::vector) AS "_score" '
|
||||
f'FROM "{table}"{whereClause} '
|
||||
f'ORDER BY "{vectorColumn}" <=> %s::vector '
|
||||
f'LIMIT %s'
|
||||
)
|
||||
params = [vectorStr] + whereValues + [vectorStr, limit]
|
||||
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(query, params)
|
||||
records = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
fields = _get_model_fields(modelClass)
|
||||
for record in records:
|
||||
_parseRecordFields(record, fields, f"semanticSearch {table}")
|
||||
|
||||
return records
|
||||
except Exception as e:
|
||||
logger.error(f"Error in semantic search on {table}: {e}")
|
||||
return []
|
||||
|
||||
def close(self, forceClose: bool = False):
|
||||
"""Close the database connection.
|
||||
|
||||
Shared cached connectors are intentionally kept open unless forceClose=True.
|
||||
This prevents accidental shutdown from interface __del__ methods while
|
||||
other requests are still using the same cached connector instance.
|
||||
"""
|
||||
if self._isCachedShared and not forceClose:
|
||||
return
|
||||
if (
|
||||
hasattr(self, "connection")
|
||||
and self.connection
|
||||
|
|
@ -1141,5 +1182,4 @@ class DatabaseConnector:
|
|||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
# Ignore errors during cleanup
|
||||
pass
|
||||
|
|
|
|||
63
modules/connectors/connectorProviderBase.py
Normal file
63
modules/connectors/connectorProviderBase.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Abstract base classes for the Provider-Connector architecture (1:n).
|
||||
|
||||
One ProviderConnector per vendor (e.g. MsftConnector, GoogleConnector).
|
||||
Each ProviderConnector exposes n ServiceAdapters (e.g. SharepointAdapter, OutlookAdapter).
|
||||
All ServiceAdapters share the same access token from the UserConnection.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadResult:
|
||||
"""Rich return type for ServiceAdapter.download() when metadata is available."""
|
||||
data: bytes = field(default=b"", repr=False)
|
||||
fileName: str = ""
|
||||
mimeType: str = ""
|
||||
|
||||
|
||||
class ServiceAdapter(ABC):
|
||||
"""Standardized operations for a single service of a provider."""
|
||||
|
||||
@abstractmethod
|
||||
async def browse(self, path: str, filter: Optional[str] = None) -> list:
|
||||
"""List items (files/folders) at the given path."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def download(self, path: str) -> Union[bytes, DownloadResult]:
|
||||
"""Download a file. Return bytes or DownloadResult with metadata."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||
"""Upload a file to the given path. Returns metadata of the created entry."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def search(self, query: str, path: Optional[str] = None) -> list:
|
||||
"""Search for items matching the query."""
|
||||
...
|
||||
|
||||
|
||||
class ProviderConnector(ABC):
|
||||
"""One connector per provider. Manages a UserConnection + token.
|
||||
Provides access to n services of the provider."""
|
||||
|
||||
def __init__(self, connection, accessToken: str):
|
||||
self.connection = connection
|
||||
self.accessToken = accessToken
|
||||
|
||||
@abstractmethod
|
||||
def getAvailableServices(self) -> List[str]:
|
||||
"""Which services does this provider offer?"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def getServiceAdapter(self, service: str) -> ServiceAdapter:
|
||||
"""Return the ServiceAdapter for a specific service."""
|
||||
...
|
||||
94
modules/connectors/connectorResolver.py
Normal file
94
modules/connectors/connectorResolver.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""ConnectorResolver -- resolves a connectionId to the correct ProviderConnector and ServiceAdapter.
|
||||
|
||||
Registry maps authority values to ProviderConnector classes.
|
||||
The resolver loads the UserConnection, obtains a fresh token via SecurityService,
|
||||
and instantiates the appropriate connector.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Type, Optional
|
||||
|
||||
from modules.connectors.connectorProviderBase import ProviderConnector, ServiceAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectorResolver:
|
||||
"""Resolves connectionId → ProviderConnector (with fresh token) → ServiceAdapter."""
|
||||
|
||||
_providerRegistry: Dict[str, Type[ProviderConnector]] = {}
|
||||
|
||||
def __init__(self, securityService, dbInterface):
|
||||
"""
|
||||
Args:
|
||||
securityService: SecurityService instance (for getFreshToken)
|
||||
dbInterface: DB interface with getUserConnection(connectionId)
|
||||
"""
|
||||
self._security = securityService
|
||||
self._db = dbInterface
|
||||
self._ensureRegistered()
|
||||
|
||||
def _ensureRegistered(self):
|
||||
"""Lazy-register known providers on first instantiation."""
|
||||
if ConnectorResolver._providerRegistry:
|
||||
return
|
||||
try:
|
||||
from modules.connectors.providerMsft.connectorMsft import MsftConnector
|
||||
ConnectorResolver._providerRegistry["msft"] = MsftConnector
|
||||
except ImportError:
|
||||
logger.warning("MsftConnector not available")
|
||||
|
||||
try:
|
||||
from modules.connectors.providerGoogle.connectorGoogle import GoogleConnector
|
||||
ConnectorResolver._providerRegistry["google"] = GoogleConnector
|
||||
except ImportError:
|
||||
logger.debug("GoogleConnector not available (stub)")
|
||||
|
||||
try:
|
||||
from modules.connectors.providerFtp.connectorFtp import FtpConnector
|
||||
ConnectorResolver._providerRegistry["local:ftp"] = FtpConnector
|
||||
except ImportError:
|
||||
logger.debug("FtpConnector not available (stub)")
|
||||
|
||||
async def resolve(self, connectionId: str) -> ProviderConnector:
|
||||
"""Resolve connectionId to a ProviderConnector with a fresh access token."""
|
||||
connection = await self._loadConnection(connectionId)
|
||||
if not connection:
|
||||
raise ValueError(f"UserConnection not found: {connectionId}")
|
||||
|
||||
authority = getattr(connection, "authority", None)
|
||||
if not authority:
|
||||
raise ValueError(f"Connection {connectionId} has no authority")
|
||||
|
||||
authorityStr = authority.value if hasattr(authority, "value") else str(authority)
|
||||
providerClass = self._providerRegistry.get(authorityStr)
|
||||
if not providerClass:
|
||||
raise ValueError(f"No ProviderConnector registered for authority: {authorityStr}")
|
||||
|
||||
token = self._security.getFreshToken(connectionId)
|
||||
if not token or not token.tokenAccess:
|
||||
raise ValueError(f"No valid token for connection {connectionId}")
|
||||
|
||||
return providerClass(connection, token.tokenAccess)
|
||||
|
||||
async def resolveService(self, connectionId: str, service: str) -> ServiceAdapter:
|
||||
"""Resolve connectionId + service name to a concrete ServiceAdapter."""
|
||||
provider = await self.resolve(connectionId)
|
||||
available = provider.getAvailableServices()
|
||||
if service not in available:
|
||||
raise ValueError(f"Service '{service}' not available. Options: {available}")
|
||||
return provider.getServiceAdapter(service)
|
||||
|
||||
async def _loadConnection(self, connectionId: str) -> Optional[Any]:
|
||||
"""Load UserConnection from DB."""
|
||||
try:
|
||||
if hasattr(self._db, "getUserConnection"):
|
||||
return self._db.getUserConnection(connectionId)
|
||||
if hasattr(self._db, "loadRecord"):
|
||||
from modules.datamodels.datamodelUam import UserConnection
|
||||
return self._db.loadRecord(UserConnection, connectionId)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load connection {connectionId}: {e}")
|
||||
return None
|
||||
|
|
@ -9,7 +9,8 @@ import json
|
|||
import html
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Optional, Any, List
|
||||
import time
|
||||
from typing import AsyncGenerator, Dict, Optional, Any, List, Tuple
|
||||
from google.cloud import speech
|
||||
from google.cloud import translate_v2 as translate
|
||||
from google.cloud import texttospeech
|
||||
|
|
@ -403,6 +404,155 @@ class ConnectorGoogleSpeech:
|
|||
"error": str(e)
|
||||
}
|
||||
|
||||
async def streamingRecognize(
|
||||
self,
|
||||
audioQueue: asyncio.Queue,
|
||||
language: str = "de-DE",
|
||||
phraseHints: Optional[list] = None,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
Stream audio chunks to Google Cloud Speech-to-Text Streaming API.
|
||||
Google handles silence/endpoint detection natively.
|
||||
|
||||
Args:
|
||||
audioQueue: Queue of (bytes, bool) tuples. bytes=audio data, bool=isLast.
|
||||
Send (b"", True) to signal end of stream.
|
||||
language: Language code
|
||||
phraseHints: Optional boost phrases
|
||||
|
||||
Yields:
|
||||
Dicts with keys: isFinal, transcript, confidence, stabilityScore, audioDurationSec
|
||||
"""
|
||||
STREAM_LIMIT_SEC = 290
|
||||
streamStartTs = time.time()
|
||||
totalAudioBytes = 0
|
||||
|
||||
configParams = {
|
||||
"encoding": speech.RecognitionConfig.AudioEncoding.WEBM_OPUS,
|
||||
"sample_rate_hertz": 48000,
|
||||
"audio_channel_count": 1,
|
||||
"language_code": language,
|
||||
"enable_automatic_punctuation": True,
|
||||
"model": "latest_long",
|
||||
"use_enhanced": True,
|
||||
}
|
||||
if phraseHints:
|
||||
configParams["speech_contexts"] = [speech.SpeechContext(phrases=phraseHints, boost=15.0)]
|
||||
|
||||
recognitionConfig = speech.RecognitionConfig(**configParams)
|
||||
streamingConfig = speech.StreamingRecognitionConfig(
|
||||
config=recognitionConfig,
|
||||
interim_results=True,
|
||||
single_utterance=False,
|
||||
)
|
||||
|
||||
import queue as threadQueue
|
||||
audioInQ: threadQueue.Queue = threadQueue.Queue()
|
||||
resultOutQ: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
async def _pumpAudioToThread():
|
||||
try:
|
||||
while True:
|
||||
item = await audioQueue.get()
|
||||
audioInQ.put(item)
|
||||
if item[1]:
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
audioInQ.put((b"", True))
|
||||
|
||||
def _requestGenerator():
|
||||
nonlocal totalAudioBytes
|
||||
while True:
|
||||
try:
|
||||
chunk, isLast = audioInQ.get(timeout=30.0)
|
||||
except threadQueue.Empty:
|
||||
return
|
||||
if isLast or not chunk:
|
||||
return
|
||||
totalAudioBytes += len(chunk)
|
||||
yield speech.StreamingRecognizeRequest(audio_content=chunk)
|
||||
|
||||
def _runStreamingInThread():
|
||||
try:
|
||||
responseStream = self.speech_client.streaming_recognize(
|
||||
config=streamingConfig,
|
||||
requests=_requestGenerator(),
|
||||
)
|
||||
for response in responseStream:
|
||||
elapsed = time.time() - streamStartTs
|
||||
estimatedDurationSec = totalAudioBytes / (48000 * 1 * 2) if totalAudioBytes else 0
|
||||
|
||||
finalTexts = []
|
||||
interimTexts = []
|
||||
lastFinalConfidence = 0.0
|
||||
|
||||
for result in response.results:
|
||||
alt = result.alternatives[0] if result.alternatives else None
|
||||
if not alt or not alt.transcript.strip():
|
||||
continue
|
||||
if result.is_final:
|
||||
finalTexts.append(alt.transcript.strip())
|
||||
lastFinalConfidence = alt.confidence
|
||||
else:
|
||||
interimTexts.append(alt.transcript.strip())
|
||||
|
||||
for ft in finalTexts:
|
||||
asyncio.run_coroutine_threadsafe(resultOutQ.put({
|
||||
"isFinal": True,
|
||||
"transcript": ft,
|
||||
"confidence": lastFinalConfidence,
|
||||
"stabilityScore": 0.0,
|
||||
"audioDurationSec": estimatedDurationSec,
|
||||
}), loop)
|
||||
|
||||
if interimTexts:
|
||||
combined = " ".join(interimTexts)
|
||||
asyncio.run_coroutine_threadsafe(resultOutQ.put({
|
||||
"isFinal": False,
|
||||
"transcript": combined,
|
||||
"confidence": 0.0,
|
||||
"stabilityScore": 0.0,
|
||||
"audioDurationSec": estimatedDurationSec,
|
||||
}), loop)
|
||||
if elapsed >= STREAM_LIMIT_SEC:
|
||||
logger.info("Streaming STT approaching 5-min limit, client should reconnect")
|
||||
asyncio.run_coroutine_threadsafe(resultOutQ.put({
|
||||
"isFinal": False, "transcript": "", "confidence": 0.0,
|
||||
"reconnectRequired": True, "audioDurationSec": 0,
|
||||
}), loop)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Google Streaming STT error: {e}")
|
||||
asyncio.run_coroutine_threadsafe(resultOutQ.put({
|
||||
"error": str(e),
|
||||
}), loop)
|
||||
finally:
|
||||
asyncio.run_coroutine_threadsafe(resultOutQ.put(None), loop)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
pumpTask = asyncio.ensure_future(_pumpAudioToThread())
|
||||
streamFuture = loop.run_in_executor(None, _runStreamingInThread)
|
||||
|
||||
try:
|
||||
while True:
|
||||
item = await resultOutQ.get()
|
||||
if item is None:
|
||||
break
|
||||
if "error" in item:
|
||||
raise RuntimeError(item["error"])
|
||||
yield item
|
||||
finally:
|
||||
pumpTask.cancel()
|
||||
await asyncio.shield(streamFuture)
|
||||
|
||||
def calculateSttCostCHF(self, audioDurationSec: float) -> float:
|
||||
"""Google STT cost: ~$0.016/min (standard model)."""
|
||||
return round((audioDurationSec / 60.0) * 0.016, 8)
|
||||
|
||||
def calculateTtsCostCHF(self, characterCount: int) -> float:
|
||||
"""Google TTS WaveNet cost: ~$0.000004/char."""
|
||||
return round(characterCount * 0.000004, 8)
|
||||
|
||||
async def translateText(self, text: str, targetLanguage: str = "en",
|
||||
sourceLanguage: str = "de") -> Dict:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
|
||||
|
||||
"""FTP/SFTP Provider Connector stub."""
|
||||
48
modules/connectors/providerFtp/connectorFtp.py
Normal file
48
modules/connectors/providerFtp/connectorFtp.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""FTP/SFTP ProviderConnector stub.
|
||||
|
||||
Implements the ProviderConnector interface for FTP/SFTP file access.
|
||||
Full implementation follows when FTP integration is prioritized.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from modules.connectors.connectorProviderBase import ProviderConnector, ServiceAdapter
|
||||
from modules.datamodels.datamodelDataSource import ExternalEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FtpFilesAdapter(ServiceAdapter):
|
||||
"""FTP files ServiceAdapter (stub)."""
|
||||
|
||||
def __init__(self, accessToken: str):
|
||||
self._accessToken = accessToken
|
||||
|
||||
async def browse(self, path: str, filter: Optional[str] = None) -> List[ExternalEntry]:
|
||||
logger.info(f"FTP browse stub: {path}")
|
||||
return []
|
||||
|
||||
async def download(self, path: str) -> bytes:
|
||||
logger.info(f"FTP download stub: {path}")
|
||||
return b""
|
||||
|
||||
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||
return {"error": "FTP upload not yet implemented"}
|
||||
|
||||
async def search(self, query: str, path: Optional[str] = None) -> List[ExternalEntry]:
|
||||
return []
|
||||
|
||||
|
||||
class FtpConnector(ProviderConnector):
|
||||
"""FTP ProviderConnector -- 1 connection -> files."""
|
||||
|
||||
def getAvailableServices(self) -> List[str]:
|
||||
return ["files"]
|
||||
|
||||
def getServiceAdapter(self, service: str) -> ServiceAdapter:
|
||||
if service != "files":
|
||||
raise ValueError(f"FTP only supports 'files' service, got '{service}'")
|
||||
return FtpFilesAdapter(self.accessToken)
|
||||
3
modules/connectors/providerGoogle/__init__.py
Normal file
3
modules/connectors/providerGoogle/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Google Provider Connector -- 1 Connection : n Services (Drive, Gmail)."""
|
||||
265
modules/connectors/providerGoogle/connectorGoogle.py
Normal file
265
modules/connectors/providerGoogle/connectorGoogle.py
Normal file
|
|
@ -0,0 +1,265 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Google ProviderConnector -- Drive and Gmail via Google OAuth."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from modules.connectors.connectorProviderBase import ProviderConnector, ServiceAdapter, DownloadResult
|
||||
from modules.datamodels.datamodelDataSource import ExternalEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DRIVE_BASE = "https://www.googleapis.com/drive/v3"
|
||||
_GMAIL_BASE = "https://gmail.googleapis.com/gmail/v1"
|
||||
|
||||
|
||||
async def _googleGet(token: str, url: str) -> Dict[str, Any]:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
timeout = aiohttp.ClientTimeout(total=20)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url, headers=headers) as resp:
|
||||
if resp.status in (200, 201):
|
||||
return await resp.json()
|
||||
errorText = await resp.text()
|
||||
logger.warning(f"Google API {resp.status}: {errorText[:300]}")
|
||||
return {"error": f"{resp.status}: {errorText[:200]}"}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
class DriveAdapter(ServiceAdapter):
|
||||
"""Google Drive ServiceAdapter -- browse files and folders."""
|
||||
|
||||
def __init__(self, accessToken: str):
|
||||
self._token = accessToken
|
||||
|
||||
async def browse(self, path: str, filter: Optional[str] = None) -> List[ExternalEntry]:
|
||||
folderId = (path or "").strip("/") or "root"
|
||||
query = f"'{folderId}' in parents and trashed=false"
|
||||
fields = "files(id,name,mimeType,size,modifiedTime,parents)"
|
||||
url = f"{_DRIVE_BASE}/files?q={query}&fields={fields}&pageSize=100&orderBy=folder,name"
|
||||
|
||||
result = await _googleGet(self._token, url)
|
||||
if "error" in result:
|
||||
logger.warning(f"Google Drive browse failed: {result['error']}")
|
||||
return []
|
||||
|
||||
entries = []
|
||||
for f in result.get("files", []):
|
||||
isFolder = f.get("mimeType") == "application/vnd.google-apps.folder"
|
||||
entries.append(ExternalEntry(
|
||||
name=f.get("name", ""),
|
||||
path=f"/{f.get('id', '')}",
|
||||
isFolder=isFolder,
|
||||
size=int(f.get("size", 0)) if f.get("size") else None,
|
||||
mimeType=f.get("mimeType") if not isFolder else None,
|
||||
metadata={"id": f.get("id"), "modifiedTime": f.get("modifiedTime")},
|
||||
))
|
||||
return entries
|
||||
|
||||
_EXPORT_MIME_MAP = {
|
||||
"application/vnd.google-apps.document": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.google-apps.spreadsheet": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"application/vnd.google-apps.presentation": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"application/vnd.google-apps.drawing": "application/pdf",
|
||||
}
|
||||
|
||||
async def download(self, path: str) -> bytes:
|
||||
fileId = (path or "").strip("/")
|
||||
if not fileId:
|
||||
return b""
|
||||
headers = {"Authorization": f"Bearer {self._token}"}
|
||||
timeout = aiohttp.ClientTimeout(total=60)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
# Try direct download first
|
||||
url = f"{_DRIVE_BASE}/files/{fileId}?alt=media"
|
||||
async with session.get(url, headers=headers) as resp:
|
||||
if resp.status == 200:
|
||||
return await resp.read()
|
||||
logger.debug(f"Google Drive direct download returned {resp.status} for {fileId}")
|
||||
|
||||
# If 403/404, check if it's a native Google file that needs export
|
||||
metaUrl = f"{_DRIVE_BASE}/files/{fileId}?fields=mimeType,name"
|
||||
async with session.get(metaUrl, headers=headers) as metaResp:
|
||||
if metaResp.status != 200:
|
||||
logger.warning(f"Google Drive metadata fetch failed ({metaResp.status}) for {fileId}")
|
||||
return b""
|
||||
meta = await metaResp.json()
|
||||
fileMime = meta.get("mimeType", "")
|
||||
fileName = meta.get("name", fileId)
|
||||
|
||||
exportMime = self._EXPORT_MIME_MAP.get(fileMime)
|
||||
if not exportMime:
|
||||
logger.warning(f"Google Drive: unsupported mimeType '{fileMime}' for file '{fileName}' ({fileId})")
|
||||
return b""
|
||||
|
||||
exportUrl = f"{_DRIVE_BASE}/files/{fileId}/export?mimeType={exportMime}"
|
||||
logger.info(f"Google Drive: exporting '{fileName}' as {exportMime}")
|
||||
async with session.get(exportUrl, headers=headers) as exportResp:
|
||||
if exportResp.status == 200:
|
||||
return await exportResp.read()
|
||||
logger.warning(f"Google Drive export failed ({exportResp.status}) for '{fileName}'")
|
||||
except Exception as e:
|
||||
logger.error(f"Google Drive download failed for {fileId}: {e}")
|
||||
return b""
|
||||
|
||||
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||
return {"error": "Google Drive upload not yet implemented"}
|
||||
|
||||
async def search(self, query: str, path: Optional[str] = None) -> List[ExternalEntry]:
|
||||
safeQuery = query.replace("'", "\\'")
|
||||
folderId = (path or "").strip("/")
|
||||
qParts = [f"name contains '{safeQuery}'", "trashed=false"]
|
||||
if folderId:
|
||||
qParts.append(f"'{folderId}' in parents")
|
||||
qStr = " and ".join(qParts)
|
||||
url = f"{_DRIVE_BASE}/files?q={qStr}&fields=files(id,name,mimeType,size)&pageSize=25"
|
||||
logger.debug(f"Google Drive search: q={qStr}")
|
||||
result = await _googleGet(self._token, url)
|
||||
if "error" in result:
|
||||
return []
|
||||
return [
|
||||
ExternalEntry(
|
||||
name=f.get("name", ""),
|
||||
path=f"/{f.get('id', '')}",
|
||||
isFolder=f.get("mimeType") == "application/vnd.google-apps.folder",
|
||||
size=int(f.get("size", 0)) if f.get("size") else None,
|
||||
)
|
||||
for f in result.get("files", [])
|
||||
]
|
||||
|
||||
|
||||
class GmailAdapter(ServiceAdapter):
|
||||
"""Gmail ServiceAdapter -- browse labels and messages."""
|
||||
|
||||
def __init__(self, accessToken: str):
|
||||
self._token = accessToken
|
||||
|
||||
async def browse(self, path: str, filter: Optional[str] = None) -> list:
|
||||
cleanPath = (path or "").strip("/")
|
||||
|
||||
if not cleanPath:
|
||||
url = f"{_GMAIL_BASE}/users/me/labels"
|
||||
result = await _googleGet(self._token, url)
|
||||
if "error" in result:
|
||||
logger.warning(f"Gmail labels failed: {result['error']}")
|
||||
return []
|
||||
_SYSTEM_LABELS = {"INBOX", "SENT", "DRAFT", "TRASH", "SPAM", "STARRED", "IMPORTANT"}
|
||||
labels = []
|
||||
for lbl in result.get("labels", []):
|
||||
labelId = lbl.get("id", "")
|
||||
labelName = lbl.get("name", labelId)
|
||||
if lbl.get("type") == "system" and labelId not in _SYSTEM_LABELS:
|
||||
continue
|
||||
labels.append(ExternalEntry(
|
||||
name=labelName,
|
||||
path=f"/{labelId}",
|
||||
isFolder=True,
|
||||
metadata={"id": labelId, "type": lbl.get("type", "")},
|
||||
))
|
||||
labels.sort(key=lambda e: (0 if e.metadata.get("type") == "system" else 1, e.name))
|
||||
return labels
|
||||
|
||||
url = f"{_GMAIL_BASE}/users/me/messages?labelIds={cleanPath}&maxResults=25"
|
||||
result = await _googleGet(self._token, url)
|
||||
if "error" in result:
|
||||
return []
|
||||
|
||||
entries = []
|
||||
for msg in result.get("messages", [])[:25]:
|
||||
msgId = msg.get("id", "")
|
||||
detailUrl = f"{_GMAIL_BASE}/users/me/messages/{msgId}?format=metadata&metadataHeaders=Subject&metadataHeaders=From&metadataHeaders=Date"
|
||||
detail = await _googleGet(self._token, detailUrl)
|
||||
if "error" in detail:
|
||||
entries.append(ExternalEntry(name=f"Message {msgId}", path=f"/{cleanPath}/{msgId}", isFolder=False))
|
||||
continue
|
||||
headers = {h.get("name", ""): h.get("value", "") for h in detail.get("payload", {}).get("headers", [])}
|
||||
entries.append(ExternalEntry(
|
||||
name=headers.get("Subject", "(no subject)"),
|
||||
path=f"/{cleanPath}/{msgId}",
|
||||
isFolder=False,
|
||||
metadata={
|
||||
"id": msgId,
|
||||
"from": headers.get("From", ""),
|
||||
"date": headers.get("Date", ""),
|
||||
"snippet": detail.get("snippet", ""),
|
||||
},
|
||||
))
|
||||
return entries
|
||||
|
||||
async def download(self, path: str) -> DownloadResult:
|
||||
"""Download a Gmail message as RFC 822 EML via format=raw."""
|
||||
import base64
|
||||
import re
|
||||
cleanPath = (path or "").strip("/")
|
||||
msgId = cleanPath.split("/")[-1] if cleanPath else ""
|
||||
if not msgId:
|
||||
return DownloadResult()
|
||||
|
||||
url = f"{_GMAIL_BASE}/users/me/messages/{msgId}?format=raw"
|
||||
result = await _googleGet(self._token, url)
|
||||
if "error" in result:
|
||||
return DownloadResult()
|
||||
|
||||
rawB64 = result.get("raw", "")
|
||||
if not rawB64:
|
||||
return DownloadResult()
|
||||
|
||||
emlBytes = base64.urlsafe_b64decode(rawB64)
|
||||
|
||||
metaUrl = f"{_GMAIL_BASE}/users/me/messages/{msgId}?format=metadata&metadataHeaders=Subject"
|
||||
meta = await _googleGet(self._token, metaUrl)
|
||||
subject = msgId
|
||||
if "error" not in meta:
|
||||
for h in meta.get("payload", {}).get("headers", []):
|
||||
if h.get("name", "").lower() == "subject":
|
||||
subject = h.get("value", msgId)
|
||||
break
|
||||
safeName = re.sub(r'[<>:"/\\|?*\x00-\x1f]', "_", subject)[:80].strip(". ") or "email"
|
||||
|
||||
return DownloadResult(
|
||||
data=emlBytes,
|
||||
fileName=f"{safeName}.eml",
|
||||
mimeType="message/rfc822",
|
||||
)
|
||||
|
||||
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||
return {"error": "Gmail upload not applicable"}
|
||||
|
||||
async def search(self, query: str, path: Optional[str] = None) -> list:
|
||||
url = f"{_GMAIL_BASE}/users/me/messages?q={query}&maxResults=10"
|
||||
result = await _googleGet(self._token, url)
|
||||
if "error" in result:
|
||||
return []
|
||||
return [
|
||||
ExternalEntry(
|
||||
name=f"Message {m.get('id', '')}",
|
||||
path=f"/{m.get('id', '')}",
|
||||
isFolder=False,
|
||||
metadata={"id": m.get("id")},
|
||||
)
|
||||
for m in result.get("messages", [])
|
||||
]
|
||||
|
||||
|
||||
class GoogleConnector(ProviderConnector):
|
||||
"""Google ProviderConnector -- 1 connection -> Drive + Gmail."""
|
||||
|
||||
_SERVICE_MAP = {
|
||||
"drive": DriveAdapter,
|
||||
"gmail": GmailAdapter,
|
||||
}
|
||||
|
||||
def getAvailableServices(self) -> List[str]:
|
||||
return list(self._SERVICE_MAP.keys())
|
||||
|
||||
def getServiceAdapter(self, service: str) -> ServiceAdapter:
|
||||
adapterClass = self._SERVICE_MAP.get(service)
|
||||
if not adapterClass:
|
||||
raise ValueError(f"Unknown Google service: {service}. Available: {list(self._SERVICE_MAP.keys())}")
|
||||
return adapterClass(self.accessToken)
|
||||
3
modules/connectors/providerMsft/__init__.py
Normal file
3
modules/connectors/providerMsft/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Microsoft Provider Connector -- 1 Connection : n Services (SharePoint, Outlook, Teams, OneDrive)."""
|
||||
469
modules/connectors/providerMsft/connectorMsft.py
Normal file
469
modules/connectors/providerMsft/connectorMsft.py
Normal file
|
|
@ -0,0 +1,469 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Microsoft ProviderConnector -- one MSFT connection serves SharePoint, Outlook, Teams, OneDrive.
|
||||
|
||||
All ServiceAdapters share the same OAuth access token obtained from the
|
||||
UserConnection (authority=msft).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from modules.connectors.connectorProviderBase import ProviderConnector, ServiceAdapter, DownloadResult
|
||||
from modules.datamodels.datamodelDataSource import ExternalEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
|
||||
class _GraphApiMixin:
|
||||
"""Shared Graph API call logic for all MSFT service adapters."""
|
||||
|
||||
def __init__(self, accessToken: str):
|
||||
self._accessToken = accessToken
|
||||
|
||||
async def _graphGet(self, endpoint: str) -> Dict[str, Any]:
|
||||
return await _makeGraphCall(self._accessToken, endpoint, "GET")
|
||||
|
||||
async def _graphPost(self, endpoint: str, data: Any = None) -> Dict[str, Any]:
|
||||
return await _makeGraphCall(self._accessToken, endpoint, "POST", data)
|
||||
|
||||
async def _graphPut(self, endpoint: str, data: bytes = None) -> Dict[str, Any]:
|
||||
return await _makeGraphCall(self._accessToken, endpoint, "PUT", data)
|
||||
|
||||
async def _graphDelete(self, endpoint: str) -> Dict[str, Any]:
|
||||
return await _makeGraphCall(self._accessToken, endpoint, "DELETE")
|
||||
|
||||
async def _graphDownload(self, endpoint: str) -> Optional[bytes]:
|
||||
"""Download binary content from Graph API."""
|
||||
headers = {"Authorization": f"Bearer {self._accessToken}"}
|
||||
timeout = aiohttp.ClientTimeout(total=60)
|
||||
url = f"{_GRAPH_BASE}/{endpoint.lstrip('/')}"
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url, headers=headers) as resp:
|
||||
if resp.status == 200:
|
||||
return await resp.read()
|
||||
logger.error(f"Download failed {resp.status}: {await resp.text()}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Graph download error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _makeGraphCall(
|
||||
token: str, endpoint: str, method: str = "GET", data: Any = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute a single Microsoft Graph API call."""
|
||||
url = f"{_GRAPH_BASE}/{endpoint.lstrip('/')}"
|
||||
contentType = "application/json"
|
||||
if method == "PUT" and isinstance(data, bytes):
|
||||
contentType = "application/octet-stream"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": contentType,
|
||||
}
|
||||
timeout = aiohttp.ClientTimeout(total=30)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
kwargs: Dict[str, Any] = {"headers": headers}
|
||||
if data is not None:
|
||||
kwargs["data"] = data
|
||||
|
||||
if method == "GET":
|
||||
async with session.get(url, **kwargs) as resp:
|
||||
return await _handleResponse(resp)
|
||||
elif method == "POST":
|
||||
async with session.post(url, **kwargs) as resp:
|
||||
return await _handleResponse(resp)
|
||||
elif method == "PUT":
|
||||
async with session.put(url, **kwargs) as resp:
|
||||
return await _handleResponse(resp)
|
||||
elif method == "DELETE":
|
||||
async with session.delete(url, **kwargs) as resp:
|
||||
if resp.status in (200, 204):
|
||||
return {}
|
||||
return await _handleResponse(resp)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return {"error": f"Graph API timeout: {endpoint}"}
|
||||
except Exception as e:
|
||||
return {"error": f"Graph API error: {e}"}
|
||||
|
||||
return {"error": f"Unsupported method: {method}"}
|
||||
|
||||
|
||||
async def _handleResponse(resp: aiohttp.ClientResponse) -> Dict[str, Any]:
|
||||
if resp.status in (200, 201):
|
||||
return await resp.json()
|
||||
errorText = await resp.text()
|
||||
logger.error(f"Graph API {resp.status}: {errorText}")
|
||||
return {"error": f"{resp.status}: {errorText}"}
|
||||
|
||||
|
||||
def _graphItemToExternalEntry(item: Dict[str, Any], basePath: str = "") -> ExternalEntry:
|
||||
isFolder = "folder" in item
|
||||
return ExternalEntry(
|
||||
name=item.get("name", ""),
|
||||
path=f"{basePath}/{item.get('name', '')}" if basePath else item.get("name", ""),
|
||||
isFolder=isFolder,
|
||||
size=item.get("size"),
|
||||
mimeType=item.get("file", {}).get("mimeType") if not isFolder else None,
|
||||
lastModified=None,
|
||||
metadata={
|
||||
"id": item.get("id"),
|
||||
"webUrl": item.get("webUrl"),
|
||||
"childCount": item.get("folder", {}).get("childCount") if isFolder else None,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SharePoint Adapter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SharepointAdapter(_GraphApiMixin, ServiceAdapter):
|
||||
"""ServiceAdapter for SharePoint (files, sites) via Microsoft Graph."""
|
||||
|
||||
async def browse(self, path: str, filter: Optional[str] = None) -> List[ExternalEntry]:
|
||||
"""List items in a SharePoint folder.
|
||||
|
||||
Path format: /sites/<SiteName>/<FolderPath>
|
||||
Root "/" lists available sites via discovery.
|
||||
"""
|
||||
if not path or path == "/":
|
||||
return await self._discoverSites()
|
||||
|
||||
siteId, folderPath = _parseSharepointPath(path)
|
||||
if not siteId:
|
||||
return await self._discoverSites()
|
||||
|
||||
if not folderPath or folderPath == "/":
|
||||
endpoint = f"sites/{siteId}/drive/root/children"
|
||||
else:
|
||||
cleanPath = folderPath.lstrip("/")
|
||||
endpoint = f"sites/{siteId}/drive/root:/{cleanPath}:/children"
|
||||
|
||||
result = await self._graphGet(endpoint)
|
||||
if "error" in result:
|
||||
logger.warning(f"SharePoint browse failed: {result['error']}")
|
||||
return []
|
||||
|
||||
entries = [_graphItemToExternalEntry(item, path) for item in result.get("value", [])]
|
||||
if filter:
|
||||
entries = [e for e in entries if _matchFilter(e, filter)]
|
||||
return entries
|
||||
|
||||
async def _discoverSites(self) -> List[ExternalEntry]:
|
||||
"""Discover accessible SharePoint sites."""
|
||||
result = await self._graphGet("sites?search=*&$top=50")
|
||||
if "error" in result:
|
||||
logger.warning(f"SharePoint site discovery failed: {result['error']}")
|
||||
return []
|
||||
return [
|
||||
ExternalEntry(
|
||||
name=s.get("displayName") or s.get("name", ""),
|
||||
path=f"/sites/{s.get('id', '')}",
|
||||
isFolder=True,
|
||||
metadata={
|
||||
"id": s.get("id"),
|
||||
"webUrl": s.get("webUrl"),
|
||||
"description": s.get("description", ""),
|
||||
},
|
||||
)
|
||||
for s in result.get("value", [])
|
||||
if s.get("displayName")
|
||||
]
|
||||
|
||||
async def download(self, path: str) -> bytes:
|
||||
siteId, filePath = _parseSharepointPath(path)
|
||||
if not siteId or not filePath:
|
||||
return b""
|
||||
cleanPath = filePath.strip("/")
|
||||
endpoint = f"sites/{siteId}/drive/root:/{cleanPath}:/content"
|
||||
data = await self._graphDownload(endpoint)
|
||||
return data or b""
|
||||
|
||||
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||
siteId, folderPath = _parseSharepointPath(path)
|
||||
if not siteId:
|
||||
return {"error": "Invalid SharePoint path"}
|
||||
cleanFolder = (folderPath or "").strip("/")
|
||||
uploadPath = f"{cleanFolder}/{fileName}" if cleanFolder else fileName
|
||||
endpoint = f"sites/{siteId}/drive/root:/{uploadPath}:/content"
|
||||
result = await self._graphPut(endpoint, data)
|
||||
return result
|
||||
|
||||
async def search(self, query: str, path: Optional[str] = None) -> List[ExternalEntry]:
|
||||
siteId, _ = _parseSharepointPath(path or "")
|
||||
if not siteId:
|
||||
return []
|
||||
safeQuery = query.replace("'", "''")
|
||||
endpoint = f"sites/{siteId}/drive/root/search(q='{safeQuery}')"
|
||||
result = await self._graphGet(endpoint)
|
||||
if "error" in result:
|
||||
return []
|
||||
return [_graphItemToExternalEntry(item) for item in result.get("value", [])]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Outlook Adapter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class OutlookAdapter(_GraphApiMixin, ServiceAdapter):
|
||||
"""ServiceAdapter for Outlook (mail, calendar) via Microsoft Graph."""
|
||||
|
||||
async def browse(self, path: str, filter: Optional[str] = None) -> List[ExternalEntry]:
|
||||
"""List mail folders or messages.
|
||||
|
||||
path = "" or "/" → list mail folders
|
||||
path = "/Inbox" → list messages in Inbox
|
||||
"""
|
||||
if not path or path == "/":
|
||||
result = await self._graphGet("me/mailFolders")
|
||||
if "error" in result:
|
||||
return []
|
||||
return [
|
||||
ExternalEntry(
|
||||
name=f.get("displayName", ""),
|
||||
path=f"/{f.get('id', '')}",
|
||||
isFolder=True,
|
||||
metadata={"id": f.get("id"), "totalItemCount": f.get("totalItemCount")},
|
||||
)
|
||||
for f in result.get("value", [])
|
||||
]
|
||||
|
||||
folderId = path.strip("/")
|
||||
endpoint = f"me/mailFolders/{folderId}/messages?$top=25&$orderby=receivedDateTime desc"
|
||||
result = await self._graphGet(endpoint)
|
||||
if "error" in result:
|
||||
return []
|
||||
return [
|
||||
ExternalEntry(
|
||||
name=m.get("subject", "(no subject)"),
|
||||
path=f"{path}/{m.get('id', '')}",
|
||||
isFolder=False,
|
||||
metadata={
|
||||
"id": m.get("id"),
|
||||
"from": m.get("from", {}).get("emailAddress", {}).get("address"),
|
||||
"receivedDateTime": m.get("receivedDateTime"),
|
||||
"hasAttachments": m.get("hasAttachments", False),
|
||||
},
|
||||
)
|
||||
for m in result.get("value", [])
|
||||
]
|
||||
|
||||
async def download(self, path: str) -> DownloadResult:
|
||||
"""Download a mail message as RFC 822 EML via Graph API $value endpoint."""
|
||||
import re
|
||||
messageId = path.strip("/").split("/")[-1]
|
||||
|
||||
meta = await self._graphGet(f"me/messages/{messageId}?$select=subject")
|
||||
subject = meta.get("subject", messageId) if "error" not in meta else messageId
|
||||
safeName = re.sub(r'[<>:"/\\|?*\x00-\x1f]', "_", subject)[:80].strip(". ") or "email"
|
||||
|
||||
emlBytes = await self._graphDownload(f"me/messages/{messageId}/$value")
|
||||
if not emlBytes:
|
||||
return DownloadResult()
|
||||
|
||||
return DownloadResult(
|
||||
data=emlBytes,
|
||||
fileName=f"{safeName}.eml",
|
||||
mimeType="message/rfc822",
|
||||
)
|
||||
|
||||
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||
"""Not applicable for Outlook in the file sense."""
|
||||
return {"error": "Upload not supported for Outlook"}
|
||||
|
||||
async def search(self, query: str, path: Optional[str] = None) -> List[ExternalEntry]:
|
||||
safeQuery = query.replace("'", "''")
|
||||
endpoint = f"me/messages?$search=\"{safeQuery}\"&$top=25"
|
||||
result = await self._graphGet(endpoint)
|
||||
if "error" in result:
|
||||
return []
|
||||
return [
|
||||
ExternalEntry(
|
||||
name=m.get("subject", "(no subject)"),
|
||||
path=f"/search/{m.get('id', '')}",
|
||||
isFolder=False,
|
||||
metadata={
|
||||
"id": m.get("id"),
|
||||
"from": m.get("from", {}).get("emailAddress", {}).get("address"),
|
||||
"receivedDateTime": m.get("receivedDateTime"),
|
||||
},
|
||||
)
|
||||
for m in result.get("value", [])
|
||||
]
|
||||
|
||||
async def sendMail(
|
||||
self, to: List[str], subject: str, body: str,
|
||||
cc: Optional[List[str]] = None, attachments: Optional[List[Dict]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Send an email via Microsoft Graph."""
|
||||
import json
|
||||
message: Dict[str, Any] = {
|
||||
"subject": subject,
|
||||
"body": {"contentType": "Text", "content": body},
|
||||
"toRecipients": [{"emailAddress": {"address": addr}} for addr in to],
|
||||
}
|
||||
if cc:
|
||||
message["ccRecipients"] = [{"emailAddress": {"address": addr}} for addr in cc]
|
||||
|
||||
payload = json.dumps({"message": message, "saveToSentItems": True}).encode("utf-8")
|
||||
result = await self._graphPost("me/sendMail", payload)
|
||||
if "error" in result:
|
||||
return result
|
||||
return {"success": True}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Teams Adapter (Stub)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TeamsAdapter(_GraphApiMixin, ServiceAdapter):
|
||||
"""ServiceAdapter for Microsoft Teams -- browse joined teams and channels."""
|
||||
|
||||
async def browse(self, path: str, filter: Optional[str] = None) -> list:
|
||||
cleanPath = (path or "").strip("/")
|
||||
|
||||
if not cleanPath:
|
||||
result = await self._graphGet("me/joinedTeams")
|
||||
if "error" in result:
|
||||
logger.warning(f"Teams browse failed: {result['error']}")
|
||||
return []
|
||||
return [
|
||||
ExternalEntry(
|
||||
name=t.get("displayName", ""),
|
||||
path=f"/{t.get('id', '')}",
|
||||
isFolder=True,
|
||||
metadata={"id": t.get("id"), "description": t.get("description", "")},
|
||||
)
|
||||
for t in result.get("value", [])
|
||||
]
|
||||
|
||||
parts = cleanPath.split("/", 1)
|
||||
teamId = parts[0]
|
||||
if len(parts) == 1:
|
||||
result = await self._graphGet(f"teams/{teamId}/channels")
|
||||
if "error" in result:
|
||||
return []
|
||||
return [
|
||||
ExternalEntry(
|
||||
name=ch.get("displayName", ""),
|
||||
path=f"/{teamId}/{ch.get('id', '')}",
|
||||
isFolder=True,
|
||||
metadata={"id": ch.get("id"), "membershipType": ch.get("membershipType", "")},
|
||||
)
|
||||
for ch in result.get("value", [])
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
async def download(self, path: str) -> bytes:
|
||||
return b""
|
||||
|
||||
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||
return {"error": "Teams upload not implemented"}
|
||||
|
||||
async def search(self, query: str, path: Optional[str] = None) -> list:
|
||||
return []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OneDrive Adapter (Stub -- similar to SharePoint but personal drive)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class OneDriveAdapter(_GraphApiMixin, ServiceAdapter):
|
||||
"""ServiceAdapter stub for OneDrive (personal drive)."""
|
||||
|
||||
async def browse(self, path: str, filter: Optional[str] = None) -> List[ExternalEntry]:
|
||||
cleanPath = (path or "").strip("/")
|
||||
if not cleanPath:
|
||||
endpoint = "me/drive/root/children"
|
||||
else:
|
||||
endpoint = f"me/drive/root:/{cleanPath}:/children"
|
||||
|
||||
result = await self._graphGet(endpoint)
|
||||
if "error" in result:
|
||||
return []
|
||||
entries = [_graphItemToExternalEntry(item, path) for item in result.get("value", [])]
|
||||
if filter:
|
||||
entries = [e for e in entries if _matchFilter(e, filter)]
|
||||
return entries
|
||||
|
||||
async def download(self, path: str) -> bytes:
|
||||
cleanPath = (path or "").strip("/")
|
||||
if not cleanPath:
|
||||
return b""
|
||||
data = await self._graphDownload(f"me/drive/root:/{cleanPath}:/content")
|
||||
return data or b""
|
||||
|
||||
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||
cleanPath = (path or "").strip("/")
|
||||
uploadPath = f"{cleanPath}/{fileName}" if cleanPath else fileName
|
||||
endpoint = f"me/drive/root:/{uploadPath}:/content"
|
||||
return await self._graphPut(endpoint, data)
|
||||
|
||||
async def search(self, query: str, path: Optional[str] = None) -> List[ExternalEntry]:
|
||||
safeQuery = query.replace("'", "''")
|
||||
endpoint = f"me/drive/root/search(q='{safeQuery}')"
|
||||
result = await self._graphGet(endpoint)
|
||||
if "error" in result:
|
||||
return []
|
||||
return [_graphItemToExternalEntry(item) for item in result.get("value", [])]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MsftConnector (1:n)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MsftConnector(ProviderConnector):
|
||||
"""Microsoft ProviderConnector -- 1 connection → n services."""
|
||||
|
||||
_SERVICE_MAP = {
|
||||
"sharepoint": SharepointAdapter,
|
||||
"outlook": OutlookAdapter,
|
||||
"teams": TeamsAdapter,
|
||||
"onedrive": OneDriveAdapter,
|
||||
}
|
||||
|
||||
def getAvailableServices(self) -> List[str]:
|
||||
return list(self._SERVICE_MAP.keys())
|
||||
|
||||
def getServiceAdapter(self, service: str) -> ServiceAdapter:
|
||||
adapterClass = self._SERVICE_MAP.get(service)
|
||||
if not adapterClass:
|
||||
raise ValueError(f"Unknown MSFT service: {service}. Available: {list(self._SERVICE_MAP.keys())}")
|
||||
return adapterClass(self.accessToken)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _parseSharepointPath(path: str) -> tuple:
|
||||
"""Parse a SharePoint path into (siteId, innerPath).
|
||||
|
||||
Expected format: /sites/<siteId>/<innerPath>
|
||||
Also accepts bare siteId if no /sites/ prefix.
|
||||
"""
|
||||
if not path:
|
||||
return ("", "")
|
||||
clean = path.strip("/")
|
||||
if clean.startswith("sites/"):
|
||||
parts = clean.split("/", 2)
|
||||
siteId = parts[1] if len(parts) > 1 else ""
|
||||
innerPath = parts[2] if len(parts) > 2 else ""
|
||||
return (siteId, innerPath)
|
||||
parts = clean.split("/", 1)
|
||||
return (parts[0], parts[1] if len(parts) > 1 else "")
|
||||
|
||||
|
||||
def _matchFilter(entry: ExternalEntry, pattern: str) -> bool:
|
||||
"""Simple glob-like filter (supports * wildcard)."""
|
||||
import fnmatch
|
||||
return fnmatch.fnmatch(entry.name.lower(), pattern.lower())
|
||||
|
|
@ -26,6 +26,12 @@ class OperationTypeEnum(str, Enum):
|
|||
WEB_SEARCH_DATA = "webSearch" # Returns list of URLs only
|
||||
WEB_CRAWL = "webCrawl" # Web crawl for a given URL
|
||||
|
||||
# Agent Operations
|
||||
AGENT = "agent" # Agent loop: reasoning + tool use
|
||||
|
||||
# Embedding Operations
|
||||
EMBEDDING = "embedding" # Text → vector conversion for semantic search
|
||||
|
||||
# Speech Operations (dedicated pipeline, bypasses standard model selection)
|
||||
SPEECH_TEAMS = "speechTeams" # Teams Meeting AI analysis: decide if/how to respond
|
||||
|
||||
|
|
@ -102,6 +108,7 @@ class AiModel(BaseModel):
|
|||
|
||||
# Function reference (not serialized)
|
||||
functionCall: Optional[Callable] = Field(default=None, exclude=True, description="Function to call for this model")
|
||||
functionCallStream: Optional[Callable] = Field(default=None, exclude=True, description="Streaming function: yields str deltas, then final AiModelResponse")
|
||||
calculatepriceCHF: Optional[Callable] = Field(default=None, exclude=True, description="Function to calculate price in USD")
|
||||
|
||||
# Selection criteria - capabilities with ratings
|
||||
|
|
@ -155,10 +162,12 @@ class AiCallOptions(BaseModel):
|
|||
class AiCallRequest(BaseModel):
|
||||
"""Centralized AI call request payload for interface use."""
|
||||
|
||||
prompt: str = Field(description="The user prompt")
|
||||
prompt: str = Field(default="", description="The user prompt")
|
||||
context: Optional[str] = Field(default=None, description="Optional external context (e.g., extracted docs)")
|
||||
options: AiCallOptions = Field(default_factory=AiCallOptions)
|
||||
contentParts: Optional[List['ContentPart']] = None # NEW: Content parts for model-aware chunking
|
||||
contentParts: Optional[List['ContentPart']] = None # Content parts for model-aware chunking
|
||||
messages: Optional[List[Dict[str, Any]]] = Field(default=None, description="OpenAI-style messages for multi-turn agent conversations")
|
||||
tools: Optional[List[Dict[str, Any]]] = Field(default=None, description="Tool definitions for native function calling")
|
||||
|
||||
|
||||
class AiCallResponse(BaseModel):
|
||||
|
|
@ -172,14 +181,19 @@ class AiCallResponse(BaseModel):
|
|||
bytesSent: int = Field(default=0, description="Input data size in bytes")
|
||||
bytesReceived: int = Field(default=0, description="Output data size in bytes")
|
||||
errorCount: int = Field(default=0, description="0 for success, 1+ for errors")
|
||||
toolCalls: Optional[List[Dict[str, Any]]] = Field(default=None, description="Tool calls from native function calling")
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional response metadata (e.g. embeddings vectors)")
|
||||
|
||||
|
||||
class AiModelCall(BaseModel):
|
||||
"""Standardized input for AI model calls."""
|
||||
|
||||
messages: List[Dict[str, Any]] = Field(description="Messages in OpenAI format (role, content)")
|
||||
messages: List[Dict[str, Any]] = Field(default_factory=list, description="Messages in OpenAI format (role, content)")
|
||||
model: Optional[AiModel] = Field(default=None, description="The AI model being called")
|
||||
options: AiCallOptions = Field(default_factory=AiCallOptions, description="Additional model-specific options")
|
||||
tools: Optional[List[Dict[str, Any]]] = Field(default=None, description="Tool definitions for native function calling")
|
||||
toolChoice: Optional[Any] = Field(default=None, description="Tool choice: 'auto', 'none', or specific tool")
|
||||
embeddingInput: Optional[List[str]] = Field(default=None, description="Input texts for embedding models (used instead of messages)")
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -119,11 +119,17 @@ class BillingTransaction(BaseModel):
|
|||
# Context for workflow transactions
|
||||
workflowId: Optional[str] = Field(None, description="Workflow ID (for WORKFLOW transactions)")
|
||||
featureInstanceId: Optional[str] = Field(None, description="Feature instance ID")
|
||||
featureCode: Optional[str] = Field(None, description="Feature code (e.g., chatplayground, automation)")
|
||||
featureCode: Optional[str] = Field(None, description="Feature code (e.g., automation)")
|
||||
aicoreProvider: Optional[str] = Field(None, description="AICore provider (anthropic, openai, etc.)")
|
||||
aicoreModel: Optional[str] = Field(None, description="AICore model name (e.g., claude-4-sonnet, gpt-4o)")
|
||||
createdByUserId: Optional[str] = Field(None, description="User who created/caused this transaction")
|
||||
|
||||
# AI call metadata (for per-call analytics)
|
||||
processingTime: Optional[float] = Field(None, description="Processing time in seconds")
|
||||
bytesSent: Optional[int] = Field(None, description="Bytes sent to AI model")
|
||||
bytesReceived: Optional[int] = Field(None, description="Bytes received from AI model")
|
||||
errorCount: Optional[int] = Field(None, description="Number of errors in this call")
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"BillingTransaction",
|
||||
|
|
@ -218,7 +224,7 @@ class UsageStatistics(BaseModel):
|
|||
# Breakdown by feature
|
||||
costByFeature: Dict[str, float] = Field(
|
||||
default_factory=dict,
|
||||
description="Cost breakdown by feature (e.g., {'chatplayground': 15.00, 'automation': 5.80})"
|
||||
description="Cost breakdown by feature (e.g., {'automation': 5.80, 'workspace': 3.20})"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Chat models: ChatWorkflow, ChatMessage, ChatLog, ChatStat, ChatDocument."""
|
||||
"""Chat models: ChatWorkflow, ChatMessage, ChatLog, ChatDocument."""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from enum import Enum
|
||||
|
|
@ -10,44 +10,6 @@ from modules.shared.timeUtils import getUtcTimestamp
|
|||
import uuid
|
||||
|
||||
|
||||
class ChatStat(BaseModel):
|
||||
"""Statistics for chat operations. User-owned, no mandate context."""
|
||||
model_config = {"populate_by_name": True, "extra": "allow"} # Allow DB system fields
|
||||
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||
)
|
||||
workflowId: Optional[str] = Field(
|
||||
None, description="Foreign key to workflow (for workflow stats)"
|
||||
)
|
||||
processingTime: Optional[float] = Field(
|
||||
None, description="Processing time in seconds"
|
||||
)
|
||||
bytesSent: Optional[int] = Field(None, description="Number of bytes sent")
|
||||
bytesReceived: Optional[int] = Field(None, description="Number of bytes received")
|
||||
errorCount: Optional[int] = Field(None, description="Number of errors encountered")
|
||||
process: Optional[str] = Field(None, description="The process that delivers the stats data (e.g. 'action.outlook.readMails', 'ai.process.document.name')")
|
||||
engine: Optional[str] = Field(None, description="The engine used (e.g. 'ai.anthropic.35', 'ai.tavily.basic', 'renderer.docx')")
|
||||
priceCHF: Optional[float] = Field(None, description="Calculated price in USD for the operation")
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"ChatStat",
|
||||
{"en": "Chat Statistics", "fr": "Statistiques de chat"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"workflowId": {"en": "Workflow ID", "fr": "ID du workflow"},
|
||||
"processingTime": {"en": "Processing Time", "fr": "Temps de traitement"},
|
||||
"bytesSent": {"en": "Bytes Sent", "fr": "Octets envoyés"},
|
||||
"bytesReceived": {"en": "Bytes Received", "fr": "Octets reçus"},
|
||||
"errorCount": {"en": "Error Count", "fr": "Nombre d'erreurs"},
|
||||
"process": {"en": "Process", "fr": "Processus"},
|
||||
"engine": {"en": "Engine", "fr": "Moteur"},
|
||||
"priceCHF": {"en": "Price CHF", "fr": "Prix CHF"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class ChatLog(BaseModel):
|
||||
"""Log entries for chat workflows. User-owned, no mandate context."""
|
||||
id: str = Field(
|
||||
|
|
@ -285,8 +247,6 @@ class WorkflowModeEnum(str, Enum):
|
|||
WORKFLOW_DYNAMIC = "Dynamic"
|
||||
WORKFLOW_AUTOMATION = "Automation"
|
||||
WORKFLOW_CHATBOT = "Chatbot"
|
||||
WORKFLOW_CODEEDITOR = "CodeEditor"
|
||||
WORKFLOW_REACT = "React" # Legacy mode - kept for backward compatibility
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
|
|
@ -296,8 +256,6 @@ registerModelLabels(
|
|||
"WORKFLOW_DYNAMIC": {"en": "Dynamic", "fr": "Dynamique"},
|
||||
"WORKFLOW_AUTOMATION": {"en": "Automation", "fr": "Automatisation"},
|
||||
"WORKFLOW_CHATBOT": {"en": "Chatbot", "fr": "Chatbot"},
|
||||
"WORKFLOW_CODEEDITOR": {"en": "Code Editor", "fr": "Éditeur de code"},
|
||||
"WORKFLOW_REACT": {"en": "React (Legacy)", "fr": "React (Hérité)"},
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -322,7 +280,6 @@ class ChatWorkflow(BaseModel):
|
|||
startedAt: float = Field(default_factory=getUtcTimestamp, description="When the workflow started (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
|
||||
logs: List[ChatLog] = Field(default_factory=list, description="Workflow logs", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
messages: List[ChatMessage] = Field(default_factory=list, description="Messages in the workflow", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
stats: List[ChatStat] = Field(default_factory=list, description="Workflow statistics list", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
tasks: list = Field(default_factory=list, description="List of tasks in the workflow", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
workflowMode: WorkflowModeEnum = Field(default=WorkflowModeEnum.WORKFLOW_DYNAMIC, description="Workflow mode selector", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": [
|
||||
{
|
||||
|
|
@ -337,10 +294,6 @@ class ChatWorkflow(BaseModel):
|
|||
"value": WorkflowModeEnum.WORKFLOW_CHATBOT.value,
|
||||
"label": {"en": "Chatbot", "fr": "Chatbot"},
|
||||
},
|
||||
{
|
||||
"value": WorkflowModeEnum.WORKFLOW_REACT.value,
|
||||
"label": {"en": "React (Legacy)", "fr": "React (Hérité)"},
|
||||
},
|
||||
]})
|
||||
maxSteps: int = Field(default=10, description="Maximum number of iterations in dynamic mode", json_schema_extra={"frontend_type": "integer", "frontend_readonly": False, "frontend_required": False})
|
||||
expectedFormats: Optional[List[str]] = Field(None, description="List of expected file format extensions from user request (e.g., ['xlsx', 'pdf']). Extracted during intent analysis.", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
|
|
|
|||
58
modules/datamodels/datamodelContent.py
Normal file
58
modules/datamodels/datamodelContent.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Content Object data models for the container and content extraction pipeline.
|
||||
|
||||
Physical layer: Container hierarchy (ZIP, Folder, File)
|
||||
Logical layer: Scalar content objects (text, image, videostream, audiostream, other)
|
||||
|
||||
The entire extraction pipeline up to ContentObjects runs without AI.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
import uuid
|
||||
|
||||
|
||||
class ContainerLimitError(Exception):
|
||||
"""Raised when container extraction exceeds safety limits (size, depth, file count)."""
|
||||
pass
|
||||
|
||||
|
||||
class ContentContextRef(BaseModel):
|
||||
"""Reference to the origin context within a container/file."""
|
||||
containerPath: str = Field(description="e.g. 'archiv.zip/folder-a/report.pdf'")
|
||||
location: str = Field(default="", description="e.g. 'page:5/region:bottomLeft'")
|
||||
label: Optional[str] = Field(default=None, description="e.g. 'Abbildung 3: Uebersicht'")
|
||||
pageIndex: Optional[int] = Field(default=None, description="Page number (PDF, DOCX)")
|
||||
sectionId: Optional[str] = Field(default=None, description="Section/Heading ID")
|
||||
sheetName: Optional[str] = Field(default=None, description="Sheet name (XLSX)")
|
||||
slideIndex: Optional[int] = Field(default=None, description="Slide number (PPTX)")
|
||||
|
||||
|
||||
class ContentObject(BaseModel):
|
||||
"""Scalar content object extracted from a file. No AI involved."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
fileId: str = Field(description="FK to the physical file")
|
||||
contentType: str = Field(description="text, image, videostream, audiostream, other")
|
||||
data: str = Field(default="", description="Content data (text, base64, URL)")
|
||||
contextRef: ContentContextRef = Field(default_factory=ContentContextRef)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
sequence: int = Field(default=0, description="Order within the context")
|
||||
|
||||
|
||||
class ContentObjectSummary(BaseModel):
|
||||
"""Compact description of a content object for the FileContentIndex."""
|
||||
id: str = Field(description="Content object ID")
|
||||
contentType: str = Field(description="text, image, videostream, audiostream, other")
|
||||
contextRef: ContentContextRef = Field(default_factory=ContentContextRef)
|
||||
charCount: Optional[int] = Field(default=None, description="Only for text")
|
||||
dimensions: Optional[str] = Field(default=None, description="Only for image/video (e.g. '1920x1080')")
|
||||
duration: Optional[float] = Field(default=None, description="Only for audio/video (seconds)")
|
||||
|
||||
|
||||
class FileEntry(BaseModel):
|
||||
"""A file extracted from a container (ZIP, TAR, Folder)."""
|
||||
path: str = Field(description="Relative path within the container")
|
||||
data: bytes = Field(description="File content bytes")
|
||||
mimeType: str = Field(description="Detected MIME type")
|
||||
size: int = Field(description="File size in bytes")
|
||||
58
modules/datamodels/datamodelDataSource.py
Normal file
58
modules/datamodels/datamodelDataSource.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""DataSource and ExternalEntry models for external data integration.
|
||||
|
||||
DataSource links a UserConnection to an external path (SharePoint folder,
|
||||
Google Drive folder, FTP directory, etc.) for agent-accessible data containers.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
import uuid
|
||||
|
||||
|
||||
class DataSource(BaseModel):
|
||||
"""Configured external data source linked to a UserConnection."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
connectionId: str = Field(description="FK to UserConnection")
|
||||
sourceType: str = Field(description="sharepointFolder, googleDriveFolder, outlookFolder, ftpFolder")
|
||||
path: str = Field(description="External path (e.g. '/sites/MySite/Documents/Reports')")
|
||||
label: str = Field(description="User-visible label")
|
||||
featureInstanceId: Optional[str] = Field(default=None, description="Scoped to feature instance")
|
||||
mandateId: Optional[str] = Field(default=None, description="Mandate scope")
|
||||
userId: str = Field(default="", description="Owner user ID")
|
||||
autoSync: bool = Field(default=False, description="Automatically sync on schedule")
|
||||
lastSynced: Optional[float] = Field(default=None, description="Last sync timestamp")
|
||||
createdAt: float = Field(default_factory=getUtcTimestamp, description="Creation timestamp")
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"DataSource",
|
||||
{"en": "Data Source", "de": "Datenquelle", "fr": "Source de données"},
|
||||
{
|
||||
"id": {"en": "ID", "de": "ID", "fr": "ID"},
|
||||
"connectionId": {"en": "Connection ID", "de": "Verbindungs-ID", "fr": "ID de connexion"},
|
||||
"sourceType": {"en": "Source Type", "de": "Quellentyp", "fr": "Type de source"},
|
||||
"path": {"en": "Path", "de": "Pfad", "fr": "Chemin"},
|
||||
"label": {"en": "Label", "de": "Bezeichnung", "fr": "Libellé"},
|
||||
"featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance de fonctionnalité"},
|
||||
"mandateId": {"en": "Mandate ID", "de": "Mandanten-ID", "fr": "ID du mandat"},
|
||||
"userId": {"en": "User ID", "de": "Benutzer-ID", "fr": "ID utilisateur"},
|
||||
"autoSync": {"en": "Auto Sync", "de": "Auto-Sync", "fr": "Synchro auto"},
|
||||
"lastSynced": {"en": "Last Synced", "de": "Letzter Sync", "fr": "Dernier sync"},
|
||||
"createdAt": {"en": "Created At", "de": "Erstellt am", "fr": "Créé le"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class ExternalEntry(BaseModel):
|
||||
"""An item (file or folder) from an external data source."""
|
||||
name: str = Field(description="Item name")
|
||||
path: str = Field(description="Full path within the source")
|
||||
isFolder: bool = Field(default=False, description="True if directory/folder")
|
||||
size: Optional[int] = Field(default=None, description="File size in bytes")
|
||||
mimeType: Optional[str] = Field(default=None, description="MIME type (files only)")
|
||||
lastModified: Optional[float] = Field(default=None, description="Last modification timestamp")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Provider-specific metadata")
|
||||
|
|
@ -73,7 +73,7 @@ class ExtractionOptions(BaseModel):
|
|||
"""Options for document extraction and processing with clear data structures."""
|
||||
|
||||
# Core extraction parameters
|
||||
prompt: str = Field(description="Extraction prompt for AI processing")
|
||||
prompt: str = Field(default="", description="Extraction prompt for AI processing")
|
||||
processDocumentsIndividually: bool = Field(default=True, description="Process each document separately")
|
||||
|
||||
# Image processing parameters
|
||||
|
|
@ -81,7 +81,7 @@ class ExtractionOptions(BaseModel):
|
|||
imageQuality: int = Field(default=85, ge=1, le=100, description="Image quality (1-100)")
|
||||
|
||||
# Merging strategy
|
||||
mergeStrategy: MergeStrategy = Field(description="Strategy for merging extraction results")
|
||||
mergeStrategy: MergeStrategy = Field(default_factory=MergeStrategy, description="Strategy for merging extraction results")
|
||||
|
||||
# Optional chunking parameters (for backward compatibility)
|
||||
chunkAllowed: Optional[bool] = Field(default=None, description="Whether chunking is allowed")
|
||||
|
|
|
|||
45
modules/datamodels/datamodelFeatureDataSource.py
Normal file
45
modules/datamodels/datamodelFeatureDataSource.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""FeatureDataSource model for exposing feature instance data to the AI workspace.
|
||||
|
||||
A FeatureDataSource links a FeatureInstance table (DATA_OBJECT) to a workspace
|
||||
so the agent can query structured feature data (e.g. TrusteePosition rows).
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
import uuid
|
||||
|
||||
|
||||
class FeatureDataSource(BaseModel):
|
||||
"""A feature-instance table attached as data source in the AI workspace."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
featureInstanceId: str = Field(description="FK to FeatureInstance")
|
||||
featureCode: str = Field(description="Feature code (e.g. trustee, commcoach)")
|
||||
tableName: str = Field(description="Table name from DATA_OBJECTS meta (e.g. TrusteePosition)")
|
||||
objectKey: str = Field(description="RBAC object key (e.g. data.feature.trustee.TrusteePosition)")
|
||||
label: str = Field(description="User-visible label")
|
||||
mandateId: str = Field(default="", description="Mandate scope")
|
||||
userId: str = Field(default="", description="Owner user ID")
|
||||
workspaceInstanceId: str = Field(description="Workspace instance where this source is used")
|
||||
createdAt: float = Field(default_factory=getUtcTimestamp, description="Creation timestamp")
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"FeatureDataSource",
|
||||
{"en": "Feature Data Source", "de": "Feature-Datenquelle", "fr": "Source de données fonctionnalité"},
|
||||
{
|
||||
"id": {"en": "ID", "de": "ID", "fr": "ID"},
|
||||
"featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance"},
|
||||
"featureCode": {"en": "Feature", "de": "Feature", "fr": "Fonctionnalité"},
|
||||
"tableName": {"en": "Table", "de": "Tabelle", "fr": "Table"},
|
||||
"objectKey": {"en": "Object Key", "de": "Objekt-Schlüssel", "fr": "Clé objet"},
|
||||
"label": {"en": "Label", "de": "Bezeichnung", "fr": "Libellé"},
|
||||
"mandateId": {"en": "Mandate", "de": "Mandant", "fr": "Mandat"},
|
||||
"userId": {"en": "User", "de": "Benutzer", "fr": "Utilisateur"},
|
||||
"workspaceInstanceId": {"en": "Workspace", "de": "Workspace", "fr": "Espace de travail"},
|
||||
"createdAt": {"en": "Created At", "de": "Erstellt am", "fr": "Créé le"},
|
||||
},
|
||||
)
|
||||
32
modules/datamodels/datamodelFileFolder.py
Normal file
32
modules/datamodels/datamodelFileFolder.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""FileFolder: hierarchical folder structure for file organization."""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
import uuid
|
||||
|
||||
|
||||
class FileFolder(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
name: str = Field(description="Folder name", json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": True})
|
||||
parentId: Optional[str] = Field(default=None, description="Parent folder ID (null = root)", json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": False})
|
||||
mandateId: Optional[str] = Field(default=None, description="Mandate context", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
featureInstanceId: Optional[str] = Field(default=None, description="Feature instance context", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
createdAt: float = Field(default_factory=getUtcTimestamp, description="Creation timestamp", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"FileFolder",
|
||||
{"en": "File Folder", "fr": "Dossier de fichiers"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"name": {"en": "Name", "fr": "Nom"},
|
||||
"parentId": {"en": "Parent Folder", "fr": "Dossier parent"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
|
||||
"featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance"},
|
||||
"createdAt": {"en": "Created At", "fr": "Créé le"},
|
||||
},
|
||||
)
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
# All rights reserved.
|
||||
"""File-related datamodels: FileItem, FilePreview, FileData."""
|
||||
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
|
|
@ -14,12 +14,16 @@ class FileItem(BaseModel):
|
|||
model_config = ConfigDict(extra='allow') # Preserve system fields (_createdBy, _createdAt, etc.)
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
mandateId: Optional[str] = Field(default="", description="ID of the mandate this file belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
featureInstanceId: Optional[str] = Field(default="", description="ID of the feature instance this file belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
featureInstanceId: Optional[str] = Field(default="", description="ID of the feature instance this file belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False, "frontend_fk_source": "/api/features/instances", "frontend_fk_display_field": "label"})
|
||||
fileName: str = Field(description="Name of the file", json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": True})
|
||||
mimeType: str = Field(description="MIME type of the file", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
fileHash: str = Field(description="Hash of the file", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
fileSize: int = Field(description="Size of the file in bytes", json_schema_extra={"frontend_type": "integer", "frontend_readonly": True, "frontend_required": False})
|
||||
creationDate: float = Field(default_factory=getUtcTimestamp, description="Date when the file was created (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
|
||||
tags: Optional[List[str]] = Field(default=None, description="Tags for categorization and search", json_schema_extra={"frontend_type": "tags", "frontend_readonly": False, "frontend_required": False})
|
||||
folderId: Optional[str] = Field(default=None, description="ID of the parent folder", json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": False})
|
||||
description: Optional[str] = Field(default=None, description="User-provided description of the file", json_schema_extra={"frontend_type": "textarea", "frontend_readonly": False, "frontend_required": False})
|
||||
status: Optional[str] = Field(default=None, description="Processing status: pending, extracted, embedding, indexed, failed", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
|
||||
registerModelLabels(
|
||||
"FileItem",
|
||||
|
|
@ -27,12 +31,16 @@ registerModelLabels(
|
|||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
|
||||
"featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance de fonctionnalité"},
|
||||
"featureInstanceId": {"en": "Feature Instance", "fr": "Instance de fonctionnalité"},
|
||||
"fileName": {"en": "fileName", "fr": "Nom de fichier"},
|
||||
"mimeType": {"en": "MIME Type", "fr": "Type MIME"},
|
||||
"fileHash": {"en": "File Hash", "fr": "Hash du fichier"},
|
||||
"fileSize": {"en": "File Size", "fr": "Taille du fichier"},
|
||||
"creationDate": {"en": "Creation Date", "fr": "Date de création"},
|
||||
"tags": {"en": "Tags", "fr": "Tags"},
|
||||
"folderId": {"en": "Folder ID", "fr": "ID du dossier"},
|
||||
"description": {"en": "Description", "fr": "Description"},
|
||||
"status": {"en": "Status", "fr": "Statut"},
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
130
modules/datamodels/datamodelKnowledge.py
Normal file
130
modules/datamodels/datamodelKnowledge.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Knowledge Store data models: FileContentIndex, ContentChunk, WorkflowMemory.
|
||||
|
||||
These models support the 3-tier RAG architecture:
|
||||
- Shared Layer: mandateId-scoped, isShared=True
|
||||
- Instance Layer: userId + featureInstanceId-scoped
|
||||
- Workflow Layer: workflowId-scoped (WorkflowMemory)
|
||||
|
||||
Vector fields use json_schema_extra={"db_type": "vector(1536)"} for pgvector.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
import uuid
|
||||
|
||||
|
||||
class FileContentIndex(BaseModel):
|
||||
"""Structural index of a file's content objects. Created without AI.
|
||||
Lives in the Instance Layer; optionally promoted to Shared Layer via isShared."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key (typically = fileId)")
|
||||
userId: str = Field(description="Owner user ID")
|
||||
featureInstanceId: str = Field(default="", description="Feature instance scope")
|
||||
mandateId: str = Field(default="", description="Mandate scope")
|
||||
isShared: bool = Field(default=False, description="Visible in Shared Layer for all mandate users")
|
||||
fileName: str = Field(description="Original file name")
|
||||
mimeType: str = Field(description="MIME type of the file")
|
||||
containerPath: Optional[str] = Field(default=None, description="Path within a container (e.g. 'archive.zip/folder/report.pdf')")
|
||||
totalObjects: int = Field(default=0, description="Total number of content objects extracted")
|
||||
totalSize: int = Field(default=0, description="Total size of all content objects in bytes")
|
||||
structure: Dict[str, Any] = Field(default_factory=dict, description="Structural overview (pages, sections, hierarchy)")
|
||||
objectSummary: List[Dict[str, Any]] = Field(default_factory=list, description="Compact summary per content object")
|
||||
extractedAt: float = Field(default_factory=getUtcTimestamp, description="Extraction timestamp")
|
||||
status: str = Field(default="pending", description="Processing status: pending, extracted, embedding, indexed, failed")
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"FileContentIndex",
|
||||
{"en": "File Content Index", "fr": "Index du contenu de fichier"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||
"featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
|
||||
"isShared": {"en": "Shared", "fr": "Partagé"},
|
||||
"fileName": {"en": "File Name", "fr": "Nom de fichier"},
|
||||
"mimeType": {"en": "MIME Type", "fr": "Type MIME"},
|
||||
"containerPath": {"en": "Container Path", "fr": "Chemin du conteneur"},
|
||||
"totalObjects": {"en": "Total Objects", "fr": "Nombre total d'objets"},
|
||||
"totalSize": {"en": "Total Size", "fr": "Taille totale"},
|
||||
"structure": {"en": "Structure", "fr": "Structure"},
|
||||
"objectSummary": {"en": "Object Summary", "fr": "Résumé des objets"},
|
||||
"extractedAt": {"en": "Extracted At", "fr": "Extrait le"},
|
||||
"status": {"en": "Status", "fr": "Statut"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class ContentChunk(BaseModel):
|
||||
"""Persisted content chunk with embedding vector. Reusable across workflows.
|
||||
Scalar content object (or chunk thereof) with pgvector embedding."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
contentObjectId: str = Field(description="Reference to the content object within FileContentIndex")
|
||||
fileId: str = Field(description="FK to the source file")
|
||||
userId: str = Field(description="Owner user ID")
|
||||
featureInstanceId: str = Field(default="", description="Feature instance scope")
|
||||
contentType: str = Field(description="Content type: text, image, videostream, audiostream, other")
|
||||
data: str = Field(description="Content data (text, base64, URL)")
|
||||
contextRef: Dict[str, Any] = Field(default_factory=dict, description="Context reference (page, position, label)")
|
||||
summary: Optional[str] = Field(default=None, description="AI-generated summary (on demand)")
|
||||
chunkMetadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
||||
embedding: Optional[List[float]] = Field(
|
||||
default=None, description="pgvector embedding (NOT NULL for text chunks)",
|
||||
json_schema_extra={"db_type": "vector(1536)"}
|
||||
)
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"ContentChunk",
|
||||
{"en": "Content Chunk", "fr": "Fragment de contenu"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"contentObjectId": {"en": "Content Object ID", "fr": "ID de l'objet de contenu"},
|
||||
"fileId": {"en": "File ID", "fr": "ID du fichier"},
|
||||
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||
"featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance"},
|
||||
"contentType": {"en": "Content Type", "fr": "Type de contenu"},
|
||||
"data": {"en": "Data", "fr": "Données"},
|
||||
"contextRef": {"en": "Context Reference", "fr": "Référence contextuelle"},
|
||||
"summary": {"en": "Summary", "fr": "Résumé"},
|
||||
"chunkMetadata": {"en": "Metadata", "fr": "Métadonnées"},
|
||||
"embedding": {"en": "Embedding", "fr": "Vecteur d'embedding"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class WorkflowMemory(BaseModel):
|
||||
"""Workflow-scoped key-value cache for entities and facts.
|
||||
Extracted during agent rounds, persisted for cross-round and cross-workflow reuse."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
workflowId: str = Field(description="FK to the workflow")
|
||||
userId: str = Field(description="Owner user ID")
|
||||
featureInstanceId: str = Field(default="", description="Feature instance scope")
|
||||
key: str = Field(description="Key identifier (e.g. 'entity:companyName')")
|
||||
value: str = Field(description="Extracted value")
|
||||
source: str = Field(default="extraction", description="Origin: extraction, tool, conversation, summary")
|
||||
createdAt: float = Field(default_factory=getUtcTimestamp, description="Creation timestamp")
|
||||
embedding: Optional[List[float]] = Field(
|
||||
default=None, description="Optional embedding for semantic lookup",
|
||||
json_schema_extra={"db_type": "vector(1536)"}
|
||||
)
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"WorkflowMemory",
|
||||
{"en": "Workflow Memory", "fr": "Mémoire de workflow"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"workflowId": {"en": "Workflow ID", "fr": "ID du workflow"},
|
||||
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||
"featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance"},
|
||||
"key": {"en": "Key", "fr": "Clé"},
|
||||
"value": {"en": "Value", "fr": "Valeur"},
|
||||
"source": {"en": "Source", "fr": "Source"},
|
||||
"createdAt": {"en": "Created At", "fr": "Créé le"},
|
||||
"embedding": {"en": "Embedding", "fr": "Vecteur d'embedding"},
|
||||
},
|
||||
)
|
||||
|
|
@ -2,6 +2,7 @@
|
|||
# All rights reserved.
|
||||
"""Voice settings datamodel."""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
|
|
@ -16,6 +17,7 @@ class VoiceSettings(BaseModel):
|
|||
sttLanguage: str = Field(default="de-DE", description="Speech-to-Text language", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True})
|
||||
ttsLanguage: str = Field(default="de-DE", description="Text-to-Speech language", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True})
|
||||
ttsVoice: str = Field(default="de-DE-KatjaNeural", description="Text-to-Speech voice", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True})
|
||||
ttsVoiceMap: Dict[str, Any] = Field(default_factory=dict, description="Per-language voice mapping, e.g. {'de-DE': {'voiceName': 'de-DE-Wavenet-A'}, 'en-US': {'voiceName': 'en-US-Wavenet-C'}}", json_schema_extra={"frontend_type": "json", "frontend_readonly": False, "frontend_required": False})
|
||||
translationEnabled: bool = Field(default=True, description="Whether translation is enabled", json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False})
|
||||
targetLanguage: str = Field(default="en-US", description="Target language for translation", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False})
|
||||
creationDate: float = Field(default_factory=getUtcTimestamp, description="Date when the settings were created (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
|
||||
|
|
@ -33,6 +35,7 @@ registerModelLabels(
|
|||
"sttLanguage": {"en": "STT Language", "fr": "Langue STT"},
|
||||
"ttsLanguage": {"en": "TTS Language", "fr": "Langue TTS"},
|
||||
"ttsVoice": {"en": "TTS Voice", "fr": "Voix TTS"},
|
||||
"ttsVoiceMap": {"en": "TTS Voice Map", "fr": "Carte des voix TTS"},
|
||||
"translationEnabled": {"en": "Translation Enabled", "fr": "Traduction activée"},
|
||||
"targetLanguage": {"en": "Target Language", "fr": "Langue cible"},
|
||||
"creationDate": {"en": "Creation Date", "fr": "Date de création"},
|
||||
|
|
|
|||
|
|
@ -180,7 +180,7 @@ def getAutomationServices(
|
|||
for spec in REQUIRED_SERVICES:
|
||||
key = spec["serviceKey"]
|
||||
try:
|
||||
svc = getService(key, ctx, legacy_hub=None)
|
||||
svc = getService(key, ctx)
|
||||
setattr(hub, key, svc)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not resolve service '{key}' for automation: {e}")
|
||||
|
|
|
|||
|
|
@ -17,10 +17,11 @@ from modules.features.automation.interfaceFeatureAutomation import getInterface
|
|||
from modules.features.automation.mainAutomation import getAutomationServices
|
||||
from modules.auth import limiter, getRequestContext, RequestContext
|
||||
from modules.features.automation.datamodelFeatureAutomation import AutomationDefinition, AutomationTemplate
|
||||
from modules.datamodels.datamodelChat import ChatWorkflow, ChatMessage, ChatLog
|
||||
from modules.datamodels.datamodelChat import ChatWorkflow, ChatMessage, ChatLog, UserInputRequest, WorkflowModeEnum
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||||
from modules.shared.attributeUtils import getModelAttributeDefinitions
|
||||
from modules.interfaces import interfaceDbChat
|
||||
from modules.interfaces.interfaceDbBilling import getInterface as _getBillingInterface
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -234,7 +235,7 @@ def get_available_actions(
|
|||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Workflow routes under /{instanceId}/workflows/ (instance-scoped, same as chatplayground)
|
||||
# Workflow routes under /{instanceId}/workflows/ (instance-scoped)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
def _validateAutomationInstanceAccess(instanceId: str, context: RequestContext) -> Optional[str]:
|
||||
|
|
@ -682,7 +683,9 @@ def get_automation_workflow_chat_data(
|
|||
workflow = chatInterface.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail=f"Workflow {workflowId} not found")
|
||||
return chatInterface.getUnifiedChatData(workflowId, afterTimestamp)
|
||||
billingInterface = _getBillingInterface(context.user, context.mandateId)
|
||||
workflowCost = billingInterface.getWorkflowCost(workflowId)
|
||||
return chatInterface.getUnifiedChatData(workflowId, afterTimestamp, workflowCost=workflowCost)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -851,6 +854,46 @@ def delete_automation(
|
|||
detail=f"Error deleting automation: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/{instanceId}/start", response_model=ChatWorkflow)
|
||||
@limiter.limit("120/minute")
|
||||
async def start_automation_workflow(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: Optional[str] = Query(None, description="Optional ID of the workflow to continue"),
|
||||
workflowMode: WorkflowModeEnum = Query(..., description="Workflow mode: 'Dynamic' or 'Automation' (mandatory)"),
|
||||
userInput: UserInputRequest = Body(...),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> ChatWorkflow:
|
||||
"""Start a new workflow or continue an existing one."""
|
||||
try:
|
||||
from modules.workflows.automation import chatStart
|
||||
mandateId = _validateAutomationInstanceAccess(instanceId, context)
|
||||
services = getAutomationServices(
|
||||
context.user,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=instanceId,
|
||||
)
|
||||
services.featureCode = "automation"
|
||||
if hasattr(userInput, 'allowedProviders') and userInput.allowedProviders:
|
||||
services.allowedProviders = userInput.allowedProviders
|
||||
workflow = await chatStart(
|
||||
context.user,
|
||||
userInput,
|
||||
workflowMode,
|
||||
workflowId,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=instanceId,
|
||||
featureCode="automation",
|
||||
services=services,
|
||||
)
|
||||
return workflow
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in start_automation_workflow: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{automationId}/execute", response_model=ChatWorkflow)
|
||||
@limiter.limit("5/minute")
|
||||
async def execute_automation_route(
|
||||
|
|
|
|||
|
|
@ -1291,17 +1291,6 @@ class ChatObjects:
|
|||
logger.error(f"Error updating message {messageId}: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Error updating message {messageId}: {str(e)}")
|
||||
|
||||
def createStat(self, statData: Dict[str, Any]):
|
||||
"""Create stat record. Compatibility with ChatService; stats may not be persisted in chatbot schema."""
|
||||
from modules.datamodels.datamodelChat import ChatStat
|
||||
stat = ChatStat(**statData)
|
||||
try:
|
||||
created = self.db.recordCreate(ChatStat, statData)
|
||||
return ChatStat(**created)
|
||||
except Exception as e:
|
||||
logger.debug(f"createStat: not persisting (chatbot schema): {e}")
|
||||
return stat
|
||||
|
||||
def deleteMessage(self, conversationId: str, messageId: str) -> bool:
|
||||
"""Deletes a conversation message and related data if user has access."""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -179,7 +179,7 @@ def getChatbotServices(
|
|||
for spec in REQUIRED_SERVICES:
|
||||
key = spec["serviceKey"]
|
||||
try:
|
||||
svc = getService(key, ctx, legacy_hub=None)
|
||||
svc = getService(key, ctx)
|
||||
setattr(hub, key, svc)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not resolve service '{key}' for chatbot: {e}")
|
||||
|
|
@ -197,7 +197,7 @@ def getChatStreamingHelper():
|
|||
from modules.serviceCenter.context import ServiceCenterContext
|
||||
# Minimal context - streaming service only needs it for resolver
|
||||
ctx = ServiceCenterContext(user=__get_placeholder_user(), mandate_id=None, feature_instance_id=None)
|
||||
streaming = getService("streaming", ctx, legacy_hub=None)
|
||||
streaming = getService("streaming", ctx)
|
||||
return streaming.getChatStreamingHelper() if streaming else None
|
||||
|
||||
|
||||
|
|
@ -219,7 +219,7 @@ def getEventManager(user, mandateId: Optional[str] = None, featureInstanceId: Op
|
|||
mandate_id=mandateId,
|
||||
feature_instance_id=featureInstanceId,
|
||||
)
|
||||
streaming = getService("streaming", ctx, legacy_hub=None)
|
||||
streaming = getService("streaming", ctx)
|
||||
return streaming.getEventManager()
|
||||
|
||||
|
||||
|
|
@ -306,12 +306,12 @@ def getChatbotServices(
|
|||
Uses interfaceFeatureChatbot (ChatObjects) for interfaceDbChat to avoid
|
||||
duplicate DB init - chatProcess reuses hub.interfaceDbChat.
|
||||
"""
|
||||
from modules.services import PublicService
|
||||
from modules.serviceHub import PublicService
|
||||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||||
from modules.features.chatbot.interfaceFeatureChatbot import getInterface as getChatbotInterface
|
||||
from modules.services.serviceChat.mainServiceChat import ChatService
|
||||
from modules.services.serviceAi.mainServiceAi import AiService
|
||||
from modules.services.serviceStreaming.mainServiceStreaming import StreamingService
|
||||
from modules.serviceCenter.services.serviceChat.mainServiceChat import ChatService
|
||||
from modules.serviceCenter.services.serviceAi.mainServiceAi import AiService
|
||||
from modules.serviceCenter.core.serviceStreaming.mainServiceStreaming import StreamingService
|
||||
|
||||
hub = _ChatbotServiceHub()
|
||||
hub.user = user
|
||||
|
|
@ -344,7 +344,7 @@ def getChatbotServices(
|
|||
feature_instance_id=featureInstanceId,
|
||||
workflow=_workflow,
|
||||
)
|
||||
hub.billing = getService("billing", ctx, legacy_hub=None)
|
||||
hub.billing = getService("billing", ctx)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not resolve billing service for chatbot: {e}")
|
||||
hub.billing = None
|
||||
|
|
|
|||
|
|
@ -1,6 +0,0 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Chat Playground Feature Container.
|
||||
Provides workflow-based chat playground functionality.
|
||||
"""
|
||||
|
|
@ -1,145 +0,0 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Chat Playground Feature Interface.
|
||||
Wrapper around interfaceDbChat with feature instance context.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.interfaces import interfaceDbChat
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Feature code constant
|
||||
FEATURE_CODE = "chatplayground"
|
||||
|
||||
# Singleton instances cache
|
||||
_instances: Dict[str, "ChatPlaygroundObjects"] = {}
|
||||
|
||||
|
||||
def getInterface(currentUser: User, mandateId: str = None, featureInstanceId: str = None) -> "ChatPlaygroundObjects":
|
||||
"""
|
||||
Factory function to get or create a ChatPlaygroundObjects instance.
|
||||
Uses singleton pattern per user context.
|
||||
|
||||
Args:
|
||||
currentUser: Current user object
|
||||
mandateId: Mandate ID
|
||||
featureInstanceId: Feature instance ID
|
||||
|
||||
Returns:
|
||||
ChatPlaygroundObjects instance
|
||||
"""
|
||||
cacheKey = f"{currentUser.id}_{mandateId}_{featureInstanceId}"
|
||||
|
||||
if cacheKey not in _instances:
|
||||
_instances[cacheKey] = ChatPlaygroundObjects(currentUser, mandateId, featureInstanceId)
|
||||
else:
|
||||
# Update context if needed
|
||||
_instances[cacheKey].setUserContext(currentUser, mandateId, featureInstanceId)
|
||||
|
||||
return _instances[cacheKey]
|
||||
|
||||
|
||||
class ChatPlaygroundObjects:
|
||||
"""
|
||||
Chat Playground feature interface.
|
||||
Wraps the shared interfaceDbChat with feature instance context.
|
||||
"""
|
||||
|
||||
FEATURE_CODE = FEATURE_CODE
|
||||
|
||||
def __init__(self, currentUser: User, mandateId: str = None, featureInstanceId: str = None):
|
||||
"""
|
||||
Initialize the Chat Playground interface.
|
||||
|
||||
Args:
|
||||
currentUser: Current user object
|
||||
mandateId: Mandate ID
|
||||
featureInstanceId: Feature instance ID
|
||||
"""
|
||||
self.currentUser = currentUser
|
||||
self.mandateId = mandateId
|
||||
self.featureInstanceId = featureInstanceId
|
||||
|
||||
# Get the underlying chat interface
|
||||
self._chatInterface = interfaceDbChat.getInterface(
|
||||
currentUser,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=featureInstanceId
|
||||
)
|
||||
|
||||
def setUserContext(self, currentUser: User, mandateId: str = None, featureInstanceId: str = None):
|
||||
"""
|
||||
Update the user context.
|
||||
|
||||
Args:
|
||||
currentUser: Current user object
|
||||
mandateId: Mandate ID
|
||||
featureInstanceId: Feature instance ID
|
||||
"""
|
||||
self.currentUser = currentUser
|
||||
self.mandateId = mandateId
|
||||
self.featureInstanceId = featureInstanceId
|
||||
|
||||
# Update underlying interface
|
||||
self._chatInterface = interfaceDbChat.getInterface(
|
||||
currentUser,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=featureInstanceId
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Delegated methods from interfaceDbChat
|
||||
# =========================================================================
|
||||
|
||||
def getWorkflow(self, workflowId: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a workflow by ID."""
|
||||
return self._chatInterface.getWorkflow(workflowId)
|
||||
|
||||
def getWorkflows(self, pagination=None) -> Dict[str, Any]:
|
||||
"""Get all workflows with pagination."""
|
||||
return self._chatInterface.getWorkflows(pagination=pagination)
|
||||
|
||||
def getUnifiedChatData(self, workflowId: str, afterTimestamp: float = None) -> Dict[str, Any]:
|
||||
"""Get unified chat data for a workflow."""
|
||||
return self._chatInterface.getUnifiedChatData(workflowId, afterTimestamp)
|
||||
|
||||
def createWorkflow(self, workflow) -> Dict[str, Any]:
|
||||
"""Create a new workflow."""
|
||||
return self._chatInterface.createWorkflow(workflow)
|
||||
|
||||
def updateWorkflow(self, workflowId: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Update a workflow."""
|
||||
return self._chatInterface.updateWorkflow(workflowId, updates)
|
||||
|
||||
def deleteWorkflow(self, workflowId: str) -> bool:
|
||||
"""Delete a workflow."""
|
||||
return self._chatInterface.deleteWorkflow(workflowId)
|
||||
|
||||
def getMessages(self, workflowId: str) -> List[Dict[str, Any]]:
|
||||
"""Get messages for a workflow."""
|
||||
return self._chatInterface.getMessages(workflowId)
|
||||
|
||||
def createMessage(self, message) -> Dict[str, Any]:
|
||||
"""Create a new message."""
|
||||
return self._chatInterface.createMessage(message)
|
||||
|
||||
def getLogs(self, workflowId: str) -> List[Dict[str, Any]]:
|
||||
"""Get logs for a workflow."""
|
||||
return self._chatInterface.getLogs(workflowId)
|
||||
|
||||
def createLog(self, log) -> Dict[str, Any]:
|
||||
"""Create a new log entry."""
|
||||
return self._chatInterface.createLog(log)
|
||||
|
||||
def getStats(self, workflowId: str) -> List[Dict[str, Any]]:
|
||||
"""Get stats for a workflow."""
|
||||
return self._chatInterface.getStats(workflowId)
|
||||
|
||||
def createStat(self, stat) -> Dict[str, Any]:
|
||||
"""Create a new stat entry."""
|
||||
return self._chatInterface.createStat(stat)
|
||||
|
|
@ -1,381 +0,0 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Chat Playground Feature Container - Main Module.
|
||||
Handles feature initialization and RBAC catalog registration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Feature metadata
|
||||
FEATURE_CODE = "chatplayground"
|
||||
FEATURE_LABEL = {"en": "Chat Playground", "de": "Chat Playground", "fr": "Chat Playground"}
|
||||
FEATURE_ICON = "mdi-message-text"
|
||||
|
||||
# UI Objects for RBAC catalog
|
||||
UI_OBJECTS = [
|
||||
{
|
||||
"objectKey": "ui.feature.chatplayground.playground",
|
||||
"label": {"en": "Playground", "de": "Playground", "fr": "Playground"},
|
||||
"meta": {"area": "playground"}
|
||||
},
|
||||
{
|
||||
"objectKey": "ui.feature.chatplayground.workflows",
|
||||
"label": {"en": "Workflows", "de": "Workflows", "fr": "Workflows"},
|
||||
"meta": {"area": "workflows"}
|
||||
},
|
||||
]
|
||||
|
||||
# Resource Objects for RBAC catalog
|
||||
RESOURCE_OBJECTS = [
|
||||
{
|
||||
"objectKey": "resource.feature.chatplayground.start",
|
||||
"label": {"en": "Start Workflow", "de": "Workflow starten", "fr": "Démarrer workflow"},
|
||||
"meta": {"endpoint": "/api/chatplayground/{instanceId}/start", "method": "POST"}
|
||||
},
|
||||
{
|
||||
"objectKey": "resource.feature.chatplayground.stop",
|
||||
"label": {"en": "Stop Workflow", "de": "Workflow stoppen", "fr": "Arrêter workflow"},
|
||||
"meta": {"endpoint": "/api/chatplayground/{instanceId}/workflows/{workflowId}/stop", "method": "POST"}
|
||||
},
|
||||
{
|
||||
"objectKey": "resource.feature.chatplayground.chatData",
|
||||
"label": {"en": "Get Chat Data", "de": "Chat-Daten abrufen", "fr": "Récupérer données chat"},
|
||||
"meta": {"endpoint": "/api/chatplayground/{instanceId}/workflows/{workflowId}/chatData", "method": "GET"}
|
||||
},
|
||||
]
|
||||
|
||||
# Service requirements - services this feature needs from the service center
|
||||
# Same as automation: chatplayground runs the same WorkflowManager and workflow methods
|
||||
REQUIRED_SERVICES = [
|
||||
{"serviceKey": "chat", "meta": {"usage": "Workflow CRUD, messages, logs"}},
|
||||
{"serviceKey": "ai", "meta": {"usage": "AI planning for workflow execution"}},
|
||||
{"serviceKey": "utils", "meta": {"usage": "Timestamps, utilities"}},
|
||||
{"serviceKey": "billing", "meta": {"usage": "AI call billing"}},
|
||||
{"serviceKey": "extraction", "meta": {"usage": "Workflow method actions"}},
|
||||
{"serviceKey": "sharepoint", "meta": {"usage": "SharePoint actions (listDocuments, uploadDocument, etc.)"}},
|
||||
{"serviceKey": "generation", "meta": {"usage": "Action completion messages, document creation from results"}},
|
||||
]
|
||||
# Template roles for this feature
|
||||
# Role names MUST follow convention: {featureCode}-{roleName}
|
||||
TEMPLATE_ROLES = [
|
||||
{
|
||||
"roleLabel": "chatplayground-viewer",
|
||||
"description": {
|
||||
"en": "Chat Playground Viewer - View chat playground (read-only)",
|
||||
"de": "Chat Playground Betrachter - Chat Playground ansehen (nur lesen)",
|
||||
"fr": "Visualiseur Chat Playground - Consulter le chat playground (lecture seule)"
|
||||
},
|
||||
"accessRules": [
|
||||
# UI: only playground view, NO workflows
|
||||
{"context": "UI", "item": "ui.feature.chatplayground.playground", "view": True},
|
||||
# RESOURCE: NO access (viewer cannot start/stop/access chat data)
|
||||
# DATA access (own records, read-only)
|
||||
{"context": "DATA", "item": None, "view": True, "read": "m", "create": "n", "update": "n", "delete": "n"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"roleLabel": "chatplayground-user",
|
||||
"description": {
|
||||
"en": "Chat Playground User - Use chat playground and workflows",
|
||||
"de": "Chat Playground Benutzer - Chat Playground und Workflows nutzen",
|
||||
"fr": "Utilisateur Chat Playground - Utiliser le chat playground et les workflows"
|
||||
},
|
||||
"accessRules": [
|
||||
# UI: full access to all views
|
||||
{"context": "UI", "item": "ui.feature.chatplayground.playground", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.chatplayground.workflows", "view": True},
|
||||
# Resource access: can start/stop workflows and access chat data
|
||||
{"context": "RESOURCE", "item": "resource.feature.chatplayground.start", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.chatplayground.stop", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.chatplayground.chatData", "view": True},
|
||||
# DATA access (own records)
|
||||
{"context": "DATA", "item": None, "view": True, "read": "m", "create": "m", "update": "m", "delete": "m"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"roleLabel": "chatplayground-admin",
|
||||
"description": {
|
||||
"en": "Chat Playground Admin - Full access to chat playground",
|
||||
"de": "Chat Playground Admin - Vollzugriff auf Chat Playground",
|
||||
"fr": "Administrateur Chat Playground - Accès complet au chat playground"
|
||||
},
|
||||
"accessRules": [
|
||||
# Full UI access
|
||||
{"context": "UI", "item": None, "view": True},
|
||||
# Full resource access
|
||||
{"context": "RESOURCE", "item": None, "view": True},
|
||||
# Full DATA access
|
||||
{"context": "DATA", "item": None, "view": True, "read": "a", "create": "a", "update": "a", "delete": "a"},
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def getRequiredServiceKeys() -> List[str]:
|
||||
"""Return list of service keys this feature requires."""
|
||||
return [s["serviceKey"] for s in REQUIRED_SERVICES]
|
||||
|
||||
|
||||
def getChatplaygroundServices(
|
||||
user,
|
||||
mandateId: Optional[str] = None,
|
||||
featureInstanceId: Optional[str] = None,
|
||||
workflow=None,
|
||||
) -> "_ChatplaygroundServiceHub":
|
||||
"""
|
||||
Get a service hub for the chatplayground feature using the service center.
|
||||
Resolves only the services declared in REQUIRED_SERVICES.
|
||||
No legacy fallback - service center only.
|
||||
|
||||
Returns a hub-like object with: chat, ai, utils, billing, extraction,
|
||||
sharepoint, rbac, interfaceDbApp, interfaceDbComponent, interfaceDbChat.
|
||||
"""
|
||||
from modules.serviceCenter import getService
|
||||
from modules.serviceCenter.context import ServiceCenterContext
|
||||
|
||||
_workflow = workflow
|
||||
if _workflow is None:
|
||||
_workflow = type("_Placeholder", (), {"featureCode": FEATURE_CODE})()
|
||||
ctx = ServiceCenterContext(
|
||||
user=user,
|
||||
mandate_id=mandateId,
|
||||
feature_instance_id=featureInstanceId,
|
||||
workflow=_workflow,
|
||||
)
|
||||
|
||||
hub = _ChatplaygroundServiceHub()
|
||||
hub.user = user
|
||||
hub.mandateId = mandateId
|
||||
hub.featureInstanceId = featureInstanceId
|
||||
hub.workflow = workflow
|
||||
hub.featureCode = FEATURE_CODE
|
||||
hub.allowedProviders = None
|
||||
|
||||
for spec in REQUIRED_SERVICES:
|
||||
key = spec["serviceKey"]
|
||||
try:
|
||||
svc = getService(key, ctx, legacy_hub=None)
|
||||
setattr(hub, key, svc)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not resolve service '{key}' for chatplayground: {e}")
|
||||
setattr(hub, key, None)
|
||||
|
||||
# Copy interfaces from chat service for WorkflowManager compatibility
|
||||
if hub.chat:
|
||||
hub.interfaceDbApp = getattr(hub.chat, "interfaceDbApp", None)
|
||||
hub.interfaceDbComponent = getattr(hub.chat, "interfaceDbComponent", None)
|
||||
hub.interfaceDbChat = getattr(hub.chat, "interfaceDbChat", None)
|
||||
|
||||
# RBAC for MethodBase action permission checks (workflow methods)
|
||||
hub.rbac = getattr(hub.interfaceDbApp, "rbac", None) if hub.interfaceDbApp else None
|
||||
|
||||
return hub
|
||||
|
||||
|
||||
class _ChatplaygroundServiceHub:
|
||||
"""Lightweight hub exposing only services required by the chatplayground feature."""
|
||||
|
||||
user = None
|
||||
mandateId = None
|
||||
featureInstanceId = None
|
||||
workflow = None
|
||||
featureCode = "chatplayground"
|
||||
allowedProviders = None
|
||||
interfaceDbApp = None
|
||||
interfaceDbComponent = None
|
||||
interfaceDbChat = None
|
||||
rbac = None
|
||||
chat = None
|
||||
ai = None
|
||||
utils = None
|
||||
billing = None
|
||||
extraction = None
|
||||
sharepoint = None
|
||||
|
||||
|
||||
def getFeatureDefinition() -> Dict[str, Any]:
|
||||
"""Return the feature definition for registration."""
|
||||
return {
|
||||
"code": FEATURE_CODE,
|
||||
"label": FEATURE_LABEL,
|
||||
"icon": FEATURE_ICON,
|
||||
"autoCreateInstance": True, # Automatically create instance in root mandate during bootstrap
|
||||
}
|
||||
|
||||
|
||||
def getUiObjects() -> List[Dict[str, Any]]:
|
||||
"""Return UI objects for RBAC catalog registration."""
|
||||
return UI_OBJECTS
|
||||
|
||||
|
||||
def getResourceObjects() -> List[Dict[str, Any]]:
|
||||
"""Return resource objects for RBAC catalog registration."""
|
||||
return RESOURCE_OBJECTS
|
||||
|
||||
|
||||
def getTemplateRoles() -> List[Dict[str, Any]]:
|
||||
"""Return template roles for this feature."""
|
||||
return TEMPLATE_ROLES
|
||||
|
||||
|
||||
def registerFeature(catalogService) -> bool:
|
||||
"""
|
||||
Register this feature's RBAC objects in the catalog.
|
||||
|
||||
Args:
|
||||
catalogService: The RBAC catalog service instance
|
||||
|
||||
Returns:
|
||||
True if registration was successful
|
||||
"""
|
||||
try:
|
||||
# Register UI objects
|
||||
for uiObj in UI_OBJECTS:
|
||||
catalogService.registerUiObject(
|
||||
featureCode=FEATURE_CODE,
|
||||
objectKey=uiObj["objectKey"],
|
||||
label=uiObj["label"],
|
||||
meta=uiObj.get("meta")
|
||||
)
|
||||
|
||||
# Register Resource objects
|
||||
for resObj in RESOURCE_OBJECTS:
|
||||
catalogService.registerResourceObject(
|
||||
featureCode=FEATURE_CODE,
|
||||
objectKey=resObj["objectKey"],
|
||||
label=resObj["label"],
|
||||
meta=resObj.get("meta")
|
||||
)
|
||||
|
||||
# Sync template roles to database
|
||||
_syncTemplateRolesToDb()
|
||||
|
||||
logger.info(f"Feature '{FEATURE_CODE}' registered {len(UI_OBJECTS)} UI objects and {len(RESOURCE_OBJECTS)} resource objects")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register feature '{FEATURE_CODE}': {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _syncTemplateRolesToDb() -> int:
|
||||
"""
|
||||
Sync template roles and their AccessRules to the database.
|
||||
Creates global template roles (mandateId=None) if they don't exist.
|
||||
|
||||
Returns:
|
||||
Number of roles created/updated
|
||||
"""
|
||||
try:
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
from modules.datamodels.datamodelRbac import Role, AccessRule, AccessRuleContext
|
||||
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Get existing template roles for this feature (Pydantic models)
|
||||
existingRoles = rootInterface.getRolesByFeatureCode(FEATURE_CODE)
|
||||
# Filter to template roles (mandateId is None)
|
||||
templateRoles = [r for r in existingRoles if r.mandateId is None]
|
||||
existingRoleLabels = {r.roleLabel: str(r.id) for r in templateRoles}
|
||||
|
||||
createdCount = 0
|
||||
for roleTemplate in TEMPLATE_ROLES:
|
||||
roleLabel = roleTemplate["roleLabel"]
|
||||
|
||||
if roleLabel in existingRoleLabels:
|
||||
roleId = existingRoleLabels[roleLabel]
|
||||
# Ensure AccessRules exist for this role
|
||||
_ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", []))
|
||||
else:
|
||||
# Create new template role
|
||||
newRole = Role(
|
||||
roleLabel=roleLabel,
|
||||
description=roleTemplate.get("description", {}),
|
||||
featureCode=FEATURE_CODE,
|
||||
mandateId=None, # Global template
|
||||
featureInstanceId=None,
|
||||
isSystemRole=False
|
||||
)
|
||||
createdRole = rootInterface.db.recordCreate(Role, newRole.model_dump())
|
||||
roleId = createdRole.get("id")
|
||||
|
||||
# Create AccessRules for this role
|
||||
_ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", []))
|
||||
|
||||
logger.info(f"Created template role '{roleLabel}' with ID {roleId}")
|
||||
createdCount += 1
|
||||
|
||||
if createdCount > 0:
|
||||
logger.info(f"Feature '{FEATURE_CODE}': Created {createdCount} template roles")
|
||||
|
||||
return createdCount
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing template roles for feature '{FEATURE_CODE}': {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: List[Dict[str, Any]]) -> int:
|
||||
"""
|
||||
Ensure AccessRules exist for a role based on templates.
|
||||
|
||||
Args:
|
||||
rootInterface: Root interface instance
|
||||
roleId: Role ID
|
||||
ruleTemplates: List of rule templates
|
||||
|
||||
Returns:
|
||||
Number of rules created
|
||||
"""
|
||||
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
|
||||
|
||||
# Get existing rules for this role (Pydantic models)
|
||||
existingRules = rootInterface.getAccessRulesByRole(roleId)
|
||||
|
||||
# Create a set of existing rule signatures to avoid duplicates
|
||||
# IMPORTANT: Use .value for enum comparison, not str() which gives "AccessRuleContext.DATA" in Python 3.11+
|
||||
existingSignatures = set()
|
||||
for rule in existingRules:
|
||||
sig = (rule.context.value if rule.context else None, rule.item)
|
||||
existingSignatures.add(sig)
|
||||
|
||||
createdCount = 0
|
||||
for template in ruleTemplates:
|
||||
context = template.get("context", "UI")
|
||||
item = template.get("item")
|
||||
sig = (context, item)
|
||||
|
||||
if sig in existingSignatures:
|
||||
continue
|
||||
|
||||
# Map context string to enum
|
||||
if context == "UI":
|
||||
contextEnum = AccessRuleContext.UI
|
||||
elif context == "DATA":
|
||||
contextEnum = AccessRuleContext.DATA
|
||||
elif context == "RESOURCE":
|
||||
contextEnum = AccessRuleContext.RESOURCE
|
||||
else:
|
||||
contextEnum = context
|
||||
|
||||
newRule = AccessRule(
|
||||
roleId=roleId,
|
||||
context=contextEnum,
|
||||
item=item,
|
||||
view=template.get("view", False),
|
||||
read=template.get("read"),
|
||||
create=template.get("create"),
|
||||
update=template.get("update"),
|
||||
delete=template.get("delete"),
|
||||
)
|
||||
rootInterface.db.recordCreate(AccessRule, newRule.model_dump())
|
||||
createdCount += 1
|
||||
|
||||
if createdCount > 0:
|
||||
logger.debug(f"Created {createdCount} AccessRules for role {roleId}")
|
||||
|
||||
return createdCount
|
||||
|
|
@ -1,719 +0,0 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Chat Playground Feature Routes.
|
||||
Implements the endpoints for chat playground workflow management as a feature.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request, status
|
||||
|
||||
# Import auth modules
|
||||
from modules.auth import limiter, getRequestContext, RequestContext
|
||||
|
||||
# Import interfaces
|
||||
from modules.interfaces import interfaceDbChat
|
||||
|
||||
# Import models
|
||||
from modules.datamodels.datamodelChat import (
|
||||
ChatWorkflow,
|
||||
ChatMessage,
|
||||
ChatLog,
|
||||
UserInputRequest,
|
||||
WorkflowModeEnum,
|
||||
)
|
||||
from modules.datamodels.datamodelPagination import (
|
||||
PaginationParams,
|
||||
PaginatedResponse,
|
||||
PaginationMetadata,
|
||||
normalize_pagination_dict,
|
||||
)
|
||||
|
||||
# Import workflow control functions
|
||||
from modules.workflows.automation import chatStart, chatStop
|
||||
from modules.features.chatplayground.mainChatplayground import getChatplaygroundServices
|
||||
from modules.shared.attributeUtils import getModelAttributeDefinitions
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Model attributes for ChatWorkflow (workflow attributes endpoint)
|
||||
workflowAttributes = getModelAttributeDefinitions(ChatWorkflow)
|
||||
|
||||
# Create router for chat playground feature endpoints
|
||||
router = APIRouter(
|
||||
prefix="/api/chatplayground",
|
||||
tags=["Chat Playground Feature"],
|
||||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
|
||||
|
||||
def _getServiceChat(context: RequestContext, featureInstanceId: str = None):
|
||||
"""Get chat interface with feature instance context."""
|
||||
return interfaceDbChat.getInterface(
|
||||
context.user,
|
||||
mandateId=str(context.mandateId) if context.mandateId else None,
|
||||
featureInstanceId=featureInstanceId
|
||||
)
|
||||
|
||||
|
||||
def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str:
|
||||
"""
|
||||
Validate that user has access to the feature instance.
|
||||
|
||||
Args:
|
||||
instanceId: Feature instance ID
|
||||
context: Request context
|
||||
|
||||
Returns:
|
||||
mandateId for the instance
|
||||
|
||||
Raises:
|
||||
HTTPException if access is denied
|
||||
"""
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Get feature instance (Pydantic model)
|
||||
instance = rootInterface.getFeatureInstance(instanceId)
|
||||
if not instance:
|
||||
raise HTTPException(status_code=404, detail=f"Feature instance {instanceId} not found")
|
||||
|
||||
# Check user has access to this instance using interface method
|
||||
featureAccess = rootInterface.getFeatureAccess(str(context.user.id), instanceId)
|
||||
|
||||
if not featureAccess or not featureAccess.enabled:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this feature instance")
|
||||
|
||||
return str(instance.mandateId) if instance.mandateId else None
|
||||
|
||||
|
||||
# Workflow start endpoint
|
||||
@router.post("/{instanceId}/start", response_model=ChatWorkflow)
|
||||
@limiter.limit("120/minute")
|
||||
async def start_workflow(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: Optional[str] = Query(None, description="Optional ID of the workflow to continue"),
|
||||
workflowMode: WorkflowModeEnum = Query(..., description="Workflow mode: 'Dynamic' or 'Automation' (mandatory)"),
|
||||
userInput: UserInputRequest = Body(...),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> ChatWorkflow:
|
||||
"""
|
||||
Starts a new workflow or continues an existing one.
|
||||
|
||||
Args:
|
||||
instanceId: Feature instance ID
|
||||
workflowMode: "Dynamic" for iterative dynamic-style processing, "Automation" for automated workflow execution
|
||||
"""
|
||||
try:
|
||||
# Validate access and get mandate ID
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
|
||||
# Get chatplayground services from service center (not automation)
|
||||
services = getChatplaygroundServices(
|
||||
context.user,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=instanceId,
|
||||
)
|
||||
services.featureCode = 'chatplayground'
|
||||
if hasattr(userInput, 'allowedProviders') and userInput.allowedProviders:
|
||||
services.allowedProviders = userInput.allowedProviders
|
||||
|
||||
# Start or continue workflow
|
||||
workflow = await chatStart(
|
||||
context.user,
|
||||
userInput,
|
||||
workflowMode,
|
||||
workflowId,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=instanceId,
|
||||
featureCode='chatplayground',
|
||||
services=services,
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in start_workflow: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
|
||||
# Stop workflow endpoint (under /workflows/{workflowId}/ for consistency)
|
||||
@router.post("/{instanceId}/workflows/{workflowId}/stop", response_model=ChatWorkflow)
|
||||
@limiter.limit("120/minute")
|
||||
async def stop_workflow(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="ID of the workflow to stop"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> ChatWorkflow:
|
||||
"""Stops a running workflow."""
|
||||
try:
|
||||
# Validate access and get mandate ID
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
|
||||
# Get chatplayground services from service center (not automation)
|
||||
services = getChatplaygroundServices(
|
||||
context.user,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=instanceId,
|
||||
)
|
||||
services.featureCode = 'chatplayground'
|
||||
|
||||
# Stop workflow (pass featureInstanceId for proper RBAC filtering)
|
||||
workflow = await chatStop(
|
||||
context.user,
|
||||
workflowId,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=instanceId,
|
||||
featureCode='chatplayground',
|
||||
services=services,
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stop_workflow: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
|
||||
# Unified Chat Data Endpoint for Polling (under /workflows/{workflowId}/ for consistency)
|
||||
@router.get("/{instanceId}/workflows/{workflowId}/chatData")
|
||||
@limiter.limit("120/minute")
|
||||
def get_workflow_chat_data(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
afterTimestamp: Optional[float] = Query(None, description="Unix timestamp to get data after"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get unified chat data (messages, logs, stats) for a workflow with timestamp-based selective data transfer.
|
||||
Returns all data types in chronological order based on _createdAt timestamp.
|
||||
"""
|
||||
try:
|
||||
# Validate access
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
|
||||
# Get service with feature instance context
|
||||
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
|
||||
|
||||
# Verify workflow exists
|
||||
workflow = chatInterface.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Workflow with ID {workflowId} not found"
|
||||
)
|
||||
|
||||
# Get unified chat data
|
||||
chatData = chatInterface.getUnifiedChatData(workflowId, afterTimestamp)
|
||||
|
||||
return chatData
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting unified chat data: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error getting unified chat data: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# Get workflow attributes (ChatWorkflow model)
|
||||
@router.get("/{instanceId}/workflows/attributes", response_model=Dict[str, Any])
|
||||
@limiter.limit("120/minute")
|
||||
def get_workflow_attributes(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get attribute definitions for ChatWorkflow model."""
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
return {"attributes": workflowAttributes}
|
||||
|
||||
|
||||
# Get workflows for this instance
|
||||
@router.get("/{instanceId}/workflows", response_model=PaginatedResponse[ChatWorkflow])
|
||||
@limiter.limit("120/minute")
|
||||
def get_workflows(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
page: int = Query(1, ge=1, description="Page number (legacy)"),
|
||||
pageSize: int = Query(20, ge=1, le=100, description="Items per page (legacy)"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> PaginatedResponse[ChatWorkflow]:
|
||||
"""
|
||||
Get all workflows for this feature instance with optional pagination.
|
||||
"""
|
||||
try:
|
||||
# Validate access
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
|
||||
# Get service with feature instance context
|
||||
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
|
||||
|
||||
# Parse pagination parameter
|
||||
paginationParams = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid pagination parameter: {str(e)}"
|
||||
)
|
||||
else:
|
||||
paginationParams = PaginationParams(page=page, pageSize=pageSize)
|
||||
|
||||
result = chatInterface.getWorkflows(pagination=paginationParams)
|
||||
|
||||
if paginationParams:
|
||||
return PaginatedResponse(
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=result.totalItems,
|
||||
totalPages=result.totalPages,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters
|
||||
)
|
||||
)
|
||||
else:
|
||||
return PaginatedResponse(items=result, pagination=None)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflows: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error getting workflows: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# Action Discovery Endpoints (must be before /{workflowId} to avoid path conflict)
|
||||
@router.get("/{instanceId}/workflows/actions", response_model=Dict[str, Any])
|
||||
@limiter.limit("120/minute")
|
||||
def get_all_workflow_actions(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get all available workflow actions for the current user (filtered by RBAC)."""
|
||||
try:
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
services = getChatplaygroundServices(
|
||||
context.user,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=instanceId,
|
||||
)
|
||||
from modules.workflows.processing.shared.methodDiscovery import discoverMethods, methods
|
||||
discoverMethods(services)
|
||||
allActions = []
|
||||
for methodName, methodInfo in methods.items():
|
||||
if methodName.startswith('Method'):
|
||||
continue
|
||||
methodInstance = methodInfo['instance']
|
||||
methodActions = methodInstance.actions
|
||||
for actionName, actionInfo in methodActions.items():
|
||||
actionResponse = {
|
||||
"module": methodInstance.name,
|
||||
"actionId": f"{methodInstance.name}.{actionName}",
|
||||
"name": actionName,
|
||||
"description": actionInfo.get('description', ''),
|
||||
"parameters": actionInfo.get('parameters', {})
|
||||
}
|
||||
allActions.append(actionResponse)
|
||||
return {"actions": allActions}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all actions: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get actions: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{instanceId}/workflows/actions/{method}", response_model=Dict[str, Any])
|
||||
@limiter.limit("120/minute")
|
||||
def get_method_workflow_actions(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
method: str = Path(..., description="Method name (e.g., 'outlook', 'sharepoint')"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get all available actions for a specific method."""
|
||||
try:
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
services = getChatplaygroundServices(
|
||||
context.user,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=instanceId,
|
||||
)
|
||||
from modules.workflows.processing.shared.methodDiscovery import discoverMethods, methods
|
||||
discoverMethods(services)
|
||||
methodInstance = None
|
||||
for methodName, methodInfo in methods.items():
|
||||
if methodInfo['instance'].name == method:
|
||||
methodInstance = methodInfo['instance']
|
||||
break
|
||||
if not methodInstance:
|
||||
raise HTTPException(status_code=404, detail=f"Method '{method}' not found")
|
||||
actions = []
|
||||
for actionName, actionInfo in methodInstance.actions.items():
|
||||
actionResponse = {
|
||||
"actionId": f"{methodInstance.name}.{actionName}",
|
||||
"name": actionName,
|
||||
"description": actionInfo.get('description', ''),
|
||||
"parameters": actionInfo.get('parameters', {})
|
||||
}
|
||||
actions.append(actionResponse)
|
||||
return {"module": methodInstance.name, "description": methodInstance.description, "actions": actions}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting actions for method {method}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get actions for method {method}: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{instanceId}/workflows/actions/{method}/{action}", response_model=Dict[str, Any])
|
||||
@limiter.limit("120/minute")
|
||||
def get_action_schema(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
method: str = Path(..., description="Method name (e.g., 'outlook', 'sharepoint')"),
|
||||
action: str = Path(..., description="Action name (e.g., 'readEmails', 'uploadDocument')"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get action schema with parameter definitions for a specific action."""
|
||||
try:
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
services = getChatplaygroundServices(
|
||||
context.user,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=instanceId,
|
||||
)
|
||||
from modules.workflows.processing.shared.methodDiscovery import discoverMethods, methods
|
||||
discoverMethods(services)
|
||||
methodInstance = None
|
||||
for methodName, methodInfo in methods.items():
|
||||
if methodInfo['instance'].name == method:
|
||||
methodInstance = methodInfo['instance']
|
||||
break
|
||||
if not methodInstance:
|
||||
raise HTTPException(status_code=404, detail=f"Method '{method}' not found")
|
||||
methodActions = methodInstance.actions
|
||||
if action not in methodActions:
|
||||
raise HTTPException(status_code=404, detail=f"Action '{action}' not found in method '{method}'")
|
||||
actionInfo = methodActions[action]
|
||||
return {
|
||||
"method": methodInstance.name,
|
||||
"action": action,
|
||||
"actionId": f"{methodInstance.name}.{action}",
|
||||
"description": actionInfo.get('description', ''),
|
||||
"parameters": actionInfo.get('parameters', {})
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting action schema for {method}.{action}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get action schema: {str(e)}")
|
||||
|
||||
|
||||
# Get single workflow by ID
|
||||
@router.get("/{instanceId}/workflows/{workflowId}", response_model=ChatWorkflow)
|
||||
@limiter.limit("120/minute")
|
||||
def get_workflow(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> ChatWorkflow:
|
||||
"""Get workflow by ID."""
|
||||
try:
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
|
||||
workflow = chatInterface.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
return workflow
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflow: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get workflow: {str(e)}")
|
||||
|
||||
|
||||
# Update workflow
|
||||
@router.put("/{instanceId}/workflows/{workflowId}", response_model=ChatWorkflow)
|
||||
@limiter.limit("120/minute")
|
||||
def update_workflow(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="ID of the workflow to update"),
|
||||
workflowData: Dict[str, Any] = Body(...),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> ChatWorkflow:
|
||||
"""Update workflow by ID."""
|
||||
try:
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
|
||||
workflow = chatInterface.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
if not chatInterface.checkRbacPermission(ChatWorkflow, "update", workflowId):
|
||||
raise HTTPException(status_code=403, detail="You don't have permission to update this workflow")
|
||||
updatedWorkflow = chatInterface.updateWorkflow(workflowId, workflowData)
|
||||
if not updatedWorkflow:
|
||||
raise HTTPException(status_code=500, detail="Failed to update workflow")
|
||||
return updatedWorkflow
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating workflow: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update workflow: {str(e)}")
|
||||
|
||||
|
||||
# Delete workflow
|
||||
@router.delete("/{instanceId}/workflows/{workflowId}")
|
||||
@limiter.limit("120/minute")
|
||||
def delete_workflow(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="ID of the workflow to delete"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Deletes a workflow and its associated data."""
|
||||
try:
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
|
||||
workflow = chatInterface.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail=f"Workflow with ID {workflowId} not found")
|
||||
if not chatInterface.checkRbacPermission(ChatWorkflow, "delete", workflowId):
|
||||
raise HTTPException(status_code=403, detail="You don't have permission to delete this workflow")
|
||||
success = chatInterface.deleteWorkflow(workflowId)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete workflow")
|
||||
return {"id": workflowId, "message": "Workflow and associated data deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting workflow: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Error deleting workflow: {str(e)}")
|
||||
|
||||
|
||||
# Get workflow status
|
||||
@router.get("/{instanceId}/workflows/{workflowId}/status", response_model=ChatWorkflow)
|
||||
@limiter.limit("120/minute")
|
||||
def get_workflow_status(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> ChatWorkflow:
|
||||
"""Get the current status of a workflow."""
|
||||
try:
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
|
||||
workflow = chatInterface.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail=f"Workflow with ID {workflowId} not found")
|
||||
return workflow
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflow status: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Error getting workflow status: {str(e)}")
|
||||
|
||||
|
||||
# Get workflow logs
|
||||
@router.get("/{instanceId}/workflows/{workflowId}/logs", response_model=PaginatedResponse[ChatLog])
|
||||
@limiter.limit("120/minute")
|
||||
def get_workflow_logs(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
logId: Optional[str] = Query(None, description="Optional log ID for selective data transfer"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> PaginatedResponse[ChatLog]:
|
||||
"""Get logs for a workflow with optional pagination."""
|
||||
try:
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
|
||||
workflow = chatInterface.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail=f"Workflow with ID {workflowId} not found")
|
||||
|
||||
paginationParams = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
|
||||
|
||||
result = chatInterface.getLogs(workflowId, pagination=paginationParams)
|
||||
|
||||
if logId:
|
||||
allLogs = result.items if paginationParams else result
|
||||
logIndex = next((i for i, log in enumerate(allLogs) if log.id == logId), -1)
|
||||
if logIndex >= 0:
|
||||
filteredLogs = allLogs[logIndex + 1:]
|
||||
return PaginatedResponse(items=filteredLogs, pagination=None)
|
||||
|
||||
if paginationParams:
|
||||
return PaginatedResponse(
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=result.totalItems,
|
||||
totalPages=result.totalPages,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters
|
||||
)
|
||||
)
|
||||
return PaginatedResponse(items=result, pagination=None)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflow logs: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Error getting workflow logs: {str(e)}")
|
||||
|
||||
|
||||
# Get workflow messages
|
||||
@router.get("/{instanceId}/workflows/{workflowId}/messages", response_model=PaginatedResponse[ChatMessage])
|
||||
@limiter.limit("120/minute")
|
||||
def get_workflow_messages(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
messageId: Optional[str] = Query(None, description="Optional message ID for selective data transfer"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> PaginatedResponse[ChatMessage]:
|
||||
"""Get messages for a workflow with optional pagination."""
|
||||
try:
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
|
||||
workflow = chatInterface.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail=f"Workflow with ID {workflowId} not found")
|
||||
|
||||
paginationParams = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
|
||||
|
||||
result = chatInterface.getMessages(workflowId, pagination=paginationParams)
|
||||
|
||||
if messageId:
|
||||
allMessages = result.items if paginationParams else result
|
||||
messageIndex = next((i for i, msg in enumerate(allMessages) if msg.id == messageId), -1)
|
||||
if messageIndex >= 0:
|
||||
filteredMessages = allMessages[messageIndex + 1:]
|
||||
return PaginatedResponse(items=filteredMessages, pagination=None)
|
||||
|
||||
if paginationParams:
|
||||
return PaginatedResponse(
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=result.totalItems,
|
||||
totalPages=result.totalPages,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters
|
||||
)
|
||||
)
|
||||
return PaginatedResponse(items=result, pagination=None)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflow messages: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Error getting workflow messages: {str(e)}")
|
||||
|
||||
|
||||
# Delete message from workflow
|
||||
@router.delete("/{instanceId}/workflows/{workflowId}/messages/{messageId}")
|
||||
@limiter.limit("120/minute")
|
||||
def delete_workflow_message(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
messageId: str = Path(..., description="ID of the message to delete"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Delete a message from a workflow."""
|
||||
try:
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
|
||||
workflow = chatInterface.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail=f"Workflow with ID {workflowId} not found")
|
||||
success = chatInterface.deleteMessage(workflowId, messageId)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail=f"Message with ID {messageId} not found in workflow {workflowId}")
|
||||
return {"workflowId": workflowId, "messageId": messageId, "message": "Message deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting message: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Error deleting message: {str(e)}")
|
||||
|
||||
|
||||
# Delete file from message
|
||||
@router.delete("/{instanceId}/workflows/{workflowId}/messages/{messageId}/files/{fileId}")
|
||||
@limiter.limit("120/minute")
|
||||
def delete_file_from_message(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
messageId: str = Path(..., description="ID of the message"),
|
||||
fileId: str = Path(..., description="ID of the file to delete"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Delete a file reference from a message in a workflow."""
|
||||
try:
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
|
||||
workflow = chatInterface.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail=f"Workflow with ID {workflowId} not found")
|
||||
success = chatInterface.deleteFileFromMessage(workflowId, messageId, fileId)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail=f"File with ID {fileId} not found in message {messageId}")
|
||||
return {"workflowId": workflowId, "messageId": messageId, "fileId": fileId, "message": "File reference deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting file reference: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Error deleting file reference: {str(e)}")
|
||||
|
|
@ -1 +0,0 @@
|
|||
"""CodeEditor Feature - Cursor-style AI file editing via chat interface."""
|
||||
|
|
@ -1,280 +0,0 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""CodeEditor processor -- single-shot (Phase 1) and agent loop (Phase 2).
|
||||
Orchestrates file loading, prompt building, AI calls, response parsing, and SSE emission."""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from modules.features.codeeditor import fileContextManager, promptAssembly, responseParser
|
||||
from modules.features.codeeditor.datamodelCodeeditor import (
|
||||
FileEditProposal, SegmentTypeEnum, AgentState
|
||||
)
|
||||
from modules.features.codeeditor import toolRegistry
|
||||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def processMessage(
|
||||
workflowId: str,
|
||||
userPrompt: str,
|
||||
selectedFileIds: List[str],
|
||||
dbManagement,
|
||||
interfaceAi,
|
||||
chatInterface,
|
||||
eventManager,
|
||||
agentMode: bool = False
|
||||
):
|
||||
"""Process a user message. Dispatches to single-shot or agent loop based on mode."""
|
||||
if agentMode:
|
||||
await _processAgentMessage(
|
||||
workflowId, userPrompt, dbManagement, interfaceAi, chatInterface, eventManager
|
||||
)
|
||||
else:
|
||||
await _processSingleShot(
|
||||
workflowId, userPrompt, selectedFileIds, dbManagement, interfaceAi, chatInterface, eventManager
|
||||
)
|
||||
|
||||
|
||||
async def _processSingleShot(
|
||||
workflowId, userPrompt, selectedFileIds, dbManagement, interfaceAi, chatInterface, eventManager
|
||||
):
|
||||
"""Phase 1: Single AI call with pre-loaded file context."""
|
||||
try:
|
||||
await _emitStatus(eventManager, workflowId, "Loading files...")
|
||||
fileContexts = await fileContextManager.loadFileContexts(dbManagement, selectedFileIds)
|
||||
|
||||
await _emitStatus(eventManager, workflowId, "Building prompt...")
|
||||
chatHistory = _loadChatHistory(chatInterface, workflowId)
|
||||
aiRequest = promptAssembly.buildRequest(userPrompt, fileContexts, chatHistory)
|
||||
|
||||
await _emitStatus(eventManager, workflowId, "AI is processing...")
|
||||
aiResponse = await interfaceAi.callWithTextContext(aiRequest)
|
||||
|
||||
if aiResponse.errorCount > 0:
|
||||
await _emitError(eventManager, workflowId, aiResponse.content)
|
||||
return
|
||||
|
||||
segments = responseParser.parseResponse(aiResponse.content)
|
||||
await _emitSegments(eventManager, workflowId, segments, fileContexts)
|
||||
_logAiStats(aiResponse, workflowId)
|
||||
|
||||
await eventManager.emit_event(workflowId, "complete", {
|
||||
"workflowId": workflowId,
|
||||
"modelName": aiResponse.modelName,
|
||||
"priceCHF": aiResponse.priceCHF,
|
||||
"processingTime": aiResponse.processingTime
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CodeEditor single-shot failed for {workflowId}: {e}", exc_info=True)
|
||||
await eventManager.emit_event(workflowId, "error", {
|
||||
"workflowId": workflowId, "error": str(e)
|
||||
})
|
||||
|
||||
|
||||
async def _processAgentMessage(
|
||||
workflowId, userPrompt, dbManagement, interfaceAi, chatInterface, eventManager
|
||||
):
|
||||
"""Phase 2: Agent loop -- multiple AI calls with tool execution until done."""
|
||||
state = AgentState(workflowId=workflowId)
|
||||
|
||||
try:
|
||||
await _emitStatus(eventManager, workflowId, "Agent: Scanning available files...")
|
||||
fileListContext = fileContextManager.buildFileListContext(dbManagement)
|
||||
|
||||
state.conversationHistory.append({"role": "user", "content": userPrompt})
|
||||
|
||||
aiRequest = promptAssembly.buildAgentRequest(
|
||||
userPrompt=userPrompt,
|
||||
fileListContext=fileListContext,
|
||||
conversationHistory=[]
|
||||
)
|
||||
|
||||
while state.status == "running" and state.currentRound < state.maxRounds:
|
||||
state.currentRound += 1
|
||||
state.totalAiCalls += 1
|
||||
|
||||
await _emitStatus(eventManager, workflowId,
|
||||
f"Agent round {state.currentRound}: AI is thinking...")
|
||||
|
||||
await eventManager.emit_event(workflowId, "chatdata", {
|
||||
"type": "agent_progress",
|
||||
"item": {
|
||||
"round": state.currentRound,
|
||||
"totalAiCalls": state.totalAiCalls,
|
||||
"totalToolCalls": state.totalToolCalls,
|
||||
"costCHF": round(state.totalCostCHF, 4),
|
||||
}
|
||||
})
|
||||
|
||||
aiResponse = await interfaceAi.callWithTextContext(aiRequest)
|
||||
state.totalCostCHF += aiResponse.priceCHF
|
||||
state.totalProcessingTime += aiResponse.processingTime
|
||||
|
||||
if aiResponse.errorCount > 0:
|
||||
logger.error(f"Agent AI call failed in round {state.currentRound}: {aiResponse.content}")
|
||||
await _emitError(eventManager, workflowId, aiResponse.content)
|
||||
state.status = "error"
|
||||
break
|
||||
|
||||
_logAiStats(aiResponse, workflowId)
|
||||
|
||||
state.conversationHistory.append({"role": "assistant", "content": aiResponse.content})
|
||||
|
||||
segments = responseParser.parseResponse(aiResponse.content)
|
||||
|
||||
textAndEditSegments = [s for s in segments if s.type != SegmentTypeEnum.TOOL_CALL]
|
||||
if textAndEditSegments:
|
||||
await _emitSegments(eventManager, workflowId, textAndEditSegments, [])
|
||||
|
||||
toolCallSegments = [s for s in segments if s.type == SegmentTypeEnum.TOOL_CALL]
|
||||
|
||||
if not toolCallSegments:
|
||||
state.status = "completed"
|
||||
break
|
||||
|
||||
toolResultTexts = []
|
||||
for tc in toolCallSegments:
|
||||
state.totalToolCalls += 1
|
||||
await _emitStatus(eventManager, workflowId,
|
||||
f"Agent: Running {tc.toolName}...")
|
||||
|
||||
result = await toolRegistry.dispatch(tc.toolName, tc.toolArgs or {}, dbManagement)
|
||||
toolResultTexts.append(f"[{tc.toolName}] (success={result.success}):\n{result.result}")
|
||||
|
||||
logger.info(f"Agent tool {tc.toolName}: success={result.success}, time={result.executionTime:.2f}s")
|
||||
|
||||
combinedResults = "\n\n".join(toolResultTexts)
|
||||
state.conversationHistory.append({
|
||||
"role": "tool_result",
|
||||
"content": combinedResults,
|
||||
"toolName": "batch"
|
||||
})
|
||||
|
||||
aiRequest = promptAssembly.buildAgentRequest(
|
||||
userPrompt=None,
|
||||
fileListContext=fileListContext,
|
||||
conversationHistory=state.conversationHistory
|
||||
)
|
||||
|
||||
if state.currentRound >= state.maxRounds and state.status == "running":
|
||||
state.status = "max_rounds"
|
||||
await eventManager.emit_event(workflowId, "chatdata", {
|
||||
"type": "message",
|
||||
"item": {
|
||||
"role": "system",
|
||||
"content": f"Agent stopped: maximum rounds ({state.maxRounds}) reached.",
|
||||
"createdAt": getUtcTimestamp()
|
||||
}
|
||||
})
|
||||
|
||||
await eventManager.emit_event(workflowId, "chatdata", {
|
||||
"type": "agent_summary",
|
||||
"item": {
|
||||
"rounds": state.currentRound,
|
||||
"totalAiCalls": state.totalAiCalls,
|
||||
"totalToolCalls": state.totalToolCalls,
|
||||
"costCHF": round(state.totalCostCHF, 4),
|
||||
"processingTime": round(state.totalProcessingTime, 1),
|
||||
"status": state.status,
|
||||
}
|
||||
})
|
||||
|
||||
await eventManager.emit_event(workflowId, "complete", {
|
||||
"workflowId": workflowId,
|
||||
"agentRounds": state.currentRound,
|
||||
"totalCostCHF": round(state.totalCostCHF, 4),
|
||||
"processingTime": round(state.totalProcessingTime, 1)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CodeEditor agent loop failed for {workflowId}: {e}", exc_info=True)
|
||||
await eventManager.emit_event(workflowId, "error", {
|
||||
"workflowId": workflowId, "error": str(e)
|
||||
})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _emitStatus(eventManager, workflowId: str, label: str):
|
||||
await eventManager.emit_event(workflowId, "chatdata", {
|
||||
"type": "status", "label": label
|
||||
})
|
||||
|
||||
|
||||
async def _emitError(eventManager, workflowId: str, errorMsg: str):
|
||||
await eventManager.emit_event(workflowId, "chatdata", {
|
||||
"type": "message",
|
||||
"item": {"role": "assistant", "content": f"Error: {errorMsg}"}
|
||||
})
|
||||
await eventManager.emit_event(workflowId, "error", {
|
||||
"workflowId": workflowId, "error": errorMsg
|
||||
})
|
||||
|
||||
|
||||
async def _emitSegments(eventManager, workflowId: str, segments, fileContexts):
|
||||
"""Emit parsed segments as SSE events."""
|
||||
for segment in segments:
|
||||
messageData = {
|
||||
"role": "assistant",
|
||||
"content": segment.content,
|
||||
"type": segment.type.value,
|
||||
"createdAt": getUtcTimestamp()
|
||||
}
|
||||
await eventManager.emit_event(workflowId, "chatdata", {
|
||||
"type": "message", "item": messageData
|
||||
})
|
||||
|
||||
if segment.type == SegmentTypeEnum.FILE_EDIT:
|
||||
proposal = FileEditProposal(
|
||||
workflowId=workflowId,
|
||||
fileId=_resolveFileId(segment.fileName, fileContexts),
|
||||
fileName=segment.fileName,
|
||||
operation="edit",
|
||||
oldContent=segment.oldContent,
|
||||
newContent=segment.newContent
|
||||
)
|
||||
await eventManager.emit_event(workflowId, "chatdata", {
|
||||
"type": "file_edit_proposal", "item": proposal.model_dump()
|
||||
})
|
||||
|
||||
|
||||
def _loadChatHistory(chatInterface, workflowId: str) -> List[Dict[str, Any]]:
|
||||
"""Load recent chat messages for multi-turn context."""
|
||||
try:
|
||||
messages = chatInterface.getMessages(workflowId)
|
||||
if not messages:
|
||||
return []
|
||||
history = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown") if isinstance(msg, dict) else getattr(msg, "role", "unknown")
|
||||
content = msg.get("content", "") if isinstance(msg, dict) else getattr(msg, "content", "")
|
||||
history.append({"role": role, "content": content})
|
||||
return history
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load chat history for {workflowId}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _resolveFileId(fileName: str, fileContexts) -> str:
|
||||
"""Resolve a fileName to its fileId from the loaded contexts."""
|
||||
for fc in fileContexts:
|
||||
if fc.fileName == fileName:
|
||||
return fc.fileId
|
||||
return f"unknown-{fileName}"
|
||||
|
||||
|
||||
def _logAiStats(aiResponse, workflowId: str):
|
||||
"""Log AI call statistics."""
|
||||
logger.info(
|
||||
f"CodeEditor AI call for {workflowId}: "
|
||||
f"model={aiResponse.modelName}, "
|
||||
f"provider={aiResponse.provider}, "
|
||||
f"cost={aiResponse.priceCHF:.4f} CHF, "
|
||||
f"time={aiResponse.processingTime:.1f}s, "
|
||||
f"sent={aiResponse.bytesSent}B, received={aiResponse.bytesReceived}B"
|
||||
)
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Data models for the CodeEditor feature."""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
import uuid
|
||||
|
||||
|
||||
class SegmentTypeEnum(str, Enum):
|
||||
TEXT = "text"
|
||||
CODE_BLOCK = "code_block"
|
||||
FILE_EDIT = "file_edit"
|
||||
TOOL_CALL = "tool_call"
|
||||
|
||||
|
||||
class EditStatusEnum(str, Enum):
|
||||
PENDING = "pending"
|
||||
ACCEPTED = "accepted"
|
||||
REJECTED = "rejected"
|
||||
|
||||
|
||||
class FileContext(BaseModel):
|
||||
"""A text file loaded as context for the AI."""
|
||||
fileId: str
|
||||
fileName: str
|
||||
content: Optional[str] = None
|
||||
mimeType: str
|
||||
sizeBytes: int = 0
|
||||
modifiedAt: Optional[float] = None
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ResponseSegment(BaseModel):
|
||||
"""A parsed segment from the AI response."""
|
||||
type: SegmentTypeEnum
|
||||
content: str
|
||||
language: Optional[str] = None
|
||||
fileId: Optional[str] = None
|
||||
fileName: Optional[str] = None
|
||||
oldContent: Optional[str] = None
|
||||
newContent: Optional[str] = None
|
||||
toolName: Optional[str] = None
|
||||
toolArgs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class FileEditProposal(BaseModel):
|
||||
"""A proposed file edit from the AI, awaiting user accept/reject."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
workflowId: str
|
||||
fileId: str
|
||||
fileName: str
|
||||
operation: str = "edit"
|
||||
oldContent: Optional[str] = None
|
||||
newContent: str
|
||||
diffSummary: Optional[str] = None
|
||||
status: EditStatusEnum = EditStatusEnum.PENDING
|
||||
createdAt: float = Field(default_factory=getUtcTimestamp)
|
||||
|
||||
|
||||
class FileVersion(BaseModel):
|
||||
"""A new version of a file created after accepting an edit proposal."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
sourceFileId: str
|
||||
editProposalId: str
|
||||
newFileId: str
|
||||
createdAt: float = Field(default_factory=getUtcTimestamp)
|
||||
|
||||
|
||||
class AgentState(BaseModel):
|
||||
"""Tracks state across an agent loop execution."""
|
||||
workflowId: str
|
||||
currentRound: int = 0
|
||||
maxRounds: int = 50
|
||||
totalAiCalls: int = 0
|
||||
totalToolCalls: int = 0
|
||||
totalCostCHF: float = 0.0
|
||||
totalProcessingTime: float = 0.0
|
||||
conversationHistory: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
status: str = "running"
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""Result from executing a tool."""
|
||||
toolName: str
|
||||
result: str
|
||||
success: bool = True
|
||||
executionTime: float = 0.0
|
||||
|
||||
|
||||
TEXT_MIME_TYPES = {
|
||||
"text/plain", "text/markdown", "text/html", "text/css", "text/csv",
|
||||
"text/xml", "text/yaml", "text/x-python", "text/x-java",
|
||||
"text/javascript", "text/x-typescript", "text/x-sql",
|
||||
"application/json", "application/xml", "application/yaml",
|
||||
"application/x-yaml", "application/javascript",
|
||||
}
|
||||
|
||||
TEXT_EXTENSIONS = {
|
||||
".md", ".txt", ".json", ".yaml", ".yml", ".xml", ".csv",
|
||||
".py", ".js", ".ts", ".tsx", ".jsx", ".html", ".htm", ".css", ".scss",
|
||||
".sql", ".sh", ".bash", ".zsh", ".ps1", ".bat",
|
||||
".toml", ".ini", ".cfg", ".conf", ".env", ".gitignore",
|
||||
".dockerfile", ".docker-compose", ".makefile",
|
||||
".java", ".kt", ".go", ".rs", ".rb", ".php", ".swift", ".c", ".cpp", ".h",
|
||||
".r", ".lua", ".dart", ".vue", ".svelte",
|
||||
}
|
||||
|
||||
|
||||
def isTextFile(mimeType: Optional[str], fileName: Optional[str] = None) -> bool:
|
||||
"""Check if a file is a text-based file suitable for the editor."""
|
||||
if mimeType and mimeType.lower() in TEXT_MIME_TYPES:
|
||||
return True
|
||||
if mimeType and mimeType.lower().startswith("text/"):
|
||||
return True
|
||||
if fileName:
|
||||
ext = "." + fileName.rsplit(".", 1)[-1].lower() if "." in fileName else ""
|
||||
if ext in TEXT_EXTENSIONS:
|
||||
return True
|
||||
return False
|
||||
|
|
@ -1,84 +0,0 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""File context manager for CodeEditor feature.
|
||||
Loads text files from the database and provides them as context for AI calls."""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from modules.features.codeeditor.datamodelCodeeditor import FileContext, isTextFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def loadFileContexts(dbManagement, fileIds: List[str]) -> List[FileContext]:
|
||||
"""Load text files from DB and return as FileContext list.
|
||||
|
||||
Args:
|
||||
dbManagement: interfaceDbManagement instance with user context set
|
||||
fileIds: list of file IDs to load
|
||||
"""
|
||||
contexts = []
|
||||
for fileId in fileIds:
|
||||
fileItem = dbManagement.getFile(fileId)
|
||||
if not fileItem:
|
||||
logger.warning(f"File {fileId} not found or no access")
|
||||
continue
|
||||
|
||||
if not isTextFile(fileItem.mimeType, fileItem.fileName):
|
||||
logger.warning(f"File {fileItem.fileName} ({fileItem.mimeType}) is not a text file, skipping")
|
||||
continue
|
||||
|
||||
fileData = dbManagement.getFileData(fileId)
|
||||
if not fileData:
|
||||
logger.warning(f"No data for file {fileId}")
|
||||
continue
|
||||
|
||||
try:
|
||||
content = fileData.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
logger.warning(f"File {fileItem.fileName} is not valid UTF-8, skipping")
|
||||
continue
|
||||
|
||||
contexts.append(FileContext(
|
||||
fileId=fileId,
|
||||
fileName=fileItem.fileName,
|
||||
content=content,
|
||||
mimeType=fileItem.mimeType,
|
||||
sizeBytes=fileItem.fileSize
|
||||
))
|
||||
|
||||
logger.info(f"Loaded {len(contexts)} file contexts from {len(fileIds)} requested")
|
||||
return contexts
|
||||
|
||||
|
||||
def listTextFiles(dbManagement) -> List[FileContext]:
|
||||
"""List all text files accessible to the user (metadata only, no content)."""
|
||||
allFiles = dbManagement.getAllFiles()
|
||||
textFiles = []
|
||||
|
||||
if not allFiles:
|
||||
return textFiles
|
||||
|
||||
for fileItem in allFiles:
|
||||
if isTextFile(fileItem.mimeType, fileItem.fileName):
|
||||
modifiedAt = getattr(fileItem, "_modifiedAt", None) or getattr(fileItem, "creationDate", None)
|
||||
textFiles.append(FileContext(
|
||||
fileId=fileItem.id,
|
||||
fileName=fileItem.fileName,
|
||||
content=None,
|
||||
mimeType=fileItem.mimeType,
|
||||
sizeBytes=fileItem.fileSize,
|
||||
modifiedAt=modifiedAt
|
||||
))
|
||||
|
||||
return textFiles
|
||||
|
||||
|
||||
def buildFileListContext(dbManagement) -> str:
|
||||
"""Build a compact file list string for the agent prompt (no content, just metadata)."""
|
||||
textFiles = listTextFiles(dbManagement)
|
||||
if not textFiles:
|
||||
return "No text files available."
|
||||
lines = [f"- {f.fileName} (id: {f.fileId}, size: {f.sizeBytes}B)" for f in textFiles]
|
||||
return f"Total: {len(lines)} text files\n" + "\n".join(lines)
|
||||
|
|
@ -1,183 +0,0 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Prompt assembly for the CodeEditor feature.
|
||||
Builds Cursor-style system prompts with file context and format instructions."""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum
|
||||
from modules.features.codeeditor.datamodelCodeeditor import FileContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = """You are an AI assistant for text and code file editing. You receive files as context and can suggest changes.
|
||||
|
||||
## Rules for file edits
|
||||
- Use ```file_edit``` blocks for file changes
|
||||
- Each file_edit block must contain: fileName, oldContent (exact text to replace), newContent (replacement text)
|
||||
- Explain changes in normal text before or after the block
|
||||
- oldContent must EXACTLY match existing content (including whitespace and indentation)
|
||||
- You may propose edits to multiple files in one response
|
||||
|
||||
## Response format
|
||||
Normal text is displayed as explanation.
|
||||
File changes must use this format:
|
||||
|
||||
```file_edit
|
||||
fileName: <filename>
|
||||
oldContent: |
|
||||
<exact existing content to replace>
|
||||
newContent: |
|
||||
<new replacement content>
|
||||
```
|
||||
|
||||
Code examples (without edits) use standard markdown code blocks:
|
||||
```language
|
||||
code here
|
||||
```
|
||||
|
||||
## Important
|
||||
- Only edit files that are provided in context
|
||||
- Make minimal, targeted changes
|
||||
- Preserve existing formatting and style
|
||||
- If a task is unclear, ask for clarification instead of guessing"""
|
||||
|
||||
|
||||
def buildRequest(
|
||||
userPrompt: str,
|
||||
fileContexts: List[FileContext],
|
||||
chatHistory: Optional[List[Dict[str, Any]]] = None
|
||||
) -> AiCallRequest:
|
||||
"""Build an AiCallRequest with system prompt, file context, and user prompt."""
|
||||
systemPart = SYSTEM_PROMPT
|
||||
fileContextPart = _buildFileContext(fileContexts)
|
||||
historyPart = _buildChatHistory(chatHistory) if chatHistory else ""
|
||||
|
||||
fullPrompt = systemPart
|
||||
if historyPart:
|
||||
fullPrompt += f"\n\n## Previous conversation\n{historyPart}"
|
||||
fullPrompt += f"\n\n## User request\n{userPrompt}"
|
||||
|
||||
return AiCallRequest(
|
||||
prompt=fullPrompt,
|
||||
context=fileContextPart if fileContextPart else None,
|
||||
options=AiCallOptions(
|
||||
operationType=OperationTypeEnum.DATA_ANALYSE,
|
||||
temperature=0.0,
|
||||
compressPrompt=False,
|
||||
compressContext=False,
|
||||
resultFormat="txt"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _buildFileContext(fileContexts: List[FileContext]) -> str:
|
||||
"""Build the file context string with line numbers."""
|
||||
if not fileContexts:
|
||||
return ""
|
||||
|
||||
parts = []
|
||||
for fc in fileContexts:
|
||||
if not fc.content:
|
||||
continue
|
||||
lines = fc.content.split("\n")
|
||||
numberedLines = [f"{i + 1}|{line}" for i, line in enumerate(lines)]
|
||||
numbered = "\n".join(numberedLines)
|
||||
parts.append(f"--- FILE: {fc.fileName} ---\n{numbered}\n--- END FILE ---")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
def buildAgentRequest(
|
||||
userPrompt: Optional[str],
|
||||
fileListContext: str,
|
||||
conversationHistory: List[Dict[str, Any]]
|
||||
) -> AiCallRequest:
|
||||
"""Build an AiCallRequest for agent mode with tool definitions and conversation history."""
|
||||
from modules.features.codeeditor.toolRegistry import formatToolDefinitions
|
||||
|
||||
systemPrompt = _AGENT_SYSTEM_PROMPT.replace("{{TOOL_DEFINITIONS}}", formatToolDefinitions())
|
||||
|
||||
if not conversationHistory:
|
||||
fullPrompt = systemPrompt
|
||||
context = f"## Available files\n{fileListContext}\n\n## Task\n{userPrompt}"
|
||||
else:
|
||||
fullPrompt = systemPrompt
|
||||
historyText = _buildConversationHistory(conversationHistory)
|
||||
context = f"## Available files\n{fileListContext}\n\n## Conversation\n{historyText}"
|
||||
|
||||
return AiCallRequest(
|
||||
prompt=fullPrompt,
|
||||
context=context,
|
||||
options=AiCallOptions(
|
||||
operationType=OperationTypeEnum.DATA_ANALYSE,
|
||||
temperature=0.0,
|
||||
compressPrompt=False,
|
||||
compressContext=False,
|
||||
resultFormat="txt"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
_AGENT_SYSTEM_PROMPT = """You are an AI agent for file analysis and editing. You work autonomously by using tools to read files, search content, and propose edits.
|
||||
|
||||
## Available tools
|
||||
{{TOOL_DEFINITIONS}}
|
||||
|
||||
## How to call tools
|
||||
Use this exact format for each tool call:
|
||||
|
||||
```tool_call
|
||||
tool: <tool_name>
|
||||
args: {"param": "value"}
|
||||
```
|
||||
|
||||
## Rules
|
||||
- Read files ONE AT A TIME with read_file, never assume file contents
|
||||
- First create a plan, then execute it step by step
|
||||
- Use search_files to find relevant files before reading them
|
||||
- Use list_files to discover what files are available
|
||||
- For file changes, use ```file_edit``` blocks (same format as before)
|
||||
- You may combine text explanations, tool calls, and file edits in one response
|
||||
- When you are DONE and need no more tool calls, simply respond with text only (no tool_call blocks)
|
||||
- Keep responses focused and efficient
|
||||
|
||||
## file_edit format (for changes)
|
||||
```file_edit
|
||||
fileName: <filename>
|
||||
oldContent: |
|
||||
<exact existing content>
|
||||
newContent: |
|
||||
<replacement content>
|
||||
```"""
|
||||
|
||||
|
||||
def _buildConversationHistory(history: List[Dict[str, Any]]) -> str:
|
||||
"""Build the full conversation history for agent multi-turn context."""
|
||||
parts = []
|
||||
for msg in history:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
if role == "tool_result":
|
||||
toolName = msg.get("toolName", "")
|
||||
parts.append(f"[Tool Result - {toolName}]:\n{content}")
|
||||
else:
|
||||
parts.append(f"[{role}]:\n{content}")
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
def _buildChatHistory(chatHistory: List[Dict[str, Any]]) -> str:
|
||||
"""Build a condensed chat history string for multi-turn context."""
|
||||
if not chatHistory:
|
||||
return ""
|
||||
|
||||
parts = []
|
||||
for msg in chatHistory[-10:]:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
if len(content) > 500:
|
||||
content = content[:500] + "..."
|
||||
parts.append(f"[{role}]: {content}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
|
@ -1,184 +0,0 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Response parser for the CodeEditor feature.
|
||||
Parses AI responses into typed segments (text, code_block, file_edit, tool_call)."""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from modules.features.codeeditor.datamodelCodeeditor import ResponseSegment, SegmentTypeEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_FENCE_PATTERN = re.compile(r"^```(\w*)\s*$", re.MULTILINE)
|
||||
|
||||
|
||||
def parseResponse(rawContent: str) -> List[ResponseSegment]:
|
||||
"""Parse an AI response into typed segments."""
|
||||
if not rawContent or not rawContent.strip():
|
||||
return []
|
||||
|
||||
segments = []
|
||||
lines = rawContent.split("\n")
|
||||
i = 0
|
||||
|
||||
textBuffer = []
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
|
||||
match = _FENCE_PATTERN.match(line)
|
||||
if match:
|
||||
if textBuffer:
|
||||
_flushTextBuffer(textBuffer, segments)
|
||||
textBuffer = []
|
||||
|
||||
lang = match.group(1).strip()
|
||||
blockLines, endIdx = _collectBlock(lines, i + 1)
|
||||
blockContent = "\n".join(blockLines)
|
||||
|
||||
if lang == "file_edit":
|
||||
segment = _parseFileEditBlock(blockContent)
|
||||
if segment:
|
||||
segments.append(segment)
|
||||
else:
|
||||
segments.append(ResponseSegment(
|
||||
type=SegmentTypeEnum.CODE_BLOCK,
|
||||
content=blockContent,
|
||||
language="text"
|
||||
))
|
||||
elif lang == "tool_call":
|
||||
segment = _parseToolCallBlock(blockContent)
|
||||
if segment:
|
||||
segments.append(segment)
|
||||
else:
|
||||
segments.append(ResponseSegment(
|
||||
type=SegmentTypeEnum.CODE_BLOCK,
|
||||
content=blockContent,
|
||||
language="text"
|
||||
))
|
||||
else:
|
||||
segments.append(ResponseSegment(
|
||||
type=SegmentTypeEnum.CODE_BLOCK,
|
||||
content=blockContent,
|
||||
language=lang or "text"
|
||||
))
|
||||
|
||||
i = endIdx + 1
|
||||
else:
|
||||
textBuffer.append(line)
|
||||
i += 1
|
||||
|
||||
if textBuffer:
|
||||
_flushTextBuffer(textBuffer, segments)
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
def hasToolCalls(segments: List[ResponseSegment]) -> bool:
|
||||
"""Check if any segments contain tool calls."""
|
||||
return any(s.type == SegmentTypeEnum.TOOL_CALL for s in segments)
|
||||
|
||||
|
||||
def _collectBlock(lines: List[str], startIdx: int) -> tuple:
|
||||
"""Collect lines inside a fenced code block until closing ```."""
|
||||
blockLines = []
|
||||
idx = startIdx
|
||||
while idx < len(lines):
|
||||
if lines[idx].strip() == "```":
|
||||
return blockLines, idx
|
||||
blockLines.append(lines[idx])
|
||||
idx += 1
|
||||
return blockLines, idx
|
||||
|
||||
|
||||
def _flushTextBuffer(buffer: List[str], segments: List[ResponseSegment]):
|
||||
"""Flush accumulated text lines into a text segment."""
|
||||
text = "\n".join(buffer).strip()
|
||||
buffer.clear()
|
||||
if text:
|
||||
segments.append(ResponseSegment(
|
||||
type=SegmentTypeEnum.TEXT,
|
||||
content=text
|
||||
))
|
||||
|
||||
|
||||
def _parseFileEditBlock(blockContent: str) -> Optional[ResponseSegment]:
|
||||
"""Parse a file_edit block into a ResponseSegment with fileName, oldContent, newContent."""
|
||||
fields = {"fileName": None, "oldContent": None, "newContent": None}
|
||||
currentField = None
|
||||
currentLines = []
|
||||
|
||||
for line in blockContent.split("\n"):
|
||||
stripped = line.strip()
|
||||
|
||||
newField = None
|
||||
for key in ("fileName", "oldContent", "newContent"):
|
||||
if stripped.startswith(f"{key}:"):
|
||||
newField = key
|
||||
break
|
||||
|
||||
if newField:
|
||||
if currentField and currentLines:
|
||||
fields[currentField] = "\n".join(currentLines)
|
||||
currentField = newField
|
||||
value = stripped[len(f"{newField}:"):].strip()
|
||||
if newField == "fileName":
|
||||
fields["fileName"] = value if value else None
|
||||
currentField = None
|
||||
currentLines = []
|
||||
else:
|
||||
currentLines = [value] if value and value != "|" else []
|
||||
else:
|
||||
if currentField in ("oldContent", "newContent"):
|
||||
dedented = line[2:] if line.startswith(" ") else line
|
||||
currentLines.append(dedented)
|
||||
|
||||
if currentField and currentLines:
|
||||
fields[currentField] = "\n".join(currentLines)
|
||||
|
||||
if not fields["fileName"]:
|
||||
logger.warning("file_edit block missing fileName")
|
||||
return None
|
||||
if fields["newContent"] is None:
|
||||
logger.warning(f"file_edit block for {fields['fileName']} missing newContent")
|
||||
return None
|
||||
|
||||
return ResponseSegment(
|
||||
type=SegmentTypeEnum.FILE_EDIT,
|
||||
content=f"Edit: {fields['fileName']}",
|
||||
fileName=fields["fileName"],
|
||||
oldContent=fields["oldContent"],
|
||||
newContent=fields["newContent"]
|
||||
)
|
||||
|
||||
|
||||
def _parseToolCallBlock(blockContent: str) -> Optional[ResponseSegment]:
|
||||
"""Parse a tool_call block into a ResponseSegment with toolName and toolArgs."""
|
||||
toolName = None
|
||||
toolArgs = {}
|
||||
|
||||
for line in blockContent.split("\n"):
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("tool:"):
|
||||
toolName = stripped[len("tool:"):].strip()
|
||||
elif stripped.startswith("args:"):
|
||||
argsStr = stripped[len("args:"):].strip()
|
||||
try:
|
||||
toolArgs = json.loads(argsStr)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Could not parse tool args as JSON: {argsStr}")
|
||||
toolArgs = {"raw": argsStr}
|
||||
|
||||
if not toolName:
|
||||
logger.warning("tool_call block missing tool name")
|
||||
return None
|
||||
|
||||
return ResponseSegment(
|
||||
type=SegmentTypeEnum.TOOL_CALL,
|
||||
content=f"Tool: {toolName}",
|
||||
toolName=toolName,
|
||||
toolArgs=toolArgs
|
||||
)
|
||||
|
|
@ -1,395 +0,0 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
CodeEditor Feature Routes.
|
||||
SSE-based endpoints for Cursor-style AI file editing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from modules.auth import limiter, getRequestContext, RequestContext
|
||||
from modules.interfaces import interfaceDbChat, interfaceDbManagement
|
||||
from modules.interfaces.interfaceAiObjects import AiObjects
|
||||
from modules.datamodels.datamodelChat import UserInputRequest
|
||||
from modules.services.serviceStreaming import get_event_manager
|
||||
from modules.features.codeeditor import codeEditorProcessor, fileContextManager
|
||||
from modules.features.codeeditor.datamodelCodeeditor import FileEditProposal, EditStatusEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/codeeditor",
|
||||
tags=["Code Editor Feature"],
|
||||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
|
||||
_aiObjects: Optional[AiObjects] = None
|
||||
|
||||
|
||||
async def _getAiObjects() -> AiObjects:
|
||||
"""Lazy-init singleton for AiObjects."""
|
||||
global _aiObjects
|
||||
if _aiObjects is None:
|
||||
_aiObjects = await AiObjects.create()
|
||||
return _aiObjects
|
||||
|
||||
|
||||
def _getServiceChat(context: RequestContext, featureInstanceId: str = None):
|
||||
"""Get chat interface with feature instance context."""
|
||||
return interfaceDbChat.getInterface(
|
||||
context.user,
|
||||
mandateId=str(context.mandateId) if context.mandateId else None,
|
||||
featureInstanceId=featureInstanceId
|
||||
)
|
||||
|
||||
|
||||
def _getDbManagement(context: RequestContext, featureInstanceId: str = None):
|
||||
"""Get management interface with user context for file access."""
|
||||
return interfaceDbManagement.getInterface(
|
||||
context.user,
|
||||
mandateId=str(context.mandateId) if context.mandateId else None,
|
||||
featureInstanceId=featureInstanceId
|
||||
)
|
||||
|
||||
|
||||
def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str:
|
||||
"""Validate user has access to the feature instance. Returns mandateId."""
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
|
||||
rootInterface = getRootInterface()
|
||||
instance = rootInterface.getFeatureInstance(instanceId)
|
||||
if not instance:
|
||||
raise HTTPException(status_code=404, detail=f"Feature instance {instanceId} not found")
|
||||
|
||||
featureAccess = rootInterface.getFeatureAccess(str(context.user.id), instanceId)
|
||||
if not featureAccess or not featureAccess.enabled:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this feature instance")
|
||||
|
||||
return str(instance.mandateId) if instance.mandateId else None
|
||||
|
||||
|
||||
@router.post("/{instanceId}/start/stream")
|
||||
@limiter.limit("60/minute")
|
||||
async def streamCodeeditorStart(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: Optional[str] = Query(None, description="Optional workflow ID to continue"),
|
||||
mode: str = Query("simple", description="Processing mode: 'simple' (single AI call) or 'agent' (multi-step with tools)"),
|
||||
userInput: UserInputRequest = Body(...),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
):
|
||||
"""Start or continue a CodeEditor workflow with SSE streaming. Supports simple and agent mode."""
|
||||
try:
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
|
||||
dbManagement = _getDbManagement(context, featureInstanceId=instanceId)
|
||||
aiObjects = await _getAiObjects()
|
||||
eventManager = get_event_manager()
|
||||
|
||||
if workflowId:
|
||||
workflow = chatInterface.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail=f"Workflow {workflowId} not found")
|
||||
else:
|
||||
workflow = chatInterface.createWorkflow({
|
||||
"workflowMode": "CodeEditor",
|
||||
"status": "running",
|
||||
"label": userInput.prompt[:80] if userInput.prompt else "CodeEditor Session",
|
||||
})
|
||||
workflowId = workflow.get("id") if isinstance(workflow, dict) else workflow.id
|
||||
|
||||
queue = eventManager.create_queue(workflowId)
|
||||
|
||||
userMessage = {
|
||||
"role": "user",
|
||||
"content": userInput.prompt,
|
||||
"selectedFiles": userInput.listFileId or []
|
||||
}
|
||||
await eventManager.emit_event(workflowId, "chatdata", {
|
||||
"type": "message", "item": userMessage
|
||||
})
|
||||
|
||||
selectedFileIds = userInput.listFileId or []
|
||||
|
||||
agentMode = mode.lower() == "agent"
|
||||
|
||||
asyncio.create_task(
|
||||
codeEditorProcessor.processMessage(
|
||||
workflowId=workflowId,
|
||||
userPrompt=userInput.prompt,
|
||||
selectedFileIds=selectedFileIds,
|
||||
dbManagement=dbManagement,
|
||||
interfaceAi=aiObjects,
|
||||
chatInterface=chatInterface,
|
||||
eventManager=eventManager,
|
||||
agentMode=agentMode
|
||||
)
|
||||
)
|
||||
|
||||
async def _eventStream():
|
||||
streamTimeout = 300
|
||||
lastActivity = asyncio.get_event_loop().time()
|
||||
|
||||
while True:
|
||||
now = asyncio.get_event_loop().time()
|
||||
if now - lastActivity > streamTimeout:
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': 'Stream timeout'})}\n\n"
|
||||
break
|
||||
|
||||
if await request.is_disconnected():
|
||||
logger.info(f"Client disconnected for workflow {workflowId}")
|
||||
break
|
||||
|
||||
try:
|
||||
event = await asyncio.wait_for(queue.get(), timeout=1.0)
|
||||
lastActivity = asyncio.get_event_loop().time()
|
||||
|
||||
eventType = event.get("type", "")
|
||||
|
||||
if eventType == "chatdata":
|
||||
yield f"data: {json.dumps(event.get('data', {}))}\n\n"
|
||||
elif eventType in ("complete", "stopped", "error"):
|
||||
yield f"data: {json.dumps({'type': eventType, **event.get('data', {})})}\n\n"
|
||||
break
|
||||
else:
|
||||
yield f"data: {json.dumps(event)}\n\n"
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
yield ": keepalive\n\n"
|
||||
|
||||
await eventManager.cleanup(workflowId)
|
||||
|
||||
return StreamingResponse(
|
||||
_eventStream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streamCodeeditorStart: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{instanceId}/{workflowId}/stop")
|
||||
@limiter.limit("120/minute")
|
||||
async def stopWorkflow(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="Workflow ID"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
):
|
||||
"""Stop a running CodeEditor workflow."""
|
||||
try:
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
eventManager = get_event_manager()
|
||||
await eventManager.emit_event(workflowId, "stopped", {
|
||||
"workflowId": workflowId
|
||||
}, event_category="workflow", message="Workflow stopped by user")
|
||||
return {"status": "stopped", "workflowId": workflowId}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping workflow: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{instanceId}/{workflowId}/chatData")
|
||||
@limiter.limit("120/minute")
|
||||
def getWorkflowChatData(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="Workflow ID"),
|
||||
afterTimestamp: Optional[float] = Query(None, description="Unix timestamp for incremental fetch"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get chat data for a workflow (polling fallback)."""
|
||||
try:
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
|
||||
workflow = chatInterface.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail=f"Workflow {workflowId} not found")
|
||||
return chatInterface.getUnifiedChatData(workflowId, afterTimestamp)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting chat data: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{instanceId}/workflows")
|
||||
@limiter.limit("120/minute")
|
||||
def getWorkflows(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
page: int = Query(1, ge=1),
|
||||
pageSize: int = Query(20, ge=1, le=100),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""List workflows for this feature instance."""
|
||||
try:
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
chatInterface = _getServiceChat(context, featureInstanceId=instanceId)
|
||||
from modules.datamodels.datamodelPagination import PaginationParams
|
||||
pagination = PaginationParams(page=page, pageSize=pageSize)
|
||||
return chatInterface.getWorkflows(pagination=pagination)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflows: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{instanceId}/files")
|
||||
@limiter.limit("120/minute")
|
||||
def getFiles(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""List all text files accessible to the user."""
|
||||
try:
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
dbManagement = _getDbManagement(context, featureInstanceId=instanceId)
|
||||
textFiles = fileContextManager.listTextFiles(dbManagement)
|
||||
return {
|
||||
"files": [f.model_dump(exclude={"content"}) for f in textFiles],
|
||||
"count": len(textFiles)
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing files: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{instanceId}/files/{fileId}/content")
|
||||
@limiter.limit("120/minute")
|
||||
def getFileContent(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
fileId: str = Path(..., description="File ID"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the text content of a file."""
|
||||
try:
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
dbManagement = _getDbManagement(context, featureInstanceId=instanceId)
|
||||
|
||||
fileItem = dbManagement.getFile(fileId)
|
||||
if not fileItem:
|
||||
raise HTTPException(status_code=404, detail=f"File {fileId} not found")
|
||||
|
||||
fileData = dbManagement.getFileData(fileId)
|
||||
if not fileData:
|
||||
raise HTTPException(status_code=404, detail=f"No data for file {fileId}")
|
||||
|
||||
try:
|
||||
content = fileData.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
|
||||
|
||||
return {
|
||||
"fileId": fileId,
|
||||
"fileName": fileItem.fileName,
|
||||
"mimeType": fileItem.mimeType,
|
||||
"content": content
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting file content: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{instanceId}/{workflowId}/apply")
|
||||
@limiter.limit("60/minute")
|
||||
async def applyEdit(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="Workflow ID"),
|
||||
proposalData: Dict[str, Any] = Body(...),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Accept a file edit proposal. Updates existing file or creates new one."""
|
||||
try:
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
dbManagement = _getDbManagement(context, featureInstanceId=instanceId)
|
||||
|
||||
fileId = proposalData.get("fileId", "")
|
||||
newContent = proposalData.get("newContent")
|
||||
fileName = proposalData.get("fileName", "")
|
||||
|
||||
if newContent is None:
|
||||
raise HTTPException(status_code=400, detail="newContent is required")
|
||||
|
||||
contentBytes = newContent.encode("utf-8")
|
||||
isNewFile = not fileId or fileId.startswith("unknown-")
|
||||
|
||||
if isNewFile:
|
||||
mimeType = _guessMimeType(fileName)
|
||||
fileItem = dbManagement.createFile(fileName, mimeType, contentBytes)
|
||||
resultFileId = fileItem.id
|
||||
resultFileName = fileItem.fileName
|
||||
else:
|
||||
fileItem = dbManagement.getFile(fileId)
|
||||
if not fileItem:
|
||||
raise HTTPException(status_code=404, detail=f"File {fileId} not found")
|
||||
success = dbManagement.createFileData(fileId, contentBytes)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to store updated file content")
|
||||
resultFileId = fileId
|
||||
resultFileName = fileName or fileItem.fileName
|
||||
|
||||
eventManager = get_event_manager()
|
||||
await eventManager.emit_event(workflowId, "chatdata", {
|
||||
"type": "file_version",
|
||||
"item": {
|
||||
"fileId": resultFileId,
|
||||
"fileName": resultFileName,
|
||||
"status": "accepted",
|
||||
"isNew": isNewFile
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
"status": "accepted",
|
||||
"fileId": resultFileId,
|
||||
"fileName": resultFileName,
|
||||
"isNew": isNewFile
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying edit: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
_MIME_MAP = {
|
||||
".md": "text/markdown", ".txt": "text/plain", ".json": "application/json",
|
||||
".yaml": "application/yaml", ".yml": "application/yaml", ".xml": "application/xml",
|
||||
".csv": "text/csv", ".py": "text/x-python", ".js": "text/javascript",
|
||||
".ts": "text/x-typescript", ".html": "text/html", ".css": "text/css",
|
||||
".sql": "text/x-sql", ".sh": "text/x-shellscript",
|
||||
}
|
||||
|
||||
|
||||
def _guessMimeType(fileName: str) -> str:
|
||||
"""Guess MIME type from file extension."""
|
||||
if not fileName or "." not in fileName:
|
||||
return "text/plain"
|
||||
ext = "." + fileName.rsplit(".", 1)[-1].lower()
|
||||
return _MIME_MAP.get(ext, "text/plain")
|
||||
|
|
@ -1,157 +0,0 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Tool registry and dispatcher for the CodeEditor agent loop.
|
||||
Defines available tools and executes them against the file context manager."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import fnmatch
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from modules.features.codeeditor.datamodelCodeeditor import ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TOOL_DEFINITIONS = [
|
||||
{
|
||||
"name": "read_file",
|
||||
"description": "Read the full content of a single file by its fileId.",
|
||||
"parameters": {"fileId": "string (required)"}
|
||||
},
|
||||
{
|
||||
"name": "list_files",
|
||||
"description": "List all available text files with metadata (name, size, mimeType). Optionally filter by glob pattern.",
|
||||
"parameters": {"filter": "string (optional, glob pattern e.g. '*.py')"}
|
||||
},
|
||||
{
|
||||
"name": "search_files",
|
||||
"description": "Search all file contents for a text query. Returns matching lines with file name and line number.",
|
||||
"parameters": {"query": "string (required)", "fileType": "string (optional, extension e.g. 'py')"}
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def dispatch(toolName: str, toolArgs: Dict[str, Any], dbManagement) -> ToolResult:
|
||||
"""Execute a tool and return the result."""
|
||||
startTime = time.time()
|
||||
try:
|
||||
if toolName == "read_file":
|
||||
result = await _toolReadFile(toolArgs, dbManagement)
|
||||
elif toolName == "list_files":
|
||||
result = _toolListFiles(toolArgs, dbManagement)
|
||||
elif toolName == "search_files":
|
||||
result = await _toolSearchFiles(toolArgs, dbManagement)
|
||||
else:
|
||||
result = f"Unknown tool: {toolName}"
|
||||
return ToolResult(toolName=toolName, result=result, success=False,
|
||||
executionTime=time.time() - startTime)
|
||||
|
||||
return ToolResult(toolName=toolName, result=result, success=True,
|
||||
executionTime=time.time() - startTime)
|
||||
except Exception as e:
|
||||
logger.error(f"Tool {toolName} failed: {e}", exc_info=True)
|
||||
return ToolResult(toolName=toolName, result=f"Error: {str(e)}", success=False,
|
||||
executionTime=time.time() - startTime)
|
||||
|
||||
|
||||
async def _toolReadFile(args: Dict[str, Any], dbManagement) -> str:
|
||||
"""Read a single file's content."""
|
||||
fileId = args.get("fileId", "")
|
||||
if not fileId:
|
||||
return "Error: fileId is required"
|
||||
|
||||
fileItem = dbManagement.getFile(fileId)
|
||||
if not fileItem:
|
||||
return f"Error: File {fileId} not found"
|
||||
|
||||
fileData = dbManagement.getFileData(fileId)
|
||||
if not fileData:
|
||||
return f"Error: No data for file {fileId}"
|
||||
|
||||
try:
|
||||
content = fileData.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return f"Error: File {fileItem.fileName} is not valid UTF-8"
|
||||
|
||||
lines = content.split("\n")
|
||||
numbered = "\n".join([f"{i + 1}|{line}" for i, line in enumerate(lines)])
|
||||
return f"--- FILE: {fileItem.fileName} (id: {fileId}) ---\n{numbered}\n--- END FILE ---"
|
||||
|
||||
|
||||
def _toolListFiles(args: Dict[str, Any], dbManagement) -> str:
|
||||
"""List all text files, optionally filtered by glob pattern."""
|
||||
from modules.features.codeeditor.datamodelCodeeditor import isTextFile
|
||||
|
||||
filterPattern = args.get("filter", "")
|
||||
allFiles = dbManagement.getAllFiles()
|
||||
if not allFiles:
|
||||
return "No files found."
|
||||
|
||||
lines = []
|
||||
for f in allFiles:
|
||||
if not isTextFile(f.mimeType, f.fileName):
|
||||
continue
|
||||
if filterPattern and not fnmatch.fnmatch(f.fileName, filterPattern):
|
||||
continue
|
||||
lines.append(f"- {f.fileName} (id: {f.id}, size: {f.fileSize}B, type: {f.mimeType})")
|
||||
|
||||
if not lines:
|
||||
return "No matching text files found."
|
||||
return f"Available files ({len(lines)}):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
async def _toolSearchFiles(args: Dict[str, Any], dbManagement) -> str:
|
||||
"""Search file contents for a query string."""
|
||||
from modules.features.codeeditor.datamodelCodeeditor import isTextFile
|
||||
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return "Error: query is required"
|
||||
|
||||
fileType = args.get("fileType", "")
|
||||
allFiles = dbManagement.getAllFiles()
|
||||
if not allFiles:
|
||||
return "No files to search."
|
||||
|
||||
hits = []
|
||||
maxHits = 50
|
||||
queryLower = query.lower()
|
||||
|
||||
for f in allFiles:
|
||||
if not isTextFile(f.mimeType, f.fileName):
|
||||
continue
|
||||
if fileType and not f.fileName.endswith(f".{fileType}"):
|
||||
continue
|
||||
|
||||
fileData = dbManagement.getFileData(f.id)
|
||||
if not fileData:
|
||||
continue
|
||||
|
||||
try:
|
||||
content = fileData.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
for lineNum, line in enumerate(content.split("\n"), 1):
|
||||
if queryLower in line.lower():
|
||||
hits.append(f"{f.fileName}:{lineNum}: {line.strip()}")
|
||||
if len(hits) >= maxHits:
|
||||
break
|
||||
if len(hits) >= maxHits:
|
||||
break
|
||||
|
||||
if not hits:
|
||||
return f"No matches found for '{query}'."
|
||||
result = f"Search results for '{query}' ({len(hits)} matches):\n" + "\n".join(hits)
|
||||
if len(hits) >= maxHits:
|
||||
result += f"\n... (truncated at {maxHits} matches)"
|
||||
return result
|
||||
|
||||
|
||||
def formatToolDefinitions() -> str:
|
||||
"""Format tool definitions for inclusion in the system prompt."""
|
||||
parts = []
|
||||
for tool in TOOL_DEFINITIONS:
|
||||
params = ", ".join([f"{k}: {v}" for k, v in tool["parameters"].items()])
|
||||
parts.append(f"- **{tool['name']}**: {tool['description']}\n Parameters: {{{params}}}")
|
||||
return "\n".join(parts)
|
||||
|
|
@ -10,8 +10,10 @@ import json
|
|||
import asyncio
|
||||
import base64
|
||||
import uuid
|
||||
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request, WebSocket, WebSocketDisconnect, Query
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
|
||||
from modules.auth import limiter, getRequestContext, RequestContext
|
||||
|
|
@ -31,7 +33,6 @@ from .datamodelCommcoach import (
|
|||
StartSessionRequest, CreatePersonaRequest, UpdatePersonaRequest,
|
||||
)
|
||||
from .serviceCommcoach import CommcoachService, emitSessionEvent, getSessionEventQueue, cleanupSessionEvents
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_activeProcessTasks: dict = {}
|
||||
|
|
|
|||
|
|
@ -1011,15 +1011,15 @@ class CommcoachService:
|
|||
|
||||
async def _callAi(self, systemPrompt: str, userPrompt: str):
|
||||
"""Call the AI service with the given prompts."""
|
||||
from modules.services.serviceAi.mainServiceAi import AiService
|
||||
from modules.serviceCenter import getService
|
||||
from modules.serviceCenter.context import ServiceCenterContext
|
||||
|
||||
serviceContext = type('Ctx', (), {
|
||||
'user': self.currentUser,
|
||||
'mandateId': self.mandateId,
|
||||
'featureInstanceId': self.instanceId,
|
||||
'featureCode': 'commcoach',
|
||||
})()
|
||||
aiService = AiService(serviceCenter=serviceContext)
|
||||
serviceContext = ServiceCenterContext(
|
||||
user=self.currentUser,
|
||||
mandate_id=self.mandateId,
|
||||
feature_instance_id=self.instanceId,
|
||||
)
|
||||
aiService = getService("ai", serviceContext)
|
||||
await aiService.ensureAiObjectsInitialized()
|
||||
|
||||
aiRequest = AiCallRequest(
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from urllib.parse import urlparse, unquote
|
|||
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from .datamodelFeatureNeutralizer import DataNeutralizerAttributes, DataNeutraliserConfig
|
||||
from modules.services import getInterface as getServices
|
||||
from modules.serviceHub import getInterface as getServices
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -205,7 +205,7 @@ class NeutralizationPlayground:
|
|||
|
||||
async def processSharepointFiles(self, sourcePath: str, targetPath: str) -> Dict[str, Any]:
|
||||
"""Process files from SharePoint source path and store neutralized files in target path"""
|
||||
from modules.services.serviceSharepoint.mainServiceSharepoint import SharepointService
|
||||
from modules.serviceCenter.services.serviceSharepoint.mainServiceSharepoint import SharepointService
|
||||
processor = SharepointProcessor(self.currentUser, self.services)
|
||||
return await processor.processSharepointFiles(sourcePath, targetPath)
|
||||
|
||||
|
|
|
|||
|
|
@ -262,8 +262,8 @@ class NeutralizationService:
|
|||
fileId: Optional[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""Extract -> neutralize -> adapt -> generate for PDF/DOCX/XLSX/PPTX."""
|
||||
from modules.services.serviceExtraction.mainServiceExtraction import ExtractionService
|
||||
from modules.services.serviceExtraction.subPipeline import runExtraction
|
||||
from modules.serviceCenter.services.serviceExtraction.mainServiceExtraction import ExtractionService
|
||||
from modules.serviceCenter.services.serviceExtraction.subPipeline import runExtraction
|
||||
from modules.datamodels.datamodelExtraction import ExtractionOptions, MergeStrategy
|
||||
|
||||
# Ensure registries exist
|
||||
|
|
@ -405,10 +405,10 @@ class NeutralizationService:
|
|||
|
||||
def _getRendererForMime(self, mimeType: str):
|
||||
"""Get renderer instance and output mime for the given input MIME type."""
|
||||
from modules.services.serviceGeneration.renderers.rendererPdf import RendererPdf
|
||||
from modules.services.serviceGeneration.renderers.rendererDocx import RendererDocx
|
||||
from modules.services.serviceGeneration.renderers.rendererXlsx import RendererXlsx
|
||||
from modules.services.serviceGeneration.renderers.rendererPptx import RendererPptx
|
||||
from modules.serviceCenter.services.serviceGeneration.renderers.rendererPdf import RendererPdf
|
||||
from modules.serviceCenter.services.serviceGeneration.renderers.rendererDocx import RendererDocx
|
||||
from modules.serviceCenter.services.serviceGeneration.renderers.rendererXlsx import RendererXlsx
|
||||
from modules.serviceCenter.services.serviceGeneration.renderers.rendererPptx import RendererPptx
|
||||
|
||||
mime_map = {
|
||||
"application/pdf": (RendererPdf, "application/pdf"),
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ _NEUTRALIZATION_BLACKLIST = frozenset({
|
|||
"Leistungen", "Basis", "Benefits", # Section labels
|
||||
"Start", "Beginn", "Ende", "End", "trip", # Contract labels (Start of trip, End of trip, etc.)
|
||||
"incomplete", "Application", "Complete", "Pending", # Form/status labels, not addresses
|
||||
"Marketing", "Verkaufsstrategien", "Qualitätsmanagement", # Business terms, not addresses
|
||||
"Ausbildungsstätte", "Realschule", # Institution types, not city names
|
||||
# Ambiguous substrings – match in Zurich, CHF, UID-Nr., websites, etc.
|
||||
"CH", "DE", "FR", "IT", "Nr", "Nr.", "Nr:", "No", "No.", "No:",
|
||||
"www", ".ch", ".com", ".org", ".net", "CHF",
|
||||
|
|
@ -95,6 +97,19 @@ class StringParser:
|
|||
if not (m[0] == "address" and any(overlaps(m[2], m[3], ds, de) for ds, de in date_ranges))
|
||||
]
|
||||
|
||||
# For name matches: resolve overlaps – keep only longest to avoid multiple placeholders for one name
|
||||
# (e.g. "Ida", "Dittrich", "Ida Dittrich" → keep only "Ida Dittrich" with one UUID)
|
||||
name_matches = [(m, m[3] - m[2]) for m in patternMatches if m[0] == "name"]
|
||||
name_spans = [(m[2], m[3]) for m, _ in name_matches]
|
||||
patternMatches = [
|
||||
m for m in patternMatches
|
||||
if m[0] != "name"
|
||||
or not any(
|
||||
overlaps(m[2], m[3], ns, ne) and (ne - ns) > (m[3] - m[2])
|
||||
for (ns, ne) in name_spans
|
||||
)
|
||||
]
|
||||
|
||||
# Process from right to left to avoid position shifts
|
||||
for patternName, matchedText, start, end in reversed(patternMatches):
|
||||
# Skip if already a placeholder
|
||||
|
|
@ -157,24 +172,72 @@ class StringParser:
|
|||
expanded.add(f"{n1} {n2}")
|
||||
expanded.add(f"{n2} {n1}")
|
||||
|
||||
# Process longest first so "Ida Dittrich" replaces before "Ida" or "Dittrich"
|
||||
# One UUID per person: composites and their parts share same UUID
|
||||
# Also align with DataPatterns mapping (step 1 may have already replaced "Ida Dittrich")
|
||||
name_to_uuid: Dict[str, str] = {}
|
||||
for composite in sorted(expanded, key=len, reverse=True):
|
||||
if " " not in composite:
|
||||
continue
|
||||
parts = composite.split()
|
||||
parts_set = frozenset(parts)
|
||||
existing_uuid = next((name_to_uuid[p] for p in parts_set if p in name_to_uuid), None)
|
||||
if existing_uuid is None:
|
||||
existing_uuid = next(
|
||||
(self.mapping[k] for k in (composite, *parts_set) if k in self.mapping),
|
||||
None
|
||||
)
|
||||
if existing_uuid is None:
|
||||
existing_uuid = f"[name.{uuid.uuid4()}]"
|
||||
for p in parts_set:
|
||||
name_to_uuid[p] = existing_uuid
|
||||
name_to_uuid[composite] = existing_uuid
|
||||
if len(parts) == 2:
|
||||
name_to_uuid[f"{parts[1]} {parts[0]}"] = existing_uuid
|
||||
for n in names:
|
||||
if n not in name_to_uuid:
|
||||
name_to_uuid[n] = self.mapping.get(n) or f"[name.{uuid.uuid4()}]"
|
||||
self.mapping.update({k: v for k, v in name_to_uuid.items() if k not in self.mapping})
|
||||
|
||||
# Collect ALL matches from all name patterns, then keep only longest per span to avoid
|
||||
# triple replacement ("Ida" + "Dittrich" + "Ida Dittrich" -> only "Ida Dittrich")
|
||||
all_matches: List[Tuple[str, int, int]] = []
|
||||
for name in sorted(expanded, key=len, reverse=True):
|
||||
# Composite: flexible whitespace (space, newline); single: word boundaries
|
||||
if " " in name:
|
||||
parts = name.split()
|
||||
pattern_str = r"\b" + r"\s+".join(re.escape(p) for p in parts) + r"\b"
|
||||
else:
|
||||
pattern_str = r"\b" + re.escape(name) + r"\b"
|
||||
pattern = re.compile(pattern_str, re.IGNORECASE)
|
||||
for m in pattern.finditer(text):
|
||||
all_matches.append((m.group(), m.start(), m.end()))
|
||||
|
||||
matches = list(pattern.finditer(text))
|
||||
for match in reversed(matches):
|
||||
matchedText = match.group()
|
||||
if matchedText not in self.mapping:
|
||||
placeholderId = str(uuid.uuid4())
|
||||
self.mapping[matchedText] = f"[name.{placeholderId}]"
|
||||
replacement = self.mapping[matchedText]
|
||||
start, end = match.span()
|
||||
# Remove matches that overlap with a longer match (keep longest per span)
|
||||
def _overlaps(s1, e1, s2, e2):
|
||||
return s1 < e2 and s2 < e1
|
||||
|
||||
def _contained_in_longer(matched_text: str, start: int, end: int) -> bool:
|
||||
for other_text, os, oe in all_matches:
|
||||
if (os, oe) == (start, end):
|
||||
continue
|
||||
if _overlaps(start, end, os, oe) and (oe - os) > (end - start):
|
||||
return True
|
||||
return False
|
||||
|
||||
to_replace = [(t, s, e) for t, s, e in all_matches if not _contained_in_longer(t, s, e)]
|
||||
to_replace = list({(s, e): (t, s, e) for t, s, e in to_replace}.values())
|
||||
|
||||
# Replace from right to left to avoid position shift
|
||||
for matched_text, start, end in sorted(to_replace, key=lambda x: -x[1]):
|
||||
normalized = " ".join(matched_text.split())
|
||||
replacement = (
|
||||
self.mapping.get(matched_text)
|
||||
or self.mapping.get(normalized)
|
||||
or next((v for k, v in self.mapping.items() if " ".join(k.split()) == normalized), None)
|
||||
or next((v for k, v in self.mapping.items() if k.lower() == matched_text.lower()), None)
|
||||
)
|
||||
if not replacement:
|
||||
replacement = f"[name.{uuid.uuid4()}]"
|
||||
self.mapping[matched_text] = replacement
|
||||
text = text[:start] + replacement + text[end:]
|
||||
|
||||
return text
|
||||
|
|
|
|||
|
|
@ -307,14 +307,17 @@ class DataPatterns:
|
|||
name="address",
|
||||
patterns=[
|
||||
# Full address block: company, street, postfach, postal+city (stop before domain like , AXA.ch)
|
||||
r'\b[^,\n]+(?:,\s*[^,\n]+)*,\s*\d{4}\s+[A-Za-zäöüßÄÖÜ]+\s*(?=,\s*[a-zA-Z0-9.-]+\.(?:ch|com|org|net)\b|$)',
|
||||
# Street + house number (standalone)
|
||||
r'\b(?:[A-Za-zäöüßÄÖÜ]+(?:-[A-Za-zäöüßÄÖÜ]+)*(?:strasse|str\.|gasse|weg|platz|allee|boulevard|avenue|via|strada|rue|chemin|route))\s+\d{1,4}(?:/\d{1,4})?(?:[a-z])?\b',
|
||||
# Supports Swiss PLZ (4 digits) and German PLZ (5 digits)
|
||||
r'\b[^,\n]+(?:,\s*[^,\n]+)*,\s*\d{4,5}\s+[A-Za-zäöüßÄÖÜ]+\s*(?=,\s*[a-zA-Z0-9.-]+\.(?:ch|com|org|net)\b|$)',
|
||||
# Street + house number (standalone); includes "straße" for German
|
||||
r'\b(?:[A-Za-zäöüßÄÖÜ]+(?:-[A-Za-zäöüßÄÖÜ]+)*(?:straße|strasse|str\.|gasse|weg|platz|allee|boulevard|avenue|via|strada|rue|chemin|route))\s+\d{1,4}(?:/\d{1,4})?(?:[a-z])?\b',
|
||||
# Postfach / PO Box (standalone)
|
||||
r'\b(?:Postfach|Postbox|P\.?O\.?\s*Box|Case\s+postale|Casella\s+postale|Boîte\s+postale)\s+\d{1,6}\b',
|
||||
# Postal code + city (standalone); exclude year+non-city and common non-city words
|
||||
# (?<!\d{2}\.\d{2}\.) = not part of date DD.MM.YYYY (e.g. 27.01.2026)
|
||||
r'(?<!\d{2}\.\d{2}\.)\b\d{4}\s+(?!den|der|die|das|dem|des|und|oder|für|bei|mit|Version|Versand|Vertrag|Verfügung|Verschickung|Versicherung|erhalten|Schreiben|Jahr|Jahres|incomplete|Application|Complete|Pending|Matrikel|Student|Studien|Kontakt|Telefon|Rechnung|Invoice)[A-Za-zäöüßÄÖÜ]+\b(?!\s*:)'
|
||||
# Exclude business terms (Marketing, Qualitätsmanagement, etc.) – often follow years
|
||||
# Swiss PLZ (4 digits) and German PLZ (5 digits)
|
||||
r'(?<!\d{2}\.\d{2}\.)\b\d{4,5}\s+(?!den|der|die|das|dem|des|und|oder|für|bei|mit|Version|Versand|Vertrag|Verfügung|Verschickung|Versicherung|erhalten|Schreiben|Jahr|Jahres|incomplete|Application|Complete|Pending|Matrikel|Student|Studien|Kontakt|Telefon|Rechnung|Invoice|Marketing|Verkaufsstrategien|Qualitätsmanagement|Management|Strategien|Projektmanagement|Vertrieb|Vertriebsstrategien|Ausbildungsstätte|Realschule)[A-Za-zäöüßÄÖÜ]+\b(?!\s*:)'
|
||||
],
|
||||
replacement_template="[ADDRESS_{}]"
|
||||
),
|
||||
|
|
|
|||
|
|
@ -56,7 +56,16 @@ def neutralize_pdf_in_place(
|
|||
logger.error(f"Failed to open PDF: {e}")
|
||||
return None
|
||||
|
||||
sorted_items = sorted(mapping.items(), key=lambda x: -len(x[0]))
|
||||
# For same placeholder: only search longest original_text to avoid triple overlay
|
||||
# (e.g. "Ida Dittrich", "Ida", "Dittrich" all map to [name.x] → only search "Ida Dittrich")
|
||||
placeholder_to_longest: Dict[str, str] = {}
|
||||
for orig, ph in mapping.items():
|
||||
if not orig or not ph:
|
||||
continue
|
||||
if ph not in placeholder_to_longest or len(orig) > len(placeholder_to_longest[ph]):
|
||||
placeholder_to_longest[ph] = orig
|
||||
filtered = [(orig, ph) for ph, orig in placeholder_to_longest.items()]
|
||||
sorted_items = sorted(filtered, key=lambda x: -len(x[0]))
|
||||
fill_color = (1, 1, 1)
|
||||
text_color = (0, 0, 0)
|
||||
fontname = "helv"
|
||||
|
|
|
|||
|
|
@ -284,7 +284,7 @@ from .datamodelFeatureRealEstate import (
|
|||
Land,
|
||||
DokumentTyp,
|
||||
)
|
||||
from modules.services import getInterface as getServices
|
||||
from modules.serviceHub import getInterface as getServices
|
||||
from .interfaceFeatureRealEstate import getInterface as getRealEstateInterface
|
||||
from modules.interfaces.interfaceDbManagement import getInterface as getComponentInterface
|
||||
from modules.connectors.connectorSwissTopoMapServer import SwissTopoMapServerConnector
|
||||
|
|
|
|||
|
|
@ -843,7 +843,7 @@ async def testVoice(
|
|||
):
|
||||
"""Test TTS voice with AI-generated sample text in the correct language."""
|
||||
from modules.interfaces.interfaceVoiceObjects import getVoiceInterface
|
||||
from modules.services.serviceAi.mainServiceAi import AiService
|
||||
from modules.serviceCenter.services.serviceAi.mainServiceAi import AiService
|
||||
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, PriorityEnum
|
||||
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
|
|
|
|||
|
|
@ -1062,7 +1062,7 @@ class TeamsbotService:
|
|||
|
||||
# Call SPEECH_TEAMS
|
||||
try:
|
||||
from modules.services.serviceAi.mainServiceAi import AiService
|
||||
from modules.serviceCenter.services.serviceAi.mainServiceAi import AiService
|
||||
|
||||
# Create minimal service context for AI billing
|
||||
serviceContext = _ServiceContext(self.currentUser, self.mandateId, self.instanceId)
|
||||
|
|
@ -1684,7 +1684,7 @@ class TeamsbotService:
|
|||
"""Summarize a long user-provided session context to its essential points.
|
||||
This reduces token usage in every subsequent AI call."""
|
||||
try:
|
||||
from modules.services.serviceAi.mainServiceAi import AiService
|
||||
from modules.serviceCenter.services.serviceAi.mainServiceAi import AiService
|
||||
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, PriorityEnum
|
||||
|
||||
serviceContext = _ServiceContext(self.currentUser, self.mandateId, self.instanceId)
|
||||
|
|
@ -1738,7 +1738,7 @@ class TeamsbotService:
|
|||
lines.append(f"[{speaker}]: {text}")
|
||||
textToSummarize = "\n".join(lines)
|
||||
|
||||
from modules.services.serviceAi.mainServiceAi import AiService
|
||||
from modules.serviceCenter.services.serviceAi.mainServiceAi import AiService
|
||||
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, PriorityEnum
|
||||
|
||||
serviceContext = _ServiceContext(self.currentUser, self.mandateId, self.instanceId)
|
||||
|
|
@ -1783,7 +1783,7 @@ class TeamsbotService:
|
|||
for t in transcripts
|
||||
)
|
||||
|
||||
from modules.services.serviceAi.mainServiceAi import AiService
|
||||
from modules.serviceCenter.services.serviceAi.mainServiceAi import AiService
|
||||
|
||||
serviceContext = _ServiceContext(self.currentUser, self.mandateId, self.instanceId)
|
||||
aiService = AiService(serviceCenter=serviceContext)
|
||||
|
|
|
|||
|
|
@ -118,6 +118,13 @@ class BaseAccountingConnector(ABC):
|
|||
"""Load the vendor list. Override in connectors that support it."""
|
||||
return []
|
||||
|
||||
async def getJournalEntries(self, config: Dict[str, Any], dateFrom: Optional[str] = None, dateTo: Optional[str] = None, accountNumbers: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
||||
"""Read journal entries from the external system. Each entry should contain:
|
||||
- externalId, bookingDate, reference, description, currency, totalAmount
|
||||
- lines: list of {accountNumber, debitAmount, creditAmount, currency, taxCode, costCenter, description}
|
||||
accountNumbers: pre-fetched account numbers (avoids redundant API call). Override in connectors that support it."""
|
||||
return []
|
||||
|
||||
async def uploadDocument(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
|
|
|
|||
306
modules/features/trustee/accounting/accountingDataSync.py
Normal file
306
modules/features/trustee/accounting/accountingDataSync.py
Normal file
|
|
@ -0,0 +1,306 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Orchestrates importing accounting data from external systems into TrusteeData* tables.
|
||||
|
||||
Flow: load config → resolve connector → fetch data → clear old records → write new records → compute balances.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from .accountingConnectorBase import BaseAccountingConnector
|
||||
from .accountingRegistry import _getAccountingRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccountingDataSync:
|
||||
"""Imports accounting data (read-only) from an external system into local TrusteeData* tables."""
|
||||
|
||||
def __init__(self, trusteeInterface):
|
||||
self._if = trusteeInterface
|
||||
self._registry = _getAccountingRegistry()
|
||||
|
||||
async def importData(
|
||||
self,
|
||||
featureInstanceId: str,
|
||||
mandateId: str,
|
||||
dateFrom: Optional[str] = None,
|
||||
dateTo: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run a full data import for a feature instance.
|
||||
|
||||
Returns a summary dict with counts per entity and any errors.
|
||||
"""
|
||||
from modules.features.trustee.datamodelFeatureTrustee import (
|
||||
TrusteeAccountingConfig,
|
||||
TrusteeDataAccount,
|
||||
TrusteeDataJournalEntry,
|
||||
TrusteeDataJournalLine,
|
||||
TrusteeDataContact,
|
||||
TrusteeDataAccountBalance,
|
||||
)
|
||||
from modules.shared.configuration import decryptValue
|
||||
|
||||
summary: Dict[str, Any] = {
|
||||
"accounts": 0,
|
||||
"journalEntries": 0,
|
||||
"journalLines": 0,
|
||||
"contacts": 0,
|
||||
"accountBalances": 0,
|
||||
"errors": [],
|
||||
"startedAt": time.time(),
|
||||
}
|
||||
|
||||
cfgRecords = self._if.db.getRecordset(
|
||||
TrusteeAccountingConfig,
|
||||
recordFilter={"featureInstanceId": featureInstanceId, "isActive": True},
|
||||
)
|
||||
if not cfgRecords:
|
||||
summary["errors"].append("No active accounting configuration found")
|
||||
return summary
|
||||
|
||||
cfgRecord = cfgRecords[0]
|
||||
connectorType = cfgRecord.get("connectorType", "")
|
||||
encryptedConfig = cfgRecord.get("encryptedConfig", "")
|
||||
|
||||
try:
|
||||
import json
|
||||
plainJson = decryptValue(encryptedConfig)
|
||||
connConfig = json.loads(plainJson) if plainJson else {}
|
||||
except Exception as e:
|
||||
summary["errors"].append(f"Failed to decrypt config: {e}")
|
||||
return summary
|
||||
|
||||
connector = self._registry.getConnector(connectorType)
|
||||
if not connector:
|
||||
summary["errors"].append(f"Unknown connector type: {connectorType}")
|
||||
return summary
|
||||
|
||||
scope = {"featureInstanceId": featureInstanceId, "mandateId": mandateId}
|
||||
logger.info(f"AccountingDataSync starting for {featureInstanceId}, connector={connectorType}, dateFrom={dateFrom}, dateTo={dateTo}")
|
||||
fetchedAccountNumbers: list = []
|
||||
|
||||
# 1) Chart of accounts
|
||||
try:
|
||||
charts = await connector.getChartOfAccounts(connConfig)
|
||||
fetchedAccountNumbers = [acc.accountNumber for acc in charts if acc.accountNumber]
|
||||
self._clearTable(TrusteeDataAccount, featureInstanceId)
|
||||
for acc in charts:
|
||||
self._if.db.recordCreate(TrusteeDataAccount, {
|
||||
"accountNumber": acc.accountNumber,
|
||||
"label": acc.label,
|
||||
"accountType": acc.accountType or "",
|
||||
"currency": "CHF",
|
||||
"isActive": True,
|
||||
**scope,
|
||||
})
|
||||
summary["accounts"] = len(charts)
|
||||
except Exception as e:
|
||||
logger.error(f"Import accounts failed: {e}", exc_info=True)
|
||||
summary["errors"].append(f"Accounts: {e}")
|
||||
|
||||
# 2) Journal entries + lines (pass already-fetched chart to avoid redundant API call)
|
||||
try:
|
||||
rawEntries = await connector.getJournalEntries(connConfig, dateFrom=dateFrom, dateTo=dateTo, accountNumbers=fetchedAccountNumbers or None)
|
||||
self._clearTable(TrusteeDataJournalEntry, featureInstanceId)
|
||||
self._clearTable(TrusteeDataJournalLine, featureInstanceId)
|
||||
lineCount = 0
|
||||
for raw in rawEntries:
|
||||
import uuid
|
||||
entryId = str(uuid.uuid4())
|
||||
self._if.db.recordCreate(TrusteeDataJournalEntry, {
|
||||
"id": entryId,
|
||||
"externalId": raw.get("externalId"),
|
||||
"bookingDate": raw.get("bookingDate"),
|
||||
"reference": raw.get("reference"),
|
||||
"description": raw.get("description", ""),
|
||||
"currency": raw.get("currency", "CHF"),
|
||||
"totalAmount": float(raw.get("totalAmount", 0)),
|
||||
**scope,
|
||||
})
|
||||
for line in (raw.get("lines") or []):
|
||||
self._if.db.recordCreate(TrusteeDataJournalLine, {
|
||||
"journalEntryId": entryId,
|
||||
"accountNumber": line.get("accountNumber", ""),
|
||||
"debitAmount": float(line.get("debitAmount", 0)),
|
||||
"creditAmount": float(line.get("creditAmount", 0)),
|
||||
"currency": line.get("currency", "CHF"),
|
||||
"taxCode": line.get("taxCode"),
|
||||
"costCenter": line.get("costCenter"),
|
||||
"description": line.get("description", ""),
|
||||
**scope,
|
||||
})
|
||||
lineCount += 1
|
||||
summary["journalEntries"] = len(rawEntries)
|
||||
summary["journalLines"] = lineCount
|
||||
except Exception as e:
|
||||
logger.error(f"Import journal entries failed: {e}")
|
||||
summary["errors"].append(f"Journal entries: {e}")
|
||||
|
||||
# 3) Contacts (customers + vendors)
|
||||
try:
|
||||
self._clearTable(TrusteeDataContact, featureInstanceId)
|
||||
contactCount = 0
|
||||
|
||||
customers = await connector.getCustomers(connConfig)
|
||||
for c in customers:
|
||||
self._if.db.recordCreate(TrusteeDataContact, self._mapContact(c, "customer", scope))
|
||||
contactCount += 1
|
||||
|
||||
vendors = await connector.getVendors(connConfig)
|
||||
for v in vendors:
|
||||
self._if.db.recordCreate(TrusteeDataContact, self._mapContact(v, "vendor", scope))
|
||||
contactCount += 1
|
||||
|
||||
summary["contacts"] = contactCount
|
||||
except Exception as e:
|
||||
logger.error(f"Import contacts failed: {e}", exc_info=True)
|
||||
summary["errors"].append(f"Contacts: {e}")
|
||||
|
||||
# 4) Compute account balances from journal lines
|
||||
try:
|
||||
self._clearTable(TrusteeDataAccountBalance, featureInstanceId)
|
||||
balanceCount = self._computeBalances(featureInstanceId, mandateId)
|
||||
summary["accountBalances"] = balanceCount
|
||||
except Exception as e:
|
||||
logger.error(f"Compute balances failed: {e}")
|
||||
summary["errors"].append(f"Balances: {e}")
|
||||
|
||||
# Update config with last import timestamp
|
||||
try:
|
||||
cfgId = cfgRecord.get("id")
|
||||
if cfgId:
|
||||
self._if.db.recordModify(TrusteeAccountingConfig, cfgId, {
|
||||
"lastSyncAt": time.time(),
|
||||
"lastSyncStatus": "success" if not summary["errors"] else "partial",
|
||||
"lastSyncErrorMessage": "; ".join(summary["errors"])[:500] if summary["errors"] else None,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
summary["finishedAt"] = time.time()
|
||||
summary["durationSeconds"] = round(summary["finishedAt"] - summary["startedAt"], 1)
|
||||
logger.info(
|
||||
f"AccountingDataSync completed for {featureInstanceId}: "
|
||||
f"{summary['accounts']} accounts, {summary['journalEntries']} entries, "
|
||||
f"{summary['journalLines']} lines, {summary['contacts']} contacts, "
|
||||
f"{summary['accountBalances']} balances, {len(summary['errors'])} errors, "
|
||||
f"{summary['durationSeconds']}s"
|
||||
)
|
||||
return summary
|
||||
|
||||
@staticmethod
|
||||
def _safeStr(val: Any) -> str:
|
||||
"""Convert a value to a safe string for DB storage, collapsing nested dicts/lists."""
|
||||
if val is None:
|
||||
return ""
|
||||
if isinstance(val, (dict, list)):
|
||||
return ""
|
||||
return str(val)
|
||||
|
||||
def _mapContact(self, raw: Dict[str, Any], contactType: str, scope: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Extract contact fields from a raw API dict, handling varying field names across connectors."""
|
||||
s = self._safeStr
|
||||
return {
|
||||
"externalId": s(raw.get("id") or raw.get("Id") or raw.get("customer_nr") or raw.get("vendor_nr") or ""),
|
||||
"contactType": contactType,
|
||||
"contactNumber": s(
|
||||
raw.get("customernumber") or raw.get("customer_nr")
|
||||
or raw.get("vendornumber") or raw.get("vendor_nr")
|
||||
or raw.get("nr") or raw.get("ContactNumber")
|
||||
or raw.get("id") or ""
|
||||
),
|
||||
"name": s(raw.get("name") or raw.get("Name") or raw.get("name_1") or ""),
|
||||
"address": s(raw.get("addr1") or raw.get("address") or raw.get("Address") or ""),
|
||||
"zip": s(raw.get("zipcode") or raw.get("postcode") or raw.get("Zip") or raw.get("zip") or ""),
|
||||
"city": s(raw.get("city") or raw.get("City") or ""),
|
||||
"country": s(raw.get("country") or raw.get("country_id") or raw.get("Country") or ""),
|
||||
"email": s(raw.get("email") or raw.get("mail") or raw.get("Email") or ""),
|
||||
"phone": s(raw.get("phone") or raw.get("phone_fixed") or raw.get("Phone") or ""),
|
||||
"vatNumber": s(raw.get("vat_identifier") or raw.get("vatNumber") or ""),
|
||||
**scope,
|
||||
}
|
||||
|
||||
def _clearTable(self, model, featureInstanceId: str):
|
||||
"""Delete all records for this feature instance from a TrusteeData* table."""
|
||||
records = self._if.db.getRecordset(model, recordFilter={"featureInstanceId": featureInstanceId})
|
||||
for r in (records or []):
|
||||
rid = r.get("id") if isinstance(r, dict) else getattr(r, "id", None)
|
||||
if rid:
|
||||
try:
|
||||
self._if.db.recordDelete(model, rid)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _computeBalances(self, featureInstanceId: str, mandateId: str) -> int:
|
||||
"""Aggregate journal lines into monthly + annual account balances."""
|
||||
from modules.features.trustee.datamodelFeatureTrustee import (
|
||||
TrusteeDataJournalEntry,
|
||||
TrusteeDataJournalLine,
|
||||
TrusteeDataAccountBalance,
|
||||
)
|
||||
|
||||
entries = self._if.db.getRecordset(
|
||||
TrusteeDataJournalEntry,
|
||||
recordFilter={"featureInstanceId": featureInstanceId},
|
||||
) or []
|
||||
entryDates = {}
|
||||
for e in entries:
|
||||
eid = e.get("id") if isinstance(e, dict) else getattr(e, "id", None)
|
||||
bdate = e.get("bookingDate") if isinstance(e, dict) else getattr(e, "bookingDate", None)
|
||||
if eid and bdate:
|
||||
entryDates[eid] = bdate
|
||||
|
||||
lines = self._if.db.getRecordset(
|
||||
TrusteeDataJournalLine,
|
||||
recordFilter={"featureInstanceId": featureInstanceId},
|
||||
) or []
|
||||
|
||||
# key: (accountNumber, year, month)
|
||||
buckets: Dict[tuple, Dict[str, float]] = defaultdict(lambda: {"debit": 0.0, "credit": 0.0})
|
||||
for ln in lines:
|
||||
if isinstance(ln, dict):
|
||||
jeid = ln.get("journalEntryId", "")
|
||||
accNo = ln.get("accountNumber", "")
|
||||
debit = float(ln.get("debitAmount", 0))
|
||||
credit = float(ln.get("creditAmount", 0))
|
||||
else:
|
||||
jeid = getattr(ln, "journalEntryId", "")
|
||||
accNo = getattr(ln, "accountNumber", "")
|
||||
debit = float(getattr(ln, "debitAmount", 0))
|
||||
credit = float(getattr(ln, "creditAmount", 0))
|
||||
|
||||
bdate = entryDates.get(jeid, "")
|
||||
if not accNo or not bdate:
|
||||
continue
|
||||
parts = bdate.split("-")
|
||||
if len(parts) < 2:
|
||||
continue
|
||||
year = int(parts[0])
|
||||
month = int(parts[1])
|
||||
|
||||
buckets[(accNo, year, month)]["debit"] += debit
|
||||
buckets[(accNo, year, month)]["credit"] += credit
|
||||
buckets[(accNo, year, 0)]["debit"] += debit
|
||||
buckets[(accNo, year, 0)]["credit"] += credit
|
||||
|
||||
count = 0
|
||||
scope = {"featureInstanceId": featureInstanceId, "mandateId": mandateId}
|
||||
for (accNo, year, month), totals in buckets.items():
|
||||
closing = totals["debit"] - totals["credit"]
|
||||
self._if.db.recordCreate(TrusteeDataAccountBalance, {
|
||||
"accountNumber": accNo,
|
||||
"periodYear": year,
|
||||
"periodMonth": month,
|
||||
"openingBalance": 0.0,
|
||||
"debitTotal": round(totals["debit"], 2),
|
||||
"creditTotal": round(totals["credit"], 2),
|
||||
"closingBalance": round(closing, 2),
|
||||
"currency": "CHF",
|
||||
**scope,
|
||||
})
|
||||
count += 1
|
||||
return count
|
||||
|
|
@ -40,15 +40,15 @@ class AccountingConnectorAbacus(BaseAccountingConnector):
|
|||
def getRequiredConfigFields(self) -> List[ConnectorConfigField]:
|
||||
return [
|
||||
ConnectorConfigField(
|
||||
key="abacusHost",
|
||||
label={"en": "Abacus Host URL", "de": "Abacus Host-URL", "fr": "URL Hôte Abacus"},
|
||||
key="apiBaseUrl",
|
||||
label={"en": "API Base URL", "de": "API Base URL", "fr": "URL de base API"},
|
||||
fieldType="text",
|
||||
secret=False,
|
||||
placeholder="e.g. abacus.meinefirma.ch",
|
||||
placeholder="e.g. https://abacus.meinefirma.ch/api/entity/v1/",
|
||||
),
|
||||
ConnectorConfigField(
|
||||
key="mandant",
|
||||
label={"en": "Mandant Number", "de": "Mandantennummer", "fr": "Numéro de mandant"},
|
||||
key="clientName",
|
||||
label={"en": "Client Name", "de": "Mandantenname", "fr": "Nom du client"},
|
||||
fieldType="text",
|
||||
secret=False,
|
||||
placeholder="e.g. 7777",
|
||||
|
|
@ -68,22 +68,37 @@ class AccountingConnectorAbacus(BaseAccountingConnector):
|
|||
]
|
||||
|
||||
def _buildBaseUrl(self, config: Dict[str, Any]) -> str:
|
||||
host = config["abacusHost"].rstrip("/")
|
||||
if not host.startswith("http"):
|
||||
host = f"https://{host}"
|
||||
return host
|
||||
apiBaseUrl = str(config.get("apiBaseUrl") or "").strip()
|
||||
if not apiBaseUrl:
|
||||
raise ValueError("Missing required config: apiBaseUrl")
|
||||
if not apiBaseUrl.startswith("http"):
|
||||
apiBaseUrl = f"https://{apiBaseUrl}"
|
||||
return apiBaseUrl.rstrip("/")
|
||||
|
||||
def _buildAuthBaseUrl(self, config: Dict[str, Any]) -> str:
|
||||
apiBaseUrl = str(config.get("apiBaseUrl") or "").strip()
|
||||
if not apiBaseUrl:
|
||||
raise ValueError("Missing required config: apiBaseUrl")
|
||||
if not apiBaseUrl.startswith("http"):
|
||||
apiBaseUrl = f"https://{apiBaseUrl}"
|
||||
apiBaseUrl = apiBaseUrl.rstrip("/")
|
||||
if "/api/entity/v1" in apiBaseUrl:
|
||||
return apiBaseUrl.split("/api/entity/v1", 1)[0]
|
||||
if "/api/" in apiBaseUrl:
|
||||
return apiBaseUrl.split("/api/", 1)[0]
|
||||
return apiBaseUrl
|
||||
|
||||
async def _getAccessToken(self, config: Dict[str, Any]) -> Optional[str]:
|
||||
"""Obtain an OAuth access token using client_credentials grant.
|
||||
|
||||
Tokens are cached and refreshed when expired (default 600s).
|
||||
"""
|
||||
cacheKey = f"{config.get('abacusHost')}_{config.get('clientId')}"
|
||||
cacheKey = f"{config.get('apiBaseUrl')}_{config.get('clientName')}_{config.get('clientId')}"
|
||||
cached = self._tokenCache.get(cacheKey)
|
||||
if cached and cached.get("expiresAt", 0) > time.time() + 30:
|
||||
return cached["accessToken"]
|
||||
|
||||
baseUrl = self._buildBaseUrl(config)
|
||||
baseUrl = self._buildAuthBaseUrl(config)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
|
@ -120,8 +135,10 @@ class AccountingConnectorAbacus(BaseAccountingConnector):
|
|||
|
||||
def _buildEntityUrl(self, config: Dict[str, Any], entity: str) -> str:
|
||||
baseUrl = self._buildBaseUrl(config)
|
||||
mandant = config["mandant"]
|
||||
return f"{baseUrl}/api/entity/v1/{mandant}/{entity}"
|
||||
clientName = config.get("clientName")
|
||||
if not clientName:
|
||||
raise ValueError("Missing required config: clientName")
|
||||
return f"{baseUrl}/{clientName}/{entity}"
|
||||
|
||||
async def _buildAuthHeaders(self, config: Dict[str, Any]) -> Optional[Dict[str, str]]:
|
||||
token = await self._getAccessToken(config)
|
||||
|
|
@ -130,6 +147,19 @@ class AccountingConnectorAbacus(BaseAccountingConnector):
|
|||
return {"Authorization": f"Bearer {token}", "Accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
async def testConnection(self, config: Dict[str, Any]) -> SyncResult:
|
||||
apiBaseUrl = str(config.get("apiBaseUrl") or "")
|
||||
clientName = str(config.get("clientName") or "")
|
||||
clientId = str(config.get("clientId") or "")
|
||||
clientSecret = str(config.get("clientSecret") or "")
|
||||
if not apiBaseUrl or not clientName or not clientId or not clientSecret:
|
||||
return SyncResult(
|
||||
success=False,
|
||||
errorMessage=(
|
||||
f"Missing credentials: apiBaseUrl={bool(apiBaseUrl)}, "
|
||||
f"clientName={bool(clientName)}, clientId={bool(clientId)}, "
|
||||
f"clientSecret={bool(clientSecret)}"
|
||||
),
|
||||
)
|
||||
headers = await self._buildAuthHeaders(config)
|
||||
if not headers:
|
||||
return SyncResult(success=False, errorMessage="Failed to obtain access token")
|
||||
|
|
@ -225,6 +255,60 @@ class AccountingConnectorAbacus(BaseAccountingConnector):
|
|||
except Exception as e:
|
||||
return SyncResult(success=False, errorMessage=str(e))
|
||||
|
||||
async def getJournalEntries(self, config: Dict[str, Any], dateFrom: Optional[str] = None, dateTo: Optional[str] = None, accountNumbers: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
||||
"""Read GeneralJournalEntries from Abacus (OData V4, paginated)."""
|
||||
headers = await self._buildAuthHeaders(config)
|
||||
if not headers:
|
||||
return []
|
||||
|
||||
filterParts = []
|
||||
if dateFrom:
|
||||
filterParts.append(f"JournalDate ge {dateFrom}")
|
||||
if dateTo:
|
||||
filterParts.append(f"JournalDate le {dateTo}")
|
||||
queryParams = ""
|
||||
if filterParts:
|
||||
queryParams = "?$filter=" + " and ".join(filterParts)
|
||||
|
||||
entries: List[Dict[str, Any]] = []
|
||||
url: Optional[str] = self._buildEntityUrl(config, f"GeneralJournalEntries{queryParams}")
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
while url:
|
||||
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=60)) as resp:
|
||||
if resp.status != 200:
|
||||
break
|
||||
data = await resp.json()
|
||||
|
||||
for item in data.get("value", []):
|
||||
lines = []
|
||||
totalAmt = 0.0
|
||||
for line in (item.get("Lines") or []):
|
||||
debit = float(line.get("DebitAmount", 0))
|
||||
credit = float(line.get("CreditAmount", 0))
|
||||
lines.append({
|
||||
"accountNumber": str(line.get("AccountId", "")),
|
||||
"debitAmount": debit,
|
||||
"creditAmount": credit,
|
||||
"description": line.get("Text", ""),
|
||||
"taxCode": line.get("TaxCode"),
|
||||
"costCenter": line.get("CostCenterId"),
|
||||
})
|
||||
totalAmt += max(debit, credit)
|
||||
entries.append({
|
||||
"externalId": str(item.get("Id", "")),
|
||||
"bookingDate": str(item.get("JournalDate", "")).split("T")[0],
|
||||
"reference": item.get("Reference", ""),
|
||||
"description": item.get("Text", ""),
|
||||
"currency": "CHF",
|
||||
"totalAmount": totalAmt,
|
||||
"lines": lines,
|
||||
})
|
||||
url = data.get("@odata.nextLink")
|
||||
except Exception as e:
|
||||
logger.error(f"Abacus getJournalEntries error: {e}")
|
||||
return entries
|
||||
|
||||
async def getCustomers(self, config: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
headers = await self._buildAuthHeaders(config)
|
||||
if not headers:
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from ..accountingConnectorBase import (
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BASE_URL = "https://api.bexio.com"
|
||||
_DEFAULT_API_BASE_URL = "https://api.bexio.com/"
|
||||
|
||||
|
||||
class AccountingConnectorBexio(BaseAccountingConnector):
|
||||
|
|
@ -40,6 +40,20 @@ class AccountingConnectorBexio(BaseAccountingConnector):
|
|||
|
||||
def getRequiredConfigFields(self) -> List[ConnectorConfigField]:
|
||||
return [
|
||||
ConnectorConfigField(
|
||||
key="apiBaseUrl",
|
||||
label={"en": "API Base URL", "de": "API Base URL", "fr": "URL de base API"},
|
||||
fieldType="text",
|
||||
secret=False,
|
||||
placeholder="https://api.bexio.com/",
|
||||
),
|
||||
ConnectorConfigField(
|
||||
key="clientName",
|
||||
label={"en": "Client Name", "de": "Mandantenname", "fr": "Nom du client"},
|
||||
fieldType="text",
|
||||
secret=False,
|
||||
placeholder="e.g. poweronag",
|
||||
),
|
||||
ConnectorConfigField(
|
||||
key="accessToken",
|
||||
label={"en": "Personal Access Token", "de": "Persönlicher Zugriffstoken", "fr": "Jeton d'accès personnel"},
|
||||
|
|
@ -49,6 +63,14 @@ class AccountingConnectorBexio(BaseAccountingConnector):
|
|||
),
|
||||
]
|
||||
|
||||
def _buildUrl(self, config: Dict[str, Any], resource: str) -> str:
|
||||
apiBaseUrl = str(config.get("apiBaseUrl") or "").strip()
|
||||
if not apiBaseUrl:
|
||||
raise ValueError("Missing required config: apiBaseUrl")
|
||||
apiBaseUrl = apiBaseUrl.rstrip("/")
|
||||
resourcePath = resource.lstrip("/")
|
||||
return f"{apiBaseUrl}/{resourcePath}"
|
||||
|
||||
def _buildHeaders(self, config: Dict[str, Any]) -> Dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {config['accessToken']}",
|
||||
|
|
@ -57,9 +79,20 @@ class AccountingConnectorBexio(BaseAccountingConnector):
|
|||
}
|
||||
|
||||
async def testConnection(self, config: Dict[str, Any]) -> SyncResult:
|
||||
apiBaseUrl = str(config.get("apiBaseUrl") or "")
|
||||
clientName = str(config.get("clientName") or "")
|
||||
accessToken = str(config.get("accessToken") or "")
|
||||
if not apiBaseUrl or not clientName or not accessToken:
|
||||
return SyncResult(
|
||||
success=False,
|
||||
errorMessage=(
|
||||
f"Missing credentials: apiBaseUrl={bool(apiBaseUrl)}, "
|
||||
f"clientName={bool(clientName)}, accessToken={bool(accessToken)}"
|
||||
),
|
||||
)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{_BASE_URL}/3.0/users/me", headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=15)) as resp:
|
||||
async with session.get(self._buildUrl(config, "3.0/users/me"), headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=15)) as resp:
|
||||
if resp.status == 200:
|
||||
return SyncResult(success=True)
|
||||
body = await resp.text()
|
||||
|
|
@ -75,7 +108,7 @@ class AccountingConnectorBexio(BaseAccountingConnector):
|
|||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{_BASE_URL}/2.0/accounts", headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
async with session.get(self._buildUrl(config, "2.0/accounts"), headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status != 200:
|
||||
return []
|
||||
accounts = await resp.json()
|
||||
|
|
@ -139,7 +172,7 @@ class AccountingConnectorBexio(BaseAccountingConnector):
|
|||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
url = f"{_BASE_URL}/3.0/accounting/manual-entries"
|
||||
url = self._buildUrl(config, "3.0/accounting/manual-entries")
|
||||
async with session.post(url, headers=self._buildHeaders(config), json=payload, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
body = await resp.json() if resp.content_type == "application/json" else {"raw": await resp.text()}
|
||||
if resp.status in (200, 201):
|
||||
|
|
@ -152,7 +185,7 @@ class AccountingConnectorBexio(BaseAccountingConnector):
|
|||
async def getBookingStatus(self, config: Dict[str, Any], externalId: str) -> SyncResult:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
url = f"{_BASE_URL}/3.0/accounting/manual-entries/{externalId}"
|
||||
url = self._buildUrl(config, f"3.0/accounting/manual-entries/{externalId}")
|
||||
async with session.get(url, headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=15)) as resp:
|
||||
if resp.status == 200:
|
||||
return SyncResult(success=True, externalId=externalId)
|
||||
|
|
@ -160,10 +193,66 @@ class AccountingConnectorBexio(BaseAccountingConnector):
|
|||
except Exception as e:
|
||||
return SyncResult(success=False, errorMessage=str(e))
|
||||
|
||||
async def getJournalEntries(self, config: Dict[str, Any], dateFrom: Optional[str] = None, dateTo: Optional[str] = None, accountNumbers: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
||||
"""Read manual entries from Bexio. API: GET 3.0/accounting/manual-entries"""
|
||||
try:
|
||||
accounts = await self._loadRawAccounts(config)
|
||||
accMap = {acc.get("id"): str(acc.get("account_no", "")) for acc in accounts}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
url = self._buildUrl(config, "3.0/accounting/manual-entries")
|
||||
params: Dict[str, str] = {}
|
||||
if dateFrom:
|
||||
params["date_from"] = dateFrom
|
||||
if dateTo:
|
||||
params["date_to"] = dateTo
|
||||
async with session.get(url, headers=self._buildHeaders(config), params=params, timeout=aiohttp.ClientTimeout(total=60)) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error(f"Bexio getJournalEntries failed: HTTP {resp.status}")
|
||||
return []
|
||||
items = await resp.json()
|
||||
|
||||
entries = []
|
||||
for item in (items if isinstance(items, list) else []):
|
||||
lines = []
|
||||
totalAmt = 0.0
|
||||
for e in (item.get("entries") or []):
|
||||
amt = float(e.get("amount", 0))
|
||||
debitAccId = e.get("debit_account_id")
|
||||
creditAccId = e.get("credit_account_id")
|
||||
lines.append({
|
||||
"accountNumber": accMap.get(debitAccId, str(debitAccId or "")),
|
||||
"debitAmount": amt,
|
||||
"creditAmount": 0.0,
|
||||
"description": e.get("description", ""),
|
||||
"taxCode": str(e.get("tax_id", "")) if e.get("tax_id") else None,
|
||||
})
|
||||
if creditAccId and creditAccId != debitAccId:
|
||||
lines.append({
|
||||
"accountNumber": accMap.get(creditAccId, str(creditAccId or "")),
|
||||
"debitAmount": 0.0,
|
||||
"creditAmount": amt,
|
||||
"description": e.get("description", ""),
|
||||
})
|
||||
totalAmt += amt
|
||||
entries.append({
|
||||
"externalId": str(item.get("id", "")),
|
||||
"bookingDate": item.get("date", ""),
|
||||
"reference": item.get("reference_nr", ""),
|
||||
"description": item.get("text", ""),
|
||||
"currency": "CHF",
|
||||
"totalAmount": totalAmt,
|
||||
"lines": lines,
|
||||
})
|
||||
return entries
|
||||
except Exception as e:
|
||||
logger.error(f"Bexio getJournalEntries error: {e}")
|
||||
return []
|
||||
|
||||
async def getCustomers(self, config: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{_BASE_URL}/2.0/contact", headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
async with session.get(self._buildUrl(config, "2.0/contact"), headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status != 200:
|
||||
return []
|
||||
return await resp.json()
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@
|
|||
"""Run My Accounts (Infoniqa) accounting connector.
|
||||
|
||||
API docs: https://runmyaccountsag.github.io/runmyaccounts-rest-api/
|
||||
Auth: API key (incl. ``pat_`` tokens since Sep 2025) via ``X-RMA-KEY`` request header.
|
||||
Auth: PAT tokens (``pat_...``) via ``Authorization: Bearer``.
|
||||
Fallback for legacy API keys via ``X-RMA-KEY``.
|
||||
Base URL: https://service.runmyaccounts.com/api/latest/clients/{clientName}/
|
||||
"""
|
||||
|
||||
|
|
@ -26,7 +27,7 @@ from ..accountingConnectorBase import (
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BASE_URL = "https://service.runmyaccounts.com/api/latest/clients"
|
||||
_DEFAULT_API_BASE_URL = "https://service.runmyaccounts.com/api/latest/clients/"
|
||||
|
||||
|
||||
class AccountingConnectorRma(BaseAccountingConnector):
|
||||
|
|
@ -39,6 +40,13 @@ class AccountingConnectorRma(BaseAccountingConnector):
|
|||
|
||||
def getRequiredConfigFields(self) -> List[ConnectorConfigField]:
|
||||
return [
|
||||
ConnectorConfigField(
|
||||
key="apiBaseUrl",
|
||||
label={"en": "API Base URL", "de": "API Base URL", "fr": "URL de base API"},
|
||||
fieldType="text",
|
||||
secret=False,
|
||||
placeholder="https://service.runmyaccounts.com/api/latest/clients/",
|
||||
),
|
||||
ConnectorConfigField(
|
||||
key="clientName",
|
||||
label={"en": "Client Name", "de": "Mandantenname", "fr": "Nom du client"},
|
||||
|
|
@ -55,33 +63,55 @@ class AccountingConnectorRma(BaseAccountingConnector):
|
|||
]
|
||||
|
||||
def _buildUrl(self, config: Dict[str, Any], resource: str) -> str:
|
||||
clientName = config.get("clientName", "")
|
||||
return f"{_BASE_URL}/{clientName}/{resource}"
|
||||
apiBaseUrl = str(config.get("apiBaseUrl") or "").strip()
|
||||
if not apiBaseUrl:
|
||||
raise ValueError("Missing required config: apiBaseUrl")
|
||||
apiBaseUrl = apiBaseUrl.rstrip("/") + "/"
|
||||
|
||||
clientName = str(config.get("clientName") or "").strip()
|
||||
if not clientName:
|
||||
raise ValueError("Missing required config: clientName")
|
||||
return f"{apiBaseUrl}{clientName}/{resource}"
|
||||
|
||||
def _buildHeaders(self, config: Dict[str, Any]) -> Dict[str, str]:
|
||||
apiKey = config.get("apiKey", "")
|
||||
return {
|
||||
"X-RMA-KEY": apiKey,
|
||||
headers = {
|
||||
"Accept": "application/json, application/xml, */*",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if str(apiKey).startswith("pat_"):
|
||||
headers["Authorization"] = f"Bearer {apiKey}"
|
||||
else:
|
||||
headers["X-RMA-KEY"] = apiKey
|
||||
return headers
|
||||
|
||||
async def testConnection(self, config: Dict[str, Any]) -> SyncResult:
|
||||
clientName = config.get("clientName", "")
|
||||
apiKey = config.get("apiKey", "")
|
||||
|
||||
if not clientName or not apiKey:
|
||||
return SyncResult(success=False, errorMessage=f"Missing credentials: clientName={bool(clientName)}, apiKey={bool(apiKey)}")
|
||||
apiBaseUrl = str(config.get("apiBaseUrl") or "")
|
||||
if not clientName or not apiKey or not apiBaseUrl:
|
||||
return SyncResult(
|
||||
success=False,
|
||||
errorMessage=(
|
||||
f"Missing credentials: apiBaseUrl={bool(apiBaseUrl)}, "
|
||||
f"clientName={bool(clientName)}, apiKey={bool(apiKey)}"
|
||||
),
|
||||
)
|
||||
|
||||
url = self._buildUrl(config, "customers")
|
||||
headers = self._buildHeaders(config)
|
||||
logger.info("RMA testConnection: url=%s, clientName=%s, apiKey=%s...", url, clientName, apiKey[:6] if len(apiKey) > 6 else "***")
|
||||
authMethod = "Bearer" if str(apiKey).startswith("pat_") else "X-RMA-KEY"
|
||||
logger.info(
|
||||
"RMA testConnection: url=%s, clientName=%s, apiKey=%s..., auth=%s",
|
||||
url, clientName, apiKey[:6] if len(apiKey) > 6 else "***", authMethod,
|
||||
)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=15)) as resp:
|
||||
if resp.status == 200:
|
||||
logger.info("RMA connection successful")
|
||||
return SyncResult(success=True)
|
||||
logger.info("RMA connection successful with auth method: %s", authMethod)
|
||||
return SyncResult(success=True, rawResponse={"authMethod": authMethod})
|
||||
body = await resp.text()
|
||||
logger.warning("RMA testConnection failed: status=%s, url=%s, body=%s", resp.status, url, body[:500])
|
||||
return SyncResult(success=False, errorMessage=f"HTTP {resp.status}: {body[:300]}")
|
||||
|
|
@ -120,11 +150,11 @@ class AccountingConnectorRma(BaseAccountingConnector):
|
|||
charts = []
|
||||
items = data if isinstance(data, list) else data.get("chart", data.get("row", []))
|
||||
if not isinstance(items, list):
|
||||
items = []
|
||||
items = [items] if isinstance(items, dict) else []
|
||||
for item in items:
|
||||
if isinstance(item, dict):
|
||||
accNo = str(item.get("accno", item.get("account_number", "")))
|
||||
label = str(item.get("description", item.get("label", "")))
|
||||
accNo = str(item.get("accno") or item.get("account_number") or item.get("number") or item.get("@accno") or "")
|
||||
label = str(item.get("description") or item.get("label") or item.get("@description") or "")
|
||||
rmaLink = item.get("link") or ""
|
||||
chartType = item.get("charttype") or item.get("category") or ""
|
||||
if not chartType and rmaLink:
|
||||
|
|
@ -308,6 +338,169 @@ class AccountingConnectorRma(BaseAccountingConnector):
|
|||
logger.debug("RMA isBookingSynced error: %s – trust local", e)
|
||||
return SyncResult(success=True)
|
||||
|
||||
async def getJournalEntries(self, config: Dict[str, Any], dateFrom: Optional[str] = None, dateTo: Optional[str] = None, accountNumbers: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
||||
"""Read GL entries from RMA.
|
||||
|
||||
Strategy: first try GET /gl (bulk), then fall back to iterating
|
||||
account transactions. Uses pre-fetched accountNumbers if provided.
|
||||
"""
|
||||
try:
|
||||
params: Dict[str, str] = {}
|
||||
if dateFrom:
|
||||
params["from_date"] = dateFrom
|
||||
if dateTo:
|
||||
params["to_date"] = dateTo
|
||||
|
||||
# Try bulk GL endpoint first
|
||||
bulkEntries = await self._fetchGlBulk(config, params)
|
||||
if bulkEntries:
|
||||
return bulkEntries
|
||||
|
||||
# Fallback: iterate accounts and fetch transactions
|
||||
if accountNumbers:
|
||||
accNums = accountNumbers
|
||||
else:
|
||||
chart = await self.getChartOfAccounts(config)
|
||||
accNums = [acc.accountNumber for acc in chart if acc.accountNumber]
|
||||
if not accNums:
|
||||
return []
|
||||
|
||||
entriesByRef: Dict[str, Dict[str, Any]] = {}
|
||||
fetchedCount = 0
|
||||
emptyCount = 0
|
||||
errorCount = 0
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for accNo in accNums:
|
||||
url = self._buildUrl(config, f"charts/{accNo}/transactions")
|
||||
try:
|
||||
async with session.get(url, headers=self._buildHeaders(config), params=params, timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
||||
if resp.status != 200:
|
||||
emptyCount += 1
|
||||
continue
|
||||
body = await resp.text()
|
||||
if not body.strip():
|
||||
emptyCount += 1
|
||||
continue
|
||||
try:
|
||||
data = json.loads(body)
|
||||
except Exception:
|
||||
errorCount += 1
|
||||
continue
|
||||
except (asyncio.TimeoutError, Exception):
|
||||
errorCount += 1
|
||||
continue
|
||||
fetchedCount += 1
|
||||
|
||||
if isinstance(data, dict):
|
||||
transactions = data.get("transaction") or data.get("@transaction")
|
||||
else:
|
||||
transactions = data
|
||||
if isinstance(transactions, dict):
|
||||
transactions = [transactions]
|
||||
if not isinstance(transactions, list):
|
||||
continue
|
||||
|
||||
for t in transactions:
|
||||
if not isinstance(t, dict):
|
||||
continue
|
||||
ref = t.get("reference") or t.get("@reference") or t.get("batch_number") or str(t.get("id") or "")
|
||||
transDate = str(t.get("transdate") or t.get("@transdate") or "").split("T")[0]
|
||||
desc = t.get("description") or t.get("memo") or t.get("@description") or ""
|
||||
|
||||
rawAmount = float(t.get("amount") or t.get("@amount") or 0)
|
||||
debit = rawAmount if rawAmount > 0 else 0.0
|
||||
credit = abs(rawAmount) if rawAmount < 0 else 0.0
|
||||
|
||||
if ref not in entriesByRef:
|
||||
entriesByRef[ref] = {
|
||||
"externalId": str(t.get("id") or t.get("@id") or ref),
|
||||
"bookingDate": transDate,
|
||||
"reference": ref,
|
||||
"description": desc,
|
||||
"currency": "CHF",
|
||||
"totalAmount": 0.0,
|
||||
"lines": [],
|
||||
}
|
||||
entry = entriesByRef[ref]
|
||||
entry["lines"].append({
|
||||
"accountNumber": accNo,
|
||||
"debitAmount": debit,
|
||||
"creditAmount": credit,
|
||||
"description": desc,
|
||||
})
|
||||
entry["totalAmount"] += max(debit, credit)
|
||||
|
||||
return list(entriesByRef.values())
|
||||
except Exception as e:
|
||||
logger.error(f"RMA getJournalEntries error: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def _fetchGlBulk(self, config: Dict[str, Any], params: Dict[str, str]) -> List[Dict[str, Any]]:
|
||||
"""Try GET /gl to fetch journal entries in bulk (not all RMA versions support this)."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
url = self._buildUrl(config, "gl")
|
||||
async with session.get(url, headers=self._buildHeaders(config), params=params, timeout=aiohttp.ClientTimeout(total=60)) as resp:
|
||||
if resp.status != 200:
|
||||
return []
|
||||
body = await resp.text()
|
||||
if not body.strip():
|
||||
return []
|
||||
try:
|
||||
data = json.loads(body)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
items = data if isinstance(data, list) else (data.get("gl_batch") or data.get("gl") or data.get("items") or [])
|
||||
if isinstance(items, dict):
|
||||
items = [items]
|
||||
if not isinstance(items, list):
|
||||
return []
|
||||
|
||||
entries = []
|
||||
for batch in items:
|
||||
if not isinstance(batch, dict):
|
||||
continue
|
||||
transDate = str(batch.get("transdate") or batch.get("date") or "").split("T")[0]
|
||||
ref = batch.get("batch_number") or batch.get("reference") or str(batch.get("id", ""))
|
||||
desc = batch.get("description") or batch.get("notes") or ""
|
||||
|
||||
rawTxns = batch.get("gl_transactions", {})
|
||||
txnList = rawTxns.get("gl_transaction") if isinstance(rawTxns, dict) else rawTxns
|
||||
if isinstance(txnList, dict):
|
||||
txnList = [txnList]
|
||||
if not isinstance(txnList, list):
|
||||
txnList = []
|
||||
|
||||
lines = []
|
||||
totalAmt = 0.0
|
||||
for t in txnList:
|
||||
if not isinstance(t, dict):
|
||||
continue
|
||||
debit = float(t.get("debit_amount") or 0)
|
||||
credit = float(t.get("credit_amount") or 0)
|
||||
lines.append({
|
||||
"accountNumber": str(t.get("accno", "")),
|
||||
"debitAmount": debit,
|
||||
"creditAmount": credit,
|
||||
"description": t.get("memo", ""),
|
||||
})
|
||||
totalAmt += max(debit, credit)
|
||||
|
||||
entries.append({
|
||||
"externalId": str(batch.get("id", ref)),
|
||||
"bookingDate": transDate,
|
||||
"reference": ref,
|
||||
"description": desc,
|
||||
"currency": batch.get("currency", "CHF"),
|
||||
"totalAmount": totalAmt,
|
||||
"lines": lines,
|
||||
})
|
||||
return entries
|
||||
except Exception as e:
|
||||
logger.debug(f"RMA _fetchGlBulk not available: {e}")
|
||||
return []
|
||||
|
||||
async def pushInvoice(self, config: Dict[str, Any], invoice: Dict[str, Any]) -> SyncResult:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
|
@ -327,8 +520,8 @@ class AccountingConnectorRma(BaseAccountingConnector):
|
|||
async with session.get(url, headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status != 200:
|
||||
return []
|
||||
data = await resp.json()
|
||||
return data if isinstance(data, list) else data.get("customer", [])
|
||||
data = await self._parseJsonOrXmlList(resp, "customer")
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"RMA getCustomers error: {e}")
|
||||
return []
|
||||
|
|
@ -340,12 +533,39 @@ class AccountingConnectorRma(BaseAccountingConnector):
|
|||
async with session.get(url, headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status != 200:
|
||||
return []
|
||||
data = await resp.json()
|
||||
return data if isinstance(data, list) else data.get("vendor", [])
|
||||
data = await self._parseJsonOrXmlList(resp, "vendor")
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"RMA getVendors error: {e}")
|
||||
return []
|
||||
|
||||
async def _parseJsonOrXmlList(self, resp: aiohttp.ClientResponse, itemKey: str) -> List[Dict[str, Any]]:
|
||||
"""Parse RMA response that may be JSON or XML. Returns list of dicts."""
|
||||
body = await resp.text()
|
||||
if not body or not body.strip():
|
||||
return []
|
||||
try:
|
||||
data = json.loads(body)
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
if isinstance(data, dict):
|
||||
items = data.get(itemKey) or data.get("items") or data.get("row") or []
|
||||
if isinstance(items, dict):
|
||||
return [items]
|
||||
return items if isinstance(items, list) else []
|
||||
return []
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
result: List[Dict[str, Any]] = []
|
||||
ids = re.findall(r"<id>([^<]+)</id>", body)
|
||||
names = re.findall(r"<name>([^<]+)</name>", body)
|
||||
for i, rid in enumerate(ids):
|
||||
entry: Dict[str, Any] = {"id": rid.strip()}
|
||||
if i < len(names):
|
||||
entry["name"] = names[i].strip()
|
||||
result.append(entry)
|
||||
return result
|
||||
|
||||
async def _findBelegByFilename(self, config: Dict[str, Any], session: aiohttp.ClientSession, fileName: str) -> Optional[str]:
|
||||
"""Try GET /belege (undocumented) to find an existing beleg by filename."""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -736,6 +736,177 @@ registerModelLabels(
|
|||
)
|
||||
|
||||
|
||||
# ── TrusteeData* tables (synced from external accounting apps for analysis) ──
|
||||
|
||||
|
||||
class TrusteeDataAccount(BaseModel):
|
||||
"""Chart of accounts synced from external accounting system."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
accountNumber: str = Field(description="Account number (e.g. '1020')")
|
||||
label: str = Field(default="", description="Account name")
|
||||
accountType: Optional[str] = Field(default=None, description="asset / liability / equity / revenue / expense")
|
||||
accountGroup: Optional[str] = Field(default=None, description="Account group/category")
|
||||
currency: str = Field(default="CHF", description="Account currency")
|
||||
isActive: bool = Field(default=True)
|
||||
mandateId: Optional[str] = Field(default=None)
|
||||
featureInstanceId: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"TrusteeDataAccount",
|
||||
{"en": "Account (Synced)", "de": "Konto (Sync)", "fr": "Compte (Sync)"},
|
||||
{
|
||||
"id": {"en": "ID", "de": "ID", "fr": "ID"},
|
||||
"accountNumber": {"en": "Account Number", "de": "Kontonummer", "fr": "Numéro de compte"},
|
||||
"label": {"en": "Name", "de": "Bezeichnung", "fr": "Libellé"},
|
||||
"accountType": {"en": "Type", "de": "Typ", "fr": "Type"},
|
||||
"accountGroup": {"en": "Group", "de": "Gruppe", "fr": "Groupe"},
|
||||
"currency": {"en": "Currency", "de": "Währung", "fr": "Devise"},
|
||||
"isActive": {"en": "Active", "de": "Aktiv", "fr": "Actif"},
|
||||
"mandateId": {"en": "Mandate", "de": "Mandat", "fr": "Mandat"},
|
||||
"featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TrusteeDataJournalEntry(BaseModel):
|
||||
"""Journal entry header synced from external accounting system."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
externalId: Optional[str] = Field(default=None, description="ID in the source system")
|
||||
bookingDate: Optional[str] = Field(default=None, description="Booking date (YYYY-MM-DD)")
|
||||
reference: Optional[str] = Field(default=None, description="Booking reference / voucher number")
|
||||
description: str = Field(default="", description="Booking text")
|
||||
currency: str = Field(default="CHF")
|
||||
totalAmount: float = Field(default=0.0, description="Total amount of entry")
|
||||
mandateId: Optional[str] = Field(default=None)
|
||||
featureInstanceId: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"TrusteeDataJournalEntry",
|
||||
{"en": "Journal Entry (Synced)", "de": "Buchung (Sync)", "fr": "Écriture (Sync)"},
|
||||
{
|
||||
"id": {"en": "ID", "de": "ID", "fr": "ID"},
|
||||
"externalId": {"en": "External ID", "de": "Externe ID", "fr": "ID externe"},
|
||||
"bookingDate": {"en": "Date", "de": "Datum", "fr": "Date"},
|
||||
"reference": {"en": "Reference", "de": "Referenz", "fr": "Référence"},
|
||||
"description": {"en": "Description", "de": "Beschreibung", "fr": "Description"},
|
||||
"currency": {"en": "Currency", "de": "Währung", "fr": "Devise"},
|
||||
"totalAmount": {"en": "Amount", "de": "Betrag", "fr": "Montant"},
|
||||
"mandateId": {"en": "Mandate", "de": "Mandat", "fr": "Mandat"},
|
||||
"featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TrusteeDataJournalLine(BaseModel):
|
||||
"""Journal entry line (debit/credit) synced from external accounting system."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
journalEntryId: str = Field(description="FK → TrusteeDataJournalEntry.id")
|
||||
accountNumber: str = Field(description="Account number")
|
||||
debitAmount: float = Field(default=0.0)
|
||||
creditAmount: float = Field(default=0.0)
|
||||
currency: str = Field(default="CHF")
|
||||
taxCode: Optional[str] = Field(default=None)
|
||||
costCenter: Optional[str] = Field(default=None)
|
||||
description: str = Field(default="")
|
||||
mandateId: Optional[str] = Field(default=None)
|
||||
featureInstanceId: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"TrusteeDataJournalLine",
|
||||
{"en": "Journal Line (Synced)", "de": "Buchungszeile (Sync)", "fr": "Ligne écriture (Sync)"},
|
||||
{
|
||||
"id": {"en": "ID", "de": "ID", "fr": "ID"},
|
||||
"journalEntryId": {"en": "Journal Entry", "de": "Buchung", "fr": "Écriture"},
|
||||
"accountNumber": {"en": "Account", "de": "Konto", "fr": "Compte"},
|
||||
"debitAmount": {"en": "Debit", "de": "Soll", "fr": "Débit"},
|
||||
"creditAmount": {"en": "Credit", "de": "Haben", "fr": "Crédit"},
|
||||
"currency": {"en": "Currency", "de": "Währung", "fr": "Devise"},
|
||||
"taxCode": {"en": "Tax Code", "de": "Steuercode", "fr": "Code TVA"},
|
||||
"costCenter": {"en": "Cost Center", "de": "Kostenstelle", "fr": "Centre de coûts"},
|
||||
"description": {"en": "Description", "de": "Beschreibung", "fr": "Description"},
|
||||
"mandateId": {"en": "Mandate", "de": "Mandat", "fr": "Mandat"},
|
||||
"featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TrusteeDataContact(BaseModel):
|
||||
"""Customer or vendor synced from external accounting system."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
externalId: Optional[str] = Field(default=None, description="ID in the source system")
|
||||
contactType: str = Field(default="customer", description="customer / vendor / both")
|
||||
contactNumber: Optional[str] = Field(default=None, description="Customer/vendor number")
|
||||
name: str = Field(default="", description="Name / company")
|
||||
address: Optional[str] = Field(default=None)
|
||||
zip: Optional[str] = Field(default=None)
|
||||
city: Optional[str] = Field(default=None)
|
||||
country: Optional[str] = Field(default=None)
|
||||
email: Optional[str] = Field(default=None)
|
||||
phone: Optional[str] = Field(default=None)
|
||||
vatNumber: Optional[str] = Field(default=None)
|
||||
mandateId: Optional[str] = Field(default=None)
|
||||
featureInstanceId: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"TrusteeDataContact",
|
||||
{"en": "Contact (Synced)", "de": "Kontakt (Sync)", "fr": "Contact (Sync)"},
|
||||
{
|
||||
"id": {"en": "ID", "de": "ID", "fr": "ID"},
|
||||
"externalId": {"en": "External ID", "de": "Externe ID", "fr": "ID externe"},
|
||||
"contactType": {"en": "Type", "de": "Typ", "fr": "Type"},
|
||||
"contactNumber": {"en": "Number", "de": "Nummer", "fr": "Numéro"},
|
||||
"name": {"en": "Name", "de": "Name", "fr": "Nom"},
|
||||
"address": {"en": "Address", "de": "Adresse", "fr": "Adresse"},
|
||||
"zip": {"en": "ZIP", "de": "PLZ", "fr": "NPA"},
|
||||
"city": {"en": "City", "de": "Ort", "fr": "Ville"},
|
||||
"country": {"en": "Country", "de": "Land", "fr": "Pays"},
|
||||
"email": {"en": "Email", "de": "E-Mail", "fr": "E-mail"},
|
||||
"phone": {"en": "Phone", "de": "Telefon", "fr": "Téléphone"},
|
||||
"vatNumber": {"en": "VAT Number", "de": "MWST-Nr.", "fr": "N° TVA"},
|
||||
"mandateId": {"en": "Mandate", "de": "Mandat", "fr": "Mandat"},
|
||||
"featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TrusteeDataAccountBalance(BaseModel):
|
||||
"""Account balance per period, derived from journal lines or directly from accounting system."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
accountNumber: str = Field(description="Account number")
|
||||
periodYear: int = Field(description="Fiscal year")
|
||||
periodMonth: int = Field(default=0, description="Month (1-12); 0 = annual total")
|
||||
openingBalance: float = Field(default=0.0)
|
||||
debitTotal: float = Field(default=0.0)
|
||||
creditTotal: float = Field(default=0.0)
|
||||
closingBalance: float = Field(default=0.0)
|
||||
currency: str = Field(default="CHF")
|
||||
mandateId: Optional[str] = Field(default=None)
|
||||
featureInstanceId: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"TrusteeDataAccountBalance",
|
||||
{"en": "Account Balance (Synced)", "de": "Kontosaldo (Sync)", "fr": "Solde compte (Sync)"},
|
||||
{
|
||||
"id": {"en": "ID", "de": "ID", "fr": "ID"},
|
||||
"accountNumber": {"en": "Account", "de": "Konto", "fr": "Compte"},
|
||||
"periodYear": {"en": "Year", "de": "Jahr", "fr": "Année"},
|
||||
"periodMonth": {"en": "Month", "de": "Monat", "fr": "Mois"},
|
||||
"openingBalance": {"en": "Opening Balance", "de": "Eröffnungssaldo", "fr": "Solde d'ouverture"},
|
||||
"debitTotal": {"en": "Debit Total", "de": "Soll-Umsatz", "fr": "Total débit"},
|
||||
"creditTotal": {"en": "Credit Total", "de": "Haben-Umsatz", "fr": "Total crédit"},
|
||||
"closingBalance": {"en": "Closing Balance", "de": "Schlusssaldo", "fr": "Solde de clôture"},
|
||||
"currency": {"en": "Currency", "de": "Währung", "fr": "Devise"},
|
||||
"mandateId": {"en": "Mandate", "de": "Mandat", "fr": "Mandat"},
|
||||
"featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TrusteeAccountingConfig(BaseModel):
|
||||
"""Per-instance accounting system configuration with encrypted credentials.
|
||||
|
||||
|
|
|
|||
|
|
@ -78,6 +78,31 @@ DATA_OBJECTS = [
|
|||
"label": {"en": "Accounting Sync", "de": "Buchhaltungs-Synchronisation", "fr": "Sync. comptable"},
|
||||
"meta": {"table": "TrusteeAccountingSync", "fields": ["id", "positionId", "syncStatus", "externalId"]}
|
||||
},
|
||||
{
|
||||
"objectKey": "data.feature.trustee.TrusteeDataAccount",
|
||||
"label": {"en": "Accounts (Synced)", "de": "Kontenplan (Sync)", "fr": "Plan comptable (Sync)"},
|
||||
"meta": {"table": "TrusteeDataAccount", "fields": ["id", "accountNumber", "label", "accountType", "accountGroup", "currency", "isActive"]}
|
||||
},
|
||||
{
|
||||
"objectKey": "data.feature.trustee.TrusteeDataJournalEntry",
|
||||
"label": {"en": "Journal Entries (Synced)", "de": "Buchungen (Sync)", "fr": "Écritures (Sync)"},
|
||||
"meta": {"table": "TrusteeDataJournalEntry", "fields": ["id", "externalId", "bookingDate", "reference", "description", "currency", "totalAmount"]}
|
||||
},
|
||||
{
|
||||
"objectKey": "data.feature.trustee.TrusteeDataJournalLine",
|
||||
"label": {"en": "Journal Lines (Synced)", "de": "Buchungszeilen (Sync)", "fr": "Lignes écriture (Sync)"},
|
||||
"meta": {"table": "TrusteeDataJournalLine", "fields": ["id", "journalEntryId", "accountNumber", "debitAmount", "creditAmount", "currency", "taxCode", "costCenter", "description"]}
|
||||
},
|
||||
{
|
||||
"objectKey": "data.feature.trustee.TrusteeDataContact",
|
||||
"label": {"en": "Contacts (Synced)", "de": "Kontakte (Sync)", "fr": "Contacts (Sync)"},
|
||||
"meta": {"table": "TrusteeDataContact", "fields": ["id", "externalId", "contactType", "contactNumber", "name", "address", "zip", "city", "country", "email", "phone", "vatNumber"]}
|
||||
},
|
||||
{
|
||||
"objectKey": "data.feature.trustee.TrusteeDataAccountBalance",
|
||||
"label": {"en": "Account Balances (Synced)", "de": "Kontosalden (Sync)", "fr": "Soldes comptes (Sync)"},
|
||||
"meta": {"table": "TrusteeDataAccountBalance", "fields": ["id", "accountNumber", "periodYear", "periodMonth", "openingBalance", "debitTotal", "creditTotal", "closingBalance", "currency"]}
|
||||
},
|
||||
{
|
||||
"objectKey": "data.feature.trustee.*",
|
||||
"label": {"en": "All Trustee Data", "de": "Alle Treuhand-Daten", "fr": "Toutes les données fiduciaires"},
|
||||
|
|
|
|||
|
|
@ -188,7 +188,7 @@ def get_mime_type_options(
|
|||
"""Get supported MIME types from the document extraction service.
|
||||
Returns: [{ value: "mime/type", label: "Description" }]
|
||||
"""
|
||||
from modules.services.serviceExtraction.subRegistry import ExtractorRegistry
|
||||
from modules.serviceCenter.services.serviceExtraction.subRegistry import ExtractorRegistry
|
||||
|
||||
registry = ExtractorRegistry()
|
||||
formats = registry.getSupportedFormats()
|
||||
|
|
@ -1481,6 +1481,63 @@ def get_position_sync_status(
|
|||
return {"items": items}
|
||||
|
||||
|
||||
# ===== Accounting Data Import =====
|
||||
|
||||
@router.post("/{instanceId}/accounting/import-data")
|
||||
@limiter.limit("3/minute")
|
||||
async def import_accounting_data(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
data: Dict[str, Any] = Body(default={}),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Import accounting data (chart, journal entries, contacts) from the external system into TrusteeData* tables."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
|
||||
from .accounting.accountingDataSync import AccountingDataSync
|
||||
sync = AccountingDataSync(interface)
|
||||
dateFrom = data.get("dateFrom")
|
||||
dateTo = data.get("dateTo")
|
||||
result = await sync.importData(
|
||||
featureInstanceId=instanceId,
|
||||
mandateId=mandateId,
|
||||
dateFrom=dateFrom,
|
||||
dateTo=dateTo,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/{instanceId}/accounting/import-status")
|
||||
@limiter.limit("30/minute")
|
||||
def get_import_status(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get counts of imported TrusteeData* records for this instance."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
|
||||
from .datamodelFeatureTrustee import (
|
||||
TrusteeDataAccount, TrusteeDataJournalEntry, TrusteeDataJournalLine,
|
||||
TrusteeDataContact, TrusteeDataAccountBalance, TrusteeAccountingConfig,
|
||||
)
|
||||
filt = {"featureInstanceId": instanceId}
|
||||
counts = {
|
||||
"accounts": len(interface.db.getRecordset(TrusteeDataAccount, recordFilter=filt) or []),
|
||||
"journalEntries": len(interface.db.getRecordset(TrusteeDataJournalEntry, recordFilter=filt) or []),
|
||||
"journalLines": len(interface.db.getRecordset(TrusteeDataJournalLine, recordFilter=filt) or []),
|
||||
"contacts": len(interface.db.getRecordset(TrusteeDataContact, recordFilter=filt) or []),
|
||||
"accountBalances": len(interface.db.getRecordset(TrusteeDataAccountBalance, recordFilter=filt) or []),
|
||||
}
|
||||
cfgRecords = interface.db.getRecordset(TrusteeAccountingConfig, recordFilter={"featureInstanceId": instanceId, "isActive": True})
|
||||
if cfgRecords:
|
||||
cfg = cfgRecords[0]
|
||||
counts["lastSyncAt"] = cfg.get("lastSyncAt")
|
||||
counts["lastSyncStatus"] = cfg.get("lastSyncStatus")
|
||||
counts["lastSyncErrorMessage"] = cfg.get("lastSyncErrorMessage")
|
||||
return counts
|
||||
|
||||
|
||||
# ===== Position-Document Query =====
|
||||
|
||||
@router.get("/{instanceId}/positions/document/{documentId}", response_model=List[TrusteePosition])
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Unified AI Workspace feature."""
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
CodeEditor Feature Container - Main Module.
|
||||
Workspace Feature Container - Main Module.
|
||||
Handles feature initialization and RBAC catalog registration.
|
||||
Cursor-style AI file editing via chat interface.
|
||||
Unified AI Workspace feature.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -11,88 +11,108 @@ from typing import Dict, List, Any
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FEATURE_CODE = "codeeditor"
|
||||
FEATURE_LABEL = {"en": "Code Editor", "de": "Code Editor", "fr": "Code Editor"}
|
||||
FEATURE_ICON = "mdi-file-document-edit"
|
||||
FEATURE_CODE = "workspace"
|
||||
FEATURE_LABEL = {"en": "AI Workspace", "de": "AI Workspace", "fr": "AI Workspace"}
|
||||
FEATURE_ICON = "mdi-brain"
|
||||
|
||||
UI_OBJECTS = [
|
||||
{
|
||||
"objectKey": "ui.feature.codeeditor.editor",
|
||||
"objectKey": "ui.feature.workspace.dashboard",
|
||||
"label": {"en": "Dashboard", "de": "Dashboard", "fr": "Tableau de bord"},
|
||||
"meta": {"area": "dashboard"}
|
||||
},
|
||||
{
|
||||
"objectKey": "ui.feature.workspace.editor",
|
||||
"label": {"en": "Editor", "de": "Editor", "fr": "Editeur"},
|
||||
"meta": {"area": "editor"}
|
||||
},
|
||||
{
|
||||
"objectKey": "ui.feature.codeeditor.workflows",
|
||||
"label": {"en": "Workflows", "de": "Workflows", "fr": "Workflows"},
|
||||
"meta": {"area": "workflows"}
|
||||
"objectKey": "ui.feature.workspace.settings",
|
||||
"label": {"en": "Settings", "de": "Einstellungen", "fr": "Parametres"},
|
||||
"meta": {"area": "settings"}
|
||||
},
|
||||
]
|
||||
|
||||
RESOURCE_OBJECTS = [
|
||||
{
|
||||
"objectKey": "resource.feature.codeeditor.start",
|
||||
"label": {"en": "Start Workflow", "de": "Workflow starten", "fr": "Demarrer workflow"},
|
||||
"meta": {"endpoint": "/api/codeeditor/{instanceId}/start/stream", "method": "POST"}
|
||||
"objectKey": "resource.feature.workspace.start",
|
||||
"label": {"en": "Start Agent", "de": "Agent starten", "fr": "Demarrer agent"},
|
||||
"meta": {"endpoint": "/api/workspace/{instanceId}/start/stream", "method": "POST"}
|
||||
},
|
||||
{
|
||||
"objectKey": "resource.feature.codeeditor.stop",
|
||||
"label": {"en": "Stop Workflow", "de": "Workflow stoppen", "fr": "Arreter workflow"},
|
||||
"meta": {"endpoint": "/api/codeeditor/{instanceId}/{workflowId}/stop", "method": "POST"}
|
||||
"objectKey": "resource.feature.workspace.stop",
|
||||
"label": {"en": "Stop Agent", "de": "Agent stoppen", "fr": "Arreter agent"},
|
||||
"meta": {"endpoint": "/api/workspace/{instanceId}/{workflowId}/stop", "method": "POST"}
|
||||
},
|
||||
{
|
||||
"objectKey": "resource.feature.codeeditor.chatData",
|
||||
"label": {"en": "Get Chat Data", "de": "Chat-Daten abrufen", "fr": "Recuperer donnees chat"},
|
||||
"meta": {"endpoint": "/api/codeeditor/{instanceId}/{workflowId}/chatData", "method": "GET"}
|
||||
},
|
||||
{
|
||||
"objectKey": "resource.feature.codeeditor.files",
|
||||
"objectKey": "resource.feature.workspace.files",
|
||||
"label": {"en": "Manage Files", "de": "Dateien verwalten", "fr": "Gerer fichiers"},
|
||||
"meta": {"endpoint": "/api/codeeditor/{instanceId}/files", "method": "GET"}
|
||||
"meta": {"endpoint": "/api/workspace/{instanceId}/files", "method": "GET"}
|
||||
},
|
||||
{
|
||||
"objectKey": "resource.feature.codeeditor.apply",
|
||||
"label": {"en": "Apply Edit", "de": "Aenderung anwenden", "fr": "Appliquer modification"},
|
||||
"meta": {"endpoint": "/api/codeeditor/{instanceId}/{workflowId}/apply", "method": "POST"}
|
||||
"objectKey": "resource.feature.workspace.folders",
|
||||
"label": {"en": "Manage Folders", "de": "Ordner verwalten", "fr": "Gerer dossiers"},
|
||||
"meta": {"endpoint": "/api/workspace/{instanceId}/folders", "method": "GET"}
|
||||
},
|
||||
{
|
||||
"objectKey": "resource.feature.workspace.datasources",
|
||||
"label": {"en": "Data Sources", "de": "Datenquellen", "fr": "Sources de donnees"},
|
||||
"meta": {"endpoint": "/api/workspace/{instanceId}/datasources", "method": "GET"}
|
||||
},
|
||||
{
|
||||
"objectKey": "resource.feature.workspace.voice",
|
||||
"label": {"en": "Voice Input/Output", "de": "Spracheingabe/-ausgabe", "fr": "Entree/sortie vocale"},
|
||||
"meta": {"endpoint": "/api/workspace/{instanceId}/voice/*", "method": "POST"}
|
||||
},
|
||||
{
|
||||
"objectKey": "resource.feature.workspace.edits",
|
||||
"label": {"en": "Review File Edits", "de": "Datei-Aenderungen pruefen", "fr": "Verifier les modifications de fichiers"},
|
||||
"meta": {"endpoint": "/api/workspace/{instanceId}/edit/*", "method": "POST"}
|
||||
},
|
||||
]
|
||||
|
||||
TEMPLATE_ROLES = [
|
||||
{
|
||||
"roleLabel": "codeeditor-viewer",
|
||||
"roleLabel": "workspace-viewer",
|
||||
"description": {
|
||||
"en": "Code Editor Viewer - View editor (read-only)",
|
||||
"de": "Code Editor Betrachter - Editor ansehen (nur lesen)",
|
||||
"fr": "Visualiseur Code Editor - Consulter l'editeur (lecture seule)"
|
||||
"en": "Workspace Viewer - View workspace (read-only)",
|
||||
"de": "Workspace Betrachter - Workspace ansehen (nur lesen)",
|
||||
"fr": "Visualiseur Workspace - Consulter le workspace (lecture seule)"
|
||||
},
|
||||
"accessRules": [
|
||||
{"context": "UI", "item": "ui.feature.codeeditor.editor", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.workspace.dashboard", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.workspace.editor", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.workspace.settings", "view": True},
|
||||
{"context": "DATA", "item": None, "view": True, "read": "m", "create": "n", "update": "n", "delete": "n"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"roleLabel": "codeeditor-user",
|
||||
"roleLabel": "workspace-user",
|
||||
"description": {
|
||||
"en": "Code Editor User - Use editor and workflows",
|
||||
"de": "Code Editor Benutzer - Editor und Workflows nutzen",
|
||||
"fr": "Utilisateur Code Editor - Utiliser l'editeur et les workflows"
|
||||
"en": "Workspace User - Use AI workspace and tools",
|
||||
"de": "Workspace Benutzer - AI Workspace und Tools nutzen",
|
||||
"fr": "Utilisateur Workspace - Utiliser l'espace de travail AI et les outils"
|
||||
},
|
||||
"accessRules": [
|
||||
{"context": "UI", "item": "ui.feature.codeeditor.editor", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.codeeditor.workflows", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.codeeditor.start", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.codeeditor.stop", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.codeeditor.chatData", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.codeeditor.files", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.codeeditor.apply", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.workspace.dashboard", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.workspace.editor", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.workspace.settings", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.workspace.start", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.workspace.stop", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.workspace.files", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.workspace.folders", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.workspace.datasources", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.workspace.voice", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.workspace.edits", "view": True},
|
||||
{"context": "DATA", "item": None, "view": True, "read": "m", "create": "m", "update": "m", "delete": "m"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"roleLabel": "codeeditor-admin",
|
||||
"roleLabel": "workspace-admin",
|
||||
"description": {
|
||||
"en": "Code Editor Admin - Full access to code editor",
|
||||
"de": "Code Editor Admin - Vollzugriff auf Code Editor",
|
||||
"fr": "Administrateur Code Editor - Acces complet au code editor"
|
||||
"en": "Workspace Admin - Full access to AI workspace",
|
||||
"de": "Workspace Admin - Vollzugriff auf AI Workspace",
|
||||
"fr": "Administrateur Workspace - Acces complet au workspace AI"
|
||||
},
|
||||
"accessRules": [
|
||||
{"context": "UI", "item": None, "view": True},
|
||||
1575
modules/features/workspace/routeFeatureWorkspace.py
Normal file
1575
modules/features/workspace/routeFeatureWorkspace.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -4,7 +4,7 @@ import logging
|
|||
import asyncio
|
||||
import uuid
|
||||
import base64
|
||||
from typing import Dict, Any, List, Union, Tuple, Optional, Callable
|
||||
from typing import Dict, Any, List, Union, Tuple, Optional, Callable, AsyncGenerator
|
||||
from dataclasses import dataclass, field
|
||||
import time
|
||||
|
||||
|
|
@ -12,6 +12,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
from modules.aicore.aicoreModelRegistry import modelRegistry
|
||||
from modules.aicore.aicoreModelSelector import modelSelector
|
||||
from modules.aicore.aicoreBase import RateLimitExceededException
|
||||
from modules.datamodels.datamodelAi import (
|
||||
AiModel,
|
||||
AiCallOptions,
|
||||
|
|
@ -84,27 +85,31 @@ class AiObjects:
|
|||
|
||||
# AI for Extraction, Processing, Generation
|
||||
async def callWithTextContext(self, request: AiCallRequest) -> AiCallResponse:
|
||||
"""Call AI model for traditional text/context calls with fallback mechanism."""
|
||||
"""Call AI model for traditional text/context calls with fallback mechanism.
|
||||
|
||||
Supports two modes:
|
||||
- Legacy: prompt + context → constructs messages internally
|
||||
- Agent: request.messages provided → passes through directly
|
||||
"""
|
||||
prompt = request.prompt
|
||||
context = request.context or ""
|
||||
options = request.options
|
||||
|
||||
# Input bytes will be calculated inside _callWithModel
|
||||
|
||||
# Generation parameters are handled inside _callWithModel
|
||||
|
||||
# Get failover models for this operation type
|
||||
availableModels = modelRegistry.getAvailableModels()
|
||||
|
||||
# Filter by allowedProviders if specified (from workflow config)
|
||||
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
|
||||
if allowedProviders:
|
||||
filteredModels = [m for m in availableModels if m.connectorType in allowedProviders]
|
||||
if filteredModels:
|
||||
logger.info(f"Filtered models by allowedProviders {allowedProviders}: {len(filteredModels)} models (from {len(availableModels)})")
|
||||
availableModels = filteredModels
|
||||
else:
|
||||
logger.warning(f"No models match allowedProviders {allowedProviders}, using all {len(availableModels)} available models")
|
||||
errorMsg = f"No models match allowedProviders {allowedProviders} for operation {options.operationType}"
|
||||
logger.error(errorMsg)
|
||||
return AiCallResponse(
|
||||
content=errorMsg, modelName="error", priceCHF=0.0,
|
||||
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1,
|
||||
)
|
||||
|
||||
failoverModelList = modelSelector.getFailoverModelList(prompt, context, options, availableModels)
|
||||
|
||||
|
|
@ -121,18 +126,46 @@ class AiObjects:
|
|||
errorCount=1
|
||||
)
|
||||
|
||||
# Try each model in failover sequence
|
||||
_MAX_SHORT_RETRY = 15.0
|
||||
|
||||
lastError = None
|
||||
for attempt, model in enumerate(failoverModelList):
|
||||
try:
|
||||
logger.info(f"Attempting AI call with model: {model.name} (attempt {attempt + 1}/{len(failoverModelList)})")
|
||||
|
||||
# Call the model directly - no truncation or compression here
|
||||
if request.messages:
|
||||
response = await self._callWithMessages(model, request.messages, options, request.tools)
|
||||
else:
|
||||
response = await self._callWithModel(model, prompt, context, options)
|
||||
|
||||
logger.info(f"✅ AI call successful with model: {model.name}")
|
||||
logger.info(f"AI call successful with model: {model.name}")
|
||||
return response
|
||||
|
||||
except RateLimitExceededException as rle:
|
||||
retryAfter = rle.retryAfterSeconds
|
||||
lastError = rle
|
||||
if 0 < retryAfter <= _MAX_SHORT_RETRY:
|
||||
logger.info(f"Rate limit on {model.name}, waiting {retryAfter:.1f}s before retry")
|
||||
await asyncio.sleep(retryAfter + 0.5)
|
||||
try:
|
||||
if request.messages:
|
||||
response = await self._callWithMessages(model, request.messages, options, request.tools)
|
||||
else:
|
||||
response = await self._callWithModel(model, prompt, context, options)
|
||||
logger.info(f"AI call successful with {model.name} after rate-limit retry")
|
||||
return response
|
||||
except Exception as retryErr:
|
||||
lastError = retryErr
|
||||
logger.warning(f"Retry after rate-limit wait also failed for {model.name}: {retryErr}")
|
||||
else:
|
||||
logger.warning(f"Rate limit on {model.name} (retryAfter={retryAfter:.1f}s), failing over")
|
||||
cooldown = max(retryAfter, 10.0) if retryAfter > 0 else 0.0
|
||||
modelSelector.reportFailure(model.name, cooldownSeconds=cooldown)
|
||||
if attempt < len(failoverModelList) - 1:
|
||||
continue
|
||||
logger.error(f"All {len(failoverModelList)} models failed for operation {options.operationType}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
lastError = e
|
||||
logger.warning(f"AI call failed with model {model.name}: {str(e)}")
|
||||
|
|
@ -142,8 +175,7 @@ class AiObjects:
|
|||
logger.info(f"Trying next failover model...")
|
||||
continue
|
||||
else:
|
||||
# All models failed
|
||||
logger.error(f"💥 All {len(failoverModelList)} models failed for operation {options.operationType}")
|
||||
logger.error(f"All {len(failoverModelList)} models failed for operation {options.operationType}")
|
||||
break
|
||||
|
||||
# All failover attempts failed - return error response
|
||||
|
|
@ -254,6 +286,321 @@ class AiObjects:
|
|||
|
||||
return response
|
||||
|
||||
async def _callWithMessages(self, model: AiModel, messages: List[Dict[str, Any]],
|
||||
options: AiCallOptions = None,
|
||||
tools: List[Dict[str, Any]] = None) -> AiCallResponse:
|
||||
"""Call a model with pre-built messages (agent mode). Supports tools for native function calling."""
|
||||
import json as _json
|
||||
|
||||
inputBytes = sum(len(str(m.get("content", "")).encode("utf-8")) for m in messages)
|
||||
startTime = time.time()
|
||||
|
||||
if not model.functionCall:
|
||||
raise ValueError(f"Model {model.name} has no function call defined")
|
||||
|
||||
modelCall = AiModelCall(
|
||||
messages=messages,
|
||||
model=model,
|
||||
options=options or {},
|
||||
tools=tools
|
||||
)
|
||||
|
||||
modelResponse = await model.functionCall(modelCall)
|
||||
|
||||
if not modelResponse.success:
|
||||
raise ValueError(f"Model call failed: {modelResponse.error}")
|
||||
|
||||
endTime = time.time()
|
||||
processingTime = endTime - startTime
|
||||
content = modelResponse.content
|
||||
outputBytes = len(content.encode("utf-8"))
|
||||
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes)
|
||||
|
||||
# Extract tool calls from metadata if present (native function calling)
|
||||
responseToolCalls = None
|
||||
if modelResponse.metadata:
|
||||
responseToolCalls = modelResponse.metadata.get("toolCalls")
|
||||
|
||||
response = AiCallResponse(
|
||||
content=content,
|
||||
modelName=model.name,
|
||||
provider=model.connectorType,
|
||||
priceCHF=priceCHF,
|
||||
processingTime=processingTime,
|
||||
bytesSent=inputBytes,
|
||||
bytesReceived=outputBytes,
|
||||
errorCount=0,
|
||||
toolCalls=responseToolCalls
|
||||
)
|
||||
|
||||
if self.billingCallback:
|
||||
try:
|
||||
self.billingCallback(response)
|
||||
except Exception as e:
|
||||
logger.error(f"BILLING: Failed to record billing for model {model.name}: {e}")
|
||||
|
||||
return response
|
||||
|
||||
async def callWithTextContextStream(
|
||||
self, request: AiCallRequest
|
||||
) -> AsyncGenerator[Union[str, AiCallResponse], None]:
|
||||
"""Streaming variant of callWithTextContext. Yields str deltas, then final AiCallResponse."""
|
||||
options = request.options
|
||||
availableModels = modelRegistry.getAvailableModels()
|
||||
|
||||
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
|
||||
if allowedProviders:
|
||||
filtered = [m for m in availableModels if m.connectorType in allowedProviders]
|
||||
if filtered:
|
||||
availableModels = filtered
|
||||
else:
|
||||
yield AiCallResponse(
|
||||
content=f"No models match allowedProviders {allowedProviders} for operation {options.operationType}",
|
||||
modelName="error", priceCHF=0.0, processingTime=0.0,
|
||||
bytesSent=0, bytesReceived=0, errorCount=1,
|
||||
)
|
||||
return
|
||||
|
||||
failoverModelList = modelSelector.getFailoverModelList(
|
||||
request.prompt, request.context or "", options, availableModels
|
||||
)
|
||||
if not failoverModelList:
|
||||
yield AiCallResponse(
|
||||
content=f"No suitable models found for operation {options.operationType}",
|
||||
modelName="error", priceCHF=0.0, processingTime=0.0,
|
||||
bytesSent=0, bytesReceived=0, errorCount=1,
|
||||
)
|
||||
return
|
||||
|
||||
_MAX_SHORT_RETRY = 15.0
|
||||
|
||||
lastError = None
|
||||
for attempt, model in enumerate(failoverModelList):
|
||||
try:
|
||||
logger.info(f"Streaming AI call with model: {model.name} (attempt {attempt + 1})")
|
||||
async for chunk in self._callWithMessagesStream(model, request.messages, options, request.tools):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
except RateLimitExceededException as rle:
|
||||
retryAfter = rle.retryAfterSeconds
|
||||
lastError = rle
|
||||
if 0 < retryAfter <= _MAX_SHORT_RETRY:
|
||||
logger.info(f"Rate limit on {model.name}, waiting {retryAfter:.1f}s before retry")
|
||||
await asyncio.sleep(retryAfter + 0.5)
|
||||
try:
|
||||
async for chunk in self._callWithMessagesStream(model, request.messages, options, request.tools):
|
||||
yield chunk
|
||||
return
|
||||
except Exception as retryErr:
|
||||
lastError = retryErr
|
||||
logger.warning(f"Retry after rate-limit wait also failed for {model.name}: {retryErr}")
|
||||
else:
|
||||
logger.warning(f"Rate limit on {model.name} (retryAfter={retryAfter:.1f}s), failing over")
|
||||
cooldown = max(retryAfter, 10.0) if retryAfter > 0 else 0.0
|
||||
modelSelector.reportFailure(model.name, cooldownSeconds=cooldown)
|
||||
if attempt < len(failoverModelList) - 1:
|
||||
continue
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
lastError = e
|
||||
logger.warning(f"Streaming AI call failed with {model.name}: {e}")
|
||||
modelSelector.reportFailure(model.name)
|
||||
if attempt < len(failoverModelList) - 1:
|
||||
continue
|
||||
break
|
||||
|
||||
yield AiCallResponse(
|
||||
content=f"All models failed (stream). Last error: {lastError}",
|
||||
modelName="error", priceCHF=0.0, processingTime=0.0,
|
||||
bytesSent=0, bytesReceived=0, errorCount=1,
|
||||
)
|
||||
|
||||
async def _callWithMessagesStream(
|
||||
self, model: AiModel, messages: List[Dict[str, Any]],
|
||||
options: AiCallOptions = None, tools: List[Dict[str, Any]] = None,
|
||||
) -> AsyncGenerator[Union[str, AiCallResponse], None]:
|
||||
"""Stream a model call. Yields str deltas, then final AiCallResponse with billing."""
|
||||
from modules.datamodels.datamodelAi import AiModelCall, AiModelResponse
|
||||
|
||||
inputBytes = sum(len(str(m.get("content", "")).encode("utf-8")) for m in messages)
|
||||
startTime = time.time()
|
||||
|
||||
if not model.functionCallStream:
|
||||
response = await self._callWithMessages(model, messages, options, tools)
|
||||
if response.content:
|
||||
yield response.content
|
||||
yield response
|
||||
return
|
||||
|
||||
modelCall = AiModelCall(
|
||||
messages=messages, model=model,
|
||||
options=options or {}, tools=tools,
|
||||
)
|
||||
|
||||
finalModelResponse = None
|
||||
async for item in model.functionCallStream(modelCall):
|
||||
if isinstance(item, AiModelResponse):
|
||||
finalModelResponse = item
|
||||
else:
|
||||
yield item
|
||||
|
||||
if not finalModelResponse:
|
||||
raise ValueError(f"Stream from {model.name} produced no final AiModelResponse")
|
||||
|
||||
endTime = time.time()
|
||||
processingTime = endTime - startTime
|
||||
content = finalModelResponse.content
|
||||
outputBytes = len(content.encode("utf-8"))
|
||||
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes)
|
||||
|
||||
responseToolCalls = None
|
||||
if finalModelResponse.metadata:
|
||||
responseToolCalls = finalModelResponse.metadata.get("toolCalls")
|
||||
|
||||
response = AiCallResponse(
|
||||
content=content,
|
||||
modelName=model.name,
|
||||
provider=model.connectorType,
|
||||
priceCHF=priceCHF,
|
||||
processingTime=processingTime,
|
||||
bytesSent=inputBytes,
|
||||
bytesReceived=outputBytes,
|
||||
errorCount=0,
|
||||
toolCalls=responseToolCalls,
|
||||
)
|
||||
|
||||
if self.billingCallback:
|
||||
try:
|
||||
self.billingCallback(response)
|
||||
except Exception as e:
|
||||
logger.error(f"BILLING: Failed to record stream billing for {model.name}: {e}")
|
||||
|
||||
yield response
|
||||
|
||||
async def callEmbedding(self, texts: List[str], options: AiCallOptions = None) -> AiCallResponse:
|
||||
"""Generate embeddings for a list of texts using the best available embedding model.
|
||||
|
||||
Token-aware batching: splits the texts list into batches that respect the
|
||||
model's contextLength (with 10% safety margin). Each batch is sent as a
|
||||
separate API call; the resulting embeddings are merged in order.
|
||||
|
||||
Failover across providers (OpenAI -> Mistral) works identically to chat models,
|
||||
but ContextLengthExceededException is NOT retried via failover (same limits).
|
||||
|
||||
Returns:
|
||||
AiCallResponse with metadata["embeddings"] containing the vectors.
|
||||
"""
|
||||
from modules.aicore.aicoreBase import ContextLengthExceededException as _CtxExc
|
||||
|
||||
if options is None:
|
||||
options = AiCallOptions(operationType=OperationTypeEnum.EMBEDDING)
|
||||
else:
|
||||
options.operationType = OperationTypeEnum.EMBEDDING
|
||||
|
||||
combinedText = " ".join(texts[:3])[:500]
|
||||
availableModels = modelRegistry.getAvailableModels()
|
||||
|
||||
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
|
||||
if allowedProviders:
|
||||
filtered = [m for m in availableModels if m.connectorType in allowedProviders]
|
||||
if filtered:
|
||||
availableModels = filtered
|
||||
else:
|
||||
logger.warning(f"No embedding models match allowedProviders {allowedProviders}")
|
||||
|
||||
failoverModelList = modelSelector.getFailoverModelList(
|
||||
combinedText, "", options, availableModels
|
||||
)
|
||||
|
||||
if not failoverModelList:
|
||||
return AiCallResponse(
|
||||
content="", modelName="error", priceCHF=0.0,
|
||||
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1
|
||||
)
|
||||
|
||||
lastError = None
|
||||
for attempt, model in enumerate(failoverModelList):
|
||||
try:
|
||||
logger.info(f"Embedding call with {model.name} (attempt {attempt + 1}/{len(failoverModelList)})")
|
||||
inputBytes = sum(len(t.encode("utf-8")) for t in texts)
|
||||
startTime = time.time()
|
||||
|
||||
batches = _buildEmbeddingBatches(texts, model.contextLength)
|
||||
logger.info(
|
||||
f"Embedding: {len(texts)} texts -> {len(batches)} batch(es), "
|
||||
f"model contextLength={model.contextLength}"
|
||||
)
|
||||
|
||||
allEmbeddings: List[List[float]] = []
|
||||
totalPriceCHF = 0.0
|
||||
|
||||
for batchIdx, batch in enumerate(batches):
|
||||
modelCall = AiModelCall(
|
||||
model=model, options=options, embeddingInput=batch
|
||||
)
|
||||
modelResponse = await model.functionCall(modelCall)
|
||||
|
||||
if not modelResponse.success:
|
||||
raise ValueError(f"Embedding batch {batchIdx + 1} failed: {modelResponse.error}")
|
||||
|
||||
batchEmbeddings = (modelResponse.metadata or {}).get("embeddings", [])
|
||||
allEmbeddings.extend(batchEmbeddings)
|
||||
|
||||
batchBytes = sum(len(t.encode("utf-8")) for t in batch)
|
||||
totalPriceCHF += model.calculatepriceCHF(0, batchBytes, 0)
|
||||
|
||||
processingTime = time.time() - startTime
|
||||
if totalPriceCHF == 0.0:
|
||||
totalPriceCHF = model.calculatepriceCHF(processingTime, inputBytes, 0)
|
||||
|
||||
response = AiCallResponse(
|
||||
content="", modelName=model.name, provider=model.connectorType,
|
||||
priceCHF=totalPriceCHF, processingTime=processingTime,
|
||||
bytesSent=inputBytes, bytesReceived=0, errorCount=0,
|
||||
metadata={"embeddings": allEmbeddings}
|
||||
)
|
||||
|
||||
if self.billingCallback:
|
||||
try:
|
||||
self.billingCallback(response)
|
||||
except Exception as e:
|
||||
logger.error(f"BILLING: Failed to record billing for embedding {model.name}: {e}")
|
||||
|
||||
return response
|
||||
|
||||
except _CtxExc as e:
|
||||
logger.error(f"ContextLengthExceeded for {model.name} despite batching – aborting failover: {e}")
|
||||
return AiCallResponse(
|
||||
content=str(e), modelName=model.name, priceCHF=0.0,
|
||||
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1
|
||||
)
|
||||
|
||||
except RateLimitExceededException as rle:
|
||||
retryAfter = rle.retryAfterSeconds
|
||||
lastError = rle
|
||||
cooldown = max(retryAfter, 10.0) if retryAfter > 0 else 0.0
|
||||
logger.warning(f"Rate limit on {model.name} during embedding (retryAfter={retryAfter:.1f}s)")
|
||||
modelSelector.reportFailure(model.name, cooldownSeconds=cooldown)
|
||||
if attempt < len(failoverModelList) - 1:
|
||||
continue
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
lastError = e
|
||||
logger.warning(f"Embedding call failed with {model.name}: {str(e)}")
|
||||
modelSelector.reportFailure(model.name)
|
||||
if attempt < len(failoverModelList) - 1:
|
||||
continue
|
||||
break
|
||||
|
||||
errorMsg = f"All embedding models failed. Last error: {str(lastError)}"
|
||||
logger.error(errorMsg)
|
||||
return AiCallResponse(
|
||||
content=errorMsg, modelName="error", priceCHF=0.0,
|
||||
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1
|
||||
)
|
||||
|
||||
# Utility methods
|
||||
async def listAvailableModels(self, connectorType: str = None) -> List[Dict[str, Any]]:
|
||||
|
|
@ -276,4 +623,50 @@ class AiObjects:
|
|||
return [model.displayName for model in models]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Internal helpers
|
||||
# =============================================================================
|
||||
|
||||
_CHARS_PER_TOKEN = 4
|
||||
_SAFETY_MARGIN = 0.90
|
||||
|
||||
|
||||
def _estimateTokens(text: str) -> int:
|
||||
"""Rough token estimate: 1 token ~ 4 characters."""
|
||||
return max(1, len(text) // _CHARS_PER_TOKEN)
|
||||
|
||||
|
||||
def _buildEmbeddingBatches(texts: List[str], contextLength: int) -> List[List[str]]:
|
||||
"""Split a list of texts into batches whose total estimated token count
|
||||
stays within the model's contextLength (with safety margin).
|
||||
|
||||
Each individual text is assumed to already be within limits (enforced by
|
||||
the chunking layer). If a single text exceeds the budget, it is placed
|
||||
in its own batch as a last resort.
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
if contextLength <= 0:
|
||||
return [texts]
|
||||
|
||||
maxTokensPerBatch = int(contextLength * _SAFETY_MARGIN)
|
||||
batches: List[List[str]] = []
|
||||
currentBatch: List[str] = []
|
||||
currentTokens = 0
|
||||
|
||||
for text in texts:
|
||||
textTokens = _estimateTokens(text)
|
||||
if currentBatch and (currentTokens + textTokens) > maxTokensPerBatch:
|
||||
batches.append(currentBatch)
|
||||
currentBatch = []
|
||||
currentTokens = 0
|
||||
currentBatch.append(text)
|
||||
currentTokens += textTokens
|
||||
|
||||
if currentBatch:
|
||||
batches.append(currentBatch)
|
||||
|
||||
return batches
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -96,6 +96,9 @@ def initBootstrap(db: DatabaseConnector) -> None:
|
|||
if mandateId:
|
||||
initRootMandateFeatures(db, mandateId)
|
||||
|
||||
# Remove feature instances for features that no longer exist in the codebase
|
||||
_cleanupRemovedFeatureInstances(db)
|
||||
|
||||
# Initialize billing settings for root mandate
|
||||
if mandateId:
|
||||
initRootMandateBilling(mandateId)
|
||||
|
|
@ -257,6 +260,33 @@ def initRootMandateFeatures(db: DatabaseConnector, mandateId: str) -> None:
|
|||
logger.info("Root mandate features initialization completed")
|
||||
|
||||
|
||||
def _cleanupRemovedFeatureInstances(db: DatabaseConnector) -> None:
|
||||
"""Remove feature instances whose featureCode no longer exists in the codebase."""
|
||||
from modules.datamodels.datamodelFeatures import FeatureInstance
|
||||
from modules.system.registry import loadFeatureMainModules
|
||||
|
||||
mainModules = loadFeatureMainModules()
|
||||
activeCodes = set()
|
||||
for featureName, module in mainModules.items():
|
||||
if hasattr(module, "getFeatureDefinition"):
|
||||
try:
|
||||
featureDef = module.getFeatureDefinition()
|
||||
activeCodes.add(featureDef.get("code", featureName))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
allInstances = db.getRecordset(FeatureInstance)
|
||||
for inst in allInstances:
|
||||
code = inst.get("featureCode") if isinstance(inst, dict) else getattr(inst, "featureCode", None)
|
||||
instId = inst.get("id") if isinstance(inst, dict) else getattr(inst, "id", None)
|
||||
if code and code not in activeCodes:
|
||||
try:
|
||||
db.recordDelete(FeatureInstance, str(instId))
|
||||
logger.info(f"Removed orphaned feature instance '{instId}' (featureCode='{code}')")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not remove orphaned feature instance '{instId}': {e}")
|
||||
|
||||
|
||||
def initRootMandate(db: DatabaseConnector) -> Optional[str]:
|
||||
"""
|
||||
Creates the Root mandate if it doesn't exist.
|
||||
|
|
@ -443,7 +473,7 @@ def initRoles(db: DatabaseConnector) -> None:
|
|||
|
||||
# Check specifically for system template roles:
|
||||
# mandateId=NULL, isSystemRole=True, featureCode=NULL
|
||||
# Feature templates (e.g. chatplayground admin) share the same labels but have featureCode set!
|
||||
# Feature templates (e.g. automation admin) share the same labels but have featureCode set!
|
||||
allTemplates = db.getRecordset(
|
||||
Role,
|
||||
recordFilter={"mandateId": None, "isSystemRole": True}
|
||||
|
|
@ -475,7 +505,7 @@ def _deduplicateRoles(db: DatabaseConnector) -> None:
|
|||
|
||||
# Group by (roleLabel, mandateId, featureInstanceId, featureCode)
|
||||
# featureCode is essential: system template ('admin', None, None, None)
|
||||
# must NOT be grouped with feature template ('admin', None, None, 'chatplayground')
|
||||
# must NOT be grouped with feature template ('admin', None, None, 'automation')
|
||||
groups: dict = {}
|
||||
for role in allRoles:
|
||||
key = (role.get("roleLabel"), role.get("mandateId"), role.get("featureInstanceId"), role.get("featureCode"))
|
||||
|
|
@ -1931,8 +1961,6 @@ def _createStoreResourceRules(db: DatabaseConnector) -> None:
|
|||
"""
|
||||
storeResources = [
|
||||
"resource.store.automation",
|
||||
"resource.store.chatplayground",
|
||||
"resource.store.codeeditor",
|
||||
"resource.store.teamsbot",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -680,6 +680,29 @@ class BillingObjects:
|
|||
record = StripeWebhookEvent(event_id=event_id)
|
||||
return self.db.recordCreate(StripeWebhookEvent, record.model_dump())
|
||||
|
||||
def getPaymentTransactionByReferenceId(self, referenceId: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Find an existing Stripe payment credit transaction by Checkout Session ID.
|
||||
|
||||
Args:
|
||||
referenceId: Stripe Checkout Session ID (cs_xxx)
|
||||
|
||||
Returns:
|
||||
Transaction record if found, else None
|
||||
"""
|
||||
try:
|
||||
results = self.db.getRecordset(
|
||||
BillingTransaction,
|
||||
recordFilter={
|
||||
"referenceType": ReferenceTypeEnum.PAYMENT.value,
|
||||
"referenceId": referenceId,
|
||||
}
|
||||
)
|
||||
return results[0] if results else None
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking Stripe payment transaction by referenceId: {e}")
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Balance Check Operations
|
||||
# =========================================================================
|
||||
|
|
@ -764,7 +787,11 @@ class BillingObjects:
|
|||
featureCode: str = None,
|
||||
aicoreProvider: str = None,
|
||||
aicoreModel: str = None,
|
||||
description: str = "AI Usage"
|
||||
description: str = "AI Usage",
|
||||
processingTime: float = None,
|
||||
bytesSent: int = None,
|
||||
bytesReceived: int = None,
|
||||
errorCount: int = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Record usage cost as a billing transaction.
|
||||
|
|
@ -774,20 +801,6 @@ class BillingObjects:
|
|||
- PREPAY_USER: deduct from user's own balance
|
||||
- PREPAY_MANDATE: deduct from mandate pool balance
|
||||
- CREDIT_POSTPAY: deduct from mandate pool balance
|
||||
|
||||
Args:
|
||||
mandateId: Mandate ID
|
||||
userId: User ID
|
||||
priceCHF: Cost in CHF
|
||||
workflowId: Optional workflow ID
|
||||
featureInstanceId: Optional feature instance ID
|
||||
featureCode: Optional feature code
|
||||
aicoreProvider: AICore provider name (e.g., 'anthropic', 'openai')
|
||||
aicoreModel: AICore model name (e.g., 'claude-4-sonnet', 'gpt-4o')
|
||||
description: Transaction description
|
||||
|
||||
Returns:
|
||||
Created transaction dict or None
|
||||
"""
|
||||
if priceCHF <= 0:
|
||||
return None
|
||||
|
|
@ -816,7 +829,11 @@ class BillingObjects:
|
|||
featureCode=featureCode,
|
||||
aicoreProvider=aicoreProvider,
|
||||
aicoreModel=aicoreModel,
|
||||
createdByUserId=userId
|
||||
createdByUserId=userId,
|
||||
processingTime=processingTime,
|
||||
bytesSent=bytesSent,
|
||||
bytesReceived=bytesReceived,
|
||||
errorCount=errorCount
|
||||
)
|
||||
|
||||
# Determine where to deduct balance
|
||||
|
|
@ -828,6 +845,20 @@ class BillingObjects:
|
|||
poolAccount = self.getOrCreateMandateAccount(mandateId)
|
||||
return self.createTransaction(transaction, balanceAccountId=poolAccount["id"])
|
||||
|
||||
# =========================================================================
|
||||
# Workflow Cost Query
|
||||
# =========================================================================
|
||||
|
||||
def getWorkflowCost(self, workflowId: str) -> float:
|
||||
"""Sum of all transaction amounts for a workflow."""
|
||||
if not workflowId:
|
||||
return 0.0
|
||||
transactions = self.db.getRecordset(
|
||||
BillingTransaction,
|
||||
recordFilter={"workflowId": workflowId}
|
||||
)
|
||||
return sum(t.get("amount", 0.0) for t in transactions)
|
||||
|
||||
# =========================================================================
|
||||
# Billing Model Switch Operations
|
||||
# =========================================================================
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ from modules.datamodels.datamodelUam import AccessLevel
|
|||
|
||||
from modules.datamodels.datamodelChat import (
|
||||
ChatDocument,
|
||||
ChatStat,
|
||||
ChatLog,
|
||||
ChatMessage,
|
||||
ChatWorkflow,
|
||||
|
|
@ -663,10 +662,8 @@ class ChatObjects:
|
|||
|
||||
workflow = workflows[0]
|
||||
try:
|
||||
# Load related data from normalized tables
|
||||
logs = self.getLogs(workflowId)
|
||||
messages = self.getMessages(workflowId)
|
||||
stats = self.getStats(workflowId)
|
||||
|
||||
# Validate workflow data against ChatWorkflow model
|
||||
# Explicit type coercion: DB may store numeric fields as TEXT on some platforms
|
||||
|
|
@ -694,8 +691,7 @@ class ChatObjects:
|
|||
lastActivity=_toFloat(workflow.get("lastActivity")),
|
||||
startedAt=_toFloat(workflow.get("startedAt")),
|
||||
logs=logs,
|
||||
messages=messages,
|
||||
stats=stats
|
||||
messages=messages
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating workflow data: {str(e)}")
|
||||
|
|
@ -731,7 +727,7 @@ class ChatObjects:
|
|||
except Exception as e:
|
||||
logger.warning(f"Could not get Root mandate: {e}")
|
||||
# Note: ChatWorkflow has featureInstanceId for multi-tenancy isolation.
|
||||
# Child tables (ChatMessage, ChatLog, ChatStat, ChatDocument) are user-owned
|
||||
# Child tables (ChatMessage, ChatLog, ChatDocument) are user-owned
|
||||
# and do NOT store featureInstanceId - they inherit isolation from ChatWorkflow.
|
||||
# Ensure featureInstanceId is set from context if not already in workflowData
|
||||
if "featureInstanceId" not in workflowData or not workflowData.get("featureInstanceId"):
|
||||
|
|
@ -760,7 +756,7 @@ class ChatObjects:
|
|||
logs=[],
|
||||
messages=[],
|
||||
stats=[],
|
||||
workflowMode=created["workflowMode"],
|
||||
workflowMode=created.get("workflowMode", "Dynamic"),
|
||||
maxSteps=created.get("maxSteps", 1)
|
||||
)
|
||||
|
||||
|
|
@ -789,23 +785,20 @@ class ChatObjects:
|
|||
# Load fresh data from normalized tables
|
||||
logs = self.getLogs(workflowId)
|
||||
messages = self.getMessages(workflowId)
|
||||
stats = self.getStats(workflowId)
|
||||
|
||||
# Convert to ChatWorkflow model
|
||||
return ChatWorkflow(
|
||||
id=updated["id"],
|
||||
status=updated.get("status", workflow.status),
|
||||
name=updated.get("name", workflow.name),
|
||||
currentRound=updated.get("currentRound", workflow.currentRound),
|
||||
currentTask=updated.get("currentTask", workflow.currentTask),
|
||||
currentAction=updated.get("currentAction", workflow.currentAction),
|
||||
totalTasks=updated.get("totalTasks", workflow.totalTasks),
|
||||
totalActions=updated.get("totalActions", workflow.totalActions),
|
||||
currentRound=updated.get("currentRound") or getattr(workflow, "currentRound", 0) or 0,
|
||||
currentTask=updated.get("currentTask") or getattr(workflow, "currentTask", 0) or 0,
|
||||
currentAction=updated.get("currentAction") or getattr(workflow, "currentAction", 0) or 0,
|
||||
totalTasks=updated.get("totalTasks") or getattr(workflow, "totalTasks", 0) or 0,
|
||||
totalActions=updated.get("totalActions") or getattr(workflow, "totalActions", 0) or 0,
|
||||
lastActivity=updated.get("lastActivity", workflow.lastActivity),
|
||||
startedAt=updated.get("startedAt", workflow.startedAt),
|
||||
logs=logs,
|
||||
messages=messages,
|
||||
stats=stats
|
||||
messages=messages
|
||||
)
|
||||
|
||||
def deleteWorkflow(self, workflowId: str) -> bool:
|
||||
|
|
@ -827,7 +820,6 @@ class ChatObjects:
|
|||
messageId = message.id
|
||||
if messageId:
|
||||
# Delete message documents (but NOT the files!)
|
||||
# Note: ChatStat does NOT have messageId - stats are only at workflow level
|
||||
try:
|
||||
existing_docs = self._getRecordset(ChatDocument, recordFilter={"messageId": messageId})
|
||||
for doc in existing_docs:
|
||||
|
|
@ -839,11 +831,7 @@ class ChatObjects:
|
|||
self.db.recordDelete(ChatMessage, messageId)
|
||||
|
||||
# 2. Delete workflow stats
|
||||
existing_stats = self._getRecordset(ChatStat, recordFilter={"workflowId": workflowId})
|
||||
for stat in existing_stats:
|
||||
self.db.recordDelete(ChatStat, stat["id"])
|
||||
|
||||
# 3. Delete workflow logs
|
||||
# 2. Delete workflow logs
|
||||
existing_logs = self._getRecordset(ChatLog, recordFilter={"workflowId": workflowId})
|
||||
for log in existing_logs:
|
||||
self.db.recordDelete(ChatLog, log["id"])
|
||||
|
|
@ -1270,7 +1258,6 @@ class ChatObjects:
|
|||
self.db.recordDelete(ChatDocument, doc["id"])
|
||||
|
||||
# 2. Finally delete the message itself
|
||||
# Note: ChatStat has no messageId field -- stats are workflow-level, not message-level
|
||||
success = self.db.recordDelete(ChatMessage, messageId)
|
||||
|
||||
return success
|
||||
|
|
@ -1517,74 +1504,10 @@ class ChatObjects:
|
|||
# Return validated ChatLog instance
|
||||
return ChatLog(**createdLog)
|
||||
|
||||
# Stats methods
|
||||
|
||||
def getStats(self, workflowId: str) -> List[ChatStat]:
|
||||
"""Returns list of statistics for a workflow if user has access."""
|
||||
# Check workflow access first (without calling getWorkflow to avoid circular reference)
|
||||
# Use RBAC filtering
|
||||
workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
|
||||
|
||||
if not workflows:
|
||||
return []
|
||||
|
||||
# Get stats for this workflow from normalized table
|
||||
stats = self._getRecordset(ChatStat, recordFilter={"workflowId": workflowId})
|
||||
|
||||
if not stats:
|
||||
return []
|
||||
|
||||
# Return all stats records sorted by creation time.
|
||||
# Use parseTimestamp to tolerate mixed DB types (float/string) on INT.
|
||||
# DB uses _createdAt (camelCase system field).
|
||||
stats.sort(key=lambda x: parseTimestamp(x.get("_createdAt"), default=0))
|
||||
|
||||
# Convert to ChatStat objects, preserving _createdAt via extra="allow"
|
||||
result = []
|
||||
for stat in stats:
|
||||
chat_stat = ChatStat(**stat)
|
||||
# Explicitly preserve _createdAt from raw DB record
|
||||
if "_createdAt" in stat:
|
||||
setattr(chat_stat, '_createdAt', stat["_createdAt"])
|
||||
result.append(chat_stat)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def createStat(self, statData: Dict[str, Any]) -> ChatStat:
|
||||
"""Creates a new stats record and returns it."""
|
||||
try:
|
||||
# Ensure workflowId is present in statData
|
||||
if "workflowId" not in statData:
|
||||
raise ValueError("workflowId is required in statData")
|
||||
|
||||
# Note: Chat data is user-owned, no mandate/featureInstance context stored
|
||||
# mandateId/featureInstanceId removed from ChatStat model
|
||||
|
||||
# Validate the stat data against ChatStat model
|
||||
stat = ChatStat(**statData)
|
||||
|
||||
logger.debug(f"Creating stat for workflow {statData.get('workflowId')}: "
|
||||
f"process={statData.get('process')}, "
|
||||
f"priceCHF={statData.get('priceCHF', 0):.4f}, "
|
||||
f"processingTime={statData.get('processingTime', 0):.2f}s")
|
||||
|
||||
# Create the stat record in the database
|
||||
created = self.db.recordCreate(ChatStat, stat)
|
||||
|
||||
logger.info(f"Created stat {created.get('id')} for workflow {statData.get('workflowId')}")
|
||||
|
||||
# Return the created ChatStat
|
||||
return ChatStat(**created)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating workflow stat: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def getUnifiedChatData(self, workflowId: str, afterTimestamp: Optional[float] = None) -> Dict[str, Any]:
|
||||
def getUnifiedChatData(self, workflowId: str, afterTimestamp: Optional[float] = None, workflowCost: float = 0.0) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns unified chat data (messages, logs, stats) for a workflow in chronological order.
|
||||
Uses timestamp-based selective data transfer for efficient polling.
|
||||
Returns unified chat data (messages, logs) for a workflow in chronological order,
|
||||
plus workflowCost from billing transactions (single source of truth).
|
||||
"""
|
||||
# Check workflow access first
|
||||
# Use RBAC filtering
|
||||
|
|
@ -1652,29 +1575,10 @@ class ChatObjects:
|
|||
"item": chatLog
|
||||
})
|
||||
|
||||
# Get stats - ChatStat model supports _createdAt via model_config extra="allow"
|
||||
stats = self.getStats(workflowId)
|
||||
for stat in stats:
|
||||
# Apply timestamp filtering in Python
|
||||
# Use _createdAt (system field from DB, preserved via model_config extra="allow")
|
||||
stat_timestamp = getattr(stat, '_createdAt', None) or getUtcTimestamp()
|
||||
if afterTimestamp is not None and stat_timestamp <= afterTimestamp:
|
||||
continue
|
||||
|
||||
# Convert to dict and include _createdAt for frontend
|
||||
stat_dict = stat.model_dump() if hasattr(stat, 'model_dump') else stat.dict()
|
||||
stat_dict['_createdAt'] = stat_timestamp
|
||||
|
||||
items.append({
|
||||
"type": "stat",
|
||||
"createdAt": stat_timestamp,
|
||||
"item": stat_dict
|
||||
})
|
||||
|
||||
# Sort all items by createdAt timestamp for chronological order
|
||||
items.sort(key=lambda x: parseTimestamp(x.get("createdAt"), default=0))
|
||||
|
||||
return {"items": items}
|
||||
return {"items": items, "workflowCost": workflowCost}
|
||||
|
||||
|
||||
def getInterface(currentUser: Optional[User] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> 'ChatObjects':
|
||||
|
|
|
|||
247
modules/interfaces/interfaceDbKnowledge.py
Normal file
247
modules/interfaces/interfaceDbKnowledge.py
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Interface to the Knowledge Store database (poweron_knowledge).
|
||||
Provides CRUD for FileContentIndex, ContentChunk, WorkflowMemory
|
||||
and semantic search via pgvector.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from modules.connectors.connectorDbPostgre import _get_cached_connector
|
||||
from modules.datamodels.datamodelKnowledge import FileContentIndex, ContentChunk, WorkflowMemory
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_instances: Dict[str, "KnowledgeObjects"] = {}
|
||||
|
||||
|
||||
class KnowledgeObjects:
|
||||
"""Interface to the Knowledge Store database.
|
||||
Manages FileContentIndex, ContentChunk, and WorkflowMemory with semantic search."""
|
||||
|
||||
def __init__(self):
|
||||
self.currentUser: Optional[User] = None
|
||||
self.userId: Optional[str] = None
|
||||
self._initializeDatabase()
|
||||
|
||||
def _initializeDatabase(self):
|
||||
dbHost = APP_CONFIG.get("DB_HOST", "_no_config_default_data")
|
||||
dbDatabase = "poweron_knowledge"
|
||||
dbUser = APP_CONFIG.get("DB_USER")
|
||||
dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET")
|
||||
dbPort = int(APP_CONFIG.get("DB_PORT", 5432))
|
||||
|
||||
self.db = _get_cached_connector(
|
||||
dbHost=dbHost,
|
||||
dbDatabase=dbDatabase,
|
||||
dbUser=dbUser,
|
||||
dbPassword=dbPassword,
|
||||
dbPort=dbPort,
|
||||
userId=self.userId,
|
||||
)
|
||||
logger.info("Knowledge Store database initialized")
|
||||
|
||||
def setUserContext(self, user: User):
|
||||
self.currentUser = user
|
||||
self.userId = user.id if user else None
|
||||
if self.userId:
|
||||
self.db.updateContext(self.userId)
|
||||
|
||||
# =========================================================================
|
||||
# FileContentIndex CRUD
|
||||
# =========================================================================
|
||||
|
||||
def upsertFileContentIndex(self, index: FileContentIndex) -> Dict[str, Any]:
|
||||
"""Create or update a FileContentIndex entry."""
|
||||
data = index.model_dump()
|
||||
existing = self.db._loadRecord(FileContentIndex, index.id)
|
||||
if existing:
|
||||
return self.db.recordModify(FileContentIndex, index.id, data)
|
||||
return self.db.recordCreate(FileContentIndex, data)
|
||||
|
||||
def getFileContentIndex(self, fileId: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a FileContentIndex by file ID."""
|
||||
return self.db._loadRecord(FileContentIndex, fileId)
|
||||
|
||||
def getFileContentIndexByUser(
|
||||
self, userId: str, featureInstanceId: str = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get all FileContentIndex entries for a user."""
|
||||
recordFilter = {"userId": userId}
|
||||
if featureInstanceId:
|
||||
recordFilter["featureInstanceId"] = featureInstanceId
|
||||
return self.db.getRecordset(FileContentIndex, recordFilter=recordFilter)
|
||||
|
||||
def updateFileStatus(self, fileId: str, status: str) -> bool:
|
||||
"""Update the processing status of a FileContentIndex."""
|
||||
existing = self.db._loadRecord(FileContentIndex, fileId)
|
||||
if not existing:
|
||||
return False
|
||||
self.db.recordModify(FileContentIndex, fileId, {"status": status})
|
||||
return True
|
||||
|
||||
def deleteFileContentIndex(self, fileId: str) -> bool:
|
||||
"""Delete a FileContentIndex and all associated ContentChunks."""
|
||||
chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
|
||||
for chunk in chunks:
|
||||
self.db.recordDelete(ContentChunk, chunk["id"])
|
||||
return self.db.recordDelete(FileContentIndex, fileId)
|
||||
|
||||
# =========================================================================
|
||||
# ContentChunk CRUD
|
||||
# =========================================================================
|
||||
|
||||
def upsertContentChunk(self, chunk: ContentChunk) -> Dict[str, Any]:
|
||||
"""Create or update a ContentChunk."""
|
||||
data = chunk.model_dump()
|
||||
existing = self.db._loadRecord(ContentChunk, chunk.id)
|
||||
if existing:
|
||||
return self.db.recordModify(ContentChunk, chunk.id, data)
|
||||
return self.db.recordCreate(ContentChunk, data)
|
||||
|
||||
def upsertContentChunks(self, chunks: List[ContentChunk]) -> int:
|
||||
"""Batch upsert multiple ContentChunks. Returns count of upserted chunks."""
|
||||
count = 0
|
||||
for chunk in chunks:
|
||||
self.upsertContentChunk(chunk)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def getContentChunks(self, fileId: str) -> List[Dict[str, Any]]:
|
||||
"""Get all ContentChunks for a file."""
|
||||
return self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
|
||||
|
||||
def deleteContentChunks(self, fileId: str) -> int:
|
||||
"""Delete all ContentChunks for a file. Returns count of deleted chunks."""
|
||||
chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
|
||||
count = 0
|
||||
for chunk in chunks:
|
||||
if self.db.recordDelete(ContentChunk, chunk["id"]):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
# =========================================================================
|
||||
# WorkflowMemory CRUD
|
||||
# =========================================================================
|
||||
|
||||
def upsertWorkflowMemory(self, memory: WorkflowMemory) -> Dict[str, Any]:
|
||||
"""Create or update a WorkflowMemory entry."""
|
||||
data = memory.model_dump()
|
||||
existing = self.db._loadRecord(WorkflowMemory, memory.id)
|
||||
if existing:
|
||||
return self.db.recordModify(WorkflowMemory, memory.id, data)
|
||||
return self.db.recordCreate(WorkflowMemory, data)
|
||||
|
||||
def getWorkflowEntities(self, workflowId: str) -> List[Dict[str, Any]]:
|
||||
"""Get all WorkflowMemory entries for a workflow."""
|
||||
return self.db.getRecordset(WorkflowMemory, recordFilter={"workflowId": workflowId})
|
||||
|
||||
def getWorkflowEntity(self, workflowId: str, key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a specific WorkflowMemory entry by workflow and key."""
|
||||
results = self.db.getRecordset(
|
||||
WorkflowMemory, recordFilter={"workflowId": workflowId, "key": key}
|
||||
)
|
||||
return results[0] if results else None
|
||||
|
||||
def deleteWorkflowMemory(self, workflowId: str) -> int:
|
||||
"""Delete all WorkflowMemory entries for a workflow. Returns count."""
|
||||
entries = self.db.getRecordset(WorkflowMemory, recordFilter={"workflowId": workflowId})
|
||||
count = 0
|
||||
for entry in entries:
|
||||
if self.db.recordDelete(WorkflowMemory, entry["id"]):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
# =========================================================================
|
||||
# Semantic Search
|
||||
# =========================================================================
|
||||
|
||||
def semanticSearch(
|
||||
self,
|
||||
queryVector: List[float],
|
||||
userId: str = None,
|
||||
featureInstanceId: str = None,
|
||||
mandateId: str = None,
|
||||
isShared: bool = None,
|
||||
limit: int = 10,
|
||||
minScore: float = None,
|
||||
contentType: str = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Semantic search across ContentChunks using pgvector cosine similarity.
|
||||
|
||||
Args:
|
||||
queryVector: Query embedding vector.
|
||||
userId: Filter by user (Instance Layer).
|
||||
featureInstanceId: Filter by feature instance.
|
||||
mandateId: Filter by mandate (for Shared Layer lookups).
|
||||
isShared: If True, search Shared Layer via FileContentIndex join.
|
||||
limit: Max results.
|
||||
minScore: Minimum cosine similarity (0.0 - 1.0).
|
||||
contentType: Filter by content type (text, image, etc.).
|
||||
|
||||
Returns:
|
||||
List of ContentChunk records with _score field, sorted by relevance.
|
||||
"""
|
||||
recordFilter = {}
|
||||
if userId:
|
||||
recordFilter["userId"] = userId
|
||||
if featureInstanceId:
|
||||
recordFilter["featureInstanceId"] = featureInstanceId
|
||||
if contentType:
|
||||
recordFilter["contentType"] = contentType
|
||||
|
||||
if isShared and mandateId:
|
||||
sharedIndexes = self.db.getRecordset(
|
||||
FileContentIndex,
|
||||
recordFilter={"mandateId": mandateId, "isShared": True},
|
||||
)
|
||||
sharedFileIds = [idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", None) for idx in sharedIndexes]
|
||||
sharedFileIds = [fid for fid in sharedFileIds if fid]
|
||||
if not sharedFileIds:
|
||||
return []
|
||||
recordFilter.pop("userId", None)
|
||||
recordFilter.pop("featureInstanceId", None)
|
||||
recordFilter["fileId"] = sharedFileIds
|
||||
|
||||
return self.db.semanticSearch(
|
||||
modelClass=ContentChunk,
|
||||
vectorColumn="embedding",
|
||||
queryVector=queryVector,
|
||||
limit=limit,
|
||||
recordFilter=recordFilter if recordFilter else None,
|
||||
minScore=minScore,
|
||||
)
|
||||
|
||||
def semanticSearchWorkflowMemory(
|
||||
self,
|
||||
queryVector: List[float],
|
||||
workflowId: str,
|
||||
limit: int = 5,
|
||||
minScore: float = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Semantic search across WorkflowMemory entries."""
|
||||
return self.db.semanticSearch(
|
||||
modelClass=WorkflowMemory,
|
||||
vectorColumn="embedding",
|
||||
queryVector=queryVector,
|
||||
limit=limit,
|
||||
recordFilter={"workflowId": workflowId},
|
||||
minScore=minScore,
|
||||
)
|
||||
|
||||
|
||||
def getInterface(currentUser: Optional[User] = None) -> KnowledgeObjects:
|
||||
"""Get or create a KnowledgeObjects singleton."""
|
||||
if "default" not in _instances:
|
||||
_instances["default"] = KnowledgeObjects()
|
||||
|
||||
interface = _instances["default"]
|
||||
if currentUser:
|
||||
interface.setUserContext(currentUser)
|
||||
|
||||
return interface
|
||||
|
|
@ -10,6 +10,7 @@ import logging
|
|||
import base64
|
||||
import hashlib
|
||||
import math
|
||||
import mimetypes
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
|
||||
from modules.connectors.connectorDbPostgre import DatabaseConnector, _get_cached_connector
|
||||
|
|
@ -18,6 +19,7 @@ from modules.security.rbac import RbacClass
|
|||
from modules.datamodels.datamodelRbac import AccessRuleContext
|
||||
from modules.datamodels.datamodelUam import AccessLevel
|
||||
from modules.datamodels.datamodelFiles import FilePreview, FileItem, FileData
|
||||
from modules.datamodels.datamodelFileFolder import FileFolder
|
||||
from modules.datamodels.datamodelUtils import Prompt
|
||||
from modules.datamodels.datamodelVoice import VoiceSettings
|
||||
from modules.datamodels.datamodelMessaging import (
|
||||
|
|
@ -851,7 +853,9 @@ class ComponentObjects:
|
|||
"svg": "image/svg+xml",
|
||||
"py": "text/x-python",
|
||||
"js": "application/javascript",
|
||||
"css": "text/css"
|
||||
"css": "text/css",
|
||||
"eml": "message/rfc822",
|
||||
"msg": "application/vnd.ms-outlook",
|
||||
}
|
||||
return extensionToMime.get(ext.lower(), "application/octet-stream")
|
||||
|
||||
|
|
@ -1143,6 +1147,350 @@ class ComponentObjects:
|
|||
logger.error(f"Error deleting file {fileId}: {str(e)}")
|
||||
raise FileDeletionError(f"Error deleting file: {str(e)}")
|
||||
|
||||
def deleteFilesBatch(self, fileIds: List[str]) -> Dict[str, Any]:
|
||||
"""Delete multiple files in a single SQL batch call."""
|
||||
uniqueIds = [str(fid) for fid in dict.fromkeys(fileIds or []) if fid]
|
||||
if not uniqueIds:
|
||||
return {"deletedFiles": 0}
|
||||
|
||||
try:
|
||||
self.db._ensure_connection()
|
||||
with self.db.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
'SELECT "id" FROM "FileItem" WHERE "id" = ANY(%s) AND "_createdBy" = %s',
|
||||
(uniqueIds, self.userId or ""),
|
||||
)
|
||||
accessibleIds = [row["id"] for row in cursor.fetchall()]
|
||||
|
||||
if len(accessibleIds) != len(uniqueIds):
|
||||
missingIds = sorted(set(uniqueIds) - set(accessibleIds))
|
||||
raise FileNotFoundError(f"Files not found or not accessible: {missingIds}")
|
||||
|
||||
cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (accessibleIds,))
|
||||
cursor.execute(
|
||||
'DELETE FROM "FileItem" WHERE "id" = ANY(%s) AND "_createdBy" = %s',
|
||||
(accessibleIds, self.userId or ""),
|
||||
)
|
||||
deletedFiles = cursor.rowcount
|
||||
|
||||
self.db.connection.commit()
|
||||
return {"deletedFiles": deletedFiles}
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting files in batch: {e}")
|
||||
self.db.connection.rollback()
|
||||
raise FileDeletionError(f"Error deleting files in batch: {str(e)}")
|
||||
|
||||
# ---- Folder methods ----
|
||||
|
||||
_RESERVED_FOLDER_NAMES = {"(Global)"}
|
||||
|
||||
def _validateFolderName(self, name: str, parentId: Optional[str], excludeFolderId: Optional[str] = None):
|
||||
"""Ensures folder name is not reserved and is unique within parent."""
|
||||
if name in self._RESERVED_FOLDER_NAMES:
|
||||
raise ValueError(f"Folder name '{name}' is reserved")
|
||||
if not name or not name.strip():
|
||||
raise ValueError("Folder name cannot be empty")
|
||||
existingFolders = self.db.getRecordset(FileFolder, recordFilter={"parentId": parentId or ""})
|
||||
for f in existingFolders:
|
||||
if f.get("name") == name and f.get("id") != excludeFolderId:
|
||||
raise ValueError(f"Folder '{name}' already exists in this directory")
|
||||
|
||||
def _isDescendantOf(self, folderId: str, ancestorId: str) -> bool:
|
||||
"""Checks if folderId is a descendant of ancestorId (circular reference check)."""
|
||||
visited = set()
|
||||
currentId = folderId
|
||||
while currentId:
|
||||
if currentId == ancestorId:
|
||||
return True
|
||||
if currentId in visited:
|
||||
break
|
||||
visited.add(currentId)
|
||||
folders = self.db.getRecordset(FileFolder, recordFilter={"id": currentId})
|
||||
if not folders:
|
||||
break
|
||||
currentId = folders[0].get("parentId")
|
||||
return False
|
||||
|
||||
def getFolder(self, folderId: str) -> Optional[Dict[str, Any]]:
|
||||
"""Returns a folder by ID if it belongs to the current user."""
|
||||
folders = self.db.getRecordset(FileFolder, recordFilter={"id": folderId, "_createdBy": self.userId or ""})
|
||||
return folders[0] if folders else None
|
||||
|
||||
def listFolders(self, parentId: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""List folders for current user, optionally filtered by parentId."""
|
||||
recordFilter = {"_createdBy": self.userId or ""}
|
||||
if parentId is not None:
|
||||
recordFilter["parentId"] = parentId
|
||||
return self.db.getRecordset(FileFolder, recordFilter=recordFilter)
|
||||
|
||||
def createFolder(self, name: str, parentId: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Create a new folder with unique name validation."""
|
||||
self._validateFolderName(name, parentId)
|
||||
folder = FileFolder(
|
||||
name=name,
|
||||
parentId=parentId,
|
||||
mandateId=self.mandateId or "",
|
||||
featureInstanceId=self.featureInstanceId or "",
|
||||
)
|
||||
return self.db.recordCreate(FileFolder, folder)
|
||||
|
||||
def renameFolder(self, folderId: str, newName: str) -> bool:
|
||||
"""Rename a folder with unique name validation."""
|
||||
folder = self.getFolder(folderId)
|
||||
if not folder:
|
||||
raise FileNotFoundError(f"Folder {folderId} not found")
|
||||
self._validateFolderName(newName, folder.get("parentId"), excludeFolderId=folderId)
|
||||
return self.db.recordModify(FileFolder, folderId, {"name": newName})
|
||||
|
||||
def moveFolder(self, folderId: str, targetParentId: Optional[str] = None) -> bool:
|
||||
"""Move a folder to a new parent, with circular reference and unique name checks."""
|
||||
folder = self.getFolder(folderId)
|
||||
if not folder:
|
||||
raise FileNotFoundError(f"Folder {folderId} not found")
|
||||
if targetParentId and self._isDescendantOf(targetParentId, folderId):
|
||||
raise ValueError("Cannot move folder into its own subtree")
|
||||
self._validateFolderName(folder.get("name", ""), targetParentId, excludeFolderId=folderId)
|
||||
return self.db.recordModify(FileFolder, folderId, {"parentId": targetParentId})
|
||||
|
||||
def moveFilesBatch(self, fileIds: List[str], targetFolderId: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Move multiple files with one SQL update."""
|
||||
uniqueIds = [str(fid) for fid in dict.fromkeys(fileIds or []) if fid]
|
||||
if not uniqueIds:
|
||||
return {"movedFiles": 0}
|
||||
|
||||
if targetFolderId:
|
||||
targetFolder = self.getFolder(targetFolderId)
|
||||
if not targetFolder:
|
||||
raise FileNotFoundError(f"Target folder {targetFolderId} not found")
|
||||
|
||||
try:
|
||||
self.db._ensure_connection()
|
||||
with self.db.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
'SELECT "id" FROM "FileItem" WHERE "id" = ANY(%s) AND "_createdBy" = %s',
|
||||
(uniqueIds, self.userId or ""),
|
||||
)
|
||||
accessibleIds = [row["id"] for row in cursor.fetchall()]
|
||||
if len(accessibleIds) != len(uniqueIds):
|
||||
missingIds = sorted(set(uniqueIds) - set(accessibleIds))
|
||||
raise FileNotFoundError(f"Files not found or not accessible: {missingIds}")
|
||||
|
||||
cursor.execute(
|
||||
'UPDATE "FileItem" SET "folderId" = %s, "_modifiedAt" = %s, "_modifiedBy" = %s '
|
||||
'WHERE "id" = ANY(%s) AND "_createdBy" = %s',
|
||||
(targetFolderId, getUtcTimestamp(), self.userId or "", accessibleIds, self.userId or ""),
|
||||
)
|
||||
movedFiles = cursor.rowcount
|
||||
|
||||
self.db.connection.commit()
|
||||
return {"movedFiles": movedFiles}
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving files in batch: {e}")
|
||||
self.db.connection.rollback()
|
||||
raise FileError(f"Error moving files in batch: {str(e)}")
|
||||
|
||||
def moveFoldersBatch(self, folderIds: List[str], targetParentId: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Move multiple folders with one SQL update after validation."""
|
||||
uniqueIds = [str(fid) for fid in dict.fromkeys(folderIds or []) if fid]
|
||||
if not uniqueIds:
|
||||
return {"movedFolders": 0}
|
||||
|
||||
foldersToMove: List[Dict[str, Any]] = []
|
||||
for folderId in uniqueIds:
|
||||
folder = self.getFolder(folderId)
|
||||
if not folder:
|
||||
raise FileNotFoundError(f"Folder {folderId} not found")
|
||||
if targetParentId and self._isDescendantOf(targetParentId, folderId):
|
||||
raise ValueError("Cannot move folder into its own subtree")
|
||||
foldersToMove.append(folder)
|
||||
|
||||
existingInTarget = self.db.getRecordset(
|
||||
FileFolder,
|
||||
recordFilter={"parentId": targetParentId or "", "_createdBy": self.userId or ""},
|
||||
)
|
||||
existingNames = {f.get("name"): f.get("id") for f in existingInTarget}
|
||||
movingNames: Dict[str, str] = {}
|
||||
movingIds = set(uniqueIds)
|
||||
|
||||
for folder in foldersToMove:
|
||||
name = folder.get("name", "")
|
||||
folderId = folder.get("id")
|
||||
if name in movingNames and movingNames[name] != folderId:
|
||||
raise ValueError(f"Folder '{name}' already exists in this move batch")
|
||||
movingNames[name] = folderId
|
||||
|
||||
existingId = existingNames.get(name)
|
||||
if existingId and existingId not in movingIds:
|
||||
raise ValueError(f"Folder '{name}' already exists in target directory")
|
||||
|
||||
try:
|
||||
self.db._ensure_connection()
|
||||
with self.db.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
'UPDATE "FileFolder" SET "parentId" = %s, "_modifiedAt" = %s, "_modifiedBy" = %s '
|
||||
'WHERE "id" = ANY(%s) AND "_createdBy" = %s',
|
||||
(targetParentId, getUtcTimestamp(), self.userId or "", uniqueIds, self.userId or ""),
|
||||
)
|
||||
movedFolders = cursor.rowcount
|
||||
|
||||
self.db.connection.commit()
|
||||
return {"movedFolders": movedFolders}
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving folders in batch: {e}")
|
||||
self.db.connection.rollback()
|
||||
raise FileError(f"Error moving folders in batch: {str(e)}")
|
||||
|
||||
def deleteFolder(self, folderId: str, recursive: bool = False) -> Dict[str, Any]:
|
||||
"""Delete a folder. If recursive, deletes all contents. Returns summary of deletions."""
|
||||
folder = self.getFolder(folderId)
|
||||
if not folder:
|
||||
raise FileNotFoundError(f"Folder {folderId} not found")
|
||||
|
||||
childFolders = self.db.getRecordset(FileFolder, recordFilter={"parentId": folderId, "_createdBy": self.userId or ""})
|
||||
childFiles = self._getFilesByCurrentUser(recordFilter={"folderId": folderId})
|
||||
|
||||
if not recursive and (childFolders or childFiles):
|
||||
raise ValueError(
|
||||
f"Folder '{folder.get('name')}' is not empty "
|
||||
f"({len(childFiles)} files, {len(childFolders)} subfolders). "
|
||||
f"Use recursive=true to delete contents."
|
||||
)
|
||||
|
||||
deletedFiles = 0
|
||||
deletedFolders = 0
|
||||
|
||||
if recursive:
|
||||
for subFolder in childFolders:
|
||||
subResult = self.deleteFolder(subFolder["id"], recursive=True)
|
||||
deletedFiles += subResult.get("deletedFiles", 0)
|
||||
deletedFolders += subResult.get("deletedFolders", 0)
|
||||
for childFile in childFiles:
|
||||
try:
|
||||
self.deleteFile(childFile["id"])
|
||||
deletedFiles += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete file {childFile['id']} during folder deletion: {e}")
|
||||
|
||||
self.db.recordDelete(FileFolder, folderId)
|
||||
deletedFolders += 1
|
||||
|
||||
return {"deletedFiles": deletedFiles, "deletedFolders": deletedFolders}
|
||||
|
||||
def deleteFoldersBatch(self, folderIds: List[str], recursive: bool = True) -> Dict[str, Any]:
|
||||
"""Delete multiple folders and their content in batched SQL calls."""
|
||||
uniqueIds = [str(fid) for fid in dict.fromkeys(folderIds or []) if fid]
|
||||
if not uniqueIds:
|
||||
return {"deletedFiles": 0, "deletedFolders": 0}
|
||||
|
||||
if not recursive:
|
||||
deletedFiles = 0
|
||||
deletedFolders = 0
|
||||
for folderId in uniqueIds:
|
||||
result = self.deleteFolder(folderId, recursive=False)
|
||||
deletedFiles += result.get("deletedFiles", 0)
|
||||
deletedFolders += result.get("deletedFolders", 0)
|
||||
return {"deletedFiles": deletedFiles, "deletedFolders": deletedFolders}
|
||||
|
||||
try:
|
||||
self.db._ensure_connection()
|
||||
with self.db.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
'SELECT "id" FROM "FileFolder" WHERE "id" = ANY(%s) AND "_createdBy" = %s',
|
||||
(uniqueIds, self.userId or ""),
|
||||
)
|
||||
rootAccessibleIds = [row["id"] for row in cursor.fetchall()]
|
||||
if len(rootAccessibleIds) != len(uniqueIds):
|
||||
missingIds = sorted(set(uniqueIds) - set(rootAccessibleIds))
|
||||
raise FileNotFoundError(f"Folders not found or not accessible: {missingIds}")
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
WITH RECURSIVE folder_tree AS (
|
||||
SELECT "id"
|
||||
FROM "FileFolder"
|
||||
WHERE "id" = ANY(%s) AND "_createdBy" = %s
|
||||
UNION ALL
|
||||
SELECT child."id"
|
||||
FROM "FileFolder" child
|
||||
INNER JOIN folder_tree ft ON child."parentId" = ft."id"
|
||||
WHERE child."_createdBy" = %s
|
||||
)
|
||||
SELECT DISTINCT "id" FROM folder_tree
|
||||
""",
|
||||
(rootAccessibleIds, self.userId or "", self.userId or ""),
|
||||
)
|
||||
allFolderIds = [row["id"] for row in cursor.fetchall()]
|
||||
|
||||
cursor.execute(
|
||||
'SELECT "id" FROM "FileItem" WHERE "folderId" = ANY(%s) AND "_createdBy" = %s',
|
||||
(allFolderIds, self.userId or ""),
|
||||
)
|
||||
allFileIds = [row["id"] for row in cursor.fetchall()]
|
||||
|
||||
if allFileIds:
|
||||
cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (allFileIds,))
|
||||
cursor.execute(
|
||||
'DELETE FROM "FileItem" WHERE "id" = ANY(%s) AND "_createdBy" = %s',
|
||||
(allFileIds, self.userId or ""),
|
||||
)
|
||||
deletedFiles = cursor.rowcount
|
||||
else:
|
||||
deletedFiles = 0
|
||||
|
||||
cursor.execute(
|
||||
'DELETE FROM "FileFolder" WHERE "id" = ANY(%s) AND "_createdBy" = %s',
|
||||
(allFolderIds, self.userId or ""),
|
||||
)
|
||||
deletedFolders = cursor.rowcount
|
||||
|
||||
self.db.connection.commit()
|
||||
return {"deletedFiles": deletedFiles, "deletedFolders": deletedFolders}
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting folders in batch: {e}")
|
||||
self.db.connection.rollback()
|
||||
raise FileDeletionError(f"Error deleting folders in batch: {str(e)}")
|
||||
|
||||
def copyFile(self, sourceFileId: str, targetFolderId: Optional[str] = None, newFileName: Optional[str] = None) -> FileItem:
|
||||
"""Create a full duplicate of a file (FileItem + FileData)."""
|
||||
sourceFile = self.getFile(sourceFileId)
|
||||
if not sourceFile:
|
||||
raise FileNotFoundError(f"File {sourceFileId} not found")
|
||||
|
||||
sourceData = self.getFileData(sourceFileId)
|
||||
if sourceData is None:
|
||||
raise FileStorageError(f"No data found for file {sourceFileId}")
|
||||
|
||||
fileName = newFileName or sourceFile.fileName
|
||||
copiedFile = self.createFile(fileName, sourceFile.mimeType, sourceData)
|
||||
|
||||
if targetFolderId:
|
||||
self.updateFile(copiedFile.id, {"folderId": targetFolderId})
|
||||
elif sourceFile.folderId:
|
||||
self.updateFile(copiedFile.id, {"folderId": sourceFile.folderId})
|
||||
|
||||
self.createFileData(copiedFile.id, sourceData)
|
||||
return copiedFile
|
||||
|
||||
def updateFileData(self, fileId: str, data: bytes) -> bool:
|
||||
"""Replace existing file data (delete + create). Updates FileItem metadata."""
|
||||
file = self.getFile(fileId)
|
||||
if not file:
|
||||
raise FileNotFoundError(f"File {fileId} not found")
|
||||
|
||||
try:
|
||||
self.db.recordDelete(FileData, fileId)
|
||||
logger.debug(f"Deleted existing FileData for {fileId}")
|
||||
except Exception as e:
|
||||
logger.debug(f"No existing FileData to delete for {fileId}: {e}")
|
||||
|
||||
success = self.createFileData(fileId, data)
|
||||
if success:
|
||||
newSize = len(data)
|
||||
newHash = hashlib.sha256(data).hexdigest()
|
||||
self.db.recordModify(FileItem, fileId, {"fileSize": newSize, "fileHash": newHash})
|
||||
logger.info(f"Updated file data for {fileId} ({newSize} bytes)")
|
||||
return success
|
||||
|
||||
# FileData methods - data operations
|
||||
|
||||
def createFileData(self, fileId: str, data: bytes) -> bool:
|
||||
|
|
@ -1395,11 +1743,15 @@ class ComponentObjects:
|
|||
logger.error("No user ID provided for voice settings")
|
||||
return None
|
||||
|
||||
# Get voice settings for the user, filtered by RBAC
|
||||
recordFilter: Dict[str, Any] = {"userId": targetUserId}
|
||||
if self.featureInstanceId:
|
||||
recordFilter["featureInstanceId"] = self.featureInstanceId
|
||||
|
||||
# Get voice settings for the user (scoped to current feature instance if available), filtered by RBAC
|
||||
filteredSettings = getRecordsetWithRBAC(self.db,
|
||||
VoiceSettings,
|
||||
self.currentUser,
|
||||
recordFilter={"userId": targetUserId},
|
||||
recordFilter=recordFilter,
|
||||
mandateId=self.mandateId
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -58,7 +58,6 @@ TABLE_NAMESPACE = {
|
|||
"ChatWorkflow": "chat",
|
||||
"ChatMessage": "chat",
|
||||
"ChatLog": "chat",
|
||||
"ChatStat": "chat",
|
||||
"ChatDocument": "chat",
|
||||
"Prompt": "chat",
|
||||
# Chatbot (poweron_chatbot) - per feature-instance isolation
|
||||
|
|
@ -69,13 +68,20 @@ TABLE_NAMESPACE = {
|
|||
# Files - benutzer-eigen
|
||||
"FileItem": "files",
|
||||
"FileData": "files",
|
||||
"FileFolder": "files",
|
||||
# Automation - benutzer-eigen
|
||||
"AutomationDefinition": "automation",
|
||||
"AutomationTemplate": "automation",
|
||||
# Knowledge Store - benutzer-eigen
|
||||
"FileContentIndex": "knowledge",
|
||||
"ContentChunk": "knowledge",
|
||||
"WorkflowMemory": "knowledge",
|
||||
# Data Sources - benutzer-eigen
|
||||
"DataSource": "datasource",
|
||||
}
|
||||
|
||||
# Namespaces ohne Mandantenkontext - GROUP wird auf MY gemappt
|
||||
USER_OWNED_NAMESPACES = {"chat", "chatbot", "files", "automation"}
|
||||
USER_OWNED_NAMESPACES = {"chat", "chatbot", "files", "automation", "knowledge", "datasource"}
|
||||
|
||||
|
||||
def buildDataObjectKey(tableName: str, featureCode: Optional[str] = None) -> str:
|
||||
|
|
@ -175,7 +181,7 @@ def getRecordsetWithRBAC(
|
|||
whereValues = []
|
||||
|
||||
# CRITICAL: Only pass featureInstanceId to WHERE clause if the model actually has
|
||||
# this column. Chat child tables (ChatMessage, ChatLog, ChatStat, ChatDocument)
|
||||
# this column. Chat child tables (ChatMessage, ChatLog, ChatDocument)
|
||||
# are user-owned and do NOT have featureInstanceId - only ChatWorkflow does.
|
||||
# Without this check, the SQL query would reference a non-existent column,
|
||||
# causing a silent error that returns empty results.
|
||||
|
|
|
|||
|
|
@ -6,8 +6,9 @@ Provides a generic interface layer between routes and voice connectors.
|
|||
Handles voice operations including speech-to-text, text-to-speech, and translation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
from typing import AsyncGenerator, Callable, Dict, Any, Optional, List
|
||||
|
||||
from modules.connectors.connectorVoiceGoogle import ConnectorGoogleSpeech
|
||||
from modules.datamodels.datamodelVoice import VoiceSettings
|
||||
|
|
@ -30,6 +31,7 @@ class VoiceObjects:
|
|||
self.currentUser: Optional[User] = None
|
||||
self.userId: Optional[str] = None
|
||||
self._google_speech_connector: Optional[ConnectorGoogleSpeech] = None
|
||||
self.billingCallback: Optional[Callable[[Dict[str, Any]], None]] = None
|
||||
|
||||
def setUserContext(self, currentUser: User, mandateId: Optional[str] = None):
|
||||
"""Set the user context for the interface.
|
||||
|
|
@ -115,6 +117,32 @@ class VoiceObjects:
|
|||
"error": str(e)
|
||||
}
|
||||
|
||||
async def streamingSpeechToText(
|
||||
self,
|
||||
audioQueue: asyncio.Queue,
|
||||
language: str = "de-DE",
|
||||
phraseHints: Optional[list] = None,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
Stream audio to Google Streaming STT and yield interim/final results.
|
||||
Billing is recorded for each final result.
|
||||
"""
|
||||
connector = self._getGoogleSpeechConnector()
|
||||
async for event in connector.streamingRecognize(audioQueue, language, phraseHints):
|
||||
if event.get("isFinal") and self.billingCallback:
|
||||
durationSec = event.get("audioDurationSec", 0)
|
||||
priceCHF = connector.calculateSttCostCHF(durationSec)
|
||||
if priceCHF > 0:
|
||||
try:
|
||||
self.billingCallback({
|
||||
"operation": "stt-streaming",
|
||||
"priceCHF": priceCHF,
|
||||
"audioDurationSec": durationSec,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Voice STT billing callback failed: {e}")
|
||||
yield event
|
||||
|
||||
# Translation Operations
|
||||
|
||||
async def detectLanguage(self, text: str) -> Dict[str, Any]:
|
||||
|
|
@ -277,7 +305,18 @@ class VoiceObjects:
|
|||
|
||||
if result["success"]:
|
||||
logger.info(f"✅ Text-to-Speech successful: {len(result['audio_content'])} bytes")
|
||||
# Map connector snake_case keys to camelCase for consistent API
|
||||
if self.billingCallback:
|
||||
connector = self._getGoogleSpeechConnector()
|
||||
priceCHF = connector.calculateTtsCostCHF(len(text))
|
||||
if priceCHF > 0:
|
||||
try:
|
||||
self.billingCallback({
|
||||
"operation": "tts-wavenet",
|
||||
"priceCHF": priceCHF,
|
||||
"characterCount": len(text),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Voice TTS billing callback failed: {e}")
|
||||
return {
|
||||
"success": True,
|
||||
"audioContent": result["audio_content"],
|
||||
|
|
|
|||
|
|
@ -247,19 +247,13 @@ def _getInstancePermissions(rootInterface, userId: str, instanceId: str) -> Dict
|
|||
# Get FeatureAccess for this user and instance (Pydantic model)
|
||||
featureAccess = rootInterface.getFeatureAccess(userId, instanceId)
|
||||
|
||||
logger.debug(f"_getInstancePermissions: userId={userId}, instanceId={instanceId}, featureAccess={featureAccess is not None}")
|
||||
|
||||
if not featureAccess:
|
||||
logger.debug(f"_getInstancePermissions: No FeatureAccess found for user {userId} and instance {instanceId}")
|
||||
return permissions
|
||||
|
||||
# Get role IDs via interface method
|
||||
roleIds = rootInterface.getRoleIdsForFeatureAccess(str(featureAccess.id))
|
||||
|
||||
logger.debug(f"_getInstancePermissions: featureAccessId={featureAccess.id}, roleIds={roleIds}")
|
||||
|
||||
if not roleIds:
|
||||
logger.debug(f"_getInstancePermissions: No roles found for FeatureAccess {featureAccess.id}")
|
||||
return permissions
|
||||
|
||||
# Check if user has admin role
|
||||
|
|
@ -274,8 +268,6 @@ def _getInstancePermissions(rootInterface, userId: str, instanceId: str) -> Dict
|
|||
# Get all rules for this role (returns Pydantic models)
|
||||
accessRules = rootInterface.getAccessRules(roleId=roleId)
|
||||
|
||||
logger.debug(f"_getInstancePermissions: roleId={roleId}, accessRules={len(accessRules) if accessRules else 0}")
|
||||
|
||||
for rule in accessRules:
|
||||
context = rule.context
|
||||
item = rule.item or ""
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from modules.auth import limiter, requireSysAdminRole, getRequestContext, Reques
|
|||
|
||||
# Import billing components
|
||||
from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface, _getRootInterface
|
||||
from modules.services.serviceBilling.mainServiceBilling import getService as getBillingService
|
||||
from modules.serviceCenter.services.serviceBilling.mainServiceBilling import getService as getBillingService
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||||
from modules.routes.routeDataUsers import _applyFiltersAndSort
|
||||
from modules.datamodels.datamodelBilling import (
|
||||
|
|
@ -162,6 +162,23 @@ def _isAdminOfMandate(ctx: RequestContext, targetMandateId: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _isMemberOfMandate(ctx: RequestContext, targetMandateId: str) -> bool:
|
||||
"""Check if user has any enabled membership in the specified mandate."""
|
||||
try:
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
rootInterface = getRootInterface()
|
||||
userMandates = rootInterface.getUserMandates(str(ctx.user.id))
|
||||
for um in userMandates:
|
||||
if str(getattr(um, 'mandateId', None)) != str(targetMandateId):
|
||||
continue
|
||||
if not getattr(um, 'enabled', True):
|
||||
continue
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _filterTransactionsByScope(transactions: list, scope: BillingDataScope) -> list:
|
||||
"""
|
||||
Filter a list of transaction dicts based on the user's BillingDataScope.
|
||||
|
|
@ -219,6 +236,7 @@ class CheckoutCreateRequest(BaseModel):
|
|||
"""Request model for creating Stripe Checkout Session."""
|
||||
userId: Optional[str] = Field(None, description="Target user ID (for PREPAY_USER model)")
|
||||
amount: float = Field(..., gt=0, description="Amount to pay in CHF (must be in allowed presets)")
|
||||
returnUrl: str = Field(..., min_length=1, description="Absolute frontend URL used for Stripe success/cancel redirects")
|
||||
|
||||
|
||||
class CheckoutCreateResponse(BaseModel):
|
||||
|
|
@ -226,6 +244,20 @@ class CheckoutCreateResponse(BaseModel):
|
|||
redirectUrl: str = Field(..., description="Stripe Checkout URL for redirect")
|
||||
|
||||
|
||||
class CheckoutConfirmRequest(BaseModel):
|
||||
"""Request model for confirming Stripe Checkout after redirect."""
|
||||
sessionId: str = Field(..., min_length=1, description="Stripe Checkout Session ID (cs_xxx)")
|
||||
|
||||
|
||||
class CheckoutConfirmResponse(BaseModel):
|
||||
"""Response model for Stripe Checkout confirmation."""
|
||||
credited: bool = Field(..., description="True if a new billing credit was created")
|
||||
alreadyCredited: bool = Field(..., description="True if session was already credited before")
|
||||
sessionId: str = Field(..., description="Stripe Checkout Session ID")
|
||||
mandateId: str = Field(..., description="Mandate ID from Stripe metadata")
|
||||
amountChf: float = Field(..., description="Credited amount in CHF")
|
||||
|
||||
|
||||
class BillingSettingsUpdate(BaseModel):
|
||||
"""Request model for updating billing settings."""
|
||||
billingModel: Optional[BillingModelEnum] = None
|
||||
|
|
@ -328,6 +360,107 @@ class UserTransactionResponse(BaseModel):
|
|||
userName: Optional[str] = None
|
||||
|
||||
|
||||
def _getStripeClient():
|
||||
"""Initialize and return configured Stripe SDK module."""
|
||||
import stripe
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
api_version = APP_CONFIG.get("STRIPE_API_VERSION")
|
||||
if api_version:
|
||||
stripe.api_version = api_version
|
||||
|
||||
secret_key = APP_CONFIG.get("STRIPE_SECRET_KEY_SECRET") or APP_CONFIG.get("STRIPE_SECRET_KEY")
|
||||
if not secret_key:
|
||||
raise ValueError("STRIPE_SECRET_KEY_SECRET not configured")
|
||||
|
||||
stripe.api_key = secret_key
|
||||
return stripe
|
||||
|
||||
|
||||
def _creditStripeSessionIfNeeded(
|
||||
billingInterface,
|
||||
session: Dict[str, Any],
|
||||
eventId: Optional[str] = None,
|
||||
) -> CheckoutConfirmResponse:
|
||||
"""
|
||||
Credit balance from Stripe Checkout session if not already credited.
|
||||
Uses Checkout session ID for idempotency across webhook + manual confirmation flows.
|
||||
"""
|
||||
from modules.serviceCenter.services.serviceBilling.stripeCheckout import ALLOWED_AMOUNTS_CHF
|
||||
|
||||
session_id = session.get("id")
|
||||
metadata = session.get("metadata") or {}
|
||||
mandate_id = metadata.get("mandateId")
|
||||
user_id = metadata.get("userId") or None
|
||||
amount_chf_str = metadata.get("amountChf", "0")
|
||||
|
||||
if not session_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe session id missing")
|
||||
if not mandate_id:
|
||||
raise HTTPException(status_code=400, detail="Invalid session metadata: mandateId missing")
|
||||
|
||||
existing_payment_tx = billingInterface.getPaymentTransactionByReferenceId(session_id)
|
||||
if existing_payment_tx:
|
||||
if eventId and not billingInterface.getStripeWebhookEventByEventId(eventId):
|
||||
billingInterface.createStripeWebhookEvent(eventId)
|
||||
return CheckoutConfirmResponse(
|
||||
credited=False,
|
||||
alreadyCredited=True,
|
||||
sessionId=session_id,
|
||||
mandateId=mandate_id,
|
||||
amountChf=float(existing_payment_tx.get("amount", 0.0)),
|
||||
)
|
||||
|
||||
try:
|
||||
amount_chf = float(amount_chf_str)
|
||||
except (TypeError, ValueError):
|
||||
amount_chf = None
|
||||
|
||||
if amount_chf is None or amount_chf not in ALLOWED_AMOUNTS_CHF:
|
||||
amount_total = session.get("amount_total")
|
||||
if amount_total is not None:
|
||||
amount_chf = amount_total / 100.0
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid amount in Stripe session")
|
||||
|
||||
settings = billingInterface.getSettings(mandate_id)
|
||||
if not settings:
|
||||
raise HTTPException(status_code=404, detail="Billing settings not found")
|
||||
|
||||
billing_model = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value))
|
||||
if billing_model == BillingModelEnum.PREPAY_USER:
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=400, detail="userId required for PREPAY_USER")
|
||||
account = billingInterface.getOrCreateUserAccount(mandate_id, user_id, initialBalance=0.0)
|
||||
elif billing_model in [BillingModelEnum.PREPAY_MANDATE, BillingModelEnum.CREDIT_POSTPAY]:
|
||||
account = billingInterface.getOrCreateMandateAccount(mandate_id, initialBalance=0.0)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Cannot add credit to {billing_model.value}")
|
||||
|
||||
transaction = BillingTransaction(
|
||||
accountId=account["id"],
|
||||
transactionType=TransactionTypeEnum.CREDIT,
|
||||
amount=amount_chf,
|
||||
description="Stripe-Zahlung",
|
||||
referenceType=ReferenceTypeEnum.PAYMENT,
|
||||
referenceId=session_id,
|
||||
createdByUserId=user_id,
|
||||
)
|
||||
billingInterface.createTransaction(transaction)
|
||||
|
||||
if eventId and not billingInterface.getStripeWebhookEventByEventId(eventId):
|
||||
billingInterface.createStripeWebhookEvent(eventId)
|
||||
|
||||
logger.info(f"Stripe credit applied: {amount_chf} CHF for session {session_id} on mandate {mandate_id}")
|
||||
return CheckoutConfirmResponse(
|
||||
credited=True,
|
||||
alreadyCredited=False,
|
||||
sessionId=session_id,
|
||||
mandateId=mandate_id,
|
||||
amountChf=amount_chf,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Router Setup
|
||||
# =============================================================================
|
||||
|
|
@ -720,11 +853,11 @@ def createCheckoutSession(
|
|||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||||
checkoutRequest: CheckoutCreateRequest = Body(...),
|
||||
ctx: RequestContext = Depends(getRequestContext),
|
||||
_admin = Depends(requireSysAdminRole)
|
||||
):
|
||||
"""
|
||||
Create Stripe Checkout Session for credit top-up. Returns redirect URL.
|
||||
SysAdmin only. Amount is validated server-side against allowed presets.
|
||||
RBAC: PREPAY_USER requires mandate membership (user loads own account),
|
||||
PREPAY_MANDATE requires mandate admin role.
|
||||
"""
|
||||
try:
|
||||
billingInterface = getBillingInterface(ctx.user, targetMandateId)
|
||||
|
|
@ -738,14 +871,22 @@ def createCheckoutSession(
|
|||
if billingModel == BillingModelEnum.PREPAY_USER:
|
||||
if not checkoutRequest.userId:
|
||||
raise HTTPException(status_code=400, detail="userId is required for PREPAY_USER model")
|
||||
elif billingModel not in [BillingModelEnum.PREPAY_MANDATE, BillingModelEnum.CREDIT_POSTPAY]:
|
||||
if str(checkoutRequest.userId) != str(ctx.user.id):
|
||||
raise HTTPException(status_code=403, detail="Users can only load credit to their own account")
|
||||
if not _isMemberOfMandate(ctx, targetMandateId):
|
||||
raise HTTPException(status_code=403, detail="User is not a member of this mandate")
|
||||
elif billingModel in [BillingModelEnum.PREPAY_MANDATE, BillingModelEnum.CREDIT_POSTPAY]:
|
||||
if not _isAdminOfMandate(ctx, targetMandateId):
|
||||
raise HTTPException(status_code=403, detail="Mandate admin role required to load mandate credit")
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Cannot add credit to {billingModel.value} billing model")
|
||||
|
||||
from modules.services.serviceBilling.stripeCheckout import create_checkout_session
|
||||
from modules.serviceCenter.services.serviceBilling.stripeCheckout import create_checkout_session
|
||||
redirect_url = create_checkout_session(
|
||||
mandate_id=targetMandateId,
|
||||
user_id=checkoutRequest.userId,
|
||||
amount_chf=checkoutRequest.amount
|
||||
amount_chf=checkoutRequest.amount,
|
||||
return_url=checkoutRequest.returnUrl
|
||||
)
|
||||
return CheckoutCreateResponse(redirectUrl=redirect_url)
|
||||
|
||||
|
|
@ -758,6 +899,65 @@ def createCheckoutSession(
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/checkout/confirm", response_model=CheckoutConfirmResponse)
|
||||
@limiter.limit("20/minute")
|
||||
def confirmCheckoutSession(
|
||||
request: Request,
|
||||
confirmRequest: CheckoutConfirmRequest = Body(...),
|
||||
ctx: RequestContext = Depends(getRequestContext),
|
||||
):
|
||||
"""
|
||||
Confirm Stripe Checkout success by session ID and apply credit idempotently.
|
||||
This is a fallback/reconciliation path in addition to webhook processing.
|
||||
"""
|
||||
try:
|
||||
stripe = _getStripeClient()
|
||||
session = stripe.checkout.Session.retrieve(confirmRequest.sessionId)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Stripe Checkout Session not found")
|
||||
|
||||
session_dict = session.to_dict_recursive() if hasattr(session, "to_dict_recursive") else dict(session)
|
||||
metadata = session_dict.get("metadata") or {}
|
||||
mandate_id = metadata.get("mandateId")
|
||||
user_id = metadata.get("userId") or None
|
||||
|
||||
if not mandate_id:
|
||||
raise HTTPException(status_code=400, detail="Invalid session metadata: mandateId missing")
|
||||
|
||||
payment_status = session_dict.get("payment_status")
|
||||
if payment_status != "paid":
|
||||
raise HTTPException(status_code=409, detail=f"Payment not completed yet (payment_status={payment_status})")
|
||||
|
||||
billingInterface = getBillingInterface(ctx.user, mandate_id)
|
||||
settings = billingInterface.getSettings(mandate_id)
|
||||
if not settings:
|
||||
raise HTTPException(status_code=404, detail="Billing settings not found")
|
||||
|
||||
billing_model = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value))
|
||||
if billing_model == BillingModelEnum.PREPAY_USER:
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=400, detail="userId required for PREPAY_USER")
|
||||
if str(user_id) != str(ctx.user.id):
|
||||
raise HTTPException(status_code=403, detail="Users can only confirm their own payment sessions")
|
||||
if not _isMemberOfMandate(ctx, mandate_id):
|
||||
raise HTTPException(status_code=403, detail="User is not a member of this mandate")
|
||||
elif billing_model in [BillingModelEnum.PREPAY_MANDATE, BillingModelEnum.CREDIT_POSTPAY]:
|
||||
if not _isAdminOfMandate(ctx, mandate_id):
|
||||
raise HTTPException(status_code=403, detail="Mandate admin role required")
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Cannot add credit to {billing_model.value}")
|
||||
|
||||
root_billing_interface = _getRootInterface()
|
||||
return _creditStripeSessionIfNeeded(root_billing_interface, session_dict, eventId=None)
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error confirming checkout session {confirmRequest.sessionId}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/webhook/stripe")
|
||||
async def stripeWebhook(
|
||||
request: Request,
|
||||
|
|
@ -768,7 +968,6 @@ async def stripeWebhook(
|
|||
No JWT auth - Stripe authenticates via Stripe-Signature header.
|
||||
"""
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.services.serviceBilling.stripeCheckout import ALLOWED_AMOUNTS_CHF
|
||||
|
||||
webhook_secret = APP_CONFIG.get("STRIPE_WEBHOOK_SECRET")
|
||||
if not webhook_secret:
|
||||
|
|
@ -792,71 +991,26 @@ async def stripeWebhook(
|
|||
logger.warning(f"Stripe webhook signature verification failed: {e}")
|
||||
raise HTTPException(status_code=400, detail="Invalid signature")
|
||||
|
||||
if event.type != "checkout.session.completed":
|
||||
logger.info(f"Stripe webhook received: event={event.id}, type={event.type}")
|
||||
|
||||
accepted_event_types = {"checkout.session.completed", "checkout.session.async_payment_succeeded"}
|
||||
if event.type not in accepted_event_types:
|
||||
return {"received": True}
|
||||
|
||||
session = event.data.object
|
||||
event_id = event.id
|
||||
session_id = session.id
|
||||
|
||||
billingInterface = _getRootInterface()
|
||||
|
||||
if billingInterface.getStripeWebhookEventByEventId(event_id):
|
||||
logger.info(f"Stripe event {event_id} already processed, skipping")
|
||||
return {"received": True}
|
||||
|
||||
metadata = session.get("metadata") or {}
|
||||
mandate_id = metadata.get("mandateId")
|
||||
user_id = metadata.get("userId") or None
|
||||
amount_chf_str = metadata.get("amountChf", "0")
|
||||
|
||||
if not mandate_id:
|
||||
logger.error(f"Stripe webhook missing mandateId in session {session_id}")
|
||||
raise HTTPException(status_code=400, detail="Invalid session metadata")
|
||||
|
||||
try:
|
||||
amount_chf = float(amount_chf_str)
|
||||
except (TypeError, ValueError):
|
||||
amount_chf = None
|
||||
|
||||
if amount_chf is None or amount_chf not in ALLOWED_AMOUNTS_CHF:
|
||||
amount_total = session.get("amount_total")
|
||||
if amount_total is not None:
|
||||
amount_chf = amount_total / 100.0
|
||||
else:
|
||||
logger.error(f"Stripe webhook invalid amount for session {session_id}")
|
||||
raise HTTPException(status_code=400, detail="Invalid amount")
|
||||
|
||||
settings = billingInterface.getSettings(mandate_id)
|
||||
if not settings:
|
||||
logger.error(f"Stripe webhook: billing settings not found for mandate {mandate_id}")
|
||||
raise HTTPException(status_code=404, detail="Billing settings not found")
|
||||
|
||||
billing_model = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value))
|
||||
|
||||
if billing_model == BillingModelEnum.PREPAY_USER:
|
||||
if not user_id:
|
||||
logger.error(f"Stripe webhook: userId required for PREPAY_USER mandate {mandate_id}")
|
||||
raise HTTPException(status_code=400, detail="userId required")
|
||||
account = billingInterface.getOrCreateUserAccount(mandate_id, user_id, initialBalance=0.0)
|
||||
elif billing_model in [BillingModelEnum.PREPAY_MANDATE, BillingModelEnum.CREDIT_POSTPAY]:
|
||||
account = billingInterface.getOrCreateMandateAccount(mandate_id, initialBalance=0.0)
|
||||
else:
|
||||
logger.error(f"Stripe webhook: cannot credit mandate {mandate_id} with model {billing_model}")
|
||||
raise HTTPException(status_code=400, detail=f"Cannot add credit to {billing_model.value}")
|
||||
|
||||
transaction = BillingTransaction(
|
||||
accountId=account["id"],
|
||||
transactionType=TransactionTypeEnum.CREDIT,
|
||||
amount=amount_chf,
|
||||
description="Stripe-Zahlung",
|
||||
referenceType=ReferenceTypeEnum.PAYMENT,
|
||||
referenceId=session_id
|
||||
session_dict = session.to_dict_recursive() if hasattr(session, "to_dict_recursive") else dict(session)
|
||||
result = _creditStripeSessionIfNeeded(billingInterface, session_dict, eventId=event_id)
|
||||
logger.info(
|
||||
f"Stripe webhook processed session {result.sessionId}: "
|
||||
f"credited={result.credited}, alreadyCredited={result.alreadyCredited}"
|
||||
)
|
||||
billingInterface.createTransaction(transaction)
|
||||
billingInterface.createStripeWebhookEvent(event_id)
|
||||
|
||||
logger.info(f"Stripe webhook: credited {amount_chf} CHF to account {account['id']} (session {session_id})")
|
||||
return {"received": True}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from modules.auth import limiter, getCurrentUser, getRequestContext, RequestCont
|
|||
# Import interfaces
|
||||
import modules.interfaces.interfaceDbManagement as interfaceDbManagement
|
||||
from modules.datamodels.datamodelFiles import FileItem, FilePreview
|
||||
from modules.datamodels.datamodelFileFolder import FileFolder
|
||||
from modules.shared.attributeUtils import getModelAttributeDefinitions
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||||
|
|
@ -19,6 +20,114 @@ from modules.datamodels.datamodelPagination import PaginationParams, PaginatedRe
|
|||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user):
|
||||
"""Background task: pre-scan + extraction + knowledge indexing.
|
||||
Step 1: Structure Pre-Scan (AI-free) -> FileContentIndex (persisted)
|
||||
Step 2: Content extraction via runExtraction -> ContentParts
|
||||
Step 3: KnowledgeService.indexFile -> chunking + embedding -> Knowledge Store"""
|
||||
userId = user.id if hasattr(user, "id") else str(user)
|
||||
try:
|
||||
mgmtInterface = interfaceDbManagement.getInterface(user)
|
||||
mgmtInterface.updateFile(fileId, {"status": "processing"})
|
||||
|
||||
rawBytes = mgmtInterface.getFileData(fileId)
|
||||
if not rawBytes:
|
||||
logger.warning(f"Auto-index: no file data for {fileId}, skipping")
|
||||
mgmtInterface.updateFile(fileId, {"status": "active"})
|
||||
return
|
||||
|
||||
logger.info(f"Auto-index starting for {fileName} ({len(rawBytes)} bytes, {mimeType})")
|
||||
|
||||
# Step 1: Structure Pre-Scan (AI-free)
|
||||
from modules.serviceCenter.services.serviceKnowledge.subPreScan import preScanDocument
|
||||
contentIndex = await preScanDocument(
|
||||
fileData=rawBytes,
|
||||
mimeType=mimeType,
|
||||
fileId=fileId,
|
||||
fileName=fileName,
|
||||
userId=userId,
|
||||
)
|
||||
logger.info(
|
||||
f"Pre-scan complete for {fileName}: "
|
||||
f"{contentIndex.totalObjects} objects"
|
||||
)
|
||||
|
||||
# Persist FileContentIndex immediately
|
||||
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
|
||||
knowledgeDb = getKnowledgeInterface()
|
||||
knowledgeDb.upsertFileContentIndex(contentIndex)
|
||||
|
||||
# Step 2: Content extraction (AI-free, produces ContentParts)
|
||||
from modules.serviceCenter.services.serviceExtraction.subRegistry import ExtractorRegistry, ChunkerRegistry
|
||||
from modules.serviceCenter.services.serviceExtraction.subPipeline import runExtraction
|
||||
from modules.datamodels.datamodelExtraction import ExtractionOptions
|
||||
|
||||
extractorRegistry = ExtractorRegistry()
|
||||
chunkerRegistry = ChunkerRegistry()
|
||||
options = ExtractionOptions()
|
||||
|
||||
extracted = runExtraction(
|
||||
extractorRegistry, chunkerRegistry,
|
||||
rawBytes, fileName, mimeType, options,
|
||||
)
|
||||
|
||||
contentObjects = []
|
||||
for part in extracted.parts:
|
||||
contentType = "text"
|
||||
if part.typeGroup == "image":
|
||||
contentType = "image"
|
||||
elif part.typeGroup in ("binary", "container"):
|
||||
contentType = "other"
|
||||
|
||||
if not part.data or not part.data.strip():
|
||||
continue
|
||||
|
||||
contentObjects.append({
|
||||
"contentObjectId": part.id,
|
||||
"contentType": contentType,
|
||||
"data": part.data,
|
||||
"contextRef": {
|
||||
"containerPath": fileName,
|
||||
"location": part.label or "file",
|
||||
**(part.metadata or {}),
|
||||
},
|
||||
})
|
||||
|
||||
logger.info(f"Extracted {len(contentObjects)} content objects from {fileName}")
|
||||
|
||||
if not contentObjects:
|
||||
knowledgeDb.updateFileStatus(fileId, "indexed")
|
||||
mgmtInterface.updateFile(fileId, {"status": "active"})
|
||||
return
|
||||
|
||||
# Step 3: Knowledge indexing (chunking + embedding)
|
||||
from modules.serviceCenter import getService
|
||||
from modules.serviceCenter.context import ServiceCenterContext
|
||||
|
||||
ctx = ServiceCenterContext(user=user, mandate_id="", feature_instance_id="")
|
||||
knowledgeService = getService("knowledge", ctx)
|
||||
|
||||
await knowledgeService.indexFile(
|
||||
fileId=fileId,
|
||||
fileName=fileName,
|
||||
mimeType=mimeType,
|
||||
userId=userId,
|
||||
contentObjects=contentObjects,
|
||||
structure=contentIndex.structure,
|
||||
)
|
||||
|
||||
mgmtInterface.updateFile(fileId, {"status": "active"})
|
||||
logger.info(f"Auto-index complete for file {fileId} ({fileName})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-index failed for file {fileId}: {e}", exc_info=True)
|
||||
try:
|
||||
errMgmt = interfaceDbManagement.getInterface(user)
|
||||
errMgmt.updateFile(fileId, {"status": "active"})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Model attributes for FileItem
|
||||
fileAttributes = getModelAttributeDefinitions(FileItem)
|
||||
|
||||
|
|
@ -111,6 +220,7 @@ async def upload_file(
|
|||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
workflowId: Optional[str] = Form(None),
|
||||
featureInstanceId: Optional[str] = Form(None),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> JSONResponse:
|
||||
# Add fileName property to UploadFile for consistency with backend model
|
||||
|
|
@ -133,6 +243,10 @@ async def upload_file(
|
|||
# Save file via LucyDOM interface in the database
|
||||
fileItem, duplicateType = managementInterface.saveUploadedFile(fileContent, file.filename)
|
||||
|
||||
if featureInstanceId and not fileItem.featureInstanceId:
|
||||
managementInterface.updateFile(fileItem.id, {"featureInstanceId": featureInstanceId})
|
||||
fileItem.featureInstanceId = featureInstanceId
|
||||
|
||||
# Determine response message based on duplicate type
|
||||
if duplicateType == "exact_duplicate":
|
||||
message = f"File '{file.filename}' already exists with identical content. Reusing existing file."
|
||||
|
|
@ -148,6 +262,32 @@ async def upload_file(
|
|||
if workflowId:
|
||||
fileMeta["workflowId"] = workflowId
|
||||
|
||||
# Trigger background auto-index pipeline (non-blocking)
|
||||
# Also runs for duplicates in case the original was never successfully indexed
|
||||
shouldIndex = duplicateType == "new_file"
|
||||
if not shouldIndex:
|
||||
try:
|
||||
from modules.interfaces.interfaceDbKnowledge import getInterface as _getKnowledgeInterface
|
||||
_kDb = _getKnowledgeInterface()
|
||||
_existingIndex = _kDb.getFileContentIndex(fileItem.id)
|
||||
if not _existingIndex:
|
||||
shouldIndex = True
|
||||
logger.info(f"Re-triggering auto-index for duplicate {fileItem.id} (not yet indexed)")
|
||||
except Exception:
|
||||
shouldIndex = True
|
||||
|
||||
if shouldIndex:
|
||||
try:
|
||||
import asyncio
|
||||
asyncio.ensure_future(_autoIndexFile(
|
||||
fileId=fileItem.id,
|
||||
fileName=fileItem.fileName,
|
||||
mimeType=fileItem.mimeType,
|
||||
user=currentUser,
|
||||
))
|
||||
except Exception as indexErr:
|
||||
logger.warning(f"Auto-index trigger failed (non-blocking): {indexErr}")
|
||||
|
||||
# Response with duplicate information
|
||||
return JSONResponse({
|
||||
"message": message,
|
||||
|
|
@ -171,6 +311,288 @@ async def upload_file(
|
|||
detail=f"Error during file upload: {str(e)}"
|
||||
)
|
||||
|
||||
# ── Folder endpoints (MUST be before /{fileId} catch-all) ─────────────────────
|
||||
|
||||
@router.get("/folders", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("30/minute")
|
||||
def list_folders(
|
||||
request: Request,
|
||||
parentId: Optional[str] = Query(None, description="Parent folder ID (omit for all folders)"),
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List folders for the current user."""
|
||||
try:
|
||||
mgmt = interfaceDbManagement.getInterface(
|
||||
currentUser,
|
||||
mandateId=str(context.mandateId) if context.mandateId else None,
|
||||
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
|
||||
)
|
||||
if parentId is not None:
|
||||
return mgmt.listFolders(parentId=parentId)
|
||||
return mgmt.listFolders()
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing folders: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/folders", status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("10/minute")
|
||||
def create_folder(
|
||||
request: Request,
|
||||
body: Dict[str, Any] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a new folder."""
|
||||
name = body.get("name", "")
|
||||
parentId = body.get("parentId")
|
||||
if not name:
|
||||
raise HTTPException(status_code=400, detail="name is required")
|
||||
try:
|
||||
mgmt = interfaceDbManagement.getInterface(
|
||||
currentUser,
|
||||
mandateId=str(context.mandateId) if context.mandateId else None,
|
||||
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
|
||||
)
|
||||
return mgmt.createFolder(name=name, parentId=parentId)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating folder: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/folders/{folderId}")
|
||||
@limiter.limit("10/minute")
|
||||
def rename_folder(
|
||||
request: Request,
|
||||
folderId: str = Path(...),
|
||||
body: Dict[str, Any] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Rename a folder."""
|
||||
newName = body.get("name", "")
|
||||
if not newName:
|
||||
raise HTTPException(status_code=400, detail="name is required")
|
||||
try:
|
||||
mgmt = interfaceDbManagement.getInterface(
|
||||
currentUser,
|
||||
mandateId=str(context.mandateId) if context.mandateId else None,
|
||||
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
|
||||
)
|
||||
mgmt.renameFolder(folderId, newName)
|
||||
return {"success": True, "folderId": folderId, "name": newName}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error renaming folder: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/folders/{folderId}")
|
||||
@limiter.limit("10/minute")
|
||||
def delete_folder(
|
||||
request: Request,
|
||||
folderId: str = Path(...),
|
||||
recursive: bool = Query(False, description="Delete folder contents recursively"),
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Delete a folder. Use recursive=true to delete non-empty folders."""
|
||||
try:
|
||||
mgmt = interfaceDbManagement.getInterface(
|
||||
currentUser,
|
||||
mandateId=str(context.mandateId) if context.mandateId else None,
|
||||
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
|
||||
)
|
||||
return mgmt.deleteFolder(folderId, recursive=recursive)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting folder: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/folders/{folderId}/move")
|
||||
@limiter.limit("10/minute")
|
||||
def move_folder(
|
||||
request: Request,
|
||||
folderId: str = Path(...),
|
||||
body: Dict[str, Any] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Move a folder to a new parent."""
|
||||
targetParentId = body.get("targetParentId")
|
||||
try:
|
||||
mgmt = interfaceDbManagement.getInterface(
|
||||
currentUser,
|
||||
mandateId=str(context.mandateId) if context.mandateId else None,
|
||||
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
|
||||
)
|
||||
mgmt.moveFolder(folderId, targetParentId)
|
||||
return {"success": True, "folderId": folderId, "parentId": targetParentId}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving folder: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/folders/{folderId}/download")
|
||||
@limiter.limit("10/minute")
|
||||
def download_folder(
|
||||
request: Request,
|
||||
folderId: str = Path(..., description="ID of the folder to download as ZIP"),
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Response:
|
||||
"""Download a folder (including subfolders) as a ZIP archive."""
|
||||
import io
|
||||
import zipfile
|
||||
import urllib.parse
|
||||
|
||||
try:
|
||||
mgmt = interfaceDbManagement.getInterface(
|
||||
currentUser,
|
||||
mandateId=str(context.mandateId) if context.mandateId else None,
|
||||
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
|
||||
)
|
||||
|
||||
folder = mgmt.getFolder(folderId)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail=f"Folder {folderId} not found")
|
||||
|
||||
folderName = folder.get("name", "download")
|
||||
|
||||
def _collectFiles(parentId: str, pathPrefix: str):
|
||||
"""Recursively collect (zipPath, fileId) tuples."""
|
||||
entries = []
|
||||
for f in mgmt._getFilesByCurrentUser(recordFilter={"folderId": parentId}):
|
||||
fname = f.get("fileName") or f.get("name") or f.get("id", "file")
|
||||
entries.append((f"{pathPrefix}{fname}", f["id"]))
|
||||
for sub in mgmt.listFolders(parentId=parentId):
|
||||
subName = sub.get("name", sub["id"])
|
||||
entries.extend(_collectFiles(sub["id"], f"{pathPrefix}{subName}/"))
|
||||
return entries
|
||||
|
||||
fileEntries = _collectFiles(folderId, "")
|
||||
if not fileEntries:
|
||||
raise HTTPException(status_code=404, detail="Folder is empty")
|
||||
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for zipPath, fileId in fileEntries:
|
||||
data = mgmt.getFileData(fileId)
|
||||
if data:
|
||||
zf.writestr(zipPath, data)
|
||||
|
||||
buf.seek(0)
|
||||
zipBytes = buf.getvalue()
|
||||
encodedName = urllib.parse.quote(f"{folderName}.zip")
|
||||
|
||||
return Response(
|
||||
content=zipBytes,
|
||||
media_type="application/zip",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encodedName}"
|
||||
}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading folder as ZIP: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Error downloading folder: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/batch-delete")
|
||||
@limiter.limit("10/minute")
|
||||
def batch_delete_items(
|
||||
request: Request,
|
||||
body: Dict[str, Any] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Batch delete files/folders with a single SQL-backed operation per type."""
|
||||
fileIds = body.get("fileIds") or []
|
||||
folderIds = body.get("folderIds") or []
|
||||
recursiveFolders = bool(body.get("recursiveFolders", True))
|
||||
|
||||
if not isinstance(fileIds, list) or not isinstance(folderIds, list):
|
||||
raise HTTPException(status_code=400, detail="fileIds and folderIds must be arrays")
|
||||
|
||||
try:
|
||||
mgmt = interfaceDbManagement.getInterface(
|
||||
currentUser,
|
||||
mandateId=str(context.mandateId) if context.mandateId else None,
|
||||
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
|
||||
)
|
||||
|
||||
result = {"deletedFiles": 0, "deletedFolders": 0}
|
||||
|
||||
if fileIds:
|
||||
fileResult = mgmt.deleteFilesBatch(fileIds)
|
||||
result["deletedFiles"] += fileResult.get("deletedFiles", 0)
|
||||
|
||||
if folderIds:
|
||||
folderResult = mgmt.deleteFoldersBatch(folderIds, recursive=recursiveFolders)
|
||||
result["deletedFiles"] += folderResult.get("deletedFiles", 0)
|
||||
result["deletedFolders"] += folderResult.get("deletedFolders", 0)
|
||||
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch delete: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/batch-move")
|
||||
@limiter.limit("10/minute")
|
||||
def batch_move_items(
|
||||
request: Request,
|
||||
body: Dict[str, Any] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Batch move files/folders with a single SQL-backed operation per type."""
|
||||
fileIds = body.get("fileIds") or []
|
||||
folderIds = body.get("folderIds") or []
|
||||
targetFolderId = body.get("targetFolderId")
|
||||
targetParentId = body.get("targetParentId")
|
||||
|
||||
if not isinstance(fileIds, list) or not isinstance(folderIds, list):
|
||||
raise HTTPException(status_code=400, detail="fileIds and folderIds must be arrays")
|
||||
|
||||
try:
|
||||
mgmt = interfaceDbManagement.getInterface(
|
||||
currentUser,
|
||||
mandateId=str(context.mandateId) if context.mandateId else None,
|
||||
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
|
||||
)
|
||||
|
||||
result = {"movedFiles": 0, "movedFolders": 0}
|
||||
|
||||
if fileIds:
|
||||
fileResult = mgmt.moveFilesBatch(fileIds, targetFolderId=targetFolderId)
|
||||
result["movedFiles"] += fileResult.get("movedFiles", 0)
|
||||
|
||||
if folderIds:
|
||||
folderResult = mgmt.moveFoldersBatch(folderIds, targetParentId=targetParentId)
|
||||
result["movedFolders"] += folderResult.get("movedFolders", 0)
|
||||
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch move: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ── File endpoints with path parameters (catch-all /{fileId}) ─────────────────
|
||||
|
||||
@router.get("/{fileId}", response_model=FileItem)
|
||||
@limiter.limit("30/minute")
|
||||
def get_file(
|
||||
|
|
@ -418,3 +840,25 @@ def preview_file(
|
|||
)
|
||||
|
||||
|
||||
@router.post("/{fileId}/move")
|
||||
@limiter.limit("10/minute")
|
||||
def move_file(
|
||||
request: Request,
|
||||
fileId: str = Path(...),
|
||||
body: Dict[str, Any] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> Dict[str, Any]:
|
||||
"""Move a file to a different folder."""
|
||||
targetFolderId = body.get("targetFolderId")
|
||||
try:
|
||||
mgmt = interfaceDbManagement.getInterface(
|
||||
currentUser,
|
||||
mandateId=str(context.mandateId) if context.mandateId else None,
|
||||
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
|
||||
)
|
||||
mgmt.updateFile(fileId, {"folderId": targetFolderId})
|
||||
return {"success": True, "fileId": fileId, "folderId": targetFolderId}
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving file: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
|
|||
|
|
@ -764,7 +764,7 @@ def send_password_link(
|
|||
expiryHours = int(APP_CONFIG.get("Auth_RESET_TOKEN_EXPIRY_HOURS", "24"))
|
||||
|
||||
try:
|
||||
from modules.services import Services
|
||||
from modules.serviceHub import Services
|
||||
services = Services(targetUser)
|
||||
|
||||
emailSubject = "PowerOn - Passwort setzen"
|
||||
|
|
|
|||
|
|
@ -395,7 +395,7 @@ def trigger_subscription(
|
|||
)
|
||||
|
||||
# Get messaging service from request app state
|
||||
from modules.services import getInterface as getServicesInterface
|
||||
from modules.serviceHub import getInterface as getServicesInterface
|
||||
services = getServicesInterface(context.user, None, mandateId=str(context.mandateId))
|
||||
|
||||
# Konvertiere Dict zu Pydantic Model
|
||||
|
|
|
|||
|
|
@ -87,9 +87,10 @@ CLIENT_SECRET = APP_CONFIG.get("Service_GOOGLE_CLIENT_SECRET")
|
|||
REDIRECT_URI = APP_CONFIG.get("Service_GOOGLE_REDIRECT_URI")
|
||||
SCOPES = [
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"openid"
|
||||
"openid",
|
||||
]
|
||||
|
||||
@router.get("/config")
|
||||
|
|
@ -488,7 +489,7 @@ async def auth_callback(code: str, state: str, request: Request, response: Respo
|
|||
connection.externalUsername = user_info.get("email")
|
||||
connection.externalEmail = user_info.get("email")
|
||||
# Store actually granted scopes for this connection
|
||||
granted_scopes_list = granted_scopes.split(" ") if granted_scopes else SCOPES
|
||||
granted_scopes_list = granted_scopes if isinstance(granted_scopes, list) else (granted_scopes.split(" ") if granted_scopes else SCOPES)
|
||||
connection.grantedScopes = granted_scopes_list
|
||||
logger.info(f"Storing granted scopes for connection {connection_id}: {granted_scopes_list}")
|
||||
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ SCOPES = [
|
|||
"Mail.Send", # Send mail
|
||||
"Files.ReadWrite.All", # Read and write files (SharePoint/OneDrive)
|
||||
"Sites.ReadWrite.All", # Read and write SharePoint sites
|
||||
"Team.ReadBasic.All", # List joined teams and channels
|
||||
# Teams Bot: Meeting and chat access (requires admin consent)
|
||||
"OnlineMeetings.Read", # Read user's Teams meeting details (delegated scope)
|
||||
"Chat.ReadWrite", # Read and write Teams chat messages
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from fastapi import APIRouter, HTTPException, Depends, Path, Query, Request, sta
|
|||
from modules.auth import limiter, getCurrentUser
|
||||
from modules.datamodels.datamodelUam import User, UserConnection
|
||||
from modules.interfaces.interfaceDbApp import getInterface
|
||||
from modules.services import getInterface as getServices
|
||||
from modules.serviceHub import getInterface as getServices
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -102,12 +102,6 @@ def _getFeatureUiObjects(featureCode: str) -> List[Dict[str, Any]]:
|
|||
elif featureCode == "realestate":
|
||||
from modules.features.realEstate.mainRealEstate import UI_OBJECTS
|
||||
return UI_OBJECTS
|
||||
elif featureCode == "chatplayground":
|
||||
from modules.features.chatplayground.mainChatplayground import UI_OBJECTS
|
||||
return UI_OBJECTS
|
||||
elif featureCode == "codeeditor":
|
||||
from modules.features.codeeditor.mainCodeeditor import UI_OBJECTS
|
||||
return UI_OBJECTS
|
||||
elif featureCode == "automation":
|
||||
from modules.features.automation.mainAutomation import UI_OBJECTS
|
||||
return UI_OBJECTS
|
||||
|
|
@ -123,8 +117,11 @@ def _getFeatureUiObjects(featureCode: str) -> List[Dict[str, Any]]:
|
|||
elif featureCode == "commcoach":
|
||||
from modules.features.commcoach.mainCommcoach import UI_OBJECTS
|
||||
return UI_OBJECTS
|
||||
elif featureCode == "workspace":
|
||||
from modules.features.workspace.mainWorkspace import UI_OBJECTS
|
||||
return UI_OBJECTS
|
||||
else:
|
||||
logger.warning(f"Unknown feature code: {featureCode}")
|
||||
logger.debug(f"Skipping removed feature code: {featureCode}")
|
||||
return []
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import UI_OBJECTS for feature {featureCode}: {e}")
|
||||
|
|
|
|||
|
|
@ -6,13 +6,16 @@ Replaces Azure voice services with Google Cloud Speech-to-Text and Translation
|
|||
Includes WebSocket support for real-time voice streaming
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import base64
|
||||
from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException, Body, WebSocket, WebSocketDisconnect
|
||||
import secrets
|
||||
import time
|
||||
from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException, Body, Query, Request, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import Response
|
||||
from typing import Optional, Dict, Any, List
|
||||
from modules.auth import getCurrentUser
|
||||
from modules.auth import getCurrentUser, getRequestContext, RequestContext, limiter
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.interfaces.interfaceVoiceObjects import getVoiceInterface, VoiceObjects
|
||||
|
||||
|
|
@ -290,10 +293,11 @@ async def realtime_interpreter(
|
|||
|
||||
@router.post("/text-to-speech")
|
||||
async def text_to_speech(
|
||||
request: Request,
|
||||
text: str = Form(...),
|
||||
language: str = Form("de-DE"),
|
||||
voice: str = Form(None),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
):
|
||||
"""Convert text to speech using Google Cloud Text-to-Speech."""
|
||||
try:
|
||||
|
|
@ -305,7 +309,20 @@ async def text_to_speech(
|
|||
detail="Empty text provided for text-to-speech"
|
||||
)
|
||||
|
||||
voiceInterface = _getVoiceInterface(currentUser)
|
||||
mandateId = str(getattr(context, "mandateId", "") or "")
|
||||
voiceInterface = getVoiceInterface(context.user, mandateId)
|
||||
try:
|
||||
from modules.serviceCenter.services.serviceBilling.mainServiceBilling import getService as getBillingService
|
||||
billingService = getBillingService(context.user, mandateId)
|
||||
def _billingCb(data):
|
||||
priceCHF = data.get("priceCHF", 0.0)
|
||||
operation = data.get("operation", "voice")
|
||||
if priceCHF > 0:
|
||||
billingService.recordUsage(priceCHF=priceCHF, aicoreProvider="google-voice", aicoreModel=operation, description=f"Voice {operation}")
|
||||
voiceInterface.billingCallback = _billingCb
|
||||
except Exception as e:
|
||||
logger.warning(f"TTS billing setup skipped: {e}")
|
||||
|
||||
result = await voiceInterface.textToSpeech(
|
||||
text=text,
|
||||
languageCode=language,
|
||||
|
|
@ -314,12 +331,12 @@ async def text_to_speech(
|
|||
|
||||
if result["success"]:
|
||||
return Response(
|
||||
content=result["audio_content"],
|
||||
content=result["audioContent"],
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Content-Disposition": "attachment; filename=speech.mp3",
|
||||
"X-Voice-Name": result["voice_name"],
|
||||
"X-Language-Code": result["language_code"]
|
||||
"X-Voice-Name": result.get("voiceName", ""),
|
||||
"X-Language-Code": result.get("languageCode", language),
|
||||
}
|
||||
)
|
||||
else:
|
||||
|
|
@ -533,189 +550,192 @@ async def save_voice_settings(
|
|||
detail=f"Failed to save voice settings: {str(e)}"
|
||||
)
|
||||
|
||||
# WebSocket endpoints for real-time voice streaming
|
||||
# =========================================================================
|
||||
# STT Streaming WebSocket — generic, used by all features
|
||||
# =========================================================================
|
||||
|
||||
@router.websocket("/ws/realtime-interpreter")
|
||||
async def websocket_realtime_interpreter(
|
||||
websocket: WebSocket,
|
||||
userId: str = "default",
|
||||
fromLanguage: str = "de-DE",
|
||||
toLanguage: str = "en-US"
|
||||
_sttTokens: Dict[str, Dict[str, Any]] = {}
|
||||
_STT_TOKEN_TTL = 45
|
||||
|
||||
|
||||
def _cleanupSttTokens():
|
||||
now = time.time()
|
||||
expired = [t for t, p in _sttTokens.items() if p.get("expiresAt", 0) <= now]
|
||||
for t in expired:
|
||||
_sttTokens.pop(t, None)
|
||||
|
||||
|
||||
@router.post("/stt/token")
|
||||
@limiter.limit("60/minute")
|
||||
async def createSttToken(
|
||||
request: Request,
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
):
|
||||
"""WebSocket endpoint for real-time voice interpretation"""
|
||||
connectionId = f"realtime_{userId}_{fromLanguage}_{toLanguage}"
|
||||
"""Issue a short-lived single-use token for the STT streaming WebSocket."""
|
||||
_cleanupSttTokens()
|
||||
token = secrets.token_urlsafe(32)
|
||||
_sttTokens[token] = {
|
||||
"userId": str(context.user.id),
|
||||
"mandateId": str(getattr(context, "mandateId", "") or ""),
|
||||
"expiresAt": time.time() + _STT_TOKEN_TTL,
|
||||
}
|
||||
return {"wsToken": token, "expiresInSeconds": _STT_TOKEN_TTL}
|
||||
|
||||
|
||||
@router.websocket("/stt/stream")
|
||||
async def sttStream(
|
||||
websocket: WebSocket,
|
||||
wsToken: Optional[str] = Query(None),
|
||||
):
|
||||
"""
|
||||
Generic STT streaming WebSocket.
|
||||
|
||||
Protocol:
|
||||
Client sends JSON:
|
||||
{"type": "open", "language": "de-DE"}
|
||||
{"type": "audio", "chunk": "<base64>"}
|
||||
{"type": "close"}
|
||||
Server sends JSON:
|
||||
{"type": "interim", "text": "..."}
|
||||
{"type": "final", "text": "...", "confidence": 0.95}
|
||||
{"type": "error", "message": "..."}
|
||||
{"type": "closed"}
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
# --- authenticate via wsToken ---
|
||||
if not wsToken:
|
||||
await websocket.send_json({"type": "error", "code": "ws_token_required", "message": "wsToken query param required"})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
_cleanupSttTokens()
|
||||
tokenPayload = _sttTokens.pop(wsToken, None)
|
||||
if not tokenPayload:
|
||||
await websocket.send_json({"type": "error", "code": "ws_token_invalid", "message": "Invalid or expired wsToken"})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
tokenUserId = tokenPayload["userId"]
|
||||
tokenMandateId = tokenPayload.get("mandateId", "")
|
||||
|
||||
# Resolve real user for billing
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
rootInterface = getRootInterface()
|
||||
currentUser = rootInterface.getUser(tokenUserId)
|
||||
if not currentUser:
|
||||
await websocket.send_json({"type": "error", "code": "user_not_found", "message": "User not found"})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
# --- billing pre-flight ---
|
||||
billingService = None
|
||||
try:
|
||||
await manager.connect(websocket, connectionId)
|
||||
|
||||
# Send connection confirmation
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "connected",
|
||||
"connection_id": connectionId,
|
||||
"message": "Connected to real-time interpreter"
|
||||
}, websocket)
|
||||
|
||||
# Initialize voice interface
|
||||
voiceInterface = _getVoiceInterface(User(id=userId))
|
||||
|
||||
while True:
|
||||
# Receive message from client
|
||||
data = await websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
|
||||
if message["type"] == "audio_chunk":
|
||||
# Process audio chunk
|
||||
try:
|
||||
# Decode base64 audio data
|
||||
audioData = base64.b64decode(message["data"])
|
||||
|
||||
# For now, just acknowledge receipt
|
||||
# In a full implementation, this would:
|
||||
# 1. Buffer audio chunks
|
||||
# 2. Process with Google Cloud Speech-to-Text streaming
|
||||
# 3. Send partial results back
|
||||
# 4. Handle translation
|
||||
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "audio_received",
|
||||
"chunk_size": len(audioData),
|
||||
"timestamp": message.get("timestamp")
|
||||
}, websocket)
|
||||
|
||||
from modules.serviceCenter.services.serviceBilling.mainServiceBilling import getService as getBillingService
|
||||
billingService = getBillingService(currentUser, tokenMandateId)
|
||||
billingCheck = billingService.checkBalance(0.0)
|
||||
if not billingCheck.allowed:
|
||||
await websocket.send_json({"type": "error", "code": "billing_insufficient", "message": "Insufficient balance for voice services"})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio chunk: {e}")
|
||||
await manager.send_personal_message({
|
||||
"type": "error",
|
||||
"error": f"Failed to process audio: {str(e)}"
|
||||
}, websocket)
|
||||
logger.warning(f"STT billing pre-flight skipped: {e}")
|
||||
|
||||
elif message["type"] == "ping":
|
||||
# Respond to ping
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "pong",
|
||||
"timestamp": message.get("timestamp")
|
||||
}, websocket)
|
||||
audioQueue: asyncio.Queue = asyncio.Queue()
|
||||
language = "de-DE"
|
||||
streamingTask: Optional[asyncio.Task] = None
|
||||
voiceInterface: Optional[VoiceObjects] = None
|
||||
|
||||
async def _sendJson(payload: Dict[str, Any]) -> bool:
|
||||
try:
|
||||
await websocket.send_json(payload)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def _runStreaming():
|
||||
nonlocal voiceInterface
|
||||
voiceInterface = getVoiceInterface(currentUser, tokenMandateId)
|
||||
if billingService:
|
||||
def _billingCb(data):
|
||||
priceCHF = data.get("priceCHF", 0.0)
|
||||
operation = data.get("operation", "voice")
|
||||
if priceCHF > 0:
|
||||
billingService.recordUsage(
|
||||
priceCHF=priceCHF,
|
||||
aicoreProvider="google-voice",
|
||||
aicoreModel=operation,
|
||||
description=f"Voice {operation}",
|
||||
)
|
||||
voiceInterface.billingCallback = _billingCb
|
||||
|
||||
try:
|
||||
async for event in voiceInterface.streamingSpeechToText(audioQueue, language):
|
||||
if event.get("reconnectRequired"):
|
||||
await _sendJson({"type": "reconnect_required"})
|
||||
return
|
||||
if event.get("isFinal"):
|
||||
if event.get("transcript"):
|
||||
await _sendJson({"type": "final", "text": event["transcript"], "confidence": event.get("confidence", 0.0)})
|
||||
else:
|
||||
logger.warning(f"Unknown message type: {message['type']}")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket, connectionId)
|
||||
logger.info(f"Client disconnected: {connectionId}")
|
||||
if event.get("transcript"):
|
||||
await _sendJson({"type": "interim", "text": event["transcript"]})
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}")
|
||||
manager.disconnect(websocket, connectionId)
|
||||
|
||||
@router.websocket("/ws/speech-to-text")
|
||||
async def websocket_speech_to_text(
|
||||
websocket: WebSocket,
|
||||
userId: str = "default",
|
||||
language: str = "de-DE"
|
||||
):
|
||||
"""WebSocket endpoint for real-time speech-to-text"""
|
||||
connectionId = f"stt_{userId}_{language}"
|
||||
logger.error(f"STT streaming error: {e}")
|
||||
await _sendJson({"type": "error", "message": str(e)})
|
||||
|
||||
try:
|
||||
await manager.connect(websocket, connectionId)
|
||||
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "connected",
|
||||
"connection_id": connectionId,
|
||||
"message": "Connected to speech-to-text"
|
||||
}, websocket)
|
||||
|
||||
# Initialize voice interface
|
||||
voiceInterface = _getVoiceInterface(User(id=userId))
|
||||
await _sendJson({"type": "status", "label": "STT stream connected"})
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
raw = await websocket.receive_text()
|
||||
msg = json.loads(raw)
|
||||
msgType = (msg.get("type") or "").strip()
|
||||
|
||||
if message["type"] == "audio_chunk":
|
||||
if msgType == "open":
|
||||
language = msg.get("language") or "de-DE"
|
||||
if streamingTask and not streamingTask.done():
|
||||
await audioQueue.put((b"", True))
|
||||
streamingTask.cancel()
|
||||
audioQueue = asyncio.Queue()
|
||||
streamingTask = asyncio.create_task(_runStreaming())
|
||||
await _sendJson({"type": "status", "label": "Listening..."})
|
||||
|
||||
elif msgType == "audio":
|
||||
chunkB64 = msg.get("chunk")
|
||||
if not chunkB64:
|
||||
continue
|
||||
chunkBytes = base64.b64decode(chunkB64)
|
||||
if len(chunkBytes) > 400_000:
|
||||
await _sendJson({"type": "error", "code": "chunk_too_large", "message": "Audio chunk too large"})
|
||||
continue
|
||||
await audioQueue.put((chunkBytes, False))
|
||||
|
||||
elif msgType == "close":
|
||||
await audioQueue.put((b"", True))
|
||||
if streamingTask:
|
||||
try:
|
||||
audioData = base64.b64decode(message["data"])
|
||||
await asyncio.wait_for(streamingTask, timeout=10.0)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
await _sendJson({"type": "closed"})
|
||||
await websocket.close()
|
||||
break
|
||||
|
||||
# Process audio chunk
|
||||
# This would integrate with Google Cloud Speech-to-Text streaming API
|
||||
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "transcription_result",
|
||||
"text": "Audio chunk received", # Placeholder
|
||||
"confidence": 0.95,
|
||||
"is_final": False
|
||||
}, websocket)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio: {e}")
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "error",
|
||||
"error": f"Failed to process audio: {str(e)}"
|
||||
}, websocket)
|
||||
|
||||
elif message["type"] == "ping":
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "pong",
|
||||
"timestamp": message.get("timestamp")
|
||||
}, websocket)
|
||||
elif msgType == "ping":
|
||||
await _sendJson({"type": "pong"})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket, connectionId)
|
||||
logger.info(f"STT WebSocket disconnected: userId={tokenUserId}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}")
|
||||
manager.disconnect(websocket, connectionId)
|
||||
|
||||
@router.websocket("/ws/text-to-speech")
|
||||
async def websocket_text_to_speech(
|
||||
websocket: WebSocket,
|
||||
userId: str = "default",
|
||||
language: str = "de-DE",
|
||||
voice: str = "de-DE-Wavenet-A"
|
||||
):
|
||||
"""WebSocket endpoint for real-time text-to-speech"""
|
||||
connectionId = f"tts_{userId}_{language}_{voice}"
|
||||
|
||||
logger.error(f"STT WebSocket error: {e}", exc_info=True)
|
||||
try:
|
||||
await manager.connect(websocket, connectionId)
|
||||
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "connected",
|
||||
"connection_id": connectionId,
|
||||
"message": "Connected to text-to-speech"
|
||||
}, websocket)
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
|
||||
if message["type"] == "text_to_speak":
|
||||
try:
|
||||
text = message["text"]
|
||||
|
||||
# Process text-to-speech
|
||||
# This would integrate with Google Cloud Text-to-Speech API
|
||||
|
||||
# For now, send a placeholder response
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "audio_data",
|
||||
"audio": "base64_encoded_audio_here", # Placeholder
|
||||
"format": "mp3"
|
||||
}, websocket)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing text-to-speech: {e}")
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "error",
|
||||
"error": f"Failed to process text: {str(e)}"
|
||||
}, websocket)
|
||||
|
||||
elif message["type"] == "ping":
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "pong",
|
||||
"timestamp": message.get("timestamp")
|
||||
}, websocket)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket, connectionId)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}")
|
||||
manager.disconnect(websocket, connectionId)
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
await audioQueue.put((b"", True))
|
||||
if streamingTask and not streamingTask.done():
|
||||
streamingTask.cancel()
|
||||
|
|
|
|||
|
|
@ -120,6 +120,49 @@ class RbacCatalogService:
|
|||
return [obj for obj in self._dataObjects.values() if obj["featureCode"] == featureCode]
|
||||
return list(self._dataObjects.values())
|
||||
|
||||
def getAccessibleDataObjects(
|
||||
self,
|
||||
featureCode: str,
|
||||
rbacInstance,
|
||||
user,
|
||||
mandateId: str,
|
||||
featureInstanceId: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get DATA objects filtered by RBAC read permission for the user.
|
||||
|
||||
Args:
|
||||
featureCode: Feature code to filter by
|
||||
rbacInstance: RbacClass instance for permission checks
|
||||
user: User object
|
||||
mandateId: Mandate scope
|
||||
featureInstanceId: Feature instance scope
|
||||
"""
|
||||
from modules.datamodels.datamodelRbac import AccessRuleContext
|
||||
allObjects = self.getDataObjects(featureCode)
|
||||
accessible = []
|
||||
for obj in allObjects:
|
||||
objectKey = obj.get("objectKey", "")
|
||||
try:
|
||||
perms = rbacInstance.getUserPermissions(
|
||||
user=user,
|
||||
context=AccessRuleContext.DATA,
|
||||
item=objectKey,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=featureInstanceId,
|
||||
)
|
||||
if perms.view or perms.read.value != "n":
|
||||
accessible.append(obj)
|
||||
except Exception:
|
||||
pass
|
||||
return accessible
|
||||
|
||||
def getFeaturesWithDataObjects(self) -> List[str]:
|
||||
"""Get feature codes that have at least one registered DATA object."""
|
||||
codes = set()
|
||||
for obj in self._dataObjects.values():
|
||||
codes.add(obj["featureCode"])
|
||||
return list(codes)
|
||||
|
||||
def getAllObjects(self, featureCode: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Get all RBAC objects (UI + RESOURCE + DATA), optionally filtered by feature."""
|
||||
return self.getUiObjects(featureCode) + self.getResourceObjects(featureCode) + self.getDataObjects(featureCode)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ logger = logging.getLogger(__name__)
|
|||
def getService(
|
||||
key: str,
|
||||
context: ServiceCenterContext,
|
||||
legacy_hub: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Get a service instance by key for the given context.
|
||||
|
|
@ -34,14 +33,13 @@ def getService(
|
|||
Args:
|
||||
key: Service key (e.g., "web", "extraction", "utils")
|
||||
context: ServiceCenterContext with user, mandate_id, feature_instance_id, workflow
|
||||
legacy_hub: Optional legacy Services instance for fallback when service not yet migrated
|
||||
|
||||
Returns:
|
||||
Service instance
|
||||
"""
|
||||
cache = get_resolution_cache()
|
||||
resolving = set()
|
||||
return resolve(key, context, cache, resolving, legacy_hub=legacy_hub)
|
||||
return resolve(key, context, cache, resolving)
|
||||
|
||||
|
||||
def preWarm(service_keys: Optional[List[str]] = None) -> None:
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ class EventManager:
|
|||
"""Initialize the event manager."""
|
||||
self._queues: Dict[str, asyncio.Queue] = {}
|
||||
self._cleanup_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._agent_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._cancelled: Dict[str, bool] = {}
|
||||
|
||||
def create_queue(self, workflow_id: str) -> asyncio.Queue:
|
||||
"""
|
||||
|
|
@ -33,9 +35,22 @@ class EventManager:
|
|||
Returns:
|
||||
Async queue for events
|
||||
"""
|
||||
if workflow_id in self._cleanup_tasks:
|
||||
self._cleanup_tasks[workflow_id].cancel()
|
||||
del self._cleanup_tasks[workflow_id]
|
||||
logger.debug(f"Cancelled pending cleanup for workflow {workflow_id}")
|
||||
|
||||
if workflow_id not in self._queues:
|
||||
self._queues[workflow_id] = asyncio.Queue()
|
||||
logger.debug(f"Created event queue for workflow {workflow_id}")
|
||||
else:
|
||||
old = self._queues[workflow_id]
|
||||
while not old.empty():
|
||||
try:
|
||||
old.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
logger.debug(f"Reusing event queue for workflow {workflow_id} (drained stale events)")
|
||||
return self._queues[workflow_id]
|
||||
|
||||
def get_queue(self, workflow_id: str) -> Optional[asyncio.Queue]:
|
||||
|
|
@ -62,6 +77,31 @@ class EventManager:
|
|||
"""
|
||||
return workflow_id in self._queues
|
||||
|
||||
def register_agent_task(self, workflow_id: str, task: asyncio.Task) -> None:
|
||||
"""Register the asyncio Task running the agent for a workflow."""
|
||||
self._agent_tasks[workflow_id] = task
|
||||
self._cancelled.pop(workflow_id, None)
|
||||
|
||||
def is_cancelled(self, workflow_id: str) -> bool:
|
||||
"""Check if a workflow has been cancelled."""
|
||||
return self._cancelled.get(workflow_id, False)
|
||||
|
||||
async def cancel_agent(self, workflow_id: str) -> bool:
|
||||
"""Cancel the running agent task for a workflow. Returns True if cancelled."""
|
||||
self._cancelled[workflow_id] = True
|
||||
task = self._agent_tasks.pop(workflow_id, None)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
logger.info(f"Cancelled agent task for workflow {workflow_id}")
|
||||
return True
|
||||
logger.debug(f"No running agent task found for workflow {workflow_id}")
|
||||
return False
|
||||
|
||||
def _unregister_agent_task(self, workflow_id: str) -> None:
|
||||
"""Remove the agent task reference after completion."""
|
||||
self._agent_tasks.pop(workflow_id, None)
|
||||
self._cancelled.pop(workflow_id, None)
|
||||
|
||||
async def emit_event(
|
||||
self,
|
||||
context_id: str,
|
||||
|
|
@ -97,6 +137,7 @@ class EventManager:
|
|||
|
||||
try:
|
||||
await queue.put(event)
|
||||
if event_type not in ("chunk",):
|
||||
logger.debug(f"Emitted {event_type} event for workflow {context_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting event for workflow {context_id}: {e}", exc_info=True)
|
||||
|
|
|
|||
|
|
@ -98,6 +98,20 @@ IMPORTABLE_SERVICES: Dict[str, Dict[str, Any]] = {
|
|||
"objectKey": "service.neutralization",
|
||||
"label": {"en": "Neutralization", "de": "Neutralisierung", "fr": "Neutralisation"},
|
||||
},
|
||||
"agent": {
|
||||
"module": "modules.serviceCenter.services.serviceAgent.mainServiceAgent",
|
||||
"class": "AgentService",
|
||||
"dependencies": ["ai", "chat", "utils", "extraction", "billing", "streaming", "knowledge"],
|
||||
"objectKey": "service.agent",
|
||||
"label": {"en": "Agent", "de": "Agent", "fr": "Agent"},
|
||||
},
|
||||
"knowledge": {
|
||||
"module": "modules.serviceCenter.services.serviceKnowledge.mainServiceKnowledge",
|
||||
"class": "KnowledgeService",
|
||||
"dependencies": ["ai"],
|
||||
"objectKey": "service.knowledge",
|
||||
"label": {"en": "Knowledge Store", "de": "Wissensspeicher", "fr": "Base de connaissances"},
|
||||
},
|
||||
}
|
||||
|
||||
# RBAC objects for service-level access control (for catalog registration)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# All rights reserved.
|
||||
"""
|
||||
Service Center Resolver.
|
||||
Resolution logic, dependency injection, and optional legacy fallback.
|
||||
Resolution logic and dependency injection for service instantiation.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
|
|
@ -14,7 +14,6 @@ from modules.serviceCenter.registry import CORE_SERVICES, IMPORTABLE_SERVICES
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type for get_service callable passed to services
|
||||
GetServiceFunc = Callable[[str], Any]
|
||||
|
||||
|
||||
|
|
@ -29,50 +28,15 @@ def _load_service_class(module_path: str, class_name: str):
|
|||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def _create_legacy_hub(ctx: ServiceCenterContext) -> Any:
|
||||
"""Create legacy Services instance for fallback when service not yet migrated."""
|
||||
from modules.services import getInterface
|
||||
return getInterface(
|
||||
ctx.user,
|
||||
workflow=ctx.workflow,
|
||||
mandateId=ctx.mandate_id,
|
||||
featureInstanceId=ctx.feature_instance_id,
|
||||
)
|
||||
|
||||
|
||||
def _get_from_legacy(legacy_hub: Any, key: str) -> Any:
|
||||
"""Map service key to legacy hub attribute (for fallback when service center module fails)."""
|
||||
key_to_attr = {
|
||||
"utils": "utils",
|
||||
"security": "security",
|
||||
"streaming": "streaming",
|
||||
"ticket": "ticket",
|
||||
"messaging": "messaging",
|
||||
"billing": "billing",
|
||||
"sharepoint": "sharepoint",
|
||||
"chat": "chat",
|
||||
"extraction": "extraction",
|
||||
"generation": "generation",
|
||||
"ai": "ai",
|
||||
"web": "web",
|
||||
"neutralization": "neutralization",
|
||||
}
|
||||
attr = key_to_attr.get(key)
|
||||
if attr and hasattr(legacy_hub, attr):
|
||||
return getattr(legacy_hub, attr)
|
||||
return None
|
||||
|
||||
|
||||
def resolve(
|
||||
key: str,
|
||||
context: ServiceCenterContext,
|
||||
cache: Dict[str, Any],
|
||||
resolving: Set[str],
|
||||
legacy_hub: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Resolve a service by key. Uses cache, resolves dependencies recursively.
|
||||
Falls back to legacy_hub if service module cannot be loaded.
|
||||
Raises KeyError if the service is not registered.
|
||||
"""
|
||||
cache_key = f"{_make_context_id(context)}_{key}"
|
||||
if cache_key in cache:
|
||||
|
|
@ -82,12 +46,10 @@ def resolve(
|
|||
raise RuntimeError(f"Circular dependency detected for service: {key}")
|
||||
|
||||
def get_service(dep_key: str) -> Any:
|
||||
return resolve(dep_key, context, cache, resolving, legacy_hub)
|
||||
return resolve(dep_key, context, cache, resolving)
|
||||
|
||||
# Try core first
|
||||
if key in CORE_SERVICES:
|
||||
spec = CORE_SERVICES[key]
|
||||
try:
|
||||
spec = CORE_SERVICES.get(key) or IMPORTABLE_SERVICES.get(key)
|
||||
if spec:
|
||||
cls = _load_service_class(spec["module"], spec["class"])
|
||||
resolving.add(key)
|
||||
try:
|
||||
|
|
@ -98,43 +60,6 @@ def resolve(
|
|||
instance = cls(context, get_service)
|
||||
cache[cache_key] = instance
|
||||
return instance
|
||||
except (ImportError, ModuleNotFoundError, AttributeError) as e:
|
||||
logger.debug(f"Could not load core service '{key}' from service center: {e}")
|
||||
if legacy_hub:
|
||||
fallback = _get_from_legacy(legacy_hub, key)
|
||||
if fallback is not None:
|
||||
cache[cache_key] = fallback
|
||||
return fallback
|
||||
raise
|
||||
|
||||
# Try importable
|
||||
if key in IMPORTABLE_SERVICES:
|
||||
spec = IMPORTABLE_SERVICES[key]
|
||||
try:
|
||||
cls = _load_service_class(spec["module"], spec["class"])
|
||||
resolving.add(key)
|
||||
try:
|
||||
for dep in spec.get("dependencies", []):
|
||||
get_service(dep)
|
||||
finally:
|
||||
resolving.discard(key)
|
||||
instance = cls(context, get_service)
|
||||
cache[cache_key] = instance
|
||||
return instance
|
||||
except (ImportError, ModuleNotFoundError, AttributeError) as e:
|
||||
logger.debug(f"Could not load importable service '{key}' from service center: {e}")
|
||||
if legacy_hub:
|
||||
fallback = _get_from_legacy(legacy_hub, key)
|
||||
if fallback is not None:
|
||||
cache[cache_key] = fallback
|
||||
return fallback
|
||||
raise
|
||||
|
||||
if legacy_hub:
|
||||
fallback = _get_from_legacy(legacy_hub, key)
|
||||
if fallback is not None:
|
||||
cache[cache_key] = fallback
|
||||
return fallback
|
||||
|
||||
raise KeyError(f"Unknown service: {key}")
|
||||
|
||||
|
|
|
|||
3
modules/serviceCenter/services/serviceAgent/__init__.py
Normal file
3
modules/serviceCenter/services/serviceAgent/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""serviceAgent: AI Agent with ReAct loop and native function calling."""
|
||||
162
modules/serviceCenter/services/serviceAgent/actionToolAdapter.py
Normal file
162
modules/serviceCenter/services/serviceAgent/actionToolAdapter.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""ActionToolAdapter: wraps existing workflow actions (dynamicMode=True) as agent tools."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
|
||||
ToolDefinition, ToolResult
|
||||
)
|
||||
from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActionToolAdapter:
|
||||
"""Wraps existing Workflow-Actions as Agent-Tools.
|
||||
|
||||
Iterates over discovered methods, finds actions with dynamicMode=True,
|
||||
and registers them in the ToolRegistry with a compound name (method.action).
|
||||
"""
|
||||
|
||||
def __init__(self, actionExecutor):
|
||||
self._actionExecutor = actionExecutor
|
||||
self._registeredTools: List[str] = []
|
||||
|
||||
def registerAll(self, toolRegistry: ToolRegistry):
|
||||
"""Discover and register all dynamicMode actions as agent tools."""
|
||||
from modules.workflows.processing.shared.methodDiscovery import methods
|
||||
|
||||
registered = 0
|
||||
for methodName, methodInfo in methods.items():
|
||||
if not methodName[0].isupper():
|
||||
continue
|
||||
|
||||
shortName = methodName.replace("Method", "").lower()
|
||||
methodInstance = methodInfo["instance"]
|
||||
|
||||
for actionName, actionInfo in methodInfo["actions"].items():
|
||||
actionDef = methodInstance._actions.get(actionName)
|
||||
if not actionDef or not getattr(actionDef, "dynamicMode", False):
|
||||
continue
|
||||
|
||||
compoundName = f"{shortName}_{actionName}"
|
||||
toolDef = _buildToolDefinition(compoundName, actionDef, actionInfo)
|
||||
|
||||
handler = _createDispatchHandler(self._actionExecutor, shortName, actionName)
|
||||
toolRegistry.registerFromDefinition(toolDef, handler)
|
||||
self._registeredTools.append(compoundName)
|
||||
registered += 1
|
||||
|
||||
logger.info(f"ActionToolAdapter: registered {registered} tools from workflow actions")
|
||||
|
||||
@property
|
||||
def registeredTools(self) -> List[str]:
|
||||
"""Names of all tools registered by this adapter."""
|
||||
return list(self._registeredTools)
|
||||
|
||||
|
||||
def _buildToolDefinition(compoundName: str, actionDef, actionInfo: Dict[str, Any]) -> ToolDefinition:
|
||||
"""Build a ToolDefinition from a WorkflowActionDefinition."""
|
||||
parameters = _convertParameterSchema(actionInfo.get("parameters", {}))
|
||||
|
||||
return ToolDefinition(
|
||||
name=compoundName,
|
||||
description=actionDef.description or actionInfo.get("description", ""),
|
||||
parameters=parameters,
|
||||
readOnly=False
|
||||
)
|
||||
|
||||
|
||||
def _convertParameterSchema(actionParams: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert workflow action parameter schema to JSON Schema for tool definitions."""
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for paramName, paramInfo in actionParams.items():
|
||||
paramType = paramInfo.get("type", "str") if isinstance(paramInfo, dict) else "str"
|
||||
paramDesc = paramInfo.get("description", "") if isinstance(paramInfo, dict) else ""
|
||||
paramRequired = paramInfo.get("required", False) if isinstance(paramInfo, dict) else False
|
||||
|
||||
jsonType = _pythonTypeToJsonType(paramType)
|
||||
properties[paramName] = {
|
||||
"type": jsonType,
|
||||
"description": paramDesc
|
||||
}
|
||||
|
||||
if paramRequired:
|
||||
required.append(paramName)
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
|
||||
|
||||
def _pythonTypeToJsonType(pythonType: str) -> str:
|
||||
"""Map Python type strings to JSON Schema types."""
|
||||
mapping = {
|
||||
"str": "string",
|
||||
"int": "integer",
|
||||
"float": "number",
|
||||
"bool": "boolean",
|
||||
"list": "array",
|
||||
"dict": "object",
|
||||
"List[str]": "array",
|
||||
"List[int]": "array",
|
||||
"List[dict]": "array",
|
||||
"Dict[str, Any]": "object",
|
||||
}
|
||||
return mapping.get(pythonType, "string")
|
||||
|
||||
|
||||
def _createDispatchHandler(actionExecutor, methodName: str, actionName: str):
|
||||
"""Create an async handler that dispatches to the ActionExecutor."""
|
||||
async def _handler(args: Dict[str, Any], context: Dict[str, Any]) -> ToolResult:
|
||||
try:
|
||||
result = await actionExecutor.executeAction(methodName, actionName, args)
|
||||
data = _formatActionResult(result)
|
||||
return ToolResult(
|
||||
toolCallId="",
|
||||
toolName=f"{methodName}_{actionName}",
|
||||
success=result.success,
|
||||
data=data,
|
||||
error=result.error
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"ActionToolAdapter dispatch failed for {methodName}_{actionName}: {e}")
|
||||
return ToolResult(
|
||||
toolCallId="",
|
||||
toolName=f"{methodName}_{actionName}",
|
||||
success=False,
|
||||
error=str(e)
|
||||
)
|
||||
return _handler
|
||||
|
||||
|
||||
def _formatActionResult(result) -> str:
|
||||
"""Format an ActionResult into a text representation for the agent."""
|
||||
parts = []
|
||||
|
||||
if result.resultLabel:
|
||||
parts.append(f"Result: {result.resultLabel}")
|
||||
|
||||
if result.error:
|
||||
parts.append(f"Error: {result.error}")
|
||||
|
||||
if result.documents:
|
||||
parts.append(f"Documents ({len(result.documents)}):")
|
||||
for doc in result.documents:
|
||||
docName = getattr(doc, "documentName", "unnamed")
|
||||
docType = getattr(doc, "mimeType", "unknown")
|
||||
parts.append(f" - {docName} ({docType})")
|
||||
docData = getattr(doc, "documentData", None)
|
||||
if docData and isinstance(docData, str) and len(docData) < 2000:
|
||||
parts.append(f" Content: {docData[:2000]}")
|
||||
|
||||
if not parts:
|
||||
parts.append("Action completed successfully." if result.success else "Action failed.")
|
||||
|
||||
return "\n".join(parts)
|
||||
507
modules/serviceCenter/services/serviceAgent/agentLoop.py
Normal file
507
modules/serviceCenter/services/serviceAgent/agentLoop.py
Normal file
|
|
@ -0,0 +1,507 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Agent loop: ReAct pattern with native function calling, budget control, and error handling."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator, Callable, Awaitable
|
||||
|
||||
from modules.datamodels.datamodelAi import (
|
||||
AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum
|
||||
)
|
||||
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
|
||||
AgentState, AgentStatusEnum, AgentConfig, AgentEvent, AgentEventTypeEnum,
|
||||
ToolCallRequest, ToolResult, ToolCallLog, AgentRoundLog, AgentTrace
|
||||
)
|
||||
from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry
|
||||
from modules.serviceCenter.services.serviceAgent.conversationManager import (
|
||||
ConversationManager, buildSystemPrompt
|
||||
)
|
||||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
from modules.shared.jsonUtils import closeJsonStructures
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def runAgentLoop(
|
||||
prompt: str,
|
||||
toolRegistry: ToolRegistry,
|
||||
config: AgentConfig,
|
||||
aiCallFn: Callable[[AiCallRequest], Awaitable[AiCallResponse]],
|
||||
getWorkflowCostFn: Callable[[], Awaitable[float]],
|
||||
workflowId: str,
|
||||
userId: str = "",
|
||||
featureInstanceId: str = "",
|
||||
buildRagContextFn: Callable[..., Awaitable[str]] = None,
|
||||
mandateId: str = "",
|
||||
aiCallStreamFn: Callable = None,
|
||||
userLanguage: str = "",
|
||||
conversationHistory: List[Dict[str, Any]] = None,
|
||||
) -> AsyncGenerator[AgentEvent, None]:
|
||||
"""Run the agent loop. Yields AgentEvent for each step (SSE-ready).
|
||||
|
||||
Args:
|
||||
prompt: User prompt
|
||||
toolRegistry: Registry with available tools
|
||||
config: Agent configuration (maxRounds, maxCostCHF, etc.)
|
||||
aiCallFn: Function to call the AI (wraps serviceAi.callAi with billing)
|
||||
getWorkflowCostFn: Function to get current workflow cost
|
||||
workflowId: Workflow ID for tracking
|
||||
userId: User ID for tracing
|
||||
featureInstanceId: Feature instance ID for tracing
|
||||
buildRagContextFn: Optional async function to build RAG context before each round
|
||||
mandateId: Mandate ID for RAG scoping
|
||||
userLanguage: ISO 639-1 language code for agent responses
|
||||
conversationHistory: Prior messages [{role, content/message}] for follow-up context
|
||||
"""
|
||||
state = AgentState(workflowId=workflowId, maxRounds=config.maxRounds)
|
||||
trace = AgentTrace(
|
||||
workflowId=workflowId, userId=userId,
|
||||
featureInstanceId=featureInstanceId
|
||||
)
|
||||
|
||||
tools = toolRegistry.getTools()
|
||||
toolDefinitions = toolRegistry.formatToolsForFunctionCalling()
|
||||
|
||||
# Text-based tool descriptions are ONLY used as fallback when native function
|
||||
# calling is unavailable. Including both creates conflicting instructions
|
||||
# (text ```tool_call format vs native tool_use blocks) and can cause the model
|
||||
# to respond with plain text instead of actual tool calls.
|
||||
toolsText = "" if toolDefinitions else toolRegistry.formatToolsForPrompt()
|
||||
|
||||
systemPrompt = buildSystemPrompt(tools, toolsText, userLanguage=userLanguage)
|
||||
conversation = ConversationManager(systemPrompt)
|
||||
if conversationHistory:
|
||||
conversation.loadHistory(conversationHistory)
|
||||
conversation.addUserMessage(prompt)
|
||||
|
||||
while state.status == AgentStatusEnum.RUNNING and state.currentRound < state.maxRounds:
|
||||
await asyncio.sleep(0)
|
||||
state.currentRound += 1
|
||||
roundStartTime = time.time()
|
||||
roundLog = AgentRoundLog(roundNumber=state.currentRound)
|
||||
|
||||
# RAG context injection (before each round for fresh relevance)
|
||||
if buildRagContextFn:
|
||||
try:
|
||||
latestUserMsg = ""
|
||||
for msg in reversed(conversation.messages):
|
||||
if msg.get("role") == "user":
|
||||
latestUserMsg = msg.get("content", "")
|
||||
break
|
||||
ragContext = await buildRagContextFn(
|
||||
currentPrompt=latestUserMsg or prompt,
|
||||
workflowId=workflowId,
|
||||
userId=userId,
|
||||
featureInstanceId=featureInstanceId,
|
||||
mandateId=mandateId,
|
||||
)
|
||||
if ragContext:
|
||||
conversation.injectRagContext(ragContext)
|
||||
except Exception as ragErr:
|
||||
logger.warning(f"RAG context injection failed (non-blocking): {ragErr}")
|
||||
|
||||
# Budget check
|
||||
budgetExceeded = await _checkBudget(config, getWorkflowCostFn)
|
||||
if budgetExceeded:
|
||||
state.status = AgentStatusEnum.BUDGET_EXCEEDED
|
||||
state.abortReason = "Workflow cost budget exceeded"
|
||||
yield AgentEvent(
|
||||
type=AgentEventTypeEnum.FINAL,
|
||||
content=_buildProgressSummary(state, "Budget exceeded. Here is the progress so far.")
|
||||
)
|
||||
break
|
||||
|
||||
logger.info(f"Agent round {state.currentRound}/{state.maxRounds} for workflow {workflowId} (tools={state.totalToolCalls}, cost={state.totalCostCHF:.4f})")
|
||||
yield AgentEvent(
|
||||
type=AgentEventTypeEnum.AGENT_PROGRESS,
|
||||
data={
|
||||
"round": state.currentRound,
|
||||
"maxRounds": state.maxRounds,
|
||||
"totalAiCalls": state.totalAiCalls,
|
||||
"totalToolCalls": state.totalToolCalls,
|
||||
"costCHF": state.totalCostCHF
|
||||
}
|
||||
)
|
||||
|
||||
# Progressive summarization
|
||||
if conversation.needsSummarization(state.currentRound):
|
||||
async def _summarizeCall(summaryPrompt: str) -> str:
|
||||
req = AiCallRequest(
|
||||
prompt=summaryPrompt,
|
||||
options=AiCallOptions(operationType=OperationTypeEnum.DATA_ANALYSE)
|
||||
)
|
||||
resp = await aiCallFn(req)
|
||||
state.totalCostCHF += resp.priceCHF
|
||||
state.totalAiCalls += 1
|
||||
return resp.content
|
||||
|
||||
await conversation.summarize(state.currentRound, _summarizeCall)
|
||||
|
||||
# AI call
|
||||
aiRequest = AiCallRequest(
|
||||
prompt="",
|
||||
options=AiCallOptions(
|
||||
operationType=OperationTypeEnum.AGENT,
|
||||
temperature=config.temperature
|
||||
),
|
||||
messages=conversation.messages,
|
||||
tools=toolDefinitions
|
||||
)
|
||||
|
||||
try:
|
||||
aiResponse = None
|
||||
streamedText = ""
|
||||
isFirstChunkOfRound = True
|
||||
|
||||
if aiCallStreamFn:
|
||||
async for chunk in aiCallStreamFn(aiRequest):
|
||||
if isinstance(chunk, str):
|
||||
if isFirstChunkOfRound and state.currentRound > 1:
|
||||
chunk = "\n\n" + chunk
|
||||
isFirstChunkOfRound = False
|
||||
elif isFirstChunkOfRound:
|
||||
isFirstChunkOfRound = False
|
||||
streamedText += chunk
|
||||
yield AgentEvent(type=AgentEventTypeEnum.CHUNK, content=chunk)
|
||||
else:
|
||||
aiResponse = chunk
|
||||
|
||||
if aiResponse is None:
|
||||
raise RuntimeError("Stream ended without final AiCallResponse")
|
||||
else:
|
||||
aiResponse = await aiCallFn(aiRequest)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AI call failed in round {state.currentRound}: {e}", exc_info=True)
|
||||
state.status = AgentStatusEnum.ERROR
|
||||
state.abortReason = f"AI call error: {e}"
|
||||
yield AgentEvent(type=AgentEventTypeEnum.ERROR, content=str(e))
|
||||
break
|
||||
|
||||
state.totalAiCalls += 1
|
||||
state.totalCostCHF += aiResponse.priceCHF
|
||||
state.totalProcessingTime += aiResponse.processingTime
|
||||
roundLog.aiModel = aiResponse.modelName
|
||||
roundLog.costCHF = aiResponse.priceCHF
|
||||
|
||||
if aiResponse.errorCount > 0:
|
||||
state.status = AgentStatusEnum.ERROR
|
||||
state.abortReason = f"AI returned error: {aiResponse.content}"
|
||||
yield AgentEvent(type=AgentEventTypeEnum.ERROR, content=aiResponse.content)
|
||||
break
|
||||
|
||||
# Parse response for tool calls
|
||||
toolCalls = _parseToolCalls(aiResponse)
|
||||
textContent = _extractTextContent(aiResponse)
|
||||
|
||||
logger.debug(
|
||||
f"Round {state.currentRound} AI response: model={aiResponse.modelName}, "
|
||||
f"toolCalls={len(toolCalls)}, nativeToolCalls={'yes' if aiResponse.toolCalls else 'no'}, "
|
||||
f"contentLen={len(aiResponse.content)}, streamedLen={len(streamedText)}"
|
||||
)
|
||||
|
||||
# Empty response (no content, no tool calls) = model returned nothing useful.
|
||||
# Burn the round but let the loop continue so the next iteration can retry
|
||||
# (the failover mechanism in the AI layer will try alternative models).
|
||||
if not toolCalls and not textContent and not streamedText:
|
||||
logger.warning(
|
||||
f"Round {state.currentRound}: AI returned empty response "
|
||||
f"(model={aiResponse.modelName}). Retrying next round."
|
||||
)
|
||||
conversation.addUserMessage(
|
||||
"Your previous response was empty. Please use the available tools "
|
||||
"to accomplish the task. Start by planning the steps, then call the "
|
||||
"appropriate tools."
|
||||
)
|
||||
roundLog.durationMs = int((time.time() - roundStartTime) * 1000)
|
||||
trace.rounds.append(roundLog)
|
||||
continue
|
||||
|
||||
if textContent and not streamedText:
|
||||
yield AgentEvent(type=AgentEventTypeEnum.MESSAGE, content=textContent)
|
||||
|
||||
if not toolCalls:
|
||||
state.status = AgentStatusEnum.COMPLETED
|
||||
conversation.addAssistantMessage(aiResponse.content)
|
||||
roundLog.durationMs = int((time.time() - roundStartTime) * 1000)
|
||||
trace.rounds.append(roundLog)
|
||||
yield AgentEvent(type=AgentEventTypeEnum.FINAL, content=textContent or aiResponse.content)
|
||||
break
|
||||
|
||||
# Add assistant message with tool calls to conversation
|
||||
assistantToolCalls = _formatAssistantToolCalls(toolCalls)
|
||||
conversation.addAssistantMessage(textContent or "", assistantToolCalls)
|
||||
|
||||
# Execute tool calls
|
||||
for tc in toolCalls:
|
||||
yield AgentEvent(
|
||||
type=AgentEventTypeEnum.TOOL_CALL,
|
||||
data={"toolName": tc.name, "args": tc.args}
|
||||
)
|
||||
|
||||
results = await _executeToolCalls(toolCalls, toolRegistry, {
|
||||
"workflowId": workflowId,
|
||||
"userId": userId,
|
||||
"featureInstanceId": featureInstanceId,
|
||||
"mandateId": mandateId,
|
||||
})
|
||||
state.totalToolCalls += len(results)
|
||||
|
||||
for result in results:
|
||||
roundLog.toolCalls.append(ToolCallLog(
|
||||
toolName=result.toolName,
|
||||
args=next((tc.args for tc in toolCalls if tc.id == result.toolCallId), {}),
|
||||
success=result.success,
|
||||
durationMs=result.durationMs,
|
||||
error=result.error,
|
||||
resultData=result.data[:300] if result.data else "",
|
||||
))
|
||||
if not result.success:
|
||||
logger.warning(f"Tool '{result.toolName}' failed: {result.error}")
|
||||
yield AgentEvent(
|
||||
type=AgentEventTypeEnum.TOOL_RESULT,
|
||||
data={
|
||||
"toolName": result.toolName,
|
||||
"success": result.success,
|
||||
"data": result.data[:500] if result.data else "",
|
||||
"error": result.error
|
||||
}
|
||||
)
|
||||
if result.sideEvents:
|
||||
for sideEvt in result.sideEvents:
|
||||
evtType = sideEvt.get("type", "")
|
||||
try:
|
||||
evtEnum = AgentEventTypeEnum(evtType)
|
||||
except (ValueError, KeyError):
|
||||
continue
|
||||
yield AgentEvent(
|
||||
type=evtEnum,
|
||||
data=sideEvt.get("data"),
|
||||
content=sideEvt.get("content"),
|
||||
)
|
||||
|
||||
# Add tool results to conversation
|
||||
toolResultMessages = [
|
||||
{"toolCallId": r.toolCallId, "toolName": r.toolName,
|
||||
"content": r.data if r.success else f"Error: {r.error}"}
|
||||
for r in results
|
||||
]
|
||||
conversation.addToolResults(toolResultMessages)
|
||||
|
||||
roundLog.durationMs = int((time.time() - roundStartTime) * 1000)
|
||||
trace.rounds.append(roundLog)
|
||||
|
||||
# maxRounds reached
|
||||
if state.currentRound >= state.maxRounds and state.status == AgentStatusEnum.RUNNING:
|
||||
state.status = AgentStatusEnum.MAX_ROUNDS_REACHED
|
||||
state.abortReason = f"Maximum rounds ({state.maxRounds}) reached"
|
||||
yield AgentEvent(
|
||||
type=AgentEventTypeEnum.FINAL,
|
||||
content=_buildProgressSummary(state, "Maximum rounds reached.")
|
||||
)
|
||||
|
||||
# Agent summary
|
||||
trace.completedAt = getUtcTimestamp()
|
||||
trace.status = state.status
|
||||
trace.totalRounds = state.currentRound
|
||||
trace.totalToolCalls = state.totalToolCalls
|
||||
trace.totalCostCHF = state.totalCostCHF
|
||||
trace.abortReason = state.abortReason
|
||||
|
||||
artifactSummary = _buildArtifactSummary(trace.rounds)
|
||||
|
||||
yield AgentEvent(
|
||||
type=AgentEventTypeEnum.AGENT_SUMMARY,
|
||||
data={
|
||||
"rounds": state.currentRound,
|
||||
"totalAiCalls": state.totalAiCalls,
|
||||
"totalToolCalls": state.totalToolCalls,
|
||||
"costCHF": round(state.totalCostCHF, 4),
|
||||
"processingTime": round(state.totalProcessingTime, 2),
|
||||
"status": state.status.value,
|
||||
"abortReason": state.abortReason,
|
||||
"artifacts": artifactSummary,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def _checkBudget(config: AgentConfig,
|
||||
getWorkflowCostFn: Callable[[], Awaitable[float]]) -> bool:
|
||||
"""Check if workflow budget is exceeded. Returns True if exceeded."""
|
||||
if config.maxCostCHF is None:
|
||||
return False
|
||||
try:
|
||||
currentCost = await getWorkflowCostFn()
|
||||
return currentCost > config.maxCostCHF
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not check workflow cost: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def _executeToolCalls(toolCalls: List[ToolCallRequest],
|
||||
toolRegistry: ToolRegistry,
|
||||
context: Dict[str, Any]) -> List[ToolResult]:
|
||||
"""Execute tool calls: readOnly tools in parallel, others sequentially.
|
||||
|
||||
Tool calls with _parseError (truncated JSON from LLM) are short-circuited
|
||||
with an error result so the agent can retry.
|
||||
"""
|
||||
readOnlyCalls = [tc for tc in toolCalls if toolRegistry.isReadOnly(tc.name)]
|
||||
writeCalls = [tc for tc in toolCalls if not toolRegistry.isReadOnly(tc.name)]
|
||||
|
||||
results: Dict[str, ToolResult] = {}
|
||||
|
||||
for tc in toolCalls:
|
||||
if "_parseError" in tc.args:
|
||||
results[tc.id] = ToolResult(
|
||||
toolCallId=tc.id,
|
||||
toolName=tc.name,
|
||||
success=False,
|
||||
data="",
|
||||
error=tc.args["_parseError"],
|
||||
durationMs=0,
|
||||
)
|
||||
|
||||
activeCalls = [tc for tc in toolCalls if tc.id not in results]
|
||||
activeReadOnly = [tc for tc in activeCalls if toolRegistry.isReadOnly(tc.name)]
|
||||
activeWrite = [tc for tc in activeCalls if not toolRegistry.isReadOnly(tc.name)]
|
||||
|
||||
if activeReadOnly:
|
||||
readResults = await asyncio.gather(*[
|
||||
toolRegistry.dispatch(tc, context) for tc in activeReadOnly
|
||||
])
|
||||
for tc, result in zip(activeReadOnly, readResults):
|
||||
results[tc.id] = result
|
||||
|
||||
for tc in activeWrite:
|
||||
results[tc.id] = await toolRegistry.dispatch(tc, context)
|
||||
|
||||
return [results[tc.id] for tc in toolCalls]
|
||||
|
||||
|
||||
def _repairTruncatedJson(raw: str) -> Optional[Dict[str, Any]]:
|
||||
"""Repair truncated JSON using the shared jsonUtils toolbox.
|
||||
|
||||
Uses closeJsonStructures which handles open strings, brackets, braces,
|
||||
and trailing commas with stack-based structure tracking.
|
||||
Returns parsed dict on success, None if unrecoverable.
|
||||
"""
|
||||
if not raw or not raw.strip().startswith("{"):
|
||||
return None
|
||||
try:
|
||||
closed = closeJsonStructures(raw)
|
||||
return json.loads(closed)
|
||||
except (json.JSONDecodeError, Exception):
|
||||
return None
|
||||
|
||||
|
||||
def _parseToolCalls(aiResponse: AiCallResponse) -> List[ToolCallRequest]:
|
||||
"""Parse tool calls from AI response. Supports native function calling and text-based fallback."""
|
||||
toolCalls = []
|
||||
|
||||
# Native function calling: check response metadata
|
||||
if hasattr(aiResponse, 'toolCalls') and aiResponse.toolCalls:
|
||||
for tc in aiResponse.toolCalls:
|
||||
rawArgs = tc["function"]["arguments"]
|
||||
if isinstance(rawArgs, str):
|
||||
rawArgs = rawArgs.strip()
|
||||
try:
|
||||
parsedArgs = json.loads(rawArgs) if rawArgs else {}
|
||||
except json.JSONDecodeError:
|
||||
parsedArgs = _repairTruncatedJson(rawArgs)
|
||||
if parsedArgs is None:
|
||||
logger.warning(f"Unrecoverable truncated JSON for '{tc['function']['name']}': {rawArgs[:200]}")
|
||||
parsedArgs = {"_parseError": (
|
||||
"Your tool call arguments were truncated (output cut off by token limit). "
|
||||
"The content is too large for a single tool call. Strategies:\n"
|
||||
"1. For new files: use writeFile(mode='create') with the first part, "
|
||||
"then writeFile(fileId=..., mode='append') for subsequent parts (~8000 chars each).\n"
|
||||
"2. For editing existing files: use replaceInFile to change only the specific parts.\n"
|
||||
"3. For documentation: split into multiple smaller files."
|
||||
)}
|
||||
else:
|
||||
logger.info(f"Repaired truncated JSON for '{tc['function']['name']}'")
|
||||
else:
|
||||
parsedArgs = rawArgs if rawArgs else {}
|
||||
toolCalls.append(ToolCallRequest(
|
||||
id=tc.get("id", str(len(toolCalls))),
|
||||
name=tc["function"]["name"],
|
||||
args=parsedArgs,
|
||||
))
|
||||
return toolCalls
|
||||
|
||||
# Text-based fallback: parse ```tool_call blocks
|
||||
content = aiResponse.content or ""
|
||||
pattern = r"```tool_call\s*\n\s*tool:\s*(\S+)\s*\n\s*args:\s*(\{.*?\})\s*\n\s*```"
|
||||
matches = re.finditer(pattern, content, re.DOTALL)
|
||||
|
||||
for match in matches:
|
||||
toolName = match.group(1).strip()
|
||||
argsStr = match.group(2).strip()
|
||||
try:
|
||||
args = json.loads(argsStr)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse tool args for '{toolName}': {argsStr}")
|
||||
args = {}
|
||||
toolCalls.append(ToolCallRequest(name=toolName, args=args))
|
||||
|
||||
return toolCalls
|
||||
|
||||
|
||||
def _extractTextContent(aiResponse: AiCallResponse) -> str:
|
||||
"""Extract text content from AI response, removing tool_call blocks."""
|
||||
content = aiResponse.content or ""
|
||||
cleaned = re.sub(r"```tool_call\s*\n.*?\n\s*```", "", content, flags=re.DOTALL)
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def _formatAssistantToolCalls(toolCalls: List[ToolCallRequest]) -> List[Dict[str, Any]]:
|
||||
"""Format tool calls for the conversation history (OpenAI tool_calls format)."""
|
||||
return [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.name,
|
||||
"arguments": json.dumps(tc.args)
|
||||
}
|
||||
}
|
||||
for tc in toolCalls
|
||||
]
|
||||
|
||||
|
||||
def _buildProgressSummary(state: AgentState, reason: str) -> str:
|
||||
"""Build a human-readable summary of agent progress for graceful termination."""
|
||||
return (
|
||||
f"{reason}\n\n"
|
||||
f"Progress after {state.currentRound} rounds:\n"
|
||||
f"- AI calls: {state.totalAiCalls}\n"
|
||||
f"- Tool calls: {state.totalToolCalls}\n"
|
||||
f"- Cost: {state.totalCostCHF:.4f} CHF\n"
|
||||
f"- Processing time: {state.totalProcessingTime:.1f}s"
|
||||
)
|
||||
|
||||
|
||||
_ARTIFACT_TOOLS = {"writeFile", "replaceInFile", "deleteFile", "renameFile", "copyFile",
|
||||
"createFolder", "deleteFolder", "renderDocument", "generateImage"}
|
||||
|
||||
def _buildArtifactSummary(roundLogs: List[AgentRoundLog]) -> str:
|
||||
"""Extract file operations and key results from all agent rounds.
|
||||
|
||||
Produces a concise summary persisted as _workflowArtifacts so
|
||||
follow-up rounds have immediate context (file IDs, names, actions).
|
||||
"""
|
||||
ops = []
|
||||
for log in roundLogs:
|
||||
for tc in log.toolCalls:
|
||||
if tc.toolName not in _ARTIFACT_TOOLS or not tc.success:
|
||||
continue
|
||||
ops.append(f"- {tc.resultData}" if tc.resultData else f"- {tc.toolName}")
|
||||
|
||||
if not ops:
|
||||
return ""
|
||||
return "File operations in this run:\n" + "\n".join(ops)
|
||||
|
|
@ -0,0 +1,331 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Conversation manager for the Agent service.
|
||||
Handles message history, context window management, and progressive summarization."""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from modules.serviceCenter.services.serviceAgent.datamodelAgent import ToolDefinition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FIRST_SUMMARY_ROUND = 4
|
||||
META_SUMMARY_ROUND = 7
|
||||
KEEP_RECENT_MESSAGES = 4
|
||||
MAX_ESTIMATED_TOKENS = 60000
|
||||
_MAX_HISTORY_MESSAGES = 40
|
||||
_MAX_HISTORY_MSG_CHARS = 12000
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
"""Manages the conversation history and context window for agent runs.
|
||||
|
||||
Progressive summarization strategy:
|
||||
- Rounds 1-3: full conversation retained
|
||||
- Round 4+: older messages compressed into a running summary
|
||||
- Round 7+: meta-summary replaces prior summaries
|
||||
Supports RAG context injection before each round via injectRagContext."""
|
||||
|
||||
def __init__(self, systemPrompt: str):
|
||||
self._messages: List[Dict[str, Any]] = [
|
||||
{"role": "system", "content": systemPrompt}
|
||||
]
|
||||
self._summaries: List[Dict[str, Any]] = []
|
||||
self._lastSummarizedRound: int = 0
|
||||
self._ragContextInjected: bool = False
|
||||
|
||||
def loadHistory(self, messages: List[Dict[str, Any]]):
|
||||
"""Load prior conversation messages for follow-up context.
|
||||
|
||||
Accepts messages with {role, content/message} format (as stored in DB).
|
||||
Truncates long messages and limits total count to keep the context window
|
||||
manageable. Must be called BEFORE addUserMessage with the current prompt.
|
||||
"""
|
||||
if not messages:
|
||||
return
|
||||
|
||||
recent = messages[-_MAX_HISTORY_MESSAGES:]
|
||||
loaded = 0
|
||||
for msg in recent:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "") or msg.get("message", "") or ""
|
||||
if role not in ("user", "assistant"):
|
||||
continue
|
||||
if not content.strip():
|
||||
continue
|
||||
if len(content) > _MAX_HISTORY_MSG_CHARS:
|
||||
content = content[:_MAX_HISTORY_MSG_CHARS] + "…"
|
||||
self._messages.append({"role": role, "content": content})
|
||||
loaded += 1
|
||||
if loaded:
|
||||
logger.info(f"Loaded {loaded} history messages into conversation context")
|
||||
|
||||
@property
|
||||
def messages(self) -> List[Dict[str, Any]]:
|
||||
"""Current messages for the next AI call (internal markers stripped)."""
|
||||
return [
|
||||
{k: v for k, v in msg.items() if not k.startswith("_")}
|
||||
for msg in self._messages
|
||||
]
|
||||
|
||||
def addUserMessage(self, content: str):
|
||||
"""Add a user message."""
|
||||
self._messages.append({"role": "user", "content": content})
|
||||
|
||||
def addAssistantMessage(self, content: str, toolCalls: List[Dict[str, Any]] = None):
|
||||
"""Add an assistant message, optionally with tool calls."""
|
||||
msg: Dict[str, Any] = {"role": "assistant", "content": content}
|
||||
if toolCalls:
|
||||
msg["tool_calls"] = toolCalls
|
||||
self._messages.append(msg)
|
||||
|
||||
def addToolResults(self, results: List[Dict[str, Any]]):
|
||||
"""Add tool results to the conversation.
|
||||
Each result: {toolCallId, toolName, content}."""
|
||||
for result in results:
|
||||
self._messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": result["toolCallId"],
|
||||
"content": result["content"]
|
||||
})
|
||||
|
||||
def addToolResultsAsText(self, resultText: str):
|
||||
"""Add combined tool results as a user message (text-based fallback)."""
|
||||
self._messages.append({
|
||||
"role": "user",
|
||||
"content": f"Tool Results:\n{resultText}"
|
||||
})
|
||||
|
||||
def injectRagContext(self, ragContext: str):
|
||||
"""Inject RAG context as a system message right after the main system prompt.
|
||||
|
||||
Called before each agent round by the agent loop if KnowledgeService is available.
|
||||
Replaces any previously injected RAG context to keep the context fresh."""
|
||||
if not ragContext:
|
||||
return
|
||||
|
||||
ragMessage = {
|
||||
"role": "system",
|
||||
"content": f"Relevant Knowledge (from indexed documents and workflow context):\n{ragContext}",
|
||||
"_isRagContext": True,
|
||||
}
|
||||
|
||||
# Replace existing RAG message if present, otherwise insert after system prompt
|
||||
for i, msg in enumerate(self._messages):
|
||||
if msg.get("_isRagContext"):
|
||||
self._messages[i] = ragMessage
|
||||
self._ragContextInjected = True
|
||||
return
|
||||
|
||||
# Insert after the first system prompt
|
||||
self._messages.insert(1, ragMessage)
|
||||
self._ragContextInjected = True
|
||||
|
||||
def getMessageCount(self) -> int:
|
||||
"""Get the number of messages (excluding system prompt)."""
|
||||
return len(self._messages) - 1
|
||||
|
||||
def estimateTokenCount(self) -> int:
|
||||
"""Rough estimate of total tokens in the conversation (4 chars ≈ 1 token)."""
|
||||
totalChars = sum(len(str(m.get("content", ""))) for m in self._messages)
|
||||
return totalChars // 4
|
||||
|
||||
def needsSummarization(self, currentRound: int) -> bool:
|
||||
"""Check if progressive summarization should be triggered.
|
||||
|
||||
Triggers:
|
||||
- At round FIRST_SUMMARY_ROUND (4) if not yet summarized
|
||||
- At round META_SUMMARY_ROUND (7) for meta-summary
|
||||
- Every 5 rounds after that
|
||||
- When estimated token count exceeds MAX_ESTIMATED_TOKENS
|
||||
"""
|
||||
if currentRound >= FIRST_SUMMARY_ROUND and self._lastSummarizedRound < currentRound:
|
||||
if currentRound == FIRST_SUMMARY_ROUND or currentRound == META_SUMMARY_ROUND:
|
||||
return True
|
||||
if (currentRound - META_SUMMARY_ROUND) % 5 == 0 and currentRound > META_SUMMARY_ROUND:
|
||||
return True
|
||||
if self.estimateTokenCount() > MAX_ESTIMATED_TOKENS:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def summarize(self, currentRound: int, aiCallFn) -> Optional[str]:
|
||||
"""Perform progressive summarization of older messages.
|
||||
|
||||
Rounds 1-3: full history retained, no summarization.
|
||||
Round 4+: compress older messages into a running summary.
|
||||
Round 7+: meta-summary that consolidates prior summaries.
|
||||
"""
|
||||
if currentRound < FIRST_SUMMARY_ROUND and self.estimateTokenCount() <= MAX_ESTIMATED_TOKENS:
|
||||
return None
|
||||
|
||||
systemMsgs = [m for m in self._messages if m.get("role") == "system"]
|
||||
nonSystemMessages = [m for m in self._messages if m.get("role") != "system"]
|
||||
|
||||
keepRecent = min(KEEP_RECENT_MESSAGES, len(nonSystemMessages))
|
||||
if len(nonSystemMessages) <= keepRecent + 1:
|
||||
return None
|
||||
|
||||
splitIdx = len(nonSystemMessages) - keepRecent
|
||||
# Ensure the split doesn't orphan tool messages from their assistant.
|
||||
# Walk backwards from splitIdx: if we're landing in the middle of a
|
||||
# tool-call sequence (assistant+tool_calls → tool → tool …), include
|
||||
# the entire sequence in recentMessages.
|
||||
while splitIdx > 0 and nonSystemMessages[splitIdx].get("role") == "tool":
|
||||
splitIdx -= 1
|
||||
# Also include the assistant message that triggered the tool calls.
|
||||
if splitIdx > 0 and splitIdx < len(nonSystemMessages) and \
|
||||
nonSystemMessages[splitIdx].get("role") == "assistant" and \
|
||||
nonSystemMessages[splitIdx].get("tool_calls"):
|
||||
pass # splitIdx already points at the assistant; keep it in recent
|
||||
elif splitIdx == 0:
|
||||
return None # nothing to summarize
|
||||
|
||||
messagesToSummarize = nonSystemMessages[:splitIdx]
|
||||
recentMessages = nonSystemMessages[splitIdx:]
|
||||
|
||||
summaryInput = _formatMessagesForSummary(messagesToSummarize)
|
||||
previousSummary = self._summaries[-1]["content"] if self._summaries else ""
|
||||
|
||||
isMetaSummary = currentRound >= META_SUMMARY_ROUND and len(self._summaries) >= 2
|
||||
summaryPrompt = _buildSummaryPrompt(summaryInput, previousSummary, isMetaSummary)
|
||||
|
||||
try:
|
||||
summaryText = await aiCallFn(summaryPrompt)
|
||||
except Exception as e:
|
||||
logger.error(f"Progressive summarization failed: {e}")
|
||||
return None
|
||||
|
||||
self._summaries.append({
|
||||
"round": currentRound,
|
||||
"content": summaryText,
|
||||
"isMeta": isMetaSummary,
|
||||
})
|
||||
self._lastSummarizedRound = currentRound
|
||||
|
||||
mainSystem = systemMsgs[0] if systemMsgs else {"role": "system", "content": ""}
|
||||
ragMessages = [m for m in systemMsgs if m.get("_isRagContext")]
|
||||
|
||||
self._messages = [
|
||||
mainSystem,
|
||||
*ragMessages,
|
||||
{"role": "system", "content": f"Conversation Summary (rounds 1-{currentRound - keepRecent}):\n{summaryText}"},
|
||||
*recentMessages,
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Progressive summarization at round {currentRound}: "
|
||||
f"compressed {len(messagesToSummarize)} messages into "
|
||||
f"{'meta-' if isMetaSummary else ''}summary"
|
||||
)
|
||||
return summaryText
|
||||
|
||||
|
||||
def _formatMessagesForSummary(messages: List[Dict[str, Any]]) -> str:
|
||||
"""Format messages into a text block for summarization."""
|
||||
parts = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
if role == "tool":
|
||||
toolName = msg.get("tool_call_id", "tool")
|
||||
parts.append(f"[Tool Result ({toolName})]:\n{content}")
|
||||
elif role == "assistant" and msg.get("tool_calls"):
|
||||
calls = msg["tool_calls"]
|
||||
callNames = [c.get("function", {}).get("name", "?") for c in calls]
|
||||
parts.append(f"[Assistant → Tool Calls: {', '.join(callNames)}]")
|
||||
if content:
|
||||
parts.append(f"[Assistant]: {content}")
|
||||
else:
|
||||
parts.append(f"[{role.capitalize()}]: {content}")
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
def _buildSummaryPrompt(messagesText: str, previousSummary: str, isMetaSummary: bool = False) -> str:
|
||||
"""Build the prompt for progressive summarization."""
|
||||
if isMetaSummary:
|
||||
prompt = (
|
||||
"Create a comprehensive meta-summary consolidating the previous summary "
|
||||
"and the new messages. Preserve all key facts, decisions, entities (names, "
|
||||
"numbers, dates), tool results, and action outcomes. Be concise but complete.\n\n"
|
||||
)
|
||||
else:
|
||||
prompt = (
|
||||
"Summarize the following conversation concisely. Preserve all key facts, "
|
||||
"decisions, entities (names, numbers, dates), and tool results. "
|
||||
"Do not lose any important information.\n\n"
|
||||
)
|
||||
if previousSummary:
|
||||
prompt += f"Previous Summary:\n{previousSummary}\n\n"
|
||||
prompt += f"New Messages to Summarize:\n{messagesText}\n\nProvide a concise, factual summary:"
|
||||
return prompt
|
||||
|
||||
|
||||
_LANGUAGE_NAMES = {
|
||||
"de": "German", "en": "English", "fr": "French", "it": "Italian",
|
||||
"es": "Spanish", "pt": "Portuguese", "nl": "Dutch", "ja": "Japanese",
|
||||
"zh": "Chinese", "ko": "Korean", "ar": "Arabic", "ru": "Russian",
|
||||
}
|
||||
|
||||
|
||||
def buildSystemPrompt(
|
||||
tools: List[ToolDefinition],
|
||||
toolsFormatted: str = None,
|
||||
userLanguage: str = "",
|
||||
) -> str:
|
||||
"""Build the system prompt for the agent.
|
||||
|
||||
Args:
|
||||
tools: Available tool definitions.
|
||||
toolsFormatted: Pre-formatted tool descriptions for text-based fallback.
|
||||
userLanguage: ISO 639-1 language code (e.g. "de", "en"). The agent will
|
||||
respond in this language.
|
||||
"""
|
||||
langName = _LANGUAGE_NAMES.get(userLanguage, "")
|
||||
langInstruction = (
|
||||
f"IMPORTANT: Always respond in {langName} ({userLanguage}). "
|
||||
f"The user's language is {langName}. All your messages, explanations, "
|
||||
f"and summaries MUST be in {langName}. "
|
||||
f"Only use English for tool call arguments and technical identifiers.\n\n"
|
||||
) if langName else ""
|
||||
|
||||
prompt = (
|
||||
f"{langInstruction}"
|
||||
"You are an AI agent with access to tools. "
|
||||
"Use the provided tools to accomplish the user's task. "
|
||||
"Think step by step. Call tools when you need information or need to perform actions. "
|
||||
"When you have enough information to answer, respond directly without calling tools.\n\n"
|
||||
)
|
||||
|
||||
prompt += (
|
||||
"## Working Guidelines\n\n"
|
||||
"### Workflow Context\n"
|
||||
"When continuing a workflow (follow-up message), the Relevant Knowledge section contains "
|
||||
"artifacts from previous rounds (file IDs, operations). Use this context instead of "
|
||||
"re-searching or re-listing files.\n\n"
|
||||
"### Efficient File Editing\n"
|
||||
"- Use readFile with offset/limit to read specific line ranges of large files.\n"
|
||||
"- Use searchInFileContent to find text before editing.\n"
|
||||
"- Use replaceInFile for targeted edits (preferred over rewriting entire files).\n"
|
||||
"- Use writeFile(mode='overwrite') only when the entire content must change.\n\n"
|
||||
"### Large Content Strategy\n"
|
||||
"- For content larger than ~8000 characters: use writeFile(mode='create') for the first "
|
||||
"part, then writeFile(fileId=..., mode='append') for subsequent parts.\n"
|
||||
"- Split large documentation into multiple focused files rather than one huge document.\n"
|
||||
"- Structure outputs so files reference each other (e.g. index.md linking to sections).\n\n"
|
||||
"### Code Generation\n"
|
||||
"- Prefer modular file structures over monolithic files.\n"
|
||||
"- When generating applications, create separate files for logical components.\n"
|
||||
"- Always plan the structure before writing code.\n\n"
|
||||
)
|
||||
|
||||
if toolsFormatted:
|
||||
prompt += f"Available Tools:\n{toolsFormatted}\n\n"
|
||||
prompt += (
|
||||
"To call a tool, use this format:\n"
|
||||
"```tool_call\n"
|
||||
"tool: <tool_name>\n"
|
||||
'args: {"param": "value"}\n'
|
||||
"```\n\n"
|
||||
)
|
||||
return prompt
|
||||
153
modules/serviceCenter/services/serviceAgent/datamodelAgent.py
Normal file
153
modules/serviceCenter/services/serviceAgent/datamodelAgent.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Data models for the Agent service."""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
import uuid
|
||||
|
||||
|
||||
class AgentStatusEnum(str, Enum):
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
MAX_ROUNDS_REACHED = "maxRoundsReached"
|
||||
BUDGET_EXCEEDED = "budgetExceeded"
|
||||
ERROR = "error"
|
||||
STOPPED = "stopped"
|
||||
|
||||
|
||||
class AgentEventTypeEnum(str, Enum):
|
||||
MESSAGE = "message"
|
||||
CHUNK = "chunk"
|
||||
TOOL_CALL = "toolCall"
|
||||
TOOL_RESULT = "toolResult"
|
||||
AGENT_PROGRESS = "agentProgress"
|
||||
AGENT_SUMMARY = "agentSummary"
|
||||
FILE_CREATED = "fileCreated"
|
||||
FILE_UPDATED = "fileUpdated"
|
||||
FILE_EDIT_PROPOSAL = "fileEditProposal"
|
||||
FILE_VERSION = "fileVersion"
|
||||
FILE_EDIT_REJECTED = "fileEditRejected"
|
||||
DATA_SOURCE_ACCESS = "dataSourceAccess"
|
||||
VOICE_RESPONSE = "voiceResponse"
|
||||
FINAL = "final"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class ToolDefinition(BaseModel):
|
||||
"""Schema for a tool available to the agent."""
|
||||
name: str = Field(description="Unique tool name")
|
||||
description: str = Field(description="What this tool does")
|
||||
parameters: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="JSON Schema for tool parameters"
|
||||
)
|
||||
readOnly: bool = Field(
|
||||
default=False,
|
||||
description="If True, tool can run in parallel with other readOnly tools"
|
||||
)
|
||||
featureType: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Feature scope for this tool (None = available to all)"
|
||||
)
|
||||
toolSet: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Tool-set scope (None = available to all sets, e.g. 'core', 'workspace')"
|
||||
)
|
||||
|
||||
|
||||
class ToolCallRequest(BaseModel):
|
||||
"""A tool call requested by the AI model."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
name: str
|
||||
args: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""Result from executing a tool."""
|
||||
toolCallId: str
|
||||
toolName: str
|
||||
success: bool = True
|
||||
data: str = ""
|
||||
error: Optional[str] = None
|
||||
durationMs: int = 0
|
||||
sideEvents: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
|
||||
class AgentEvent(BaseModel):
|
||||
"""Event emitted during agent execution for SSE streaming."""
|
||||
type: AgentEventTypeEnum
|
||||
content: Optional[str] = None
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
"""Configuration for an agent run."""
|
||||
maxRounds: int = Field(default=25, ge=1, le=100)
|
||||
maxCostCHF: Optional[float] = Field(default=None, ge=0.0)
|
||||
toolSet: str = Field(default="core")
|
||||
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
|
||||
|
||||
|
||||
class AgentState(BaseModel):
|
||||
"""Tracks state across an agent loop execution."""
|
||||
workflowId: str
|
||||
currentRound: int = 0
|
||||
maxRounds: int = 25
|
||||
totalAiCalls: int = 0
|
||||
totalToolCalls: int = 0
|
||||
totalCostCHF: float = 0.0
|
||||
totalProcessingTime: float = 0.0
|
||||
status: AgentStatusEnum = AgentStatusEnum.RUNNING
|
||||
abortReason: Optional[str] = None
|
||||
|
||||
|
||||
class ToolCallLog(BaseModel):
|
||||
"""Log of a single tool call for observability."""
|
||||
toolName: str
|
||||
args: Dict[str, Any] = Field(default_factory=dict)
|
||||
success: bool = True
|
||||
durationMs: int = 0
|
||||
error: Optional[str] = None
|
||||
resultData: str = Field(default="", description="Short result summary for artifact tracking")
|
||||
|
||||
|
||||
class AgentRoundLog(BaseModel):
|
||||
"""Log of a single agent round for observability."""
|
||||
roundNumber: int
|
||||
aiModel: str = ""
|
||||
inputTokens: int = 0
|
||||
outputTokens: int = 0
|
||||
costCHF: float = 0.0
|
||||
toolCalls: List[ToolCallLog] = Field(default_factory=list)
|
||||
durationMs: int = 0
|
||||
|
||||
|
||||
class AgentTrace(BaseModel):
|
||||
"""Full trace of an agent workflow for observability."""
|
||||
workflowId: str
|
||||
userId: str = ""
|
||||
featureInstanceId: str = ""
|
||||
startedAt: float = Field(default_factory=getUtcTimestamp)
|
||||
completedAt: Optional[float] = None
|
||||
status: AgentStatusEnum = AgentStatusEnum.RUNNING
|
||||
totalRounds: int = 0
|
||||
totalToolCalls: int = 0
|
||||
totalCostCHF: float = 0.0
|
||||
abortReason: Optional[str] = None
|
||||
rounds: List[AgentRoundLog] = Field(default_factory=list)
|
||||
|
||||
|
||||
class PendingFileEdit(BaseModel):
|
||||
"""A proposed file edit awaiting user approval."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
fileId: str
|
||||
fileName: str
|
||||
mimeType: str = ""
|
||||
oldContent: str = ""
|
||||
newContent: str = ""
|
||||
status: str = Field(default="pending", description="pending | accepted | rejected")
|
||||
toolCallId: str = ""
|
||||
workflowId: str = ""
|
||||
253
modules/serviceCenter/services/serviceAgent/featureDataAgent.py
Normal file
253
modules/serviceCenter/services/serviceAgent/featureDataAgent.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Feature Data Sub-Agent.
|
||||
|
||||
Specialized mini-agent that queries feature-instance data tables. Receives
|
||||
schema context (fields, descriptions) for the selected tables and has two
|
||||
tools: browseTable and queryTable. Runs its own agent loop (max 5 rounds,
|
||||
low budget) and returns structured results back to the main agent.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable, Awaitable, Dict, List, Optional
|
||||
|
||||
from modules.datamodels.datamodelAi import (
|
||||
AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum,
|
||||
)
|
||||
from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop
|
||||
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
|
||||
AgentConfig, AgentEvent, AgentEventTypeEnum, ToolResult,
|
||||
)
|
||||
from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry
|
||||
from modules.serviceCenter.services.serviceAgent.featureDataProvider import FeatureDataProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MAX_ROUNDS = 5
|
||||
_MAX_COST_CHF = 0.10
|
||||
|
||||
|
||||
async def runFeatureDataAgent(
|
||||
question: str,
|
||||
featureInstanceId: str,
|
||||
featureCode: str,
|
||||
selectedTables: List[Dict[str, Any]],
|
||||
mandateId: str,
|
||||
userId: str,
|
||||
aiCallFn: Callable[[AiCallRequest], Awaitable[AiCallResponse]],
|
||||
dbConnector,
|
||||
instanceLabel: str = "",
|
||||
) -> str:
|
||||
"""Run the feature data sub-agent and return the textual result.
|
||||
|
||||
Args:
|
||||
question: The user/main-agent question to answer using feature data.
|
||||
featureInstanceId: Feature instance to scope queries.
|
||||
featureCode: Feature code (trustee, commcoach, ...).
|
||||
selectedTables: List of DATA_OBJECT dicts the user selected.
|
||||
mandateId: Mandate scope.
|
||||
userId: Calling user ID.
|
||||
aiCallFn: AI call function (with billing).
|
||||
dbConnector: DatabaseConnector for queries.
|
||||
instanceLabel: Human-readable instance name for context.
|
||||
|
||||
Returns:
|
||||
Plain-text answer produced by the sub-agent.
|
||||
"""
|
||||
|
||||
provider = FeatureDataProvider(dbConnector)
|
||||
registry = _buildSubAgentTools(provider, featureInstanceId, mandateId)
|
||||
|
||||
for tbl in selectedTables:
|
||||
meta = tbl.get("meta", {})
|
||||
tableName = meta.get("table", "")
|
||||
if tableName:
|
||||
realCols = provider.getActualColumns(tableName)
|
||||
if realCols:
|
||||
meta["fields"] = realCols
|
||||
|
||||
schemaContext = _buildSchemaContext(featureCode, instanceLabel, selectedTables)
|
||||
prompt = f"{schemaContext}\n\nUser question:\n{question}"
|
||||
|
||||
config = AgentConfig(maxRounds=_MAX_ROUNDS, maxCostCHF=_MAX_COST_CHF)
|
||||
|
||||
async def _getWorkflowCost() -> float:
|
||||
return 0.0
|
||||
|
||||
result = ""
|
||||
async for event in runAgentLoop(
|
||||
prompt=prompt,
|
||||
toolRegistry=registry,
|
||||
config=config,
|
||||
aiCallFn=aiCallFn,
|
||||
getWorkflowCostFn=_getWorkflowCost,
|
||||
workflowId=f"fda-{featureInstanceId[:8]}",
|
||||
userId=userId,
|
||||
featureInstanceId=featureInstanceId,
|
||||
mandateId=mandateId,
|
||||
):
|
||||
if event.type == AgentEventTypeEnum.FINAL and event.content:
|
||||
result = event.content
|
||||
elif event.type == AgentEventTypeEnum.MESSAGE and event.content:
|
||||
result += event.content
|
||||
|
||||
return result or "(no data returned by feature agent)"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# tool registration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _buildSubAgentTools(
|
||||
provider: FeatureDataProvider,
|
||||
featureInstanceId: str,
|
||||
mandateId: str,
|
||||
) -> ToolRegistry:
|
||||
"""Register browseTable and queryTable as sub-agent tools."""
|
||||
registry = ToolRegistry()
|
||||
|
||||
async def _browseTable(args: Dict[str, Any], context: Dict[str, Any]):
|
||||
tableName = args.get("tableName", "")
|
||||
limit = args.get("limit", 50)
|
||||
offset = args.get("offset", 0)
|
||||
fields = args.get("fields")
|
||||
if not tableName:
|
||||
return ToolResult(toolCallId="", toolName="browseTable", success=False, error="tableName required")
|
||||
result = provider.browseTable(
|
||||
tableName=tableName,
|
||||
featureInstanceId=featureInstanceId,
|
||||
mandateId=mandateId,
|
||||
fields=fields,
|
||||
limit=min(limit, 200),
|
||||
offset=offset,
|
||||
)
|
||||
return ToolResult(
|
||||
toolCallId="", toolName="browseTable",
|
||||
success="error" not in result,
|
||||
data=json.dumps(result, default=str, ensure_ascii=False)[:30000],
|
||||
error=result.get("error"),
|
||||
)
|
||||
|
||||
async def _queryTable(args: Dict[str, Any], context: Dict[str, Any]):
|
||||
tableName = args.get("tableName", "")
|
||||
filters = args.get("filters", [])
|
||||
fields = args.get("fields")
|
||||
orderBy = args.get("orderBy")
|
||||
limit = args.get("limit", 50)
|
||||
offset = args.get("offset", 0)
|
||||
if not tableName:
|
||||
return ToolResult(toolCallId="", toolName="queryTable", success=False, error="tableName required")
|
||||
result = provider.queryTable(
|
||||
tableName=tableName,
|
||||
featureInstanceId=featureInstanceId,
|
||||
mandateId=mandateId,
|
||||
filters=filters,
|
||||
fields=fields,
|
||||
orderBy=orderBy,
|
||||
limit=min(limit, 200),
|
||||
offset=offset,
|
||||
)
|
||||
return ToolResult(
|
||||
toolCallId="", toolName="queryTable",
|
||||
success="error" not in result,
|
||||
data=json.dumps(result, default=str, ensure_ascii=False)[:30000],
|
||||
error=result.get("error"),
|
||||
)
|
||||
|
||||
registry.register(
|
||||
"browseTable", _browseTable,
|
||||
description="List rows from a feature data table with pagination.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tableName": {"type": "string", "description": "Name of the table to browse"},
|
||||
"fields": {
|
||||
"type": "array", "items": {"type": "string"},
|
||||
"description": "Optional list of fields to return (default: all)",
|
||||
},
|
||||
"limit": {"type": "integer", "description": "Max rows to return (default 50, max 200)"},
|
||||
"offset": {"type": "integer", "description": "Row offset for pagination"},
|
||||
},
|
||||
"required": ["tableName"],
|
||||
},
|
||||
readOnly=True,
|
||||
)
|
||||
|
||||
registry.register(
|
||||
"queryTable", _queryTable,
|
||||
description=(
|
||||
"Query a feature data table with filters, field selection, and ordering. "
|
||||
"Filters: [{\"field\": \"status\", \"op\": \"=\", \"value\": \"active\"}]. "
|
||||
"Operators: =, !=, >, <, >=, <=, LIKE, ILIKE, IS NULL, IS NOT NULL."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tableName": {"type": "string", "description": "Name of the table to query"},
|
||||
"filters": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {"type": "string"},
|
||||
"op": {"type": "string"},
|
||||
"value": {},
|
||||
},
|
||||
},
|
||||
"description": "Filter conditions",
|
||||
},
|
||||
"fields": {
|
||||
"type": "array", "items": {"type": "string"},
|
||||
"description": "Optional list of fields to return",
|
||||
},
|
||||
"orderBy": {"type": "string", "description": "Field name to order by"},
|
||||
"limit": {"type": "integer", "description": "Max rows (default 50, max 200)"},
|
||||
"offset": {"type": "integer", "description": "Row offset"},
|
||||
},
|
||||
"required": ["tableName"],
|
||||
},
|
||||
readOnly=True,
|
||||
)
|
||||
|
||||
return registry
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# context building
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _buildSchemaContext(
|
||||
featureCode: str,
|
||||
instanceLabel: str,
|
||||
selectedTables: List[Dict[str, Any]],
|
||||
) -> str:
|
||||
"""Build a system-level context block describing available tables."""
|
||||
parts = [
|
||||
f"You are a data query assistant for the '{featureCode}' feature",
|
||||
]
|
||||
if instanceLabel:
|
||||
parts[0] += f' (instance: "{instanceLabel}")'
|
||||
parts[0] += "."
|
||||
parts.append(
|
||||
"You have access to the following data tables. "
|
||||
"Use browseTable to list rows and queryTable to filter/search."
|
||||
)
|
||||
parts.append("")
|
||||
|
||||
for obj in selectedTables:
|
||||
meta = obj.get("meta", {})
|
||||
tbl = meta.get("table", "?")
|
||||
fields = meta.get("fields", [])
|
||||
label = obj.get("label", {})
|
||||
labelStr = label.get("en") or label.get("de") or tbl
|
||||
parts.append(f"Table: {tbl} ({labelStr})")
|
||||
if fields:
|
||||
parts.append(f" Fields: {', '.join(fields)}")
|
||||
parts.append("")
|
||||
|
||||
parts.append(
|
||||
"Answer the user's question using the data from these tables. "
|
||||
"Be precise, cite row counts, and format data clearly."
|
||||
)
|
||||
return "\n".join(parts)
|
||||
|
|
@ -0,0 +1,215 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Generic data provider for querying feature-instance tables.
|
||||
|
||||
Uses the RBAC catalog's DATA_OBJECTS metadata (table name, fields) and the
|
||||
DB connector to execute scoped, read-only queries against any registered
|
||||
feature table. All queries are automatically filtered by featureInstanceId
|
||||
and mandateId so data isolation is guaranteed.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ALLOWED_OPERATORS = {"=", "!=", ">", "<", ">=", "<=", "LIKE", "ILIKE", "IS NULL", "IS NOT NULL"}
|
||||
|
||||
|
||||
class FeatureDataProvider:
|
||||
"""Reads feature-instance data from the DB using DATA_OBJECTS metadata."""
|
||||
|
||||
def __init__(self, dbConnector):
|
||||
"""
|
||||
Args:
|
||||
dbConnector: A connectorDbPostgre.DatabaseConnector with an open connection.
|
||||
"""
|
||||
self._db = dbConnector
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# public API (called by FeatureDataAgent tools)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def getAvailableTables(self, featureCode: str) -> List[Dict[str, Any]]:
|
||||
"""Return DATA_OBJECTS registered for *featureCode*."""
|
||||
from modules.security.rbacCatalog import getCatalogService
|
||||
catalog = getCatalogService()
|
||||
return catalog.getDataObjects(featureCode)
|
||||
|
||||
def getTableSchema(self, featureCode: str, tableName: str) -> Optional[Dict[str, Any]]:
|
||||
"""Return the DATA_OBJECT entry for a specific table."""
|
||||
for obj in self.getAvailableTables(featureCode):
|
||||
if obj.get("meta", {}).get("table") == tableName:
|
||||
return obj
|
||||
return None
|
||||
|
||||
def getActualColumns(self, tableName: str) -> List[str]:
|
||||
"""Read real column names from PostgreSQL information_schema."""
|
||||
try:
|
||||
conn = self._db.connection
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT column_name FROM information_schema.columns "
|
||||
"WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
|
||||
"ORDER BY ordinal_position",
|
||||
[tableName],
|
||||
)
|
||||
cols = [row["column_name"] for row in cur.fetchall()]
|
||||
return [c for c in cols if not c.startswith("_")]
|
||||
except Exception as e:
|
||||
logger.warning(f"getActualColumns({tableName}) failed: {e}")
|
||||
return []
|
||||
|
||||
def browseTable(
|
||||
self,
|
||||
tableName: str,
|
||||
featureInstanceId: str,
|
||||
mandateId: str,
|
||||
fields: List[str] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> Dict[str, Any]:
|
||||
"""List rows from a feature table with pagination.
|
||||
|
||||
Returns ``{"rows": [...], "total": N, "limit": L, "offset": O}``.
|
||||
"""
|
||||
_validateTableName(tableName)
|
||||
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId)
|
||||
|
||||
try:
|
||||
conn = self._db.connection
|
||||
with conn.cursor() as cur:
|
||||
countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {scopeFilter["where"]}'
|
||||
cur.execute(countSql, scopeFilter["params"])
|
||||
total = cur.fetchone()["count"] if cur.rowcount else 0
|
||||
|
||||
selectCols = ", ".join(f'"{f}"' for f in fields) if fields else "*"
|
||||
dataSql = (
|
||||
f'SELECT {selectCols} FROM "{tableName}" '
|
||||
f'WHERE {scopeFilter["where"]} '
|
||||
f'ORDER BY "id" LIMIT %s OFFSET %s'
|
||||
)
|
||||
cur.execute(dataSql, scopeFilter["params"] + [limit, offset])
|
||||
rows = [_serializeRow(dict(r)) for r in cur.fetchall()]
|
||||
|
||||
return {"rows": rows, "total": total, "limit": limit, "offset": offset}
|
||||
except Exception as e:
|
||||
logger.error(f"browseTable({tableName}) failed: {e}")
|
||||
return {"rows": [], "total": 0, "limit": limit, "offset": offset, "error": str(e)}
|
||||
|
||||
def queryTable(
|
||||
self,
|
||||
tableName: str,
|
||||
featureInstanceId: str,
|
||||
mandateId: str,
|
||||
filters: List[Dict[str, Any]] = None,
|
||||
fields: List[str] = None,
|
||||
orderBy: str = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> Dict[str, Any]:
|
||||
"""Query a feature table with optional filters.
|
||||
|
||||
``filters`` is a list of ``{"field": "x", "op": "=", "value": "y"}``.
|
||||
"""
|
||||
_validateTableName(tableName)
|
||||
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId)
|
||||
extraWhere, extraParams = _buildFilterClauses(filters)
|
||||
|
||||
fullWhere = scopeFilter["where"]
|
||||
allParams = list(scopeFilter["params"])
|
||||
if extraWhere:
|
||||
fullWhere += " AND " + extraWhere
|
||||
allParams.extend(extraParams)
|
||||
|
||||
try:
|
||||
conn = self._db.connection
|
||||
with conn.cursor() as cur:
|
||||
countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}'
|
||||
cur.execute(countSql, allParams)
|
||||
total = cur.fetchone()["count"] if cur.rowcount else 0
|
||||
|
||||
selectCols = ", ".join(f'"{f}"' for f in fields) if fields else "*"
|
||||
orderClause = f'ORDER BY "{orderBy}"' if orderBy and _isValidIdentifier(orderBy) else 'ORDER BY "id"'
|
||||
dataSql = (
|
||||
f'SELECT {selectCols} FROM "{tableName}" '
|
||||
f'WHERE {fullWhere} {orderClause} LIMIT %s OFFSET %s'
|
||||
)
|
||||
cur.execute(dataSql, allParams + [limit, offset])
|
||||
rows = [_serializeRow(dict(r)) for r in cur.fetchall()]
|
||||
|
||||
return {"rows": rows, "total": total, "limit": limit, "offset": offset}
|
||||
except Exception as e:
|
||||
logger.error(f"queryTable({tableName}) failed: {e}")
|
||||
return {"rows": [], "total": 0, "limit": limit, "offset": offset, "error": str(e)}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _validateTableName(tableName: str):
|
||||
if not tableName or not _isValidIdentifier(tableName):
|
||||
raise ValueError(f"Invalid table name: {tableName}")
|
||||
|
||||
|
||||
def _isValidIdentifier(name: str) -> bool:
|
||||
"""Only allow alphanumeric + underscore to prevent SQL injection."""
|
||||
return name.isidentifier()
|
||||
|
||||
|
||||
def _buildScopeFilter(tableName: str, featureInstanceId: str, mandateId: str) -> Dict[str, Any]:
|
||||
"""Build the mandatory WHERE clause that scopes rows to the feature instance.
|
||||
|
||||
Feature tables usually have either ``featureInstanceId`` or a combination
|
||||
of ``mandateId`` + an org/context FK. We try ``featureInstanceId`` first,
|
||||
then fall back to ``mandateId``.
|
||||
"""
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
conditions.append('"featureInstanceId" = %s')
|
||||
params.append(featureInstanceId)
|
||||
|
||||
if mandateId:
|
||||
conditions.append('"mandateId" = %s')
|
||||
params.append(mandateId)
|
||||
|
||||
return {"where": " AND ".join(conditions), "params": params}
|
||||
|
||||
|
||||
def _buildFilterClauses(filters: Optional[List[Dict[str, Any]]]) -> tuple:
|
||||
"""Convert agent-provided filter dicts into safe SQL."""
|
||||
if not filters:
|
||||
return "", []
|
||||
|
||||
parts = []
|
||||
params = []
|
||||
for f in filters:
|
||||
field = f.get("field", "")
|
||||
op = (f.get("op") or "=").upper()
|
||||
value = f.get("value")
|
||||
|
||||
if not field or not _isValidIdentifier(field):
|
||||
continue
|
||||
if op not in _ALLOWED_OPERATORS:
|
||||
continue
|
||||
|
||||
if op in ("IS NULL", "IS NOT NULL"):
|
||||
parts.append(f'"{field}" {op}')
|
||||
else:
|
||||
parts.append(f'"{field}" {op} %s')
|
||||
params.append(value)
|
||||
|
||||
return " AND ".join(parts), params
|
||||
|
||||
|
||||
def _serializeRow(row: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Ensure all values are JSON-serializable."""
|
||||
for k, v in row.items():
|
||||
if isinstance(v, (bytes, bytearray)):
|
||||
row[k] = f"<binary {len(v)} bytes>"
|
||||
elif hasattr(v, "isoformat"):
|
||||
row[k] = v.isoformat()
|
||||
return row
|
||||
2954
modules/serviceCenter/services/serviceAgent/mainServiceAgent.py
Normal file
2954
modules/serviceCenter/services/serviceAgent/mainServiceAgent.py
Normal file
File diff suppressed because it is too large
Load diff
108
modules/serviceCenter/services/serviceAgent/sandboxExecutor.py
Normal file
108
modules/serviceCenter/services/serviceAgent/sandboxExecutor.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Sandboxed code execution for the AI agent executeCode tool."""
|
||||
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import io
|
||||
import traceback
|
||||
from typing import Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PYTHON_ALLOWED_MODULES = {
|
||||
"math", "statistics", "json", "csv", "re", "datetime",
|
||||
"collections", "itertools", "functools", "decimal", "fractions",
|
||||
"random", "string", "textwrap", "operator", "copy",
|
||||
}
|
||||
|
||||
_PYTHON_BLOCKED_BUILTINS = {
|
||||
"open", "exec", "eval", "compile", "__import__", "globals", "locals",
|
||||
"getattr", "setattr", "delattr", "breakpoint", "exit", "quit",
|
||||
"input", "memoryview", "type",
|
||||
}
|
||||
|
||||
_MAX_EXECUTION_TIME_S = 30
|
||||
_MAX_OUTPUT_CHARS = 50000
|
||||
|
||||
|
||||
def _safeImport(name, *args, **kwargs):
|
||||
"""Restricted import that only allows whitelisted modules."""
|
||||
if name not in _PYTHON_ALLOWED_MODULES:
|
||||
raise ImportError(f"Module '{name}' is not allowed. Permitted: {', '.join(sorted(_PYTHON_ALLOWED_MODULES))}")
|
||||
return __builtins__["__import__"](name, *args, **kwargs) if isinstance(__builtins__, dict) else __import__(name, *args, **kwargs)
|
||||
|
||||
|
||||
def _buildRestrictedGlobals() -> Dict[str, Any]:
|
||||
"""Build a restricted globals dict for exec()."""
|
||||
import builtins
|
||||
safeBuiltins = {}
|
||||
for name in dir(builtins):
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
if name in _PYTHON_BLOCKED_BUILTINS:
|
||||
continue
|
||||
safeBuiltins[name] = getattr(builtins, name)
|
||||
|
||||
safeBuiltins["__import__"] = _safeImport
|
||||
safeBuiltins["__name__"] = "__sandbox__"
|
||||
safeBuiltins["__builtins__"] = safeBuiltins
|
||||
|
||||
for modName in _PYTHON_ALLOWED_MODULES:
|
||||
try:
|
||||
safeBuiltins[modName] = __import__(modName)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return {"__builtins__": safeBuiltins}
|
||||
|
||||
|
||||
async def executePython(code: str) -> Dict[str, Any]:
|
||||
"""Execute Python code in a restricted sandbox. Returns {success, output, error}."""
|
||||
import asyncio
|
||||
|
||||
def _run():
|
||||
restrictedGlobals = _buildRestrictedGlobals()
|
||||
capturedOutput = io.StringIO()
|
||||
oldStdout = sys.stdout
|
||||
oldStderr = sys.stderr
|
||||
|
||||
try:
|
||||
sys.stdout = capturedOutput
|
||||
sys.stderr = capturedOutput
|
||||
|
||||
if sys.platform != "win32":
|
||||
signal.signal(signal.SIGALRM, lambda *_: (_ for _ in ()).throw(TimeoutError("Execution timed out")))
|
||||
signal.alarm(_MAX_EXECUTION_TIME_S)
|
||||
|
||||
exec(compile(code, "<sandbox>", "exec"), restrictedGlobals)
|
||||
|
||||
if sys.platform != "win32":
|
||||
signal.alarm(0)
|
||||
|
||||
output = capturedOutput.getvalue()
|
||||
if len(output) > _MAX_OUTPUT_CHARS:
|
||||
output = output[:_MAX_OUTPUT_CHARS] + f"\n... (truncated at {_MAX_OUTPUT_CHARS} chars)"
|
||||
return {"success": True, "output": output}
|
||||
|
||||
except TimeoutError:
|
||||
return {"success": False, "error": f"Execution timed out after {_MAX_EXECUTION_TIME_S}s"}
|
||||
except Exception as e:
|
||||
tb = traceback.format_exc()
|
||||
return {"success": False, "error": f"{type(e).__name__}: {e}", "traceback": tb}
|
||||
finally:
|
||||
sys.stdout = oldStdout
|
||||
sys.stderr = oldStderr
|
||||
if sys.platform != "win32":
|
||||
signal.alarm(0)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
loop.run_in_executor(None, _run),
|
||||
timeout=_MAX_EXECUTION_TIME_S + 5,
|
||||
)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
return {"success": False, "error": f"Execution timed out after {_MAX_EXECUTION_TIME_S}s"}
|
||||
154
modules/serviceCenter/services/serviceAgent/toolRegistry.py
Normal file
154
modules/serviceCenter/services/serviceAgent/toolRegistry.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Tool registry for the Agent service. Manages tool definitions and dispatch."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional, Callable, Awaitable
|
||||
|
||||
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
|
||||
ToolDefinition, ToolCallRequest, ToolResult
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""Registry for agent tools. Handles registration, lookup, and dispatch."""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: Dict[str, ToolDefinition] = {}
|
||||
self._handlers: Dict[str, Callable[..., Awaitable[ToolResult]]] = {}
|
||||
|
||||
def register(self, name: str, handler: Callable[..., Awaitable[ToolResult]],
|
||||
description: str = "", parameters: Dict[str, Any] = None,
|
||||
readOnly: bool = False, featureType: str = None,
|
||||
toolSet: str = None):
|
||||
"""Register a tool with its handler function."""
|
||||
if name in self._tools:
|
||||
logger.warning(f"Tool '{name}' already registered, overwriting")
|
||||
|
||||
self._tools[name] = ToolDefinition(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters=parameters or {},
|
||||
readOnly=readOnly,
|
||||
featureType=featureType,
|
||||
toolSet=toolSet,
|
||||
)
|
||||
self._handlers[name] = handler
|
||||
logger.debug(f"Registered tool: {name} (readOnly={readOnly}, toolSet={toolSet})")
|
||||
|
||||
def registerFromDefinition(self, definition: ToolDefinition,
|
||||
handler: Callable[..., Awaitable[ToolResult]]):
|
||||
"""Register a tool from a pre-built ToolDefinition."""
|
||||
self._tools[definition.name] = definition
|
||||
self._handlers[definition.name] = handler
|
||||
logger.debug(f"Registered tool: {definition.name} (readOnly={definition.readOnly})")
|
||||
|
||||
def unregister(self, name: str):
|
||||
"""Remove a tool from the registry."""
|
||||
self._tools.pop(name, None)
|
||||
self._handlers.pop(name, None)
|
||||
|
||||
def getTools(self, toolSet: str = None, featureType: str = None) -> List[ToolDefinition]:
|
||||
"""Get available tools, optionally filtered by toolSet or featureType."""
|
||||
tools = list(self._tools.values())
|
||||
if featureType:
|
||||
tools = [t for t in tools if t.featureType is None or t.featureType == featureType]
|
||||
if toolSet:
|
||||
tools = [t for t in tools if t.toolSet is None or t.toolSet == toolSet]
|
||||
return tools
|
||||
|
||||
def getToolNames(self) -> List[str]:
|
||||
"""Get names of all registered tools."""
|
||||
return list(self._tools.keys())
|
||||
|
||||
def getTool(self, name: str) -> Optional[ToolDefinition]:
|
||||
"""Get a single tool definition by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def isReadOnly(self, name: str) -> bool:
|
||||
"""Check if a tool is marked as readOnly."""
|
||||
tool = self._tools.get(name)
|
||||
return tool.readOnly if tool else False
|
||||
|
||||
def isValidTool(self, name: str) -> bool:
|
||||
"""Check if a tool name is valid (registered)."""
|
||||
return name in self._tools
|
||||
|
||||
async def dispatch(self, toolCall: ToolCallRequest, context: Dict[str, Any] = None) -> ToolResult:
|
||||
"""Execute a tool call and return the result."""
|
||||
startTime = time.time()
|
||||
|
||||
if not self.isValidTool(toolCall.name):
|
||||
return ToolResult(
|
||||
toolCallId=toolCall.id,
|
||||
toolName=toolCall.name,
|
||||
success=False,
|
||||
error=f"Unknown tool: '{toolCall.name}'. Available: {', '.join(self.getToolNames())}"
|
||||
)
|
||||
|
||||
handler = self._handlers[toolCall.name]
|
||||
argsSummary = ", ".join(f"{k}={str(v)[:80]}" for k, v in (toolCall.args or {}).items())
|
||||
logger.info(f"Tool dispatch: {toolCall.name}({argsSummary})")
|
||||
try:
|
||||
result = await handler(toolCall.args, context or {})
|
||||
durationMs = int((time.time() - startTime) * 1000)
|
||||
|
||||
if isinstance(result, ToolResult):
|
||||
result.toolCallId = toolCall.id
|
||||
result.durationMs = durationMs
|
||||
dataSummary = (result.data[:200] + "...") if result.data and len(result.data) > 200 else (result.data or "")
|
||||
if result.success:
|
||||
logger.info(f"Tool result: {toolCall.name} OK ({durationMs}ms) → {dataSummary}")
|
||||
else:
|
||||
logger.warning(f"Tool result: {toolCall.name} FAILED ({durationMs}ms) → {result.error}")
|
||||
return result
|
||||
|
||||
return ToolResult(
|
||||
toolCallId=toolCall.id,
|
||||
toolName=toolCall.name,
|
||||
success=True,
|
||||
data=str(result),
|
||||
durationMs=durationMs
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
durationMs = int((time.time() - startTime) * 1000)
|
||||
logger.error(f"Tool '{toolCall.name}' failed: {e}", exc_info=True)
|
||||
return ToolResult(
|
||||
toolCallId=toolCall.id,
|
||||
toolName=toolCall.name,
|
||||
success=False,
|
||||
error=str(e),
|
||||
durationMs=durationMs
|
||||
)
|
||||
|
||||
def formatToolsForPrompt(self) -> str:
|
||||
"""Format all tools as text for system prompt (text-based fallback)."""
|
||||
parts = []
|
||||
for tool in self._tools.values():
|
||||
paramStr = ", ".join(
|
||||
f"{k}: {v}" for k, v in tool.parameters.items()
|
||||
) if tool.parameters else "none"
|
||||
parts.append(f"- **{tool.name}**: {tool.description}\n Parameters: {{{paramStr}}}")
|
||||
return "\n".join(parts)
|
||||
|
||||
def formatToolsForFunctionCalling(self) -> List[Dict[str, Any]]:
|
||||
"""Format all tools as OpenAI-compatible function definitions for native function calling."""
|
||||
functions = []
|
||||
for tool in self._tools.values():
|
||||
functions.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters if tool.parameters else {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
})
|
||||
return functions
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue