codeeditor
This commit is contained in:
parent
6c8cc843ce
commit
1d4148e8b5
6 changed files with 45 additions and 15 deletions
|
|
@ -9,7 +9,7 @@ and Playground into a single agent-driven workspace.
|
|||
import logging
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, Optional, List
|
||||
from typing import Any, Dict, Optional, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request, UploadFile, File
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
|
|
@ -50,7 +50,8 @@ async def _getAiObjects() -> AiObjects:
|
|||
return _aiObjects
|
||||
|
||||
|
||||
def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str:
|
||||
def _validateInstanceAccess(instanceId: str, context: RequestContext):
|
||||
"""Validate access and return (mandateId, instanceConfig) tuple."""
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
rootInterface = getRootInterface()
|
||||
instance = rootInterface.getFeatureInstance(instanceId)
|
||||
|
|
@ -59,7 +60,9 @@ def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str:
|
|||
featureAccess = rootInterface.getFeatureAccess(str(context.user.id), instanceId)
|
||||
if not featureAccess or not featureAccess.enabled:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this feature instance")
|
||||
return str(instance.mandateId) if instance.mandateId else None
|
||||
mandateId = str(instance.mandateId) if instance.mandateId else None
|
||||
instanceConfig = instance.config if hasattr(instance, "config") and instance.config else {}
|
||||
return mandateId, instanceConfig
|
||||
|
||||
|
||||
def _getChatInterface(context: RequestContext, featureInstanceId: str = None):
|
||||
|
|
@ -218,7 +221,7 @@ async def streamWorkspaceStart(
|
|||
context: RequestContext = Depends(getRequestContext),
|
||||
):
|
||||
"""Start or continue a Workspace session with SSE streaming via serviceAgent."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
mandateId, instanceConfig = _validateInstanceAccess(instanceId, context)
|
||||
chatInterface = _getChatInterface(context, featureInstanceId=instanceId)
|
||||
aiObjects = await _getAiObjects()
|
||||
eventManager = get_event_manager()
|
||||
|
|
@ -260,6 +263,7 @@ async def streamWorkspaceStart(
|
|||
chatInterface=chatInterface,
|
||||
eventManager=eventManager,
|
||||
userLanguage=userInput.userLanguage,
|
||||
instanceConfig=instanceConfig,
|
||||
)
|
||||
)
|
||||
eventManager.register_agent_task(queueId, agentTask)
|
||||
|
|
@ -312,6 +316,7 @@ async def _runWorkspaceAgent(
|
|||
chatInterface,
|
||||
eventManager,
|
||||
userLanguage: str = "en",
|
||||
instanceConfig: Dict[str, Any] = None,
|
||||
):
|
||||
"""Run the serviceAgent loop and forward events to the SSE queue."""
|
||||
try:
|
||||
|
|
@ -356,12 +361,20 @@ async def _runWorkspaceAgent(
|
|||
accumulatedText = ""
|
||||
messagePersisted = False
|
||||
|
||||
_cfg = instanceConfig or {}
|
||||
_toolSet = _cfg.get("toolSet", "core")
|
||||
_agentCfg = _cfg.get("agentConfig")
|
||||
from modules.serviceCenter.services.serviceAgent.datamodelAgent import AgentConfig
|
||||
agentConfig = AgentConfig(**_agentCfg) if isinstance(_agentCfg, dict) else None
|
||||
|
||||
async for event in agentService.runAgent(
|
||||
prompt=enrichedPrompt,
|
||||
fileIds=fileIds,
|
||||
workflowId=workflowId,
|
||||
userLanguage=userLanguage,
|
||||
conversationHistory=conversationHistory,
|
||||
toolSet=_toolSet,
|
||||
config=agentConfig,
|
||||
):
|
||||
if eventManager.is_cancelled(queueId):
|
||||
logger.info(f"Agent cancelled by user for workflow {workflowId}")
|
||||
|
|
@ -1034,7 +1047,7 @@ async def getVoiceLanguages(
|
|||
context: RequestContext = Depends(getRequestContext),
|
||||
):
|
||||
"""Return available TTS languages."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
mandateId, _ = _validateInstanceAccess(instanceId, context)
|
||||
from modules.interfaces.interfaceVoiceObjects import getVoiceInterface
|
||||
voiceInterface = getVoiceInterface(context.user, mandateId)
|
||||
languagesResult = await voiceInterface.getAvailableLanguages()
|
||||
|
|
@ -1051,7 +1064,7 @@ async def getVoiceVoices(
|
|||
context: RequestContext = Depends(getRequestContext),
|
||||
):
|
||||
"""Return available TTS voices for a given language."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
mandateId, _ = _validateInstanceAccess(instanceId, context)
|
||||
from modules.interfaces.interfaceVoiceObjects import getVoiceInterface
|
||||
voiceInterface = getVoiceInterface(context.user, mandateId)
|
||||
voicesResult = await voiceInterface.getAvailableVoices(language)
|
||||
|
|
@ -1069,7 +1082,7 @@ async def testVoice(
|
|||
):
|
||||
"""Test a specific voice with a sample text."""
|
||||
import base64
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
mandateId, _ = _validateInstanceAccess(instanceId, context)
|
||||
text = body.get("text", "Hallo, das ist ein Stimmtest.")
|
||||
language = body.get("language", "de-DE")
|
||||
voiceId = body.get("voiceId")
|
||||
|
|
|
|||
|
|
@ -195,6 +195,19 @@ class KnowledgeObjects:
|
|||
if contentType:
|
||||
recordFilter["contentType"] = contentType
|
||||
|
||||
if isShared and mandateId:
|
||||
sharedIndexes = self.db.getRecordset(
|
||||
FileContentIndex,
|
||||
recordFilter={"mandateId": mandateId, "isShared": True},
|
||||
)
|
||||
sharedFileIds = [idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", None) for idx in sharedIndexes]
|
||||
sharedFileIds = [fid for fid in sharedFileIds if fid]
|
||||
if not sharedFileIds:
|
||||
return []
|
||||
recordFilter.pop("userId", None)
|
||||
recordFilter.pop("featureInstanceId", None)
|
||||
recordFilter["fileId"] = sharedFileIds
|
||||
|
||||
return self.db.semanticSearch(
|
||||
modelClass=ContentChunk,
|
||||
vectorColumn="embedding",
|
||||
|
|
|
|||
|
|
@ -24,9 +24,6 @@ from modules.shared.timeUtils import getUtcTimestamp
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_RETRIES_PER_TOOL = 3
|
||||
RETRY_BASE_DELAY_S = 1.0
|
||||
|
||||
|
||||
async def runAgentLoop(
|
||||
prompt: str,
|
||||
|
|
|
|||
|
|
@ -48,6 +48,10 @@ class ToolDefinition(BaseModel):
|
|||
default=None,
|
||||
description="Feature scope for this tool (None = available to all)"
|
||||
)
|
||||
toolSet: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Tool-set scope (None = available to all sets, e.g. 'core', 'codeeditor')"
|
||||
)
|
||||
|
||||
|
||||
class ToolCallRequest(BaseModel):
|
||||
|
|
@ -79,7 +83,6 @@ class AgentConfig(BaseModel):
|
|||
"""Configuration for an agent run."""
|
||||
maxRounds: int = Field(default=25, ge=1, le=100)
|
||||
maxCostCHF: Optional[float] = Field(default=None, ge=0.0)
|
||||
entityCacheEnabled: bool = Field(default=False)
|
||||
toolSet: str = Field(default="core")
|
||||
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
|
||||
|
||||
|
|
|
|||
|
|
@ -1613,7 +1613,7 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
result = await knowledgeService.extractContainerItem(fileId, containerPath)
|
||||
if result:
|
||||
return ToolResult(toolCallId="", toolName="extractContainerItem", success=True, data=str(result))
|
||||
return ToolResult(toolCallId="", toolName="extractContainerItem", success=True, data=f"On-demand extraction for '{containerPath}' queued.")
|
||||
return ToolResult(toolCallId="", toolName="extractContainerItem", success=False, error=f"Item '{containerPath}' not found in container index for file {fileId}. On-demand extraction is not yet implemented.")
|
||||
except Exception as e:
|
||||
return ToolResult(toolCallId="", toolName="extractContainerItem", success=False, error=str(e))
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,8 @@ class ToolRegistry:
|
|||
|
||||
def register(self, name: str, handler: Callable[..., Awaitable[ToolResult]],
|
||||
description: str = "", parameters: Dict[str, Any] = None,
|
||||
readOnly: bool = False, featureType: str = None):
|
||||
readOnly: bool = False, featureType: str = None,
|
||||
toolSet: str = None):
|
||||
"""Register a tool with its handler function."""
|
||||
if name in self._tools:
|
||||
logger.warning(f"Tool '{name}' already registered, overwriting")
|
||||
|
|
@ -32,10 +33,11 @@ class ToolRegistry:
|
|||
description=description,
|
||||
parameters=parameters or {},
|
||||
readOnly=readOnly,
|
||||
featureType=featureType
|
||||
featureType=featureType,
|
||||
toolSet=toolSet,
|
||||
)
|
||||
self._handlers[name] = handler
|
||||
logger.debug(f"Registered tool: {name} (readOnly={readOnly})")
|
||||
logger.debug(f"Registered tool: {name} (readOnly={readOnly}, toolSet={toolSet})")
|
||||
|
||||
def registerFromDefinition(self, definition: ToolDefinition,
|
||||
handler: Callable[..., Awaitable[ToolResult]]):
|
||||
|
|
@ -54,6 +56,8 @@ class ToolRegistry:
|
|||
tools = list(self._tools.values())
|
||||
if featureType:
|
||||
tools = [t for t in tools if t.featureType is None or t.featureType == featureType]
|
||||
if toolSet:
|
||||
tools = [t for t in tools if t.toolSet is None or t.toolSet == toolSet]
|
||||
return tools
|
||||
|
||||
def getToolNames(self) -> List[str]:
|
||||
|
|
|
|||
Loading…
Reference in a new issue