fixed source access tools
This commit is contained in:
parent
2a04acb278
commit
a24b20d302
6 changed files with 250 additions and 18 deletions
|
|
@ -61,19 +61,51 @@ class DriveAdapter(ServiceAdapter):
|
|||
))
|
||||
return entries
|
||||
|
||||
_EXPORT_MIME_MAP = {
|
||||
"application/vnd.google-apps.document": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.google-apps.spreadsheet": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"application/vnd.google-apps.presentation": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"application/vnd.google-apps.drawing": "application/pdf",
|
||||
}
|
||||
|
||||
async def download(self, path: str) -> bytes:
|
||||
fileId = (path or "").strip("/")
|
||||
if not fileId:
|
||||
return b""
|
||||
url = f"{_DRIVE_BASE}/files/{fileId}?alt=media"
|
||||
headers = {"Authorization": f"Bearer {self._token}"}
|
||||
timeout = aiohttp.ClientTimeout(total=60)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
# Try direct download first
|
||||
url = f"{_DRIVE_BASE}/files/{fileId}?alt=media"
|
||||
async with session.get(url, headers=headers) as resp:
|
||||
if resp.status == 200:
|
||||
return await resp.read()
|
||||
logger.debug(f"Google Drive direct download returned {resp.status} for {fileId}")
|
||||
|
||||
# If 403/404, check if it's a native Google file that needs export
|
||||
metaUrl = f"{_DRIVE_BASE}/files/{fileId}?fields=mimeType,name"
|
||||
async with session.get(metaUrl, headers=headers) as metaResp:
|
||||
if metaResp.status != 200:
|
||||
logger.warning(f"Google Drive metadata fetch failed ({metaResp.status}) for {fileId}")
|
||||
return b""
|
||||
meta = await metaResp.json()
|
||||
fileMime = meta.get("mimeType", "")
|
||||
fileName = meta.get("name", fileId)
|
||||
|
||||
exportMime = self._EXPORT_MIME_MAP.get(fileMime)
|
||||
if not exportMime:
|
||||
logger.warning(f"Google Drive: unsupported mimeType '{fileMime}' for file '{fileName}' ({fileId})")
|
||||
return b""
|
||||
|
||||
exportUrl = f"{_DRIVE_BASE}/files/{fileId}/export?mimeType={exportMime}"
|
||||
logger.info(f"Google Drive: exporting '{fileName}' as {exportMime}")
|
||||
async with session.get(exportUrl, headers=headers) as exportResp:
|
||||
if exportResp.status == 200:
|
||||
return await exportResp.read()
|
||||
logger.warning(f"Google Drive export failed ({exportResp.status}) for '{fileName}'")
|
||||
except Exception as e:
|
||||
logger.error(f"Google Drive download failed: {e}")
|
||||
logger.error(f"Google Drive download failed for {fileId}: {e}")
|
||||
return b""
|
||||
|
||||
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||
|
|
|
|||
|
|
@ -98,20 +98,45 @@ def _getDbManagement(context: RequestContext, featureInstanceId: str = None):
|
|||
)
|
||||
|
||||
|
||||
def _buildDataSourceContext(chatInterface, dataSourceIds: List[str]) -> str:
|
||||
_SOURCE_TYPE_TO_SERVICE = {
|
||||
"sharepointFolder": "sharepoint",
|
||||
"onedriveFolder": "onedrive",
|
||||
"outlookFolder": "outlook",
|
||||
"googleDriveFolder": "drive",
|
||||
"gmailFolder": "gmail",
|
||||
"ftpFolder": "files",
|
||||
}
|
||||
|
||||
|
||||
def _buildDataSourceContext(chatService, dataSourceIds: List[str]) -> str:
|
||||
"""Build a description of active data sources for the agent prompt."""
|
||||
parts = []
|
||||
parts = [
|
||||
"The user has attached the following external data sources to this prompt.",
|
||||
"IMPORTANT: Use the dataSourceId (UUID) exactly as shown below when calling browseDataSource or searchDataSource.",
|
||||
"Use downloadFromDataSource to download a specific file into local storage.",
|
||||
"",
|
||||
]
|
||||
found = False
|
||||
for dsId in dataSourceIds:
|
||||
try:
|
||||
ds = chatInterface.db.recordGet("DataSource", dsId)
|
||||
ds = chatService.getDataSource(dsId) if hasattr(chatService, "getDataSource") else None
|
||||
if ds:
|
||||
found = True
|
||||
label = ds.get("label", "")
|
||||
sourceType = ds.get("sourceType", "")
|
||||
connectionId = ds.get("connectionId", "")
|
||||
path = ds.get("path", "/")
|
||||
parts.append(f"- {label} ({sourceType}, path: {path})")
|
||||
service = _SOURCE_TYPE_TO_SERVICE.get(sourceType, sourceType)
|
||||
parts.append(
|
||||
f"- dataSourceId: {dsId}\n"
|
||||
f" label: \"{label}\"\n"
|
||||
f" type: {sourceType} (service: {service})\n"
|
||||
f" connectionId: {connectionId}\n"
|
||||
f" path: {path}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return "\n".join(parts) if parts else ""
|
||||
return "\n".join(parts) if found else ""
|
||||
|
||||
|
||||
def _deriveWorkflowName(prompt: str, maxLen: int = 40) -> str:
|
||||
|
|
@ -241,10 +266,11 @@ async def _runWorkspaceAgent(
|
|||
workflow_id=workflowId,
|
||||
)
|
||||
agentService = getService("agent", ctx)
|
||||
chatService = getService("chat", ctx)
|
||||
|
||||
enrichedPrompt = prompt
|
||||
if dataSourceIds:
|
||||
dsInfo = _buildDataSourceContext(chatInterface, dataSourceIds)
|
||||
dsInfo = _buildDataSourceContext(chatService, dataSourceIds)
|
||||
if dsInfo:
|
||||
enrichedPrompt = f"{prompt}\n\n[Active Data Sources]\n{dsInfo}"
|
||||
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ class EventManager:
|
|||
|
||||
try:
|
||||
await queue.put(event)
|
||||
if event_type not in ("chunk",):
|
||||
logger.debug(f"Emitted {event_type} event for workflow {context_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting event for workflow {context_id}: {e}", exc_info=True)
|
||||
|
|
|
|||
|
|
@ -146,10 +146,16 @@ async def runAgentLoop(
|
|||
try:
|
||||
aiResponse = None
|
||||
streamedText = ""
|
||||
isFirstChunkOfRound = True
|
||||
|
||||
if aiCallStreamFn:
|
||||
async for chunk in aiCallStreamFn(aiRequest):
|
||||
if isinstance(chunk, str):
|
||||
if isFirstChunkOfRound and state.currentRound > 1:
|
||||
chunk = "\n\n" + chunk
|
||||
isFirstChunkOfRound = False
|
||||
elif isFirstChunkOfRound:
|
||||
isFirstChunkOfRound = False
|
||||
streamedText += chunk
|
||||
yield AgentEvent(type=AgentEventTypeEnum.CHUNK, content=chunk)
|
||||
else:
|
||||
|
|
@ -221,6 +227,8 @@ async def runAgentLoop(
|
|||
durationMs=result.durationMs,
|
||||
error=result.error
|
||||
))
|
||||
if not result.success:
|
||||
logger.warning(f"Tool '{result.toolName}' failed: {result.error}")
|
||||
yield AgentEvent(
|
||||
type=AgentEventTypeEnum.TOOL_RESULT,
|
||||
data={
|
||||
|
|
|
|||
|
|
@ -695,6 +695,20 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
|
||||
# ---- Connection tools (external data sources) ----
|
||||
|
||||
def _buildResolverDb():
|
||||
"""Build a DB adapter that ConnectorResolver can use to load UserConnections.
|
||||
interfaceDbApp has getUserConnectionById; ConnectorResolver expects getUserConnection."""
|
||||
chatService = services.chat
|
||||
appIf = getattr(chatService, "interfaceDbApp", None)
|
||||
if appIf and hasattr(appIf, "getUserConnectionById"):
|
||||
class _Adapter:
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
def getUserConnection(self, connectionId: str):
|
||||
return self._app.getUserConnectionById(connectionId)
|
||||
return _Adapter(appIf)
|
||||
return getattr(chatService, "interfaceDbComponent", None)
|
||||
|
||||
async def _listConnections(args: Dict[str, Any], context: Dict[str, Any]):
|
||||
try:
|
||||
chatService = services.chat
|
||||
|
|
@ -721,7 +735,7 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
from modules.connectors.connectorResolver import ConnectorResolver
|
||||
resolver = ConnectorResolver(
|
||||
services.getService("security"),
|
||||
services.chat.interfaceDbComponent if hasattr(services.chat, "interfaceDbComponent") else None,
|
||||
_buildResolverDb(),
|
||||
)
|
||||
adapter = await resolver.resolveService(connectionId, service)
|
||||
entries = await adapter.browse(path, filter=args.get("filter"))
|
||||
|
|
@ -743,7 +757,7 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
from modules.connectors.connectorResolver import ConnectorResolver
|
||||
resolver = ConnectorResolver(
|
||||
services.getService("security"),
|
||||
services.chat.interfaceDbComponent if hasattr(services.chat, "interfaceDbComponent") else None,
|
||||
_buildResolverDb(),
|
||||
)
|
||||
adapter = await resolver.resolveService(connectionId, service)
|
||||
fileBytes = await adapter.download(path)
|
||||
|
|
@ -752,9 +766,11 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
fileName = path.split("/")[-1] or "downloaded_file"
|
||||
chatService = services.chat
|
||||
fileItem, _ = chatService.interfaceDbComponent.saveUploadedFile(fileBytes, fileName)
|
||||
ext = fileName.rsplit(".", 1)[-1].lower() if "." in fileName else ""
|
||||
hint = "Use readFile to read text content." if ext in ("doc", "docx", "txt", "csv", "json", "xml", "html", "md", "rtf", "odt", "xls", "xlsx", "pptx") else "Use readFile to access the content."
|
||||
return ToolResult(
|
||||
toolCallId="", toolName="externalDownload", success=True,
|
||||
data=f"Downloaded '{fileName}' ({len(fileBytes)} bytes) → local file id: {fileItem.id}"
|
||||
data=f"Downloaded '{fileName}' ({len(fileBytes)} bytes) → local file id: {fileItem.id}. {hint}"
|
||||
)
|
||||
except Exception as e:
|
||||
return ToolResult(toolCallId="", toolName="externalDownload", success=False, error=str(e))
|
||||
|
|
@ -770,7 +786,7 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
from modules.connectors.connectorResolver import ConnectorResolver
|
||||
resolver = ConnectorResolver(
|
||||
services.getService("security"),
|
||||
services.chat.interfaceDbComponent if hasattr(services.chat, "interfaceDbComponent") else None,
|
||||
_buildResolverDb(),
|
||||
)
|
||||
adapter = await resolver.resolveService(connectionId, service)
|
||||
chatService = services.chat
|
||||
|
|
@ -796,7 +812,7 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
from modules.connectors.connectorResolver import ConnectorResolver
|
||||
resolver = ConnectorResolver(
|
||||
services.getService("security"),
|
||||
services.chat.interfaceDbComponent if hasattr(services.chat, "interfaceDbComponent") else None,
|
||||
_buildResolverDb(),
|
||||
)
|
||||
adapter = await resolver.resolveService(connectionId, service)
|
||||
entries = await adapter.search(query, path=args.get("path"))
|
||||
|
|
@ -819,7 +835,7 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
from modules.connectors.connectorResolver import ConnectorResolver
|
||||
resolver = ConnectorResolver(
|
||||
services.getService("security"),
|
||||
services.chat.interfaceDbComponent if hasattr(services.chat, "interfaceDbComponent") else None,
|
||||
_buildResolverDb(),
|
||||
)
|
||||
adapter = await resolver.resolveService(connectionId, "outlook")
|
||||
if hasattr(adapter, "sendMail"):
|
||||
|
|
@ -917,6 +933,149 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
readOnly=False,
|
||||
)
|
||||
|
||||
# ---- DataSource convenience tools ----
|
||||
_SOURCE_TYPE_TO_SERVICE = {
|
||||
"sharepointFolder": "sharepoint",
|
||||
"onedriveFolder": "onedrive",
|
||||
"outlookFolder": "outlook",
|
||||
"googleDriveFolder": "drive",
|
||||
"gmailFolder": "gmail",
|
||||
"ftpFolder": "files",
|
||||
}
|
||||
|
||||
async def _resolveDataSource(dsId: str):
|
||||
"""Resolve a DataSource record and return (connectionId, service, path) or raise."""
|
||||
chatService = services.chat
|
||||
ds = chatService.getDataSource(dsId) if hasattr(chatService, "getDataSource") else None
|
||||
if not ds:
|
||||
raise ValueError(f"DataSource '{dsId}' not found")
|
||||
connectionId = ds.get("connectionId", "")
|
||||
sourceType = ds.get("sourceType", "")
|
||||
path = ds.get("path", "/")
|
||||
service = _SOURCE_TYPE_TO_SERVICE.get(sourceType, sourceType)
|
||||
if not connectionId:
|
||||
raise ValueError(f"DataSource '{dsId}' has no connectionId")
|
||||
return connectionId, service, path
|
||||
|
||||
async def _browseDataSource(args: Dict[str, Any], context: Dict[str, Any]):
|
||||
dsId = args.get("dataSourceId", "")
|
||||
subPath = args.get("subPath", "")
|
||||
if not dsId:
|
||||
return ToolResult(toolCallId="", toolName="browseDataSource", success=False, error="dataSourceId is required")
|
||||
try:
|
||||
connectionId, service, basePath = await _resolveDataSource(dsId)
|
||||
browsePath = f"{basePath.rstrip('/')}/{subPath.lstrip('/')}" if subPath else basePath
|
||||
from modules.connectors.connectorResolver import ConnectorResolver
|
||||
resolver = ConnectorResolver(
|
||||
services.getService("security"),
|
||||
_buildResolverDb(),
|
||||
)
|
||||
adapter = await resolver.resolveService(connectionId, service)
|
||||
entries = await adapter.browse(browsePath, filter=args.get("filter"))
|
||||
if not entries:
|
||||
return ToolResult(toolCallId="", toolName="browseDataSource", success=True, data="Empty directory.")
|
||||
lines = []
|
||||
for e in entries:
|
||||
prefix = "[DIR]" if e.isFolder else "[FILE]"
|
||||
sizeInfo = f" ({e.size} bytes)" if e.size else ""
|
||||
lines.append(f"- {prefix} {e.name}{sizeInfo} path: {e.path}")
|
||||
return ToolResult(toolCallId="", toolName="browseDataSource", success=True, data="\n".join(lines))
|
||||
except Exception as e:
|
||||
return ToolResult(toolCallId="", toolName="browseDataSource", success=False, error=str(e))
|
||||
|
||||
async def _searchDataSource(args: Dict[str, Any], context: Dict[str, Any]):
|
||||
dsId = args.get("dataSourceId", "")
|
||||
query = args.get("query", "")
|
||||
if not dsId or not query:
|
||||
return ToolResult(toolCallId="", toolName="searchDataSource", success=False, error="dataSourceId and query are required")
|
||||
try:
|
||||
connectionId, service, basePath = await _resolveDataSource(dsId)
|
||||
from modules.connectors.connectorResolver import ConnectorResolver
|
||||
resolver = ConnectorResolver(
|
||||
services.getService("security"),
|
||||
_buildResolverDb(),
|
||||
)
|
||||
adapter = await resolver.resolveService(connectionId, service)
|
||||
entries = await adapter.search(query, path=basePath)
|
||||
if not entries:
|
||||
return ToolResult(toolCallId="", toolName="searchDataSource", success=True, data="No results found.")
|
||||
lines = [f"- {e.name} (path: {e.path})" for e in entries]
|
||||
return ToolResult(toolCallId="", toolName="searchDataSource", success=True, data="\n".join(lines))
|
||||
except Exception as e:
|
||||
return ToolResult(toolCallId="", toolName="searchDataSource", success=False, error=str(e))
|
||||
|
||||
async def _downloadFromDataSource(args: Dict[str, Any], context: Dict[str, Any]):
|
||||
dsId = args.get("dataSourceId", "")
|
||||
filePath = args.get("filePath", "")
|
||||
if not dsId or not filePath:
|
||||
return ToolResult(toolCallId="", toolName="downloadFromDataSource", success=False, error="dataSourceId and filePath are required")
|
||||
try:
|
||||
connectionId, service, basePath = await _resolveDataSource(dsId)
|
||||
fullPath = filePath if filePath.startswith("/") else f"{basePath.rstrip('/')}/{filePath}"
|
||||
from modules.connectors.connectorResolver import ConnectorResolver
|
||||
resolver = ConnectorResolver(
|
||||
services.getService("security"),
|
||||
_buildResolverDb(),
|
||||
)
|
||||
adapter = await resolver.resolveService(connectionId, service)
|
||||
fileBytes = await adapter.download(fullPath)
|
||||
if not fileBytes:
|
||||
return ToolResult(toolCallId="", toolName="downloadFromDataSource", success=False, error="Download returned empty")
|
||||
fileName = fullPath.split("/")[-1] or "downloaded_file"
|
||||
chatService = services.chat
|
||||
fileItem, _ = chatService.interfaceDbComponent.saveUploadedFile(fileBytes, fileName)
|
||||
ext = fileName.rsplit(".", 1)[-1].lower() if "." in fileName else ""
|
||||
hint = "Use readFile to read the text content." if ext in ("doc", "docx", "txt", "csv", "json", "xml", "html", "md", "rtf", "odt", "xls", "xlsx", "pptx") else "Use readFile to access the content."
|
||||
return ToolResult(
|
||||
toolCallId="", toolName="downloadFromDataSource", success=True,
|
||||
data=f"Downloaded '{fileName}' ({len(fileBytes)} bytes) → local file id: {fileItem.id}. {hint}"
|
||||
)
|
||||
except Exception as e:
|
||||
return ToolResult(toolCallId="", toolName="downloadFromDataSource", success=False, error=str(e))
|
||||
|
||||
registry.register(
|
||||
"browseDataSource", _browseDataSource,
|
||||
description="Browse files and folders in an attached data source by its dataSourceId. Returns file/folder listing.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dataSourceId": {"type": "string", "description": "DataSource ID (from the attached data sources in the prompt)"},
|
||||
"subPath": {"type": "string", "description": "Optional sub-path within the data source to browse"},
|
||||
"filter": {"type": "string", "description": "Optional filter pattern (e.g. '*.pdf')"},
|
||||
},
|
||||
"required": ["dataSourceId"],
|
||||
},
|
||||
readOnly=True,
|
||||
)
|
||||
|
||||
registry.register(
|
||||
"searchDataSource", _searchDataSource,
|
||||
description="Search for files within an attached data source by query.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dataSourceId": {"type": "string", "description": "DataSource ID"},
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
},
|
||||
"required": ["dataSourceId", "query"],
|
||||
},
|
||||
readOnly=True,
|
||||
)
|
||||
|
||||
registry.register(
|
||||
"downloadFromDataSource", _downloadFromDataSource,
|
||||
description="Download a file from an attached data source into local storage. Returns the local file ID which can then be read with readFile.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dataSourceId": {"type": "string", "description": "DataSource ID"},
|
||||
"filePath": {"type": "string", "description": "Path of the file to download (as returned by browseDataSource)"},
|
||||
},
|
||||
"required": ["dataSourceId", "filePath"],
|
||||
},
|
||||
readOnly=False,
|
||||
)
|
||||
|
||||
# ---- Document tools (Smart Documents / Container Handling) ----
|
||||
|
||||
async def _browseContainer(args: Dict[str, Any], context: Dict[str, Any]):
|
||||
|
|
@ -1198,8 +1357,13 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
mimeType = fileMimeType
|
||||
|
||||
if not imageData:
|
||||
chatService = services.chat
|
||||
fileInfo = chatService.getFileInfo(fileId) if hasattr(chatService, "getFileInfo") else None
|
||||
fileName = fileInfo.get("fileName", fileId) if fileInfo else fileId
|
||||
fileMime = fileInfo.get("mimeType", "unknown") if fileInfo else "unknown"
|
||||
return ToolResult(toolCallId="", toolName="describeImage", success=False,
|
||||
error="No image data found. The file may not contain images or extraction failed.")
|
||||
error=f"No image data found in '{fileName}' (type: {fileMime}). "
|
||||
f"This file likely contains text, not images. Use readFile(fileId=\"{fileId}\") to access its text content.")
|
||||
|
||||
dataUrl = f"data:{mimeType};base64,{imageData}"
|
||||
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum as OTE
|
||||
|
|
|
|||
|
|
@ -541,7 +541,8 @@ class ChatService:
|
|||
def getDataSource(self, dataSourceId: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a single data source by ID."""
|
||||
from modules.datamodels.datamodelDataSource import DataSource
|
||||
return self.interfaceDbComponent.db.loadRecord(DataSource, dataSourceId)
|
||||
results = self.interfaceDbComponent.db.getRecordset(DataSource, recordFilter={"id": dataSourceId})
|
||||
return results[0] if results else None
|
||||
|
||||
def deleteDataSource(self, dataSourceId: str) -> bool:
|
||||
"""Delete a data source."""
|
||||
|
|
|
|||
Loading…
Reference in a new issue