fixed source access and workflow resource limits and stop
This commit is contained in:
parent
b418207c2c
commit
c873f96848
5 changed files with 186 additions and 58 deletions
|
|
@ -412,6 +412,20 @@ class AiAnthropic(BaseConnectorAi):
|
|||
mimeType = parts[0].replace("data:", "")
|
||||
base64Data = parts[1]
|
||||
|
||||
import base64 as _b64
|
||||
try:
|
||||
rawHead = _b64.b64decode(base64Data[:32])
|
||||
if rawHead[:3] == b"\xff\xd8\xff":
|
||||
mimeType = "image/jpeg"
|
||||
elif rawHead[:8] == b"\x89PNG\r\n\x1a\n":
|
||||
mimeType = "image/png"
|
||||
elif rawHead[:4] == b"GIF8":
|
||||
mimeType = "image/gif"
|
||||
elif rawHead[:4] == b"RIFF" and rawHead[8:12] == b"WEBP":
|
||||
mimeType = "image/webp"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Convert to Anthropic's vision format
|
||||
anthropicMessages = [{
|
||||
"role": "user",
|
||||
|
|
|
|||
|
|
@ -117,6 +117,7 @@ def _buildDataSourceContext(chatService, dataSourceIds: List[str]) -> str:
|
|||
"- Use the dataSourceId (UUID) exactly as shown below.",
|
||||
"- Do NOT use listFiles, externalBrowse, or externalSearch for attached data sources -- those tools are for other purposes.",
|
||||
"- browseDataSource returns BOTH files and folders at the given path.",
|
||||
"- When downloading files, ALWAYS provide the human-readable fileName (with extension) from the browse results.",
|
||||
"",
|
||||
]
|
||||
found = False
|
||||
|
|
@ -130,6 +131,7 @@ def _buildDataSourceContext(chatService, dataSourceIds: List[str]) -> str:
|
|||
connectionId = ds.get("connectionId", "")
|
||||
path = ds.get("path", "/")
|
||||
service = _SOURCE_TYPE_TO_SERVICE.get(sourceType, sourceType)
|
||||
logger.info(f"DataSource context: id={dsId}, label={label}, sourceType={sourceType}, service={service}, connectionId={connectionId}, path={path[:80]}")
|
||||
parts.append(
|
||||
f"- dataSourceId: {dsId}\n"
|
||||
f" label: \"{label}\"\n"
|
||||
|
|
@ -137,8 +139,10 @@ def _buildDataSourceContext(chatService, dataSourceIds: List[str]) -> str:
|
|||
f" connectionId: {connectionId}\n"
|
||||
f" path: {path}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
logger.warning(f"DataSource {dsId} not found in DB")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading DataSource {dsId}: {e}")
|
||||
return "\n".join(parts) if found else ""
|
||||
|
||||
|
||||
|
|
@ -210,7 +214,7 @@ async def streamWorkspaceStart(
|
|||
"message": userInput.prompt,
|
||||
})
|
||||
|
||||
asyncio.ensure_future(
|
||||
agentTask = asyncio.ensure_future(
|
||||
_runWorkspaceAgent(
|
||||
workflowId=workflowId,
|
||||
queueId=queueId,
|
||||
|
|
@ -227,6 +231,7 @@ async def streamWorkspaceStart(
|
|||
userLanguage=userInput.userLanguage,
|
||||
)
|
||||
)
|
||||
eventManager.register_agent_task(queueId, agentTask)
|
||||
|
||||
async def _sseGenerator():
|
||||
queue = eventManager.get_queue(queueId)
|
||||
|
|
@ -321,6 +326,10 @@ async def _runWorkspaceAgent(
|
|||
workflowId=workflowId,
|
||||
userLanguage=userLanguage,
|
||||
):
|
||||
if eventManager.is_cancelled(queueId):
|
||||
logger.info(f"Agent cancelled by user for workflow {workflowId}")
|
||||
break
|
||||
|
||||
sseEvent = {
|
||||
"type": event.type.value if hasattr(event.type, "value") else event.type,
|
||||
"workflowId": workflowId,
|
||||
|
|
@ -356,6 +365,13 @@ async def _runWorkspaceAgent(
|
|||
"workflowId": workflowId,
|
||||
})
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Agent task cancelled for workflow {workflowId}")
|
||||
await eventManager.emit_event(queueId, "stopped", {
|
||||
"type": "stopped",
|
||||
"workflowId": workflowId,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Workspace agent error: {e}", exc_info=True)
|
||||
await eventManager.emit_event(queueId, "error", {
|
||||
|
|
@ -363,6 +379,8 @@ async def _runWorkspaceAgent(
|
|||
"content": str(e),
|
||||
"workflowId": workflowId,
|
||||
})
|
||||
finally:
|
||||
eventManager._unregister_agent_task(queueId)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -380,10 +398,12 @@ async def stopWorkspace(
|
|||
_validateInstanceAccess(instanceId, context)
|
||||
queueId = f"workspace-{workflowId}"
|
||||
eventManager = get_event_manager()
|
||||
cancelled = await eventManager.cancel_agent(queueId)
|
||||
await eventManager.emit_event(queueId, "stopped", {
|
||||
"type": "stopped",
|
||||
"workflowId": workflowId,
|
||||
})
|
||||
logger.info(f"Stop requested for workflow {workflowId}, agent task cancelled: {cancelled}")
|
||||
return JSONResponse({"status": "stopped", "workflowId": workflowId})
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ class EventManager:
|
|||
"""Initialize the event manager."""
|
||||
self._queues: Dict[str, asyncio.Queue] = {}
|
||||
self._cleanup_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._agent_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._cancelled: Dict[str, bool] = {}
|
||||
|
||||
def create_queue(self, workflow_id: str) -> asyncio.Queue:
|
||||
"""
|
||||
|
|
@ -75,6 +77,31 @@ class EventManager:
|
|||
"""
|
||||
return workflow_id in self._queues
|
||||
|
||||
def register_agent_task(self, workflow_id: str, task: asyncio.Task) -> None:
|
||||
"""Register the asyncio Task running the agent for a workflow."""
|
||||
self._agent_tasks[workflow_id] = task
|
||||
self._cancelled.pop(workflow_id, None)
|
||||
|
||||
def is_cancelled(self, workflow_id: str) -> bool:
|
||||
"""Check if a workflow has been cancelled."""
|
||||
return self._cancelled.get(workflow_id, False)
|
||||
|
||||
async def cancel_agent(self, workflow_id: str) -> bool:
|
||||
"""Cancel the running agent task for a workflow. Returns True if cancelled."""
|
||||
self._cancelled[workflow_id] = True
|
||||
task = self._agent_tasks.pop(workflow_id, None)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
logger.info(f"Cancelled agent task for workflow {workflow_id}")
|
||||
return True
|
||||
logger.debug(f"No running agent task found for workflow {workflow_id}")
|
||||
return False
|
||||
|
||||
def _unregister_agent_task(self, workflow_id: str) -> None:
|
||||
"""Remove the agent task reference after completion."""
|
||||
self._agent_tasks.pop(workflow_id, None)
|
||||
self._cancelled.pop(workflow_id, None)
|
||||
|
||||
async def emit_event(
|
||||
self,
|
||||
context_id: str,
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ async def runAgentLoop(
|
|||
conversation.addUserMessage(prompt)
|
||||
|
||||
while state.status == AgentStatusEnum.RUNNING and state.currentRound < state.maxRounds:
|
||||
await asyncio.sleep(0)
|
||||
state.currentRound += 1
|
||||
roundStartTime = time.time()
|
||||
roundLog = AgentRoundLog(roundNumber=state.currentRound)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,21 @@ from modules.serviceCenter.services.serviceBilling.mainServiceBilling import (
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = 50_000
|
||||
|
||||
_BINARY_SIGNATURES = (b"%PDF", b"\x89PNG", b"\xff\xd8\xff", b"GIF8", b"PK\x03\x04", b"Rar!", b"\x1f\x8b")
|
||||
|
||||
|
||||
def _looksLikeBinary(data: bytes, sampleSize: int = 1024) -> bool:
|
||||
"""Detect binary content by checking for magic bytes and non-printable char ratio."""
|
||||
if any(data[:8].startswith(sig) for sig in _BINARY_SIGNATURES):
|
||||
return True
|
||||
sample = data[:sampleSize]
|
||||
if not sample:
|
||||
return False
|
||||
nonPrintable = sum(1 for b in sample if b < 0x09 or (0x0E <= b < 0x20 and b != 0x1B))
|
||||
return nonPrintable / len(sample) > 0.10
|
||||
|
||||
|
||||
class _ServicesAdapter:
|
||||
"""Adapter providing service access from (context, get_service)."""
|
||||
|
|
@ -328,6 +343,8 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
]
|
||||
if textChunks:
|
||||
assembled = "\n\n".join(c["data"] for c in textChunks)
|
||||
if len(assembled) > _MAX_TOOL_RESULT_CHARS:
|
||||
assembled = assembled[:_MAX_TOOL_RESULT_CHARS] + f"\n\n[Truncated – showing first {_MAX_TOOL_RESULT_CHARS} chars of {len(assembled)}]"
|
||||
return ToolResult(
|
||||
toolCallId="", toolName="readFile", success=True,
|
||||
data=assembled,
|
||||
|
|
@ -348,12 +365,18 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
mimeType = fileInfo.get("mimeType", "")
|
||||
|
||||
_BINARY_TYPES = ("application/pdf", "image/", "application/vnd.", "application/zip",
|
||||
"application/x-zip", "application/x-tar", "application/x-7z")
|
||||
"application/x-zip", "application/x-tar", "application/x-7z",
|
||||
"application/msword", "application/octet-stream")
|
||||
isBinary = any(mimeType.startswith(t) for t in _BINARY_TYPES)
|
||||
|
||||
if isBinary and knowledgeService:
|
||||
rawBytes = chatService.getFileData(fileId)
|
||||
if rawBytes:
|
||||
if not rawBytes:
|
||||
return ToolResult(toolCallId="", toolName="readFile", success=True, data="File data not accessible.")
|
||||
|
||||
if not isBinary:
|
||||
isBinary = _looksLikeBinary(rawBytes)
|
||||
|
||||
if isBinary:
|
||||
try:
|
||||
from modules.serviceCenter.services.serviceExtraction.subRegistry import ExtractorRegistry, ChunkerRegistry
|
||||
from modules.serviceCenter.services.serviceExtraction.subPipeline import runExtraction
|
||||
|
|
@ -381,40 +404,48 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
})
|
||||
|
||||
if contentObjects:
|
||||
if knowledgeService:
|
||||
try:
|
||||
userId = context.get("userId", "")
|
||||
await knowledgeService.indexFile(
|
||||
fileId=fileId, fileName=fileName, mimeType=mimeType,
|
||||
userId=userId, contentObjects=contentObjects,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
textParts = [o["data"] for o in contentObjects if o["contentType"] == "text"]
|
||||
if textParts:
|
||||
joined = "\n\n".join(textParts)
|
||||
if len(joined) > _MAX_TOOL_RESULT_CHARS:
|
||||
joined = joined[:_MAX_TOOL_RESULT_CHARS] + f"\n\n[Truncated – showing first {_MAX_TOOL_RESULT_CHARS} chars of {len(joined)}]"
|
||||
return ToolResult(
|
||||
toolCallId="", toolName="readFile", success=True,
|
||||
data="\n\n".join(textParts),
|
||||
data=joined,
|
||||
)
|
||||
imgCount = sum(1 for o in contentObjects if o["contentType"] == "image")
|
||||
return ToolResult(
|
||||
toolCallId="", toolName="readFile", success=True,
|
||||
data=f"[Extracted {len(contentObjects)} content objects from {fileName}. "
|
||||
f"No text content found. Use describeImage or readContentObjects for image/other content.]",
|
||||
data=f"[Extracted {len(contentObjects)} content objects from '{fileName}' "
|
||||
f"({imgCount} images, 0 text). "
|
||||
f"Use describeImage(fileId='{fileId}') to analyze visual content.]",
|
||||
)
|
||||
except Exception as extractErr:
|
||||
logger.warning(f"readFile on-demand extraction failed for {fileId}: {extractErr}")
|
||||
logger.warning(f"readFile extraction failed for {fileId} ({fileName}): {extractErr}")
|
||||
|
||||
# 3) Read raw bytes and decode
|
||||
rawBytes = chatService.getFileData(fileId)
|
||||
if not rawBytes:
|
||||
return ToolResult(toolCallId="", toolName="readFile", success=True, data="File data not accessible.")
|
||||
|
||||
if isBinary:
|
||||
return ToolResult(
|
||||
toolCallId="", toolName="readFile", success=True,
|
||||
data=f"[Binary file: {fileName}, type={mimeType}. Extraction failed or not available.]",
|
||||
data=f"[Binary file: '{fileName}', type={mimeType}, size={len(rawBytes)} bytes. "
|
||||
f"Text extraction not available. Use describeImage for images.]",
|
||||
)
|
||||
|
||||
# 3) Text file: decode raw bytes
|
||||
for encoding in ("utf-8", "utf-8-sig", "latin-1"):
|
||||
try:
|
||||
text = rawBytes.decode(encoding)
|
||||
if text.strip():
|
||||
if len(text) > _MAX_TOOL_RESULT_CHARS:
|
||||
text = text[:_MAX_TOOL_RESULT_CHARS] + f"\n\n[Truncated – showing first {_MAX_TOOL_RESULT_CHARS} chars of {len(text)}]"
|
||||
return ToolResult(
|
||||
toolCallId="", toolName="readFile", success=True,
|
||||
data=text,
|
||||
|
|
@ -954,9 +985,11 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
connectionId = ds.get("connectionId", "")
|
||||
sourceType = ds.get("sourceType", "")
|
||||
path = ds.get("path", "/")
|
||||
label = ds.get("label", "")
|
||||
service = _SOURCE_TYPE_TO_SERVICE.get(sourceType, sourceType)
|
||||
if not connectionId:
|
||||
raise ValueError(f"DataSource '{dsId}' has no connectionId")
|
||||
logger.info(f"Resolved DataSource '{dsId}' ({label}): sourceType={sourceType}, service={service}, connectionId={connectionId}, path={path[:80]}")
|
||||
return connectionId, service, path
|
||||
|
||||
async def _browseDataSource(args: Dict[str, Any], context: Dict[str, Any]):
|
||||
|
|
@ -1012,6 +1045,7 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
async def _downloadFromDataSource(args: Dict[str, Any], context: Dict[str, Any]):
|
||||
dsId = args.get("dataSourceId", "")
|
||||
filePath = args.get("filePath", "")
|
||||
fileName = args.get("fileName", "")
|
||||
if not dsId or not filePath:
|
||||
return ToolResult(toolCallId="", toolName="downloadFromDataSource", success=False, error="dataSourceId and filePath are required")
|
||||
try:
|
||||
|
|
@ -1026,11 +1060,30 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
fileBytes = await adapter.download(fullPath)
|
||||
if not fileBytes:
|
||||
return ToolResult(toolCallId="", toolName="downloadFromDataSource", success=False, error="Download returned empty")
|
||||
fileName = fullPath.split("/")[-1] or "downloaded_file"
|
||||
if not fileName or "." not in fileName:
|
||||
pathSegment = fullPath.split("/")[-1] or "downloaded_file"
|
||||
fileName = fileName or pathSegment
|
||||
if "." not in fileName:
|
||||
try:
|
||||
entries = await adapter.browse(basePath)
|
||||
for entry in entries:
|
||||
if getattr(entry, "path", "") == filePath or getattr(entry, "path", "").endswith(filePath):
|
||||
if "." in entry.name:
|
||||
fileName = entry.name
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
if "." not in fileName:
|
||||
import mimetypes as _mt
|
||||
guessed = _mt.guess_type(f"file.{_mt.guess_extension('application/octet-stream') or ''}")[0]
|
||||
if not guessed and fileBytes[:4] == b"%PDF":
|
||||
fileName = f"{fileName}.pdf"
|
||||
elif not guessed and fileBytes[:2] == b"PK":
|
||||
fileName = f"{fileName}.zip"
|
||||
chatService = services.chat
|
||||
fileItem, _ = chatService.interfaceDbComponent.saveUploadedFile(fileBytes, fileName)
|
||||
ext = fileName.rsplit(".", 1)[-1].lower() if "." in fileName else ""
|
||||
hint = "Use readFile to read the text content." if ext in ("doc", "docx", "txt", "csv", "json", "xml", "html", "md", "rtf", "odt", "xls", "xlsx", "pptx") else "Use readFile to access the content."
|
||||
hint = "Use readFile to read the text content." if ext in ("doc", "docx", "txt", "csv", "json", "xml", "html", "md", "rtf", "odt", "xls", "xlsx", "pptx", "pdf") else "Use readFile to access the content."
|
||||
return ToolResult(
|
||||
toolCallId="", toolName="downloadFromDataSource", success=True,
|
||||
data=f"Downloaded '{fileName}' ({len(fileBytes)} bytes) → local file id: {fileItem.id}. {hint}"
|
||||
|
|
@ -1069,12 +1122,13 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
|
||||
registry.register(
|
||||
"downloadFromDataSource", _downloadFromDataSource,
|
||||
description="Download a file from an attached data source into local storage. Returns the local file ID which can then be read with readFile.",
|
||||
description="Download a file from an attached data source into local storage. Returns the local file ID which can then be read with readFile. Always provide the fileName if known from the browse results.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dataSourceId": {"type": "string", "description": "DataSource ID"},
|
||||
"filePath": {"type": "string", "description": "Path of the file to download (as returned by browseDataSource)"},
|
||||
"fileName": {"type": "string", "description": "Human-readable file name with extension (e.g. 'report.pdf'). Get this from browseDataSource results."},
|
||||
},
|
||||
"required": ["dataSourceId", "filePath"],
|
||||
},
|
||||
|
|
@ -1370,6 +1424,18 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
error=f"No image data found in '{fileName}' (type: {fileMime}). "
|
||||
f"This file likely contains text, not images. Use readFile(fileId=\"{fileId}\") to access its text content.")
|
||||
|
||||
try:
|
||||
rawHead = _b64.b64decode(imageData[:32])
|
||||
if rawHead[:3] == b"\xff\xd8\xff":
|
||||
mimeType = "image/jpeg"
|
||||
elif rawHead[:8] == b"\x89PNG\r\n\x1a\n":
|
||||
mimeType = "image/png"
|
||||
elif rawHead[:4] == b"GIF8":
|
||||
mimeType = "image/gif"
|
||||
elif rawHead[:4] == b"RIFF" and rawHead[8:12] == b"WEBP":
|
||||
mimeType = "image/webp"
|
||||
except Exception:
|
||||
pass
|
||||
dataUrl = f"data:{mimeType};base64,{imageData}"
|
||||
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum as OTE
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue