codeeditor
This commit is contained in:
parent
6c8cc843ce
commit
1d4148e8b5
6 changed files with 45 additions and 15 deletions
|
|
@ -9,7 +9,7 @@ and Playground into a single agent-driven workspace.
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Dict, Optional, List
|
from typing import Any, Dict, Optional, List
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request, UploadFile, File
|
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request, UploadFile, File
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
|
|
@ -50,7 +50,8 @@ async def _getAiObjects() -> AiObjects:
|
||||||
return _aiObjects
|
return _aiObjects
|
||||||
|
|
||||||
|
|
||||||
def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str:
|
def _validateInstanceAccess(instanceId: str, context: RequestContext):
|
||||||
|
"""Validate access and return (mandateId, instanceConfig) tuple."""
|
||||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||||
rootInterface = getRootInterface()
|
rootInterface = getRootInterface()
|
||||||
instance = rootInterface.getFeatureInstance(instanceId)
|
instance = rootInterface.getFeatureInstance(instanceId)
|
||||||
|
|
@ -59,7 +60,9 @@ def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str:
|
||||||
featureAccess = rootInterface.getFeatureAccess(str(context.user.id), instanceId)
|
featureAccess = rootInterface.getFeatureAccess(str(context.user.id), instanceId)
|
||||||
if not featureAccess or not featureAccess.enabled:
|
if not featureAccess or not featureAccess.enabled:
|
||||||
raise HTTPException(status_code=403, detail="Access denied to this feature instance")
|
raise HTTPException(status_code=403, detail="Access denied to this feature instance")
|
||||||
return str(instance.mandateId) if instance.mandateId else None
|
mandateId = str(instance.mandateId) if instance.mandateId else None
|
||||||
|
instanceConfig = instance.config if hasattr(instance, "config") and instance.config else {}
|
||||||
|
return mandateId, instanceConfig
|
||||||
|
|
||||||
|
|
||||||
def _getChatInterface(context: RequestContext, featureInstanceId: str = None):
|
def _getChatInterface(context: RequestContext, featureInstanceId: str = None):
|
||||||
|
|
@ -218,7 +221,7 @@ async def streamWorkspaceStart(
|
||||||
context: RequestContext = Depends(getRequestContext),
|
context: RequestContext = Depends(getRequestContext),
|
||||||
):
|
):
|
||||||
"""Start or continue a Workspace session with SSE streaming via serviceAgent."""
|
"""Start or continue a Workspace session with SSE streaming via serviceAgent."""
|
||||||
mandateId = _validateInstanceAccess(instanceId, context)
|
mandateId, instanceConfig = _validateInstanceAccess(instanceId, context)
|
||||||
chatInterface = _getChatInterface(context, featureInstanceId=instanceId)
|
chatInterface = _getChatInterface(context, featureInstanceId=instanceId)
|
||||||
aiObjects = await _getAiObjects()
|
aiObjects = await _getAiObjects()
|
||||||
eventManager = get_event_manager()
|
eventManager = get_event_manager()
|
||||||
|
|
@ -260,6 +263,7 @@ async def streamWorkspaceStart(
|
||||||
chatInterface=chatInterface,
|
chatInterface=chatInterface,
|
||||||
eventManager=eventManager,
|
eventManager=eventManager,
|
||||||
userLanguage=userInput.userLanguage,
|
userLanguage=userInput.userLanguage,
|
||||||
|
instanceConfig=instanceConfig,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
eventManager.register_agent_task(queueId, agentTask)
|
eventManager.register_agent_task(queueId, agentTask)
|
||||||
|
|
@ -312,6 +316,7 @@ async def _runWorkspaceAgent(
|
||||||
chatInterface,
|
chatInterface,
|
||||||
eventManager,
|
eventManager,
|
||||||
userLanguage: str = "en",
|
userLanguage: str = "en",
|
||||||
|
instanceConfig: Dict[str, Any] = None,
|
||||||
):
|
):
|
||||||
"""Run the serviceAgent loop and forward events to the SSE queue."""
|
"""Run the serviceAgent loop and forward events to the SSE queue."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -356,12 +361,20 @@ async def _runWorkspaceAgent(
|
||||||
accumulatedText = ""
|
accumulatedText = ""
|
||||||
messagePersisted = False
|
messagePersisted = False
|
||||||
|
|
||||||
|
_cfg = instanceConfig or {}
|
||||||
|
_toolSet = _cfg.get("toolSet", "core")
|
||||||
|
_agentCfg = _cfg.get("agentConfig")
|
||||||
|
from modules.serviceCenter.services.serviceAgent.datamodelAgent import AgentConfig
|
||||||
|
agentConfig = AgentConfig(**_agentCfg) if isinstance(_agentCfg, dict) else None
|
||||||
|
|
||||||
async for event in agentService.runAgent(
|
async for event in agentService.runAgent(
|
||||||
prompt=enrichedPrompt,
|
prompt=enrichedPrompt,
|
||||||
fileIds=fileIds,
|
fileIds=fileIds,
|
||||||
workflowId=workflowId,
|
workflowId=workflowId,
|
||||||
userLanguage=userLanguage,
|
userLanguage=userLanguage,
|
||||||
conversationHistory=conversationHistory,
|
conversationHistory=conversationHistory,
|
||||||
|
toolSet=_toolSet,
|
||||||
|
config=agentConfig,
|
||||||
):
|
):
|
||||||
if eventManager.is_cancelled(queueId):
|
if eventManager.is_cancelled(queueId):
|
||||||
logger.info(f"Agent cancelled by user for workflow {workflowId}")
|
logger.info(f"Agent cancelled by user for workflow {workflowId}")
|
||||||
|
|
@ -1034,7 +1047,7 @@ async def getVoiceLanguages(
|
||||||
context: RequestContext = Depends(getRequestContext),
|
context: RequestContext = Depends(getRequestContext),
|
||||||
):
|
):
|
||||||
"""Return available TTS languages."""
|
"""Return available TTS languages."""
|
||||||
mandateId = _validateInstanceAccess(instanceId, context)
|
mandateId, _ = _validateInstanceAccess(instanceId, context)
|
||||||
from modules.interfaces.interfaceVoiceObjects import getVoiceInterface
|
from modules.interfaces.interfaceVoiceObjects import getVoiceInterface
|
||||||
voiceInterface = getVoiceInterface(context.user, mandateId)
|
voiceInterface = getVoiceInterface(context.user, mandateId)
|
||||||
languagesResult = await voiceInterface.getAvailableLanguages()
|
languagesResult = await voiceInterface.getAvailableLanguages()
|
||||||
|
|
@ -1051,7 +1064,7 @@ async def getVoiceVoices(
|
||||||
context: RequestContext = Depends(getRequestContext),
|
context: RequestContext = Depends(getRequestContext),
|
||||||
):
|
):
|
||||||
"""Return available TTS voices for a given language."""
|
"""Return available TTS voices for a given language."""
|
||||||
mandateId = _validateInstanceAccess(instanceId, context)
|
mandateId, _ = _validateInstanceAccess(instanceId, context)
|
||||||
from modules.interfaces.interfaceVoiceObjects import getVoiceInterface
|
from modules.interfaces.interfaceVoiceObjects import getVoiceInterface
|
||||||
voiceInterface = getVoiceInterface(context.user, mandateId)
|
voiceInterface = getVoiceInterface(context.user, mandateId)
|
||||||
voicesResult = await voiceInterface.getAvailableVoices(language)
|
voicesResult = await voiceInterface.getAvailableVoices(language)
|
||||||
|
|
@ -1069,7 +1082,7 @@ async def testVoice(
|
||||||
):
|
):
|
||||||
"""Test a specific voice with a sample text."""
|
"""Test a specific voice with a sample text."""
|
||||||
import base64
|
import base64
|
||||||
mandateId = _validateInstanceAccess(instanceId, context)
|
mandateId, _ = _validateInstanceAccess(instanceId, context)
|
||||||
text = body.get("text", "Hallo, das ist ein Stimmtest.")
|
text = body.get("text", "Hallo, das ist ein Stimmtest.")
|
||||||
language = body.get("language", "de-DE")
|
language = body.get("language", "de-DE")
|
||||||
voiceId = body.get("voiceId")
|
voiceId = body.get("voiceId")
|
||||||
|
|
|
||||||
|
|
@ -195,6 +195,19 @@ class KnowledgeObjects:
|
||||||
if contentType:
|
if contentType:
|
||||||
recordFilter["contentType"] = contentType
|
recordFilter["contentType"] = contentType
|
||||||
|
|
||||||
|
if isShared and mandateId:
|
||||||
|
sharedIndexes = self.db.getRecordset(
|
||||||
|
FileContentIndex,
|
||||||
|
recordFilter={"mandateId": mandateId, "isShared": True},
|
||||||
|
)
|
||||||
|
sharedFileIds = [idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", None) for idx in sharedIndexes]
|
||||||
|
sharedFileIds = [fid for fid in sharedFileIds if fid]
|
||||||
|
if not sharedFileIds:
|
||||||
|
return []
|
||||||
|
recordFilter.pop("userId", None)
|
||||||
|
recordFilter.pop("featureInstanceId", None)
|
||||||
|
recordFilter["fileId"] = sharedFileIds
|
||||||
|
|
||||||
return self.db.semanticSearch(
|
return self.db.semanticSearch(
|
||||||
modelClass=ContentChunk,
|
modelClass=ContentChunk,
|
||||||
vectorColumn="embedding",
|
vectorColumn="embedding",
|
||||||
|
|
|
||||||
|
|
@ -24,9 +24,6 @@ from modules.shared.timeUtils import getUtcTimestamp
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MAX_RETRIES_PER_TOOL = 3
|
|
||||||
RETRY_BASE_DELAY_S = 1.0
|
|
||||||
|
|
||||||
|
|
||||||
async def runAgentLoop(
|
async def runAgentLoop(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,10 @@ class ToolDefinition(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Feature scope for this tool (None = available to all)"
|
description="Feature scope for this tool (None = available to all)"
|
||||||
)
|
)
|
||||||
|
toolSet: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Tool-set scope (None = available to all sets, e.g. 'core', 'codeeditor')"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolCallRequest(BaseModel):
|
class ToolCallRequest(BaseModel):
|
||||||
|
|
@ -79,7 +83,6 @@ class AgentConfig(BaseModel):
|
||||||
"""Configuration for an agent run."""
|
"""Configuration for an agent run."""
|
||||||
maxRounds: int = Field(default=25, ge=1, le=100)
|
maxRounds: int = Field(default=25, ge=1, le=100)
|
||||||
maxCostCHF: Optional[float] = Field(default=None, ge=0.0)
|
maxCostCHF: Optional[float] = Field(default=None, ge=0.0)
|
||||||
entityCacheEnabled: bool = Field(default=False)
|
|
||||||
toolSet: str = Field(default="core")
|
toolSet: str = Field(default="core")
|
||||||
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
|
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1613,7 +1613,7 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
||||||
result = await knowledgeService.extractContainerItem(fileId, containerPath)
|
result = await knowledgeService.extractContainerItem(fileId, containerPath)
|
||||||
if result:
|
if result:
|
||||||
return ToolResult(toolCallId="", toolName="extractContainerItem", success=True, data=str(result))
|
return ToolResult(toolCallId="", toolName="extractContainerItem", success=True, data=str(result))
|
||||||
return ToolResult(toolCallId="", toolName="extractContainerItem", success=True, data=f"On-demand extraction for '{containerPath}' queued.")
|
return ToolResult(toolCallId="", toolName="extractContainerItem", success=False, error=f"Item '{containerPath}' not found in container index for file {fileId}. On-demand extraction is not yet implemented.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ToolResult(toolCallId="", toolName="extractContainerItem", success=False, error=str(e))
|
return ToolResult(toolCallId="", toolName="extractContainerItem", success=False, error=str(e))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,8 @@ class ToolRegistry:
|
||||||
|
|
||||||
def register(self, name: str, handler: Callable[..., Awaitable[ToolResult]],
|
def register(self, name: str, handler: Callable[..., Awaitable[ToolResult]],
|
||||||
description: str = "", parameters: Dict[str, Any] = None,
|
description: str = "", parameters: Dict[str, Any] = None,
|
||||||
readOnly: bool = False, featureType: str = None):
|
readOnly: bool = False, featureType: str = None,
|
||||||
|
toolSet: str = None):
|
||||||
"""Register a tool with its handler function."""
|
"""Register a tool with its handler function."""
|
||||||
if name in self._tools:
|
if name in self._tools:
|
||||||
logger.warning(f"Tool '{name}' already registered, overwriting")
|
logger.warning(f"Tool '{name}' already registered, overwriting")
|
||||||
|
|
@ -32,10 +33,11 @@ class ToolRegistry:
|
||||||
description=description,
|
description=description,
|
||||||
parameters=parameters or {},
|
parameters=parameters or {},
|
||||||
readOnly=readOnly,
|
readOnly=readOnly,
|
||||||
featureType=featureType
|
featureType=featureType,
|
||||||
|
toolSet=toolSet,
|
||||||
)
|
)
|
||||||
self._handlers[name] = handler
|
self._handlers[name] = handler
|
||||||
logger.debug(f"Registered tool: {name} (readOnly={readOnly})")
|
logger.debug(f"Registered tool: {name} (readOnly={readOnly}, toolSet={toolSet})")
|
||||||
|
|
||||||
def registerFromDefinition(self, definition: ToolDefinition,
|
def registerFromDefinition(self, definition: ToolDefinition,
|
||||||
handler: Callable[..., Awaitable[ToolResult]]):
|
handler: Callable[..., Awaitable[ToolResult]]):
|
||||||
|
|
@ -54,6 +56,8 @@ class ToolRegistry:
|
||||||
tools = list(self._tools.values())
|
tools = list(self._tools.values())
|
||||||
if featureType:
|
if featureType:
|
||||||
tools = [t for t in tools if t.featureType is None or t.featureType == featureType]
|
tools = [t for t in tools if t.featureType is None or t.featureType == featureType]
|
||||||
|
if toolSet:
|
||||||
|
tools = [t for t in tools if t.toolSet is None or t.toolSet == toolSet]
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
def getToolNames(self) -> List[str]:
|
def getToolNames(self) -> List[str]:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue