new ai agent
This commit is contained in:
parent
c8b7517209
commit
7fe6f9bc97
58 changed files with 8297 additions and 293 deletions
|
|
@ -12,8 +12,8 @@ IMPORTANT: Model Registration Requirements
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional, AsyncGenerator, Union
|
||||||
from modules.datamodels.datamodelAi import AiModel
|
from modules.datamodels.datamodelAi import AiModel, AiModelCall, AiModelResponse
|
||||||
|
|
||||||
|
|
||||||
class BaseConnectorAi(ABC):
|
class BaseConnectorAi(ABC):
|
||||||
|
|
@ -102,3 +102,24 @@ class BaseConnectorAi(ABC):
|
||||||
"""Get only available models."""
|
"""Get only available models."""
|
||||||
models = self.getCachedModels()
|
models = self.getCachedModels()
|
||||||
return [model for model in models if model.isAvailable]
|
return [model for model in models if model.isAvailable]
|
||||||
|
|
||||||
|
async def callAiBasicStream(self, modelCall: AiModelCall) -> AsyncGenerator[Union[str, AiModelResponse], None]:
|
||||||
|
"""Stream AI response. Yields str deltas during generation, then final AiModelResponse.
|
||||||
|
|
||||||
|
Default implementation: falls back to non-streaming callAiBasic.
|
||||||
|
Override in connectors that support streaming.
|
||||||
|
"""
|
||||||
|
response = await self.callAiBasic(modelCall)
|
||||||
|
if response.content:
|
||||||
|
yield response.content
|
||||||
|
yield response
|
||||||
|
|
||||||
|
async def callEmbedding(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||||
|
"""Generate embeddings for input texts. Override in connectors that support embeddings.
|
||||||
|
|
||||||
|
Reads texts from modelCall.embeddingInput.
|
||||||
|
Returns AiModelResponse with metadata["embeddings"] containing the vectors.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} does not support embeddings"
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
# Copyright (c) 2025 Patrick Motsch
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import httpx
|
import httpx
|
||||||
import os
|
import os
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List, AsyncGenerator, Union
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from modules.shared.configuration import APP_CONFIG
|
from modules.shared.configuration import APP_CONFIG
|
||||||
from .aicoreBase import BaseConnectorAi
|
from .aicoreBase import BaseConnectorAi
|
||||||
|
|
@ -61,13 +62,15 @@ class AiAnthropic(BaseConnectorAi):
|
||||||
speedRating=6, # Slower due to high-quality processing
|
speedRating=6, # Slower due to high-quality processing
|
||||||
qualityRating=10, # Best quality available
|
qualityRating=10, # Best quality available
|
||||||
functionCall=self.callAiBasic,
|
functionCall=self.callAiBasic,
|
||||||
|
functionCallStream=self.callAiBasicStream,
|
||||||
priority=PriorityEnum.QUALITY,
|
priority=PriorityEnum.QUALITY,
|
||||||
processingMode=ProcessingModeEnum.DETAILED,
|
processingMode=ProcessingModeEnum.DETAILED,
|
||||||
operationTypes=createOperationTypeRatings(
|
operationTypes=createOperationTypeRatings(
|
||||||
(OperationTypeEnum.PLAN, 9),
|
(OperationTypeEnum.PLAN, 9),
|
||||||
(OperationTypeEnum.DATA_ANALYSE, 9),
|
(OperationTypeEnum.DATA_ANALYSE, 9),
|
||||||
(OperationTypeEnum.DATA_GENERATE, 9),
|
(OperationTypeEnum.DATA_GENERATE, 9),
|
||||||
(OperationTypeEnum.DATA_EXTRACT, 8)
|
(OperationTypeEnum.DATA_EXTRACT, 8),
|
||||||
|
(OperationTypeEnum.AGENT, 9),
|
||||||
),
|
),
|
||||||
version="claude-sonnet-4-5-20250929",
|
version="claude-sonnet-4-5-20250929",
|
||||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.003 + (bytesReceived / 4 / 1000) * 0.015
|
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.003 + (bytesReceived / 4 / 1000) * 0.015
|
||||||
|
|
@ -85,13 +88,15 @@ class AiAnthropic(BaseConnectorAi):
|
||||||
speedRating=9, # Very fast, lightweight model
|
speedRating=9, # Very fast, lightweight model
|
||||||
qualityRating=8, # Good quality, cost-efficient
|
qualityRating=8, # Good quality, cost-efficient
|
||||||
functionCall=self.callAiBasic,
|
functionCall=self.callAiBasic,
|
||||||
|
functionCallStream=self.callAiBasicStream,
|
||||||
priority=PriorityEnum.SPEED,
|
priority=PriorityEnum.SPEED,
|
||||||
processingMode=ProcessingModeEnum.BASIC,
|
processingMode=ProcessingModeEnum.BASIC,
|
||||||
operationTypes=createOperationTypeRatings(
|
operationTypes=createOperationTypeRatings(
|
||||||
(OperationTypeEnum.PLAN, 8),
|
(OperationTypeEnum.PLAN, 8),
|
||||||
(OperationTypeEnum.DATA_ANALYSE, 8),
|
(OperationTypeEnum.DATA_ANALYSE, 8),
|
||||||
(OperationTypeEnum.DATA_GENERATE, 8),
|
(OperationTypeEnum.DATA_GENERATE, 8),
|
||||||
(OperationTypeEnum.DATA_EXTRACT, 7)
|
(OperationTypeEnum.DATA_EXTRACT, 7),
|
||||||
|
(OperationTypeEnum.AGENT, 7),
|
||||||
),
|
),
|
||||||
version="claude-haiku-4-5-20251001",
|
version="claude-haiku-4-5-20251001",
|
||||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.001 + (bytesReceived / 4 / 1000) * 0.005
|
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.001 + (bytesReceived / 4 / 1000) * 0.005
|
||||||
|
|
@ -109,13 +114,15 @@ class AiAnthropic(BaseConnectorAi):
|
||||||
speedRating=5, # Moderate latency, most capable
|
speedRating=5, # Moderate latency, most capable
|
||||||
qualityRating=10, # Top-tier intelligence
|
qualityRating=10, # Top-tier intelligence
|
||||||
functionCall=self.callAiBasic,
|
functionCall=self.callAiBasic,
|
||||||
|
functionCallStream=self.callAiBasicStream,
|
||||||
priority=PriorityEnum.QUALITY,
|
priority=PriorityEnum.QUALITY,
|
||||||
processingMode=ProcessingModeEnum.DETAILED,
|
processingMode=ProcessingModeEnum.DETAILED,
|
||||||
operationTypes=createOperationTypeRatings(
|
operationTypes=createOperationTypeRatings(
|
||||||
(OperationTypeEnum.PLAN, 10),
|
(OperationTypeEnum.PLAN, 10),
|
||||||
(OperationTypeEnum.DATA_ANALYSE, 8),
|
(OperationTypeEnum.DATA_ANALYSE, 8),
|
||||||
(OperationTypeEnum.DATA_GENERATE, 10),
|
(OperationTypeEnum.DATA_GENERATE, 10),
|
||||||
(OperationTypeEnum.DATA_EXTRACT, 9)
|
(OperationTypeEnum.DATA_EXTRACT, 9),
|
||||||
|
(OperationTypeEnum.AGENT, 10),
|
||||||
),
|
),
|
||||||
version="claude-opus-4-6",
|
version="claude-opus-4-6",
|
||||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.005 + (bytesReceived / 4 / 1000) * 0.025
|
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.005 + (bytesReceived / 4 / 1000) * 0.025
|
||||||
|
|
@ -158,53 +165,15 @@ class AiAnthropic(BaseConnectorAi):
|
||||||
HTTPException: For errors in API communication
|
HTTPException: For errors in API communication
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Extract parameters from modelCall
|
|
||||||
messages = modelCall.messages
|
|
||||||
model = modelCall.model
|
model = modelCall.model
|
||||||
options = modelCall.options
|
options = modelCall.options
|
||||||
temperature = getattr(options, "temperature", None)
|
temperature = getattr(options, "temperature", None)
|
||||||
if temperature is None:
|
if temperature is None:
|
||||||
temperature = model.temperature
|
temperature = model.temperature
|
||||||
maxTokens = model.maxTokens
|
maxTokens = model.maxTokens
|
||||||
|
|
||||||
# Transform OpenAI-style messages to Anthropic format:
|
|
||||||
# - Move any 'system' role content to top-level 'system'
|
|
||||||
# - Keep only 'user'/'assistant' messages in the list
|
|
||||||
system_contents: List[str] = []
|
|
||||||
converted_messages: List[Dict[str, Any]] = []
|
|
||||||
for m in messages:
|
|
||||||
role = m.get("role")
|
|
||||||
content = m.get("content", "")
|
|
||||||
if role == "system":
|
|
||||||
# Collect system content; Anthropic expects top-level 'system'
|
|
||||||
if isinstance(content, list):
|
|
||||||
# Join text parts if provided as blocks
|
|
||||||
joined = "\n\n".join(
|
|
||||||
[
|
|
||||||
(part.get("text") if isinstance(part, dict) else str(part))
|
|
||||||
for part in content
|
|
||||||
]
|
|
||||||
)
|
|
||||||
system_contents.append(joined)
|
|
||||||
else:
|
|
||||||
system_contents.append(str(content))
|
|
||||||
continue
|
|
||||||
# For Anthropic, content can be a string; pass through strings, collapse blocks
|
|
||||||
if isinstance(content, list):
|
|
||||||
# Collapse to text if blocks are provided
|
|
||||||
collapsed = "\n\n".join(
|
|
||||||
[
|
|
||||||
(part.get("text") if isinstance(part, dict) else str(part))
|
|
||||||
for part in content
|
|
||||||
]
|
|
||||||
)
|
|
||||||
converted_messages.append({"role": role, "content": collapsed})
|
|
||||||
else:
|
|
||||||
converted_messages.append({"role": role, "content": content})
|
|
||||||
|
|
||||||
system_prompt = "\n\n".join([s for s in system_contents if s]) if system_contents else None
|
converted_messages, system_prompt = _convertMessagesForAnthropic(modelCall.messages)
|
||||||
|
|
||||||
# Create Anthropic API payload
|
|
||||||
payload: Dict[str, Any] = {
|
payload: Dict[str, Any] = {
|
||||||
"model": model.name,
|
"model": model.name,
|
||||||
"messages": converted_messages,
|
"messages": converted_messages,
|
||||||
|
|
@ -217,6 +186,13 @@ class AiAnthropic(BaseConnectorAi):
|
||||||
payload["max_tokens"] = maxTokens
|
payload["max_tokens"] = maxTokens
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
payload["system"] = system_prompt
|
payload["system"] = system_prompt
|
||||||
|
|
||||||
|
if modelCall.tools:
|
||||||
|
payload["tools"] = _convertToolsToAnthropicFormat(modelCall.tools)
|
||||||
|
if modelCall.toolChoice:
|
||||||
|
payload["tool_choice"] = modelCall.toolChoice
|
||||||
|
else:
|
||||||
|
payload["tool_choice"] = {"type": "auto"}
|
||||||
|
|
||||||
response = await self.httpClient.post(
|
response = await self.httpClient.post(
|
||||||
model.apiUrl,
|
model.apiUrl,
|
||||||
|
|
@ -244,29 +220,39 @@ class AiAnthropic(BaseConnectorAi):
|
||||||
# Parse response
|
# Parse response
|
||||||
anthropicResponse = response.json()
|
anthropicResponse = response.json()
|
||||||
|
|
||||||
# Extract content from response
|
# Extract content and tool_use blocks from response
|
||||||
content = ""
|
content = ""
|
||||||
|
toolCalls = []
|
||||||
if "content" in anthropicResponse:
|
if "content" in anthropicResponse:
|
||||||
if isinstance(anthropicResponse["content"], list):
|
if isinstance(anthropicResponse["content"], list):
|
||||||
# Content is a list of parts (in newer API versions)
|
|
||||||
for part in anthropicResponse["content"]:
|
for part in anthropicResponse["content"]:
|
||||||
if part.get("type") == "text":
|
if part.get("type") == "text":
|
||||||
content += part.get("text", "")
|
content += part.get("text", "")
|
||||||
|
elif part.get("type") == "tool_use":
|
||||||
|
toolCalls.append({
|
||||||
|
"id": part.get("id", ""),
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": part.get("name", ""),
|
||||||
|
"arguments": json.dumps(part.get("input", {})) if isinstance(part.get("input"), dict) else str(part.get("input", "{}"))
|
||||||
|
}
|
||||||
|
})
|
||||||
else:
|
else:
|
||||||
# Direct content as string (in older API versions)
|
|
||||||
content = anthropicResponse["content"]
|
content = anthropicResponse["content"]
|
||||||
|
|
||||||
# Debug logging for empty responses
|
if not content and not toolCalls:
|
||||||
if not content or content.strip() == "":
|
|
||||||
logger.warning(f"Anthropic API returned empty content. Full response: {anthropicResponse}")
|
logger.warning(f"Anthropic API returned empty content. Full response: {anthropicResponse}")
|
||||||
content = "[Anthropic API returned empty response]"
|
content = "[Anthropic API returned empty response]"
|
||||||
|
|
||||||
# Return standardized response
|
metadata = {"response_id": anthropicResponse.get("id", "")}
|
||||||
|
if toolCalls:
|
||||||
|
metadata["toolCalls"] = toolCalls
|
||||||
|
|
||||||
return AiModelResponse(
|
return AiModelResponse(
|
||||||
content=content,
|
content=content,
|
||||||
success=True,
|
success=True,
|
||||||
modelId=model.name,
|
modelId=model.name,
|
||||||
metadata={"response_id": anthropicResponse.get("id", "")}
|
metadata=metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -278,7 +264,102 @@ class AiAnthropic(BaseConnectorAi):
|
||||||
error_detail += f" | Status: {e.status_code}"
|
error_detail += f" | Status: {e.status_code}"
|
||||||
logger.error(error_detail, exc_info=True)
|
logger.error(error_detail, exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=error_detail)
|
raise HTTPException(status_code=500, detail=error_detail)
|
||||||
|
|
||||||
|
async def callAiBasicStream(self, modelCall: AiModelCall) -> AsyncGenerator[Union[str, AiModelResponse], None]:
|
||||||
|
"""Stream Anthropic response. Yields str deltas, then final AiModelResponse."""
|
||||||
|
try:
|
||||||
|
model = modelCall.model
|
||||||
|
options = modelCall.options
|
||||||
|
temperature = getattr(options, "temperature", None)
|
||||||
|
if temperature is None:
|
||||||
|
temperature = model.temperature
|
||||||
|
|
||||||
|
converted, system_prompt = _convertMessagesForAnthropic(modelCall.messages)
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"model": model.name,
|
||||||
|
"messages": converted,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": model.maxTokens,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
if system_prompt:
|
||||||
|
payload["system"] = system_prompt
|
||||||
|
if modelCall.tools:
|
||||||
|
payload["tools"] = _convertToolsToAnthropicFormat(modelCall.tools)
|
||||||
|
payload["tool_choice"] = modelCall.toolChoice or {"type": "auto"}
|
||||||
|
|
||||||
|
fullContent = ""
|
||||||
|
toolUseBlocks: Dict[int, Dict[str, Any]] = {}
|
||||||
|
currentToolIdx = -1
|
||||||
|
|
||||||
|
async with self.httpClient.stream("POST", model.apiUrl, json=payload) as response:
|
||||||
|
if response.status_code != 200:
|
||||||
|
body = await response.aread()
|
||||||
|
raise HTTPException(status_code=500, detail=f"Anthropic stream error: {response.status_code} - {body.decode()}")
|
||||||
|
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
event = json.loads(line[6:])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
eventType = event.get("type", "")
|
||||||
|
|
||||||
|
if eventType == "content_block_start":
|
||||||
|
block = event.get("content_block", {})
|
||||||
|
idx = event.get("index", 0)
|
||||||
|
if block.get("type") == "tool_use":
|
||||||
|
currentToolIdx = idx
|
||||||
|
toolUseBlocks[idx] = {
|
||||||
|
"id": block.get("id", ""),
|
||||||
|
"name": block.get("name", ""),
|
||||||
|
"arguments": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
elif eventType == "content_block_delta":
|
||||||
|
delta = event.get("delta", {})
|
||||||
|
if delta.get("type") == "text_delta":
|
||||||
|
text = delta.get("text", "")
|
||||||
|
fullContent += text
|
||||||
|
yield text
|
||||||
|
elif delta.get("type") == "input_json_delta":
|
||||||
|
idx = event.get("index", currentToolIdx)
|
||||||
|
if idx in toolUseBlocks:
|
||||||
|
toolUseBlocks[idx]["arguments"] += delta.get("partial_json", "")
|
||||||
|
|
||||||
|
elif eventType == "message_stop":
|
||||||
|
break
|
||||||
|
|
||||||
|
metadata: Dict[str, Any] = {}
|
||||||
|
if toolUseBlocks:
|
||||||
|
metadata["toolCalls"] = [
|
||||||
|
{
|
||||||
|
"id": tb["id"],
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tb["name"],
|
||||||
|
"arguments": tb["arguments"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tb in toolUseBlocks.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
yield AiModelResponse(
|
||||||
|
content=fullContent,
|
||||||
|
success=True,
|
||||||
|
modelId=model.name,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error streaming Anthropic API: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"Error streaming Anthropic API: {e}")
|
||||||
|
|
||||||
async def callAiImage(self, modelCall: AiModelCall) -> AiModelResponse:
|
async def callAiImage(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||||
"""
|
"""
|
||||||
Analyzes an image using Anthropic's vision capabilities using standardized pattern.
|
Analyzes an image using Anthropic's vision capabilities using standardized pattern.
|
||||||
|
|
@ -424,4 +505,101 @@ class AiAnthropic(BaseConnectorAi):
|
||||||
content="",
|
content="",
|
||||||
success=False,
|
success=False,
|
||||||
error=f"Error during image analysis: {str(e)}"
|
error=f"Error during image analysis: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convertMessagesForAnthropic(messages: List[Dict[str, Any]]):
|
||||||
|
"""Convert OpenAI-style messages to Anthropic format. Returns (messages, system_prompt)."""
|
||||||
|
system_contents: List[str] = []
|
||||||
|
converted_messages: List[Dict[str, Any]] = []
|
||||||
|
pendingToolResults: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
def _flush():
|
||||||
|
if not pendingToolResults:
|
||||||
|
return
|
||||||
|
converted_messages.append({"role": "user", "content": list(pendingToolResults)})
|
||||||
|
pendingToolResults.clear()
|
||||||
|
|
||||||
|
def _collapse(content):
|
||||||
|
if isinstance(content, list):
|
||||||
|
return "\n\n".join(
|
||||||
|
(part.get("text") if isinstance(part, dict) else str(part))
|
||||||
|
for part in content
|
||||||
|
)
|
||||||
|
return str(content) if content else ""
|
||||||
|
|
||||||
|
for m in messages:
|
||||||
|
role = m.get("role")
|
||||||
|
content = m.get("content", "")
|
||||||
|
|
||||||
|
if role == "system":
|
||||||
|
system_contents.append(_collapse(content))
|
||||||
|
continue
|
||||||
|
if role == "tool":
|
||||||
|
pendingToolResults.append({
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": m.get("tool_call_id", ""),
|
||||||
|
"content": str(content) if content else "",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
_flush()
|
||||||
|
|
||||||
|
if role == "assistant" and m.get("tool_calls"):
|
||||||
|
contentBlocks = []
|
||||||
|
textPart = _collapse(content)
|
||||||
|
if textPart:
|
||||||
|
contentBlocks.append({"type": "text", "text": textPart})
|
||||||
|
for tc in m["tool_calls"]:
|
||||||
|
fn = tc.get("function", {})
|
||||||
|
inputData = fn.get("arguments", "{}")
|
||||||
|
if isinstance(inputData, str):
|
||||||
|
try:
|
||||||
|
inputData = json.loads(inputData)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
inputData = {}
|
||||||
|
contentBlocks.append({
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": tc.get("id", ""),
|
||||||
|
"name": fn.get("name", ""),
|
||||||
|
"input": inputData,
|
||||||
|
})
|
||||||
|
converted_messages.append({"role": "assistant", "content": contentBlocks})
|
||||||
|
continue
|
||||||
|
|
||||||
|
converted_messages.append({"role": role, "content": _collapse(content)})
|
||||||
|
|
||||||
|
_flush()
|
||||||
|
|
||||||
|
merged: List[Dict[str, Any]] = []
|
||||||
|
for msg in converted_messages:
|
||||||
|
if merged and merged[-1]["role"] == msg["role"]:
|
||||||
|
prev = merged[-1]
|
||||||
|
pc, nc = prev["content"], msg["content"]
|
||||||
|
if isinstance(pc, str) and isinstance(nc, str):
|
||||||
|
prev["content"] = pc + "\n\n" + nc
|
||||||
|
elif isinstance(pc, list) and isinstance(nc, list):
|
||||||
|
prev["content"] = pc + nc
|
||||||
|
elif isinstance(pc, str) and isinstance(nc, list):
|
||||||
|
prev["content"] = [{"type": "text", "text": pc}] + nc
|
||||||
|
elif isinstance(pc, list) and isinstance(nc, str):
|
||||||
|
prev["content"] = pc + [{"type": "text", "text": nc}]
|
||||||
|
else:
|
||||||
|
merged.append(msg)
|
||||||
|
|
||||||
|
system_prompt = "\n\n".join([s for s in system_contents if s]) if system_contents else None
|
||||||
|
return merged, system_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def _convertToolsToAnthropicFormat(openaiTools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
|
"""Convert OpenAI-style tool definitions to Anthropic format."""
|
||||||
|
anthropicTools = []
|
||||||
|
for tool in openaiTools:
|
||||||
|
if tool.get("type") == "function":
|
||||||
|
fn = tool["function"]
|
||||||
|
anthropicTools.append({
|
||||||
|
"name": fn["name"],
|
||||||
|
"description": fn.get("description", ""),
|
||||||
|
"input_schema": fn.get("parameters", {"type": "object", "properties": {}})
|
||||||
|
})
|
||||||
|
return anthropicTools
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
# Copyright (c) 2025 Patrick Motsch
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
import logging
|
import logging
|
||||||
|
import json as _json
|
||||||
import httpx
|
import httpx
|
||||||
from typing import List
|
from typing import List, Dict, Any, AsyncGenerator, Union
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from modules.shared.configuration import APP_CONFIG
|
from modules.shared.configuration import APP_CONFIG
|
||||||
from .aicoreBase import BaseConnectorAi
|
from .aicoreBase import BaseConnectorAi
|
||||||
|
|
@ -66,13 +67,15 @@ class AiMistral(BaseConnectorAi):
|
||||||
speedRating=8, # Good speed for complex tasks
|
speedRating=8, # Good speed for complex tasks
|
||||||
qualityRating=9, # High quality
|
qualityRating=9, # High quality
|
||||||
functionCall=self.callAiBasic,
|
functionCall=self.callAiBasic,
|
||||||
|
functionCallStream=self.callAiBasicStream,
|
||||||
priority=PriorityEnum.BALANCED,
|
priority=PriorityEnum.BALANCED,
|
||||||
processingMode=ProcessingModeEnum.ADVANCED,
|
processingMode=ProcessingModeEnum.ADVANCED,
|
||||||
operationTypes=createOperationTypeRatings(
|
operationTypes=createOperationTypeRatings(
|
||||||
(OperationTypeEnum.PLAN, 9),
|
(OperationTypeEnum.PLAN, 9),
|
||||||
(OperationTypeEnum.DATA_ANALYSE, 9),
|
(OperationTypeEnum.DATA_ANALYSE, 9),
|
||||||
(OperationTypeEnum.DATA_GENERATE, 9),
|
(OperationTypeEnum.DATA_GENERATE, 9),
|
||||||
(OperationTypeEnum.DATA_EXTRACT, 8)
|
(OperationTypeEnum.DATA_EXTRACT, 8),
|
||||||
|
(OperationTypeEnum.AGENT, 8),
|
||||||
),
|
),
|
||||||
version="mistral-large-latest",
|
version="mistral-large-latest",
|
||||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0005 + (bytesReceived / 4 / 1000) * 0.0015
|
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0005 + (bytesReceived / 4 / 1000) * 0.0015
|
||||||
|
|
@ -90,17 +93,40 @@ class AiMistral(BaseConnectorAi):
|
||||||
speedRating=9, # Very fast, lightweight model
|
speedRating=9, # Very fast, lightweight model
|
||||||
qualityRating=7, # Good quality, cost-efficient
|
qualityRating=7, # Good quality, cost-efficient
|
||||||
functionCall=self.callAiBasic,
|
functionCall=self.callAiBasic,
|
||||||
|
functionCallStream=self.callAiBasicStream,
|
||||||
priority=PriorityEnum.SPEED,
|
priority=PriorityEnum.SPEED,
|
||||||
processingMode=ProcessingModeEnum.BASIC,
|
processingMode=ProcessingModeEnum.BASIC,
|
||||||
operationTypes=createOperationTypeRatings(
|
operationTypes=createOperationTypeRatings(
|
||||||
(OperationTypeEnum.PLAN, 7),
|
(OperationTypeEnum.PLAN, 7),
|
||||||
(OperationTypeEnum.DATA_ANALYSE, 7),
|
(OperationTypeEnum.DATA_ANALYSE, 7),
|
||||||
(OperationTypeEnum.DATA_GENERATE, 8),
|
(OperationTypeEnum.DATA_GENERATE, 8),
|
||||||
(OperationTypeEnum.DATA_EXTRACT, 7)
|
(OperationTypeEnum.DATA_EXTRACT, 7),
|
||||||
|
(OperationTypeEnum.AGENT, 6),
|
||||||
),
|
),
|
||||||
version="mistral-small-latest",
|
version="mistral-small-latest",
|
||||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00006 + (bytesReceived / 4 / 1000) * 0.00018
|
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00006 + (bytesReceived / 4 / 1000) * 0.00018
|
||||||
),
|
),
|
||||||
|
AiModel(
|
||||||
|
name="mistral-embed",
|
||||||
|
displayName="Mistral Embed",
|
||||||
|
connectorType="mistral",
|
||||||
|
apiUrl="https://api.mistral.ai/v1/embeddings",
|
||||||
|
temperature=0.0,
|
||||||
|
maxTokens=0,
|
||||||
|
contextLength=8192,
|
||||||
|
costPer1kTokensInput=0.0001, # $0.10/M tokens
|
||||||
|
costPer1kTokensOutput=0.0,
|
||||||
|
speedRating=10,
|
||||||
|
qualityRating=7,
|
||||||
|
functionCall=self.callEmbedding,
|
||||||
|
priority=PriorityEnum.COST,
|
||||||
|
processingMode=ProcessingModeEnum.BASIC,
|
||||||
|
operationTypes=createOperationTypeRatings(
|
||||||
|
(OperationTypeEnum.EMBEDDING, 8)
|
||||||
|
),
|
||||||
|
version="mistral-embed",
|
||||||
|
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0001
|
||||||
|
),
|
||||||
AiModel(
|
AiModel(
|
||||||
name="mistral-large-latest",
|
name="mistral-large-latest",
|
||||||
displayName="Mistral Large 3 Vision",
|
displayName="Mistral Large 3 Vision",
|
||||||
|
|
@ -215,7 +241,105 @@ class AiMistral(BaseConnectorAi):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error calling Mistral API: {str(e)}")
|
logger.error(f"Error calling Mistral API: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail=f"Error calling Mistral API: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Error calling Mistral API: {str(e)}")
|
||||||
|
|
||||||
|
async def callAiBasicStream(self, modelCall: AiModelCall) -> AsyncGenerator[Union[str, AiModelResponse], None]:
|
||||||
|
"""Stream Mistral response. Yields str deltas, then final AiModelResponse."""
|
||||||
|
try:
|
||||||
|
model = modelCall.model
|
||||||
|
options = modelCall.options
|
||||||
|
temperature = getattr(options, "temperature", None)
|
||||||
|
if temperature is None:
|
||||||
|
temperature = model.temperature
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"model": model.name,
|
||||||
|
"messages": modelCall.messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": model.maxTokens,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
fullContent = ""
|
||||||
|
|
||||||
|
async with self.httpClient.stream("POST", model.apiUrl, json=payload) as response:
|
||||||
|
if response.status_code != 200:
|
||||||
|
body = await response.aread()
|
||||||
|
raise HTTPException(status_code=500, detail=f"Mistral stream error: {response.status_code} - {body.decode()}")
|
||||||
|
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
data = line[6:]
|
||||||
|
if data.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
chunk = _json.loads(data)
|
||||||
|
except _json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||||
|
if "content" in delta and delta["content"]:
|
||||||
|
fullContent += delta["content"]
|
||||||
|
yield delta["content"]
|
||||||
|
|
||||||
|
yield AiModelResponse(
|
||||||
|
content=fullContent,
|
||||||
|
success=True,
|
||||||
|
modelId=model.name,
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error streaming Mistral API: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Error streaming Mistral API: {e}")
|
||||||
|
|
||||||
|
async def callEmbedding(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||||
|
"""Generate embeddings via the Mistral Embeddings API.
|
||||||
|
|
||||||
|
Reads texts from modelCall.embeddingInput.
|
||||||
|
Returns vectors in metadata["embeddings"].
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model = modelCall.model
|
||||||
|
texts = modelCall.embeddingInput or []
|
||||||
|
if not texts:
|
||||||
|
return AiModelResponse(
|
||||||
|
content="", success=False, error="No embeddingInput provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = {"model": model.name, "input": texts}
|
||||||
|
response = await self.httpClient.post(model.apiUrl, json=payload)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
errorMessage = f"Mistral Embedding API error: {response.status_code} - {response.text}"
|
||||||
|
logger.error(errorMessage)
|
||||||
|
if response.status_code == 429:
|
||||||
|
raise RateLimitExceededException(f"Rate limit exceeded for {model.name}")
|
||||||
|
raise HTTPException(status_code=500, detail=errorMessage)
|
||||||
|
|
||||||
|
responseJson = response.json()
|
||||||
|
embeddings = [item["embedding"] for item in responseJson["data"]]
|
||||||
|
usage = responseJson.get("usage", {})
|
||||||
|
|
||||||
|
return AiModelResponse(
|
||||||
|
content="",
|
||||||
|
success=True,
|
||||||
|
modelId=model.name,
|
||||||
|
tokensUsed={
|
||||||
|
"input": usage.get("prompt_tokens", 0),
|
||||||
|
"output": 0,
|
||||||
|
"total": usage.get("total_tokens", 0),
|
||||||
|
},
|
||||||
|
metadata={"embeddings": embeddings},
|
||||||
|
)
|
||||||
|
except RateLimitExceededException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error calling Mistral Embedding API: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Error calling Mistral Embedding API: {str(e)}")
|
||||||
|
|
||||||
async def callAiImage(self, modelCall: AiModelCall) -> AiModelResponse:
|
async def callAiImage(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||||
"""
|
"""
|
||||||
Analyzes an image with the Mistral Vision API using standardized pattern.
|
Analyzes an image with the Mistral Vision API using standardized pattern.
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
# Copyright (c) 2025 Patrick Motsch
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
import logging
|
import logging
|
||||||
|
import json as _json
|
||||||
import httpx
|
import httpx
|
||||||
from typing import List
|
from typing import List, Dict, Any, AsyncGenerator, Union
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from modules.shared.configuration import APP_CONFIG
|
from modules.shared.configuration import APP_CONFIG
|
||||||
from .aicoreBase import BaseConnectorAi
|
from .aicoreBase import BaseConnectorAi
|
||||||
|
|
@ -67,13 +68,15 @@ class AiOpenai(BaseConnectorAi):
|
||||||
speedRating=8, # Good speed for complex tasks
|
speedRating=8, # Good speed for complex tasks
|
||||||
qualityRating=10, # High quality
|
qualityRating=10, # High quality
|
||||||
functionCall=self.callAiBasic,
|
functionCall=self.callAiBasic,
|
||||||
|
functionCallStream=self.callAiBasicStream,
|
||||||
priority=PriorityEnum.BALANCED,
|
priority=PriorityEnum.BALANCED,
|
||||||
processingMode=ProcessingModeEnum.ADVANCED,
|
processingMode=ProcessingModeEnum.ADVANCED,
|
||||||
operationTypes=createOperationTypeRatings(
|
operationTypes=createOperationTypeRatings(
|
||||||
(OperationTypeEnum.PLAN, 9),
|
(OperationTypeEnum.PLAN, 9),
|
||||||
(OperationTypeEnum.DATA_ANALYSE, 10),
|
(OperationTypeEnum.DATA_ANALYSE, 10),
|
||||||
(OperationTypeEnum.DATA_GENERATE, 10),
|
(OperationTypeEnum.DATA_GENERATE, 10),
|
||||||
(OperationTypeEnum.DATA_EXTRACT, 7)
|
(OperationTypeEnum.DATA_EXTRACT, 7),
|
||||||
|
(OperationTypeEnum.AGENT, 9),
|
||||||
),
|
),
|
||||||
version="gpt-4o",
|
version="gpt-4o",
|
||||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0025 + (bytesReceived / 4 / 1000) * 0.01
|
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0025 + (bytesReceived / 4 / 1000) * 0.01
|
||||||
|
|
@ -92,13 +95,15 @@ class AiOpenai(BaseConnectorAi):
|
||||||
speedRating=9, # Very fast
|
speedRating=9, # Very fast
|
||||||
qualityRating=8, # Good quality, replaces gpt-3.5-turbo
|
qualityRating=8, # Good quality, replaces gpt-3.5-turbo
|
||||||
functionCall=self.callAiBasic,
|
functionCall=self.callAiBasic,
|
||||||
|
functionCallStream=self.callAiBasicStream,
|
||||||
priority=PriorityEnum.SPEED,
|
priority=PriorityEnum.SPEED,
|
||||||
processingMode=ProcessingModeEnum.BASIC,
|
processingMode=ProcessingModeEnum.BASIC,
|
||||||
operationTypes=createOperationTypeRatings(
|
operationTypes=createOperationTypeRatings(
|
||||||
(OperationTypeEnum.PLAN, 8),
|
(OperationTypeEnum.PLAN, 8),
|
||||||
(OperationTypeEnum.DATA_ANALYSE, 8),
|
(OperationTypeEnum.DATA_ANALYSE, 8),
|
||||||
(OperationTypeEnum.DATA_GENERATE, 9),
|
(OperationTypeEnum.DATA_GENERATE, 9),
|
||||||
(OperationTypeEnum.DATA_EXTRACT, 7)
|
(OperationTypeEnum.DATA_EXTRACT, 7),
|
||||||
|
(OperationTypeEnum.AGENT, 8),
|
||||||
),
|
),
|
||||||
version="gpt-4o-mini",
|
version="gpt-4o-mini",
|
||||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00015 + (bytesReceived / 4 / 1000) * 0.0006
|
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00015 + (bytesReceived / 4 / 1000) * 0.0006
|
||||||
|
|
@ -125,6 +130,48 @@ class AiOpenai(BaseConnectorAi):
|
||||||
version="gpt-4o",
|
version="gpt-4o",
|
||||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0025 + (bytesReceived / 4 / 1000) * 0.01
|
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0025 + (bytesReceived / 4 / 1000) * 0.01
|
||||||
),
|
),
|
||||||
|
AiModel(
|
||||||
|
name="text-embedding-3-small",
|
||||||
|
displayName="OpenAI Embedding Small",
|
||||||
|
connectorType="openai",
|
||||||
|
apiUrl="https://api.openai.com/v1/embeddings",
|
||||||
|
temperature=0.0,
|
||||||
|
maxTokens=0,
|
||||||
|
contextLength=8191,
|
||||||
|
costPer1kTokensInput=0.00002, # $0.02/M tokens
|
||||||
|
costPer1kTokensOutput=0.0,
|
||||||
|
speedRating=10,
|
||||||
|
qualityRating=8,
|
||||||
|
functionCall=self.callEmbedding,
|
||||||
|
priority=PriorityEnum.COST,
|
||||||
|
processingMode=ProcessingModeEnum.BASIC,
|
||||||
|
operationTypes=createOperationTypeRatings(
|
||||||
|
(OperationTypeEnum.EMBEDDING, 10)
|
||||||
|
),
|
||||||
|
version="text-embedding-3-small",
|
||||||
|
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00002
|
||||||
|
),
|
||||||
|
AiModel(
|
||||||
|
name="text-embedding-3-large",
|
||||||
|
displayName="OpenAI Embedding Large",
|
||||||
|
connectorType="openai",
|
||||||
|
apiUrl="https://api.openai.com/v1/embeddings",
|
||||||
|
temperature=0.0,
|
||||||
|
maxTokens=0,
|
||||||
|
contextLength=8191,
|
||||||
|
costPer1kTokensInput=0.00013, # $0.13/M tokens
|
||||||
|
costPer1kTokensOutput=0.0,
|
||||||
|
speedRating=9,
|
||||||
|
qualityRating=10,
|
||||||
|
functionCall=self.callEmbedding,
|
||||||
|
priority=PriorityEnum.QUALITY,
|
||||||
|
processingMode=ProcessingModeEnum.ADVANCED,
|
||||||
|
operationTypes=createOperationTypeRatings(
|
||||||
|
(OperationTypeEnum.EMBEDDING, 10)
|
||||||
|
),
|
||||||
|
version="text-embedding-3-large",
|
||||||
|
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00013
|
||||||
|
),
|
||||||
AiModel(
|
AiModel(
|
||||||
name="dall-e-3",
|
name="dall-e-3",
|
||||||
displayName="OpenAI DALL-E 3",
|
displayName="OpenAI DALL-E 3",
|
||||||
|
|
@ -179,6 +226,10 @@ class AiOpenai(BaseConnectorAi):
|
||||||
"max_tokens": maxTokens
|
"max_tokens": maxTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if modelCall.tools:
|
||||||
|
payload["tools"] = modelCall.tools
|
||||||
|
payload["tool_choice"] = modelCall.toolChoice or "auto"
|
||||||
|
|
||||||
response = await self.httpClient.post(
|
response = await self.httpClient.post(
|
||||||
model.apiUrl,
|
model.apiUrl,
|
||||||
json=payload
|
json=payload
|
||||||
|
|
@ -218,22 +269,150 @@ class AiOpenai(BaseConnectorAi):
|
||||||
raise HTTPException(status_code=500, detail=error_message)
|
raise HTTPException(status_code=500, detail=error_message)
|
||||||
|
|
||||||
responseJson = response.json()
|
responseJson = response.json()
|
||||||
content = responseJson["choices"][0]["message"]["content"]
|
choiceMessage = responseJson["choices"][0]["message"]
|
||||||
|
content = choiceMessage.get("content") or ""
|
||||||
|
|
||||||
|
metadata = {"response_id": responseJson.get("id", "")}
|
||||||
|
if choiceMessage.get("tool_calls"):
|
||||||
|
metadata["toolCalls"] = choiceMessage["tool_calls"]
|
||||||
|
|
||||||
return AiModelResponse(
|
return AiModelResponse(
|
||||||
content=content,
|
content=content,
|
||||||
success=True,
|
success=True,
|
||||||
modelId=model.name,
|
modelId=model.name,
|
||||||
metadata={"response_id": responseJson.get("id", "")}
|
metadata=metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
except ContextLengthExceededException:
|
except ContextLengthExceededException:
|
||||||
# Re-raise context length exceptions without wrapping
|
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error calling OpenAI API: {str(e)}")
|
logger.error(f"Error calling OpenAI API: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail=f"Error calling OpenAI API: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Error calling OpenAI API: {str(e)}")
|
||||||
|
|
||||||
|
async def callAiBasicStream(self, modelCall: AiModelCall) -> AsyncGenerator[Union[str, AiModelResponse], None]:
|
||||||
|
"""Stream OpenAI response. Yields str deltas, then final AiModelResponse."""
|
||||||
|
try:
|
||||||
|
messages = modelCall.messages
|
||||||
|
model = modelCall.model
|
||||||
|
options = modelCall.options
|
||||||
|
temperature = getattr(options, "temperature", None)
|
||||||
|
if temperature is None:
|
||||||
|
temperature = model.temperature
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"model": model.name,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": model.maxTokens,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
if modelCall.tools:
|
||||||
|
payload["tools"] = modelCall.tools
|
||||||
|
payload["tool_choice"] = modelCall.toolChoice or "auto"
|
||||||
|
|
||||||
|
fullContent = ""
|
||||||
|
toolCallsAccum: Dict[int, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
async with self.httpClient.stream("POST", model.apiUrl, json=payload) as response:
|
||||||
|
if response.status_code != 200:
|
||||||
|
body = await response.aread()
|
||||||
|
raise HTTPException(status_code=500, detail=f"OpenAI stream error: {response.status_code} - {body.decode()}")
|
||||||
|
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
data = line[6:]
|
||||||
|
if data.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
chunk = _json.loads(data)
|
||||||
|
except _json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||||
|
|
||||||
|
if "content" in delta and delta["content"]:
|
||||||
|
fullContent += delta["content"]
|
||||||
|
yield delta["content"]
|
||||||
|
|
||||||
|
for tcDelta in delta.get("tool_calls", []):
|
||||||
|
idx = tcDelta.get("index", 0)
|
||||||
|
if idx not in toolCallsAccum:
|
||||||
|
toolCallsAccum[idx] = {
|
||||||
|
"id": tcDelta.get("id", ""),
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "", "arguments": ""},
|
||||||
|
}
|
||||||
|
if tcDelta.get("id"):
|
||||||
|
toolCallsAccum[idx]["id"] = tcDelta["id"]
|
||||||
|
fn = tcDelta.get("function", {})
|
||||||
|
if fn.get("name"):
|
||||||
|
toolCallsAccum[idx]["function"]["name"] = fn["name"]
|
||||||
|
if fn.get("arguments"):
|
||||||
|
toolCallsAccum[idx]["function"]["arguments"] += fn["arguments"]
|
||||||
|
|
||||||
|
metadata: Dict[str, Any] = {}
|
||||||
|
if toolCallsAccum:
|
||||||
|
metadata["toolCalls"] = [toolCallsAccum[i] for i in sorted(toolCallsAccum)]
|
||||||
|
|
||||||
|
yield AiModelResponse(
|
||||||
|
content=fullContent,
|
||||||
|
success=True,
|
||||||
|
modelId=model.name,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error streaming OpenAI API: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Error streaming OpenAI API: {e}")
|
||||||
|
|
||||||
|
async def callEmbedding(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||||
|
"""Generate embeddings via the OpenAI Embeddings API.
|
||||||
|
|
||||||
|
Reads texts from modelCall.embeddingInput.
|
||||||
|
Returns vectors in metadata["embeddings"].
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model = modelCall.model
|
||||||
|
texts = modelCall.embeddingInput or []
|
||||||
|
if not texts:
|
||||||
|
return AiModelResponse(
|
||||||
|
content="", success=False, error="No embeddingInput provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = {"model": model.name, "input": texts}
|
||||||
|
response = await self.httpClient.post(model.apiUrl, json=payload)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
errorMessage = f"OpenAI Embedding API error: {response.status_code} - {response.text}"
|
||||||
|
logger.error(errorMessage)
|
||||||
|
if response.status_code == 429:
|
||||||
|
raise RateLimitExceededException(f"Rate limit exceeded for {model.name}")
|
||||||
|
raise HTTPException(status_code=500, detail=errorMessage)
|
||||||
|
|
||||||
|
responseJson = response.json()
|
||||||
|
embeddings = [item["embedding"] for item in responseJson["data"]]
|
||||||
|
usage = responseJson.get("usage", {})
|
||||||
|
|
||||||
|
return AiModelResponse(
|
||||||
|
content="",
|
||||||
|
success=True,
|
||||||
|
modelId=model.name,
|
||||||
|
tokensUsed={
|
||||||
|
"input": usage.get("prompt_tokens", 0),
|
||||||
|
"output": 0,
|
||||||
|
"total": usage.get("total_tokens", 0),
|
||||||
|
},
|
||||||
|
metadata={"embeddings": embeddings},
|
||||||
|
)
|
||||||
|
except (RateLimitExceededException, ContextLengthExceededException):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error calling OpenAI Embedding API: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Error calling OpenAI Embedding API: {str(e)}")
|
||||||
|
|
||||||
async def callAiImage(self, modelCall: AiModelCall) -> AiModelResponse:
|
async def callAiImage(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||||
"""
|
"""
|
||||||
Analyzes an image with the OpenAI Vision API using standardized pattern.
|
Analyzes an image with the OpenAI Vision API using standardized pattern.
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,11 @@ class SystemTable(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _isVectorType(sqlType: str) -> bool:
|
||||||
|
"""Check if a SQL type string represents a pgvector column."""
|
||||||
|
return sqlType.upper().startswith("VECTOR")
|
||||||
|
|
||||||
|
|
||||||
def _isJsonbType(fieldType) -> bool:
|
def _isJsonbType(fieldType) -> bool:
|
||||||
"""Check if a type should be stored as JSONB in PostgreSQL."""
|
"""Check if a type should be stored as JSONB in PostgreSQL."""
|
||||||
# Direct dict or list
|
# Direct dict or list
|
||||||
|
|
@ -70,20 +75,26 @@ def _isJsonbType(fieldType) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def _get_model_fields(model_class) -> Dict[str, str]:
|
def _get_model_fields(model_class) -> Dict[str, str]:
|
||||||
"""Get all fields from Pydantic model and map to SQL types."""
|
"""Get all fields from Pydantic model and map to SQL types.
|
||||||
# Pydantic v2
|
|
||||||
|
Supports explicit db_type override via json_schema_extra={"db_type": "vector(1536)"}.
|
||||||
|
This enables pgvector columns without special-casing field names.
|
||||||
|
"""
|
||||||
model_fields = model_class.model_fields
|
model_fields = model_class.model_fields
|
||||||
|
|
||||||
fields = {}
|
fields = {}
|
||||||
for field_name, field_info in model_fields.items():
|
for field_name, field_info in model_fields.items():
|
||||||
# Pydantic v2
|
|
||||||
field_type = field_info.annotation
|
field_type = field_info.annotation
|
||||||
|
|
||||||
|
# Explicit db_type override (e.g. vector columns)
|
||||||
|
extra = field_info.json_schema_extra
|
||||||
|
if extra and isinstance(extra, dict) and "db_type" in extra:
|
||||||
|
fields[field_name] = extra["db_type"]
|
||||||
|
continue
|
||||||
|
|
||||||
# Check for JSONB fields (Dict, List, or complex types)
|
# Check for JSONB fields (Dict, List, or complex types)
|
||||||
# Purely type-based detection - no hardcoded field names
|
|
||||||
if _isJsonbType(field_type):
|
if _isJsonbType(field_type):
|
||||||
fields[field_name] = "JSONB"
|
fields[field_name] = "JSONB"
|
||||||
# Simple type mapping
|
|
||||||
elif field_type in (str, type(None)) or (
|
elif field_type in (str, type(None)) or (
|
||||||
get_origin(field_type) is Union and type(None) in get_args(field_type)
|
get_origin(field_type) is Union and type(None) in get_args(field_type)
|
||||||
):
|
):
|
||||||
|
|
@ -95,11 +106,45 @@ def _get_model_fields(model_class) -> Dict[str, str]:
|
||||||
elif field_type == bool:
|
elif field_type == bool:
|
||||||
fields[field_name] = "BOOLEAN"
|
fields[field_name] = "BOOLEAN"
|
||||||
else:
|
else:
|
||||||
fields[field_name] = "TEXT" # Default to TEXT
|
fields[field_name] = "TEXT"
|
||||||
|
|
||||||
return fields
|
return fields
|
||||||
|
|
||||||
|
|
||||||
|
def _parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context: str = "") -> None:
|
||||||
|
"""Parse record fields in-place: numeric typing, vector parsing, JSONB deserialization."""
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
for fieldName, fieldType in fields.items():
|
||||||
|
if fieldName not in record:
|
||||||
|
continue
|
||||||
|
value = record[fieldName]
|
||||||
|
|
||||||
|
if fieldType in ("DOUBLE PRECISION", "INTEGER") and value is not None:
|
||||||
|
try:
|
||||||
|
record[fieldName] = float(value) if fieldType == "DOUBLE PRECISION" else int(value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(f"Could not convert {fieldName} to {fieldType} ({context}): {value}")
|
||||||
|
|
||||||
|
elif _isVectorType(fieldType) and value is not None:
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
record[fieldName] = [float(v) for v in value.strip("[]").split(",")]
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(f"Could not parse vector field {fieldName} ({context})")
|
||||||
|
elif isinstance(value, list):
|
||||||
|
pass # already a list
|
||||||
|
|
||||||
|
elif fieldType == "JSONB" and value is not None:
|
||||||
|
try:
|
||||||
|
if isinstance(value, str):
|
||||||
|
record[fieldName] = _json.loads(value)
|
||||||
|
elif not isinstance(value, (dict, list)):
|
||||||
|
record[fieldName] = _json.loads(str(value))
|
||||||
|
except (_json.JSONDecodeError, TypeError, ValueError):
|
||||||
|
logger.warning(f"Could not parse JSONB field {fieldName}, keeping as string ({context})")
|
||||||
|
|
||||||
|
|
||||||
# Cache connectors by (host, database, port) to avoid duplicate inits for same database.
|
# Cache connectors by (host, database, port) to avoid duplicate inits for same database.
|
||||||
# Thread safety: _connector_cache_lock protects cache access. userId is request-scoped via
|
# Thread safety: _connector_cache_lock protects cache access. userId is request-scoped via
|
||||||
# contextvars to avoid races when concurrent requests share the same connector.
|
# contextvars to avoid races when concurrent requests share the same connector.
|
||||||
|
|
@ -187,6 +232,9 @@ class DatabaseConnector:
|
||||||
# Thread safety
|
# Thread safety
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
# pgvector extension state
|
||||||
|
self._vectorExtensionEnabled = False
|
||||||
|
|
||||||
# Initialize system table
|
# Initialize system table
|
||||||
self._systemTableName = "_system"
|
self._systemTableName = "_system"
|
||||||
self._initializeSystemTable()
|
self._initializeSystemTable()
|
||||||
|
|
@ -500,10 +548,32 @@ class DatabaseConnector:
|
||||||
self.connection.rollback()
|
self.connection.rollback()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _ensureVectorExtension(self) -> bool:
|
||||||
|
"""Enable pgvector extension if not already enabled. Called lazily on first vector table."""
|
||||||
|
if self._vectorExtensionEnabled:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
self._ensure_connection()
|
||||||
|
with self.connection.cursor() as cursor:
|
||||||
|
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||||
|
self.connection.commit()
|
||||||
|
self._vectorExtensionEnabled = True
|
||||||
|
logger.info("pgvector extension enabled")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to enable pgvector extension: {e}")
|
||||||
|
if hasattr(self, "connection") and self.connection:
|
||||||
|
self.connection.rollback()
|
||||||
|
return False
|
||||||
|
|
||||||
def _create_table_from_model(self, cursor, table: str, model_class: type) -> None:
|
def _create_table_from_model(self, cursor, table: str, model_class: type) -> None:
|
||||||
"""Create table with columns matching Pydantic model fields."""
|
"""Create table with columns matching Pydantic model fields."""
|
||||||
fields = _get_model_fields(model_class)
|
fields = _get_model_fields(model_class)
|
||||||
|
|
||||||
|
# Enable pgvector if any field uses vector type
|
||||||
|
if any(_isVectorType(sqlType) for sqlType in fields.values()):
|
||||||
|
self._ensureVectorExtension()
|
||||||
|
|
||||||
# Build column definitions with quoted identifiers to preserve exact case
|
# Build column definitions with quoted identifiers to preserve exact case
|
||||||
columns = ['"id" VARCHAR(255) PRIMARY KEY']
|
columns = ['"id" VARCHAR(255) PRIMARY KEY']
|
||||||
for field_name, sql_type in fields.items():
|
for field_name, sql_type in fields.items():
|
||||||
|
|
@ -576,28 +646,25 @@ class DatabaseConnector:
|
||||||
elif hasattr(value, "value"):
|
elif hasattr(value, "value"):
|
||||||
value = value.value
|
value = value.value
|
||||||
|
|
||||||
|
# Handle vector fields (pgvector) - convert List[float] to string
|
||||||
|
elif col in fields and _isVectorType(fields[col]) and value is not None:
|
||||||
|
if isinstance(value, list):
|
||||||
|
value = f"[{','.join(str(v) for v in value)}]"
|
||||||
|
|
||||||
# Handle JSONB fields - ensure proper JSON format for PostgreSQL
|
# Handle JSONB fields - ensure proper JSON format for PostgreSQL
|
||||||
elif col in fields and fields[col] == "JSONB" and value is not None:
|
elif col in fields and fields[col] == "JSONB" and value is not None:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
if isinstance(value, (dict, list)):
|
if isinstance(value, (dict, list)):
|
||||||
# Convert Python objects to JSON string for PostgreSQL JSONB
|
|
||||||
value = json.dumps(value)
|
value = json.dumps(value)
|
||||||
elif isinstance(value, str):
|
elif isinstance(value, str):
|
||||||
# Validate that it's valid JSON, if not, try to parse and re-serialize
|
|
||||||
try:
|
try:
|
||||||
# Test if it's already valid JSON
|
|
||||||
json.loads(value)
|
json.loads(value)
|
||||||
# If successful, keep as is
|
|
||||||
pass
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
# If not valid JSON, convert to JSON string
|
|
||||||
value = json.dumps(value)
|
value = json.dumps(value)
|
||||||
elif hasattr(value, 'model_dump'):
|
elif hasattr(value, 'model_dump'):
|
||||||
# Handle Pydantic models
|
|
||||||
value = json.dumps(value.model_dump())
|
value = json.dumps(value.model_dump())
|
||||||
else:
|
else:
|
||||||
# Convert other types to JSON
|
|
||||||
value = json.dumps(value)
|
value = json.dumps(value)
|
||||||
|
|
||||||
values.append(value)
|
values.append(value)
|
||||||
|
|
@ -635,46 +702,7 @@ class DatabaseConnector:
|
||||||
record = dict(row)
|
record = dict(row)
|
||||||
fields = _get_model_fields(model_class)
|
fields = _get_model_fields(model_class)
|
||||||
|
|
||||||
# Ensure numeric fields are properly typed and parse JSONB fields
|
_parseRecordFields(record, fields, f"record {recordId}")
|
||||||
for field_name, field_type in fields.items():
|
|
||||||
# Ensure numeric fields (float/int) are properly typed
|
|
||||||
# psycopg2 may return them as strings in some environments (e.g., Azure PostgreSQL)
|
|
||||||
if field_type in ("DOUBLE PRECISION", "INTEGER") and field_name in record:
|
|
||||||
value = record[field_name]
|
|
||||||
if value is not None:
|
|
||||||
try:
|
|
||||||
if field_type == "DOUBLE PRECISION":
|
|
||||||
record[field_name] = float(value)
|
|
||||||
elif field_type == "INTEGER":
|
|
||||||
record[field_name] = int(value)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
# If conversion fails, log warning but keep original value
|
|
||||||
logger.warning(
|
|
||||||
f"Could not convert {field_name} to {field_type} for record {recordId}: {value}"
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
field_type == "JSONB"
|
|
||||||
and field_name in record
|
|
||||||
and record[field_name] is not None
|
|
||||||
):
|
|
||||||
import json
|
|
||||||
|
|
||||||
try:
|
|
||||||
if isinstance(record[field_name], str):
|
|
||||||
# Parse JSON string back to Python object
|
|
||||||
record[field_name] = json.loads(record[field_name])
|
|
||||||
elif isinstance(record[field_name], (dict, list)):
|
|
||||||
# Already a Python object, keep as is
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# Try to parse as JSON
|
|
||||||
record[field_name] = json.loads(str(record[field_name]))
|
|
||||||
except (json.JSONDecodeError, TypeError, ValueError):
|
|
||||||
# If parsing fails, keep as string
|
|
||||||
logger.warning(
|
|
||||||
f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}"
|
|
||||||
)
|
|
||||||
pass
|
|
||||||
|
|
||||||
return record
|
return record
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -737,55 +765,24 @@ class DatabaseConnector:
|
||||||
cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"')
|
cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"')
|
||||||
records = [dict(row) for row in cursor.fetchall()]
|
records = [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
# Handle JSONB fields for all records
|
|
||||||
fields = _get_model_fields(model_class)
|
fields = _get_model_fields(model_class)
|
||||||
model_fields = model_class.model_fields # Get Pydantic model fields
|
modelFields = model_class.model_fields
|
||||||
for record in records:
|
for record in records:
|
||||||
for field_name, field_type in fields.items():
|
_parseRecordFields(record, fields, f"table {table}")
|
||||||
if field_type == "JSONB" and field_name in record:
|
# Set type-aware defaults for NULL JSONB fields
|
||||||
if record[field_name] is None:
|
for fieldName, fieldType in fields.items():
|
||||||
# Generic type-based default: List types -> [], Dict types -> {}
|
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
|
||||||
# Interfaces handle domain-specific defaults
|
fieldInfo = modelFields.get(fieldName)
|
||||||
field_info = model_fields.get(field_name)
|
if fieldInfo:
|
||||||
if field_info:
|
fieldAnnotation = fieldInfo.annotation
|
||||||
field_annotation = field_info.annotation
|
if (fieldAnnotation == list or
|
||||||
# Check if it's a List type
|
(hasattr(fieldAnnotation, "__origin__") and
|
||||||
if (field_annotation == list or
|
fieldAnnotation.__origin__ is list)):
|
||||||
(hasattr(field_annotation, "__origin__") and
|
record[fieldName] = []
|
||||||
field_annotation.__origin__ is list)):
|
elif (fieldAnnotation == dict or
|
||||||
record[field_name] = []
|
(hasattr(fieldAnnotation, "__origin__") and
|
||||||
# Check if it's a Dict type
|
fieldAnnotation.__origin__ is dict)):
|
||||||
elif (field_annotation == dict or
|
record[fieldName] = {}
|
||||||
(hasattr(field_annotation, "__origin__") and
|
|
||||||
field_annotation.__origin__ is dict)):
|
|
||||||
record[field_name] = {}
|
|
||||||
else:
|
|
||||||
record[field_name] = None
|
|
||||||
else:
|
|
||||||
record[field_name] = None
|
|
||||||
else:
|
|
||||||
import json
|
|
||||||
|
|
||||||
try:
|
|
||||||
if isinstance(record[field_name], str):
|
|
||||||
# Parse JSON string back to Python object
|
|
||||||
record[field_name] = json.loads(
|
|
||||||
record[field_name]
|
|
||||||
)
|
|
||||||
elif isinstance(record[field_name], (dict, list)):
|
|
||||||
# Already a Python object, keep as is
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# Try to parse as JSON
|
|
||||||
record[field_name] = json.loads(
|
|
||||||
str(record[field_name])
|
|
||||||
)
|
|
||||||
except (json.JSONDecodeError, TypeError, ValueError):
|
|
||||||
# If parsing fails, keep as string
|
|
||||||
logger.warning(
|
|
||||||
f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}"
|
|
||||||
)
|
|
||||||
pass
|
|
||||||
|
|
||||||
return records
|
return records
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -936,70 +933,23 @@ class DatabaseConnector:
|
||||||
cursor.execute(query, where_values)
|
cursor.execute(query, where_values)
|
||||||
records = [dict(row) for row in cursor.fetchall()]
|
records = [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
# Handle JSONB fields and ensure numeric types are correct
|
|
||||||
fields = _get_model_fields(model_class)
|
fields = _get_model_fields(model_class)
|
||||||
model_fields = model_class.model_fields # Get Pydantic model fields
|
modelFields = model_class.model_fields
|
||||||
for record in records:
|
for record in records:
|
||||||
for field_name, field_type in fields.items():
|
_parseRecordFields(record, fields, f"table {table}")
|
||||||
# Ensure numeric fields (float/int) are properly typed
|
for fieldName, fieldType in fields.items():
|
||||||
# psycopg2 may return them as strings in some environments (e.g., Azure PostgreSQL)
|
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
|
||||||
if field_type in ("DOUBLE PRECISION", "INTEGER") and field_name in record:
|
fieldInfo = modelFields.get(fieldName)
|
||||||
value = record[field_name]
|
if fieldInfo:
|
||||||
if value is not None:
|
fieldAnnotation = fieldInfo.annotation
|
||||||
try:
|
if (fieldAnnotation == list or
|
||||||
if field_type == "DOUBLE PRECISION":
|
(hasattr(fieldAnnotation, "__origin__") and
|
||||||
record[field_name] = float(value)
|
fieldAnnotation.__origin__ is list)):
|
||||||
elif field_type == "INTEGER":
|
record[fieldName] = []
|
||||||
record[field_name] = int(value)
|
elif (fieldAnnotation == dict or
|
||||||
except (ValueError, TypeError):
|
(hasattr(fieldAnnotation, "__origin__") and
|
||||||
# If conversion fails, log warning but keep original value
|
fieldAnnotation.__origin__ is dict)):
|
||||||
logger.warning(
|
record[fieldName] = {}
|
||||||
f"Could not convert {field_name} to {field_type} for record {record.get('id', 'unknown')}: {value}"
|
|
||||||
)
|
|
||||||
elif field_type == "JSONB" and field_name in record:
|
|
||||||
if record[field_name] is None:
|
|
||||||
# Generic type-based default: List types -> [], Dict types -> {}
|
|
||||||
# Interfaces handle domain-specific defaults
|
|
||||||
field_info = model_fields.get(field_name)
|
|
||||||
if field_info:
|
|
||||||
field_annotation = field_info.annotation
|
|
||||||
# Check if it's a List type
|
|
||||||
if (field_annotation == list or
|
|
||||||
(hasattr(field_annotation, "__origin__") and
|
|
||||||
field_annotation.__origin__ is list)):
|
|
||||||
record[field_name] = []
|
|
||||||
# Check if it's a Dict type
|
|
||||||
elif (field_annotation == dict or
|
|
||||||
(hasattr(field_annotation, "__origin__") and
|
|
||||||
field_annotation.__origin__ is dict)):
|
|
||||||
record[field_name] = {}
|
|
||||||
else:
|
|
||||||
record[field_name] = None
|
|
||||||
else:
|
|
||||||
record[field_name] = None
|
|
||||||
else:
|
|
||||||
import json
|
|
||||||
|
|
||||||
try:
|
|
||||||
if isinstance(record[field_name], str):
|
|
||||||
# Parse JSON string back to Python object
|
|
||||||
record[field_name] = json.loads(
|
|
||||||
record[field_name]
|
|
||||||
)
|
|
||||||
elif isinstance(record[field_name], (dict, list)):
|
|
||||||
# Already a Python object, keep as is
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# Try to parse as JSON
|
|
||||||
record[field_name] = json.loads(
|
|
||||||
str(record[field_name])
|
|
||||||
)
|
|
||||||
except (json.JSONDecodeError, TypeError, ValueError):
|
|
||||||
# If parsing fails, keep as string
|
|
||||||
logger.warning(
|
|
||||||
f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}"
|
|
||||||
)
|
|
||||||
pass
|
|
||||||
|
|
||||||
# If fieldFilter is available, reduce the fields
|
# If fieldFilter is available, reduce the fields
|
||||||
if fieldFilter and isinstance(fieldFilter, list):
|
if fieldFilter and isinstance(fieldFilter, list):
|
||||||
|
|
@ -1127,6 +1077,85 @@ class DatabaseConnector:
|
||||||
initialId = systemData.get(table)
|
initialId = systemData.get(table)
|
||||||
return initialId
|
return initialId
|
||||||
|
|
||||||
|
def semanticSearch(
|
||||||
|
self,
|
||||||
|
modelClass: type,
|
||||||
|
vectorColumn: str,
|
||||||
|
queryVector: List[float],
|
||||||
|
limit: int = 10,
|
||||||
|
recordFilter: Dict[str, Any] = None,
|
||||||
|
minScore: float = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Semantic search using pgvector cosine distance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modelClass: Pydantic model class for the table.
|
||||||
|
vectorColumn: Name of the vector column to search.
|
||||||
|
queryVector: Query vector as List[float].
|
||||||
|
limit: Maximum number of results.
|
||||||
|
recordFilter: Additional WHERE filters (field: value).
|
||||||
|
minScore: Minimum cosine similarity (0.0 - 1.0).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of records with an added '_score' field (cosine similarity),
|
||||||
|
sorted by similarity descending.
|
||||||
|
"""
|
||||||
|
table = modelClass.__name__
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self._ensureTableExists(modelClass):
|
||||||
|
return []
|
||||||
|
|
||||||
|
vectorStr = f"[{','.join(str(v) for v in queryVector)}]"
|
||||||
|
|
||||||
|
whereConditions = []
|
||||||
|
whereValues = []
|
||||||
|
|
||||||
|
if recordFilter:
|
||||||
|
for field, value in recordFilter.items():
|
||||||
|
if value is None:
|
||||||
|
whereConditions.append(f'"{field}" IS NULL')
|
||||||
|
elif isinstance(value, (list, tuple)):
|
||||||
|
if not value:
|
||||||
|
whereConditions.append("1 = 0")
|
||||||
|
else:
|
||||||
|
whereConditions.append(f'"{field}" = ANY(%s)')
|
||||||
|
whereValues.append(list(value))
|
||||||
|
else:
|
||||||
|
whereConditions.append(f'"{field}" = %s')
|
||||||
|
whereValues.append(value)
|
||||||
|
|
||||||
|
if minScore is not None:
|
||||||
|
whereConditions.append(
|
||||||
|
f'1 - ("{vectorColumn}" <=> %s::vector) >= %s'
|
||||||
|
)
|
||||||
|
whereValues.extend([vectorStr, minScore])
|
||||||
|
|
||||||
|
whereClause = ""
|
||||||
|
if whereConditions:
|
||||||
|
whereClause = " WHERE " + " AND ".join(whereConditions)
|
||||||
|
|
||||||
|
query = (
|
||||||
|
f'SELECT *, 1 - ("{vectorColumn}" <=> %s::vector) AS "_score" '
|
||||||
|
f'FROM "{table}"{whereClause} '
|
||||||
|
f'ORDER BY "{vectorColumn}" <=> %s::vector '
|
||||||
|
f'LIMIT %s'
|
||||||
|
)
|
||||||
|
params = [vectorStr] + whereValues + [vectorStr, limit]
|
||||||
|
|
||||||
|
with self.connection.cursor() as cursor:
|
||||||
|
cursor.execute(query, params)
|
||||||
|
records = [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
fields = _get_model_fields(modelClass)
|
||||||
|
for record in records:
|
||||||
|
_parseRecordFields(record, fields, f"semanticSearch {table}")
|
||||||
|
|
||||||
|
return records
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in semantic search on {table}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Close the database connection."""
|
"""Close the database connection."""
|
||||||
if (
|
if (
|
||||||
|
|
@ -1141,5 +1170,4 @@ class DatabaseConnector:
|
||||||
try:
|
try:
|
||||||
self.close()
|
self.close()
|
||||||
except Exception:
|
except Exception:
|
||||||
# Ignore errors during cleanup
|
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
54
modules/connectors/connectorProviderBase.py
Normal file
54
modules/connectors/connectorProviderBase.py
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Abstract base classes for the Provider-Connector architecture (1:n).
|
||||||
|
|
||||||
|
One ProviderConnector per vendor (e.g. MsftConnector, GoogleConnector).
|
||||||
|
Each ProviderConnector exposes n ServiceAdapters (e.g. SharepointAdapter, OutlookAdapter).
|
||||||
|
All ServiceAdapters share the same access token from the UserConnection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceAdapter(ABC):
|
||||||
|
"""Standardized operations for a single service of a provider."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def browse(self, path: str, filter: Optional[str] = None) -> list:
|
||||||
|
"""List items (files/folders) at the given path."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def download(self, path: str) -> bytes:
|
||||||
|
"""Download a file and return its content bytes."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||||
|
"""Upload a file to the given path. Returns metadata of the created entry."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def search(self, query: str, path: Optional[str] = None) -> list:
|
||||||
|
"""Search for items matching the query."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderConnector(ABC):
|
||||||
|
"""One connector per provider. Manages a UserConnection + token.
|
||||||
|
Provides access to n services of the provider."""
|
||||||
|
|
||||||
|
def __init__(self, connection, accessToken: str):
|
||||||
|
self.connection = connection
|
||||||
|
self.accessToken = accessToken
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def getAvailableServices(self) -> List[str]:
|
||||||
|
"""Which services does this provider offer?"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def getServiceAdapter(self, service: str) -> ServiceAdapter:
|
||||||
|
"""Return the ServiceAdapter for a specific service."""
|
||||||
|
...
|
||||||
94
modules/connectors/connectorResolver.py
Normal file
94
modules/connectors/connectorResolver.py
Normal file
|
|
@ -0,0 +1,94 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""ConnectorResolver -- resolves a connectionId to the correct ProviderConnector and ServiceAdapter.
|
||||||
|
|
||||||
|
Registry maps authority values to ProviderConnector classes.
|
||||||
|
The resolver loads the UserConnection, obtains a fresh token via SecurityService,
|
||||||
|
and instantiates the appropriate connector.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, Type, Optional
|
||||||
|
|
||||||
|
from modules.connectors.connectorProviderBase import ProviderConnector, ServiceAdapter
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorResolver:
|
||||||
|
"""Resolves connectionId → ProviderConnector (with fresh token) → ServiceAdapter."""
|
||||||
|
|
||||||
|
_providerRegistry: Dict[str, Type[ProviderConnector]] = {}
|
||||||
|
|
||||||
|
def __init__(self, securityService, dbInterface):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
securityService: SecurityService instance (for getFreshToken)
|
||||||
|
dbInterface: DB interface with getUserConnection(connectionId)
|
||||||
|
"""
|
||||||
|
self._security = securityService
|
||||||
|
self._db = dbInterface
|
||||||
|
self._ensureRegistered()
|
||||||
|
|
||||||
|
def _ensureRegistered(self):
|
||||||
|
"""Lazy-register known providers on first instantiation."""
|
||||||
|
if ConnectorResolver._providerRegistry:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
from modules.connectors.providerMsft.connectorMsft import MsftConnector
|
||||||
|
ConnectorResolver._providerRegistry["msft"] = MsftConnector
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("MsftConnector not available")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from modules.connectors.providerGoogle.connectorGoogle import GoogleConnector
|
||||||
|
ConnectorResolver._providerRegistry["google"] = GoogleConnector
|
||||||
|
except ImportError:
|
||||||
|
logger.debug("GoogleConnector not available (stub)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from modules.connectors.providerFtp.connectorFtp import FtpConnector
|
||||||
|
ConnectorResolver._providerRegistry["local:ftp"] = FtpConnector
|
||||||
|
except ImportError:
|
||||||
|
logger.debug("FtpConnector not available (stub)")
|
||||||
|
|
||||||
|
async def resolve(self, connectionId: str) -> ProviderConnector:
|
||||||
|
"""Resolve connectionId to a ProviderConnector with a fresh access token."""
|
||||||
|
connection = await self._loadConnection(connectionId)
|
||||||
|
if not connection:
|
||||||
|
raise ValueError(f"UserConnection not found: {connectionId}")
|
||||||
|
|
||||||
|
authority = getattr(connection, "authority", None)
|
||||||
|
if not authority:
|
||||||
|
raise ValueError(f"Connection {connectionId} has no authority")
|
||||||
|
|
||||||
|
authorityStr = authority.value if hasattr(authority, "value") else str(authority)
|
||||||
|
providerClass = self._providerRegistry.get(authorityStr)
|
||||||
|
if not providerClass:
|
||||||
|
raise ValueError(f"No ProviderConnector registered for authority: {authorityStr}")
|
||||||
|
|
||||||
|
token = self._security.getFreshToken(connectionId)
|
||||||
|
if not token or not token.tokenAccess:
|
||||||
|
raise ValueError(f"No valid token for connection {connectionId}")
|
||||||
|
|
||||||
|
return providerClass(connection, token.tokenAccess)
|
||||||
|
|
||||||
|
async def resolveService(self, connectionId: str, service: str) -> ServiceAdapter:
|
||||||
|
"""Resolve connectionId + service name to a concrete ServiceAdapter."""
|
||||||
|
provider = await self.resolve(connectionId)
|
||||||
|
available = provider.getAvailableServices()
|
||||||
|
if service not in available:
|
||||||
|
raise ValueError(f"Service '{service}' not available. Options: {available}")
|
||||||
|
return provider.getServiceAdapter(service)
|
||||||
|
|
||||||
|
async def _loadConnection(self, connectionId: str) -> Optional[Any]:
|
||||||
|
"""Load UserConnection from DB."""
|
||||||
|
try:
|
||||||
|
if hasattr(self._db, "getUserConnection"):
|
||||||
|
return self._db.getUserConnection(connectionId)
|
||||||
|
if hasattr(self._db, "loadRecord"):
|
||||||
|
from modules.datamodels.datamodelUam import UserConnection
|
||||||
|
return self._db.loadRecord(UserConnection, connectionId)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load connection {connectionId}: {e}")
|
||||||
|
return None
|
||||||
3
modules/connectors/providerFtp/__init__.py
Normal file
3
modules/connectors/providerFtp/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""FTP/SFTP Provider Connector stub."""
|
||||||
48
modules/connectors/providerFtp/connectorFtp.py
Normal file
48
modules/connectors/providerFtp/connectorFtp.py
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""FTP/SFTP ProviderConnector stub.
|
||||||
|
|
||||||
|
Implements the ProviderConnector interface for FTP/SFTP file access.
|
||||||
|
Full implementation follows when FTP integration is prioritized.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from modules.connectors.connectorProviderBase import ProviderConnector, ServiceAdapter
|
||||||
|
from modules.datamodels.datamodelDataSource import ExternalEntry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FtpFilesAdapter(ServiceAdapter):
|
||||||
|
"""FTP files ServiceAdapter (stub)."""
|
||||||
|
|
||||||
|
def __init__(self, accessToken: str):
|
||||||
|
self._accessToken = accessToken
|
||||||
|
|
||||||
|
async def browse(self, path: str, filter: Optional[str] = None) -> List[ExternalEntry]:
|
||||||
|
logger.info(f"FTP browse stub: {path}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def download(self, path: str) -> bytes:
|
||||||
|
logger.info(f"FTP download stub: {path}")
|
||||||
|
return b""
|
||||||
|
|
||||||
|
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||||
|
return {"error": "FTP upload not yet implemented"}
|
||||||
|
|
||||||
|
async def search(self, query: str, path: Optional[str] = None) -> List[ExternalEntry]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class FtpConnector(ProviderConnector):
|
||||||
|
"""FTP ProviderConnector -- 1 connection -> files."""
|
||||||
|
|
||||||
|
def getAvailableServices(self) -> List[str]:
|
||||||
|
return ["files"]
|
||||||
|
|
||||||
|
def getServiceAdapter(self, service: str) -> ServiceAdapter:
|
||||||
|
if service != "files":
|
||||||
|
raise ValueError(f"FTP only supports 'files' service, got '{service}'")
|
||||||
|
return FtpFilesAdapter(self.accessToken)
|
||||||
3
modules/connectors/providerGoogle/__init__.py
Normal file
3
modules/connectors/providerGoogle/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Google Provider Connector -- 1 Connection : n Services (Drive, Gmail)."""
|
||||||
194
modules/connectors/providerGoogle/connectorGoogle.py
Normal file
194
modules/connectors/providerGoogle/connectorGoogle.py
Normal file
|
|
@ -0,0 +1,194 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Google ProviderConnector -- Drive and Gmail via Google OAuth."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from modules.connectors.connectorProviderBase import ProviderConnector, ServiceAdapter
|
||||||
|
from modules.datamodels.datamodelDataSource import ExternalEntry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_DRIVE_BASE = "https://www.googleapis.com/drive/v3"
|
||||||
|
_GMAIL_BASE = "https://gmail.googleapis.com/gmail/v1"
|
||||||
|
|
||||||
|
|
||||||
|
async def _googleGet(token: str, url: str) -> Dict[str, Any]:
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
timeout = aiohttp.ClientTimeout(total=20)
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
async with session.get(url, headers=headers) as resp:
|
||||||
|
if resp.status in (200, 201):
|
||||||
|
return await resp.json()
|
||||||
|
errorText = await resp.text()
|
||||||
|
logger.warning(f"Google API {resp.status}: {errorText[:300]}")
|
||||||
|
return {"error": f"{resp.status}: {errorText[:200]}"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
class DriveAdapter(ServiceAdapter):
|
||||||
|
"""Google Drive ServiceAdapter -- browse files and folders."""
|
||||||
|
|
||||||
|
def __init__(self, accessToken: str):
|
||||||
|
self._token = accessToken
|
||||||
|
|
||||||
|
async def browse(self, path: str, filter: Optional[str] = None) -> List[ExternalEntry]:
|
||||||
|
folderId = (path or "").strip("/") or "root"
|
||||||
|
query = f"'{folderId}' in parents and trashed=false"
|
||||||
|
fields = "files(id,name,mimeType,size,modifiedTime,parents)"
|
||||||
|
url = f"{_DRIVE_BASE}/files?q={query}&fields={fields}&pageSize=100&orderBy=folder,name"
|
||||||
|
|
||||||
|
result = await _googleGet(self._token, url)
|
||||||
|
if "error" in result:
|
||||||
|
logger.warning(f"Google Drive browse failed: {result['error']}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
entries = []
|
||||||
|
for f in result.get("files", []):
|
||||||
|
isFolder = f.get("mimeType") == "application/vnd.google-apps.folder"
|
||||||
|
entries.append(ExternalEntry(
|
||||||
|
name=f.get("name", ""),
|
||||||
|
path=f"/{f.get('id', '')}",
|
||||||
|
isFolder=isFolder,
|
||||||
|
size=int(f.get("size", 0)) if f.get("size") else None,
|
||||||
|
mimeType=f.get("mimeType") if not isFolder else None,
|
||||||
|
metadata={"id": f.get("id"), "modifiedTime": f.get("modifiedTime")},
|
||||||
|
))
|
||||||
|
return entries
|
||||||
|
|
||||||
|
async def download(self, path: str) -> bytes:
|
||||||
|
fileId = (path or "").strip("/")
|
||||||
|
if not fileId:
|
||||||
|
return b""
|
||||||
|
url = f"{_DRIVE_BASE}/files/{fileId}?alt=media"
|
||||||
|
headers = {"Authorization": f"Bearer {self._token}"}
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url, headers=headers) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
return await resp.read()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Google Drive download failed: {e}")
|
||||||
|
return b""
|
||||||
|
|
||||||
|
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||||
|
return {"error": "Google Drive upload not yet implemented"}
|
||||||
|
|
||||||
|
async def search(self, query: str, path: Optional[str] = None) -> List[ExternalEntry]:
|
||||||
|
safeQuery = query.replace("'", "\\'")
|
||||||
|
url = f"{_DRIVE_BASE}/files?q=name contains '{safeQuery}' and trashed=false&fields=files(id,name,mimeType,size)&pageSize=25"
|
||||||
|
result = await _googleGet(self._token, url)
|
||||||
|
if "error" in result:
|
||||||
|
return []
|
||||||
|
return [
|
||||||
|
ExternalEntry(
|
||||||
|
name=f.get("name", ""),
|
||||||
|
path=f"/{f.get('id', '')}",
|
||||||
|
isFolder=f.get("mimeType") == "application/vnd.google-apps.folder",
|
||||||
|
size=int(f.get("size", 0)) if f.get("size") else None,
|
||||||
|
)
|
||||||
|
for f in result.get("files", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class GmailAdapter(ServiceAdapter):
|
||||||
|
"""Gmail ServiceAdapter -- browse labels and messages."""
|
||||||
|
|
||||||
|
def __init__(self, accessToken: str):
|
||||||
|
self._token = accessToken
|
||||||
|
|
||||||
|
async def browse(self, path: str, filter: Optional[str] = None) -> list:
|
||||||
|
cleanPath = (path or "").strip("/")
|
||||||
|
|
||||||
|
if not cleanPath:
|
||||||
|
url = f"{_GMAIL_BASE}/users/me/labels"
|
||||||
|
result = await _googleGet(self._token, url)
|
||||||
|
if "error" in result:
|
||||||
|
logger.warning(f"Gmail labels failed: {result['error']}")
|
||||||
|
return []
|
||||||
|
_SYSTEM_LABELS = {"INBOX", "SENT", "DRAFT", "TRASH", "SPAM", "STARRED", "IMPORTANT"}
|
||||||
|
labels = []
|
||||||
|
for lbl in result.get("labels", []):
|
||||||
|
labelId = lbl.get("id", "")
|
||||||
|
labelName = lbl.get("name", labelId)
|
||||||
|
if lbl.get("type") == "system" and labelId not in _SYSTEM_LABELS:
|
||||||
|
continue
|
||||||
|
labels.append(ExternalEntry(
|
||||||
|
name=labelName,
|
||||||
|
path=f"/{labelId}",
|
||||||
|
isFolder=True,
|
||||||
|
metadata={"id": labelId, "type": lbl.get("type", "")},
|
||||||
|
))
|
||||||
|
labels.sort(key=lambda e: (0 if e.metadata.get("type") == "system" else 1, e.name))
|
||||||
|
return labels
|
||||||
|
|
||||||
|
url = f"{_GMAIL_BASE}/users/me/messages?labelIds={cleanPath}&maxResults=25"
|
||||||
|
result = await _googleGet(self._token, url)
|
||||||
|
if "error" in result:
|
||||||
|
return []
|
||||||
|
|
||||||
|
entries = []
|
||||||
|
for msg in result.get("messages", [])[:25]:
|
||||||
|
msgId = msg.get("id", "")
|
||||||
|
detailUrl = f"{_GMAIL_BASE}/users/me/messages/{msgId}?format=metadata&metadataHeaders=Subject&metadataHeaders=From&metadataHeaders=Date"
|
||||||
|
detail = await _googleGet(self._token, detailUrl)
|
||||||
|
if "error" in detail:
|
||||||
|
entries.append(ExternalEntry(name=f"Message {msgId}", path=f"/{cleanPath}/{msgId}", isFolder=False))
|
||||||
|
continue
|
||||||
|
headers = {h.get("name", ""): h.get("value", "") for h in detail.get("payload", {}).get("headers", [])}
|
||||||
|
entries.append(ExternalEntry(
|
||||||
|
name=headers.get("Subject", "(no subject)"),
|
||||||
|
path=f"/{cleanPath}/{msgId}",
|
||||||
|
isFolder=False,
|
||||||
|
metadata={
|
||||||
|
"id": msgId,
|
||||||
|
"from": headers.get("From", ""),
|
||||||
|
"date": headers.get("Date", ""),
|
||||||
|
"snippet": detail.get("snippet", ""),
|
||||||
|
},
|
||||||
|
))
|
||||||
|
return entries
|
||||||
|
|
||||||
|
async def download(self, path: str) -> bytes:
|
||||||
|
return b""
|
||||||
|
|
||||||
|
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||||
|
return {"error": "Gmail upload not applicable"}
|
||||||
|
|
||||||
|
async def search(self, query: str, path: Optional[str] = None) -> list:
|
||||||
|
url = f"{_GMAIL_BASE}/users/me/messages?q={query}&maxResults=10"
|
||||||
|
result = await _googleGet(self._token, url)
|
||||||
|
if "error" in result:
|
||||||
|
return []
|
||||||
|
return [
|
||||||
|
ExternalEntry(
|
||||||
|
name=f"Message {m.get('id', '')}",
|
||||||
|
path=f"/{m.get('id', '')}",
|
||||||
|
isFolder=False,
|
||||||
|
metadata={"id": m.get("id")},
|
||||||
|
)
|
||||||
|
for m in result.get("messages", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleConnector(ProviderConnector):
|
||||||
|
"""Google ProviderConnector -- 1 connection -> Drive + Gmail."""
|
||||||
|
|
||||||
|
_SERVICE_MAP = {
|
||||||
|
"drive": DriveAdapter,
|
||||||
|
"gmail": GmailAdapter,
|
||||||
|
}
|
||||||
|
|
||||||
|
def getAvailableServices(self) -> List[str]:
|
||||||
|
return list(self._SERVICE_MAP.keys())
|
||||||
|
|
||||||
|
def getServiceAdapter(self, service: str) -> ServiceAdapter:
|
||||||
|
adapterClass = self._SERVICE_MAP.get(service)
|
||||||
|
if not adapterClass:
|
||||||
|
raise ValueError(f"Unknown Google service: {service}. Available: {list(self._SERVICE_MAP.keys())}")
|
||||||
|
return adapterClass(self.accessToken)
|
||||||
3
modules/connectors/providerMsft/__init__.py
Normal file
3
modules/connectors/providerMsft/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Microsoft Provider Connector -- 1 Connection : n Services (SharePoint, Outlook, Teams, OneDrive)."""
|
||||||
459
modules/connectors/providerMsft/connectorMsft.py
Normal file
459
modules/connectors/providerMsft/connectorMsft.py
Normal file
|
|
@ -0,0 +1,459 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Microsoft ProviderConnector -- one MSFT connection serves SharePoint, Outlook, Teams, OneDrive.
|
||||||
|
|
||||||
|
All ServiceAdapters share the same OAuth access token obtained from the
|
||||||
|
UserConnection (authority=msft).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
|
from modules.connectors.connectorProviderBase import ProviderConnector, ServiceAdapter
|
||||||
|
from modules.datamodels.datamodelDataSource import ExternalEntry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
|
||||||
|
|
||||||
|
|
||||||
|
class _GraphApiMixin:
|
||||||
|
"""Shared Graph API call logic for all MSFT service adapters."""
|
||||||
|
|
||||||
|
def __init__(self, accessToken: str):
|
||||||
|
self._accessToken = accessToken
|
||||||
|
|
||||||
|
async def _graphGet(self, endpoint: str) -> Dict[str, Any]:
|
||||||
|
return await _makeGraphCall(self._accessToken, endpoint, "GET")
|
||||||
|
|
||||||
|
async def _graphPost(self, endpoint: str, data: Any = None) -> Dict[str, Any]:
|
||||||
|
return await _makeGraphCall(self._accessToken, endpoint, "POST", data)
|
||||||
|
|
||||||
|
async def _graphPut(self, endpoint: str, data: bytes = None) -> Dict[str, Any]:
|
||||||
|
return await _makeGraphCall(self._accessToken, endpoint, "PUT", data)
|
||||||
|
|
||||||
|
async def _graphDelete(self, endpoint: str) -> Dict[str, Any]:
|
||||||
|
return await _makeGraphCall(self._accessToken, endpoint, "DELETE")
|
||||||
|
|
||||||
|
async def _graphDownload(self, endpoint: str) -> Optional[bytes]:
|
||||||
|
"""Download binary content from Graph API."""
|
||||||
|
headers = {"Authorization": f"Bearer {self._accessToken}"}
|
||||||
|
timeout = aiohttp.ClientTimeout(total=60)
|
||||||
|
url = f"{_GRAPH_BASE}/{endpoint.lstrip('/')}"
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
async with session.get(url, headers=headers) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
return await resp.read()
|
||||||
|
logger.error(f"Download failed {resp.status}: {await resp.text()}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Graph download error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _makeGraphCall(
|
||||||
|
token: str, endpoint: str, method: str = "GET", data: Any = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Execute a single Microsoft Graph API call."""
|
||||||
|
url = f"{_GRAPH_BASE}/{endpoint.lstrip('/')}"
|
||||||
|
contentType = "application/json"
|
||||||
|
if method == "PUT" and isinstance(data, bytes):
|
||||||
|
contentType = "application/octet-stream"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Content-Type": contentType,
|
||||||
|
}
|
||||||
|
timeout = aiohttp.ClientTimeout(total=30)
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
kwargs: Dict[str, Any] = {"headers": headers}
|
||||||
|
if data is not None:
|
||||||
|
kwargs["data"] = data
|
||||||
|
|
||||||
|
if method == "GET":
|
||||||
|
async with session.get(url, **kwargs) as resp:
|
||||||
|
return await _handleResponse(resp)
|
||||||
|
elif method == "POST":
|
||||||
|
async with session.post(url, **kwargs) as resp:
|
||||||
|
return await _handleResponse(resp)
|
||||||
|
elif method == "PUT":
|
||||||
|
async with session.put(url, **kwargs) as resp:
|
||||||
|
return await _handleResponse(resp)
|
||||||
|
elif method == "DELETE":
|
||||||
|
async with session.delete(url, **kwargs) as resp:
|
||||||
|
if resp.status in (200, 204):
|
||||||
|
return {}
|
||||||
|
return await _handleResponse(resp)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return {"error": f"Graph API timeout: {endpoint}"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"Graph API error: {e}"}
|
||||||
|
|
||||||
|
return {"error": f"Unsupported method: {method}"}
|
||||||
|
|
||||||
|
|
||||||
|
async def _handleResponse(resp: aiohttp.ClientResponse) -> Dict[str, Any]:
|
||||||
|
if resp.status in (200, 201):
|
||||||
|
return await resp.json()
|
||||||
|
errorText = await resp.text()
|
||||||
|
logger.error(f"Graph API {resp.status}: {errorText}")
|
||||||
|
return {"error": f"{resp.status}: {errorText}"}
|
||||||
|
|
||||||
|
|
||||||
|
def _graphItemToExternalEntry(item: Dict[str, Any], basePath: str = "") -> ExternalEntry:
|
||||||
|
isFolder = "folder" in item
|
||||||
|
return ExternalEntry(
|
||||||
|
name=item.get("name", ""),
|
||||||
|
path=f"{basePath}/{item.get('name', '')}" if basePath else item.get("name", ""),
|
||||||
|
isFolder=isFolder,
|
||||||
|
size=item.get("size"),
|
||||||
|
mimeType=item.get("file", {}).get("mimeType") if not isFolder else None,
|
||||||
|
lastModified=None,
|
||||||
|
metadata={
|
||||||
|
"id": item.get("id"),
|
||||||
|
"webUrl": item.get("webUrl"),
|
||||||
|
"childCount": item.get("folder", {}).get("childCount") if isFolder else None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SharePoint Adapter
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class SharepointAdapter(_GraphApiMixin, ServiceAdapter):
|
||||||
|
"""ServiceAdapter for SharePoint (files, sites) via Microsoft Graph."""
|
||||||
|
|
||||||
|
async def browse(self, path: str, filter: Optional[str] = None) -> List[ExternalEntry]:
|
||||||
|
"""List items in a SharePoint folder.
|
||||||
|
|
||||||
|
Path format: /sites/<SiteName>/<FolderPath>
|
||||||
|
Root "/" lists available sites via discovery.
|
||||||
|
"""
|
||||||
|
if not path or path == "/":
|
||||||
|
return await self._discoverSites()
|
||||||
|
|
||||||
|
siteId, folderPath = _parseSharepointPath(path)
|
||||||
|
if not siteId:
|
||||||
|
return await self._discoverSites()
|
||||||
|
|
||||||
|
if not folderPath or folderPath == "/":
|
||||||
|
endpoint = f"sites/{siteId}/drive/root/children"
|
||||||
|
else:
|
||||||
|
cleanPath = folderPath.lstrip("/")
|
||||||
|
endpoint = f"sites/{siteId}/drive/root:/{cleanPath}:/children"
|
||||||
|
|
||||||
|
result = await self._graphGet(endpoint)
|
||||||
|
if "error" in result:
|
||||||
|
logger.warning(f"SharePoint browse failed: {result['error']}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
entries = [_graphItemToExternalEntry(item, path) for item in result.get("value", [])]
|
||||||
|
if filter:
|
||||||
|
entries = [e for e in entries if _matchFilter(e, filter)]
|
||||||
|
return entries
|
||||||
|
|
||||||
|
async def _discoverSites(self) -> List[ExternalEntry]:
|
||||||
|
"""Discover accessible SharePoint sites."""
|
||||||
|
result = await self._graphGet("sites?search=*&$top=50")
|
||||||
|
if "error" in result:
|
||||||
|
logger.warning(f"SharePoint site discovery failed: {result['error']}")
|
||||||
|
return []
|
||||||
|
return [
|
||||||
|
ExternalEntry(
|
||||||
|
name=s.get("displayName") or s.get("name", ""),
|
||||||
|
path=f"/sites/{s.get('id', '')}",
|
||||||
|
isFolder=True,
|
||||||
|
metadata={
|
||||||
|
"id": s.get("id"),
|
||||||
|
"webUrl": s.get("webUrl"),
|
||||||
|
"description": s.get("description", ""),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for s in result.get("value", [])
|
||||||
|
if s.get("displayName")
|
||||||
|
]
|
||||||
|
|
||||||
|
async def download(self, path: str) -> bytes:
|
||||||
|
siteId, filePath = _parseSharepointPath(path)
|
||||||
|
if not siteId or not filePath:
|
||||||
|
return b""
|
||||||
|
cleanPath = filePath.strip("/")
|
||||||
|
endpoint = f"sites/{siteId}/drive/root:/{cleanPath}:/content"
|
||||||
|
data = await self._graphDownload(endpoint)
|
||||||
|
return data or b""
|
||||||
|
|
||||||
|
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||||
|
siteId, folderPath = _parseSharepointPath(path)
|
||||||
|
if not siteId:
|
||||||
|
return {"error": "Invalid SharePoint path"}
|
||||||
|
cleanFolder = (folderPath or "").strip("/")
|
||||||
|
uploadPath = f"{cleanFolder}/{fileName}" if cleanFolder else fileName
|
||||||
|
endpoint = f"sites/{siteId}/drive/root:/{uploadPath}:/content"
|
||||||
|
result = await self._graphPut(endpoint, data)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def search(self, query: str, path: Optional[str] = None) -> List[ExternalEntry]:
|
||||||
|
siteId, _ = _parseSharepointPath(path or "")
|
||||||
|
if not siteId:
|
||||||
|
return []
|
||||||
|
safeQuery = query.replace("'", "''")
|
||||||
|
endpoint = f"sites/{siteId}/drive/root/search(q='{safeQuery}')"
|
||||||
|
result = await self._graphGet(endpoint)
|
||||||
|
if "error" in result:
|
||||||
|
return []
|
||||||
|
return [_graphItemToExternalEntry(item) for item in result.get("value", [])]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Outlook Adapter
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class OutlookAdapter(_GraphApiMixin, ServiceAdapter):
|
||||||
|
"""ServiceAdapter for Outlook (mail, calendar) via Microsoft Graph."""
|
||||||
|
|
||||||
|
async def browse(self, path: str, filter: Optional[str] = None) -> List[ExternalEntry]:
|
||||||
|
"""List mail folders or messages.
|
||||||
|
|
||||||
|
path = "" or "/" → list mail folders
|
||||||
|
path = "/Inbox" → list messages in Inbox
|
||||||
|
"""
|
||||||
|
if not path or path == "/":
|
||||||
|
result = await self._graphGet("me/mailFolders")
|
||||||
|
if "error" in result:
|
||||||
|
return []
|
||||||
|
return [
|
||||||
|
ExternalEntry(
|
||||||
|
name=f.get("displayName", ""),
|
||||||
|
path=f"/{f.get('displayName', '')}",
|
||||||
|
isFolder=True,
|
||||||
|
metadata={"id": f.get("id"), "totalItemCount": f.get("totalItemCount")},
|
||||||
|
)
|
||||||
|
for f in result.get("value", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
folderName = path.strip("/")
|
||||||
|
endpoint = f"me/mailFolders/{folderName}/messages?$top=25&$orderby=receivedDateTime desc"
|
||||||
|
result = await self._graphGet(endpoint)
|
||||||
|
if "error" in result:
|
||||||
|
return []
|
||||||
|
return [
|
||||||
|
ExternalEntry(
|
||||||
|
name=m.get("subject", "(no subject)"),
|
||||||
|
path=f"{path}/{m.get('id', '')}",
|
||||||
|
isFolder=False,
|
||||||
|
metadata={
|
||||||
|
"id": m.get("id"),
|
||||||
|
"from": m.get("from", {}).get("emailAddress", {}).get("address"),
|
||||||
|
"receivedDateTime": m.get("receivedDateTime"),
|
||||||
|
"hasAttachments": m.get("hasAttachments", False),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for m in result.get("value", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
async def download(self, path: str) -> bytes:
|
||||||
|
"""Download a mail message as JSON bytes."""
|
||||||
|
import json
|
||||||
|
messageId = path.strip("/").split("/")[-1]
|
||||||
|
result = await self._graphGet(f"me/messages/{messageId}")
|
||||||
|
if "error" in result:
|
||||||
|
return b""
|
||||||
|
return json.dumps(result, ensure_ascii=False).encode("utf-8")
|
||||||
|
|
||||||
|
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||||
|
"""Not applicable for Outlook in the file sense."""
|
||||||
|
return {"error": "Upload not supported for Outlook"}
|
||||||
|
|
||||||
|
async def search(self, query: str, path: Optional[str] = None) -> List[ExternalEntry]:
|
||||||
|
safeQuery = query.replace("'", "''")
|
||||||
|
endpoint = f"me/messages?$search=\"{safeQuery}\"&$top=25"
|
||||||
|
result = await self._graphGet(endpoint)
|
||||||
|
if "error" in result:
|
||||||
|
return []
|
||||||
|
return [
|
||||||
|
ExternalEntry(
|
||||||
|
name=m.get("subject", "(no subject)"),
|
||||||
|
path=f"/search/{m.get('id', '')}",
|
||||||
|
isFolder=False,
|
||||||
|
metadata={
|
||||||
|
"id": m.get("id"),
|
||||||
|
"from": m.get("from", {}).get("emailAddress", {}).get("address"),
|
||||||
|
"receivedDateTime": m.get("receivedDateTime"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for m in result.get("value", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
async def sendMail(
|
||||||
|
self, to: List[str], subject: str, body: str,
|
||||||
|
cc: Optional[List[str]] = None, attachments: Optional[List[Dict]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Send an email via Microsoft Graph."""
|
||||||
|
import json
|
||||||
|
message: Dict[str, Any] = {
|
||||||
|
"subject": subject,
|
||||||
|
"body": {"contentType": "Text", "content": body},
|
||||||
|
"toRecipients": [{"emailAddress": {"address": addr}} for addr in to],
|
||||||
|
}
|
||||||
|
if cc:
|
||||||
|
message["ccRecipients"] = [{"emailAddress": {"address": addr}} for addr in cc]
|
||||||
|
|
||||||
|
payload = json.dumps({"message": message, "saveToSentItems": True}).encode("utf-8")
|
||||||
|
result = await self._graphPost("me/sendMail", payload)
|
||||||
|
if "error" in result:
|
||||||
|
return result
|
||||||
|
return {"success": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Teams Adapter (Stub)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TeamsAdapter(_GraphApiMixin, ServiceAdapter):
|
||||||
|
"""ServiceAdapter for Microsoft Teams -- browse joined teams and channels."""
|
||||||
|
|
||||||
|
async def browse(self, path: str, filter: Optional[str] = None) -> list:
|
||||||
|
cleanPath = (path or "").strip("/")
|
||||||
|
|
||||||
|
if not cleanPath:
|
||||||
|
result = await self._graphGet("me/joinedTeams")
|
||||||
|
if "error" in result:
|
||||||
|
logger.warning(f"Teams browse failed: {result['error']}")
|
||||||
|
return []
|
||||||
|
return [
|
||||||
|
ExternalEntry(
|
||||||
|
name=t.get("displayName", ""),
|
||||||
|
path=f"/{t.get('id', '')}",
|
||||||
|
isFolder=True,
|
||||||
|
metadata={"id": t.get("id"), "description": t.get("description", "")},
|
||||||
|
)
|
||||||
|
for t in result.get("value", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
parts = cleanPath.split("/", 1)
|
||||||
|
teamId = parts[0]
|
||||||
|
if len(parts) == 1:
|
||||||
|
result = await self._graphGet(f"teams/{teamId}/channels")
|
||||||
|
if "error" in result:
|
||||||
|
return []
|
||||||
|
return [
|
||||||
|
ExternalEntry(
|
||||||
|
name=ch.get("displayName", ""),
|
||||||
|
path=f"/{teamId}/{ch.get('id', '')}",
|
||||||
|
isFolder=True,
|
||||||
|
metadata={"id": ch.get("id"), "membershipType": ch.get("membershipType", "")},
|
||||||
|
)
|
||||||
|
for ch in result.get("value", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def download(self, path: str) -> bytes:
|
||||||
|
return b""
|
||||||
|
|
||||||
|
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||||
|
return {"error": "Teams upload not implemented"}
|
||||||
|
|
||||||
|
async def search(self, query: str, path: Optional[str] = None) -> list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# OneDrive Adapter (Stub -- similar to SharePoint but personal drive)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class OneDriveAdapter(_GraphApiMixin, ServiceAdapter):
|
||||||
|
"""ServiceAdapter stub for OneDrive (personal drive)."""
|
||||||
|
|
||||||
|
async def browse(self, path: str, filter: Optional[str] = None) -> List[ExternalEntry]:
|
||||||
|
cleanPath = (path or "").strip("/")
|
||||||
|
if not cleanPath:
|
||||||
|
endpoint = "me/drive/root/children"
|
||||||
|
else:
|
||||||
|
endpoint = f"me/drive/root:/{cleanPath}:/children"
|
||||||
|
|
||||||
|
result = await self._graphGet(endpoint)
|
||||||
|
if "error" in result:
|
||||||
|
return []
|
||||||
|
entries = [_graphItemToExternalEntry(item, path) for item in result.get("value", [])]
|
||||||
|
if filter:
|
||||||
|
entries = [e for e in entries if _matchFilter(e, filter)]
|
||||||
|
return entries
|
||||||
|
|
||||||
|
async def download(self, path: str) -> bytes:
|
||||||
|
cleanPath = (path or "").strip("/")
|
||||||
|
if not cleanPath:
|
||||||
|
return b""
|
||||||
|
data = await self._graphDownload(f"me/drive/root:/{cleanPath}:/content")
|
||||||
|
return data or b""
|
||||||
|
|
||||||
|
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||||
|
cleanPath = (path or "").strip("/")
|
||||||
|
uploadPath = f"{cleanPath}/{fileName}" if cleanPath else fileName
|
||||||
|
endpoint = f"me/drive/root:/{uploadPath}:/content"
|
||||||
|
return await self._graphPut(endpoint, data)
|
||||||
|
|
||||||
|
async def search(self, query: str, path: Optional[str] = None) -> List[ExternalEntry]:
|
||||||
|
safeQuery = query.replace("'", "''")
|
||||||
|
endpoint = f"me/drive/root/search(q='{safeQuery}')"
|
||||||
|
result = await self._graphGet(endpoint)
|
||||||
|
if "error" in result:
|
||||||
|
return []
|
||||||
|
return [_graphItemToExternalEntry(item) for item in result.get("value", [])]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# MsftConnector (1:n)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class MsftConnector(ProviderConnector):
|
||||||
|
"""Microsoft ProviderConnector -- 1 connection → n services."""
|
||||||
|
|
||||||
|
_SERVICE_MAP = {
|
||||||
|
"sharepoint": SharepointAdapter,
|
||||||
|
"outlook": OutlookAdapter,
|
||||||
|
"teams": TeamsAdapter,
|
||||||
|
"onedrive": OneDriveAdapter,
|
||||||
|
}
|
||||||
|
|
||||||
|
def getAvailableServices(self) -> List[str]:
|
||||||
|
return list(self._SERVICE_MAP.keys())
|
||||||
|
|
||||||
|
def getServiceAdapter(self, service: str) -> ServiceAdapter:
|
||||||
|
adapterClass = self._SERVICE_MAP.get(service)
|
||||||
|
if not adapterClass:
|
||||||
|
raise ValueError(f"Unknown MSFT service: {service}. Available: {list(self._SERVICE_MAP.keys())}")
|
||||||
|
return adapterClass(self.accessToken)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _parseSharepointPath(path: str) -> tuple:
|
||||||
|
"""Parse a SharePoint path into (siteId, innerPath).
|
||||||
|
|
||||||
|
Expected format: /sites/<siteId>/<innerPath>
|
||||||
|
Also accepts bare siteId if no /sites/ prefix.
|
||||||
|
"""
|
||||||
|
if not path:
|
||||||
|
return ("", "")
|
||||||
|
clean = path.strip("/")
|
||||||
|
if clean.startswith("sites/"):
|
||||||
|
parts = clean.split("/", 2)
|
||||||
|
siteId = parts[1] if len(parts) > 1 else ""
|
||||||
|
innerPath = parts[2] if len(parts) > 2 else ""
|
||||||
|
return (siteId, innerPath)
|
||||||
|
parts = clean.split("/", 1)
|
||||||
|
return (parts[0], parts[1] if len(parts) > 1 else "")
|
||||||
|
|
||||||
|
|
||||||
|
def _matchFilter(entry: ExternalEntry, pattern: str) -> bool:
|
||||||
|
"""Simple glob-like filter (supports * wildcard)."""
|
||||||
|
import fnmatch
|
||||||
|
return fnmatch.fnmatch(entry.name.lower(), pattern.lower())
|
||||||
|
|
@ -26,6 +26,12 @@ class OperationTypeEnum(str, Enum):
|
||||||
WEB_SEARCH_DATA = "webSearch" # Returns list of URLs only
|
WEB_SEARCH_DATA = "webSearch" # Returns list of URLs only
|
||||||
WEB_CRAWL = "webCrawl" # Web crawl for a given URL
|
WEB_CRAWL = "webCrawl" # Web crawl for a given URL
|
||||||
|
|
||||||
|
# Agent Operations
|
||||||
|
AGENT = "agent" # Agent loop: reasoning + tool use
|
||||||
|
|
||||||
|
# Embedding Operations
|
||||||
|
EMBEDDING = "embedding" # Text → vector conversion for semantic search
|
||||||
|
|
||||||
# Speech Operations (dedicated pipeline, bypasses standard model selection)
|
# Speech Operations (dedicated pipeline, bypasses standard model selection)
|
||||||
SPEECH_TEAMS = "speechTeams" # Teams Meeting AI analysis: decide if/how to respond
|
SPEECH_TEAMS = "speechTeams" # Teams Meeting AI analysis: decide if/how to respond
|
||||||
|
|
||||||
|
|
@ -102,6 +108,7 @@ class AiModel(BaseModel):
|
||||||
|
|
||||||
# Function reference (not serialized)
|
# Function reference (not serialized)
|
||||||
functionCall: Optional[Callable] = Field(default=None, exclude=True, description="Function to call for this model")
|
functionCall: Optional[Callable] = Field(default=None, exclude=True, description="Function to call for this model")
|
||||||
|
functionCallStream: Optional[Callable] = Field(default=None, exclude=True, description="Streaming function: yields str deltas, then final AiModelResponse")
|
||||||
calculatepriceCHF: Optional[Callable] = Field(default=None, exclude=True, description="Function to calculate price in USD")
|
calculatepriceCHF: Optional[Callable] = Field(default=None, exclude=True, description="Function to calculate price in USD")
|
||||||
|
|
||||||
# Selection criteria - capabilities with ratings
|
# Selection criteria - capabilities with ratings
|
||||||
|
|
@ -155,10 +162,12 @@ class AiCallOptions(BaseModel):
|
||||||
class AiCallRequest(BaseModel):
|
class AiCallRequest(BaseModel):
|
||||||
"""Centralized AI call request payload for interface use."""
|
"""Centralized AI call request payload for interface use."""
|
||||||
|
|
||||||
prompt: str = Field(description="The user prompt")
|
prompt: str = Field(default="", description="The user prompt")
|
||||||
context: Optional[str] = Field(default=None, description="Optional external context (e.g., extracted docs)")
|
context: Optional[str] = Field(default=None, description="Optional external context (e.g., extracted docs)")
|
||||||
options: AiCallOptions = Field(default_factory=AiCallOptions)
|
options: AiCallOptions = Field(default_factory=AiCallOptions)
|
||||||
contentParts: Optional[List['ContentPart']] = None # NEW: Content parts for model-aware chunking
|
contentParts: Optional[List['ContentPart']] = None # Content parts for model-aware chunking
|
||||||
|
messages: Optional[List[Dict[str, Any]]] = Field(default=None, description="OpenAI-style messages for multi-turn agent conversations")
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = Field(default=None, description="Tool definitions for native function calling")
|
||||||
|
|
||||||
|
|
||||||
class AiCallResponse(BaseModel):
|
class AiCallResponse(BaseModel):
|
||||||
|
|
@ -172,14 +181,19 @@ class AiCallResponse(BaseModel):
|
||||||
bytesSent: int = Field(default=0, description="Input data size in bytes")
|
bytesSent: int = Field(default=0, description="Input data size in bytes")
|
||||||
bytesReceived: int = Field(default=0, description="Output data size in bytes")
|
bytesReceived: int = Field(default=0, description="Output data size in bytes")
|
||||||
errorCount: int = Field(default=0, description="0 for success, 1+ for errors")
|
errorCount: int = Field(default=0, description="0 for success, 1+ for errors")
|
||||||
|
toolCalls: Optional[List[Dict[str, Any]]] = Field(default=None, description="Tool calls from native function calling")
|
||||||
|
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional response metadata (e.g. embeddings vectors)")
|
||||||
|
|
||||||
|
|
||||||
class AiModelCall(BaseModel):
|
class AiModelCall(BaseModel):
|
||||||
"""Standardized input for AI model calls."""
|
"""Standardized input for AI model calls."""
|
||||||
|
|
||||||
messages: List[Dict[str, Any]] = Field(description="Messages in OpenAI format (role, content)")
|
messages: List[Dict[str, Any]] = Field(default_factory=list, description="Messages in OpenAI format (role, content)")
|
||||||
model: Optional[AiModel] = Field(default=None, description="The AI model being called")
|
model: Optional[AiModel] = Field(default=None, description="The AI model being called")
|
||||||
options: AiCallOptions = Field(default_factory=AiCallOptions, description="Additional model-specific options")
|
options: AiCallOptions = Field(default_factory=AiCallOptions, description="Additional model-specific options")
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = Field(default=None, description="Tool definitions for native function calling")
|
||||||
|
toolChoice: Optional[Any] = Field(default=None, description="Tool choice: 'auto', 'none', or specific tool")
|
||||||
|
embeddingInput: Optional[List[str]] = Field(default=None, description="Input texts for embedding models (used instead of messages)")
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
|
|
||||||
58
modules/datamodels/datamodelContent.py
Normal file
58
modules/datamodels/datamodelContent.py
Normal file
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Content Object data models for the container and content extraction pipeline.
|
||||||
|
|
||||||
|
Physical layer: Container hierarchy (ZIP, Folder, File)
|
||||||
|
Logical layer: Scalar content objects (text, image, videostream, audiostream, other)
|
||||||
|
|
||||||
|
The entire extraction pipeline up to ContentObjects runs without AI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
class ContainerLimitError(Exception):
|
||||||
|
"""Raised when container extraction exceeds safety limits (size, depth, file count)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ContentContextRef(BaseModel):
|
||||||
|
"""Reference to the origin context within a container/file."""
|
||||||
|
containerPath: str = Field(description="e.g. 'archiv.zip/folder-a/report.pdf'")
|
||||||
|
location: str = Field(default="", description="e.g. 'page:5/region:bottomLeft'")
|
||||||
|
label: Optional[str] = Field(default=None, description="e.g. 'Abbildung 3: Uebersicht'")
|
||||||
|
pageIndex: Optional[int] = Field(default=None, description="Page number (PDF, DOCX)")
|
||||||
|
sectionId: Optional[str] = Field(default=None, description="Section/Heading ID")
|
||||||
|
sheetName: Optional[str] = Field(default=None, description="Sheet name (XLSX)")
|
||||||
|
slideIndex: Optional[int] = Field(default=None, description="Slide number (PPTX)")
|
||||||
|
|
||||||
|
|
||||||
|
class ContentObject(BaseModel):
|
||||||
|
"""Scalar content object extracted from a file. No AI involved."""
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
fileId: str = Field(description="FK to the physical file")
|
||||||
|
contentType: str = Field(description="text, image, videostream, audiostream, other")
|
||||||
|
data: str = Field(default="", description="Content data (text, base64, URL)")
|
||||||
|
contextRef: ContentContextRef = Field(default_factory=ContentContextRef)
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
sequence: int = Field(default=0, description="Order within the context")
|
||||||
|
|
||||||
|
|
||||||
|
class ContentObjectSummary(BaseModel):
|
||||||
|
"""Compact description of a content object for the FileContentIndex."""
|
||||||
|
id: str = Field(description="Content object ID")
|
||||||
|
contentType: str = Field(description="text, image, videostream, audiostream, other")
|
||||||
|
contextRef: ContentContextRef = Field(default_factory=ContentContextRef)
|
||||||
|
charCount: Optional[int] = Field(default=None, description="Only for text")
|
||||||
|
dimensions: Optional[str] = Field(default=None, description="Only for image/video (e.g. '1920x1080')")
|
||||||
|
duration: Optional[float] = Field(default=None, description="Only for audio/video (seconds)")
|
||||||
|
|
||||||
|
|
||||||
|
class FileEntry(BaseModel):
|
||||||
|
"""A file extracted from a container (ZIP, TAR, Folder)."""
|
||||||
|
path: str = Field(description="Relative path within the container")
|
||||||
|
data: bytes = Field(description="File content bytes")
|
||||||
|
mimeType: str = Field(description="Detected MIME type")
|
||||||
|
size: int = Field(description="File size in bytes")
|
||||||
58
modules/datamodels/datamodelDataSource.py
Normal file
58
modules/datamodels/datamodelDataSource.py
Normal file
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""DataSource and ExternalEntry models for external data integration.
|
||||||
|
|
||||||
|
DataSource links a UserConnection to an external path (SharePoint folder,
|
||||||
|
Google Drive folder, FTP directory, etc.) for agent-accessible data containers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from modules.shared.attributeUtils import registerModelLabels
|
||||||
|
from modules.shared.timeUtils import getUtcTimestamp
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
class DataSource(BaseModel):
|
||||||
|
"""Configured external data source linked to a UserConnection."""
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||||
|
connectionId: str = Field(description="FK to UserConnection")
|
||||||
|
sourceType: str = Field(description="sharepointFolder, googleDriveFolder, outlookFolder, ftpFolder")
|
||||||
|
path: str = Field(description="External path (e.g. '/sites/MySite/Documents/Reports')")
|
||||||
|
label: str = Field(description="User-visible label")
|
||||||
|
featureInstanceId: Optional[str] = Field(default=None, description="Scoped to feature instance")
|
||||||
|
mandateId: Optional[str] = Field(default=None, description="Mandate scope")
|
||||||
|
userId: str = Field(default="", description="Owner user ID")
|
||||||
|
autoSync: bool = Field(default=False, description="Automatically sync on schedule")
|
||||||
|
lastSynced: Optional[float] = Field(default=None, description="Last sync timestamp")
|
||||||
|
createdAt: float = Field(default_factory=getUtcTimestamp, description="Creation timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
registerModelLabels(
|
||||||
|
"DataSource",
|
||||||
|
{"en": "Data Source", "de": "Datenquelle", "fr": "Source de données"},
|
||||||
|
{
|
||||||
|
"id": {"en": "ID", "de": "ID", "fr": "ID"},
|
||||||
|
"connectionId": {"en": "Connection ID", "de": "Verbindungs-ID", "fr": "ID de connexion"},
|
||||||
|
"sourceType": {"en": "Source Type", "de": "Quellentyp", "fr": "Type de source"},
|
||||||
|
"path": {"en": "Path", "de": "Pfad", "fr": "Chemin"},
|
||||||
|
"label": {"en": "Label", "de": "Bezeichnung", "fr": "Libellé"},
|
||||||
|
"featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance de fonctionnalité"},
|
||||||
|
"mandateId": {"en": "Mandate ID", "de": "Mandanten-ID", "fr": "ID du mandat"},
|
||||||
|
"userId": {"en": "User ID", "de": "Benutzer-ID", "fr": "ID utilisateur"},
|
||||||
|
"autoSync": {"en": "Auto Sync", "de": "Auto-Sync", "fr": "Synchro auto"},
|
||||||
|
"lastSynced": {"en": "Last Synced", "de": "Letzter Sync", "fr": "Dernier sync"},
|
||||||
|
"createdAt": {"en": "Created At", "de": "Erstellt am", "fr": "Créé le"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalEntry(BaseModel):
|
||||||
|
"""An item (file or folder) from an external data source."""
|
||||||
|
name: str = Field(description="Item name")
|
||||||
|
path: str = Field(description="Full path within the source")
|
||||||
|
isFolder: bool = Field(default=False, description="True if directory/folder")
|
||||||
|
size: Optional[int] = Field(default=None, description="File size in bytes")
|
||||||
|
mimeType: Optional[str] = Field(default=None, description="MIME type (files only)")
|
||||||
|
lastModified: Optional[float] = Field(default=None, description="Last modification timestamp")
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Provider-specific metadata")
|
||||||
|
|
@ -73,7 +73,7 @@ class ExtractionOptions(BaseModel):
|
||||||
"""Options for document extraction and processing with clear data structures."""
|
"""Options for document extraction and processing with clear data structures."""
|
||||||
|
|
||||||
# Core extraction parameters
|
# Core extraction parameters
|
||||||
prompt: str = Field(description="Extraction prompt for AI processing")
|
prompt: str = Field(default="", description="Extraction prompt for AI processing")
|
||||||
processDocumentsIndividually: bool = Field(default=True, description="Process each document separately")
|
processDocumentsIndividually: bool = Field(default=True, description="Process each document separately")
|
||||||
|
|
||||||
# Image processing parameters
|
# Image processing parameters
|
||||||
|
|
@ -81,7 +81,7 @@ class ExtractionOptions(BaseModel):
|
||||||
imageQuality: int = Field(default=85, ge=1, le=100, description="Image quality (1-100)")
|
imageQuality: int = Field(default=85, ge=1, le=100, description="Image quality (1-100)")
|
||||||
|
|
||||||
# Merging strategy
|
# Merging strategy
|
||||||
mergeStrategy: MergeStrategy = Field(description="Strategy for merging extraction results")
|
mergeStrategy: MergeStrategy = Field(default_factory=MergeStrategy, description="Strategy for merging extraction results")
|
||||||
|
|
||||||
# Optional chunking parameters (for backward compatibility)
|
# Optional chunking parameters (for backward compatibility)
|
||||||
chunkAllowed: Optional[bool] = Field(default=None, description="Whether chunking is allowed")
|
chunkAllowed: Optional[bool] = Field(default=None, description="Whether chunking is allowed")
|
||||||
|
|
|
||||||
32
modules/datamodels/datamodelFileFolder.py
Normal file
32
modules/datamodels/datamodelFileFolder.py
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""FileFolder: hierarchical folder structure for file organization."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from modules.shared.attributeUtils import registerModelLabels
|
||||||
|
from modules.shared.timeUtils import getUtcTimestamp
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
class FileFolder(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||||
|
name: str = Field(description="Folder name", json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": True})
|
||||||
|
parentId: Optional[str] = Field(default=None, description="Parent folder ID (null = root)", json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": False})
|
||||||
|
mandateId: Optional[str] = Field(default=None, description="Mandate context", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||||
|
featureInstanceId: Optional[str] = Field(default=None, description="Feature instance context", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||||
|
createdAt: float = Field(default_factory=getUtcTimestamp, description="Creation timestamp", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
|
||||||
|
|
||||||
|
|
||||||
|
registerModelLabels(
|
||||||
|
"FileFolder",
|
||||||
|
{"en": "File Folder", "fr": "Dossier de fichiers"},
|
||||||
|
{
|
||||||
|
"id": {"en": "ID", "fr": "ID"},
|
||||||
|
"name": {"en": "Name", "fr": "Nom"},
|
||||||
|
"parentId": {"en": "Parent Folder", "fr": "Dossier parent"},
|
||||||
|
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
|
||||||
|
"featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance"},
|
||||||
|
"createdAt": {"en": "Created At", "fr": "Créé le"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
"""File-related datamodels: FileItem, FilePreview, FileData."""
|
"""File-related datamodels: FileItem, FilePreview, FileData."""
|
||||||
|
|
||||||
from typing import Dict, Any, Optional, Union
|
from typing import Dict, Any, List, Optional, Union
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from modules.shared.attributeUtils import registerModelLabels
|
from modules.shared.attributeUtils import registerModelLabels
|
||||||
from modules.shared.timeUtils import getUtcTimestamp
|
from modules.shared.timeUtils import getUtcTimestamp
|
||||||
|
|
@ -20,6 +20,10 @@ class FileItem(BaseModel):
|
||||||
fileHash: str = Field(description="Hash of the file", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
fileHash: str = Field(description="Hash of the file", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||||
fileSize: int = Field(description="Size of the file in bytes", json_schema_extra={"frontend_type": "integer", "frontend_readonly": True, "frontend_required": False})
|
fileSize: int = Field(description="Size of the file in bytes", json_schema_extra={"frontend_type": "integer", "frontend_readonly": True, "frontend_required": False})
|
||||||
creationDate: float = Field(default_factory=getUtcTimestamp, description="Date when the file was created (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
|
creationDate: float = Field(default_factory=getUtcTimestamp, description="Date when the file was created (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
|
||||||
|
tags: Optional[List[str]] = Field(default=None, description="Tags for categorization and search", json_schema_extra={"frontend_type": "tags", "frontend_readonly": False, "frontend_required": False})
|
||||||
|
folderId: Optional[str] = Field(default=None, description="ID of the parent folder", json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": False})
|
||||||
|
description: Optional[str] = Field(default=None, description="User-provided description of the file", json_schema_extra={"frontend_type": "textarea", "frontend_readonly": False, "frontend_required": False})
|
||||||
|
status: Optional[str] = Field(default=None, description="Processing status: pending, extracted, embedding, indexed, failed", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||||
|
|
||||||
registerModelLabels(
|
registerModelLabels(
|
||||||
"FileItem",
|
"FileItem",
|
||||||
|
|
@ -33,6 +37,10 @@ registerModelLabels(
|
||||||
"fileHash": {"en": "File Hash", "fr": "Hash du fichier"},
|
"fileHash": {"en": "File Hash", "fr": "Hash du fichier"},
|
||||||
"fileSize": {"en": "File Size", "fr": "Taille du fichier"},
|
"fileSize": {"en": "File Size", "fr": "Taille du fichier"},
|
||||||
"creationDate": {"en": "Creation Date", "fr": "Date de création"},
|
"creationDate": {"en": "Creation Date", "fr": "Date de création"},
|
||||||
|
"tags": {"en": "Tags", "fr": "Tags"},
|
||||||
|
"folderId": {"en": "Folder ID", "fr": "ID du dossier"},
|
||||||
|
"description": {"en": "Description", "fr": "Description"},
|
||||||
|
"status": {"en": "Status", "fr": "Statut"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
130
modules/datamodels/datamodelKnowledge.py
Normal file
130
modules/datamodels/datamodelKnowledge.py
Normal file
|
|
@ -0,0 +1,130 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Knowledge Store data models: FileContentIndex, ContentChunk, WorkflowMemory.
|
||||||
|
|
||||||
|
These models support the 3-tier RAG architecture:
|
||||||
|
- Shared Layer: mandateId-scoped, isShared=True
|
||||||
|
- Instance Layer: userId + featureInstanceId-scoped
|
||||||
|
- Workflow Layer: workflowId-scoped (WorkflowMemory)
|
||||||
|
|
||||||
|
Vector fields use json_schema_extra={"db_type": "vector(1536)"} for pgvector.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from modules.shared.attributeUtils import registerModelLabels
|
||||||
|
from modules.shared.timeUtils import getUtcTimestamp
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
class FileContentIndex(BaseModel):
|
||||||
|
"""Structural index of a file's content objects. Created without AI.
|
||||||
|
Lives in the Instance Layer; optionally promoted to Shared Layer via isShared."""
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key (typically = fileId)")
|
||||||
|
userId: str = Field(description="Owner user ID")
|
||||||
|
featureInstanceId: str = Field(default="", description="Feature instance scope")
|
||||||
|
mandateId: str = Field(default="", description="Mandate scope")
|
||||||
|
isShared: bool = Field(default=False, description="Visible in Shared Layer for all mandate users")
|
||||||
|
fileName: str = Field(description="Original file name")
|
||||||
|
mimeType: str = Field(description="MIME type of the file")
|
||||||
|
containerPath: Optional[str] = Field(default=None, description="Path within a container (e.g. 'archive.zip/folder/report.pdf')")
|
||||||
|
totalObjects: int = Field(default=0, description="Total number of content objects extracted")
|
||||||
|
totalSize: int = Field(default=0, description="Total size of all content objects in bytes")
|
||||||
|
structure: Dict[str, Any] = Field(default_factory=dict, description="Structural overview (pages, sections, hierarchy)")
|
||||||
|
objectSummary: List[Dict[str, Any]] = Field(default_factory=list, description="Compact summary per content object")
|
||||||
|
extractedAt: float = Field(default_factory=getUtcTimestamp, description="Extraction timestamp")
|
||||||
|
status: str = Field(default="pending", description="Processing status: pending, extracted, embedding, indexed, failed")
|
||||||
|
|
||||||
|
|
||||||
|
registerModelLabels(
|
||||||
|
"FileContentIndex",
|
||||||
|
{"en": "File Content Index", "fr": "Index du contenu de fichier"},
|
||||||
|
{
|
||||||
|
"id": {"en": "ID", "fr": "ID"},
|
||||||
|
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||||
|
"featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance"},
|
||||||
|
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
|
||||||
|
"isShared": {"en": "Shared", "fr": "Partagé"},
|
||||||
|
"fileName": {"en": "File Name", "fr": "Nom de fichier"},
|
||||||
|
"mimeType": {"en": "MIME Type", "fr": "Type MIME"},
|
||||||
|
"containerPath": {"en": "Container Path", "fr": "Chemin du conteneur"},
|
||||||
|
"totalObjects": {"en": "Total Objects", "fr": "Nombre total d'objets"},
|
||||||
|
"totalSize": {"en": "Total Size", "fr": "Taille totale"},
|
||||||
|
"structure": {"en": "Structure", "fr": "Structure"},
|
||||||
|
"objectSummary": {"en": "Object Summary", "fr": "Résumé des objets"},
|
||||||
|
"extractedAt": {"en": "Extracted At", "fr": "Extrait le"},
|
||||||
|
"status": {"en": "Status", "fr": "Statut"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ContentChunk(BaseModel):
|
||||||
|
"""Persisted content chunk with embedding vector. Reusable across workflows.
|
||||||
|
Scalar content object (or chunk thereof) with pgvector embedding."""
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||||
|
contentObjectId: str = Field(description="Reference to the content object within FileContentIndex")
|
||||||
|
fileId: str = Field(description="FK to the source file")
|
||||||
|
userId: str = Field(description="Owner user ID")
|
||||||
|
featureInstanceId: str = Field(default="", description="Feature instance scope")
|
||||||
|
contentType: str = Field(description="Content type: text, image, videostream, audiostream, other")
|
||||||
|
data: str = Field(description="Content data (text, base64, URL)")
|
||||||
|
contextRef: Dict[str, Any] = Field(default_factory=dict, description="Context reference (page, position, label)")
|
||||||
|
summary: Optional[str] = Field(default=None, description="AI-generated summary (on demand)")
|
||||||
|
chunkMetadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
||||||
|
embedding: Optional[List[float]] = Field(
|
||||||
|
default=None, description="pgvector embedding (NOT NULL for text chunks)",
|
||||||
|
json_schema_extra={"db_type": "vector(1536)"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
registerModelLabels(
|
||||||
|
"ContentChunk",
|
||||||
|
{"en": "Content Chunk", "fr": "Fragment de contenu"},
|
||||||
|
{
|
||||||
|
"id": {"en": "ID", "fr": "ID"},
|
||||||
|
"contentObjectId": {"en": "Content Object ID", "fr": "ID de l'objet de contenu"},
|
||||||
|
"fileId": {"en": "File ID", "fr": "ID du fichier"},
|
||||||
|
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||||
|
"featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance"},
|
||||||
|
"contentType": {"en": "Content Type", "fr": "Type de contenu"},
|
||||||
|
"data": {"en": "Data", "fr": "Données"},
|
||||||
|
"contextRef": {"en": "Context Reference", "fr": "Référence contextuelle"},
|
||||||
|
"summary": {"en": "Summary", "fr": "Résumé"},
|
||||||
|
"chunkMetadata": {"en": "Metadata", "fr": "Métadonnées"},
|
||||||
|
"embedding": {"en": "Embedding", "fr": "Vecteur d'embedding"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowMemory(BaseModel):
|
||||||
|
"""Workflow-scoped key-value cache for entities and facts.
|
||||||
|
Extracted during agent rounds, persisted for cross-round and cross-workflow reuse."""
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||||
|
workflowId: str = Field(description="FK to the workflow")
|
||||||
|
userId: str = Field(description="Owner user ID")
|
||||||
|
featureInstanceId: str = Field(default="", description="Feature instance scope")
|
||||||
|
key: str = Field(description="Key identifier (e.g. 'entity:companyName')")
|
||||||
|
value: str = Field(description="Extracted value")
|
||||||
|
source: str = Field(default="extraction", description="Origin: extraction, tool, conversation, summary")
|
||||||
|
createdAt: float = Field(default_factory=getUtcTimestamp, description="Creation timestamp")
|
||||||
|
embedding: Optional[List[float]] = Field(
|
||||||
|
default=None, description="Optional embedding for semantic lookup",
|
||||||
|
json_schema_extra={"db_type": "vector(1536)"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
registerModelLabels(
|
||||||
|
"WorkflowMemory",
|
||||||
|
{"en": "Workflow Memory", "fr": "Mémoire de workflow"},
|
||||||
|
{
|
||||||
|
"id": {"en": "ID", "fr": "ID"},
|
||||||
|
"workflowId": {"en": "Workflow ID", "fr": "ID du workflow"},
|
||||||
|
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||||
|
"featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance"},
|
||||||
|
"key": {"en": "Key", "fr": "Clé"},
|
||||||
|
"value": {"en": "Value", "fr": "Valeur"},
|
||||||
|
"source": {"en": "Source", "fr": "Source"},
|
||||||
|
"createdAt": {"en": "Created At", "fr": "Créé le"},
|
||||||
|
"embedding": {"en": "Embedding", "fr": "Vecteur d'embedding"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
@ -180,7 +180,7 @@ def getAutomationServices(
|
||||||
for spec in REQUIRED_SERVICES:
|
for spec in REQUIRED_SERVICES:
|
||||||
key = spec["serviceKey"]
|
key = spec["serviceKey"]
|
||||||
try:
|
try:
|
||||||
svc = getService(key, ctx, legacy_hub=None)
|
svc = getService(key, ctx)
|
||||||
setattr(hub, key, svc)
|
setattr(hub, key, svc)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not resolve service '{key}' for automation: {e}")
|
logger.warning(f"Could not resolve service '{key}' for automation: {e}")
|
||||||
|
|
|
||||||
|
|
@ -179,7 +179,7 @@ def getChatbotServices(
|
||||||
for spec in REQUIRED_SERVICES:
|
for spec in REQUIRED_SERVICES:
|
||||||
key = spec["serviceKey"]
|
key = spec["serviceKey"]
|
||||||
try:
|
try:
|
||||||
svc = getService(key, ctx, legacy_hub=None)
|
svc = getService(key, ctx)
|
||||||
setattr(hub, key, svc)
|
setattr(hub, key, svc)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not resolve service '{key}' for chatbot: {e}")
|
logger.warning(f"Could not resolve service '{key}' for chatbot: {e}")
|
||||||
|
|
@ -197,7 +197,7 @@ def getChatStreamingHelper():
|
||||||
from modules.serviceCenter.context import ServiceCenterContext
|
from modules.serviceCenter.context import ServiceCenterContext
|
||||||
# Minimal context - streaming service only needs it for resolver
|
# Minimal context - streaming service only needs it for resolver
|
||||||
ctx = ServiceCenterContext(user=__get_placeholder_user(), mandate_id=None, feature_instance_id=None)
|
ctx = ServiceCenterContext(user=__get_placeholder_user(), mandate_id=None, feature_instance_id=None)
|
||||||
streaming = getService("streaming", ctx, legacy_hub=None)
|
streaming = getService("streaming", ctx)
|
||||||
return streaming.getChatStreamingHelper() if streaming else None
|
return streaming.getChatStreamingHelper() if streaming else None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -219,7 +219,7 @@ def getEventManager(user, mandateId: Optional[str] = None, featureInstanceId: Op
|
||||||
mandate_id=mandateId,
|
mandate_id=mandateId,
|
||||||
feature_instance_id=featureInstanceId,
|
feature_instance_id=featureInstanceId,
|
||||||
)
|
)
|
||||||
streaming = getService("streaming", ctx, legacy_hub=None)
|
streaming = getService("streaming", ctx)
|
||||||
return streaming.getEventManager()
|
return streaming.getEventManager()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -344,7 +344,7 @@ def getChatbotServices(
|
||||||
feature_instance_id=featureInstanceId,
|
feature_instance_id=featureInstanceId,
|
||||||
workflow=_workflow,
|
workflow=_workflow,
|
||||||
)
|
)
|
||||||
hub.billing = getService("billing", ctx, legacy_hub=None)
|
hub.billing = getService("billing", ctx)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not resolve billing service for chatbot: {e}")
|
logger.warning(f"Could not resolve billing service for chatbot: {e}")
|
||||||
hub.billing = None
|
hub.billing = None
|
||||||
|
|
|
||||||
|
|
@ -158,7 +158,7 @@ def getChatplaygroundServices(
|
||||||
for spec in REQUIRED_SERVICES:
|
for spec in REQUIRED_SERVICES:
|
||||||
key = spec["serviceKey"]
|
key = spec["serviceKey"]
|
||||||
try:
|
try:
|
||||||
svc = getService(key, ctx, legacy_hub=None)
|
svc = getService(key, ctx)
|
||||||
setattr(hub, key, svc)
|
setattr(hub, key, svc)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not resolve service '{key}' for chatplayground: {e}")
|
logger.warning(f"Could not resolve service '{key}' for chatplayground: {e}")
|
||||||
|
|
|
||||||
3
modules/features/workspace/__init__.py
Normal file
3
modules/features/workspace/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Unified AI Workspace feature -- merges Codeeditor, Chatbot, and Playground."""
|
||||||
248
modules/features/workspace/mainWorkspace.py
Normal file
248
modules/features/workspace/mainWorkspace.py
Normal file
|
|
@ -0,0 +1,248 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""
|
||||||
|
Workspace Feature Container - Main Module.
|
||||||
|
Handles feature initialization and RBAC catalog registration.
|
||||||
|
Unified AI Workspace combining Codeeditor, Chatbot, and Playground capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FEATURE_CODE = "workspace"
|
||||||
|
FEATURE_LABEL = {"en": "AI Workspace", "de": "AI Workspace", "fr": "AI Workspace"}
|
||||||
|
FEATURE_ICON = "mdi-brain"
|
||||||
|
|
||||||
|
UI_OBJECTS = [
|
||||||
|
{
|
||||||
|
"objectKey": "ui.feature.workspace.dashboard",
|
||||||
|
"label": {"en": "Dashboard", "de": "Dashboard", "fr": "Tableau de bord"},
|
||||||
|
"meta": {"area": "dashboard"}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
RESOURCE_OBJECTS = [
|
||||||
|
{
|
||||||
|
"objectKey": "resource.feature.workspace.start",
|
||||||
|
"label": {"en": "Start Agent", "de": "Agent starten", "fr": "Demarrer agent"},
|
||||||
|
"meta": {"endpoint": "/api/workspace/{instanceId}/start/stream", "method": "POST"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"objectKey": "resource.feature.workspace.stop",
|
||||||
|
"label": {"en": "Stop Agent", "de": "Agent stoppen", "fr": "Arreter agent"},
|
||||||
|
"meta": {"endpoint": "/api/workspace/{instanceId}/{workflowId}/stop", "method": "POST"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"objectKey": "resource.feature.workspace.files",
|
||||||
|
"label": {"en": "Manage Files", "de": "Dateien verwalten", "fr": "Gerer fichiers"},
|
||||||
|
"meta": {"endpoint": "/api/workspace/{instanceId}/files", "method": "GET"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"objectKey": "resource.feature.workspace.folders",
|
||||||
|
"label": {"en": "Manage Folders", "de": "Ordner verwalten", "fr": "Gerer dossiers"},
|
||||||
|
"meta": {"endpoint": "/api/workspace/{instanceId}/folders", "method": "GET"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"objectKey": "resource.feature.workspace.datasources",
|
||||||
|
"label": {"en": "Data Sources", "de": "Datenquellen", "fr": "Sources de donnees"},
|
||||||
|
"meta": {"endpoint": "/api/workspace/{instanceId}/datasources", "method": "GET"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"objectKey": "resource.feature.workspace.voice",
|
||||||
|
"label": {"en": "Voice Input/Output", "de": "Spracheingabe/-ausgabe", "fr": "Entree/sortie vocale"},
|
||||||
|
"meta": {"endpoint": "/api/workspace/{instanceId}/voice/*", "method": "POST"}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
TEMPLATE_ROLES = [
|
||||||
|
{
|
||||||
|
"roleLabel": "workspace-viewer",
|
||||||
|
"description": {
|
||||||
|
"en": "Workspace Viewer - View workspace (read-only)",
|
||||||
|
"de": "Workspace Betrachter - Workspace ansehen (nur lesen)",
|
||||||
|
"fr": "Visualiseur Workspace - Consulter le workspace (lecture seule)"
|
||||||
|
},
|
||||||
|
"accessRules": [
|
||||||
|
{"context": "UI", "item": "ui.feature.workspace.dashboard", "view": True},
|
||||||
|
{"context": "DATA", "item": None, "view": True, "read": "m", "create": "n", "update": "n", "delete": "n"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"roleLabel": "workspace-user",
|
||||||
|
"description": {
|
||||||
|
"en": "Workspace User - Use AI workspace and tools",
|
||||||
|
"de": "Workspace Benutzer - AI Workspace und Tools nutzen",
|
||||||
|
"fr": "Utilisateur Workspace - Utiliser l'espace de travail AI et les outils"
|
||||||
|
},
|
||||||
|
"accessRules": [
|
||||||
|
{"context": "UI", "item": "ui.feature.workspace.dashboard", "view": True},
|
||||||
|
{"context": "RESOURCE", "item": "resource.feature.workspace.start", "view": True},
|
||||||
|
{"context": "RESOURCE", "item": "resource.feature.workspace.stop", "view": True},
|
||||||
|
{"context": "RESOURCE", "item": "resource.feature.workspace.files", "view": True},
|
||||||
|
{"context": "RESOURCE", "item": "resource.feature.workspace.folders", "view": True},
|
||||||
|
{"context": "RESOURCE", "item": "resource.feature.workspace.datasources", "view": True},
|
||||||
|
{"context": "RESOURCE", "item": "resource.feature.workspace.voice", "view": True},
|
||||||
|
{"context": "DATA", "item": None, "view": True, "read": "m", "create": "m", "update": "m", "delete": "m"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"roleLabel": "workspace-admin",
|
||||||
|
"description": {
|
||||||
|
"en": "Workspace Admin - Full access to AI workspace",
|
||||||
|
"de": "Workspace Admin - Vollzugriff auf AI Workspace",
|
||||||
|
"fr": "Administrateur Workspace - Acces complet au workspace AI"
|
||||||
|
},
|
||||||
|
"accessRules": [
|
||||||
|
{"context": "UI", "item": None, "view": True},
|
||||||
|
{"context": "RESOURCE", "item": None, "view": True},
|
||||||
|
{"context": "DATA", "item": None, "view": True, "read": "a", "create": "a", "update": "a", "delete": "a"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def getFeatureDefinition() -> Dict[str, Any]:
|
||||||
|
"""Return the feature definition for registration."""
|
||||||
|
return {
|
||||||
|
"code": FEATURE_CODE,
|
||||||
|
"label": FEATURE_LABEL,
|
||||||
|
"icon": FEATURE_ICON,
|
||||||
|
"autoCreateInstance": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def getUiObjects() -> List[Dict[str, Any]]:
|
||||||
|
"""Return UI objects for RBAC catalog registration."""
|
||||||
|
return UI_OBJECTS
|
||||||
|
|
||||||
|
|
||||||
|
def getResourceObjects() -> List[Dict[str, Any]]:
|
||||||
|
"""Return resource objects for RBAC catalog registration."""
|
||||||
|
return RESOURCE_OBJECTS
|
||||||
|
|
||||||
|
|
||||||
|
def getTemplateRoles() -> List[Dict[str, Any]]:
|
||||||
|
"""Return template roles for this feature."""
|
||||||
|
return TEMPLATE_ROLES
|
||||||
|
|
||||||
|
|
||||||
|
def registerFeature(catalogService) -> bool:
|
||||||
|
"""Register this feature's RBAC objects in the catalog."""
|
||||||
|
try:
|
||||||
|
for uiObj in UI_OBJECTS:
|
||||||
|
catalogService.registerUiObject(
|
||||||
|
featureCode=FEATURE_CODE,
|
||||||
|
objectKey=uiObj["objectKey"],
|
||||||
|
label=uiObj["label"],
|
||||||
|
meta=uiObj.get("meta")
|
||||||
|
)
|
||||||
|
|
||||||
|
for resObj in RESOURCE_OBJECTS:
|
||||||
|
catalogService.registerResourceObject(
|
||||||
|
featureCode=FEATURE_CODE,
|
||||||
|
objectKey=resObj["objectKey"],
|
||||||
|
label=resObj["label"],
|
||||||
|
meta=resObj.get("meta")
|
||||||
|
)
|
||||||
|
|
||||||
|
_syncTemplateRolesToDb()
|
||||||
|
|
||||||
|
logger.info(f"Feature '{FEATURE_CODE}' registered {len(UI_OBJECTS)} UI objects and {len(RESOURCE_OBJECTS)} resource objects")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to register feature '{FEATURE_CODE}': {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _syncTemplateRolesToDb() -> int:
|
||||||
|
"""Sync template roles and their AccessRules to the database."""
|
||||||
|
try:
|
||||||
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||||
|
from modules.datamodels.datamodelRbac import Role, AccessRule, AccessRuleContext
|
||||||
|
|
||||||
|
rootInterface = getRootInterface()
|
||||||
|
|
||||||
|
existingRoles = rootInterface.getRolesByFeatureCode(FEATURE_CODE)
|
||||||
|
templateRoles = [r for r in existingRoles if r.mandateId is None]
|
||||||
|
existingRoleLabels = {r.roleLabel: str(r.id) for r in templateRoles}
|
||||||
|
|
||||||
|
createdCount = 0
|
||||||
|
for roleTemplate in TEMPLATE_ROLES:
|
||||||
|
roleLabel = roleTemplate["roleLabel"]
|
||||||
|
|
||||||
|
if roleLabel in existingRoleLabels:
|
||||||
|
roleId = existingRoleLabels[roleLabel]
|
||||||
|
_ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", []))
|
||||||
|
else:
|
||||||
|
newRole = Role(
|
||||||
|
roleLabel=roleLabel,
|
||||||
|
description=roleTemplate.get("description", {}),
|
||||||
|
featureCode=FEATURE_CODE,
|
||||||
|
mandateId=None,
|
||||||
|
featureInstanceId=None,
|
||||||
|
isSystemRole=False
|
||||||
|
)
|
||||||
|
createdRole = rootInterface.db.recordCreate(Role, newRole.model_dump())
|
||||||
|
roleId = createdRole.get("id")
|
||||||
|
_ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", []))
|
||||||
|
logger.info(f"Created template role '{roleLabel}' with ID {roleId}")
|
||||||
|
createdCount += 1
|
||||||
|
|
||||||
|
if createdCount > 0:
|
||||||
|
logger.info(f"Feature '{FEATURE_CODE}': Created {createdCount} template roles")
|
||||||
|
|
||||||
|
return createdCount
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error syncing template roles for feature '{FEATURE_CODE}': {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: List[Dict[str, Any]]) -> int:
|
||||||
|
"""Ensure AccessRules exist for a role based on templates."""
|
||||||
|
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
|
||||||
|
|
||||||
|
existingRules = rootInterface.getAccessRulesByRole(roleId)
|
||||||
|
existingSignatures = set()
|
||||||
|
for rule in existingRules:
|
||||||
|
sig = (rule.context.value if rule.context else None, rule.item)
|
||||||
|
existingSignatures.add(sig)
|
||||||
|
|
||||||
|
createdCount = 0
|
||||||
|
for template in ruleTemplates:
|
||||||
|
context = template.get("context", "UI")
|
||||||
|
item = template.get("item")
|
||||||
|
sig = (context, item)
|
||||||
|
|
||||||
|
if sig in existingSignatures:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if context == "UI":
|
||||||
|
contextEnum = AccessRuleContext.UI
|
||||||
|
elif context == "DATA":
|
||||||
|
contextEnum = AccessRuleContext.DATA
|
||||||
|
elif context == "RESOURCE":
|
||||||
|
contextEnum = AccessRuleContext.RESOURCE
|
||||||
|
else:
|
||||||
|
contextEnum = context
|
||||||
|
|
||||||
|
newRule = AccessRule(
|
||||||
|
roleId=roleId,
|
||||||
|
context=contextEnum,
|
||||||
|
item=item,
|
||||||
|
view=template.get("view", False),
|
||||||
|
read=template.get("read"),
|
||||||
|
create=template.get("create"),
|
||||||
|
update=template.get("update"),
|
||||||
|
delete=template.get("delete"),
|
||||||
|
)
|
||||||
|
rootInterface.db.recordCreate(AccessRule, newRule.model_dump())
|
||||||
|
createdCount += 1
|
||||||
|
|
||||||
|
if createdCount > 0:
|
||||||
|
logger.debug(f"Created {createdCount} AccessRules for role {roleId}")
|
||||||
|
|
||||||
|
return createdCount
|
||||||
720
modules/features/workspace/routeFeatureWorkspace.py
Normal file
720
modules/features/workspace/routeFeatureWorkspace.py
Normal file
|
|
@ -0,0 +1,720 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Unified AI Workspace routes.
|
||||||
|
|
||||||
|
SSE-based endpoints that combine the capabilities of Codeeditor, Chatbot,
|
||||||
|
and Playground into a single agent-driven workspace.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request, UploadFile, File
|
||||||
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from modules.auth import limiter, getRequestContext, RequestContext
|
||||||
|
from modules.interfaces import interfaceDbChat, interfaceDbManagement
|
||||||
|
from modules.interfaces.interfaceAiObjects import AiObjects
|
||||||
|
from modules.serviceCenter.core.serviceStreaming import get_event_manager
|
||||||
|
from modules.serviceCenter.services.serviceAgent.datamodelAgent import AgentEventTypeEnum
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/api/workspace",
|
||||||
|
tags=["Unified Workspace"],
|
||||||
|
responses={404: {"description": "Not found"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
_aiObjects: Optional[AiObjects] = None
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceInputRequest(BaseModel):
|
||||||
|
"""Prompt input for the unified workspace."""
|
||||||
|
prompt: str = Field(description="User prompt text")
|
||||||
|
fileIds: List[str] = Field(default_factory=list, description="Referenced file IDs")
|
||||||
|
uploadedFiles: List[str] = Field(default_factory=list, description="Newly uploaded file IDs")
|
||||||
|
dataSourceIds: List[str] = Field(default_factory=list, description="Active DataSource IDs")
|
||||||
|
voiceMode: bool = Field(default=False, description="Enable voice response")
|
||||||
|
workflowId: Optional[str] = Field(default=None, description="Continue existing workflow")
|
||||||
|
userLanguage: str = Field(default="en", description="User language code")
|
||||||
|
|
||||||
|
|
||||||
|
async def _getAiObjects() -> AiObjects:
|
||||||
|
global _aiObjects
|
||||||
|
if _aiObjects is None:
|
||||||
|
_aiObjects = await AiObjects.create()
|
||||||
|
return _aiObjects
|
||||||
|
|
||||||
|
|
||||||
|
def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str:
|
||||||
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||||
|
rootInterface = getRootInterface()
|
||||||
|
instance = rootInterface.getFeatureInstance(instanceId)
|
||||||
|
if not instance:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Feature instance {instanceId} not found")
|
||||||
|
featureAccess = rootInterface.getFeatureAccess(str(context.user.id), instanceId)
|
||||||
|
if not featureAccess or not featureAccess.enabled:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied to this feature instance")
|
||||||
|
return str(instance.mandateId) if instance.mandateId else None
|
||||||
|
|
||||||
|
|
||||||
|
def _getChatInterface(context: RequestContext, featureInstanceId: str = None):
|
||||||
|
return interfaceDbChat.getInterface(
|
||||||
|
context.user,
|
||||||
|
mandateId=str(context.mandateId) if context.mandateId else None,
|
||||||
|
featureInstanceId=featureInstanceId,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _buildResolverDbInterface(chatService):
|
||||||
|
"""Build a DB adapter that ConnectorResolver can use to load UserConnections.
|
||||||
|
|
||||||
|
ConnectorResolver calls db.getUserConnection(connectionId).
|
||||||
|
interfaceDbApp provides getUserConnectionById(connectionId).
|
||||||
|
This adapter bridges the method name difference.
|
||||||
|
"""
|
||||||
|
class _ResolverDbAdapter:
|
||||||
|
def __init__(self, appInterface):
|
||||||
|
self._app = appInterface
|
||||||
|
def getUserConnection(self, connectionId: str):
|
||||||
|
if hasattr(self._app, "getUserConnectionById"):
|
||||||
|
return self._app.getUserConnectionById(connectionId)
|
||||||
|
return None
|
||||||
|
appIf = getattr(chatService, "interfaceDbApp", None)
|
||||||
|
if appIf:
|
||||||
|
return _ResolverDbAdapter(appIf)
|
||||||
|
return getattr(chatService, "interfaceDbComponent", None)
|
||||||
|
|
||||||
|
|
||||||
|
def _getDbManagement(context: RequestContext, featureInstanceId: str = None):
|
||||||
|
return interfaceDbManagement.getInterface(
|
||||||
|
context.user,
|
||||||
|
mandateId=str(context.mandateId) if context.mandateId else None,
|
||||||
|
featureInstanceId=featureInstanceId,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SSE Stream endpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/{instanceId}/start/stream")
|
||||||
|
@limiter.limit("60/minute")
|
||||||
|
async def streamWorkspaceStart(
|
||||||
|
request: Request,
|
||||||
|
instanceId: str = Path(..., description="Feature instance ID"),
|
||||||
|
userInput: WorkspaceInputRequest = Body(...),
|
||||||
|
context: RequestContext = Depends(getRequestContext),
|
||||||
|
):
|
||||||
|
"""Start or continue a Workspace session with SSE streaming via serviceAgent."""
|
||||||
|
mandateId = _validateInstanceAccess(instanceId, context)
|
||||||
|
chatInterface = _getChatInterface(context, featureInstanceId=instanceId)
|
||||||
|
aiObjects = await _getAiObjects()
|
||||||
|
eventManager = get_event_manager()
|
||||||
|
|
||||||
|
if userInput.workflowId:
|
||||||
|
workflow = chatInterface.getWorkflow(userInput.workflowId)
|
||||||
|
if not workflow:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Workflow {userInput.workflowId} not found")
|
||||||
|
else:
|
||||||
|
existingWorkflows = chatInterface.getWorkflows() or []
|
||||||
|
nextNum = len(existingWorkflows) + 1
|
||||||
|
workflow = chatInterface.createWorkflow({
|
||||||
|
"featureInstanceId": instanceId,
|
||||||
|
"status": "active",
|
||||||
|
"name": f"Chat {nextNum}",
|
||||||
|
"workflowMode": "Dynamic",
|
||||||
|
})
|
||||||
|
|
||||||
|
workflowId = workflow.get("id") if isinstance(workflow, dict) else getattr(workflow, "id", str(workflow))
|
||||||
|
queueId = f"workspace-{workflowId}"
|
||||||
|
eventManager.create_queue(queueId)
|
||||||
|
|
||||||
|
chatInterface.createMessage({
|
||||||
|
"workflowId": workflowId,
|
||||||
|
"role": "user",
|
||||||
|
"message": userInput.prompt,
|
||||||
|
})
|
||||||
|
|
||||||
|
asyncio.ensure_future(
|
||||||
|
_runWorkspaceAgent(
|
||||||
|
workflowId=workflowId,
|
||||||
|
queueId=queueId,
|
||||||
|
prompt=userInput.prompt,
|
||||||
|
fileIds=userInput.fileIds,
|
||||||
|
dataSourceIds=userInput.dataSourceIds,
|
||||||
|
voiceMode=userInput.voiceMode,
|
||||||
|
instanceId=instanceId,
|
||||||
|
user=context.user,
|
||||||
|
mandateId=mandateId or "",
|
||||||
|
aiObjects=aiObjects,
|
||||||
|
chatInterface=chatInterface,
|
||||||
|
eventManager=eventManager,
|
||||||
|
userLanguage=userInput.userLanguage,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _sseGenerator():
|
||||||
|
queue = eventManager.get_queue(queueId)
|
||||||
|
if not queue:
|
||||||
|
return
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
event = await asyncio.wait_for(queue.get(), timeout=120)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
yield "data: {\"type\": \"keepalive\"}\n\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
ssePayload = event.get("data", event) if isinstance(event, dict) else event
|
||||||
|
yield f"data: {json.dumps(ssePayload, default=str)}\n\n"
|
||||||
|
|
||||||
|
eventType = ssePayload.get("type", "") if isinstance(ssePayload, dict) else ""
|
||||||
|
if eventType in ("complete", "error", "stopped"):
|
||||||
|
break
|
||||||
|
|
||||||
|
await eventManager.cleanup(queueId, delay=30)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
_sseGenerator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _runWorkspaceAgent(
|
||||||
|
workflowId: str,
|
||||||
|
queueId: str,
|
||||||
|
prompt: str,
|
||||||
|
fileIds: List[str],
|
||||||
|
dataSourceIds: List[str],
|
||||||
|
voiceMode: bool,
|
||||||
|
instanceId: str,
|
||||||
|
user,
|
||||||
|
mandateId: str,
|
||||||
|
aiObjects,
|
||||||
|
chatInterface,
|
||||||
|
eventManager,
|
||||||
|
userLanguage: str = "en",
|
||||||
|
):
|
||||||
|
"""Run the serviceAgent loop and forward events to the SSE queue."""
|
||||||
|
try:
|
||||||
|
from modules.serviceCenter import getService
|
||||||
|
from modules.serviceCenter.context import ServiceCenterContext
|
||||||
|
ctx = ServiceCenterContext(
|
||||||
|
user=user,
|
||||||
|
mandate_id=mandateId,
|
||||||
|
feature_instance_id=instanceId,
|
||||||
|
workflow_id=workflowId,
|
||||||
|
)
|
||||||
|
agentService = getService("agent", ctx)
|
||||||
|
|
||||||
|
async for event in agentService.runAgent(
|
||||||
|
prompt=prompt,
|
||||||
|
fileIds=fileIds,
|
||||||
|
workflowId=workflowId,
|
||||||
|
userLanguage=userLanguage,
|
||||||
|
):
|
||||||
|
sseEvent = {
|
||||||
|
"type": event.type.value if hasattr(event.type, "value") else event.type,
|
||||||
|
"workflowId": workflowId,
|
||||||
|
}
|
||||||
|
if event.content:
|
||||||
|
sseEvent["content"] = event.content
|
||||||
|
if event.type == AgentEventTypeEnum.MESSAGE:
|
||||||
|
sseEvent["item"] = {
|
||||||
|
"id": f"msg-{workflowId}-{id(event)}",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": event.content,
|
||||||
|
"workflowId": workflowId,
|
||||||
|
}
|
||||||
|
if event.data:
|
||||||
|
sseEvent["item"] = event.data
|
||||||
|
|
||||||
|
await eventManager.emit_event(queueId, sseEvent["type"], sseEvent)
|
||||||
|
|
||||||
|
if event.type in (AgentEventTypeEnum.FINAL, AgentEventTypeEnum.ERROR):
|
||||||
|
if event.content:
|
||||||
|
chatInterface.createMessage({
|
||||||
|
"workflowId": workflowId,
|
||||||
|
"role": "assistant",
|
||||||
|
"message": event.content,
|
||||||
|
})
|
||||||
|
|
||||||
|
await eventManager.emit_event(queueId, "complete", {
|
||||||
|
"type": "complete",
|
||||||
|
"workflowId": workflowId,
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Workspace agent error: {e}", exc_info=True)
|
||||||
|
await eventManager.emit_event(queueId, "error", {
|
||||||
|
"type": "error",
|
||||||
|
"content": str(e),
|
||||||
|
"workflowId": workflowId,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Stop endpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/{instanceId}/{workflowId}/stop")
|
||||||
|
@limiter.limit("30/minute")
|
||||||
|
async def stopWorkspace(
|
||||||
|
request: Request,
|
||||||
|
instanceId: str = Path(...),
|
||||||
|
workflowId: str = Path(...),
|
||||||
|
context: RequestContext = Depends(getRequestContext),
|
||||||
|
):
|
||||||
|
_validateInstanceAccess(instanceId, context)
|
||||||
|
queueId = f"workspace-{workflowId}"
|
||||||
|
eventManager = get_event_manager()
|
||||||
|
await eventManager.emit_event(queueId, "stopped", {
|
||||||
|
"type": "stopped",
|
||||||
|
"workflowId": workflowId,
|
||||||
|
})
|
||||||
|
return JSONResponse({"status": "stopped", "workflowId": workflowId})
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Workflow / Conversation endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/{instanceId}/workflows")
|
||||||
|
@limiter.limit("60/minute")
|
||||||
|
async def listWorkspaceWorkflows(
|
||||||
|
request: Request,
|
||||||
|
instanceId: str = Path(...),
|
||||||
|
context: RequestContext = Depends(getRequestContext),
|
||||||
|
):
|
||||||
|
"""List all workspace workflows/conversations for this instance."""
|
||||||
|
_validateInstanceAccess(instanceId, context)
|
||||||
|
chatInterface = _getChatInterface(context, featureInstanceId=instanceId)
|
||||||
|
workflows = chatInterface.getWorkflows() or []
|
||||||
|
items = []
|
||||||
|
for wf in workflows:
|
||||||
|
if isinstance(wf, dict):
|
||||||
|
items.append(wf)
|
||||||
|
else:
|
||||||
|
items.append({
|
||||||
|
"id": getattr(wf, "id", None),
|
||||||
|
"name": getattr(wf, "name", ""),
|
||||||
|
"status": getattr(wf, "status", ""),
|
||||||
|
"startedAt": getattr(wf, "startedAt", None),
|
||||||
|
"lastActivity": getattr(wf, "lastActivity", None),
|
||||||
|
})
|
||||||
|
return JSONResponse({"workflows": items})
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateWorkflowRequest(BaseModel):
|
||||||
|
"""Request body for updating a workflow (PATCH)."""
|
||||||
|
name: Optional[str] = Field(default=None, description="New workflow name")
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/{instanceId}/workflows/{workflowId}")
|
||||||
|
@limiter.limit("60/minute")
|
||||||
|
async def patchWorkspaceWorkflow(
|
||||||
|
request: Request,
|
||||||
|
instanceId: str = Path(..., description="Feature instance ID"),
|
||||||
|
workflowId: str = Path(..., description="Workflow ID to update"),
|
||||||
|
body: UpdateWorkflowRequest = Body(...),
|
||||||
|
context: RequestContext = Depends(getRequestContext),
|
||||||
|
):
|
||||||
|
"""Update a workspace workflow (e.g. rename)."""
|
||||||
|
_validateInstanceAccess(instanceId, context)
|
||||||
|
chatInterface = _getChatInterface(context, featureInstanceId=instanceId)
|
||||||
|
workflow = chatInterface.getWorkflow(workflowId)
|
||||||
|
if not workflow:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Workflow {workflowId} not found")
|
||||||
|
updateData = {}
|
||||||
|
if body.name is not None:
|
||||||
|
updateData["name"] = body.name
|
||||||
|
if not updateData:
|
||||||
|
updated = workflow
|
||||||
|
else:
|
||||||
|
updated = chatInterface.updateWorkflow(workflowId, updateData)
|
||||||
|
if isinstance(updated, dict):
|
||||||
|
return JSONResponse(updated)
|
||||||
|
return JSONResponse({
|
||||||
|
"id": getattr(updated, "id", None),
|
||||||
|
"name": getattr(updated, "name", ""),
|
||||||
|
"status": getattr(updated, "status", ""),
|
||||||
|
"startedAt": getattr(updated, "startedAt", None),
|
||||||
|
"lastActivity": getattr(updated, "lastActivity", None),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{instanceId}/workflows/{workflowId}/messages")
|
||||||
|
@limiter.limit("60/minute")
|
||||||
|
async def getWorkspaceMessages(
|
||||||
|
request: Request,
|
||||||
|
instanceId: str = Path(...),
|
||||||
|
workflowId: str = Path(...),
|
||||||
|
context: RequestContext = Depends(getRequestContext),
|
||||||
|
):
|
||||||
|
"""Get all messages for a workspace workflow/conversation."""
|
||||||
|
_validateInstanceAccess(instanceId, context)
|
||||||
|
chatInterface = _getChatInterface(context, featureInstanceId=instanceId)
|
||||||
|
messages = chatInterface.getMessages(workflowId) or []
|
||||||
|
items = []
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
items.append(msg)
|
||||||
|
else:
|
||||||
|
items.append({
|
||||||
|
"id": getattr(msg, "id", None),
|
||||||
|
"role": getattr(msg, "role", ""),
|
||||||
|
"content": getattr(msg, "message", "") or getattr(msg, "content", ""),
|
||||||
|
"createdAt": getattr(msg, "publishedAt", None) or getattr(msg, "createdAt", None),
|
||||||
|
})
|
||||||
|
return JSONResponse({"messages": items})
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# File and folder list endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/{instanceId}/files")
|
||||||
|
@limiter.limit("60/minute")
|
||||||
|
async def listWorkspaceFiles(
|
||||||
|
request: Request,
|
||||||
|
instanceId: str = Path(...),
|
||||||
|
folderId: Optional[str] = Query(None),
|
||||||
|
tags: Optional[str] = Query(None),
|
||||||
|
search: Optional[str] = Query(None),
|
||||||
|
context: RequestContext = Depends(getRequestContext),
|
||||||
|
):
|
||||||
|
_validateInstanceAccess(instanceId, context)
|
||||||
|
dbMgmt = _getDbManagement(context, featureInstanceId=instanceId)
|
||||||
|
files = dbMgmt.getAllFiles()
|
||||||
|
return JSONResponse({"files": [f if isinstance(f, dict) else f.model_dump() for f in (files or [])]})
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{instanceId}/files/{fileId}/content")
|
||||||
|
@limiter.limit("60/minute")
|
||||||
|
async def getFileContent(
|
||||||
|
request: Request,
|
||||||
|
instanceId: str = Path(...),
|
||||||
|
fileId: str = Path(...),
|
||||||
|
context: RequestContext = Depends(getRequestContext),
|
||||||
|
):
|
||||||
|
"""Return the raw content of a file for preview."""
|
||||||
|
from fastapi.responses import Response
|
||||||
|
_validateInstanceAccess(instanceId, context)
|
||||||
|
dbMgmt = _getDbManagement(context, featureInstanceId=instanceId)
|
||||||
|
fileRecord = dbMgmt.getFile(fileId)
|
||||||
|
if not fileRecord:
|
||||||
|
raise HTTPException(status_code=404, detail=f"File {fileId} not found")
|
||||||
|
fileData = fileRecord if isinstance(fileRecord, dict) else fileRecord.model_dump()
|
||||||
|
filePath = fileData.get("filePath")
|
||||||
|
if not filePath:
|
||||||
|
raise HTTPException(status_code=404, detail="File has no stored path")
|
||||||
|
import os
|
||||||
|
if not os.path.isfile(filePath):
|
||||||
|
raise HTTPException(status_code=404, detail="File not found on disk")
|
||||||
|
mimeType = fileData.get("mimeType", "application/octet-stream")
|
||||||
|
with open(filePath, "rb") as fh:
|
||||||
|
content = fh.read()
|
||||||
|
return Response(content=content, media_type=mimeType)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{instanceId}/folders")
|
||||||
|
@limiter.limit("60/minute")
|
||||||
|
async def listWorkspaceFolders(
|
||||||
|
request: Request,
|
||||||
|
instanceId: str = Path(...),
|
||||||
|
parentId: Optional[str] = Query(None),
|
||||||
|
context: RequestContext = Depends(getRequestContext),
|
||||||
|
):
|
||||||
|
_validateInstanceAccess(instanceId, context)
|
||||||
|
try:
|
||||||
|
from modules.serviceCenter import getService
|
||||||
|
from modules.serviceCenter.context import ServiceCenterContext
|
||||||
|
ctx = ServiceCenterContext(
|
||||||
|
user=context.user,
|
||||||
|
mandate_id=str(context.mandateId) if context.mandateId else None,
|
||||||
|
feature_instance_id=instanceId,
|
||||||
|
)
|
||||||
|
chatService = getService("chat", ctx)
|
||||||
|
folders = chatService.listFolders(parentId=parentId)
|
||||||
|
return JSONResponse({"folders": folders or []})
|
||||||
|
except Exception:
|
||||||
|
return JSONResponse({"folders": []})
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{instanceId}/datasources")
|
||||||
|
@limiter.limit("60/minute")
|
||||||
|
async def listWorkspaceDataSources(
|
||||||
|
request: Request,
|
||||||
|
instanceId: str = Path(...),
|
||||||
|
context: RequestContext = Depends(getRequestContext),
|
||||||
|
):
|
||||||
|
_validateInstanceAccess(instanceId, context)
|
||||||
|
try:
|
||||||
|
from modules.serviceCenter import getService
|
||||||
|
from modules.serviceCenter.context import ServiceCenterContext
|
||||||
|
ctx = ServiceCenterContext(
|
||||||
|
user=context.user,
|
||||||
|
mandate_id=str(context.mandateId) if context.mandateId else None,
|
||||||
|
feature_instance_id=instanceId,
|
||||||
|
)
|
||||||
|
chatService = getService("chat", ctx)
|
||||||
|
dataSources = chatService.listDataSources(featureInstanceId=instanceId)
|
||||||
|
return JSONResponse({"dataSources": dataSources or []})
|
||||||
|
except Exception:
|
||||||
|
return JSONResponse({"dataSources": []})
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{instanceId}/connections")
|
||||||
|
@limiter.limit("60/minute")
|
||||||
|
async def listWorkspaceConnections(
|
||||||
|
request: Request,
|
||||||
|
instanceId: str = Path(...),
|
||||||
|
context: RequestContext = Depends(getRequestContext),
|
||||||
|
):
|
||||||
|
"""Return the user's active connections (UserConnections)."""
|
||||||
|
_validateInstanceAccess(instanceId, context)
|
||||||
|
from modules.serviceCenter import getService
|
||||||
|
from modules.serviceCenter.context import ServiceCenterContext
|
||||||
|
ctx = ServiceCenterContext(
|
||||||
|
user=context.user,
|
||||||
|
mandate_id=str(context.mandateId) if context.mandateId else None,
|
||||||
|
feature_instance_id=instanceId,
|
||||||
|
)
|
||||||
|
chatService = getService("chat", ctx)
|
||||||
|
connections = chatService.getUserConnections()
|
||||||
|
items = []
|
||||||
|
for c in connections or []:
|
||||||
|
conn = c if isinstance(c, dict) else (c.model_dump() if hasattr(c, "model_dump") else {})
|
||||||
|
authority = conn.get("authority")
|
||||||
|
if hasattr(authority, "value"):
|
||||||
|
authority = authority.value
|
||||||
|
status = conn.get("status")
|
||||||
|
if hasattr(status, "value"):
|
||||||
|
status = status.value
|
||||||
|
items.append({
|
||||||
|
"id": conn.get("id"),
|
||||||
|
"authority": authority,
|
||||||
|
"externalUsername": conn.get("externalUsername"),
|
||||||
|
"externalEmail": conn.get("externalEmail"),
|
||||||
|
"status": status,
|
||||||
|
})
|
||||||
|
return JSONResponse({"connections": items})
|
||||||
|
|
||||||
|
|
||||||
|
class CreateDataSourceRequest(BaseModel):
|
||||||
|
"""Request body for creating a DataSource."""
|
||||||
|
connectionId: str = Field(description="Connection ID")
|
||||||
|
sourceType: str = Field(description="Source type")
|
||||||
|
path: str = Field(description="Path")
|
||||||
|
label: str = Field(description="Label")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{instanceId}/datasources")
|
||||||
|
@limiter.limit("60/minute")
|
||||||
|
async def createWorkspaceDataSource(
|
||||||
|
request: Request,
|
||||||
|
instanceId: str = Path(...),
|
||||||
|
body: CreateDataSourceRequest = Body(...),
|
||||||
|
context: RequestContext = Depends(getRequestContext),
|
||||||
|
):
|
||||||
|
"""Create a new DataSource for this workspace instance."""
|
||||||
|
_validateInstanceAccess(instanceId, context)
|
||||||
|
from modules.serviceCenter import getService
|
||||||
|
from modules.serviceCenter.context import ServiceCenterContext
|
||||||
|
ctx = ServiceCenterContext(
|
||||||
|
user=context.user,
|
||||||
|
mandate_id=str(context.mandateId) if context.mandateId else None,
|
||||||
|
feature_instance_id=instanceId,
|
||||||
|
)
|
||||||
|
chatService = getService("chat", ctx)
|
||||||
|
dataSource = chatService.createDataSource(
|
||||||
|
connectionId=body.connectionId,
|
||||||
|
sourceType=body.sourceType,
|
||||||
|
path=body.path,
|
||||||
|
label=body.label,
|
||||||
|
featureInstanceId=instanceId,
|
||||||
|
)
|
||||||
|
return JSONResponse(dataSource if isinstance(dataSource, dict) else dataSource.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{instanceId}/datasources/{dataSourceId}")
|
||||||
|
@limiter.limit("60/minute")
|
||||||
|
async def deleteWorkspaceDataSource(
|
||||||
|
request: Request,
|
||||||
|
instanceId: str = Path(...),
|
||||||
|
dataSourceId: str = Path(...),
|
||||||
|
context: RequestContext = Depends(getRequestContext),
|
||||||
|
):
|
||||||
|
"""Delete a DataSource."""
|
||||||
|
_validateInstanceAccess(instanceId, context)
|
||||||
|
from modules.serviceCenter import getService
|
||||||
|
from modules.serviceCenter.context import ServiceCenterContext
|
||||||
|
ctx = ServiceCenterContext(
|
||||||
|
user=context.user,
|
||||||
|
mandate_id=str(context.mandateId) if context.mandateId else None,
|
||||||
|
feature_instance_id=instanceId,
|
||||||
|
)
|
||||||
|
chatService = getService("chat", ctx)
|
||||||
|
chatService.deleteDataSource(dataSourceId)
|
||||||
|
return JSONResponse({"success": True})
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{instanceId}/connections/{connectionId}/services")
|
||||||
|
@limiter.limit("30/minute")
|
||||||
|
async def listConnectionServices(
|
||||||
|
request: Request,
|
||||||
|
instanceId: str = Path(...),
|
||||||
|
connectionId: str = Path(...),
|
||||||
|
context: RequestContext = Depends(getRequestContext),
|
||||||
|
):
|
||||||
|
"""Return the available services for a specific UserConnection."""
|
||||||
|
_validateInstanceAccess(instanceId, context)
|
||||||
|
try:
|
||||||
|
from modules.connectors.connectorResolver import ConnectorResolver
|
||||||
|
from modules.serviceCenter import getService as getSvc
|
||||||
|
from modules.serviceCenter.context import ServiceCenterContext
|
||||||
|
ctx = ServiceCenterContext(
|
||||||
|
user=context.user,
|
||||||
|
mandate_id=str(context.mandateId) if context.mandateId else None,
|
||||||
|
feature_instance_id=instanceId,
|
||||||
|
)
|
||||||
|
chatService = getSvc("chat", ctx)
|
||||||
|
securityService = getSvc("security", ctx)
|
||||||
|
dbInterface = _buildResolverDbInterface(chatService)
|
||||||
|
resolver = ConnectorResolver(securityService, dbInterface)
|
||||||
|
provider = await resolver.resolve(connectionId)
|
||||||
|
services = provider.getAvailableServices()
|
||||||
|
_serviceLabels = {
|
||||||
|
"sharepoint": "SharePoint",
|
||||||
|
"outlook": "Outlook",
|
||||||
|
"teams": "Teams",
|
||||||
|
"onedrive": "OneDrive",
|
||||||
|
"drive": "Google Drive",
|
||||||
|
"gmail": "Gmail",
|
||||||
|
"files": "Files (FTP)",
|
||||||
|
}
|
||||||
|
_serviceIcons = {
|
||||||
|
"sharepoint": "sharepoint",
|
||||||
|
"outlook": "mail",
|
||||||
|
"teams": "chat",
|
||||||
|
"onedrive": "cloud",
|
||||||
|
"drive": "cloud",
|
||||||
|
"gmail": "mail",
|
||||||
|
"files": "folder",
|
||||||
|
}
|
||||||
|
items = [
|
||||||
|
{
|
||||||
|
"service": s,
|
||||||
|
"label": _serviceLabels.get(s, s),
|
||||||
|
"icon": _serviceIcons.get(s, "folder"),
|
||||||
|
}
|
||||||
|
for s in services
|
||||||
|
]
|
||||||
|
return JSONResponse({"services": items})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing services for connection {connectionId}: {e}")
|
||||||
|
return JSONResponse({"services": [], "error": str(e)}, status_code=400)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{instanceId}/connections/{connectionId}/browse")
|
||||||
|
@limiter.limit("60/minute")
|
||||||
|
async def browseConnectionService(
|
||||||
|
request: Request,
|
||||||
|
instanceId: str = Path(...),
|
||||||
|
connectionId: str = Path(...),
|
||||||
|
service: str = Query(..., description="Service name (e.g. sharepoint, onedrive, outlook)"),
|
||||||
|
path: str = Query("/", description="Path within the service to browse"),
|
||||||
|
context: RequestContext = Depends(getRequestContext),
|
||||||
|
):
|
||||||
|
"""Browse folders/items within a connection's service at a given path."""
|
||||||
|
_validateInstanceAccess(instanceId, context)
|
||||||
|
try:
|
||||||
|
from modules.connectors.connectorResolver import ConnectorResolver
|
||||||
|
from modules.serviceCenter import getService as getSvc
|
||||||
|
from modules.serviceCenter.context import ServiceCenterContext
|
||||||
|
ctx = ServiceCenterContext(
|
||||||
|
user=context.user,
|
||||||
|
mandate_id=str(context.mandateId) if context.mandateId else None,
|
||||||
|
feature_instance_id=instanceId,
|
||||||
|
)
|
||||||
|
chatService = getSvc("chat", ctx)
|
||||||
|
securityService = getSvc("security", ctx)
|
||||||
|
dbInterface = _buildResolverDbInterface(chatService)
|
||||||
|
resolver = ConnectorResolver(securityService, dbInterface)
|
||||||
|
adapter = await resolver.resolveService(connectionId, service)
|
||||||
|
entries = await adapter.browse(path, filter=None)
|
||||||
|
items = []
|
||||||
|
for entry in (entries or []):
|
||||||
|
items.append({
|
||||||
|
"name": entry.name,
|
||||||
|
"path": entry.path,
|
||||||
|
"isFolder": entry.isFolder,
|
||||||
|
"size": entry.size,
|
||||||
|
"mimeType": entry.mimeType,
|
||||||
|
"metadata": entry.metadata if hasattr(entry, "metadata") else {},
|
||||||
|
})
|
||||||
|
return JSONResponse({"items": items, "path": path, "service": service})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error browsing {service} for connection {connectionId} at '{path}': {e}")
|
||||||
|
return JSONResponse({"items": [], "error": str(e)}, status_code=400)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Voice endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/{instanceId}/voice/transcribe")
|
||||||
|
@limiter.limit("30/minute")
|
||||||
|
async def transcribeVoice(
|
||||||
|
request: Request,
|
||||||
|
instanceId: str = Path(...),
|
||||||
|
audio: UploadFile = File(...),
|
||||||
|
context: RequestContext = Depends(getRequestContext),
|
||||||
|
):
|
||||||
|
"""Transcribe audio to text using speech-to-text."""
|
||||||
|
_validateInstanceAccess(instanceId, context)
|
||||||
|
audioBytes = await audio.read()
|
||||||
|
try:
|
||||||
|
import aiohttp
|
||||||
|
formData = aiohttp.FormData()
|
||||||
|
formData.add_field("audio", audioBytes, filename=audio.filename or "audio.webm")
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{request.base_url}api/voice-google/speech-to-text",
|
||||||
|
data=formData,
|
||||||
|
) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
result = await resp.json()
|
||||||
|
return JSONResponse({"text": result.get("text", "")})
|
||||||
|
return JSONResponse({"text": "", "error": f"STT failed: {resp.status}"})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Voice transcription error: {e}")
|
||||||
|
return JSONResponse({"text": "", "error": str(e)})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{instanceId}/voice/synthesize")
|
||||||
|
@limiter.limit("30/minute")
|
||||||
|
async def synthesizeVoice(
|
||||||
|
request: Request,
|
||||||
|
instanceId: str = Path(...),
|
||||||
|
body: dict = Body(...),
|
||||||
|
context: RequestContext = Depends(getRequestContext),
|
||||||
|
):
|
||||||
|
"""Synthesize text to speech audio."""
|
||||||
|
_validateInstanceAccess(instanceId, context)
|
||||||
|
text = body.get("text", "")
|
||||||
|
if not text:
|
||||||
|
raise HTTPException(status_code=400, detail="text is required")
|
||||||
|
return JSONResponse({"audio": None, "note": "TTS via browser Speech Synthesis API recommended"})
|
||||||
|
|
@ -4,7 +4,7 @@ import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
import base64
|
import base64
|
||||||
from typing import Dict, Any, List, Union, Tuple, Optional, Callable
|
from typing import Dict, Any, List, Union, Tuple, Optional, Callable, AsyncGenerator
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
@ -84,15 +84,16 @@ class AiObjects:
|
||||||
|
|
||||||
# AI for Extraction, Processing, Generation
|
# AI for Extraction, Processing, Generation
|
||||||
async def callWithTextContext(self, request: AiCallRequest) -> AiCallResponse:
|
async def callWithTextContext(self, request: AiCallRequest) -> AiCallResponse:
|
||||||
"""Call AI model for traditional text/context calls with fallback mechanism."""
|
"""Call AI model for traditional text/context calls with fallback mechanism.
|
||||||
|
|
||||||
|
Supports two modes:
|
||||||
|
- Legacy: prompt + context → constructs messages internally
|
||||||
|
- Agent: request.messages provided → passes through directly
|
||||||
|
"""
|
||||||
prompt = request.prompt
|
prompt = request.prompt
|
||||||
context = request.context or ""
|
context = request.context or ""
|
||||||
options = request.options
|
options = request.options
|
||||||
|
|
||||||
# Input bytes will be calculated inside _callWithModel
|
|
||||||
|
|
||||||
# Generation parameters are handled inside _callWithModel
|
|
||||||
|
|
||||||
# Get failover models for this operation type
|
# Get failover models for this operation type
|
||||||
availableModels = modelRegistry.getAvailableModels()
|
availableModels = modelRegistry.getAvailableModels()
|
||||||
|
|
||||||
|
|
@ -127,10 +128,12 @@ class AiObjects:
|
||||||
try:
|
try:
|
||||||
logger.info(f"Attempting AI call with model: {model.name} (attempt {attempt + 1}/{len(failoverModelList)})")
|
logger.info(f"Attempting AI call with model: {model.name} (attempt {attempt + 1}/{len(failoverModelList)})")
|
||||||
|
|
||||||
# Call the model directly - no truncation or compression here
|
if request.messages:
|
||||||
response = await self._callWithModel(model, prompt, context, options)
|
response = await self._callWithMessages(model, request.messages, options, request.tools)
|
||||||
|
else:
|
||||||
|
response = await self._callWithModel(model, prompt, context, options)
|
||||||
|
|
||||||
logger.info(f"✅ AI call successful with model: {model.name}")
|
logger.info(f"AI call successful with model: {model.name}")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -142,8 +145,7 @@ class AiObjects:
|
||||||
logger.info(f"Trying next failover model...")
|
logger.info(f"Trying next failover model...")
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
# All models failed
|
logger.error(f"All {len(failoverModelList)} models failed for operation {options.operationType}")
|
||||||
logger.error(f"💥 All {len(failoverModelList)} models failed for operation {options.operationType}")
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# All failover attempts failed - return error response
|
# All failover attempts failed - return error response
|
||||||
|
|
@ -254,6 +256,242 @@ class AiObjects:
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def _callWithMessages(self, model: AiModel, messages: List[Dict[str, Any]],
|
||||||
|
options: AiCallOptions = None,
|
||||||
|
tools: List[Dict[str, Any]] = None) -> AiCallResponse:
|
||||||
|
"""Call a model with pre-built messages (agent mode). Supports tools for native function calling."""
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
inputBytes = sum(len(str(m.get("content", "")).encode("utf-8")) for m in messages)
|
||||||
|
startTime = time.time()
|
||||||
|
|
||||||
|
if not model.functionCall:
|
||||||
|
raise ValueError(f"Model {model.name} has no function call defined")
|
||||||
|
|
||||||
|
modelCall = AiModelCall(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
options=options or {},
|
||||||
|
tools=tools
|
||||||
|
)
|
||||||
|
|
||||||
|
modelResponse = await model.functionCall(modelCall)
|
||||||
|
|
||||||
|
if not modelResponse.success:
|
||||||
|
raise ValueError(f"Model call failed: {modelResponse.error}")
|
||||||
|
|
||||||
|
endTime = time.time()
|
||||||
|
processingTime = endTime - startTime
|
||||||
|
content = modelResponse.content
|
||||||
|
outputBytes = len(content.encode("utf-8"))
|
||||||
|
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes)
|
||||||
|
|
||||||
|
# Extract tool calls from metadata if present (native function calling)
|
||||||
|
responseToolCalls = None
|
||||||
|
if modelResponse.metadata:
|
||||||
|
responseToolCalls = modelResponse.metadata.get("toolCalls")
|
||||||
|
|
||||||
|
response = AiCallResponse(
|
||||||
|
content=content,
|
||||||
|
modelName=model.name,
|
||||||
|
provider=model.connectorType,
|
||||||
|
priceCHF=priceCHF,
|
||||||
|
processingTime=processingTime,
|
||||||
|
bytesSent=inputBytes,
|
||||||
|
bytesReceived=outputBytes,
|
||||||
|
errorCount=0,
|
||||||
|
toolCalls=responseToolCalls
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.billingCallback:
|
||||||
|
try:
|
||||||
|
self.billingCallback(response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"BILLING: Failed to record billing for model {model.name}: {e}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def callWithTextContextStream(
|
||||||
|
self, request: AiCallRequest
|
||||||
|
) -> AsyncGenerator[Union[str, AiCallResponse], None]:
|
||||||
|
"""Streaming variant of callWithTextContext. Yields str deltas, then final AiCallResponse."""
|
||||||
|
options = request.options
|
||||||
|
availableModels = modelRegistry.getAvailableModels()
|
||||||
|
|
||||||
|
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
|
||||||
|
if allowedProviders:
|
||||||
|
filtered = [m for m in availableModels if m.connectorType in allowedProviders]
|
||||||
|
if filtered:
|
||||||
|
availableModels = filtered
|
||||||
|
|
||||||
|
failoverModelList = modelSelector.getFailoverModelList(
|
||||||
|
request.prompt, request.context or "", options, availableModels
|
||||||
|
)
|
||||||
|
if not failoverModelList:
|
||||||
|
yield AiCallResponse(
|
||||||
|
content=f"No suitable models found for operation {options.operationType}",
|
||||||
|
modelName="error", priceCHF=0.0, processingTime=0.0,
|
||||||
|
bytesSent=0, bytesReceived=0, errorCount=1,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
lastError = None
|
||||||
|
for attempt, model in enumerate(failoverModelList):
|
||||||
|
try:
|
||||||
|
logger.info(f"Streaming AI call with model: {model.name} (attempt {attempt + 1})")
|
||||||
|
async for chunk in self._callWithMessagesStream(model, request.messages, options, request.tools):
|
||||||
|
yield chunk
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
lastError = e
|
||||||
|
logger.warning(f"Streaming AI call failed with {model.name}: {e}")
|
||||||
|
modelSelector.reportFailure(model.name)
|
||||||
|
if attempt < len(failoverModelList) - 1:
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
yield AiCallResponse(
|
||||||
|
content=f"All models failed (stream). Last error: {lastError}",
|
||||||
|
modelName="error", priceCHF=0.0, processingTime=0.0,
|
||||||
|
bytesSent=0, bytesReceived=0, errorCount=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _callWithMessagesStream(
|
||||||
|
self, model: AiModel, messages: List[Dict[str, Any]],
|
||||||
|
options: AiCallOptions = None, tools: List[Dict[str, Any]] = None,
|
||||||
|
) -> AsyncGenerator[Union[str, AiCallResponse], None]:
|
||||||
|
"""Stream a model call. Yields str deltas, then final AiCallResponse with billing."""
|
||||||
|
from modules.datamodels.datamodelAi import AiModelCall, AiModelResponse
|
||||||
|
|
||||||
|
inputBytes = sum(len(str(m.get("content", "")).encode("utf-8")) for m in messages)
|
||||||
|
startTime = time.time()
|
||||||
|
|
||||||
|
if not model.functionCallStream:
|
||||||
|
response = await self._callWithMessages(model, messages, options, tools)
|
||||||
|
if response.content:
|
||||||
|
yield response.content
|
||||||
|
yield response
|
||||||
|
return
|
||||||
|
|
||||||
|
modelCall = AiModelCall(
|
||||||
|
messages=messages, model=model,
|
||||||
|
options=options or {}, tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
finalModelResponse = None
|
||||||
|
async for item in model.functionCallStream(modelCall):
|
||||||
|
if isinstance(item, AiModelResponse):
|
||||||
|
finalModelResponse = item
|
||||||
|
else:
|
||||||
|
yield item
|
||||||
|
|
||||||
|
if not finalModelResponse:
|
||||||
|
raise ValueError(f"Stream from {model.name} produced no final AiModelResponse")
|
||||||
|
|
||||||
|
endTime = time.time()
|
||||||
|
processingTime = endTime - startTime
|
||||||
|
content = finalModelResponse.content
|
||||||
|
outputBytes = len(content.encode("utf-8"))
|
||||||
|
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes)
|
||||||
|
|
||||||
|
responseToolCalls = None
|
||||||
|
if finalModelResponse.metadata:
|
||||||
|
responseToolCalls = finalModelResponse.metadata.get("toolCalls")
|
||||||
|
|
||||||
|
response = AiCallResponse(
|
||||||
|
content=content,
|
||||||
|
modelName=model.name,
|
||||||
|
provider=model.connectorType,
|
||||||
|
priceCHF=priceCHF,
|
||||||
|
processingTime=processingTime,
|
||||||
|
bytesSent=inputBytes,
|
||||||
|
bytesReceived=outputBytes,
|
||||||
|
errorCount=0,
|
||||||
|
toolCalls=responseToolCalls,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.billingCallback:
|
||||||
|
try:
|
||||||
|
self.billingCallback(response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"BILLING: Failed to record stream billing for {model.name}: {e}")
|
||||||
|
|
||||||
|
yield response
|
||||||
|
|
||||||
|
async def callEmbedding(self, texts: List[str], options: AiCallOptions = None) -> AiCallResponse:
|
||||||
|
"""Generate embeddings for a list of texts using the best available embedding model.
|
||||||
|
|
||||||
|
Uses the standard model selector with OperationTypeEnum.EMBEDDING to pick the model.
|
||||||
|
Failover across providers (OpenAI → Mistral) works identically to chat models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AiCallResponse with metadata["embeddings"] containing the vectors.
|
||||||
|
"""
|
||||||
|
if options is None:
|
||||||
|
options = AiCallOptions(operationType=OperationTypeEnum.EMBEDDING)
|
||||||
|
else:
|
||||||
|
options.operationType = OperationTypeEnum.EMBEDDING
|
||||||
|
|
||||||
|
combinedText = " ".join(texts[:3])[:500]
|
||||||
|
availableModels = modelRegistry.getAvailableModels()
|
||||||
|
failoverModelList = modelSelector.getFailoverModelList(
|
||||||
|
combinedText, "", options, availableModels
|
||||||
|
)
|
||||||
|
|
||||||
|
if not failoverModelList:
|
||||||
|
return AiCallResponse(
|
||||||
|
content="", modelName="error", priceCHF=0.0,
|
||||||
|
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1
|
||||||
|
)
|
||||||
|
|
||||||
|
lastError = None
|
||||||
|
for attempt, model in enumerate(failoverModelList):
|
||||||
|
try:
|
||||||
|
logger.info(f"Embedding call with {model.name} (attempt {attempt + 1}/{len(failoverModelList)})")
|
||||||
|
inputBytes = sum(len(t.encode("utf-8")) for t in texts)
|
||||||
|
startTime = time.time()
|
||||||
|
|
||||||
|
modelCall = AiModelCall(
|
||||||
|
model=model, options=options, embeddingInput=texts
|
||||||
|
)
|
||||||
|
modelResponse = await model.functionCall(modelCall)
|
||||||
|
|
||||||
|
if not modelResponse.success:
|
||||||
|
raise ValueError(f"Embedding call failed: {modelResponse.error}")
|
||||||
|
|
||||||
|
processingTime = time.time() - startTime
|
||||||
|
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, 0)
|
||||||
|
embeddings = (modelResponse.metadata or {}).get("embeddings", [])
|
||||||
|
|
||||||
|
response = AiCallResponse(
|
||||||
|
content="", modelName=model.name, provider=model.connectorType,
|
||||||
|
priceCHF=priceCHF, processingTime=processingTime,
|
||||||
|
bytesSent=inputBytes, bytesReceived=0, errorCount=0,
|
||||||
|
metadata={"embeddings": embeddings}
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.billingCallback:
|
||||||
|
try:
|
||||||
|
self.billingCallback(response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"BILLING: Failed to record billing for embedding {model.name}: {e}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
lastError = e
|
||||||
|
logger.warning(f"Embedding call failed with {model.name}: {str(e)}")
|
||||||
|
modelSelector.reportFailure(model.name)
|
||||||
|
if attempt < len(failoverModelList) - 1:
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
errorMsg = f"All embedding models failed. Last error: {str(lastError)}"
|
||||||
|
logger.error(errorMsg)
|
||||||
|
return AiCallResponse(
|
||||||
|
content=errorMsg, modelName="error", priceCHF=0.0,
|
||||||
|
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1
|
||||||
|
)
|
||||||
|
|
||||||
# Utility methods
|
# Utility methods
|
||||||
async def listAvailableModels(self, connectorType: str = None) -> List[Dict[str, Any]]:
|
async def listAvailableModels(self, connectorType: str = None) -> List[Dict[str, Any]]:
|
||||||
|
|
|
||||||
|
|
@ -756,7 +756,7 @@ class ChatObjects:
|
||||||
logs=[],
|
logs=[],
|
||||||
messages=[],
|
messages=[],
|
||||||
stats=[],
|
stats=[],
|
||||||
workflowMode=created["workflowMode"],
|
workflowMode=created.get("workflowMode", "Dynamic"),
|
||||||
maxSteps=created.get("maxSteps", 1)
|
maxSteps=created.get("maxSteps", 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -790,11 +790,11 @@ class ChatObjects:
|
||||||
id=updated["id"],
|
id=updated["id"],
|
||||||
status=updated.get("status", workflow.status),
|
status=updated.get("status", workflow.status),
|
||||||
name=updated.get("name", workflow.name),
|
name=updated.get("name", workflow.name),
|
||||||
currentRound=updated.get("currentRound", workflow.currentRound),
|
currentRound=updated.get("currentRound") or getattr(workflow, "currentRound", 0) or 0,
|
||||||
currentTask=updated.get("currentTask", workflow.currentTask),
|
currentTask=updated.get("currentTask") or getattr(workflow, "currentTask", 0) or 0,
|
||||||
currentAction=updated.get("currentAction", workflow.currentAction),
|
currentAction=updated.get("currentAction") or getattr(workflow, "currentAction", 0) or 0,
|
||||||
totalTasks=updated.get("totalTasks", workflow.totalTasks),
|
totalTasks=updated.get("totalTasks") or getattr(workflow, "totalTasks", 0) or 0,
|
||||||
totalActions=updated.get("totalActions", workflow.totalActions),
|
totalActions=updated.get("totalActions") or getattr(workflow, "totalActions", 0) or 0,
|
||||||
lastActivity=updated.get("lastActivity", workflow.lastActivity),
|
lastActivity=updated.get("lastActivity", workflow.lastActivity),
|
||||||
startedAt=updated.get("startedAt", workflow.startedAt),
|
startedAt=updated.get("startedAt", workflow.startedAt),
|
||||||
logs=logs,
|
logs=logs,
|
||||||
|
|
|
||||||
234
modules/interfaces/interfaceDbKnowledge.py
Normal file
234
modules/interfaces/interfaceDbKnowledge.py
Normal file
|
|
@ -0,0 +1,234 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""
|
||||||
|
Interface to the Knowledge Store database (poweron_knowledge).
|
||||||
|
Provides CRUD for FileContentIndex, ContentChunk, WorkflowMemory
|
||||||
|
and semantic search via pgvector.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
|
from modules.connectors.connectorDbPostgre import _get_cached_connector
|
||||||
|
from modules.datamodels.datamodelKnowledge import FileContentIndex, ContentChunk, WorkflowMemory
|
||||||
|
from modules.datamodels.datamodelUam import User
|
||||||
|
from modules.shared.configuration import APP_CONFIG
|
||||||
|
from modules.shared.timeUtils import getUtcTimestamp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_instances: Dict[str, "KnowledgeObjects"] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeObjects:
|
||||||
|
"""Interface to the Knowledge Store database.
|
||||||
|
Manages FileContentIndex, ContentChunk, and WorkflowMemory with semantic search."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.currentUser: Optional[User] = None
|
||||||
|
self.userId: Optional[str] = None
|
||||||
|
self._initializeDatabase()
|
||||||
|
|
||||||
|
def _initializeDatabase(self):
|
||||||
|
dbHost = APP_CONFIG.get("DB_HOST", "_no_config_default_data")
|
||||||
|
dbDatabase = "poweron_knowledge"
|
||||||
|
dbUser = APP_CONFIG.get("DB_USER")
|
||||||
|
dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET")
|
||||||
|
dbPort = int(APP_CONFIG.get("DB_PORT", 5432))
|
||||||
|
|
||||||
|
self.db = _get_cached_connector(
|
||||||
|
dbHost=dbHost,
|
||||||
|
dbDatabase=dbDatabase,
|
||||||
|
dbUser=dbUser,
|
||||||
|
dbPassword=dbPassword,
|
||||||
|
dbPort=dbPort,
|
||||||
|
userId=self.userId,
|
||||||
|
)
|
||||||
|
logger.info("Knowledge Store database initialized")
|
||||||
|
|
||||||
|
def setUserContext(self, user: User):
|
||||||
|
self.currentUser = user
|
||||||
|
self.userId = user.id if user else None
|
||||||
|
if self.userId:
|
||||||
|
self.db.updateContext(self.userId)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# FileContentIndex CRUD
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def upsertFileContentIndex(self, index: FileContentIndex) -> Dict[str, Any]:
|
||||||
|
"""Create or update a FileContentIndex entry."""
|
||||||
|
data = index.model_dump()
|
||||||
|
existing = self.db._loadRecord(FileContentIndex, index.id)
|
||||||
|
if existing:
|
||||||
|
return self.db.recordModify(FileContentIndex, index.id, data)
|
||||||
|
return self.db.recordCreate(FileContentIndex, data)
|
||||||
|
|
||||||
|
def getFileContentIndex(self, fileId: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get a FileContentIndex by file ID."""
|
||||||
|
return self.db._loadRecord(FileContentIndex, fileId)
|
||||||
|
|
||||||
|
def getFileContentIndexByUser(
|
||||||
|
self, userId: str, featureInstanceId: str = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Get all FileContentIndex entries for a user."""
|
||||||
|
recordFilter = {"userId": userId}
|
||||||
|
if featureInstanceId:
|
||||||
|
recordFilter["featureInstanceId"] = featureInstanceId
|
||||||
|
return self.db.getRecordset(FileContentIndex, recordFilter=recordFilter)
|
||||||
|
|
||||||
|
def updateFileStatus(self, fileId: str, status: str) -> bool:
|
||||||
|
"""Update the processing status of a FileContentIndex."""
|
||||||
|
existing = self.db._loadRecord(FileContentIndex, fileId)
|
||||||
|
if not existing:
|
||||||
|
return False
|
||||||
|
self.db.recordModify(FileContentIndex, fileId, {"status": status})
|
||||||
|
return True
|
||||||
|
|
||||||
|
def deleteFileContentIndex(self, fileId: str) -> bool:
|
||||||
|
"""Delete a FileContentIndex and all associated ContentChunks."""
|
||||||
|
chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
|
||||||
|
for chunk in chunks:
|
||||||
|
self.db.recordDelete(ContentChunk, chunk["id"])
|
||||||
|
return self.db.recordDelete(FileContentIndex, fileId)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# ContentChunk CRUD
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def upsertContentChunk(self, chunk: ContentChunk) -> Dict[str, Any]:
|
||||||
|
"""Create or update a ContentChunk."""
|
||||||
|
data = chunk.model_dump()
|
||||||
|
existing = self.db._loadRecord(ContentChunk, chunk.id)
|
||||||
|
if existing:
|
||||||
|
return self.db.recordModify(ContentChunk, chunk.id, data)
|
||||||
|
return self.db.recordCreate(ContentChunk, data)
|
||||||
|
|
||||||
|
def upsertContentChunks(self, chunks: List[ContentChunk]) -> int:
|
||||||
|
"""Batch upsert multiple ContentChunks. Returns count of upserted chunks."""
|
||||||
|
count = 0
|
||||||
|
for chunk in chunks:
|
||||||
|
self.upsertContentChunk(chunk)
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
def getContentChunks(self, fileId: str) -> List[Dict[str, Any]]:
|
||||||
|
"""Get all ContentChunks for a file."""
|
||||||
|
return self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
|
||||||
|
|
||||||
|
def deleteContentChunks(self, fileId: str) -> int:
|
||||||
|
"""Delete all ContentChunks for a file. Returns count of deleted chunks."""
|
||||||
|
chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
|
||||||
|
count = 0
|
||||||
|
for chunk in chunks:
|
||||||
|
if self.db.recordDelete(ContentChunk, chunk["id"]):
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# WorkflowMemory CRUD
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def upsertWorkflowMemory(self, memory: WorkflowMemory) -> Dict[str, Any]:
|
||||||
|
"""Create or update a WorkflowMemory entry."""
|
||||||
|
data = memory.model_dump()
|
||||||
|
existing = self.db._loadRecord(WorkflowMemory, memory.id)
|
||||||
|
if existing:
|
||||||
|
return self.db.recordModify(WorkflowMemory, memory.id, data)
|
||||||
|
return self.db.recordCreate(WorkflowMemory, data)
|
||||||
|
|
||||||
|
def getWorkflowEntities(self, workflowId: str) -> List[Dict[str, Any]]:
|
||||||
|
"""Get all WorkflowMemory entries for a workflow."""
|
||||||
|
return self.db.getRecordset(WorkflowMemory, recordFilter={"workflowId": workflowId})
|
||||||
|
|
||||||
|
def getWorkflowEntity(self, workflowId: str, key: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get a specific WorkflowMemory entry by workflow and key."""
|
||||||
|
results = self.db.getRecordset(
|
||||||
|
WorkflowMemory, recordFilter={"workflowId": workflowId, "key": key}
|
||||||
|
)
|
||||||
|
return results[0] if results else None
|
||||||
|
|
||||||
|
def deleteWorkflowMemory(self, workflowId: str) -> int:
|
||||||
|
"""Delete all WorkflowMemory entries for a workflow. Returns count."""
|
||||||
|
entries = self.db.getRecordset(WorkflowMemory, recordFilter={"workflowId": workflowId})
|
||||||
|
count = 0
|
||||||
|
for entry in entries:
|
||||||
|
if self.db.recordDelete(WorkflowMemory, entry["id"]):
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Semantic Search
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def semanticSearch(
|
||||||
|
self,
|
||||||
|
queryVector: List[float],
|
||||||
|
userId: str = None,
|
||||||
|
featureInstanceId: str = None,
|
||||||
|
mandateId: str = None,
|
||||||
|
isShared: bool = None,
|
||||||
|
limit: int = 10,
|
||||||
|
minScore: float = None,
|
||||||
|
contentType: str = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Semantic search across ContentChunks using pgvector cosine similarity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queryVector: Query embedding vector.
|
||||||
|
userId: Filter by user (Instance Layer).
|
||||||
|
featureInstanceId: Filter by feature instance.
|
||||||
|
mandateId: Filter by mandate (for Shared Layer lookups).
|
||||||
|
isShared: If True, search Shared Layer via FileContentIndex join.
|
||||||
|
limit: Max results.
|
||||||
|
minScore: Minimum cosine similarity (0.0 - 1.0).
|
||||||
|
contentType: Filter by content type (text, image, etc.).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ContentChunk records with _score field, sorted by relevance.
|
||||||
|
"""
|
||||||
|
recordFilter = {}
|
||||||
|
if userId:
|
||||||
|
recordFilter["userId"] = userId
|
||||||
|
if featureInstanceId:
|
||||||
|
recordFilter["featureInstanceId"] = featureInstanceId
|
||||||
|
if contentType:
|
||||||
|
recordFilter["contentType"] = contentType
|
||||||
|
|
||||||
|
return self.db.semanticSearch(
|
||||||
|
modelClass=ContentChunk,
|
||||||
|
vectorColumn="embedding",
|
||||||
|
queryVector=queryVector,
|
||||||
|
limit=limit,
|
||||||
|
recordFilter=recordFilter if recordFilter else None,
|
||||||
|
minScore=minScore,
|
||||||
|
)
|
||||||
|
|
||||||
|
def semanticSearchWorkflowMemory(
|
||||||
|
self,
|
||||||
|
queryVector: List[float],
|
||||||
|
workflowId: str,
|
||||||
|
limit: int = 5,
|
||||||
|
minScore: float = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Semantic search across WorkflowMemory entries."""
|
||||||
|
return self.db.semanticSearch(
|
||||||
|
modelClass=WorkflowMemory,
|
||||||
|
vectorColumn="embedding",
|
||||||
|
queryVector=queryVector,
|
||||||
|
limit=limit,
|
||||||
|
recordFilter={"workflowId": workflowId},
|
||||||
|
minScore=minScore,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def getInterface(currentUser: Optional[User] = None) -> KnowledgeObjects:
|
||||||
|
"""Get or create a KnowledgeObjects singleton."""
|
||||||
|
if "default" not in _instances:
|
||||||
|
_instances["default"] = KnowledgeObjects()
|
||||||
|
|
||||||
|
interface = _instances["default"]
|
||||||
|
if currentUser:
|
||||||
|
interface.setUserContext(currentUser)
|
||||||
|
|
||||||
|
return interface
|
||||||
|
|
@ -68,13 +68,20 @@ TABLE_NAMESPACE = {
|
||||||
# Files - benutzer-eigen
|
# Files - benutzer-eigen
|
||||||
"FileItem": "files",
|
"FileItem": "files",
|
||||||
"FileData": "files",
|
"FileData": "files",
|
||||||
|
"FileFolder": "files",
|
||||||
# Automation - benutzer-eigen
|
# Automation - benutzer-eigen
|
||||||
"AutomationDefinition": "automation",
|
"AutomationDefinition": "automation",
|
||||||
"AutomationTemplate": "automation",
|
"AutomationTemplate": "automation",
|
||||||
|
# Knowledge Store - benutzer-eigen
|
||||||
|
"FileContentIndex": "knowledge",
|
||||||
|
"ContentChunk": "knowledge",
|
||||||
|
"WorkflowMemory": "knowledge",
|
||||||
|
# Data Sources - benutzer-eigen
|
||||||
|
"DataSource": "datasource",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Namespaces ohne Mandantenkontext - GROUP wird auf MY gemappt
|
# Namespaces ohne Mandantenkontext - GROUP wird auf MY gemappt
|
||||||
USER_OWNED_NAMESPACES = {"chat", "chatbot", "files", "automation"}
|
USER_OWNED_NAMESPACES = {"chat", "chatbot", "files", "automation", "knowledge", "datasource"}
|
||||||
|
|
||||||
|
|
||||||
def buildDataObjectKey(tableName: str, featureCode: Optional[str] = None) -> str:
|
def buildDataObjectKey(tableName: str, featureCode: Optional[str] = None) -> str:
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,114 @@ from modules.datamodels.datamodelPagination import PaginationParams, PaginatedRe
|
||||||
# Configure logger
|
# Configure logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user):
|
||||||
|
"""Background task: pre-scan + extraction + knowledge indexing.
|
||||||
|
Step 1: Structure Pre-Scan (AI-free) -> FileContentIndex (persisted)
|
||||||
|
Step 2: Content extraction via runExtraction -> ContentParts
|
||||||
|
Step 3: KnowledgeService.indexFile -> chunking + embedding -> Knowledge Store"""
|
||||||
|
userId = user.id if hasattr(user, "id") else str(user)
|
||||||
|
try:
|
||||||
|
mgmtInterface = interfaceDbManagement.getInterface(user)
|
||||||
|
mgmtInterface.updateFile(fileId, {"status": "processing"})
|
||||||
|
|
||||||
|
rawBytes = mgmtInterface.getFileData(fileId)
|
||||||
|
if not rawBytes:
|
||||||
|
logger.warning(f"Auto-index: no file data for {fileId}, skipping")
|
||||||
|
mgmtInterface.updateFile(fileId, {"status": "active"})
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Auto-index starting for {fileName} ({len(rawBytes)} bytes, {mimeType})")
|
||||||
|
|
||||||
|
# Step 1: Structure Pre-Scan (AI-free)
|
||||||
|
from modules.serviceCenter.services.serviceKnowledge.subPreScan import preScanDocument
|
||||||
|
contentIndex = await preScanDocument(
|
||||||
|
fileData=rawBytes,
|
||||||
|
mimeType=mimeType,
|
||||||
|
fileId=fileId,
|
||||||
|
fileName=fileName,
|
||||||
|
userId=userId,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Pre-scan complete for {fileName}: "
|
||||||
|
f"{contentIndex.totalObjects} objects"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Persist FileContentIndex immediately
|
||||||
|
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
|
||||||
|
knowledgeDb = getKnowledgeInterface()
|
||||||
|
knowledgeDb.upsertFileContentIndex(contentIndex)
|
||||||
|
|
||||||
|
# Step 2: Content extraction (AI-free, produces ContentParts)
|
||||||
|
from modules.serviceCenter.services.serviceExtraction.subRegistry import ExtractorRegistry, ChunkerRegistry
|
||||||
|
from modules.serviceCenter.services.serviceExtraction.subPipeline import runExtraction
|
||||||
|
from modules.datamodels.datamodelExtraction import ExtractionOptions
|
||||||
|
|
||||||
|
extractorRegistry = ExtractorRegistry()
|
||||||
|
chunkerRegistry = ChunkerRegistry()
|
||||||
|
options = ExtractionOptions()
|
||||||
|
|
||||||
|
extracted = runExtraction(
|
||||||
|
extractorRegistry, chunkerRegistry,
|
||||||
|
rawBytes, fileName, mimeType, options,
|
||||||
|
)
|
||||||
|
|
||||||
|
contentObjects = []
|
||||||
|
for part in extracted.parts:
|
||||||
|
contentType = "text"
|
||||||
|
if part.typeGroup == "image":
|
||||||
|
contentType = "image"
|
||||||
|
elif part.typeGroup in ("binary", "container"):
|
||||||
|
contentType = "other"
|
||||||
|
|
||||||
|
if not part.data or not part.data.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
contentObjects.append({
|
||||||
|
"contentObjectId": part.id,
|
||||||
|
"contentType": contentType,
|
||||||
|
"data": part.data,
|
||||||
|
"contextRef": {
|
||||||
|
"containerPath": fileName,
|
||||||
|
"location": part.label or "file",
|
||||||
|
**(part.metadata or {}),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Extracted {len(contentObjects)} content objects from {fileName}")
|
||||||
|
|
||||||
|
if not contentObjects:
|
||||||
|
knowledgeDb.updateFileStatus(fileId, "indexed")
|
||||||
|
mgmtInterface.updateFile(fileId, {"status": "active"})
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 3: Knowledge indexing (chunking + embedding)
|
||||||
|
from modules.serviceCenter import getService
|
||||||
|
from modules.serviceCenter.context import ServiceCenterContext
|
||||||
|
|
||||||
|
ctx = ServiceCenterContext(user=user, mandate_id="", feature_instance_id="")
|
||||||
|
knowledgeService = getService("knowledge", ctx)
|
||||||
|
|
||||||
|
await knowledgeService.indexFile(
|
||||||
|
fileId=fileId,
|
||||||
|
fileName=fileName,
|
||||||
|
mimeType=mimeType,
|
||||||
|
userId=userId,
|
||||||
|
contentObjects=contentObjects,
|
||||||
|
structure=contentIndex.structure,
|
||||||
|
)
|
||||||
|
|
||||||
|
mgmtInterface.updateFile(fileId, {"status": "active"})
|
||||||
|
logger.info(f"Auto-index complete for file {fileId} ({fileName})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Auto-index failed for file {fileId}: {e}", exc_info=True)
|
||||||
|
try:
|
||||||
|
errMgmt = interfaceDbManagement.getInterface(user)
|
||||||
|
errMgmt.updateFile(fileId, {"status": "active"})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Model attributes for FileItem
|
# Model attributes for FileItem
|
||||||
fileAttributes = getModelAttributeDefinitions(FileItem)
|
fileAttributes = getModelAttributeDefinitions(FileItem)
|
||||||
|
|
||||||
|
|
@ -148,6 +256,32 @@ async def upload_file(
|
||||||
if workflowId:
|
if workflowId:
|
||||||
fileMeta["workflowId"] = workflowId
|
fileMeta["workflowId"] = workflowId
|
||||||
|
|
||||||
|
# Trigger background auto-index pipeline (non-blocking)
|
||||||
|
# Also runs for duplicates in case the original was never successfully indexed
|
||||||
|
shouldIndex = duplicateType == "new_file"
|
||||||
|
if not shouldIndex:
|
||||||
|
try:
|
||||||
|
from modules.interfaces.interfaceDbKnowledge import getInterface as _getKnowledgeInterface
|
||||||
|
_kDb = _getKnowledgeInterface()
|
||||||
|
_existingIndex = _kDb.getFileContentIndex(fileItem.id)
|
||||||
|
if not _existingIndex:
|
||||||
|
shouldIndex = True
|
||||||
|
logger.info(f"Re-triggering auto-index for duplicate {fileItem.id} (not yet indexed)")
|
||||||
|
except Exception:
|
||||||
|
shouldIndex = True
|
||||||
|
|
||||||
|
if shouldIndex:
|
||||||
|
try:
|
||||||
|
import asyncio
|
||||||
|
asyncio.ensure_future(_autoIndexFile(
|
||||||
|
fileId=fileItem.id,
|
||||||
|
fileName=fileItem.fileName,
|
||||||
|
mimeType=fileItem.mimeType,
|
||||||
|
user=currentUser,
|
||||||
|
))
|
||||||
|
except Exception as indexErr:
|
||||||
|
logger.warning(f"Auto-index trigger failed (non-blocking): {indexErr}")
|
||||||
|
|
||||||
# Response with duplicate information
|
# Response with duplicate information
|
||||||
return JSONResponse({
|
return JSONResponse({
|
||||||
"message": message,
|
"message": message,
|
||||||
|
|
|
||||||
|
|
@ -488,7 +488,7 @@ async def auth_callback(code: str, state: str, request: Request, response: Respo
|
||||||
connection.externalUsername = user_info.get("email")
|
connection.externalUsername = user_info.get("email")
|
||||||
connection.externalEmail = user_info.get("email")
|
connection.externalEmail = user_info.get("email")
|
||||||
# Store actually granted scopes for this connection
|
# Store actually granted scopes for this connection
|
||||||
granted_scopes_list = granted_scopes.split(" ") if granted_scopes else SCOPES
|
granted_scopes_list = granted_scopes if isinstance(granted_scopes, list) else (granted_scopes.split(" ") if granted_scopes else SCOPES)
|
||||||
connection.grantedScopes = granted_scopes_list
|
connection.grantedScopes = granted_scopes_list
|
||||||
logger.info(f"Storing granted scopes for connection {connection_id}: {granted_scopes_list}")
|
logger.info(f"Storing granted scopes for connection {connection_id}: {granted_scopes_list}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -123,6 +123,9 @@ def _getFeatureUiObjects(featureCode: str) -> List[Dict[str, Any]]:
|
||||||
elif featureCode == "commcoach":
|
elif featureCode == "commcoach":
|
||||||
from modules.features.commcoach.mainCommcoach import UI_OBJECTS
|
from modules.features.commcoach.mainCommcoach import UI_OBJECTS
|
||||||
return UI_OBJECTS
|
return UI_OBJECTS
|
||||||
|
elif featureCode == "workspace":
|
||||||
|
from modules.features.workspace.mainWorkspace import UI_OBJECTS
|
||||||
|
return UI_OBJECTS
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unknown feature code: {featureCode}")
|
logger.warning(f"Unknown feature code: {featureCode}")
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -98,6 +98,20 @@ IMPORTABLE_SERVICES: Dict[str, Dict[str, Any]] = {
|
||||||
"objectKey": "service.neutralization",
|
"objectKey": "service.neutralization",
|
||||||
"label": {"en": "Neutralization", "de": "Neutralisierung", "fr": "Neutralisation"},
|
"label": {"en": "Neutralization", "de": "Neutralisierung", "fr": "Neutralisation"},
|
||||||
},
|
},
|
||||||
|
"agent": {
|
||||||
|
"module": "modules.serviceCenter.services.serviceAgent.mainServiceAgent",
|
||||||
|
"class": "AgentService",
|
||||||
|
"dependencies": ["ai", "chat", "utils", "extraction", "billing", "streaming", "knowledge"],
|
||||||
|
"objectKey": "service.agent",
|
||||||
|
"label": {"en": "Agent", "de": "Agent", "fr": "Agent"},
|
||||||
|
},
|
||||||
|
"knowledge": {
|
||||||
|
"module": "modules.serviceCenter.services.serviceKnowledge.mainServiceKnowledge",
|
||||||
|
"class": "KnowledgeService",
|
||||||
|
"dependencies": ["ai"],
|
||||||
|
"objectKey": "service.knowledge",
|
||||||
|
"label": {"en": "Knowledge Store", "de": "Wissensspeicher", "fr": "Base de connaissances"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# RBAC objects for service-level access control (for catalog registration)
|
# RBAC objects for service-level access control (for catalog registration)
|
||||||
|
|
|
||||||
3
modules/serviceCenter/services/serviceAgent/__init__.py
Normal file
3
modules/serviceCenter/services/serviceAgent/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""serviceAgent: AI Agent with ReAct loop and native function calling."""
|
||||||
162
modules/serviceCenter/services/serviceAgent/actionToolAdapter.py
Normal file
162
modules/serviceCenter/services/serviceAgent/actionToolAdapter.py
Normal file
|
|
@ -0,0 +1,162 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""ActionToolAdapter: wraps existing workflow actions (dynamicMode=True) as agent tools."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
|
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
|
||||||
|
ToolDefinition, ToolResult
|
||||||
|
)
|
||||||
|
from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ActionToolAdapter:
|
||||||
|
"""Wraps existing Workflow-Actions as Agent-Tools.
|
||||||
|
|
||||||
|
Iterates over discovered methods, finds actions with dynamicMode=True,
|
||||||
|
and registers them in the ToolRegistry with a compound name (method.action).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, actionExecutor):
|
||||||
|
self._actionExecutor = actionExecutor
|
||||||
|
self._registeredTools: List[str] = []
|
||||||
|
|
||||||
|
def registerAll(self, toolRegistry: ToolRegistry):
|
||||||
|
"""Discover and register all dynamicMode actions as agent tools."""
|
||||||
|
from modules.workflows.processing.shared.methodDiscovery import methods
|
||||||
|
|
||||||
|
registered = 0
|
||||||
|
for methodName, methodInfo in methods.items():
|
||||||
|
if not methodName[0].isupper():
|
||||||
|
continue
|
||||||
|
|
||||||
|
shortName = methodName.replace("Method", "").lower()
|
||||||
|
methodInstance = methodInfo["instance"]
|
||||||
|
|
||||||
|
for actionName, actionInfo in methodInfo["actions"].items():
|
||||||
|
actionDef = methodInstance._actions.get(actionName)
|
||||||
|
if not actionDef or not getattr(actionDef, "dynamicMode", False):
|
||||||
|
continue
|
||||||
|
|
||||||
|
compoundName = f"{shortName}.{actionName}"
|
||||||
|
toolDef = _buildToolDefinition(compoundName, actionDef, actionInfo)
|
||||||
|
|
||||||
|
handler = _createDispatchHandler(self._actionExecutor, shortName, actionName)
|
||||||
|
toolRegistry.registerFromDefinition(toolDef, handler)
|
||||||
|
self._registeredTools.append(compoundName)
|
||||||
|
registered += 1
|
||||||
|
|
||||||
|
logger.info(f"ActionToolAdapter: registered {registered} tools from workflow actions")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def registeredTools(self) -> List[str]:
|
||||||
|
"""Names of all tools registered by this adapter."""
|
||||||
|
return list(self._registeredTools)
|
||||||
|
|
||||||
|
|
||||||
|
def _buildToolDefinition(compoundName: str, actionDef, actionInfo: Dict[str, Any]) -> ToolDefinition:
|
||||||
|
"""Build a ToolDefinition from a WorkflowActionDefinition."""
|
||||||
|
parameters = _convertParameterSchema(actionInfo.get("parameters", {}))
|
||||||
|
|
||||||
|
return ToolDefinition(
|
||||||
|
name=compoundName,
|
||||||
|
description=actionDef.description or actionInfo.get("description", ""),
|
||||||
|
parameters=parameters,
|
||||||
|
readOnly=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convertParameterSchema(actionParams: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Convert workflow action parameter schema to JSON Schema for tool definitions."""
|
||||||
|
properties = {}
|
||||||
|
required = []
|
||||||
|
|
||||||
|
for paramName, paramInfo in actionParams.items():
|
||||||
|
paramType = paramInfo.get("type", "str") if isinstance(paramInfo, dict) else "str"
|
||||||
|
paramDesc = paramInfo.get("description", "") if isinstance(paramInfo, dict) else ""
|
||||||
|
paramRequired = paramInfo.get("required", False) if isinstance(paramInfo, dict) else False
|
||||||
|
|
||||||
|
jsonType = _pythonTypeToJsonType(paramType)
|
||||||
|
properties[paramName] = {
|
||||||
|
"type": jsonType,
|
||||||
|
"description": paramDesc
|
||||||
|
}
|
||||||
|
|
||||||
|
if paramRequired:
|
||||||
|
required.append(paramName)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": properties,
|
||||||
|
"required": required
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _pythonTypeToJsonType(pythonType: str) -> str:
|
||||||
|
"""Map Python type strings to JSON Schema types."""
|
||||||
|
mapping = {
|
||||||
|
"str": "string",
|
||||||
|
"int": "integer",
|
||||||
|
"float": "number",
|
||||||
|
"bool": "boolean",
|
||||||
|
"list": "array",
|
||||||
|
"dict": "object",
|
||||||
|
"List[str]": "array",
|
||||||
|
"List[int]": "array",
|
||||||
|
"List[dict]": "array",
|
||||||
|
"Dict[str, Any]": "object",
|
||||||
|
}
|
||||||
|
return mapping.get(pythonType, "string")
|
||||||
|
|
||||||
|
|
||||||
|
def _createDispatchHandler(actionExecutor, methodName: str, actionName: str):
|
||||||
|
"""Create an async handler that dispatches to the ActionExecutor."""
|
||||||
|
async def _handler(args: Dict[str, Any], context: Dict[str, Any]) -> ToolResult:
|
||||||
|
try:
|
||||||
|
result = await actionExecutor.executeAction(methodName, actionName, args)
|
||||||
|
data = _formatActionResult(result)
|
||||||
|
return ToolResult(
|
||||||
|
toolCallId="",
|
||||||
|
toolName=f"{methodName}.{actionName}",
|
||||||
|
success=result.success,
|
||||||
|
data=data,
|
||||||
|
error=result.error
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"ActionToolAdapter dispatch failed for {methodName}.{actionName}: {e}")
|
||||||
|
return ToolResult(
|
||||||
|
toolCallId="",
|
||||||
|
toolName=f"{methodName}.{actionName}",
|
||||||
|
success=False,
|
||||||
|
error=str(e)
|
||||||
|
)
|
||||||
|
return _handler
|
||||||
|
|
||||||
|
|
||||||
|
def _formatActionResult(result) -> str:
|
||||||
|
"""Format an ActionResult into a text representation for the agent."""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
if result.resultLabel:
|
||||||
|
parts.append(f"Result: {result.resultLabel}")
|
||||||
|
|
||||||
|
if result.error:
|
||||||
|
parts.append(f"Error: {result.error}")
|
||||||
|
|
||||||
|
if result.documents:
|
||||||
|
parts.append(f"Documents ({len(result.documents)}):")
|
||||||
|
for doc in result.documents:
|
||||||
|
docName = getattr(doc, "documentName", "unnamed")
|
||||||
|
docType = getattr(doc, "mimeType", "unknown")
|
||||||
|
parts.append(f" - {docName} ({docType})")
|
||||||
|
docData = getattr(doc, "documentData", None)
|
||||||
|
if docData and isinstance(docData, str) and len(docData) < 2000:
|
||||||
|
parts.append(f" Content: {docData[:2000]}")
|
||||||
|
|
||||||
|
if not parts:
|
||||||
|
parts.append("Action completed successfully." if result.success else "Action failed.")
|
||||||
|
|
||||||
|
return "\n".join(parts)
|
||||||
395
modules/serviceCenter/services/serviceAgent/agentLoop.py
Normal file
395
modules/serviceCenter/services/serviceAgent/agentLoop.py
Normal file
|
|
@ -0,0 +1,395 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Agent loop: ReAct pattern with native function calling, budget control, and error handling."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import List, Dict, Any, Optional, AsyncGenerator, Callable, Awaitable
|
||||||
|
|
||||||
|
from modules.datamodels.datamodelAi import (
|
||||||
|
AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum
|
||||||
|
)
|
||||||
|
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
|
||||||
|
AgentState, AgentStatusEnum, AgentConfig, AgentEvent, AgentEventTypeEnum,
|
||||||
|
ToolCallRequest, ToolResult, ToolCallLog, AgentRoundLog, AgentTrace
|
||||||
|
)
|
||||||
|
from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry
|
||||||
|
from modules.serviceCenter.services.serviceAgent.conversationManager import (
|
||||||
|
ConversationManager, buildSystemPrompt
|
||||||
|
)
|
||||||
|
from modules.shared.timeUtils import getUtcTimestamp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_RETRIES_PER_TOOL = 3
|
||||||
|
RETRY_BASE_DELAY_S = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
async def runAgentLoop(
|
||||||
|
prompt: str,
|
||||||
|
toolRegistry: ToolRegistry,
|
||||||
|
config: AgentConfig,
|
||||||
|
aiCallFn: Callable[[AiCallRequest], Awaitable[AiCallResponse]],
|
||||||
|
getWorkflowCostFn: Callable[[], Awaitable[float]],
|
||||||
|
workflowId: str,
|
||||||
|
userId: str = "",
|
||||||
|
featureInstanceId: str = "",
|
||||||
|
buildRagContextFn: Callable[..., Awaitable[str]] = None,
|
||||||
|
mandateId: str = "",
|
||||||
|
aiCallStreamFn: Callable = None,
|
||||||
|
userLanguage: str = "",
|
||||||
|
) -> AsyncGenerator[AgentEvent, None]:
|
||||||
|
"""Run the agent loop. Yields AgentEvent for each step (SSE-ready).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt
|
||||||
|
toolRegistry: Registry with available tools
|
||||||
|
config: Agent configuration (maxRounds, maxCostCHF, etc.)
|
||||||
|
aiCallFn: Function to call the AI (wraps serviceAi.callAi with billing)
|
||||||
|
getWorkflowCostFn: Function to get current workflow cost
|
||||||
|
workflowId: Workflow ID for tracking
|
||||||
|
userId: User ID for tracing
|
||||||
|
featureInstanceId: Feature instance ID for tracing
|
||||||
|
buildRagContextFn: Optional async function to build RAG context before each round
|
||||||
|
mandateId: Mandate ID for RAG scoping
|
||||||
|
userLanguage: ISO 639-1 language code for agent responses
|
||||||
|
"""
|
||||||
|
state = AgentState(workflowId=workflowId, maxRounds=config.maxRounds)
|
||||||
|
trace = AgentTrace(
|
||||||
|
workflowId=workflowId, userId=userId,
|
||||||
|
featureInstanceId=featureInstanceId
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = toolRegistry.getTools()
|
||||||
|
toolDefinitions = toolRegistry.formatToolsForFunctionCalling()
|
||||||
|
toolsText = toolRegistry.formatToolsForPrompt()
|
||||||
|
|
||||||
|
systemPrompt = buildSystemPrompt(tools, toolsText, userLanguage=userLanguage)
|
||||||
|
conversation = ConversationManager(systemPrompt)
|
||||||
|
conversation.addUserMessage(prompt)
|
||||||
|
|
||||||
|
while state.status == AgentStatusEnum.RUNNING and state.currentRound < state.maxRounds:
|
||||||
|
state.currentRound += 1
|
||||||
|
roundStartTime = time.time()
|
||||||
|
roundLog = AgentRoundLog(roundNumber=state.currentRound)
|
||||||
|
|
||||||
|
# RAG context injection (before each round for fresh relevance)
|
||||||
|
if buildRagContextFn:
|
||||||
|
try:
|
||||||
|
latestUserMsg = ""
|
||||||
|
for msg in reversed(conversation.messages):
|
||||||
|
if msg.get("role") == "user":
|
||||||
|
latestUserMsg = msg.get("content", "")
|
||||||
|
break
|
||||||
|
ragContext = await buildRagContextFn(
|
||||||
|
currentPrompt=latestUserMsg or prompt,
|
||||||
|
workflowId=workflowId,
|
||||||
|
userId=userId,
|
||||||
|
featureInstanceId=featureInstanceId,
|
||||||
|
mandateId=mandateId,
|
||||||
|
)
|
||||||
|
if ragContext:
|
||||||
|
conversation.injectRagContext(ragContext)
|
||||||
|
except Exception as ragErr:
|
||||||
|
logger.warning(f"RAG context injection failed (non-blocking): {ragErr}")
|
||||||
|
|
||||||
|
# Budget check
|
||||||
|
budgetExceeded = await _checkBudget(config, getWorkflowCostFn)
|
||||||
|
if budgetExceeded:
|
||||||
|
state.status = AgentStatusEnum.BUDGET_EXCEEDED
|
||||||
|
state.abortReason = "Workflow cost budget exceeded"
|
||||||
|
yield AgentEvent(
|
||||||
|
type=AgentEventTypeEnum.FINAL,
|
||||||
|
content=_buildProgressSummary(state, "Budget exceeded. Here is the progress so far.")
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
yield AgentEvent(
|
||||||
|
type=AgentEventTypeEnum.AGENT_PROGRESS,
|
||||||
|
data={
|
||||||
|
"round": state.currentRound,
|
||||||
|
"maxRounds": state.maxRounds,
|
||||||
|
"totalAiCalls": state.totalAiCalls,
|
||||||
|
"totalToolCalls": state.totalToolCalls,
|
||||||
|
"costCHF": state.totalCostCHF
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Progressive summarization
|
||||||
|
if conversation.needsSummarization(state.currentRound):
|
||||||
|
async def _summarizeCall(summaryPrompt: str) -> str:
|
||||||
|
req = AiCallRequest(
|
||||||
|
prompt=summaryPrompt,
|
||||||
|
options=AiCallOptions(operationType=OperationTypeEnum.DATA_ANALYSE)
|
||||||
|
)
|
||||||
|
resp = await aiCallFn(req)
|
||||||
|
state.totalCostCHF += resp.priceCHF
|
||||||
|
state.totalAiCalls += 1
|
||||||
|
return resp.content
|
||||||
|
|
||||||
|
await conversation.summarize(state.currentRound, _summarizeCall)
|
||||||
|
|
||||||
|
# AI call
|
||||||
|
aiRequest = AiCallRequest(
|
||||||
|
prompt="",
|
||||||
|
options=AiCallOptions(
|
||||||
|
operationType=OperationTypeEnum.AGENT,
|
||||||
|
temperature=config.temperature
|
||||||
|
),
|
||||||
|
messages=conversation.messages,
|
||||||
|
tools=toolDefinitions
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
aiResponse = None
|
||||||
|
streamedText = ""
|
||||||
|
|
||||||
|
if aiCallStreamFn:
|
||||||
|
async for chunk in aiCallStreamFn(aiRequest):
|
||||||
|
if isinstance(chunk, str):
|
||||||
|
streamedText += chunk
|
||||||
|
yield AgentEvent(type=AgentEventTypeEnum.CHUNK, content=chunk)
|
||||||
|
else:
|
||||||
|
aiResponse = chunk
|
||||||
|
|
||||||
|
if aiResponse is None:
|
||||||
|
raise RuntimeError("Stream ended without final AiCallResponse")
|
||||||
|
else:
|
||||||
|
aiResponse = await aiCallFn(aiRequest)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"AI call failed in round {state.currentRound}: {e}", exc_info=True)
|
||||||
|
state.status = AgentStatusEnum.ERROR
|
||||||
|
state.abortReason = f"AI call error: {e}"
|
||||||
|
yield AgentEvent(type=AgentEventTypeEnum.ERROR, content=str(e))
|
||||||
|
break
|
||||||
|
|
||||||
|
state.totalAiCalls += 1
|
||||||
|
state.totalCostCHF += aiResponse.priceCHF
|
||||||
|
state.totalProcessingTime += aiResponse.processingTime
|
||||||
|
roundLog.aiModel = aiResponse.modelName
|
||||||
|
roundLog.costCHF = aiResponse.priceCHF
|
||||||
|
|
||||||
|
if aiResponse.errorCount > 0:
|
||||||
|
state.status = AgentStatusEnum.ERROR
|
||||||
|
state.abortReason = f"AI returned error: {aiResponse.content}"
|
||||||
|
yield AgentEvent(type=AgentEventTypeEnum.ERROR, content=aiResponse.content)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Parse response for tool calls
|
||||||
|
toolCalls = _parseToolCalls(aiResponse)
|
||||||
|
textContent = _extractTextContent(aiResponse)
|
||||||
|
|
||||||
|
if textContent and not streamedText:
|
||||||
|
yield AgentEvent(type=AgentEventTypeEnum.MESSAGE, content=textContent)
|
||||||
|
|
||||||
|
if not toolCalls:
|
||||||
|
state.status = AgentStatusEnum.COMPLETED
|
||||||
|
conversation.addAssistantMessage(aiResponse.content)
|
||||||
|
roundLog.durationMs = int((time.time() - roundStartTime) * 1000)
|
||||||
|
trace.rounds.append(roundLog)
|
||||||
|
yield AgentEvent(type=AgentEventTypeEnum.FINAL, content=textContent or aiResponse.content)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Add assistant message with tool calls to conversation
|
||||||
|
assistantToolCalls = _formatAssistantToolCalls(toolCalls)
|
||||||
|
conversation.addAssistantMessage(textContent or "", assistantToolCalls)
|
||||||
|
|
||||||
|
# Execute tool calls
|
||||||
|
for tc in toolCalls:
|
||||||
|
yield AgentEvent(
|
||||||
|
type=AgentEventTypeEnum.TOOL_CALL,
|
||||||
|
data={"toolName": tc.name, "args": tc.args}
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await _executeToolCalls(toolCalls, toolRegistry, {
|
||||||
|
"workflowId": workflowId,
|
||||||
|
"userId": userId,
|
||||||
|
"featureInstanceId": featureInstanceId
|
||||||
|
})
|
||||||
|
state.totalToolCalls += len(results)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
roundLog.toolCalls.append(ToolCallLog(
|
||||||
|
toolName=result.toolName,
|
||||||
|
args=next((tc.args for tc in toolCalls if tc.id == result.toolCallId), {}),
|
||||||
|
success=result.success,
|
||||||
|
durationMs=result.durationMs,
|
||||||
|
error=result.error
|
||||||
|
))
|
||||||
|
yield AgentEvent(
|
||||||
|
type=AgentEventTypeEnum.TOOL_RESULT,
|
||||||
|
data={
|
||||||
|
"toolName": result.toolName,
|
||||||
|
"success": result.success,
|
||||||
|
"data": result.data[:500] if result.data else "",
|
||||||
|
"error": result.error
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if result.sideEvents:
|
||||||
|
for sideEvt in result.sideEvents:
|
||||||
|
evtType = sideEvt.get("type", "")
|
||||||
|
try:
|
||||||
|
evtEnum = AgentEventTypeEnum(evtType)
|
||||||
|
except (ValueError, KeyError):
|
||||||
|
continue
|
||||||
|
yield AgentEvent(
|
||||||
|
type=evtEnum,
|
||||||
|
data=sideEvt.get("data"),
|
||||||
|
content=sideEvt.get("content"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add tool results to conversation
|
||||||
|
toolResultMessages = [
|
||||||
|
{"toolCallId": r.toolCallId, "toolName": r.toolName,
|
||||||
|
"content": r.data if r.success else f"Error: {r.error}"}
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
|
conversation.addToolResults(toolResultMessages)
|
||||||
|
|
||||||
|
roundLog.durationMs = int((time.time() - roundStartTime) * 1000)
|
||||||
|
trace.rounds.append(roundLog)
|
||||||
|
|
||||||
|
# maxRounds reached
|
||||||
|
if state.currentRound >= state.maxRounds and state.status == AgentStatusEnum.RUNNING:
|
||||||
|
state.status = AgentStatusEnum.MAX_ROUNDS_REACHED
|
||||||
|
state.abortReason = f"Maximum rounds ({state.maxRounds}) reached"
|
||||||
|
yield AgentEvent(
|
||||||
|
type=AgentEventTypeEnum.FINAL,
|
||||||
|
content=_buildProgressSummary(state, "Maximum rounds reached.")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Agent summary
|
||||||
|
trace.completedAt = getUtcTimestamp()
|
||||||
|
trace.status = state.status
|
||||||
|
trace.totalRounds = state.currentRound
|
||||||
|
trace.totalToolCalls = state.totalToolCalls
|
||||||
|
trace.totalCostCHF = state.totalCostCHF
|
||||||
|
trace.abortReason = state.abortReason
|
||||||
|
|
||||||
|
yield AgentEvent(
|
||||||
|
type=AgentEventTypeEnum.AGENT_SUMMARY,
|
||||||
|
data={
|
||||||
|
"rounds": state.currentRound,
|
||||||
|
"totalAiCalls": state.totalAiCalls,
|
||||||
|
"totalToolCalls": state.totalToolCalls,
|
||||||
|
"costCHF": round(state.totalCostCHF, 4),
|
||||||
|
"processingTime": round(state.totalProcessingTime, 2),
|
||||||
|
"status": state.status.value,
|
||||||
|
"abortReason": state.abortReason
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _checkBudget(config: AgentConfig,
|
||||||
|
getWorkflowCostFn: Callable[[], Awaitable[float]]) -> bool:
|
||||||
|
"""Check if workflow budget is exceeded. Returns True if exceeded."""
|
||||||
|
if config.maxCostCHF is None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
currentCost = await getWorkflowCostFn()
|
||||||
|
return currentCost > config.maxCostCHF
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not check workflow cost: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def _executeToolCalls(toolCalls: List[ToolCallRequest],
|
||||||
|
toolRegistry: ToolRegistry,
|
||||||
|
context: Dict[str, Any]) -> List[ToolResult]:
|
||||||
|
"""Execute tool calls: readOnly tools in parallel, others sequentially."""
|
||||||
|
readOnlyCalls = [tc for tc in toolCalls if toolRegistry.isReadOnly(tc.name)]
|
||||||
|
writeCalls = [tc for tc in toolCalls if not toolRegistry.isReadOnly(tc.name)]
|
||||||
|
|
||||||
|
results: Dict[str, ToolResult] = {}
|
||||||
|
|
||||||
|
if readOnlyCalls:
|
||||||
|
readResults = await asyncio.gather(*[
|
||||||
|
toolRegistry.dispatch(tc, context) for tc in readOnlyCalls
|
||||||
|
])
|
||||||
|
for tc, result in zip(readOnlyCalls, readResults):
|
||||||
|
results[tc.id] = result
|
||||||
|
|
||||||
|
for tc in writeCalls:
|
||||||
|
results[tc.id] = await toolRegistry.dispatch(tc, context)
|
||||||
|
|
||||||
|
return [results[tc.id] for tc in toolCalls]
|
||||||
|
|
||||||
|
|
||||||
|
def _parseToolCalls(aiResponse: AiCallResponse) -> List[ToolCallRequest]:
|
||||||
|
"""Parse tool calls from AI response. Supports native function calling and text-based fallback."""
|
||||||
|
toolCalls = []
|
||||||
|
|
||||||
|
# Native function calling: check response metadata
|
||||||
|
if hasattr(aiResponse, 'toolCalls') and aiResponse.toolCalls:
|
||||||
|
for tc in aiResponse.toolCalls:
|
||||||
|
rawArgs = tc["function"]["arguments"]
|
||||||
|
if isinstance(rawArgs, str):
|
||||||
|
rawArgs = rawArgs.strip()
|
||||||
|
try:
|
||||||
|
parsedArgs = json.loads(rawArgs) if rawArgs else {}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Failed to parse tool args for '{tc['function']['name']}': {rawArgs[:200]}")
|
||||||
|
parsedArgs = {}
|
||||||
|
else:
|
||||||
|
parsedArgs = rawArgs if rawArgs else {}
|
||||||
|
toolCalls.append(ToolCallRequest(
|
||||||
|
id=tc.get("id", str(len(toolCalls))),
|
||||||
|
name=tc["function"]["name"],
|
||||||
|
args=parsedArgs,
|
||||||
|
))
|
||||||
|
return toolCalls
|
||||||
|
|
||||||
|
# Text-based fallback: parse ```tool_call blocks
|
||||||
|
content = aiResponse.content or ""
|
||||||
|
pattern = r"```tool_call\s*\n\s*tool:\s*(\S+)\s*\n\s*args:\s*(\{.*?\})\s*\n\s*```"
|
||||||
|
matches = re.finditer(pattern, content, re.DOTALL)
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
toolName = match.group(1).strip()
|
||||||
|
argsStr = match.group(2).strip()
|
||||||
|
try:
|
||||||
|
args = json.loads(argsStr)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Failed to parse tool args for '{toolName}': {argsStr}")
|
||||||
|
args = {}
|
||||||
|
toolCalls.append(ToolCallRequest(name=toolName, args=args))
|
||||||
|
|
||||||
|
return toolCalls
|
||||||
|
|
||||||
|
|
||||||
|
def _extractTextContent(aiResponse: AiCallResponse) -> str:
|
||||||
|
"""Extract text content from AI response, removing tool_call blocks."""
|
||||||
|
content = aiResponse.content or ""
|
||||||
|
cleaned = re.sub(r"```tool_call\s*\n.*?\n\s*```", "", content, flags=re.DOTALL)
|
||||||
|
return cleaned.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _formatAssistantToolCalls(toolCalls: List[ToolCallRequest]) -> List[Dict[str, Any]]:
|
||||||
|
"""Format tool calls for the conversation history (OpenAI tool_calls format)."""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": tc.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tc.name,
|
||||||
|
"arguments": json.dumps(tc.args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for tc in toolCalls
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _buildProgressSummary(state: AgentState, reason: str) -> str:
|
||||||
|
"""Build a human-readable summary of agent progress for graceful termination."""
|
||||||
|
return (
|
||||||
|
f"{reason}\n\n"
|
||||||
|
f"Progress after {state.currentRound} rounds:\n"
|
||||||
|
f"- AI calls: {state.totalAiCalls}\n"
|
||||||
|
f"- Tool calls: {state.totalToolCalls}\n"
|
||||||
|
f"- Cost: {state.totalCostCHF:.4f} CHF\n"
|
||||||
|
f"- Processing time: {state.totalProcessingTime:.1f}s"
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,265 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Conversation manager for the Agent service.
|
||||||
|
Handles message history, context window management, and progressive summarization."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
from modules.serviceCenter.services.serviceAgent.datamodelAgent import ToolDefinition
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FIRST_SUMMARY_ROUND = 4
|
||||||
|
META_SUMMARY_ROUND = 7
|
||||||
|
KEEP_RECENT_MESSAGES = 4
|
||||||
|
MAX_ESTIMATED_TOKENS = 60000
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationManager:
|
||||||
|
"""Manages the conversation history and context window for agent runs.
|
||||||
|
|
||||||
|
Progressive summarization strategy:
|
||||||
|
- Rounds 1-3: full conversation retained
|
||||||
|
- Round 4+: older messages compressed into a running summary
|
||||||
|
- Round 7+: meta-summary replaces prior summaries
|
||||||
|
Supports RAG context injection before each round via injectRagContext."""
|
||||||
|
|
||||||
|
def __init__(self, systemPrompt: str):
|
||||||
|
self._messages: List[Dict[str, Any]] = [
|
||||||
|
{"role": "system", "content": systemPrompt}
|
||||||
|
]
|
||||||
|
self._summaries: List[Dict[str, Any]] = []
|
||||||
|
self._lastSummarizedRound: int = 0
|
||||||
|
self._ragContextInjected: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messages(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Current messages for the next AI call (internal markers stripped)."""
|
||||||
|
return [
|
||||||
|
{k: v for k, v in msg.items() if not k.startswith("_")}
|
||||||
|
for msg in self._messages
|
||||||
|
]
|
||||||
|
|
||||||
|
def addUserMessage(self, content: str):
|
||||||
|
"""Add a user message."""
|
||||||
|
self._messages.append({"role": "user", "content": content})
|
||||||
|
|
||||||
|
def addAssistantMessage(self, content: str, toolCalls: List[Dict[str, Any]] = None):
|
||||||
|
"""Add an assistant message, optionally with tool calls."""
|
||||||
|
msg: Dict[str, Any] = {"role": "assistant", "content": content}
|
||||||
|
if toolCalls:
|
||||||
|
msg["tool_calls"] = toolCalls
|
||||||
|
self._messages.append(msg)
|
||||||
|
|
||||||
|
def addToolResults(self, results: List[Dict[str, Any]]):
|
||||||
|
"""Add tool results to the conversation.
|
||||||
|
Each result: {toolCallId, toolName, content}."""
|
||||||
|
for result in results:
|
||||||
|
self._messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": result["toolCallId"],
|
||||||
|
"content": result["content"]
|
||||||
|
})
|
||||||
|
|
||||||
|
def addToolResultsAsText(self, resultText: str):
|
||||||
|
"""Add combined tool results as a user message (text-based fallback)."""
|
||||||
|
self._messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Tool Results:\n{resultText}"
|
||||||
|
})
|
||||||
|
|
||||||
|
def injectRagContext(self, ragContext: str):
|
||||||
|
"""Inject RAG context as a system message right after the main system prompt.
|
||||||
|
|
||||||
|
Called before each agent round by the agent loop if KnowledgeService is available.
|
||||||
|
Replaces any previously injected RAG context to keep the context fresh."""
|
||||||
|
if not ragContext:
|
||||||
|
return
|
||||||
|
|
||||||
|
ragMessage = {
|
||||||
|
"role": "system",
|
||||||
|
"content": f"Relevant Knowledge (from indexed documents and workflow context):\n{ragContext}",
|
||||||
|
"_isRagContext": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Replace existing RAG message if present, otherwise insert after system prompt
|
||||||
|
for i, msg in enumerate(self._messages):
|
||||||
|
if msg.get("_isRagContext"):
|
||||||
|
self._messages[i] = ragMessage
|
||||||
|
self._ragContextInjected = True
|
||||||
|
return
|
||||||
|
|
||||||
|
# Insert after the first system prompt
|
||||||
|
self._messages.insert(1, ragMessage)
|
||||||
|
self._ragContextInjected = True
|
||||||
|
|
||||||
|
def getMessageCount(self) -> int:
|
||||||
|
"""Get the number of messages (excluding system prompt)."""
|
||||||
|
return len(self._messages) - 1
|
||||||
|
|
||||||
|
def estimateTokenCount(self) -> int:
|
||||||
|
"""Rough estimate of total tokens in the conversation (4 chars ≈ 1 token)."""
|
||||||
|
totalChars = sum(len(str(m.get("content", ""))) for m in self._messages)
|
||||||
|
return totalChars // 4
|
||||||
|
|
||||||
|
def needsSummarization(self, currentRound: int) -> bool:
|
||||||
|
"""Check if progressive summarization should be triggered.
|
||||||
|
|
||||||
|
Triggers:
|
||||||
|
- At round FIRST_SUMMARY_ROUND (4) if not yet summarized
|
||||||
|
- At round META_SUMMARY_ROUND (7) for meta-summary
|
||||||
|
- Every 5 rounds after that
|
||||||
|
- When estimated token count exceeds MAX_ESTIMATED_TOKENS
|
||||||
|
"""
|
||||||
|
if currentRound >= FIRST_SUMMARY_ROUND and self._lastSummarizedRound < currentRound:
|
||||||
|
if currentRound == FIRST_SUMMARY_ROUND or currentRound == META_SUMMARY_ROUND:
|
||||||
|
return True
|
||||||
|
if (currentRound - META_SUMMARY_ROUND) % 5 == 0 and currentRound > META_SUMMARY_ROUND:
|
||||||
|
return True
|
||||||
|
if self.estimateTokenCount() > MAX_ESTIMATED_TOKENS:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def summarize(self, currentRound: int, aiCallFn) -> Optional[str]:
|
||||||
|
"""Perform progressive summarization of older messages.
|
||||||
|
|
||||||
|
Rounds 1-3: full history retained, no summarization.
|
||||||
|
Round 4+: compress older messages into a running summary.
|
||||||
|
Round 7+: meta-summary that consolidates prior summaries.
|
||||||
|
"""
|
||||||
|
if currentRound < FIRST_SUMMARY_ROUND and self.estimateTokenCount() <= MAX_ESTIMATED_TOKENS:
|
||||||
|
return None
|
||||||
|
|
||||||
|
systemMsgs = [m for m in self._messages if m.get("role") == "system"]
|
||||||
|
nonSystemMessages = [m for m in self._messages if m.get("role") != "system"]
|
||||||
|
|
||||||
|
keepRecent = min(KEEP_RECENT_MESSAGES, len(nonSystemMessages))
|
||||||
|
if len(nonSystemMessages) <= keepRecent + 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
messagesToSummarize = nonSystemMessages[:-keepRecent]
|
||||||
|
recentMessages = nonSystemMessages[-keepRecent:]
|
||||||
|
|
||||||
|
summaryInput = _formatMessagesForSummary(messagesToSummarize)
|
||||||
|
previousSummary = self._summaries[-1]["content"] if self._summaries else ""
|
||||||
|
|
||||||
|
isMetaSummary = currentRound >= META_SUMMARY_ROUND and len(self._summaries) >= 2
|
||||||
|
summaryPrompt = _buildSummaryPrompt(summaryInput, previousSummary, isMetaSummary)
|
||||||
|
|
||||||
|
try:
|
||||||
|
summaryText = await aiCallFn(summaryPrompt)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Progressive summarization failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
self._summaries.append({
|
||||||
|
"round": currentRound,
|
||||||
|
"content": summaryText,
|
||||||
|
"isMeta": isMetaSummary,
|
||||||
|
})
|
||||||
|
self._lastSummarizedRound = currentRound
|
||||||
|
|
||||||
|
mainSystem = systemMsgs[0] if systemMsgs else {"role": "system", "content": ""}
|
||||||
|
ragMessages = [m for m in systemMsgs if m.get("_isRagContext")]
|
||||||
|
|
||||||
|
self._messages = [
|
||||||
|
mainSystem,
|
||||||
|
*ragMessages,
|
||||||
|
{"role": "system", "content": f"Conversation Summary (rounds 1-{currentRound - keepRecent}):\n{summaryText}"},
|
||||||
|
*recentMessages,
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Progressive summarization at round {currentRound}: "
|
||||||
|
f"compressed {len(messagesToSummarize)} messages into "
|
||||||
|
f"{'meta-' if isMetaSummary else ''}summary"
|
||||||
|
)
|
||||||
|
return summaryText
|
||||||
|
|
||||||
|
|
||||||
|
def _formatMessagesForSummary(messages: List[Dict[str, Any]]) -> str:
|
||||||
|
"""Format messages into a text block for summarization."""
|
||||||
|
parts = []
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role", "unknown")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if role == "tool":
|
||||||
|
toolName = msg.get("tool_call_id", "tool")
|
||||||
|
parts.append(f"[Tool Result ({toolName})]:\n{content}")
|
||||||
|
elif role == "assistant" and msg.get("tool_calls"):
|
||||||
|
calls = msg["tool_calls"]
|
||||||
|
callNames = [c.get("function", {}).get("name", "?") for c in calls]
|
||||||
|
parts.append(f"[Assistant → Tool Calls: {', '.join(callNames)}]")
|
||||||
|
if content:
|
||||||
|
parts.append(f"[Assistant]: {content}")
|
||||||
|
else:
|
||||||
|
parts.append(f"[{role.capitalize()}]: {content}")
|
||||||
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _buildSummaryPrompt(messagesText: str, previousSummary: str, isMetaSummary: bool = False) -> str:
|
||||||
|
"""Build the prompt for progressive summarization."""
|
||||||
|
if isMetaSummary:
|
||||||
|
prompt = (
|
||||||
|
"Create a comprehensive meta-summary consolidating the previous summary "
|
||||||
|
"and the new messages. Preserve all key facts, decisions, entities (names, "
|
||||||
|
"numbers, dates), tool results, and action outcomes. Be concise but complete.\n\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = (
|
||||||
|
"Summarize the following conversation concisely. Preserve all key facts, "
|
||||||
|
"decisions, entities (names, numbers, dates), and tool results. "
|
||||||
|
"Do not lose any important information.\n\n"
|
||||||
|
)
|
||||||
|
if previousSummary:
|
||||||
|
prompt += f"Previous Summary:\n{previousSummary}\n\n"
|
||||||
|
prompt += f"New Messages to Summarize:\n{messagesText}\n\nProvide a concise, factual summary:"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
_LANGUAGE_NAMES = {
|
||||||
|
"de": "German", "en": "English", "fr": "French", "it": "Italian",
|
||||||
|
"es": "Spanish", "pt": "Portuguese", "nl": "Dutch", "ja": "Japanese",
|
||||||
|
"zh": "Chinese", "ko": "Korean", "ar": "Arabic", "ru": "Russian",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def buildSystemPrompt(
|
||||||
|
tools: List[ToolDefinition],
|
||||||
|
toolsFormatted: str = None,
|
||||||
|
userLanguage: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Build the system prompt for the agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: Available tool definitions.
|
||||||
|
toolsFormatted: Pre-formatted tool descriptions for text-based fallback.
|
||||||
|
userLanguage: ISO 639-1 language code (e.g. "de", "en"). The agent will
|
||||||
|
respond in this language.
|
||||||
|
"""
|
||||||
|
langName = _LANGUAGE_NAMES.get(userLanguage, "")
|
||||||
|
langInstruction = (
|
||||||
|
f"IMPORTANT: Always respond in {langName} ({userLanguage}). "
|
||||||
|
f"The user's language is {langName}. All your messages, explanations, "
|
||||||
|
f"and summaries MUST be in {langName}. "
|
||||||
|
f"Only use English for tool call arguments and technical identifiers.\n\n"
|
||||||
|
) if langName else ""
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
f"{langInstruction}"
|
||||||
|
"You are an AI agent with access to tools. "
|
||||||
|
"Use the provided tools to accomplish the user's task. "
|
||||||
|
"Think step by step. Call tools when you need information or need to perform actions. "
|
||||||
|
"When you have enough information to answer, respond directly without calling tools.\n\n"
|
||||||
|
)
|
||||||
|
if toolsFormatted:
|
||||||
|
prompt += f"Available Tools:\n{toolsFormatted}\n\n"
|
||||||
|
prompt += (
|
||||||
|
"To call a tool, use this format:\n"
|
||||||
|
"```tool_call\n"
|
||||||
|
"tool: <tool_name>\n"
|
||||||
|
'args: {"param": "value"}\n'
|
||||||
|
"```\n\n"
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
132
modules/serviceCenter/services/serviceAgent/datamodelAgent.py
Normal file
132
modules/serviceCenter/services/serviceAgent/datamodelAgent.py
Normal file
|
|
@ -0,0 +1,132 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Data models for the Agent service."""
|
||||||
|
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from modules.shared.timeUtils import getUtcTimestamp
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
class AgentStatusEnum(str, Enum):
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
MAX_ROUNDS_REACHED = "maxRoundsReached"
|
||||||
|
BUDGET_EXCEEDED = "budgetExceeded"
|
||||||
|
ERROR = "error"
|
||||||
|
STOPPED = "stopped"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentEventTypeEnum(str, Enum):
|
||||||
|
MESSAGE = "message"
|
||||||
|
CHUNK = "chunk"
|
||||||
|
TOOL_CALL = "toolCall"
|
||||||
|
TOOL_RESULT = "toolResult"
|
||||||
|
AGENT_PROGRESS = "agentProgress"
|
||||||
|
AGENT_SUMMARY = "agentSummary"
|
||||||
|
FILE_CREATED = "fileCreated"
|
||||||
|
DATA_SOURCE_ACCESS = "dataSourceAccess"
|
||||||
|
VOICE_RESPONSE = "voiceResponse"
|
||||||
|
FINAL = "final"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
class ToolDefinition(BaseModel):
|
||||||
|
"""Schema for a tool available to the agent."""
|
||||||
|
name: str = Field(description="Unique tool name")
|
||||||
|
description: str = Field(description="What this tool does")
|
||||||
|
parameters: Dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="JSON Schema for tool parameters"
|
||||||
|
)
|
||||||
|
readOnly: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="If True, tool can run in parallel with other readOnly tools"
|
||||||
|
)
|
||||||
|
featureType: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Feature scope for this tool (None = available to all)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallRequest(BaseModel):
|
||||||
|
"""A tool call requested by the AI model."""
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
name: str
|
||||||
|
args: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolResult(BaseModel):
|
||||||
|
"""Result from executing a tool."""
|
||||||
|
toolCallId: str
|
||||||
|
toolName: str
|
||||||
|
success: bool = True
|
||||||
|
data: str = ""
|
||||||
|
error: Optional[str] = None
|
||||||
|
durationMs: int = 0
|
||||||
|
sideEvents: Optional[List[Dict[str, Any]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentEvent(BaseModel):
|
||||||
|
"""Event emitted during agent execution for SSE streaming."""
|
||||||
|
type: AgentEventTypeEnum
|
||||||
|
content: Optional[str] = None
|
||||||
|
data: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfig(BaseModel):
|
||||||
|
"""Configuration for an agent run."""
|
||||||
|
maxRounds: int = Field(default=25, ge=1, le=100)
|
||||||
|
maxCostCHF: Optional[float] = Field(default=None, ge=0.0)
|
||||||
|
entityCacheEnabled: bool = Field(default=False)
|
||||||
|
toolSet: str = Field(default="core")
|
||||||
|
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentState(BaseModel):
|
||||||
|
"""Tracks state across an agent loop execution."""
|
||||||
|
workflowId: str
|
||||||
|
currentRound: int = 0
|
||||||
|
maxRounds: int = 25
|
||||||
|
totalAiCalls: int = 0
|
||||||
|
totalToolCalls: int = 0
|
||||||
|
totalCostCHF: float = 0.0
|
||||||
|
totalProcessingTime: float = 0.0
|
||||||
|
status: AgentStatusEnum = AgentStatusEnum.RUNNING
|
||||||
|
abortReason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallLog(BaseModel):
|
||||||
|
"""Log of a single tool call for observability."""
|
||||||
|
toolName: str
|
||||||
|
args: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
success: bool = True
|
||||||
|
durationMs: int = 0
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRoundLog(BaseModel):
|
||||||
|
"""Log of a single agent round for observability."""
|
||||||
|
roundNumber: int
|
||||||
|
aiModel: str = ""
|
||||||
|
inputTokens: int = 0
|
||||||
|
outputTokens: int = 0
|
||||||
|
costCHF: float = 0.0
|
||||||
|
toolCalls: List[ToolCallLog] = Field(default_factory=list)
|
||||||
|
durationMs: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTrace(BaseModel):
|
||||||
|
"""Full trace of an agent workflow for observability."""
|
||||||
|
workflowId: str
|
||||||
|
userId: str = ""
|
||||||
|
featureInstanceId: str = ""
|
||||||
|
startedAt: float = Field(default_factory=getUtcTimestamp)
|
||||||
|
completedAt: Optional[float] = None
|
||||||
|
status: AgentStatusEnum = AgentStatusEnum.RUNNING
|
||||||
|
totalRounds: int = 0
|
||||||
|
totalToolCalls: int = 0
|
||||||
|
totalCostCHF: float = 0.0
|
||||||
|
abortReason: Optional[str] = None
|
||||||
|
rounds: List[AgentRoundLog] = Field(default_factory=list)
|
||||||
1293
modules/serviceCenter/services/serviceAgent/mainServiceAgent.py
Normal file
1293
modules/serviceCenter/services/serviceAgent/mainServiceAgent.py
Normal file
File diff suppressed because it is too large
Load diff
143
modules/serviceCenter/services/serviceAgent/toolRegistry.py
Normal file
143
modules/serviceCenter/services/serviceAgent/toolRegistry.py
Normal file
|
|
@ -0,0 +1,143 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Tool registry for the Agent service. Manages tool definitions and dispatch."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Any, Optional, Callable, Awaitable
|
||||||
|
|
||||||
|
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
|
||||||
|
ToolDefinition, ToolCallRequest, ToolResult
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRegistry:
|
||||||
|
"""Registry for agent tools. Handles registration, lookup, and dispatch."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._tools: Dict[str, ToolDefinition] = {}
|
||||||
|
self._handlers: Dict[str, Callable[..., Awaitable[ToolResult]]] = {}
|
||||||
|
|
||||||
|
def register(self, name: str, handler: Callable[..., Awaitable[ToolResult]],
|
||||||
|
description: str = "", parameters: Dict[str, Any] = None,
|
||||||
|
readOnly: bool = False, featureType: str = None):
|
||||||
|
"""Register a tool with its handler function."""
|
||||||
|
if name in self._tools:
|
||||||
|
logger.warning(f"Tool '{name}' already registered, overwriting")
|
||||||
|
|
||||||
|
self._tools[name] = ToolDefinition(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
parameters=parameters or {},
|
||||||
|
readOnly=readOnly,
|
||||||
|
featureType=featureType
|
||||||
|
)
|
||||||
|
self._handlers[name] = handler
|
||||||
|
logger.debug(f"Registered tool: {name} (readOnly={readOnly})")
|
||||||
|
|
||||||
|
def registerFromDefinition(self, definition: ToolDefinition,
|
||||||
|
handler: Callable[..., Awaitable[ToolResult]]):
|
||||||
|
"""Register a tool from a pre-built ToolDefinition."""
|
||||||
|
self._tools[definition.name] = definition
|
||||||
|
self._handlers[definition.name] = handler
|
||||||
|
logger.debug(f"Registered tool: {definition.name} (readOnly={definition.readOnly})")
|
||||||
|
|
||||||
|
def unregister(self, name: str):
|
||||||
|
"""Remove a tool from the registry."""
|
||||||
|
self._tools.pop(name, None)
|
||||||
|
self._handlers.pop(name, None)
|
||||||
|
|
||||||
|
def getTools(self, toolSet: str = None, featureType: str = None) -> List[ToolDefinition]:
|
||||||
|
"""Get available tools, optionally filtered by toolSet or featureType."""
|
||||||
|
tools = list(self._tools.values())
|
||||||
|
if featureType:
|
||||||
|
tools = [t for t in tools if t.featureType is None or t.featureType == featureType]
|
||||||
|
return tools
|
||||||
|
|
||||||
|
def getToolNames(self) -> List[str]:
|
||||||
|
"""Get names of all registered tools."""
|
||||||
|
return list(self._tools.keys())
|
||||||
|
|
||||||
|
def getTool(self, name: str) -> Optional[ToolDefinition]:
|
||||||
|
"""Get a single tool definition by name."""
|
||||||
|
return self._tools.get(name)
|
||||||
|
|
||||||
|
def isReadOnly(self, name: str) -> bool:
|
||||||
|
"""Check if a tool is marked as readOnly."""
|
||||||
|
tool = self._tools.get(name)
|
||||||
|
return tool.readOnly if tool else False
|
||||||
|
|
||||||
|
def isValidTool(self, name: str) -> bool:
|
||||||
|
"""Check if a tool name is valid (registered)."""
|
||||||
|
return name in self._tools
|
||||||
|
|
||||||
|
async def dispatch(self, toolCall: ToolCallRequest, context: Dict[str, Any] = None) -> ToolResult:
|
||||||
|
"""Execute a tool call and return the result."""
|
||||||
|
startTime = time.time()
|
||||||
|
|
||||||
|
if not self.isValidTool(toolCall.name):
|
||||||
|
return ToolResult(
|
||||||
|
toolCallId=toolCall.id,
|
||||||
|
toolName=toolCall.name,
|
||||||
|
success=False,
|
||||||
|
error=f"Unknown tool: '{toolCall.name}'. Available: {', '.join(self.getToolNames())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = self._handlers[toolCall.name]
|
||||||
|
try:
|
||||||
|
result = await handler(toolCall.args, context or {})
|
||||||
|
durationMs = int((time.time() - startTime) * 1000)
|
||||||
|
|
||||||
|
if isinstance(result, ToolResult):
|
||||||
|
result.toolCallId = toolCall.id
|
||||||
|
result.durationMs = durationMs
|
||||||
|
return result
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
toolCallId=toolCall.id,
|
||||||
|
toolName=toolCall.name,
|
||||||
|
success=True,
|
||||||
|
data=str(result),
|
||||||
|
durationMs=durationMs
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
durationMs = int((time.time() - startTime) * 1000)
|
||||||
|
logger.error(f"Tool '{toolCall.name}' failed: {e}", exc_info=True)
|
||||||
|
return ToolResult(
|
||||||
|
toolCallId=toolCall.id,
|
||||||
|
toolName=toolCall.name,
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
durationMs=durationMs
|
||||||
|
)
|
||||||
|
|
||||||
|
def formatToolsForPrompt(self) -> str:
|
||||||
|
"""Format all tools as text for system prompt (text-based fallback)."""
|
||||||
|
parts = []
|
||||||
|
for tool in self._tools.values():
|
||||||
|
paramStr = ", ".join(
|
||||||
|
f"{k}: {v}" for k, v in tool.parameters.items()
|
||||||
|
) if tool.parameters else "none"
|
||||||
|
parts.append(f"- **{tool.name}**: {tool.description}\n Parameters: {{{paramStr}}}")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
def formatToolsForFunctionCalling(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Format all tools as OpenAI-compatible function definitions for native function calling."""
|
||||||
|
functions = []
|
||||||
|
for tool in self._tools.values():
|
||||||
|
functions.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": tool.parameters if tool.parameters else {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return functions
|
||||||
|
|
@ -146,6 +146,8 @@ class AiService:
|
||||||
3. billingCallback on aiObjects: records one billing transaction per model call
|
3. billingCallback on aiObjects: records one billing transaction per model call
|
||||||
with exact provider + model name (set before AI call, invoked by _callWithModel)
|
with exact provider + model name (set before AI call, invoked by _callWithModel)
|
||||||
"""
|
"""
|
||||||
|
await self.ensureAiObjectsInitialized()
|
||||||
|
|
||||||
# SPEECH_TEAMS: Dedicated pipeline, bypasses standard model selection
|
# SPEECH_TEAMS: Dedicated pipeline, bypasses standard model selection
|
||||||
if request.options and request.options.operationType == OperationTypeEnum.SPEECH_TEAMS:
|
if request.options and request.options.operationType == OperationTypeEnum.SPEECH_TEAMS:
|
||||||
return await self._handleSpeechTeams(request)
|
return await self._handleSpeechTeams(request)
|
||||||
|
|
@ -179,6 +181,23 @@ class AiService:
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def callAiStream(self, request: AiCallRequest):
|
||||||
|
"""Streaming variant of callAi. Yields str deltas during generation, then final AiCallResponse."""
|
||||||
|
await self.ensureAiObjectsInitialized()
|
||||||
|
self._preflightBillingCheck()
|
||||||
|
await self._checkBillingBeforeAiCall()
|
||||||
|
|
||||||
|
effectiveProviders = self._calculateEffectiveProviders()
|
||||||
|
if effectiveProviders and request.options:
|
||||||
|
request.options = request.options.model_copy(update={'allowedProviders': effectiveProviders})
|
||||||
|
|
||||||
|
self.aiObjects.billingCallback = self._createBillingCallback()
|
||||||
|
try:
|
||||||
|
async for chunk in self.aiObjects.callWithTextContextStream(request):
|
||||||
|
yield chunk
|
||||||
|
finally:
|
||||||
|
self.aiObjects.billingCallback = None
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# SPEECH_TEAMS: Dedicated handler for Teams Meeting AI analysis
|
# SPEECH_TEAMS: Dedicated handler for Teams Meeting AI analysis
|
||||||
# Bypasses standard model selection. Uses a fixed fast model.
|
# Bypasses standard model selection. Uses a fixed fast model.
|
||||||
|
|
|
||||||
|
|
@ -411,23 +411,158 @@ class ChatService:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def getFileInfo(self, fileId: str) -> Dict[str, Any]:
|
def getFileInfo(self, fileId: str) -> Dict[str, Any]:
|
||||||
"""Get file information"""
|
"""Get file information including new fields (tags, folderId, description, status)."""
|
||||||
file_item = self.interfaceDbComponent.getFile(fileId)
|
fileItem = self.interfaceDbComponent.getFile(fileId)
|
||||||
if file_item:
|
if fileItem:
|
||||||
return {
|
return {
|
||||||
"id": file_item.id,
|
"id": fileItem.id,
|
||||||
"fileName": file_item.fileName,
|
"fileName": fileItem.fileName,
|
||||||
"size": file_item.fileSize,
|
"size": fileItem.fileSize,
|
||||||
"mimeType": file_item.mimeType,
|
"mimeType": fileItem.mimeType,
|
||||||
"fileHash": file_item.fileHash,
|
"fileHash": fileItem.fileHash,
|
||||||
"creationDate": file_item.creationDate
|
"creationDate": fileItem.creationDate,
|
||||||
|
"tags": getattr(fileItem, "tags", None),
|
||||||
|
"folderId": getattr(fileItem, "folderId", None),
|
||||||
|
"description": getattr(fileItem, "description", None),
|
||||||
|
"status": getattr(fileItem, "status", None),
|
||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def getFileData(self, fileId: str) -> bytes:
|
def getFileData(self, fileId: str) -> bytes:
|
||||||
"""Get file data by ID"""
|
"""Get file data by ID."""
|
||||||
return self.interfaceDbComponent.getFileData(fileId)
|
return self.interfaceDbComponent.getFileData(fileId)
|
||||||
|
|
||||||
|
def getFileContent(self, fileId: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get file content as text or base64 via FilePreview."""
|
||||||
|
preview = self.interfaceDbComponent.getFileContent(fileId)
|
||||||
|
if preview:
|
||||||
|
return preview.toDictWithBase64Encoding()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def listFiles(
|
||||||
|
self,
|
||||||
|
folderId: str = None,
|
||||||
|
tags: List[str] = None,
|
||||||
|
search: str = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""List files for the current user with optional filters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folderId: Filter by folder (None = root / all).
|
||||||
|
tags: Filter by tags (any match).
|
||||||
|
search: Search in fileName and description.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of file info dicts.
|
||||||
|
"""
|
||||||
|
allFiles = self.interfaceDbComponent.getAllFiles()
|
||||||
|
results = []
|
||||||
|
for fileItem in allFiles:
|
||||||
|
if folderId is not None:
|
||||||
|
itemFolderId = getattr(fileItem, "folderId", None)
|
||||||
|
if itemFolderId != folderId:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if tags:
|
||||||
|
itemTags = getattr(fileItem, "tags", None) or []
|
||||||
|
if not any(t in itemTags for t in tags):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if search:
|
||||||
|
searchLower = search.lower()
|
||||||
|
nameMatch = searchLower in (fileItem.fileName or "").lower()
|
||||||
|
descMatch = searchLower in (getattr(fileItem, "description", None) or "").lower()
|
||||||
|
if not nameMatch and not descMatch:
|
||||||
|
continue
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"id": fileItem.id,
|
||||||
|
"fileName": fileItem.fileName,
|
||||||
|
"mimeType": fileItem.mimeType,
|
||||||
|
"fileSize": fileItem.fileSize,
|
||||||
|
"creationDate": fileItem.creationDate,
|
||||||
|
"tags": getattr(fileItem, "tags", None),
|
||||||
|
"folderId": getattr(fileItem, "folderId", None),
|
||||||
|
"description": getattr(fileItem, "description", None),
|
||||||
|
"status": getattr(fileItem, "status", None),
|
||||||
|
})
|
||||||
|
return results
|
||||||
|
|
||||||
|
def listFolders(self, parentId: str = None) -> List[Dict[str, Any]]:
|
||||||
|
"""List file folders for the current user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parentId: Parent folder ID (None = root folders).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of folder dicts.
|
||||||
|
"""
|
||||||
|
from modules.datamodels.datamodelFileFolder import FileFolder
|
||||||
|
recordFilter = {"_createdBy": self.user.id if self.user else ""}
|
||||||
|
if parentId is not None:
|
||||||
|
recordFilter["parentId"] = parentId
|
||||||
|
else:
|
||||||
|
recordFilter["parentId"] = None
|
||||||
|
return self.interfaceDbComponent.db.getRecordset(FileFolder, recordFilter=recordFilter)
|
||||||
|
|
||||||
|
def createFolder(self, name: str, parentId: str = None) -> Dict[str, Any]:
|
||||||
|
"""Create a new file folder."""
|
||||||
|
from modules.datamodels.datamodelFileFolder import FileFolder
|
||||||
|
folder = FileFolder(name=name, parentId=parentId)
|
||||||
|
return self.interfaceDbComponent.db.recordCreate(FileFolder, folder)
|
||||||
|
|
||||||
|
# ---- DataSource CRUD ----
|
||||||
|
|
||||||
|
def createDataSource(
|
||||||
|
self, connectionId: str, sourceType: str, path: str, label: str,
|
||||||
|
featureInstanceId: str = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Create a new external data source reference."""
|
||||||
|
from modules.datamodels.datamodelDataSource import DataSource
|
||||||
|
ds = DataSource(
|
||||||
|
connectionId=connectionId,
|
||||||
|
sourceType=sourceType,
|
||||||
|
path=path,
|
||||||
|
label=label,
|
||||||
|
featureInstanceId=featureInstanceId or self._context.feature_instance_id or "",
|
||||||
|
mandateId=self._context.mandate_id or "",
|
||||||
|
userId=self.user.id if self.user else "",
|
||||||
|
)
|
||||||
|
return self.interfaceDbComponent.db.recordCreate(DataSource, ds)
|
||||||
|
|
||||||
|
def listDataSources(self, featureInstanceId: str = None) -> List[Dict[str, Any]]:
|
||||||
|
"""List data sources, optionally filtered by feature instance."""
|
||||||
|
from modules.datamodels.datamodelDataSource import DataSource
|
||||||
|
recordFilter = {}
|
||||||
|
if featureInstanceId:
|
||||||
|
recordFilter["featureInstanceId"] = featureInstanceId
|
||||||
|
return self.interfaceDbComponent.db.getRecordset(DataSource, recordFilter=recordFilter)
|
||||||
|
|
||||||
|
def getDataSource(self, dataSourceId: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get a single data source by ID."""
|
||||||
|
from modules.datamodels.datamodelDataSource import DataSource
|
||||||
|
return self.interfaceDbComponent.db.loadRecord(DataSource, dataSourceId)
|
||||||
|
|
||||||
|
def deleteDataSource(self, dataSourceId: str) -> bool:
|
||||||
|
"""Delete a data source."""
|
||||||
|
from modules.datamodels.datamodelDataSource import DataSource
|
||||||
|
try:
|
||||||
|
self.interfaceDbComponent.db.recordDelete(DataSource, dataSourceId)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete DataSource {dataSourceId}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def getUserConnections(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Get all UserConnections for the current user."""
|
||||||
|
try:
|
||||||
|
if self.interfaceDbApp and self.user:
|
||||||
|
connections = self.interfaceDbApp.getUserConnections(self.user.id)
|
||||||
|
return [c.model_dump() if hasattr(c, "model_dump") else c for c in (connections or [])]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting user connections: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
def _diagnoseDocumentAccess(self, document: ChatDocument) -> Dict[str, Any]:
|
def _diagnoseDocumentAccess(self, document: ChatDocument) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Diagnose document access issues and provide recovery information.
|
Diagnose document access issues and provide recovery information.
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,175 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Audio extractor for common audio formats.
|
||||||
|
|
||||||
|
Extracts metadata (duration, bitrate, sample rate, channels) and produces
|
||||||
|
an `audiostream` ContentPart. For files under 10 MB the base64 audio data
|
||||||
|
is included; larger files only get metadata.
|
||||||
|
|
||||||
|
Optional dependency: mutagen (for rich metadata).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from modules.datamodels.datamodelExtraction import ContentPart
|
||||||
|
from ..subUtils import makeId
|
||||||
|
from ..subRegistry import Extractor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_AUDIO_MIME_TYPES = [
|
||||||
|
"audio/mpeg",
|
||||||
|
"audio/mp3",
|
||||||
|
"audio/wav",
|
||||||
|
"audio/x-wav",
|
||||||
|
"audio/ogg",
|
||||||
|
"audio/flac",
|
||||||
|
"audio/x-flac",
|
||||||
|
"audio/mp4",
|
||||||
|
"audio/x-m4a",
|
||||||
|
"audio/aac",
|
||||||
|
"audio/webm",
|
||||||
|
]
|
||||||
|
_AUDIO_EXTENSIONS = [".mp3", ".wav", ".ogg", ".flac", ".m4a", ".aac", ".wma", ".webm"]
|
||||||
|
|
||||||
|
_MAX_INLINE_SIZE = 10 * 1024 * 1024 # 10 MB
|
||||||
|
|
||||||
|
|
||||||
|
class AudioExtractor(Extractor):
|
||||||
|
"""Extractor for audio files.
|
||||||
|
|
||||||
|
Produces:
|
||||||
|
- 1 text ContentPart with metadata summary
|
||||||
|
- 1 audiostream ContentPart (base64 data included only if < 10 MB)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def detect(self, fileName: str, mimeType: str, headBytes: bytes) -> bool:
|
||||||
|
if mimeType in _AUDIO_MIME_TYPES:
|
||||||
|
return True
|
||||||
|
lower = (fileName or "").lower()
|
||||||
|
return any(lower.endswith(ext) for ext in _AUDIO_EXTENSIONS)
|
||||||
|
|
||||||
|
def getSupportedExtensions(self) -> list[str]:
|
||||||
|
return list(_AUDIO_EXTENSIONS)
|
||||||
|
|
||||||
|
def getSupportedMimeTypes(self) -> list[str]:
|
||||||
|
return list(_AUDIO_MIME_TYPES)
|
||||||
|
|
||||||
|
def extract(self, fileBytes: bytes, context: Dict[str, Any]) -> List[ContentPart]:
|
||||||
|
fileName = context.get("fileName", "audio")
|
||||||
|
mimeType = context.get("mimeType") or "audio/mpeg"
|
||||||
|
fileSize = len(fileBytes)
|
||||||
|
|
||||||
|
rootId = makeId()
|
||||||
|
parts: List[ContentPart] = []
|
||||||
|
|
||||||
|
meta = _extractMetadata(fileBytes, fileName)
|
||||||
|
meta["size"] = fileSize
|
||||||
|
meta["fileName"] = fileName
|
||||||
|
meta["mimeType"] = mimeType
|
||||||
|
|
||||||
|
metaLines = [f"Audio file: {fileName}"]
|
||||||
|
if meta.get("duration"):
|
||||||
|
mins = int(meta["duration"] // 60)
|
||||||
|
secs = int(meta["duration"] % 60)
|
||||||
|
metaLines.append(f"Duration: {mins}:{secs:02d}")
|
||||||
|
if meta.get("bitrate"):
|
||||||
|
metaLines.append(f"Bitrate: {meta['bitrate']} kbps")
|
||||||
|
if meta.get("sampleRate"):
|
||||||
|
metaLines.append(f"Sample rate: {meta['sampleRate']} Hz")
|
||||||
|
if meta.get("channels"):
|
||||||
|
metaLines.append(f"Channels: {meta['channels']}")
|
||||||
|
if meta.get("title") or meta.get("artist") or meta.get("album"):
|
||||||
|
metaLines.append(f"Title: {meta.get('title', 'N/A')}")
|
||||||
|
metaLines.append(f"Artist: {meta.get('artist', 'N/A')}")
|
||||||
|
metaLines.append(f"Album: {meta.get('album', 'N/A')}")
|
||||||
|
metaLines.append(f"Size: {fileSize:,} bytes")
|
||||||
|
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=rootId, parentId=None, label="metadata",
|
||||||
|
typeGroup="text", mimeType="text/plain",
|
||||||
|
data="\n".join(metaLines), metadata=meta,
|
||||||
|
))
|
||||||
|
|
||||||
|
audioData = ""
|
||||||
|
if fileSize <= _MAX_INLINE_SIZE:
|
||||||
|
audioData = base64.b64encode(fileBytes).decode("utf-8")
|
||||||
|
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=makeId(), parentId=rootId, label="audiostream",
|
||||||
|
typeGroup="audiostream", mimeType=mimeType,
|
||||||
|
data=audioData, metadata={"size": fileSize, "inlined": fileSize <= _MAX_INLINE_SIZE},
|
||||||
|
))
|
||||||
|
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def _extractMetadata(fileBytes: bytes, fileName: str) -> Dict[str, Any]:
|
||||||
|
"""Extract audio metadata using mutagen (optional) with stdlib fallback."""
|
||||||
|
meta: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
import mutagen
|
||||||
|
import io
|
||||||
|
audio = mutagen.File(io.BytesIO(fileBytes))
|
||||||
|
if audio is not None:
|
||||||
|
if audio.info:
|
||||||
|
meta["duration"] = getattr(audio.info, "length", None)
|
||||||
|
meta["bitrate"] = getattr(audio.info, "bitrate", None)
|
||||||
|
if meta["bitrate"]:
|
||||||
|
meta["bitrate"] = meta["bitrate"] // 1000
|
||||||
|
meta["sampleRate"] = getattr(audio.info, "sample_rate", None)
|
||||||
|
meta["channels"] = getattr(audio.info, "channels", None)
|
||||||
|
|
||||||
|
tags = audio.tags
|
||||||
|
if tags:
|
||||||
|
meta["title"] = _getTag(tags, ["TIT2", "title", "\xa9nam"])
|
||||||
|
meta["artist"] = _getTag(tags, ["TPE1", "artist", "\xa9ART"])
|
||||||
|
meta["album"] = _getTag(tags, ["TALB", "album", "\xa9alb"])
|
||||||
|
|
||||||
|
return {k: v for k, v in meta.items() if v is not None}
|
||||||
|
except ImportError:
|
||||||
|
logger.debug("mutagen not installed -- using basic metadata extraction")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"mutagen metadata extraction failed: {e}")
|
||||||
|
|
||||||
|
lower = fileName.lower()
|
||||||
|
if lower.endswith(".wav"):
|
||||||
|
meta.update(_parseWavHeader(fileBytes))
|
||||||
|
|
||||||
|
return {k: v for k, v in meta.items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
|
def _getTag(tags, keys: list) -> Any:
|
||||||
|
"""Try multiple tag keys and return the first found value."""
|
||||||
|
for key in keys:
|
||||||
|
val = tags.get(key)
|
||||||
|
if val is not None:
|
||||||
|
return str(val) if not isinstance(val, str) else val
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _parseWavHeader(fileBytes: bytes) -> Dict[str, Any]:
|
||||||
|
"""Minimal WAV header parser for basic metadata."""
|
||||||
|
meta: Dict[str, Any] = {}
|
||||||
|
if len(fileBytes) < 44:
|
||||||
|
return meta
|
||||||
|
try:
|
||||||
|
if fileBytes[:4] != b"RIFF" or fileBytes[8:12] != b"WAVE":
|
||||||
|
return meta
|
||||||
|
channels = struct.unpack_from("<H", fileBytes, 22)[0]
|
||||||
|
sampleRate = struct.unpack_from("<I", fileBytes, 24)[0]
|
||||||
|
bitsPerSample = struct.unpack_from("<H", fileBytes, 34)[0]
|
||||||
|
dataSize = struct.unpack_from("<I", fileBytes, 40)[0]
|
||||||
|
|
||||||
|
meta["channels"] = channels
|
||||||
|
meta["sampleRate"] = sampleRate
|
||||||
|
meta["bitrate"] = (sampleRate * channels * bitsPerSample) // 1000
|
||||||
|
if sampleRate and channels and bitsPerSample:
|
||||||
|
meta["duration"] = dataSize / (sampleRate * channels * (bitsPerSample / 8))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return meta
|
||||||
|
|
@ -0,0 +1,339 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Container extractor for ZIP, TAR, GZ, and 7Z archives.
|
||||||
|
|
||||||
|
Recursively unpacks containers and delegates each contained file to the
|
||||||
|
appropriate type-specific extractor via the ExtractorRegistry.
|
||||||
|
|
||||||
|
Safety limits:
|
||||||
|
- MAX_TOTAL_EXTRACTED_SIZE: 500 MB
|
||||||
|
- MAX_FILE_COUNT: 10000
|
||||||
|
- maxDepth: 5
|
||||||
|
- Symlinks blocked
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
import zipfile
|
||||||
|
import tarfile
|
||||||
|
|
||||||
|
from ..subUtils import makeId
|
||||||
|
from modules.datamodels.datamodelExtraction import ContentPart
|
||||||
|
from modules.datamodels.datamodelContent import ContainerLimitError, ContentContextRef
|
||||||
|
from ..subRegistry import Extractor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_TOTAL_EXTRACTED_SIZE = 500 * 1024 * 1024 # 500 MB
|
||||||
|
MAX_FILE_COUNT = 10000
|
||||||
|
MAX_DEPTH = 5
|
||||||
|
|
||||||
|
_CONTAINER_MIME_TYPES = [
|
||||||
|
"application/zip",
|
||||||
|
"application/x-zip-compressed",
|
||||||
|
"application/x-tar",
|
||||||
|
"application/gzip",
|
||||||
|
"application/x-gzip",
|
||||||
|
"application/x-7z-compressed",
|
||||||
|
]
|
||||||
|
_CONTAINER_EXTENSIONS = [".zip", ".tar", ".gz", ".tar.gz", ".tgz", ".7z"]
|
||||||
|
|
||||||
|
|
||||||
|
def _detectMimeType(fileName: str) -> str:
|
||||||
|
"""Detect MIME type from file name."""
|
||||||
|
guessed, _ = mimetypes.guess_type(fileName)
|
||||||
|
return guessed or "application/octet-stream"
|
||||||
|
|
||||||
|
|
||||||
|
def _isSymlink(info) -> bool:
|
||||||
|
"""Check if a tar member is a symlink."""
|
||||||
|
if hasattr(info, "issym") and callable(info.issym):
|
||||||
|
return info.issym() or info.islnk()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class ContainerExtractor(Extractor):
|
||||||
|
"""Extractor for archive containers (ZIP, TAR, GZ, 7Z).
|
||||||
|
|
||||||
|
Recursively resolves nested containers and produces a flat list of
|
||||||
|
ContentPart entries -- one per contained file -- with containerPath metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def detect(self, fileName: str, mimeType: str, headBytes: bytes) -> bool:
|
||||||
|
if mimeType in _CONTAINER_MIME_TYPES:
|
||||||
|
return True
|
||||||
|
lower = (fileName or "").lower()
|
||||||
|
return any(lower.endswith(ext) for ext in _CONTAINER_EXTENSIONS)
|
||||||
|
|
||||||
|
def getSupportedExtensions(self) -> list[str]:
|
||||||
|
return list(_CONTAINER_EXTENSIONS)
|
||||||
|
|
||||||
|
def getSupportedMimeTypes(self) -> list[str]:
|
||||||
|
return list(_CONTAINER_MIME_TYPES)
|
||||||
|
|
||||||
|
def extract(self, fileBytes: bytes, context: Dict[str, Any]) -> List[ContentPart]:
|
||||||
|
"""Extract by recursively unpacking the container."""
|
||||||
|
fileName = context.get("fileName", "archive")
|
||||||
|
mimeType = context.get("mimeType", "application/octet-stream")
|
||||||
|
|
||||||
|
rootId = makeId()
|
||||||
|
parts: List[ContentPart] = [
|
||||||
|
ContentPart(
|
||||||
|
id=rootId,
|
||||||
|
parentId=None,
|
||||||
|
label=fileName,
|
||||||
|
typeGroup="container",
|
||||||
|
mimeType=mimeType,
|
||||||
|
data="",
|
||||||
|
metadata={"size": len(fileBytes), "containerType": "archive"},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
state = {"totalSize": 0, "fileCount": 0}
|
||||||
|
try:
|
||||||
|
childParts = _resolveContainerRecursive(
|
||||||
|
fileBytes, mimeType, fileName, rootId, "", 0, state
|
||||||
|
)
|
||||||
|
parts.extend(childParts)
|
||||||
|
except ContainerLimitError as e:
|
||||||
|
logger.warning(f"Container limit reached for {fileName}: {e}")
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=makeId(),
|
||||||
|
parentId=rootId,
|
||||||
|
label="limit_exceeded",
|
||||||
|
typeGroup="text",
|
||||||
|
mimeType="text/plain",
|
||||||
|
data=str(e),
|
||||||
|
metadata={"warning": "Container extraction limit exceeded"},
|
||||||
|
))
|
||||||
|
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def _resolveContainerRecursive(
|
||||||
|
containerBytes: bytes,
|
||||||
|
containerMime: str,
|
||||||
|
containerName: str,
|
||||||
|
parentId: str,
|
||||||
|
containerPath: str,
|
||||||
|
depth: int,
|
||||||
|
state: Dict[str, int],
|
||||||
|
) -> List[ContentPart]:
|
||||||
|
"""Recursively unpack containers. No AI calls."""
|
||||||
|
if depth > MAX_DEPTH:
|
||||||
|
raise ContainerLimitError(f"Max nesting depth {MAX_DEPTH} exceeded")
|
||||||
|
|
||||||
|
parts: List[ContentPart] = []
|
||||||
|
|
||||||
|
if containerMime in ("application/zip", "application/x-zip-compressed") or containerName.lower().endswith(".zip"):
|
||||||
|
parts.extend(_extractZip(containerBytes, parentId, containerPath, depth, state))
|
||||||
|
elif containerMime in ("application/x-tar",) or containerName.lower().endswith(".tar"):
|
||||||
|
parts.extend(_extractTar(containerBytes, parentId, containerPath, depth, state, compressed=False))
|
||||||
|
elif containerMime in ("application/gzip", "application/x-gzip") or containerName.lower().endswith((".gz", ".tgz", ".tar.gz")):
|
||||||
|
parts.extend(_extractTar(containerBytes, parentId, containerPath, depth, state, compressed=True))
|
||||||
|
elif containerName.lower().endswith(".7z"):
|
||||||
|
parts.extend(_extract7z(containerBytes, parentId, containerPath, depth, state))
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unknown container format: {containerMime} ({containerName})")
|
||||||
|
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def _addFilePart(
|
||||||
|
data: bytes,
|
||||||
|
fileName: str,
|
||||||
|
parentId: str,
|
||||||
|
containerPath: str,
|
||||||
|
state: Dict[str, int],
|
||||||
|
) -> List[ContentPart]:
|
||||||
|
"""Extract a file via its type-specific Extractor and return ContentParts."""
|
||||||
|
state["totalSize"] += len(data)
|
||||||
|
state["fileCount"] += 1
|
||||||
|
|
||||||
|
if state["totalSize"] > MAX_TOTAL_EXTRACTED_SIZE:
|
||||||
|
raise ContainerLimitError(f"Total extracted size exceeds {MAX_TOTAL_EXTRACTED_SIZE // (1024 * 1024)} MB")
|
||||||
|
if state["fileCount"] > MAX_FILE_COUNT:
|
||||||
|
raise ContainerLimitError(f"File count exceeds {MAX_FILE_COUNT}")
|
||||||
|
|
||||||
|
entryPath = f"{containerPath}/{fileName}" if containerPath else fileName
|
||||||
|
detectedMime = _detectMimeType(fileName)
|
||||||
|
|
||||||
|
from ..subRegistry import ExtractorRegistry
|
||||||
|
registry = ExtractorRegistry()
|
||||||
|
extractor = registry.resolve(detectedMime, fileName)
|
||||||
|
|
||||||
|
if extractor and not isinstance(extractor, ContainerExtractor):
|
||||||
|
try:
|
||||||
|
childParts = extractor.extract(data, {"fileName": fileName, "mimeType": detectedMime})
|
||||||
|
for part in childParts:
|
||||||
|
part.parentId = parentId
|
||||||
|
if not part.metadata:
|
||||||
|
part.metadata = {}
|
||||||
|
part.metadata["containerPath"] = entryPath
|
||||||
|
return childParts
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Type-extractor failed for {fileName} in container: {e}")
|
||||||
|
|
||||||
|
import base64
|
||||||
|
encodedData = base64.b64encode(data).decode("utf-8") if data else ""
|
||||||
|
|
||||||
|
return [ContentPart(
|
||||||
|
id=makeId(),
|
||||||
|
parentId=parentId,
|
||||||
|
label=fileName,
|
||||||
|
typeGroup="binary",
|
||||||
|
mimeType=detectedMime,
|
||||||
|
data=encodedData,
|
||||||
|
metadata={
|
||||||
|
"size": len(data),
|
||||||
|
"containerPath": entryPath,
|
||||||
|
"contextRef": ContentContextRef(
|
||||||
|
containerPath=entryPath,
|
||||||
|
location="file",
|
||||||
|
).model_dump(),
|
||||||
|
},
|
||||||
|
)]
|
||||||
|
|
||||||
|
|
||||||
|
def _isNestedContainer(fileName: str, mimeType: str) -> bool:
|
||||||
|
lower = fileName.lower()
|
||||||
|
return any(lower.endswith(ext) for ext in _CONTAINER_EXTENSIONS) or mimeType in _CONTAINER_MIME_TYPES
|
||||||
|
|
||||||
|
|
||||||
|
def _extractZip(
|
||||||
|
data: bytes, parentId: str, containerPath: str, depth: int, state: Dict[str, int]
|
||||||
|
) -> List[ContentPart]:
|
||||||
|
parts: List[ContentPart] = []
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(io.BytesIO(data)) as zf:
|
||||||
|
for info in zf.infolist():
|
||||||
|
if info.is_dir():
|
||||||
|
continue
|
||||||
|
if info.file_size == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
entryPath = f"{containerPath}/{info.filename}" if containerPath else info.filename
|
||||||
|
entryMime = _detectMimeType(info.filename)
|
||||||
|
entryData = zf.read(info.filename)
|
||||||
|
|
||||||
|
if _isNestedContainer(info.filename, entryMime):
|
||||||
|
nestedId = makeId()
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=nestedId,
|
||||||
|
parentId=parentId,
|
||||||
|
label=info.filename,
|
||||||
|
typeGroup="container",
|
||||||
|
mimeType=entryMime,
|
||||||
|
data="",
|
||||||
|
metadata={"size": len(entryData), "containerPath": entryPath},
|
||||||
|
))
|
||||||
|
nested = _resolveContainerRecursive(
|
||||||
|
entryData, entryMime, info.filename, nestedId, entryPath, depth + 1, state
|
||||||
|
)
|
||||||
|
parts.extend(nested)
|
||||||
|
else:
|
||||||
|
parts.extend(_addFilePart(entryData, info.filename, parentId, containerPath, state))
|
||||||
|
except zipfile.BadZipFile as e:
|
||||||
|
logger.error(f"Invalid ZIP file: {e}")
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=makeId(), parentId=parentId, label="error",
|
||||||
|
typeGroup="text", mimeType="text/plain",
|
||||||
|
data=f"Invalid ZIP archive: {e}", metadata={"error": True},
|
||||||
|
))
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def _extractTar(
|
||||||
|
data: bytes, parentId: str, containerPath: str, depth: int, state: Dict[str, int],
|
||||||
|
compressed: bool = False,
|
||||||
|
) -> List[ContentPart]:
|
||||||
|
parts: List[ContentPart] = []
|
||||||
|
mode = "r:gz" if compressed else "r:"
|
||||||
|
try:
|
||||||
|
with tarfile.open(fileobj=io.BytesIO(data), mode=mode) as tf:
|
||||||
|
for member in tf.getmembers():
|
||||||
|
if member.isdir():
|
||||||
|
continue
|
||||||
|
if _isSymlink(member):
|
||||||
|
logger.warning(f"Skipping symlink in TAR: {member.name}")
|
||||||
|
continue
|
||||||
|
if member.size == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
entryPath = f"{containerPath}/{member.name}" if containerPath else member.name
|
||||||
|
entryMime = _detectMimeType(member.name)
|
||||||
|
fobj = tf.extractfile(member)
|
||||||
|
if fobj is None:
|
||||||
|
continue
|
||||||
|
entryData = fobj.read()
|
||||||
|
|
||||||
|
if _isNestedContainer(member.name, entryMime):
|
||||||
|
nestedId = makeId()
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=nestedId, parentId=parentId, label=member.name,
|
||||||
|
typeGroup="container", mimeType=entryMime, data="",
|
||||||
|
metadata={"size": len(entryData), "containerPath": entryPath},
|
||||||
|
))
|
||||||
|
nested = _resolveContainerRecursive(
|
||||||
|
entryData, entryMime, member.name, nestedId, entryPath, depth + 1, state
|
||||||
|
)
|
||||||
|
parts.extend(nested)
|
||||||
|
else:
|
||||||
|
parts.extend(_addFilePart(entryData, member.name, parentId, containerPath, state))
|
||||||
|
except tarfile.TarError as e:
|
||||||
|
logger.error(f"Invalid TAR file: {e}")
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=makeId(), parentId=parentId, label="error",
|
||||||
|
typeGroup="text", mimeType="text/plain",
|
||||||
|
data=f"Invalid TAR archive: {e}", metadata={"error": True},
|
||||||
|
))
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def _extract7z(
|
||||||
|
data: bytes, parentId: str, containerPath: str, depth: int, state: Dict[str, int]
|
||||||
|
) -> List[ContentPart]:
|
||||||
|
"""Extract 7z archive. Requires py7zr (optional dependency)."""
|
||||||
|
parts: List[ContentPart] = []
|
||||||
|
try:
|
||||||
|
import py7zr
|
||||||
|
with py7zr.SevenZipFile(io.BytesIO(data), mode="r") as szf:
|
||||||
|
allFiles = szf.readall()
|
||||||
|
for fileName, bio in allFiles.items():
|
||||||
|
entryData = bio.read() if hasattr(bio, "read") else bytes(bio)
|
||||||
|
if not entryData:
|
||||||
|
continue
|
||||||
|
|
||||||
|
entryPath = f"{containerPath}/{fileName}" if containerPath else fileName
|
||||||
|
entryMime = _detectMimeType(fileName)
|
||||||
|
|
||||||
|
if _isNestedContainer(fileName, entryMime):
|
||||||
|
nestedId = makeId()
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=nestedId, parentId=parentId, label=fileName,
|
||||||
|
typeGroup="container", mimeType=entryMime, data="",
|
||||||
|
metadata={"size": len(entryData), "containerPath": entryPath},
|
||||||
|
))
|
||||||
|
nested = _resolveContainerRecursive(
|
||||||
|
entryData, entryMime, fileName, nestedId, entryPath, depth + 1, state
|
||||||
|
)
|
||||||
|
parts.extend(nested)
|
||||||
|
else:
|
||||||
|
parts.extend(_addFilePart(entryData, fileName, parentId, containerPath, state))
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("py7zr not installed -- 7z files will be treated as binary")
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=makeId(), parentId=parentId, label="unsupported",
|
||||||
|
typeGroup="text", mimeType="text/plain",
|
||||||
|
data="7z extraction requires py7zr package", metadata={"warning": True},
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Invalid 7z file: {e}")
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=makeId(), parentId=parentId, label="error",
|
||||||
|
typeGroup="text", mimeType="text/plain",
|
||||||
|
data=f"Invalid 7z archive: {e}", metadata={"error": True},
|
||||||
|
))
|
||||||
|
return parts
|
||||||
|
|
@ -74,19 +74,33 @@ class DocxExtractor(Extractor):
|
||||||
with io.BytesIO(fileBytes) as buf:
|
with io.BytesIO(fileBytes) as buf:
|
||||||
d = docx.Document(buf)
|
d = docx.Document(buf)
|
||||||
# paragraphs
|
# paragraphs
|
||||||
|
fileName = context.get("fileName", "document.docx")
|
||||||
|
headingIndex = 0
|
||||||
|
currentSection = "body"
|
||||||
for i, para in enumerate(d.paragraphs):
|
for i, para in enumerate(d.paragraphs):
|
||||||
text = para.text or ""
|
text = para.text or ""
|
||||||
if text.strip():
|
if not text.strip():
|
||||||
parts.append(ContentPart(
|
continue
|
||||||
id=makeId(),
|
styleName = (para.style.name or "").lower() if para.style else ""
|
||||||
parentId=rootId,
|
if "heading" in styleName:
|
||||||
label=f"p_{i+1}",
|
headingIndex += 1
|
||||||
typeGroup="text",
|
currentSection = f"heading:{headingIndex}"
|
||||||
mimeType="text/plain",
|
parts.append(ContentPart(
|
||||||
data=text,
|
id=makeId(),
|
||||||
metadata={"size": len(text.encode('utf-8'))}
|
parentId=rootId,
|
||||||
))
|
label=f"p_{i+1}",
|
||||||
# tables → CSV rows
|
typeGroup="text",
|
||||||
|
mimeType="text/plain",
|
||||||
|
data=text,
|
||||||
|
metadata={
|
||||||
|
"size": len(text.encode('utf-8')),
|
||||||
|
"contextRef": {
|
||||||
|
"containerPath": fileName,
|
||||||
|
"location": f"paragraph:{i+1}",
|
||||||
|
"sectionId": currentSection,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
))
|
||||||
for ti, table in enumerate(d.tables):
|
for ti, table in enumerate(d.tables):
|
||||||
rows: list[str] = []
|
rows: list[str] = []
|
||||||
for row in table.rows:
|
for row in table.rows:
|
||||||
|
|
@ -101,7 +115,14 @@ class DocxExtractor(Extractor):
|
||||||
typeGroup="table",
|
typeGroup="table",
|
||||||
mimeType="text/csv",
|
mimeType="text/csv",
|
||||||
data=csvData,
|
data=csvData,
|
||||||
metadata={"size": len(csvData.encode('utf-8'))}
|
metadata={
|
||||||
|
"size": len(csvData.encode('utf-8')),
|
||||||
|
"contextRef": {
|
||||||
|
"containerPath": fileName,
|
||||||
|
"location": f"table:{ti+1}",
|
||||||
|
"sectionId": currentSection,
|
||||||
|
},
|
||||||
|
}
|
||||||
))
|
))
|
||||||
|
|
||||||
return parts
|
return parts
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,230 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Email extractor for EML and MSG files.
|
||||||
|
|
||||||
|
Parses email headers, body (text/html), and attachments.
|
||||||
|
Attachments are delegated to the ExtractorRegistry for type-specific processing.
|
||||||
|
|
||||||
|
Optional dependency: extract-msg (for .msg files).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
import email
|
||||||
|
import email.policy
|
||||||
|
import email.utils
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
|
||||||
|
from modules.datamodels.datamodelExtraction import ContentPart
|
||||||
|
from ..subUtils import makeId
|
||||||
|
from ..subRegistry import Extractor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_EMAIL_MIME_TYPES = [
|
||||||
|
"message/rfc822",
|
||||||
|
"application/vnd.ms-outlook",
|
||||||
|
]
|
||||||
|
_EMAIL_EXTENSIONS = [".eml", ".msg"]
|
||||||
|
|
||||||
|
|
||||||
|
class EmailExtractor(Extractor):
|
||||||
|
"""Extractor for email files (EML, MSG).
|
||||||
|
|
||||||
|
Produces:
|
||||||
|
- 1 text ContentPart with header metadata (From, To, Subject, Date)
|
||||||
|
- 1 text ContentPart per body part (plain text / HTML)
|
||||||
|
- Delegated ContentParts for each attachment via ExtractorRegistry
|
||||||
|
"""
|
||||||
|
|
||||||
|
def detect(self, fileName: str, mimeType: str, headBytes: bytes) -> bool:
|
||||||
|
if mimeType in _EMAIL_MIME_TYPES:
|
||||||
|
return True
|
||||||
|
lower = (fileName or "").lower()
|
||||||
|
return any(lower.endswith(ext) for ext in _EMAIL_EXTENSIONS)
|
||||||
|
|
||||||
|
def getSupportedExtensions(self) -> list[str]:
|
||||||
|
return list(_EMAIL_EXTENSIONS)
|
||||||
|
|
||||||
|
def getSupportedMimeTypes(self) -> list[str]:
|
||||||
|
return list(_EMAIL_MIME_TYPES)
|
||||||
|
|
||||||
|
def extract(self, fileBytes: bytes, context: Dict[str, Any]) -> List[ContentPart]:
|
||||||
|
fileName = context.get("fileName", "email")
|
||||||
|
lower = (fileName or "").lower()
|
||||||
|
|
||||||
|
if lower.endswith(".msg"):
|
||||||
|
return self._extractMsg(fileBytes, fileName)
|
||||||
|
return self._extractEml(fileBytes, fileName)
|
||||||
|
|
||||||
|
def _extractEml(self, fileBytes: bytes, fileName: str) -> List[ContentPart]:
|
||||||
|
"""Parse standard EML (RFC 822) using stdlib email."""
|
||||||
|
rootId = makeId()
|
||||||
|
parts: List[ContentPart] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
msg = email.message_from_bytes(fileBytes, policy=email.policy.default)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"EmailExtractor: failed to parse EML: {e}")
|
||||||
|
return [ContentPart(
|
||||||
|
id=rootId, parentId=None, label=fileName,
|
||||||
|
typeGroup="text", mimeType="text/plain",
|
||||||
|
data=f"Failed to parse email: {e}", metadata={"error": True},
|
||||||
|
)]
|
||||||
|
|
||||||
|
headerText = _buildHeaderText(msg)
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=rootId, parentId=None, label="headers",
|
||||||
|
typeGroup="text", mimeType="text/plain",
|
||||||
|
data=headerText, metadata={"emailPart": "headers"},
|
||||||
|
))
|
||||||
|
|
||||||
|
for part in msg.walk():
|
||||||
|
contentType = part.get_content_type()
|
||||||
|
disposition = str(part.get("Content-Disposition", ""))
|
||||||
|
|
||||||
|
if part.is_multipart():
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "attachment" in disposition:
|
||||||
|
attachName = part.get_filename() or "attachment"
|
||||||
|
attachData = part.get_payload(decode=True)
|
||||||
|
if attachData:
|
||||||
|
parts.extend(_delegateAttachment(attachData, attachName, rootId))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if contentType == "text/plain":
|
||||||
|
body = part.get_content()
|
||||||
|
if body:
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=makeId(), parentId=rootId, label="body_text",
|
||||||
|
typeGroup="text", mimeType="text/plain",
|
||||||
|
data=str(body), metadata={"emailPart": "body"},
|
||||||
|
))
|
||||||
|
elif contentType == "text/html":
|
||||||
|
body = part.get_content()
|
||||||
|
if body:
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=makeId(), parentId=rootId, label="body_html",
|
||||||
|
typeGroup="text", mimeType="text/html",
|
||||||
|
data=str(body), metadata={"emailPart": "body_html"},
|
||||||
|
))
|
||||||
|
|
||||||
|
return parts
|
||||||
|
|
||||||
|
def _extractMsg(self, fileBytes: bytes, fileName: str) -> List[ContentPart]:
|
||||||
|
"""Parse Outlook MSG files using extract-msg (optional)."""
|
||||||
|
rootId = makeId()
|
||||||
|
parts: List[ContentPart] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
import extract_msg
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("extract-msg not installed -- MSG files will be treated as binary")
|
||||||
|
return [ContentPart(
|
||||||
|
id=rootId, parentId=None, label=fileName,
|
||||||
|
typeGroup="text", mimeType="text/plain",
|
||||||
|
data="MSG extraction requires the extract-msg package.",
|
||||||
|
metadata={"warning": True},
|
||||||
|
)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
msgFile = extract_msg.Message(io.BytesIO(fileBytes))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"EmailExtractor: failed to parse MSG: {e}")
|
||||||
|
return [ContentPart(
|
||||||
|
id=rootId, parentId=None, label=fileName,
|
||||||
|
typeGroup="text", mimeType="text/plain",
|
||||||
|
data=f"Failed to parse MSG: {e}", metadata={"error": True},
|
||||||
|
)]
|
||||||
|
|
||||||
|
headerLines = []
|
||||||
|
if msgFile.sender:
|
||||||
|
headerLines.append(f"From: {msgFile.sender}")
|
||||||
|
if msgFile.to:
|
||||||
|
headerLines.append(f"To: {msgFile.to}")
|
||||||
|
if getattr(msgFile, "cc", None):
|
||||||
|
headerLines.append(f"Cc: {msgFile.cc}")
|
||||||
|
if msgFile.subject:
|
||||||
|
headerLines.append(f"Subject: {msgFile.subject}")
|
||||||
|
if msgFile.date:
|
||||||
|
headerLines.append(f"Date: {msgFile.date}")
|
||||||
|
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=rootId, parentId=None, label="headers",
|
||||||
|
typeGroup="text", mimeType="text/plain",
|
||||||
|
data="\n".join(headerLines), metadata={"emailPart": "headers"},
|
||||||
|
))
|
||||||
|
|
||||||
|
body = msgFile.body
|
||||||
|
if body:
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=makeId(), parentId=rootId, label="body_text",
|
||||||
|
typeGroup="text", mimeType="text/plain",
|
||||||
|
data=body, metadata={"emailPart": "body"},
|
||||||
|
))
|
||||||
|
|
||||||
|
htmlBody = getattr(msgFile, "htmlBody", None)
|
||||||
|
if htmlBody:
|
||||||
|
if isinstance(htmlBody, bytes):
|
||||||
|
htmlBody = htmlBody.decode("utf-8", errors="replace")
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=makeId(), parentId=rootId, label="body_html",
|
||||||
|
typeGroup="text", mimeType="text/html",
|
||||||
|
data=htmlBody, metadata={"emailPart": "body_html"},
|
||||||
|
))
|
||||||
|
|
||||||
|
for attachment in (msgFile.attachments or []):
|
||||||
|
attachName = getattr(attachment, "longFilename", None) or getattr(attachment, "shortFilename", None) or "attachment"
|
||||||
|
attachData = getattr(attachment, "data", None)
|
||||||
|
if attachData:
|
||||||
|
parts.extend(_delegateAttachment(attachData, attachName, rootId))
|
||||||
|
|
||||||
|
try:
|
||||||
|
msgFile.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def _buildHeaderText(msg) -> str:
|
||||||
|
"""Build a readable text summary of key email headers."""
|
||||||
|
lines = []
|
||||||
|
for header in ("From", "To", "Cc", "Subject", "Date", "Message-ID"):
|
||||||
|
value = msg.get(header)
|
||||||
|
if value:
|
||||||
|
lines.append(f"{header}: {value}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _delegateAttachment(attachData: bytes, attachName: str, parentId: str) -> List[ContentPart]:
|
||||||
|
"""Delegate an attachment to the appropriate type-specific extractor."""
|
||||||
|
guessedMime, _ = mimetypes.guess_type(attachName)
|
||||||
|
detectedMime = guessedMime or "application/octet-stream"
|
||||||
|
|
||||||
|
from ..subRegistry import ExtractorRegistry
|
||||||
|
registry = ExtractorRegistry()
|
||||||
|
extractor = registry.resolve(detectedMime, attachName)
|
||||||
|
|
||||||
|
if extractor and not isinstance(extractor, EmailExtractor):
|
||||||
|
try:
|
||||||
|
childParts = extractor.extract(attachData, {"fileName": attachName, "mimeType": detectedMime})
|
||||||
|
for part in childParts:
|
||||||
|
part.parentId = parentId
|
||||||
|
if not part.metadata:
|
||||||
|
part.metadata = {}
|
||||||
|
part.metadata["emailAttachment"] = attachName
|
||||||
|
return childParts
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Extractor failed for email attachment {attachName}: {e}")
|
||||||
|
|
||||||
|
import base64
|
||||||
|
encodedData = base64.b64encode(attachData).decode("utf-8") if attachData else ""
|
||||||
|
return [ContentPart(
|
||||||
|
id=makeId(), parentId=parentId, label=attachName,
|
||||||
|
typeGroup="binary", mimeType=detectedMime,
|
||||||
|
data=encodedData,
|
||||||
|
metadata={"size": len(attachData), "emailAttachment": attachName},
|
||||||
|
)]
|
||||||
|
|
@ -0,0 +1,184 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Folder extractor -- treats a local folder reference as a container.
|
||||||
|
|
||||||
|
Not registered in the MIME-based ExtractorRegistry (folders have no MIME type).
|
||||||
|
Instead, called directly by agent tools (browseContainer) when handling folder references.
|
||||||
|
|
||||||
|
Applies the same safety limits as ContainerExtractor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from ..subUtils import makeId
|
||||||
|
from modules.datamodels.datamodelExtraction import ContentPart
|
||||||
|
from modules.datamodels.datamodelContent import ContainerLimitError, ContentContextRef
|
||||||
|
from ..subRegistry import Extractor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_TOTAL_EXTRACTED_SIZE = 500 * 1024 * 1024
|
||||||
|
MAX_FILE_COUNT = 10000
|
||||||
|
MAX_DEPTH = 5
|
||||||
|
|
||||||
|
|
||||||
|
class FolderExtractor(Extractor):
|
||||||
|
"""Extracts contents from a local folder path.
|
||||||
|
|
||||||
|
Unlike other extractors, this does not receive fileBytes. Instead it
|
||||||
|
receives a folder path via context["folderPath"] and walks the directory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def detect(self, fileName: str, mimeType: str, headBytes: bytes) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def getSupportedExtensions(self) -> list[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def getSupportedMimeTypes(self) -> list[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def extract(self, fileBytes: bytes, context: Dict[str, Any]) -> List[ContentPart]:
|
||||||
|
"""Extract folder contents.
|
||||||
|
|
||||||
|
context must contain:
|
||||||
|
folderPath: str -- absolute path to the folder
|
||||||
|
"""
|
||||||
|
folderPath = context.get("folderPath", "")
|
||||||
|
if not folderPath:
|
||||||
|
return []
|
||||||
|
|
||||||
|
folder = Path(folderPath)
|
||||||
|
if not folder.is_dir():
|
||||||
|
logger.error(f"FolderExtractor: not a directory: {folderPath}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
rootId = makeId()
|
||||||
|
parts: List[ContentPart] = [
|
||||||
|
ContentPart(
|
||||||
|
id=rootId,
|
||||||
|
parentId=None,
|
||||||
|
label=folder.name or "folder",
|
||||||
|
typeGroup="container",
|
||||||
|
mimeType="inode/directory",
|
||||||
|
data="",
|
||||||
|
metadata={"folderPath": str(folder), "containerType": "folder"},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
state = {"totalSize": 0, "fileCount": 0}
|
||||||
|
try:
|
||||||
|
_walkFolder(folder, rootId, "", 0, state, parts)
|
||||||
|
except ContainerLimitError as e:
|
||||||
|
logger.warning(f"Folder extraction limit reached: {e}")
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=makeId(),
|
||||||
|
parentId=rootId,
|
||||||
|
label="limit_exceeded",
|
||||||
|
typeGroup="text",
|
||||||
|
mimeType="text/plain",
|
||||||
|
data=str(e),
|
||||||
|
metadata={"warning": "Folder extraction limit exceeded"},
|
||||||
|
))
|
||||||
|
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def _walkFolder(
|
||||||
|
folder: Path,
|
||||||
|
parentId: str,
|
||||||
|
containerPath: str,
|
||||||
|
depth: int,
|
||||||
|
state: Dict[str, int],
|
||||||
|
parts: List[ContentPart],
|
||||||
|
) -> None:
|
||||||
|
if depth > MAX_DEPTH:
|
||||||
|
raise ContainerLimitError(f"Max folder depth {MAX_DEPTH} exceeded")
|
||||||
|
|
||||||
|
try:
|
||||||
|
entries = sorted(folder.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower()))
|
||||||
|
except PermissionError:
|
||||||
|
logger.warning(f"Permission denied: {folder}")
|
||||||
|
return
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
if entry.is_symlink():
|
||||||
|
logger.debug(f"Skipping symlink: {entry}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
entryPath = f"{containerPath}/{entry.name}" if containerPath else entry.name
|
||||||
|
|
||||||
|
if entry.is_dir():
|
||||||
|
folderId = makeId()
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=folderId,
|
||||||
|
parentId=parentId,
|
||||||
|
label=entry.name,
|
||||||
|
typeGroup="container",
|
||||||
|
mimeType="inode/directory",
|
||||||
|
data="",
|
||||||
|
metadata={"containerPath": entryPath, "containerType": "folder"},
|
||||||
|
))
|
||||||
|
_walkFolder(entry, folderId, entryPath, depth + 1, state, parts)
|
||||||
|
|
||||||
|
elif entry.is_file():
|
||||||
|
try:
|
||||||
|
fileSize = entry.stat().st_size
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
state["totalSize"] += fileSize
|
||||||
|
state["fileCount"] += 1
|
||||||
|
|
||||||
|
if state["totalSize"] > MAX_TOTAL_EXTRACTED_SIZE:
|
||||||
|
raise ContainerLimitError(f"Total extracted size exceeds {MAX_TOTAL_EXTRACTED_SIZE // (1024 * 1024)} MB")
|
||||||
|
if state["fileCount"] > MAX_FILE_COUNT:
|
||||||
|
raise ContainerLimitError(f"File count exceeds {MAX_FILE_COUNT}")
|
||||||
|
|
||||||
|
guessedMime, _ = mimetypes.guess_type(entry.name)
|
||||||
|
detectedMime = guessedMime or "application/octet-stream"
|
||||||
|
|
||||||
|
from ..subRegistry import ExtractorRegistry
|
||||||
|
registry = ExtractorRegistry()
|
||||||
|
extractor = registry.resolve(detectedMime, entry.name)
|
||||||
|
|
||||||
|
if extractor and not isinstance(extractor, FolderExtractor):
|
||||||
|
try:
|
||||||
|
fileData = entry.read_bytes()
|
||||||
|
childParts = extractor.extract(fileData, {"fileName": entry.name, "mimeType": detectedMime})
|
||||||
|
for part in childParts:
|
||||||
|
part.parentId = parentId
|
||||||
|
if not part.metadata:
|
||||||
|
part.metadata = {}
|
||||||
|
part.metadata["containerPath"] = entryPath
|
||||||
|
parts.extend(childParts)
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Type-extractor failed for {entry.name}: {e}")
|
||||||
|
|
||||||
|
import base64
|
||||||
|
try:
|
||||||
|
fileData = entry.read_bytes()
|
||||||
|
encodedData = base64.b64encode(fileData).decode("utf-8")
|
||||||
|
except Exception:
|
||||||
|
encodedData = ""
|
||||||
|
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=makeId(),
|
||||||
|
parentId=parentId,
|
||||||
|
label=entry.name,
|
||||||
|
typeGroup="binary",
|
||||||
|
mimeType=detectedMime,
|
||||||
|
data=encodedData,
|
||||||
|
metadata={
|
||||||
|
"size": fileSize,
|
||||||
|
"containerPath": entryPath,
|
||||||
|
"contextRef": ContentContextRef(
|
||||||
|
containerPath=entryPath,
|
||||||
|
location="file",
|
||||||
|
).model_dump(),
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
@ -89,7 +89,15 @@ class PdfExtractor(Extractor):
|
||||||
typeGroup="text",
|
typeGroup="text",
|
||||||
mimeType="text/plain",
|
mimeType="text/plain",
|
||||||
data=text,
|
data=text,
|
||||||
metadata={"pages": 1, "pageIndex": i, "size": len(text.encode('utf-8'))}
|
metadata={
|
||||||
|
"pages": 1, "pageIndex": i,
|
||||||
|
"size": len(text.encode('utf-8')),
|
||||||
|
"contextRef": {
|
||||||
|
"containerPath": context.get("fileName", "document.pdf"),
|
||||||
|
"location": f"page:{i+1}",
|
||||||
|
"pageIndex": i,
|
||||||
|
},
|
||||||
|
}
|
||||||
))
|
))
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
@ -114,7 +122,15 @@ class PdfExtractor(Extractor):
|
||||||
typeGroup="text",
|
typeGroup="text",
|
||||||
mimeType="text/plain",
|
mimeType="text/plain",
|
||||||
data=text,
|
data=text,
|
||||||
metadata={"pages": 1, "pageIndex": i, "size": len(text.encode('utf-8'))}
|
metadata={
|
||||||
|
"pages": 1, "pageIndex": i,
|
||||||
|
"size": len(text.encode('utf-8')),
|
||||||
|
"contextRef": {
|
||||||
|
"containerPath": context.get("fileName", "document.pdf"),
|
||||||
|
"location": f"page:{i+1}",
|
||||||
|
"pageIndex": i,
|
||||||
|
},
|
||||||
|
}
|
||||||
))
|
))
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
@ -143,7 +159,14 @@ class PdfExtractor(Extractor):
|
||||||
typeGroup="image",
|
typeGroup="image",
|
||||||
mimeType=f"image/{ext}",
|
mimeType=f"image/{ext}",
|
||||||
data=base64.b64encode(imgBytes).decode("utf-8"),
|
data=base64.b64encode(imgBytes).decode("utf-8"),
|
||||||
metadata={"pageIndex": i, "size": len(imgBytes)}
|
metadata={
|
||||||
|
"pageIndex": i, "size": len(imgBytes),
|
||||||
|
"contextRef": {
|
||||||
|
"containerPath": context.get("fileName", "document.pdf"),
|
||||||
|
"location": f"page:{i+1}/image:{j}",
|
||||||
|
"pageIndex": i,
|
||||||
|
},
|
||||||
|
}
|
||||||
))
|
))
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
|
|
@ -119,17 +119,22 @@ class PptxExtractor(Extractor):
|
||||||
image_bytes = image.blob
|
image_bytes = image.blob
|
||||||
image_b64 = base64.b64encode(image_bytes).decode('utf-8')
|
image_b64 = base64.b64encode(image_bytes).decode('utf-8')
|
||||||
|
|
||||||
# Create image part
|
fileName = context.get("fileName", "presentation.pptx")
|
||||||
image_part = ContentPart(
|
image_part = ContentPart(
|
||||||
id=f"slide_{slide_index}_image_{len(parts)}",
|
id=f"slide_{slide_index}_image_{len(parts)}",
|
||||||
label=f"Slide {slide_index} Image",
|
label=f"Slide {slide_index} Image",
|
||||||
typeGroup="image",
|
typeGroup="image",
|
||||||
mimeType="image/png", # Default to PNG
|
mimeType="image/png",
|
||||||
data=image_b64,
|
data=image_b64,
|
||||||
metadata={
|
metadata={
|
||||||
"slide_number": slide_index,
|
"slide_number": slide_index,
|
||||||
"shape_type": "image",
|
"shape_type": "image",
|
||||||
"extracted_from": "powerpoint"
|
"extracted_from": "powerpoint",
|
||||||
|
"contextRef": {
|
||||||
|
"containerPath": fileName,
|
||||||
|
"location": f"slide:{slide_index}/image",
|
||||||
|
"slideIndex": slide_index - 1,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
parts.append(image_part)
|
parts.append(image_part)
|
||||||
|
|
@ -140,6 +145,7 @@ class PptxExtractor(Extractor):
|
||||||
if slide_content:
|
if slide_content:
|
||||||
slide_text = f"# Slide {slide_index}\n\n" + "\n\n".join(slide_content)
|
slide_text = f"# Slide {slide_index}\n\n" + "\n\n".join(slide_content)
|
||||||
|
|
||||||
|
fileName = context.get("fileName", "presentation.pptx")
|
||||||
slide_part = ContentPart(
|
slide_part = ContentPart(
|
||||||
id=f"slide_{slide_index}",
|
id=f"slide_{slide_index}",
|
||||||
label=f"Slide {slide_index} Content",
|
label=f"Slide {slide_index} Content",
|
||||||
|
|
@ -150,7 +156,12 @@ class PptxExtractor(Extractor):
|
||||||
"slide_number": slide_index,
|
"slide_number": slide_index,
|
||||||
"content_type": "slide",
|
"content_type": "slide",
|
||||||
"extracted_from": "powerpoint",
|
"extracted_from": "powerpoint",
|
||||||
"text_length": len(slide_text)
|
"text_length": len(slide_text),
|
||||||
|
"contextRef": {
|
||||||
|
"containerPath": fileName,
|
||||||
|
"location": f"slide:{slide_index}",
|
||||||
|
"slideIndex": slide_index - 1,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
parts.append(slide_part)
|
parts.append(slide_part)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,208 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Video extractor for common video formats.
|
||||||
|
|
||||||
|
Extracts metadata (duration, resolution, codec, bitrate) and produces
|
||||||
|
a `videostream` ContentPart. Video data is never base64-encoded due to size.
|
||||||
|
|
||||||
|
Optional dependency: mutagen (for rich metadata from MP4/WebM containers).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
import logging
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from modules.datamodels.datamodelExtraction import ContentPart
|
||||||
|
from ..subUtils import makeId
|
||||||
|
from ..subRegistry import Extractor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_VIDEO_MIME_TYPES = [
|
||||||
|
"video/mp4",
|
||||||
|
"video/webm",
|
||||||
|
"video/x-msvideo",
|
||||||
|
"video/avi",
|
||||||
|
"video/quicktime",
|
||||||
|
"video/x-matroska",
|
||||||
|
"video/x-ms-wmv",
|
||||||
|
"video/mpeg",
|
||||||
|
"video/ogg",
|
||||||
|
]
|
||||||
|
_VIDEO_EXTENSIONS = [".mp4", ".webm", ".avi", ".mov", ".mkv", ".wmv", ".mpeg", ".mpg", ".ogv"]
|
||||||
|
|
||||||
|
|
||||||
|
class VideoExtractor(Extractor):
|
||||||
|
"""Extractor for video files.
|
||||||
|
|
||||||
|
Produces:
|
||||||
|
- 1 text ContentPart with metadata summary
|
||||||
|
- 1 videostream ContentPart (no inline data -- too large)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def detect(self, fileName: str, mimeType: str, headBytes: bytes) -> bool:
|
||||||
|
if mimeType in _VIDEO_MIME_TYPES:
|
||||||
|
return True
|
||||||
|
lower = (fileName or "").lower()
|
||||||
|
return any(lower.endswith(ext) for ext in _VIDEO_EXTENSIONS)
|
||||||
|
|
||||||
|
def getSupportedExtensions(self) -> list[str]:
|
||||||
|
return list(_VIDEO_EXTENSIONS)
|
||||||
|
|
||||||
|
def getSupportedMimeTypes(self) -> list[str]:
|
||||||
|
return list(_VIDEO_MIME_TYPES)
|
||||||
|
|
||||||
|
def extract(self, fileBytes: bytes, context: Dict[str, Any]) -> List[ContentPart]:
|
||||||
|
fileName = context.get("fileName", "video")
|
||||||
|
mimeType = context.get("mimeType") or "video/mp4"
|
||||||
|
fileSize = len(fileBytes)
|
||||||
|
|
||||||
|
rootId = makeId()
|
||||||
|
parts: List[ContentPart] = []
|
||||||
|
|
||||||
|
meta = _extractMetadata(fileBytes, fileName)
|
||||||
|
meta["size"] = fileSize
|
||||||
|
meta["fileName"] = fileName
|
||||||
|
meta["mimeType"] = mimeType
|
||||||
|
|
||||||
|
metaLines = [f"Video file: {fileName}"]
|
||||||
|
if meta.get("duration"):
|
||||||
|
mins = int(meta["duration"] // 60)
|
||||||
|
secs = int(meta["duration"] % 60)
|
||||||
|
metaLines.append(f"Duration: {mins}:{secs:02d}")
|
||||||
|
if meta.get("width") and meta.get("height"):
|
||||||
|
metaLines.append(f"Resolution: {meta['width']}x{meta['height']}")
|
||||||
|
if meta.get("codec"):
|
||||||
|
metaLines.append(f"Codec: {meta['codec']}")
|
||||||
|
if meta.get("bitrate"):
|
||||||
|
metaLines.append(f"Bitrate: {meta['bitrate']} kbps")
|
||||||
|
if meta.get("fps"):
|
||||||
|
metaLines.append(f"FPS: {meta['fps']}")
|
||||||
|
metaLines.append(f"Size: {fileSize:,} bytes")
|
||||||
|
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=rootId, parentId=None, label="metadata",
|
||||||
|
typeGroup="text", mimeType="text/plain",
|
||||||
|
data="\n".join(metaLines), metadata=meta,
|
||||||
|
))
|
||||||
|
|
||||||
|
parts.append(ContentPart(
|
||||||
|
id=makeId(), parentId=rootId, label="videostream",
|
||||||
|
typeGroup="videostream", mimeType=mimeType,
|
||||||
|
data="", metadata={"size": fileSize, "inlined": False},
|
||||||
|
))
|
||||||
|
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def _extractMetadata(fileBytes: bytes, fileName: str) -> Dict[str, Any]:
|
||||||
|
"""Extract video metadata using mutagen (optional) with basic fallback."""
|
||||||
|
meta: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
import mutagen
|
||||||
|
import io
|
||||||
|
mediaFile = mutagen.File(io.BytesIO(fileBytes))
|
||||||
|
if mediaFile is not None and mediaFile.info:
|
||||||
|
meta["duration"] = getattr(mediaFile.info, "length", None)
|
||||||
|
meta["bitrate"] = getattr(mediaFile.info, "bitrate", None)
|
||||||
|
if meta["bitrate"]:
|
||||||
|
meta["bitrate"] = meta["bitrate"] // 1000
|
||||||
|
|
||||||
|
if hasattr(mediaFile.info, "video"):
|
||||||
|
for stream in (mediaFile.info.video if isinstance(mediaFile.info.video, list) else [mediaFile.info.video]):
|
||||||
|
if hasattr(stream, "width"):
|
||||||
|
meta["width"] = stream.width
|
||||||
|
if hasattr(stream, "height"):
|
||||||
|
meta["height"] = stream.height
|
||||||
|
if hasattr(stream, "codec"):
|
||||||
|
meta["codec"] = stream.codec
|
||||||
|
|
||||||
|
width = getattr(mediaFile.info, "width", None)
|
||||||
|
height = getattr(mediaFile.info, "height", None)
|
||||||
|
if width and height:
|
||||||
|
meta["width"] = width
|
||||||
|
meta["height"] = height
|
||||||
|
|
||||||
|
fps = getattr(mediaFile.info, "fps", None)
|
||||||
|
if fps:
|
||||||
|
meta["fps"] = round(fps, 2)
|
||||||
|
|
||||||
|
codec = getattr(mediaFile.info, "codec", None)
|
||||||
|
if codec:
|
||||||
|
meta["codec"] = codec
|
||||||
|
|
||||||
|
return {k: v for k, v in meta.items() if v is not None}
|
||||||
|
except ImportError:
|
||||||
|
logger.debug("mutagen not installed -- using basic video metadata extraction")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"mutagen video metadata extraction failed: {e}")
|
||||||
|
|
||||||
|
lower = fileName.lower()
|
||||||
|
if lower.endswith(".mp4"):
|
||||||
|
meta.update(_parseMp4Header(fileBytes))
|
||||||
|
elif lower.endswith(".avi"):
|
||||||
|
meta.update(_parseAviHeader(fileBytes))
|
||||||
|
|
||||||
|
return {k: v for k, v in meta.items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
|
def _parseMp4Header(fileBytes: bytes) -> Dict[str, Any]:
|
||||||
|
"""Minimal MP4 moov/mvhd parser for duration and timescale."""
|
||||||
|
meta: Dict[str, Any] = {}
|
||||||
|
try:
|
||||||
|
pos = 0
|
||||||
|
while pos < len(fileBytes) - 8:
|
||||||
|
boxSize = struct.unpack_from(">I", fileBytes, pos)[0]
|
||||||
|
boxType = fileBytes[pos + 4:pos + 8]
|
||||||
|
if boxSize < 8:
|
||||||
|
break
|
||||||
|
if boxType == b"moov":
|
||||||
|
meta.update(_parseMoovBox(fileBytes[pos + 8:pos + boxSize]))
|
||||||
|
break
|
||||||
|
pos += boxSize
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return meta
|
||||||
|
|
||||||
|
|
||||||
|
def _parseMoovBox(data: bytes) -> Dict[str, Any]:
|
||||||
|
"""Parse moov box to find mvhd with duration."""
|
||||||
|
meta: Dict[str, Any] = {}
|
||||||
|
pos = 0
|
||||||
|
while pos < len(data) - 8:
|
||||||
|
try:
|
||||||
|
boxSize = struct.unpack_from(">I", data, pos)[0]
|
||||||
|
boxType = data[pos + 4:pos + 8]
|
||||||
|
if boxSize < 8:
|
||||||
|
break
|
||||||
|
if boxType == b"mvhd":
|
||||||
|
version = data[pos + 8]
|
||||||
|
if version == 0 and pos + 28 < len(data):
|
||||||
|
timeScale = struct.unpack_from(">I", data, pos + 20)[0]
|
||||||
|
duration = struct.unpack_from(">I", data, pos + 24)[0]
|
||||||
|
if timeScale > 0:
|
||||||
|
meta["duration"] = duration / timeScale
|
||||||
|
break
|
||||||
|
pos += boxSize
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
return meta
|
||||||
|
|
||||||
|
|
||||||
|
def _parseAviHeader(fileBytes: bytes) -> Dict[str, Any]:
|
||||||
|
"""Minimal AVI header parser for resolution."""
|
||||||
|
meta: Dict[str, Any] = {}
|
||||||
|
if len(fileBytes) < 72:
|
||||||
|
return meta
|
||||||
|
try:
|
||||||
|
if fileBytes[:4] != b"RIFF" or fileBytes[8:12] != b"AVI ":
|
||||||
|
return meta
|
||||||
|
width = struct.unpack_from("<I", fileBytes, 64)[0]
|
||||||
|
height = struct.unpack_from("<I", fileBytes, 68)[0]
|
||||||
|
if 0 < width < 100000 and 0 < height < 100000:
|
||||||
|
meta["width"] = width
|
||||||
|
meta["height"] = height
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return meta
|
||||||
|
|
@ -99,6 +99,7 @@ class XlsxExtractor(Extractor):
|
||||||
cells.append(f'"{escaped_value}"')
|
cells.append(f'"{escaped_value}"')
|
||||||
lines.append(",".join(cells))
|
lines.append(",".join(cells))
|
||||||
csvData = "\n".join(lines)
|
csvData = "\n".join(lines)
|
||||||
|
fileName = context.get("fileName", "spreadsheet.xlsx")
|
||||||
parts.append(ContentPart(
|
parts.append(ContentPart(
|
||||||
id=makeId(),
|
id=makeId(),
|
||||||
parentId=rootId,
|
parentId=rootId,
|
||||||
|
|
@ -106,7 +107,15 @@ class XlsxExtractor(Extractor):
|
||||||
typeGroup="table",
|
typeGroup="table",
|
||||||
mimeType="text/csv",
|
mimeType="text/csv",
|
||||||
data=csvData,
|
data=csvData,
|
||||||
metadata={"sheet": sheetName, "size": len(csvData.encode('utf-8'))}
|
metadata={
|
||||||
|
"sheet": sheetName,
|
||||||
|
"size": len(csvData.encode('utf-8')),
|
||||||
|
"contextRef": {
|
||||||
|
"containerPath": fileName,
|
||||||
|
"location": f"sheet:{sheetName}",
|
||||||
|
"sheetName": sheetName,
|
||||||
|
},
|
||||||
|
}
|
||||||
))
|
))
|
||||||
|
|
||||||
return parts
|
return parts
|
||||||
|
|
|
||||||
|
|
@ -191,9 +191,11 @@ class ChunkerRegistry:
|
||||||
self.register("table", TableChunker())
|
self.register("table", TableChunker())
|
||||||
self.register("structure", StructureChunker())
|
self.register("structure", StructureChunker())
|
||||||
self.register("image", ImageChunker())
|
self.register("image", ImageChunker())
|
||||||
# Use text chunker for container and binary content
|
# Use text chunker for container, binary, and media stream content
|
||||||
self.register("container", TextChunker())
|
self.register("container", TextChunker())
|
||||||
self.register("binary", TextChunker())
|
self.register("binary", TextChunker())
|
||||||
|
self.register("audiostream", TextChunker())
|
||||||
|
self.register("videostream", TextChunker())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"ChunkerRegistry: Failed to register chunkers: {str(e)}")
|
logger.error(f"ChunkerRegistry: Failed to register chunkers: {str(e)}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""serviceKnowledge: 3-tier RAG Knowledge Store with semantic search."""
|
||||||
|
|
@ -0,0 +1,531 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Knowledge service: 3-tier RAG with indexing, semantic search, and context building."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
from modules.datamodels.datamodelKnowledge import (
|
||||||
|
FileContentIndex, ContentChunk, WorkflowMemory,
|
||||||
|
)
|
||||||
|
from modules.datamodels.datamodelAi import AiCallOptions, OperationTypeEnum
|
||||||
|
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
|
||||||
|
from modules.shared.timeUtils import getUtcTimestamp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_CHUNK_SIZE = 512
|
||||||
|
DEFAULT_CONTEXT_BUDGET = 8000
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeService:
|
||||||
|
"""Service for Knowledge Store operations: indexing, retrieval, and context building."""
|
||||||
|
|
||||||
|
def __init__(self, context, get_service: Callable[[str], Any]):
|
||||||
|
self._context = context
|
||||||
|
self._getService = get_service
|
||||||
|
self._knowledgeDb = getKnowledgeInterface(context.user)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Embedding helper
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def _embed(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Embed texts via the AI interface's generic embedding method."""
|
||||||
|
aiService = self._getService("ai")
|
||||||
|
await aiService.ensureAiObjectsInitialized()
|
||||||
|
aiObjects = aiService.aiObjects
|
||||||
|
if aiObjects is None:
|
||||||
|
logger.warning("Embedding skipped: aiObjects not available")
|
||||||
|
return []
|
||||||
|
response = await aiObjects.callEmbedding(texts)
|
||||||
|
if response.errorCount > 0:
|
||||||
|
logger.error(f"Embedding failed: {response.content}")
|
||||||
|
return []
|
||||||
|
return (response.metadata or {}).get("embeddings", [])
|
||||||
|
|
||||||
|
async def _embedSingle(self, text: str) -> List[float]:
|
||||||
|
"""Embed a single text. Returns empty list on failure."""
|
||||||
|
results = await self._embed([text])
|
||||||
|
return results[0] if results else []
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# File Indexing (called after extraction, before embedding)
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def indexFile(
|
||||||
|
self,
|
||||||
|
fileId: str,
|
||||||
|
fileName: str,
|
||||||
|
mimeType: str,
|
||||||
|
userId: str,
|
||||||
|
featureInstanceId: str = "",
|
||||||
|
mandateId: str = "",
|
||||||
|
contentObjects: List[Dict[str, Any]] = None,
|
||||||
|
structure: Dict[str, Any] = None,
|
||||||
|
containerPath: str = None,
|
||||||
|
) -> FileContentIndex:
|
||||||
|
"""Index a file's content objects and create embeddings for text chunks.
|
||||||
|
|
||||||
|
This is the main entry point after non-AI extraction has produced content objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fileId: The file ID.
|
||||||
|
fileName: Original file name.
|
||||||
|
mimeType: MIME type.
|
||||||
|
userId: Owner user.
|
||||||
|
featureInstanceId: Feature instance scope.
|
||||||
|
mandateId: Mandate scope.
|
||||||
|
contentObjects: List of extracted content objects, each with keys:
|
||||||
|
contentType (str), data (str), contextRef (dict), contentObjectId (str).
|
||||||
|
structure: Structural overview of the file.
|
||||||
|
containerPath: Path within container if applicable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created FileContentIndex.
|
||||||
|
"""
|
||||||
|
contentObjects = contentObjects or []
|
||||||
|
|
||||||
|
# 1. Create FileContentIndex
|
||||||
|
index = FileContentIndex(
|
||||||
|
id=fileId,
|
||||||
|
userId=userId,
|
||||||
|
featureInstanceId=featureInstanceId,
|
||||||
|
mandateId=mandateId,
|
||||||
|
fileName=fileName,
|
||||||
|
mimeType=mimeType,
|
||||||
|
containerPath=containerPath,
|
||||||
|
totalObjects=len(contentObjects),
|
||||||
|
totalSize=sum(len(obj.get("data", "").encode("utf-8")) for obj in contentObjects),
|
||||||
|
structure=structure or {},
|
||||||
|
objectSummary=[
|
||||||
|
{
|
||||||
|
"id": obj.get("contentObjectId", ""),
|
||||||
|
"type": obj.get("contentType", "other"),
|
||||||
|
"size": len(obj.get("data", "").encode("utf-8")),
|
||||||
|
"ref": obj.get("contextRef", {}),
|
||||||
|
}
|
||||||
|
for obj in contentObjects
|
||||||
|
],
|
||||||
|
status="extracted",
|
||||||
|
)
|
||||||
|
self._knowledgeDb.upsertFileContentIndex(index)
|
||||||
|
|
||||||
|
# 2. Chunk text content objects and create embeddings
|
||||||
|
textObjects = [o for o in contentObjects if o.get("contentType") == "text"]
|
||||||
|
if textObjects:
|
||||||
|
self._knowledgeDb.updateFileStatus(fileId, "embedding")
|
||||||
|
chunks = _chunkForEmbedding(textObjects, chunkSize=DEFAULT_CHUNK_SIZE)
|
||||||
|
texts = [c["data"] for c in chunks]
|
||||||
|
|
||||||
|
embeddings = await self._embed(texts) if texts else []
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
embedding = embeddings[i] if i < len(embeddings) else None
|
||||||
|
contentChunk = ContentChunk(
|
||||||
|
contentObjectId=chunk["contentObjectId"],
|
||||||
|
fileId=fileId,
|
||||||
|
userId=userId,
|
||||||
|
featureInstanceId=featureInstanceId,
|
||||||
|
contentType="text",
|
||||||
|
data=chunk["data"],
|
||||||
|
contextRef=chunk["contextRef"],
|
||||||
|
embedding=embedding,
|
||||||
|
)
|
||||||
|
self._knowledgeDb.upsertContentChunk(contentChunk)
|
||||||
|
|
||||||
|
# 3. Store non-text content objects (images, etc.) without embedding
|
||||||
|
nonTextObjects = [o for o in contentObjects if o.get("contentType") != "text"]
|
||||||
|
for obj in nonTextObjects:
|
||||||
|
contentChunk = ContentChunk(
|
||||||
|
contentObjectId=obj.get("contentObjectId", ""),
|
||||||
|
fileId=fileId,
|
||||||
|
userId=userId,
|
||||||
|
featureInstanceId=featureInstanceId,
|
||||||
|
contentType=obj.get("contentType", "other"),
|
||||||
|
data=obj.get("data", ""),
|
||||||
|
contextRef=obj.get("contextRef", {}),
|
||||||
|
embedding=None,
|
||||||
|
)
|
||||||
|
self._knowledgeDb.upsertContentChunk(contentChunk)
|
||||||
|
|
||||||
|
self._knowledgeDb.updateFileStatus(fileId, "indexed")
|
||||||
|
index.status = "indexed"
|
||||||
|
logger.info(f"Indexed file {fileId} ({fileName}): {len(contentObjects)} objects, {len(textObjects)} text chunks")
|
||||||
|
return index
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# RAG Context Building (3-tier search)
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def buildAgentContext(
|
||||||
|
self,
|
||||||
|
currentPrompt: str,
|
||||||
|
workflowId: str,
|
||||||
|
userId: str,
|
||||||
|
featureInstanceId: str = "",
|
||||||
|
mandateId: str = "",
|
||||||
|
contextBudget: int = DEFAULT_CONTEXT_BUDGET,
|
||||||
|
) -> str:
|
||||||
|
"""Build RAG context for an agent round by searching all 3 layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
currentPrompt: The current user prompt to find relevant context for.
|
||||||
|
workflowId: Current workflow ID.
|
||||||
|
userId: Current user.
|
||||||
|
featureInstanceId: Feature instance scope.
|
||||||
|
mandateId: Mandate scope.
|
||||||
|
contextBudget: Maximum characters for the context string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted context string for injection into the agent's system prompt.
|
||||||
|
"""
|
||||||
|
queryVector = await self._embedSingle(currentPrompt)
|
||||||
|
if not queryVector:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
builder = _ContextBuilder(budget=contextBudget)
|
||||||
|
|
||||||
|
# Layer 1: Instance Layer (user's own documents, highest priority)
|
||||||
|
instanceChunks = self._knowledgeDb.semanticSearch(
|
||||||
|
queryVector=queryVector,
|
||||||
|
userId=userId,
|
||||||
|
featureInstanceId=featureInstanceId,
|
||||||
|
limit=15,
|
||||||
|
minScore=0.65,
|
||||||
|
)
|
||||||
|
if instanceChunks:
|
||||||
|
builder.add(priority=1, label="Relevant Documents", items=instanceChunks)
|
||||||
|
|
||||||
|
# Layer 2: Workflow Layer (current workflow entities & memory)
|
||||||
|
entities = self._knowledgeDb.getWorkflowEntities(workflowId)
|
||||||
|
if entities:
|
||||||
|
builder.add(priority=2, label="Workflow Context", items=entities, isKeyValue=True)
|
||||||
|
|
||||||
|
# Layer 3: Shared Layer (mandate-wide shared documents)
|
||||||
|
sharedChunks = self._knowledgeDb.semanticSearch(
|
||||||
|
queryVector=queryVector,
|
||||||
|
mandateId=mandateId,
|
||||||
|
isShared=True,
|
||||||
|
limit=10,
|
||||||
|
minScore=0.7,
|
||||||
|
)
|
||||||
|
if sharedChunks:
|
||||||
|
builder.add(priority=3, label="Shared Knowledge", items=sharedChunks)
|
||||||
|
|
||||||
|
return builder.build()
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Workflow Memory
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def storeEntity(
|
||||||
|
self,
|
||||||
|
workflowId: str,
|
||||||
|
userId: str,
|
||||||
|
featureInstanceId: str,
|
||||||
|
key: str,
|
||||||
|
value: str,
|
||||||
|
source: str = "extraction",
|
||||||
|
) -> WorkflowMemory:
|
||||||
|
"""Store a key-value entity in workflow memory with optional embedding."""
|
||||||
|
embedding = await self._embedSingle(f"{key}: {value}")
|
||||||
|
memory = WorkflowMemory(
|
||||||
|
workflowId=workflowId,
|
||||||
|
userId=userId,
|
||||||
|
featureInstanceId=featureInstanceId,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
source=source,
|
||||||
|
embedding=embedding if embedding else None,
|
||||||
|
)
|
||||||
|
self._knowledgeDb.upsertWorkflowMemory(memory)
|
||||||
|
return memory
|
||||||
|
|
||||||
|
def getEntities(self, workflowId: str) -> List[Dict[str, Any]]:
|
||||||
|
"""Get all entities for a workflow."""
|
||||||
|
return self._knowledgeDb.getWorkflowEntities(workflowId)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# File Status
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def getFileStatus(self, fileId: str) -> Optional[str]:
|
||||||
|
"""Get the indexing status of a file."""
|
||||||
|
index = self._knowledgeDb.getFileContentIndex(fileId)
|
||||||
|
return index.get("status") if index else None
|
||||||
|
|
||||||
|
def isFileIndexed(self, fileId: str) -> bool:
|
||||||
|
"""Check if a file has been fully indexed."""
|
||||||
|
return self.getFileStatus(fileId) == "indexed"
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# On-Demand Extraction (Smart Document Handling)
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def readSection(self, fileId: str, sectionId: str) -> List[Dict[str, Any]]:
|
||||||
|
"""Read content objects for a specific section. Uses cache if available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fileId: Source file ID.
|
||||||
|
sectionId: Section identifier from the FileContentIndex structure.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of content object dicts with data and contextRef.
|
||||||
|
"""
|
||||||
|
cached = self._knowledgeDb.getContentChunks(fileId)
|
||||||
|
sectionChunks = [
|
||||||
|
c for c in (cached or [])
|
||||||
|
if (c.get("contextRef", {}).get("sectionId") == sectionId)
|
||||||
|
]
|
||||||
|
if sectionChunks:
|
||||||
|
return sectionChunks
|
||||||
|
|
||||||
|
index = self._knowledgeDb.getFileContentIndex(fileId)
|
||||||
|
if not index:
|
||||||
|
return []
|
||||||
|
|
||||||
|
structure = index.get("structure", {}) if isinstance(index, dict) else getattr(index, "structure", {})
|
||||||
|
sections = structure.get("sections", [])
|
||||||
|
section = next((s for s in sections if s.get("id") == sectionId), None)
|
||||||
|
if not section:
|
||||||
|
return []
|
||||||
|
|
||||||
|
startPage = section.get("startPage", 0)
|
||||||
|
endPage = section.get("endPage", startPage)
|
||||||
|
|
||||||
|
return await self._extractPagesOnDemand(fileId, startPage, endPage, sectionId)
|
||||||
|
|
||||||
|
async def readContentObjects(
|
||||||
|
self, fileId: str, filter: Dict[str, Any] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Read content objects with optional filters (pageIndex, contentType, sectionId).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fileId: Source file ID.
|
||||||
|
filter: Optional dict with keys pageIndex (list[int]), contentType (str), sectionId (str).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of content chunk dicts.
|
||||||
|
"""
|
||||||
|
filter = filter or {}
|
||||||
|
chunks = self._knowledgeDb.getContentChunks(fileId) or []
|
||||||
|
|
||||||
|
if "pageIndex" in filter:
|
||||||
|
targetPages = filter["pageIndex"]
|
||||||
|
if isinstance(targetPages, int):
|
||||||
|
targetPages = [targetPages]
|
||||||
|
chunks = [
|
||||||
|
c for c in chunks
|
||||||
|
if c.get("contextRef", {}).get("pageIndex") in targetPages
|
||||||
|
]
|
||||||
|
|
||||||
|
if "contentType" in filter:
|
||||||
|
chunks = [c for c in chunks if c.get("contentType") == filter["contentType"]]
|
||||||
|
|
||||||
|
if "sectionId" in filter:
|
||||||
|
chunks = [
|
||||||
|
c for c in chunks
|
||||||
|
if c.get("contextRef", {}).get("sectionId") == filter["sectionId"]
|
||||||
|
]
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
async def extractContainerItem(
|
||||||
|
self, fileId: str, containerPath: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""On-demand extraction of a specific item within a container.
|
||||||
|
|
||||||
|
If the item is already indexed, returns existing data.
|
||||||
|
Otherwise triggers extraction and indexing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fileId: The container file ID.
|
||||||
|
containerPath: Path within the container (e.g. "folder/report.pdf").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FileContentIndex dict for the extracted item, or None.
|
||||||
|
"""
|
||||||
|
existing = self._knowledgeDb.getFileContentIndex(fileId)
|
||||||
|
if existing:
|
||||||
|
existingPath = existing.get("containerPath") if isinstance(existing, dict) else getattr(existing, "containerPath", None)
|
||||||
|
if existingPath == containerPath:
|
||||||
|
return existing
|
||||||
|
|
||||||
|
logger.info(f"On-demand extraction for {containerPath} in file {fileId}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _extractPagesOnDemand(
|
||||||
|
self, fileId: str, startPage: int, endPage: int, sectionId: str
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Extract specific pages from a file and cache in knowledge store."""
|
||||||
|
try:
|
||||||
|
chatService = self._getService("chat")
|
||||||
|
fileContent = chatService.getFileContent(fileId)
|
||||||
|
if not fileContent:
|
||||||
|
return []
|
||||||
|
|
||||||
|
fileData = fileContent.get("data", b"")
|
||||||
|
mimeType = fileContent.get("mimeType", "")
|
||||||
|
fileName = fileContent.get("fileName", "")
|
||||||
|
|
||||||
|
if isinstance(fileData, str):
|
||||||
|
import base64
|
||||||
|
fileData = base64.b64decode(fileData)
|
||||||
|
|
||||||
|
if mimeType != "application/pdf":
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
import fitz
|
||||||
|
except ImportError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
doc = fitz.open(stream=fileData, filetype="pdf")
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for pageIdx in range(startPage, min(endPage + 1, len(doc))):
|
||||||
|
page = doc[pageIdx]
|
||||||
|
text = page.get_text() or ""
|
||||||
|
if not text.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = ContentChunk(
|
||||||
|
contentObjectId=f"page-{pageIdx}",
|
||||||
|
fileId=fileId,
|
||||||
|
userId=self._context.user.id if self._context.user else "",
|
||||||
|
featureInstanceId=self._context.feature_instance_id or "",
|
||||||
|
contentType="text",
|
||||||
|
data=text,
|
||||||
|
contextRef={
|
||||||
|
"containerPath": fileName,
|
||||||
|
"location": f"page:{pageIdx+1}",
|
||||||
|
"pageIndex": pageIdx,
|
||||||
|
"sectionId": sectionId,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding = await self._embedSingle(text[:2000])
|
||||||
|
if embedding:
|
||||||
|
chunk.embedding = embedding
|
||||||
|
|
||||||
|
self._knowledgeDb.upsertContentChunk(chunk)
|
||||||
|
results.append(chunk.model_dump())
|
||||||
|
|
||||||
|
doc.close()
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"On-demand page extraction failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def getFileContentIndex(self, fileId: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get the FileContentIndex for a file."""
|
||||||
|
return self._knowledgeDb.getFileContentIndex(fileId)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Internal helpers
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def _chunkForEmbedding(
|
||||||
|
textObjects: List[Dict[str, Any]], chunkSize: int = 512
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Split text content objects into chunks suitable for embedding.
|
||||||
|
|
||||||
|
Each chunk preserves the contextRef from its source object.
|
||||||
|
Long texts are split at sentence boundaries where possible.
|
||||||
|
"""
|
||||||
|
chunks = []
|
||||||
|
for obj in textObjects:
|
||||||
|
text = obj.get("data", "")
|
||||||
|
contentObjectId = obj.get("contentObjectId", "")
|
||||||
|
contextRef = obj.get("contextRef", {})
|
||||||
|
|
||||||
|
if len(text) <= chunkSize:
|
||||||
|
chunks.append({
|
||||||
|
"data": text,
|
||||||
|
"contentObjectId": contentObjectId,
|
||||||
|
"contextRef": contextRef,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Split at sentence boundaries
|
||||||
|
sentences = text.replace("\n", " ").split(". ")
|
||||||
|
currentChunk = ""
|
||||||
|
for sentence in sentences:
|
||||||
|
candidate = f"{currentChunk}. {sentence}" if currentChunk else sentence
|
||||||
|
if len(candidate) > chunkSize and currentChunk:
|
||||||
|
chunks.append({
|
||||||
|
"data": currentChunk.strip(),
|
||||||
|
"contentObjectId": contentObjectId,
|
||||||
|
"contextRef": contextRef,
|
||||||
|
})
|
||||||
|
currentChunk = sentence
|
||||||
|
else:
|
||||||
|
currentChunk = candidate
|
||||||
|
|
||||||
|
if currentChunk.strip():
|
||||||
|
chunks.append({
|
||||||
|
"data": currentChunk.strip(),
|
||||||
|
"contentObjectId": contentObjectId,
|
||||||
|
"contextRef": contextRef,
|
||||||
|
})
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
class _ContextBuilder:
|
||||||
|
"""Assembles RAG context from multiple sources respecting a character budget."""
|
||||||
|
|
||||||
|
def __init__(self, budget: int):
|
||||||
|
self._budget = budget
|
||||||
|
self._sections: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
priority: int,
|
||||||
|
label: str,
|
||||||
|
items: List[Dict[str, Any]],
|
||||||
|
isKeyValue: bool = False,
|
||||||
|
):
|
||||||
|
self._sections.append({
|
||||||
|
"priority": priority,
|
||||||
|
"label": label,
|
||||||
|
"items": items,
|
||||||
|
"isKeyValue": isKeyValue,
|
||||||
|
})
|
||||||
|
|
||||||
|
def build(self) -> str:
|
||||||
|
self._sections.sort(key=lambda s: s["priority"])
|
||||||
|
parts = []
|
||||||
|
remaining = self._budget
|
||||||
|
|
||||||
|
for section in self._sections:
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
header = f"### {section['label']}\n"
|
||||||
|
sectionText = header
|
||||||
|
remaining -= len(header)
|
||||||
|
|
||||||
|
for item in section["items"]:
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
if section["isKeyValue"]:
|
||||||
|
line = f"- {item.get('key', '')}: {item.get('value', '')}\n"
|
||||||
|
else:
|
||||||
|
data = item.get("data", "")
|
||||||
|
ref = item.get("contextRef", {})
|
||||||
|
score = item.get("_score", "")
|
||||||
|
refStr = f" [{ref}]" if ref else ""
|
||||||
|
line = f"{data}{refStr}\n"
|
||||||
|
|
||||||
|
if len(line) <= remaining:
|
||||||
|
sectionText += line
|
||||||
|
remaining -= len(line)
|
||||||
|
|
||||||
|
parts.append(sectionText)
|
||||||
|
|
||||||
|
return "\n".join(parts).strip()
|
||||||
427
modules/serviceCenter/services/serviceKnowledge/subPreScan.py
Normal file
427
modules/serviceCenter/services/serviceKnowledge/subPreScan.py
Normal file
|
|
@ -0,0 +1,427 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Structure Pre-Scan: fast, AI-free document analysis.
|
||||||
|
|
||||||
|
Extracts TOC, headings, page map, image positions, and structural metadata
|
||||||
|
from documents. Used as the first step in the auto-index pipeline.
|
||||||
|
|
||||||
|
Supported formats:
|
||||||
|
- PDF: TOC, heading detection (font-size heuristic), page map, image positions
|
||||||
|
- DOCX: heading styles, paragraph map
|
||||||
|
- PPTX: slide titles, slide map
|
||||||
|
- XLSX: sheet names, row/column counts
|
||||||
|
- Other: minimal index (single content object = the file itself)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
|
from modules.datamodels.datamodelKnowledge import FileContentIndex
|
||||||
|
from modules.datamodels.datamodelContent import ContentObjectSummary, ContentContextRef
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def preScanDocument(
|
||||||
|
fileData: bytes,
|
||||||
|
mimeType: str,
|
||||||
|
fileId: str,
|
||||||
|
fileName: str = "",
|
||||||
|
userId: str = "",
|
||||||
|
featureInstanceId: str = "",
|
||||||
|
mandateId: str = "",
|
||||||
|
) -> FileContentIndex:
|
||||||
|
"""Create a structural FileContentIndex without AI.
|
||||||
|
|
||||||
|
This is purely programmatic: TOC extraction, heading detection,
|
||||||
|
page mapping, image position scanning.
|
||||||
|
"""
|
||||||
|
scanner = _SCANNER_MAP.get(mimeType)
|
||||||
|
if scanner is None:
|
||||||
|
ext = (fileName.rsplit(".", 1)[-1].lower()) if "." in fileName else ""
|
||||||
|
scanner = _EXTENSION_SCANNER_MAP.get(ext, _scanMinimal)
|
||||||
|
|
||||||
|
try:
|
||||||
|
structure, objectSummary, totalObjects, totalSize = await scanner(fileData, fileName)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Pre-scan failed for {fileName} ({mimeType}): {e}")
|
||||||
|
structure = {"error": str(e)}
|
||||||
|
objectSummary = []
|
||||||
|
totalObjects = 0
|
||||||
|
totalSize = len(fileData)
|
||||||
|
|
||||||
|
return FileContentIndex(
|
||||||
|
id=fileId,
|
||||||
|
userId=userId,
|
||||||
|
featureInstanceId=featureInstanceId,
|
||||||
|
mandateId=mandateId,
|
||||||
|
fileName=fileName,
|
||||||
|
mimeType=mimeType,
|
||||||
|
totalObjects=totalObjects,
|
||||||
|
totalSize=totalSize,
|
||||||
|
structure=structure,
|
||||||
|
objectSummary=[s.model_dump() for s in objectSummary],
|
||||||
|
status="extracted",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PDF scanner
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _scanPdf(fileData: bytes, fileName: str):
|
||||||
|
try:
|
||||||
|
import fitz
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("PyMuPDF not installed -- PDF pre-scan unavailable")
|
||||||
|
return _fallbackStructure(fileData, fileName)
|
||||||
|
|
||||||
|
doc = fitz.open(stream=fileData, filetype="pdf")
|
||||||
|
toc = doc.get_toc()
|
||||||
|
|
||||||
|
pageMap: List[Dict[str, Any]] = []
|
||||||
|
summaries: List[ContentObjectSummary] = []
|
||||||
|
totalSize = 0
|
||||||
|
objIndex = 0
|
||||||
|
|
||||||
|
for i in range(len(doc)):
|
||||||
|
page = doc[i]
|
||||||
|
textLen = len(page.get_text())
|
||||||
|
blocks = page.get_text("dict", flags=0).get("blocks", [])
|
||||||
|
|
||||||
|
headings = []
|
||||||
|
for b in blocks:
|
||||||
|
if b.get("type") != 0:
|
||||||
|
continue
|
||||||
|
for line in b.get("lines", []):
|
||||||
|
for span in line.get("spans", []):
|
||||||
|
if _isHeading(span):
|
||||||
|
headings.append(span.get("text", "").strip())
|
||||||
|
|
||||||
|
images = page.get_images(full=True)
|
||||||
|
hasTable = _detectTableHeuristic(page)
|
||||||
|
|
||||||
|
pageMap.append({
|
||||||
|
"pageIndex": i,
|
||||||
|
"headings": headings,
|
||||||
|
"hasImages": len(images) > 0,
|
||||||
|
"imageCount": len(images),
|
||||||
|
"textLength": textLen,
|
||||||
|
"hasTable": hasTable,
|
||||||
|
})
|
||||||
|
|
||||||
|
if textLen > 0:
|
||||||
|
summaries.append(ContentObjectSummary(
|
||||||
|
id=f"co-{objIndex}",
|
||||||
|
contentType="text",
|
||||||
|
contextRef=ContentContextRef(
|
||||||
|
containerPath=fileName,
|
||||||
|
location=f"page:{i+1}",
|
||||||
|
pageIndex=i,
|
||||||
|
),
|
||||||
|
charCount=textLen,
|
||||||
|
))
|
||||||
|
totalSize += textLen
|
||||||
|
objIndex += 1
|
||||||
|
|
||||||
|
for j in range(len(images)):
|
||||||
|
summaries.append(ContentObjectSummary(
|
||||||
|
id=f"co-{objIndex}",
|
||||||
|
contentType="image",
|
||||||
|
contextRef=ContentContextRef(
|
||||||
|
containerPath=fileName,
|
||||||
|
location=f"page:{i+1}/image:{j}",
|
||||||
|
pageIndex=i,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
objIndex += 1
|
||||||
|
|
||||||
|
sections = _buildSectionsFromTocOrHeadings(toc, pageMap)
|
||||||
|
doc.close()
|
||||||
|
|
||||||
|
structure = {
|
||||||
|
"pages": len(pageMap),
|
||||||
|
"toc": toc,
|
||||||
|
"sections": sections,
|
||||||
|
"pageMap": pageMap,
|
||||||
|
"imageCount": sum(p.get("imageCount", 0) for p in pageMap),
|
||||||
|
"tableCount": sum(1 for p in pageMap if p.get("hasTable")),
|
||||||
|
}
|
||||||
|
return structure, summaries, len(summaries), totalSize
|
||||||
|
|
||||||
|
|
||||||
|
def _isHeading(span: Dict) -> bool:
|
||||||
|
"""Heuristic: heading if font size >= 14 or bold + size >= 12."""
|
||||||
|
size = span.get("size", 0)
|
||||||
|
flags = span.get("flags", 0)
|
||||||
|
isBold = bool(flags & (1 << 4))
|
||||||
|
return size >= 14 or (isBold and size >= 12)
|
||||||
|
|
||||||
|
|
||||||
|
def _detectTableHeuristic(page) -> bool:
|
||||||
|
"""Detect tables by looking for grid-like line patterns."""
|
||||||
|
try:
|
||||||
|
drawings = page.get_drawings()
|
||||||
|
lineCount = sum(1 for d in drawings if d.get("type") == "l")
|
||||||
|
return lineCount >= 6
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _buildSectionsFromTocOrHeadings(
|
||||||
|
toc: list, pageMap: List[Dict]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Build section boundaries from TOC or heading data."""
|
||||||
|
sections: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
if toc:
|
||||||
|
for i, entry in enumerate(toc):
|
||||||
|
level, title, pageNum = entry[0], entry[1], entry[2]
|
||||||
|
endPage = toc[i + 1][2] - 1 if i + 1 < len(toc) else len(pageMap) - 1
|
||||||
|
sections.append({
|
||||||
|
"id": f"section-{i}",
|
||||||
|
"title": title,
|
||||||
|
"level": level,
|
||||||
|
"startPage": pageNum - 1,
|
||||||
|
"endPage": endPage,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
currentSection = None
|
||||||
|
for pm in pageMap:
|
||||||
|
headings = pm.get("headings", [])
|
||||||
|
if headings:
|
||||||
|
if currentSection:
|
||||||
|
currentSection["endPage"] = pm["pageIndex"] - 1
|
||||||
|
sections.append(currentSection)
|
||||||
|
currentSection = {
|
||||||
|
"id": f"section-{len(sections)}",
|
||||||
|
"title": headings[0],
|
||||||
|
"level": 1,
|
||||||
|
"startPage": pm["pageIndex"],
|
||||||
|
"endPage": pm["pageIndex"],
|
||||||
|
}
|
||||||
|
elif currentSection:
|
||||||
|
currentSection["endPage"] = pm["pageIndex"]
|
||||||
|
|
||||||
|
if currentSection:
|
||||||
|
sections.append(currentSection)
|
||||||
|
|
||||||
|
return sections
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# DOCX scanner
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _scanDocx(fileData: bytes, fileName: str):
|
||||||
|
try:
|
||||||
|
import docx
|
||||||
|
except ImportError:
|
||||||
|
return _fallbackStructure(fileData, fileName)
|
||||||
|
|
||||||
|
doc = docx.Document(io.BytesIO(fileData))
|
||||||
|
summaries: List[ContentObjectSummary] = []
|
||||||
|
sections: List[Dict[str, Any]] = []
|
||||||
|
totalSize = 0
|
||||||
|
objIndex = 0
|
||||||
|
currentSection = None
|
||||||
|
|
||||||
|
for i, para in enumerate(doc.paragraphs):
|
||||||
|
text = para.text or ""
|
||||||
|
styleName = (para.style.name or "").lower() if para.style else ""
|
||||||
|
|
||||||
|
if "heading" in styleName and text.strip():
|
||||||
|
if currentSection:
|
||||||
|
sections.append(currentSection)
|
||||||
|
level = 1
|
||||||
|
for ch in styleName:
|
||||||
|
if ch.isdigit():
|
||||||
|
level = int(ch)
|
||||||
|
break
|
||||||
|
currentSection = {
|
||||||
|
"id": f"section-{len(sections)}",
|
||||||
|
"title": text.strip(),
|
||||||
|
"level": level,
|
||||||
|
"startParagraph": i,
|
||||||
|
"endParagraph": i,
|
||||||
|
}
|
||||||
|
elif currentSection:
|
||||||
|
currentSection["endParagraph"] = i
|
||||||
|
|
||||||
|
if text.strip():
|
||||||
|
summaries.append(ContentObjectSummary(
|
||||||
|
id=f"co-{objIndex}",
|
||||||
|
contentType="text",
|
||||||
|
contextRef=ContentContextRef(
|
||||||
|
containerPath=fileName,
|
||||||
|
location=f"paragraph:{i+1}",
|
||||||
|
sectionId=currentSection["id"] if currentSection else "body",
|
||||||
|
),
|
||||||
|
charCount=len(text),
|
||||||
|
))
|
||||||
|
totalSize += len(text)
|
||||||
|
objIndex += 1
|
||||||
|
|
||||||
|
if currentSection:
|
||||||
|
sections.append(currentSection)
|
||||||
|
|
||||||
|
for ti, table in enumerate(doc.tables):
|
||||||
|
summaries.append(ContentObjectSummary(
|
||||||
|
id=f"co-{objIndex}",
|
||||||
|
contentType="text",
|
||||||
|
contextRef=ContentContextRef(
|
||||||
|
containerPath=fileName,
|
||||||
|
location=f"table:{ti+1}",
|
||||||
|
),
|
||||||
|
))
|
||||||
|
objIndex += 1
|
||||||
|
|
||||||
|
structure = {
|
||||||
|
"paragraphs": len(doc.paragraphs),
|
||||||
|
"tables": len(doc.tables),
|
||||||
|
"sections": sections,
|
||||||
|
}
|
||||||
|
return structure, summaries, len(summaries), totalSize
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PPTX scanner
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _scanPptx(fileData: bytes, fileName: str):
|
||||||
|
try:
|
||||||
|
from pptx import Presentation
|
||||||
|
except ImportError:
|
||||||
|
return _fallbackStructure(fileData, fileName)
|
||||||
|
|
||||||
|
prs = Presentation(io.BytesIO(fileData))
|
||||||
|
summaries: List[ContentObjectSummary] = []
|
||||||
|
slideMap: List[Dict[str, Any]] = []
|
||||||
|
totalSize = 0
|
||||||
|
objIndex = 0
|
||||||
|
|
||||||
|
for i, slide in enumerate(prs.slides):
|
||||||
|
title = ""
|
||||||
|
textLen = 0
|
||||||
|
imageCount = 0
|
||||||
|
for shape in slide.shapes:
|
||||||
|
if hasattr(shape, "text"):
|
||||||
|
textLen += len(shape.text)
|
||||||
|
if shape.has_text_frame and not title:
|
||||||
|
title = shape.text.strip()[:80]
|
||||||
|
if shape.shape_type == 13:
|
||||||
|
imageCount += 1
|
||||||
|
|
||||||
|
slideMap.append({
|
||||||
|
"slideIndex": i,
|
||||||
|
"title": title,
|
||||||
|
"textLength": textLen,
|
||||||
|
"imageCount": imageCount,
|
||||||
|
})
|
||||||
|
|
||||||
|
if textLen > 0:
|
||||||
|
summaries.append(ContentObjectSummary(
|
||||||
|
id=f"co-{objIndex}",
|
||||||
|
contentType="text",
|
||||||
|
contextRef=ContentContextRef(
|
||||||
|
containerPath=fileName,
|
||||||
|
location=f"slide:{i+1}",
|
||||||
|
slideIndex=i,
|
||||||
|
),
|
||||||
|
charCount=textLen,
|
||||||
|
))
|
||||||
|
totalSize += textLen
|
||||||
|
objIndex += 1
|
||||||
|
|
||||||
|
structure = {
|
||||||
|
"slides": len(prs.slides),
|
||||||
|
"slideMap": slideMap,
|
||||||
|
}
|
||||||
|
return structure, summaries, len(summaries), totalSize
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# XLSX scanner
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _scanXlsx(fileData: bytes, fileName: str):
|
||||||
|
try:
|
||||||
|
import openpyxl
|
||||||
|
except ImportError:
|
||||||
|
return _fallbackStructure(fileData, fileName)
|
||||||
|
|
||||||
|
wb = openpyxl.load_workbook(io.BytesIO(fileData), data_only=True, read_only=True)
|
||||||
|
summaries: List[ContentObjectSummary] = []
|
||||||
|
sheetMap: List[Dict[str, Any]] = []
|
||||||
|
totalSize = 0
|
||||||
|
objIndex = 0
|
||||||
|
|
||||||
|
for sheetName in wb.sheetnames:
|
||||||
|
ws = wb[sheetName]
|
||||||
|
rowCount = ws.max_row or 0
|
||||||
|
colCount = ws.max_column or 0
|
||||||
|
|
||||||
|
sheetMap.append({
|
||||||
|
"sheetName": sheetName,
|
||||||
|
"rows": rowCount,
|
||||||
|
"columns": colCount,
|
||||||
|
})
|
||||||
|
|
||||||
|
summaries.append(ContentObjectSummary(
|
||||||
|
id=f"co-{objIndex}",
|
||||||
|
contentType="text",
|
||||||
|
contextRef=ContentContextRef(
|
||||||
|
containerPath=fileName,
|
||||||
|
location=f"sheet:{sheetName}",
|
||||||
|
sheetName=sheetName,
|
||||||
|
),
|
||||||
|
charCount=rowCount * colCount * 10,
|
||||||
|
))
|
||||||
|
totalSize += rowCount * colCount * 10
|
||||||
|
objIndex += 1
|
||||||
|
|
||||||
|
wb.close()
|
||||||
|
structure = {"sheets": len(wb.sheetnames), "sheetMap": sheetMap}
|
||||||
|
return structure, summaries, len(summaries), totalSize
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Minimal / fallback scanner
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _scanMinimal(fileData: bytes, fileName: str):
|
||||||
|
return _fallbackStructure(fileData, fileName)
|
||||||
|
|
||||||
|
|
||||||
|
def _fallbackStructure(fileData: bytes, fileName: str):
|
||||||
|
summary = ContentObjectSummary(
|
||||||
|
id="co-0",
|
||||||
|
contentType="other",
|
||||||
|
contextRef=ContentContextRef(containerPath=fileName, location="file"),
|
||||||
|
charCount=len(fileData),
|
||||||
|
)
|
||||||
|
structure = {"type": "single", "size": len(fileData)}
|
||||||
|
return structure, [summary], 1, len(fileData)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scanner map
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_SCANNER_MAP: Dict[str, Any] = {
|
||||||
|
"application/pdf": _scanPdf,
|
||||||
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": _scanDocx,
|
||||||
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation": _scanPptx,
|
||||||
|
"application/vnd.ms-powerpoint": _scanPptx,
|
||||||
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": _scanXlsx,
|
||||||
|
}
|
||||||
|
|
||||||
|
_EXTENSION_SCANNER_MAP: Dict[str, Any] = {
|
||||||
|
"pdf": _scanPdf,
|
||||||
|
"docx": _scanDocx,
|
||||||
|
"pptx": _scanPptx,
|
||||||
|
"ppt": _scanPptx,
|
||||||
|
"xlsx": _scanXlsx,
|
||||||
|
"xlsm": _scanXlsx,
|
||||||
|
}
|
||||||
|
|
@ -452,6 +452,11 @@ RESOURCE_OBJECTS = [
|
||||||
"label": {"en": "Store: Teams Bot", "de": "Store: Teams Bot", "fr": "Store: Teams Bot"},
|
"label": {"en": "Store: Teams Bot", "de": "Store: Teams Bot", "fr": "Store: Teams Bot"},
|
||||||
"meta": {"category": "store", "featureCode": "teamsbot"}
|
"meta": {"category": "store", "featureCode": "teamsbot"}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"objectKey": "resource.store.workspace",
|
||||||
|
"label": {"en": "Store: AI Workspace", "de": "Store: AI Workspace", "fr": "Store: AI Workspace"},
|
||||||
|
"meta": {"category": "store", "featureCode": "workspace"}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"objectKey": "resource.system.api.auth",
|
"objectKey": "resource.system.api.auth",
|
||||||
"label": {"en": "Authentication API", "de": "Authentifizierungs-API", "fr": "API d'authentification"},
|
"label": {"en": "Authentication API", "de": "Authentifizierungs-API", "fr": "API d'authentification"},
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ async def webResearch(self, parameters: Dict[str, Any]) -> ActionResult:
|
||||||
workflow_id=self.services.workflow.id if self.services.workflow else None,
|
workflow_id=self.services.workflow.id if self.services.workflow else None,
|
||||||
workflow=self.services.workflow,
|
workflow=self.services.workflow,
|
||||||
)
|
)
|
||||||
web_service = getService("web", context, legacy_hub=self.services)
|
web_service = getService("web", context)
|
||||||
|
|
||||||
# Init progress logger
|
# Init progress logger
|
||||||
workflowId = self.services.workflow.id if self.services.workflow else f"no-workflow-{int(time.time())}"
|
workflowId = self.services.workflow.id if self.services.workflow else f"no-workflow-{int(time.time())}"
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue