302 lines
14 KiB
Python
302 lines
14 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""RAG Inventory API — global knowledge-store visibility for users, admins, platform."""
|
|
|
|
import logging
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from fastapi import APIRouter, HTTPException, Depends, Request
|
|
from modules.auth import limiter, getCurrentUser, getRequestContext, RequestContext
|
|
from modules.datamodels.datamodelUam import User
|
|
from modules.shared.i18nRegistry import apiRouteContext
|
|
|
|
routeApiMsg = apiRouteContext("routeRagInventory")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(
|
|
prefix="/api/rag/inventory",
|
|
tags=["RAG Inventory"],
|
|
responses={
|
|
401: {"description": "Unauthorized"},
|
|
403: {"description": "Forbidden"},
|
|
500: {"description": "Internal server error"},
|
|
},
|
|
)
|
|
|
|
|
|
def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]:
|
|
from modules.datamodels.datamodelDataSource import DataSource
|
|
from modules.datamodels.datamodelKnowledge import FileContentIndex
|
|
|
|
out = []
|
|
for conn in connections:
|
|
connectionId = str(conn.id)
|
|
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
|
|
|
|
connIndexRows = knowledgeIf.db.getRecordset(FileContentIndex, recordFilter={"connectionId": connectionId})
|
|
connChunkTotal = len(connIndexRows)
|
|
|
|
chunksByDs: Dict[str, int] = {}
|
|
unassigned = 0
|
|
for idx in connIndexRows:
|
|
struct = (idx.get("structure") if isinstance(idx, dict) else getattr(idx, "structure", None)) or {}
|
|
ingestion = struct.get("_ingestion") or {} if isinstance(struct, dict) else {}
|
|
prov = ingestion.get("provenance") or {} if isinstance(ingestion, dict) else {}
|
|
dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else ""
|
|
if dsIdRef:
|
|
chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + 1
|
|
else:
|
|
unassigned += 1
|
|
|
|
seen: Dict[str, bool] = {}
|
|
dsItems = []
|
|
for ds in dataSources:
|
|
dsId = ds.get("id") if isinstance(ds, dict) else getattr(ds, "id", "")
|
|
dsPath = ds.get("path") if isinstance(ds, dict) else getattr(ds, "path", "")
|
|
if dsPath in seen:
|
|
continue
|
|
seen[dsPath] = True
|
|
dsItems.append({
|
|
"id": dsId,
|
|
"label": ds.get("label") if isinstance(ds, dict) else getattr(ds, "label", ""),
|
|
"path": dsPath,
|
|
"sourceType": ds.get("sourceType") if isinstance(ds, dict) else getattr(ds, "sourceType", ""),
|
|
"ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False),
|
|
"neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False),
|
|
"lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None),
|
|
"chunkCount": chunksByDs.get(dsId, 0),
|
|
})
|
|
|
|
if unassigned > 0 and len(dsItems) > 0:
|
|
perDs = unassigned // len(dsItems)
|
|
remainder = unassigned % len(dsItems)
|
|
for i, item in enumerate(dsItems):
|
|
item["chunkCount"] += perDs + (1 if i < remainder else 0)
|
|
|
|
# Pull a wider window than the previous 5 so the "last successful
|
|
# sync" is found even if a connection has many recent jobs queued.
|
|
jobs = jobService.listJobs(jobType="connection.bootstrap", limit=50)
|
|
connJobs = [j for j in jobs if (j.get("payload") or {}).get("connectionId") == connectionId]
|
|
runningJobs = [
|
|
{"jobId": j["id"], "progress": j.get("progress", 0), "progressMessage": j.get("progressMessage", "")}
|
|
for j in connJobs
|
|
if j.get("status") in ("PENDING", "RUNNING")
|
|
]
|
|
lastError: Optional[Dict[str, Any]] = None
|
|
lastSuccess: Optional[Dict[str, Any]] = None
|
|
for j in connJobs:
|
|
status = j.get("status")
|
|
if status == "ERROR" and lastError is None:
|
|
lastError = {
|
|
"jobId": j["id"],
|
|
"errorMessage": j.get("errorMessage", ""),
|
|
"finishedAt": j.get("finishedAt"),
|
|
}
|
|
elif status == "SUCCESS" and lastSuccess is None:
|
|
result = j.get("result") or {}
|
|
lastSuccess = {
|
|
"jobId": j["id"],
|
|
"finishedAt": j.get("finishedAt"),
|
|
"indexed": result.get("indexed", 0),
|
|
"skippedDuplicate": result.get("skippedDuplicate", 0),
|
|
"skippedPolicy": result.get("skippedPolicy", 0),
|
|
"failed": result.get("failed", 0),
|
|
"durationMs": result.get("durationMs", 0),
|
|
}
|
|
if lastError and lastSuccess:
|
|
break
|
|
|
|
out.append({
|
|
"id": connectionId,
|
|
"authority": conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority),
|
|
"externalEmail": getattr(conn, "externalEmail", ""),
|
|
"knowledgeIngestionEnabled": getattr(conn, "knowledgeIngestionEnabled", False),
|
|
"preferences": getattr(conn, "knowledgePreferences", None) or {},
|
|
"dataSources": dsItems,
|
|
"totalChunks": connChunkTotal,
|
|
"runningJobs": runningJobs,
|
|
"lastError": lastError,
|
|
"lastSuccess": lastSuccess,
|
|
})
|
|
return out
|
|
|
|
|
|
@router.get("/me")
|
|
@limiter.limit("30/minute")
|
|
def _getInventoryMe(
|
|
request: Request,
|
|
currentUser: User = Depends(getCurrentUser),
|
|
) -> Dict[str, Any]:
|
|
"""Personal RAG inventory: own connections + DataSources + chunk counts."""
|
|
try:
|
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
|
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
|
|
from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService
|
|
|
|
rootIf = getRootInterface()
|
|
knowledgeIf = getKnowledgeInterface(None)
|
|
connections = rootIf.getUserConnections(currentUser.id)
|
|
|
|
items = _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService)
|
|
totalChunks = sum(c.get("totalChunks", 0) for c in items)
|
|
|
|
return {"connections": items, "totals": {"chunks": totalChunks}}
|
|
except Exception as e:
|
|
logger.error("Error in RAG inventory /me: %s", e, exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/mandate")
|
|
@limiter.limit("20/minute")
|
|
def _getInventoryMandate(
|
|
request: Request,
|
|
context: RequestContext = Depends(getRequestContext),
|
|
) -> Dict[str, Any]:
|
|
"""Mandate-level RAG aggregation (requires mandate membership)."""
|
|
if not context.mandateId:
|
|
raise HTTPException(status_code=403, detail=routeApiMsg("Mandate context required"))
|
|
try:
|
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
|
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface, aggregateMandateRagTotalBytes
|
|
from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService
|
|
|
|
rootIf = getRootInterface()
|
|
knowledgeIf = getKnowledgeInterface(None)
|
|
mandateId = str(context.mandateId) if context.mandateId else ""
|
|
|
|
from modules.datamodels.datamodelUam import UserConnection
|
|
allConnections = rootIf.db.getRecordset(UserConnection, recordFilter={"mandateId": mandateId})
|
|
connectionObjects = [type("C", (), row)() if isinstance(row, dict) else row for row in allConnections]
|
|
|
|
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
|
|
totalChunks = sum(c.get("totalChunks", 0) for c in items)
|
|
totalBytes = aggregateMandateRagTotalBytes(mandateId)
|
|
|
|
return {"connections": items, "totals": {"chunks": totalChunks, "bytes": totalBytes}}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Error in RAG inventory /mandate: %s", e, exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/platform")
|
|
@limiter.limit("10/minute")
|
|
def _getInventoryPlatform(
|
|
request: Request,
|
|
context: RequestContext = Depends(getRequestContext),
|
|
) -> Dict[str, Any]:
|
|
"""Platform-wide RAG statistics (sysadmin only)."""
|
|
if not context.isSysAdmin:
|
|
raise HTTPException(status_code=403, detail=routeApiMsg("Platform admin required"))
|
|
try:
|
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
|
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
|
|
from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService
|
|
from modules.datamodels.datamodelUam import UserConnection
|
|
|
|
rootIf = getRootInterface()
|
|
knowledgeIf = getKnowledgeInterface(None)
|
|
allConnections = rootIf.db.getRecordset(UserConnection)
|
|
connectionObjects = [type("C", (), row)() if isinstance(row, dict) else row for row in allConnections]
|
|
|
|
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
|
|
totalChunks = sum(c.get("totalChunks", 0) for c in items)
|
|
|
|
return {"connections": items, "totals": {"chunks": totalChunks}}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Error in RAG inventory /platform: %s", e, exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/reindex/{connectionId}")
|
|
@limiter.limit("10/minute")
|
|
async def _reindexConnection(
|
|
request: Request,
|
|
connectionId: str,
|
|
currentUser: User = Depends(getCurrentUser),
|
|
) -> Dict[str, Any]:
|
|
"""Re-trigger bootstrap for a connection (re-index all ragIndexEnabled DataSources).
|
|
|
|
Submits a new connection.bootstrap job, regardless of previous failures.
|
|
|
|
Must be `async def` so `await startJob(...)` registers the `_runJob` task
|
|
in FastAPI's main event loop. A sync route would land in the worker
|
|
threadpool and `asyncio.run` would tear down the temporary loop right
|
|
after `create_task`, leaving the job stuck in PENDING forever.
|
|
"""
|
|
try:
|
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
|
from modules.serviceCenter.services.serviceBackgroundJobs import startJob
|
|
from modules.datamodels.datamodelDataSource import DataSource
|
|
|
|
rootIf = getRootInterface()
|
|
conn = rootIf.getUserConnectionById(connectionId)
|
|
if conn is None:
|
|
raise HTTPException(status_code=404, detail="Connection not found")
|
|
|
|
if str(conn.userId) != str(currentUser.id):
|
|
raise HTTPException(status_code=403, detail="Not your connection")
|
|
|
|
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
|
|
ragDs = [ds for ds in dataSources if (ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False))]
|
|
if not ragDs:
|
|
return {"status": "skipped", "reason": "no_rag_enabled_datasources"}
|
|
|
|
authority = conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority or "")
|
|
dsIds = [(ds.get("id") if isinstance(ds, dict) else getattr(ds, "id", "")) for ds in ragDs]
|
|
|
|
jobId = await startJob(
|
|
"connection.bootstrap",
|
|
{"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": dsIds},
|
|
triggeredBy=str(currentUser.id),
|
|
)
|
|
|
|
logger.info("Reindex triggered for connection %s (%d DataSources, jobId=%s)", connectionId, len(dsIds), jobId)
|
|
return {"status": "queued", "connectionId": connectionId, "dataSourceCount": len(dsIds), "jobId": jobId}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Error triggering reindex: %s", e, exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/jobs")
|
|
@limiter.limit("60/minute")
|
|
def _getActiveJobs(
|
|
request: Request,
|
|
currentUser: User = Depends(getCurrentUser),
|
|
) -> List[Dict[str, Any]]:
|
|
"""Active RAG jobs for the current user (used by header badge)."""
|
|
try:
|
|
from modules.serviceCenter.services.serviceBackgroundJobs import listJobs
|
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
|
|
|
rootIf = getRootInterface()
|
|
connections = rootIf.getUserConnections(currentUser.id)
|
|
connectionMap = {str(c.id): c for c in connections}
|
|
connectionIds = set(connectionMap.keys())
|
|
|
|
jobs = listJobs(jobType="connection.bootstrap", limit=50)
|
|
active = []
|
|
for j in jobs:
|
|
if j.get("status") not in ("PENDING", "RUNNING"):
|
|
continue
|
|
payload = j.get("payload") or {}
|
|
connId = payload.get("connectionId")
|
|
if connId in connectionIds:
|
|
conn = connectionMap[connId]
|
|
active.append({
|
|
"jobId": j["id"],
|
|
"connectionId": connId,
|
|
"connectionLabel": getattr(conn, "displayLabel", None) or getattr(conn, "authority", connId),
|
|
"jobType": j.get("jobType", "connection.bootstrap"),
|
|
"progress": j.get("progress", 0),
|
|
"progressMessage": j.get("progressMessage", ""),
|
|
})
|
|
return active
|
|
except Exception as e:
|
|
logger.error("Error in RAG inventory /jobs: %s", e, exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|