platform-core/modules/routes/routeRagInventory.py
2026-05-18 07:56:53 +02:00

410 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""RAG Inventory API — global knowledge-store visibility for users, admins, platform."""
import logging
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, HTTPException, Depends, Request
from modules.auth import limiter, getCurrentUser, getRequestContext, RequestContext
from modules.datamodels.datamodelUam import User
from modules.shared.i18nRegistry import apiRouteContext, resolveJobMessage
routeApiMsg = apiRouteContext("routeRagInventory")
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/api/rag/inventory",
tags=["RAG Inventory"],
responses={
401: {"description": "Unauthorized"},
403: {"description": "Forbidden"},
500: {"description": "Internal server error"},
},
)
_SUB_RESULT_KEYS = ("sharepoint", "outlook", "drive", "gmail", "clickup", "kdrive")
def _flattenJobResult(result: Dict[str, Any]) -> Dict[str, Any]:
"""Bootstrap handlers nest per-service results (e.g. msft returns
`{"sharepoint": {...}, "outlook": {...}}`). The UI needs per-connection
aggregates AND the first hit limit, so we sum the counters and pick the
most informative `stoppedAtLimit` across sub-services.
Returns a flat dict with the same keys the UI expects on `lastSuccess`.
"""
subResults = [result[k] for k in _SUB_RESULT_KEYS if isinstance(result.get(k), dict)]
if not subResults:
# Single-service handler that returns flat dict directly (legacy path).
return result
indexed = sum(int(r.get("indexed") or 0) for r in subResults)
skippedDup = sum(int(r.get("skippedDuplicate") or 0) for r in subResults)
skippedPol = sum(int(r.get("skippedPolicy") or 0) for r in subResults)
failed = sum(int(r.get("failed") or 0) for r in subResults)
bytes_ = sum(int(r.get("bytesProcessed") or 0) for r in subResults)
# Parallel sub-services: wall-clock ≈ slowest one.
durationMs = max((int(r.get("durationMs") or 0) for r in subResults), default=0)
# First sub-service that hit a limit wins — UI shows one banner per
# connection; if multiple stopped, the first one is informative enough
# and the user re-runs after raising that budget.
stoppedAtLimit: Optional[str] = None
limits: Dict[str, Any] = {}
for r in subResults:
if r.get("stoppedAtLimit"):
stoppedAtLimit = r["stoppedAtLimit"]
limits = r.get("limits") or {}
break
return {
"indexed": indexed,
"skippedDuplicate": skippedDup,
"skippedPolicy": skippedPol,
"failed": failed,
"bytesProcessed": bytes_,
"durationMs": durationMs,
"stoppedAtLimit": stoppedAtLimit,
"limits": limits,
}
def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]:
"""Build per-connection RAG inventory rows.
Each DataSource row exposes BOTH numbers because they mean different things:
* `fileCount` — distinct files indexed (== `FileContentIndex` rows)
* `chunkCount` — embedding-sized text fragments (== `ContentChunk` rows,
max `DEFAULT_CHUNK_TOKENS` tokens each, what the vector retrieval
actually hits)
A single PDF typically yields 1 file × 5100 chunks; legacy UI labelled
`len(FileContentIndex)` as "chunks" which was off by 12 orders of
magnitude and misleading.
"""
from modules.datamodels.datamodelDataSource import DataSource
from modules.datamodels.datamodelKnowledge import FileContentIndex
out = []
for conn in connections:
connectionId = str(conn.id)
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
connIndexRows = knowledgeIf.db.getRecordset(FileContentIndex, recordFilter={"connectionId": connectionId})
connFileTotal = len(connIndexRows)
# Map fileId → real chunk count via 1 aggregate query (cheap even for
# connections with thousands of files; we never load the vector body).
fileIds = [
(idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", ""))
for idx in connIndexRows
]
fileIds = [fid for fid in fileIds if fid]
chunkCountByFile = knowledgeIf.countChunksByFileIds(fileIds) if fileIds else {}
connChunkTotal = sum(chunkCountByFile.values())
filesByDs: Dict[str, int] = {}
chunksByDs: Dict[str, int] = {}
unassignedFiles = 0
unassignedChunks = 0
for idx in connIndexRows:
fileId = idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", "")
chunkCnt = chunkCountByFile.get(fileId, 0)
struct = (idx.get("structure") if isinstance(idx, dict) else getattr(idx, "structure", None)) or {}
ingestion = struct.get("_ingestion") or {} if isinstance(struct, dict) else {}
prov = ingestion.get("provenance") or {} if isinstance(ingestion, dict) else {}
dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else ""
if dsIdRef:
filesByDs[dsIdRef] = filesByDs.get(dsIdRef, 0) + 1
chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + chunkCnt
else:
unassignedFiles += 1
unassignedChunks += chunkCnt
seen: Dict[str, bool] = {}
dsItems = []
for ds in dataSources:
dsId = ds.get("id") if isinstance(ds, dict) else getattr(ds, "id", "")
dsPath = ds.get("path") if isinstance(ds, dict) else getattr(ds, "path", "")
if dsPath in seen:
continue
seen[dsPath] = True
dsItems.append({
"id": dsId,
"label": ds.get("label") if isinstance(ds, dict) else getattr(ds, "label", ""),
"path": dsPath,
"sourceType": ds.get("sourceType") if isinstance(ds, dict) else getattr(ds, "sourceType", ""),
"ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False),
"neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False),
"lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None),
"fileCount": filesByDs.get(dsId, 0),
"chunkCount": chunksByDs.get(dsId, 0),
})
# Spread orphan files (provenance lost) evenly so totals match.
if unassignedFiles > 0 and len(dsItems) > 0:
perFile = unassignedFiles // len(dsItems)
remFile = unassignedFiles % len(dsItems)
perChunk = unassignedChunks // len(dsItems)
remChunk = unassignedChunks % len(dsItems)
for i, item in enumerate(dsItems):
item["fileCount"] += perFile + (1 if i < remFile else 0)
item["chunkCount"] += perChunk + (1 if i < remChunk else 0)
# Pull a wider window than the previous 5 so the "last successful
# sync" is found even if a connection has many recent jobs queued.
jobs = jobService.listJobs(jobType="connection.bootstrap", limit=50)
connJobs = [j for j in jobs if (j.get("payload") or {}).get("connectionId") == connectionId]
runningJobs = [
{
"jobId": j["id"],
"progress": j.get("progress", 0),
# Server-side translate the structured walker payload into
# the request-context language; frontend renders 1:1 (no
# `t()` on backend-supplied keys).
"progressMessage": (
resolveJobMessage(j.get("progressMessageData"))
or j.get("progressMessage", "")
),
}
for j in connJobs
if j.get("status") in ("PENDING", "RUNNING")
]
lastError: Optional[Dict[str, Any]] = None
lastSuccess: Optional[Dict[str, Any]] = None
for j in connJobs:
status = j.get("status")
if status == "ERROR" and lastError is None:
lastError = {
"jobId": j["id"],
"errorMessage": j.get("errorMessage", ""),
"finishedAt": j.get("finishedAt"),
}
elif status == "SUCCESS" and lastSuccess is None:
# Bootstrap handlers may return either a flat dict (single
# service) or a nested dict keyed by sub-service (e.g. msft
# returns {"sharepoint": {...}, "outlook": {...}}). Flatten
# so the UI always sees aggregated counters and the first
# sub-service that hit a limit.
result = _flattenJobResult(j.get("result") or {})
lastSuccess = {
"jobId": j["id"],
"finishedAt": j.get("finishedAt"),
"indexed": result.get("indexed", 0),
"skippedDuplicate": result.get("skippedDuplicate", 0),
"skippedPolicy": result.get("skippedPolicy", 0),
"failed": result.get("failed", 0),
"durationMs": result.get("durationMs", 0),
# Surface limit-stop reason so the UI can warn the user
# that the index is provably incomplete (and which budget
# to raise). None means the walker finished naturally.
"stoppedAtLimit": result.get("stoppedAtLimit"),
"limits": result.get("limits") or {},
"bytesProcessed": result.get("bytesProcessed", 0),
}
if lastError and lastSuccess:
break
out.append({
"id": connectionId,
"authority": conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority),
"externalEmail": getattr(conn, "externalEmail", ""),
"knowledgeIngestionEnabled": getattr(conn, "knowledgeIngestionEnabled", False),
"preferences": getattr(conn, "knowledgePreferences", None) or {},
"dataSources": dsItems,
"totalFiles": connFileTotal,
"totalChunks": connChunkTotal,
"runningJobs": runningJobs,
"lastError": lastError,
"lastSuccess": lastSuccess,
})
return out
@router.get("/me")
@limiter.limit("30/minute")
def _getInventoryMe(
request: Request,
currentUser: User = Depends(getCurrentUser),
) -> Dict[str, Any]:
"""Personal RAG inventory: own connections + DataSources + chunk counts."""
try:
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService
rootIf = getRootInterface()
knowledgeIf = getKnowledgeInterface(None)
connections = rootIf.getUserConnections(currentUser.id)
items = _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items)
return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
except Exception as e:
logger.error("Error in RAG inventory /me: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/mandate")
@limiter.limit("20/minute")
def _getInventoryMandate(
request: Request,
context: RequestContext = Depends(getRequestContext),
) -> Dict[str, Any]:
"""Mandate-level RAG aggregation (requires mandate membership)."""
if not context.mandateId:
raise HTTPException(status_code=403, detail=routeApiMsg("Mandate context required"))
try:
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface, aggregateMandateRagTotalBytes
from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService
rootIf = getRootInterface()
knowledgeIf = getKnowledgeInterface(None)
mandateId = str(context.mandateId) if context.mandateId else ""
from modules.datamodels.datamodelUam import UserConnection
allConnections = rootIf.db.getRecordset(UserConnection, recordFilter={"mandateId": mandateId})
connectionObjects = [type("C", (), row)() if isinstance(row, dict) else row for row in allConnections]
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items)
totalBytes = aggregateMandateRagTotalBytes(mandateId)
return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes}}
except HTTPException:
raise
except Exception as e:
logger.error("Error in RAG inventory /mandate: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/platform")
@limiter.limit("10/minute")
def _getInventoryPlatform(
request: Request,
context: RequestContext = Depends(getRequestContext),
) -> Dict[str, Any]:
"""Platform-wide RAG statistics (sysadmin only)."""
if not context.isSysAdmin:
raise HTTPException(status_code=403, detail=routeApiMsg("Platform admin required"))
try:
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService
from modules.datamodels.datamodelUam import UserConnection
rootIf = getRootInterface()
knowledgeIf = getKnowledgeInterface(None)
allConnections = rootIf.db.getRecordset(UserConnection)
connectionObjects = [type("C", (), row)() if isinstance(row, dict) else row for row in allConnections]
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items)
return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
except HTTPException:
raise
except Exception as e:
logger.error("Error in RAG inventory /platform: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/reindex/{connectionId}")
@limiter.limit("10/minute")
async def _reindexConnection(
request: Request,
connectionId: str,
currentUser: User = Depends(getCurrentUser),
) -> Dict[str, Any]:
"""Re-trigger bootstrap for a connection (re-index all ragIndexEnabled DataSources).
Submits a new connection.bootstrap job, regardless of previous failures.
Must be `async def` so `await startJob(...)` registers the `_runJob` task
in FastAPI's main event loop. A sync route would land in the worker
threadpool and `asyncio.run` would tear down the temporary loop right
after `create_task`, leaving the job stuck in PENDING forever.
"""
try:
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.serviceCenter.services.serviceBackgroundJobs import startJob
from modules.datamodels.datamodelDataSource import DataSource
rootIf = getRootInterface()
conn = rootIf.getUserConnectionById(connectionId)
if conn is None:
raise HTTPException(status_code=404, detail="Connection not found")
if str(conn.userId) != str(currentUser.id):
raise HTTPException(status_code=403, detail="Not your connection")
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
ragDs = [ds for ds in dataSources if (ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False))]
if not ragDs:
return {"status": "skipped", "reason": "no_rag_enabled_datasources"}
authority = conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority or "")
dsIds = [(ds.get("id") if isinstance(ds, dict) else getattr(ds, "id", "")) for ds in ragDs]
jobId = await startJob(
"connection.bootstrap",
{"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": dsIds},
triggeredBy=str(currentUser.id),
)
logger.info("Reindex triggered for connection %s (%d DataSources, jobId=%s)", connectionId, len(dsIds), jobId)
return {"status": "queued", "connectionId": connectionId, "dataSourceCount": len(dsIds), "jobId": jobId}
except HTTPException:
raise
except Exception as e:
logger.error("Error triggering reindex: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/jobs")
@limiter.limit("60/minute")
def _getActiveJobs(
request: Request,
currentUser: User = Depends(getCurrentUser),
) -> List[Dict[str, Any]]:
"""Active RAG jobs for the current user (used by header badge)."""
try:
from modules.serviceCenter.services.serviceBackgroundJobs import listJobs
from modules.interfaces.interfaceDbApp import getRootInterface
rootIf = getRootInterface()
connections = rootIf.getUserConnections(currentUser.id)
connectionMap = {str(c.id): c for c in connections}
connectionIds = set(connectionMap.keys())
jobs = listJobs(jobType="connection.bootstrap", limit=50)
active = []
for j in jobs:
if j.get("status") not in ("PENDING", "RUNNING"):
continue
payload = j.get("payload") or {}
connId = payload.get("connectionId")
if connId in connectionIds:
conn = connectionMap[connId]
active.append({
"jobId": j["id"],
"connectionId": connId,
"connectionLabel": getattr(conn, "displayLabel", None) or getattr(conn, "authority", connId),
"jobType": j.get("jobType", "connection.bootstrap"),
"progress": j.get("progress", 0),
"progressMessage": (
resolveJobMessage(j.get("progressMessageData"))
or j.get("progressMessage", "")
),
})
return active
except Exception as e:
logger.error("Error in RAG inventory /jobs: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))